{-# LANGUAGE ExistentialQuantification #-} {-# LANGUAGE GADTs #-} module Data.Array.Knead.Parameterized.Private where import qualified Data.Array.Knead.Simple.Symbolic as Core import qualified Data.Array.Knead.Shape as Shape import qualified Data.Array.Knead.Expression as Expr import Data.Array.Knead.Expression (Exp, ) import qualified LLVM.DSL.Parameter as Param import qualified LLVM.Extra.Multi.Value as MultiValue import qualified LLVM.Extra.Marshal as Marshal import Control.Monad (liftM2) import Control.Applicative (Applicative (pure, (<*>)), ) import Data.Tuple.Strict (zipPair) import Prelude2010 hiding (id, map, zipWith, replicate) import Prelude () -- in principle we could define Array in terms of Hull and Core.Array data Array p sh a = forall parameter context. (Marshal.MV parameter) => Array { core :: MultiValue.T parameter -> Core.Array sh a, createContext :: p -> IO (context, parameter), deleteContext :: context -> IO () } instance Core.C (Array p) where lift0 arr = Array (const arr) (createPlain (const ())) deletePlain lift1 f (Array arr create delete) = Array (f . arr) create delete lift2 f (Array arrA createA deleteA) (Array arrB createB deleteB) = Array (MultiValue.uncurry $ \paramA paramB -> f (arrA paramA) (arrB paramB)) (combineCreate createA createB) (combineDelete deleteA deleteB) (!) :: (Shape.C sh, Shape.Index sh ~ ix, Marshal.MV ix, Shape.Scalar z) => Array p sh a -> Param.T p ix -> Array p z a (!) arr pix = runHull $ mapHullWithExp (\ix carr -> Core.fromScalar $ carr Core.! ix) (expParam pix) (arrayHull arr) fill :: (Shape.C sh, Marshal.MV sh, Marshal.MV a) => Param.T p sh -> Param.T p a -> Array p sh a fill sh a = Shape.paramWith sh $ \getSh valueSh -> Param.withMulti a $ \getA valueA -> Array (MultiValue.uncurry $ \vsh va -> Core.fill (valueSh vsh) (Expr.lift0 $ valueA va)) (createPlain $ \p -> (getSh p, getA p)) deletePlain gather :: (Shape.C sh0, Shape.Index sh0 ~ ix0, Shape.C sh1, MultiValue.C a) => Array p sh1 ix0 -> Array p sh0 a -> Array p sh1 a gather = Core.gather id :: (Shape.C sh, Marshal.MV sh, Shape.Index sh ~ ix) => Param.T p sh -> Array p sh ix id sh = Shape.paramWith sh $ \getSh valueSh -> Array (Core.id . valueSh) (createPlain getSh) deletePlain map :: (Shape.C sh, Marshal.MV c) => (Exp c -> Exp a -> Exp b) -> Param.T p c -> Array p sh a -> Array p sh b map = lift Core.map mapWithIndex :: (Shape.C sh, Marshal.MV c, Shape.Index sh ~ ix) => (Exp c -> Exp ix -> Exp a -> Exp b) -> Param.T p c -> Array p sh a -> Array p sh b mapWithIndex = lift Core.mapWithIndex fold1 :: (Shape.C sh0, Shape.C sh1, Marshal.MV c, MultiValue.C a) => (Exp c -> Exp a -> Exp a -> Exp a) -> Param.T p c -> Array p (sh0, sh1) a -> Array p sh0 a fold1 = lift Core.fold1 fold1All :: (Shape.C sh, Shape.Scalar z, Marshal.MV c, MultiValue.C a) => (Exp c -> Exp a -> Exp a -> Exp a) -> Param.T p c -> Array p sh a -> Array p z a fold1All = lift (\p -> Core.fill Shape.scalar . Core.fold1All p) lift :: (Shape.C sh0, Shape.C sh1, Marshal.MV c) => (f -> Core.Array sh0 a -> Core.Array sh1 b) -> (Exp c -> f) -> Param.T p c -> Array p sh0 a -> Array p sh1 b lift g f c arr = runHull $ mapHullWithExp (\cexp -> g (f cexp)) (expParam c) (arrayHull arr) data Hull p a = forall parameter context. (Marshal.MV parameter) => Hull { hullCore :: MultiValue.T parameter -> a, hullCreateContext :: p -> IO (context, parameter), hullDeleteContext :: context -> IO () } instance Functor (Hull p) where fmap f (Hull arr create delete) = Hull (f . arr) create delete instance Applicative (Hull p) where pure a = Hull (const a) (const $ return ((),())) return Hull arrA createA deleteA <*> Hull arrB createB deleteB = Hull (MultiValue.uncurry $ \a b -> arrA a $ arrB b) (combineCreate createA createB) (combineDelete deleteA deleteB) {- | Equivalent to @liftA2 f (expHull p)@ but saves us an empty context. -} mapHullWithExp :: (Exp sl -> a -> b) -> Param.Tunnel p sl -> Hull p a -> Hull p b mapHullWithExp f tunnel (Hull arr create delete) = case tunnel of Param.Tunnel getSl valueSl -> Hull (MultiValue.uncurry $ \arrp sl -> f (Expr.lift0 $ valueSl sl) $ arr arrp) (\p -> do (ctx, param) <- create p return (ctx, (param, getSl p))) delete expHull :: Param.Tunnel p sl -> Hull p (Exp sl) expHull tunnel = case tunnel of Param.Tunnel getSl valueSl -> Hull (Expr.lift0 . valueSl) (\p -> return ((), getSl p)) return arrayHull :: Array p sh a -> Hull p (Core.Array sh a) arrayHull (Array arr create delete) = Hull arr create delete runHull :: Hull p (Core.Array sh a) -> Array p sh a runHull (Hull arr create delete) = Array arr create delete extendHull :: (q -> p) -> Hull p a -> Hull q a extendHull f (Hull arr create delete) = Hull arr (create . f) delete expParam :: (Marshal.MV a) => Param.T p a -> Param.Tunnel p a expParam = Param.tunnel MultiValue.cons createPlain :: (Monad m) => (p -> pl) -> p -> m ((), pl) createPlain f p = return ((), f p) deletePlain :: (Monad m) => () -> m () deletePlain () = return () {-# INLINE combineCreate #-} combineCreate :: Monad m => (p -> m (ctxA, paramA)) -> (p -> m (ctxB, paramB)) -> p -> m ((ctxA, ctxB), (paramA, paramB)) combineCreate createA createB p = liftM2 zipPair (createA p) (createB p) {-# INLINE combineDelete #-} combineDelete :: Monad m => (ctxA -> m ()) -> (ctxB -> m ()) -> (ctxA, ctxB) -> m () combineDelete deleteA deleteB (ctxA, ctxB) = do deleteA ctxA deleteB ctxB extendParameter :: (q -> p) -> Array p sh a -> Array q sh a extendParameter f (Array arr create delete) = Array arr (create . f) delete