{-# LANGUAGE TypeFamilies #-} {-# LANGUAGE Rank2Types #-} {-# LANGUAGE ForeignFunctionInterface #-} module Data.Array.Knead.Simple.Physical ( Array(Array, shape, buffer), -- data constructor intended for PhysicalParameterized toList, fromList, vectorFromList, with, render, scanl1, mapAccumL, scatter, scatterMaybe, permute, ) where import qualified Data.Array.Knead.Simple.PhysicalPrivate as Priv import qualified Data.Array.Knead.Simple.Private as Sym import qualified Data.Array.Knead.Index.Nested.Shape as Shape import qualified Data.Array.Knead.Expression as Expr import qualified Data.Array.Knead.Code as Code import Data.Array.Knead.Expression (Exp, unExp, ) import Data.Array.Knead.Code (getElementPtr, compile, ) import qualified LLVM.Extra.Multi.Value.Memory as MultiValueMemory import qualified LLVM.Extra.Multi.Value as MultiValue import qualified LLVM.Extra.Arithmetic as A import qualified LLVM.Extra.Memory as Memory import qualified LLVM.Extra.Maybe as Maybe import qualified LLVM.Core as LLVM import Foreign.Marshal.Array (pokeArray, peekArray, ) import Foreign.Marshal.Alloc (alloca, ) import Foreign.Storable (Storable, peek, ) import Foreign.ForeignPtr (ForeignPtr, withForeignPtr, mallocForeignPtrArray, ) import Foreign.Ptr (FunPtr, Ptr, ) import Control.Monad.HT (void, ) import Control.Applicative (liftA2, (<$>), ) import Data.Word (Word32, ) import Prelude hiding (scanl1, ) data Array sh a = Array { shape :: sh, buffer :: ForeignPtr a } toList :: (Shape.C sh, Storable a) => Array sh a -> IO [a] toList (Array sh fptr) = withForeignPtr fptr $ peekArray (Shape.size sh) fromList :: (Shape.C sh, Storable a) => sh -> [a] -> IO (Array sh a) fromList sh xs = do let size = Shape.size sh fptr <- mallocForeignPtrArray size withForeignPtr fptr $ \ptr -> pokeArray ptr $ take size $ xs ++ repeat (error "Array.Knead.Physical.fromList: list too short for shape") return (Array sh fptr) vectorFromList :: (Shape.C sh, Num sh, Storable a) => [a] -> IO (Array sh a) vectorFromList xs = do let size = length xs fptr <- mallocForeignPtrArray size withForeignPtr fptr $ flip pokeArray xs return (Array (fromIntegral size) fptr) {- | The symbolic array is only valid inside the enclosed action. -} with :: (Shape.C sh, MultiValueMemory.C a) => (Sym.Array sh a -> IO b) -> Array sh a -> IO b with f (Array sh fptr) = withForeignPtr fptr $ \ptr -> f $ Sym.Array (Shape.value sh) (\ix -> Memory.load =<< getElementPtr (Shape.value sh) (LLVM.valueOf (MultiValueMemory.castStructPtr ptr)) ix) type Importer f = FunPtr f -> f foreign import ccall safe "dynamic" callShaper :: Importer (Ptr sh -> IO Word32) foreign import ccall safe "dynamic" callRenderer :: Importer (Ptr sh -> Ptr am -> IO ()) materialize :: (Shape.C sh, Storable sh, MultiValueMemory.C sh, Storable a, MultiValueMemory.C a, MultiValueMemory.Struct a ~ am) => String -> Exp sh -> (LLVM.Value (Ptr (MultiValueMemory.Struct sh)) -> LLVM.Value (Ptr am) -> LLVM.CodeGenFunction () ()) -> IO (Array sh a) materialize name esh code = alloca $ \shptr -> do (fsh, farr) <- compile name $ liftA2 (,) (Code.createFunction callShaper "shape" $ \ptr -> do sh <- unExp esh MultiValueMemory.store sh ptr Shape.sizeCode sh >>= LLVM.ret) (Code.createFunction callRenderer "fill" (\paramPtr arrayPtr -> code paramPtr arrayPtr >> LLVM.ret ())) let lshptr = MultiValueMemory.castStructPtr shptr n <- fsh lshptr fptr <- mallocForeignPtrArray (fromIntegral n) withForeignPtr fptr $ farr lshptr . MultiValueMemory.castStructPtr sh <- peek shptr return (Array sh fptr) render :: (Shape.C sh, Storable sh, MultiValueMemory.C sh, Storable a, MultiValueMemory.C a) => Sym.Array sh a -> IO (Array sh a) render (Sym.Array esh code) = materialize "render" esh $ \sptr ptr -> do let step ix p = do flip Memory.store p =<< code ix A.advanceArrayElementPtr p sh <- Shape.load esh sptr void $ Shape.loop step sh ptr scanl1 :: (Shape.C sh, Storable sh, MultiValueMemory.C sh, Storable a, MultiValueMemory.C a) => (Exp a -> Exp a -> Exp a) -> Sym.Array (sh, Word32) a -> IO (Array (sh, Word32) a) scanl1 = scanl1Gen scanl1Gen :: (Shape.C sh, Storable sh, MultiValueMemory.C sh, Shape.C n, Storable n, MultiValueMemory.C n, Storable a, MultiValueMemory.C a) => (Exp a -> Exp a -> Exp a) -> Sym.Array (sh, n) a -> IO (Array (sh, n) a) scanl1Gen f (Sym.Array esh code) = materialize "scanl1" esh $ \sptr ptr -> do (sh, n) <- MultiValue.unzip <$> Shape.load esh sptr let step ix ptrStart = fmap fst $ (\body -> Shape.loop body n (ptrStart, Maybe.nothing)) $ \k0 (ptr0, macc0) -> do a <- code $ MultiValue.zip ix k0 acc1 <- Maybe.run macc0 (return a) (flip (Expr.unliftM2 f) a) Memory.store acc1 ptr0 ptr1 <- A.advanceArrayElementPtr ptr0 return (ptr1, Maybe.just acc1) void $ Shape.loop step sh ptr mapAccumL :: (Shape.C sh, Storable sh, MultiValueMemory.C sh, Shape.C n, Storable n, MultiValueMemory.C n, MultiValue.C acc, Storable x, MultiValueMemory.C x, Storable y, MultiValueMemory.C y) => (Exp acc -> Exp x -> Exp (acc,y)) -> Sym.Array sh acc -> Sym.Array (sh, n) x -> IO (Array (sh, n) y) mapAccumL f arrInit arrData = materialize "mapAccumL" (Sym.shape arrData) $ Priv.mapAccumL f arrInit arrData scatterMaybe :: (Shape.C sh0, Shape.Index sh0 ~ ix0, Shape.C sh1, Shape.Index sh1 ~ ix1, Storable sh1, MultiValueMemory.C sh1, Storable a, MultiValueMemory.C a) => (Exp a -> Exp a -> Exp a) -> Sym.Array sh1 a -> Sym.Array sh0 (Maybe (ix1, a)) -> IO (Array sh1 a) scatterMaybe accum arrInit arrMap = materialize "scatterMaybe" (Sym.shape arrInit) $ Priv.scatterMaybe accum arrInit arrMap scatter :: (Shape.C sh0, Shape.Index sh0 ~ ix0, Shape.C sh1, Shape.Index sh1 ~ ix1, Storable sh1, MultiValueMemory.C sh1, Storable a, MultiValueMemory.C a) => (Exp a -> Exp a -> Exp a) -> Sym.Array sh1 a -> Sym.Array sh0 (ix1, a) -> IO (Array sh1 a) scatter accum arrInit arrMap = materialize "scatter" (Sym.shape arrInit) $ Priv.scatter accum arrInit arrMap permute :: (Shape.C sh0, Shape.Index sh0 ~ ix0, Shape.C sh1, Shape.Index sh1 ~ ix1, Storable sh1, MultiValueMemory.C sh1, Storable a, MultiValueMemory.C a) => (Exp a -> Exp a -> Exp a) -> Sym.Array sh1 a -> (Exp ix0 -> Exp ix1) -> Sym.Array sh0 a -> IO (Array sh1 a) permute accum deflt ixmap input = scatter accum deflt (Sym.mapWithIndex (Expr.lift2 MultiValue.zip . ixmap) input)