{-# LANGUAGE GADTs #-} {-# LANGUAGE ForeignFunctionInterface #-} module Data.Array.Knead.Parameterized.Physical ( Phys.Array, Array.shape, Phys.fromList, feed, the, theMarshal, render, renderShape, mapAccumLSimple, foldOuterL, scatter, scatterMaybe, permute, ) where import qualified Data.Array.Knead.Parameterized.PhysicalHull as PhysHull import qualified Data.Array.Knead.Parameterized.Private as Sym import qualified Data.Array.Knead.Simple.Physical as Phys import qualified Data.Array.Knead.Simple.Private as Core import qualified Data.Array.Knead.Shape as Shape import qualified Data.Array.Knead.Expression as Expr import Data.Array.Knead.Code (getElementPtr) import qualified Data.Array.Comfort.Storable.Unchecked as Array import qualified LLVM.DSL.Parameter as Param import qualified LLVM.DSL.Execution as Code import LLVM.DSL.Expression (Exp, unExp) import qualified LLVM.Extra.Multi.Value as MultiValue import qualified LLVM.Extra.Storable as Storable import qualified LLVM.Extra.Marshal as Marshal import qualified LLVM.Extra.Memory as Memory import qualified LLVM.Core as LLVM import Foreign.Marshal.Alloc (alloca, ) import Foreign.Storable (peek, ) import Foreign.ForeignPtr (withForeignPtr, touchForeignPtr, ) import Foreign.Ptr (FunPtr, Ptr, ) import Control.Exception (bracket, ) import Control.Monad.HT ((<=<), ) import Control.Applicative (liftA2, ) import Data.Tuple.HT (mapFst, ) {-# INLINE feed #-} feed :: (Shape.C sh, Marshal.MV sh, Storable.C a) => Param.T p (Phys.Array sh a) -> Sym.Array p sh a feed arr = Param.withMulti (fmap Array.shape arr) $ \getShape valueShape -> Sym.Array (\p -> case mapFst valueShape $ MultiValue.unzip p of (sh, MultiValue.Cons ptr) -> Core.Array (Expr.lift0 sh) $ Storable.loadMultiValue <=< getElementPtr sh ptr) (\p -> case Array.buffer $ Param.get arr p of fptr -> withForeignPtr fptr $ \ptr -> return (fptr, (getShape p, ptr))) touchForeignPtr type Importer f = FunPtr f -> f foreign import ccall safe "dynamic" callThe :: Importer (LLVM.Ptr param -> Ptr a -> IO ()) the :: (Shape.Scalar z, Storable.C a, MultiValue.C a) => Sym.Array p z a -> IO (p -> IO a) the (Sym.Array arr create delete) = do func <- Code.compile "the" $ Code.createFunction callThe "eval" $ \paramPtr resultPtr -> do param <- Memory.load paramPtr case arr param of Core.Array z code -> code (Shape.zeroIndex z) >>= flip Storable.storeMultiValue resultPtr LLVM.ret () return $ \p -> bracket (create p) (delete . fst) $ \(_ctx, param) -> Marshal.with param $ \pptr -> alloca $ \aptr -> func pptr aptr >> peek aptr foreign import ccall safe "dynamic" callTheMarshal :: Importer (LLVM.Ptr param -> LLVM.Ptr a -> IO ()) theMarshal :: (Shape.Scalar z, Marshal.C a, MultiValue.C a) => Sym.Array p z a -> IO (p -> IO a) theMarshal (Sym.Array arr create delete) = do func <- Code.compile "the-marshal" $ Code.createFunction callTheMarshal "eval" $ \paramPtr resultPtr -> do param <- Memory.load paramPtr case arr param of Core.Array z code -> code (Shape.zeroIndex z) >>= flip Memory.store resultPtr LLVM.ret () return $ \p -> bracket (create p) (delete . fst) $ \(_ctx, param) -> Marshal.with param $ \pptr -> Marshal.alloca $ \aptr -> func pptr aptr >> Marshal.peek aptr foreign import ccall safe "dynamic" callShaper :: Importer (LLVM.Ptr param -> LLVM.Ptr shape -> IO Shape.Size) renderShape :: (Shape.C sh, Marshal.MV sh, Storable.C a, MultiValue.C a) => Sym.Array p sh a -> IO (p -> IO (sh, Shape.Size)) renderShape (Sym.Array arr create delete) = do fsh <- Code.compile "renderShape" $ Code.createFunction callShaper "shape" $ \paramPtr resultPtr -> do param <- Memory.load paramPtr case arr param of Core.Array esh _code -> do sh <- unExp esh Memory.store sh resultPtr Shape.size sh >>= LLVM.ret return $ \p -> bracket (create p) (delete . fst) $ \(_ctx, param) -> Marshal.alloca $ \shptr -> Marshal.with param $ \pptr -> do n <- fsh pptr shptr sh <- Marshal.peek shptr return (sh, n) render :: (Shape.C sh, Marshal.MV sh, Storable.C a) => Sym.Array p sh a -> IO (p -> IO (Phys.Array sh a)) render = PhysHull.render . Sym.arrayHull mapAccumLSimple :: (Shape.C sh, Marshal.MV sh, Shape.C n, Marshal.MV n, MultiValue.C acc, Storable.C a, MultiValue.C a, Storable.C b, MultiValue.C b) => (Exp acc -> Exp a -> Exp (acc,b)) -> Sym.Array p sh acc -> Sym.Array p (sh, n) a -> IO (p -> IO (Phys.Array (sh,n) b)) mapAccumLSimple f arrInit arrMap = PhysHull.mapAccumLSimple $ liftA2 (PhysHull.MapAccumLSimple f) (Sym.arrayHull arrInit) (Sym.arrayHull arrMap) foldOuterL :: (Shape.C sh, Marshal.MV sh, Shape.C n, Marshal.MV n, Storable.C a, MultiValue.C a) => (Exp a -> Exp b -> Exp a) -> Sym.Array p sh a -> Sym.Array p (n,sh) b -> IO (p -> IO (Phys.Array sh a)) foldOuterL f arrInit arrMap = PhysHull.foldOuterL $ liftA2 (PhysHull.FoldOuterL f) (Sym.arrayHull arrInit) (Sym.arrayHull arrMap) scatter :: (Shape.C sh0, Shape.Index sh0 ~ ix0, Shape.C sh1, Shape.Index sh1 ~ ix1, Marshal.MV sh1, Storable.C a, MultiValue.C a) => (Exp a -> Exp a -> Exp a) -> Sym.Array p sh1 a -> Sym.Array p sh0 (ix1, a) -> IO (p -> IO (Phys.Array sh1 a)) scatter accum arrBase arrMap = PhysHull.scatter $ liftA2 (PhysHull.Scatter accum) (Sym.arrayHull arrBase) (Sym.arrayHull arrMap) scatterMaybe :: (Shape.C sh0, Shape.Index sh0 ~ ix0, Shape.C sh1, Shape.Index sh1 ~ ix1, Marshal.MV sh1, Storable.C a, MultiValue.C a) => (Exp a -> Exp a -> Exp a) -> Sym.Array p sh1 a -> Sym.Array p sh0 (Maybe (ix1, a)) -> IO (p -> IO (Phys.Array sh1 a)) scatterMaybe accum arrBase arrMap = PhysHull.scatterMaybe $ liftA2 (PhysHull.ScatterMaybe accum) (Sym.arrayHull arrBase) (Sym.arrayHull arrMap) permute :: (Shape.C sh0, Shape.Index sh0 ~ ix0, Shape.C sh1, Shape.Index sh1 ~ ix1, Marshal.MV sh1, Storable.C a, MultiValue.C a) => (Exp a -> Exp a -> Exp a) -> Sym.Array p sh1 a -> (Exp ix0 -> Exp ix1) -> Sym.Array p sh0 a -> IO (p -> IO (Phys.Array sh1 a)) permute accum deflt ixmap input = scatter accum deflt (Core.mapWithIndex (Expr.lift2 MultiValue.zip . ixmap) input)