{-# LANGUAGE TypeFamilies #-}

-- | VJP transformation for Map SOACs.  This is a pretty complicated
-- case due to the possibility of free variables.
module Futhark.AD.Rev.Map (vjpMap) where

import Control.Monad
import Data.Bifunctor (first)
import Futhark.AD.Rev.Monad
import Futhark.Analysis.PrimExp.Convert
import Futhark.Builder
import Futhark.IR.SOACS
import Futhark.Tools
import Futhark.Transform.Rename
import Futhark.Util (splitAt3)

-- | A classification of a free variable based on its adjoint.  The
-- 'VName' stored is *not* the adjoint, but the primal variable.
data AdjVar
  = -- | Adjoint is already an accumulator.
    FreeAcc VName
  | -- | Currently has no adjoint, but should be given one, and is an
    -- array with this shape and element type.
    FreeArr VName Shape PrimType
  | -- | Does not need an accumulator adjoint (might still be an array).
    FreeNonAcc VName

classifyAdjVars :: [VName] -> ADM [AdjVar]
classifyAdjVars :: [VName] -> ADM [AdjVar]
classifyAdjVars = forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> ADM AdjVar
f
  where
    f :: VName -> ADM AdjVar
f VName
v = do
      VName
v_adj <- VName -> ADM VName
lookupAdjVal VName
v
      Type
v_adj_t <- forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
v_adj
      case Type
v_adj_t of
        Array PrimType
pt Shape
shape NoUniqueness
_ ->
          forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ VName -> Shape -> PrimType -> AdjVar
FreeArr VName
v Shape
shape PrimType
pt
        Acc {} ->
          forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ VName -> AdjVar
FreeAcc VName
v
        Type
_ ->
          forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ VName -> AdjVar
FreeNonAcc VName
v

partitionAdjVars :: [AdjVar] -> ([(VName, (Shape, PrimType))], [VName], [VName])
partitionAdjVars :: [AdjVar] -> ([(VName, (Shape, PrimType))], [VName], [VName])
partitionAdjVars [] = ([], [], [])
partitionAdjVars (AdjVar
fv : [AdjVar]
fvs) =
  case AdjVar
fv of
    FreeArr VName
v Shape
shape PrimType
t -> ((VName
v, (Shape
shape, PrimType
t)) forall a. a -> [a] -> [a]
: [(VName, (Shape, PrimType))]
xs, [VName]
ys, [VName]
zs)
    FreeAcc VName
v -> ([(VName, (Shape, PrimType))]
xs, VName
v forall a. a -> [a] -> [a]
: [VName]
ys, [VName]
zs)
    FreeNonAcc VName
v -> ([(VName, (Shape, PrimType))]
xs, [VName]
ys, VName
v forall a. a -> [a] -> [a]
: [VName]
zs)
  where
    ([(VName, (Shape, PrimType))]
xs, [VName]
ys, [VName]
zs) = [AdjVar] -> ([(VName, (Shape, PrimType))], [VName], [VName])
partitionAdjVars [AdjVar]
fvs

buildRenamedBody ::
  MonadBuilder m =>
  m (Result, a) ->
  m (Body (Rep m), a)
buildRenamedBody :: forall (m :: * -> *) a.
MonadBuilder m =>
m (Result, a) -> m (Body (Rep m), a)
buildRenamedBody m (Result, a)
m = do
  (Body (Rep m)
body, a
x) <- forall (m :: * -> *) a.
MonadBuilder m =>
m (Result, a) -> m (Body (Rep m), a)
buildBody m (Result, a)
m
  Body (Rep m)
body' <- forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Body rep -> m (Body rep)
renameBody Body (Rep m)
body
  forall (f :: * -> *) a. Applicative f => a -> f a
pure (Body (Rep m)
body', a
x)

withAcc ::
  [(Shape, [VName], Maybe (Lambda SOACS, [SubExp]))] ->
  ([VName] -> ADM Result) ->
  ADM [VName]
withAcc :: [(Shape, [VName], Maybe (Lambda SOACS, [SubExp]))]
-> ([VName] -> ADM Result) -> ADM [VName]
withAcc [] [VName] -> ADM Result
m =
  forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"withacc_res" forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. BasicOp -> Exp rep
BasicOp forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> BasicOp
SubExp forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExpRes -> SubExp
resSubExp) forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [VName] -> ADM Result
m []
withAcc [(Shape, [VName], Maybe (Lambda SOACS, [SubExp]))]
inputs [VName] -> ADM Result
m = do
  ([Param Type]
cert_params, [Param Type]
acc_params) <- forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$
    forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [(Shape, [VName], Maybe (Lambda SOACS, [SubExp]))]
inputs forall a b. (a -> b) -> a -> b
$ \(Shape
shape, [VName]
arrs, Maybe (Lambda SOACS, [SubExp])
_) -> do
      Param Type
cert_param <- forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"acc_cert_p" forall a b. (a -> b) -> a -> b
$ forall shape u. PrimType -> TypeBase shape u
Prim PrimType
Unit
      [Type]
ts <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall u. Int -> TypeBase Shape u -> TypeBase Shape u
stripArray (forall a. ArrayShape a => a -> Int
shapeRank Shape
shape)) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType) [VName]
arrs
      Param Type
acc_param <- forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"acc_p" forall a b. (a -> b) -> a -> b
$ forall shape u. VName -> Shape -> [Type] -> u -> TypeBase shape u
Acc (forall dec. Param dec -> VName
paramName Param Type
cert_param) Shape
shape [Type]
ts NoUniqueness
NoUniqueness
      forall (f :: * -> *) a. Applicative f => a -> f a
pure (Param Type
cert_param, Param Type
acc_param)
  Lambda SOACS
acc_lam <-
    forall a. ADM a -> ADM a
subAD forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m Result -> m (Lambda (Rep m))
mkLambda ([Param Type]
cert_params forall a. [a] -> [a] -> [a]
++ [Param Type]
acc_params) forall a b. (a -> b) -> a -> b
$ [VName] -> ADM Result
m forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall dec. Param dec -> VName
paramName [Param Type]
acc_params
  forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"withhacc_res" forall a b. (a -> b) -> a -> b
$ forall rep. [WithAccInput rep] -> Lambda rep -> Exp rep
WithAcc [(Shape, [VName], Maybe (Lambda SOACS, [SubExp]))]
inputs Lambda SOACS
acc_lam

-- | Perform VJP on a Map.  The 'Adj' list is the adjoints of the
-- result of the map.
vjpMap :: VjpOps -> [Adj] -> StmAux () -> SubExp -> Lambda SOACS -> [VName] -> ADM ()
vjpMap :: VjpOps
-> [Adj]
-> StmAux ()
-> SubExp
-> Lambda SOACS
-> [VName]
-> ADM ()
vjpMap VjpOps
ops [Adj]
res_adjs StmAux ()
_ SubExp
w Lambda SOACS
map_lam [VName]
as
  | Just [[(InBounds, SubExp, SubExp)]]
res_ivs <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Adj -> Maybe [(InBounds, SubExp, SubExp)]
isSparse [Adj]
res_adjs = forall a. ADM a -> ADM a
returnSweepCode forall a b. (a -> b) -> a -> b
$ do
      -- Since at most only a constant number of adjoint are nonzero
      -- (length res_ivs), there is no need for the return sweep code to
      -- contain a Map at all.

      [VName]
free <- forall (m :: * -> *) a.
Applicative m =>
(a -> m Bool) -> [a] -> m [a]
filterM VName -> ADM Bool
isActive forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList forall a b. (a -> b) -> a -> b
$ forall a. FreeIn a => a -> Names
freeIn Lambda SOACS
map_lam Names -> Names -> Names
`namesSubtract` [VName] -> Names
namesFromList [VName]
as
      [Type]
free_ts <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType [VName]
free
      let adjs_for :: [VName]
adjs_for = forall a b. (a -> b) -> [a] -> [b]
map forall dec. Param dec -> VName
paramName (forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
map_lam) forall a. [a] -> [a] -> [a]
++ [VName]
free
          adjs_ts :: [Type]
adjs_ts = forall a b. (a -> b) -> [a] -> [b]
map forall dec. Typed dec => Param dec -> Type
paramType (forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
map_lam) forall a. [a] -> [a] -> [a]
++ [Type]
free_ts

      let oneHot :: Int -> Adj -> [Adj]
oneHot Int
res_i Adj
adj_v = forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Int -> Type -> Adj
f [Int
0 :: Int ..] forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda SOACS
map_lam
            where
              f :: Int -> Type -> Adj
f Int
j Type
t
                | Int
res_i forall a. Eq a => a -> a -> Bool
== Int
j = Adj
adj_v
                | Bool
otherwise = Shape -> PrimType -> Adj
AdjZero (forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape Type
t) (forall shape u. TypeBase shape u -> PrimType
elemType Type
t)
          -- Values for the out-of-bounds case does not matter, as we will
          -- be writing to an out-of-bounds index anyway, which is ignored.
          ooBounds :: SubExp -> ADM (Body SOACS, [SubExp] -> [Adj])
ooBounds SubExp
adj_i = forall a. ADM a -> ADM a
subAD forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) a.
MonadBuilder m =>
m (Result, a) -> m (Body (Rep m), a)
buildRenamedBody forall a b. (a -> b) -> a -> b
$ do
            forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
as [Type]
adjs_ts) forall a b. (a -> b) -> a -> b
$ \(VName
a, Type
t) -> do
              SubExp
scratch <- forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"oo_scratch" forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *). MonadBuilder m => Type -> m (Exp (Rep m))
eBlank Type
t
              VName -> (InBounds, SubExp) -> SubExp -> ADM ()
updateAdjIndex VName
a (InBounds
OutOfBounds, SubExp
adj_i) SubExp
scratch
            -- We must make sure that all free variables have the same
            -- representation in the oo-branch as in the ib-branch.
            -- In practice we do this by manifesting the adjoint.
            -- This is probably efficient, since the adjoint of a free
            -- variable is probably either a scalar or an accumulator.
            forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [VName]
free forall a b. (a -> b) -> a -> b
$ \VName
v -> VName -> VName -> ADM ()
insAdj VName
v forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Adj -> ADM VName
adjVal forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName -> ADM Adj
lookupAdj VName
v
            forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first [SubExp] -> Result
subExpsRes forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Adj] -> ([SubExp], [SubExp] -> [Adj])
adjsReps forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> ADM Adj
lookupAdj ([VName]
as forall a. Semigroup a => a -> a -> a
<> [VName]
free)
          inBounds :: Int -> SubExp -> SubExp -> ADM (Body SOACS, [SubExp] -> [Adj])
inBounds Int
res_i SubExp
adj_i SubExp
adj_v = forall a. ADM a -> ADM a
subAD forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) a.
MonadBuilder m =>
m (Result, a) -> m (Body (Rep m), a)
buildRenamedBody forall a b. (a -> b) -> a -> b
$ do
            forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip (forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
map_lam) [VName]
as) forall a b. (a -> b) -> a -> b
$ \(Param Type
p, VName
a) -> do
              Type
a_t <- forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
a
              forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [forall dec. Param dec -> VName
paramName Param Type
p] forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. BasicOp -> Exp rep
BasicOp forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> Slice SubExp -> BasicOp
Index VName
a forall a b. (a -> b) -> a -> b
$
                Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice Type
a_t [forall d. d -> DimIndex d
DimFix SubExp
adj_i]
            [SubExp]
adj_elems <-
              forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *). MonadBuilder m => Body (Rep m) -> m Result
bodyBind forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. Lambda rep -> Body rep
lambdaBody
                forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VjpOps -> [Adj] -> [VName] -> Lambda SOACS -> ADM (Lambda SOACS)
vjpLambda VjpOps
ops (Int -> Adj -> [Adj]
oneHot Int
res_i (SubExp -> Adj
AdjVal SubExp
adj_v)) [VName]
adjs_for Lambda SOACS
map_lam
            let ([SubExp]
as_adj_elems, [SubExp]
free_adj_elems) = forall a. Int -> [a] -> ([a], [a])
splitAt (forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
as) [SubExp]
adj_elems
            forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
as [SubExp]
as_adj_elems) forall a b. (a -> b) -> a -> b
$ \(VName
a, SubExp
a_adj_elem) ->
              VName -> (InBounds, SubExp) -> SubExp -> ADM ()
updateAdjIndex VName
a (InBounds
AssumeBounds, SubExp
adj_i) SubExp
a_adj_elem
            forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
free [SubExp]
free_adj_elems) forall a b. (a -> b) -> a -> b
$ \(VName
v, SubExp
adj_se) -> do
              VName
adj_se_v <- forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"adj_v" (forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
adj_se)
              VName -> VName -> ADM ()
insAdj VName
v VName
adj_se_v
            forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first [SubExp] -> Result
subExpsRes forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Adj] -> ([SubExp], [SubExp] -> [Adj])
adjsReps forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> ADM Adj
lookupAdj ([VName]
as forall a. Semigroup a => a -> a -> a
<> [VName]
free)

          -- Generate an iteration of the map function for every
          -- position.  This is a bit inefficient - probably we could do
          -- some deduplication.
          forPos :: Int -> (InBounds, SubExp, SubExp) -> ADM [()]
forPos Int
res_i (InBounds
check, SubExp
adj_i, SubExp
adj_v) = do
            [Adj]
adjs <-
              case InBounds
check of
                CheckBounds Maybe SubExp
b -> do
                  (Body SOACS
obbranch, [SubExp] -> [Adj]
mkadjs) <- SubExp -> ADM (Body SOACS, [SubExp] -> [Adj])
ooBounds SubExp
adj_i
                  (Body SOACS
ibbranch, [SubExp] -> [Adj]
_) <- Int -> SubExp -> SubExp -> ADM (Body SOACS, [SubExp] -> [Adj])
inBounds Int
res_i SubExp
adj_i SubExp
adj_v
                  forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [SubExp] -> [Adj]
mkadjs forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [SubExp]
letTupExp' String
"map_adj_elem"
                    forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
                      (forall b a. b -> (a -> b) -> Maybe a -> b
maybe (forall (m :: * -> *).
MonadBuilder m =>
m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eDimInBounds (forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp SubExp
w) (forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp SubExp
adj_i)) forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp Maybe SubExp
b)
                      (forall (f :: * -> *) a. Applicative f => a -> f a
pure Body SOACS
ibbranch)
                      (forall (f :: * -> *) a. Applicative f => a -> f a
pure Body SOACS
obbranch)
                InBounds
AssumeBounds -> do
                  (Body SOACS
body, [SubExp] -> [Adj]
mkadjs) <- Int -> SubExp -> SubExp -> ADM (Body SOACS, [SubExp] -> [Adj])
inBounds Int
res_i SubExp
adj_i SubExp
adj_v
                  [SubExp] -> [Adj]
mkadjs forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *). MonadBuilder m => Body (Rep m) -> m Result
bodyBind Body SOACS
body
                InBounds
OutOfBounds ->
                  forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> ADM Adj
lookupAdj [VName]
as

            forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM VName -> Adj -> ADM ()
setAdj ([VName]
as forall a. Semigroup a => a -> a -> a
<> [VName]
free) [Adj]
adjs

          -- Generate an iteration of the map function for every result.
          forRes :: Int -> [(InBounds, SubExp, SubExp)] -> ADM ()
forRes Int
res_i = forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Int -> (InBounds, SubExp, SubExp) -> ADM [()]
forPos Int
res_i)

      forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ Int -> [(InBounds, SubExp, SubExp)] -> ADM ()
forRes [Int
0 ..] [[(InBounds, SubExp, SubExp)]]
res_ivs
  where
    isSparse :: Adj -> Maybe [(InBounds, SubExp, SubExp)]
isSparse (AdjSparse (Sparse Shape
shape PrimType
_ [(InBounds, SubExp, SubExp)]
ivs)) = do
      forall (f :: * -> *). Alternative f => Bool -> f ()
guard forall a b. (a -> b) -> a -> b
$ forall d. ShapeBase d -> [d]
shapeDims Shape
shape forall a. Eq a => a -> a -> Bool
== [SubExp
w]
      forall a. a -> Maybe a
Just [(InBounds, SubExp, SubExp)]
ivs
    isSparse Adj
_ =
      forall a. Maybe a
Nothing
-- See Note [Adjoints of accumulators] for how we deal with
-- accumulators - it's a bit tricky here.
vjpMap VjpOps
ops [Adj]
pat_adj StmAux ()
aux SubExp
w Lambda SOACS
map_lam [VName]
as = forall a. ADM a -> ADM a
returnSweepCode forall a b. (a -> b) -> a -> b
$ do
  [VName]
pat_adj_vals <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (forall a b. [a] -> [b] -> [(a, b)]
zip [Adj]
pat_adj (forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda SOACS
map_lam)) forall a b. (a -> b) -> a -> b
$ \(Adj
adj, Type
t) ->
    case Type
t of
      Acc {} -> forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"acc_adj_rep" forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. BasicOp -> Exp rep
BasicOp forall b c a. (b -> c) -> (a -> b) -> a -> c
. Shape -> SubExp -> BasicOp
Replicate (forall d. [d] -> ShapeBase d
Shape [SubExp
w]) forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Adj -> ADM VName
adjVal Adj
adj
      Type
_ -> Adj -> ADM VName
adjVal Adj
adj
  [Param Type]
pat_adj_params <-
    forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"map_adj_p" forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall u. TypeBase Shape u -> TypeBase Shape u
rowType forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType) [VName]
pat_adj_vals

  Lambda SOACS
map_lam' <- forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda SOACS
map_lam
  [VName]
free <- forall (m :: * -> *) a.
Applicative m =>
(a -> m Bool) -> [a] -> m [a]
filterM VName -> ADM Bool
isActive forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList forall a b. (a -> b) -> a -> b
$ forall a. FreeIn a => a -> Names
freeIn Lambda SOACS
map_lam'

  [VName] -> ([VName] -> Names -> ADM ()) -> ADM ()
accAdjoints [VName]
free forall a b. (a -> b) -> a -> b
$ \[VName]
free_with_adjs Names
free_without_adjs -> do
    [VName]
free_adjs <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> ADM VName
lookupAdjVal [VName]
free_with_adjs
    [Type]
free_adjs_ts <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType [VName]
free_adjs
    [Param Type]
free_adjs_params <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"free_adj_p") [Type]
free_adjs_ts
    let lam_rev_params :: [Param Type]
lam_rev_params =
          forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
map_lam' forall a. [a] -> [a] -> [a]
++ [Param Type]
pat_adj_params forall a. [a] -> [a] -> [a]
++ [Param Type]
free_adjs_params
        adjs_for :: [VName]
adjs_for = forall a b. (a -> b) -> [a] -> [b]
map forall dec. Param dec -> VName
paramName (forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
map_lam') forall a. [a] -> [a] -> [a]
++ [VName]
free
    Lambda SOACS
lam_rev <-
      forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m Result -> m (Lambda (Rep m))
mkLambda [Param Type]
lam_rev_params forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. ADM a -> ADM a
subAD forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Names -> ADM a -> ADM a
noAdjsFor Names
free_without_adjs forall a b. (a -> b) -> a -> b
$ do
        forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> VName -> ADM ()
insAdj [VName]
free_with_adjs forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall dec. Param dec -> VName
paramName [Param Type]
free_adjs_params
        forall (m :: * -> *). MonadBuilder m => Body (Rep m) -> m Result
bodyBind forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. Lambda rep -> Body rep
lambdaBody
          forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VjpOps -> [Adj] -> [VName] -> Lambda SOACS -> ADM (Lambda SOACS)
vjpLambda VjpOps
ops (forall a b. (a -> b) -> [a] -> [b]
map forall t. Param t -> Adj
adjFromParam [Param Type]
pat_adj_params) [VName]
adjs_for Lambda SOACS
map_lam'

    ([VName]
param_contribs, [VName]
free_contribs) <-
      forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall a. Int -> [a] -> ([a], [a])
splitAt (forall (t :: * -> *) a. Foldable t => t a -> Int
length (forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
map_lam'))) forall a b. (a -> b) -> a -> b
$
        forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux ()
aux forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"map_adjs" forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$
          forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w ([VName]
as forall a. [a] -> [a] -> [a]
++ [VName]
pat_adj_vals forall a. [a] -> [a] -> [a]
++ [VName]
free_adjs) (forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
lam_rev)

    -- Crucial that we handle the free contribs first in case 'free'
    -- and 'as' intersect.
    forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> VName -> ADM ()
freeContrib [VName]
free [VName]
free_contribs
    let param_ts :: [Type]
param_ts = forall a b. (a -> b) -> [a] -> [b]
map forall dec. Typed dec => Param dec -> Type
paramType (forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
map_lam')
    forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [Type]
param_ts [VName]
as [VName]
param_contribs) forall a b. (a -> b) -> a -> b
$ \(Type
param_t, VName
a, VName
param_contrib) ->
      case Type
param_t of
        Acc {} -> VName -> VName -> ADM ()
freeContrib VName
a VName
param_contrib
        Type
_ -> VName -> VName -> ADM ()
updateAdj VName
a VName
param_contrib
  where
    addIdxParams :: Int -> Lambda rep -> m (Lambda rep)
addIdxParams Int
n Lambda rep
lam = do
      [Param (TypeBase shape u)]
idxs <- forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
n forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"idx" forall a b. (a -> b) -> a -> b
$ forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Lambda rep
lam {lambdaParams :: [LParam rep]
lambdaParams = [Param (TypeBase shape u)]
idxs forall a. [a] -> [a] -> [a]
++ forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda rep
lam}

    accAddLambda :: Int -> Type -> ADM (Lambda SOACS)
accAddLambda Int
n Type
t = forall {rep} {shape} {u} {m :: * -> *}.
(LParamInfo rep ~ TypeBase shape u, MonadFreshNames m) =>
Int -> Lambda rep -> m (Lambda rep)
addIdxParams Int
n forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Type -> ADM (Lambda SOACS)
addLambda Type
t

    withAccInput :: (VName, (a, PrimType))
-> ADM (a, [VName], Maybe (Lambda SOACS, [SubExp]))
withAccInput (VName
v, (a
shape, PrimType
pt)) = do
      VName
v_adj <- VName -> ADM VName
lookupAdjVal VName
v
      Lambda SOACS
add_lam <- Int -> Type -> ADM (Lambda SOACS)
accAddLambda (forall a. ArrayShape a => a -> Int
shapeRank a
shape) forall a b. (a -> b) -> a -> b
$ forall shape u. PrimType -> TypeBase shape u
Prim PrimType
pt
      SubExp
zero <- forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"zero" forall a b. (a -> b) -> a -> b
$ forall rep. Type -> Exp rep
zeroExp forall a b. (a -> b) -> a -> b
$ forall shape u. PrimType -> TypeBase shape u
Prim PrimType
pt
      forall (f :: * -> *) a. Applicative f => a -> f a
pure (a
shape, [VName
v_adj], forall a. a -> Maybe a
Just (Lambda SOACS
add_lam, [SubExp
zero]))

    accAdjoints :: [VName] -> ([VName] -> Names -> ADM ()) -> ADM ()
accAdjoints [VName]
free [VName] -> Names -> ADM ()
m = do
      ([(VName, (Shape, PrimType))]
arr_free, [VName]
acc_free, [VName]
nonacc_free) <-
        [AdjVar] -> ([(VName, (Shape, PrimType))], [VName], [VName])
partitionAdjVars forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [VName] -> ADM [AdjVar]
classifyAdjVars [VName]
free
      [(Shape, [VName], Maybe (Lambda SOACS, [SubExp]))]
arr_free' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall {a}.
ArrayShape a =>
(VName, (a, PrimType))
-> ADM (a, [VName], Maybe (Lambda SOACS, [SubExp]))
withAccInput [(VName, (Shape, PrimType))]
arr_free
      -- We only consider those input arrays that are also not free in
      -- the lambda.
      let as_nonfree :: [VName]
as_nonfree = forall a. (a -> Bool) -> [a] -> [a]
filter (forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` [VName]
free) [VName]
as
      ([VName]
arr_adjs, [VName]
acc_adjs, [VName]
rest_adjs) <-
        forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall a. Int -> Int -> [a] -> ([a], [a], [a])
splitAt3 (forall (t :: * -> *) a. Foldable t => t a -> Int
length [(VName, (Shape, PrimType))]
arr_free) (forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
acc_free)) forall b c a. (b -> c) -> (a -> b) -> a -> c
. [(Shape, [VName], Maybe (Lambda SOACS, [SubExp]))]
-> ([VName] -> ADM Result) -> ADM [VName]
withAcc [(Shape, [VName], Maybe (Lambda SOACS, [SubExp]))]
arr_free' forall a b. (a -> b) -> a -> b
$ \[VName]
accs -> do
          forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> VName -> ADM ()
insAdj (forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst [(VName, (Shape, PrimType))]
arr_free) [VName]
accs
          () <- [VName] -> Names -> ADM ()
m ([VName]
acc_free forall a. [a] -> [a] -> [a]
++ forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst [(VName, (Shape, PrimType))]
arr_free) ([VName] -> Names
namesFromList [VName]
nonacc_free)
          [VName]
acc_free_adj <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> ADM VName
lookupAdjVal [VName]
acc_free
          [VName]
arr_free_adj <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (VName -> ADM VName
lookupAdjVal forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst) [(VName, (Shape, PrimType))]
arr_free
          [VName]
nonacc_free_adj <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> ADM VName
lookupAdjVal [VName]
nonacc_free
          [VName]
as_nonfree_adj <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> ADM VName
lookupAdjVal [VName]
as_nonfree
          forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ [VName] -> Result
varsRes forall a b. (a -> b) -> a -> b
$ [VName]
arr_free_adj forall a. Semigroup a => a -> a -> a
<> [VName]
acc_free_adj forall a. Semigroup a => a -> a -> a
<> [VName]
nonacc_free_adj forall a. Semigroup a => a -> a -> a
<> [VName]
as_nonfree_adj
      forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> VName -> ADM ()
insAdj [VName]
acc_free [VName]
acc_adjs
      forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> VName -> ADM ()
insAdj (forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst [(VName, (Shape, PrimType))]
arr_free) [VName]
arr_adjs
      let ([VName]
nonacc_adjs, [VName]
as_nonfree_adjs) = forall a. Int -> [a] -> ([a], [a])
splitAt (forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
nonacc_free) [VName]
rest_adjs
      forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> VName -> ADM ()
insAdj [VName]
nonacc_free [VName]
nonacc_adjs
      forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> VName -> ADM ()
insAdj [VName]
as_nonfree [VName]
as_nonfree_adjs

    freeContrib :: VName -> VName -> ADM ()
freeContrib VName
v VName
contribs = do
      Type
contribs_t <- forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
contribs
      case forall u. TypeBase Shape u -> TypeBase Shape u
rowType Type
contribs_t of
        Acc {} -> forall (f :: * -> *) a. Functor f => f a -> f ()
void forall a b. (a -> b) -> a -> b
$ VName -> VName -> ADM ()
insAdj VName
v VName
contribs
        Type
t -> do
          Lambda SOACS
lam <- Type -> ADM (Lambda SOACS)
addLambda Type
t
          SubExp
zero <- forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"zero" forall a b. (a -> b) -> a -> b
$ forall rep. Type -> Exp rep
zeroExp Type
t
          ScremaForm SOACS
reduce <- forall rep (m :: * -> *).
(Buildable rep, MonadFreshNames m) =>
[Reduce rep] -> m (ScremaForm rep)
reduceSOAC [forall rep. Commutativity -> Lambda rep -> [SubExp] -> Reduce rep
Reduce Commutativity
Commutative Lambda SOACS
lam [SubExp
zero]]
          VName
contrib_sum <-
            forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp (VName -> String
baseString VName
v forall a. Semigroup a => a -> a -> a
<> String
"_contrib_sum") forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$
              forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w [VName
contribs] ScremaForm SOACS
reduce
          forall (f :: * -> *) a. Functor f => f a -> f ()
void forall a b. (a -> b) -> a -> b
$ VName -> VName -> ADM ()
updateAdj VName
v VName
contrib_sum