{- |
Reduce selected dimensions.
Alternatively you may reorder dimensions with 'ShapeDep.backpermute'
and fold once along multiple dimensions.
-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE Rank2Types #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
module Data.Array.Knead.Simple.Fold (
   T,
   Linear,
   apply,
   passAny,
   pass,
   fold,
   (Core.$:.),
   ) where

import qualified Data.Array.Knead.Simple.Private as Core
import Data.Array.Knead.Simple.Private (Array(Array), Code, Val, )

import qualified Data.Array.Knead.Index.Linear as Linear
import qualified Data.Array.Knead.Index.Linear.Int as IndexInt
import qualified Data.Array.Knead.Index.Nested.Shape as Shape
import qualified Data.Array.Knead.Expression as Expr
import Data.Array.Knead.Expression (Exp, unExp, )
import Data.Array.Knead.Index.Linear ((#:.), (:.)((:.)), )

import qualified LLVM.Extra.Multi.Value as MultiValue
import LLVM.Extra.Multi.Value (atom, )

import Prelude hiding (zipWith, zipWith3, zip, zip3, replicate, )


data T sh0 sh1 a =
   forall ix0 ix1.
   (Shape.Index sh0 ~ ix0, Shape.Index sh1 ~ ix1) =>
   Cons
      (Exp sh0 -> Exp sh1)
      (forall r. Val sh0 -> (Val ix0 -> Code r a) -> (Val ix1 -> Code r a))


apply ::
   (Core.C array, Shape.C sh0, Shape.C sh1, MultiValue.C a) =>
   T sh0 sh1 a ->
   array sh0 a ->
   array sh1 a
apply (Cons fsh reduce) =
   Core.lift1 $ \(Array sh code) ->
      Array (fsh sh) (\ix -> do sh0 <- unExp sh; reduce sh0 code ix)


type Linear sh0 sh1 = T (Linear.Shape sh0) (Linear.Shape sh1)

passAny :: Linear sh sh a
passAny = Cons id (const id)

pass ::
   Linear sh0 sh1 a ->
   Linear (sh0:.i) (sh1:.i) a
pass (Cons fsh reduce) =
   Cons
      (Expr.modify (Linear.shape (atom:.atom)) $ \(sh:.s) -> fsh sh :. s)
      (\sh code ->
       Linear.switchR $ \jx j ->
          reduce (Linear.tail sh) (\kx -> code (kx #:. j)) jx)


fold1CodeLinear ::
   (MultiValue.C a) =>
   (Exp a -> Exp a -> Exp a) ->
   Exp IndexInt.Int ->
   (Val (Linear.Index (sh :. IndexInt.Int)) -> Code r a) ->
   (Val (Linear.Index sh) -> Code r a)
fold1CodeLinear f nc code ix =
   Core.fold1Code f (IndexInt.decons nc)
      (\jx j -> code (jx #:. IndexInt.cons j))
      ix

fold ::
   (MultiValue.C a) =>
   (Exp a -> Exp a -> Exp a) ->
   Linear sh0 sh1 a ->
   Linear (sh0:.IndexInt.Int) sh1 a
fold f (Cons fsh reduce) =
   Cons
      (fsh . Linear.tail)
      (\sh code jx ->
          reduce (Linear.tail sh)
             (fold1CodeLinear f (Expr.lift0 (Linear.head sh)) code) jx)


instance Core.Process (T sh0 sh1 a) where