{-# LANGUAGE TypeFamilies #-} {- | Simplify running the @render@ function by handling passing of parameters. -} module Data.Array.Knead.Parameterized.Render ( run, MapFilter(..), FilterOuter(..), Scatter(..), ScatterMaybe(..), MapAccumLSimple(..), MapAccumLSequence(..), MapAccumL(..), FoldOuterL(..), AddDimension(..), ) where import qualified Data.Array.Knead.Parameterized.PhysicalHull as PhysHullP import qualified Data.Array.Knead.Parameterized.Physical as PhysP import qualified Data.Array.Knead.Parameterized.Private as Sym import qualified Data.Array.Knead.Simple.Physical as Phys import qualified Data.Array.Knead.Simple.Private as Core import qualified Data.Array.Knead.Parameter as Param import qualified Data.Array.Knead.Shape as Shape import Data.Array.Knead.Parameterized.PhysicalHull (MapFilter, FilterOuter, MapAccumLSimple, MapAccumLSequence, MapAccumL, FoldOuterL, Scatter, ScatterMaybe, AddDimension) import Data.Array.Knead.Expression (Exp, ) import qualified LLVM.Extra.Multi.Value.Memory as MultiValueMemory import qualified LLVM.Extra.Multi.Value as MultiValue import Foreign.Storable (Storable, ) import Control.Arrow (arr, ) import Control.Applicative (liftA2, liftA3, pure, (<*>), ) import Data.Tuple.HT (fst3, snd3, thd3, ) class C f where type Plain f build :: Sym.Hull p f -> IO (p -> Plain f) instance (MultiValueMemory.C sh, Storable sh, Shape.C sh, MultiValueMemory.C a, Storable a) => C (Core.Array sh a) where type Plain (Core.Array sh a) = IO (Phys.Array sh a) build = PhysHullP.render instance (Shape.Sequence n, Storable n, MultiValueMemory.C n, Storable b, MultiValueMemory.C b) => C (MapFilter n a b) where type Plain (MapFilter n a b) = IO (Phys.Array n b) build = PhysHullP.mapFilter instance (Shape.Sequence n, Storable n, MultiValueMemory.C n, Shape.C sh, Storable sh, MultiValueMemory.C sh, Storable a, MultiValueMemory.C a) => C (FilterOuter n sh a) where type Plain (FilterOuter n sh a) = IO (Phys.Array (n,sh) a) build = PhysHullP.filterOuter instance (MultiValueMemory.C sh0, Storable sh0, Shape.C sh0, MultiValueMemory.C sh1, Storable sh1, Shape.C sh1, MultiValueMemory.C a, Storable a) => C (Scatter sh0 sh1 a) where type Plain (Scatter sh0 sh1 a) = IO (Phys.Array sh1 a) build = PhysHullP.scatter instance (MultiValueMemory.C sh0, Storable sh0, Shape.C sh0, MultiValueMemory.C sh1, Storable sh1, Shape.C sh1, MultiValueMemory.C a, Storable a) => C (ScatterMaybe sh0 sh1 a) where type Plain (ScatterMaybe sh0 sh1 a) = IO (Phys.Array sh1 a) build = PhysHullP.scatterMaybe instance (Shape.C sh, Storable sh, MultiValueMemory.C sh, Shape.C n, Storable n, MultiValueMemory.C n, MultiValue.C acc, Storable a, MultiValueMemory.C a, Storable b, MultiValueMemory.C b) => C (MapAccumLSimple sh n acc a b) where type Plain (MapAccumLSimple sh n acc a b) = IO (Phys.Array (sh,n) b) build = PhysHullP.mapAccumLSimple instance (Shape.C n, Storable n, MultiValueMemory.C n, MultiValue.C acc, Storable final, MultiValueMemory.C final, Storable a, MultiValueMemory.C a, Storable b, MultiValueMemory.C b) => C (MapAccumLSequence n acc final a b) where type Plain (MapAccumLSequence n acc final a b) = IO (final, Phys.Array n b) build = PhysHullP.mapAccumLSequence instance (Shape.C sh, Storable sh, MultiValueMemory.C sh, Shape.C n, Storable n, MultiValueMemory.C n, MultiValue.C acc, Storable final, MultiValueMemory.C final, Storable a, MultiValueMemory.C a, Storable b, MultiValueMemory.C b) => C (MapAccumL sh n acc final a b) where type Plain (MapAccumL sh n acc final a b) = IO (Phys.Array sh final, Phys.Array (sh,n) b) build = PhysHullP.mapAccumL instance (Shape.C n, Storable n, MultiValueMemory.C n, Shape.C sh, Storable sh, MultiValueMemory.C sh, Storable a, MultiValueMemory.C a, Storable b, MultiValueMemory.C b) => C (FoldOuterL n sh a b) where type Plain (FoldOuterL n sh a b) = IO (Phys.Array sh a) build = PhysHullP.foldOuterL instance (Shape.C sh, Storable sh, MultiValueMemory.C sh, Shape.C n, Storable n, MultiValueMemory.C n, Storable b, MultiValueMemory.C b) => C (AddDimension sh n a b) where type Plain (AddDimension sh n a b) = IO (Phys.Array (sh,n) b) build = PhysHullP.addDimension singleton :: Exp a -> Core.Array () a singleton = Core.fromScalar instance (MultiValueMemory.C a, Storable a) => C (Exp a) where type Plain (Exp a) = IO a build = PhysP.the . Sym.runHull . fmap singleton instance (Argument arg, C func) => C (arg -> func) where type Plain (arg -> func) = PlainArg arg -> Plain func build f = fmap curry $ build $ Sym.extendHull fst f <*> buildArg (arr snd) class Argument arg where type PlainArg arg buildArg :: Param.T p (PlainArg arg) -> Sym.Hull p arg instance (MultiValueMemory.C sh, Storable sh, Shape.C sh, MultiValueMemory.C a) => Argument (Core.Array sh a) where type PlainArg (Core.Array sh a) = Phys.Array sh a buildArg = Sym.arrayHull . PhysP.feed instance (MultiValueMemory.C a, Storable a) => Argument (Exp a) where type PlainArg (Exp a) = a buildArg = Sym.expHull . Sym.expParam instance (Argument a, Argument b) => Argument (a,b) where type PlainArg (a,b) = (PlainArg a, PlainArg b) buildArg p = liftA2 (,) (buildArg $ fmap fst p) (buildArg $ fmap snd p) instance (Argument a, Argument b, Argument c) => Argument (a,b,c) where type PlainArg (a,b,c) = (PlainArg a, PlainArg b, PlainArg c) buildArg p = liftA3 (,,) (buildArg $ fmap fst3 p) (buildArg $ fmap snd3 p) (buildArg $ fmap thd3 p) run :: (C f) => f -> IO (Plain f) run f = fmap ($()) $ build $ pure f _example :: (Storable x, MultiValueMemory.C x, Shape.C sha, Storable sha, MultiValueMemory.C sha, MultiValueMemory.C a, Shape.C shb, Storable shb, MultiValueMemory.C shb, MultiValueMemory.C b, Shape.C shc, Storable shc, MultiValueMemory.C shc, MultiValueMemory.C c, Storable c) => (Exp x -> Core.Array sha a -> Core.Array shb b -> Core.Array shc c) -> IO (x -> Phys.Array sha a -> Phys.Array shb b -> IO (Phys.Array shc c)) _example f = fmap (\g -> curry $ curry g) $ PhysP.render $ Sym.runHull $ pure f <*> Sym.expHull (Sym.expParam $ arr (fst.fst)) <*> Sym.arrayHull (PhysP.feed $ arr (snd.fst)) <*> Sym.arrayHull (PhysP.feed $ arr snd)