{-# LANGUAGE ImplicitParams #-} -- | An implementation of nested data parallelism module Control.CUtils.DataParallel where import Data.Array hiding (index) import Data.Tree import Control.CUtils.Conc import Control.Monad import System.IO.Unsafe import Prelude hiding (zip, concat, and) import qualified Prelude as P -- | The array interface data ArrC t = ArrC !(Array Int t) !(Forest Int) -- | Inject a basic array into the ArrC type. inject ar = ArrC ar [Node 0 [], Node (uncurry subtract (bounds ar) + 1) []] -- | Get a basic array out. project (ArrC ar _) = ar -- | Convenience for making an array from a list. newArray ls = listArray (0, length ls - 1) ls mirror x = either Right Left x pairUp ls = P.zip ls (tail ls) -- | Programs involving these array operations are optimized -- by a set of rules when GHC's -O option is set. Use +RTS -N to get parallelism. {-# INLINE [0] mp #-} mp f (ArrC ar ls) = ArrC (unsafePerformIO $ let ?seq = True in conc $ fmap ((return $!) . f) ar) ls {-# INLINE [0] count #-} count n = inject $ unsafePerformIO $ let ?seq = True in concF n (return $!) {-# INLINE [0] index #-} index (ArrC ar _) i = ar ! i {-# INLINE [0] zip #-} zip (ArrC ar _) (ArrC ar2 _) = inject $ unsafePerformIO $ let ?seq = True in concF (snd (bounds ar) `min` snd (bounds ar2)) (\i -> let x = ar ! i; y = ar2 ! i in x `seq` y `seq` return $! (x, y)) {-# INLINE [0] concat #-} concat ar0 = ArrC ar [ Node (i + j) ls3 | Node i ls2 <- ls, Node j ls3 <- ls2 ] where ArrC ar ls = __pack ar0 -- | Associative fold {-# INLINE [0] fold #-} fold f g init ar = unsafePerformIO $ assocFold (\y z -> return $! f y z) g init $ project ar -- | Control.Arrow substitutes {-# INLINE [0] first #-} first f (x, y) = (f x, y) {-# INLINE [0] second #-} second f (x, y) = (x, f y) {-# INLINE [0] left #-} left f = either (Left . f) Right {-# INLINE [0] right #-} right f = either Left (Right . f) {-# INLINE [0] and #-} and f g x = (f x, g x) -- | Internals {-# INLINE [0] __pack #-} __pack (ArrC ar ls) = ArrC (newArray $ concatMap (elems . project) $ elems ar) (zipWith Node (scanl (\i (ArrC ar _) -> i + rangeSize (bounds ar)) 0 $ elems ar) (map (\(ArrC _ ls) -> ls) (elems ar) ++ [[]])) {-# INLINE [0] __unpack #-} __unpack (ArrC ar ls) = inject $ unsafePerformIO $ let ?seq = True in conc $ fmap (\(Node i ls, Node j _) -> liftM (\ar -> ArrC ar ls) $ concF (j-i) $ \k -> return $! ar ! (k+i)) $ newArray $ pairUp ls {-# INLINE [0] __packProd #-} __packProd (x, y) = inject $ newArray [Left x, Right y] {-# INLINE [0] __unpackProd #-} __unpackProd ar = (case project ar ! 0 of Left x -> x, case project ar ! 1 of Right x -> x) {-# INLINE [0] __packSum1 #-} __packSum1 (Left x) = inject (newArray [Left x]) __packSum1 (Right ar) = mp Right ar {-# INLINE [0] __unpackSum1 #-} __unpackSum1 ar = either Left (\_ -> Right (mp (\(Right x) -> x) ar)) (project ar ! 0) {-# INLINE [0] __packSum2 #-} __packSum2 x = mp mirror (__packSum1 (mirror x)) {-# INLINE [0] __unpackSum2 #-} __unpackSum2 x = mirror (__unpackSum1 (mp mirror x)) {-# RULES "packMap" [2] forall f x. mp (mp f) x = __unpack (mp f (__pack x)) "packProd" [2] forall f x. first f x = __unpackProd (mp (left f) (__packProd x)) "packProd2" [2] forall f x. second f x = __unpackProd (mp (right f) (__packProd x)) "packSum" [2] forall f x. right (mp f) x = __unpackSum1 (mp (right f) (__packSum1 x)) "packSum2" [2] forall f x. left (mp f) x = __unpackSum2 (mp (left f) (__packSum2 x)) "sepMapComp" [2] forall f g x. mp (f . g) x = mp f (mp g x) "sepMapProd" [2] forall f ar. mp (and f id) ar = zip (mp f ar) ar "sepMapProd2" [2] forall f ar. mp (and id f) ar = zip ar (mp f ar) "sepSum" [2] forall f g x. left (f . g) x = left f (left g x) "sepSum2" [2] forall f g x. right (f . g) x = right f (right g x) "combMapComp" [1] forall f g x. mp f (mp g x) = mp (f . g) x "combMapProd" [1] forall f ar. zip (mp f ar) ar = mp (and f id) ar "combMapProd2" [1] forall f ar. zip ar (mp f ar) = mp (and id f) ar "combSum" [1] forall f g x. left f (left g x) = left (f . g) x "combSum2" [1] forall f g x. right f (right g x) = right (f . g) x "unpackPack" [1] forall x. __pack (__unpack x) = x "unpackProd" [1] forall x. __packProd (__unpackProd x) = x "unpackSum" [1] forall x. __packSum1 (__unpackSum1 x) = x "unpackSum2" [1] forall x. __packSum2 (__unpackSum2 x) = x "zip" [1] forall f x y. mp (\y -> f (fst y)) (zip x y) = mp f x "zip2" [1] forall f x y. mp (\y -> f (snd y)) (zip x y) = mp f y "index" [1] forall f ar i. index (mp f ar) i = f (index ar i) "concatConcat" [1] forall x. concat (mp concat x) = concat (concat x) "fold" [1] forall f g h x y. fold f g x (mp h y) = fold f (g . h) x y #-}