{-# LANGUAGE GADTs #-} {-# LANGUAGE ForeignFunctionInterface #-} module Data.Array.Knead.Parameterized.PhysicalHull ( render, MapFilter(..), mapFilter, FilterOuter(..), filterOuter, Scatter(..), scatter, ScatterMaybe(..), scatterMaybe, MapAccumLSimple(..), mapAccumLSimple, MapAccumLSequence(..), mapAccumLSequence, MapAccumL(..), mapAccumL, FoldOuterL(..), foldOuterL, AddDimension(..), addDimension, ) where import qualified Data.Array.Knead.Parameterized.Private as Sym import qualified Data.Array.Knead.Simple.PhysicalPrivate as Priv import qualified Data.Array.Knead.Simple.Private as Core import qualified Data.Array.Knead.Shape as Shape import qualified Data.Array.Knead.Expression as Expr import qualified Data.Array.Knead.Code as Code import Data.Array.Knead.Expression (Exp, unExp, ) import Data.Array.Knead.Code (compile, ) import Data.Array.Comfort.Storable.Unchecked (Array(Array)) import qualified LLVM.Extra.Multi.Value.Memory as MultiValueMemory import qualified LLVM.Extra.Multi.Value as MultiValue import qualified LLVM.Extra.Memory as Memory import qualified LLVM.Extra.Arithmetic as A import qualified LLVM.Core as LLVM import Foreign.Marshal.Utils (with, ) import Foreign.Marshal.Array (allocaArray, ) import Foreign.Marshal.Alloc (alloca, ) import Foreign.Storable (Storable, peek, peekElemOff, ) import Foreign.ForeignPtr (ForeignPtr, withForeignPtr, mallocForeignPtrArray, ) import Foreign.Ptr (FunPtr, Ptr, ) import Control.Exception (bracket, ) import Control.Monad.HT (void, ) import Control.Applicative (liftA2, ) mallocArray :: (Storable a) => Shape.Size -> IO (ForeignPtr a) mallocArray = mallocForeignPtrArray . fromIntegral withForeignMemPtr :: ForeignPtr a -> (Ptr (MultiValueMemory.Struct a) -> IO b) -> IO b withForeignMemPtr fptr act = withForeignPtr fptr $ act . MultiValueMemory.castStructPtr type Importer f = FunPtr f -> f foreign import ccall safe "dynamic" callShaper :: Importer (Ptr param -> Ptr shape -> IO Shape.Size) foreign import ccall safe "dynamic" callFill :: Importer (Ptr param -> Ptr shape -> Ptr am -> IO ()) {- Attention: The 'fill' function may alter the shape. An example is 'mapFilter'. -} materialize :: (Shape.C sh, Storable sh, MultiValueMemory.C sh, Storable a, MultiValueMemory.C a) => String -> (core -> Exp sh) -> (core -> LLVM.Value (Ptr (MultiValueMemory.Struct sh)) -> LLVM.Value (Ptr (MultiValueMemory.Struct a)) -> LLVM.CodeGenFunction () ()) -> Sym.Hull p core -> IO (p -> IO (Array sh a)) materialize name shape fill (Sym.Hull core create delete) = do (fsh, farr) <- compile name $ liftA2 (,) (Code.createFunction callShaper "shape" $ \paramPtr resultPtr -> do param <- Memory.load paramPtr sh <- unExp $ shape $ core param MultiValueMemory.store sh resultPtr Shape.size sh >>= LLVM.ret) (Code.createFunction callFill "fill" $ \paramPtr shapePtr bufferPtr -> do param <- Memory.load paramPtr fill (core param) shapePtr bufferPtr LLVM.ret ()) return $ \p -> bracket (create p) (delete . fst) $ \(_ctx, param) -> alloca $ \shptr -> with param $ \paramPtr -> do let paramMVPtr = MultiValueMemory.castStructPtr paramPtr let shapeMVPtr = MultiValueMemory.castStructPtr shptr fptr <- mallocArray =<< fsh paramMVPtr shapeMVPtr withForeignMemPtr fptr $ farr paramMVPtr shapeMVPtr sh <- peek shptr return (Array sh fptr) foreign import ccall safe "dynamic" callFillExpArray :: Importer (Ptr param -> Ptr final -> Ptr shape -> Ptr am -> IO ()) materializeExpArray :: (Shape.C sh, Storable sh, MultiValueMemory.C sh, Storable a, MultiValueMemory.C a, Storable b, MultiValueMemory.C b) => String -> (core -> Exp sh) -> (core -> LLVM.Value (Ptr (MultiValueMemory.Struct b)) -> LLVM.Value (Ptr (MultiValueMemory.Struct sh)) -> LLVM.Value (Ptr (MultiValueMemory.Struct a)) -> LLVM.CodeGenFunction () ()) -> Sym.Hull p core -> IO (p -> IO (b, Array sh a)) materializeExpArray name shape fill (Sym.Hull core create delete) = do (fsh, farr) <- compile name $ liftA2 (,) (Code.createFunction callShaper "shape" $ \paramPtr resultPtr -> do param <- Memory.load paramPtr sh <- unExp $ shape $ core param MultiValueMemory.store sh resultPtr Shape.size sh >>= LLVM.ret) (Code.createFunction callFillExpArray "fill" $ \paramPtr finalPtr shapePtr bufferPtr -> do param <- Memory.load paramPtr fill (core param) finalPtr shapePtr bufferPtr LLVM.ret ()) return $ \p -> bracket (create p) (delete . fst) $ \(_ctx, param) -> alloca $ \shptr -> alloca $ \finalPtr -> with param $ \paramPtr -> do let paramMVPtr = MultiValueMemory.castStructPtr paramPtr let finalMVPtr = MultiValueMemory.castStructPtr finalPtr let shapeMVPtr = MultiValueMemory.castStructPtr shptr fptr <- mallocArray =<< fsh paramMVPtr shapeMVPtr withForeignMemPtr fptr $ farr paramMVPtr finalMVPtr shapeMVPtr sh <- peek shptr final <- peek finalPtr return (final, Array sh fptr) foreign import ccall safe "dynamic" callShaper2 :: Importer (Ptr param -> Ptr shapeA -> Ptr shapeB -> Ptr Shape.Size -> IO ()) foreign import ccall safe "dynamic" callFill2 :: Importer (Ptr param -> Ptr shapeA -> Ptr am -> Ptr shapeB -> Ptr bm -> IO ()) materialize2 :: (Shape.C sha, Storable sha, MultiValueMemory.C sha, Shape.C shb, Storable shb, MultiValueMemory.C shb, Storable a, MultiValueMemory.C a, Storable b, MultiValueMemory.C b) => String -> (core -> Exp (sha,shb)) -> (core -> (LLVM.Value (Ptr (MultiValueMemory.Struct sha)), LLVM.Value (Ptr (MultiValueMemory.Struct a))) -> (LLVM.Value (Ptr (MultiValueMemory.Struct shb)), LLVM.Value (Ptr (MultiValueMemory.Struct b))) -> LLVM.CodeGenFunction () ()) -> Sym.Hull p core -> IO (p -> IO (Array sha a, Array shb b)) materialize2 name shape fill (Sym.Hull core create delete) = do (fsh, farr) <- compile name $ liftA2 (,) (Code.createFunction callShaper2 "shape" $ \paramPtr shapeAPtr shapeBPtr sizesPtr -> do param <- Memory.load paramPtr (sha,shb) <- fmap MultiValue.unzip $ unExp $ shape $ core param MultiValueMemory.store sha shapeAPtr MultiValueMemory.store shb shapeBPtr sizeAPtr <- LLVM.bitcast sizesPtr flip LLVM.store sizeAPtr =<< Shape.size sha sizeBPtr <- A.advanceArrayElementPtr sizeAPtr flip LLVM.store sizeBPtr =<< Shape.size shb LLVM.ret ()) (Code.createFunction callFill2 "fill" $ \paramPtr shapeAPtr bufferAPtr shapeBPtr bufferBPtr -> do param <- Memory.load paramPtr fill (core param) (shapeAPtr, bufferAPtr) (shapeBPtr, bufferBPtr) LLVM.ret ()) return $ \p -> bracket (create p) (delete . fst) $ \(_ctx, param) -> alloca $ \shaPtr -> alloca $ \shbPtr -> allocaArray 2 $ \sizesPtr -> with param $ \paramPtr -> do let paramMVPtr = MultiValueMemory.castStructPtr paramPtr let shapeAMVPtr = MultiValueMemory.castStructPtr shaPtr let shapeBMVPtr = MultiValueMemory.castStructPtr shbPtr fsh paramMVPtr shapeAMVPtr shapeBMVPtr sizesPtr afptr <- mallocArray =<< peekElemOff sizesPtr 0 bfptr <- mallocArray =<< peekElemOff sizesPtr 1 withForeignMemPtr afptr $ \aptr -> withForeignMemPtr bfptr $ \bptr -> farr paramMVPtr shapeAMVPtr aptr shapeBMVPtr bptr sha <- peek shaPtr shb <- peek shbPtr return (Array sha afptr, Array shb bfptr) render :: (Shape.C sh, Shape.Index sh ~ ix, Storable sh, MultiValueMemory.C sh, Storable a, MultiValueMemory.C a) => Sym.Hull p (Core.Array sh a) -> IO (p -> IO (Array sh a)) render = materialize "render" Core.shape (\(Core.Array esh code) shapePtr bufferPtr -> do let step ix p = do flip Memory.store p =<< code ix A.advanceArrayElementPtr p sh <- Shape.load esh shapePtr void $ Shape.loop step sh bufferPtr) data Scatter sh0 sh1 a = Scatter { scatterAccum :: Exp a -> Exp a -> Exp a, scatterInit :: Core.Array sh1 a, scatterMap :: Core.Array sh0 (Shape.Index sh1, a) } scatter :: (Shape.C sh0, Shape.Index sh0 ~ ix0, Shape.C sh1, Shape.Index sh1 ~ ix1, Storable sh1, MultiValueMemory.C sh1, Storable a, MultiValueMemory.C a) => Sym.Hull p (Scatter sh0 sh1 a) -> IO (p -> IO (Array sh1 a)) scatter = materialize "scatter" (Core.shape . scatterInit) (\(Scatter accum arrInit arrMap) -> Priv.scatter accum arrInit arrMap) data ScatterMaybe sh0 sh1 a = ScatterMaybe { scatterMaybeAccum :: Exp a -> Exp a -> Exp a, scatterMaybeInit :: Core.Array sh1 a, scatterMaybeMap :: Core.Array sh0 (Maybe (Shape.Index sh1, a)) } scatterMaybe :: (Shape.C sh0, Shape.Index sh0 ~ ix0, Shape.C sh1, Shape.Index sh1 ~ ix1, Storable sh1, MultiValueMemory.C sh1, Storable a, MultiValueMemory.C a) => Sym.Hull p (ScatterMaybe sh0 sh1 a) -> IO (p -> IO (Array sh1 a)) scatterMaybe = materialize "scatterMaybe" (Core.shape . scatterMaybeInit) (\(ScatterMaybe accum arrInit arrMap) -> Priv.scatterMaybe accum arrInit arrMap) data MapAccumLSimple sh n acc a b = MapAccumLSimple { mapAccumLSimpleAccum :: Exp acc -> Exp a -> Exp (acc,b), mapAccumLSimpleInit :: Core.Array sh acc, mapAccumLSimpleArray :: Core.Array (sh, n) a } mapAccumLSimple :: (Shape.C sh, Storable sh, MultiValueMemory.C sh, Shape.C n, Storable n, MultiValueMemory.C n, MultiValue.C acc, Storable a, MultiValueMemory.C a, Storable b, MultiValueMemory.C b) => Sym.Hull p (MapAccumLSimple sh n acc a b) -> IO (p -> IO (Array (sh,n) b)) mapAccumLSimple = materialize "mapAccumLSimple" (Core.shape . mapAccumLSimpleArray) (\(MapAccumLSimple f arrInit arrData) -> Priv.mapAccumLSimple f arrInit arrData) data MapAccumLSequence n acc final a b = MapAccumLSequence { mapAccumLSequenceAccum :: Exp acc -> Exp a -> Exp (acc,b), mapAccumLSequenceFinal :: Exp acc -> Exp final, mapAccumLSequenceInit :: Exp acc, mapAccumLSequenceArray :: Core.Array n a } -- FIXME: check correct size of array of initial values mapAccumLSequence :: (Shape.C n, Storable n, MultiValueMemory.C n, MultiValue.C acc, Storable final, MultiValueMemory.C final, Storable a, MultiValueMemory.C a, Storable b, MultiValueMemory.C b) => Sym.Hull p (MapAccumLSequence n acc final a b) -> IO (p -> IO (final, Array n b)) mapAccumLSequence = materializeExpArray "mapAccumLSequence" (Core.shape . mapAccumLSequenceArray) (\(MapAccumLSequence f final expInit arr) -> Priv.mapAccumLSequence f final expInit arr) data MapAccumL sh n acc final a b = MapAccumL { mapAccumLAccum :: Exp acc -> Exp a -> Exp (acc,b), mapAccumLFinal :: Exp acc -> Exp final, mapAccumLInit :: Core.Array sh acc, mapAccumLArray :: Core.Array (sh, n) a } -- FIXME: check correct size of array of initial values mapAccumL :: (Shape.C sh, Storable sh, MultiValueMemory.C sh, Shape.C n, Storable n, MultiValueMemory.C n, MultiValue.C acc, Storable final, MultiValueMemory.C final, Storable a, MultiValueMemory.C a, Storable b, MultiValueMemory.C b) => Sym.Hull p (MapAccumL sh n acc final a b) -> IO (p -> IO (Array sh final, Array (sh,n) b)) mapAccumL = materialize2 "mapAccumL" (\core -> Expr.zip (Core.shape $ mapAccumLInit core) (Core.shape $ mapAccumLArray core)) (\(MapAccumL f final arrInit arrData) -> Priv.mapAccumL f final arrInit arrData) data FoldOuterL n sh a b = FoldOuterL { foldOuterLAccum :: Exp a -> Exp b -> Exp a, foldOuterLInit :: Core.Array sh a, foldOuterLArray :: Core.Array (n,sh) b } -- FIXME: check correct size of array of initial values foldOuterL :: (Shape.C n, Storable n, MultiValueMemory.C n, Shape.C sh, Storable sh, MultiValueMemory.C sh, Storable a, MultiValueMemory.C a) => Sym.Hull p (FoldOuterL n sh a b) -> IO (p -> IO (Array sh a)) foldOuterL = materialize "foldOuterL" (Core.shape . foldOuterLInit) (\(FoldOuterL f arrInit arrData) -> Priv.foldOuterL f arrInit arrData) data MapFilter n a b = MapFilter { mapFilterMap :: Exp a -> Exp b, mapFilterPredicate :: Exp a -> Exp Bool, mapFilterArray :: Core.Array n a } mapFilter :: (Shape.Sequence n, Storable n, MultiValueMemory.C n, Storable b, MultiValueMemory.C b) => Sym.Hull p (MapFilter n a b) -> IO (p -> IO (Array n b)) mapFilter = materialize "mapFilter" (Core.shape . mapFilterArray) (\(MapFilter f p arr) shapePtr bufferPtr -> flip MultiValueMemory.store shapePtr =<< Priv.mapFilter f p arr shapePtr bufferPtr) data FilterOuter n sh a = FilterOuter { filterOuterPredicate :: Core.Array n Bool, filterOuterArray :: Core.Array (n,sh) a } -- FIXME: check correct size of row selection array filterOuter :: (Shape.Sequence n, Storable n, MultiValueMemory.C n, Shape.C sh, Storable sh, MultiValueMemory.C sh, Storable a, MultiValueMemory.C a) => Sym.Hull p (FilterOuter n sh a) -> IO (p -> IO (Array (n,sh) a)) filterOuter = materialize "filterOuter" (Core.shape . filterOuterArray) (\(FilterOuter p arr) shapePtr bufferPtr -> flip MultiValueMemory.store shapePtr =<< Priv.filterOuter p arr shapePtr bufferPtr) data AddDimension sh n a b = AddDimension { addDimensionSize :: Exp n, addDimensionSelect :: Exp (Shape.Index n) -> Exp a -> Exp b, addDimensionArray :: Core.Array sh a } addDimension :: (Shape.C sh, Storable sh, MultiValueMemory.C sh, Shape.C n, Storable n, MultiValueMemory.C n, Storable b, MultiValueMemory.C b) => Sym.Hull p (AddDimension sh n a b) -> IO (p -> IO (Array (sh,n) b)) addDimension = materialize "addDimension" (\r -> Expr.zip (Core.shape (addDimensionArray r)) (addDimensionSize r)) (\(AddDimension n select arr) -> Priv.addDimension n select arr)