module Control.CUtils.DataParallel (ArrC, inject, project, newArray, A(Count, Index, Zip, Unzip, Concat, Map, Comp, Arr, Prod, Sum), optimize, eval) where
import Data.Array
import Data.Tree
import Data.Monoid (Any(Any))
import Control.Category
import Control.Arrow
import Control.Monad.Writer (Writer, tell, runWriter)
import Control.Monad
import Control.CUtils.Conc
import System.IO.Unsafe
import Prelude hiding (id, (.))
data ArrC t = ArrC !(Array Int t) !(Forest Int)
inject ar = ArrC ar [Node 0 [], Node (uncurry subtract (bounds ar) + 1) []]
project (ArrC ar _) = ar
instance Functor ArrC where
fmap f (ArrC ar ls) = ArrC (fmap f ar) ls
newArray ls = listArray (0, length ls 1) ls
pairUp ls = zip ls (tail ls)
instance Show (t -> u) where
showsPrec _ _ = ("<FUNCTION>"++)
data A t u where
Count :: A Int (ArrC Int)
Index :: A (ArrC t, Int) t
Zip :: A (ArrC t, ArrC u) (ArrC (t, u))
Unzip :: A (ArrC (t, u)) (ArrC t, ArrC u)
Concat :: A (ArrC (ArrC t)) (ArrC t)
Map :: A t u -> A (ArrC t) (ArrC u)
Comp :: A u v -> A t u -> A t v
Arr :: (t -> u) -> A t u
Prod :: A t u -> A v w -> A (t, v) (u, w)
Sum :: A t u -> A v w -> A (Either t v) (Either u w)
Id :: A t t
Pack :: A (ArrC (ArrC t)) (ArrC t)
Unpack :: A (ArrC t) (ArrC (ArrC t))
PackProd :: A (t, u) (ArrC (Either t u))
UnpackProd :: A (ArrC (Either t u)) (t, u)
PackSum1 :: A (Either t (ArrC u)) (ArrC (Either t u))
UnpackSum1 :: A (ArrC (Either t u)) (Either t (ArrC u))
PackSum2 :: A (Either (ArrC t) u) (ArrC (Either t u))
UnpackSum2 :: A (ArrC (Either t u)) (Either (ArrC t) u)
mirror ei = either Right Left ei
deriving instance Show (A t u)
instance Category A where
id = arr id
(.) = Comp
instance Arrow A where
arr = Arr
(***) = Prod
first a = a *** arr id
second a = arr id *** a
instance ArrowChoice A where
(+++) = Sum
left a = a +++ arr id
right a = arr id +++ a
reassociate :: A u v -> A t u -> A t v
reassociate (Comp a a2) = reassociate a . reassociate a2
reassociate x = (x .)
step :: A t u -> A t u
step (Comp (Map (Comp a a2)) a3) = step (Map (step a)) . (Map a2 . a3)
step (Comp (Map (Prod a a2)) a3) = Zip . ((Map a *** Map a2) . (Unzip . a3))
step (Comp (Map a) a2) = step (Map (step a)) . a2
step (Comp Index (Prod (Map a) a2)) = step a . (Index . second a2)
step (Comp Index (Prod Count a)) = arr (\(i, j) -> if inRange (0, i 1) j then j else error $ "DataParallel.eval: bad index: " ++ show j) . second a
step (Comp Concat (Map (Map a))) = step (Map (step a)) . Concat
step (Comp Concat (Map Concat)) = Concat . Concat
step (Comp (Prod (Comp a a2) a3) a4) = step (Prod (step a) id) . (Prod a2 a3 . a4)
step (Comp (Prod a (Comp a2 a3)) a4) = step (Prod id (step a2)) . (Prod a a3 . a4)
step (Comp (Sum (Comp a a2) a3) a4) = step (Sum (step a) id) . (Sum a2 a3 . a4)
step (Comp (Sum a (Comp a2 a3)) a4) = step (Sum id (step a2)) . (Sum a a3 . a4)
step (Comp a (Comp a2 a3)) = case step (a . a2) of Comp a4 a5 -> a4 . step (a5 . a3)
step a = a
step2 :: A t u -> Writer Any (A t u)
step2 (Map (Map a)) = tell (Any True) >> liftM ((Unpack .) . (. Pack) . Map) (step2 a)
step2 (Prod a a2) = tell (Any True) >> liftM ((UnpackProd .) . (. PackProd)) (step2 (Map (Sum a a2)))
step2 (Sum a (Map a2)) = tell (Any True) >> liftM2 (\x y -> UnpackSum1 . Map (Sum x y) . PackSum1) (step2 a) (step2 a2)
step2 (Sum (Map a) a2) = tell (Any True) >> liftM2 (\x y -> arr mirror . UnpackSum1 . Map (Sum y x) . PackSum1 . arr mirror) (step2 a) (step2 a2)
step2 (Sum a a2) = liftM2 (+++) (step2 a) (step2 a2)
step2 (Map a) = liftM Map (step2 a)
step2 (Comp a a2) = liftM2 (.) (step2 a) (step2 a2)
step2 a = return a
step3 :: A t u -> Maybe (A t u)
step3 (Comp (Map a) (Comp (Map a2) a3)) = Just $ Map (repetition step3 (a . a2)) . a3
step3 (Comp Zip (Prod (Map a) (Map a2))) = Just $ Map (repetition step3 (a *** a2)) . Zip
step3 (Comp Zip (Prod Count Count)) = Just $ Map (arr (\x -> (x, x))) . (Count . arr (uncurry min))
step3 (Comp Zip (Comp Unzip a)) = Just a
step3 (Comp Pack (Comp Unpack a)) = Just a
step3 (Comp PackProd (Comp UnpackProd a)) = Just a
step3 (Comp PackSum1 (Comp UnpackSum1 a)) = Just a
step3 (Comp PackSum2 (Comp UnpackSum2 a)) = Just a
step3 (Comp (Sum a a2) (Sum a3 a4)) = Just $ repetition step3 (a . a3) +++ repetition step3 (a2 . a4)
step3 (Comp a (Comp a2 a3)) = liftM (a .) (step3 (a2 . a3))
step3 _ = Nothing
repetition f x = maybe x (repetition f) (f x)
repetition2 f x = if b then repetition2 f y else y where
(y, Any b) = runWriter (f x)
optimize = repetition2 (liftM (`reassociate` arr id) . step2 . step) . (`reassociate` arr id)
eval0 :: (?seq :: Bool) => A t u -> t -> u
eval0 Count n = inject $ unsafePerformIO $ concF n (return $!)
eval0 Index (ArrC ar _, i) = ar ! i
eval0 Zip (ArrC ar _, ArrC ar2 _) = inject $ unsafePerformIO $ concF (snd (bounds ar) `min` snd (bounds ar2))
(\i -> let x = ar ! i; y = ar2 ! i in x `seq` y `seq` return $! (x, y))
eval0 Unzip ar = (fmap fst ar, fmap snd ar)
eval0 Concat ar0 = ArrC ar [ Node (i + j) ls3 | Node i ls2 <- ls, Node j ls3 <- ls2 ] where ArrC ar ls = eval0 Pack ar0
eval0 (Map a) (ArrC ar ls) = ArrC (unsafePerformIO $ conc $ fmap ((return $!) . eval0 a) ar) ls
eval0 Pack (ArrC ar ls) = ArrC (newArray $ concatMap (elems . project) $ elems ar)
(zipWith Node (scanl (\i (ArrC ar _) -> i + rangeSize (bounds ar)) 0 $ elems ar)
(map (\(ArrC _ ls) -> ls) (elems ar) ++ [[]]))
eval0 Unpack (ArrC ar ls) = inject $ newArray $ map
(\(Node i ls, Node j _) -> ArrC (ixmap (0, ji1) (+i) ar) ls)
(pairUp ls)
eval0 PackProd (x, y) = inject $ newArray [Left x, Right y]
eval0 UnpackProd ar = (let Left x = project ar ! 0 in x, let Right x = project ar ! 1 in x)
eval0 PackSum1 (Left x) = inject (newArray [Left x])
eval0 PackSum1 (Right ar) = fmap Right ar
eval0 UnpackSum1 ar = either Left (\_ -> Right (fmap (\(Right x) -> x) ar)) (project ar ! 0)
eval0 PackSum2 ei = fmap mirror $ eval0 PackSum1 $ mirror ei
eval0 UnpackSum2 ar = mirror $ eval0 UnpackSum1 $ fmap mirror ar
eval0 (Comp a a2) x = eval0 a $ eval0 a2 x
eval0 (Arr f) x = f x
eval0 (Prod a a2) (x, y) = b `seq` c `seq` (b, c) where b = eval0 a x; c = eval0 a2 y
eval0 (Sum a a2) ei = either (Left . eval0 a) (Right . eval0 a2) ei
eval a = let ?seq = True in eval0 a