{-# LANGUAGE GADTs #-} {-# LANGUAGE ForeignFunctionInterface #-} module Data.Array.Knead.Parameterized.Physical ( Phys.Array, Phys.shape, Phys.fromList, feed, the, render, renderShape, scatter, permute, ) where import qualified Data.Array.Knead.Parameterized.Private as Sym import qualified Data.Array.Knead.Simple.Physical as Phys import qualified Data.Array.Knead.Simple.Private as Core import qualified Data.Array.Knead.Parameter as Param import qualified Data.Array.Knead.Index.Nested.Shape as Shape import qualified Data.Array.Knead.Expression as Expr import Data.Array.Knead.Expression (Exp, unExp, ) import Data.Array.Knead.Code (getElementPtr, compile, ) import qualified LLVM.Extra.Multi.Value.Memory as MultiValueMemory import qualified LLVM.Extra.Multi.Value as MultiValue import qualified LLVM.Extra.Arithmetic as A import qualified LLVM.Extra.Memory as Memory import qualified LLVM.Core as LLVM import Foreign.Marshal.Utils (with, ) import Foreign.Marshal.Alloc (alloca, ) import Foreign.Storable (Storable, peek, ) import Foreign.ForeignPtr (withForeignPtr, mallocForeignPtrArray, touchForeignPtr, ) import Foreign.Ptr (FunPtr, Ptr, ) import Control.Exception (bracket, ) import Control.Monad.HT (void, (<=<), ) import Control.Monad (liftM2, ) import Data.Tuple.HT (mapFst, ) import Data.Word (Word32, ) import Prelude hiding (scanl1, ) feed :: (Shape.C sh, Storable sh, MultiValueMemory.C sh, MultiValueMemory.C a) => Param.T p (Phys.Array sh a) -> Sym.Array p sh a feed arr = Param.withMulti (fmap Phys.shape arr) $ \getShape valueShape -> Sym.Array (\p -> case mapFst valueShape $ MultiValue.unzip p of (sh, MultiValue.Cons ptr) -> Core.Array (Expr.lift0 sh) $ Memory.load <=< getElementPtr sh ptr) (\p -> case Phys.buffer $ Param.get arr p of fptr -> withForeignPtr fptr $ \ptr -> return (fptr, (getShape p, MultiValueMemory.castStructPtr ptr))) touchForeignPtr type Importer f = FunPtr f -> f foreign import ccall safe "dynamic" callThe :: Importer (Ptr param -> Ptr am -> IO ()) the :: (Shape.Scalar z, MultiValueMemory.C a, Storable a) => Sym.Array p z a -> IO (p -> IO a) the (Sym.Array arr create delete) = do func <- compile "the" $ LLVM.createNamedFunction LLVM.ExternalLinkage "eval" $ \paramPtr resultPtr -> do param <- Memory.load paramPtr case arr param of Core.Array z code -> code (Shape.zeroIndex z) >>= flip Memory.store resultPtr return $ \p -> bracket (create p) (delete . fst) $ \(_ctx, param) -> with param $ \pptr -> alloca $ \aptr -> callThe func (MultiValueMemory.castStructPtr pptr) (MultiValueMemory.castStructPtr aptr) >> peek aptr foreign import ccall safe "dynamic" callShaper :: Importer (Ptr param -> Ptr shape -> IO Word32) foreign import ccall safe "dynamic" callRenderer :: Importer (Ptr param -> Ptr shape -> Ptr am -> IO ()) renderShape :: (Shape.C sh, Storable sh, MultiValueMemory.C sh, Storable a, MultiValueMemory.C a) => Sym.Array p sh a -> IO (p -> IO (sh, Word32)) renderShape (Sym.Array arr create delete) = do fsh <- compile "renderShape" $ LLVM.createNamedFunction LLVM.ExternalLinkage "shape" $ \paramPtr resultPtr -> do param <- Memory.load paramPtr case arr param of Core.Array esh _code -> do sh <- unExp esh MultiValueMemory.store sh resultPtr Shape.sizeCode sh >>= LLVM.ret return $ \p -> bracket (create p) (delete . fst) $ \(_ctx, param) -> alloca $ \shptr -> with param $ \pptr -> do let lpptr = MultiValueMemory.castStructPtr pptr let lshptr = MultiValueMemory.castStructPtr shptr n <- callShaper fsh lpptr lshptr sh <- peek shptr return (sh, n) render :: (Shape.C sh, Storable sh, MultiValueMemory.C sh, Storable a, MultiValueMemory.C a) => Sym.Array p sh a -> IO (p -> IO (Phys.Array sh a)) render (Sym.Array arr create delete) = do (fsh, farr) <- compile "render" $ liftM2 (,) (LLVM.createNamedFunction LLVM.ExternalLinkage "shape" $ \paramPtr resultPtr -> do param <- Memory.load paramPtr case arr param of Core.Array esh _code -> do sh <- unExp esh MultiValueMemory.store sh resultPtr Shape.sizeCode sh >>= LLVM.ret) (LLVM.createNamedFunction LLVM.ExternalLinkage "fill" $ \paramPtr shapePtr bufferPtr -> do param <- Memory.load paramPtr case arr param of Core.Array esh code -> do let step ix p = do flip Memory.store p =<< code ix A.advanceArrayElementPtr p sh <- Shape.load esh shapePtr void $ Shape.loop step sh bufferPtr) return $ \p -> bracket (create p) (delete . fst) $ \(_ctx, param) -> alloca $ \shptr -> with param $ \pptr -> do let lpptr = MultiValueMemory.castStructPtr pptr let lshptr = MultiValueMemory.castStructPtr shptr n <- callShaper fsh lpptr lshptr fptr <- mallocForeignPtrArray (fromIntegral n) withForeignPtr fptr $ callRenderer farr lpptr lshptr . MultiValueMemory.castStructPtr sh <- peek shptr return (Phys.Array sh fptr) foreign import ccall safe "dynamic" callScatterer :: Importer (Ptr paramBase -> Ptr paramMap -> Ptr shape -> Ptr am -> IO ()) scatter :: (Shape.C sh0, Shape.Index sh0 ~ ix0, Shape.C sh1, Shape.Index sh1 ~ ix1, Storable sh1, MultiValueMemory.C sh1, Storable a, MultiValueMemory.C a) => (Exp a -> Exp a -> Exp a) -> Sym.Array p sh1 a -> Sym.Array p sh0 (ix1, a) -> IO (p -> IO (Phys.Array sh1 a)) scatter accum (Sym.Array arrBase createBase deleteBase) (Sym.Array arrMap createMap deleteMap) = do (fsh, farr) <- compile "scatter" $ liftM2 (,) (LLVM.createNamedFunction LLVM.ExternalLinkage "shape" $ \paramPtr resultPtr -> do param <- Memory.load paramPtr case arrBase param of Core.Array esh _code -> do sh <- unExp esh MultiValueMemory.store sh resultPtr Shape.sizeCode sh >>= LLVM.ret) (LLVM.createNamedFunction LLVM.ExternalLinkage "fill" $ \paramBasePtr paramMapPtr shapePtr bufferPtr -> do paramBase <- Memory.load paramBasePtr paramMap <- Memory.load paramMapPtr case (arrBase paramBase, arrMap paramMap) of (Core.Array esh codeBase, Core.Array eish codeMap) -> do let clear ix p = do flip Memory.store p =<< codeBase ix A.advanceArrayElementPtr p sh <- Shape.load esh shapePtr void $ Shape.loop clear sh bufferPtr ish <- unExp eish let fill ix () = do (jx, a) <- fmap MultiValue.unzip $ codeMap ix p <- getElementPtr sh bufferPtr jx flip Memory.store p =<< Expr.unliftM2 (flip accum) a =<< Memory.load p Shape.loop fill ish ()) return $ \p -> bracket (createBase p) (deleteBase . fst) $ \(_ctxBase, paramBase) -> bracket (createMap p) (deleteMap . fst) $ \(_ctxMap, paramMap) -> alloca $ \shptr -> with paramBase $ \paramBasePtr -> do with paramMap $ \paramMapPtr -> do let paramBaseMVPtr = MultiValueMemory.castStructPtr paramBasePtr let paramMapMVPtr = MultiValueMemory.castStructPtr paramMapPtr let shapeMVPtr = MultiValueMemory.castStructPtr shptr n <- callShaper fsh paramBaseMVPtr shapeMVPtr fptr <- mallocForeignPtrArray (fromIntegral n) withForeignPtr fptr $ callScatterer farr paramBaseMVPtr paramMapMVPtr shapeMVPtr . MultiValueMemory.castStructPtr sh <- peek shptr return (Phys.Array sh fptr) permute :: (Shape.C sh0, Shape.Index sh0 ~ ix0, Shape.C sh1, Shape.Index sh1 ~ ix1, Storable sh1, MultiValueMemory.C sh1, Storable a, MultiValueMemory.C a) => (Exp a -> Exp a -> Exp a) -> Sym.Array p sh1 a -> (Exp ix0 -> Exp ix1) -> Sym.Array p sh0 a -> IO (p -> IO (Phys.Array sh1 a)) permute accum deflt ixmap input = scatter accum deflt (Core.mapWithIndex (Expr.lift2 MultiValue.zip . ixmap) input)