{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
module Data.Array.Knead.Symbolic.ShapeDependent where

import qualified Data.Array.Knead.Symbolic.Private as Core
import Data.Array.Knead.Symbolic.Private (Array(Array), )

import qualified Data.Array.Knead.Shape as Shape
import qualified Data.Array.Knead.Expression as Expr
import Data.Array.Knead.Expression (Exp, )

import qualified Control.Monad.HT as Monad
import Control.Monad ((<=<), )


shape :: (Core.C array, Shape.C sh, Shape.Scalar z) => array sh a -> array z sh
shape :: forall (array :: * -> * -> *) sh z a.
(C array, C sh, Scalar z) =>
array sh a -> array z sh
shape = (Array sh a -> Array z sh) -> array sh a -> array z sh
forall sha a shb b.
(Array sha a -> Array shb b) -> array sha a -> array shb b
forall (array :: * -> * -> *) sha a shb b.
C array =>
(Array sha a -> Array shb b) -> array sha a -> array shb b
Core.lift1 ((Array sh a -> Array z sh) -> array sh a -> array z sh)
-> (Array sh a -> Array z sh) -> array sh a -> array z sh
forall a b. (a -> b) -> a -> b
$ Exp sh -> Array z sh
forall sh a. Scalar sh => Exp a -> Array sh a
Core.fromScalar (Exp sh -> Array z sh)
-> (Array sh a -> Exp sh) -> Array sh a -> Array z sh
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Array sh a -> Exp sh
forall sh a. Array sh a -> Exp sh
Core.shape

backpermute ::
   (Core.C array,
    Shape.C sh0, Shape.Index sh0 ~ ix0,
    Shape.C sh1, Shape.Index sh1 ~ ix1) =>
   (Exp sh0 -> Exp sh1) ->
   (Exp ix1 -> Exp ix0) ->
   array sh0 a ->
   array sh1 a
backpermute :: forall (array :: * -> * -> *) sh0 ix0 sh1 ix1 a.
(C array, C sh0, Index sh0 ~ ix0, C sh1, Index sh1 ~ ix1) =>
(Exp sh0 -> Exp sh1)
-> (Exp ix1 -> Exp ix0) -> array sh0 a -> array sh1 a
backpermute Exp sh0 -> Exp sh1
createShape Exp ix1 -> Exp ix0
projectIndex =
   (Array sh0 a -> Array sh1 a) -> array sh0 a -> array sh1 a
forall sha a shb b.
(Array sha a -> Array shb b) -> array sha a -> array shb b
forall (array :: * -> * -> *) sha a shb b.
C array =>
(Array sha a -> Array shb b) -> array sha a -> array shb b
Core.lift1 ((Array sh0 a -> Array sh1 a) -> array sh0 a -> array sh1 a)
-> (Array sh0 a -> Array sh1 a) -> array sh0 a -> array sh1 a
forall a b. (a -> b) -> a -> b
$ \(Array Exp sh0
sh forall r. Val (Index sh0) -> Code r a
code) ->
      Exp sh1 -> (forall r. Val (Index sh1) -> Code r a) -> Array sh1 a
forall sh a.
Exp sh -> (forall r. Val (Index sh) -> Code r a) -> Array sh a
Array (Exp sh0 -> Exp sh1
createShape Exp sh0
sh)
         (Val ix0 -> CodeGenFunction r (Val a)
Val (Index sh0) -> CodeGenFunction r (Val a)
forall r. Val (Index sh0) -> Code r a
code (Val ix0 -> CodeGenFunction r (Val a))
-> (Val ix1 -> CodeGenFunction r (Val ix0))
-> Val ix1
-> CodeGenFunction r (Val a)
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< (Exp ix1 -> Exp ix0) -> Val ix1 -> CodeGenFunction r (Val ix0)
forall ae am be bm r.
(Aggregate ae am, Aggregate be bm) =>
(ae -> be) -> am -> CodeGenFunction r bm
Expr.unliftM1 Exp ix1 -> Exp ix0
projectIndex)

{- |
This is between 'backpermute' and 'backpermute2'.
You can access the shapes of two arrays,
but only the content of one of them.
This is necessary if the second array contributes only a virtual dimension.
-}
backpermuteExtra ::
   (Core.C array,
    Shape.C sh0, Shape.Index sh0 ~ ix0,
    Shape.C sh1, Shape.Index sh1 ~ ix1,
    Shape.C sh,  Shape.Index sh  ~ ix) =>
   (Exp sh0 -> Exp sh1 -> Exp sh) ->
   (Exp ix -> Exp ix0) ->
   array sh0 a -> array sh1 b -> array sh a
backpermuteExtra :: forall (array :: * -> * -> *) sh0 ix0 sh1 ix1 sh ix a b.
(C array, C sh0, Index sh0 ~ ix0, C sh1, Index sh1 ~ ix1, C sh,
 Index sh ~ ix) =>
(Exp sh0 -> Exp sh1 -> Exp sh)
-> (Exp ix -> Exp ix0) -> array sh0 a -> array sh1 b -> array sh a
backpermuteExtra Exp sh0 -> Exp sh1 -> Exp sh
newShape Exp ix -> Exp ix0
projectIndex =
   (Array sh0 a -> Array sh1 b -> Array sh a)
-> array sh0 a -> array sh1 b -> array sh a
forall sha a shb b shc c.
(Array sha a -> Array shb b -> Array shc c)
-> array sha a -> array shb b -> array shc c
forall (array :: * -> * -> *) sha a shb b shc c.
C array =>
(Array sha a -> Array shb b -> Array shc c)
-> array sha a -> array shb b -> array shc c
Core.lift2 ((Array sh0 a -> Array sh1 b -> Array sh a)
 -> array sh0 a -> array sh1 b -> array sh a)
-> (Array sh0 a -> Array sh1 b -> Array sh a)
-> array sh0 a
-> array sh1 b
-> array sh a
forall a b. (a -> b) -> a -> b
$ \(Array Exp sh0
sh0 forall r. Val (Index sh0) -> Code r a
code) (Array Exp sh1
sh1 forall r. Val (Index sh1) -> Code r b
_code) ->
      Exp sh -> (forall r. Val (Index sh) -> Code r a) -> Array sh a
forall sh a.
Exp sh -> (forall r. Val (Index sh) -> Code r a) -> Array sh a
Array (Exp sh0 -> Exp sh1 -> Exp sh
newShape Exp sh0
sh0 Exp sh1
sh1)
         (\Val (Index sh)
ix -> Val ix0 -> Code r a
Val (Index sh0) -> Code r a
forall r. Val (Index sh0) -> Code r a
code (Val ix0 -> Code r a) -> CodeGenFunction r (Val ix0) -> Code r a
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< (Exp ix -> Exp ix0) -> Val ix -> CodeGenFunction r (Val ix0)
forall ae am be bm r.
(Aggregate ae am, Aggregate be bm) =>
(ae -> be) -> am -> CodeGenFunction r bm
Expr.unliftM1 Exp ix -> Exp ix0
projectIndex Val ix
Val (Index sh)
ix)

backpermute2 ::
   (Core.C array,
    Shape.C sh0, Shape.Index sh0 ~ ix0,
    Shape.C sh1, Shape.Index sh1 ~ ix1,
    Shape.C sh,  Shape.Index sh  ~ ix) =>
   (Exp sh0 -> Exp sh1 -> Exp sh) ->
   (Exp ix -> Exp ix0) ->
   (Exp ix -> Exp ix1) ->
   (Exp a -> Exp b -> Exp c) ->
   array sh0 a -> array sh1 b -> array sh c
backpermute2 :: forall (array :: * -> * -> *) sh0 ix0 sh1 ix1 sh ix a b c.
(C array, C sh0, Index sh0 ~ ix0, C sh1, Index sh1 ~ ix1, C sh,
 Index sh ~ ix) =>
(Exp sh0 -> Exp sh1 -> Exp sh)
-> (Exp ix -> Exp ix0)
-> (Exp ix -> Exp ix1)
-> (Exp a -> Exp b -> Exp c)
-> array sh0 a
-> array sh1 b
-> array sh c
backpermute2 Exp sh0 -> Exp sh1 -> Exp sh
combineShape Exp ix -> Exp ix0
projectIndex0 Exp ix -> Exp ix1
projectIndex1 Exp a -> Exp b -> Exp c
f =
   (Array sh0 a -> Array sh1 b -> Array sh c)
-> array sh0 a -> array sh1 b -> array sh c
forall sha a shb b shc c.
(Array sha a -> Array shb b -> Array shc c)
-> array sha a -> array shb b -> array shc c
forall (array :: * -> * -> *) sha a shb b shc c.
C array =>
(Array sha a -> Array shb b -> Array shc c)
-> array sha a -> array shb b -> array shc c
Core.lift2 ((Array sh0 a -> Array sh1 b -> Array sh c)
 -> array sh0 a -> array sh1 b -> array sh c)
-> (Array sh0 a -> Array sh1 b -> Array sh c)
-> array sh0 a
-> array sh1 b
-> array sh c
forall a b. (a -> b) -> a -> b
$ \(Array Exp sh0
sha forall r. Val (Index sh0) -> Code r a
codeA) (Array Exp sh1
shb forall r. Val (Index sh1) -> Code r b
codeB) ->
      Exp sh -> (forall r. Val (Index sh) -> Code r c) -> Array sh c
forall sh a.
Exp sh -> (forall r. Val (Index sh) -> Code r a) -> Array sh a
Array (Exp sh0 -> Exp sh1 -> Exp sh
combineShape Exp sh0
sha Exp sh1
shb)
         (\Val (Index sh)
ix ->
            (Val a -> Val b -> Code r c)
-> CodeGenFunction r (Val a)
-> CodeGenFunction r (Val b)
-> Code r c
forall (m :: * -> *) a b c.
Monad m =>
(a -> b -> m c) -> m a -> m b -> m c
Monad.liftJoin2 ((Exp a -> Exp b -> Exp c) -> Val a -> Val b -> Code r c
forall ae am be bm ce cm r.
(Aggregate ae am, Aggregate be bm, Aggregate ce cm) =>
(ae -> be -> ce) -> am -> bm -> CodeGenFunction r cm
Expr.unliftM2 Exp a -> Exp b -> Exp c
f)
               (Val ix0 -> CodeGenFunction r (Val a)
Val (Index sh0) -> CodeGenFunction r (Val a)
forall r. Val (Index sh0) -> Code r a
codeA (Val ix0 -> CodeGenFunction r (Val a))
-> CodeGenFunction r (Val ix0) -> CodeGenFunction r (Val a)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< (Exp ix -> Exp ix0) -> Val ix -> CodeGenFunction r (Val ix0)
forall ae am be bm r.
(Aggregate ae am, Aggregate be bm) =>
(ae -> be) -> am -> CodeGenFunction r bm
Expr.unliftM1 Exp ix -> Exp ix0
projectIndex0 Val ix
Val (Index sh)
ix)
               (Val ix1 -> CodeGenFunction r (Val b)
Val (Index sh1) -> CodeGenFunction r (Val b)
forall r. Val (Index sh1) -> Code r b
codeB (Val ix1 -> CodeGenFunction r (Val b))
-> CodeGenFunction r (Val ix1) -> CodeGenFunction r (Val b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< (Exp ix -> Exp ix1) -> Val ix -> CodeGenFunction r (Val ix1)
forall ae am be bm r.
(Aggregate ae am, Aggregate be bm) =>
(ae -> be) -> am -> CodeGenFunction r bm
Expr.unliftM1 Exp ix -> Exp ix1
projectIndex1 Val ix
Val (Index sh)
ix))

fill ::
   (Core.C array) =>
   (Exp sh0 -> Exp sh1) -> Exp b ->
   array sh0 a -> array sh1 b
fill :: forall (array :: * -> * -> *) sh0 sh1 b a.
C array =>
(Exp sh0 -> Exp sh1) -> Exp b -> array sh0 a -> array sh1 b
fill Exp sh0 -> Exp sh1
fsh Exp b
a =
   (Array sh0 a -> Array sh1 b) -> array sh0 a -> array sh1 b
forall sha a shb b.
(Array sha a -> Array shb b) -> array sha a -> array shb b
forall (array :: * -> * -> *) sha a shb b.
C array =>
(Array sha a -> Array shb b) -> array sha a -> array shb b
Core.lift1 ((Array sh0 a -> Array sh1 b) -> array sh0 a -> array sh1 b)
-> (Array sh0 a -> Array sh1 b) -> array sh0 a -> array sh1 b
forall a b. (a -> b) -> a -> b
$ \Array sh0 a
arr ->
      Exp sh1 -> Exp b -> Array sh1 b
forall sh a. Exp sh -> Exp a -> Array sh a
Core.fill (Exp sh0 -> Exp sh1
fsh (Exp sh0 -> Exp sh1) -> Exp sh0 -> Exp sh1
forall a b. (a -> b) -> a -> b
$ Array sh0 a -> Exp sh0
forall sh a. Array sh a -> Exp sh
Core.shape Array sh0 a
arr) Exp b
a