module Data.Array.Knead.Simple.PhysicalPrivate where
import qualified Data.Array.Knead.Simple.Private as Sym
import qualified Data.Array.Knead.Index.Nested.Shape as Shape
import qualified Data.Array.Knead.Expression as Expr
import Data.Array.Knead.Expression (Exp, unExp)
import Data.Array.Knead.Code (getElementPtr)
import qualified LLVM.Extra.Multi.Value.Memory as MultiValueMemory
import qualified LLVM.Extra.Multi.Value as MultiValue
import qualified LLVM.Extra.Arithmetic as A
import qualified LLVM.Extra.Control as C
import qualified LLVM.Extra.Memory as Memory
import qualified LLVM.Core as LLVM
import Foreign.Storable (Storable, )
import Foreign.Ptr (Ptr, )
import Control.Monad.HT (void, )
import Control.Applicative ((<$>), )
import Data.Tuple.HT (mapSnd, )
mapAccumL ::
(Shape.C sh, Storable sh, MultiValueMemory.C sh,
Shape.C n, Storable n, MultiValueMemory.C n,
MultiValue.C acc,
Storable x, MultiValueMemory.C x,
Storable y, MultiValueMemory.C y) =>
(Exp acc -> Exp x -> Exp (acc,y)) ->
Sym.Array sh acc -> Sym.Array (sh, n) x ->
LLVM.Value (Ptr (MultiValueMemory.Struct (sh,n))) ->
LLVM.Value (Ptr (MultiValueMemory.Struct y)) ->
LLVM.CodeGenFunction r ()
mapAccumL f (Sym.Array _ initCode) (Sym.Array esh code) sptr ptr = do
(sh, n) <- MultiValue.unzip <$> Shape.load esh sptr
let step ix ptrStart = do
accInit <- initCode ix
fmap fst $
(\body -> Shape.loop body n (ptrStart, accInit)) $
\k0 (ptr0, acc0) -> do
x <- code $ MultiValue.zip ix k0
(acc1,y) <- MultiValue.unzip <$> Expr.unliftM2 f acc0 x
Memory.store y ptr0
ptr1 <- A.advanceArrayElementPtr ptr0
return (ptr1, acc1)
void $ Shape.loop step sh ptr
foldOuterL ::
(Shape.C sh, Storable sh, MultiValueMemory.C sh,
Shape.C n, Storable n, MultiValueMemory.C n,
MultiValueMemory.C a) =>
(Exp a -> Exp b -> Exp a) ->
Sym.Array sh a -> Sym.Array (n,sh) b ->
LLVM.Value (Ptr (MultiValueMemory.Struct sh)) ->
LLVM.Value (Ptr (MultiValueMemory.Struct a)) ->
LLVM.CodeGenFunction r ()
foldOuterL f (Sym.Array _ initCode) (Sym.Array esh code) _sptr ptr = do
(n,sh) <- MultiValue.unzip <$> unExp esh
let fillInit ix ptr0 = do
a <- initCode ix
Memory.store a ptr0
A.advanceArrayElementPtr ptr0
void $ Shape.loop fillInit sh ptr
let step k ix ptr0 = do
b <- code $ MultiValue.zip k ix
a0 <- Memory.load ptr0
a1 <- Expr.unliftM2 f a0 b
Memory.store a1 ptr0
A.advanceArrayElementPtr ptr0
void $ Shape.loop (\k () -> void $ Shape.loop (step k) sh ptr) n ()
scatterMaybe ::
(Shape.C sh0, Shape.Index sh0 ~ ix0,
Shape.C sh1, Shape.Index sh1 ~ ix1,
Storable sh1, MultiValueMemory.C sh1,
Storable a, MultiValueMemory.C a) =>
(Exp a -> Exp a -> Exp a) ->
Sym.Array sh1 a -> Sym.Array sh0 (Maybe (ix1, a)) ->
LLVM.Value (Ptr (MultiValueMemory.Struct sh1)) ->
LLVM.Value (Ptr (MultiValueMemory.Struct a)) ->
LLVM.CodeGenFunction r ()
scatterMaybe accum (Sym.Array esh codeInit) (Sym.Array eish codeMap)
sptr ptr = do
let clear ix p = do
flip Memory.store p =<< codeInit ix
A.advanceArrayElementPtr p
sh <- Shape.load esh sptr
void $ Shape.loop clear sh ptr
ish <- unExp eish
let fill ix () = do
(MultiValue.Cons c, (jx, a)) <-
mapSnd MultiValue.unzip . MultiValue.splitMaybe <$> codeMap ix
C.ifThen c () $ do
p <- getElementPtr sh ptr jx
flip Memory.store p
=<< Expr.unliftM2 (flip accum) a
=<< Memory.load p
Shape.loop fill ish ()
scatter ::
(Shape.C sh0, Shape.Index sh0 ~ ix0,
Shape.C sh1, Shape.Index sh1 ~ ix1,
Storable sh1, MultiValueMemory.C sh1,
Storable a, MultiValueMemory.C a) =>
(Exp a -> Exp a -> Exp a) ->
Sym.Array sh1 a ->
Sym.Array sh0 (Shape.Index sh1, a) ->
LLVM.Value (Ptr (MultiValueMemory.Struct sh1)) ->
LLVM.Value (Ptr (MultiValueMemory.Struct a)) ->
LLVM.CodeGenFunction r ()
scatter accum (Sym.Array esh codeInit) (Sym.Array eish codeMap) sptr ptr = do
let clear ix p = do
flip Memory.store p =<< codeInit ix
A.advanceArrayElementPtr p
sh <- Shape.load esh sptr
void $ Shape.loop clear sh ptr
ish <- unExp eish
let fill ix () = do
(jx, a) <- fmap MultiValue.unzip $ codeMap ix
p <- getElementPtr sh ptr jx
flip Memory.store p
=<< Expr.unliftM2 (flip accum) a
=<< Memory.load p
Shape.loop fill ish ()