{- |
Reduce selected dimensions.
Alternatively you may reorder dimensions with 'ShapeDep.backpermute'
and fold once along multiple dimensions.
-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE Rank2Types #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
module Data.Array.Knead.Simple.Fold (
   T,
   Cubic,
   apply,
   passAny,
   pass,
   fold,
   (Core.$:.),
   ) where

import qualified Data.Array.Knead.Simple.Private as Core
import Data.Array.Knead.Simple.Private (Array(Array), Code, Val, )

import qualified Data.Array.Knead.Shape.Cubic.Int as Index
import qualified Data.Array.Knead.Shape.Cubic as Cubic
import qualified Data.Array.Knead.Shape as Shape
import qualified Data.Array.Knead.Expression as Expr
import Data.Array.Knead.Shape.Cubic ((#:.), (:.)((:.)), )

import LLVM.DSL.Expression (Exp, unExp)

import qualified LLVM.Extra.Multi.Value as MultiValue
import LLVM.Extra.Multi.Value (atom, )

import qualified Type.Data.Num.Unary as Unary

import Prelude hiding (zipWith, zipWith3, zip, zip3, replicate, )


data T sh0 sh1 a =
   forall ix0 ix1.
   (Shape.Index sh0 ~ ix0, Shape.Index sh1 ~ ix1) =>
   Cons
      (Exp sh0 -> Exp sh1)
      (forall r. Val sh0 -> (Val ix0 -> Code r a) -> (Val ix1 -> Code r a))


apply ::
   (Core.C array, Shape.C sh0, Shape.C sh1, MultiValue.C a) =>
   T sh0 sh1 a ->
   array sh0 a ->
   array sh1 a
apply (Cons fsh reduce) =
   Core.lift1 $ \(Array sh code) ->
      Array (fsh sh) (\ix -> do sh0 <- unExp sh; reduce sh0 code ix)


type Cubic rank0 rank1 = T (Cubic.Shape rank0) (Cubic.Shape rank1)

passAny :: Cubic rank rank a
passAny = Cons id (const id)

pass ::
   (Unary.Natural rank0, Unary.Natural rank1, MultiValue.C a) =>
   Cubic rank0 rank1 a ->
   Cubic (Unary.Succ rank0) (Unary.Succ rank1) a
pass (Cons fsh reduce) =
   Cons
      (Expr.modify (atom:.atom) $ \(sh:.s) -> fsh sh :. s)
      (\sh code ->
       Cubic.switchR $ \jx j ->
          reduce (Cubic.tail sh) (\kx -> code (kx #:. j)) jx)


fold1CodeLinear ::
   (Unary.Natural rank, MultiValue.C a) =>
   (Exp a -> Exp a -> Exp a) ->
   Exp Index.Int ->
   (Val (Cubic.Index (Unary.Succ rank)) -> Code r a) ->
   (Val (Cubic.Index rank) -> Code r a)
fold1CodeLinear f nc code ix =
   Core.fold1Code f
      (Expr.lift1 (MultiValue.compose . Shape.ZeroBased) $ Index.decons nc)
      (\j -> code (ix #:. Index.cons j))

fold ::
   (Unary.Natural rank0, Unary.Natural rank1, MultiValue.C a) =>
   (Exp a -> Exp a -> Exp a) ->
   Cubic rank0 rank1 a ->
   Cubic (Unary.Succ rank0) rank1 a
fold f (Cons fsh reduce) =
   Cons
      (fsh . Cubic.tail)
      (\sh code jx ->
          reduce (Cubic.tail sh)
             (fold1CodeLinear f (Expr.lift0 (Cubic.head sh)) code) jx)


instance Core.Process (T sh0 sh1 a) where