module Data.Array.Knead.Simple.PhysicalPrivate where
import qualified Data.Array.Knead.Simple.Private as Sym
import qualified Data.Array.Knead.Shape as Shape
import qualified Data.Array.Knead.Expression as Expr
import Data.Array.Knead.Expression (Exp, unExp)
import Data.Array.Knead.Code (getElementPtr)
import qualified LLVM.Extra.Multi.Value.Memory as MultiValueMemory
import qualified LLVM.Extra.Multi.Value as MultiValue
import qualified LLVM.Extra.Arithmetic as A
import qualified LLVM.Extra.Control as C
import qualified LLVM.Extra.Memory as Memory
import qualified LLVM.Core as LLVM
import Foreign.Ptr (Ptr, )
import qualified Control.Applicative.HT as App
import Control.Monad.HT (void, )
import Control.Applicative ((<$>), )
import Data.Tuple.HT (mapSnd, )
writeArray ::
(Shape.C sh, Memory.C a) =>
MultiValue.T sh ->
(MultiValue.T (Shape.Index sh) -> LLVM.CodeGenFunction r a) ->
LLVM.Value (Ptr (Memory.Struct a)) ->
LLVM.CodeGenFunction r (LLVM.Value (Ptr (Memory.Struct a)))
writeArray sh code ptr = do
let clear ix p = do
flip Memory.store p =<< code ix
A.advanceArrayElementPtr p
Shape.loop clear sh ptr
mapAccumLLoop ::
(MultiValue.C acc,
MultiValueMemory.C b, MultiValueMemory.Struct b ~ bm,
Shape.C sh, Shape.Index sh ~ ix) =>
(MultiValue.T ix -> LLVM.CodeGenFunction r (MultiValue.T a)) ->
(Exp acc -> Exp a -> Exp (acc, b)) ->
MultiValue.T sh ->
LLVM.Value (Ptr bm) -> MultiValue.T acc ->
LLVM.CodeGenFunction r (LLVM.Value (Ptr bm), MultiValue.T acc)
mapAccumLLoop code f n yPtr accInit = do
let step k0 (ptr0, acc0) = do
x <- code k0
(acc1,y) <- MultiValue.unzip <$> Expr.unliftM2 f acc0 x
Memory.store y ptr0
ptr1 <- A.advanceArrayElementPtr ptr0
return (ptr1, acc1)
Shape.loop step n (yPtr, accInit)
mapAccumLSimple ::
(Shape.C sh, MultiValueMemory.C sh,
Shape.C n, MultiValueMemory.C n,
MultiValue.C acc,
MultiValueMemory.C x,
MultiValueMemory.C y) =>
(Exp acc -> Exp x -> Exp (acc,y)) ->
Sym.Array sh acc -> Sym.Array (sh, n) x ->
LLVM.Value (Ptr (MultiValueMemory.Struct (sh,n))) ->
LLVM.Value (Ptr (MultiValueMemory.Struct y)) ->
LLVM.CodeGenFunction r ()
mapAccumLSimple f (Sym.Array _ initCode) (Sym.Array esh code) sptr ptr = do
(sh, n) <- MultiValue.unzip <$> Shape.load esh sptr
let step ix ptrStart = do
accInit <- initCode ix
fst <$> mapAccumLLoop (code . MultiValue.zip ix) f n ptrStart accInit
void $ Shape.loop step sh ptr
mapAccumLSequence ::
(Shape.C n, MultiValueMemory.C n,
MultiValue.C acc, MultiValueMemory.C final,
MultiValueMemory.C x,
MultiValueMemory.C y) =>
(Exp acc -> Exp x -> Exp (acc,y)) ->
(Exp acc -> Exp final) ->
Exp acc -> Sym.Array n x ->
LLVM.Value (Ptr (MultiValueMemory.Struct final)) ->
LLVM.Value (Ptr (MultiValueMemory.Struct n)) ->
LLVM.Value (Ptr (MultiValueMemory.Struct y)) ->
LLVM.CodeGenFunction r ()
mapAccumLSequence f final initExp (Sym.Array esh code) accPtr sptr yPtr = do
n <- Shape.load esh sptr
accInit <- unExp initExp
accExit <- snd <$> mapAccumLLoop code f n yPtr accInit
flip Memory.store accPtr =<< Expr.unliftM1 final accExit
mapAccumL ::
(Shape.C sh, MultiValueMemory.C sh,
Shape.C n, MultiValueMemory.C n,
MultiValue.C acc, MultiValueMemory.C final,
MultiValueMemory.C x,
MultiValueMemory.C y) =>
(Exp acc -> Exp x -> Exp (acc,y)) ->
(Exp acc -> Exp final) ->
Sym.Array sh acc -> Sym.Array (sh, n) x ->
(LLVM.Value (Ptr (MultiValueMemory.Struct sh)),
LLVM.Value (Ptr (MultiValueMemory.Struct final))) ->
(LLVM.Value (Ptr (MultiValueMemory.Struct (sh,n))),
LLVM.Value (Ptr (MultiValueMemory.Struct y))) ->
LLVM.CodeGenFunction r ()
mapAccumL f final (Sym.Array _ initCode) (Sym.Array esh code)
(_, accPtr) (sptr, yPtr) = do
(sh, n) <- MultiValue.unzip <$> Shape.load esh sptr
let step ix (accPtr0, yPtrStart) = do
accInit <- initCode ix
(ptrStop, accExit) <-
mapAccumLLoop (code . MultiValue.zip ix) f n yPtrStart accInit
flip Memory.store accPtr0 =<< Expr.unliftM1 final accExit
accPtr1 <- A.advanceArrayElementPtr accPtr0
return (accPtr1, ptrStop)
void $ Shape.loop step sh (accPtr,yPtr)
foldOuterL ::
(Shape.C sh, MultiValueMemory.C sh,
Shape.C n, MultiValueMemory.C n,
MultiValueMemory.C a) =>
(Exp a -> Exp b -> Exp a) ->
Sym.Array sh a -> Sym.Array (n,sh) b ->
LLVM.Value (Ptr (MultiValueMemory.Struct sh)) ->
LLVM.Value (Ptr (MultiValueMemory.Struct a)) ->
LLVM.CodeGenFunction r ()
foldOuterL f (Sym.Array _ initCode) (Sym.Array esh code) sptr ptr = do
sh <- Shape.load (Expr.snd esh) sptr
n <- MultiValue.fst <$> unExp esh
void $ writeArray sh initCode ptr
let step k ix ptr0 = do
b <- code $ MultiValue.zip k ix
a0 <- Memory.load ptr0
a1 <- Expr.unliftM2 f a0 b
Memory.store a1 ptr0
A.advanceArrayElementPtr ptr0
void $ Shape.loop (\k () -> void $ Shape.loop (step k) sh ptr) n ()
mapFilter ::
(Shape.Sequence n, MultiValueMemory.C n,
MultiValueMemory.C b) =>
(Exp a -> Exp b) ->
(Exp a -> Exp Bool) ->
Sym.Array n a ->
LLVM.Value (Ptr (MultiValueMemory.Struct n)) ->
LLVM.Value (Ptr (MultiValueMemory.Struct b)) ->
LLVM.CodeGenFunction r (MultiValue.T n)
mapFilter f p (Sym.Array esh code) sptr ptr = do
n <- Shape.load esh sptr
let step ix (dstPtr,dstIx) = do
a <- code ix
MultiValue.Cons c <- Expr.unliftM1 p a
C.ifThen c (dstPtr,dstIx)
(do
flip Memory.store dstPtr =<< Expr.unliftM1 f a
App.lift2 (,)
(A.advanceArrayElementPtr dstPtr)
(MultiValue.inc dstIx))
Shape.sequenceShapeFromIndex . snd
=<< Shape.loop step n (ptr, MultiValue.zero)
filterOuter ::
(Shape.Sequence n, MultiValueMemory.C n,
Shape.C sh, MultiValueMemory.C sh,
MultiValueMemory.C a) =>
Sym.Array n Bool ->
Sym.Array (n,sh) a ->
LLVM.Value (Ptr (MultiValueMemory.Struct (n,sh))) ->
LLVM.Value (Ptr (MultiValueMemory.Struct a)) ->
LLVM.CodeGenFunction r (MultiValue.T (n,sh))
filterOuter (Sym.Array _eish selectCode) (Sym.Array esh code) sptr ptr = do
(n,sh) <- MultiValue.unzip <$> Shape.load esh sptr
let step k (dstPtr0,dstK) = do
MultiValue.Cons c <- selectCode k
C.ifThen c (dstPtr0,dstK)
(do
dstPtr1 <- writeArray sh (code . MultiValue.zip k) dstPtr0
(,) dstPtr1 <$> MultiValue.inc dstK)
finalN <-
Shape.sequenceShapeFromIndex . snd
=<< Shape.loop step n (ptr, MultiValue.zero)
return $ MultiValue.zip finalN sh
scatterMaybe ::
(Shape.C sh0, Shape.Index sh0 ~ ix0,
Shape.C sh1, Shape.Index sh1 ~ ix1,
MultiValueMemory.C sh1,
MultiValueMemory.C a) =>
(Exp a -> Exp a -> Exp a) ->
Sym.Array sh1 a -> Sym.Array sh0 (Maybe (ix1, a)) ->
LLVM.Value (Ptr (MultiValueMemory.Struct sh1)) ->
LLVM.Value (Ptr (MultiValueMemory.Struct a)) ->
LLVM.CodeGenFunction r ()
scatterMaybe accum (Sym.Array esh codeInit) (Sym.Array eish codeMap)
sptr ptr = do
sh <- Shape.load esh sptr
void $ writeArray sh codeInit ptr
ish <- unExp eish
let fill ix () = do
(MultiValue.Cons c, (jx, a)) <-
mapSnd MultiValue.unzip . MultiValue.splitMaybe <$> codeMap ix
C.ifThen c () $ do
p <- getElementPtr sh ptr jx
flip Memory.store p
=<< Expr.unliftM2 (flip accum) a
=<< Memory.load p
Shape.loop fill ish ()
scatter ::
(Shape.C sh0, Shape.Index sh0 ~ ix0,
Shape.C sh1, Shape.Index sh1 ~ ix1,
MultiValueMemory.C sh1,
MultiValueMemory.C a) =>
(Exp a -> Exp a -> Exp a) ->
Sym.Array sh1 a ->
Sym.Array sh0 (Shape.Index sh1, a) ->
LLVM.Value (Ptr (MultiValueMemory.Struct sh1)) ->
LLVM.Value (Ptr (MultiValueMemory.Struct a)) ->
LLVM.CodeGenFunction r ()
scatter accum (Sym.Array esh codeInit) (Sym.Array eish codeMap) sptr ptr = do
sh <- Shape.load esh sptr
void $ writeArray sh codeInit ptr
ish <- unExp eish
let fill ix () = do
(jx, a) <- MultiValue.unzip <$> codeMap ix
p <- getElementPtr sh ptr jx
flip Memory.store p
=<< Expr.unliftM2 (flip accum) a
=<< Memory.load p
Shape.loop fill ish ()
addDimension ::
(Shape.C n, MultiValueMemory.C n, Shape.Index n ~ k,
Shape.C sh, MultiValueMemory.C sh,
MultiValueMemory.C b) =>
Exp n ->
(Exp k -> Exp a -> Exp b) ->
Sym.Array sh a ->
LLVM.Value (Ptr (MultiValueMemory.Struct (sh,n))) ->
LLVM.Value (Ptr (MultiValueMemory.Struct b)) ->
LLVM.CodeGenFunction r ()
addDimension en select (Sym.Array esh code) sptr ptr = do
(sh,n) <- MultiValue.unzip <$> Shape.load (Expr.zip esh en) sptr
let fill ix ptr0 = do
a <- code ix
writeArray n (\k -> Expr.unliftM2 select k a) ptr0
void $ Shape.loop fill sh ptr