{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TypeFamilies #-}

-- Naming scheme:
--
-- An adjoint-related object for "x" is named "x_adj".  This means
-- both actual adjoints and statements.
--
-- Do not assume "x'" means anything related to derivatives.
module Futhark.AD.Rev (revVJP) where

import Control.Monad
import Data.List ((\\))
import Data.List.NonEmpty (NonEmpty (..))
import qualified Data.Map as M
import Futhark.AD.Derivatives
import Futhark.AD.Rev.Loop
import Futhark.AD.Rev.Monad
import Futhark.AD.Rev.SOAC
import Futhark.Analysis.PrimExp.Convert
import Futhark.Builder
import Futhark.IR.SOACS
import Futhark.Tools
import Futhark.Transform.Rename
import Futhark.Transform.Substitute
import Futhark.Util (takeLast)

patName :: Pat Type -> ADM VName
patName :: Pat Type -> ADM VName
patName (Pat [PatElem Type
pe]) = VName -> ADM VName
forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName -> ADM VName) -> VName -> ADM VName
forall a b. (a -> b) -> a -> b
$ PatElem Type -> VName
forall dec. PatElem dec -> VName
patElemName PatElem Type
pe
patName Pat Type
pat = [Char] -> ADM VName
forall a. HasCallStack => [Char] -> a
error ([Char] -> ADM VName) -> [Char] -> ADM VName
forall a b. (a -> b) -> a -> b
$ [Char]
"Expected single-element pattern: " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ Pat Type -> [Char]
forall a. Pretty a => a -> [Char]
pretty Pat Type
pat

-- The vast majority of BasicOps require no special treatment in the
-- forward pass and produce one value (and hence one adjoint).  We
-- deal with that case here.
commonBasicOp :: Pat Type -> StmAux () -> BasicOp -> ADM () -> ADM (VName, VName)
commonBasicOp :: Pat Type -> StmAux () -> BasicOp -> ADM () -> ADM (VName, VName)
commonBasicOp Pat Type
pat StmAux ()
aux BasicOp
op ADM ()
m = do
  Stm (Rep ADM) -> ADM ()
forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm (Stm (Rep ADM) -> ADM ()) -> Stm (Rep ADM) -> ADM ()
forall a b. (a -> b) -> a -> b
$ Pat (LetDec SOACS)
-> StmAux (ExpDec SOACS) -> Exp SOACS -> Stm SOACS
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
Pat (LetDec SOACS)
pat StmAux ()
StmAux (ExpDec SOACS)
aux (Exp SOACS -> Stm SOACS) -> Exp SOACS -> Stm SOACS
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp BasicOp
op
  ADM ()
m
  VName
pat_v <- Pat Type -> ADM VName
patName Pat Type
pat
  VName
pat_adj <- VName -> ADM VName
lookupAdjVal VName
pat_v
  (VName, VName) -> ADM (VName, VName)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName
pat_v, VName
pat_adj)

diffBasicOp :: Pat Type -> StmAux () -> BasicOp -> ADM () -> ADM ()
diffBasicOp :: Pat Type -> StmAux () -> BasicOp -> ADM () -> ADM ()
diffBasicOp Pat Type
pat StmAux ()
aux BasicOp
e ADM ()
m =
  case BasicOp
e of
    CmpOp CmpOp
cmp SubExp
x SubExp
y -> do
      (VName
_pat_v, VName
pat_adj) <- Pat Type -> StmAux () -> BasicOp -> ADM () -> ADM (VName, VName)
commonBasicOp Pat Type
pat StmAux ()
aux BasicOp
e ADM ()
m
      ADM () -> ADM ()
forall a. ADM a -> ADM a
returnSweepCode (ADM () -> ADM ()) -> ADM () -> ADM ()
forall a b. (a -> b) -> a -> b
$ do
        let t :: PrimType
t = CmpOp -> PrimType
cmpOpType CmpOp
cmp
            update :: VName -> ADM ()
update VName
contrib = do
              ADM () -> ADM ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (ADM () -> ADM ()) -> ADM () -> ADM ()
forall a b. (a -> b) -> a -> b
$ SubExp -> VName -> ADM ()
updateSubExpAdj SubExp
x VName
contrib
              ADM () -> ADM ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (ADM () -> ADM ()) -> ADM () -> ADM ()
forall a b. (a -> b) -> a -> b
$ SubExp -> VName -> ADM ()
updateSubExpAdj SubExp
y VName
contrib

        case PrimType
t of
          FloatType FloatType
ft ->
            VName -> ADM ()
update (VName -> ADM ())
-> (Exp SOACS -> ADM VName) -> Exp SOACS -> ADM ()
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< [Char] -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"contrib" (Exp SOACS -> ADM ()) -> Exp SOACS -> ADM ()
forall a b. (a -> b) -> a -> b
$
              SubExp
-> Body SOACS
-> Body SOACS
-> IfDec (BranchType SOACS)
-> Exp SOACS
forall rep.
SubExp -> Body rep -> Body rep -> IfDec (BranchType rep) -> Exp rep
If
                (VName -> SubExp
Var VName
pat_adj)
                ([SubExp] -> Body SOACS
forall rep. Buildable rep => [SubExp] -> Body rep
resultBody [FloatValue -> SubExp
forall v. IsValue v => v -> SubExp
constant (FloatType -> Int -> FloatValue
forall num. Real num => FloatType -> num -> FloatValue
floatValue FloatType
ft (Int
1 :: Int))])
                ([SubExp] -> Body SOACS
forall rep. Buildable rep => [SubExp] -> Body rep
resultBody [FloatValue -> SubExp
forall v. IsValue v => v -> SubExp
constant (FloatType -> Int -> FloatValue
forall num. Real num => FloatType -> num -> FloatValue
floatValue FloatType
ft (Int
0 :: Int))])
                ([TypeBase ExtShape NoUniqueness]
-> IfSort -> IfDec (TypeBase ExtShape NoUniqueness)
forall rt. [rt] -> IfSort -> IfDec rt
IfDec [PrimType -> TypeBase ExtShape NoUniqueness
forall shape u. PrimType -> TypeBase shape u
Prim (FloatType -> PrimType
FloatType FloatType
ft)] IfSort
IfNormal)
          IntType IntType
it ->
            VName -> ADM ()
update (VName -> ADM ())
-> (Exp SOACS -> ADM VName) -> Exp SOACS -> ADM ()
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< [Char] -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"contrib" (Exp SOACS -> ADM ()) -> Exp SOACS -> ADM ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp SOACS) -> BasicOp -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ ConvOp -> SubExp -> BasicOp
ConvOp (IntType -> ConvOp
BToI IntType
it) (VName -> SubExp
Var VName
pat_adj)
          PrimType
Bool ->
            VName -> ADM ()
update VName
pat_adj
          PrimType
Unit ->
            () -> ADM ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
    --
    ConvOp ConvOp
op SubExp
x -> do
      (VName
_pat_v, VName
pat_adj) <- Pat Type -> StmAux () -> BasicOp -> ADM () -> ADM (VName, VName)
commonBasicOp Pat Type
pat StmAux ()
aux BasicOp
e ADM ()
m
      ADM () -> ADM ()
forall a. ADM a -> ADM a
returnSweepCode (ADM () -> ADM ()) -> ADM () -> ADM ()
forall a b. (a -> b) -> a -> b
$ do
        VName
contrib <-
          [Char] -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"contrib" (Exp (Rep ADM) -> ADM VName) -> Exp (Rep ADM) -> ADM VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp SOACS) -> BasicOp -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ ConvOp -> SubExp -> BasicOp
ConvOp (ConvOp -> ConvOp
flipConvOp ConvOp
op) (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
pat_adj
        SubExp -> VName -> ADM ()
updateSubExpAdj SubExp
x VName
contrib
    --
    UnOp UnOp
op SubExp
x -> do
      (VName
_pat_v, VName
pat_adj) <- Pat Type -> StmAux () -> BasicOp -> ADM () -> ADM (VName, VName)
commonBasicOp Pat Type
pat StmAux ()
aux BasicOp
e ADM ()
m

      ADM () -> ADM ()
forall a. ADM a -> ADM a
returnSweepCode (ADM () -> ADM ()) -> ADM () -> ADM ()
forall a b. (a -> b) -> a -> b
$ do
        let t :: PrimType
t = UnOp -> PrimType
unOpType UnOp
op
        VName
contrib <- do
          let x_pe :: PrimExp VName
x_pe = PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
t SubExp
x
              pat_adj' :: PrimExp VName
pat_adj' = PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
t (VName -> SubExp
Var VName
pat_adj)
              dx :: PrimExp VName
dx = UnOp -> PrimExp VName -> PrimExp VName
pdUnOp UnOp
op PrimExp VName
x_pe
          [Char] -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"contrib" (Exp SOACS -> ADM VName)
-> (PrimExp VName -> ADM (Exp SOACS)) -> PrimExp VName -> ADM VName
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< PrimExp VName -> ADM (Exp SOACS)
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp (PrimExp VName -> ADM VName) -> PrimExp VName -> ADM VName
forall a b. (a -> b) -> a -> b
$ PrimExp VName
pat_adj' PrimExp VName -> PrimExp VName -> PrimExp VName
forall v. PrimExp v -> PrimExp v -> PrimExp v
~*~ PrimExp VName
dx

        SubExp -> VName -> ADM ()
updateSubExpAdj SubExp
x VName
contrib
    --
    BinOp BinOp
op SubExp
x SubExp
y -> do
      (VName
_pat_v, VName
pat_adj) <- Pat Type -> StmAux () -> BasicOp -> ADM () -> ADM (VName, VName)
commonBasicOp Pat Type
pat StmAux ()
aux BasicOp
e ADM ()
m

      ADM () -> ADM ()
forall a. ADM a -> ADM a
returnSweepCode (ADM () -> ADM ()) -> ADM () -> ADM ()
forall a b. (a -> b) -> a -> b
$ do
        let t :: PrimType
t = BinOp -> PrimType
binOpType BinOp
op
            (PrimExp VName
wrt_x, PrimExp VName
wrt_y) =
              BinOp
-> PrimExp VName -> PrimExp VName -> (PrimExp VName, PrimExp VName)
pdBinOp BinOp
op (PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
t SubExp
x) (PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
t SubExp
y)

            pat_adj' :: PrimExp VName
pat_adj' = PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
t (SubExp -> PrimExp VName) -> SubExp -> PrimExp VName
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
pat_adj

        VName
adj_x <- [Char] -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"binop_x_adj" (Exp SOACS -> ADM VName)
-> (PrimExp VName -> ADM (Exp SOACS)) -> PrimExp VName -> ADM VName
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< PrimExp VName -> ADM (Exp SOACS)
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp (PrimExp VName -> ADM VName) -> PrimExp VName -> ADM VName
forall a b. (a -> b) -> a -> b
$ PrimExp VName
pat_adj' PrimExp VName -> PrimExp VName -> PrimExp VName
forall v. PrimExp v -> PrimExp v -> PrimExp v
~*~ PrimExp VName
wrt_x
        VName
adj_y <- [Char] -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"binop_y_adj" (Exp SOACS -> ADM VName)
-> (PrimExp VName -> ADM (Exp SOACS)) -> PrimExp VName -> ADM VName
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< PrimExp VName -> ADM (Exp SOACS)
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp (PrimExp VName -> ADM VName) -> PrimExp VName -> ADM VName
forall a b. (a -> b) -> a -> b
$ PrimExp VName
pat_adj' PrimExp VName -> PrimExp VName -> PrimExp VName
forall v. PrimExp v -> PrimExp v -> PrimExp v
~*~ PrimExp VName
wrt_y
        SubExp -> VName -> ADM ()
updateSubExpAdj SubExp
x VName
adj_x
        SubExp -> VName -> ADM ()
updateSubExpAdj SubExp
y VName
adj_y
    --
    SubExp SubExp
se -> do
      (VName
_pat_v, VName
pat_adj) <- Pat Type -> StmAux () -> BasicOp -> ADM () -> ADM (VName, VName)
commonBasicOp Pat Type
pat StmAux ()
aux BasicOp
e ADM ()
m
      ADM () -> ADM ()
forall a. ADM a -> ADM a
returnSweepCode (ADM () -> ADM ()) -> ADM () -> ADM ()
forall a b. (a -> b) -> a -> b
$ SubExp -> VName -> ADM ()
updateSubExpAdj SubExp
se VName
pat_adj
    --
    Assert {} ->
      ADM (VName, VName) -> ADM ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (ADM (VName, VName) -> ADM ()) -> ADM (VName, VName) -> ADM ()
forall a b. (a -> b) -> a -> b
$ Pat Type -> StmAux () -> BasicOp -> ADM () -> ADM (VName, VName)
commonBasicOp Pat Type
pat StmAux ()
aux BasicOp
e ADM ()
m
    --
    ArrayLit [SubExp]
elems Type
_ -> do
      (VName
_pat_v, VName
pat_adj) <- Pat Type -> StmAux () -> BasicOp -> ADM () -> ADM (VName, VName)
commonBasicOp Pat Type
pat StmAux ()
aux BasicOp
e ADM ()
m
      Type
t <- VName -> ADM Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
pat_adj
      ADM () -> ADM ()
forall a. ADM a -> ADM a
returnSweepCode (ADM () -> ADM ()) -> ADM () -> ADM ()
forall a b. (a -> b) -> a -> b
$ do
        [(Int64, SubExp)] -> ((Int64, SubExp) -> ADM ()) -> ADM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Int64] -> [SubExp] -> [(Int64, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [(Int64
0 :: Int64) ..] [SubExp]
elems) (((Int64, SubExp) -> ADM ()) -> ADM ())
-> ((Int64, SubExp) -> ADM ()) -> ADM ()
forall a b. (a -> b) -> a -> b
$ \(Int64
i, SubExp
se) -> do
          let slice :: Slice SubExp
slice = Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice Type
t [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (Int64 -> SubExp
forall v. IsValue v => v -> SubExp
constant Int64
i)]
          SubExp -> VName -> ADM ()
updateSubExpAdj SubExp
se (VName -> ADM ())
-> (Exp SOACS -> ADM VName) -> Exp SOACS -> ADM ()
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< [Char] -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"elem_adj" (Exp SOACS -> ADM ()) -> Exp SOACS -> ADM ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp SOACS) -> BasicOp -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
pat_adj Slice SubExp
slice
    --
    Index VName
arr Slice SubExp
slice -> do
      (VName
_pat_v, VName
pat_adj) <- Pat Type -> StmAux () -> BasicOp -> ADM () -> ADM (VName, VName)
commonBasicOp Pat Type
pat StmAux ()
aux BasicOp
e ADM ()
m
      ADM () -> ADM ()
forall a. ADM a -> ADM a
returnSweepCode (ADM () -> ADM ()) -> ADM () -> ADM ()
forall a b. (a -> b) -> a -> b
$ do
        ADM () -> ADM ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (ADM () -> ADM ()) -> ADM () -> ADM ()
forall a b. (a -> b) -> a -> b
$ Slice SubExp -> VName -> VName -> ADM ()
updateAdjSlice Slice SubExp
slice VName
arr VName
pat_adj
    FlatIndex {} -> [Char] -> ADM ()
forall a. HasCallStack => [Char] -> a
error [Char]
"FlatIndex not handled by AD yet."
    FlatUpdate {} -> [Char] -> ADM ()
forall a. HasCallStack => [Char] -> a
error [Char]
"FlatUpdate not handled by AD yet."
    --
    Opaque OpaqueOp
_ SubExp
se -> do
      (VName
_pat_v, VName
pat_adj) <- Pat Type -> StmAux () -> BasicOp -> ADM () -> ADM (VName, VName)
commonBasicOp Pat Type
pat StmAux ()
aux BasicOp
e ADM ()
m
      ADM () -> ADM ()
forall a. ADM a -> ADM a
returnSweepCode (ADM () -> ADM ()) -> ADM () -> ADM ()
forall a b. (a -> b) -> a -> b
$ SubExp -> VName -> ADM ()
updateSubExpAdj SubExp
se VName
pat_adj
    --
    Reshape ShapeChange SubExp
_ VName
arr -> do
      (VName
_pat_v, VName
pat_adj) <- Pat Type -> StmAux () -> BasicOp -> ADM () -> ADM (VName, VName)
commonBasicOp Pat Type
pat StmAux ()
aux BasicOp
e ADM ()
m
      ADM () -> ADM ()
forall a. ADM a -> ADM a
returnSweepCode (ADM () -> ADM ()) -> ADM () -> ADM ()
forall a b. (a -> b) -> a -> b
$ do
        [SubExp]
arr_dims <- Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims (Type -> [SubExp]) -> ADM Type -> ADM [SubExp]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ADM Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
arr
        ADM () -> ADM ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (ADM () -> ADM ()) -> ADM () -> ADM ()
forall a b. (a -> b) -> a -> b
$
          VName -> VName -> ADM ()
updateAdj VName
arr (VName -> ADM ())
-> (Exp SOACS -> ADM VName) -> Exp SOACS -> ADM ()
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< [Char] -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"adj_reshape" (Exp SOACS -> ADM ()) -> Exp SOACS -> ADM ()
forall a b. (a -> b) -> a -> b
$
            BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp SOACS) -> BasicOp -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ ShapeChange SubExp -> VName -> BasicOp
Reshape ((SubExp -> DimChange SubExp) -> [SubExp] -> ShapeChange SubExp
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> DimChange SubExp
forall d. d -> DimChange d
DimNew [SubExp]
arr_dims) VName
pat_adj
    --
    Rearrange [Int]
perm VName
arr -> do
      (VName
_pat_v, VName
pat_adj) <- Pat Type -> StmAux () -> BasicOp -> ADM () -> ADM (VName, VName)
commonBasicOp Pat Type
pat StmAux ()
aux BasicOp
e ADM ()
m
      ADM () -> ADM ()
forall a. ADM a -> ADM a
returnSweepCode (ADM () -> ADM ()) -> ADM () -> ADM ()
forall a b. (a -> b) -> a -> b
$
        ADM () -> ADM ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (ADM () -> ADM ()) -> ADM () -> ADM ()
forall a b. (a -> b) -> a -> b
$
          VName -> VName -> ADM ()
updateAdj VName
arr (VName -> ADM ())
-> (Exp SOACS -> ADM VName) -> Exp SOACS -> ADM ()
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< [Char] -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"adj_rearrange" (Exp SOACS -> ADM ()) -> Exp SOACS -> ADM ()
forall a b. (a -> b) -> a -> b
$
            BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp SOACS) -> BasicOp -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ [Int] -> VName -> BasicOp
Rearrange ([Int] -> [Int]
rearrangeInverse [Int]
perm) VName
pat_adj
    --
    Rotate [SubExp]
rots VName
arr -> do
      (VName
_pat_v, VName
pat_adj) <- Pat Type -> StmAux () -> BasicOp -> ADM () -> ADM (VName, VName)
commonBasicOp Pat Type
pat StmAux ()
aux BasicOp
e ADM ()
m
      ADM () -> ADM ()
forall a. ADM a -> ADM a
returnSweepCode (ADM () -> ADM ()) -> ADM () -> ADM ()
forall a b. (a -> b) -> a -> b
$ do
        let neg :: SubExp -> Exp rep
neg = BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp rep) -> (SubExp -> BasicOp) -> SubExp -> Exp rep
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Overflow -> BinOp
Sub IntType
Int64 Overflow
OverflowWrap) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0)
        [SubExp]
rots' <- (SubExp -> ADM SubExp) -> [SubExp] -> ADM [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ([Char] -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"rot_neg" (Exp SOACS -> ADM SubExp)
-> (SubExp -> Exp SOACS) -> SubExp -> ADM SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> Exp SOACS
forall rep. SubExp -> Exp rep
neg) [SubExp]
rots
        ADM () -> ADM ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (ADM () -> ADM ()) -> ADM () -> ADM ()
forall a b. (a -> b) -> a -> b
$
          VName -> VName -> ADM ()
updateAdj VName
arr (VName -> ADM ())
-> (Exp SOACS -> ADM VName) -> Exp SOACS -> ADM ()
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< [Char] -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"adj_rotate" (Exp SOACS -> ADM ()) -> Exp SOACS -> ADM ()
forall a b. (a -> b) -> a -> b
$
            BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp SOACS) -> BasicOp -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ [SubExp] -> VName -> BasicOp
Rotate [SubExp]
rots' VName
pat_adj
    --
    Replicate (Shape [SubExp]
ns) SubExp
x -> do
      (VName
_pat_v, VName
pat_adj) <- Pat Type -> StmAux () -> BasicOp -> ADM () -> ADM (VName, VName)
commonBasicOp Pat Type
pat StmAux ()
aux BasicOp
e ADM ()
m
      ADM () -> ADM ()
forall a. ADM a -> ADM a
returnSweepCode (ADM () -> ADM ()) -> ADM () -> ADM ()
forall a b. (a -> b) -> a -> b
$ do
        Type
x_t <- SubExp -> ADM Type
forall t (m :: * -> *). HasScope t m => SubExp -> m Type
subExpType SubExp
x
        Lambda SOACS
lam <- Type -> ADM (Lambda SOACS)
addLambda Type
x_t
        SubExp
ne <- [Char] -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"zero" (Exp (Rep ADM) -> ADM SubExp) -> Exp (Rep ADM) -> ADM SubExp
forall a b. (a -> b) -> a -> b
$ Type -> Exp SOACS
forall rep. Type -> Exp rep
zeroExp Type
x_t
        SubExp
n <- [Char] -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"rep_size" (Exp SOACS -> ADM SubExp) -> ADM (Exp SOACS) -> ADM SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BinOp -> SubExp -> [SubExp] -> ADM (Exp (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
BinOp -> SubExp -> [SubExp] -> m (Exp (Rep m))
foldBinOp (IntType -> Overflow -> BinOp
Mul IntType
Int64 Overflow
OverflowUndef) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) [SubExp]
ns
        VName
pat_adj_flat <-
          [Char] -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp (VName -> [Char]
baseString VName
pat_adj [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> [Char]
"_flat") (Exp (Rep ADM) -> ADM VName) -> Exp (Rep ADM) -> ADM VName
forall a b. (a -> b) -> a -> b
$
            BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp SOACS) -> BasicOp -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ ShapeChange SubExp -> VName -> BasicOp
Reshape ((SubExp -> DimChange SubExp) -> [SubExp] -> ShapeChange SubExp
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> DimChange SubExp
forall d. d -> DimChange d
DimNew ([SubExp] -> ShapeChange SubExp) -> [SubExp] -> ShapeChange SubExp
forall a b. (a -> b) -> a -> b
$ SubExp
n SubExp -> [SubExp] -> [SubExp]
forall a. a -> [a] -> [a]
: Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
x_t) VName
pat_adj
        ScremaForm SOACS
reduce <- [Reduce SOACS] -> ADM (ScremaForm SOACS)
forall rep (m :: * -> *).
(Buildable rep, MonadFreshNames m) =>
[Reduce rep] -> m (ScremaForm rep)
reduceSOAC [Commutativity -> Lambda SOACS -> [SubExp] -> Reduce SOACS
forall rep. Commutativity -> Lambda rep -> [SubExp] -> Reduce rep
Reduce Commutativity
Commutative Lambda SOACS
lam [SubExp
ne]]
        SubExp -> VName -> ADM ()
updateSubExpAdj SubExp
x
          (VName -> ADM ()) -> ADM VName -> ADM ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [Char] -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"rep_contrib" (Op SOACS -> Exp SOACS
forall rep. Op rep -> Exp rep
Op (Op SOACS -> Exp SOACS) -> Op SOACS -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
n [VName
pat_adj_flat] ScremaForm SOACS
reduce)
    --
    Concat Int
d (VName
arr :| [VName]
arrs) SubExp
_ -> do
      (VName
_pat_v, VName
pat_adj) <- Pat Type -> StmAux () -> BasicOp -> ADM () -> ADM (VName, VName)
commonBasicOp Pat Type
pat StmAux ()
aux BasicOp
e ADM ()
m
      ADM () -> ADM ()
forall a. ADM a -> ADM a
returnSweepCode (ADM () -> ADM ()) -> ADM () -> ADM ()
forall a b. (a -> b) -> a -> b
$ do
        let sliceAdj :: SubExp -> [VName] -> ADM [VName]
sliceAdj SubExp
_ [] = [VName] -> ADM [VName]
forall (f :: * -> *) a. Applicative f => a -> f a
pure []
            sliceAdj SubExp
start (VName
v : [VName]
vs) = do
              Type
v_t <- VName -> ADM Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
v
              let w :: SubExp
w = Int -> Type -> SubExp
forall u. Int -> TypeBase Shape u -> SubExp
arraySize Int
0 Type
v_t
                  slice :: DimIndex SubExp
slice = SubExp -> SubExp -> SubExp -> DimIndex SubExp
forall d. d -> d -> d -> DimIndex d
DimSlice SubExp
start SubExp
w (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1)
              VName
pat_adj_slice <-
                [Char] -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp (VName -> [Char]
baseString VName
pat_adj [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> [Char]
"_slice") (Exp (Rep ADM) -> ADM VName) -> Exp (Rep ADM) -> ADM VName
forall a b. (a -> b) -> a -> b
$
                  BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp SOACS) -> BasicOp -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
pat_adj (Type -> Int -> [DimIndex SubExp] -> Slice SubExp
sliceAt Type
v_t Int
d [DimIndex SubExp
slice])
              SubExp
start' <- [Char] -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"start" (Exp (Rep ADM) -> ADM SubExp) -> Exp (Rep ADM) -> ADM SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp SOACS) -> BasicOp -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Overflow -> BinOp
Add IntType
Int64 Overflow
OverflowUndef) SubExp
start SubExp
w
              [VName]
slices <- SubExp -> [VName] -> ADM [VName]
sliceAdj SubExp
start' [VName]
vs
              [VName] -> ADM [VName]
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([VName] -> ADM [VName]) -> [VName] -> ADM [VName]
forall a b. (a -> b) -> a -> b
$ VName
pat_adj_slice VName -> [VName] -> [VName]
forall a. a -> [a] -> [a]
: [VName]
slices

        [VName]
slices <- SubExp -> [VName] -> ADM [VName]
sliceAdj (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) ([VName] -> ADM [VName]) -> [VName] -> ADM [VName]
forall a b. (a -> b) -> a -> b
$ VName
arr VName -> [VName] -> [VName]
forall a. a -> [a] -> [a]
: [VName]
arrs

        (VName -> VName -> ADM ()) -> [VName] -> [VName] -> ADM ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> VName -> ADM ()
updateAdj (VName
arr VName -> [VName] -> [VName]
forall a. a -> [a] -> [a]
: [VName]
arrs) [VName]
slices
    --
    Copy VName
se -> do
      (VName
_pat_v, VName
pat_adj) <- Pat Type -> StmAux () -> BasicOp -> ADM () -> ADM (VName, VName)
commonBasicOp Pat Type
pat StmAux ()
aux BasicOp
e ADM ()
m
      ADM () -> ADM ()
forall a. ADM a -> ADM a
returnSweepCode (ADM () -> ADM ()) -> ADM () -> ADM ()
forall a b. (a -> b) -> a -> b
$ ADM () -> ADM ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (ADM () -> ADM ()) -> ADM () -> ADM ()
forall a b. (a -> b) -> a -> b
$ VName -> VName -> ADM ()
updateAdj VName
se VName
pat_adj
    --
    Manifest [Int]
_ VName
se -> do
      (VName
_pat_v, VName
pat_adj) <- Pat Type -> StmAux () -> BasicOp -> ADM () -> ADM (VName, VName)
commonBasicOp Pat Type
pat StmAux ()
aux BasicOp
e ADM ()
m
      ADM () -> ADM ()
forall a. ADM a -> ADM a
returnSweepCode (ADM () -> ADM ()) -> ADM () -> ADM ()
forall a b. (a -> b) -> a -> b
$ ADM () -> ADM ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (ADM () -> ADM ()) -> ADM () -> ADM ()
forall a b. (a -> b) -> a -> b
$ VName -> VName -> ADM ()
updateAdj VName
se VName
pat_adj
    --
    Scratch {} ->
      ADM (VName, VName) -> ADM ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (ADM (VName, VName) -> ADM ()) -> ADM (VName, VName) -> ADM ()
forall a b. (a -> b) -> a -> b
$ Pat Type -> StmAux () -> BasicOp -> ADM () -> ADM (VName, VName)
commonBasicOp Pat Type
pat StmAux ()
aux BasicOp
e ADM ()
m
    --
    Iota SubExp
n SubExp
_ SubExp
_ IntType
t -> do
      (VName
_pat_v, VName
pat_adj) <- Pat Type -> StmAux () -> BasicOp -> ADM () -> ADM (VName, VName)
commonBasicOp Pat Type
pat StmAux ()
aux BasicOp
e ADM ()
m
      ADM () -> ADM ()
forall a. ADM a -> ADM a
returnSweepCode (ADM () -> ADM ()) -> ADM () -> ADM ()
forall a b. (a -> b) -> a -> b
$ do
        SubExp
ne <- [Char] -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"zero" (Exp (Rep ADM) -> ADM SubExp) -> Exp (Rep ADM) -> ADM SubExp
forall a b. (a -> b) -> a -> b
$ Type -> Exp SOACS
forall rep. Type -> Exp rep
zeroExp (Type -> Exp SOACS) -> Type -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim (PrimType -> Type) -> PrimType -> Type
forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
IntType IntType
t
        Lambda SOACS
lam <- Type -> ADM (Lambda SOACS)
addLambda (Type -> ADM (Lambda SOACS)) -> Type -> ADM (Lambda SOACS)
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim (PrimType -> Type) -> PrimType -> Type
forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
IntType IntType
t
        ScremaForm SOACS
reduce <- [Reduce SOACS] -> ADM (ScremaForm SOACS)
forall rep (m :: * -> *).
(Buildable rep, MonadFreshNames m) =>
[Reduce rep] -> m (ScremaForm rep)
reduceSOAC [Commutativity -> Lambda SOACS -> [SubExp] -> Reduce SOACS
forall rep. Commutativity -> Lambda rep -> [SubExp] -> Reduce rep
Reduce Commutativity
Commutative Lambda SOACS
lam [SubExp
ne]]
        SubExp -> VName -> ADM ()
updateSubExpAdj SubExp
n
          (VName -> ADM ()) -> ADM VName -> ADM ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [Char] -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"iota_contrib" (Op SOACS -> Exp SOACS
forall rep. Op rep -> Exp rep
Op (Op SOACS -> Exp SOACS) -> Op SOACS -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
n [VName
pat_adj] ScremaForm SOACS
reduce)
    --
    Update Safety
safety VName
arr Slice SubExp
slice SubExp
v -> do
      (VName
_pat_v, VName
pat_adj) <- Pat Type -> StmAux () -> BasicOp -> ADM () -> ADM (VName, VName)
commonBasicOp Pat Type
pat StmAux ()
aux BasicOp
e ADM ()
m
      ADM () -> ADM ()
forall a. ADM a -> ADM a
returnSweepCode (ADM () -> ADM ()) -> ADM () -> ADM ()
forall a b. (a -> b) -> a -> b
$ do
        VName
v_adj <- [Char] -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"update_val_adj" (Exp (Rep ADM) -> ADM VName) -> Exp (Rep ADM) -> ADM VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp SOACS) -> BasicOp -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
pat_adj Slice SubExp
slice
        Type
t <- VName -> ADM Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
v_adj
        VName
v_adj_copy <-
          case Type
t of
            Array {} -> [Char] -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"update_val_adj_copy" (Exp (Rep ADM) -> ADM VName) -> Exp (Rep ADM) -> ADM VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp SOACS) -> BasicOp -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy VName
v_adj
            Type
_ -> VName -> ADM VName
forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
v_adj
        SubExp -> VName -> ADM ()
updateSubExpAdj SubExp
v VName
v_adj_copy
        SubExp
zeroes <- [Char] -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"update_zero" (Exp SOACS -> ADM SubExp)
-> (Type -> Exp SOACS) -> Type -> ADM SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Type -> Exp SOACS
forall rep. Type -> Exp rep
zeroExp (Type -> ADM SubExp) -> ADM Type -> ADM SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< SubExp -> ADM Type
forall t (m :: * -> *). HasScope t m => SubExp -> m Type
subExpType SubExp
v
        ADM () -> ADM ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (ADM () -> ADM ()) -> ADM () -> ADM ()
forall a b. (a -> b) -> a -> b
$
          VName -> VName -> ADM ()
updateAdj VName
arr
            (VName -> ADM ()) -> ADM VName -> ADM ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [Char] -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"update_src_adj" (BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp SOACS) -> BasicOp -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ Safety -> VName -> Slice SubExp -> SubExp -> BasicOp
Update Safety
safety VName
pat_adj Slice SubExp
slice SubExp
zeroes)
    -- See Note [Adjoints of accumulators]
    UpdateAcc VName
_ [SubExp]
is [SubExp]
vs -> do
      Stm (Rep ADM) -> ADM ()
forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm (Stm (Rep ADM) -> ADM ()) -> Stm (Rep ADM) -> ADM ()
forall a b. (a -> b) -> a -> b
$ Pat (LetDec SOACS)
-> StmAux (ExpDec SOACS) -> Exp SOACS -> Stm SOACS
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
Pat (LetDec SOACS)
pat StmAux ()
StmAux (ExpDec SOACS)
aux (Exp SOACS -> Stm SOACS) -> Exp SOACS -> Stm SOACS
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp BasicOp
e
      ADM ()
m
      [VName]
pat_adjs <- (VName -> ADM VName) -> [VName] -> ADM [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> ADM VName
lookupAdjVal (Pat Type -> [VName]
forall dec. Pat dec -> [VName]
patNames Pat Type
pat)
      ADM () -> ADM ()
forall a. ADM a -> ADM a
returnSweepCode (ADM () -> ADM ()) -> ADM () -> ADM ()
forall a b. (a -> b) -> a -> b
$ do
        [(VName, SubExp)] -> ((VName, SubExp) -> ADM ()) -> ADM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
pat_adjs [SubExp]
vs) (((VName, SubExp) -> ADM ()) -> ADM ())
-> ((VName, SubExp) -> ADM ()) -> ADM ()
forall a b. (a -> b) -> a -> b
$ \(VName
adj, SubExp
v) -> do
          VName
adj_i <- [Char] -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"updateacc_val_adj" (Exp (Rep ADM) -> ADM VName) -> Exp (Rep ADM) -> ADM VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp SOACS) -> BasicOp -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
adj (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ [DimIndex SubExp] -> Slice SubExp
forall d. [DimIndex d] -> Slice d
Slice ([DimIndex SubExp] -> Slice SubExp)
-> [DimIndex SubExp] -> Slice SubExp
forall a b. (a -> b) -> a -> b
$ (SubExp -> DimIndex SubExp) -> [SubExp] -> [DimIndex SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix [SubExp]
is
          SubExp -> VName -> ADM ()
updateSubExpAdj SubExp
v VName
adj_i

vjpOps :: VjpOps
vjpOps :: VjpOps
vjpOps = ([Adj] -> [VName] -> Lambda SOACS -> ADM (Lambda SOACS))
-> (Stm SOACS -> ADM () -> ADM ()) -> VjpOps
VjpOps [Adj] -> [VName] -> Lambda SOACS -> ADM (Lambda SOACS)
diffLambda Stm SOACS -> ADM () -> ADM ()
diffStm

diffStm :: Stm SOACS -> ADM () -> ADM ()
diffStm :: Stm SOACS -> ADM () -> ADM ()
diffStm (Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux (BasicOp BasicOp
e)) ADM ()
m =
  Pat Type -> StmAux () -> BasicOp -> ADM () -> ADM ()
diffBasicOp Pat Type
Pat (LetDec SOACS)
pat StmAux ()
StmAux (ExpDec SOACS)
aux BasicOp
e ADM ()
m
diffStm stm :: Stm SOACS
stm@(Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
_ (Apply Name
f [(SubExp, Diet)]
args [RetType SOACS]
_ (Safety, SrcLoc, [SrcLoc])
_)) ADM ()
m
  | Just (PrimType
ret, [PrimType]
argts) <- Name
-> Map Name (PrimType, [PrimType]) -> Maybe (PrimType, [PrimType])
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup Name
f Map Name (PrimType, [PrimType])
builtInFunctions = do
      Stm (Rep ADM) -> ADM ()
forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm Stm (Rep ADM)
Stm SOACS
stm
      ADM ()
m

      VName
pat_adj <- VName -> ADM VName
lookupAdjVal (VName -> ADM VName) -> ADM VName -> ADM VName
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Pat Type -> ADM VName
patName Pat Type
Pat (LetDec SOACS)
pat
      let arg_pes :: [PrimExp VName]
arg_pes = (PrimType -> SubExp -> PrimExp VName)
-> [PrimType] -> [SubExp] -> [PrimExp VName]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith PrimType -> SubExp -> PrimExp VName
primExpFromSubExp [PrimType]
argts (((SubExp, Diet) -> SubExp) -> [(SubExp, Diet)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp, Diet) -> SubExp
forall a b. (a, b) -> a
fst [(SubExp, Diet)]
args)
          pat_adj' :: PrimExp VName
pat_adj' = PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
ret (VName -> SubExp
Var VName
pat_adj)
          convert :: PrimType -> PrimType -> PrimExp VName -> PrimExp VName
convert PrimType
ft PrimType
tt
            | PrimType
ft PrimType -> PrimType -> Bool
forall a. Eq a => a -> a -> Bool
== PrimType
tt = PrimExp VName -> PrimExp VName
forall a. a -> a
id
          convert (IntType IntType
ft) (IntType IntType
tt) = ConvOp -> PrimExp VName -> PrimExp VName
forall v. ConvOp -> PrimExp v -> PrimExp v
ConvOpExp (IntType -> IntType -> ConvOp
SExt IntType
ft IntType
tt)
          convert (FloatType FloatType
ft) (FloatType FloatType
tt) = ConvOp -> PrimExp VName -> PrimExp VName
forall v. ConvOp -> PrimExp v -> PrimExp v
ConvOpExp (FloatType -> FloatType -> ConvOp
FPConv FloatType
ft FloatType
tt)
          convert PrimType
Bool (FloatType FloatType
tt) = ConvOp -> PrimExp VName -> PrimExp VName
forall v. ConvOp -> PrimExp v -> PrimExp v
ConvOpExp (FloatType -> ConvOp
BToF FloatType
tt)
          convert (FloatType FloatType
ft) PrimType
Bool = ConvOp -> PrimExp VName -> PrimExp VName
forall v. ConvOp -> PrimExp v -> PrimExp v
ConvOpExp (FloatType -> ConvOp
FToB FloatType
ft)
          convert PrimType
ft PrimType
tt = [Char] -> PrimExp VName -> PrimExp VName
forall a. HasCallStack => [Char] -> a
error ([Char] -> PrimExp VName -> PrimExp VName)
-> [Char] -> PrimExp VName -> PrimExp VName
forall a b. (a -> b) -> a -> b
$ [Char]
"diffStm.convert: " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ (Name, PrimType, PrimType) -> [Char]
forall a. Pretty a => a -> [Char]
pretty (Name
f, PrimType
ft, PrimType
tt)

      [VName]
contribs <-
        case Name -> [PrimExp VName] -> Maybe [PrimExp VName]
pdBuiltin Name
f [PrimExp VName]
arg_pes of
          Maybe [PrimExp VName]
Nothing ->
            [Char] -> ADM [VName]
forall a. HasCallStack => [Char] -> a
error ([Char] -> ADM [VName]) -> [Char] -> ADM [VName]
forall a b. (a -> b) -> a -> b
$ [Char]
"No partial derivative defined for builtin function: " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ Name -> [Char]
forall a. Pretty a => a -> [Char]
pretty Name
f
          Just [PrimExp VName]
derivs ->
            [(PrimExp VName, PrimType)]
-> ((PrimExp VName, PrimType) -> ADM VName) -> ADM [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([PrimExp VName] -> [PrimType] -> [(PrimExp VName, PrimType)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PrimExp VName]
derivs [PrimType]
argts) (((PrimExp VName, PrimType) -> ADM VName) -> ADM [VName])
-> ((PrimExp VName, PrimType) -> ADM VName) -> ADM [VName]
forall a b. (a -> b) -> a -> b
$ \(PrimExp VName
deriv, PrimType
argt) ->
              [Char] -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"contrib" (Exp SOACS -> ADM VName)
-> (PrimExp VName -> ADM (Exp SOACS)) -> PrimExp VName -> ADM VName
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< PrimExp VName -> ADM (Exp SOACS)
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp (PrimExp VName -> ADM (Exp SOACS))
-> (PrimExp VName -> PrimExp VName)
-> PrimExp VName
-> ADM (Exp SOACS)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrimType -> PrimType -> PrimExp VName -> PrimExp VName
convert PrimType
ret PrimType
argt (PrimExp VName -> ADM VName) -> PrimExp VName -> ADM VName
forall a b. (a -> b) -> a -> b
$ PrimExp VName
pat_adj' PrimExp VName -> PrimExp VName -> PrimExp VName
forall v. PrimExp v -> PrimExp v -> PrimExp v
~*~ PrimExp VName
deriv

      (SubExp -> VName -> ADM ()) -> [SubExp] -> [VName] -> ADM ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ SubExp -> VName -> ADM ()
updateSubExpAdj (((SubExp, Diet) -> SubExp) -> [(SubExp, Diet)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp, Diet) -> SubExp
forall a b. (a, b) -> a
fst [(SubExp, Diet)]
args) [VName]
contribs
diffStm stm :: Stm SOACS
stm@(Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
_ (If SubExp
cond Body SOACS
tbody Body SOACS
fbody IfDec (BranchType SOACS)
_)) ADM ()
m = do
  Stm (Rep ADM) -> ADM ()
forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm Stm (Rep ADM)
Stm SOACS
stm
  ADM ()
m
  ADM () -> ADM ()
forall a. ADM a -> ADM a
returnSweepCode (ADM () -> ADM ()) -> ADM () -> ADM ()
forall a b. (a -> b) -> a -> b
$ do
    let tbody_free :: Names
tbody_free = Body SOACS -> Names
forall a. FreeIn a => a -> Names
freeIn Body SOACS
tbody
        fbody_free :: Names
fbody_free = Body SOACS -> Names
forall a. FreeIn a => a -> Names
freeIn Body SOACS
fbody
        branches_free :: [VName]
branches_free = Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$ Names
tbody_free Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names
fbody_free

    [Adj]
adjs <- (VName -> ADM Adj) -> [VName] -> ADM [Adj]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> ADM Adj
lookupAdj ([VName] -> ADM [Adj]) -> [VName] -> ADM [Adj]
forall a b. (a -> b) -> a -> b
$ Pat Type -> [VName]
forall dec. Pat dec -> [VName]
patNames Pat Type
Pat (LetDec SOACS)
pat

    [VName]
branches_free_adj <-
      ( [VName] -> ADM [VName]
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([VName] -> ADM [VName])
-> ([VName] -> [VName]) -> [VName] -> ADM [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> [VName] -> [VName]
forall a. Int -> [a] -> [a]
takeLast ([VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
branches_free)
          ([VName] -> ADM [VName])
-> (Exp SOACS -> ADM [VName]) -> Exp SOACS -> ADM [VName]
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< [Char] -> Exp (Rep ADM) -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m [VName]
letTupExp [Char]
"branch_adj"
          (Exp SOACS -> ADM [VName])
-> (Exp SOACS -> ADM (Exp SOACS)) -> Exp SOACS -> ADM [VName]
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< Exp SOACS -> ADM (Exp SOACS)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Exp rep -> m (Exp rep)
renameExp
        )
        (Exp SOACS -> ADM [VName]) -> ADM (Exp SOACS) -> ADM [VName]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ADM (Exp (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
(MonadBuilder m,
 BranchType (Rep m) ~ TypeBase ExtShape NoUniqueness) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
          (SubExp -> ADM (Exp (Rep ADM))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp SubExp
cond)
          ([Adj] -> [VName] -> Body SOACS -> ADM (Body SOACS)
diffBody [Adj]
adjs [VName]
branches_free Body SOACS
tbody)
          ([Adj] -> [VName] -> Body SOACS -> ADM (Body SOACS)
diffBody [Adj]
adjs [VName]
branches_free Body SOACS
fbody)
    (VName -> VName -> ADM ()) -> [VName] -> [VName] -> ADM ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> VName -> ADM ()
insAdj [VName]
branches_free [VName]
branches_free_adj
diffStm (Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux (Op Op SOACS
soac)) ADM ()
m =
  VjpOps -> Pat Type -> StmAux () -> SOAC SOACS -> ADM () -> ADM ()
vjpSOAC VjpOps
vjpOps Pat Type
Pat (LetDec SOACS)
pat StmAux ()
StmAux (ExpDec SOACS)
aux Op SOACS
SOAC SOACS
soac ADM ()
m
diffStm (Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux loop :: Exp SOACS
loop@DoLoop {}) ADM ()
m =
  (Stms SOACS -> ADM ())
-> Pat Type -> StmAux () -> Exp SOACS -> ADM () -> ADM ()
diffLoop Stms SOACS -> ADM ()
diffStms Pat Type
Pat (LetDec SOACS)
pat StmAux ()
StmAux (ExpDec SOACS)
aux Exp SOACS
loop ADM ()
m
-- See Note [Adjoints of accumulators]
diffStm stm :: Stm SOACS
stm@(Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
_aux (WithAcc [WithAccInput SOACS]
inputs Lambda SOACS
lam)) ADM ()
m = do
  Stm (Rep ADM) -> ADM ()
forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm Stm (Rep ADM)
Stm SOACS
stm
  ADM ()
m
  ADM () -> ADM ()
forall a. ADM a -> ADM a
returnSweepCode (ADM () -> ADM ()) -> ADM () -> ADM ()
forall a b. (a -> b) -> a -> b
$ do
    [Adj]
adjs <- (VName -> ADM Adj) -> [VName] -> ADM [Adj]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> ADM Adj
lookupAdj ([VName] -> ADM [Adj]) -> [VName] -> ADM [Adj]
forall a b. (a -> b) -> a -> b
$ Pat Type -> [VName]
forall dec. Pat dec -> [VName]
patNames Pat Type
Pat (LetDec SOACS)
pat
    Lambda SOACS
lam' <- Lambda SOACS -> ADM (Lambda SOACS)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda SOACS
lam
    [VName]
free_vars <- (VName -> ADM Bool) -> [VName] -> ADM [VName]
forall (m :: * -> *) a.
Applicative m =>
(a -> m Bool) -> [a] -> m [a]
filterM VName -> ADM Bool
isActive ([VName] -> ADM [VName]) -> [VName] -> ADM [VName]
forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> Names
forall a. FreeIn a => a -> Names
freeIn Lambda SOACS
lam'
    [VName]
free_accs <- (VName -> ADM Bool) -> [VName] -> ADM [VName]
forall (m :: * -> *) a.
Applicative m =>
(a -> m Bool) -> [a] -> m [a]
filterM ((Type -> Bool) -> ADM Type -> ADM Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Type -> Bool
forall shape u. TypeBase shape u -> Bool
isAcc (ADM Type -> ADM Bool) -> (VName -> ADM Type) -> VName -> ADM Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> ADM Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType) [VName]
free_vars
    let free_vars' :: [VName]
free_vars' = [VName]
free_vars [VName] -> [VName] -> [VName]
forall a. Eq a => [a] -> [a] -> [a]
\\ [VName]
free_accs
    Lambda SOACS
lam'' <- [Adj] -> [VName] -> Lambda SOACS -> ADM (Lambda SOACS)
diffLambda' [Adj]
adjs [VName]
free_vars' Lambda SOACS
lam'
    [WithAccInput SOACS]
inputs' <- (WithAccInput SOACS -> ADM (WithAccInput SOACS))
-> [WithAccInput SOACS] -> ADM [WithAccInput SOACS]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM WithAccInput SOACS -> ADM (WithAccInput SOACS)
forall (m :: * -> *) rep a b b.
(Rename (LetDec rep), Rename (ExpDec rep), Rename (BodyDec rep),
 Rename (FParamInfo rep), Rename (LParamInfo rep),
 Rename (RetType rep), Rename (BranchType rep), Rename (Op rep),
 MonadFreshNames m) =>
(a, b, Maybe (Lambda rep, b)) -> m (a, b, Maybe (Lambda rep, b))
renameInputLambda [WithAccInput SOACS]
inputs
    [VName]
free_adjs <- [Char] -> Exp (Rep ADM) -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m [VName]
letTupExp [Char]
"with_acc_contrib" (Exp (Rep ADM) -> ADM [VName]) -> Exp (Rep ADM) -> ADM [VName]
forall a b. (a -> b) -> a -> b
$ [WithAccInput SOACS] -> Lambda SOACS -> Exp SOACS
forall rep. [WithAccInput rep] -> Lambda rep -> Exp rep
WithAcc [WithAccInput SOACS]
inputs' Lambda SOACS
lam''
    (VName -> VName -> ADM ()) -> [VName] -> [VName] -> ADM ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> VName -> ADM ()
insAdj ([VName]
arrs [VName] -> [VName] -> [VName]
forall a. Semigroup a => a -> a -> a
<> [VName]
free_vars') [VName]
free_adjs
  where
    arrs :: [VName]
arrs = (WithAccInput SOACS -> [VName]) -> [WithAccInput SOACS] -> [VName]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (\(Shape
_, [VName]
as, Maybe (Lambda SOACS, [SubExp])
_) -> [VName]
as) [WithAccInput SOACS]
inputs
    renameInputLambda :: (a, b, Maybe (Lambda rep, b)) -> m (a, b, Maybe (Lambda rep, b))
renameInputLambda (a
shape, b
as, Just (Lambda rep
f, b
nes)) = do
      Lambda rep
f' <- Lambda rep -> m (Lambda rep)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda rep
f
      (a, b, Maybe (Lambda rep, b)) -> m (a, b, Maybe (Lambda rep, b))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (a
shape, b
as, (Lambda rep, b) -> Maybe (Lambda rep, b)
forall a. a -> Maybe a
Just (Lambda rep
f', b
nes))
    renameInputLambda (a, b, Maybe (Lambda rep, b))
input = (a, b, Maybe (Lambda rep, b)) -> m (a, b, Maybe (Lambda rep, b))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (a, b, Maybe (Lambda rep, b))
input
    diffLambda' :: [Adj] -> [VName] -> Lambda SOACS -> ADM (Lambda SOACS)
diffLambda' [Adj]
res_adjs [VName]
get_adjs_for (Lambda [LParam SOACS]
params Body SOACS
body [Type]
ts) =
      Scope SOACS -> ADM (Lambda SOACS) -> ADM (Lambda SOACS)
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope ([Param Type] -> Scope SOACS
forall rep dec. (LParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfLParams [Param Type]
[LParam SOACS]
params) (ADM (Lambda SOACS) -> ADM (Lambda SOACS))
-> ADM (Lambda SOACS) -> ADM (Lambda SOACS)
forall a b. (a -> b) -> a -> b
$ do
        Body () Stms SOACS
stms Result
res <- [Adj] -> [VName] -> Body SOACS -> ADM (Body SOACS)
diffBody [Adj]
res_adjs [VName]
get_adjs_for Body SOACS
body
        let body' :: Body SOACS
body' = BodyDec SOACS -> Stms SOACS -> Result -> Body SOACS
forall rep. BodyDec rep -> Stms rep -> Result -> Body rep
Body () Stms SOACS
stms (Result -> Body SOACS) -> Result -> Body SOACS
forall a b. (a -> b) -> a -> b
$ Int -> Result -> Result
forall a. Int -> [a] -> [a]
take ([WithAccInput SOACS] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [WithAccInput SOACS]
inputs) Result
res Result -> Result -> Result
forall a. Semigroup a => a -> a -> a
<> Int -> Result -> Result
forall a. Int -> [a] -> [a]
takeLast ([VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
get_adjs_for) Result
res
        [Type]
ts' <- (VName -> ADM Type) -> [VName] -> ADM [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> ADM Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType [VName]
get_adjs_for
        Lambda SOACS -> ADM (Lambda SOACS)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Lambda SOACS -> ADM (Lambda SOACS))
-> Lambda SOACS -> ADM (Lambda SOACS)
forall a b. (a -> b) -> a -> b
$ [LParam SOACS] -> Body SOACS -> [Type] -> Lambda SOACS
forall rep. [LParam rep] -> Body rep -> [Type] -> Lambda rep
Lambda [LParam SOACS]
params Body SOACS
body' ([Type] -> Lambda SOACS) -> [Type] -> Lambda SOACS
forall a b. (a -> b) -> a -> b
$ Int -> [Type] -> [Type]
forall a. Int -> [a] -> [a]
take ([WithAccInput SOACS] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [WithAccInput SOACS]
inputs) [Type]
ts [Type] -> [Type] -> [Type]
forall a. Semigroup a => a -> a -> a
<> [Type]
ts'
diffStm Stm SOACS
stm ADM ()
_ = [Char] -> ADM ()
forall a. HasCallStack => [Char] -> a
error ([Char] -> ADM ()) -> [Char] -> ADM ()
forall a b. (a -> b) -> a -> b
$ [Char]
"diffStm unhandled:\n" [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ Stm SOACS -> [Char]
forall a. Pretty a => a -> [Char]
pretty Stm SOACS
stm

diffStms :: Stms SOACS -> ADM ()
diffStms :: Stms SOACS -> ADM ()
diffStms Stms SOACS
all_stms
  | Just (Stm SOACS
stm, Stms SOACS
stms) <- Stms SOACS -> Maybe (Stm SOACS, Stms SOACS)
forall rep. Stms rep -> Maybe (Stm rep, Stms rep)
stmsHead Stms SOACS
all_stms = do
      (Substitutions
subst, Stms SOACS
copy_stms) <- Stm SOACS -> ADM (Substitutions, Stms SOACS)
copyConsumedArrsInStm Stm SOACS
stm
      let (Stm SOACS
stm', Stms SOACS
stms') = Substitutions -> (Stm SOACS, Stms SOACS) -> (Stm SOACS, Stms SOACS)
forall a. Substitute a => Substitutions -> a -> a
substituteNames Substitutions
subst (Stm SOACS
stm, Stms SOACS
stms)
      Stms SOACS -> ADM ()
diffStms Stms SOACS
copy_stms ADM () -> ADM () -> ADM ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Stm SOACS -> ADM () -> ADM ()
diffStm Stm SOACS
stm' (Stms SOACS -> ADM ()
diffStms Stms SOACS
stms')
      [(VName, VName)] -> ((VName, VName) -> ADM ()) -> ADM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (Substitutions -> [(VName, VName)]
forall k a. Map k a -> [(k, a)]
M.toList Substitutions
subst) (((VName, VName) -> ADM ()) -> ADM ())
-> ((VName, VName) -> ADM ()) -> ADM ()
forall a b. (a -> b) -> a -> b
$ \(VName
from, VName
to) ->
        VName -> Adj -> ADM ()
setAdj VName
from (Adj -> ADM ()) -> ADM Adj -> ADM ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName -> ADM Adj
lookupAdj VName
to
  | Bool
otherwise =
      () -> ADM ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

-- | Preprocess statements before differentiating.
-- For now, it's just stripmining.
preprocess :: Stms SOACS -> ADM (Stms SOACS)
preprocess :: Stms SOACS -> ADM (Stms SOACS)
preprocess = Stms SOACS -> ADM (Stms SOACS)
stripmineStms

diffBody :: [Adj] -> [VName] -> Body SOACS -> ADM (Body SOACS)
diffBody :: [Adj] -> [VName] -> Body SOACS -> ADM (Body SOACS)
diffBody [Adj]
res_adjs [VName]
get_adjs_for (Body () Stms SOACS
stms Result
res) = ADM (Body SOACS) -> ADM (Body SOACS)
forall a. ADM a -> ADM a
subAD (ADM (Body SOACS) -> ADM (Body SOACS))
-> ADM (Body SOACS) -> ADM (Body SOACS)
forall a b. (a -> b) -> a -> b
$
  ADM (Body SOACS) -> ADM (Body SOACS)
forall a. ADM a -> ADM a
subSubsts (ADM (Body SOACS) -> ADM (Body SOACS))
-> ADM (Body SOACS) -> ADM (Body SOACS)
forall a b. (a -> b) -> a -> b
$ do
    let onResult :: SubExpRes -> Adj -> ADM ()
onResult (SubExpRes Certs
_ (Constant PrimValue
_)) Adj
_ = () -> ADM ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
        onResult (SubExpRes Certs
_ (Var VName
v)) Adj
v_adj = ADM () -> ADM ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (ADM () -> ADM ()) -> ADM () -> ADM ()
forall a b. (a -> b) -> a -> b
$ VName -> VName -> ADM ()
updateAdj VName
v (VName -> ADM ()) -> ADM VName -> ADM ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Adj -> ADM VName
adjVal Adj
v_adj
    ([VName]
adjs, Stms SOACS
stms') <- ADM [VName] -> ADM ([VName], Stms (Rep ADM))
forall (m :: * -> *) a.
MonadBuilder m =>
m a -> m (a, Stms (Rep m))
collectStms (ADM [VName] -> ADM ([VName], Stms (Rep ADM)))
-> ADM [VName] -> ADM ([VName], Stms (Rep ADM))
forall a b. (a -> b) -> a -> b
$ do
      (SubExpRes -> Adj -> ADM ()) -> Result -> [Adj] -> ADM ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ SubExpRes -> Adj -> ADM ()
onResult (Int -> Result -> Result
forall a. Int -> [a] -> [a]
takeLast ([Adj] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Adj]
res_adjs) Result
res) [Adj]
res_adjs
      Stms SOACS -> ADM ()
diffStms (Stms SOACS -> ADM ()) -> ADM (Stms SOACS) -> ADM ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Stms SOACS -> ADM (Stms SOACS)
preprocess Stms SOACS
stms
      (VName -> ADM VName) -> [VName] -> ADM [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> ADM VName
lookupAdjVal [VName]
get_adjs_for
    Body SOACS -> ADM (Body SOACS)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Body SOACS -> ADM (Body SOACS)) -> Body SOACS -> ADM (Body SOACS)
forall a b. (a -> b) -> a -> b
$ BodyDec SOACS -> Stms SOACS -> Result -> Body SOACS
forall rep. BodyDec rep -> Stms rep -> Result -> Body rep
Body () Stms SOACS
stms' (Result -> Body SOACS) -> Result -> Body SOACS
forall a b. (a -> b) -> a -> b
$ Result
res Result -> Result -> Result
forall a. Semigroup a => a -> a -> a
<> [VName] -> Result
varsRes [VName]
adjs

diffLambda :: [Adj] -> [VName] -> Lambda SOACS -> ADM (Lambda SOACS)
diffLambda :: [Adj] -> [VName] -> Lambda SOACS -> ADM (Lambda SOACS)
diffLambda [Adj]
res_adjs [VName]
get_adjs_for (Lambda [LParam SOACS]
params Body SOACS
body [Type]
_) =
  Scope SOACS -> ADM (Lambda SOACS) -> ADM (Lambda SOACS)
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope ([Param Type] -> Scope SOACS
forall rep dec. (LParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfLParams [Param Type]
[LParam SOACS]
params) (ADM (Lambda SOACS) -> ADM (Lambda SOACS))
-> ADM (Lambda SOACS) -> ADM (Lambda SOACS)
forall a b. (a -> b) -> a -> b
$ do
    Body () Stms SOACS
stms Result
res <- [Adj] -> [VName] -> Body SOACS -> ADM (Body SOACS)
diffBody [Adj]
res_adjs [VName]
get_adjs_for Body SOACS
body
    let body' :: Body SOACS
body' = BodyDec SOACS -> Stms SOACS -> Result -> Body SOACS
forall rep. BodyDec rep -> Stms rep -> Result -> Body rep
Body () Stms SOACS
stms (Result -> Body SOACS) -> Result -> Body SOACS
forall a b. (a -> b) -> a -> b
$ Int -> Result -> Result
forall a. Int -> [a] -> [a]
takeLast ([VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
get_adjs_for) Result
res
    [Type]
ts' <- (VName -> ADM Type) -> [VName] -> ADM [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> ADM Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType [VName]
get_adjs_for
    Lambda SOACS -> ADM (Lambda SOACS)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Lambda SOACS -> ADM (Lambda SOACS))
-> Lambda SOACS -> ADM (Lambda SOACS)
forall a b. (a -> b) -> a -> b
$ [LParam SOACS] -> Body SOACS -> [Type] -> Lambda SOACS
forall rep. [LParam rep] -> Body rep -> [Type] -> Lambda rep
Lambda [LParam SOACS]
params Body SOACS
body' [Type]
ts'

revVJP :: MonadFreshNames m => Scope SOACS -> Lambda SOACS -> m (Lambda SOACS)
revVJP :: Scope SOACS -> Lambda SOACS -> m (Lambda SOACS)
revVJP Scope SOACS
scope (Lambda [LParam SOACS]
params Body SOACS
body [Type]
ts) =
  ADM (Lambda SOACS) -> m (Lambda SOACS)
forall (m :: * -> *) a. MonadFreshNames m => ADM a -> m a
runADM (ADM (Lambda SOACS) -> m (Lambda SOACS))
-> (ADM (Lambda SOACS) -> ADM (Lambda SOACS))
-> ADM (Lambda SOACS)
-> m (Lambda SOACS)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Scope SOACS -> ADM (Lambda SOACS) -> ADM (Lambda SOACS)
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (Scope SOACS
scope Scope SOACS -> Scope SOACS -> Scope SOACS
forall a. Semigroup a => a -> a -> a
<> [Param Type] -> Scope SOACS
forall rep dec. (LParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfLParams [Param Type]
[LParam SOACS]
params) (ADM (Lambda SOACS) -> m (Lambda SOACS))
-> ADM (Lambda SOACS) -> m (Lambda SOACS)
forall a b. (a -> b) -> a -> b
$ do
    [Param Type]
params_adj <- [(SubExp, Type)]
-> ((SubExp, Type) -> ADM (Param Type)) -> ADM [Param Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([SubExp] -> [Type] -> [(SubExp, Type)]
forall a b. [a] -> [b] -> [(a, b)]
zip ((SubExpRes -> SubExp) -> Result -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp (Body SOACS -> Result
forall rep. Body rep -> Result
bodyResult Body SOACS
body)) [Type]
ts) (((SubExp, Type) -> ADM (Param Type)) -> ADM [Param Type])
-> ((SubExp, Type) -> ADM (Param Type)) -> ADM [Param Type]
forall a b. (a -> b) -> a -> b
$ \(SubExp
se, Type
t) ->
      Attrs -> VName -> Type -> Param Type
forall dec. Attrs -> VName -> dec -> Param dec
Param Attrs
forall a. Monoid a => a
mempty (VName -> Type -> Param Type)
-> ADM VName -> ADM (Type -> Param Type)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ADM VName -> (VName -> ADM VName) -> Maybe VName -> ADM VName
forall b a. b -> (a -> b) -> Maybe a -> b
maybe ([Char] -> ADM VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"const_adj") VName -> ADM VName
adjVName (SubExp -> Maybe VName
subExpVar SubExp
se) ADM (Type -> Param Type) -> ADM Type -> ADM (Param Type)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Type -> ADM Type
forall (f :: * -> *) a. Applicative f => a -> f a
pure Type
t

    Body SOACS
body' <-
      Scope SOACS -> ADM (Body SOACS) -> ADM (Body SOACS)
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope ([Param Type] -> Scope SOACS
forall rep dec. (LParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfLParams [Param Type]
params_adj) (ADM (Body SOACS) -> ADM (Body SOACS))
-> ADM (Body SOACS) -> ADM (Body SOACS)
forall a b. (a -> b) -> a -> b
$
        [Adj] -> [VName] -> Body SOACS -> ADM (Body SOACS)
diffBody
          ((Param Type -> Adj) -> [Param Type] -> [Adj]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> Adj
forall t. Param t -> Adj
adjFromParam [Param Type]
params_adj)
          ((Param Type -> VName) -> [Param Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> VName
forall dec. Param dec -> VName
paramName [Param Type]
[LParam SOACS]
params)
          Body SOACS
body

    Lambda SOACS -> ADM (Lambda SOACS)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Lambda SOACS -> ADM (Lambda SOACS))
-> Lambda SOACS -> ADM (Lambda SOACS)
forall a b. (a -> b) -> a -> b
$ [LParam SOACS] -> Body SOACS -> [Type] -> Lambda SOACS
forall rep. [LParam rep] -> Body rep -> [Type] -> Lambda rep
Lambda ([Param Type]
[LParam SOACS]
params [Param Type] -> [Param Type] -> [Param Type]
forall a. [a] -> [a] -> [a]
++ [Param Type]
params_adj) Body SOACS
body' ([Type]
ts [Type] -> [Type] -> [Type]
forall a. Semigroup a => a -> a -> a
<> (Param Type -> Type) -> [Param Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> Type
forall dec. Typed dec => Param dec -> Type
paramType [Param Type]
[LParam SOACS]
params)

-- Note [Adjoints of accumulators]
--
-- The general case of taking adjoints of WithAcc is tricky.  We make
-- some assumptions and lay down a basic design.
--
-- First, we assume that any WithAccs that occur in the program are
-- the result of previous invocations of VJP.  This means we can rely
-- on the operator having a constant adjoint (it's some kind of
-- addition).
--
-- Second, the adjoint of an accumulator is an array of the same type
-- as the underlying array.  For example, the adjoint type of the
-- primal type 'acc(c, [n], {f64})' is '[n]f64'.  In principle the
-- adjoint of 'acc(c, [n], {f64,f32})' should be two arrays of type
-- '[]f64', '[]f32'.  Our current design assumes that adjoints are
-- single variables.  This is fixable.
--
-- # Adjoint of UpdateAcc
--
--   Consider primal code
--
--     update_acc(acc, i, v)
--
--   Interpreted as an imperative statement, this means
--
--     acc[i] ⊕= v
--
--   for some '⊕'.  Normally all the compiler knows of '⊕' is that it
--   is associative and commutative, but because we assume that all
--   accumulators are the result of previous AD transformations, we
--   can assume that '⊕' actually behaves like addition - that is, has
--   unit partial derivatives.  So the return sweep is
--
--     v += acc_adj[i]
--
-- # Adjoint of Map
--
-- Suppose we have primal code
--
--   let acc' =
--     map (...) acc
--
-- where "acc : acc(c, [n], {f64})" and the width of the Map is "w".
-- Our normal transformation for Map input arrays is to similarly map
-- their adjoint, but clearly this doesn't work here because the
-- semantics of mapping an adjoint is an "implicit replicate".  So
-- when generating the return sweep we actually perform that
-- replication:
--
--   map (...) (replicate w acc_adj)
--
-- But what about the contributions to "acc'"?  Those we also have to
-- take special care of.  The result of the map itself is actually a
-- multidimensional array:
--
--   let acc_contribs =
--     map (...) (replicate w acc'_adj)
--
-- which we must then sum to add to the contribution.
--
--   acc_adj += sum(acc_contribs)
--
-- I'm slightly worried about the asymptotics of this, since my
-- intuition of this is that the contributions might be rather sparse.
-- (Maybe completely zero?  If so it will be simplified away
-- entirely.)  Perhaps a better solution is to treat
-- accumulator-inputs in the primal code as we do free variables, and
-- create accumulators for them in the return sweep.
--
-- # Consumption
--
-- A minor problem is that our usual way of handling consumption (Note
-- [Consumption]) is not viable, because accumulators are not
-- copyable.  Fortunately, while the accumulators that are consumed in
-- the forward sweep will also be present in the return sweep given
-- our current translation rules, they will be dead code.  As long as
-- we are careful to run dead code elimination after revVJP, we should
-- be good.