{-# LANGUAGE Rank2Types #-} {-# LANGUAGE GADTs #-} module Data.Array.Knead.Parameterized.Symbolic ( Array, Exp, Sym.extendParameter, withExp, withExp2, withExp3, (Sym.!), Sym.fill, gather, backpermute, Sym.id, Sym.map, zipWith, Sym.fold1, Sym.fold1All, ) where import qualified Data.Array.Knead.Parameterized.Private as Sym import qualified Data.Array.Knead.Simple.Symbolic as Core import Data.Array.Knead.Parameterized.Private (Array, gather, ) import qualified Data.Array.Knead.Shape as Shape import qualified Data.Array.Knead.Parameter as Param import qualified Data.Array.Knead.Expression as Expr import Data.Array.Knead.Expression (Exp, ) import qualified LLVM.Extra.Multi.Value.Memory as MultiValueMemory import qualified LLVM.Extra.Multi.Value as MultiValue import Foreign.Storable (Storable, ) import Control.Applicative ((<*>), ) import Prelude (uncurry, ($), (.), ) {- fromScalar :: (Storable a, MultiValueMemory.C a, MultiValue.C a) => Param.T p a -> Array p Z a fromScalar = Sym.fill (return Z) -} backpermute :: (Shape.C sh0, Shape.Index sh0 ~ ix0, Shape.C sh1, Shape.Index sh1 ~ ix1, Storable sh1, MultiValueMemory.C sh1, MultiValue.C a) => Param.T p sh1 -> (Exp ix1 -> Exp ix0) -> Array p sh0 a -> Array p sh1 a backpermute sh1 f = gather (Core.map f (Sym.id sh1)) zipWith :: (Shape.C sh, MultiValueMemory.C d, Storable d) => (Exp d -> Exp a -> Exp b -> Exp c) -> Param.T p d -> Array p sh a -> Array p sh b -> Array p sh c zipWith f d a b = Sym.map (\di ab -> uncurry (f di) $ Expr.unzip ab) d $ Core.zip a b withExp :: (Storable x, MultiValueMemory.C x) => (Exp x -> Core.Array shb b -> Core.Array sha a) -> Param.T p x -> Array p shb b -> Array p sha a withExp f x = Sym.runHull . Sym.mapHullWithExp f (Sym.expParam x) . Sym.arrayHull withExp2 :: (Storable x, MultiValueMemory.C x) => (Exp x -> Core.Array sha a -> Core.Array shb b -> Core.Array shc c) -> Param.T p x -> Array p sha a -> Array p shb b -> Array p shc c withExp2 f x a b = Sym.runHull $ Sym.mapHullWithExp f (Sym.expParam x) (Sym.arrayHull a) <*> Sym.arrayHull b withExp3 :: (Storable x, MultiValueMemory.C x) => (Exp x -> Core.Array sha a -> Core.Array shb b -> Core.Array shc c -> Core.Array shd d) -> Param.T p x -> Array p sha a -> Array p shb b -> Array p shc c -> Array p shd d withExp3 f x a b c = Sym.runHull $ Sym.mapHullWithExp f (Sym.expParam x) (Sym.arrayHull a) <*> Sym.arrayHull b <*> Sym.arrayHull c