{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
module Data.Array.Knead.Symbolic.ShapeDependent where
import qualified Data.Array.Knead.Symbolic.Private as Core
import Data.Array.Knead.Symbolic.Private (Array(Array), )
import qualified Data.Array.Knead.Shape as Shape
import qualified Data.Array.Knead.Expression as Expr
import Data.Array.Knead.Expression (Exp, )
import qualified Control.Monad.HT as Monad
import Control.Monad ((<=<), )
shape :: (Core.C array, Shape.C sh, Shape.Scalar z) => array sh a -> array z sh
shape :: forall (array :: * -> * -> *) sh z a.
(C array, C sh, Scalar z) =>
array sh a -> array z sh
shape = (Array sh a -> Array z sh) -> array sh a -> array z sh
forall sha a shb b.
(Array sha a -> Array shb b) -> array sha a -> array shb b
forall (array :: * -> * -> *) sha a shb b.
C array =>
(Array sha a -> Array shb b) -> array sha a -> array shb b
Core.lift1 ((Array sh a -> Array z sh) -> array sh a -> array z sh)
-> (Array sh a -> Array z sh) -> array sh a -> array z sh
forall a b. (a -> b) -> a -> b
$ Exp sh -> Array z sh
forall sh a. Scalar sh => Exp a -> Array sh a
Core.fromScalar (Exp sh -> Array z sh)
-> (Array sh a -> Exp sh) -> Array sh a -> Array z sh
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Array sh a -> Exp sh
forall sh a. Array sh a -> Exp sh
Core.shape
backpermute ::
(Core.C array,
Shape.C sh0, Shape.Index sh0 ~ ix0,
Shape.C sh1, Shape.Index sh1 ~ ix1) =>
(Exp sh0 -> Exp sh1) ->
(Exp ix1 -> Exp ix0) ->
array sh0 a ->
array sh1 a
backpermute :: forall (array :: * -> * -> *) sh0 ix0 sh1 ix1 a.
(C array, C sh0, Index sh0 ~ ix0, C sh1, Index sh1 ~ ix1) =>
(Exp sh0 -> Exp sh1)
-> (Exp ix1 -> Exp ix0) -> array sh0 a -> array sh1 a
backpermute Exp sh0 -> Exp sh1
createShape Exp ix1 -> Exp ix0
projectIndex =
(Array sh0 a -> Array sh1 a) -> array sh0 a -> array sh1 a
forall sha a shb b.
(Array sha a -> Array shb b) -> array sha a -> array shb b
forall (array :: * -> * -> *) sha a shb b.
C array =>
(Array sha a -> Array shb b) -> array sha a -> array shb b
Core.lift1 ((Array sh0 a -> Array sh1 a) -> array sh0 a -> array sh1 a)
-> (Array sh0 a -> Array sh1 a) -> array sh0 a -> array sh1 a
forall a b. (a -> b) -> a -> b
$ \(Array Exp sh0
sh forall r. Val (Index sh0) -> Code r a
code) ->
Exp sh1 -> (forall r. Val (Index sh1) -> Code r a) -> Array sh1 a
forall sh a.
Exp sh -> (forall r. Val (Index sh) -> Code r a) -> Array sh a
Array (Exp sh0 -> Exp sh1
createShape Exp sh0
sh)
(Val ix0 -> CodeGenFunction r (Val a)
Val (Index sh0) -> CodeGenFunction r (Val a)
forall r. Val (Index sh0) -> Code r a
code (Val ix0 -> CodeGenFunction r (Val a))
-> (Val ix1 -> CodeGenFunction r (Val ix0))
-> Val ix1
-> CodeGenFunction r (Val a)
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< (Exp ix1 -> Exp ix0) -> Val ix1 -> CodeGenFunction r (Val ix0)
forall ae am be bm r.
(Aggregate ae am, Aggregate be bm) =>
(ae -> be) -> am -> CodeGenFunction r bm
Expr.unliftM1 Exp ix1 -> Exp ix0
projectIndex)
backpermuteExtra ::
(Core.C array,
Shape.C sh0, Shape.Index sh0 ~ ix0,
Shape.C sh1, Shape.Index sh1 ~ ix1,
Shape.C sh, Shape.Index sh ~ ix) =>
(Exp sh0 -> Exp sh1 -> Exp sh) ->
(Exp ix -> Exp ix0) ->
array sh0 a -> array sh1 b -> array sh a
Exp sh0 -> Exp sh1 -> Exp sh
newShape Exp ix -> Exp ix0
projectIndex =
(Array sh0 a -> Array sh1 b -> Array sh a)
-> array sh0 a -> array sh1 b -> array sh a
forall sha a shb b shc c.
(Array sha a -> Array shb b -> Array shc c)
-> array sha a -> array shb b -> array shc c
forall (array :: * -> * -> *) sha a shb b shc c.
C array =>
(Array sha a -> Array shb b -> Array shc c)
-> array sha a -> array shb b -> array shc c
Core.lift2 ((Array sh0 a -> Array sh1 b -> Array sh a)
-> array sh0 a -> array sh1 b -> array sh a)
-> (Array sh0 a -> Array sh1 b -> Array sh a)
-> array sh0 a
-> array sh1 b
-> array sh a
forall a b. (a -> b) -> a -> b
$ \(Array Exp sh0
sh0 forall r. Val (Index sh0) -> Code r a
code) (Array Exp sh1
sh1 forall r. Val (Index sh1) -> Code r b
_code) ->
Exp sh -> (forall r. Val (Index sh) -> Code r a) -> Array sh a
forall sh a.
Exp sh -> (forall r. Val (Index sh) -> Code r a) -> Array sh a
Array (Exp sh0 -> Exp sh1 -> Exp sh
newShape Exp sh0
sh0 Exp sh1
sh1)
(\Val (Index sh)
ix -> Val ix0 -> Code r a
Val (Index sh0) -> Code r a
forall r. Val (Index sh0) -> Code r a
code (Val ix0 -> Code r a) -> CodeGenFunction r (Val ix0) -> Code r a
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< (Exp ix -> Exp ix0) -> Val ix -> CodeGenFunction r (Val ix0)
forall ae am be bm r.
(Aggregate ae am, Aggregate be bm) =>
(ae -> be) -> am -> CodeGenFunction r bm
Expr.unliftM1 Exp ix -> Exp ix0
projectIndex Val ix
Val (Index sh)
ix)
backpermute2 ::
(Core.C array,
Shape.C sh0, Shape.Index sh0 ~ ix0,
Shape.C sh1, Shape.Index sh1 ~ ix1,
Shape.C sh, Shape.Index sh ~ ix) =>
(Exp sh0 -> Exp sh1 -> Exp sh) ->
(Exp ix -> Exp ix0) ->
(Exp ix -> Exp ix1) ->
(Exp a -> Exp b -> Exp c) ->
array sh0 a -> array sh1 b -> array sh c
backpermute2 :: forall (array :: * -> * -> *) sh0 ix0 sh1 ix1 sh ix a b c.
(C array, C sh0, Index sh0 ~ ix0, C sh1, Index sh1 ~ ix1, C sh,
Index sh ~ ix) =>
(Exp sh0 -> Exp sh1 -> Exp sh)
-> (Exp ix -> Exp ix0)
-> (Exp ix -> Exp ix1)
-> (Exp a -> Exp b -> Exp c)
-> array sh0 a
-> array sh1 b
-> array sh c
backpermute2 Exp sh0 -> Exp sh1 -> Exp sh
combineShape Exp ix -> Exp ix0
projectIndex0 Exp ix -> Exp ix1
projectIndex1 Exp a -> Exp b -> Exp c
f =
(Array sh0 a -> Array sh1 b -> Array sh c)
-> array sh0 a -> array sh1 b -> array sh c
forall sha a shb b shc c.
(Array sha a -> Array shb b -> Array shc c)
-> array sha a -> array shb b -> array shc c
forall (array :: * -> * -> *) sha a shb b shc c.
C array =>
(Array sha a -> Array shb b -> Array shc c)
-> array sha a -> array shb b -> array shc c
Core.lift2 ((Array sh0 a -> Array sh1 b -> Array sh c)
-> array sh0 a -> array sh1 b -> array sh c)
-> (Array sh0 a -> Array sh1 b -> Array sh c)
-> array sh0 a
-> array sh1 b
-> array sh c
forall a b. (a -> b) -> a -> b
$ \(Array Exp sh0
sha forall r. Val (Index sh0) -> Code r a
codeA) (Array Exp sh1
shb forall r. Val (Index sh1) -> Code r b
codeB) ->
Exp sh -> (forall r. Val (Index sh) -> Code r c) -> Array sh c
forall sh a.
Exp sh -> (forall r. Val (Index sh) -> Code r a) -> Array sh a
Array (Exp sh0 -> Exp sh1 -> Exp sh
combineShape Exp sh0
sha Exp sh1
shb)
(\Val (Index sh)
ix ->
(Val a -> Val b -> Code r c)
-> CodeGenFunction r (Val a)
-> CodeGenFunction r (Val b)
-> Code r c
forall (m :: * -> *) a b c.
Monad m =>
(a -> b -> m c) -> m a -> m b -> m c
Monad.liftJoin2 ((Exp a -> Exp b -> Exp c) -> Val a -> Val b -> Code r c
forall ae am be bm ce cm r.
(Aggregate ae am, Aggregate be bm, Aggregate ce cm) =>
(ae -> be -> ce) -> am -> bm -> CodeGenFunction r cm
Expr.unliftM2 Exp a -> Exp b -> Exp c
f)
(Val ix0 -> CodeGenFunction r (Val a)
Val (Index sh0) -> CodeGenFunction r (Val a)
forall r. Val (Index sh0) -> Code r a
codeA (Val ix0 -> CodeGenFunction r (Val a))
-> CodeGenFunction r (Val ix0) -> CodeGenFunction r (Val a)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< (Exp ix -> Exp ix0) -> Val ix -> CodeGenFunction r (Val ix0)
forall ae am be bm r.
(Aggregate ae am, Aggregate be bm) =>
(ae -> be) -> am -> CodeGenFunction r bm
Expr.unliftM1 Exp ix -> Exp ix0
projectIndex0 Val ix
Val (Index sh)
ix)
(Val ix1 -> CodeGenFunction r (Val b)
Val (Index sh1) -> CodeGenFunction r (Val b)
forall r. Val (Index sh1) -> Code r b
codeB (Val ix1 -> CodeGenFunction r (Val b))
-> CodeGenFunction r (Val ix1) -> CodeGenFunction r (Val b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< (Exp ix -> Exp ix1) -> Val ix -> CodeGenFunction r (Val ix1)
forall ae am be bm r.
(Aggregate ae am, Aggregate be bm) =>
(ae -> be) -> am -> CodeGenFunction r bm
Expr.unliftM1 Exp ix -> Exp ix1
projectIndex1 Val ix
Val (Index sh)
ix))
fill ::
(Core.C array) =>
(Exp sh0 -> Exp sh1) -> Exp b ->
array sh0 a -> array sh1 b
fill :: forall (array :: * -> * -> *) sh0 sh1 b a.
C array =>
(Exp sh0 -> Exp sh1) -> Exp b -> array sh0 a -> array sh1 b
fill Exp sh0 -> Exp sh1
fsh Exp b
a =
(Array sh0 a -> Array sh1 b) -> array sh0 a -> array sh1 b
forall sha a shb b.
(Array sha a -> Array shb b) -> array sha a -> array shb b
forall (array :: * -> * -> *) sha a shb b.
C array =>
(Array sha a -> Array shb b) -> array sha a -> array shb b
Core.lift1 ((Array sh0 a -> Array sh1 b) -> array sh0 a -> array sh1 b)
-> (Array sh0 a -> Array sh1 b) -> array sh0 a -> array sh1 b
forall a b. (a -> b) -> a -> b
$ \Array sh0 a
arr ->
Exp sh1 -> Exp b -> Array sh1 b
forall sh a. Exp sh -> Exp a -> Array sh a
Core.fill (Exp sh0 -> Exp sh1
fsh (Exp sh0 -> Exp sh1) -> Exp sh0 -> Exp sh1
forall a b. (a -> b) -> a -> b
$ Array sh0 a -> Exp sh0
forall sh a. Array sh a -> Exp sh
Core.shape Array sh0 a
arr) Exp b
a