{-# LANGUAGE GADTs, Rank2Types, StandaloneDeriving, ImplicitParams #-} -- | An implementation of nested data parallelism module Control.CUtils.DataParallel (ArrC, inject, project, newArray, A(Count, Index, Zip, Unzip, Concat, Map, Comp, Arr, Prod, Sum), optimize, eval) where import Data.Array import Data.Tree import Data.Monoid (Any(Any)) import Control.Category import Control.Arrow import Control.Monad.Writer (Writer, tell, runWriter) import Control.Monad import Control.CUtils.Conc import System.IO.Unsafe import Prelude hiding (id, (.)) data ArrC t = ArrC !(Array Int t) !(Forest Int) inject ar = ArrC ar [Node 0 [], Node (uncurry subtract (bounds ar) + 1) []] project (ArrC ar _) = ar instance Functor ArrC where fmap f (ArrC ar ls) = ArrC (fmap f ar) ls newArray ls = listArray (0, length ls - 1) ls pairUp ls = zip ls (tail ls) instance Show (t -> u) where showsPrec _ _ = (""++) -- | Constructors for caller's use data A t u where Count :: A Int (ArrC Int) Index :: A (ArrC t, Int) t Zip :: A (ArrC t, ArrC u) (ArrC (t, u)) Unzip :: A (ArrC (t, u)) (ArrC t, ArrC u) Concat :: A (ArrC (ArrC t)) (ArrC t) Map :: A t u -> A (ArrC t) (ArrC u) Comp :: A u v -> A t u -> A t v Arr :: (t -> u) -> A t u Prod :: A t u -> A v w -> A (t, v) (u, w) Sum :: A t u -> A v w -> A (Either t v) (Either u w) -- Internal constructors Id :: A t t Pack :: A (ArrC (ArrC t)) (ArrC t) Unpack :: A (ArrC t) (ArrC (ArrC t)) PackProd :: A (t, u) (ArrC (Either t u)) UnpackProd :: A (ArrC (Either t u)) (t, u) PackSum1 :: A (Either t (ArrC u)) (ArrC (Either t u)) UnpackSum1 :: A (ArrC (Either t u)) (Either t (ArrC u)) PackSum2 :: A (Either (ArrC t) u) (ArrC (Either t u)) UnpackSum2 :: A (ArrC (Either t u)) (Either (ArrC t) u) mirror ei = either Right Left ei deriving instance Show (A t u) instance Category A where id = arr id (.) = Comp instance Arrow A where arr = Arr (***) = Prod first a = a *** arr id second a = arr id *** a instance ArrowChoice A where (+++) = Sum left a = a +++ arr id right a = arr id +++ a reassociate :: A u v -> A t u -> A t v reassociate (Comp a a2) = reassociate a . reassociate a2 reassociate x = (x .) -- Optimizer step 1. Pushes indexes and concats to the right and separates maps/products/sums. -- Once this is done, the result should be internal layers of only Maps. step :: A t u -> A t u step (Comp (Map (Comp a a2)) a3) = step (Map (step a)) . (Map a2 . a3) step (Comp (Map (Prod a a2)) a3) = Zip . ((Map a *** Map a2) . (Unzip . a3)) step (Comp (Map a) a2) = step (Map (step a)) . a2 step (Comp Index (Prod (Map a) a2)) = step a . (Index . second a2) step (Comp Index (Prod Count a)) = arr (\(i, j) -> if inRange (0, i - 1) j then j else error $ "DataParallel.eval: bad index: " ++ show j) . second a step (Comp Concat (Map (Map a))) = step (Map (step a)) . Concat step (Comp Concat (Map Concat)) = Concat . Concat step (Comp (Prod (Comp a a2) a3) a4) = step (Prod (step a) id) . (Prod a2 a3 . a4) step (Comp (Prod a (Comp a2 a3)) a4) = step (Prod id (step a2)) . (Prod a a3 . a4) step (Comp (Sum (Comp a a2) a3) a4) = step (Sum (step a) id) . (Sum a2 a3 . a4) step (Comp (Sum a (Comp a2 a3)) a4) = step (Sum id (step a2)) . (Sum a a3 . a4) step (Comp a (Comp a2 a3)) = case step (a . a2) of Comp a4 a5 -> a4 . step (a5 . a3) step a = a -- Optimizer step 2. Replaces nested arrays with the packed representation. -- The first two steps will be repeated, until there is only one layer of Maps. step2 :: A t u -> Writer Any (A t u) step2 (Map (Map a)) = tell (Any True) >> liftM ((Unpack .) . (. Pack) . Map) (step2 a) step2 (Prod a a2) = tell (Any True) >> liftM ((UnpackProd .) . (. PackProd)) (step2 (Map (Sum a a2))) -- Sums create the possibility of recursion trees w/ variable depth. step2 (Sum a (Map a2)) = tell (Any True) >> liftM2 (\x y -> UnpackSum1 . Map (Sum x y) . PackSum1) (step2 a) (step2 a2) step2 (Sum (Map a) a2) = tell (Any True) >> liftM2 (\x y -> arr mirror . UnpackSum1 . Map (Sum y x) . PackSum1 . arr mirror) (step2 a) (step2 a2) step2 (Sum a a2) = liftM2 (+++) (step2 a) (step2 a2) step2 (Map a) = liftM Map (step2 a) step2 (Comp a a2) = liftM2 (.) (step2 a) (step2 a2) step2 a = return a -- Optimizer step 3. Removes redundant packs and zips, combines maps/products/sums, pushes zips right. step3 :: A t u -> Maybe (A t u) step3 (Comp (Map a) (Comp (Map a2) a3)) = Just $ Map (repetition step3 (a . a2)) . a3 step3 (Comp Zip (Prod (Map a) (Map a2))) = Just $ Map (repetition step3 (a *** a2)) . Zip step3 (Comp Zip (Prod Count Count)) = Just $ Map (arr (\x -> (x, x))) . (Count . arr (uncurry min)) step3 (Comp Zip (Comp Unzip a)) = Just a step3 (Comp Pack (Comp Unpack a)) = Just a step3 (Comp PackProd (Comp UnpackProd a)) = Just a step3 (Comp PackSum1 (Comp UnpackSum1 a)) = Just a step3 (Comp PackSum2 (Comp UnpackSum2 a)) = Just a step3 (Comp (Sum a a2) (Sum a3 a4)) = Just $ repetition step3 (a . a3) +++ repetition step3 (a2 . a4) step3 (Comp a (Comp a2 a3)) = liftM (a .) (step3 (a2 . a3)) step3 _ = Nothing repetition f x = maybe x (repetition f) (f x) repetition2 f x = if b then repetition2 f y else y where (y, Any b) = runWriter (f x) -- | Optimizes an arrow for parallel execution. The arrow can be optimized once, and the result saved for multiple computations. (The exact output of the optimizer is subject to change.) -- -- The arrow must be finitely examinable. optimize = {-repetition step3 . -}repetition2 (liftM (`reassociate` arr id) . step2 . step) . (`reassociate` arr id) eval0 :: (?seq :: Bool) => A t u -> t -> u eval0 Count n = inject $ unsafePerformIO $ concF n (return $!) eval0 Index (ArrC ar _, i) = ar ! i eval0 Zip (ArrC ar _, ArrC ar2 _) = inject $ unsafePerformIO $ concF (snd (bounds ar) `min` snd (bounds ar2)) (\i -> let x = ar ! i; y = ar2 ! i in x `seq` y `seq` return $! (x, y)) eval0 Unzip ar = (fmap fst ar, fmap snd ar) eval0 Concat ar0 = ArrC ar [ Node (i + j) ls3 | Node i ls2 <- ls, Node j ls3 <- ls2 ] where ArrC ar ls = eval0 Pack ar0 eval0 (Map a) (ArrC ar ls) = ArrC (unsafePerformIO $ conc $ fmap ((return $!) . eval0 a) ar) ls eval0 Pack (ArrC ar ls) = ArrC (newArray $ concatMap (elems . project) $ elems ar) (zipWith Node (scanl (\i (ArrC ar _) -> i + rangeSize (bounds ar)) 0 $ elems ar) (map (\(ArrC _ ls) -> ls) (elems ar) ++ [[]])) eval0 Unpack (ArrC ar ls) = inject $ newArray $ map (\(Node i ls, Node j _) -> ArrC (ixmap (0, j-i-1) (+i) ar) ls) (pairUp ls) eval0 PackProd (x, y) = inject $ newArray [Left x, Right y] eval0 UnpackProd ar = (let Left x = project ar ! 0 in x, let Right x = project ar ! 1 in x) eval0 PackSum1 (Left x) = inject (newArray [Left x]) eval0 PackSum1 (Right ar) = fmap Right ar eval0 UnpackSum1 ar = either Left (\_ -> Right (fmap (\(Right x) -> x) ar)) (project ar ! 0) eval0 PackSum2 ei = fmap mirror $ eval0 PackSum1 $ mirror ei eval0 UnpackSum2 ar = mirror $ eval0 UnpackSum1 $ fmap mirror ar eval0 (Comp a a2) x = eval0 a $ eval0 a2 x eval0 (Arr f) x = f x eval0 (Prod a a2) (x, y) = b `seq` c `seq` (b, c) where b = eval0 a x; c = eval0 a2 y eval0 (Sum a a2) ei = either (Left . eval0 a) (Right . eval0 a2) ei -- | Evaluates arrows. eval a = let ?seq = True in eval0 a