{-# LANGUAGE Rank2Types #-}
{-# LANGUAGE TypeFamilies #-}
module Data.Array.Knead.Simple.Private where

import qualified Data.Array.Knead.Shape as Shape
import qualified Data.Array.Knead.Expression as Expr
import Data.Array.Knead.Expression (Exp(Exp), )

import qualified LLVM.Extra.Multi.Value as MultiValue
import qualified LLVM.Extra.Iterator as Iter
import qualified LLVM.Extra.Maybe as Maybe
import qualified LLVM.Core as LLVM

import qualified Control.Category as Cat
import qualified Control.Monad.HT as Monad
import Control.Monad ((<=<), )

import Prelude hiding (id, map, zipWith, replicate, )


type Val = MultiValue.T
type Code r a = LLVM.CodeGenFunction r (Val a)

data Array sh a =
   Array (Exp sh) (forall r. Val (Shape.Index sh) -> Code r a)

shape :: Array sh a -> Exp sh
shape (Array sh _) = sh

(!) ::
   (Shape.C sh,  Shape.Index sh  ~ ix) =>
   Array sh a -> Exp ix -> Exp a
(!) (Array _ code) (Exp ix) = Exp (code =<< ix)

the :: (Shape.Scalar sh) => Array sh a -> Exp a
the (Array z code) = Exp (code $ Shape.zeroIndex z)

fromScalar :: (Shape.Scalar sh) => Exp a -> Array sh a
fromScalar = fill Shape.scalar


fill :: Exp sh -> Exp a -> Array sh a
fill sh (Exp code) = Array sh (\_z -> code)


{- |
This class allows to implement functions without parameters
for both simple and parameterized arrays.
-}
class C array where
   lift0 :: Array sh a -> array sh a
   lift1 :: (Array sha a -> Array shb b) -> array sha a -> array shb b
   lift2 ::
      (Array sha a -> Array shb b -> Array shc c) ->
      array sha a -> array shb b -> array shc c

instance C Array where
   lift0 = Cat.id
   lift1 = Cat.id
   lift2 = Cat.id


gather ::
   (C array,
    Shape.C sh0, Shape.Index sh0 ~ ix0,
    Shape.C sh1, Shape.Index sh1 ~ ix1,
    MultiValue.C a) =>
   array sh1 ix0 ->
   array sh0 a ->
   array sh1 a
gather =
   lift2 $ \(Array sh1 f) (Array _sh0 code) ->
      Array sh1 (code <=< f)

backpermute2 ::
   (C array,
    Shape.C sh0, Shape.Index sh0 ~ ix0,
    Shape.C sh1, Shape.Index sh1 ~ ix1,
    Shape.C sh,  Shape.Index sh  ~ ix) =>
   Exp sh ->
   (Exp ix -> Exp ix0) ->
   (Exp ix -> Exp ix1) ->
   (Exp a -> Exp b -> Exp c) ->
   array sh0 a -> array sh1 b -> array sh c
backpermute2 sh projectIndex0 projectIndex1 f =
   lift2 $ \(Array _sha codeA) (Array _shb codeB) ->
      Array sh
         (\ix ->
            Monad.liftJoin2 (Expr.unliftM2 f)
               (codeA =<< Expr.unliftM1 projectIndex0 ix)
               (codeB =<< Expr.unliftM1 projectIndex1 ix))


id ::
   (C array, Shape.C sh, Shape.Index sh ~ ix) =>
   Exp sh -> array sh ix
id sh = lift0 $ Array sh return

map ::
   (C array, Shape.C sh) =>
   (Exp a -> Exp b) ->
   array sh a -> array sh b
map f =
   lift1 $ \(Array sh code) ->
      Array sh (Expr.unliftM1 f <=< code)

mapWithIndex ::
   (C array, Shape.C sh, Shape.Index sh ~ ix) =>
   (Exp ix -> Exp a -> Exp b) ->
   array sh a -> array sh b
mapWithIndex f =
   lift1 $ \(Array sh code) ->
      Array sh (\ix -> Expr.unliftM2 f ix =<< code ix)


fold1Code ::
   (Shape.C sh, Shape.Index sh ~ ix, MultiValue.C a) =>
   (Exp a -> Exp a -> Exp a) ->
   Exp sh ->
   (Val ix -> Code r a) ->
   Code r a
fold1Code f (Exp nc) code = do
   n <- nc
   fmap Maybe.fromJust $
      Shape.loop
         (\i0 macc0 -> do
            a <- code i0
            acc1 <- Maybe.run macc0 (return a) (flip (Expr.unliftM2 f) a)
            return $ Maybe.just acc1)
         n Maybe.nothing

fold1 ::
   (C array, Shape.C sh0, Shape.C sh1, MultiValue.C a) =>
   (Exp a -> Exp a -> Exp a) ->
   array (sh0, sh1) a -> array sh0 a
fold1 f =
   lift1 $ \(Array shs code) ->
      case Expr.unzip shs of
         (sh, s) -> Array sh $ fold1Code f s . MultiValue.curry code


fold1All ::
   (Shape.C sh, MultiValue.C a) =>
   (Exp a -> Exp a -> Exp a) ->
   Array sh a -> Exp a
fold1All f (Array sh code) = Exp (fold1Code f sh code)


findAllCode ::
   (Shape.C sh, Shape.Index sh ~ ix, MultiValue.C a) =>
   (Exp a -> Exp Bool) ->
   Exp sh ->
   (Val ix -> Code r a) ->
   Code r (Maybe a)
findAllCode p (Exp sh) code = do
   n <- sh
   finalFound <-
      Iter.mapWhileState_
         (\a _found -> do
            MultiValue.Cons b <- Expr.unliftM1 p a
            notb <- LLVM.inv b
            return (notb, Maybe.fromBool b a))
         (Iter.mapM code $ Shape.iterator n)
         Maybe.nothing
   Maybe.run finalFound
      (return MultiValue.nothing)
      (return . MultiValue.just)

{- |
In principle this can be implemented using fold1All
but it has a short-cut semantics.
@All@ means that it scans all dimensions
but it does not mean that it finds all occurrences.
If you want to get the index of the found element,
please decorate the array elements with their indices before calling 'findAll'.
-}
findAll ::
   (Shape.C sh, MultiValue.C a) =>
   (Exp a -> Exp Bool) ->
   Array sh a -> Exp (Maybe a)
findAll p (Array sh code) = Exp (findAllCode p sh code)


class Process proc where


infixl 3 $:.

{- |
Use this for combining several dimension manipulators.
E.g.

> apply (passAny $:. pick 3 $:. pass $:. replicate 10) array

The constraint @(Process proc0, Process proc1)@ is a bit weak.
We like to enforce that the type constructor like @Slice.T@
is the same in @proc0@ and @proc1@, and only the parameters differ.
Currently this coherence is achieved,
because we only provide functions of type @proc0 -> proc1@ with this condition.
-}
($:.) :: (Process proc0, Process proc1) => proc0 -> (proc0 -> proc1) -> proc1
($:.) = flip ($)