{-# LANGUAGE ImplicitParams #-}
-- | An implementation of nested data parallelism
module Control.CUtils.DataParallel where

import Data.Array hiding (index)
import Data.Tree
import Control.CUtils.Conc
import Control.Monad
import System.IO.Unsafe
import Prelude hiding (zip, concat, and)
import qualified Prelude as P

-- | The array interface
data ArrC t = ArrC !(Array Int t) !(Forest Int)

-- | Inject a basic array into the ArrC type.
inject ar = ArrC ar [Node 0 [], Node (uncurry subtract (bounds ar) + 1) []]

-- | Get a basic array out.
project (ArrC ar _) = ar

-- | Convenience for making an array from a list.
newArray ls = listArray (0, length ls - 1) ls

mirror x = either Right Left x

pairUp ls = P.zip ls (tail ls)

-- | Programs involving these array operations are optimized
--   by a set of rules when GHC's -O option is set. Use +RTS -N to get parallelism.
{-# INLINE [0] mp #-}
mp f (ArrC ar ls) = ArrC (unsafePerformIO $ let ?seq = True in conc $ fmap ((return $!) . f) ar) ls

{-# INLINE [0] count #-}
count n = inject $ unsafePerformIO $ let ?seq = True in concF n (return $!)

{-# INLINE [0] index #-}
index (ArrC ar _) i = ar ! i

{-# INLINE [0] zip #-}
zip (ArrC ar _) (ArrC ar2 _) = inject $ unsafePerformIO $ let ?seq = True in concF (snd (bounds ar) `min` snd (bounds ar2))
	(\i -> let x = ar ! i; y = ar2 ! i in x `seq` y `seq` return $! (x, y))

{-# INLINE [0] concat #-}
concat ar0 = ArrC ar [ Node (i + j) ls3 | Node i ls2 <- ls, Node j ls3 <- ls2 ]
	where ArrC ar ls = __pack ar0

-- | Associative fold
{-# INLINE [0] fold #-}
fold f g init ar = unsafePerformIO $ assocFold (\y z -> return $! f y z) g init $ project ar

-- | Control.Arrow substitutes
{-# INLINE [0] first #-}
first f (x, y) = (f x, y)

{-# INLINE [0] second #-}
second f (x, y) = (x, f y)

{-# INLINE [0] left #-}
left f = either (Left . f) Right

{-# INLINE [0] right #-}
right f = either Left (Right . f)

{-# INLINE [0] and #-}
and f g x = (f x, g x)

-- | Internals
{-# INLINE [0] __pack #-}
__pack (ArrC ar ls) = ArrC (newArray $ concatMap (elems . project) $ elems ar)
	(zipWith Node (scanl (\i (ArrC ar _) -> i + rangeSize (bounds ar)) 0 $ elems ar)
		(map (\(ArrC _ ls) -> ls) (elems ar) ++ [[]]))

{-# INLINE [0] __unpack #-}
__unpack (ArrC ar ls) = inject $ unsafePerformIO $ let ?seq = True in conc $ fmap
	(\(Node i ls, Node j _) -> liftM (\ar -> ArrC ar ls) $ concF (j-i) $ \k -> return $! ar ! (k+i))
	$ newArray $ pairUp ls

{-# INLINE [0] __packProd #-}
__packProd (x, y) = inject $ newArray [Left x, Right y]

{-# INLINE [0] __unpackProd #-}
__unpackProd ar = (case project ar ! 0 of Left x -> x, case project ar ! 1 of Right x -> x)

{-# INLINE [0] __packSum1 #-}
__packSum1 (Left x) = inject (newArray [Left x])
__packSum1 (Right ar) = mp Right ar

{-# INLINE [0] __unpackSum1 #-}
__unpackSum1 ar = either Left (\_ -> Right (mp (\(Right x) -> x) ar)) (project ar ! 0)

{-# INLINE [0] __packSum2 #-}
__packSum2 x = mp mirror (__packSum1 (mirror x))

{-# INLINE [0] __unpackSum2 #-}
__unpackSum2 x = mirror (__unpackSum1 (mp mirror x))

{-# RULES

"packMap" [2] forall f x. mp (mp f) x = __unpack (mp f (__pack x))
"packProd" [2] forall f x. first f x = __unpackProd (mp (left f) (__packProd x))
"packProd2" [2] forall f x. second f x = __unpackProd (mp (right f) (__packProd x))
"packSum" [2] forall f x. right (mp f) x = __unpackSum1 (mp (right f) (__packSum1 x))
"packSum2" [2] forall f x. left (mp f) x = __unpackSum2 (mp (left f) (__packSum2 x))

"sepMapComp" [2] forall f g x. mp (f . g) x = mp f (mp g x)
"sepMapProd" [2] forall f ar. mp (and f id) ar = zip (mp f ar) ar
"sepMapProd2" [2] forall f ar. mp (and id f) ar = zip ar (mp f ar)
"sepSum" [2] forall f g x. left (f . g) x = left f (left g x)
"sepSum2" [2] forall f g x. right (f . g) x = right f (right g x)

"combMapComp" [1] forall f g x. mp f (mp g x) = mp (f . g) x
"combMapProd" [1] forall f ar. zip (mp f ar) ar = mp (and f id) ar
"combMapProd2" [1] forall f ar. zip ar (mp f ar) = mp (and id f) ar
"combSum" [1] forall f g x. left f (left g x) = left (f . g) x
"combSum2" [1] forall f g x. right f (right g x) = right (f . g) x
"unpackPack" [1] forall x. __pack (__unpack x) = x
"unpackProd" [1] forall x. __packProd (__unpackProd x) = x
"unpackSum" [1] forall x. __packSum1 (__unpackSum1 x) = x
"unpackSum2" [1] forall x. __packSum2 (__unpackSum2 x) = x

"zip" [1] forall f x y. mp (\y -> f (fst y)) (zip x y) = mp f x
"zip2" [1] forall f x y. mp (\y -> f (snd y)) (zip x y) = mp f y
"index" [1] forall f ar i. index (mp f ar) i = f (index ar i)
"concatConcat" [1] forall x. concat (mp concat x) = concat (concat x)
"fold" [1] forall f g h x y. fold f g x (mp h y) = fold f (g . h) x y
  #-}