module Data.Array.Knead.Simple.Physical (
Array(Array, shape, buffer),
toList,
fromList,
vectorFromList,
with,
render,
scanl1,
scatter,
permute,
) where
import qualified Data.Array.Knead.Simple.Private as Sym
import qualified Data.Array.Knead.Index.Nested.Shape as Shape
import qualified Data.Array.Knead.Expression as Expr
import qualified Data.Array.Knead.Code as Code
import Data.Array.Knead.Expression (Exp, unExp, )
import Data.Array.Knead.Code (getElementPtr, compile, )
import qualified LLVM.Extra.Multi.Value.Memory as MultiValueMemory
import qualified LLVM.Extra.Multi.Value as MultiValue
import qualified LLVM.Extra.Arithmetic as A
import qualified LLVM.Extra.Control as C
import qualified LLVM.Extra.Memory as Memory
import qualified LLVM.Extra.Maybe as Maybe
import qualified LLVM.Core as LLVM
import Foreign.Marshal.Array (pokeArray, peekArray, )
import Foreign.Marshal.Alloc (alloca, )
import Foreign.Storable (Storable, peek, )
import Foreign.ForeignPtr (ForeignPtr, withForeignPtr, mallocForeignPtrArray, )
import Foreign.Ptr (FunPtr, Ptr, )
import Control.Monad.HT (void, )
import Control.Applicative (liftA2, )
import Data.Word (Word32, )
import Prelude hiding (scanl1, )
data Array sh a =
Array {
shape :: sh,
buffer :: ForeignPtr a
}
toList ::
(Shape.C sh, Storable a) =>
Array sh a -> IO [a]
toList (Array sh fptr) =
withForeignPtr fptr $ peekArray (Shape.size sh)
fromList ::
(Shape.C sh, Storable a) =>
sh -> [a] -> IO (Array sh a)
fromList sh xs = do
let size = Shape.size sh
fptr <- mallocForeignPtrArray size
withForeignPtr fptr $
\ptr ->
pokeArray ptr $
take size $
xs ++ repeat (error "Array.Knead.Physical.fromList: list too short for shape")
return (Array sh fptr)
vectorFromList ::
(Shape.C sh, Num sh, Storable a) =>
[a] -> IO (Array sh a)
vectorFromList xs = do
let size = length xs
fptr <- mallocForeignPtrArray size
withForeignPtr fptr $ flip pokeArray xs
return (Array (fromIntegral size) fptr)
with ::
(Shape.C sh, MultiValueMemory.C a) =>
(Sym.Array sh a -> IO b) ->
Array sh a -> IO b
with f (Array sh fptr) =
withForeignPtr fptr $ \ptr ->
f $
Sym.Array
(Shape.value sh)
(\ix ->
Memory.load =<<
getElementPtr (Shape.value sh)
(LLVM.valueOf (MultiValueMemory.castStructPtr ptr)) ix)
type Importer f = FunPtr f -> f
foreign import ccall safe "dynamic" callShaper ::
Importer (Ptr sh -> IO Word32)
foreign import ccall safe "dynamic" callRenderer ::
Importer (Ptr sh -> Ptr am -> IO ())
materialize ::
(Shape.C sh, Storable sh, MultiValueMemory.C sh,
Storable a, MultiValueMemory.C a, MultiValueMemory.Struct a ~ am) =>
String ->
Exp sh ->
(LLVM.Value (Ptr (MultiValueMemory.Struct sh)) ->
LLVM.Value (Ptr am) -> LLVM.CodeGenFunction () ()) ->
IO (Array sh a)
materialize name esh code =
alloca $ \shptr -> do
(fsh, farr) <-
compile name $
liftA2 (,)
(Code.createFunction callShaper "shape" $ \ptr -> do
sh <- unExp esh
MultiValueMemory.store sh ptr
Shape.sizeCode sh >>= LLVM.ret)
(Code.createFunction callRenderer "fill"
(\paramPtr arrayPtr -> code paramPtr arrayPtr >> LLVM.ret ()))
let lshptr = MultiValueMemory.castStructPtr shptr
n <- fsh lshptr
fptr <- mallocForeignPtrArray (fromIntegral n)
withForeignPtr fptr $ farr lshptr . MultiValueMemory.castStructPtr
sh <- peek shptr
return (Array sh fptr)
render ::
(Shape.C sh, Storable sh, MultiValueMemory.C sh,
Storable a, MultiValueMemory.C a) =>
Sym.Array sh a -> IO (Array sh a)
render (Sym.Array esh code) =
materialize "render" esh $ \sptr ptr -> do
let step ix p = do
flip Memory.store p =<< code ix
A.advanceArrayElementPtr p
sh <- Shape.load esh sptr
void $ Shape.loop step sh ptr
scanl1 ::
(Shape.C sh, Storable sh, MultiValueMemory.C sh,
Storable a, MultiValueMemory.C a) =>
(Exp a -> Exp a -> Exp a) ->
Sym.Array (sh, Word32) a -> IO (Array (sh, Word32) a)
scanl1 f (Sym.Array esh code) =
materialize "scanl1" esh $ \sptr ptr -> do
(sh, MultiValue.Cons n) <-
fmap MultiValue.unzip $ Shape.load esh sptr
let step ix ptrStart =
fmap (fst.fst) $
C.fixedLengthLoop n ((ptrStart, A.zero), Maybe.nothing) $ \((ptr0, k0), macc0) -> do
a <- code (MultiValue.zip ix $ MultiValue.Cons k0)
acc1 <- Maybe.run macc0 (return a) (flip (Expr.unliftM2 f) a)
flip Memory.store ptr0 acc1
ptrK1 <-
liftA2 (,)
(A.advanceArrayElementPtr ptr0)
(A.inc k0)
return (ptrK1, Maybe.just acc1)
void $ Shape.loop step sh ptr
scatter ::
(Shape.C sh0, Shape.Index sh0 ~ ix0,
Shape.C sh1, Shape.Index sh1 ~ ix1,
Storable sh1, MultiValueMemory.C sh1,
Storable a, MultiValueMemory.C a) =>
(Exp a -> Exp a -> Exp a) ->
Sym.Array sh1 a ->
Sym.Array sh0 (ix1, a) -> IO (Array sh1 a)
scatter accum (Sym.Array esh defltCode) (Sym.Array eish code) =
materialize "scatter" esh $ \sptr ptr -> do
let clear ix p = do
flip Memory.store p =<< defltCode ix
A.advanceArrayElementPtr p
sh <- Shape.load esh sptr
void $ Shape.loop clear sh ptr
ish <- unExp eish
let fill ix () = do
(jx, a) <- fmap MultiValue.unzip $ code ix
p <- getElementPtr sh ptr jx
flip Memory.store p
=<< Expr.unliftM2 (flip accum) a
=<< Memory.load p
void $ Shape.loop fill ish ()
permute ::
(Shape.C sh0, Shape.Index sh0 ~ ix0,
Shape.C sh1, Shape.Index sh1 ~ ix1,
Storable sh1, MultiValueMemory.C sh1,
Storable a, MultiValueMemory.C a) =>
(Exp a -> Exp a -> Exp a) ->
Sym.Array sh1 a ->
(Exp ix0 -> Exp ix1) ->
Sym.Array sh0 a ->
IO (Array sh1 a)
permute accum deflt ixmap input =
scatter accum deflt
(Sym.mapWithIndex (Expr.lift2 MultiValue.zip . ixmap) input)