module Futhark.AD.Rev.Scan (diffScan, diffScanVec, diffScanAdd) where

import Control.Monad
import Data.List (transpose)
import Futhark.AD.Rev.Monad
import Futhark.Analysis.PrimExp.Convert
import Futhark.Builder
import Futhark.IR.SOACS
import Futhark.IR.SOACS.Simplify (simplifyLambda)
import Futhark.Tools
import Futhark.Transform.Rename
import Futhark.Util (chunk)

data FirstOrSecond = WrtFirst | WrtSecond

identityM :: Int -> Type -> ADM [[SubExp]]
identityM :: Int -> Type -> ADM [[SubExp]]
identityM Int
n Type
t =
  forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse
    (forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"id"))
    [[if Int
i forall a. Eq a => a -> a -> Bool
== Int
j then forall rep. Type -> Exp rep
oneExp Type
t else forall rep. Type -> Exp rep
zeroExp Type
t | Int
i <- [Int
1 .. Int
n]] | Int
j <- [Int
1 .. Int
n]]

matrixMul :: [[PrimExp VName]] -> [[PrimExp VName]] -> PrimType -> [[PrimExp VName]]
matrixMul :: [[PrimExp VName]]
-> [[PrimExp VName]] -> PrimType -> [[PrimExp VName]]
matrixMul [[PrimExp VName]]
m1 [[PrimExp VName]]
m2 PrimType
t =
  let zero :: PrimExp VName
zero = PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
t forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant forall a b. (a -> b) -> a -> b
$ PrimType -> PrimValue
blankPrimValue PrimType
t
   in [[forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl forall v. PrimExp v -> PrimExp v -> PrimExp v
(~+~) PrimExp VName
zero forall a b. (a -> b) -> a -> b
$ forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith forall v. PrimExp v -> PrimExp v -> PrimExp v
(~*~) [PrimExp VName]
r [PrimExp VName]
q | [PrimExp VName]
q <- forall a. [[a]] -> [[a]]
transpose [[PrimExp VName]]
m2] | [PrimExp VName]
r <- [[PrimExp VName]]
m1]

matrixVecMul :: [[PrimExp VName]] -> [PrimExp VName] -> PrimType -> [PrimExp VName]
matrixVecMul :: [[PrimExp VName]] -> [PrimExp VName] -> PrimType -> [PrimExp VName]
matrixVecMul [[PrimExp VName]]
m [PrimExp VName]
v PrimType
t =
  let zero :: PrimExp VName
zero = PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
t forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant forall a b. (a -> b) -> a -> b
$ PrimType -> PrimValue
blankPrimValue PrimType
t
   in [forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl forall v. PrimExp v -> PrimExp v -> PrimExp v
(~+~) PrimExp VName
zero forall a b. (a -> b) -> a -> b
$ forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith forall v. PrimExp v -> PrimExp v -> PrimExp v
(~*~) [PrimExp VName]
v [PrimExp VName]
r | [PrimExp VName]
r <- [[PrimExp VName]]
m]

vectorAdd :: [PrimExp VName] -> [PrimExp VName] -> [PrimExp VName]
vectorAdd :: [PrimExp VName] -> [PrimExp VName] -> [PrimExp VName]
vectorAdd = forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith forall v. PrimExp v -> PrimExp v -> PrimExp v
(~+~)

orderArgs :: Special -> [a] -> [[a]]
orderArgs :: forall a. Special -> [a] -> [[a]]
orderArgs Special
s [a]
lst =
  let d :: Int
d = forall a. Integral a => a -> a -> a
div (forall (t :: * -> *) a. Foldable t => t a -> Int
length [a]
lst) forall a b. (a -> b) -> a -> b
$ Special -> Int
specialScans Special
s
   in forall a. Int -> [a] -> [[a]]
chunk Int
d [a]
lst

-- computes `d(x op y)/dx` or d(x op y)/dy
mkScanAdjointLam :: VjpOps -> Lambda SOACS -> FirstOrSecond -> [SubExp] -> ADM (Lambda SOACS)
mkScanAdjointLam :: VjpOps
-> Lambda SOACS -> FirstOrSecond -> [SubExp] -> ADM (Lambda SOACS)
mkScanAdjointLam VjpOps
ops Lambda SOACS
lam0 FirstOrSecond
which [SubExp]
adjs = do
  let len :: Int
len = forall (t :: * -> *) a. Foldable t => t a -> Int
length forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda SOACS
lam0
  Lambda SOACS
lam <- forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda SOACS
lam0
  let p2diff :: [Param Type]
p2diff =
        case FirstOrSecond
which of
          FirstOrSecond
WrtFirst -> forall a. Int -> [a] -> [a]
take Int
len forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
lam
          FirstOrSecond
WrtSecond -> forall a. Int -> [a] -> [a]
drop Int
len forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
lam
  VjpOps -> [Adj] -> [VName] -> Lambda SOACS -> ADM (Lambda SOACS)
vjpLambda VjpOps
ops (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> Adj
AdjVal [SubExp]
adjs) (forall a b. (a -> b) -> [a] -> [b]
map forall dec. Param dec -> VName
paramName [Param Type]
p2diff) Lambda SOACS
lam

-- Should generate something like:
-- `\ j -> let i = n - 1 - j
--         if i < n-1 then ( ys_adj[i], df2dx ys[i] xs[i+1]) else (ys_adj[i],1) )`
-- where `ys` is  the result of scan
--       `xs` is  the input  of scan
--       `ys_adj` is the known adjoint of ys
--       `j` draw values from `iota n`
mkScanFusedMapLam ::
  VjpOps ->
  SubExp ->
  Lambda SOACS ->
  [VName] ->
  [VName] ->
  [VName] ->
  Special ->
  Int ->
  ADM (Lambda SOACS)
mkScanFusedMapLam :: VjpOps
-> SubExp
-> Lambda SOACS
-> [VName]
-> [VName]
-> [VName]
-> Special
-> Int
-> ADM (Lambda SOACS)
mkScanFusedMapLam VjpOps
ops SubExp
w Lambda SOACS
scn_lam [VName]
xs [VName]
ys [VName]
ys_adj Special
s Int
d = do
  let sc :: SpecialCase
sc = Special -> SpecialCase
specialCase Special
s
  let k :: Int
k = Special -> Int
specialSubSize Special
s
  [Type]
ys_ts <- forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType [VName]
ys
  [[SubExp]]
idmat <- Int -> Type -> ADM [[SubExp]]
identityM (forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
ys) forall a b. (a -> b) -> a -> b
$ forall u. TypeBase Shape u -> TypeBase Shape u
rowType forall a b. (a -> b) -> a -> b
$ forall a. [a] -> a
head [Type]
ys_ts
  [Lambda SOACS]
lams <- forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (VjpOps
-> Lambda SOACS -> FirstOrSecond -> [SubExp] -> ADM (Lambda SOACS)
mkScanAdjointLam VjpOps
ops Lambda SOACS
scn_lam FirstOrSecond
WrtFirst) [[SubExp]]
idmat
  Param Type
par_i <- forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam [Char]
"i" forall a b. (a -> b) -> a -> b
$ forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
  let i :: VName
i = forall dec. Param dec -> VName
paramName Param Type
par_i
  forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m Result -> m (Lambda (Rep m))
mkLambda [Param Type
par_i] forall a b. (a -> b) -> a -> b
$
    forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> Result
varsRes forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m [VName]
letTupExp [Char]
"x"
      forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
        (forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp forall a b. (a -> b) -> a -> b
$ forall a. a -> TPrimExp Int64 a
le64 VName
i forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int64 VName
0)
        ( forall (m :: * -> *).
MonadBuilder m =>
m Result -> m (Body (Rep m))
buildBody_ forall a b. (a -> b) -> a -> b
$ do
            SubExp
j <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"j" forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp (SubExp -> TPrimExp Int64 VName
pe64 SubExp
w forall a. Num a => a -> a -> a
- (forall a. a -> TPrimExp Int64 a
le64 VName
i forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
1))
            [SubExp]
y_s <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [VName]
ys_adj forall a b. (a -> b) -> a -> b
$ \VName
y_ ->
              forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp (VName -> [Char]
baseString VName
y_ forall a. [a] -> [a] -> [a]
++ [Char]
"_j") forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *).
MonadBuilder m =>
VName -> m (Exp (Rep m)) -> m (Exp (Rep m))
eIndex VName
y_ (forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp SubExp
j)
            let zso :: [[SubExp]]
zso = forall a. Special -> [a] -> [[a]]
orderArgs Special
s [SubExp]
y_s
            let ido :: [[[SubExp]]]
ido = forall a. Special -> [a] -> [[a]]
orderArgs Special
s forall a b. (a -> b) -> a -> b
$ forall a. Int -> SpecialCase -> [[a]] -> [[a]]
case_jac Int
k SpecialCase
sc [[SubExp]]
idmat
            forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ [SubExp] -> Result
subExpsRes forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat forall a b. (a -> b) -> a -> b
$ forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith forall a. [a] -> [a] -> [a]
(++) [[SubExp]]
zso forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[[SubExp]]]
ido
        )
        ( forall (m :: * -> *).
MonadBuilder m =>
m Result -> m (Body (Rep m))
buildBody_ forall a b. (a -> b) -> a -> b
$ do
            SubExp
j <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"j" forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp (SubExp -> TPrimExp Int64 VName
pe64 SubExp
w forall a. Num a => a -> a -> a
- (forall a. a -> TPrimExp Int64 a
le64 VName
i forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
1))
            SubExp
j1 <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"j1" forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp (SubExp -> TPrimExp Int64 VName
pe64 SubExp
w forall a. Num a => a -> a -> a
- forall a. a -> TPrimExp Int64 a
le64 VName
i)
            [SubExp]
y_s <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [VName]
ys_adj forall a b. (a -> b) -> a -> b
$ \VName
y_ ->
              forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp (VName -> [Char]
baseString VName
y_ forall a. [a] -> [a] -> [a]
++ [Char]
"_j") forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *).
MonadBuilder m =>
VName -> m (Exp (Rep m)) -> m (Exp (Rep m))
eIndex VName
y_ (forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp SubExp
j)

            let args :: [ADM (Exp (Rep ADM))]
args =
                  forall a b. (a -> b) -> [a] -> [b]
map (forall (m :: * -> *).
MonadBuilder m =>
VName -> m (Exp (Rep m)) -> m (Exp (Rep m))
`eIndex` forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp SubExp
j) [VName]
ys forall a. [a] -> [a] -> [a]
++ forall a b. (a -> b) -> [a] -> [b]
map (forall (m :: * -> *).
MonadBuilder m =>
VName -> m (Exp (Rep m)) -> m (Exp (Rep m))
`eIndex` forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp SubExp
j1) [VName]
xs
            [Result]
lam_rs <- forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (forall (m :: * -> *).
MonadBuilder m =>
Lambda (Rep m) -> [m (Exp (Rep m))] -> m Result
`eLambda` [ADM (Exp (Rep ADM))]
args) [Lambda SOACS]
lams

            let yso :: [Result]
yso = forall a. Special -> [a] -> [[a]]
orderArgs Special
s forall a b. (a -> b) -> a -> b
$ [SubExp] -> Result
subExpsRes [SubExp]
y_s
            let jaco :: [[Result]]
jaco = forall a. Special -> [a] -> [[a]]
orderArgs Special
s forall a b. (a -> b) -> a -> b
$ forall a. Int -> SpecialCase -> [[a]] -> [[a]]
case_jac Int
k SpecialCase
sc forall a b. (a -> b) -> a -> b
$ forall a. [[a]] -> [[a]]
transpose [Result]
lam_rs

            forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat forall a b. (a -> b) -> a -> b
$ forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith forall a. [a] -> [a] -> [a]
(++) [Result]
yso forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[Result]]
jaco
        )
  where
    case_jac :: Int -> SpecialCase -> [[a]] -> [[a]]
    case_jac :: forall a. Int -> SpecialCase -> [[a]] -> [[a]]
case_jac Int
_ SpecialCase
Generic [[a]]
jac = [[a]]
jac
    case_jac Int
k SpecialCase
ZeroQuadrant [[a]]
jac =
      forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat
        forall a b. (a -> b) -> a -> b
$ forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith
          (\Int
i -> forall a b. (a -> b) -> [a] -> [b]
map (forall a. Int -> [a] -> [a]
take Int
k forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Int -> [a] -> [a]
drop (Int
i forall a. Num a => a -> a -> a
* Int
k)))
          [Int
0 .. Int
d forall a. Integral a => a -> a -> a
`div` Int
k]
        forall a b. (a -> b) -> a -> b
$ forall a. Int -> [a] -> [[a]]
chunk Int
k [[a]]
jac
    case_jac Int
k SpecialCase
MatrixMul [[a]]
jac =
      forall a. Int -> [a] -> [a]
take Int
k forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. Int -> [a] -> [a]
take Int
k [[a]]
jac

-- a1 a2 b -> a2 + b * a1
linFunT0 :: [PrimExp VName] -> [PrimExp VName] -> [[PrimExp VName]] -> Special -> PrimType -> [PrimExp VName]
linFunT0 :: [PrimExp VName]
-> [PrimExp VName]
-> [[PrimExp VName]]
-> Special
-> PrimType
-> [PrimExp VName]
linFunT0 [PrimExp VName]
a1 [PrimExp VName]
a2 [[PrimExp VName]]
b Special
s PrimType
pt =
  let t :: [PrimExp VName]
t = case Special -> SpecialCase
specialCase Special
s of
        SpecialCase
MatrixMul ->
          forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (\[PrimExp VName]
v -> [[PrimExp VName]] -> [PrimExp VName] -> PrimType -> [PrimExp VName]
matrixVecMul [[PrimExp VName]]
b [PrimExp VName]
v PrimType
pt) forall a b. (a -> b) -> a -> b
$ forall a. Int -> [a] -> [[a]]
chunk (Special -> Int
specialSubSize Special
s) [PrimExp VName]
a1
        SpecialCase
_ -> [[PrimExp VName]] -> [PrimExp VName] -> PrimType -> [PrimExp VName]
matrixVecMul [[PrimExp VName]]
b [PrimExp VName]
a1 PrimType
pt
   in [PrimExp VName]
a2 [PrimExp VName] -> [PrimExp VName] -> [PrimExp VName]
`vectorAdd` [PrimExp VName]
t

-- \(a1, b1) (a2, b2) -> (a2 + b2 * a1, b2 * b1)
mkScanLinFunO :: Type -> Special -> ADM (Scan SOACS)
mkScanLinFunO :: Type -> Special -> ADM (Scan SOACS)
mkScanLinFunO Type
t Special
s = do
  let pt :: PrimType
pt = forall shape u. TypeBase shape u -> PrimType
elemType Type
t
  [SubExp]
neu_elm <- (Int, Int) -> ADM [SubExp]
mkNeutral forall a b. (a -> b) -> a -> b
$ Special -> (Int, Int)
specialNeutral Special
s
  let (Int
as, Int
bs) = Special -> (Int, Int)
specialParams Special
s
  ([VName]
a1s, [VName]
b1s, [VName]
a2s, [VName]
b2s) <- forall {m :: * -> *}.
MonadFreshNames m =>
(Int, Int) -> m ([VName], [VName], [VName], [VName])
mkParams (Int
as, Int
bs)
  let pet :: VName -> PrimExp VName
pet = PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
pt forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var
  let (Int
_, Int
n) = Special -> (Int, Int)
specialNeutral Special
s

  Lambda SOACS
lam <- forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m Result -> m (Lambda (Rep m))
mkLambda (forall a b. (a -> b) -> [a] -> [b]
map (\VName
v -> forall dec. Attrs -> VName -> dec -> Param dec
Param forall a. Monoid a => a
mempty VName
v (forall u. TypeBase Shape u -> TypeBase Shape u
rowType Type
t)) ([VName]
a1s forall a. [a] -> [a] -> [a]
++ [VName]
b1s forall a. [a] -> [a] -> [a]
++ [VName]
a2s forall a. [a] -> [a] -> [a]
++ [VName]
b2s)) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [SubExp] -> Result
subExpsRes forall a b. (a -> b) -> a -> b
$ do
    let [[PrimExp VName]
a1s', [PrimExp VName]
b1s', [PrimExp VName]
a2s', [PrimExp VName]
b2s'] = (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap) VName -> PrimExp VName
pet [[VName]
a1s, [VName]
b1s, [VName]
a2s, [VName]
b2s]
    let ([[PrimExp VName]]
b1sm, [[PrimExp VName]]
b2sm) = (forall a. Int -> [a] -> [[a]]
chunk Int
n [PrimExp VName]
b1s', forall a. Int -> [a] -> [[a]]
chunk Int
n [PrimExp VName]
b2s')

    let t0 :: [PrimExp VName]
t0 = [PrimExp VName]
-> [PrimExp VName]
-> [[PrimExp VName]]
-> Special
-> PrimType
-> [PrimExp VName]
linFunT0 [PrimExp VName]
a1s' [PrimExp VName]
a2s' [[PrimExp VName]]
b2sm Special
s PrimType
pt
    let t1 :: [PrimExp VName]
t1 = forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat forall a b. (a -> b) -> a -> b
$ [[PrimExp VName]]
-> [[PrimExp VName]] -> PrimType -> [[PrimExp VName]]
matrixMul [[PrimExp VName]]
b2sm [[PrimExp VName]]
b1sm PrimType
pt
    forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"r" forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp) forall a b. (a -> b) -> a -> b
$ [PrimExp VName]
t0 forall a. [a] -> [a] -> [a]
++ [PrimExp VName]
t1

  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> [SubExp] -> Scan rep
Scan Lambda SOACS
lam [SubExp]
neu_elm
  where
    mkNeutral :: (Int, Int) -> ADM [SubExp]
mkNeutral (Int
a, Int
b) = do
      [SubExp]
zeros <- forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
a forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"zeros" forall a b. (a -> b) -> a -> b
$ forall rep. Type -> Exp rep
zeroExp forall a b. (a -> b) -> a -> b
$ forall u. TypeBase Shape u -> TypeBase Shape u
rowType Type
t
      [[SubExp]]
idmat <- Int -> Type -> ADM [[SubExp]]
identityM Int
b forall a b. (a -> b) -> a -> b
$ forall shape u. PrimType -> TypeBase shape u
Prim forall a b. (a -> b) -> a -> b
$ forall shape u. TypeBase shape u -> PrimType
elemType Type
t
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ [SubExp]
zeros forall a. [a] -> [a] -> [a]
++ forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[SubExp]]
idmat

    mkParams :: (Int, Int) -> m ([VName], [VName], [VName], [VName])
mkParams (Int
a, Int
b) = do
      [VName]
a1s <- forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
a forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"a1"
      [VName]
b1s <- forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
b forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"b1"
      [VName]
a2s <- forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
a forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"a2"
      [VName]
b2s <- forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
b forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"b2"
      forall (f :: * -> *) a. Applicative f => a -> f a
pure ([VName]
a1s, [VName]
b1s, [VName]
a2s, [VName]
b2s)

-- perform the final map
-- let xs_contribs =
--    map3 (\ i a r -> if i==0 then r else (df2dy (ys[i-1]) a) \bar{*} r)
--         (iota n) xs (reverse ds)
mkScanFinalMap :: VjpOps -> SubExp -> Lambda SOACS -> [VName] -> [VName] -> [VName] -> ADM [VName]
mkScanFinalMap :: VjpOps
-> SubExp
-> Lambda SOACS
-> [VName]
-> [VName]
-> [VName]
-> ADM [VName]
mkScanFinalMap VjpOps
ops SubExp
w Lambda SOACS
scan_lam [VName]
xs [VName]
ys [VName]
ds = do
  let eltps :: [Type]
eltps = forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda SOACS
scan_lam

  Param Type
par_i <- forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam [Char]
"i" forall a b. (a -> b) -> a -> b
$ forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
  let i :: VName
i = forall dec. Param dec -> VName
paramName Param Type
par_i
  [Param Type]
par_x <- forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM (\VName
x -> forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam (VName -> [Char]
baseString VName
x forall a. [a] -> [a] -> [a]
++ [Char]
"_par_x")) [VName]
xs [Type]
eltps

  Lambda SOACS
map_lam <-
    forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m Result -> m (Lambda (Rep m))
mkLambda (Param Type
par_i forall a. a -> [a] -> [a]
: [Param Type]
par_x) forall a b. (a -> b) -> a -> b
$ do
      SubExp
j <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"j" forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp (SubExp -> TPrimExp Int64 VName
pe64 SubExp
w forall a. Num a => a -> a -> a
- (forall a. a -> TPrimExp Int64 a
le64 VName
i forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
1))

      [VName]
dj <-
        forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse
          (\VName
dd -> forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp (VName -> [Char]
baseString VName
dd forall a. [a] -> [a] -> [a]
++ [Char]
"_dj") forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *).
MonadBuilder m =>
VName -> m (Exp (Rep m)) -> m (Exp (Rep m))
eIndex VName
dd (forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp SubExp
j))
          [VName]
ds

      forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> Result
varsRes forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m [VName]
letTupExp [Char]
"scan_contribs"
        forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
          (forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp forall a b. (a -> b) -> a -> b
$ forall a. a -> TPrimExp Int64 a
le64 VName
i forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int64 VName
0)
          (forall (m :: * -> *).
MonadBuilder m =>
[SubExp] -> m (Body (Rep m))
resultBodyM forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap VName -> SubExp
Var [VName]
dj)
          ( forall (m :: * -> *).
MonadBuilder m =>
m Result -> m (Body (Rep m))
buildBody_ forall a b. (a -> b) -> a -> b
$ do
              Lambda SOACS
lam <- VjpOps
-> Lambda SOACS -> FirstOrSecond -> [SubExp] -> ADM (Lambda SOACS)
mkScanAdjointLam VjpOps
ops Lambda SOACS
scan_lam FirstOrSecond
WrtSecond forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap VName -> SubExp
Var [VName]
dj

              SubExp
im1 <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"im1" forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp (forall a. a -> TPrimExp Int64 a
le64 VName
i forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1)
              [SubExp]
ys_im1 <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [VName]
ys forall a b. (a -> b) -> a -> b
$ \VName
y -> do
                Type
y_t <- forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
y
                forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp (VName -> [Char]
baseString VName
y forall a. [a] -> [a] -> [a]
++ [Char]
"_last") forall a b. (a -> b) -> a -> b
$ forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
y forall a b. (a -> b) -> a -> b
$ Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice Type
y_t [forall d. d -> DimIndex d
DimFix SubExp
im1]

              let args :: [ADM (Exp (Rep ADM))]
args = forall a b. (a -> b) -> [a] -> [b]
map forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp forall a b. (a -> b) -> a -> b
$ [SubExp]
ys_im1 forall a. [a] -> [a] -> [a]
++ forall a b. (a -> b) -> [a] -> [b]
map (VName -> SubExp
Var forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dec. Param dec -> VName
paramName) [Param Type]
par_x
              forall (m :: * -> *).
MonadBuilder m =>
Lambda (Rep m) -> [m (Exp (Rep m))] -> m Result
eLambda Lambda SOACS
lam [ADM (Exp (Rep ADM))]
args
          )

  VName
iota <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"iota" forall a b. (a -> b) -> a -> b
$ forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ SubExp -> SubExp -> SubExp -> IntType -> BasicOp
Iota SubExp
w (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) IntType
Int64
  forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m [VName]
letTupExp [Char]
"scan_contribs" forall a b. (a -> b) -> a -> b
$ forall rep. Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w (VName
iota forall a. a -> [a] -> [a]
: [VName]
xs) forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
map_lam

data SpecialCase
  = Generic
  | ZeroQuadrant
  | MatrixMul
  deriving (Int -> SpecialCase -> ShowS
[SpecialCase] -> ShowS
SpecialCase -> [Char]
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
showList :: [SpecialCase] -> ShowS
$cshowList :: [SpecialCase] -> ShowS
show :: SpecialCase -> [Char]
$cshow :: SpecialCase -> [Char]
showsPrec :: Int -> SpecialCase -> ShowS
$cshowsPrec :: Int -> SpecialCase -> ShowS
Show)

data Special = Special
  { Special -> (Int, Int)
specialNeutral :: (Int, Int),
    Special -> (Int, Int)
specialParams :: (Int, Int),
    Special -> Int
specialScans :: Int,
    Special -> Int
specialSubSize :: Int,
    Special -> SpecialCase
specialCase :: SpecialCase
  }
  deriving (Int -> Special -> ShowS
[Special] -> ShowS
Special -> [Char]
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
showList :: [Special] -> ShowS
$cshowList :: [Special] -> ShowS
show :: Special -> [Char]
$cshow :: Special -> [Char]
showsPrec :: Int -> Special -> ShowS
$cshowsPrec :: Int -> Special -> ShowS
Show)

subMats :: Int -> [[Exp SOACS]] -> Exp SOACS -> Maybe Int
subMats :: Int -> [[Exp SOACS]] -> Exp SOACS -> Maybe Int
subMats Int
d [[Exp SOACS]]
mat Exp SOACS
zero =
  let sub_d :: [Int]
sub_d = forall a. (a -> Bool) -> [a] -> [a]
filter (\Int
x -> Int
d forall a. Integral a => a -> a -> a
`mod` Int
x forall a. Eq a => a -> a -> Bool
== Int
0) [Int
1 .. (Int
d forall a. Integral a => a -> a -> a
`div` Int
2)]
      poss :: [Bool]
poss = forall a b. (a -> b) -> [a] -> [b]
map (\Int
m -> forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Int -> ([Exp SOACS], Int) -> Bool
ok Int
m) forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip [[Exp SOACS]]
mat [Int
0 .. Int
d forall a. Num a => a -> a -> a
- Int
1]) [Int]
sub_d
      tmp :: [(Bool, Int)]
tmp = forall a. (a -> Bool) -> [a] -> [a]
filter forall a b. (a, b) -> a
fst (forall a b. [a] -> [b] -> [(a, b)]
zip [Bool]
poss [Int]
sub_d)
   in if forall (t :: * -> *) a. Foldable t => t a -> Bool
null [(Bool, Int)]
tmp then forall a. Maybe a
Nothing else forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ forall a b. (a, b) -> b
snd forall a b. (a -> b) -> a -> b
$ forall a. [a] -> a
head [(Bool, Int)]
tmp
  where
    ok :: Int -> ([Exp SOACS], Int) -> Bool
ok Int
m ([Exp SOACS]
row, Int
i) =
      forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (\(Exp SOACS
v, Int
j) -> Exp SOACS
v forall a. Eq a => a -> a -> Bool
== Exp SOACS
zero Bool -> Bool -> Bool
|| (Int
i forall a. Integral a => a -> a -> a
`div` Int
m forall a. Eq a => a -> a -> Bool
== Int
j forall a. Integral a => a -> a -> a
`div` Int
m)) forall a b. (a -> b) -> a -> b
$
        forall a b. [a] -> [b] -> [(a, b)]
zip [Exp SOACS]
row [Int
0 .. Int
d forall a. Num a => a -> a -> a
- Int
1]

cases :: Int -> Type -> [[Exp SOACS]] -> Special
cases :: Int -> Type -> [[Exp SOACS]] -> Special
cases Int
d Type
t [[Exp SOACS]]
mat
  | Just Int
k <- Int -> [[Exp SOACS]] -> Exp SOACS -> Maybe Int
subMats Int
d [[Exp SOACS]]
mat forall a b. (a -> b) -> a -> b
$ forall rep. Type -> Exp rep
zeroExp Type
t =
      let nonZeros :: [[[Exp SOACS]]]
nonZeros = forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (\Int
i -> forall a b. (a -> b) -> [a] -> [b]
map (forall a. Int -> [a] -> [a]
take Int
k forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Int -> [a] -> [a]
drop (Int
i forall a. Num a => a -> a -> a
* Int
k))) [Int
0 .. Int
d forall a. Integral a => a -> a -> a
`div` Int
k] forall a b. (a -> b) -> a -> b
$ forall a. Int -> [a] -> [[a]]
chunk Int
k [[Exp SOACS]]
mat
       in if forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (forall a. Eq a => a -> a -> Bool
== forall a. [a] -> a
head [[[Exp SOACS]]]
nonZeros) forall a b. (a -> b) -> a -> b
$ forall a. [a] -> [a]
tail [[[Exp SOACS]]]
nonZeros
            then (Int, Int) -> (Int, Int) -> Int -> Int -> SpecialCase -> Special
Special (Int
d, Int
k) (Int
d, Int
k forall a. Num a => a -> a -> a
* Int
k) Int
1 Int
k SpecialCase
MatrixMul
            else (Int, Int) -> (Int, Int) -> Int -> Int -> SpecialCase -> Special
Special (Int
k, Int
k) (Int
k, Int
k forall a. Num a => a -> a -> a
* Int
k) (Int
d forall a. Integral a => a -> a -> a
`div` Int
k) Int
k SpecialCase
ZeroQuadrant
cases Int
d Type
_ [[Exp SOACS]]
_ = (Int, Int) -> (Int, Int) -> Int -> Int -> SpecialCase -> Special
Special (Int
d, Int
d) (Int
d, Int
d forall a. Num a => a -> a -> a
* Int
d) Int
1 Int
d SpecialCase
Generic

identifyCase :: VjpOps -> Lambda SOACS -> ADM Special
identifyCase :: VjpOps -> Lambda SOACS -> ADM Special
identifyCase VjpOps
ops Lambda SOACS
lam = do
  let t :: [Type]
t = forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda SOACS
lam
  let d :: Int
d = forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
t
  [[SubExp]]
idmat <- Int -> Type -> ADM [[SubExp]]
identityM Int
d forall a b. (a -> b) -> a -> b
$ forall a. [a] -> a
head [Type]
t
  [Lambda SOACS]
lams <- forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (VjpOps
-> Lambda SOACS -> FirstOrSecond -> [SubExp] -> ADM (Lambda SOACS)
mkScanAdjointLam VjpOps
ops Lambda SOACS
lam FirstOrSecond
WrtFirst) [[SubExp]]
idmat

  [Param Type]
par1 <- forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam [Char]
"tmp1") [Type]
t
  [Param Type]
par2 <- forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam [Char]
"tmp2") [Type]
t
  Lambda SOACS
jac_lam <- forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m Result -> m (Lambda (Rep m))
mkLambda ([Param Type]
par1 forall a. [a] -> [a] -> [a]
++ [Param Type]
par2) forall a b. (a -> b) -> a -> b
$ do
    let args :: [ADM (Exp (Rep ADM))]
args = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam forall a b. (a -> b) -> a -> b
$ [Param Type]
par1 forall a. [a] -> [a] -> [a]
++ [Param Type]
par2
    [Result]
lam_rs <- forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (forall (m :: * -> *).
MonadBuilder m =>
Lambda (Rep m) -> [m (Exp (Rep m))] -> m Result
`eLambda` [ADM (Exp (Rep ADM))]
args) [Lambda SOACS]
lams

    forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat (forall a. [[a]] -> [[a]]
transpose [Result]
lam_rs)

  Lambda SOACS
simp <- forall (m :: * -> *).
(HasScope SOACS m, MonadFreshNames m) =>
Lambda SOACS -> m (Lambda SOACS)
simplifyLambda Lambda SOACS
jac_lam
  let jac :: [[Exp rep]]
jac = forall a. Int -> [a] -> [[a]]
chunk Int
d forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall rep. BasicOp -> Exp rep
BasicOp forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> BasicOp
SubExp forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExpRes -> SubExp
resSubExp) forall a b. (a -> b) -> a -> b
$ forall rep. Body rep -> Result
bodyResult forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> Body rep
lambdaBody Lambda SOACS
simp
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Int -> Type -> [[Exp SOACS]] -> Special
cases Int
d (forall a. [a] -> a
head [Type]
t) forall {rep}. [[Exp rep]]
jac

diffScan :: VjpOps -> [VName] -> SubExp -> [VName] -> Scan SOACS -> ADM ()
diffScan :: VjpOps -> [VName] -> SubExp -> [VName] -> Scan SOACS -> ADM ()
diffScan VjpOps
ops [VName]
ys SubExp
w [VName]
as Scan SOACS
scan = do
  Special
sc <- VjpOps -> Lambda SOACS -> ADM Special
identifyCase VjpOps
ops (forall rep. Scan rep -> Lambda rep
scanLambda Scan SOACS
scan)
  let d :: Int
d = forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
as
  [VName]
ys_adj <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> ADM VName
lookupAdjVal [VName]
ys
  [Type]
as_ts <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType [VName]
as
  Lambda SOACS
map1_lam <- VjpOps
-> SubExp
-> Lambda SOACS
-> [VName]
-> [VName]
-> [VName]
-> Special
-> Int
-> ADM (Lambda SOACS)
mkScanFusedMapLam VjpOps
ops SubExp
w (forall rep. Scan rep -> Lambda rep
scanLambda Scan SOACS
scan) [VName]
as [VName]
ys [VName]
ys_adj Special
sc Int
d
  Scan SOACS
scans_lin_fun_o <- Type -> Special -> ADM (Scan SOACS)
mkScanLinFunO (forall a. [a] -> a
head [Type]
as_ts) Special
sc
  [Scan SOACS]
scan_lams <- Int -> Scan SOACS -> ADM [Scan SOACS]
mkScans (Special -> Int
specialScans Special
sc) Scan SOACS
scans_lin_fun_o
  VName
iota <-
    forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"iota" forall a b. (a -> b) -> a -> b
$ forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ SubExp -> SubExp -> SubExp -> IntType -> BasicOp
Iota SubExp
w (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) IntType
Int64
  [VName]
r_scan <-
    forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m [VName]
letTupExp [Char]
"adj_ctrb_scan" forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. Op rep -> Exp rep
Op forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w [VName
iota] forall a b. (a -> b) -> a -> b
$
      forall rep. [Scan rep] -> Lambda rep -> ScremaForm rep
scanomapSOAC [Scan SOACS]
scan_lams Lambda SOACS
map1_lam

  [VName]
as_contribs <- VjpOps
-> SubExp
-> Lambda SOACS
-> [VName]
-> [VName]
-> [VName]
-> ADM [VName]
mkScanFinalMap VjpOps
ops SubExp
w (forall rep. Scan rep -> Lambda rep
scanLambda Scan SOACS
scan) [VName]
as [VName]
ys (forall {b}. Special -> [b] -> Int -> [b]
splitScanRes Special
sc [VName]
r_scan Int
d)
  forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> VName -> ADM ()
updateAdj [VName]
as [VName]
as_contribs
  where
    mkScans :: Int -> Scan SOACS -> ADM [Scan SOACS]
    mkScans :: Int -> Scan SOACS -> ADM [Scan SOACS]
mkScans Int
d Scan SOACS
s =
      forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
d forall a b. (a -> b) -> a -> b
$ do
        Lambda SOACS
lam' <- forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda forall a b. (a -> b) -> a -> b
$ forall rep. Scan rep -> Lambda rep
scanLambda Scan SOACS
s
        forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> [SubExp] -> Scan rep
Scan Lambda SOACS
lam' forall a b. (a -> b) -> a -> b
$ forall rep. Scan rep -> [SubExp]
scanNeutral Scan SOACS
s

    splitScanRes :: Special -> [b] -> Int -> [b]
splitScanRes Special
sc [b]
res Int
d =
      forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (forall a. Int -> [a] -> [a]
take (forall a. Integral a => a -> a -> a
div Int
d forall a b. (a -> b) -> a -> b
$ Special -> Int
specialScans Special
sc)) (forall a. Special -> [a] -> [[a]]
orderArgs Special
sc [b]
res)

diffScanVec ::
  VjpOps ->
  [VName] ->
  StmAux () ->
  SubExp ->
  Lambda SOACS ->
  [SubExp] ->
  [VName] ->
  ADM () ->
  ADM ()
diffScanVec :: VjpOps
-> [VName]
-> StmAux ()
-> SubExp
-> Lambda SOACS
-> [SubExp]
-> [VName]
-> ADM ()
-> ADM ()
diffScanVec VjpOps
ops [VName]
ys StmAux ()
aux SubExp
w Lambda SOACS
lam [SubExp]
ne [VName]
as ADM ()
m = do
  Seq (Stm SOACS)
stmts <- forall (m :: * -> *) a. MonadBuilder m => m a -> m (Stms (Rep m))
collectStms_ forall a b. (a -> b) -> a -> b
$ do
    Int
rank <- forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType (forall a. [a] -> a
head [VName]
as)
    let rear :: [Int]
rear = [Int
1, Int
0] forall a. [a] -> [a] -> [a]
++ forall a. Int -> [a] -> [a]
drop Int
2 [Int
0 .. Int
rank forall a. Num a => a -> a -> a
- Int
1]

    [VName]
transp_as <-
      forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse
        (\VName
a -> forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp (VName -> [Char]
baseString VName
a forall a. [a] -> [a] -> [a]
++ [Char]
"_transp") forall a b. (a -> b) -> a -> b
$ forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ [Int] -> VName -> BasicOp
Rearrange [Int]
rear VName
a)
        [VName]
as

    [Type]
ts <- forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType [VName]
transp_as
    let n :: SubExp
n = forall u. Int -> [TypeBase Shape u] -> SubExp
arraysSize Int
0 [Type]
ts

    [Param Type]
as_par <- forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam [Char]
"as_par" forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall u. TypeBase Shape u -> TypeBase Shape u
rowType) [Type]
ts
    [Param Type]
ne_par <- forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam [Char]
"ne_par") forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda SOACS
lam

    ScremaForm SOACS
scan_form <- forall rep (m :: * -> *).
(Buildable rep, MonadFreshNames m) =>
[Scan rep] -> m (ScremaForm rep)
scanSOAC [forall rep. Lambda rep -> [SubExp] -> Scan rep
Scan Lambda SOACS
lam (forall a b. (a -> b) -> [a] -> [b]
map (VName -> SubExp
Var forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dec. Param dec -> VName
paramName) [Param Type]
ne_par)]

    Lambda SOACS
map_lam <-
      forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m Result -> m (Lambda (Rep m))
mkLambda ([Param Type]
as_par forall a. [a] -> [a] -> [a]
++ [Param Type]
ne_par) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> Result
varsRes forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m [VName]
letTupExp [Char]
"map_res" forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$
        forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w (forall a b. (a -> b) -> [a] -> [b]
map forall dec. Param dec -> VName
paramName [Param Type]
as_par) ScremaForm SOACS
scan_form

    [VName]
transp_ys <-
      forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m [VName]
letTupExp [Char]
"trans_ys" forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$
        forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
n ([VName]
transp_as forall a. [a] -> [a] -> [a]
++ [SubExp] -> [VName]
subExpVars [SubExp]
ne) (forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
map_lam)

    forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM
      (\VName
y VName
x -> forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux ()
aux forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
y] forall a b. (a -> b) -> a -> b
$ forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ [Int] -> VName -> BasicOp
Rearrange [Int]
rear VName
x)
      [VName]
ys
      [VName]
transp_ys

  forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (VjpOps -> Stm SOACS -> ADM () -> ADM ()
vjpStm VjpOps
ops) ADM ()
m Seq (Stm SOACS)
stmts

diffScanAdd :: VjpOps -> VName -> SubExp -> Lambda SOACS -> SubExp -> VName -> ADM ()
diffScanAdd :: VjpOps
-> VName -> SubExp -> Lambda SOACS -> SubExp -> VName -> ADM ()
diffScanAdd VjpOps
_ops VName
ys SubExp
n Lambda SOACS
lam' SubExp
ne VName
as = do
  Lambda SOACS
lam <- forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda SOACS
lam'
  VName
ys_bar <- VName -> ADM VName
lookupAdjVal VName
ys

  Lambda SOACS
map_scan <- VName -> ADM (Lambda SOACS)
rev_arr_lam VName
ys_bar

  VName
iota <-
    forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"iota" forall a b. (a -> b) -> a -> b
$ forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ SubExp -> SubExp -> SubExp -> IntType -> BasicOp
Iota SubExp
n (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) IntType
Int64

  VName
scan_res <-
    forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"res_rev" forall a b. (a -> b) -> a -> b
$ forall rep. Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
n [VName
iota] forall a b. (a -> b) -> a -> b
$ forall rep. [Scan rep] -> Lambda rep -> ScremaForm rep
scanomapSOAC [forall rep. Lambda rep -> [SubExp] -> Scan rep
Scan Lambda SOACS
lam [SubExp
ne]] Lambda SOACS
map_scan

  Lambda SOACS
rev_lam <- VName -> ADM (Lambda SOACS)
rev_arr_lam VName
scan_res
  VName
contrb <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"contrb" forall a b. (a -> b) -> a -> b
$ forall rep. Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
n [VName
iota] forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
rev_lam

  VName -> VName -> ADM ()
updateAdj VName
as VName
contrb
  where
    rev_arr_lam :: VName -> ADM (Lambda SOACS)
    rev_arr_lam :: VName -> ADM (Lambda SOACS)
rev_arr_lam VName
arr = do
      Param Type
par_i <- forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam [Char]
"i" forall a b. (a -> b) -> a -> b
$ forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
      forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m Result -> m (Lambda (Rep m))
mkLambda [Param Type
par_i] forall a b. (a -> b) -> a -> b
$ do
        VName
a <-
          forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"ys_bar_rev"
            forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *).
MonadBuilder m =>
VName -> m (Exp (Rep m)) -> m (Exp (Rep m))
eIndex VName
arr (forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp (SubExp -> TPrimExp Int64 VName
pe64 SubExp
n forall a. Num a => a -> a -> a
- forall a. a -> TPrimExp Int64 a
le64 (forall dec. Param dec -> VName
paramName Param Type
par_i) forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1))
        forall (f :: * -> *) a. Applicative f => a -> f a
pure [VName -> SubExpRes
varRes VName
a]