module Data.Array.Knead.Parameterized.Private where
import qualified Data.Array.Knead.Simple.Symbolic as Core
import qualified Data.Array.Knead.Shape as Shape
import qualified Data.Array.Knead.Expression as Expr
import Data.Array.Knead.Expression (Exp, )
import qualified LLVM.DSL.Parameter as Param
import qualified LLVM.Extra.Multi.Value as MultiValue
import qualified LLVM.Extra.Marshal as Marshal
import Control.Monad (liftM2)
import Control.Applicative (Applicative (pure, (<*>)), )
import Data.Tuple.Strict (zipPair)
import Prelude2010 hiding (id, map, zipWith, replicate)
import Prelude ()
data Array p sh a =
forall parameter context.
(Marshal.MV parameter) =>
Array {
core :: MultiValue.T parameter -> Core.Array sh a,
createContext :: p -> IO (context, parameter),
deleteContext :: context -> IO ()
}
instance Core.C (Array p) where
lift0 arr = Array (const arr) (createPlain (const ())) deletePlain
lift1 f (Array arr create delete) = Array (f . arr) create delete
lift2 f (Array arrA createA deleteA) (Array arrB createB deleteB) =
Array
(MultiValue.uncurry $ \paramA paramB ->
f (arrA paramA) (arrB paramB))
(combineCreate createA createB)
(combineDelete deleteA deleteB)
(!) ::
(Shape.C sh, Shape.Index sh ~ ix, Marshal.MV ix,
Shape.Scalar z) =>
Array p sh a -> Param.T p ix -> Array p z a
(!) arr pix =
runHull $
mapHullWithExp
(\ix carr -> Core.fromScalar $ carr Core.! ix)
(expParam pix)
(arrayHull arr)
fill ::
(Shape.C sh, Marshal.MV sh, Marshal.MV a) =>
Param.T p sh -> Param.T p a -> Array p sh a
fill sh a =
Shape.paramWith sh $ \getSh valueSh ->
Param.withMulti a $ \getA valueA ->
Array
(MultiValue.uncurry $ \vsh va ->
Core.fill (valueSh vsh) (Expr.lift0 $ valueA va))
(createPlain $ \p -> (getSh p, getA p))
deletePlain
gather ::
(Shape.C sh0, Shape.Index sh0 ~ ix0,
Shape.C sh1, MultiValue.C a) =>
Array p sh1 ix0 ->
Array p sh0 a ->
Array p sh1 a
gather = Core.gather
id ::
(Shape.C sh, Marshal.MV sh, Shape.Index sh ~ ix) =>
Param.T p sh -> Array p sh ix
id sh =
Shape.paramWith sh $ \getSh valueSh ->
Array
(Core.id . valueSh)
(createPlain getSh)
deletePlain
map ::
(Shape.C sh, Marshal.MV c) =>
(Exp c -> Exp a -> Exp b) ->
Param.T p c -> Array p sh a -> Array p sh b
map = lift Core.map
mapWithIndex ::
(Shape.C sh, Marshal.MV c, Shape.Index sh ~ ix) =>
(Exp c -> Exp ix -> Exp a -> Exp b) ->
Param.T p c -> Array p sh a -> Array p sh b
mapWithIndex = lift Core.mapWithIndex
fold1 ::
(Shape.C sh0, Shape.C sh1, Marshal.MV c, MultiValue.C a) =>
(Exp c -> Exp a -> Exp a -> Exp a) ->
Param.T p c -> Array p (sh0, sh1) a -> Array p sh0 a
fold1 = lift Core.fold1
fold1All ::
(Shape.C sh, Shape.Scalar z, Marshal.MV c, MultiValue.C a) =>
(Exp c -> Exp a -> Exp a -> Exp a) ->
Param.T p c -> Array p sh a -> Array p z a
fold1All = lift (\p -> Core.fill Shape.scalar . Core.fold1All p)
lift ::
(Shape.C sh0, Shape.C sh1, Marshal.MV c) =>
(f -> Core.Array sh0 a -> Core.Array sh1 b) ->
(Exp c -> f) ->
Param.T p c -> Array p sh0 a -> Array p sh1 b
lift g f c arr =
runHull $
mapHullWithExp
(\cexp -> g (f cexp))
(expParam c)
(arrayHull arr)
data Hull p a =
forall parameter context.
(Marshal.MV parameter) =>
Hull {
hullCore :: MultiValue.T parameter -> a,
hullCreateContext :: p -> IO (context, parameter),
hullDeleteContext :: context -> IO ()
}
instance Functor (Hull p) where
fmap f (Hull arr create delete) = Hull (f . arr) create delete
instance Applicative (Hull p) where
pure a = Hull (const a) (const $ return ((),())) return
Hull arrA createA deleteA <*> Hull arrB createB deleteB =
Hull
(MultiValue.uncurry $ \a b -> arrA a $ arrB b)
(combineCreate createA createB)
(combineDelete deleteA deleteB)
mapHullWithExp ::
(Exp sl -> a -> b) ->
Param.Tunnel p sl -> Hull p a -> Hull p b
mapHullWithExp f tunnel (Hull arr create delete) =
case tunnel of
Param.Tunnel getSl valueSl ->
Hull
(MultiValue.uncurry $ \arrp sl ->
f (Expr.lift0 $ valueSl sl) $ arr arrp)
(\p -> do
(ctx, param) <- create p
return (ctx, (param, getSl p)))
delete
expHull :: Param.Tunnel p sl -> Hull p (Exp sl)
expHull tunnel =
case tunnel of
Param.Tunnel getSl valueSl ->
Hull
(Expr.lift0 . valueSl)
(\p -> return ((), getSl p))
return
arrayHull :: Array p sh a -> Hull p (Core.Array sh a)
arrayHull (Array arr create delete) = Hull arr create delete
runHull :: Hull p (Core.Array sh a) -> Array p sh a
runHull (Hull arr create delete) = Array arr create delete
extendHull :: (q -> p) -> Hull p a -> Hull q a
extendHull f (Hull arr create delete) = Hull arr (create . f) delete
expParam :: (Marshal.MV a) => Param.T p a -> Param.Tunnel p a
expParam = Param.tunnel MultiValue.cons
createPlain :: (Monad m) => (p -> pl) -> p -> m ((), pl)
createPlain f p = return ((), f p)
deletePlain :: (Monad m) => () -> m ()
deletePlain () = return ()
combineCreate ::
Monad m =>
(p -> m (ctxA, paramA)) -> (p -> m (ctxB, paramB)) ->
p -> m ((ctxA, ctxB), (paramA, paramB))
combineCreate createA createB p =
liftM2 zipPair (createA p) (createB p)
combineDelete ::
Monad m =>
(ctxA -> m ()) -> (ctxB -> m ()) -> (ctxA, ctxB) -> m ()
combineDelete deleteA deleteB (ctxA, ctxB) = do
deleteA ctxA
deleteB ctxB
extendParameter ::
(q -> p) -> Array p sh a -> Array q sh a
extendParameter f (Array arr create delete) =
Array arr (create . f) delete