{-# LANGUAGE Rank2Types #-}
{-# LANGUAGE TypeFamilies #-}
module Data.Array.Knead.Simple.Private where

import qualified Data.Array.Knead.Index.Nested.Shape as Shape
import qualified Data.Array.Knead.Expression as Expr
import Data.Array.Knead.Expression (Exp(Exp), )

import qualified LLVM.Extra.Multi.Value as MultiValue
import qualified LLVM.Extra.Monad as Monad
import qualified LLVM.Extra.Maybe as Maybe
import qualified LLVM.Core as LLVM

import qualified Control.Category as Cat
import Control.Monad ((<=<), )

import Prelude hiding (id, map, zipWith, replicate, )


type Val = MultiValue.T
type Code r a = LLVM.CodeGenFunction r (Val a)

data Array sh a =
   Array (Exp sh) (forall r. Val (Shape.Index sh) -> Code r a)

shape :: Array sh a -> Exp sh
shape (Array sh _) = sh

(!) ::
   (Shape.C sh,  Shape.Index sh  ~ ix) =>
   Array sh a -> Exp ix -> Exp a
(!) (Array _ code) (Exp ix) = Exp (code =<< ix)

the :: (Shape.Scalar sh) => Array sh a -> Exp a
the (Array z code) = Exp (code $ Shape.zeroIndex z)

fromScalar :: (Shape.Scalar sh) => Exp a -> Array sh a
fromScalar = fill Shape.scalar


fill :: Exp sh -> Exp a -> Array sh a
fill sh (Exp code) = Array sh (\_z -> code)


{- |
This class allows to implement functions without parameters
for both simple and parameterized arrays.
-}
class C array where
   lift0 :: Array sh a -> array sh a
   lift1 :: (Array sha a -> Array shb b) -> array sha a -> array shb b
   lift2 ::
      (Array sha a -> Array shb b -> Array shc c) ->
      array sha a -> array shb b -> array shc c

instance C Array where
   lift0 = Cat.id
   lift1 = Cat.id
   lift2 = Cat.id


gather ::
   (C array,
    Shape.C sh0, Shape.Index sh0 ~ ix0,
    Shape.C sh1, Shape.Index sh1 ~ ix1,
    MultiValue.C a) =>
   array sh1 ix0 ->
   array sh0 a ->
   array sh1 a
gather =
   lift2 $ \(Array sh1 f) (Array _sh0 code) ->
      Array sh1 (code <=< f)

backpermute2 ::
   (C array,
    Shape.C sh0, Shape.Index sh0 ~ ix0,
    Shape.C sh1, Shape.Index sh1 ~ ix1,
    Shape.C sh,  Shape.Index sh  ~ ix) =>
   Exp sh ->
   (Exp ix -> Exp ix0) ->
   (Exp ix -> Exp ix1) ->
   (Exp a -> Exp b -> Exp c) ->
   array sh0 a -> array sh1 b -> array sh c
backpermute2 sh projectIndex0 projectIndex1 f =
   lift2 $ \(Array _sha codeA) (Array _shb codeB) ->
      Array sh
         (\ix ->
            Monad.liftR2 (Expr.unliftM2 f)
               (codeA =<< Expr.unliftM1 projectIndex0 ix)
               (codeB =<< Expr.unliftM1 projectIndex1 ix))


id ::
   (Shape.C sh, Shape.Index sh ~ ix) =>
   Exp sh -> Array sh ix
id sh = Array sh return

map ::
   (C array, Shape.C sh) =>
   (Exp a -> Exp b) ->
   array sh a -> array sh b
map f =
   lift1 $ \(Array sh code) ->
      Array sh (Expr.unliftM1 f <=< code)

mapWithIndex ::
   (C array, Shape.C sh, Shape.Index sh ~ ix) =>
   (Exp ix -> Exp a -> Exp b) ->
   array sh a -> array sh b
mapWithIndex f =
   lift1 $ \(Array sh code) ->
      Array sh (\ix -> Expr.unliftM2 f ix =<< code ix)


fold1Code ::
   (Shape.C sh1, Shape.Index sh1 ~ ix1, MultiValue.C a) =>
   (Exp a -> Exp a -> Exp a) ->
   Exp sh1 ->
   (Val ix0 -> Val ix1 -> Code r a) ->
   (Val ix0 -> Code r a)
fold1Code f (Exp nc) code ix = do
   n <- nc
   fmap Maybe.fromJust $
      Shape.loop
         (\i0 macc0 -> do
            a <- code ix i0
            acc1 <- Maybe.run macc0 (return a) (flip (Expr.unliftM2 f) a)
            return $ Maybe.just acc1)
         n Maybe.nothing

fold1 ::
   (C array, Shape.C sh0, Shape.C sh1, MultiValue.C a) =>
   (Exp a -> Exp a -> Exp a) ->
   array (sh0, sh1) a -> array sh0 a
fold1 f =
   lift1 $ \(Array shs code) ->
      case Expr.unzip shs of
         (sh, s) -> Array sh $ fold1Code f s $ MultiValue.curry code


fold1All ::
   (Shape.C sh, MultiValue.C a) =>
   (Exp a -> Exp a -> Exp a) ->
   Array sh a -> Array () a
fold1All f (Array esh code) =
   fold1 f $
   Array
      (Expr.lift1 (MultiValue.zip (MultiValue.Cons ())) esh)
      (code . MultiValue.snd)



class Process proc where


infixl 3 $:.

{- |
Use this for combining several dimension manipulators.
E.g.

> apply (passAny $:. pick 3 $:. pass $:. replicate 10) array

The constraint @(Process proc0, Process proc1)@ is a bit weak.
We like to enforce that the type constructor like @Slice.T@
is the same in @proc0@ and @proc1@, and only the parameters differ.
Currently this coherence is achieved,
because we only provide functions of type @proc0 -> proc1@ with this condition.
-}
($:.) :: (Process proc0, Process proc1) => proc0 -> (proc0 -> proc1) -> proc1
($:.) = flip ($)