module Futhark.AD.Rev.Scan (diffScan) where

import Control.Monad
import Futhark.AD.Rev.Monad
import Futhark.Analysis.PrimExp.Convert
import Futhark.Builder
import Futhark.IR.SOACS
import Futhark.Tools
import Futhark.Transform.Rename
import Futhark.Util (pairs, unpairs)

data FirstOrSecond = WrtFirst | WrtSecond

-- computes `d(x op y)/dx` or d(x op y)/dy
mkScanAdjointLam :: VjpOps -> Lambda SOACS -> FirstOrSecond -> ADM (Lambda SOACS)
mkScanAdjointLam :: VjpOps -> Lambda SOACS -> FirstOrSecond -> ADM (Lambda SOACS)
mkScanAdjointLam VjpOps
ops Lambda SOACS
lam0 FirstOrSecond
which = do
  let len :: Int
len = forall (t :: * -> *) a. Foldable t => t a -> Int
length forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> [Type]
lambdaReturnType Lambda SOACS
lam0
  Lambda SOACS
lam <- forall {k} (rep :: k) (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda SOACS
lam0
  let p2diff :: [Param Type]
p2diff =
        case FirstOrSecond
which of
          FirstOrSecond
WrtFirst -> forall a. Int -> [a] -> [a]
take Int
len forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
lam
          FirstOrSecond
WrtSecond -> forall a. Int -> [a] -> [a]
drop Int
len forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
lam
  [Adj]
p_adjs <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Type -> ADM Adj
unitAdjOfType (forall {k} (rep :: k). Lambda rep -> [Type]
lambdaReturnType Lambda SOACS
lam)
  VjpOps -> [Adj] -> [VName] -> Lambda SOACS -> ADM (Lambda SOACS)
vjpLambda VjpOps
ops [Adj]
p_adjs (forall a b. (a -> b) -> [a] -> [b]
map forall dec. Param dec -> VName
paramName [Param Type]
p2diff) Lambda SOACS
lam

-- Should generate something like:
-- `\ j -> let i = n - 1 - j
--         if i < n-1 then ( ys_adj[i], df2dx ys[i] xs[i+1]) else (0,1) )`
-- where `ys` is  the result of scan
--       `xs` is  the input  of scan
--       `ys_adj` is the known adjoint of ys
--       `j` draw values from `iota n`
mkScanFusedMapLam :: VjpOps -> SubExp -> Lambda SOACS -> [VName] -> [VName] -> [VName] -> ADM (Lambda SOACS)
mkScanFusedMapLam :: VjpOps
-> SubExp
-> Lambda SOACS
-> [VName]
-> [VName]
-> [VName]
-> ADM (Lambda SOACS)
mkScanFusedMapLam VjpOps
ops SubExp
w Lambda SOACS
scn_lam [VName]
xs [VName]
ys [VName]
ys_adj = do
  Lambda SOACS
lam <- VjpOps -> Lambda SOACS -> FirstOrSecond -> ADM (Lambda SOACS)
mkScanAdjointLam VjpOps
ops Lambda SOACS
scn_lam FirstOrSecond
WrtFirst
  [Type]
ys_ts <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall {k} (rep :: k) (m :: * -> *).
HasScope rep m =>
VName -> m Type
lookupType [VName]
ys
  Param Type
par_i <- forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam [Char]
"i" forall a b. (a -> b) -> a -> b
$ forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
  let i :: VName
i = forall dec. Param dec -> VName
paramName Param Type
par_i
  forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m Result -> m (Lambda (Rep m))
mkLambda [Param Type
par_i] forall a b. (a -> b) -> a -> b
$
    forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> Result
varsRes forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m [VName]
letTupExp [Char]
"x"
      forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
        (forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp forall a b. (a -> b) -> a -> b
$ forall a. a -> TPrimExp Int64 a
le64 VName
i forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int64 VName
0)
        ( forall (m :: * -> *).
MonadBuilder m =>
m Result -> m (Body (Rep m))
buildBody_ forall a b. (a -> b) -> a -> b
$ do
            [SubExp]
zs <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"ct_zero" forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). Type -> Exp rep
zeroExp forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall u. TypeBase Shape u -> TypeBase Shape u
rowType) [Type]
ys_ts
            [SubExp]
os <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"ct_one" forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). Type -> Exp rep
oneExp forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall u. TypeBase Shape u -> TypeBase Shape u
rowType) [Type]
ys_ts
            forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ [SubExp] -> Result
subExpsRes forall a b. (a -> b) -> a -> b
$ forall a. [(a, a)] -> [a]
unpairs forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip [SubExp]
zs [SubExp]
os
        )
        ( forall (m :: * -> *).
MonadBuilder m =>
m Result -> m (Body (Rep m))
buildBody_ forall a b. (a -> b) -> a -> b
$ do
            SubExp
j <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"j" forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp (SubExp -> TPrimExp Int64 VName
pe64 SubExp
w forall a. Num a => a -> a -> a
- (forall a. a -> TPrimExp Int64 a
le64 VName
i forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
1))
            SubExp
j1 <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"j1" forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp (SubExp -> TPrimExp Int64 VName
pe64 SubExp
w forall a. Num a => a -> a -> a
- forall a. a -> TPrimExp Int64 a
le64 VName
i)
            let index :: SubExp -> VName -> Type -> Exp rep
index SubExp
idx VName
arr Type
t = forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
arr forall a b. (a -> b) -> a -> b
$ Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice Type
t [forall d. d -> DimIndex d
DimFix SubExp
idx]
            [SubExp]
y_s <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
ys_adj [Type]
ys_ts) forall a b. (a -> b) -> a -> b
$ \(VName
y_, Type
t) ->
              forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp (VName -> [Char]
baseString VName
y_ forall a. [a] -> [a] -> [a]
++ [Char]
"_j") forall a b. (a -> b) -> a -> b
$ forall {k} {rep :: k}. SubExp -> VName -> Type -> Exp rep
index SubExp
j VName
y_ Type
t
            Result
lam_rs <-
              forall (m :: * -> *).
MonadBuilder m =>
Lambda (Rep m) -> [m (Exp (Rep m))] -> m Result
eLambda Lambda SOACS
lam forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a -> b) -> [a] -> [b]
map forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$
                forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (forall {k} {rep :: k}. SubExp -> VName -> Type -> Exp rep
index SubExp
j) [VName]
ys [Type]
ys_ts forall a. [a] -> [a] -> [a]
++ forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (forall {k} {rep :: k}. SubExp -> VName -> Type -> Exp rep
index SubExp
j1) [VName]
xs [Type]
ys_ts
            forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a. [(a, a)] -> [a]
unpairs forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip ([SubExp] -> Result
subExpsRes [SubExp]
y_s) Result
lam_rs
        )

-- \(a1, b1) (a2, b2) -> (a2 + b2 * a1, b1 * b2)
mkScanLinFunO :: Type -> ADM (Scan SOACS)
mkScanLinFunO :: Type -> ADM (Scan SOACS)
mkScanLinFunO Type
t = do
  let pt :: PrimType
pt = forall shape u. TypeBase shape u -> PrimType
elemType Type
t
  SubExp
zero <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"zeros" forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Type -> Exp rep
zeroExp Type
t
  SubExp
one <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"ones" forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Type -> Exp rep
oneExp Type
t
  [VName]
tmp <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [[Char]
"a1", [Char]
"b1", [Char]
"a2", [Char]
"b2"]
  let [VName
a1, VName
b1, VName
a2, VName
b2] = [VName]
tmp
      pet :: VName -> PrimExp VName
pet = PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
pt forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var
  Lambda SOACS
lam <- forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m Result -> m (Lambda (Rep m))
mkLambda (forall a b. (a -> b) -> [a] -> [b]
map (\VName
v -> forall dec. Attrs -> VName -> dec -> Param dec
Param forall a. Monoid a => a
mempty VName
v Type
t) [VName
a1, VName
b1, VName
a2, VName
b2]) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> Result
varsRes forall a b. (a -> b) -> a -> b
$
    Int
-> [VName] -> ([VName] -> [VName] -> ADM [VName]) -> ADM [VName]
tabNest (forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank Type
t) [VName
a1, VName
b1, VName
a2, VName
b2] forall a b. (a -> b) -> a -> b
$ \[VName]
_ [VName
a1', VName
b1', VName
a2', VName
b2'] -> do
      VName
x <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"x" forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp forall a b. (a -> b) -> a -> b
$ VName -> PrimExp VName
pet VName
a2' forall v. PrimExp v -> PrimExp v -> PrimExp v
~+~ VName -> PrimExp VName
pet VName
b2' forall v. PrimExp v -> PrimExp v -> PrimExp v
~*~ VName -> PrimExp VName
pet VName
a1'
      VName
y <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"y" forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp forall a b. (a -> b) -> a -> b
$ VName -> PrimExp VName
pet VName
b1' forall v. PrimExp v -> PrimExp v -> PrimExp v
~*~ VName -> PrimExp VName
pet VName
b2'
      forall (f :: * -> *) a. Applicative f => a -> f a
pure [VName
x, VName
y]
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> [SubExp] -> Scan rep
Scan Lambda SOACS
lam [SubExp
zero, SubExp
one]

-- build the map following the scan with linear-function-composition:
-- for each (ds,cs) length-n array results of the scan, combine them as:
--    `let rs = map2 (\ d_i c_i -> d_i + c_i * y_adj[n-1]) d c |> reverse`
-- but insert explicit indexing to reverse inside the map.
mkScan2ndMaps :: SubExp -> (Type, VName, (VName, VName)) -> ADM VName
mkScan2ndMaps :: SubExp -> (Type, VName, (VName, VName)) -> ADM VName
mkScan2ndMaps SubExp
w (Type
arr_tp, VName
y_adj, (VName
ds, VName
cs)) = do
  SubExp
nm1 <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"nm1" forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp (SubExp -> TPrimExp Int64 VName
pe64 SubExp
w forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1)
  VName
y_adj_last <-
    forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp (VName -> [Char]
baseString VName
y_adj forall a. [a] -> [a] -> [a]
++ [Char]
"_last") forall a b. (a -> b) -> a -> b
$
      forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
        VName -> Slice SubExp -> BasicOp
Index VName
y_adj forall a b. (a -> b) -> a -> b
$
          Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice Type
arr_tp [forall d. d -> DimIndex d
DimFix SubExp
nm1]

  Param Type
par_i <- forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam [Char]
"i" forall a b. (a -> b) -> a -> b
$ forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
  Lambda SOACS
lam <- forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m Result -> m (Lambda (Rep m))
mkLambda [Param Type
par_i] forall a b. (a -> b) -> a -> b
$ do
    let i :: VName
i = forall dec. Param dec -> VName
paramName Param Type
par_i
    SubExp
j <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"j" forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp (SubExp -> TPrimExp Int64 VName
pe64 SubExp
w forall a. Num a => a -> a -> a
- (forall a. a -> TPrimExp Int64 a
le64 VName
i forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
1))
    VName
dj <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp (VName -> [Char]
baseString VName
ds forall a. [a] -> [a] -> [a]
++ [Char]
"_dj") forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
ds forall a b. (a -> b) -> a -> b
$ Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice Type
arr_tp [forall d. d -> DimIndex d
DimFix SubExp
j]
    VName
cj <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp (VName -> [Char]
baseString VName
cs forall a. [a] -> [a] -> [a]
++ [Char]
"_cj") forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
cs forall a b. (a -> b) -> a -> b
$ Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice Type
arr_tp [forall d. d -> DimIndex d
DimFix SubExp
j]

    let pet :: VName -> PrimExp VName
pet = PrimType -> SubExp -> PrimExp VName
primExpFromSubExp (forall shape u. TypeBase shape u -> PrimType
elemType Type
arr_tp) forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var
    forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> Result
varsRes forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int
-> [VName] -> ([VName] -> [VName] -> ADM [VName]) -> ADM [VName]
tabNest (forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank (forall u. TypeBase Shape u -> TypeBase Shape u
rowType Type
arr_tp)) [VName
y_adj_last, VName
dj, VName
cj] forall a b. (a -> b) -> a -> b
$ \[VName]
_ [VName
y_adj_last', VName
dj', VName
cj'] ->
      forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m [VName]
letTupExp [Char]
"res" forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp forall a b. (a -> b) -> a -> b
$ VName -> PrimExp VName
pet VName
dj' forall v. PrimExp v -> PrimExp v -> PrimExp v
~+~ VName -> PrimExp VName
pet VName
cj' forall v. PrimExp v -> PrimExp v -> PrimExp v
~*~ VName -> PrimExp VName
pet VName
y_adj_last'

  VName
iota <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"iota" forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ SubExp -> SubExp -> SubExp -> IntType -> BasicOp
Iota SubExp
w (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) IntType
Int64
  forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"after_scan" forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Op rep -> Exp rep
Op (forall {k} (rep :: k).
SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w [VName
iota] (forall {k} (rep :: k).
[Scan rep] -> [Reduce rep] -> Lambda rep -> ScremaForm rep
ScremaForm [] [] Lambda SOACS
lam))

-- perform the final map, which is fusable with the maps obtained from `mkScan2ndMaps`
-- let xs_contribs =
--    map3 (\ i a r -> if i==0 then r else (df2dy (ys[i-1]) a) \bar{*} r)
--         (iota n) xs rs
mkScanFinalMap :: VjpOps -> SubExp -> Lambda SOACS -> [VName] -> [VName] -> [VName] -> ADM [VName]
mkScanFinalMap :: VjpOps
-> SubExp
-> Lambda SOACS
-> [VName]
-> [VName]
-> [VName]
-> ADM [VName]
mkScanFinalMap VjpOps
ops SubExp
w Lambda SOACS
scan_lam [VName]
xs [VName]
ys [VName]
rs = do
  let eltps :: [Type]
eltps = forall {k} (rep :: k). Lambda rep -> [Type]
lambdaReturnType Lambda SOACS
scan_lam
  Lambda SOACS
lam <- VjpOps -> Lambda SOACS -> FirstOrSecond -> ADM (Lambda SOACS)
mkScanAdjointLam VjpOps
ops Lambda SOACS
scan_lam FirstOrSecond
WrtSecond
  Param Type
par_i <- forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam [Char]
"i" forall a b. (a -> b) -> a -> b
$ forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
  let i :: VName
i = forall dec. Param dec -> VName
paramName Param Type
par_i
  [Param Type]
par_x <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (\(VName
x, Type
t) -> forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam (VName -> [Char]
baseString VName
x forall a. [a] -> [a] -> [a]
++ [Char]
"_par_x") Type
t) forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
xs [Type]
eltps
  [Param Type]
par_r <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (\(VName
r, Type
t) -> forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam (VName -> [Char]
baseString VName
r forall a. [a] -> [a] -> [a]
++ [Char]
"_par_r") Type
t) forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
rs [Type]
eltps

  Lambda SOACS
map_lam <-
    forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m Result -> m (Lambda (Rep m))
mkLambda (Param Type
par_i forall a. a -> [a] -> [a]
: [Param Type]
par_x forall a. [a] -> [a] -> [a]
++ [Param Type]
par_r) forall a b. (a -> b) -> a -> b
$
      forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> Result
varsRes forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m [VName]
letTupExp [Char]
"scan_contribs"
        forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
          (forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp forall a b. (a -> b) -> a -> b
$ forall a. a -> TPrimExp Int64 a
le64 VName
i forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int64 VName
0)
          (forall (m :: * -> *).
MonadBuilder m =>
[SubExp] -> m (Body (Rep m))
resultBodyM forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (VName -> SubExp
Var forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dec. Param dec -> VName
paramName) [Param Type]
par_r)
          ( forall (m :: * -> *).
MonadBuilder m =>
m Result -> m (Body (Rep m))
buildBody_ forall a b. (a -> b) -> a -> b
$ do
              SubExp
im1 <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"im1" forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp (forall a. a -> TPrimExp Int64 a
le64 VName
i forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1)
              [SubExp]
ys_im1 <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [VName]
ys forall a b. (a -> b) -> a -> b
$ \VName
y -> do
                Type
y_t <- forall {k} (rep :: k) (m :: * -> *).
HasScope rep m =>
VName -> m Type
lookupType VName
y
                forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp (VName -> [Char]
baseString VName
y forall a. [a] -> [a] -> [a]
++ [Char]
"_last") forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
y forall a b. (a -> b) -> a -> b
$ Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice Type
y_t [forall d. d -> DimIndex d
DimFix SubExp
im1]

              [VName]
lam_res <-
                forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"const" forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> BasicOp
SubExp forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExpRes -> SubExp
resSubExp)
                  forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *).
MonadBuilder m =>
Lambda (Rep m) -> [m (Exp (Rep m))] -> m Result
eLambda Lambda SOACS
lam (forall a b. (a -> b) -> [a] -> [b]
map forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp forall a b. (a -> b) -> a -> b
$ [SubExp]
ys_im1 forall a. [a] -> [a] -> [a]
++ forall a b. (a -> b) -> [a] -> [b]
map (VName -> SubExp
Var forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dec. Param dec -> VName
paramName) [Param Type]
par_x)

              forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ([VName] -> Result
varsRes forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Monoid a => [a] -> a
mconcat) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [VName]
lam_res (forall a b. (a -> b) -> [a] -> [b]
map forall dec. Param dec -> VName
paramName [Param Type]
par_r) [Type]
eltps) forall a b. (a -> b) -> a -> b
$
                \(VName
lam_r, VName
r, Type
eltp) -> do
                  let pet :: VName -> PrimExp VName
pet = PrimType -> SubExp -> PrimExp VName
primExpFromSubExp (forall shape u. TypeBase shape u -> PrimType
elemType Type
eltp) forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var

                  Int
-> [VName] -> ([VName] -> [VName] -> ADM [VName]) -> ADM [VName]
tabNest (forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank Type
eltp) [VName
lam_r, VName
r] forall a b. (a -> b) -> a -> b
$ \[VName]
_ [VName
lam_r', VName
r'] ->
                    forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m [VName]
letTupExp [Char]
"res" forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp forall a b. (a -> b) -> a -> b
$ VName -> PrimExp VName
pet VName
lam_r' forall v. PrimExp v -> PrimExp v -> PrimExp v
~*~ VName -> PrimExp VName
pet VName
r'
          )

  VName
iota <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"iota" forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ SubExp -> SubExp -> SubExp -> IntType -> BasicOp
Iota SubExp
w (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) IntType
Int64
  forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m [VName]
letTupExp [Char]
"scan_contribs" forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Op rep -> Exp rep
Op (forall {k} (rep :: k).
SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w (VName
iota forall a. a -> [a] -> [a]
: [VName]
xs forall a. [a] -> [a] -> [a]
++ [VName]
rs) (forall {k} (rep :: k).
[Scan rep] -> [Reduce rep] -> Lambda rep -> ScremaForm rep
ScremaForm [] [] Lambda SOACS
map_lam))

diffScan :: VjpOps -> [VName] -> SubExp -> [VName] -> Scan SOACS -> ADM ()
diffScan :: VjpOps -> [VName] -> SubExp -> [VName] -> Scan SOACS -> ADM ()
diffScan VjpOps
ops [VName]
ys SubExp
w [VName]
as Scan SOACS
scan = do
  [VName]
ys_adj <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> ADM VName
lookupAdjVal [VName]
ys
  [Type]
as_ts <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall {k} (rep :: k) (m :: * -> *).
HasScope rep m =>
VName -> m Type
lookupType [VName]
as
  Lambda SOACS
map1_lam <- VjpOps
-> SubExp
-> Lambda SOACS
-> [VName]
-> [VName]
-> [VName]
-> ADM (Lambda SOACS)
mkScanFusedMapLam VjpOps
ops SubExp
w (forall {k} (rep :: k). Scan rep -> Lambda rep
scanLambda Scan SOACS
scan) [VName]
as [VName]
ys [VName]
ys_adj
  [Scan SOACS]
scans_lin_fun_o <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Type -> ADM (Scan SOACS)
mkScanLinFunO forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> [Type]
lambdaReturnType forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Scan rep -> Lambda rep
scanLambda Scan SOACS
scan
  VName
iota <-
    forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"iota" forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ SubExp -> SubExp -> SubExp -> IntType -> BasicOp
Iota SubExp
w (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) IntType
Int64
  [VName]
r_scan <-
    forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m [VName]
letTupExp [Char]
"adj_ctrb_scan" forall a b. (a -> b) -> a -> b
$
      forall {k} (rep :: k). Op rep -> Exp rep
Op (forall {k} (rep :: k).
SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w [VName
iota] (forall {k} (rep :: k).
[Scan rep] -> [Reduce rep] -> Lambda rep -> ScremaForm rep
ScremaForm [Scan SOACS]
scans_lin_fun_o [] Lambda SOACS
map1_lam))
  [VName]
red_nms <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SubExp -> (Type, VName, (VName, VName)) -> ADM VName
mkScan2ndMaps SubExp
w) (forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [Type]
as_ts [VName]
ys_adj (forall a. [a] -> [(a, a)]
pairs [VName]
r_scan))
  [VName]
as_contribs <- VjpOps
-> SubExp
-> Lambda SOACS
-> [VName]
-> [VName]
-> [VName]
-> ADM [VName]
mkScanFinalMap VjpOps
ops SubExp
w (forall {k} (rep :: k). Scan rep -> Lambda rep
scanLambda Scan SOACS
scan) [VName]
as [VName]
ys [VName]
red_nms
  forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> VName -> ADM ()
updateAdj [VName]
as [VName]
as_contribs