module Futhark.AD.Rev.Scan (diffScan) where
import Control.Monad
import Futhark.AD.Rev.Monad
import Futhark.Analysis.PrimExp.Convert
import Futhark.Builder
import Futhark.IR.SOACS
import Futhark.Tools
import Futhark.Transform.Rename
import Futhark.Util (pairs, unpairs)
data FirstOrSecond = WrtFirst | WrtSecond
mkScanAdjointLam :: VjpOps -> Lambda SOACS -> FirstOrSecond -> ADM (Lambda SOACS)
mkScanAdjointLam :: VjpOps -> Lambda SOACS -> FirstOrSecond -> ADM (Lambda SOACS)
mkScanAdjointLam VjpOps
ops Lambda SOACS
lam0 FirstOrSecond
which = do
let len :: Int
len = forall (t :: * -> *) a. Foldable t => t a -> Int
length forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> [Type]
lambdaReturnType Lambda SOACS
lam0
Lambda SOACS
lam <- forall {k} (rep :: k) (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda SOACS
lam0
let p2diff :: [Param Type]
p2diff =
case FirstOrSecond
which of
FirstOrSecond
WrtFirst -> forall a. Int -> [a] -> [a]
take Int
len forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
lam
FirstOrSecond
WrtSecond -> forall a. Int -> [a] -> [a]
drop Int
len forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
lam
[Adj]
p_adjs <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Type -> ADM Adj
unitAdjOfType (forall {k} (rep :: k). Lambda rep -> [Type]
lambdaReturnType Lambda SOACS
lam)
VjpOps -> [Adj] -> [VName] -> Lambda SOACS -> ADM (Lambda SOACS)
vjpLambda VjpOps
ops [Adj]
p_adjs (forall a b. (a -> b) -> [a] -> [b]
map forall dec. Param dec -> VName
paramName [Param Type]
p2diff) Lambda SOACS
lam
mkScanFusedMapLam :: VjpOps -> SubExp -> Lambda SOACS -> [VName] -> [VName] -> [VName] -> ADM (Lambda SOACS)
mkScanFusedMapLam :: VjpOps
-> SubExp
-> Lambda SOACS
-> [VName]
-> [VName]
-> [VName]
-> ADM (Lambda SOACS)
mkScanFusedMapLam VjpOps
ops SubExp
w Lambda SOACS
scn_lam [VName]
xs [VName]
ys [VName]
ys_adj = do
Lambda SOACS
lam <- VjpOps -> Lambda SOACS -> FirstOrSecond -> ADM (Lambda SOACS)
mkScanAdjointLam VjpOps
ops Lambda SOACS
scn_lam FirstOrSecond
WrtFirst
[Type]
ys_ts <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall {k} (rep :: k) (m :: * -> *).
HasScope rep m =>
VName -> m Type
lookupType [VName]
ys
Param Type
par_i <- forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam [Char]
"i" forall a b. (a -> b) -> a -> b
$ forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
let i :: VName
i = forall dec. Param dec -> VName
paramName Param Type
par_i
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m Result -> m (Lambda (Rep m))
mkLambda [Param Type
par_i] forall a b. (a -> b) -> a -> b
$
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> Result
varsRes forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m [VName]
letTupExp [Char]
"x"
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
(forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp forall a b. (a -> b) -> a -> b
$ forall a. a -> TPrimExp Int64 a
le64 VName
i forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int64 VName
0)
( forall (m :: * -> *).
MonadBuilder m =>
m Result -> m (Body (Rep m))
buildBody_ forall a b. (a -> b) -> a -> b
$ do
[SubExp]
zs <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"ct_zero" forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). Type -> Exp rep
zeroExp forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall u. TypeBase Shape u -> TypeBase Shape u
rowType) [Type]
ys_ts
[SubExp]
os <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"ct_one" forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). Type -> Exp rep
oneExp forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall u. TypeBase Shape u -> TypeBase Shape u
rowType) [Type]
ys_ts
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ [SubExp] -> Result
subExpsRes forall a b. (a -> b) -> a -> b
$ forall a. [(a, a)] -> [a]
unpairs forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip [SubExp]
zs [SubExp]
os
)
( forall (m :: * -> *).
MonadBuilder m =>
m Result -> m (Body (Rep m))
buildBody_ forall a b. (a -> b) -> a -> b
$ do
SubExp
j <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"j" forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp (SubExp -> TPrimExp Int64 VName
pe64 SubExp
w forall a. Num a => a -> a -> a
- (forall a. a -> TPrimExp Int64 a
le64 VName
i forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
1))
SubExp
j1 <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"j1" forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp (SubExp -> TPrimExp Int64 VName
pe64 SubExp
w forall a. Num a => a -> a -> a
- forall a. a -> TPrimExp Int64 a
le64 VName
i)
let index :: SubExp -> VName -> Type -> Exp rep
index SubExp
idx VName
arr Type
t = forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
arr forall a b. (a -> b) -> a -> b
$ Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice Type
t [forall d. d -> DimIndex d
DimFix SubExp
idx]
[SubExp]
y_s <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
ys_adj [Type]
ys_ts) forall a b. (a -> b) -> a -> b
$ \(VName
y_, Type
t) ->
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp (VName -> [Char]
baseString VName
y_ forall a. [a] -> [a] -> [a]
++ [Char]
"_j") forall a b. (a -> b) -> a -> b
$ forall {k} {rep :: k}. SubExp -> VName -> Type -> Exp rep
index SubExp
j VName
y_ Type
t
Result
lam_rs <-
forall (m :: * -> *).
MonadBuilder m =>
Lambda (Rep m) -> [m (Exp (Rep m))] -> m Result
eLambda Lambda SOACS
lam forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a -> b) -> [a] -> [b]
map forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (forall {k} {rep :: k}. SubExp -> VName -> Type -> Exp rep
index SubExp
j) [VName]
ys [Type]
ys_ts forall a. [a] -> [a] -> [a]
++ forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (forall {k} {rep :: k}. SubExp -> VName -> Type -> Exp rep
index SubExp
j1) [VName]
xs [Type]
ys_ts
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a. [(a, a)] -> [a]
unpairs forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip ([SubExp] -> Result
subExpsRes [SubExp]
y_s) Result
lam_rs
)
mkScanLinFunO :: Type -> ADM (Scan SOACS)
mkScanLinFunO :: Type -> ADM (Scan SOACS)
mkScanLinFunO Type
t = do
let pt :: PrimType
pt = forall shape u. TypeBase shape u -> PrimType
elemType Type
t
SubExp
zero <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"zeros" forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Type -> Exp rep
zeroExp Type
t
SubExp
one <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"ones" forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Type -> Exp rep
oneExp Type
t
[VName]
tmp <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [[Char]
"a1", [Char]
"b1", [Char]
"a2", [Char]
"b2"]
let [VName
a1, VName
b1, VName
a2, VName
b2] = [VName]
tmp
pet :: VName -> PrimExp VName
pet = PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
pt forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var
Lambda SOACS
lam <- forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m Result -> m (Lambda (Rep m))
mkLambda (forall a b. (a -> b) -> [a] -> [b]
map (\VName
v -> forall dec. Attrs -> VName -> dec -> Param dec
Param forall a. Monoid a => a
mempty VName
v Type
t) [VName
a1, VName
b1, VName
a2, VName
b2]) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> Result
varsRes forall a b. (a -> b) -> a -> b
$
Int
-> [VName] -> ([VName] -> [VName] -> ADM [VName]) -> ADM [VName]
tabNest (forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank Type
t) [VName
a1, VName
b1, VName
a2, VName
b2] forall a b. (a -> b) -> a -> b
$ \[VName]
_ [VName
a1', VName
b1', VName
a2', VName
b2'] -> do
VName
x <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"x" forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp forall a b. (a -> b) -> a -> b
$ VName -> PrimExp VName
pet VName
a2' forall v. PrimExp v -> PrimExp v -> PrimExp v
~+~ VName -> PrimExp VName
pet VName
b2' forall v. PrimExp v -> PrimExp v -> PrimExp v
~*~ VName -> PrimExp VName
pet VName
a1'
VName
y <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"y" forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp forall a b. (a -> b) -> a -> b
$ VName -> PrimExp VName
pet VName
b1' forall v. PrimExp v -> PrimExp v -> PrimExp v
~*~ VName -> PrimExp VName
pet VName
b2'
forall (f :: * -> *) a. Applicative f => a -> f a
pure [VName
x, VName
y]
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> [SubExp] -> Scan rep
Scan Lambda SOACS
lam [SubExp
zero, SubExp
one]
mkScan2ndMaps :: SubExp -> (Type, VName, (VName, VName)) -> ADM VName
mkScan2ndMaps :: SubExp -> (Type, VName, (VName, VName)) -> ADM VName
mkScan2ndMaps SubExp
w (Type
arr_tp, VName
y_adj, (VName
ds, VName
cs)) = do
SubExp
nm1 <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"nm1" forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp (SubExp -> TPrimExp Int64 VName
pe64 SubExp
w forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1)
VName
y_adj_last <-
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp (VName -> [Char]
baseString VName
y_adj forall a. [a] -> [a] -> [a]
++ [Char]
"_last") forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
VName -> Slice SubExp -> BasicOp
Index VName
y_adj forall a b. (a -> b) -> a -> b
$
Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice Type
arr_tp [forall d. d -> DimIndex d
DimFix SubExp
nm1]
Param Type
par_i <- forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam [Char]
"i" forall a b. (a -> b) -> a -> b
$ forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
Lambda SOACS
lam <- forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m Result -> m (Lambda (Rep m))
mkLambda [Param Type
par_i] forall a b. (a -> b) -> a -> b
$ do
let i :: VName
i = forall dec. Param dec -> VName
paramName Param Type
par_i
SubExp
j <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"j" forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp (SubExp -> TPrimExp Int64 VName
pe64 SubExp
w forall a. Num a => a -> a -> a
- (forall a. a -> TPrimExp Int64 a
le64 VName
i forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
1))
VName
dj <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp (VName -> [Char]
baseString VName
ds forall a. [a] -> [a] -> [a]
++ [Char]
"_dj") forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
ds forall a b. (a -> b) -> a -> b
$ Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice Type
arr_tp [forall d. d -> DimIndex d
DimFix SubExp
j]
VName
cj <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp (VName -> [Char]
baseString VName
cs forall a. [a] -> [a] -> [a]
++ [Char]
"_cj") forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
cs forall a b. (a -> b) -> a -> b
$ Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice Type
arr_tp [forall d. d -> DimIndex d
DimFix SubExp
j]
let pet :: VName -> PrimExp VName
pet = PrimType -> SubExp -> PrimExp VName
primExpFromSubExp (forall shape u. TypeBase shape u -> PrimType
elemType Type
arr_tp) forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> Result
varsRes forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int
-> [VName] -> ([VName] -> [VName] -> ADM [VName]) -> ADM [VName]
tabNest (forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank (forall u. TypeBase Shape u -> TypeBase Shape u
rowType Type
arr_tp)) [VName
y_adj_last, VName
dj, VName
cj] forall a b. (a -> b) -> a -> b
$ \[VName]
_ [VName
y_adj_last', VName
dj', VName
cj'] ->
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m [VName]
letTupExp [Char]
"res" forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp forall a b. (a -> b) -> a -> b
$ VName -> PrimExp VName
pet VName
dj' forall v. PrimExp v -> PrimExp v -> PrimExp v
~+~ VName -> PrimExp VName
pet VName
cj' forall v. PrimExp v -> PrimExp v -> PrimExp v
~*~ VName -> PrimExp VName
pet VName
y_adj_last'
VName
iota <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"iota" forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ SubExp -> SubExp -> SubExp -> IntType -> BasicOp
Iota SubExp
w (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) IntType
Int64
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"after_scan" forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Op rep -> Exp rep
Op (forall {k} (rep :: k).
SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w [VName
iota] (forall {k} (rep :: k).
[Scan rep] -> [Reduce rep] -> Lambda rep -> ScremaForm rep
ScremaForm [] [] Lambda SOACS
lam))
mkScanFinalMap :: VjpOps -> SubExp -> Lambda SOACS -> [VName] -> [VName] -> [VName] -> ADM [VName]
mkScanFinalMap :: VjpOps
-> SubExp
-> Lambda SOACS
-> [VName]
-> [VName]
-> [VName]
-> ADM [VName]
mkScanFinalMap VjpOps
ops SubExp
w Lambda SOACS
scan_lam [VName]
xs [VName]
ys [VName]
rs = do
let eltps :: [Type]
eltps = forall {k} (rep :: k). Lambda rep -> [Type]
lambdaReturnType Lambda SOACS
scan_lam
Lambda SOACS
lam <- VjpOps -> Lambda SOACS -> FirstOrSecond -> ADM (Lambda SOACS)
mkScanAdjointLam VjpOps
ops Lambda SOACS
scan_lam FirstOrSecond
WrtSecond
Param Type
par_i <- forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam [Char]
"i" forall a b. (a -> b) -> a -> b
$ forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
let i :: VName
i = forall dec. Param dec -> VName
paramName Param Type
par_i
[Param Type]
par_x <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (\(VName
x, Type
t) -> forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam (VName -> [Char]
baseString VName
x forall a. [a] -> [a] -> [a]
++ [Char]
"_par_x") Type
t) forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
xs [Type]
eltps
[Param Type]
par_r <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (\(VName
r, Type
t) -> forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam (VName -> [Char]
baseString VName
r forall a. [a] -> [a] -> [a]
++ [Char]
"_par_r") Type
t) forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
rs [Type]
eltps
Lambda SOACS
map_lam <-
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m Result -> m (Lambda (Rep m))
mkLambda (Param Type
par_i forall a. a -> [a] -> [a]
: [Param Type]
par_x forall a. [a] -> [a] -> [a]
++ [Param Type]
par_r) forall a b. (a -> b) -> a -> b
$
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> Result
varsRes forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m [VName]
letTupExp [Char]
"scan_contribs"
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
(forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp forall a b. (a -> b) -> a -> b
$ forall a. a -> TPrimExp Int64 a
le64 VName
i forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int64 VName
0)
(forall (m :: * -> *).
MonadBuilder m =>
[SubExp] -> m (Body (Rep m))
resultBodyM forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (VName -> SubExp
Var forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dec. Param dec -> VName
paramName) [Param Type]
par_r)
( forall (m :: * -> *).
MonadBuilder m =>
m Result -> m (Body (Rep m))
buildBody_ forall a b. (a -> b) -> a -> b
$ do
SubExp
im1 <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"im1" forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp (forall a. a -> TPrimExp Int64 a
le64 VName
i forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1)
[SubExp]
ys_im1 <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [VName]
ys forall a b. (a -> b) -> a -> b
$ \VName
y -> do
Type
y_t <- forall {k} (rep :: k) (m :: * -> *).
HasScope rep m =>
VName -> m Type
lookupType VName
y
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp (VName -> [Char]
baseString VName
y forall a. [a] -> [a] -> [a]
++ [Char]
"_last") forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
y forall a b. (a -> b) -> a -> b
$ Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice Type
y_t [forall d. d -> DimIndex d
DimFix SubExp
im1]
[VName]
lam_res <-
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"const" forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> BasicOp
SubExp forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExpRes -> SubExp
resSubExp)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *).
MonadBuilder m =>
Lambda (Rep m) -> [m (Exp (Rep m))] -> m Result
eLambda Lambda SOACS
lam (forall a b. (a -> b) -> [a] -> [b]
map forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp forall a b. (a -> b) -> a -> b
$ [SubExp]
ys_im1 forall a. [a] -> [a] -> [a]
++ forall a b. (a -> b) -> [a] -> [b]
map (VName -> SubExp
Var forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dec. Param dec -> VName
paramName) [Param Type]
par_x)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ([VName] -> Result
varsRes forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Monoid a => [a] -> a
mconcat) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [VName]
lam_res (forall a b. (a -> b) -> [a] -> [b]
map forall dec. Param dec -> VName
paramName [Param Type]
par_r) [Type]
eltps) forall a b. (a -> b) -> a -> b
$
\(VName
lam_r, VName
r, Type
eltp) -> do
let pet :: VName -> PrimExp VName
pet = PrimType -> SubExp -> PrimExp VName
primExpFromSubExp (forall shape u. TypeBase shape u -> PrimType
elemType Type
eltp) forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var
Int
-> [VName] -> ([VName] -> [VName] -> ADM [VName]) -> ADM [VName]
tabNest (forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank Type
eltp) [VName
lam_r, VName
r] forall a b. (a -> b) -> a -> b
$ \[VName]
_ [VName
lam_r', VName
r'] ->
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m [VName]
letTupExp [Char]
"res" forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp forall a b. (a -> b) -> a -> b
$ VName -> PrimExp VName
pet VName
lam_r' forall v. PrimExp v -> PrimExp v -> PrimExp v
~*~ VName -> PrimExp VName
pet VName
r'
)
VName
iota <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"iota" forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ SubExp -> SubExp -> SubExp -> IntType -> BasicOp
Iota SubExp
w (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) IntType
Int64
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m [VName]
letTupExp [Char]
"scan_contribs" forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Op rep -> Exp rep
Op (forall {k} (rep :: k).
SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w (VName
iota forall a. a -> [a] -> [a]
: [VName]
xs forall a. [a] -> [a] -> [a]
++ [VName]
rs) (forall {k} (rep :: k).
[Scan rep] -> [Reduce rep] -> Lambda rep -> ScremaForm rep
ScremaForm [] [] Lambda SOACS
map_lam))
diffScan :: VjpOps -> [VName] -> SubExp -> [VName] -> Scan SOACS -> ADM ()
diffScan :: VjpOps -> [VName] -> SubExp -> [VName] -> Scan SOACS -> ADM ()
diffScan VjpOps
ops [VName]
ys SubExp
w [VName]
as Scan SOACS
scan = do
[VName]
ys_adj <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> ADM VName
lookupAdjVal [VName]
ys
[Type]
as_ts <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall {k} (rep :: k) (m :: * -> *).
HasScope rep m =>
VName -> m Type
lookupType [VName]
as
Lambda SOACS
map1_lam <- VjpOps
-> SubExp
-> Lambda SOACS
-> [VName]
-> [VName]
-> [VName]
-> ADM (Lambda SOACS)
mkScanFusedMapLam VjpOps
ops SubExp
w (forall {k} (rep :: k). Scan rep -> Lambda rep
scanLambda Scan SOACS
scan) [VName]
as [VName]
ys [VName]
ys_adj
[Scan SOACS]
scans_lin_fun_o <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Type -> ADM (Scan SOACS)
mkScanLinFunO forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> [Type]
lambdaReturnType forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Scan rep -> Lambda rep
scanLambda Scan SOACS
scan
VName
iota <-
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"iota" forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ SubExp -> SubExp -> SubExp -> IntType -> BasicOp
Iota SubExp
w (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) IntType
Int64
[VName]
r_scan <-
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m [VName]
letTupExp [Char]
"adj_ctrb_scan" forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k). Op rep -> Exp rep
Op (forall {k} (rep :: k).
SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w [VName
iota] (forall {k} (rep :: k).
[Scan rep] -> [Reduce rep] -> Lambda rep -> ScremaForm rep
ScremaForm [Scan SOACS]
scans_lin_fun_o [] Lambda SOACS
map1_lam))
[VName]
red_nms <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SubExp -> (Type, VName, (VName, VName)) -> ADM VName
mkScan2ndMaps SubExp
w) (forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [Type]
as_ts [VName]
ys_adj (forall a. [a] -> [(a, a)]
pairs [VName]
r_scan))
[VName]
as_contribs <- VjpOps
-> SubExp
-> Lambda SOACS
-> [VName]
-> [VName]
-> [VName]
-> ADM [VName]
mkScanFinalMap VjpOps
ops SubExp
w (forall {k} (rep :: k). Scan rep -> Lambda rep
scanLambda Scan SOACS
scan) [VName]
as [VName]
ys [VName]
red_nms
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> VName -> ADM ()
updateAdj [VName]
as [VName]
as_contribs