{- |
Reduce selected dimensions.
Alternatively you may reorder dimensions with 'ShapeDep.backpermute'
and fold once along multiple dimensions.
-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE Rank2Types #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
module Data.Array.Knead.Symbolic.Fold (
   T,
   Cubic,
   apply,
   passAny,
   pass,
   fold,
   (Core.$:.),
   ) where

import qualified Data.Array.Knead.Symbolic.Private as Core
import Data.Array.Knead.Symbolic.Private (Array(Array), Code, Val, )

import qualified Data.Array.Knead.Shape.Cubic.Int as Index
import qualified Data.Array.Knead.Shape.Cubic as Cubic
import qualified Data.Array.Knead.Shape as Shape
import qualified Data.Array.Knead.Expression as Expr
import Data.Array.Knead.Shape.Cubic ((#:.), (:.)((:.)), )

import LLVM.DSL.Expression (Exp, unExp)

import qualified LLVM.Extra.Multi.Value as MultiValue
import LLVM.Extra.Multi.Value (atom, )

import qualified Type.Data.Num.Unary as Unary

import Prelude hiding (zipWith, zipWith3, zip, zip3, replicate, )


data T sh0 sh1 a =
   forall ix0 ix1.
   (Shape.Index sh0 ~ ix0, Shape.Index sh1 ~ ix1) =>
   Cons
      (Exp sh0 -> Exp sh1)
      (forall r. Val sh0 -> (Val ix0 -> Code r a) -> (Val ix1 -> Code r a))


apply ::
   (Core.C array, Shape.C sh0, Shape.C sh1, MultiValue.C a) =>
   T sh0 sh1 a ->
   array sh0 a ->
   array sh1 a
apply :: forall (array :: * -> * -> *) sh0 sh1 a.
(C array, C sh0, C sh1, C a) =>
T sh0 sh1 a -> array sh0 a -> array sh1 a
apply (Cons Exp sh0 -> Exp sh1
fsh forall r. Val sh0 -> (Val ix0 -> Code r a) -> Val ix1 -> Code r a
reduce) =
   (Array sh0 a -> Array sh1 a) -> array sh0 a -> array sh1 a
forall sha a shb b.
(Array sha a -> Array shb b) -> array sha a -> array shb b
forall (array :: * -> * -> *) sha a shb b.
C array =>
(Array sha a -> Array shb b) -> array sha a -> array shb b
Core.lift1 ((Array sh0 a -> Array sh1 a) -> array sh0 a -> array sh1 a)
-> (Array sh0 a -> Array sh1 a) -> array sh0 a -> array sh1 a
forall a b. (a -> b) -> a -> b
$ \(Array Exp sh0
sh forall r. Val (Index sh0) -> Code r a
code) ->
      Exp sh1 -> (forall r. Val (Index sh1) -> Code r a) -> Array sh1 a
forall sh a.
Exp sh -> (forall r. Val (Index sh) -> Code r a) -> Array sh a
Array (Exp sh0 -> Exp sh1
fsh Exp sh0
sh) (\Val (Index sh1)
ix -> do Val sh0
sh0 <- Exp sh0 -> forall r. CodeGenFunction r (Val sh0)
forall a. Exp a -> forall r. CodeGenFunction r (T a)
unExp Exp sh0
sh; Val sh0 -> (Val ix0 -> Code r a) -> Val ix1 -> Code r a
forall r. Val sh0 -> (Val ix0 -> Code r a) -> Val ix1 -> Code r a
reduce Val sh0
sh0 Val ix0 -> Code r a
Val (Index sh0) -> Code r a
forall r. Val (Index sh0) -> Code r a
code Val ix1
Val (Index sh1)
ix)


type Cubic rank0 rank1 = T (Cubic.Shape rank0) (Cubic.Shape rank1)

passAny :: Cubic rank rank a
passAny :: forall rank a. Cubic rank rank a
passAny = (Exp (Shape rank) -> Exp (Shape rank))
-> (forall r.
    Val (Shape rank)
    -> (Val (Index rank) -> Code r a) -> Val (Index rank) -> Code r a)
-> T (Shape rank) (Shape rank) a
forall sh0 sh1 a ix0 ix1.
(Index sh0 ~ ix0, Index sh1 ~ ix1) =>
(Exp sh0 -> Exp sh1)
-> (forall r.
    Val sh0 -> (Val ix0 -> Code r a) -> Val ix1 -> Code r a)
-> T sh0 sh1 a
Cons Exp (Shape rank) -> Exp (Shape rank)
forall a. a -> a
id (((Val (Index rank) -> Code r a) -> Val (Index rank) -> Code r a)
-> Val (Shape rank)
-> (Val (Index rank) -> Code r a)
-> Val (Index rank)
-> Code r a
forall a b. a -> b -> a
const (Val (Index rank) -> Code r a) -> Val (Index rank) -> Code r a
forall a. a -> a
id)

pass ::
   (Unary.Natural rank0, Unary.Natural rank1, MultiValue.C a) =>
   Cubic rank0 rank1 a ->
   Cubic (Unary.Succ rank0) (Unary.Succ rank1) a
pass :: forall rank0 rank1 a.
(Natural rank0, Natural rank1, C a) =>
Cubic rank0 rank1 a -> Cubic (Succ rank0) (Succ rank1) a
pass (Cons Exp (Shape rank0) -> Exp (Shape rank1)
fsh forall r.
Val (Shape rank0) -> (Val ix0 -> Code r a) -> Val ix1 -> Code r a
reduce) =
   (Exp (Shape (Succ rank0)) -> Exp (Shape (Succ rank1)))
-> (forall r.
    Val (Shape (Succ rank0))
    -> (Val (Index (Succ rank0)) -> Code r a)
    -> Val (Index (Succ rank1))
    -> Code r a)
-> T (Shape (Succ rank0)) (Shape (Succ rank1)) a
forall sh0 sh1 a ix0 ix1.
(Index sh0 ~ ix0, Index sh1 ~ ix1) =>
(Exp sh0 -> Exp sh1)
-> (forall r.
    Val sh0 -> (Val ix0 -> Code r a) -> Val ix1 -> Code r a)
-> T sh0 sh1 a
Cons
      ((Atom (Shape rank0) :. Atom Int)
-> (Decomposed Exp (Atom (Shape rank0) :. Atom Int)
    -> Exp (Shape rank1) :. Exp Int)
-> Exp (PatternTuple (Atom (Shape rank0) :. Atom Int))
-> Exp (Composed (Exp (Shape rank1) :. Exp Int))
forall a pattern.
(Compose a, Decompose pattern) =>
pattern
-> (Decomposed Exp pattern -> a)
-> Exp (PatternTuple pattern)
-> Exp (Composed a)
Expr.modify (Atom (Shape rank0)
forall a. Atom a
atomAtom (Shape rank0) -> Atom Int -> Atom (Shape rank0) :. Atom Int
forall tail head. tail -> head -> tail :. head
:.Atom Int
forall a. Atom a
atom) ((Decomposed Exp (Atom (Shape rank0) :. Atom Int)
  -> Exp (Shape rank1) :. Exp Int)
 -> Exp (PatternTuple (Atom (Shape rank0) :. Atom Int))
 -> Exp (Composed (Exp (Shape rank1) :. Exp Int)))
-> (Decomposed Exp (Atom (Shape rank0) :. Atom Int)
    -> Exp (Shape rank1) :. Exp Int)
-> Exp (PatternTuple (Atom (Shape rank0) :. Atom Int))
-> Exp (Composed (Exp (Shape rank1) :. Exp Int))
forall a b. (a -> b) -> a -> b
$ \(Exp (Shape rank0)
sh:.Exp Int
s) -> Exp (Shape rank0) -> Exp (Shape rank1)
fsh Exp (Shape rank0)
sh Exp (Shape rank1) -> Exp Int -> Exp (Shape rank1) :. Exp Int
forall tail head. tail -> head -> tail :. head
:. Exp Int
s)
      (\Val (Shape (Succ rank0))
sh Val (Index (Succ rank0)) -> Code r a
code ->
       (T (Index rank1) -> T Int -> Code r a)
-> Val (Index (Succ rank1)) -> Code r a
forall rank (val :: * -> *) tag a.
(Natural rank, Value val) =>
(val (T tag rank) -> val Int -> a) -> val (T tag (Succ rank)) -> a
Cubic.switchR ((T (Index rank1) -> T Int -> Code r a)
 -> Val (Index (Succ rank1)) -> Code r a)
-> (T (Index rank1) -> T Int -> Code r a)
-> Val (Index (Succ rank1))
-> Code r a
forall a b. (a -> b) -> a -> b
$ \T (Index rank1)
jx T Int
j ->
          Val (Shape rank0) -> (Val ix0 -> Code r a) -> Val ix1 -> Code r a
forall r.
Val (Shape rank0) -> (Val ix0 -> Code r a) -> Val ix1 -> Code r a
reduce (Val (Shape (Succ rank0)) -> Val (Shape rank0)
forall (val :: * -> *) rank tag.
(Value val, Natural rank) =>
val (T tag (Succ rank)) -> val (T tag rank)
Cubic.tail Val (Shape (Succ rank0))
sh) (\Val ix0
kx -> Val (Index (Succ rank0)) -> Code r a
code (Val ix0
T (Index rank0)
kx T (Index rank0) -> T Int -> Val (Index (Succ rank0))
forall (val :: * -> *) tag rank.
Value val =>
val (T tag rank) -> val Int -> val (T tag (Succ rank))
#:. T Int
j)) Val ix1
T (Index rank1)
jx)


fold1CodeLinear ::
   (Unary.Natural rank, MultiValue.C a) =>
   (Exp a -> Exp a -> Exp a) ->
   Exp Index.Int ->
   (Val (Cubic.Index (Unary.Succ rank)) -> Code r a) ->
   (Val (Cubic.Index rank) -> Code r a)
fold1CodeLinear :: forall rank a r.
(Natural rank, C a) =>
(Exp a -> Exp a -> Exp a)
-> Exp Int
-> (Val (Index (Succ rank)) -> Code r a)
-> Val (Index rank)
-> Code r a
fold1CodeLinear Exp a -> Exp a -> Exp a
f Exp Int
nc Val (Index (Succ rank)) -> Code r a
code Val (Index rank)
ix =
   (Exp a -> Exp a -> Exp a)
-> Exp (ZeroBased Word) -> (T Word -> Code r a) -> Code r a
forall sh ix a r.
(C sh, Index sh ~ ix, C a) =>
(Exp a -> Exp a -> Exp a)
-> Exp sh -> (Val ix -> Code r a) -> Code r a
Core.fold1Code Exp a -> Exp a -> Exp a
f
      ((T Word -> T (ZeroBased Word)) -> Exp Word -> Exp (ZeroBased Word)
forall a b. (T a -> T b) -> Exp a -> Exp b
forall (val :: * -> *) a b.
Value val =>
(T a -> T b) -> val a -> val b
Expr.lift1 (ZeroBased (T Word) -> T (ZeroBased Word)
ZeroBased (T Word) -> T (Composed (ZeroBased (T Word)))
forall multituple.
Compose multituple =>
multituple -> T (Composed multituple)
MultiValue.compose (ZeroBased (T Word) -> T (ZeroBased Word))
-> (T Word -> ZeroBased (T Word)) -> T Word -> T (ZeroBased Word)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. T Word -> ZeroBased (T Word)
forall n. n -> ZeroBased n
Shape.ZeroBased) (Exp Word -> Exp (ZeroBased Word))
-> Exp Word -> Exp (ZeroBased Word)
forall a b. (a -> b) -> a -> b
$ Exp Int -> Exp Word
forall (val :: * -> *). Value val => val Int -> val Word
Index.decons Exp Int
nc)
      (\T Word
j -> Val (Index (Succ rank)) -> Code r a
code (Val (Index rank)
ix Val (Index rank) -> T Int -> Val (Index (Succ rank))
forall (val :: * -> *) tag rank.
Value val =>
val (T tag rank) -> val Int -> val (T tag (Succ rank))
#:. T Word -> T Int
forall (val :: * -> *). Value val => val Word -> val Int
Index.cons T Word
j))

fold ::
   (Unary.Natural rank0, Unary.Natural rank1, MultiValue.C a) =>
   (Exp a -> Exp a -> Exp a) ->
   Cubic rank0 rank1 a ->
   Cubic (Unary.Succ rank0) rank1 a
fold :: forall rank0 rank1 a.
(Natural rank0, Natural rank1, C a) =>
(Exp a -> Exp a -> Exp a)
-> Cubic rank0 rank1 a -> Cubic (Succ rank0) rank1 a
fold Exp a -> Exp a -> Exp a
f (Cons Exp (Shape rank0) -> Exp (Shape rank1)
fsh forall r.
Val (Shape rank0) -> (Val ix0 -> Code r a) -> Val ix1 -> Code r a
reduce) =
   (Exp (Shape (Succ rank0)) -> Exp (Shape rank1))
-> (forall r.
    Val (Shape (Succ rank0))
    -> (Val (Index (Succ rank0)) -> Code r a)
    -> Val (Index rank1)
    -> Code r a)
-> T (Shape (Succ rank0)) (Shape rank1) a
forall sh0 sh1 a ix0 ix1.
(Index sh0 ~ ix0, Index sh1 ~ ix1) =>
(Exp sh0 -> Exp sh1)
-> (forall r.
    Val sh0 -> (Val ix0 -> Code r a) -> Val ix1 -> Code r a)
-> T sh0 sh1 a
Cons
      (Exp (Shape rank0) -> Exp (Shape rank1)
fsh (Exp (Shape rank0) -> Exp (Shape rank1))
-> (Exp (Shape (Succ rank0)) -> Exp (Shape rank0))
-> Exp (Shape (Succ rank0))
-> Exp (Shape rank1)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Exp (Shape (Succ rank0)) -> Exp (Shape rank0)
forall (val :: * -> *) rank tag.
(Value val, Natural rank) =>
val (T tag (Succ rank)) -> val (T tag rank)
Cubic.tail)
      (\Val (Shape (Succ rank0))
sh Val (Index (Succ rank0)) -> Code r a
code Val (Index rank1)
jx ->
          Val (Shape rank0) -> (Val ix0 -> Code r a) -> Val ix1 -> Code r a
forall r.
Val (Shape rank0) -> (Val ix0 -> Code r a) -> Val ix1 -> Code r a
reduce (Val (Shape (Succ rank0)) -> Val (Shape rank0)
forall (val :: * -> *) rank tag.
(Value val, Natural rank) =>
val (T tag (Succ rank)) -> val (T tag rank)
Cubic.tail Val (Shape (Succ rank0))
sh)
             ((Exp a -> Exp a -> Exp a)
-> Exp Int
-> (Val (Index (Succ rank0)) -> Code r a)
-> Val (Index rank0)
-> Code r a
forall rank a r.
(Natural rank, C a) =>
(Exp a -> Exp a -> Exp a)
-> Exp Int
-> (Val (Index (Succ rank)) -> Code r a)
-> Val (Index rank)
-> Code r a
fold1CodeLinear Exp a -> Exp a -> Exp a
f (T Int -> Exp Int
forall a. T a -> Exp a
forall (val :: * -> *) a. Value val => T a -> val a
Expr.lift0 (Val (Shape (Succ rank0)) -> T Int
forall (val :: * -> *) rank tag.
(Value val, Natural rank) =>
val (T tag (Succ rank)) -> val Int
Cubic.head Val (Shape (Succ rank0))
sh)) Val (Index (Succ rank0)) -> Code r a
code) Val ix1
Val (Index rank1)
jx)


instance Core.Process (T sh0 sh1 a) where