{-# LANGUAGE Rank2Types #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE TypeOperators #-}
module Data.Array.Knead.Parameterized.Symbolic (
   Array,
   Exp,
   Sym.extendParameter,
   withExp,
   withExp2,
   withExp3,
   (Sym.!),
   Sym.fill,
   gather,
   backpermute,
   Sym.id,
   Sym.map,
   zipWith,
   Sym.fold1,
   Sym.fold1All,
   ) where

import qualified Data.Array.Knead.Parameterized.Private as Sym
import qualified Data.Array.Knead.Symbolic as Core
import Data.Array.Knead.Parameterized.Private (Array, gather, )

import qualified Data.Array.Knead.Shape as Shape
import qualified Data.Array.Knead.Expression as Expr
import Data.Array.Knead.Expression (Exp, )

import qualified LLVM.DSL.Parameter as Param

import qualified LLVM.Extra.Multi.Value.Marshal as Marshal
import qualified LLVM.Extra.Multi.Value as MultiValue

import Control.Applicative ((<*>), )

import Prelude (uncurry, ($), (.), )


{-
fromScalar ::
   (Storable a, MultiValueMemory.C a, MultiValue.C a) =>
   Param.T p a -> Array p Z a
fromScalar = Sym.fill (return Z)
-}


backpermute ::
   (Shape.C sh0, Shape.Index sh0 ~ ix0,
    Shape.C sh1, Shape.Index sh1 ~ ix1, Marshal.C sh1,
    MultiValue.C a) =>
   Param.T p sh1 ->
   (Exp ix1 -> Exp ix0) ->
   Array p sh0 a ->
   Array p sh1 a
backpermute :: forall sh0 ix0 sh1 ix1 a p.
(C sh0, Index sh0 ~ ix0, C sh1, Index sh1 ~ ix1, C sh1, C a) =>
T p sh1 -> (Exp ix1 -> Exp ix0) -> Array p sh0 a -> Array p sh1 a
backpermute T p sh1
sh1 Exp ix1 -> Exp ix0
f = Array p sh1 ix0 -> Array p sh0 a -> Array p sh1 a
forall sh0 ix0 sh1 a p.
(C sh0, Index sh0 ~ ix0, C sh1, C a) =>
Array p sh1 ix0 -> Array p sh0 a -> Array p sh1 a
gather ((Exp ix1 -> Exp ix0) -> Array p sh1 ix1 -> Array p sh1 ix0
forall (array :: * -> * -> *) sh a b.
(C array, C sh) =>
(Exp a -> Exp b) -> array sh a -> array sh b
Core.map Exp ix1 -> Exp ix0
f (T p sh1 -> Array p sh1 ix1
forall sh ix p.
(C sh, C sh, Index sh ~ ix) =>
T p sh -> Array p sh ix
Sym.id T p sh1
sh1))


zipWith ::
   (Shape.C sh, Marshal.C d) =>
   (Exp d -> Exp a -> Exp b -> Exp c) ->
   Param.T p d -> Array p sh a -> Array p sh b -> Array p sh c
zipWith :: forall sh d a b c p.
(C sh, C d) =>
(Exp d -> Exp a -> Exp b -> Exp c)
-> T p d -> Array p sh a -> Array p sh b -> Array p sh c
zipWith Exp d -> Exp a -> Exp b -> Exp c
f T p d
d Array p sh a
a Array p sh b
b =
   (Exp d -> Exp (a, b) -> Exp c)
-> T p d -> Array p sh (a, b) -> Array p sh c
forall sh c a b p.
(C sh, C c) =>
(Exp c -> Exp a -> Exp b) -> T p c -> Array p sh a -> Array p sh b
Sym.map (\Exp d
di Exp (a, b)
ab -> (Exp a -> Exp b -> Exp c) -> (Exp a, Exp b) -> Exp c
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry (Exp d -> Exp a -> Exp b -> Exp c
f Exp d
di) ((Exp a, Exp b) -> Exp c) -> (Exp a, Exp b) -> Exp c
forall a b. (a -> b) -> a -> b
$ Exp (a, b) -> (Exp a, Exp b)
forall (val :: * -> *) a b.
Value val =>
val (a, b) -> (val a, val b)
Expr.unzip Exp (a, b)
ab) T p d
d (Array p sh (a, b) -> Array p sh c)
-> Array p sh (a, b) -> Array p sh c
forall a b. (a -> b) -> a -> b
$ Array p sh a -> Array p sh b -> Array p sh (a, b)
forall (array :: * -> * -> *) sh a b.
(C array, C sh) =>
array sh a -> array sh b -> array sh (a, b)
Core.zip Array p sh a
a Array p sh b
b


withExp ::
   (Marshal.C x) =>
   (Exp x -> Core.Array shb b -> Core.Array sha a) ->
   Param.T p x -> Array p shb b -> Array p sha a
withExp :: forall x shb b sha a p.
C x =>
(Exp x -> Array shb b -> Array sha a)
-> T p x -> Array p shb b -> Array p sha a
withExp Exp x -> Array shb b -> Array sha a
f T p x
x =
   Hull p (Array sha a) -> Array p sha a
forall p sh a. Hull p (Array sh a) -> Array p sh a
Sym.runHull (Hull p (Array sha a) -> Array p sha a)
-> (Array p shb b -> Hull p (Array sha a))
-> Array p shb b
-> Array p sha a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Exp x -> Array shb b -> Array sha a)
-> Tunnel p x -> Hull p (Array shb b) -> Hull p (Array sha a)
forall sl a b p.
(Exp sl -> a -> b) -> Tunnel p sl -> Hull p a -> Hull p b
Sym.mapHullWithExp Exp x -> Array shb b -> Array sha a
f (T p x -> Tunnel p x
forall a p. C a => T p a -> Tunnel p a
Sym.expParam T p x
x) (Hull p (Array shb b) -> Hull p (Array sha a))
-> (Array p shb b -> Hull p (Array shb b))
-> Array p shb b
-> Hull p (Array sha a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Array p shb b -> Hull p (Array shb b)
forall p sh a. Array p sh a -> Hull p (Array sh a)
Sym.arrayHull

withExp2 ::
   (Marshal.C x) =>
   (Exp x -> Core.Array sha a -> Core.Array shb b -> Core.Array shc c) ->
   Param.T p x -> Array p sha a -> Array p shb b -> Array p shc c
withExp2 :: forall x sha a shb b shc c p.
C x =>
(Exp x -> Array sha a -> Array shb b -> Array shc c)
-> T p x -> Array p sha a -> Array p shb b -> Array p shc c
withExp2 Exp x -> Array sha a -> Array shb b -> Array shc c
f T p x
x Array p sha a
a Array p shb b
b =
   Hull p (Array shc c) -> Array p shc c
forall p sh a. Hull p (Array sh a) -> Array p sh a
Sym.runHull (Hull p (Array shc c) -> Array p shc c)
-> Hull p (Array shc c) -> Array p shc c
forall a b. (a -> b) -> a -> b
$
   (Exp x -> Array sha a -> Array shb b -> Array shc c)
-> Tunnel p x
-> Hull p (Array sha a)
-> Hull p (Array shb b -> Array shc c)
forall sl a b p.
(Exp sl -> a -> b) -> Tunnel p sl -> Hull p a -> Hull p b
Sym.mapHullWithExp Exp x -> Array sha a -> Array shb b -> Array shc c
f (T p x -> Tunnel p x
forall a p. C a => T p a -> Tunnel p a
Sym.expParam T p x
x) (Array p sha a -> Hull p (Array sha a)
forall p sh a. Array p sh a -> Hull p (Array sh a)
Sym.arrayHull Array p sha a
a)
     Hull p (Array shb b -> Array shc c)
-> Hull p (Array shb b) -> Hull p (Array shc c)
forall a b. Hull p (a -> b) -> Hull p a -> Hull p b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Array p shb b -> Hull p (Array shb b)
forall p sh a. Array p sh a -> Hull p (Array sh a)
Sym.arrayHull Array p shb b
b

withExp3 ::
   (Marshal.C x) =>
   (Exp x -> Core.Array sha a ->
    Core.Array shb b -> Core.Array shc c -> Core.Array shd d) ->
   Param.T p x -> Array p sha a ->
   Array p shb b -> Array p shc c -> Array p shd d
withExp3 :: forall x sha a shb b shc c shd d p.
C x =>
(Exp x -> Array sha a -> Array shb b -> Array shc c -> Array shd d)
-> T p x
-> Array p sha a
-> Array p shb b
-> Array p shc c
-> Array p shd d
withExp3 Exp x -> Array sha a -> Array shb b -> Array shc c -> Array shd d
f T p x
x Array p sha a
a Array p shb b
b Array p shc c
c =
   Hull p (Array shd d) -> Array p shd d
forall p sh a. Hull p (Array sh a) -> Array p sh a
Sym.runHull (Hull p (Array shd d) -> Array p shd d)
-> Hull p (Array shd d) -> Array p shd d
forall a b. (a -> b) -> a -> b
$
   (Exp x -> Array sha a -> Array shb b -> Array shc c -> Array shd d)
-> Tunnel p x
-> Hull p (Array sha a)
-> Hull p (Array shb b -> Array shc c -> Array shd d)
forall sl a b p.
(Exp sl -> a -> b) -> Tunnel p sl -> Hull p a -> Hull p b
Sym.mapHullWithExp Exp x -> Array sha a -> Array shb b -> Array shc c -> Array shd d
f (T p x -> Tunnel p x
forall a p. C a => T p a -> Tunnel p a
Sym.expParam T p x
x) (Array p sha a -> Hull p (Array sha a)
forall p sh a. Array p sh a -> Hull p (Array sh a)
Sym.arrayHull Array p sha a
a)
     Hull p (Array shb b -> Array shc c -> Array shd d)
-> Hull p (Array shb b) -> Hull p (Array shc c -> Array shd d)
forall a b. Hull p (a -> b) -> Hull p a -> Hull p b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Array p shb b -> Hull p (Array shb b)
forall p sh a. Array p sh a -> Hull p (Array sh a)
Sym.arrayHull Array p shb b
b
     Hull p (Array shc c -> Array shd d)
-> Hull p (Array shc c) -> Hull p (Array shd d)
forall a b. Hull p (a -> b) -> Hull p a -> Hull p b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Array p shc c -> Hull p (Array shc c)
forall p sh a. Array p sh a -> Hull p (Array sh a)
Sym.arrayHull Array p shc c
c