{-# LANGUAGE TypeFamilies #-}

-- Naming scheme:
--
-- An adjoint-related object for "x" is named "x_adj".  This means
-- both actual adjoints and statements.
--
-- Do not assume "x'" means anything related to derivatives.
module Futhark.AD.Rev (revVJP) where

import Control.Monad
import Data.List ((\\))
import Data.List.NonEmpty (NonEmpty (..))
import Data.Map qualified as M
import Futhark.AD.Derivatives
import Futhark.AD.Rev.Loop
import Futhark.AD.Rev.Monad
import Futhark.AD.Rev.SOAC
import Futhark.Analysis.PrimExp.Convert
import Futhark.Builder
import Futhark.IR.SOACS
import Futhark.Tools
import Futhark.Transform.Rename
import Futhark.Transform.Substitute
import Futhark.Util (takeLast)

patName :: Pat Type -> ADM VName
patName :: Pat Type -> ADM VName
patName (Pat [PatElem Type
pe]) = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall dec. PatElem dec -> VName
patElemName PatElem Type
pe
patName Pat Type
pat = forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$ [Char]
"Expected single-element pattern: " forall a. [a] -> [a] -> [a]
++ forall a. Pretty a => a -> [Char]
prettyString Pat Type
pat

-- The vast majority of BasicOps require no special treatment in the
-- forward pass and produce one value (and hence one adjoint).  We
-- deal with that case here.
commonBasicOp :: Pat Type -> StmAux () -> BasicOp -> ADM () -> ADM (VName, VName)
commonBasicOp :: Pat Type -> StmAux () -> BasicOp -> ADM () -> ADM (VName, VName)
commonBasicOp Pat Type
pat StmAux ()
aux BasicOp
op ADM ()
m = do
  forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm forall a b. (a -> b) -> a -> b
$ forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
pat StmAux ()
aux forall a b. (a -> b) -> a -> b
$ forall rep. BasicOp -> Exp rep
BasicOp BasicOp
op
  ADM ()
m
  VName
pat_v <- Pat Type -> ADM VName
patName Pat Type
pat
  VName
pat_adj <- VName -> ADM VName
lookupAdjVal VName
pat_v
  forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName
pat_v, VName
pat_adj)

diffBasicOp :: Pat Type -> StmAux () -> BasicOp -> ADM () -> ADM ()
diffBasicOp :: Pat Type -> StmAux () -> BasicOp -> ADM () -> ADM ()
diffBasicOp Pat Type
pat StmAux ()
aux BasicOp
e ADM ()
m =
  case BasicOp
e of
    CmpOp CmpOp
cmp SubExp
x SubExp
y -> do
      (VName
_pat_v, VName
pat_adj) <- Pat Type -> StmAux () -> BasicOp -> ADM () -> ADM (VName, VName)
commonBasicOp Pat Type
pat StmAux ()
aux BasicOp
e ADM ()
m
      forall a. ADM a -> ADM a
returnSweepCode forall a b. (a -> b) -> a -> b
$ do
        let t :: PrimType
t = CmpOp -> PrimType
cmpOpType CmpOp
cmp
            update :: VName -> ADM ()
update VName
contrib = do
              forall (f :: * -> *) a. Functor f => f a -> f ()
void forall a b. (a -> b) -> a -> b
$ SubExp -> VName -> ADM ()
updateSubExpAdj SubExp
x VName
contrib
              forall (f :: * -> *) a. Functor f => f a -> f ()
void forall a b. (a -> b) -> a -> b
$ SubExp -> VName -> ADM ()
updateSubExpAdj SubExp
y VName
contrib

        case PrimType
t of
          FloatType FloatType
ft ->
            VName -> ADM ()
update forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"contrib" forall a b. (a -> b) -> a -> b
$
              forall rep.
[SubExp]
-> [Case (Body rep)]
-> Body rep
-> MatchDec (BranchType rep)
-> Exp rep
Match
                [VName -> SubExp
Var VName
pat_adj]
                [forall body. [Maybe PrimValue] -> body -> Case body
Case [forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ Bool -> PrimValue
BoolValue Bool
True] forall a b. (a -> b) -> a -> b
$ forall rep. Buildable rep => [SubExp] -> Body rep
resultBody [forall v. IsValue v => v -> SubExp
constant (forall num. Real num => FloatType -> num -> FloatValue
floatValue FloatType
ft (Int
1 :: Int))]]
                (forall rep. Buildable rep => [SubExp] -> Body rep
resultBody [forall v. IsValue v => v -> SubExp
constant (forall num. Real num => FloatType -> num -> FloatValue
floatValue FloatType
ft (Int
0 :: Int))])
                (forall rt. [rt] -> MatchSort -> MatchDec rt
MatchDec [forall shape u. PrimType -> TypeBase shape u
Prim (FloatType -> PrimType
FloatType FloatType
ft)] MatchSort
MatchNormal)
          IntType IntType
it ->
            VName -> ADM ()
update forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"contrib" forall a b. (a -> b) -> a -> b
$ forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ ConvOp -> SubExp -> BasicOp
ConvOp (IntType -> ConvOp
BToI IntType
it) (VName -> SubExp
Var VName
pat_adj)
          PrimType
Bool ->
            VName -> ADM ()
update VName
pat_adj
          PrimType
Unit ->
            forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
    --
    ConvOp ConvOp
op SubExp
x -> do
      (VName
_pat_v, VName
pat_adj) <- Pat Type -> StmAux () -> BasicOp -> ADM () -> ADM (VName, VName)
commonBasicOp Pat Type
pat StmAux ()
aux BasicOp
e ADM ()
m
      forall a. ADM a -> ADM a
returnSweepCode forall a b. (a -> b) -> a -> b
$ do
        VName
contrib <-
          forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"contrib" forall a b. (a -> b) -> a -> b
$ forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ ConvOp -> SubExp -> BasicOp
ConvOp (ConvOp -> ConvOp
flipConvOp ConvOp
op) forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
pat_adj
        SubExp -> VName -> ADM ()
updateSubExpAdj SubExp
x VName
contrib
    --
    UnOp UnOp
op SubExp
x -> do
      (VName
_pat_v, VName
pat_adj) <- Pat Type -> StmAux () -> BasicOp -> ADM () -> ADM (VName, VName)
commonBasicOp Pat Type
pat StmAux ()
aux BasicOp
e ADM ()
m

      forall a. ADM a -> ADM a
returnSweepCode forall a b. (a -> b) -> a -> b
$ do
        let t :: PrimType
t = UnOp -> PrimType
unOpType UnOp
op
        VName
contrib <- do
          let x_pe :: PrimExp VName
x_pe = PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
t SubExp
x
              pat_adj' :: PrimExp VName
pat_adj' = PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
t (VName -> SubExp
Var VName
pat_adj)
              dx :: PrimExp VName
dx = UnOp -> PrimExp VName -> PrimExp VName
pdUnOp UnOp
op PrimExp VName
x_pe
          forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"contrib" forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp forall a b. (a -> b) -> a -> b
$ PrimExp VName
pat_adj' forall v. PrimExp v -> PrimExp v -> PrimExp v
~*~ PrimExp VName
dx

        SubExp -> VName -> ADM ()
updateSubExpAdj SubExp
x VName
contrib
    --
    BinOp BinOp
op SubExp
x SubExp
y -> do
      (VName
_pat_v, VName
pat_adj) <- Pat Type -> StmAux () -> BasicOp -> ADM () -> ADM (VName, VName)
commonBasicOp Pat Type
pat StmAux ()
aux BasicOp
e ADM ()
m

      forall a. ADM a -> ADM a
returnSweepCode forall a b. (a -> b) -> a -> b
$ do
        let t :: PrimType
t = BinOp -> PrimType
binOpType BinOp
op
            (PrimExp VName
wrt_x, PrimExp VName
wrt_y) =
              BinOp
-> PrimExp VName -> PrimExp VName -> (PrimExp VName, PrimExp VName)
pdBinOp BinOp
op (PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
t SubExp
x) (PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
t SubExp
y)

            pat_adj' :: PrimExp VName
pat_adj' = PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
t forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
pat_adj

        VName
adj_x <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"binop_x_adj" forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp forall a b. (a -> b) -> a -> b
$ PrimExp VName
pat_adj' forall v. PrimExp v -> PrimExp v -> PrimExp v
~*~ PrimExp VName
wrt_x
        VName
adj_y <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"binop_y_adj" forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp forall a b. (a -> b) -> a -> b
$ PrimExp VName
pat_adj' forall v. PrimExp v -> PrimExp v -> PrimExp v
~*~ PrimExp VName
wrt_y
        SubExp -> VName -> ADM ()
updateSubExpAdj SubExp
x VName
adj_x
        SubExp -> VName -> ADM ()
updateSubExpAdj SubExp
y VName
adj_y
    --
    SubExp SubExp
se -> do
      (VName
_pat_v, VName
pat_adj) <- Pat Type -> StmAux () -> BasicOp -> ADM () -> ADM (VName, VName)
commonBasicOp Pat Type
pat StmAux ()
aux BasicOp
e ADM ()
m
      forall a. ADM a -> ADM a
returnSweepCode forall a b. (a -> b) -> a -> b
$ SubExp -> VName -> ADM ()
updateSubExpAdj SubExp
se VName
pat_adj
    --
    Assert {} ->
      forall (f :: * -> *) a. Functor f => f a -> f ()
void forall a b. (a -> b) -> a -> b
$ Pat Type -> StmAux () -> BasicOp -> ADM () -> ADM (VName, VName)
commonBasicOp Pat Type
pat StmAux ()
aux BasicOp
e ADM ()
m
    --
    ArrayLit [SubExp]
elems Type
_ -> do
      (VName
_pat_v, VName
pat_adj) <- Pat Type -> StmAux () -> BasicOp -> ADM () -> ADM (VName, VName)
commonBasicOp Pat Type
pat StmAux ()
aux BasicOp
e ADM ()
m
      Type
t <- forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
pat_adj
      forall a. ADM a -> ADM a
returnSweepCode forall a b. (a -> b) -> a -> b
$ do
        forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [(Int64
0 :: Int64) ..] [SubExp]
elems) forall a b. (a -> b) -> a -> b
$ \(Int64
i, SubExp
se) -> do
          let slice :: Slice SubExp
slice = Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice Type
t [forall d. d -> DimIndex d
DimFix (forall v. IsValue v => v -> SubExp
constant Int64
i)]
          SubExp -> VName -> ADM ()
updateSubExpAdj SubExp
se forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"elem_adj" forall a b. (a -> b) -> a -> b
$ forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
pat_adj Slice SubExp
slice
    --
    Index VName
arr Slice SubExp
slice -> do
      (VName
_pat_v, VName
pat_adj) <- Pat Type -> StmAux () -> BasicOp -> ADM () -> ADM (VName, VName)
commonBasicOp Pat Type
pat StmAux ()
aux BasicOp
e ADM ()
m
      forall a. ADM a -> ADM a
returnSweepCode forall a b. (a -> b) -> a -> b
$ do
        forall (f :: * -> *) a. Functor f => f a -> f ()
void forall a b. (a -> b) -> a -> b
$ Slice SubExp -> VName -> VName -> ADM ()
updateAdjSlice Slice SubExp
slice VName
arr VName
pat_adj
    FlatIndex {} -> forall a. HasCallStack => [Char] -> a
error [Char]
"FlatIndex not handled by AD yet."
    FlatUpdate {} -> forall a. HasCallStack => [Char] -> a
error [Char]
"FlatUpdate not handled by AD yet."
    --
    Opaque OpaqueOp
_ SubExp
se -> do
      (VName
_pat_v, VName
pat_adj) <- Pat Type -> StmAux () -> BasicOp -> ADM () -> ADM (VName, VName)
commonBasicOp Pat Type
pat StmAux ()
aux BasicOp
e ADM ()
m
      forall a. ADM a -> ADM a
returnSweepCode forall a b. (a -> b) -> a -> b
$ SubExp -> VName -> ADM ()
updateSubExpAdj SubExp
se VName
pat_adj
    --
    Reshape ReshapeKind
k Shape
_ VName
arr -> do
      (VName
_pat_v, VName
pat_adj) <- Pat Type -> StmAux () -> BasicOp -> ADM () -> ADM (VName, VName)
commonBasicOp Pat Type
pat StmAux ()
aux BasicOp
e ADM ()
m
      forall a. ADM a -> ADM a
returnSweepCode forall a b. (a -> b) -> a -> b
$ do
        Shape
arr_shape <- forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
arr
        forall (f :: * -> *) a. Functor f => f a -> f ()
void forall a b. (a -> b) -> a -> b
$
          VName -> VName -> ADM ()
updateAdj VName
arr forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"adj_reshape" forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
            ReshapeKind -> Shape -> VName -> BasicOp
Reshape ReshapeKind
k Shape
arr_shape VName
pat_adj
    --
    Rearrange [Int]
perm VName
arr -> do
      (VName
_pat_v, VName
pat_adj) <- Pat Type -> StmAux () -> BasicOp -> ADM () -> ADM (VName, VName)
commonBasicOp Pat Type
pat StmAux ()
aux BasicOp
e ADM ()
m
      forall a. ADM a -> ADM a
returnSweepCode forall a b. (a -> b) -> a -> b
$
        forall (f :: * -> *) a. Functor f => f a -> f ()
void forall a b. (a -> b) -> a -> b
$
          VName -> VName -> ADM ()
updateAdj VName
arr forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"adj_rearrange" forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
            [Int] -> VName -> BasicOp
Rearrange ([Int] -> [Int]
rearrangeInverse [Int]
perm) VName
pat_adj
    --
    Rotate [SubExp]
rots VName
arr -> do
      (VName
_pat_v, VName
pat_adj) <- Pat Type -> StmAux () -> BasicOp -> ADM () -> ADM (VName, VName)
commonBasicOp Pat Type
pat StmAux ()
aux BasicOp
e ADM ()
m
      forall a. ADM a -> ADM a
returnSweepCode forall a b. (a -> b) -> a -> b
$ do
        let neg :: SubExp -> Exp rep
neg = forall rep. BasicOp -> Exp rep
BasicOp forall b c a. (b -> c) -> (a -> b) -> a -> c
. BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Overflow -> BinOp
Sub IntType
Int64 Overflow
OverflowWrap) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0)
        [SubExp]
rots' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"rot_neg" forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {rep}. SubExp -> Exp rep
neg) [SubExp]
rots
        forall (f :: * -> *) a. Functor f => f a -> f ()
void forall a b. (a -> b) -> a -> b
$
          VName -> VName -> ADM ()
updateAdj VName
arr forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"adj_rotate" forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
            [SubExp] -> VName -> BasicOp
Rotate [SubExp]
rots' VName
pat_adj
    --
    Replicate (Shape [SubExp]
ns) SubExp
x -> do
      (VName
_pat_v, VName
pat_adj) <- Pat Type -> StmAux () -> BasicOp -> ADM () -> ADM (VName, VName)
commonBasicOp Pat Type
pat StmAux ()
aux BasicOp
e ADM ()
m
      forall a. ADM a -> ADM a
returnSweepCode forall a b. (a -> b) -> a -> b
$ do
        Type
x_t <- forall t (m :: * -> *). HasScope t m => SubExp -> m Type
subExpType SubExp
x
        Lambda SOACS
lam <- Type -> ADM (Lambda SOACS)
addLambda Type
x_t
        SubExp
ne <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"zero" forall a b. (a -> b) -> a -> b
$ forall rep. Type -> Exp rep
zeroExp Type
x_t
        SubExp
n <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"rep_size" forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *).
MonadBuilder m =>
BinOp -> SubExp -> [SubExp] -> m (Exp (Rep m))
foldBinOp (IntType -> Overflow -> BinOp
Mul IntType
Int64 Overflow
OverflowUndef) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) [SubExp]
ns
        VName
pat_adj_flat <-
          forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp (VName -> [Char]
baseString VName
pat_adj forall a. Semigroup a => a -> a -> a
<> [Char]
"_flat") forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
            ReshapeKind -> Shape -> VName -> BasicOp
Reshape ReshapeKind
ReshapeArbitrary (forall d. [d] -> ShapeBase d
Shape forall a b. (a -> b) -> a -> b
$ SubExp
n forall a. a -> [a] -> [a]
: forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
x_t) VName
pat_adj
        ScremaForm SOACS
reduce <- forall rep (m :: * -> *).
(Buildable rep, MonadFreshNames m) =>
[Reduce rep] -> m (ScremaForm rep)
reduceSOAC [forall rep. Commutativity -> Lambda rep -> [SubExp] -> Reduce rep
Reduce Commutativity
Commutative Lambda SOACS
lam [SubExp
ne]]
        SubExp -> VName -> ADM ()
updateSubExpAdj SubExp
x
          forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"rep_contrib" (forall rep. Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
n [VName
pat_adj_flat] ScremaForm SOACS
reduce)
    --
    Concat Int
d (VName
arr :| [VName]
arrs) SubExp
_ -> do
      (VName
_pat_v, VName
pat_adj) <- Pat Type -> StmAux () -> BasicOp -> ADM () -> ADM (VName, VName)
commonBasicOp Pat Type
pat StmAux ()
aux BasicOp
e ADM ()
m
      forall a. ADM a -> ADM a
returnSweepCode forall a b. (a -> b) -> a -> b
$ do
        let sliceAdj :: SubExp -> [VName] -> ADM [VName]
sliceAdj SubExp
_ [] = forall (f :: * -> *) a. Applicative f => a -> f a
pure []
            sliceAdj SubExp
start (VName
v : [VName]
vs) = do
              Type
v_t <- forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
v
              let w :: SubExp
w = forall u. Int -> TypeBase Shape u -> SubExp
arraySize Int
0 Type
v_t
                  slice :: DimIndex SubExp
slice = forall d. d -> d -> d -> DimIndex d
DimSlice SubExp
start SubExp
w (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1)
              VName
pat_adj_slice <-
                forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp (VName -> [Char]
baseString VName
pat_adj forall a. Semigroup a => a -> a -> a
<> [Char]
"_slice") forall a b. (a -> b) -> a -> b
$
                  forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
                    VName -> Slice SubExp -> BasicOp
Index VName
pat_adj (Type -> Int -> [DimIndex SubExp] -> Slice SubExp
sliceAt Type
v_t Int
d [DimIndex SubExp
slice])
              SubExp
start' <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"start" forall a b. (a -> b) -> a -> b
$ forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Overflow -> BinOp
Add IntType
Int64 Overflow
OverflowUndef) SubExp
start SubExp
w
              [VName]
slices <- SubExp -> [VName] -> ADM [VName]
sliceAdj SubExp
start' [VName]
vs
              forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ VName
pat_adj_slice forall a. a -> [a] -> [a]
: [VName]
slices

        [VName]
slices <- SubExp -> [VName] -> ADM [VName]
sliceAdj (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) forall a b. (a -> b) -> a -> b
$ VName
arr forall a. a -> [a] -> [a]
: [VName]
arrs

        forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> VName -> ADM ()
updateAdj (VName
arr forall a. a -> [a] -> [a]
: [VName]
arrs) [VName]
slices
    --
    Copy VName
se -> do
      (VName
_pat_v, VName
pat_adj) <- Pat Type -> StmAux () -> BasicOp -> ADM () -> ADM (VName, VName)
commonBasicOp Pat Type
pat StmAux ()
aux BasicOp
e ADM ()
m
      forall a. ADM a -> ADM a
returnSweepCode forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Functor f => f a -> f ()
void forall a b. (a -> b) -> a -> b
$ VName -> VName -> ADM ()
updateAdj VName
se VName
pat_adj
    --
    Manifest [Int]
_ VName
se -> do
      (VName
_pat_v, VName
pat_adj) <- Pat Type -> StmAux () -> BasicOp -> ADM () -> ADM (VName, VName)
commonBasicOp Pat Type
pat StmAux ()
aux BasicOp
e ADM ()
m
      forall a. ADM a -> ADM a
returnSweepCode forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Functor f => f a -> f ()
void forall a b. (a -> b) -> a -> b
$ VName -> VName -> ADM ()
updateAdj VName
se VName
pat_adj
    --
    Scratch {} ->
      forall (f :: * -> *) a. Functor f => f a -> f ()
void forall a b. (a -> b) -> a -> b
$ Pat Type -> StmAux () -> BasicOp -> ADM () -> ADM (VName, VName)
commonBasicOp Pat Type
pat StmAux ()
aux BasicOp
e ADM ()
m
    --
    Iota SubExp
n SubExp
_ SubExp
_ IntType
t -> do
      (VName
_pat_v, VName
pat_adj) <- Pat Type -> StmAux () -> BasicOp -> ADM () -> ADM (VName, VName)
commonBasicOp Pat Type
pat StmAux ()
aux BasicOp
e ADM ()
m
      forall a. ADM a -> ADM a
returnSweepCode forall a b. (a -> b) -> a -> b
$ do
        SubExp
ne <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"zero" forall a b. (a -> b) -> a -> b
$ forall rep. Type -> Exp rep
zeroExp forall a b. (a -> b) -> a -> b
$ forall shape u. PrimType -> TypeBase shape u
Prim forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
IntType IntType
t
        Lambda SOACS
lam <- Type -> ADM (Lambda SOACS)
addLambda forall a b. (a -> b) -> a -> b
$ forall shape u. PrimType -> TypeBase shape u
Prim forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
IntType IntType
t
        ScremaForm SOACS
reduce <- forall rep (m :: * -> *).
(Buildable rep, MonadFreshNames m) =>
[Reduce rep] -> m (ScremaForm rep)
reduceSOAC [forall rep. Commutativity -> Lambda rep -> [SubExp] -> Reduce rep
Reduce Commutativity
Commutative Lambda SOACS
lam [SubExp
ne]]
        SubExp -> VName -> ADM ()
updateSubExpAdj SubExp
n
          forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"iota_contrib" (forall rep. Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
n [VName
pat_adj] ScremaForm SOACS
reduce)
    --
    Update Safety
safety VName
arr Slice SubExp
slice SubExp
v -> do
      (VName
_pat_v, VName
pat_adj) <- Pat Type -> StmAux () -> BasicOp -> ADM () -> ADM (VName, VName)
commonBasicOp Pat Type
pat StmAux ()
aux BasicOp
e ADM ()
m
      forall a. ADM a -> ADM a
returnSweepCode forall a b. (a -> b) -> a -> b
$ do
        VName
v_adj <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"update_val_adj" forall a b. (a -> b) -> a -> b
$ forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
pat_adj Slice SubExp
slice
        Type
t <- forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
v_adj
        VName
v_adj_copy <-
          case Type
t of
            Array {} -> forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"update_val_adj_copy" forall a b. (a -> b) -> a -> b
$ forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy VName
v_adj
            Type
_ -> forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
v_adj
        SubExp -> VName -> ADM ()
updateSubExpAdj SubExp
v VName
v_adj_copy
        SubExp
zeroes <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"update_zero" forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. Type -> Exp rep
zeroExp forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall t (m :: * -> *). HasScope t m => SubExp -> m Type
subExpType SubExp
v
        forall (f :: * -> *) a. Functor f => f a -> f ()
void forall a b. (a -> b) -> a -> b
$
          VName -> VName -> ADM ()
updateAdj VName
arr
            forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"update_src_adj" (forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ Safety -> VName -> Slice SubExp -> SubExp -> BasicOp
Update Safety
safety VName
pat_adj Slice SubExp
slice SubExp
zeroes)
    -- See Note [Adjoints of accumulators]
    UpdateAcc VName
_ [SubExp]
is [SubExp]
vs -> do
      forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm forall a b. (a -> b) -> a -> b
$ forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
pat StmAux ()
aux forall a b. (a -> b) -> a -> b
$ forall rep. BasicOp -> Exp rep
BasicOp BasicOp
e
      ADM ()
m
      [VName]
pat_adjs <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> ADM VName
lookupAdjVal (forall dec. Pat dec -> [VName]
patNames Pat Type
pat)
      forall a. ADM a -> ADM a
returnSweepCode forall a b. (a -> b) -> a -> b
$ do
        forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
pat_adjs [SubExp]
vs) forall a b. (a -> b) -> a -> b
$ \(VName
adj, SubExp
v) -> do
          VName
adj_i <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"updateacc_val_adj" forall a b. (a -> b) -> a -> b
$ forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
adj forall a b. (a -> b) -> a -> b
$ forall d. [DimIndex d] -> Slice d
Slice forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall d. d -> DimIndex d
DimFix [SubExp]
is
          SubExp -> VName -> ADM ()
updateSubExpAdj SubExp
v VName
adj_i

vjpOps :: VjpOps
vjpOps :: VjpOps
vjpOps = ([Adj] -> [VName] -> Lambda SOACS -> ADM (Lambda SOACS))
-> (Stm SOACS -> ADM () -> ADM ()) -> VjpOps
VjpOps [Adj] -> [VName] -> Lambda SOACS -> ADM (Lambda SOACS)
diffLambda Stm SOACS -> ADM () -> ADM ()
diffStm

diffStm :: Stm SOACS -> ADM () -> ADM ()
diffStm :: Stm SOACS -> ADM () -> ADM ()
diffStm (Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux (BasicOp BasicOp
e)) ADM ()
m =
  Pat Type -> StmAux () -> BasicOp -> ADM () -> ADM ()
diffBasicOp Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux BasicOp
e ADM ()
m
diffStm stm :: Stm SOACS
stm@(Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
_ (Apply Name
f [(SubExp, Diet)]
args [RetType SOACS]
_ (Safety, SrcLoc, [SrcLoc])
_)) ADM ()
m
  | Just (PrimType
ret, [PrimType]
argts) <- forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup Name
f Map Name (PrimType, [PrimType])
builtInFunctions = do
      forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm Stm SOACS
stm
      ADM ()
m

      VName
pat_adj <- VName -> ADM VName
lookupAdjVal forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Pat Type -> ADM VName
patName Pat (LetDec SOACS)
pat
      let arg_pes :: [PrimExp VName]
arg_pes = forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith PrimType -> SubExp -> PrimExp VName
primExpFromSubExp [PrimType]
argts (forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst [(SubExp, Diet)]
args)
          pat_adj' :: PrimExp VName
pat_adj' = PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
ret (VName -> SubExp
Var VName
pat_adj)
          convert :: PrimType -> PrimType -> PrimExp VName -> PrimExp VName
convert PrimType
ft PrimType
tt
            | PrimType
ft forall a. Eq a => a -> a -> Bool
== PrimType
tt = forall a. a -> a
id
          convert (IntType IntType
ft) (IntType IntType
tt) = forall v. ConvOp -> PrimExp v -> PrimExp v
ConvOpExp (IntType -> IntType -> ConvOp
SExt IntType
ft IntType
tt)
          convert (FloatType FloatType
ft) (FloatType FloatType
tt) = forall v. ConvOp -> PrimExp v -> PrimExp v
ConvOpExp (FloatType -> FloatType -> ConvOp
FPConv FloatType
ft FloatType
tt)
          convert PrimType
Bool (FloatType FloatType
tt) = forall v. ConvOp -> PrimExp v -> PrimExp v
ConvOpExp (FloatType -> ConvOp
BToF FloatType
tt)
          convert (FloatType FloatType
ft) PrimType
Bool = forall v. ConvOp -> PrimExp v -> PrimExp v
ConvOpExp (FloatType -> ConvOp
FToB FloatType
ft)
          convert PrimType
ft PrimType
tt = forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$ [Char]
"diffStm.convert: " forall a. [a] -> [a] -> [a]
++ forall a. Pretty a => a -> [Char]
prettyString (Name
f, PrimType
ft, PrimType
tt)

      [VName]
contribs <-
        case Name -> [PrimExp VName] -> Maybe [PrimExp VName]
pdBuiltin Name
f [PrimExp VName]
arg_pes of
          Maybe [PrimExp VName]
Nothing ->
            forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$ [Char]
"No partial derivative defined for builtin function: " forall a. [a] -> [a] -> [a]
++ forall a. Pretty a => a -> [Char]
prettyString Name
f
          Just [PrimExp VName]
derivs ->
            forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (forall a b. [a] -> [b] -> [(a, b)]
zip [PrimExp VName]
derivs [PrimType]
argts) forall a b. (a -> b) -> a -> b
$ \(PrimExp VName
deriv, PrimType
argt) ->
              forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"contrib" forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrimType -> PrimType -> PrimExp VName -> PrimExp VName
convert PrimType
ret PrimType
argt forall a b. (a -> b) -> a -> b
$ PrimExp VName
pat_adj' forall v. PrimExp v -> PrimExp v -> PrimExp v
~*~ PrimExp VName
deriv

      forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ SubExp -> VName -> ADM ()
updateSubExpAdj (forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst [(SubExp, Diet)]
args) [VName]
contribs
diffStm stm :: Stm SOACS
stm@(Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
_ (Match [SubExp]
ses [Case (Body SOACS)]
cases Body SOACS
defbody MatchDec (BranchType SOACS)
_)) ADM ()
m = do
  forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm Stm SOACS
stm
  ADM ()
m
  forall a. ADM a -> ADM a
returnSweepCode forall a b. (a -> b) -> a -> b
$ do
    let cases_free :: [Names]
cases_free = forall a b. (a -> b) -> [a] -> [b]
map forall a. FreeIn a => a -> Names
freeIn [Case (Body SOACS)]
cases
        defbody_free :: Names
defbody_free = forall a. FreeIn a => a -> Names
freeIn Body SOACS
defbody
        branches_free :: [VName]
branches_free = Names -> [VName]
namesToList forall a b. (a -> b) -> a -> b
$ forall a. Monoid a => [a] -> a
mconcat forall a b. (a -> b) -> a -> b
$ Names
defbody_free forall a. a -> [a] -> [a]
: [Names]
cases_free

    [Adj]
adjs <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> ADM Adj
lookupAdj forall a b. (a -> b) -> a -> b
$ forall dec. Pat dec -> [VName]
patNames Pat (LetDec SOACS)
pat

    [VName]
branches_free_adj <-
      ( forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Int -> [a] -> [a]
takeLast (forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
branches_free)
          forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m [VName]
letTupExp [Char]
"branch_adj"
          forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Exp rep -> m (Exp rep)
renameExp
        )
        forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
[SubExp]
-> [Case (m (Body (Rep m)))] -> m (Body (Rep m)) -> m (Exp (Rep m))
eMatch
          [SubExp]
ses
          (forall a b. (a -> b) -> [a] -> [b]
map (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a b. (a -> b) -> a -> b
$ [Adj] -> [VName] -> Body SOACS -> ADM (Body SOACS)
diffBody [Adj]
adjs [VName]
branches_free) [Case (Body SOACS)]
cases)
          ([Adj] -> [VName] -> Body SOACS -> ADM (Body SOACS)
diffBody [Adj]
adjs [VName]
branches_free Body SOACS
defbody)
    forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> VName -> ADM ()
insAdj [VName]
branches_free [VName]
branches_free_adj
diffStm (Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux (Op OpC SOACS SOACS
soac)) ADM ()
m =
  VjpOps -> Pat Type -> StmAux () -> SOAC SOACS -> ADM () -> ADM ()
vjpSOAC VjpOps
vjpOps Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux OpC SOACS SOACS
soac ADM ()
m
diffStm (Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux loop :: Exp SOACS
loop@DoLoop {}) ADM ()
m =
  (Stms SOACS -> ADM ())
-> Pat Type -> StmAux () -> Exp SOACS -> ADM () -> ADM ()
diffLoop Stms SOACS -> ADM ()
diffStms Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux Exp SOACS
loop ADM ()
m
-- See Note [Adjoints of accumulators]
diffStm stm :: Stm SOACS
stm@(Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
_aux (WithAcc [WithAccInput SOACS]
inputs Lambda SOACS
lam)) ADM ()
m = do
  forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm Stm SOACS
stm
  ADM ()
m
  forall a. ADM a -> ADM a
returnSweepCode forall a b. (a -> b) -> a -> b
$ do
    [Adj]
adjs <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> ADM Adj
lookupAdj forall a b. (a -> b) -> a -> b
$ forall dec. Pat dec -> [VName]
patNames Pat (LetDec SOACS)
pat
    Lambda SOACS
lam' <- forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda SOACS
lam
    [VName]
free_vars <- forall (m :: * -> *) a.
Applicative m =>
(a -> m Bool) -> [a] -> m [a]
filterM VName -> ADM Bool
isActive forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList forall a b. (a -> b) -> a -> b
$ forall a. FreeIn a => a -> Names
freeIn Lambda SOACS
lam'
    [VName]
free_accs <- forall (m :: * -> *) a.
Applicative m =>
(a -> m Bool) -> [a] -> m [a]
filterM (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall shape u. TypeBase shape u -> Bool
isAcc forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType) [VName]
free_vars
    let free_vars' :: [VName]
free_vars' = [VName]
free_vars forall a. Eq a => [a] -> [a] -> [a]
\\ [VName]
free_accs
    Lambda SOACS
lam'' <- [Adj] -> [VName] -> Lambda SOACS -> ADM (Lambda SOACS)
diffLambda' [Adj]
adjs [VName]
free_vars' Lambda SOACS
lam'
    [WithAccInput SOACS]
inputs' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall {m :: * -> *} {rep} {a} {b} {b}.
(Rename (OpC rep rep), Rename (LetDec rep), Rename (ExpDec rep),
 Rename (BodyDec rep), Rename (FParamInfo rep),
 Rename (LParamInfo rep), Rename (RetType rep),
 Rename (BranchType rep), MonadFreshNames m) =>
(a, b, Maybe (Lambda rep, b)) -> m (a, b, Maybe (Lambda rep, b))
renameInputLambda [WithAccInput SOACS]
inputs
    [VName]
free_adjs <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m [VName]
letTupExp [Char]
"with_acc_contrib" forall a b. (a -> b) -> a -> b
$ forall rep. [WithAccInput rep] -> Lambda rep -> Exp rep
WithAcc [WithAccInput SOACS]
inputs' Lambda SOACS
lam''
    forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> VName -> ADM ()
insAdj ([VName]
arrs forall a. Semigroup a => a -> a -> a
<> [VName]
free_vars') [VName]
free_adjs
  where
    arrs :: [VName]
arrs = forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (\(Shape
_, [VName]
as, Maybe (Lambda SOACS, [SubExp])
_) -> [VName]
as) [WithAccInput SOACS]
inputs
    renameInputLambda :: (a, b, Maybe (Lambda rep, b)) -> m (a, b, Maybe (Lambda rep, b))
renameInputLambda (a
shape, b
as, Just (Lambda rep
f, b
nes)) = do
      Lambda rep
f' <- forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda rep
f
      forall (f :: * -> *) a. Applicative f => a -> f a
pure (a
shape, b
as, forall a. a -> Maybe a
Just (Lambda rep
f', b
nes))
    renameInputLambda (a, b, Maybe (Lambda rep, b))
input = forall (f :: * -> *) a. Applicative f => a -> f a
pure (a, b, Maybe (Lambda rep, b))
input
    diffLambda' :: [Adj] -> [VName] -> Lambda SOACS -> ADM (Lambda SOACS)
diffLambda' [Adj]
res_adjs [VName]
get_adjs_for (Lambda [LParam SOACS]
params Body SOACS
body [Type]
ts) =
      forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (forall rep dec. (LParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfLParams [LParam SOACS]
params) forall a b. (a -> b) -> a -> b
$ do
        Body () Stms SOACS
stms Result
res <- [Adj] -> [VName] -> Body SOACS -> ADM (Body SOACS)
diffBody [Adj]
res_adjs [VName]
get_adjs_for Body SOACS
body
        let body' :: Body SOACS
body' = forall rep. BodyDec rep -> Stms rep -> Result -> Body rep
Body () Stms SOACS
stms forall a b. (a -> b) -> a -> b
$ forall a. Int -> [a] -> [a]
take (forall (t :: * -> *) a. Foldable t => t a -> Int
length [WithAccInput SOACS]
inputs) Result
res forall a. Semigroup a => a -> a -> a
<> forall a. Int -> [a] -> [a]
takeLast (forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
get_adjs_for) Result
res
        [Type]
ts' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType [VName]
get_adjs_for
        forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall rep. [LParam rep] -> Body rep -> [Type] -> Lambda rep
Lambda [LParam SOACS]
params Body SOACS
body' forall a b. (a -> b) -> a -> b
$ forall a. Int -> [a] -> [a]
take (forall (t :: * -> *) a. Foldable t => t a -> Int
length [WithAccInput SOACS]
inputs) [Type]
ts forall a. Semigroup a => a -> a -> a
<> [Type]
ts'
diffStm Stm SOACS
stm ADM ()
_ = forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$ [Char]
"diffStm unhandled:\n" forall a. [a] -> [a] -> [a]
++ forall a. Pretty a => a -> [Char]
prettyString Stm SOACS
stm

diffStms :: Stms SOACS -> ADM ()
diffStms :: Stms SOACS -> ADM ()
diffStms Stms SOACS
all_stms
  | Just (Stm SOACS
stm, Stms SOACS
stms) <- forall rep. Stms rep -> Maybe (Stm rep, Stms rep)
stmsHead Stms SOACS
all_stms = do
      (Substitutions
subst, Stms SOACS
copy_stms) <- Stm SOACS -> ADM (Substitutions, Stms SOACS)
copyConsumedArrsInStm Stm SOACS
stm
      let (Stm SOACS
stm', Stms SOACS
stms') = forall a. Substitute a => Substitutions -> a -> a
substituteNames Substitutions
subst (Stm SOACS
stm, Stms SOACS
stms)
      Stms SOACS -> ADM ()
diffStms Stms SOACS
copy_stms forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Stm SOACS -> ADM () -> ADM ()
diffStm Stm SOACS
stm' (Stms SOACS -> ADM ()
diffStms Stms SOACS
stms')
      forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall k a. Map k a -> [(k, a)]
M.toList Substitutions
subst) forall a b. (a -> b) -> a -> b
$ \(VName
from, VName
to) ->
        VName -> Adj -> ADM ()
setAdj VName
from forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName -> ADM Adj
lookupAdj VName
to
  | Bool
otherwise =
      forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

-- | Preprocess statements before differentiating.
-- For now, it's just stripmining.
preprocess :: Stms SOACS -> ADM (Stms SOACS)
preprocess :: Stms SOACS -> ADM (Stms SOACS)
preprocess = Stms SOACS -> ADM (Stms SOACS)
stripmineStms

diffBody :: [Adj] -> [VName] -> Body SOACS -> ADM (Body SOACS)
diffBody :: [Adj] -> [VName] -> Body SOACS -> ADM (Body SOACS)
diffBody [Adj]
res_adjs [VName]
get_adjs_for (Body () Stms SOACS
stms Result
res) = forall a. ADM a -> ADM a
subAD forall a b. (a -> b) -> a -> b
$
  forall a. ADM a -> ADM a
subSubsts forall a b. (a -> b) -> a -> b
$ do
    let onResult :: SubExpRes -> Adj -> ADM ()
onResult (SubExpRes Certs
_ (Constant PrimValue
_)) Adj
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
        onResult (SubExpRes Certs
_ (Var VName
v)) Adj
v_adj = forall (f :: * -> *) a. Functor f => f a -> f ()
void forall a b. (a -> b) -> a -> b
$ VName -> VName -> ADM ()
updateAdj VName
v forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Adj -> ADM VName
adjVal Adj
v_adj
    ([VName]
adjs, Stms SOACS
stms') <- forall (m :: * -> *) a.
MonadBuilder m =>
m a -> m (a, Stms (Rep m))
collectStms forall a b. (a -> b) -> a -> b
$ do
      forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ SubExpRes -> Adj -> ADM ()
onResult (forall a. Int -> [a] -> [a]
takeLast (forall (t :: * -> *) a. Foldable t => t a -> Int
length [Adj]
res_adjs) Result
res) [Adj]
res_adjs
      Stms SOACS -> ADM ()
diffStms forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Stms SOACS -> ADM (Stms SOACS)
preprocess Stms SOACS
stms
      forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> ADM VName
lookupAdjVal [VName]
get_adjs_for
    forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall rep. BodyDec rep -> Stms rep -> Result -> Body rep
Body () Stms SOACS
stms' forall a b. (a -> b) -> a -> b
$ Result
res forall a. Semigroup a => a -> a -> a
<> [VName] -> Result
varsRes [VName]
adjs

diffLambda :: [Adj] -> [VName] -> Lambda SOACS -> ADM (Lambda SOACS)
diffLambda :: [Adj] -> [VName] -> Lambda SOACS -> ADM (Lambda SOACS)
diffLambda [Adj]
res_adjs [VName]
get_adjs_for (Lambda [LParam SOACS]
params Body SOACS
body [Type]
_) =
  forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (forall rep dec. (LParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfLParams [LParam SOACS]
params) forall a b. (a -> b) -> a -> b
$ do
    Body () Stms SOACS
stms Result
res <- [Adj] -> [VName] -> Body SOACS -> ADM (Body SOACS)
diffBody [Adj]
res_adjs [VName]
get_adjs_for Body SOACS
body
    let body' :: Body SOACS
body' = forall rep. BodyDec rep -> Stms rep -> Result -> Body rep
Body () Stms SOACS
stms forall a b. (a -> b) -> a -> b
$ forall a. Int -> [a] -> [a]
takeLast (forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
get_adjs_for) Result
res
    [Type]
ts' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType [VName]
get_adjs_for
    forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall rep. [LParam rep] -> Body rep -> [Type] -> Lambda rep
Lambda [LParam SOACS]
params Body SOACS
body' [Type]
ts'

revVJP :: MonadFreshNames m => Scope SOACS -> Lambda SOACS -> m (Lambda SOACS)
revVJP :: forall (m :: * -> *).
MonadFreshNames m =>
Scope SOACS -> Lambda SOACS -> m (Lambda SOACS)
revVJP Scope SOACS
scope (Lambda [LParam SOACS]
params Body SOACS
body [Type]
ts) =
  forall (m :: * -> *) a. MonadFreshNames m => ADM a -> m a
runADM forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (Scope SOACS
scope forall a. Semigroup a => a -> a -> a
<> forall rep dec. (LParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfLParams [LParam SOACS]
params) forall a b. (a -> b) -> a -> b
$ do
    [Param Type]
params_adj <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (forall a b. [a] -> [b] -> [(a, b)]
zip (forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp (forall rep. Body rep -> Result
bodyResult Body SOACS
body)) [Type]
ts) forall a b. (a -> b) -> a -> b
$ \(SubExp
se, Type
t) ->
      forall dec. Attrs -> VName -> dec -> Param dec
Param forall a. Monoid a => a
mempty forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall b a. b -> (a -> b) -> Maybe a -> b
maybe (forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"const_adj") VName -> ADM VName
adjVName (SubExp -> Maybe VName
subExpVar SubExp
se) forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure Type
t

    Body SOACS
body' <-
      forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (forall rep dec. (LParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfLParams [Param Type]
params_adj) forall a b. (a -> b) -> a -> b
$
        [Adj] -> [VName] -> Body SOACS -> ADM (Body SOACS)
diffBody
          (forall a b. (a -> b) -> [a] -> [b]
map forall t. Param t -> Adj
adjFromParam [Param Type]
params_adj)
          (forall a b. (a -> b) -> [a] -> [b]
map forall dec. Param dec -> VName
paramName [LParam SOACS]
params)
          Body SOACS
body

    forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall rep. [LParam rep] -> Body rep -> [Type] -> Lambda rep
Lambda ([LParam SOACS]
params forall a. [a] -> [a] -> [a]
++ [Param Type]
params_adj) Body SOACS
body' ([Type]
ts forall a. Semigroup a => a -> a -> a
<> forall a b. (a -> b) -> [a] -> [b]
map forall dec. Typed dec => Param dec -> Type
paramType [LParam SOACS]
params)

-- Note [Adjoints of accumulators]
--
-- The general case of taking adjoints of WithAcc is tricky.  We make
-- some assumptions and lay down a basic design.
--
-- First, we assume that any WithAccs that occur in the program are
-- the result of previous invocations of VJP.  This means we can rely
-- on the operator having a constant adjoint (it's some kind of
-- addition).
--
-- Second, the adjoint of an accumulator is an array of the same type
-- as the underlying array.  For example, the adjoint type of the
-- primal type 'acc(c, [n], {f64})' is '[n]f64'.  In principle the
-- adjoint of 'acc(c, [n], {f64,f32})' should be two arrays of type
-- '[]f64', '[]f32'.  Our current design assumes that adjoints are
-- single variables.  This is fixable.
--
-- # Adjoint of UpdateAcc
--
--   Consider primal code
--
--     update_acc(acc, i, v)
--
--   Interpreted as an imperative statement, this means
--
--     acc[i] ⊕= v
--
--   for some '⊕'.  Normally all the compiler knows of '⊕' is that it
--   is associative and commutative, but because we assume that all
--   accumulators are the result of previous AD transformations, we
--   can assume that '⊕' actually behaves like addition - that is, has
--   unit partial derivatives.  So the return sweep is
--
--     v += acc_adj[i]
--
-- # Adjoint of Map
--
-- Suppose we have primal code
--
--   let acc' =
--     map (...) acc
--
-- where "acc : acc(c, [n], {f64})" and the width of the Map is "w".
-- Our normal transformation for Map input arrays is to similarly map
-- their adjoint, but clearly this doesn't work here because the
-- semantics of mapping an adjoint is an "implicit replicate".  So
-- when generating the return sweep we actually perform that
-- replication:
--
--   map (...) (replicate w acc_adj)
--
-- But what about the contributions to "acc'"?  Those we also have to
-- take special care of.  The result of the map itself is actually a
-- multidimensional array:
--
--   let acc_contribs =
--     map (...) (replicate w acc'_adj)
--
-- which we must then sum to add to the contribution.
--
--   acc_adj += sum(acc_contribs)
--
-- I'm slightly worried about the asymptotics of this, since my
-- intuition of this is that the contributions might be rather sparse.
-- (Maybe completely zero?  If so it will be simplified away
-- entirely.)  Perhaps a better solution is to treat
-- accumulator-inputs in the primal code as we do free variables, and
-- create accumulators for them in the return sweep.
--
-- # Consumption
--
-- A minor problem is that our usual way of handling consumption (Note
-- [Consumption]) is not viable, because accumulators are not
-- copyable.  Fortunately, while the accumulators that are consumed in
-- the forward sweep will also be present in the return sweep given
-- our current translation rules, they will be dead code.  As long as
-- we are careful to run dead code elimination after revVJP, we should
-- be good.