{-# LANGUAGE GADTs, Rank2Types, StandaloneDeriving #-} -- | An implementation of nested data parallelism module Control.Concurrent.DataParallel (ArrC, inject, project, newArray, A(Count, Index, Zip, Unzip, Concat, Map, Comp, Arr, Prod, Sum), optimize, eval) where import Data.Array import Data.Tree import Control.Category import Control.Arrow import Control.Monad import Control.Concurrent.Conc import System.IO.Unsafe import Prelude hiding (id, (.)) data ArrC t = ArrC !(Array Int t) !(Forest Int) inject ar = ArrC ar [Node 0 [], Node (uncurry subtract (bounds ar) + 1) []] project (ArrC ar _) = ar instance Functor ArrC where fmap f (ArrC ar ls) = ArrC (fmap f ar) ls newArray ls = listArray (0, length ls - 1) ls pairUp ls = zip ls (tail ls) instance Show (t -> u) where showsPrec _ _ = (""++) data A t u where -- | Constructors for caller's use Count :: A Int (ArrC Int) Index :: A (ArrC t, Int) t Zip :: A (ArrC t, ArrC u) (ArrC (t, u)) Unzip :: A (ArrC (t, u)) (ArrC t, ArrC u) Concat :: A (ArrC (ArrC t)) (ArrC t) Map :: A t u -> A (ArrC t) (ArrC u) Comp :: A u v -> A t u -> A t v Arr :: (t -> u) -> A t u Prod :: A t u -> A v w -> A (t, v) (u, w) Sum :: A t u -> A v w -> A (Either t v) (Either u w) -- Internal constructors Pack :: A (ArrC (ArrC t)) (ArrC t) Unpack :: A (ArrC t) (ArrC (ArrC t)) PackSum :: A (Either t (ArrC u)) (ArrC (Either t u)) UnpackSum :: A (ArrC (Either t u)) (Either t (ArrC u)) mirror ei = either Right Left ei deriving instance Show (A t u) instance Category A where id = arr id (.) = Comp instance Arrow A where arr = Arr (***) = Prod first a = a *** arr id second a = arr id *** a instance ArrowChoice A where (+++) = Sum left a = a +++ arr id right a = arr id +++ a reassociate :: A u v -> A t u -> A t v reassociate (Comp a a2) = reassociate a . reassociate a2 reassociate x = (x .) -- Optimizer step 1. Pushes indexes and concats to the right and separates maps/products/sums. step :: A t u -> A t u step (Comp (Map (Comp a a2)) a3) = step (Map (step a)) . (Map a2 . a3) step (Comp (Map (Prod a a2)) a3) = Zip . ((Map a *** Map a2) . (Unzip . a3)) step (Comp Index (Prod (Map a) a2)) = step a . (Index . second a2) step (Comp Index (Prod Count a)) = arr (\(i, j) -> if inRange (0, i - 1) j then j else error "DataParallel.eval: bad index") . second a step (Comp Concat (Map (Map a))) = Map (step a) . Concat step (Comp Concat (Map Concat)) = Concat . Concat step (Comp (Prod (Comp a a2) a3) a4) = step (Prod (step a) id) . (Prod a2 a3 . a4) step (Comp (Prod a (Comp a2 a3)) a4) = step (Prod id (step a2)) . (Prod a a3 . a4) step (Comp (Sum (Comp a a2) a3) a4) = step (Sum (step a) id) . (Sum a2 a3 . a4) step (Comp (Sum a (Comp a2 a3)) a4) = step (Sum id (step a2)) . (Sum a a3 . a4) step (Comp a (Comp a2 a3)) = case step (a . a2) of Comp a4 a5 -> a4 . step (a5 . a3) step a = a -- Optimizer step 2. Replaces nested arrays with the packed representation. step2 :: A t u -> A t u step2 (Map (Map a)) = Unpack . step2 (Map (step2 a)) . Pack step2 (Map a) = case step2 a of Map a -> Unpack . Map a . Pack a -> Map a step2 (Prod a a2) = Prod (step2 a) (step2 a2) -- Sums create the possibility of recursion trees w/ variable depth. step2 (Sum a (Map a2)) = UnpackSum . Map (Sum (step2 a) (step2 a2)) . PackSum step2 (Sum (Map a) a2) = arr mirror . step2 (Sum a2 (Map a)) . arr mirror step2 (Sum a a2) = Sum (step2 a) (step2 a2) step2 (Comp a a2) = step2 a . step2 a2 step2 a = a -- Optimizer step 3. Removes redundant packs and zips, combines maps/products/sums, pushes zips right. step3 :: A t u -> Maybe (A t u) step3 (Comp (Map a) (Comp (Map a2) a3)) = Just $ Map (repetition step3 (a . a2)) . a3 step3 (Comp Zip (Prod (Map a) (Map a2))) = Just $ Map (repetition step3 (a *** a2)) . Zip step3 (Comp Zip (Prod Count Count)) = Just $ Map (arr (\x -> (x, x))) . (Count . arr (uncurry min)) step3 (Comp Zip (Comp Unzip a)) = Just a step3 (Comp Pack (Comp Unpack a)) = Just a step3 (Comp PackSum (Comp UnpackSum a)) = Just a step3 (Comp (Prod a a2) (Prod a3 a4)) = Just $ repetition step3 (a . a3) *** repetition step3 (a2 . a4) step3 (Comp (Sum a a2) (Sum a3 a4)) = Just $ repetition step3 (a . a3) +++ repetition step3 (a2 . a4) step3 (Comp a (Comp a2 a3)) = liftM (a .) (step3 (a2 . a3)) step3 _ = Nothing repetition f x = maybe x (repetition f) (f x) -- | Optimizes an arrow for parallel execution. The arrow can be optimized once, and the result saved for multiple computations. -- (The exact output of the optimizer is subject to change.) optimize a = repetition step3 $ reassociate (step2 $ step $ reassociate a (arr id)) (arr id) -- | Evaluates arrows. eval :: A t u -> t -> u eval Count n = inject $ unsafePerformIO $ concF n (return $!) eval Index (ArrC ar _, i) = ar ! i eval Zip (ArrC ar _, ArrC ar2 _) = inject $ unsafePerformIO $ concF (snd (bounds ar) `min` snd (bounds ar2)) (\i -> let x = ar ! i; y = ar2 ! i in x `seq` y `seq` return $! (x, y)) eval Unzip ar = (fmap fst ar, fmap snd ar) eval Concat ar0 = ArrC ar [ Node (i + j) ls3 | Node i ls2 <- ls, Node j ls3 <- ls2 ] where ArrC ar ls = eval Pack ar0 eval (Map a) (ArrC ar ls) = ArrC (unsafePerformIO $ conc $ fmap ((return $!) . eval a) ar) ls eval Pack (ArrC ar ls) = ArrC (newArray $ concatMap (elems . project) $ elems ar) (zipWith Node (scanl (\i (ArrC ar _) -> i + rangeSize (bounds ar)) 0 $ elems ar) (map (\(ArrC _ ls) -> ls) (elems ar) ++ [[]])) eval Unpack (ArrC ar ls) = inject $ newArray $ map (\(Node i ls, Node j _) -> ArrC (ixmap (0, j-i-1) (+i) ar) ls) (pairUp ls) eval PackSum (Left x) = inject (newArray [Left x]) eval PackSum (Right ar) = fmap Right ar eval UnpackSum ar = either Left (\_ -> Right (fmap (\(Right x) -> x) ar)) (project ar ! 0) eval (Comp a a2) x = eval a $ eval a2 x eval (Arr f) x = f x eval (Prod a a2) (x, y) = b `seq` c `seq` (b, c) where b = eval a x; c = eval a2 y eval (Sum a a2) ei = either (Left . eval a) (Right . eval a2) ei