{-# LANGUAGE TypeFamilies #-} {- | Simplify running the @render@ function by handling passing of parameters. -} module Data.Array.Knead.Parameterized.Render ( run, Scatter(..), ScatterMaybe(..), MapAccumL(..), FoldOuterL(..), ) where import qualified Data.Array.Knead.Parameterized.PhysicalHull as PhysHullP import qualified Data.Array.Knead.Parameterized.Physical as PhysP import qualified Data.Array.Knead.Parameterized.Private as Sym import qualified Data.Array.Knead.Simple.Physical as Phys import qualified Data.Array.Knead.Simple.Private as Core import qualified Data.Array.Knead.Parameter as Param import qualified Data.Array.Knead.Index.Nested.Shape as Shape import Data.Array.Knead.Parameterized.PhysicalHull (Scatter, ScatterMaybe, MapAccumL, FoldOuterL) import Data.Array.Knead.Expression (Exp, ) import qualified LLVM.Extra.Multi.Value.Memory as MultiValueMemory import qualified LLVM.Extra.Multi.Value as MultiValue import Foreign.Storable (Storable, ) import Control.Arrow (arr, ) import Control.Applicative (liftA2, liftA3, pure, (<*>), ) import Data.Tuple.HT (fst3, snd3, thd3, ) class C f where type Plain f build :: Sym.Hull p f -> IO (p -> Plain f) instance (MultiValueMemory.C sh, Storable sh, Shape.C sh, MultiValueMemory.C a, Storable a) => C (Core.Array sh a) where type Plain (Core.Array sh a) = IO (Phys.Array sh a) build = PhysHullP.render instance (MultiValueMemory.C sh0, Storable sh0, Shape.C sh0, MultiValueMemory.C sh1, Storable sh1, Shape.C sh1, MultiValueMemory.C a, Storable a) => C (Scatter sh0 sh1 a) where type Plain (Scatter sh0 sh1 a) = IO (Phys.Array sh1 a) build = PhysHullP.scatter instance (MultiValueMemory.C sh0, Storable sh0, Shape.C sh0, MultiValueMemory.C sh1, Storable sh1, Shape.C sh1, MultiValueMemory.C a, Storable a) => C (ScatterMaybe sh0 sh1 a) where type Plain (ScatterMaybe sh0 sh1 a) = IO (Phys.Array sh1 a) build = PhysHullP.scatterMaybe instance (Shape.C sh, Storable sh, MultiValueMemory.C sh, Shape.C n, Storable n, MultiValueMemory.C n, MultiValue.C acc, Storable a, MultiValueMemory.C a, Storable b, MultiValueMemory.C b) => C (MapAccumL sh n acc a b) where type Plain (MapAccumL sh n acc a b) = IO (Phys.Array (sh,n) b) build = PhysHullP.mapAccumL instance (Shape.C sh, Storable sh, MultiValueMemory.C sh, Shape.C n, Storable n, MultiValueMemory.C n, Storable a, MultiValueMemory.C a, Storable b, MultiValueMemory.C b) => C (FoldOuterL n sh a b) where type Plain (FoldOuterL n sh a b) = IO (Phys.Array sh a) build = PhysHullP.foldOuterL singleton :: Exp a -> Core.Array () a singleton = Core.fromScalar instance (MultiValueMemory.C a, Storable a) => C (Exp a) where type Plain (Exp a) = IO a build = PhysP.the . Sym.runHull . fmap singleton instance (Argument arg, C func) => C (arg -> func) where type Plain (arg -> func) = PlainArg arg -> Plain func build f = fmap curry $ build $ Sym.extendHull fst f <*> buildArg (arr snd) class Argument arg where type PlainArg arg buildArg :: Param.T p (PlainArg arg) -> Sym.Hull p arg instance (MultiValueMemory.C sh, Storable sh, Shape.C sh, MultiValueMemory.C a) => Argument (Core.Array sh a) where type PlainArg (Core.Array sh a) = Phys.Array sh a buildArg = Sym.arrayHull . PhysP.feed instance (MultiValueMemory.C a, Storable a) => Argument (Exp a) where type PlainArg (Exp a) = a buildArg = Sym.expHull . Sym.expParam instance (Argument a, Argument b) => Argument (a,b) where type PlainArg (a,b) = (PlainArg a, PlainArg b) buildArg p = liftA2 (,) (buildArg $ fmap fst p) (buildArg $ fmap snd p) instance (Argument a, Argument b, Argument c) => Argument (a,b,c) where type PlainArg (a,b,c) = (PlainArg a, PlainArg b, PlainArg c) buildArg p = liftA3 (,,) (buildArg $ fmap fst3 p) (buildArg $ fmap snd3 p) (buildArg $ fmap thd3 p) run :: (C f) => f -> IO (Plain f) run f = fmap ($()) $ build $ pure f _example :: (Storable x, MultiValueMemory.C x, Shape.C sha, Storable sha, MultiValueMemory.C sha, MultiValueMemory.C a, Shape.C shb, Storable shb, MultiValueMemory.C shb, MultiValueMemory.C b, Shape.C shc, Storable shc, MultiValueMemory.C shc, MultiValueMemory.C c, Storable c) => (Exp x -> Core.Array sha a -> Core.Array shb b -> Core.Array shc c) -> IO (x -> Phys.Array sha a -> Phys.Array shb b -> IO (Phys.Array shc c)) _example f = fmap (\g -> curry $ curry g) $ PhysP.render $ Sym.runHull $ pure f <*> Sym.expHull (Sym.expParam $ arr (fst.fst)) <*> Sym.arrayHull (PhysP.feed $ arr (snd.fst)) <*> Sym.arrayHull (PhysP.feed $ arr snd)