module Control.CUtils.DataParallel (Equal(Equal),
ArrC, newArray, inject, project,
Structural, A, unA, mapA', liftA, countA, countA', splitOff, assoc, indexA, zipA, unzipA, concatA, dupA, fstA, sndA, eval,
nQueens, sorting, permute, dotProduct, transpose') where
import qualified Data.Sequence as S
import Data.Array
import Data.List
import Data.Monoid (Any(Any))
import Data.Foldable (toList)
import Control.Parallel
import Control.Parallel.Strategies
import Control.Category
import Control.Arrow
import Control.Monad.Writer (Writer, tell, runWriter)
import Control.Monad.Identity
import Control.Monad
import Control.CUtils.Conc
import Control.CUtils.StrictArrow
import Prelude hiding (id, (.))
data Tree t = Node !t !(S.Seq (Tree t))
instance Functor Tree where
	
	
	
	
	fmap f (Node x sq) =
		let sq' = fastConcat (return . fmap f) sq in
		(toList sq' `using` parList rseq) `pseq` Node (f x) sq'
data ArrC t = ArrC !(Array Int t) !(S.Seq (Tree Int)) deriving Functor
newArray ls = listArray (0, length ls  1) ls
inject ar = ArrC (ixmap (0, uncurry subtract (bounds ar)) (subtract (fst (bounds ar))) ar) (S.fromList [Node 0 S.empty, Node (uncurry subtract (bounds ar) + 1) S.empty])
project (ArrC ar _) = ar
instance Show (t -> u) where
	showsPrec _ _ = ("<FUNCTION>"++)
data Structural a t u where
	Map :: Structural a t u -> Structural a (ArrC t) (ArrC u)
	Comp :: Structural a u v -> Structural a t u -> Structural a t v
	Id :: Structural a t t
	Product :: Structural a t u -> Structural a v w -> Structural a (t, v) (u, w)
	Lift :: a t u -> Structural a t u
	Count :: Structural a (t, [Int]) (ArrC (t, [Int]))
	Index :: Structural a (ArrC t, Int) t
	Split :: Structural a (ArrC t, Array Int Int) (ArrC t)
	
	ClearMarks :: Structural a (ArrC t) (ArrC t)
	Separate :: Structural a (Either t u) (ArrC t, ArrC u)
	Combine :: Structural a (ArrC t, ArrC u) (Either t u)
	Pack :: Structural a (ArrC (ArrC t)) (ArrC t)
	Unpack :: Structural a (ArrC t) (ArrC (ArrC t))
	Dup :: Structural a t (t, t)
	Fst :: Structural a (t, u) t
	Snd :: Structural a (t, u) u
data A a t u = A (forall v. Structural a v t -> Structural a v u)
sHead sq = case S.viewl sq of x S.:< _ -> x
sTail sq = case S.viewl sq of
	_ S.:< xs -> xs
	S.EmptyL -> S.empty
sLast sq = case S.viewr sq of _ S.:> x -> x
fromTo :: Int -> Int -> S.Seq t -> S.Seq t
fromTo n1 n2 sq =
	let (sq1, _) = S.splitAt n2 sq in
		snd$S.splitAt n1 sq1
pairUp sq = S.zip sq (sTail sq)
fastConcat :: (t -> S.Seq u) -> S.Seq t -> S.Seq u
fastConcat f sq = case S.length sq of
	0 -> S.empty
	1 -> f (sHead sq)
	n -> let
		(sq1, sq2) = S.splitAt (n `div` 2) sq
		cc1 = fastConcat f sq1
		cc2 = fastConcat f sq2 in
		(cc1 `par` cc2) `pseq` (cc1 S.>< cc2)
data Equal t u where
	Equal :: Equal t t
reassociate :: (Category a) => Structural a u v -> Either (Equal t u) (Structural a t u) -> Structural a t v
reassociate (Comp a Id) = reassociate a
reassociate (Comp a a2) = reassociate a . Right . reassociate a2
reassociate a = either (\Equal -> a) (a.)
unA (A f) = f id
unA' :: A a u v -> Structural a t u -> Structural a t v
unA' (A f) = f
mapA' :: (ArrowChoice a) => A a t u -> A a (ArrC t) (ArrC u)
mapA' (A f) = mapA (f id)
liftA :: (Category a) => a t u -> A a t u
liftA a = A (\a2 -> case a2 of
	Comp (Lift a2) a3 -> Comp (Lift (a . a2)) a3
	_ -> Lift a . a2)
pack :: (Category a) => A a (ArrC (ArrC t)) (ArrC t)
pack = A (\a -> case a of
	Comp (Map (Comp (Map a) a2)) a3 -> Map a . unA' pack (Map a2 . a3)
	Comp (Map (Map a)) a2 -> Map a . unA' pack a2
	Comp (Map (Comp Pack a)) a2 -> unA' pack (unA' pack (Map a . a2))
	Comp (Map Pack) a2 -> unA' pack (unA' pack a2)
	Comp Unpack a2 -> a2
	_ -> Pack . a)
flatten :: Structural a t u -> Bool
flatten (Comp a a2) = flatten a || flatten a2
flatten Id = False
flatten Unpack = False
flatten Pack = False
flatten Separate = False
flatten Combine = False
flatten _ = True
mapA :: (ArrowChoice a) => Structural a t u -> A a (ArrC t) (ArrC u)
mapA (Map a) | flatten a = A (\a2 -> case a2 of
	Comp Unpack a3 -> Unpack . unA' (mapA a) a3
	Comp Split a3 -> Split . unA' (first (mapA' (mapA a))) a3
	Comp ClearMarks a3 -> ClearMarks . unA' (mapA' (mapA a)) a3
	_ -> Unpack . unA' (mapA a) (unA' pack a2))
mapA (Product a a2) = A (\a3 -> case a3 of
	Comp Count a4 -> Comp (Map (Product Id a2)) (unA' (countA . first (A (Comp a))) a4)
	Comp ClearMarks a4 -> ClearMarks . unA' (mapA (Product a a2)) a4
	Comp (Map (Product a4 a5)) a6 -> unA' (mapA (Product (a . a4) (a2 . a5))) a6
	Comp (Map (Comp (Product a4 a5) a6)) a7 -> unA' (mapA (Product (a . a4) (a2 . a5) . a6)) a7
	_ -> Map (Product a a2) . a3)
mapA (Comp a a2) = mapA a . mapA a2
mapA Id = id
mapA Unpack = A (\a -> case a of
	Comp Unpack a -> Unpack . (Unpack . a)
	Comp (Map (Comp Pack a)) a2 -> Map a . a2
	Comp (Map Pack) a -> a
	Comp ClearMarks a3 -> ClearMarks . unA' (mapA Unpack) a3
	_ -> Map Unpack . a)
mapA a = A (\a2 -> case a2 of
	Comp (Map a2) a3 -> Comp (Map (reassociate a (Right a2))) a3
	Comp ClearMarks a3 -> ClearMarks . unA' (mapA a) a3
	_ -> Comp (Map a) a2)
scrubIds (Comp Id x) = scrubIds x
scrubIds x = x
instance (Category a) => Category (A a) where
	id = A (\a -> a)
	A f . A g = A (f . scrubIds . g)
instance (ArrowChoice a) => Arrow (A a) where
	arr = liftA . arr
	A f *** A g = A (\a -> case a of
		Comp (Product a2 a3) a4 -> Product (f a2) (g a3) . a4
		_ -> Product (f id) (g id) . a)
	first a = a *** id
	second a = id *** a
	a &&& a2 = (a *** a2) . dupA
instance (ArrowChoice a) => ArrowChoice (A a) where
	a +++ a2 = A (\a3 -> case a3 of
		Comp Combine a3 -> Combine . unA' (mapA (unA a) *** mapA (unA a2)) a3
		_ -> Combine . unA' (mapA (unA a) *** mapA (unA a2)) (Separate . a3))
	left a = a +++ id
	right a = id +++ a
instance Show (Structural a t u) where
	showsPrec prec (Map a) = ("Map " ++) . showParen (prec==11) (showsPrec 11 a)
	showsPrec _ (Comp a a2) = showsPrec 11 a . (" . "++) . showsPrec 11 a2
	showsPrec prec (Product a a2) = showParen (prec>=3) (showsPrec 3 a . (" *** "++) . showsPrec 3 a2)
	showsPrec _ Count = ("Count"++)
	showsPrec _ Index = ("Index"++)
	showsPrec _ Split = ("Split"++)
	showsPrec _ ClearMarks = ("Clr"++)
	showsPrec _ Pack = ("Pk"++)
	showsPrec _ Unpack = ("Unpk"++)
	showsPrec _ Separate = ("Sep"++)
	showsPrec _ Combine = ("Comb"++)
	showsPrec _ Dup = ("Dup"++)
	showsPrec _ Fst = ("Fst"++)
	showsPrec _ Snd = ("Snd"++)
	showsPrec _ Id = ("Id"++)
	showsPrec _ _ = ("_"++)
instance (Category a) => Category (Structural a) where
	id = Id
	(.) = Comp
mirror ei = either Right Left ei
splitOff :: (ArrowChoice a) => A a ((t1, t2), u) ((t1, u), (t2, u))
splitOff = first fstA &&& first sndA
assoc :: (ArrowChoice a) => A a ((t, u), v) (t, (u, v))
assoc = fstA . fstA &&& (sndA . fstA &&& sndA)
indexA :: (ArrowChoice a) => A a (ArrC t, Int) t
indexA = A (\a -> case a of
	Comp (Product (Map a) a2) a3 -> a . unA' indexA (Product Id a2 . a3)
	
	Comp (Product Count a) a2 -> unA' (fstA . fstA &&& arr (\(ns, i) -> snd (mapAccumL divMod i ns)) . (sndA . fstA &&& sndA)) (Product Id a . a2)
	_ -> Index . a)
zipA :: (ArrowChoice a) => A a (ArrC t, ArrC u) (ArrC (t, u))
zipA = id &&& arr (\(ar, ar2) -> (uncurry subtract (bounds (project ar)) `min` uncurry subtract (bounds (project ar2))) + 1)
	>>> countA'
	>>> mapA' (splitOff >>> indexA *** indexA)
unzipA :: (ArrowChoice a) => A a (ArrC (t, u)) (ArrC t, ArrC u)
unzipA = mapA' fstA &&& mapA' sndA
concatA :: (Category a) => A a (ArrC (ArrC t)) (ArrC t)
concatA = A (\a -> case a of
	Comp Split a2 -> unA' fstA a2
	_ -> Comp ClearMarks a) . pack
forcePair (x, y) = x `seq` y `seq` (x, y)
countA :: (ArrowChoice a) => A a (t, [Int]) (ArrC (t, [Int]))
countA = A(Comp Count)
countA' :: (ArrowChoice a) => A a (t, Int) (ArrC (t, Int))
countA' = second (arr return) >>> countA >>> mapA' (second (arr head))
dupA :: (Category a) => A a t (t, t)
dupA = A (Dup .)
fstA :: (Category a) => A a (t, u) t
fstA = A (\a -> case a of
	Comp Dup a -> a
	Comp (Product Id a) a2 -> Fst . (Product Id a . a2)
	Comp (Product a Id) a2 -> a . unA' fstA a2
	Comp (Product a a2) a3 -> a . unA' fstA (Product Id a2 . a3) 
	_ -> Fst . a)
sndA :: (Category a) => A a (t, u) u
sndA = A (\a -> case a of
	Comp Dup a -> a
	Comp (Product a Id) a2 -> Snd . (Product a Id . a2)
	Comp (Product Id a) a2 -> a . unA' sndA a2
	Comp (Product a a2) a3 -> a2 . unA' sndA (Product a Id . a3)
	_ -> Snd . a)
binarySearch :: (Ord t) => t -> S.Seq t -> Int
binarySearch x sq = recurse 0 (S.length sq) sq where
	recurse off sz sq = if sz <= 1 then
			off
		else let
			sz' = sz `div` 2
			(sq1, sq2) = S.splitAt sz' sq
			y S.:< _ = S.viewl sq2 in
			if x < y then
				recurse off sz' sq1
			else
				recurse (off + sz') (sz  sz') sq2
packImpl (ArrC ar fr) = ArrC
	(arr_concF (\(_, i) -> let
		j = binarySearch i fr''
		i2 = S.index fr'' j
		ArrC ar' _ = ar ! j in
		ar' ! (ii2))
		((), sz))
	fr'
	where
	fr' = S.fromList $ snd $ mapAccumL (\i (ArrC ar fr) -> let j = i + rangeSize (bounds ar) in (j, Node i (fastConcat ((return $!) . fmap (+i)) fr))) 0 $ elems ar ++ [ArrC (newArray []) S.empty]
	fr'' = fmap (\(Node i _) -> i) fr'
	_ S.:> sz = S.viewr fr''
unpackImpl (ArrC ar fr) = fastConcat
	(\(Node j fr2, Node k _) -> return $! ArrC (ixmap (0, kj1) (+j) ar) (fastConcat ((return $!) . fmap (subtract j)) fr2))
	(pairUp fr)
eval0 :: (Concurrent a, Strict a, ArrowChoice a, ?seq :: Bool, ?pool :: BoxedThreadPool) => Structural a t u -> a t u
eval0 Count = id &&& arr(snd>>>product) >>> arr_concF (arr (\((x, ns), i) -> (x, snd (mapAccumL divMod i ns)))) >>> arr inject
eval0 Index = arr (\(ArrC ar _, i) -> ar ! i)
eval0 ClearMarks =
	arr (\(ArrC ar fr) ->
		ArrC ar (fastConcat id (fmap (\(Node _ fr) -> fr) fr)))
eval0 (Map a) = (arr (\(ArrC ar _) -> (ar, uncurry subtract (bounds ar) + 1)) >>> arr_concF (arr (uncurry (!)) >>> eval0 a)) &&& arr (\(ArrC _ fr) -> fr) >>> arr (uncurry ArrC)
eval0 Split = undefined
eval0 Pack = arr packImpl
eval0 Unpack = arr (inject . newArray . toList . unpackImpl)
eval0 Separate = arr (\ei -> ((,) $! either (\x -> inject $ newArray [x]) (\_ -> inject $ newArray []) ei) $! either (\_ -> inject $ newArray []) (\x -> inject $ newArray [x]) ei)
eval0 Combine = arr (\(ar, ar2) -> let
	a1 = project ar
	a2 = project ar2 in
	if uncurry subtract (bounds (project ar)) == 0 then Left $! a1 ! 0 else Right $! a2 ! 0)
eval0 (Comp a a2) = force (eval0 a) . eval0 a2
eval0 Id = id
eval0 (Lift a) = a
eval0 (Product a a2) = arr forcePair . force (second (eval0 a2)) . arr forcePair . first (eval0 a)
eval0 Dup = arr (\x -> forcePair (x, x))
eval0 Fst = arr fst
eval0 Snd = arr snd
eval a = let ?seq = True in eval0 a
instance (Concurrent a, Strict a, ArrowChoice a, ArrowApply a) => ArrowApply (A a) where
	app = first (arr (eval . unA)) >>> liftA app where
		?pool = BoxedThreadPool NoPool
checkThreats n positions = n `elem` positions 
	|| n `elem` zipWith () positions [1..] 
	|| n `elem` zipWith (+) positions [1..] 
checkThreats2 positions = or [ checkThreats n tl | n:tl <- tails positions ]
nQueensImpl :: A (->) ((), [Int]) (ArrC [Int])
nQueensImpl = countA >>> mapA' (arr (\(_, soln) -> if checkThreats2 soln then inject (newArray []) else inject (newArray [soln])))
	>>> concatA
nQueens n = arr (\() -> ((), replicate n n)) >>> nQueensImpl
sorting :: (Ord t) => Int -> A (->) (ArrC t) (ArrC t)
sorting depth | depth <= 0 = arr (inject . newArray . sort . elems . project)
sorting depth = arr (\x -> if uncurry subtract (bounds (project x)) <= 0 then Left x else Right x)
	>>> id
		||| (arr (\ar -> let
			x:xs = elems (project ar)
			(bef, aft) = partition (<x) xs in
			((inject (newArray bef), inject (newArray aft)), x))
			>>> first (s *** s)
			>>> arr (\((bef, aft), x) -> inject (newArray (elems (project bef) ++ x : elems (project aft)))))
	where s = sorting (pred depth) 
permute :: A (->) (ArrC Int) (ArrC Int)
permute = arr (\ar -> (ar, [uncurry subtract (bounds (project ar)) + 1])) >>> countA >>> mapA' (second (arr head) >>> indexA)
dotProduct :: (Num t) => A (->) (ArrC t, ArrC t) t
dotProduct = proc (v1, v2) -> do
	vzip <- zipA -< (v1, v2)
	vdots <- mapA' (arr (uncurry (*))) -< vzip
	returnA -< sum $ elems $ project vdots
transpose' :: A (->) (ArrC (ArrC t)) (ArrC (ArrC t))
transpose' = proc m -> do
	firstrow <- indexA -< (m, 0)
	rows <- countA -< (m, [uncurry subtract (bounds (project firstrow)) + 1])
	
	rowcols <- mapA' (proc (m, [ii]) -> do
		v <- countA -< ((m, ii), [uncurry subtract (bounds (project m)) + 1])
		mapA' (proc ((m, ii), [jj]) -> returnA -< (m, (ii, jj))) -< v) -< rows
	
	mapA' (mapA' (proc (m, (ii, jj)) -> do
		v <- indexA -< (m, jj)
		indexA -< (v, ii)))
		-< rowcols