{-# LANGUAGE GADTs #-} {-# LANGUAGE ForeignFunctionInterface #-} module Data.Array.Knead.Parameterized.PhysicalHull ( render, Scatter(..), scatter, ScatterMaybe(..), scatterMaybe, MapAccumL(..), mapAccumL, FoldOuterL(..), foldOuterL, ) where import qualified Data.Array.Knead.Parameterized.Private as Sym import qualified Data.Array.Knead.Simple.PhysicalPrivate as Priv import qualified Data.Array.Knead.Simple.Physical as Phys import qualified Data.Array.Knead.Simple.Private as Core import qualified Data.Array.Knead.Index.Nested.Shape as Shape import qualified Data.Array.Knead.Code as Code import Data.Array.Knead.Expression (Exp, unExp, ) import Data.Array.Knead.Code (compile, ) import qualified LLVM.Extra.Multi.Value.Memory as MultiValueMemory import qualified LLVM.Extra.Multi.Value as MultiValue import qualified LLVM.Extra.Memory as Memory import qualified LLVM.Extra.Arithmetic as A import qualified LLVM.Core as LLVM import Foreign.Marshal.Utils (with, ) import Foreign.Marshal.Alloc (alloca, ) import Foreign.Storable (Storable, peek, ) import Foreign.ForeignPtr (withForeignPtr, mallocForeignPtrArray, ) import Foreign.Ptr (FunPtr, Ptr, ) import Control.Exception (bracket, ) import Control.Monad.HT (void, ) import Control.Applicative (liftA2, ) import Data.Word (Word32, ) type Importer f = FunPtr f -> f foreign import ccall safe "dynamic" callShaper :: Importer (Ptr param -> Ptr shape -> IO Word32) foreign import ccall safe "dynamic" callFill :: Importer (Ptr param -> Ptr shape -> Ptr am -> IO ()) materialize :: (Shape.C sh, Storable sh, MultiValueMemory.C sh, Storable a, MultiValueMemory.C a) => String -> (core -> Exp sh) -> (core -> LLVM.Value (Ptr (MultiValueMemory.Struct sh)) -> LLVM.Value (Ptr (MultiValueMemory.Struct a)) -> LLVM.CodeGenFunction () ()) -> Sym.Hull p core -> IO (p -> IO (Phys.Array sh a)) materialize name shape fill (Sym.Hull core create delete) = do (fsh, farr) <- compile name $ liftA2 (,) (Code.createFunction callShaper "shape" $ \paramPtr resultPtr -> do param <- Memory.load paramPtr sh <- unExp $ shape $ core param MultiValueMemory.store sh resultPtr Shape.sizeCode sh >>= LLVM.ret) (Code.createFunction callFill "fill" $ \paramPtr shapePtr bufferPtr -> do param <- Memory.load paramPtr fill (core param) shapePtr bufferPtr LLVM.ret ()) return $ \p -> bracket (create p) (delete . fst) $ \(_ctx, param) -> alloca $ \shptr -> with param $ \paramPtr -> do let paramMVPtr = MultiValueMemory.castStructPtr paramPtr let shapeMVPtr = MultiValueMemory.castStructPtr shptr n <- fsh paramMVPtr shapeMVPtr fptr <- mallocForeignPtrArray (fromIntegral n) withForeignPtr fptr $ farr paramMVPtr shapeMVPtr . MultiValueMemory.castStructPtr sh <- peek shptr return (Phys.Array sh fptr) render :: (Shape.C sh, Shape.Index sh ~ ix, Storable sh, MultiValueMemory.C sh, Storable a, MultiValueMemory.C a) => Sym.Hull p (Core.Array sh a) -> IO (p -> IO (Phys.Array sh a)) render = materialize "render" Core.shape (\(Core.Array esh code) shapePtr bufferPtr -> do let step ix p = do flip Memory.store p =<< code ix A.advanceArrayElementPtr p sh <- Shape.load esh shapePtr void $ Shape.loop step sh bufferPtr) data Scatter sh0 sh1 a = Scatter { scatterAccum :: Exp a -> Exp a -> Exp a, scatterInit :: Core.Array sh1 a, scatterMap :: Core.Array sh0 (Shape.Index sh1, a) } scatter :: (Shape.C sh0, Shape.Index sh0 ~ ix0, Shape.C sh1, Shape.Index sh1 ~ ix1, Storable sh1, MultiValueMemory.C sh1, Storable a, MultiValueMemory.C a) => Sym.Hull p (Scatter sh0 sh1 a) -> IO (p -> IO (Phys.Array sh1 a)) scatter = materialize "scatter" (Core.shape . scatterInit) (\(Scatter accum arrInit arrMap) -> Priv.scatter accum arrInit arrMap) data ScatterMaybe sh0 sh1 a = ScatterMaybe { scatterMaybeAccum :: Exp a -> Exp a -> Exp a, scatterMaybeInit :: Core.Array sh1 a, scatterMaybeMap :: Core.Array sh0 (Maybe (Shape.Index sh1, a)) } scatterMaybe :: (Shape.C sh0, Shape.Index sh0 ~ ix0, Shape.C sh1, Shape.Index sh1 ~ ix1, Storable sh1, MultiValueMemory.C sh1, Storable a, MultiValueMemory.C a) => Sym.Hull p (ScatterMaybe sh0 sh1 a) -> IO (p -> IO (Phys.Array sh1 a)) scatterMaybe = materialize "scatterMaybe" (Core.shape . scatterMaybeInit) (\(ScatterMaybe accum arrInit arrMap) -> Priv.scatterMaybe accum arrInit arrMap) data MapAccumL sh n acc a b = MapAccumL { mapAccumLAccum :: Exp acc -> Exp a -> Exp (acc,b), mapAccumLInit :: Core.Array sh acc, mapAccumLMap :: Core.Array (sh, n) a } mapAccumL :: (Shape.C sh, Storable sh, MultiValueMemory.C sh, Shape.C n, Storable n, MultiValueMemory.C n, MultiValue.C acc, Storable a, MultiValueMemory.C a, Storable b, MultiValueMemory.C b) => Sym.Hull p (MapAccumL sh n acc a b) -> IO (p -> IO (Phys.Array (sh,n) b)) mapAccumL = materialize "mapAccumL" (Core.shape . mapAccumLMap) (\(MapAccumL f arrInit arrData) -> Priv.mapAccumL f arrInit arrData) data FoldOuterL n sh a b = FoldOuterL { foldOuterLAccum :: Exp a -> Exp b -> Exp a, foldOuterLInit :: Core.Array sh a, foldOuterLMap :: Core.Array (n,sh) b } -- FIXME: check correct size of array of initial values foldOuterL :: (Shape.C sh, Storable sh, MultiValueMemory.C sh, Shape.C n, Storable n, MultiValueMemory.C n, Storable a, MultiValueMemory.C a) => Sym.Hull p (FoldOuterL n sh a b) -> IO (p -> IO (Phys.Array sh a)) foldOuterL = materialize "foldOuterL" (Core.shape . foldOuterLInit) (\(FoldOuterL f arrInit arrData) -> Priv.foldOuterL f arrInit arrData)