{-# LANGUAGE GADTs #-} {-# LANGUAGE ForeignFunctionInterface #-} module Data.Array.Knead.Parameterized.PhysicalHull ( render, MapFilter(..), mapFilter, FilterOuter(..), filterOuter, Scatter(..), scatter, ScatterMaybe(..), scatterMaybe, MapAccumLSimple(..), mapAccumLSimple, MapAccumLSequence(..), mapAccumLSequence, MapAccumL(..), mapAccumL, FoldOuterL(..), foldOuterL, AddDimension(..), addDimension, ) where import qualified Data.Array.Knead.Parameterized.Private as Sym import qualified Data.Array.Knead.Simple.PhysicalPrivate as Priv import qualified Data.Array.Knead.Simple.Private as Core import qualified Data.Array.Knead.Shape as Shape import qualified Data.Array.Knead.Expression as Expr import Data.Array.Knead.Simple.PhysicalPrivate (MarshalPtr) import Data.Array.Comfort.Storable.Unchecked (Array(Array)) import qualified LLVM.DSL.Execution as Code import LLVM.DSL.Expression (Exp, unExp) import qualified LLVM.Extra.Multi.Value as MultiValue import qualified LLVM.Extra.Storable as Storable import qualified LLVM.Extra.Marshal as Marshal import qualified LLVM.Extra.Memory as Memory import qualified LLVM.Extra.Arithmetic as A import qualified LLVM.Core as LLVM import Foreign.Marshal.Array (allocaArray, ) import Foreign.Marshal.Alloc (alloca, ) import Foreign.Storable (Storable, peek, peekElemOff, ) import Foreign.ForeignPtr (ForeignPtr, withForeignPtr, mallocForeignPtrArray, ) import Foreign.Ptr (FunPtr, Ptr, ) import Control.Exception (bracket, ) import Control.Monad.HT (void, ) import Control.Applicative (liftA2, ) mallocArray :: (Storable a) => Shape.Size -> IO (ForeignPtr a) mallocArray = mallocForeignPtrArray . fromIntegral type Importer f = FunPtr f -> f foreign import ccall safe "dynamic" callShaper :: Importer (LLVM.Ptr param -> LLVM.Ptr shape -> IO Shape.Size) foreign import ccall safe "dynamic" callFill :: Importer (LLVM.Ptr param -> LLVM.Ptr shape -> Ptr a -> IO ()) {- Attention: The 'fill' function may alter the shape. An example is 'mapFilter'. -} materialize :: (Shape.C sh, Marshal.MV sh, Storable.C a) => String -> (core -> Exp sh) -> (core -> LLVM.Value (MarshalPtr sh) -> LLVM.Value (Ptr a) -> LLVM.CodeGenFunction () ()) -> Sym.Hull p core -> IO (p -> IO (Array sh a)) materialize name shape fill (Sym.Hull core create delete) = do (fsh, farr) <- Code.compile name $ liftA2 (,) (Code.createFunction callShaper "shape" $ \paramPtr resultPtr -> do param <- Memory.load paramPtr sh <- unExp $ shape $ core param Memory.store sh resultPtr Shape.size sh >>= LLVM.ret) (Code.createFunction callFill "fill" $ \paramPtr shapePtr bufferPtr -> do param <- Memory.load paramPtr fill (core param) shapePtr bufferPtr LLVM.ret ()) return $ \p -> bracket (create p) (delete . fst) $ \(_ctx, param) -> Marshal.alloca $ \shptr -> Marshal.with param $ \paramPtr -> do fptr <- mallocArray =<< fsh paramPtr shptr withForeignPtr fptr $ farr paramPtr shptr sh <- Marshal.peek shptr return (Array sh fptr) foreign import ccall safe "dynamic" callFillExpArray :: Importer (LLVM.Ptr param -> Ptr final -> LLVM.Ptr shape -> Ptr a -> IO ()) materializeExpArray :: (Shape.C sh, Marshal.MV sh, Storable.C a, Storable.C b) => String -> (core -> Exp sh) -> (core -> LLVM.Value (Ptr b) -> LLVM.Value (MarshalPtr sh) -> LLVM.Value (Ptr a) -> LLVM.CodeGenFunction () ()) -> Sym.Hull p core -> IO (p -> IO (b, Array sh a)) materializeExpArray name shape fill (Sym.Hull core create delete) = do (fsh, farr) <- Code.compile name $ liftA2 (,) (Code.createFunction callShaper "shape" $ \paramPtr resultPtr -> do param <- Memory.load paramPtr sh <- unExp $ shape $ core param Memory.store sh resultPtr Shape.size sh >>= LLVM.ret) (Code.createFunction callFillExpArray "fill" $ \paramPtr finalPtr shapePtr bufferPtr -> do param <- Memory.load paramPtr fill (core param) finalPtr shapePtr bufferPtr LLVM.ret ()) return $ \p -> bracket (create p) (delete . fst) $ \(_ctx, param) -> Marshal.alloca $ \shptr -> alloca $ \finalPtr -> Marshal.with param $ \paramPtr -> do fptr <- mallocArray =<< fsh paramPtr shptr withForeignPtr fptr $ farr paramPtr finalPtr shptr sh <- Marshal.peek shptr final <- peek finalPtr return (final, Array sh fptr) foreign import ccall safe "dynamic" callShaper2 :: Importer (LLVM.Ptr param -> LLVM.Ptr shapeA -> LLVM.Ptr shapeB -> Ptr Shape.Size -> IO ()) foreign import ccall safe "dynamic" callFill2 :: Importer (LLVM.Ptr param -> LLVM.Ptr shapeA -> Ptr a -> LLVM.Ptr shapeB -> Ptr b -> IO ()) materialize2 :: (Shape.C sha, Marshal.MV sha, Shape.C shb, Marshal.MV shb, Storable.C a, Storable.C b) => String -> (core -> Exp (sha,shb)) -> (core -> (LLVM.Value (MarshalPtr sha), LLVM.Value (Ptr a)) -> (LLVM.Value (MarshalPtr shb), LLVM.Value (Ptr b)) -> LLVM.CodeGenFunction () ()) -> Sym.Hull p core -> IO (p -> IO (Array sha a, Array shb b)) materialize2 name shape fill (Sym.Hull core create delete) = do (fsh, farr) <- Code.compile name $ liftA2 (,) (Code.createFunction callShaper2 "shape" $ \paramPtr shapeAPtr shapeBPtr sizesPtr -> do param <- Memory.load paramPtr (sha,shb) <- fmap MultiValue.unzip $ unExp $ shape $ core param Memory.store sha shapeAPtr Memory.store shb shapeBPtr sizeAPtr <- LLVM.bitcast sizesPtr flip LLVM.store sizeAPtr =<< Shape.size sha sizeBPtr <- A.advanceArrayElementPtr sizeAPtr flip LLVM.store sizeBPtr =<< Shape.size shb LLVM.ret ()) (Code.createFunction callFill2 "fill" $ \paramPtr shapeAPtr bufferAPtr shapeBPtr bufferBPtr -> do param <- Memory.load paramPtr fill (core param) (shapeAPtr, bufferAPtr) (shapeBPtr, bufferBPtr) LLVM.ret ()) return $ \p -> bracket (create p) (delete . fst) $ \(_ctx, param) -> Marshal.alloca $ \shaPtr -> Marshal.alloca $ \shbPtr -> allocaArray 2 $ \sizesPtr -> Marshal.with param $ \paramPtr -> do fsh paramPtr shaPtr shbPtr sizesPtr afptr <- mallocArray =<< peekElemOff sizesPtr 0 bfptr <- mallocArray =<< peekElemOff sizesPtr 1 withForeignPtr afptr $ \aptr -> withForeignPtr bfptr $ \bptr -> farr paramPtr shaPtr aptr shbPtr bptr sha <- Marshal.peek shaPtr shb <- Marshal.peek shbPtr return (Array sha afptr, Array shb bfptr) render :: (Shape.C sh, Shape.Index sh ~ ix, Marshal.MV sh, Storable.C a) => Sym.Hull p (Core.Array sh a) -> IO (p -> IO (Array sh a)) render = materialize "render" Core.shape (\(Core.Array esh code) shapePtr bufferPtr -> do let step ix p = flip Storable.storeNextMultiValue p =<< code ix sh <- Shape.load esh shapePtr void $ Shape.loop step sh bufferPtr) data Scatter sh0 sh1 a = Scatter { scatterAccum :: Exp a -> Exp a -> Exp a, scatterInit :: Core.Array sh1 a, scatterMap :: Core.Array sh0 (Shape.Index sh1, a) } scatter :: (Shape.C sh0, Shape.Index sh0 ~ ix0, Shape.C sh1, Shape.Index sh1 ~ ix1, Marshal.MV sh1, Storable.C a) => Sym.Hull p (Scatter sh0 sh1 a) -> IO (p -> IO (Array sh1 a)) scatter = materialize "scatter" (Core.shape . scatterInit) (\(Scatter accum arrInit arrMap) -> Priv.scatter accum arrInit arrMap) data ScatterMaybe sh0 sh1 a = ScatterMaybe { scatterMaybeAccum :: Exp a -> Exp a -> Exp a, scatterMaybeInit :: Core.Array sh1 a, scatterMaybeMap :: Core.Array sh0 (Maybe (Shape.Index sh1, a)) } scatterMaybe :: (Shape.C sh0, Shape.Index sh0 ~ ix0, Shape.C sh1, Shape.Index sh1 ~ ix1, Marshal.MV sh1, Storable.C a) => Sym.Hull p (ScatterMaybe sh0 sh1 a) -> IO (p -> IO (Array sh1 a)) scatterMaybe = materialize "scatterMaybe" (Core.shape . scatterMaybeInit) (\(ScatterMaybe accum arrInit arrMap) -> Priv.scatterMaybe accum arrInit arrMap) data MapAccumLSimple sh n acc a b = MapAccumLSimple { mapAccumLSimpleAccum :: Exp acc -> Exp a -> Exp (acc,b), mapAccumLSimpleInit :: Core.Array sh acc, mapAccumLSimpleArray :: Core.Array (sh, n) a } mapAccumLSimple :: (Shape.C sh, Marshal.MV sh, Shape.C n, Marshal.MV n, MultiValue.C acc, Storable.C a, Storable.C b) => Sym.Hull p (MapAccumLSimple sh n acc a b) -> IO (p -> IO (Array (sh,n) b)) mapAccumLSimple = materialize "mapAccumLSimple" (Core.shape . mapAccumLSimpleArray) (\(MapAccumLSimple f arrInit arrData) -> Priv.mapAccumLSimple f arrInit arrData) data MapAccumLSequence n acc final a b = MapAccumLSequence { mapAccumLSequenceAccum :: Exp acc -> Exp a -> Exp (acc,b), mapAccumLSequenceFinal :: Exp acc -> Exp final, mapAccumLSequenceInit :: Exp acc, mapAccumLSequenceArray :: Core.Array n a } -- FIXME: check correct size of array of initial values mapAccumLSequence :: (Shape.C n, Marshal.MV n, MultiValue.C acc, Storable.C final, MultiValue.C final, Storable.C a, Storable.C b) => Sym.Hull p (MapAccumLSequence n acc final a b) -> IO (p -> IO (final, Array n b)) mapAccumLSequence = materializeExpArray "mapAccumLSequence" (Core.shape . mapAccumLSequenceArray) (\(MapAccumLSequence f final expInit arr) -> Priv.mapAccumLSequence f final expInit arr) data MapAccumL sh n acc final a b = MapAccumL { mapAccumLAccum :: Exp acc -> Exp a -> Exp (acc,b), mapAccumLFinal :: Exp acc -> Exp final, mapAccumLInit :: Core.Array sh acc, mapAccumLArray :: Core.Array (sh, n) a } -- FIXME: check correct size of array of initial values mapAccumL :: (Shape.C sh, Marshal.MV sh, Shape.C n, Marshal.MV n, MultiValue.C acc, Storable.C final, MultiValue.C final, Storable.C a, Storable.C b) => Sym.Hull p (MapAccumL sh n acc final a b) -> IO (p -> IO (Array sh final, Array (sh,n) b)) mapAccumL = materialize2 "mapAccumL" (\core -> Expr.zip (Core.shape $ mapAccumLInit core) (Core.shape $ mapAccumLArray core)) (\(MapAccumL f final arrInit arrData) -> Priv.mapAccumL f final arrInit arrData) data FoldOuterL n sh a b = FoldOuterL { foldOuterLAccum :: Exp a -> Exp b -> Exp a, foldOuterLInit :: Core.Array sh a, foldOuterLArray :: Core.Array (n,sh) b } -- FIXME: check correct size of array of initial values foldOuterL :: (Shape.C n, Marshal.MV n, Shape.C sh, Marshal.MV sh, Storable.C a) => Sym.Hull p (FoldOuterL n sh a b) -> IO (p -> IO (Array sh a)) foldOuterL = materialize "foldOuterL" (Core.shape . foldOuterLInit) (\(FoldOuterL f arrInit arrData) -> Priv.foldOuterL f arrInit arrData) data MapFilter n a b = MapFilter { mapFilterMap :: Exp a -> Exp b, mapFilterPredicate :: Exp a -> Exp Bool, mapFilterArray :: Core.Array n a } mapFilter :: (Shape.Sequence n, Marshal.MV n, Storable.C b) => Sym.Hull p (MapFilter n a b) -> IO (p -> IO (Array n b)) mapFilter = materialize "mapFilter" (Core.shape . mapFilterArray) (\(MapFilter f p arr) shapePtr bufferPtr -> flip Memory.store shapePtr =<< Priv.mapFilter f p arr shapePtr bufferPtr) data FilterOuter n sh a = FilterOuter { filterOuterPredicate :: Core.Array n Bool, filterOuterArray :: Core.Array (n,sh) a } -- FIXME: check correct size of row selection array filterOuter :: (Shape.Sequence n, Marshal.MV n, Shape.C sh, Marshal.MV sh, Storable.C a) => Sym.Hull p (FilterOuter n sh a) -> IO (p -> IO (Array (n,sh) a)) filterOuter = materialize "filterOuter" (Core.shape . filterOuterArray) (\(FilterOuter p arr) shapePtr bufferPtr -> flip Memory.store shapePtr =<< Priv.filterOuter p arr shapePtr bufferPtr) data AddDimension sh n a b = AddDimension { addDimensionSize :: Exp n, addDimensionSelect :: Exp (Shape.Index n) -> Exp a -> Exp b, addDimensionArray :: Core.Array sh a } addDimension :: (Shape.C sh, Marshal.MV sh, Shape.C n, Marshal.MV n, Storable.C b) => Sym.Hull p (AddDimension sh n a b) -> IO (p -> IO (Array (sh,n) b)) addDimension = materialize "addDimension" (\r -> Expr.zip (Core.shape (addDimensionArray r)) (addDimensionSize r)) (\(AddDimension n select arr) -> Priv.addDimension n select arr)