module Data.Array.Knead.Parameterized.Render (run) where
import qualified Data.Array.Knead.Parameterized.Physical as PhysP
import qualified Data.Array.Knead.Parameterized.Private as Sym
import qualified Data.Array.Knead.Simple.Physical as Phys
import qualified Data.Array.Knead.Simple.Private as Core
import qualified Data.Array.Knead.Parameter as Param
import qualified Data.Array.Knead.Index.Nested.Shape as Shape
import Data.Array.Knead.Expression (Exp, )
import qualified LLVM.Extra.Multi.Value.Memory as MultiValueMemory
import Foreign.Storable (Storable, )
import Control.Arrow (arr, )
import Control.Applicative (liftA2, liftA3, pure, (<*>), )
import Data.Tuple.HT (fst3, snd3, thd3, )
class C f where
type Plain f
build :: Sym.Hull p f -> IO (p -> Plain f)
instance
(MultiValueMemory.C sh, Storable sh, Shape.C sh,
MultiValueMemory.C a, Storable a) =>
C (Core.Array sh a) where
type Plain (Core.Array sh a) = IO (Phys.Array sh a)
build = PhysP.render . Sym.runHull
singleton :: Exp a -> Core.Array () a
singleton = Core.fromScalar
instance (MultiValueMemory.C a, Storable a) => C (Exp a) where
type Plain (Exp a) = IO a
build = PhysP.the . Sym.runHull . fmap singleton
instance (Argument arg, C func) => C (arg -> func) where
type Plain (arg -> func) = PlainArg arg -> Plain func
build f = fmap curry $ build $ Sym.extendHull fst f <*> buildArg (arr snd)
class Argument arg where
type PlainArg arg
buildArg :: Param.T p (PlainArg arg) -> Sym.Hull p arg
instance
(MultiValueMemory.C sh, Storable sh, Shape.C sh, MultiValueMemory.C a) =>
Argument (Core.Array sh a) where
type PlainArg (Core.Array sh a) = Phys.Array sh a
buildArg = Sym.arrayHull . PhysP.feed
instance (MultiValueMemory.C a, Storable a) => Argument (Exp a) where
type PlainArg (Exp a) = a
buildArg = Sym.expHull . Sym.expParam
instance (Argument a, Argument b) => Argument (a,b) where
type PlainArg (a,b) = (PlainArg a, PlainArg b)
buildArg p = liftA2 (,) (buildArg $ fmap fst p) (buildArg $ fmap snd p)
instance (Argument a, Argument b, Argument c) => Argument (a,b,c) where
type PlainArg (a,b,c) = (PlainArg a, PlainArg b, PlainArg c)
buildArg p =
liftA3 (,,)
(buildArg $ fmap fst3 p) (buildArg $ fmap snd3 p) (buildArg $ fmap thd3 p)
run :: (C f) => f -> IO (Plain f)
run f = fmap ($()) $ build $ pure f
_example ::
(Storable x, MultiValueMemory.C x,
Shape.C sha, Storable sha, MultiValueMemory.C sha, MultiValueMemory.C a,
Shape.C shb, Storable shb, MultiValueMemory.C shb, MultiValueMemory.C b,
Shape.C shc, Storable shc, MultiValueMemory.C shc, MultiValueMemory.C c,
Storable c) =>
(Exp x -> Core.Array sha a -> Core.Array shb b -> Core.Array shc c) ->
IO (x -> Phys.Array sha a -> Phys.Array shb b -> IO (Phys.Array shc c))
_example f =
fmap (\g -> curry $ curry g) $
PhysP.render $
Sym.runHull $
pure f
<*> Sym.expHull (Sym.expParam $ arr (fst.fst))
<*> Sym.arrayHull (PhysP.feed $ arr (snd.fst))
<*> Sym.arrayHull (PhysP.feed $ arr snd)