{-# LANGUAGE TypeFamilies #-}
module Futhark.AD.Rev.Scatter (vjpScatter) where
import Control.Monad
import Futhark.AD.Rev.Monad
import Futhark.Analysis.PrimExp.Convert
import Futhark.Builder
import Futhark.IR.SOACS
import Futhark.Tools
import Futhark.Util (chunk)
withinBounds :: [(SubExp, VName)] -> TPrimExp Bool VName
withinBounds :: [(SubExp, VName)] -> TPrimExp Bool VName
withinBounds [] = forall {k} (t :: k) v. PrimExp v -> TPrimExp t v
TPrimExp forall a b. (a -> b) -> a -> b
$ forall v. PrimValue -> PrimExp v
ValueExp (Bool -> PrimValue
BoolValue Bool
True)
withinBounds [(SubExp
q, VName
i)] = (forall a. a -> TPrimExp Int64 a
le64 VName
i forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. SubExp -> TPrimExp Int64 VName
pe64 SubExp
q) forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. (SubExp -> TPrimExp Int64 VName
pe64 (IntType -> Integer -> SubExp
intConst IntType
Int64 (-Integer
1)) forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. forall a. a -> TPrimExp Int64 a
le64 VName
i)
withinBounds ((SubExp, VName)
qi : [(SubExp, VName)]
qis) = [(SubExp, VName)] -> TPrimExp Bool VName
withinBounds [(SubExp, VName)
qi] forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. [(SubExp, VName)] -> TPrimExp Bool VName
withinBounds [(SubExp, VName)]
qis
genIdxLamBody :: VName -> [(SubExp, Param Type)] -> Type -> ADM (Body SOACS)
genIdxLamBody :: VName -> [(SubExp, Param Type)] -> Type -> ADM (Body SOACS)
genIdxLamBody VName
as [(SubExp, Param Type)]
wpis = VName
-> [(SubExp, Param Type)]
-> [Param Type]
-> Type
-> ADM (Body SOACS)
genRecLamBody VName
as [(SubExp, Param Type)]
wpis []
where
genRecLamBody :: VName -> [(SubExp, Param Type)] -> [Param Type] -> Type -> ADM (Body SOACS)
genRecLamBody :: VName
-> [(SubExp, Param Type)]
-> [Param Type]
-> Type
-> ADM (Body SOACS)
genRecLamBody VName
arr [(SubExp, Param Type)]
w_pis [Param Type]
nest_pis (Array PrimType
t (Shape []) NoUniqueness
_) =
VName
-> [(SubExp, Param Type)]
-> [Param Type]
-> Type
-> ADM (Body SOACS)
genRecLamBody VName
arr [(SubExp, Param Type)]
w_pis [Param Type]
nest_pis (forall shape u. PrimType -> TypeBase shape u
Prim PrimType
t)
genRecLamBody VName
arr [(SubExp, Param Type)]
w_pis [Param Type]
nest_pis (Array PrimType
t (Shape (SubExp
s : [SubExp]
ss)) NoUniqueness
_) = do
Param Type
new_ip <- forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam [Char]
"i" (forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64)
let t' :: Type
t' = forall shape u. PrimType -> TypeBase shape u
Prim PrimType
t Type -> ShapeBase SubExp -> Type
`arrayOfShape` forall d. [d] -> ShapeBase d
Shape [SubExp]
ss
Lambda SOACS
inner_lam <-
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m Result -> m (Lambda (Rep m))
mkLambda [Param Type
new_ip] forall a b. (a -> b) -> a -> b
$
forall (m :: * -> *). MonadBuilder m => Body (Rep m) -> m Result
bodyBind forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName
-> [(SubExp, Param Type)]
-> [Param Type]
-> Type
-> ADM (Body SOACS)
genRecLamBody VName
arr [(SubExp, Param Type)]
w_pis ([Param Type]
nest_pis forall a. [a] -> [a] -> [a]
++ [Param Type
new_ip]) Type
t'
let ([SubExp]
_, [Param Type]
orig_pis) = forall a b. [(a, b)] -> ([a], [b])
unzip [(SubExp, Param Type)]
w_pis
forall (m :: * -> *).
MonadBuilder m =>
m Result -> m (Body (Rep m))
buildBody_ forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k) (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (forall {k} (rep :: k) dec.
(LParamInfo rep ~ dec) =>
[Param dec] -> Scope rep
scopeOfLParams ([Param Type]
orig_pis forall a. [a] -> [a] -> [a]
++ [Param Type]
nest_pis)) forall a b. (a -> b) -> a -> b
$ do
VName
iota_v <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"iota" forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ SubExp -> SubExp -> SubExp -> IntType -> BasicOp
Iota SubExp
s (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) IntType
Int64
SubExp
r <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp (VName -> [Char]
baseString VName
arr forall a. [a] -> [a] -> [a]
++ [Char]
"_elem") forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
s [VName
iota_v] (forall {k} (rep :: k). Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
inner_lam)
forall (f :: * -> *) a. Applicative f => a -> f a
pure [SubExp -> SubExpRes
subExpRes SubExp
r]
genRecLamBody VName
arr [(SubExp, Param Type)]
w_pis [Param Type]
nest_pis (Prim PrimType
ptp) = do
let ([SubExp]
ws, [Param Type]
orig_pis) = forall a b. [(a, b)] -> ([a], [b])
unzip [(SubExp, Param Type)]
w_pis
let inds :: [VName]
inds = forall a b. (a -> b) -> [a] -> [b]
map forall dec. Param dec -> VName
paramName ([Param Type]
orig_pis forall a. [a] -> [a] -> [a]
++ [Param Type]
nest_pis)
forall {k} (rep :: k) (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (forall {k} (rep :: k) dec.
(LParamInfo rep ~ dec) =>
[Param dec] -> Scope rep
scopeOfLParams ([Param Type]
orig_pis forall a. [a] -> [a] -> [a]
++ [Param Type]
nest_pis)) forall a b. (a -> b) -> a -> b
$
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody
[ forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
(forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp forall a b. (a -> b) -> a -> b
$ [(SubExp, VName)] -> TPrimExp Bool VName
withinBounds forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip [SubExp]
ws forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall dec. Param dec -> VName
paramName [Param Type]
orig_pis)
( do
SubExp
r <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"r" forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
arr forall a b. (a -> b) -> a -> b
$ forall d. [DimIndex d] -> Slice d
Slice forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (forall d. d -> DimIndex d
DimFix forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var) [VName]
inds
forall (m :: * -> *).
MonadBuilder m =>
[SubExp] -> m (Body (Rep m))
resultBodyM [SubExp
r]
)
(forall (m :: * -> *).
MonadBuilder m =>
[SubExp] -> m (Body (Rep m))
resultBodyM [PrimValue -> SubExp
Constant forall a b. (a -> b) -> a -> b
$ PrimType -> PrimValue
blankPrimValue PrimType
ptp])
]
genRecLamBody VName
_ [(SubExp, Param Type)]
_ [Param Type]
_ Type
_ = forall a. HasCallStack => [Char] -> a
error [Char]
"In Rev.hs, helper function genRecLamBody, unreachable case reached!"
vjpScatter1 ::
PatElem Type ->
StmAux () ->
(SubExp, [VName], (ShapeBase SubExp, Int, VName)) ->
ADM () ->
ADM ()
vjpScatter1 :: PatElem Type
-> StmAux ()
-> (SubExp, [VName], (ShapeBase SubExp, Int, VName))
-> ADM ()
-> ADM ()
vjpScatter1 PatElem Type
pys StmAux ()
aux (SubExp
w, [VName]
ass, (ShapeBase SubExp
shp, Int
num_vals, VName
xs)) ADM ()
m = do
let rank :: Int
rank = forall (t :: * -> *) a. Foldable t => t a -> Int
length forall a b. (a -> b) -> a -> b
$ forall d. ShapeBase d -> [d]
shapeDims ShapeBase SubExp
shp
([VName]
all_inds, [VName]
val_as) = forall a. Int -> [a] -> ([a], [a])
splitAt (Int
rank forall a. Num a => a -> a -> a
* Int
num_vals) [VName]
ass
inds_as :: [[VName]]
inds_as = forall a. Int -> [a] -> [[a]]
chunk Int
rank [VName]
all_inds
Type
xs_t <- forall {k} (rep :: k) (m :: * -> *).
HasScope rep m =>
VName -> m Type
lookupType VName
xs
let val_t :: Type
val_t = forall u.
Int
-> TypeBase (ShapeBase SubExp) u -> TypeBase (ShapeBase SubExp) u
stripArray (forall a. ArrayShape a => a -> Int
shapeRank ShapeBase SubExp
shp) Type
xs_t
[VName]
xs_saves <- [[VName]] -> VName -> Type -> ADM [VName]
mkGather [[VName]]
inds_as VName
xs Type
xs_t
Lambda SOACS
id_lam <-
forall {k} (rep :: k) (m :: * -> *).
(Buildable rep, MonadFreshNames m) =>
[Type] -> m (Lambda rep)
mkIdentityLambda forall a b. (a -> b) -> a -> b
$
forall a. Int -> a -> [a]
replicate (forall a. ArrayShape a => a -> Int
shapeRank ShapeBase SubExp
shp) (forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64) forall a. [a] -> [a] -> [a]
++ forall a. Int -> a -> [a]
replicate (forall a. ArrayShape a => a -> Int
shapeRank ShapeBase SubExp
shp) Type
val_t
forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let (forall dec. [PatElem dec] -> Pat dec
Pat [PatElem Type
pys]) StmAux ()
aux forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
SubExp
-> [VName]
-> Lambda rep
-> [(ShapeBase SubExp, Int, VName)]
-> SOAC rep
Scatter SubExp
w [VName]
ass Lambda SOACS
id_lam [(ShapeBase SubExp
shp, Int
num_vals, VName
xs)]
ADM ()
m
let ys :: VName
ys = forall dec. PatElem dec -> VName
patElemName PatElem Type
pys
VName
ys_copy <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp (VName -> [Char]
baseString VName
ys forall a. Semigroup a => a -> a -> a
<> [Char]
"_copy") forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy VName
ys
forall a. ADM a -> ADM a
returnSweepCode forall a b. (a -> b) -> a -> b
$ do
VName
ys_adj <- VName -> ADM VName
lookupAdjVal VName
ys
[VName]
vs_ctrbs <- [[VName]] -> VName -> Type -> ADM [VName]
mkGather [[VName]]
inds_as VName
ys_adj Type
xs_t
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> VName -> ADM ()
updateAdj [VName]
val_as [VName]
vs_ctrbs
[VName]
zeros <-
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM (forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
val_as) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"zeros" forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k). Type -> Exp rep
zeroExp forall a b. (a -> b) -> a -> b
$
Type
xs_t forall d u.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) u -> d -> TypeBase (ShapeBase d) u
`setOuterSize` SubExp
w
let f_tps :: [Type]
f_tps = forall a. Int -> a -> [a]
replicate (Int
rank forall a. Num a => a -> a -> a
* Int
num_vals) (forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64) forall a. [a] -> [a] -> [a]
++ forall a. Int -> a -> [a]
replicate Int
num_vals Type
val_t
Lambda SOACS
f <- forall {k} (rep :: k) (m :: * -> *).
(Buildable rep, MonadFreshNames m) =>
[Type] -> m (Lambda rep)
mkIdentityLambda [Type]
f_tps
VName
xs_adj <-
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp (VName -> [Char]
baseString VName
xs forall a. [a] -> [a] -> [a]
++ [Char]
"_adj") forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k).
SubExp
-> [VName]
-> Lambda rep
-> [(ShapeBase SubExp, Int, VName)]
-> SOAC rep
Scatter SubExp
w ([VName]
all_inds forall a. [a] -> [a] -> [a]
++ [VName]
zeros) Lambda SOACS
f [(ShapeBase SubExp
shp, Int
num_vals, VName
ys_adj)]
VName -> VName -> ADM ()
insAdj VName
xs VName
xs_adj
Lambda SOACS
f' <- forall {k} (rep :: k) (m :: * -> *).
(Buildable rep, MonadFreshNames m) =>
[Type] -> m (Lambda rep)
mkIdentityLambda [Type]
f_tps
VName
xs_rc <-
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux ()
aux forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp (VName -> [Char]
baseString VName
xs forall a. Semigroup a => a -> a -> a
<> [Char]
"_rc") forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k).
SubExp
-> [VName]
-> Lambda rep
-> [(ShapeBase SubExp, Int, VName)]
-> SOAC rep
Scatter SubExp
w ([VName]
all_inds forall a. [a] -> [a] -> [a]
++ [VName]
xs_saves) Lambda SOACS
f' [(ShapeBase SubExp
shp, Int
num_vals, VName
ys)]
VName -> VName -> ADM ()
addSubstitution VName
xs VName
xs_rc
VName -> VName -> ADM ()
addSubstitution VName
ys VName
ys_copy
where
mkGather :: [[VName]] -> VName -> Type -> ADM [VName]
mkGather :: [[VName]] -> VName -> Type -> ADM [VName]
mkGather [[VName]]
inds_as VName
arr Type
arr_t = do
[[Param Type]]
ips <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [[VName]]
inds_as forall a b. (a -> b) -> a -> b
$ \[VName]
idxs ->
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (\VName
idx -> forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam (VName -> [Char]
baseString VName
idx forall a. [a] -> [a] -> [a]
++ [Char]
"_elem") (forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64)) [VName]
idxs
Lambda SOACS
gather_lam <- forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m Result -> m (Lambda (Rep m))
mkLambda (forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[Param Type]]
ips) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a. Monoid a => [a] -> a
mconcat forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [[Param Type]]
ips forall a b. (a -> b) -> a -> b
$ \[Param Type]
idxs -> do
let q :: Int
q = forall (t :: * -> *) a. Foldable t => t a -> Int
length [Param Type]
idxs
([SubExp]
ws, Type
eltp) = (forall a. Int -> [a] -> [a]
take Int
q forall a b. (a -> b) -> a -> b
$ forall u. TypeBase (ShapeBase SubExp) u -> [SubExp]
arrayDims Type
arr_t, forall u.
Int
-> TypeBase (ShapeBase SubExp) u -> TypeBase (ShapeBase SubExp) u
stripArray Int
q Type
arr_t)
forall (m :: * -> *). MonadBuilder m => Body (Rep m) -> m Result
bodyBind forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName -> [(SubExp, Param Type)] -> Type -> ADM (Body SOACS)
genIdxLamBody VName
arr (forall a b. [a] -> [b] -> [(a, b)]
zip [SubExp]
ws [Param Type]
idxs) Type
eltp
let soac :: SOAC SOACS
soac = forall {k} (rep :: k).
SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w (forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[VName]]
inds_as) (forall {k} (rep :: k). Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
gather_lam)
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m [VName]
letTupExp (VName -> [Char]
baseString VName
arr forall a. [a] -> [a] -> [a]
++ [Char]
"_gather") forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Op rep -> Exp rep
Op SOAC SOACS
soac
vjpScatter ::
VjpOps ->
Pat Type ->
StmAux () ->
(SubExp, [VName], Lambda SOACS, [(Shape, Int, VName)]) ->
ADM () ->
ADM ()
vjpScatter :: VjpOps
-> Pat Type
-> StmAux ()
-> (SubExp, [VName], Lambda SOACS,
[(ShapeBase SubExp, Int, VName)])
-> ADM ()
-> ADM ()
vjpScatter VjpOps
ops (Pat [PatElem Type]
pes) StmAux ()
aux (SubExp
w, [VName]
ass, Lambda SOACS
lam, [(ShapeBase SubExp, Int, VName)]
written_info) ADM ()
m
| forall {k} (rep :: k). Lambda rep -> Bool
isIdentityLambda Lambda SOACS
lam,
[(ShapeBase SubExp
shp, Int
num_vals, VName
xs)] <- [(ShapeBase SubExp, Int, VName)]
written_info,
[PatElem Type
pys] <- [PatElem Type]
pes =
PatElem Type
-> StmAux ()
-> (SubExp, [VName], (ShapeBase SubExp, Int, VName))
-> ADM ()
-> ADM ()
vjpScatter1 PatElem Type
pys StmAux ()
aux (SubExp
w, [VName]
ass, (ShapeBase SubExp
shp, Int
num_vals, VName
xs)) ADM ()
m
| forall {k} (rep :: k). Lambda rep -> Bool
isIdentityLambda Lambda SOACS
lam = do
let sind :: Int
sind = forall {a} {c}. [(ShapeBase a, Int, c)] -> Int
splitInd [(ShapeBase SubExp, Int, VName)]
written_info
([VName]
inds, [VName]
vals) = forall a. Int -> [a] -> ([a], [a])
splitAt Int
sind [VName]
ass
[Stm SOACS]
lst_stms <- ([VName], [VName])
-> [(PatElem Type, (ShapeBase SubExp, Int, VName))]
-> ADM [Stm SOACS]
chunkScatterInps ([VName]
inds, [VName]
vals) (forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem Type]
pes [(ShapeBase SubExp, Int, VName)]
written_info)
Stms SOACS -> ADM ()
diffScatters (forall {k} (rep :: k). [Stm rep] -> Stms rep
stmsFromList [Stm SOACS]
lst_stms)
| Bool
otherwise =
forall a. HasCallStack => [Char] -> a
error [Char]
"vjpScatter: cannot handle"
where
splitInd :: [(ShapeBase a, Int, c)] -> Int
splitInd [] = Int
0
splitInd ((ShapeBase a
shp, Int
num_res, c
_) : [(ShapeBase a, Int, c)]
rest) =
Int
num_res forall a. Num a => a -> a -> a
* forall (t :: * -> *) a. Foldable t => t a -> Int
length (forall d. ShapeBase d -> [d]
shapeDims ShapeBase a
shp) forall a. Num a => a -> a -> a
+ [(ShapeBase a, Int, c)] -> Int
splitInd [(ShapeBase a, Int, c)]
rest
chunkScatterInps :: ([VName], [VName])
-> [(PatElem Type, (ShapeBase SubExp, Int, VName))]
-> ADM [Stm SOACS]
chunkScatterInps ([VName]
acc_inds, [VName]
acc_vals) [] =
case ([VName]
acc_inds, [VName]
acc_vals) of
([], []) -> forall (f :: * -> *) a. Applicative f => a -> f a
pure []
([VName], [VName])
_ -> forall a. HasCallStack => [Char] -> a
error [Char]
"chunkScatterInps: cannot handle"
chunkScatterInps
([VName]
acc_inds, [VName]
acc_vals)
((PatElem Type
pe, info :: (ShapeBase SubExp, Int, VName)
info@(ShapeBase SubExp
shp, Int
num_vals, VName
_)) : [(PatElem Type, (ShapeBase SubExp, Int, VName))]
rest) = do
let num_inds :: Int
num_inds = Int
num_vals forall a. Num a => a -> a -> a
* forall (t :: * -> *) a. Foldable t => t a -> Int
length (forall d. ShapeBase d -> [d]
shapeDims ShapeBase SubExp
shp)
([VName]
curr_inds, [VName]
other_inds) = forall a. Int -> [a] -> ([a], [a])
splitAt Int
num_inds [VName]
acc_inds
([VName]
curr_vals, [VName]
other_vals) = forall a. Int -> [a] -> ([a], [a])
splitAt Int
num_vals [VName]
acc_vals
[Type]
vtps <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall {k} (rep :: k) (m :: * -> *).
HasScope rep m =>
VName -> m Type
lookupType [VName]
curr_vals
Lambda SOACS
f <- forall {k} (rep :: k) (m :: * -> *).
(Buildable rep, MonadFreshNames m) =>
[Type] -> m (Lambda rep)
mkIdentityLambda (forall a. Int -> a -> [a]
replicate Int
num_inds (forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64) forall a. [a] -> [a] -> [a]
++ [Type]
vtps)
let stm :: Stm SOACS
stm =
forall {k} (rep :: k).
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let (forall dec. [PatElem dec] -> Pat dec
Pat [PatElem Type
pe]) StmAux ()
aux forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k).
SubExp
-> [VName]
-> Lambda rep
-> [(ShapeBase SubExp, Int, VName)]
-> SOAC rep
Scatter SubExp
w ([VName]
curr_inds forall a. [a] -> [a] -> [a]
++ [VName]
curr_vals) Lambda SOACS
f [(ShapeBase SubExp, Int, VName)
info]
[Stm SOACS]
stms_rest <- ([VName], [VName])
-> [(PatElem Type, (ShapeBase SubExp, Int, VName))]
-> ADM [Stm SOACS]
chunkScatterInps ([VName]
other_inds, [VName]
other_vals) [(PatElem Type, (ShapeBase SubExp, Int, VName))]
rest
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Stm SOACS
stm forall a. a -> [a] -> [a]
: [Stm SOACS]
stms_rest
diffScatters :: Stms SOACS -> ADM ()
diffScatters Stms SOACS
all_stms
| Just (Stm SOACS
stm, Stms SOACS
stms) <- forall {k} (rep :: k). Stms rep -> Maybe (Stm rep, Stms rep)
stmsHead Stms SOACS
all_stms =
VjpOps -> Stm SOACS -> ADM () -> ADM ()
vjpStm VjpOps
ops Stm SOACS
stm forall a b. (a -> b) -> a -> b
$ Stms SOACS -> ADM ()
diffScatters Stms SOACS
stms
| Bool
otherwise = ADM ()
m