{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TypeFamilies #-}

module Futhark.AD.Rev.Map (vjpMap) where

import Control.Monad
import Data.Bifunctor (first)
import Futhark.AD.Rev.Monad
import Futhark.Analysis.PrimExp.Convert
import Futhark.Builder
import Futhark.IR.SOACS
import Futhark.Tools
import Futhark.Transform.Rename
import Futhark.Util (splitAt3)

-- | A classification of a free variable based on its adjoint.  The
-- 'VName' stored is *not* the adjoint, but the primal variable.
data AdjVar
  = -- | Adjoint is already an accumulator.
    FreeAcc VName
  | -- | Currently has no adjoint, but should be given one, and is an
    -- array with this shape and element type.
    FreeArr VName Shape PrimType
  | -- | Does not need an accumulator adjoint (might still be an array).
    FreeNonAcc VName

classifyAdjVars :: [VName] -> ADM [AdjVar]
classifyAdjVars :: [VName] -> ADM [AdjVar]
classifyAdjVars = (VName -> ADM AdjVar) -> [VName] -> ADM [AdjVar]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> ADM AdjVar
f
  where
    f :: VName -> ADM AdjVar
f VName
v = do
      VName
v_adj <- VName -> ADM VName
lookupAdjVal VName
v
      Type
v_adj_t <- VName -> ADM Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
v_adj
      case Type
v_adj_t of
        Array PrimType
pt Shape
shape NoUniqueness
_ ->
          AdjVar -> ADM AdjVar
forall (f :: * -> *) a. Applicative f => a -> f a
pure (AdjVar -> ADM AdjVar) -> AdjVar -> ADM AdjVar
forall a b. (a -> b) -> a -> b
$ VName -> Shape -> PrimType -> AdjVar
FreeArr VName
v Shape
shape PrimType
pt
        Acc {} ->
          AdjVar -> ADM AdjVar
forall (f :: * -> *) a. Applicative f => a -> f a
pure (AdjVar -> ADM AdjVar) -> AdjVar -> ADM AdjVar
forall a b. (a -> b) -> a -> b
$ VName -> AdjVar
FreeAcc VName
v
        Type
_ ->
          AdjVar -> ADM AdjVar
forall (f :: * -> *) a. Applicative f => a -> f a
pure (AdjVar -> ADM AdjVar) -> AdjVar -> ADM AdjVar
forall a b. (a -> b) -> a -> b
$ VName -> AdjVar
FreeNonAcc VName
v

partitionAdjVars :: [AdjVar] -> ([(VName, (Shape, PrimType))], [VName], [VName])
partitionAdjVars :: [AdjVar] -> ([(VName, (Shape, PrimType))], [VName], [VName])
partitionAdjVars [] = ([], [], [])
partitionAdjVars (AdjVar
fv : [AdjVar]
fvs) =
  case AdjVar
fv of
    FreeArr VName
v Shape
shape PrimType
t -> ((VName
v, (Shape
shape, PrimType
t)) (VName, (Shape, PrimType))
-> [(VName, (Shape, PrimType))] -> [(VName, (Shape, PrimType))]
forall a. a -> [a] -> [a]
: [(VName, (Shape, PrimType))]
xs, [VName]
ys, [VName]
zs)
    FreeAcc VName
v -> ([(VName, (Shape, PrimType))]
xs, VName
v VName -> [VName] -> [VName]
forall a. a -> [a] -> [a]
: [VName]
ys, [VName]
zs)
    FreeNonAcc VName
v -> ([(VName, (Shape, PrimType))]
xs, [VName]
ys, VName
v VName -> [VName] -> [VName]
forall a. a -> [a] -> [a]
: [VName]
zs)
  where
    ([(VName, (Shape, PrimType))]
xs, [VName]
ys, [VName]
zs) = [AdjVar] -> ([(VName, (Shape, PrimType))], [VName], [VName])
partitionAdjVars [AdjVar]
fvs

buildRenamedBody ::
  MonadBuilder m =>
  m (Result, a) ->
  m (Body (Rep m), a)
buildRenamedBody :: m (Result, a) -> m (Body (Rep m), a)
buildRenamedBody m (Result, a)
m = do
  (Body (Rep m)
body, a
x) <- m (Result, a) -> m (Body (Rep m), a)
forall (m :: * -> *) a.
MonadBuilder m =>
m (Result, a) -> m (Body (Rep m), a)
buildBody m (Result, a)
m
  Body (Rep m)
body' <- Body (Rep m) -> m (Body (Rep m))
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Body rep -> m (Body rep)
renameBody Body (Rep m)
body
  (Body (Rep m), a) -> m (Body (Rep m), a)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Body (Rep m)
body', a
x)

withAcc ::
  [(Shape, [VName], Maybe (Lambda SOACS, [SubExp]))] ->
  ([VName] -> ADM Result) ->
  ADM [VName]
withAcc :: [(Shape, [VName], Maybe (Lambda SOACS, [SubExp]))]
-> ([VName] -> ADM Result) -> ADM [VName]
withAcc [] [VName] -> ADM Result
m =
  (SubExpRes -> ADM VName) -> Result -> ADM [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"withacc_res" (Exp SOACS -> ADM VName)
-> (SubExpRes -> Exp SOACS) -> SubExpRes -> ADM VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp SOACS)
-> (SubExpRes -> BasicOp) -> SubExpRes -> Exp SOACS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> BasicOp
SubExp (SubExp -> BasicOp)
-> (SubExpRes -> SubExp) -> SubExpRes -> BasicOp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExpRes -> SubExp
resSubExp) (Result -> ADM [VName]) -> ADM Result -> ADM [VName]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [VName] -> ADM Result
m []
withAcc [(Shape, [VName], Maybe (Lambda SOACS, [SubExp]))]
inputs [VName] -> ADM Result
m = do
  ([Param Type]
cert_params, [Param Type]
acc_params) <- ([(Param Type, Param Type)] -> ([Param Type], [Param Type]))
-> ADM [(Param Type, Param Type)]
-> ADM ([Param Type], [Param Type])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [(Param Type, Param Type)] -> ([Param Type], [Param Type])
forall a b. [(a, b)] -> ([a], [b])
unzip (ADM [(Param Type, Param Type)]
 -> ADM ([Param Type], [Param Type]))
-> ADM [(Param Type, Param Type)]
-> ADM ([Param Type], [Param Type])
forall a b. (a -> b) -> a -> b
$
    [(Shape, [VName], Maybe (Lambda SOACS, [SubExp]))]
-> ((Shape, [VName], Maybe (Lambda SOACS, [SubExp]))
    -> ADM (Param Type, Param Type))
-> ADM [(Param Type, Param Type)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [(Shape, [VName], Maybe (Lambda SOACS, [SubExp]))]
inputs (((Shape, [VName], Maybe (Lambda SOACS, [SubExp]))
  -> ADM (Param Type, Param Type))
 -> ADM [(Param Type, Param Type)])
-> ((Shape, [VName], Maybe (Lambda SOACS, [SubExp]))
    -> ADM (Param Type, Param Type))
-> ADM [(Param Type, Param Type)]
forall a b. (a -> b) -> a -> b
$ \(Shape
shape, [VName]
arrs, Maybe (Lambda SOACS, [SubExp])
_) -> do
      Param Type
cert_param <- String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"acc_cert_p" (Type -> ADM (Param Type)) -> Type -> ADM (Param Type)
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
Unit
      [Type]
ts <- (VName -> ADM Type) -> [VName] -> ADM [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((Type -> Type) -> ADM Type -> ADM Type
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Int -> Type -> Type
forall u. Int -> TypeBase Shape u -> TypeBase Shape u
stripArray (Shape -> Int
forall a. ArrayShape a => a -> Int
shapeRank Shape
shape)) (ADM Type -> ADM Type) -> (VName -> ADM Type) -> VName -> ADM Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> ADM Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType) [VName]
arrs
      Param Type
acc_param <- String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"acc_p" (Type -> ADM (Param Type)) -> Type -> ADM (Param Type)
forall a b. (a -> b) -> a -> b
$ VName -> Shape -> [Type] -> NoUniqueness -> Type
forall shape u. VName -> Shape -> [Type] -> u -> TypeBase shape u
Acc (Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
cert_param) Shape
shape [Type]
ts NoUniqueness
NoUniqueness
      (Param Type, Param Type) -> ADM (Param Type, Param Type)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Param Type
cert_param, Param Type
acc_param)
  Lambda SOACS
acc_lam <-
    ADM (Lambda SOACS) -> ADM (Lambda SOACS)
forall a. ADM a -> ADM a
subAD (ADM (Lambda SOACS) -> ADM (Lambda SOACS))
-> ADM (Lambda SOACS) -> ADM (Lambda SOACS)
forall a b. (a -> b) -> a -> b
$ [LParam (Rep ADM)] -> ADM Result -> ADM (Lambda (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m Result -> m (Lambda (Rep m))
mkLambda ([Param Type]
cert_params [Param Type] -> [Param Type] -> [Param Type]
forall a. [a] -> [a] -> [a]
++ [Param Type]
acc_params) (ADM Result -> ADM (Lambda (Rep ADM)))
-> ADM Result -> ADM (Lambda (Rep ADM))
forall a b. (a -> b) -> a -> b
$ [VName] -> ADM Result
m ([VName] -> ADM Result) -> [VName] -> ADM Result
forall a b. (a -> b) -> a -> b
$ (Param Type -> VName) -> [Param Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> VName
forall dec. Param dec -> VName
paramName [Param Type]
acc_params
  String -> Exp (Rep ADM) -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"withhacc_res" (Exp (Rep ADM) -> ADM [VName]) -> Exp (Rep ADM) -> ADM [VName]
forall a b. (a -> b) -> a -> b
$ [(Shape, [VName], Maybe (Lambda SOACS, [SubExp]))]
-> Lambda SOACS -> Exp SOACS
forall rep. [WithAccInput rep] -> Lambda rep -> Exp rep
WithAcc [(Shape, [VName], Maybe (Lambda SOACS, [SubExp]))]
inputs Lambda SOACS
acc_lam

vjpMap :: VjpOps -> [Adj] -> StmAux () -> SubExp -> Lambda SOACS -> [VName] -> ADM ()
vjpMap :: VjpOps
-> [Adj]
-> StmAux ()
-> SubExp
-> Lambda SOACS
-> [VName]
-> ADM ()
vjpMap VjpOps
ops [Adj]
res_adjs StmAux ()
_ SubExp
w Lambda SOACS
map_lam [VName]
as
  | Just [[(InBounds, SubExp, SubExp)]]
res_ivs <- (Adj -> Maybe [(InBounds, SubExp, SubExp)])
-> [Adj] -> Maybe [[(InBounds, SubExp, SubExp)]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Adj -> Maybe [(InBounds, SubExp, SubExp)]
isSparse [Adj]
res_adjs = ADM () -> ADM ()
forall a. ADM a -> ADM a
returnSweepCode (ADM () -> ADM ()) -> ADM () -> ADM ()
forall a b. (a -> b) -> a -> b
$ do
      -- Since at most only a constant number of adjoint are nonzero
      -- (length res_ivs), there is no need for the return sweep code to
      -- contain a Map at all.

      [VName]
free <- (VName -> ADM Bool) -> [VName] -> ADM [VName]
forall (m :: * -> *) a.
Applicative m =>
(a -> m Bool) -> [a] -> m [a]
filterM VName -> ADM Bool
isActive ([VName] -> ADM [VName]) -> [VName] -> ADM [VName]
forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> Names
forall a. FreeIn a => a -> Names
freeIn Lambda SOACS
map_lam
      [Type]
free_ts <- (VName -> ADM Type) -> [VName] -> ADM [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> ADM Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType [VName]
free
      let adjs_for :: [VName]
adjs_for = (Param Type -> VName) -> [Param Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> VName
forall dec. Param dec -> VName
paramName (Lambda SOACS -> [LParam SOACS]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
map_lam) [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
free
          adjs_ts :: [Type]
adjs_ts = (Param Type -> Type) -> [Param Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> Type
forall dec. Typed dec => Param dec -> Type
paramType (Lambda SOACS -> [LParam SOACS]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
map_lam) [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ [Type]
free_ts

      let oneHot :: Int -> Adj -> [Adj]
oneHot Int
res_i Adj
adj_v = (Int -> Type -> Adj) -> [Int] -> [Type] -> [Adj]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Int -> Type -> Adj
f [Int
0 :: Int ..] ([Type] -> [Adj]) -> [Type] -> [Adj]
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda SOACS
map_lam
            where
              f :: Int -> Type -> Adj
f Int
j Type
t
                | Int
res_i Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
j = Adj
adj_v
                | Bool
otherwise = Shape -> PrimType -> Adj
AdjZero (Type -> Shape
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape Type
t) (Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType Type
t)
          -- Values for the out-of-bounds case does not matter, as we will
          -- be writing to an out-of-bounds index anyway, which is ignored.
          ooBounds :: SubExp -> ADM (Body SOACS, [SubExp] -> [Adj])
ooBounds SubExp
adj_i = ADM (Body SOACS, [SubExp] -> [Adj])
-> ADM (Body SOACS, [SubExp] -> [Adj])
forall a. ADM a -> ADM a
subAD (ADM (Body SOACS, [SubExp] -> [Adj])
 -> ADM (Body SOACS, [SubExp] -> [Adj]))
-> (ADM (Result, [SubExp] -> [Adj])
    -> ADM (Body SOACS, [SubExp] -> [Adj]))
-> ADM (Result, [SubExp] -> [Adj])
-> ADM (Body SOACS, [SubExp] -> [Adj])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ADM (Result, [SubExp] -> [Adj])
-> ADM (Body SOACS, [SubExp] -> [Adj])
forall (m :: * -> *) a.
MonadBuilder m =>
m (Result, a) -> m (Body (Rep m), a)
buildRenamedBody (ADM (Result, [SubExp] -> [Adj])
 -> ADM (Body SOACS, [SubExp] -> [Adj]))
-> ADM (Result, [SubExp] -> [Adj])
-> ADM (Body SOACS, [SubExp] -> [Adj])
forall a b. (a -> b) -> a -> b
$ do
            [(VName, Type)] -> ((VName, Type) -> ADM ()) -> ADM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [Type] -> [(VName, Type)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
as [Type]
adjs_ts) (((VName, Type) -> ADM ()) -> ADM ())
-> ((VName, Type) -> ADM ()) -> ADM ()
forall a b. (a -> b) -> a -> b
$ \(VName
a, Type
t) -> do
              SubExp
scratch <- String -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"oo_scratch" (Exp SOACS -> ADM SubExp) -> ADM (Exp SOACS) -> ADM SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Type -> ADM (Exp (Rep ADM))
forall (m :: * -> *). MonadBuilder m => Type -> m (Exp (Rep m))
eBlank Type
t
              VName -> (InBounds, SubExp) -> SubExp -> ADM ()
updateAdjIndex VName
a (InBounds
OutOfBounds, SubExp
adj_i) SubExp
scratch
            ([SubExp] -> Result)
-> ([SubExp], [SubExp] -> [Adj]) -> (Result, [SubExp] -> [Adj])
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first [SubExp] -> Result
subExpsRes (([SubExp], [SubExp] -> [Adj]) -> (Result, [SubExp] -> [Adj]))
-> ([Adj] -> ([SubExp], [SubExp] -> [Adj]))
-> [Adj]
-> (Result, [SubExp] -> [Adj])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Adj] -> ([SubExp], [SubExp] -> [Adj])
adjsReps ([Adj] -> (Result, [SubExp] -> [Adj]))
-> ADM [Adj] -> ADM (Result, [SubExp] -> [Adj])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (VName -> ADM Adj) -> [VName] -> ADM [Adj]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> ADM Adj
lookupAdj [VName]
as
          inBounds :: Int -> SubExp -> SubExp -> ADM (Body SOACS, [SubExp] -> [Adj])
inBounds Int
res_i SubExp
adj_i SubExp
adj_v = ADM (Body SOACS, [SubExp] -> [Adj])
-> ADM (Body SOACS, [SubExp] -> [Adj])
forall a. ADM a -> ADM a
subAD (ADM (Body SOACS, [SubExp] -> [Adj])
 -> ADM (Body SOACS, [SubExp] -> [Adj]))
-> (ADM (Result, [SubExp] -> [Adj])
    -> ADM (Body SOACS, [SubExp] -> [Adj]))
-> ADM (Result, [SubExp] -> [Adj])
-> ADM (Body SOACS, [SubExp] -> [Adj])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ADM (Result, [SubExp] -> [Adj])
-> ADM (Body SOACS, [SubExp] -> [Adj])
forall (m :: * -> *) a.
MonadBuilder m =>
m (Result, a) -> m (Body (Rep m), a)
buildRenamedBody (ADM (Result, [SubExp] -> [Adj])
 -> ADM (Body SOACS, [SubExp] -> [Adj]))
-> ADM (Result, [SubExp] -> [Adj])
-> ADM (Body SOACS, [SubExp] -> [Adj])
forall a b. (a -> b) -> a -> b
$ do
            [(Param Type, VName)] -> ((Param Type, VName) -> ADM ()) -> ADM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param Type] -> [VName] -> [(Param Type, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Lambda SOACS -> [LParam SOACS]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
map_lam) [VName]
as) (((Param Type, VName) -> ADM ()) -> ADM ())
-> ((Param Type, VName) -> ADM ()) -> ADM ()
forall a b. (a -> b) -> a -> b
$ \(Param Type
p, VName
a) -> do
              Type
a_t <- VName -> ADM Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
a
              [VName] -> Exp (Rep ADM) -> ADM ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
p] (Exp (Rep ADM) -> ADM ()) -> Exp (Rep ADM) -> ADM ()
forall a b. (a -> b) -> a -> b
$
                BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp SOACS) -> BasicOp -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
a (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice Type
a_t [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix SubExp
adj_i]
            [SubExp]
adj_elems <-
              (Result -> [SubExp]) -> ADM Result -> ADM [SubExp]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((SubExpRes -> SubExp) -> Result -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp) (ADM Result -> ADM [SubExp])
-> (Lambda SOACS -> ADM Result) -> Lambda SOACS -> ADM [SubExp]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Body SOACS -> ADM Result
forall (m :: * -> *). MonadBuilder m => Body (Rep m) -> m Result
bodyBind (Body SOACS -> ADM Result)
-> (Lambda SOACS -> Body SOACS) -> Lambda SOACS -> ADM Result
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Lambda SOACS -> Body SOACS
forall rep. Lambda rep -> Body rep
lambdaBody
                (Lambda SOACS -> ADM [SubExp])
-> ADM (Lambda SOACS) -> ADM [SubExp]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VjpOps -> [Adj] -> [VName] -> Lambda SOACS -> ADM (Lambda SOACS)
vjpLambda VjpOps
ops (Int -> Adj -> [Adj]
oneHot Int
res_i (SubExp -> Adj
AdjVal SubExp
adj_v)) [VName]
adjs_for Lambda SOACS
map_lam
            [(VName, SubExp)] -> ((VName, SubExp) -> ADM ()) -> ADM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
as [SubExp]
adj_elems) (((VName, SubExp) -> ADM ()) -> ADM ())
-> ((VName, SubExp) -> ADM ()) -> ADM ()
forall a b. (a -> b) -> a -> b
$ \(VName
a, SubExp
a_adj_elem) -> do
              VName -> (InBounds, SubExp) -> SubExp -> ADM ()
updateAdjIndex VName
a (InBounds
AssumeBounds, SubExp
adj_i) SubExp
a_adj_elem
            ([SubExp] -> Result)
-> ([SubExp], [SubExp] -> [Adj]) -> (Result, [SubExp] -> [Adj])
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first [SubExp] -> Result
subExpsRes (([SubExp], [SubExp] -> [Adj]) -> (Result, [SubExp] -> [Adj]))
-> ([Adj] -> ([SubExp], [SubExp] -> [Adj]))
-> [Adj]
-> (Result, [SubExp] -> [Adj])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Adj] -> ([SubExp], [SubExp] -> [Adj])
adjsReps ([Adj] -> (Result, [SubExp] -> [Adj]))
-> ADM [Adj] -> ADM (Result, [SubExp] -> [Adj])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (VName -> ADM Adj) -> [VName] -> ADM [Adj]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> ADM Adj
lookupAdj [VName]
as

          -- Generate an iteration of the map function for every
          -- position.  This is a bit inefficient - probably we could do
          -- some deduplication.
          forPos :: Int -> (InBounds, SubExp, SubExp) -> ADM [()]
forPos Int
res_i (InBounds
check, SubExp
adj_i, SubExp
adj_v) = do
            [Adj]
as_adj <-
              case InBounds
check of
                CheckBounds Maybe SubExp
b -> do
                  (Body SOACS
obbranch, [SubExp] -> [Adj]
mkadjs) <- SubExp -> ADM (Body SOACS, [SubExp] -> [Adj])
ooBounds SubExp
adj_i
                  (Body SOACS
ibbranch, [SubExp] -> [Adj]
_) <- Int -> SubExp -> SubExp -> ADM (Body SOACS, [SubExp] -> [Adj])
inBounds Int
res_i SubExp
adj_i SubExp
adj_v
                  ([SubExp] -> [Adj]) -> ADM [SubExp] -> ADM [Adj]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [SubExp] -> [Adj]
mkadjs (ADM [SubExp] -> ADM [Adj])
-> (Exp SOACS -> ADM [SubExp]) -> Exp SOACS -> ADM [Adj]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Exp (Rep ADM) -> ADM [SubExp]
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [SubExp]
letTupExp' String
"map_adj_elem"
                    (Exp SOACS -> ADM [Adj]) -> ADM (Exp SOACS) -> ADM [Adj]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ADM (Exp (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
                      (ADM (Exp SOACS)
-> (SubExp -> ADM (Exp SOACS)) -> Maybe SubExp -> ADM (Exp SOACS)
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (ADM (Exp (Rep ADM)) -> ADM (Exp (Rep ADM)) -> ADM (Exp (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eDimInBounds (SubExp -> ADM (Exp (Rep ADM))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp SubExp
w) (SubExp -> ADM (Exp (Rep ADM))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp SubExp
adj_i)) SubExp -> ADM (Exp SOACS)
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp Maybe SubExp
b)
                      (Body SOACS -> ADM (Body SOACS)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Body SOACS
ibbranch)
                      (Body SOACS -> ADM (Body SOACS)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Body SOACS
obbranch)
                InBounds
AssumeBounds -> do
                  (Body SOACS
body, [SubExp] -> [Adj]
mkadjs) <- Int -> SubExp -> SubExp -> ADM (Body SOACS, [SubExp] -> [Adj])
inBounds Int
res_i SubExp
adj_i SubExp
adj_v
                  [SubExp] -> [Adj]
mkadjs ([SubExp] -> [Adj]) -> (Result -> [SubExp]) -> Result -> [Adj]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (SubExpRes -> SubExp) -> Result -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp (Result -> [Adj]) -> ADM Result -> ADM [Adj]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Body (Rep ADM) -> ADM Result
forall (m :: * -> *). MonadBuilder m => Body (Rep m) -> m Result
bodyBind Body (Rep ADM)
Body SOACS
body
                InBounds
OutOfBounds ->
                  (VName -> ADM Adj) -> [VName] -> ADM [Adj]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> ADM Adj
lookupAdj [VName]
as

            (VName -> Adj -> ADM ()) -> [VName] -> [Adj] -> ADM [()]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM VName -> Adj -> ADM ()
setAdj [VName]
as [Adj]
as_adj

          -- Generate an iteration of the map function for every result.
          forRes :: Int -> [(InBounds, SubExp, SubExp)] -> ADM ()
forRes Int
res_i = ((InBounds, SubExp, SubExp) -> ADM [()])
-> [(InBounds, SubExp, SubExp)] -> ADM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Int -> (InBounds, SubExp, SubExp) -> ADM [()]
forPos Int
res_i)

      (Int -> [(InBounds, SubExp, SubExp)] -> ADM ())
-> [Int] -> [[(InBounds, SubExp, SubExp)]] -> ADM ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ Int -> [(InBounds, SubExp, SubExp)] -> ADM ()
forRes [Int
0 ..] [[(InBounds, SubExp, SubExp)]]
res_ivs
  where
    isSparse :: Adj -> Maybe [(InBounds, SubExp, SubExp)]
isSparse (AdjSparse (Sparse Shape
shape PrimType
_ [(InBounds, SubExp, SubExp)]
ivs)) = do
      Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$ Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
shape [SubExp] -> [SubExp] -> Bool
forall a. Eq a => a -> a -> Bool
== [SubExp
w]
      [(InBounds, SubExp, SubExp)] -> Maybe [(InBounds, SubExp, SubExp)]
forall a. a -> Maybe a
Just [(InBounds, SubExp, SubExp)]
ivs
    isSparse Adj
_ =
      Maybe [(InBounds, SubExp, SubExp)]
forall a. Maybe a
Nothing
-- See Note [Adjoints of accumulators] for how we deal with
-- accumulators - it's a bit tricky here.
vjpMap VjpOps
ops [Adj]
pat_adj StmAux ()
aux SubExp
w Lambda SOACS
map_lam [VName]
as = ADM () -> ADM ()
forall a. ADM a -> ADM a
returnSweepCode (ADM () -> ADM ()) -> ADM () -> ADM ()
forall a b. (a -> b) -> a -> b
$ do
  [VName]
pat_adj_vals <- [(Adj, Type)] -> ((Adj, Type) -> ADM VName) -> ADM [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([Adj] -> [Type] -> [(Adj, Type)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Adj]
pat_adj (Lambda SOACS -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda SOACS
map_lam)) (((Adj, Type) -> ADM VName) -> ADM [VName])
-> ((Adj, Type) -> ADM VName) -> ADM [VName]
forall a b. (a -> b) -> a -> b
$ \(Adj
adj, Type
t) ->
    case Type
t of
      Acc {} -> String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"acc_adj_rep" (Exp SOACS -> ADM VName)
-> (VName -> Exp SOACS) -> VName -> ADM VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp SOACS) -> (VName -> BasicOp) -> VName -> Exp SOACS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Shape -> SubExp -> BasicOp
Replicate ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp
w]) (SubExp -> BasicOp) -> (VName -> SubExp) -> VName -> BasicOp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var (VName -> ADM VName) -> ADM VName -> ADM VName
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Adj -> ADM VName
adjVal Adj
adj
      Type
_ -> Adj -> ADM VName
adjVal Adj
adj
  [Param Type]
pat_adj_params <-
    (VName -> ADM (Param Type)) -> [VName] -> ADM [Param Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"map_adj_p" (Type -> ADM (Param Type))
-> (Type -> Type) -> Type -> ADM (Param Type)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Type -> Type
forall u. TypeBase Shape u -> TypeBase Shape u
rowType (Type -> ADM (Param Type))
-> (VName -> ADM Type) -> VName -> ADM (Param Type)
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< VName -> ADM Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType) [VName]
pat_adj_vals

  Lambda SOACS
map_lam' <- Lambda SOACS -> ADM (Lambda SOACS)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda SOACS
map_lam
  [VName]
free <- (VName -> ADM Bool) -> [VName] -> ADM [VName]
forall (m :: * -> *) a.
Applicative m =>
(a -> m Bool) -> [a] -> m [a]
filterM VName -> ADM Bool
isActive ([VName] -> ADM [VName]) -> [VName] -> ADM [VName]
forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> Names
forall a. FreeIn a => a -> Names
freeIn Lambda SOACS
map_lam'

  [VName] -> ([VName] -> Names -> ADM ()) -> ADM ()
accAdjoints [VName]
free (([VName] -> Names -> ADM ()) -> ADM ())
-> ([VName] -> Names -> ADM ()) -> ADM ()
forall a b. (a -> b) -> a -> b
$ \[VName]
free_with_adjs Names
free_without_adjs -> do
    [VName]
free_adjs <- (VName -> ADM VName) -> [VName] -> ADM [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> ADM VName
lookupAdjVal [VName]
free_with_adjs
    [Type]
free_adjs_ts <- (VName -> ADM Type) -> [VName] -> ADM [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> ADM Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType [VName]
free_adjs
    [Param Type]
free_adjs_params <- (Type -> ADM (Param Type)) -> [Type] -> ADM [Param Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"free_adj_p") [Type]
free_adjs_ts
    let lam_rev_params :: [Param Type]
lam_rev_params =
          Lambda SOACS -> [LParam SOACS]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
map_lam' [Param Type] -> [Param Type] -> [Param Type]
forall a. [a] -> [a] -> [a]
++ [Param Type]
pat_adj_params [Param Type] -> [Param Type] -> [Param Type]
forall a. [a] -> [a] -> [a]
++ [Param Type]
free_adjs_params
        adjs_for :: [VName]
adjs_for = (Param Type -> VName) -> [Param Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> VName
forall dec. Param dec -> VName
paramName (Lambda SOACS -> [LParam SOACS]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
map_lam') [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
free
    Lambda SOACS
lam_rev <-
      [LParam (Rep ADM)] -> ADM Result -> ADM (Lambda (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m Result -> m (Lambda (Rep m))
mkLambda [Param Type]
[LParam (Rep ADM)]
lam_rev_params (ADM Result -> ADM (Lambda SOACS))
-> (ADM Result -> ADM Result) -> ADM Result -> ADM (Lambda SOACS)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ADM Result -> ADM Result
forall a. ADM a -> ADM a
subAD (ADM Result -> ADM Result)
-> (ADM Result -> ADM Result) -> ADM Result -> ADM Result
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Names -> ADM Result -> ADM Result
forall a. Names -> ADM a -> ADM a
noAdjsFor Names
free_without_adjs (ADM Result -> ADM (Lambda SOACS))
-> ADM Result -> ADM (Lambda SOACS)
forall a b. (a -> b) -> a -> b
$ do
        (VName -> VName -> ADM ()) -> [VName] -> [VName] -> ADM ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> VName -> ADM ()
insAdj [VName]
free_with_adjs ([VName] -> ADM ()) -> [VName] -> ADM ()
forall a b. (a -> b) -> a -> b
$ (Param Type -> VName) -> [Param Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> VName
forall dec. Param dec -> VName
paramName [Param Type]
free_adjs_params
        Body SOACS -> ADM Result
forall (m :: * -> *). MonadBuilder m => Body (Rep m) -> m Result
bodyBind (Body SOACS -> ADM Result)
-> (Lambda SOACS -> Body SOACS) -> Lambda SOACS -> ADM Result
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Lambda SOACS -> Body SOACS
forall rep. Lambda rep -> Body rep
lambdaBody
          (Lambda SOACS -> ADM Result) -> ADM (Lambda SOACS) -> ADM Result
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VjpOps -> [Adj] -> [VName] -> Lambda SOACS -> ADM (Lambda SOACS)
vjpLambda VjpOps
ops ((Param Type -> Adj) -> [Param Type] -> [Adj]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> Adj
forall t. Param t -> Adj
adjFromParam [Param Type]
pat_adj_params) [VName]
adjs_for Lambda SOACS
map_lam'

    ([VName]
param_contribs, [VName]
free_contribs) <-
      ([VName] -> ([VName], [VName]))
-> ADM [VName] -> ADM ([VName], [VName])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Int -> [VName] -> ([VName], [VName])
forall a. Int -> [a] -> ([a], [a])
splitAt ([Param Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (Lambda SOACS -> [LParam SOACS]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
map_lam'))) (ADM [VName] -> ADM ([VName], [VName]))
-> ADM [VName] -> ADM ([VName], [VName])
forall a b. (a -> b) -> a -> b
$
        StmAux () -> ADM [VName] -> ADM [VName]
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux ()
aux (ADM [VName] -> ADM [VName])
-> (SOAC SOACS -> ADM [VName]) -> SOAC SOACS -> ADM [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Exp (Rep ADM) -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"map_adjs" (Exp SOACS -> ADM [VName])
-> (SOAC SOACS -> Exp SOACS) -> SOAC SOACS -> ADM [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SOAC SOACS -> Exp SOACS
forall rep. Op rep -> Exp rep
Op (SOAC SOACS -> ADM [VName]) -> SOAC SOACS -> ADM [VName]
forall a b. (a -> b) -> a -> b
$
          SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w ([VName]
as [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
pat_adj_vals [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
free_adjs) (Lambda SOACS -> ScremaForm SOACS
forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
lam_rev)

    -- Crucial that we handle the free contribs first in case 'free'
    -- and 'as' intersect.
    (VName -> VName -> ADM ()) -> [VName] -> [VName] -> ADM ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> VName -> ADM ()
freeContrib [VName]
free [VName]
free_contribs
    let param_ts :: [Type]
param_ts = (Param Type -> Type) -> [Param Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> Type
forall dec. Typed dec => Param dec -> Type
paramType (Lambda SOACS -> [LParam SOACS]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
map_lam')
    [(Type, VName, VName)]
-> ((Type, VName, VName) -> ADM ()) -> ADM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Type] -> [VName] -> [VName] -> [(Type, VName, VName)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [Type]
param_ts [VName]
as [VName]
param_contribs) (((Type, VName, VName) -> ADM ()) -> ADM ())
-> ((Type, VName, VName) -> ADM ()) -> ADM ()
forall a b. (a -> b) -> a -> b
$ \(Type
param_t, VName
a, VName
param_contrib) ->
      case Type
param_t of
        Acc {} -> VName -> VName -> ADM ()
freeContrib VName
a VName
param_contrib
        Type
_ -> VName -> VName -> ADM ()
updateAdj VName
a VName
param_contrib
  where
    addIdxParams :: Int -> Lambda rep -> m (Lambda rep)
addIdxParams Int
n Lambda rep
lam = do
      [Param (TypeBase shape u)]
idxs <- Int -> m (Param (TypeBase shape u)) -> m [Param (TypeBase shape u)]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
n (m (Param (TypeBase shape u)) -> m [Param (TypeBase shape u)])
-> m (Param (TypeBase shape u)) -> m [Param (TypeBase shape u)]
forall a b. (a -> b) -> a -> b
$ String -> TypeBase shape u -> m (Param (TypeBase shape u))
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"idx" (TypeBase shape u -> m (Param (TypeBase shape u)))
-> TypeBase shape u -> m (Param (TypeBase shape u))
forall a b. (a -> b) -> a -> b
$ PrimType -> TypeBase shape u
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
      Lambda rep -> m (Lambda rep)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Lambda rep -> m (Lambda rep)) -> Lambda rep -> m (Lambda rep)
forall a b. (a -> b) -> a -> b
$ Lambda rep
lam {lambdaParams :: [LParam rep]
lambdaParams = [Param (TypeBase shape u)]
idxs [Param (TypeBase shape u)]
-> [Param (TypeBase shape u)] -> [Param (TypeBase shape u)]
forall a. [a] -> [a] -> [a]
++ Lambda rep -> [LParam rep]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda rep
lam}

    accAddLambda :: Int -> Type -> ADM (Lambda SOACS)
accAddLambda Int
n Type
t = Int -> Lambda SOACS -> ADM (Lambda SOACS)
forall (m :: * -> *) rep shape u.
(MonadFreshNames m, LParamInfo rep ~ TypeBase shape u) =>
Int -> Lambda rep -> m (Lambda rep)
addIdxParams Int
n (Lambda SOACS -> ADM (Lambda SOACS))
-> ADM (Lambda SOACS) -> ADM (Lambda SOACS)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Type -> ADM (Lambda SOACS)
addLambda Type
t

    withAccInput :: (VName, (a, PrimType))
-> ADM (a, [VName], Maybe (Lambda SOACS, [SubExp]))
withAccInput (VName
v, (a
shape, PrimType
pt)) = do
      VName
v_adj <- VName -> ADM VName
lookupAdjVal VName
v
      Lambda SOACS
add_lam <- Int -> Type -> ADM (Lambda SOACS)
accAddLambda (a -> Int
forall a. ArrayShape a => a -> Int
shapeRank a
shape) (Type -> ADM (Lambda SOACS)) -> Type -> ADM (Lambda SOACS)
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
pt
      SubExp
zero <- String -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"zero" (Exp (Rep ADM) -> ADM SubExp) -> Exp (Rep ADM) -> ADM SubExp
forall a b. (a -> b) -> a -> b
$ Type -> Exp SOACS
forall rep. Type -> Exp rep
zeroExp (Type -> Exp SOACS) -> Type -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
pt
      (a, [VName], Maybe (Lambda SOACS, [SubExp]))
-> ADM (a, [VName], Maybe (Lambda SOACS, [SubExp]))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (a
shape, [VName
v_adj], (Lambda SOACS, [SubExp]) -> Maybe (Lambda SOACS, [SubExp])
forall a. a -> Maybe a
Just (Lambda SOACS
add_lam, [SubExp
zero]))

    accAdjoints :: [VName] -> ([VName] -> Names -> ADM ()) -> ADM ()
accAdjoints [VName]
free [VName] -> Names -> ADM ()
m = do
      ([(VName, (Shape, PrimType))]
arr_free, [VName]
acc_free, [VName]
nonacc_free) <-
        [AdjVar] -> ([(VName, (Shape, PrimType))], [VName], [VName])
partitionAdjVars ([AdjVar] -> ([(VName, (Shape, PrimType))], [VName], [VName]))
-> ADM [AdjVar]
-> ADM ([(VName, (Shape, PrimType))], [VName], [VName])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [VName] -> ADM [AdjVar]
classifyAdjVars [VName]
free
      [(Shape, [VName], Maybe (Lambda SOACS, [SubExp]))]
arr_free' <- ((VName, (Shape, PrimType))
 -> ADM (Shape, [VName], Maybe (Lambda SOACS, [SubExp])))
-> [(VName, (Shape, PrimType))]
-> ADM [(Shape, [VName], Maybe (Lambda SOACS, [SubExp]))]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (VName, (Shape, PrimType))
-> ADM (Shape, [VName], Maybe (Lambda SOACS, [SubExp]))
forall a.
ArrayShape a =>
(VName, (a, PrimType))
-> ADM (a, [VName], Maybe (Lambda SOACS, [SubExp]))
withAccInput [(VName, (Shape, PrimType))]
arr_free
      -- We only consider those input arrays that are also not free in
      -- the lambda.
      let as_nonfree :: [VName]
as_nonfree = (VName -> Bool) -> [VName] -> [VName]
forall a. (a -> Bool) -> [a] -> [a]
filter (VName -> [VName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` [VName]
free) [VName]
as
      ([VName]
arr_adjs, [VName]
acc_adjs, [VName]
rest_adjs) <-
        ([VName] -> ([VName], [VName], [VName]))
-> ADM [VName] -> ADM ([VName], [VName], [VName])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Int -> Int -> [VName] -> ([VName], [VName], [VName])
forall a. Int -> Int -> [a] -> ([a], [a], [a])
splitAt3 ([(VName, (Shape, PrimType))] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [(VName, (Shape, PrimType))]
arr_free) ([VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
acc_free)) (ADM [VName] -> ADM ([VName], [VName], [VName]))
-> (([VName] -> ADM Result) -> ADM [VName])
-> ([VName] -> ADM Result)
-> ADM ([VName], [VName], [VName])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [(Shape, [VName], Maybe (Lambda SOACS, [SubExp]))]
-> ([VName] -> ADM Result) -> ADM [VName]
withAcc [(Shape, [VName], Maybe (Lambda SOACS, [SubExp]))]
arr_free' (([VName] -> ADM Result) -> ADM ([VName], [VName], [VName]))
-> ([VName] -> ADM Result) -> ADM ([VName], [VName], [VName])
forall a b. (a -> b) -> a -> b
$ \[VName]
accs -> do
          (VName -> VName -> ADM ()) -> [VName] -> [VName] -> ADM ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> VName -> ADM ()
insAdj (((VName, (Shape, PrimType)) -> VName)
-> [(VName, (Shape, PrimType))] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (VName, (Shape, PrimType)) -> VName
forall a b. (a, b) -> a
fst [(VName, (Shape, PrimType))]
arr_free) [VName]
accs
          () <- [VName] -> Names -> ADM ()
m ([VName]
acc_free [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ ((VName, (Shape, PrimType)) -> VName)
-> [(VName, (Shape, PrimType))] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (VName, (Shape, PrimType)) -> VName
forall a b. (a, b) -> a
fst [(VName, (Shape, PrimType))]
arr_free) ([VName] -> Names
namesFromList [VName]
nonacc_free)
          [VName]
acc_free_adj <- (VName -> ADM VName) -> [VName] -> ADM [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> ADM VName
lookupAdjVal [VName]
acc_free
          [VName]
arr_free_adj <- ((VName, (Shape, PrimType)) -> ADM VName)
-> [(VName, (Shape, PrimType))] -> ADM [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (VName -> ADM VName
lookupAdjVal (VName -> ADM VName)
-> ((VName, (Shape, PrimType)) -> VName)
-> (VName, (Shape, PrimType))
-> ADM VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName, (Shape, PrimType)) -> VName
forall a b. (a, b) -> a
fst) [(VName, (Shape, PrimType))]
arr_free
          [VName]
nonacc_free_adj <- (VName -> ADM VName) -> [VName] -> ADM [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> ADM VName
lookupAdjVal [VName]
nonacc_free
          [VName]
as_nonfree_adj <- (VName -> ADM VName) -> [VName] -> ADM [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> ADM VName
lookupAdjVal [VName]
as_nonfree
          Result -> ADM Result
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Result -> ADM Result) -> Result -> ADM Result
forall a b. (a -> b) -> a -> b
$ [VName] -> Result
varsRes ([VName] -> Result) -> [VName] -> Result
forall a b. (a -> b) -> a -> b
$ [VName]
arr_free_adj [VName] -> [VName] -> [VName]
forall a. Semigroup a => a -> a -> a
<> [VName]
acc_free_adj [VName] -> [VName] -> [VName]
forall a. Semigroup a => a -> a -> a
<> [VName]
nonacc_free_adj [VName] -> [VName] -> [VName]
forall a. Semigroup a => a -> a -> a
<> [VName]
as_nonfree_adj
      (VName -> VName -> ADM ()) -> [VName] -> [VName] -> ADM ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> VName -> ADM ()
insAdj [VName]
acc_free [VName]
acc_adjs
      (VName -> VName -> ADM ()) -> [VName] -> [VName] -> ADM ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> VName -> ADM ()
insAdj (((VName, (Shape, PrimType)) -> VName)
-> [(VName, (Shape, PrimType))] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (VName, (Shape, PrimType)) -> VName
forall a b. (a, b) -> a
fst [(VName, (Shape, PrimType))]
arr_free) [VName]
arr_adjs
      let ([VName]
nonacc_adjs, [VName]
as_nonfree_adjs) = Int -> [VName] -> ([VName], [VName])
forall a. Int -> [a] -> ([a], [a])
splitAt ([VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
nonacc_free) [VName]
rest_adjs
      (VName -> VName -> ADM ()) -> [VName] -> [VName] -> ADM ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> VName -> ADM ()
insAdj [VName]
nonacc_free [VName]
nonacc_adjs
      (VName -> VName -> ADM ()) -> [VName] -> [VName] -> ADM ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> VName -> ADM ()
insAdj [VName]
as_nonfree [VName]
as_nonfree_adjs

    freeContrib :: VName -> VName -> ADM ()
freeContrib VName
v VName
contribs = do
      Type
contribs_t <- VName -> ADM Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
contribs
      case Type -> Type
forall u. TypeBase Shape u -> TypeBase Shape u
rowType Type
contribs_t of
        Acc {} -> ADM () -> ADM ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (ADM () -> ADM ()) -> ADM () -> ADM ()
forall a b. (a -> b) -> a -> b
$ VName -> VName -> ADM ()
insAdj VName
v VName
contribs
        Type
t -> do
          Lambda SOACS
lam <- Type -> ADM (Lambda SOACS)
addLambda Type
t
          SubExp
zero <- String -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"zero" (Exp (Rep ADM) -> ADM SubExp) -> Exp (Rep ADM) -> ADM SubExp
forall a b. (a -> b) -> a -> b
$ Type -> Exp SOACS
forall rep. Type -> Exp rep
zeroExp Type
t
          ScremaForm SOACS
reduce <- [Reduce SOACS] -> ADM (ScremaForm SOACS)
forall rep (m :: * -> *).
(Buildable rep, MonadFreshNames m) =>
[Reduce rep] -> m (ScremaForm rep)
reduceSOAC [Commutativity -> Lambda SOACS -> [SubExp] -> Reduce SOACS
forall rep. Commutativity -> Lambda rep -> [SubExp] -> Reduce rep
Reduce Commutativity
Commutative Lambda SOACS
lam [SubExp
zero]]
          VName
contrib_sum <-
            String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp (VName -> String
baseString VName
v String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_contrib_sum") (Exp (Rep ADM) -> ADM VName) -> Exp (Rep ADM) -> ADM VName
forall a b. (a -> b) -> a -> b
$
              Op SOACS -> Exp SOACS
forall rep. Op rep -> Exp rep
Op (Op SOACS -> Exp SOACS) -> Op SOACS -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w [VName
contribs] ScremaForm SOACS
reduce
          ADM () -> ADM ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (ADM () -> ADM ()) -> ADM () -> ADM ()
forall a b. (a -> b) -> a -> b
$ VName -> VName -> ADM ()
updateAdj VName
v VName
contrib_sum