{-# LANGUAGE Rank2Types #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE TypeOperators #-} module Data.Array.Knead.Parameterized.Symbolic ( Array, Exp, Sym.extendParameter, (Sym.!), Sym.fill, gather, backpermute, Sym.id, Sym.map, zipWith, Sym.fold1, Sym.fold1All, ) where import qualified Data.Array.Knead.Parameterized.Private as Sym import qualified Data.Array.Knead.Simple.Symbolic as Core import Data.Array.Knead.Parameterized.Private (Array, gather, ) import qualified Data.Array.Knead.Index.Nested.Shape as Shape import qualified Data.Array.Knead.Parameter as Param import qualified Data.Array.Knead.Expression as Expr import Data.Array.Knead.Expression (Exp, ) import qualified LLVM.Extra.Multi.Value.Memory as MultiValueMemory import qualified LLVM.Extra.Multi.Value as MultiValue import Foreign.Storable (Storable, ) import Prelude (uncurry, ($), ) {- fromScalar :: (Storable a, MultiValueMemory.C a, MultiValue.C a) => Param.T p a -> Array p Z a fromScalar = Sym.fill (return Z) -} backpermute :: (Shape.C sh0, Shape.Index sh0 ~ ix0, Shape.C sh1, Shape.Index sh1 ~ ix1, Storable sh1, MultiValueMemory.C sh1, MultiValue.C a) => Param.T p sh1 -> (Exp ix1 -> Exp ix0) -> Array p sh0 a -> Array p sh1 a backpermute sh1 f = gather (Core.map f (Sym.id sh1)) {- _backpermute sh1 f = paramArray (flip Core.backpermute f) (Shape.tunnel sh1) -} zipWith :: (Shape.C sh, MultiValueMemory.C d, Storable d) => (Exp d -> Exp a -> Exp b -> Exp c) -> Param.T p d -> Array p sh a -> Array p sh b -> Array p sh c zipWith f d a b = Sym.map (\di ab -> uncurry (f di) $ Expr.unzip ab) d $ Core.zip a b