{-# LANGUAGE TypeFamilies #-}

-- | VJP transformation for Map SOACs.  This is a pretty complicated
-- case due to the possibility of free variables.
module Futhark.AD.Rev.Map (vjpMap) where

import Control.Monad
import Data.Bifunctor (first)
import Futhark.AD.Rev.Monad
import Futhark.Analysis.PrimExp.Convert
import Futhark.Builder
import Futhark.IR.SOACS
import Futhark.Tools
import Futhark.Transform.Rename
import Futhark.Util (splitAt3)

-- | A classification of a free variable based on its adjoint.  The
-- 'VName' stored is *not* the adjoint, but the primal variable.
data AdjVar
  = -- | Adjoint is already an accumulator.
    FreeAcc VName
  | -- | Currently has no adjoint, but should be given one, and is an
    -- array with this shape and element type.
    FreeArr VName Shape PrimType
  | -- | Does not need an accumulator adjoint (might still be an array).
    FreeNonAcc VName

classifyAdjVars :: [VName] -> ADM [AdjVar]
classifyAdjVars :: [VName] -> ADM [AdjVar]
classifyAdjVars = (VName -> ADM AdjVar) -> [VName] -> ADM [AdjVar]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM VName -> ADM AdjVar
f
  where
    f :: VName -> ADM AdjVar
f VName
v = do
      VName
v_adj <- VName -> ADM VName
lookupAdjVal VName
v
      Type
v_adj_t <- VName -> ADM Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
v_adj
      case Type
v_adj_t of
        Array PrimType
pt Shape
shape NoUniqueness
_ ->
          AdjVar -> ADM AdjVar
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (AdjVar -> ADM AdjVar) -> AdjVar -> ADM AdjVar
forall a b. (a -> b) -> a -> b
$ VName -> Shape -> PrimType -> AdjVar
FreeArr VName
v Shape
shape PrimType
pt
        Acc {} ->
          AdjVar -> ADM AdjVar
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (AdjVar -> ADM AdjVar) -> AdjVar -> ADM AdjVar
forall a b. (a -> b) -> a -> b
$ VName -> AdjVar
FreeAcc VName
v
        Type
_ ->
          AdjVar -> ADM AdjVar
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (AdjVar -> ADM AdjVar) -> AdjVar -> ADM AdjVar
forall a b. (a -> b) -> a -> b
$ VName -> AdjVar
FreeNonAcc VName
v

partitionAdjVars :: [AdjVar] -> ([(VName, (Shape, PrimType))], [VName], [VName])
partitionAdjVars :: [AdjVar] -> ([(VName, (Shape, PrimType))], [VName], [VName])
partitionAdjVars [] = ([], [], [])
partitionAdjVars (AdjVar
fv : [AdjVar]
fvs) =
  case AdjVar
fv of
    FreeArr VName
v Shape
shape PrimType
t -> ((VName
v, (Shape
shape, PrimType
t)) (VName, (Shape, PrimType))
-> [(VName, (Shape, PrimType))] -> [(VName, (Shape, PrimType))]
forall a. a -> [a] -> [a]
: [(VName, (Shape, PrimType))]
xs, [VName]
ys, [VName]
zs)
    FreeAcc VName
v -> ([(VName, (Shape, PrimType))]
xs, VName
v VName -> [VName] -> [VName]
forall a. a -> [a] -> [a]
: [VName]
ys, [VName]
zs)
    FreeNonAcc VName
v -> ([(VName, (Shape, PrimType))]
xs, [VName]
ys, VName
v VName -> [VName] -> [VName]
forall a. a -> [a] -> [a]
: [VName]
zs)
  where
    ([(VName, (Shape, PrimType))]
xs, [VName]
ys, [VName]
zs) = [AdjVar] -> ([(VName, (Shape, PrimType))], [VName], [VName])
partitionAdjVars [AdjVar]
fvs

buildRenamedBody ::
  (MonadBuilder m) =>
  m (Result, a) ->
  m (Body (Rep m), a)
buildRenamedBody :: forall (m :: * -> *) a.
MonadBuilder m =>
m (Result, a) -> m (Body (Rep m), a)
buildRenamedBody m (Result, a)
m = do
  (Body (Rep m)
body, a
x) <- m (Result, a) -> m (Body (Rep m), a)
forall (m :: * -> *) a.
MonadBuilder m =>
m (Result, a) -> m (Body (Rep m), a)
buildBody m (Result, a)
m
  Body (Rep m)
body' <- Body (Rep m) -> m (Body (Rep m))
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Body rep -> m (Body rep)
renameBody Body (Rep m)
body
  (Body (Rep m), a) -> m (Body (Rep m), a)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Body (Rep m)
body', a
x)

withAcc ::
  [(Shape, [VName], Maybe (Lambda SOACS, [SubExp]))] ->
  ([VName] -> ADM Result) ->
  ADM [VName]
withAcc :: [(Shape, [VName], Maybe (Lambda SOACS, [SubExp]))]
-> ([VName] -> ADM Result) -> ADM [VName]
withAcc [] [VName] -> ADM Result
m =
  (SubExpRes -> ADM VName) -> Result -> ADM [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"withacc_res" (Exp SOACS -> ADM VName)
-> (SubExpRes -> Exp SOACS) -> SubExpRes -> ADM VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp SOACS)
-> (SubExpRes -> BasicOp) -> SubExpRes -> Exp SOACS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> BasicOp
SubExp (SubExp -> BasicOp)
-> (SubExpRes -> SubExp) -> SubExpRes -> BasicOp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExpRes -> SubExp
resSubExp) (Result -> ADM [VName]) -> ADM Result -> ADM [VName]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [VName] -> ADM Result
m []
withAcc [(Shape, [VName], Maybe (Lambda SOACS, [SubExp]))]
inputs [VName] -> ADM Result
m = do
  ([Param Type]
cert_params, [Param Type]
acc_params) <- ([(Param Type, Param Type)] -> ([Param Type], [Param Type]))
-> ADM [(Param Type, Param Type)]
-> ADM ([Param Type], [Param Type])
forall a b. (a -> b) -> ADM a -> ADM b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [(Param Type, Param Type)] -> ([Param Type], [Param Type])
forall a b. [(a, b)] -> ([a], [b])
unzip (ADM [(Param Type, Param Type)]
 -> ADM ([Param Type], [Param Type]))
-> ADM [(Param Type, Param Type)]
-> ADM ([Param Type], [Param Type])
forall a b. (a -> b) -> a -> b
$
    [(Shape, [VName], Maybe (Lambda SOACS, [SubExp]))]
-> ((Shape, [VName], Maybe (Lambda SOACS, [SubExp]))
    -> ADM (Param Type, Param Type))
-> ADM [(Param Type, Param Type)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [(Shape, [VName], Maybe (Lambda SOACS, [SubExp]))]
inputs (((Shape, [VName], Maybe (Lambda SOACS, [SubExp]))
  -> ADM (Param Type, Param Type))
 -> ADM [(Param Type, Param Type)])
-> ((Shape, [VName], Maybe (Lambda SOACS, [SubExp]))
    -> ADM (Param Type, Param Type))
-> ADM [(Param Type, Param Type)]
forall a b. (a -> b) -> a -> b
$ \(Shape
shape, [VName]
arrs, Maybe (Lambda SOACS, [SubExp])
_) -> do
      Param Type
cert_param <- String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"acc_cert_p" (Type -> ADM (Param Type)) -> Type -> ADM (Param Type)
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
Unit
      [Type]
ts <- (VName -> ADM Type) -> [VName] -> ADM [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM ((Type -> Type) -> ADM Type -> ADM Type
forall a b. (a -> b) -> ADM a -> ADM b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Int -> Type -> Type
forall u. Int -> TypeBase Shape u -> TypeBase Shape u
stripArray (Shape -> Int
forall a. ArrayShape a => a -> Int
shapeRank Shape
shape)) (ADM Type -> ADM Type) -> (VName -> ADM Type) -> VName -> ADM Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> ADM Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType) [VName]
arrs
      Param Type
acc_param <- String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"acc_p" (Type -> ADM (Param Type)) -> Type -> ADM (Param Type)
forall a b. (a -> b) -> a -> b
$ VName -> Shape -> [Type] -> NoUniqueness -> Type
forall shape u. VName -> Shape -> [Type] -> u -> TypeBase shape u
Acc (Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
cert_param) Shape
shape [Type]
ts NoUniqueness
NoUniqueness
      (Param Type, Param Type) -> ADM (Param Type, Param Type)
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Param Type
cert_param, Param Type
acc_param)
  Lambda SOACS
acc_lam <-
    ADM (Lambda SOACS) -> ADM (Lambda SOACS)
forall a. ADM a -> ADM a
subAD (ADM (Lambda SOACS) -> ADM (Lambda SOACS))
-> ADM (Lambda SOACS) -> ADM (Lambda SOACS)
forall a b. (a -> b) -> a -> b
$ [LParam (Rep ADM)] -> ADM Result -> ADM (Lambda (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m Result -> m (Lambda (Rep m))
mkLambda ([Param Type]
cert_params [Param Type] -> [Param Type] -> [Param Type]
forall a. [a] -> [a] -> [a]
++ [Param Type]
acc_params) (ADM Result -> ADM (Lambda (Rep ADM)))
-> ADM Result -> ADM (Lambda (Rep ADM))
forall a b. (a -> b) -> a -> b
$ [VName] -> ADM Result
m ([VName] -> ADM Result) -> [VName] -> ADM Result
forall a b. (a -> b) -> a -> b
$ (Param Type -> VName) -> [Param Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> VName
forall dec. Param dec -> VName
paramName [Param Type]
acc_params
  String -> Exp (Rep ADM) -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"withhacc_res" (Exp (Rep ADM) -> ADM [VName]) -> Exp (Rep ADM) -> ADM [VName]
forall a b. (a -> b) -> a -> b
$ [(Shape, [VName], Maybe (Lambda SOACS, [SubExp]))]
-> Lambda SOACS -> Exp SOACS
forall rep. [WithAccInput rep] -> Lambda rep -> Exp rep
WithAcc [(Shape, [VName], Maybe (Lambda SOACS, [SubExp]))]
inputs Lambda SOACS
acc_lam

-- | Perform VJP on a Map.  The 'Adj' list is the adjoints of the
-- result of the map.
vjpMap :: VjpOps -> [Adj] -> StmAux () -> SubExp -> Lambda SOACS -> [VName] -> ADM ()
vjpMap :: VjpOps
-> [Adj]
-> StmAux ()
-> SubExp
-> Lambda SOACS
-> [VName]
-> ADM ()
vjpMap VjpOps
ops [Adj]
res_adjs StmAux ()
_ SubExp
w Lambda SOACS
map_lam [VName]
as
  | Just [[(InBounds, SubExp, SubExp)]]
res_ivs <- (Adj -> Maybe [(InBounds, SubExp, SubExp)])
-> [Adj] -> Maybe [[(InBounds, SubExp, SubExp)]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM Adj -> Maybe [(InBounds, SubExp, SubExp)]
isSparse [Adj]
res_adjs = ADM () -> ADM ()
forall a. ADM a -> ADM a
returnSweepCode (ADM () -> ADM ()) -> ADM () -> ADM ()
forall a b. (a -> b) -> a -> b
$ do
      -- Since at most only a constant number of adjoint are nonzero
      -- (length res_ivs), there is no need for the return sweep code to
      -- contain a Map at all.

      [VName]
free <- (VName -> ADM Bool) -> [VName] -> ADM [VName]
forall (m :: * -> *) a.
Applicative m =>
(a -> m Bool) -> [a] -> m [a]
filterM VName -> ADM Bool
isActive ([VName] -> ADM [VName]) -> [VName] -> ADM [VName]
forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> Names
forall a. FreeIn a => a -> Names
freeIn Lambda SOACS
map_lam Names -> Names -> Names
`namesSubtract` [VName] -> Names
namesFromList [VName]
as
      [Type]
free_ts <- (VName -> ADM Type) -> [VName] -> ADM [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM VName -> ADM Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType [VName]
free
      let adjs_for :: [VName]
adjs_for = (Param Type -> VName) -> [Param Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> VName
forall dec. Param dec -> VName
paramName (Lambda SOACS -> [LParam SOACS]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
map_lam) [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
free
          adjs_ts :: [Type]
adjs_ts = (Param Type -> Type) -> [Param Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> Type
forall dec. Typed dec => Param dec -> Type
paramType (Lambda SOACS -> [LParam SOACS]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
map_lam) [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ [Type]
free_ts

      let oneHot :: Int -> Adj -> [Adj]
oneHot Int
res_i Adj
adj_v = (Int -> Type -> Adj) -> [Int] -> [Type] -> [Adj]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Int -> Type -> Adj
f [Int
0 :: Int ..] ([Type] -> [Adj]) -> [Type] -> [Adj]
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda SOACS
map_lam
            where
              f :: Int -> Type -> Adj
f Int
j Type
t
                | Int
res_i Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
j = Adj
adj_v
                | Bool
otherwise = Shape -> PrimType -> Adj
AdjZero (Type -> Shape
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape Type
t) (Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType Type
t)
          -- Values for the out-of-bounds case does not matter, as we will
          -- be writing to an out-of-bounds index anyway, which is ignored.
          ooBounds :: SubExp -> ADM (Body SOACS, [SubExp] -> [Adj])
ooBounds SubExp
adj_i = ADM (Body SOACS, [SubExp] -> [Adj])
-> ADM (Body SOACS, [SubExp] -> [Adj])
forall a. ADM a -> ADM a
subAD (ADM (Body SOACS, [SubExp] -> [Adj])
 -> ADM (Body SOACS, [SubExp] -> [Adj]))
-> (ADM (Result, [SubExp] -> [Adj])
    -> ADM (Body SOACS, [SubExp] -> [Adj]))
-> ADM (Result, [SubExp] -> [Adj])
-> ADM (Body SOACS, [SubExp] -> [Adj])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ADM (Result, [SubExp] -> [Adj])
-> ADM (Body (Rep ADM), [SubExp] -> [Adj])
ADM (Result, [SubExp] -> [Adj])
-> ADM (Body SOACS, [SubExp] -> [Adj])
forall (m :: * -> *) a.
MonadBuilder m =>
m (Result, a) -> m (Body (Rep m), a)
buildRenamedBody (ADM (Result, [SubExp] -> [Adj])
 -> ADM (Body SOACS, [SubExp] -> [Adj]))
-> ADM (Result, [SubExp] -> [Adj])
-> ADM (Body SOACS, [SubExp] -> [Adj])
forall a b. (a -> b) -> a -> b
$ do
            [(VName, Type)] -> ((VName, Type) -> ADM ()) -> ADM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [Type] -> [(VName, Type)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
as [Type]
adjs_ts) (((VName, Type) -> ADM ()) -> ADM ())
-> ((VName, Type) -> ADM ()) -> ADM ()
forall a b. (a -> b) -> a -> b
$ \(VName
a, Type
t) -> do
              SubExp
scratch <- String -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"oo_scratch" (Exp SOACS -> ADM SubExp) -> ADM (Exp SOACS) -> ADM SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Type -> ADM (Exp (Rep ADM))
forall (m :: * -> *). MonadBuilder m => Type -> m (Exp (Rep m))
eBlank Type
t
              VName -> (InBounds, SubExp) -> SubExp -> ADM ()
updateAdjIndex VName
a (InBounds
OutOfBounds, SubExp
adj_i) SubExp
scratch
            -- We must make sure that all free variables have the same
            -- representation in the oo-branch as in the ib-branch.
            -- In practice we do this by manifesting the adjoint.
            -- This is probably efficient, since the adjoint of a free
            -- variable is probably either a scalar or an accumulator.
            [VName] -> (VName -> ADM ()) -> ADM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [VName]
free ((VName -> ADM ()) -> ADM ()) -> (VName -> ADM ()) -> ADM ()
forall a b. (a -> b) -> a -> b
$ \VName
v -> VName -> VName -> ADM ()
insAdj VName
v (VName -> ADM ()) -> ADM VName -> ADM ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Adj -> ADM VName
adjVal (Adj -> ADM VName) -> ADM Adj -> ADM VName
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName -> ADM Adj
lookupAdj VName
v
            ([SubExp] -> Result)
-> ([SubExp], [SubExp] -> [Adj]) -> (Result, [SubExp] -> [Adj])
forall a b c. (a -> b) -> (a, c) -> (b, c)
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first [SubExp] -> Result
subExpsRes (([SubExp], [SubExp] -> [Adj]) -> (Result, [SubExp] -> [Adj]))
-> ([Adj] -> ([SubExp], [SubExp] -> [Adj]))
-> [Adj]
-> (Result, [SubExp] -> [Adj])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Adj] -> ([SubExp], [SubExp] -> [Adj])
adjsReps ([Adj] -> (Result, [SubExp] -> [Adj]))
-> ADM [Adj] -> ADM (Result, [SubExp] -> [Adj])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (VName -> ADM Adj) -> [VName] -> ADM [Adj]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM VName -> ADM Adj
lookupAdj ([VName]
as [VName] -> [VName] -> [VName]
forall a. Semigroup a => a -> a -> a
<> [VName]
free)
          inBounds :: Int -> SubExp -> SubExp -> ADM (Body SOACS, [SubExp] -> [Adj])
inBounds Int
res_i SubExp
adj_i SubExp
adj_v = ADM (Body SOACS, [SubExp] -> [Adj])
-> ADM (Body SOACS, [SubExp] -> [Adj])
forall a. ADM a -> ADM a
subAD (ADM (Body SOACS, [SubExp] -> [Adj])
 -> ADM (Body SOACS, [SubExp] -> [Adj]))
-> (ADM (Result, [SubExp] -> [Adj])
    -> ADM (Body SOACS, [SubExp] -> [Adj]))
-> ADM (Result, [SubExp] -> [Adj])
-> ADM (Body SOACS, [SubExp] -> [Adj])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ADM (Result, [SubExp] -> [Adj])
-> ADM (Body (Rep ADM), [SubExp] -> [Adj])
ADM (Result, [SubExp] -> [Adj])
-> ADM (Body SOACS, [SubExp] -> [Adj])
forall (m :: * -> *) a.
MonadBuilder m =>
m (Result, a) -> m (Body (Rep m), a)
buildRenamedBody (ADM (Result, [SubExp] -> [Adj])
 -> ADM (Body SOACS, [SubExp] -> [Adj]))
-> ADM (Result, [SubExp] -> [Adj])
-> ADM (Body SOACS, [SubExp] -> [Adj])
forall a b. (a -> b) -> a -> b
$ do
            [(Param Type, VName)] -> ((Param Type, VName) -> ADM ()) -> ADM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param Type] -> [VName] -> [(Param Type, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Lambda SOACS -> [LParam SOACS]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
map_lam) [VName]
as) (((Param Type, VName) -> ADM ()) -> ADM ())
-> ((Param Type, VName) -> ADM ()) -> ADM ()
forall a b. (a -> b) -> a -> b
$ \(Param Type
p, VName
a) -> do
              Type
a_t <- VName -> ADM Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
a
              [VName] -> Exp (Rep ADM) -> ADM ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
p] (Exp SOACS -> ADM ())
-> (Slice SubExp -> Exp SOACS) -> Slice SubExp -> ADM ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp SOACS)
-> (Slice SubExp -> BasicOp) -> Slice SubExp -> Exp SOACS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> Slice SubExp -> BasicOp
Index VName
a (Slice SubExp -> ADM ()) -> Slice SubExp -> ADM ()
forall a b. (a -> b) -> a -> b
$
                Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice Type
a_t [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix SubExp
adj_i]
            [SubExp]
adj_elems <-
              (Result -> [SubExp]) -> ADM Result -> ADM [SubExp]
forall a b. (a -> b) -> ADM a -> ADM b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((SubExpRes -> SubExp) -> Result -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp) (ADM Result -> ADM [SubExp])
-> (Lambda SOACS -> ADM Result) -> Lambda SOACS -> ADM [SubExp]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Body (Rep ADM) -> ADM Result
Body SOACS -> ADM Result
forall (m :: * -> *). MonadBuilder m => Body (Rep m) -> m Result
bodyBind (Body SOACS -> ADM Result)
-> (Lambda SOACS -> Body SOACS) -> Lambda SOACS -> ADM Result
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Lambda SOACS -> Body SOACS
forall rep. Lambda rep -> Body rep
lambdaBody
                (Lambda SOACS -> ADM [SubExp])
-> ADM (Lambda SOACS) -> ADM [SubExp]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VjpOps -> [Adj] -> [VName] -> Lambda SOACS -> ADM (Lambda SOACS)
vjpLambda VjpOps
ops (Int -> Adj -> [Adj]
oneHot Int
res_i (SubExp -> Adj
AdjVal SubExp
adj_v)) [VName]
adjs_for Lambda SOACS
map_lam
            let ([SubExp]
as_adj_elems, [SubExp]
free_adj_elems) = Int -> [SubExp] -> ([SubExp], [SubExp])
forall a. Int -> [a] -> ([a], [a])
splitAt ([VName] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
as) [SubExp]
adj_elems
            [(VName, SubExp)] -> ((VName, SubExp) -> ADM ()) -> ADM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
as [SubExp]
as_adj_elems) (((VName, SubExp) -> ADM ()) -> ADM ())
-> ((VName, SubExp) -> ADM ()) -> ADM ()
forall a b. (a -> b) -> a -> b
$ \(VName
a, SubExp
a_adj_elem) ->
              VName -> (InBounds, SubExp) -> SubExp -> ADM ()
updateAdjIndex VName
a (InBounds
AssumeBounds, SubExp
adj_i) SubExp
a_adj_elem
            [(VName, SubExp)] -> ((VName, SubExp) -> ADM ()) -> ADM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
free [SubExp]
free_adj_elems) (((VName, SubExp) -> ADM ()) -> ADM ())
-> ((VName, SubExp) -> ADM ()) -> ADM ()
forall a b. (a -> b) -> a -> b
$ \(VName
v, SubExp
adj_se) -> do
              VName
adj_se_v <- String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"adj_v" (BasicOp -> Exp (Rep ADM)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep ADM)) -> BasicOp -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
adj_se)
              VName -> VName -> ADM ()
insAdj VName
v VName
adj_se_v
            ([SubExp] -> Result)
-> ([SubExp], [SubExp] -> [Adj]) -> (Result, [SubExp] -> [Adj])
forall a b c. (a -> b) -> (a, c) -> (b, c)
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first [SubExp] -> Result
subExpsRes (([SubExp], [SubExp] -> [Adj]) -> (Result, [SubExp] -> [Adj]))
-> ([Adj] -> ([SubExp], [SubExp] -> [Adj]))
-> [Adj]
-> (Result, [SubExp] -> [Adj])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Adj] -> ([SubExp], [SubExp] -> [Adj])
adjsReps ([Adj] -> (Result, [SubExp] -> [Adj]))
-> ADM [Adj] -> ADM (Result, [SubExp] -> [Adj])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (VName -> ADM Adj) -> [VName] -> ADM [Adj]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM VName -> ADM Adj
lookupAdj ([VName]
as [VName] -> [VName] -> [VName]
forall a. Semigroup a => a -> a -> a
<> [VName]
free)

          -- Generate an iteration of the map function for every
          -- position.  This is a bit inefficient - probably we could do
          -- some deduplication.
          forPos :: Int -> (InBounds, SubExp, SubExp) -> ADM [()]
forPos Int
res_i (InBounds
check, SubExp
adj_i, SubExp
adj_v) = do
            [Adj]
adjs <-
              case InBounds
check of
                CheckBounds Maybe SubExp
b -> do
                  (Body SOACS
obbranch, [SubExp] -> [Adj]
mkadjs) <- SubExp -> ADM (Body SOACS, [SubExp] -> [Adj])
ooBounds SubExp
adj_i
                  (Body SOACS
ibbranch, [SubExp] -> [Adj]
_) <- Int -> SubExp -> SubExp -> ADM (Body SOACS, [SubExp] -> [Adj])
inBounds Int
res_i SubExp
adj_i SubExp
adj_v
                  ([SubExp] -> [Adj]) -> ADM [SubExp] -> ADM [Adj]
forall a b. (a -> b) -> ADM a -> ADM b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [SubExp] -> [Adj]
mkadjs (ADM [SubExp] -> ADM [Adj])
-> (Exp SOACS -> ADM [SubExp]) -> Exp SOACS -> ADM [Adj]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Exp (Rep ADM) -> ADM [SubExp]
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [SubExp]
letTupExp' String
"map_adj_elem"
                    (Exp SOACS -> ADM [Adj]) -> ADM (Exp SOACS) -> ADM [Adj]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ADM (Exp (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
                      (ADM (Exp SOACS)
-> (SubExp -> ADM (Exp SOACS)) -> Maybe SubExp -> ADM (Exp SOACS)
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (ADM (Exp (Rep ADM)) -> ADM (Exp (Rep ADM)) -> ADM (Exp (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eDimInBounds (SubExp -> ADM (Exp (Rep ADM))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp SubExp
w) (SubExp -> ADM (Exp (Rep ADM))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp SubExp
adj_i)) SubExp -> ADM (Exp (Rep ADM))
SubExp -> ADM (Exp SOACS)
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp Maybe SubExp
b)
                      (Body SOACS -> ADM (Body SOACS)
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Body SOACS
ibbranch)
                      (Body SOACS -> ADM (Body SOACS)
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Body SOACS
obbranch)
                InBounds
AssumeBounds -> do
                  (Body SOACS
body, [SubExp] -> [Adj]
mkadjs) <- Int -> SubExp -> SubExp -> ADM (Body SOACS, [SubExp] -> [Adj])
inBounds Int
res_i SubExp
adj_i SubExp
adj_v
                  [SubExp] -> [Adj]
mkadjs ([SubExp] -> [Adj]) -> (Result -> [SubExp]) -> Result -> [Adj]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (SubExpRes -> SubExp) -> Result -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp (Result -> [Adj]) -> ADM Result -> ADM [Adj]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Body (Rep ADM) -> ADM Result
forall (m :: * -> *). MonadBuilder m => Body (Rep m) -> m Result
bodyBind Body (Rep ADM)
Body SOACS
body
                InBounds
OutOfBounds ->
                  (VName -> ADM Adj) -> [VName] -> ADM [Adj]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM VName -> ADM Adj
lookupAdj [VName]
as

            (VName -> Adj -> ADM ()) -> [VName] -> [Adj] -> ADM [()]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM VName -> Adj -> ADM ()
setAdj ([VName]
as [VName] -> [VName] -> [VName]
forall a. Semigroup a => a -> a -> a
<> [VName]
free) [Adj]
adjs

          -- Generate an iteration of the map function for every result.
          forRes :: Int -> [(InBounds, SubExp, SubExp)] -> ADM ()
forRes Int
res_i = ((InBounds, SubExp, SubExp) -> ADM [()])
-> [(InBounds, SubExp, SubExp)] -> ADM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Int -> (InBounds, SubExp, SubExp) -> ADM [()]
forPos Int
res_i)

      (Int -> [(InBounds, SubExp, SubExp)] -> ADM ())
-> [Int] -> [[(InBounds, SubExp, SubExp)]] -> ADM ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ Int -> [(InBounds, SubExp, SubExp)] -> ADM ()
forRes [Int
0 ..] [[(InBounds, SubExp, SubExp)]]
res_ivs
  where
    isSparse :: Adj -> Maybe [(InBounds, SubExp, SubExp)]
isSparse (AdjSparse (Sparse Shape
shape PrimType
_ [(InBounds, SubExp, SubExp)]
ivs)) = do
      Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$ Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
shape [SubExp] -> [SubExp] -> Bool
forall a. Eq a => a -> a -> Bool
== [SubExp
w]
      [(InBounds, SubExp, SubExp)] -> Maybe [(InBounds, SubExp, SubExp)]
forall a. a -> Maybe a
Just [(InBounds, SubExp, SubExp)]
ivs
    isSparse Adj
_ =
      Maybe [(InBounds, SubExp, SubExp)]
forall a. Maybe a
Nothing
-- See Note [Adjoints of accumulators] for how we deal with
-- accumulators - it's a bit tricky here.
vjpMap VjpOps
ops [Adj]
pat_adj StmAux ()
aux SubExp
w Lambda SOACS
map_lam [VName]
as = ADM () -> ADM ()
forall a. ADM a -> ADM a
returnSweepCode (ADM () -> ADM ()) -> ADM () -> ADM ()
forall a b. (a -> b) -> a -> b
$ do
  [VName]
pat_adj_vals <- [(Adj, Type)] -> ((Adj, Type) -> ADM VName) -> ADM [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([Adj] -> [Type] -> [(Adj, Type)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Adj]
pat_adj (Lambda SOACS -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda SOACS
map_lam)) (((Adj, Type) -> ADM VName) -> ADM [VName])
-> ((Adj, Type) -> ADM VName) -> ADM [VName]
forall a b. (a -> b) -> a -> b
$ \(Adj
adj, Type
t) ->
    case Type
t of
      Acc {} -> String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"acc_adj_rep" (Exp SOACS -> ADM VName)
-> (VName -> Exp SOACS) -> VName -> ADM VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp SOACS) -> (VName -> BasicOp) -> VName -> Exp SOACS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Shape -> SubExp -> BasicOp
Replicate ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp
w]) (SubExp -> BasicOp) -> (VName -> SubExp) -> VName -> BasicOp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var (VName -> ADM VName) -> ADM VName -> ADM VName
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Adj -> ADM VName
adjVal Adj
adj
      Type
_ -> Adj -> ADM VName
adjVal Adj
adj
  [Param Type]
pat_adj_params <-
    (VName -> ADM (Param Type)) -> [VName] -> ADM [Param Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"map_adj_p" (Type -> ADM (Param Type))
-> (Type -> Type) -> Type -> ADM (Param Type)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Type -> Type
forall u. TypeBase Shape u -> TypeBase Shape u
rowType (Type -> ADM (Param Type))
-> (VName -> ADM Type) -> VName -> ADM (Param Type)
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< VName -> ADM Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType) [VName]
pat_adj_vals

  Lambda SOACS
map_lam' <- Lambda SOACS -> ADM (Lambda SOACS)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda SOACS
map_lam
  [VName]
free <- (VName -> ADM Bool) -> [VName] -> ADM [VName]
forall (m :: * -> *) a.
Applicative m =>
(a -> m Bool) -> [a] -> m [a]
filterM VName -> ADM Bool
isActive ([VName] -> ADM [VName]) -> [VName] -> ADM [VName]
forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> Names
forall a. FreeIn a => a -> Names
freeIn Lambda SOACS
map_lam'

  [VName] -> ([VName] -> Names -> ADM ()) -> ADM ()
accAdjoints [VName]
free (([VName] -> Names -> ADM ()) -> ADM ())
-> ([VName] -> Names -> ADM ()) -> ADM ()
forall a b. (a -> b) -> a -> b
$ \[VName]
free_with_adjs Names
free_without_adjs -> do
    [VName]
free_adjs <- (VName -> ADM VName) -> [VName] -> ADM [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM VName -> ADM VName
lookupAdjVal [VName]
free_with_adjs
    [Type]
free_adjs_ts <- (VName -> ADM Type) -> [VName] -> ADM [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM VName -> ADM Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType [VName]
free_adjs
    [Param Type]
free_adjs_params <- (Type -> ADM (Param Type)) -> [Type] -> ADM [Param Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"free_adj_p") [Type]
free_adjs_ts
    let lam_rev_params :: [Param Type]
lam_rev_params =
          Lambda SOACS -> [LParam SOACS]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
map_lam' [Param Type] -> [Param Type] -> [Param Type]
forall a. [a] -> [a] -> [a]
++ [Param Type]
pat_adj_params [Param Type] -> [Param Type] -> [Param Type]
forall a. [a] -> [a] -> [a]
++ [Param Type]
free_adjs_params
        adjs_for :: [VName]
adjs_for = (Param Type -> VName) -> [Param Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> VName
forall dec. Param dec -> VName
paramName (Lambda SOACS -> [LParam SOACS]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
map_lam') [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
free
    Lambda SOACS
lam_rev <-
      [LParam (Rep ADM)] -> ADM Result -> ADM (Lambda (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m Result -> m (Lambda (Rep m))
mkLambda [Param Type]
[LParam (Rep ADM)]
lam_rev_params (ADM Result -> ADM (Lambda SOACS))
-> (ADM Result -> ADM Result) -> ADM Result -> ADM (Lambda SOACS)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ADM Result -> ADM Result
forall a. ADM a -> ADM a
subAD (ADM Result -> ADM Result)
-> (ADM Result -> ADM Result) -> ADM Result -> ADM Result
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Names -> ADM Result -> ADM Result
forall a. Names -> ADM a -> ADM a
noAdjsFor Names
free_without_adjs (ADM Result -> ADM (Lambda SOACS))
-> ADM Result -> ADM (Lambda SOACS)
forall a b. (a -> b) -> a -> b
$ do
        (VName -> VName -> ADM ()) -> [VName] -> [VName] -> ADM ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> VName -> ADM ()
insAdj [VName]
free_with_adjs ([VName] -> ADM ()) -> [VName] -> ADM ()
forall a b. (a -> b) -> a -> b
$ (Param Type -> VName) -> [Param Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> VName
forall dec. Param dec -> VName
paramName [Param Type]
free_adjs_params
        Body (Rep ADM) -> ADM Result
Body SOACS -> ADM Result
forall (m :: * -> *). MonadBuilder m => Body (Rep m) -> m Result
bodyBind (Body SOACS -> ADM Result)
-> (Lambda SOACS -> Body SOACS) -> Lambda SOACS -> ADM Result
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Lambda SOACS -> Body SOACS
forall rep. Lambda rep -> Body rep
lambdaBody
          (Lambda SOACS -> ADM Result) -> ADM (Lambda SOACS) -> ADM Result
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VjpOps -> [Adj] -> [VName] -> Lambda SOACS -> ADM (Lambda SOACS)
vjpLambda VjpOps
ops ((Param Type -> Adj) -> [Param Type] -> [Adj]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> Adj
forall t. Param t -> Adj
adjFromParam [Param Type]
pat_adj_params) [VName]
adjs_for Lambda SOACS
map_lam'

    ([VName]
param_contribs, [VName]
free_contribs) <-
      ([VName] -> ([VName], [VName]))
-> ADM [VName] -> ADM ([VName], [VName])
forall a b. (a -> b) -> ADM a -> ADM b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Int -> [VName] -> ([VName], [VName])
forall a. Int -> [a] -> ([a], [a])
splitAt ([Param Type] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (Lambda SOACS -> [LParam SOACS]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
map_lam'))) (ADM [VName] -> ADM ([VName], [VName]))
-> ADM [VName] -> ADM ([VName], [VName])
forall a b. (a -> b) -> a -> b
$
        StmAux () -> ADM [VName] -> ADM [VName]
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux ()
aux (ADM [VName] -> ADM [VName])
-> (SOAC SOACS -> ADM [VName]) -> SOAC SOACS -> ADM [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Exp (Rep ADM) -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"map_adjs" (Exp SOACS -> ADM [VName])
-> (SOAC SOACS -> Exp SOACS) -> SOAC SOACS -> ADM [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Op SOACS -> Exp SOACS
SOAC SOACS -> Exp SOACS
forall rep. Op rep -> Exp rep
Op (SOAC SOACS -> ADM [VName]) -> SOAC SOACS -> ADM [VName]
forall a b. (a -> b) -> a -> b
$
          SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w ([VName]
as [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
pat_adj_vals [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
free_adjs) (Lambda SOACS -> ScremaForm SOACS
forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
lam_rev)

    -- Crucial that we handle the free contribs first in case 'free'
    -- and 'as' intersect.
    (VName -> VName -> ADM ()) -> [VName] -> [VName] -> ADM ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> VName -> ADM ()
freeContrib [VName]
free [VName]
free_contribs
    let param_ts :: [Type]
param_ts = (Param Type -> Type) -> [Param Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> Type
forall dec. Typed dec => Param dec -> Type
paramType (Lambda SOACS -> [LParam SOACS]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
map_lam')
    [(Type, VName, VName)]
-> ((Type, VName, VName) -> ADM ()) -> ADM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Type] -> [VName] -> [VName] -> [(Type, VName, VName)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [Type]
param_ts [VName]
as [VName]
param_contribs) (((Type, VName, VName) -> ADM ()) -> ADM ())
-> ((Type, VName, VName) -> ADM ()) -> ADM ()
forall a b. (a -> b) -> a -> b
$ \(Type
param_t, VName
a, VName
param_contrib) ->
      case Type
param_t of
        Acc {} -> VName -> VName -> ADM ()
freeContrib VName
a VName
param_contrib
        Type
_ -> VName -> VName -> ADM ()
updateAdj VName
a VName
param_contrib
  where
    addIdxParams :: Int -> Lambda rep -> m (Lambda rep)
addIdxParams Int
n Lambda rep
lam = do
      [Param (TypeBase shape u)]
idxs <- Int -> m (Param (TypeBase shape u)) -> m [Param (TypeBase shape u)]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
n (m (Param (TypeBase shape u)) -> m [Param (TypeBase shape u)])
-> m (Param (TypeBase shape u)) -> m [Param (TypeBase shape u)]
forall a b. (a -> b) -> a -> b
$ String -> TypeBase shape u -> m (Param (TypeBase shape u))
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"idx" (TypeBase shape u -> m (Param (TypeBase shape u)))
-> TypeBase shape u -> m (Param (TypeBase shape u))
forall a b. (a -> b) -> a -> b
$ PrimType -> TypeBase shape u
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
      Lambda rep -> m (Lambda rep)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Lambda rep -> m (Lambda rep)) -> Lambda rep -> m (Lambda rep)
forall a b. (a -> b) -> a -> b
$ Lambda rep
lam {lambdaParams = idxs ++ lambdaParams lam}

    accAddLambda :: Int -> Type -> ADM (Lambda SOACS)
accAddLambda Int
n Type
t = Int -> Lambda SOACS -> ADM (Lambda SOACS)
forall {rep} {shape} {u} {m :: * -> *}.
(LParamInfo rep ~ TypeBase shape u, MonadFreshNames m) =>
Int -> Lambda rep -> m (Lambda rep)
addIdxParams Int
n (Lambda SOACS -> ADM (Lambda SOACS))
-> ADM (Lambda SOACS) -> ADM (Lambda SOACS)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Type -> ADM (Lambda SOACS)
addLambda Type
t

    withAccInput :: (VName, (a, PrimType))
-> ADM (a, [VName], Maybe (Lambda SOACS, [SubExp]))
withAccInput (VName
v, (a
shape, PrimType
pt)) = do
      VName
v_adj <- VName -> ADM VName
lookupAdjVal VName
v
      Lambda SOACS
add_lam <- Int -> Type -> ADM (Lambda SOACS)
accAddLambda (a -> Int
forall a. ArrayShape a => a -> Int
shapeRank a
shape) (Type -> ADM (Lambda SOACS)) -> Type -> ADM (Lambda SOACS)
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
pt
      SubExp
zero <- String -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"zero" (Exp (Rep ADM) -> ADM SubExp) -> Exp (Rep ADM) -> ADM SubExp
forall a b. (a -> b) -> a -> b
$ Type -> Exp (Rep ADM)
forall rep. Type -> Exp rep
zeroExp (Type -> Exp (Rep ADM)) -> Type -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
pt
      (a, [VName], Maybe (Lambda SOACS, [SubExp]))
-> ADM (a, [VName], Maybe (Lambda SOACS, [SubExp]))
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (a
shape, [VName
v_adj], (Lambda SOACS, [SubExp]) -> Maybe (Lambda SOACS, [SubExp])
forall a. a -> Maybe a
Just (Lambda SOACS
add_lam, [SubExp
zero]))

    accAdjoints :: [VName] -> ([VName] -> Names -> ADM ()) -> ADM ()
accAdjoints [VName]
free [VName] -> Names -> ADM ()
m = do
      ([(VName, (Shape, PrimType))]
arr_free, [VName]
acc_free, [VName]
nonacc_free) <-
        [AdjVar] -> ([(VName, (Shape, PrimType))], [VName], [VName])
partitionAdjVars ([AdjVar] -> ([(VName, (Shape, PrimType))], [VName], [VName]))
-> ADM [AdjVar]
-> ADM ([(VName, (Shape, PrimType))], [VName], [VName])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [VName] -> ADM [AdjVar]
classifyAdjVars [VName]
free
      [(Shape, [VName], Maybe (Lambda SOACS, [SubExp]))]
arr_free' <- ((VName, (Shape, PrimType))
 -> ADM (Shape, [VName], Maybe (Lambda SOACS, [SubExp])))
-> [(VName, (Shape, PrimType))]
-> ADM [(Shape, [VName], Maybe (Lambda SOACS, [SubExp]))]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (VName, (Shape, PrimType))
-> ADM (Shape, [VName], Maybe (Lambda SOACS, [SubExp]))
forall {a}.
ArrayShape a =>
(VName, (a, PrimType))
-> ADM (a, [VName], Maybe (Lambda SOACS, [SubExp]))
withAccInput [(VName, (Shape, PrimType))]
arr_free
      -- We only consider those input arrays that are also not free in
      -- the lambda.
      let as_nonfree :: [VName]
as_nonfree = (VName -> Bool) -> [VName] -> [VName]
forall a. (a -> Bool) -> [a] -> [a]
filter (VName -> [VName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` [VName]
free) [VName]
as
      ([VName]
arr_adjs, [VName]
acc_adjs, [VName]
rest_adjs) <-
        ([VName] -> ([VName], [VName], [VName]))
-> ADM [VName] -> ADM ([VName], [VName], [VName])
forall a b. (a -> b) -> ADM a -> ADM b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Int -> Int -> [VName] -> ([VName], [VName], [VName])
forall a. Int -> Int -> [a] -> ([a], [a], [a])
splitAt3 ([(VName, (Shape, PrimType))] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [(VName, (Shape, PrimType))]
arr_free) ([VName] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
acc_free)) (ADM [VName] -> ADM ([VName], [VName], [VName]))
-> (([VName] -> ADM Result) -> ADM [VName])
-> ([VName] -> ADM Result)
-> ADM ([VName], [VName], [VName])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [(Shape, [VName], Maybe (Lambda SOACS, [SubExp]))]
-> ([VName] -> ADM Result) -> ADM [VName]
withAcc [(Shape, [VName], Maybe (Lambda SOACS, [SubExp]))]
arr_free' (([VName] -> ADM Result) -> ADM ([VName], [VName], [VName]))
-> ([VName] -> ADM Result) -> ADM ([VName], [VName], [VName])
forall a b. (a -> b) -> a -> b
$ \[VName]
accs -> do
          (VName -> VName -> ADM ()) -> [VName] -> [VName] -> ADM ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> VName -> ADM ()
insAdj (((VName, (Shape, PrimType)) -> VName)
-> [(VName, (Shape, PrimType))] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (VName, (Shape, PrimType)) -> VName
forall a b. (a, b) -> a
fst [(VName, (Shape, PrimType))]
arr_free) [VName]
accs
          () <- [VName] -> Names -> ADM ()
m ([VName]
acc_free [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ ((VName, (Shape, PrimType)) -> VName)
-> [(VName, (Shape, PrimType))] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (VName, (Shape, PrimType)) -> VName
forall a b. (a, b) -> a
fst [(VName, (Shape, PrimType))]
arr_free) ([VName] -> Names
namesFromList [VName]
nonacc_free)
          [VName]
acc_free_adj <- (VName -> ADM VName) -> [VName] -> ADM [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM VName -> ADM VName
lookupAdjVal [VName]
acc_free
          [VName]
arr_free_adj <- ((VName, (Shape, PrimType)) -> ADM VName)
-> [(VName, (Shape, PrimType))] -> ADM [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (VName -> ADM VName
lookupAdjVal (VName -> ADM VName)
-> ((VName, (Shape, PrimType)) -> VName)
-> (VName, (Shape, PrimType))
-> ADM VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName, (Shape, PrimType)) -> VName
forall a b. (a, b) -> a
fst) [(VName, (Shape, PrimType))]
arr_free
          [VName]
nonacc_free_adj <- (VName -> ADM VName) -> [VName] -> ADM [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM VName -> ADM VName
lookupAdjVal [VName]
nonacc_free
          [VName]
as_nonfree_adj <- (VName -> ADM VName) -> [VName] -> ADM [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM VName -> ADM VName
lookupAdjVal [VName]
as_nonfree
          Result -> ADM Result
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Result -> ADM Result) -> Result -> ADM Result
forall a b. (a -> b) -> a -> b
$ [VName] -> Result
varsRes ([VName] -> Result) -> [VName] -> Result
forall a b. (a -> b) -> a -> b
$ [VName]
arr_free_adj [VName] -> [VName] -> [VName]
forall a. Semigroup a => a -> a -> a
<> [VName]
acc_free_adj [VName] -> [VName] -> [VName]
forall a. Semigroup a => a -> a -> a
<> [VName]
nonacc_free_adj [VName] -> [VName] -> [VName]
forall a. Semigroup a => a -> a -> a
<> [VName]
as_nonfree_adj
      (VName -> VName -> ADM ()) -> [VName] -> [VName] -> ADM ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> VName -> ADM ()
insAdj [VName]
acc_free [VName]
acc_adjs
      (VName -> VName -> ADM ()) -> [VName] -> [VName] -> ADM ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> VName -> ADM ()
insAdj (((VName, (Shape, PrimType)) -> VName)
-> [(VName, (Shape, PrimType))] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (VName, (Shape, PrimType)) -> VName
forall a b. (a, b) -> a
fst [(VName, (Shape, PrimType))]
arr_free) [VName]
arr_adjs
      let ([VName]
nonacc_adjs, [VName]
as_nonfree_adjs) = Int -> [VName] -> ([VName], [VName])
forall a. Int -> [a] -> ([a], [a])
splitAt ([VName] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
nonacc_free) [VName]
rest_adjs
      (VName -> VName -> ADM ()) -> [VName] -> [VName] -> ADM ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> VName -> ADM ()
insAdj [VName]
nonacc_free [VName]
nonacc_adjs
      (VName -> VName -> ADM ()) -> [VName] -> [VName] -> ADM ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> VName -> ADM ()
insAdj [VName]
as_nonfree [VName]
as_nonfree_adjs

    freeContrib :: VName -> VName -> ADM ()
freeContrib VName
v VName
contribs = do
      Type
contribs_t <- VName -> ADM Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
contribs
      case Type -> Type
forall u. TypeBase Shape u -> TypeBase Shape u
rowType Type
contribs_t of
        Acc {} -> ADM () -> ADM ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (ADM () -> ADM ()) -> ADM () -> ADM ()
forall a b. (a -> b) -> a -> b
$ VName -> VName -> ADM ()
insAdj VName
v VName
contribs
        Type
t -> do
          Lambda SOACS
lam <- Type -> ADM (Lambda SOACS)
addLambda Type
t
          SubExp
zero <- String -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"zero" (Exp (Rep ADM) -> ADM SubExp) -> Exp (Rep ADM) -> ADM SubExp
forall a b. (a -> b) -> a -> b
$ Type -> Exp SOACS
forall rep. Type -> Exp rep
zeroExp Type
t
          ScremaForm SOACS
reduce <- [Reduce SOACS] -> ADM (ScremaForm SOACS)
forall rep (m :: * -> *).
(Buildable rep, MonadFreshNames m) =>
[Reduce rep] -> m (ScremaForm rep)
reduceSOAC [Commutativity -> Lambda SOACS -> [SubExp] -> Reduce SOACS
forall rep. Commutativity -> Lambda rep -> [SubExp] -> Reduce rep
Reduce Commutativity
Commutative Lambda SOACS
lam [SubExp
zero]]
          VName
contrib_sum <-
            String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp (VName -> String
baseString VName
v String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_contrib_sum") (Exp SOACS -> ADM VName)
-> (SOAC SOACS -> Exp SOACS) -> SOAC SOACS -> ADM VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Op SOACS -> Exp SOACS
SOAC SOACS -> Exp SOACS
forall rep. Op rep -> Exp rep
Op (SOAC SOACS -> ADM VName) -> SOAC SOACS -> ADM VName
forall a b. (a -> b) -> a -> b
$
              SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w [VName
contribs] ScremaForm SOACS
reduce
          ADM () -> ADM ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (ADM () -> ADM ()) -> ADM () -> ADM ()
forall a b. (a -> b) -> a -> b
$ VName -> VName -> ADM ()
updateAdj VName
v VName
contrib_sum