{-# LANGUAGE ExistentialQuantification #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE TypeOperators #-} module Data.Array.Knead.Parameterized.Private where import qualified Data.Array.Knead.Simple.Symbolic as Core import qualified Data.Array.Knead.Parameter as Param import qualified Data.Array.Knead.Index.Nested.Shape as Shape import qualified Data.Array.Knead.Expression as Expr import Data.Array.Knead.Expression (Exp, ) import qualified LLVM.Extra.Multi.Value.Memory as MultiValueMemory import qualified LLVM.Extra.Multi.Value as MultiValue import Foreign.Storable (Storable, ) import Prelude hiding (id, map, zipWith, replicate, ) data Array p sh a = forall parameter context. (Storable parameter, MultiValueMemory.C parameter) => Array { core :: MultiValue.T parameter -> Core.Array sh a, createContext :: p -> IO (context, parameter), deleteContext :: context -> IO () } instance Core.C (Array p) where lift0 arr = Array (const arr) (createPlain (const ())) deletePlain lift1 f (Array arr create delete) = Array (f . arr) create delete lift2 f (Array arrA createA deleteA) (Array arrB createB deleteB) = Array (\p -> case MultiValue.unzip p of (paramA, paramB) -> f (arrA paramA) (arrB paramB)) (combineCreate createA createB) (combineDelete deleteA deleteB) {- (!) :: (Shape.C sh) => Array p sh a -> Param.T p sh -> Array p z a (!) arr pix = paramArray (\ix carr -> Core.fromScalar $ carr Core.! ix) (Shape.tunnel pix) arr -} (!) :: (Shape.C sh, Shape.Index sh ~ ix, MultiValue.C ix, Storable ix, MultiValueMemory.C ix, Shape.Scalar z) => Array p sh a -> Param.T p ix -> Array p z a (!) arr pix = paramArray (\ix carr -> Core.fromScalar $ carr Core.! ix) (Param.tunnel MultiValue.cons pix) arr fill :: (Shape.C sh, Storable sh, MultiValueMemory.C sh, MultiValue.C a, Storable a, MultiValueMemory.C a) => Param.T p sh -> Param.T p a -> Array p sh a fill sh a = Shape.paramWith sh $ \getSh valueSh -> Param.withMulti a $ \getA valueA -> Array (\p -> case MultiValue.unzip p of (vsh, va) -> Core.fill (valueSh vsh) (Expr.lift0 $ valueA va)) (createPlain $ \p -> (getSh p, getA p)) deletePlain gather :: (Shape.C sh0, Shape.Index sh0 ~ ix0, Shape.C sh1, MultiValue.C a) => Array p sh1 ix0 -> Array p sh0 a -> Array p sh1 a gather = Core.gather id :: (Shape.C sh, Storable sh, MultiValueMemory.C sh, Shape.Index sh ~ ix) => Param.T p sh -> Array p sh ix id sh = Shape.paramWith sh $ \getSh valueSh -> Array (Core.id . valueSh) (createPlain getSh) deletePlain map :: (Shape.C sh, MultiValueMemory.C c, Storable c) => (Exp c -> Exp a -> Exp b) -> Param.T p c -> Array p sh a -> Array p sh b map = lift Core.map mapWithIndex :: (Shape.C sh, MultiValueMemory.C c, Storable c, Shape.Index sh ~ ix) => (Exp c -> Exp ix -> Exp a -> Exp b) -> Param.T p c -> Array p sh a -> Array p sh b mapWithIndex = lift Core.mapWithIndex fold1 :: (Shape.C sh0, Shape.C sh1, MultiValueMemory.C c, Storable c, MultiValue.C a) => (Exp c -> Exp a -> Exp a -> Exp a) -> Param.T p c -> Array p (sh0, sh1) a -> Array p sh0 a fold1 = lift Core.fold1 fold1All :: (Shape.C sh, Shape.Scalar z, MultiValueMemory.C c, Storable c, MultiValue.C a) => (Exp c -> Exp a -> Exp a -> Exp a) -> Param.T p c -> Array p sh a -> Array p z a fold1All = lift Core.fold1All lift :: (Shape.C sh0, Shape.C sh1, MultiValueMemory.C c, Storable c) => (f -> Core.Array sh0 a -> Core.Array sh1 b) -> (Exp c -> f) -> Param.T p c -> Array p sh0 a -> Array p sh1 b lift g f c arr = paramArray (\cexp -> g (f cexp)) (Param.tunnel MultiValue.cons c) arr {- Could be generalized to nested indices. foldSelected1 :: (Fold.C sl, MultiValue.C a) => (Exp a -> Exp a -> Exp a) -> Param.T p (Linear.Shape sl) -> Array p (Linear.Shape (Fold.FullShape sl)) a -> Array p (Linear.Shape (Fold.FoldShape sl)) a foldSelected1 f esl arr = paramArray (Core.foldSelected1 f) (Fold.tunnel esl) arr -} paramArray :: (Exp sl -> Core.Array shb b -> Core.Array sha a) -> Param.Tunnel p sl -> Array p shb b -> Array p sha a paramArray f tunnel (Array arr create delete) = case tunnel of Param.Tunnel getSl valueSl -> Array (\p -> case MultiValue.unzip p of (arrp, sl) -> f (Expr.lift0 $ valueSl sl) $ arr arrp) (\p -> do (ctx, param) <- create p return (ctx, (param, getSl p))) delete createPlain :: (Monad m) => (p -> pl) -> p -> m ((), pl) createPlain f p = return ((), f p) deletePlain :: (Monad m) => () -> m () deletePlain () = return () combineCreate :: Monad m => (p -> m (ctxA, paramA)) -> (p -> m (ctxB, paramB)) -> p -> m ((ctxA, ctxB), (paramA, paramB)) combineCreate createA createB p = do (ctxA, paramA) <- createA p (ctxB, paramB) <- createB p return ((ctxA, ctxB), (paramA, paramB)) combineDelete :: Monad m => (ctxA -> m ()) -> (ctxB -> m ()) -> (ctxA, ctxB) -> m () combineDelete deleteA deleteB (ctxA, ctxB) = do deleteA ctxA deleteB ctxB extendParameter :: (q -> p) -> Array p sh a -> Array q sh a extendParameter f (Array arr create delete) = Array arr (create . f) delete