{-# LANGUAGE Safe, GADTs, Rank2Types, ImplicitParams, Arrows #-}
-- | An implementation of nested data parallelism (due to Simon Peyton Jones et al)
module Control.CUtils.DataParallel (
-- * Flattenable arrays
ArrC, newArray, inject, project,
-- * The arrows and associated operations
Structural, A, unA, mapA', liftA, countA, indexA, zipA, unzipA, concatA, eval,
-- * Examples
nQueens, sorting, permute) where

import Data.Array
import Data.List
import Data.Monoid (Any(Any))
import Control.Category
import Control.Arrow
import Control.Monad.Writer (Writer, tell, runWriter)
import Control.Monad.Identity
import Control.Monad
import Control.CUtils.Conc
import Control.CUtils.StrictArrow
import Prelude hiding (id, (.))

data Tree t = Node !t !(Array Int (Tree t))

data ArrC t = ArrC !(Array Int t) !(Array Int (Tree Int))

newArray ls = listArray (0, length ls - 1) ls

inject ar = ArrC (ixmap (0, uncurry subtract (bounds ar)) (subtract (fst (bounds ar))) ar) (newArray [Node 0 (newArray []), Node (uncurry subtract (bounds ar) + 1) (newArray [])])

project (ArrC ar _) = ar

instance Functor ArrC where
	fmap f (ArrC ar fr) = ArrC (fmap f ar) fr

instance Show (t -> u) where
	showsPrec _ _ = ("<FUNCTION>"++)

data Structural a t u where
	Map :: Structural a t u -> Structural a (ArrC t) (ArrC u)
	Comp :: Structural a u v -> Structural a t u -> Structural a t v
	Id :: Structural a t t
	Product :: Structural a t u -> Structural a v w -> Structural a (t, v) (u, w)
	Lift :: a t u -> Structural a t u
	Count :: Structural a (t, Int) (ArrC (t, Int))
	Index :: Structural a (ArrC t, Int) t
	Zip :: Structural a (ArrC t, ArrC u) (ArrC (t, u))
	Unzip :: Structural a (ArrC (t, u)) (ArrC t, ArrC u)
	ClearMarks :: Structural a (ArrC t) (ArrC t)
	Separate :: Structural a (Either t u) (ArrC t, ArrC u)
	Combine :: Structural a (ArrC t, ArrC u) (Either t u)
	Pack :: Structural a (ArrC (ArrC t)) (ArrC t)
	Unpack :: Structural a (ArrC t) (ArrC (ArrC t))

-- | The 'A' arrow includes a set of primitives that may be executed concurrently.
--   Programs are incrementally optimized as they are put together. A program may be
--   optimized once, and the result saved for repeated use.
--
-- Notes:
--
--   * The exact output of the optimizer is subject to change.
--
--   * The program must be a finite data structure, or optimization will diverge.
data A a t u = A (forall v. Structural a v t -> Structural a v u)

data Equal t u = (t ~ u) => Equal

reassociate :: (Category a) => Structural a u v -> Either (Equal t u) (Structural a t u) -> Structural a t v
reassociate (Comp a Id) = reassociate a
reassociate (Comp a a2) = reassociate a . Right . reassociate a2
reassociate a = either (\Equal -> a) (a.)

-- | Obtain a 'Structural' program from an 'A' program.
unA (A f) = f id

-- | Obtain a 'Structural' program but postcompose with another program. 
unA' :: A a u v -> Structural a t u -> Structural a t v
unA' (A f) = f

mapA' :: (ArrowChoice a) => A a t u -> A a (ArrC t) (ArrC u)
mapA' (A f) = mapA (f id)

liftA :: (Category a) => a t u -> A a t u
liftA a = A (\a2 -> case a2 of
	Comp (Lift a2) a3 -> Comp (Lift (a . a2)) a3
	_ -> Lift a . a2)

pack :: (Category a) => A a (ArrC (ArrC t)) (ArrC t)
pack = A (\a -> case a of
	Comp (Map (Comp (Map a) a2)) a3 -> Map a . unA' pack (Map a2 . a3)
	Comp (Map (Map a)) a2 -> Map a . unA' pack a2
	Comp (Map (Comp Pack a)) a2 -> unA' pack (unA' pack (Map a . a2))
	Comp (Map Pack) a2 -> unA' pack (unA' pack a2)
	Comp Unpack a2 -> a2
	_ -> Pack . a)

flatten :: Structural a t u -> Bool
flatten (Comp a a2) = flatten a || flatten a2
flatten Id = False
flatten Unpack = False
flatten Pack = False
flatten Zip = False
flatten Unzip = False
flatten Separate = False
flatten Combine = False
flatten _ = True

-- | Mapping is the primary way of constructing nested data parallel programs.
--   It applies an (arrow) transformation to each element of an array
--   uniformly. A form of flattening transformation is applied to nested
--   maps (following the NESL paper). The flattening transformation converts
--   two levels of 'Map' into one level.
mapA :: (ArrowChoice a) => Structural a t u -> A a (ArrC t) (ArrC u)
mapA (Map a) | flatten a = A (\a2 -> case a2 of
	Comp Unpack a3 -> Unpack . unA' (mapA a) a3
	_ -> Unpack . unA' (mapA a) (unA' pack a2))
mapA (Comp a a2) = mapA a . mapA a2
mapA Id = id
mapA (Product a a2) = zipA . (mapA a *** mapA a2) . unzipA
mapA Unpack = A (\a -> case a of
	Comp Unpack a -> Unpack . (Unpack . a)
	Comp (Map (Comp Pack a)) a2 -> Map a . a2
	Comp (Map Pack) a -> a
	_ -> Map Unpack . a)
mapA Count = A (Comp Unpack) . arr (\(ArrC ar fr) -> ArrC
	(newArray $ concatMap (\(x, n) -> map ((,) x) [0..n-1]) $ elems ar)
	(newArray $ zipWith Node (scanl (\n (_, m) -> n + m) 0 $ elems ar)
		(map (\(_, n) -> newArray [Node 0 (newArray []), Node n (newArray [])]) (elems ar) ++ [newArray []])))
mapA a = A (\a2 -> case a2 of
	Comp (Map a2) a3 -> Comp (Map (reassociate a (Right a2))) a3
	_ -> Comp (Map a) a2)

instance (Category a) => Category (A a) where
	id = A (\a -> a)
	A f . A g = A (f . g)

instance (ArrowChoice a) => Arrow (A a) where
	arr = liftA . arr
	A f *** A g = A (\a -> case a of
		Comp (Product a2 a3) a4 -> Product (f a2) (g a3) . a4
		_ -> Product (f id) (g id) . a)
	first a = a *** id
	second a = id *** a

instance (ArrowChoice a) => ArrowChoice (A a) where
	a +++ a2 = A (\a3 -> case a3 of
		Comp Combine a3 -> Combine . unA' (mapA (unA a) *** mapA (unA a2)) a3
		_ -> Combine . unA' (mapA (unA a) *** mapA (unA a2)) (Separate . a3))
	left a = a +++ id
	right a = id +++ a

instance Show (Structural a t u) where
	showsPrec prec (Map a) = ("Map " ++) . showParen (prec==11) (showsPrec 11 a)
	showsPrec _ (Comp a a2) = showsPrec 11 a . (" . "++) . showsPrec 11 a2
	showsPrec prec (Product a a2) = showParen (prec>=3) (showsPrec 3 a . (" *** "++) . showsPrec 3 a2)
	showsPrec _ Count = ("Count"++)
	showsPrec _ Index = ("Index"++)
	showsPrec _ Zip = ("Zip"++)
	showsPrec _ Unzip = ("Unzip"++)
	showsPrec _ ClearMarks = ("Clr"++)
	showsPrec _ Pack = ("Pk"++)
	showsPrec _ Unpack = ("Unpk"++)
	showsPrec _ Separate = ("Sep"++)
	showsPrec _ Combine = ("Comb"++)
	showsPrec _ _ = ("_"++)

instance (Category a) => Category (Structural a) where
	id = Id
	(.) = Comp

mirror ei = either Right Left ei

-- | Supplies an array of a repeated value paired with the index of each element.
countA = A (Comp Count)

-- | Access one index of an array.
indexA = A (Comp Index)

-- | An operation analogous to 'zip'.
zipA :: (Category a) => A a (ArrC t, ArrC u) (ArrC (t, u))
zipA = A (\a -> case a of
	Comp Unzip a2 -> a2
	Comp (Product (Map a) (Map a2)) a3 -> Map (Product a a2) . unA' zipA a3
	Comp (Product (Map a) Id) a3 -> Map (Product a Id) . unA' zipA a3
	Comp (Product Id (Map a2)) a3 -> Map (Product Id a2) . unA' zipA a3
	_ -> Zip . a)

-- | 'unzipA' and 'zipA' are inverses.
unzipA :: (Category a) => A a (ArrC (t, u)) (ArrC t, ArrC u)
unzipA = A (\a -> case a of
	Comp Zip a2 -> a2
	Comp (Map (Product a2 a3)) a4 -> Product (Map a2) (Map a3) . unA' unzipA a4
	_ -> Unzip . a)

concatA :: (Category a) => A a (ArrC (ArrC t)) (ArrC t)
concatA = A (Comp ClearMarks) . pack

forcePair (x, y) = x `seq` y `seq` (x, y)

-- | An evaluator for 'Structural' arrows.
eval0 :: (Concurrent a, Strict a, ArrowChoice a, ?seq :: Bool) => Structural a t u -> a t u
eval0 Count = arr_concF id >>> arr inject
eval0 Index = arr (\(ArrC ar _, i) -> ar ! i)
eval0 Zip = arr (\pr@(ArrC ar _, ArrC ar2 _) -> (pr, (snd (bounds ar) `min` snd (bounds ar2)) + 1))
	>>> arr_concF (arr (\((ArrC ar _, ArrC ar2 _), i) ->
	forcePair (ar ! i, ar2 ! i)))
	>>> arr inject
eval0 Unzip = arr (\ar -> (fmap fst ar, fmap snd ar))
eval0 ClearMarks =
	arr (\(ArrC ar fr) ->
		ArrC ar (newArray [ Node (i + j) fr3 | Node i fr2 <- elems fr, Node j fr3 <- elems fr2 ]))
eval0 (Map a) = (arr (\(ArrC ar _) -> (ar, uncurry subtract (bounds ar) + 1)) >>> arr_concF (arr (uncurry (!)) >>> eval0 a)) &&& arr (\(ArrC _ fr) -> fr) >>> arr (uncurry ArrC)
eval0 Pack = arr (\(ArrC ar _) -> ArrC (newArray $ concatMap (elems . project) $ elems ar)
	(newArray $ zipWith Node (scanl (\i (ArrC ar _) -> i + rangeSize (bounds ar)) 0 $ elems ar)
		(map (\(ArrC _ fr) -> fr) (elems ar) ++ [newArray []])))
eval0 Unpack = arr (\ arc@(ArrC _ fr) -> (arc, uncurry subtract (bounds fr))) >>> arr_concF (arr (\(ArrC ar fr, index) ->
	let
		Node i fr2 = fr ! index
		Node j _ = fr ! (index + 1) in
	ArrC (ixmap (0, j-i-1) (+i) ar) fr2))
	>>> arr inject
eval0 Separate = arr (\ei -> ((,) $! either (\x -> inject $ newArray [x]) (\_ -> inject $ newArray []) ei) $! either (\_ -> inject $ newArray []) (\x -> inject $ newArray [x]) ei)
eval0 Combine = arr (\(ar, ar2) -> let
	a1 = project ar
	a2 = project ar2 in
	if uncurry subtract (bounds (project ar)) == 0 then Left $! a1 ! 0 else Right $! a2 ! 0)
eval0 (Comp a a2) = force (eval0 a) . eval0 a2
eval0 Id = id
eval0 (Lift a) = a
eval0 (Product a a2) = arr forcePair . force (second (eval0 a2)) . arr forcePair . first (eval0 a)

-- | Evaluates arrows.
--
-- Notes:
--
--   * Effects are supported, but with much weaker semantics than the Kleisli arrows
--   of the monad. In particular, the 'Map' and '***' operations are allowed to be parallelized,
--   but on the other hand parallelism is not guaranteed.

eval a = let ?seq = True in eval0 a

instance (Concurrent a, Strict a, ArrowChoice a, ArrowApply a) => ArrowApply (A a) where
	app = first (arr (eval . unA)) >>> liftA app

-------------------------------

checkThreats n positions = n `elem` positions -- Check if there is a piece on the row
	|| n `elem` zipWith (-) positions [1..] -- ... the diagonal
	|| n `elem` zipWith (+) positions [1..] -- ... or the other diagonal

checkThreats2 positions = or [ checkThreats n tl | n:tl <- tails positions ]

nQueensImpl :: Int -> Int -> A (->) [Int] (ArrC [Int])
nQueensImpl _ n | n <= 0 = arr (\soln -> if checkThreats2 soln then inject (newArray []) else inject (newArray [soln]))
nQueensImpl m n = arr (\partialSoln -> (partialSoln, m)) >>> countA >>>
	mapA'
		(arr (uncurry (flip (:))) >>> nQueensImpl m (pred n))
	>>> concatA

nQueens n = arr (\() -> []) >>> nQueensImpl n n

-------------------------------

sorting :: (Ord t) => Int -> A (->) (ArrC t) (ArrC t)
sorting depth | depth <= 0 = arr (inject . newArray . sort . elems . project)
sorting depth = arr (\x -> if uncurry subtract (bounds (project x)) <= 0 then Left x else Right x)
	>>> id
		||| (arr (\ar -> let
			x:xs = elems (project ar)
			(bef, aft) = partition (<x) xs in
			((inject (newArray bef), inject (newArray aft)), x))
			>>> first (s *** s)
			>>> arr (\((bef, aft), x) -> inject (newArray (elems (project bef) ++ x : elems (project aft)))))
	where s = sorting (pred depth) -- Memoize the answer
-- In order to make this recursive function a finite structure, there is a depth limit
-- parameter, beyond which the standard 'sort' takes over.

-------------------------------

permute :: A (->) (ArrC Int) (ArrC Int)
permute = arr (\ar -> (ar, uncurry subtract (bounds (project ar)) + 1)) >>> countA >>> mapA' indexA