{-# LANGUAGE GADTs #-} {-# LANGUAGE ForeignFunctionInterface #-} module Data.Array.Knead.Parameterized.Physical ( Phys.Array, Array.shape, Phys.fromList, feed, the, render, renderShape, mapAccumLSimple, foldOuterL, scatter, scatterMaybe, permute, ) where import qualified Data.Array.Knead.Parameterized.PhysicalHull as PhysHull import qualified Data.Array.Knead.Parameterized.Private as Sym import qualified Data.Array.Knead.Simple.Physical as Phys import qualified Data.Array.Knead.Simple.Private as Core import qualified Data.Array.Knead.Parameter as Param import qualified Data.Array.Knead.Shape as Shape import qualified Data.Array.Knead.Expression as Expr import qualified Data.Array.Knead.Code as Code import Data.Array.Knead.Expression (Exp, unExp, ) import Data.Array.Knead.Code (getElementPtr, compile, ) import qualified Data.Array.Comfort.Storable.Unchecked as Array import qualified LLVM.Extra.Multi.Value.Memory as MultiValueMemory import qualified LLVM.Extra.Multi.Value as MultiValue import qualified LLVM.Extra.Memory as Memory import qualified LLVM.Core as LLVM import Foreign.Marshal.Utils (with, ) import Foreign.Marshal.Alloc (alloca, ) import Foreign.Storable (Storable, peek, ) import Foreign.ForeignPtr (withForeignPtr, touchForeignPtr, ) import Foreign.Ptr (FunPtr, Ptr, ) import Control.Exception (bracket, ) import Control.Monad.HT ((<=<), ) import Control.Applicative (liftA2, ) import Data.Tuple.HT (mapFst, ) {-# INLINE feed #-} feed :: (Shape.C sh, Storable sh, MultiValueMemory.C sh, MultiValueMemory.C a) => Param.T p (Phys.Array sh a) -> Sym.Array p sh a feed arr = Param.withMulti (fmap Array.shape arr) $ \getShape valueShape -> Sym.Array (\p -> case mapFst valueShape $ MultiValue.unzip p of (sh, MultiValue.Cons ptr) -> Core.Array (Expr.lift0 sh) $ Memory.load <=< getElementPtr sh ptr) (\p -> case Array.buffer $ Param.get arr p of fptr -> withForeignPtr fptr $ \ptr -> return (fptr, (getShape p, MultiValueMemory.castStructPtr ptr))) touchForeignPtr type Importer f = FunPtr f -> f foreign import ccall safe "dynamic" callThe :: Importer (Ptr param -> Ptr am -> IO ()) the :: (Shape.Scalar z, MultiValueMemory.C a, Storable a) => Sym.Array p z a -> IO (p -> IO a) the (Sym.Array arr create delete) = do func <- compile "the" $ Code.createFunction callThe "eval" $ \paramPtr resultPtr -> do param <- Memory.load paramPtr case arr param of Core.Array z code -> code (Shape.zeroIndex z) >>= flip Memory.store resultPtr LLVM.ret () return $ \p -> bracket (create p) (delete . fst) $ \(_ctx, param) -> with param $ \pptr -> alloca $ \aptr -> func (MultiValueMemory.castStructPtr pptr) (MultiValueMemory.castStructPtr aptr) >> peek aptr foreign import ccall safe "dynamic" callShaper :: Importer (Ptr param -> Ptr shape -> IO Shape.Size) renderShape :: (Shape.C sh, Storable sh, MultiValueMemory.C sh, Storable a, MultiValueMemory.C a) => Sym.Array p sh a -> IO (p -> IO (sh, Shape.Size)) renderShape (Sym.Array arr create delete) = do fsh <- compile "renderShape" $ Code.createFunction callShaper "shape" $ \paramPtr resultPtr -> do param <- Memory.load paramPtr case arr param of Core.Array esh _code -> do sh <- unExp esh MultiValueMemory.store sh resultPtr Shape.size sh >>= LLVM.ret return $ \p -> bracket (create p) (delete . fst) $ \(_ctx, param) -> alloca $ \shptr -> with param $ \pptr -> do let lpptr = MultiValueMemory.castStructPtr pptr let lshptr = MultiValueMemory.castStructPtr shptr n <- fsh lpptr lshptr sh <- peek shptr return (sh, n) render :: (Shape.C sh, Storable sh, MultiValueMemory.C sh, Storable a, MultiValueMemory.C a) => Sym.Array p sh a -> IO (p -> IO (Phys.Array sh a)) render = PhysHull.render . Sym.arrayHull mapAccumLSimple :: (Shape.C sh, Storable sh, MultiValueMemory.C sh, Shape.C n, Storable n, MultiValueMemory.C n, MultiValue.C acc, Storable a, MultiValueMemory.C a, Storable b, MultiValueMemory.C b) => (Exp acc -> Exp a -> Exp (acc,b)) -> Sym.Array p sh acc -> Sym.Array p (sh, n) a -> IO (p -> IO (Phys.Array (sh,n) b)) mapAccumLSimple f arrInit arrMap = PhysHull.mapAccumLSimple $ liftA2 (PhysHull.MapAccumLSimple f) (Sym.arrayHull arrInit) (Sym.arrayHull arrMap) foldOuterL :: (Shape.C sh, Storable sh, MultiValueMemory.C sh, Shape.C n, Storable n, MultiValueMemory.C n, Storable a, MultiValueMemory.C a) => (Exp a -> Exp b -> Exp a) -> Sym.Array p sh a -> Sym.Array p (n,sh) b -> IO (p -> IO (Phys.Array sh a)) foldOuterL f arrInit arrMap = PhysHull.foldOuterL $ liftA2 (PhysHull.FoldOuterL f) (Sym.arrayHull arrInit) (Sym.arrayHull arrMap) scatter :: (Shape.C sh0, Shape.Index sh0 ~ ix0, Shape.C sh1, Shape.Index sh1 ~ ix1, Storable sh1, MultiValueMemory.C sh1, Storable a, MultiValueMemory.C a) => (Exp a -> Exp a -> Exp a) -> Sym.Array p sh1 a -> Sym.Array p sh0 (ix1, a) -> IO (p -> IO (Phys.Array sh1 a)) scatter accum arrBase arrMap = PhysHull.scatter $ liftA2 (PhysHull.Scatter accum) (Sym.arrayHull arrBase) (Sym.arrayHull arrMap) scatterMaybe :: (Shape.C sh0, Shape.Index sh0 ~ ix0, Shape.C sh1, Shape.Index sh1 ~ ix1, Storable sh1, MultiValueMemory.C sh1, Storable a, MultiValueMemory.C a) => (Exp a -> Exp a -> Exp a) -> Sym.Array p sh1 a -> Sym.Array p sh0 (Maybe (ix1, a)) -> IO (p -> IO (Phys.Array sh1 a)) scatterMaybe accum arrBase arrMap = PhysHull.scatterMaybe $ liftA2 (PhysHull.ScatterMaybe accum) (Sym.arrayHull arrBase) (Sym.arrayHull arrMap) permute :: (Shape.C sh0, Shape.Index sh0 ~ ix0, Shape.C sh1, Shape.Index sh1 ~ ix1, Storable sh1, MultiValueMemory.C sh1, Storable a, MultiValueMemory.C a) => (Exp a -> Exp a -> Exp a) -> Sym.Array p sh1 a -> (Exp ix0 -> Exp ix1) -> Sym.Array p sh0 a -> IO (p -> IO (Phys.Array sh1 a)) permute accum deflt ixmap input = scatter accum deflt (Core.mapWithIndex (Expr.lift2 MultiValue.zip . ixmap) input)