{-# LANGUAGE TypeFamilies #-}

module Futhark.AD.Rev.Scatter (vjpScatter) where

import Control.Monad
import Futhark.AD.Rev.Monad
import Futhark.Analysis.PrimExp.Convert
import Futhark.Builder
import Futhark.IR.SOACS
import Futhark.Tools
import Futhark.Util (chunk)

withinBounds :: [(SubExp, VName)] -> TPrimExp Bool VName
withinBounds :: [(SubExp, VName)] -> TPrimExp Bool VName
withinBounds [] = forall {k} (t :: k) v. PrimExp v -> TPrimExp t v
TPrimExp forall a b. (a -> b) -> a -> b
$ forall v. PrimValue -> PrimExp v
ValueExp (Bool -> PrimValue
BoolValue Bool
True)
withinBounds [(SubExp
q, VName
i)] = (forall a. a -> TPrimExp Int64 a
le64 VName
i forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. SubExp -> TPrimExp Int64 VName
pe64 SubExp
q) forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. (SubExp -> TPrimExp Int64 VName
pe64 (IntType -> Integer -> SubExp
intConst IntType
Int64 (-Integer
1)) forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. forall a. a -> TPrimExp Int64 a
le64 VName
i)
withinBounds ((SubExp, VName)
qi : [(SubExp, VName)]
qis) = [(SubExp, VName)] -> TPrimExp Bool VName
withinBounds [(SubExp, VName)
qi] forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. [(SubExp, VName)] -> TPrimExp Bool VName
withinBounds [(SubExp, VName)]
qis

-- Generates a potential tower-of-maps lambda body for an indexing operation.
-- Assuming parameters:
--   `arr`   the array that is indexed
--   `[(w_1, i_1), (w_2, i_2), ..., (w_k, i_k)]` outer lambda formal parameters and their bounds
--   `[n_1,n_2,...]ptp` the type of the index expression `arr[i_1,i_2,...,i_k]`
-- Generates something like:
-- (\ i_1 i_2 ->
--    map (\j_1 -> ... if (i_1 >= 0 && i_1 < w_1) &&
--                        (i_2 >= 0 && i_2 < w_2) && ...
--                     then arr[i_1, i_2, ... j_1, ...]
--                     else 0
--        ) (iota n_1)
-- )
-- The idea is that you do not want to put under the `if` something
--     that is an array because it would not flatten well!
genIdxLamBody :: VName -> [(SubExp, Param Type)] -> Type -> ADM (Body SOACS)
genIdxLamBody :: VName -> [(SubExp, Param Type)] -> Type -> ADM (Body SOACS)
genIdxLamBody VName
as [(SubExp, Param Type)]
wpis = VName
-> [(SubExp, Param Type)]
-> [Param Type]
-> Type
-> ADM (Body SOACS)
genRecLamBody VName
as [(SubExp, Param Type)]
wpis []
  where
    genRecLamBody :: VName -> [(SubExp, Param Type)] -> [Param Type] -> Type -> ADM (Body SOACS)
    genRecLamBody :: VName
-> [(SubExp, Param Type)]
-> [Param Type]
-> Type
-> ADM (Body SOACS)
genRecLamBody VName
arr [(SubExp, Param Type)]
w_pis [Param Type]
nest_pis (Array PrimType
t (Shape []) NoUniqueness
_) =
      VName
-> [(SubExp, Param Type)]
-> [Param Type]
-> Type
-> ADM (Body SOACS)
genRecLamBody VName
arr [(SubExp, Param Type)]
w_pis [Param Type]
nest_pis (forall shape u. PrimType -> TypeBase shape u
Prim PrimType
t)
    genRecLamBody VName
arr [(SubExp, Param Type)]
w_pis [Param Type]
nest_pis (Array PrimType
t (Shape (SubExp
s : [SubExp]
ss)) NoUniqueness
_) = do
      Param Type
new_ip <- forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam [Char]
"i" (forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64)
      let t' :: Type
t' = forall shape u. PrimType -> TypeBase shape u
Prim PrimType
t Type -> ShapeBase SubExp -> Type
`arrayOfShape` forall d. [d] -> ShapeBase d
Shape [SubExp]
ss
      Lambda SOACS
inner_lam <-
        forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m Result -> m (Lambda (Rep m))
mkLambda [Param Type
new_ip] forall a b. (a -> b) -> a -> b
$
          forall (m :: * -> *). MonadBuilder m => Body (Rep m) -> m Result
bodyBind forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName
-> [(SubExp, Param Type)]
-> [Param Type]
-> Type
-> ADM (Body SOACS)
genRecLamBody VName
arr [(SubExp, Param Type)]
w_pis ([Param Type]
nest_pis forall a. [a] -> [a] -> [a]
++ [Param Type
new_ip]) Type
t'
      let ([SubExp]
_, [Param Type]
orig_pis) = forall a b. [(a, b)] -> ([a], [b])
unzip [(SubExp, Param Type)]
w_pis
      forall (m :: * -> *).
MonadBuilder m =>
m Result -> m (Body (Rep m))
buildBody_ forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k) (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (forall {k} (rep :: k) dec.
(LParamInfo rep ~ dec) =>
[Param dec] -> Scope rep
scopeOfLParams ([Param Type]
orig_pis forall a. [a] -> [a] -> [a]
++ [Param Type]
nest_pis)) forall a b. (a -> b) -> a -> b
$ do
        VName
iota_v <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"iota" forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ SubExp -> SubExp -> SubExp -> IntType -> BasicOp
Iota SubExp
s (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) IntType
Int64
        SubExp
r <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp (VName -> [Char]
baseString VName
arr forall a. [a] -> [a] -> [a]
++ [Char]
"_elem") forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
s [VName
iota_v] (forall {k} (rep :: k). Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
inner_lam)
        forall (f :: * -> *) a. Applicative f => a -> f a
pure [SubExp -> SubExpRes
subExpRes SubExp
r]
    genRecLamBody VName
arr [(SubExp, Param Type)]
w_pis [Param Type]
nest_pis (Prim PrimType
ptp) = do
      let ([SubExp]
ws, [Param Type]
orig_pis) = forall a b. [(a, b)] -> ([a], [b])
unzip [(SubExp, Param Type)]
w_pis
      let inds :: [VName]
inds = forall a b. (a -> b) -> [a] -> [b]
map forall dec. Param dec -> VName
paramName ([Param Type]
orig_pis forall a. [a] -> [a] -> [a]
++ [Param Type]
nest_pis)
      forall {k} (rep :: k) (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (forall {k} (rep :: k) dec.
(LParamInfo rep ~ dec) =>
[Param dec] -> Scope rep
scopeOfLParams ([Param Type]
orig_pis forall a. [a] -> [a] -> [a]
++ [Param Type]
nest_pis)) forall a b. (a -> b) -> a -> b
$
        forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody
          [ forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
              (forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp forall a b. (a -> b) -> a -> b
$ [(SubExp, VName)] -> TPrimExp Bool VName
withinBounds forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip [SubExp]
ws forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall dec. Param dec -> VName
paramName [Param Type]
orig_pis)
              ( do
                  SubExp
r <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"r" forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
arr forall a b. (a -> b) -> a -> b
$ forall d. [DimIndex d] -> Slice d
Slice forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (forall d. d -> DimIndex d
DimFix forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var) [VName]
inds
                  forall (m :: * -> *).
MonadBuilder m =>
[SubExp] -> m (Body (Rep m))
resultBodyM [SubExp
r]
              )
              (forall (m :: * -> *).
MonadBuilder m =>
[SubExp] -> m (Body (Rep m))
resultBodyM [PrimValue -> SubExp
Constant forall a b. (a -> b) -> a -> b
$ PrimType -> PrimValue
blankPrimValue PrimType
ptp])
          ]
    genRecLamBody VName
_ [(SubExp, Param Type)]
_ [Param Type]
_ Type
_ = forall a. HasCallStack => [Char] -> a
error [Char]
"In Rev.hs, helper function genRecLamBody, unreachable case reached!"

--
-- Original:
--   let ys = scatter xs is vs
-- Assumes no duplicate indices in `is`
-- Forward Sweep:
--   let xs_save = gather xs is
--   let ys = scatter xs is vs
-- Return Sweep:
--   let vs_ctrbs = gather is ys_adj
--   let vs_adj \overline{+}= vs_ctrbs -- by map or generalized reduction
--   let xs_adj = scatter ys_adj is \overline{0}
--   let xs = scatter ys is xs_save
vjpScatter1 ::
  PatElem Type ->
  StmAux () ->
  (SubExp, [VName], (ShapeBase SubExp, Int, VName)) ->
  ADM () ->
  ADM ()
vjpScatter1 :: PatElem Type
-> StmAux ()
-> (SubExp, [VName], (ShapeBase SubExp, Int, VName))
-> ADM ()
-> ADM ()
vjpScatter1 PatElem Type
pys StmAux ()
aux (SubExp
w, [VName]
ass, (ShapeBase SubExp
shp, Int
num_vals, VName
xs)) ADM ()
m = do
  let rank :: Int
rank = forall (t :: * -> *) a. Foldable t => t a -> Int
length forall a b. (a -> b) -> a -> b
$ forall d. ShapeBase d -> [d]
shapeDims ShapeBase SubExp
shp
      ([VName]
all_inds, [VName]
val_as) = forall a. Int -> [a] -> ([a], [a])
splitAt (Int
rank forall a. Num a => a -> a -> a
* Int
num_vals) [VName]
ass
      inds_as :: [[VName]]
inds_as = forall a. Int -> [a] -> [[a]]
chunk Int
rank [VName]
all_inds
  Type
xs_t <- forall {k} (rep :: k) (m :: * -> *).
HasScope rep m =>
VName -> m Type
lookupType VName
xs
  let val_t :: Type
val_t = forall u.
Int
-> TypeBase (ShapeBase SubExp) u -> TypeBase (ShapeBase SubExp) u
stripArray (forall a. ArrayShape a => a -> Int
shapeRank ShapeBase SubExp
shp) Type
xs_t
  -- computing xs_save
  [VName]
xs_saves <- [[VName]] -> VName -> Type -> ADM [VName]
mkGather [[VName]]
inds_as VName
xs Type
xs_t
  -- performing the scatter
  Lambda SOACS
id_lam <-
    forall {k} (rep :: k) (m :: * -> *).
(Buildable rep, MonadFreshNames m) =>
[Type] -> m (Lambda rep)
mkIdentityLambda forall a b. (a -> b) -> a -> b
$
      forall a. Int -> a -> [a]
replicate (forall a. ArrayShape a => a -> Int
shapeRank ShapeBase SubExp
shp) (forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64) forall a. [a] -> [a] -> [a]
++ forall a. Int -> a -> [a]
replicate (forall a. ArrayShape a => a -> Int
shapeRank ShapeBase SubExp
shp) Type
val_t
  forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let (forall dec. [PatElem dec] -> Pat dec
Pat [PatElem Type
pys]) StmAux ()
aux forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
SubExp
-> [VName]
-> Lambda rep
-> [(ShapeBase SubExp, Int, VName)]
-> SOAC rep
Scatter SubExp
w [VName]
ass Lambda SOACS
id_lam [(ShapeBase SubExp
shp, Int
num_vals, VName
xs)]
  ADM ()
m
  let ys :: VName
ys = forall dec. PatElem dec -> VName
patElemName PatElem Type
pys
  -- XXX: Since our restoration of xs will consume ys, we have to
  -- make a copy of ys in the chance that it is actually the result
  -- of the program.  In that case the asymptotics will not be
  -- (locally) preserved, but since ys must necessarily have been
  -- constructed somewhere close, they are probably globally OK.
  VName
ys_copy <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp (VName -> [Char]
baseString VName
ys forall a. Semigroup a => a -> a -> a
<> [Char]
"_copy") forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy VName
ys
  forall a. ADM a -> ADM a
returnSweepCode forall a b. (a -> b) -> a -> b
$ do
    VName
ys_adj <- VName -> ADM VName
lookupAdjVal VName
ys
    -- computing vs_ctrbs and updating vs_adj
    [VName]
vs_ctrbs <- [[VName]] -> VName -> Type -> ADM [VName]
mkGather [[VName]]
inds_as VName
ys_adj Type
xs_t
    forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> VName -> ADM ()
updateAdj [VName]
val_as [VName]
vs_ctrbs -- use Slice?
    -- creating xs_adj
    [VName]
zeros <-
      forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM (forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
val_as) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"zeros" forall a b. (a -> b) -> a -> b
$
        forall {k} (rep :: k). Type -> Exp rep
zeroExp forall a b. (a -> b) -> a -> b
$
          Type
xs_t forall d u.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) u -> d -> TypeBase (ShapeBase d) u
`setOuterSize` SubExp
w
    let f_tps :: [Type]
f_tps = forall a. Int -> a -> [a]
replicate (Int
rank forall a. Num a => a -> a -> a
* Int
num_vals) (forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64) forall a. [a] -> [a] -> [a]
++ forall a. Int -> a -> [a]
replicate Int
num_vals Type
val_t
    Lambda SOACS
f <- forall {k} (rep :: k) (m :: * -> *).
(Buildable rep, MonadFreshNames m) =>
[Type] -> m (Lambda rep)
mkIdentityLambda [Type]
f_tps
    VName
xs_adj <-
      forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp (VName -> [Char]
baseString VName
xs forall a. [a] -> [a] -> [a]
++ [Char]
"_adj") forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$
        forall {k} (rep :: k).
SubExp
-> [VName]
-> Lambda rep
-> [(ShapeBase SubExp, Int, VName)]
-> SOAC rep
Scatter SubExp
w ([VName]
all_inds forall a. [a] -> [a] -> [a]
++ [VName]
zeros) Lambda SOACS
f [(ShapeBase SubExp
shp, Int
num_vals, VName
ys_adj)]
    VName -> VName -> ADM ()
insAdj VName
xs VName
xs_adj -- reusing the ys_adj for xs_adj!
    Lambda SOACS
f' <- forall {k} (rep :: k) (m :: * -> *).
(Buildable rep, MonadFreshNames m) =>
[Type] -> m (Lambda rep)
mkIdentityLambda [Type]
f_tps
    VName
xs_rc <-
      forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux ()
aux forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp (VName -> [Char]
baseString VName
xs forall a. Semigroup a => a -> a -> a
<> [Char]
"_rc") forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$
        forall {k} (rep :: k).
SubExp
-> [VName]
-> Lambda rep
-> [(ShapeBase SubExp, Int, VName)]
-> SOAC rep
Scatter SubExp
w ([VName]
all_inds forall a. [a] -> [a] -> [a]
++ [VName]
xs_saves) Lambda SOACS
f' [(ShapeBase SubExp
shp, Int
num_vals, VName
ys)]
    VName -> VName -> ADM ()
addSubstitution VName
xs VName
xs_rc
    VName -> VName -> ADM ()
addSubstitution VName
ys VName
ys_copy
  where
    -- Creates a potential map-nest that indexes in full the array,
    --   and applies the condition of indices within bounds at the
    --   deepest level in the nest so that everything can be parallel.
    mkGather :: [[VName]] -> VName -> Type -> ADM [VName]
    mkGather :: [[VName]] -> VName -> Type -> ADM [VName]
mkGather [[VName]]
inds_as VName
arr Type
arr_t = do
      [[Param Type]]
ips <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [[VName]]
inds_as forall a b. (a -> b) -> a -> b
$ \[VName]
idxs ->
        forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (\VName
idx -> forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam (VName -> [Char]
baseString VName
idx forall a. [a] -> [a] -> [a]
++ [Char]
"_elem") (forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64)) [VName]
idxs

      Lambda SOACS
gather_lam <- forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m Result -> m (Lambda (Rep m))
mkLambda (forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[Param Type]]
ips) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a. Monoid a => [a] -> a
mconcat forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [[Param Type]]
ips forall a b. (a -> b) -> a -> b
$ \[Param Type]
idxs -> do
        let q :: Int
q = forall (t :: * -> *) a. Foldable t => t a -> Int
length [Param Type]
idxs
            ([SubExp]
ws, Type
eltp) = (forall a. Int -> [a] -> [a]
take Int
q forall a b. (a -> b) -> a -> b
$ forall u. TypeBase (ShapeBase SubExp) u -> [SubExp]
arrayDims Type
arr_t, forall u.
Int
-> TypeBase (ShapeBase SubExp) u -> TypeBase (ShapeBase SubExp) u
stripArray Int
q Type
arr_t)
        forall (m :: * -> *). MonadBuilder m => Body (Rep m) -> m Result
bodyBind forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName -> [(SubExp, Param Type)] -> Type -> ADM (Body SOACS)
genIdxLamBody VName
arr (forall a b. [a] -> [b] -> [(a, b)]
zip [SubExp]
ws [Param Type]
idxs) Type
eltp
      let soac :: SOAC SOACS
soac = forall {k} (rep :: k).
SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w (forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[VName]]
inds_as) (forall {k} (rep :: k). Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
gather_lam)
      forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m [VName]
letTupExp (VName -> [Char]
baseString VName
arr forall a. [a] -> [a] -> [a]
++ [Char]
"_gather") forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Op rep -> Exp rep
Op SOAC SOACS
soac

vjpScatter ::
  VjpOps ->
  Pat Type ->
  StmAux () ->
  (SubExp, [VName], Lambda SOACS, [(Shape, Int, VName)]) ->
  ADM () ->
  ADM ()
vjpScatter :: VjpOps
-> Pat Type
-> StmAux ()
-> (SubExp, [VName], Lambda SOACS,
    [(ShapeBase SubExp, Int, VName)])
-> ADM ()
-> ADM ()
vjpScatter VjpOps
ops (Pat [PatElem Type]
pes) StmAux ()
aux (SubExp
w, [VName]
ass, Lambda SOACS
lam, [(ShapeBase SubExp, Int, VName)]
written_info) ADM ()
m
  | forall {k} (rep :: k). Lambda rep -> Bool
isIdentityLambda Lambda SOACS
lam,
    [(ShapeBase SubExp
shp, Int
num_vals, VName
xs)] <- [(ShapeBase SubExp, Int, VName)]
written_info,
    [PatElem Type
pys] <- [PatElem Type]
pes =
      PatElem Type
-> StmAux ()
-> (SubExp, [VName], (ShapeBase SubExp, Int, VName))
-> ADM ()
-> ADM ()
vjpScatter1 PatElem Type
pys StmAux ()
aux (SubExp
w, [VName]
ass, (ShapeBase SubExp
shp, Int
num_vals, VName
xs)) ADM ()
m
  | forall {k} (rep :: k). Lambda rep -> Bool
isIdentityLambda Lambda SOACS
lam = do
      let sind :: Int
sind = forall {a} {c}. [(ShapeBase a, Int, c)] -> Int
splitInd [(ShapeBase SubExp, Int, VName)]
written_info
          ([VName]
inds, [VName]
vals) = forall a. Int -> [a] -> ([a], [a])
splitAt Int
sind [VName]
ass
      [Stm SOACS]
lst_stms <- ([VName], [VName])
-> [(PatElem Type, (ShapeBase SubExp, Int, VName))]
-> ADM [Stm SOACS]
chunkScatterInps ([VName]
inds, [VName]
vals) (forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem Type]
pes [(ShapeBase SubExp, Int, VName)]
written_info)
      Stms SOACS -> ADM ()
diffScatters (forall {k} (rep :: k). [Stm rep] -> Stms rep
stmsFromList [Stm SOACS]
lst_stms)
  | Bool
otherwise =
      forall a. HasCallStack => [Char] -> a
error [Char]
"vjpScatter: cannot handle"
  where
    splitInd :: [(ShapeBase a, Int, c)] -> Int
splitInd [] = Int
0
    splitInd ((ShapeBase a
shp, Int
num_res, c
_) : [(ShapeBase a, Int, c)]
rest) =
      Int
num_res forall a. Num a => a -> a -> a
* forall (t :: * -> *) a. Foldable t => t a -> Int
length (forall d. ShapeBase d -> [d]
shapeDims ShapeBase a
shp) forall a. Num a => a -> a -> a
+ [(ShapeBase a, Int, c)] -> Int
splitInd [(ShapeBase a, Int, c)]
rest
    chunkScatterInps :: ([VName], [VName])
-> [(PatElem Type, (ShapeBase SubExp, Int, VName))]
-> ADM [Stm SOACS]
chunkScatterInps ([VName]
acc_inds, [VName]
acc_vals) [] =
      case ([VName]
acc_inds, [VName]
acc_vals) of
        ([], []) -> forall (f :: * -> *) a. Applicative f => a -> f a
pure []
        ([VName], [VName])
_ -> forall a. HasCallStack => [Char] -> a
error [Char]
"chunkScatterInps: cannot handle"
    chunkScatterInps
      ([VName]
acc_inds, [VName]
acc_vals)
      ((PatElem Type
pe, info :: (ShapeBase SubExp, Int, VName)
info@(ShapeBase SubExp
shp, Int
num_vals, VName
_)) : [(PatElem Type, (ShapeBase SubExp, Int, VName))]
rest) = do
        let num_inds :: Int
num_inds = Int
num_vals forall a. Num a => a -> a -> a
* forall (t :: * -> *) a. Foldable t => t a -> Int
length (forall d. ShapeBase d -> [d]
shapeDims ShapeBase SubExp
shp)
            ([VName]
curr_inds, [VName]
other_inds) = forall a. Int -> [a] -> ([a], [a])
splitAt Int
num_inds [VName]
acc_inds
            ([VName]
curr_vals, [VName]
other_vals) = forall a. Int -> [a] -> ([a], [a])
splitAt Int
num_vals [VName]
acc_vals
        [Type]
vtps <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall {k} (rep :: k) (m :: * -> *).
HasScope rep m =>
VName -> m Type
lookupType [VName]
curr_vals
        Lambda SOACS
f <- forall {k} (rep :: k) (m :: * -> *).
(Buildable rep, MonadFreshNames m) =>
[Type] -> m (Lambda rep)
mkIdentityLambda (forall a. Int -> a -> [a]
replicate Int
num_inds (forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64) forall a. [a] -> [a] -> [a]
++ [Type]
vtps)
        let stm :: Stm SOACS
stm =
              forall {k} (rep :: k).
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let (forall dec. [PatElem dec] -> Pat dec
Pat [PatElem Type
pe]) StmAux ()
aux forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$
                forall {k} (rep :: k).
SubExp
-> [VName]
-> Lambda rep
-> [(ShapeBase SubExp, Int, VName)]
-> SOAC rep
Scatter SubExp
w ([VName]
curr_inds forall a. [a] -> [a] -> [a]
++ [VName]
curr_vals) Lambda SOACS
f [(ShapeBase SubExp, Int, VName)
info]
        [Stm SOACS]
stms_rest <- ([VName], [VName])
-> [(PatElem Type, (ShapeBase SubExp, Int, VName))]
-> ADM [Stm SOACS]
chunkScatterInps ([VName]
other_inds, [VName]
other_vals) [(PatElem Type, (ShapeBase SubExp, Int, VName))]
rest
        forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Stm SOACS
stm forall a. a -> [a] -> [a]
: [Stm SOACS]
stms_rest
    diffScatters :: Stms SOACS -> ADM ()
diffScatters Stms SOACS
all_stms
      | Just (Stm SOACS
stm, Stms SOACS
stms) <- forall {k} (rep :: k). Stms rep -> Maybe (Stm rep, Stms rep)
stmsHead Stms SOACS
all_stms =
          VjpOps -> Stm SOACS -> ADM () -> ADM ()
vjpStm VjpOps
ops Stm SOACS
stm forall a b. (a -> b) -> a -> b
$ Stms SOACS -> ADM ()
diffScatters Stms SOACS
stms
      | Bool
otherwise = ADM ()
m