{-# LANGUAGE TypeFamilies #-}

-- Naming scheme:
--
-- An adjoint-related object for "x" is named "x_adj".  This means
-- both actual adjoints and statements.
--
-- Do not assume "x'" means anything related to derivatives.
module Futhark.AD.Rev.Monad
  ( ADM,
    RState (..),
    runADM,
    Adj (..),
    InBounds (..),
    Sparse (..),
    adjFromParam,
    adjFromVar,
    lookupAdj,
    lookupAdjVal,
    adjVal,
    updateAdj,
    updateSubExpAdj,
    updateAdjSlice,
    updateAdjIndex,
    setAdj,
    insAdj,
    adjsReps,
    --
    copyConsumedArrsInStm,
    copyConsumedArrsInBody,
    addSubstitution,
    returnSweepCode,
    --
    adjVName,
    subAD,
    noAdjsFor,
    subSubsts,
    isActive,
    --
    tabNest,
    oneExp,
    zeroExp,
    unitAdjOfType,
    addLambda,
    --
    VjpOps (..),
    --
    setLoopTape,
    lookupLoopTape,
    substLoopTape,
    renameLoopTape,
  )
where

import Control.Monad
import Control.Monad.State.Strict
import Data.Bifunctor (second)
import Data.List (foldl')
import Data.Map qualified as M
import Data.Maybe
import Futhark.Analysis.Alias qualified as Alias
import Futhark.Analysis.PrimExp.Convert
import Futhark.Builder
import Futhark.IR.Aliases (consumedInStms)
import Futhark.IR.Prop.Aliases
import Futhark.IR.SOACS
import Futhark.Tools
import Futhark.Transform.Substitute
import Futhark.Util (chunks)

zeroExp :: Type -> Exp rep
zeroExp :: forall rep. TypeBase Shape NoUniqueness -> Exp rep
zeroExp (Prim PrimType
pt) =
  forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant forall a b. (a -> b) -> a -> b
$ PrimType -> PrimValue
blankPrimValue PrimType
pt
zeroExp (Array PrimType
pt Shape
shape NoUniqueness
_) =
  forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ Shape -> SubExp -> BasicOp
Replicate Shape
shape forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant forall a b. (a -> b) -> a -> b
$ PrimType -> PrimValue
blankPrimValue PrimType
pt
zeroExp TypeBase Shape NoUniqueness
t = forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$ [Char]
"zeroExp: " forall a. [a] -> [a] -> [a]
++ forall a. Pretty a => a -> [Char]
prettyString TypeBase Shape NoUniqueness
t

onePrim :: PrimType -> PrimValue
onePrim :: PrimType -> PrimValue
onePrim (IntType IntType
it) = IntValue -> PrimValue
IntValue forall a b. (a -> b) -> a -> b
$ forall int. Integral int => IntType -> int -> IntValue
intValue IntType
it (Int
1 :: Int)
onePrim (FloatType FloatType
ft) = FloatValue -> PrimValue
FloatValue forall a b. (a -> b) -> a -> b
$ forall num. Real num => FloatType -> num -> FloatValue
floatValue FloatType
ft (Double
1 :: Double)
onePrim PrimType
Bool = Bool -> PrimValue
BoolValue Bool
True
onePrim PrimType
Unit = PrimValue
UnitValue

oneExp :: Type -> Exp rep
oneExp :: forall rep. TypeBase Shape NoUniqueness -> Exp rep
oneExp (Prim PrimType
t) = forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp forall a b. (a -> b) -> a -> b
$ forall v. IsValue v => v -> SubExp
constant forall a b. (a -> b) -> a -> b
$ PrimType -> PrimValue
onePrim PrimType
t
oneExp (Array PrimType
pt Shape
shape NoUniqueness
_) =
  forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ Shape -> SubExp -> BasicOp
Replicate Shape
shape forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant forall a b. (a -> b) -> a -> b
$ PrimType -> PrimValue
onePrim PrimType
pt
oneExp TypeBase Shape NoUniqueness
t = forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$ [Char]
"oneExp: " forall a. [a] -> [a] -> [a]
++ forall a. Pretty a => a -> [Char]
prettyString TypeBase Shape NoUniqueness
t

-- | Whether 'Sparse' should check bounds or assume they are correct.
-- The latter results in simpler code.
data InBounds
  = -- | If a SubExp is provided, it references a boolean that is true
    -- when in-bounds.
    CheckBounds (Maybe SubExp)
  | AssumeBounds
  | -- | Dynamically these will always fail, so don't bother
    -- generating code for the update.  This is only needed to ensure
    -- a consistent representation of sparse Jacobians.
    OutOfBounds
  deriving (InBounds -> InBounds -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: InBounds -> InBounds -> Bool
$c/= :: InBounds -> InBounds -> Bool
== :: InBounds -> InBounds -> Bool
$c== :: InBounds -> InBounds -> Bool
Eq, Eq InBounds
InBounds -> InBounds -> Bool
InBounds -> InBounds -> Ordering
InBounds -> InBounds -> InBounds
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: InBounds -> InBounds -> InBounds
$cmin :: InBounds -> InBounds -> InBounds
max :: InBounds -> InBounds -> InBounds
$cmax :: InBounds -> InBounds -> InBounds
>= :: InBounds -> InBounds -> Bool
$c>= :: InBounds -> InBounds -> Bool
> :: InBounds -> InBounds -> Bool
$c> :: InBounds -> InBounds -> Bool
<= :: InBounds -> InBounds -> Bool
$c<= :: InBounds -> InBounds -> Bool
< :: InBounds -> InBounds -> Bool
$c< :: InBounds -> InBounds -> Bool
compare :: InBounds -> InBounds -> Ordering
$ccompare :: InBounds -> InBounds -> Ordering
Ord, Int -> InBounds -> ShowS
[InBounds] -> ShowS
InBounds -> [Char]
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
showList :: [InBounds] -> ShowS
$cshowList :: [InBounds] -> ShowS
show :: InBounds -> [Char]
$cshow :: InBounds -> [Char]
showsPrec :: Int -> InBounds -> ShowS
$cshowsPrec :: Int -> InBounds -> ShowS
Show)

-- | A symbolic representation of an array that is all zeroes, except
-- at certain indexes.
data Sparse = Sparse
  { -- | The shape of the array.
    Sparse -> Shape
sparseShape :: Shape,
    -- | Element type of the array.
    Sparse -> PrimType
sparseType :: PrimType,
    -- | Locations and values of nonzero values.  Indexes may be
    -- negative, in which case the value is ignored (unless
    -- 'AssumeBounds' is used).
    Sparse -> [(InBounds, SubExp, SubExp)]
sparseIdxVals :: [(InBounds, SubExp, SubExp)]
  }
  deriving (Sparse -> Sparse -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Sparse -> Sparse -> Bool
$c/= :: Sparse -> Sparse -> Bool
== :: Sparse -> Sparse -> Bool
$c== :: Sparse -> Sparse -> Bool
Eq, Eq Sparse
Sparse -> Sparse -> Bool
Sparse -> Sparse -> Ordering
Sparse -> Sparse -> Sparse
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: Sparse -> Sparse -> Sparse
$cmin :: Sparse -> Sparse -> Sparse
max :: Sparse -> Sparse -> Sparse
$cmax :: Sparse -> Sparse -> Sparse
>= :: Sparse -> Sparse -> Bool
$c>= :: Sparse -> Sparse -> Bool
> :: Sparse -> Sparse -> Bool
$c> :: Sparse -> Sparse -> Bool
<= :: Sparse -> Sparse -> Bool
$c<= :: Sparse -> Sparse -> Bool
< :: Sparse -> Sparse -> Bool
$c< :: Sparse -> Sparse -> Bool
compare :: Sparse -> Sparse -> Ordering
$ccompare :: Sparse -> Sparse -> Ordering
Ord, Int -> Sparse -> ShowS
[Sparse] -> ShowS
Sparse -> [Char]
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
showList :: [Sparse] -> ShowS
$cshowList :: [Sparse] -> ShowS
show :: Sparse -> [Char]
$cshow :: Sparse -> [Char]
showsPrec :: Int -> Sparse -> ShowS
$cshowsPrec :: Int -> Sparse -> ShowS
Show)

-- | The adjoint of a variable.
data Adj
  = AdjSparse Sparse
  | AdjVal SubExp
  | AdjZero Shape PrimType
  deriving (Adj -> Adj -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Adj -> Adj -> Bool
$c/= :: Adj -> Adj -> Bool
== :: Adj -> Adj -> Bool
$c== :: Adj -> Adj -> Bool
Eq, Eq Adj
Adj -> Adj -> Bool
Adj -> Adj -> Ordering
Adj -> Adj -> Adj
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: Adj -> Adj -> Adj
$cmin :: Adj -> Adj -> Adj
max :: Adj -> Adj -> Adj
$cmax :: Adj -> Adj -> Adj
>= :: Adj -> Adj -> Bool
$c>= :: Adj -> Adj -> Bool
> :: Adj -> Adj -> Bool
$c> :: Adj -> Adj -> Bool
<= :: Adj -> Adj -> Bool
$c<= :: Adj -> Adj -> Bool
< :: Adj -> Adj -> Bool
$c< :: Adj -> Adj -> Bool
compare :: Adj -> Adj -> Ordering
$ccompare :: Adj -> Adj -> Ordering
Ord, Int -> Adj -> ShowS
[Adj] -> ShowS
Adj -> [Char]
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
showList :: [Adj] -> ShowS
$cshowList :: [Adj] -> ShowS
show :: Adj -> [Char]
$cshow :: Adj -> [Char]
showsPrec :: Int -> Adj -> ShowS
$cshowsPrec :: Int -> Adj -> ShowS
Show)

instance Substitute Adj where
  substituteNames :: Substitutions -> Adj -> Adj
substituteNames Substitutions
m (AdjVal (Var VName
v)) = SubExp -> Adj
AdjVal forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var forall a b. (a -> b) -> a -> b
$ forall a. Substitute a => Substitutions -> a -> a
substituteNames Substitutions
m VName
v
  substituteNames Substitutions
_ Adj
adj = Adj
adj

zeroArray :: MonadBuilder m => Shape -> Type -> m VName
zeroArray :: forall (m :: * -> *).
MonadBuilder m =>
Shape -> TypeBase Shape NoUniqueness -> m VName
zeroArray Shape
shape TypeBase Shape NoUniqueness
t
  | forall a. ArrayShape a => a -> Int
shapeRank Shape
shape forall a. Eq a => a -> a -> Bool
== Int
0 =
      forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"zero" forall a b. (a -> b) -> a -> b
$ forall rep. TypeBase Shape NoUniqueness -> Exp rep
zeroExp TypeBase Shape NoUniqueness
t
  | Bool
otherwise = do
      SubExp
zero <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"zero" forall a b. (a -> b) -> a -> b
$ forall rep. TypeBase Shape NoUniqueness -> Exp rep
zeroExp TypeBase Shape NoUniqueness
t
      forall (m :: * -> *) a. MonadBuilder m => Attrs -> m a -> m a
attributing (Attr -> Attrs
oneAttr Attr
"sequential") forall a b. (a -> b) -> a -> b
$
        forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"zeroes_" forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
          Shape -> SubExp -> BasicOp
Replicate Shape
shape SubExp
zero

sparseArray :: (MonadBuilder m, Rep m ~ SOACS) => Sparse -> m VName
sparseArray :: forall (m :: * -> *).
(MonadBuilder m, Rep m ~ SOACS) =>
Sparse -> m VName
sparseArray (Sparse Shape
shape PrimType
t [(InBounds, SubExp, SubExp)]
ivs) = do
  forall a b c. (a -> b -> c) -> b -> a -> c
flip (forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM VName -> (InBounds, SubExp, SubExp) -> m VName
f) [(InBounds, SubExp, SubExp)]
ivs forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *).
MonadBuilder m =>
Shape -> TypeBase Shape NoUniqueness -> m VName
zeroArray Shape
shape (forall shape u. PrimType -> TypeBase shape u
Prim PrimType
t)
  where
    arr_t :: TypeBase Shape NoUniqueness
arr_t = forall shape u. PrimType -> TypeBase shape u
Prim PrimType
t TypeBase Shape NoUniqueness -> Shape -> TypeBase Shape NoUniqueness
`arrayOfShape` Shape
shape
    f :: VName -> (InBounds, SubExp, SubExp) -> m VName
f VName
arr (InBounds
check, SubExp
i, SubExp
se) = do
      let stm :: Safety -> m VName
stm Safety
s =
            forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"sparse" forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
              Safety -> VName -> Slice SubExp -> SubExp -> BasicOp
Update Safety
s VName
arr (TypeBase Shape NoUniqueness -> [DimIndex SubExp] -> Slice SubExp
fullSlice TypeBase Shape NoUniqueness
arr_t [forall d. d -> DimIndex d
DimFix SubExp
i]) SubExp
se
      case InBounds
check of
        InBounds
AssumeBounds -> Safety -> m VName
stm Safety
Unsafe
        CheckBounds Maybe SubExp
_ -> Safety -> m VName
stm Safety
Safe
        InBounds
OutOfBounds -> forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
arr

adjFromVar :: VName -> Adj
adjFromVar :: VName -> Adj
adjFromVar = SubExp -> Adj
AdjVal forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var

adjFromParam :: Param t -> Adj
adjFromParam :: forall t. Param t -> Adj
adjFromParam = VName -> Adj
adjFromVar forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dec. Param dec -> VName
paramName

unitAdjOfType :: Type -> ADM Adj
unitAdjOfType :: TypeBase Shape NoUniqueness -> ADM Adj
unitAdjOfType TypeBase Shape NoUniqueness
t = SubExp -> Adj
AdjVal forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"adj_unit" (forall rep. TypeBase Shape NoUniqueness -> Exp rep
oneExp TypeBase Shape NoUniqueness
t)

-- | The values representing an adjoint in symbolic form.  This is
-- used for when we wish to return an Adj from a Body or similar
-- without forcing manifestation.  Also returns a function for
-- reassembling the Adj from a new representation (the list must have
-- the same length).
adjRep :: Adj -> ([SubExp], [SubExp] -> Adj)
adjRep :: Adj -> ([SubExp], [SubExp] -> Adj)
adjRep (AdjVal SubExp
se) = ([SubExp
se], \[SubExp
se'] -> SubExp -> Adj
AdjVal SubExp
se')
adjRep (AdjZero Shape
shape PrimType
pt) = ([], \[] -> Shape -> PrimType -> Adj
AdjZero Shape
shape PrimType
pt)
adjRep (AdjSparse (Sparse Shape
shape PrimType
pt [(InBounds, SubExp, SubExp)]
ivs)) =
  (forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap forall {a} {a}. (a, a, a) -> [a]
ivRep [(InBounds, SubExp, SubExp)]
ivs, Sparse -> Adj
AdjSparse forall b c a. (b -> c) -> (a -> b) -> a -> c
. Shape -> PrimType -> [(InBounds, SubExp, SubExp)] -> Sparse
Sparse Shape
shape PrimType
pt forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {b} {c} {c}. [(InBounds, b, c)] -> [c] -> [(InBounds, c, c)]
repIvs [(InBounds, SubExp, SubExp)]
ivs)
  where
    ivRep :: (a, a, a) -> [a]
ivRep (a
_, a
i, a
v) = [a
i, a
v]
    repIvs :: [(InBounds, b, c)] -> [c] -> [(InBounds, c, c)]
repIvs ((InBounds
check, b
_, c
_) : [(InBounds, b, c)]
ivs') (c
i : c
v : [c]
ses) =
      (InBounds
check', c
i, c
v) forall a. a -> [a] -> [a]
: [(InBounds, b, c)] -> [c] -> [(InBounds, c, c)]
repIvs [(InBounds, b, c)]
ivs' [c]
ses
      where
        check' :: InBounds
check' = case InBounds
check of
          InBounds
AssumeBounds -> InBounds
AssumeBounds
          CheckBounds Maybe SubExp
b -> Maybe SubExp -> InBounds
CheckBounds Maybe SubExp
b
          InBounds
OutOfBounds -> Maybe SubExp -> InBounds
CheckBounds (forall a. a -> Maybe a
Just (forall v. IsValue v => v -> SubExp
constant Bool
False)) -- sic!
    repIvs [(InBounds, b, c)]
_ [c]
_ = []

-- | Conveniently convert a list of Adjs to their representation, as
-- well as produce a function for converting back.
adjsReps :: [Adj] -> ([SubExp], [SubExp] -> [Adj])
adjsReps :: [Adj] -> ([SubExp], [SubExp] -> [Adj])
adjsReps [Adj]
adjs =
  let ([[SubExp]]
reps, [[SubExp] -> Adj]
fs) = forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map Adj -> ([SubExp], [SubExp] -> Adj)
adjRep [Adj]
adjs
   in (forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[SubExp]]
reps, forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith forall a b. (a -> b) -> a -> b
($) [[SubExp] -> Adj]
fs forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. [Int] -> [a] -> [[a]]
chunks (forall a b. (a -> b) -> [a] -> [b]
map forall (t :: * -> *) a. Foldable t => t a -> Int
length [[SubExp]]
reps))

data RState = RState
  { RState -> Map VName Adj
stateAdjs :: M.Map VName Adj,
    RState -> Substitutions
stateLoopTape :: Substitutions,
    RState -> Substitutions
stateSubsts :: Substitutions,
    RState -> VNameSource
stateNameSource :: VNameSource
  }

newtype ADM a = ADM (BuilderT SOACS (State RState) a)
  deriving
    ( forall a b. a -> ADM b -> ADM a
forall a b. (a -> b) -> ADM a -> ADM b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: forall a b. a -> ADM b -> ADM a
$c<$ :: forall a b. a -> ADM b -> ADM a
fmap :: forall a b. (a -> b) -> ADM a -> ADM b
$cfmap :: forall a b. (a -> b) -> ADM a -> ADM b
Functor,
      Functor ADM
forall a. a -> ADM a
forall a b. ADM a -> ADM b -> ADM a
forall a b. ADM a -> ADM b -> ADM b
forall a b. ADM (a -> b) -> ADM a -> ADM b
forall a b c. (a -> b -> c) -> ADM a -> ADM b -> ADM c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: forall a b. ADM a -> ADM b -> ADM a
$c<* :: forall a b. ADM a -> ADM b -> ADM a
*> :: forall a b. ADM a -> ADM b -> ADM b
$c*> :: forall a b. ADM a -> ADM b -> ADM b
liftA2 :: forall a b c. (a -> b -> c) -> ADM a -> ADM b -> ADM c
$cliftA2 :: forall a b c. (a -> b -> c) -> ADM a -> ADM b -> ADM c
<*> :: forall a b. ADM (a -> b) -> ADM a -> ADM b
$c<*> :: forall a b. ADM (a -> b) -> ADM a -> ADM b
pure :: forall a. a -> ADM a
$cpure :: forall a. a -> ADM a
Applicative,
      Applicative ADM
forall a. a -> ADM a
forall a b. ADM a -> ADM b -> ADM b
forall a b. ADM a -> (a -> ADM b) -> ADM b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: forall a. a -> ADM a
$creturn :: forall a. a -> ADM a
>> :: forall a b. ADM a -> ADM b -> ADM b
$c>> :: forall a b. ADM a -> ADM b -> ADM b
>>= :: forall a b. ADM a -> (a -> ADM b) -> ADM b
$c>>= :: forall a b. ADM a -> (a -> ADM b) -> ADM b
Monad,
      MonadState RState,
      Monad ADM
ADM VNameSource
VNameSource -> ADM ()
forall (m :: * -> *).
Monad m
-> m VNameSource -> (VNameSource -> m ()) -> MonadFreshNames m
putNameSource :: VNameSource -> ADM ()
$cputNameSource :: VNameSource -> ADM ()
getNameSource :: ADM VNameSource
$cgetNameSource :: ADM VNameSource
MonadFreshNames,
      HasScope SOACS,
      LocalScope SOACS
    )

instance MonadBuilder ADM where
  type Rep ADM = SOACS
  mkExpDecM :: Pat (LetDec (Rep ADM)) -> Exp (Rep ADM) -> ADM (ExpDec (Rep ADM))
mkExpDecM Pat (LetDec (Rep ADM))
pat Exp (Rep ADM)
e = forall a. BuilderT SOACS (State RState) a -> ADM a
ADM forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m (ExpDec (Rep m))
mkExpDecM Pat (LetDec (Rep ADM))
pat Exp (Rep ADM)
e
  mkBodyM :: Stms (Rep ADM) -> Result -> ADM (Body (Rep ADM))
mkBodyM Stms (Rep ADM)
bnds Result
res = forall a. BuilderT SOACS (State RState) a -> ADM a
ADM forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *).
MonadBuilder m =>
Stms (Rep m) -> Result -> m (Body (Rep m))
mkBodyM Stms (Rep ADM)
bnds Result
res
  mkLetNamesM :: [VName] -> Exp (Rep ADM) -> ADM (Stm (Rep ADM))
mkLetNamesM [VName]
pat Exp (Rep ADM)
e = forall a. BuilderT SOACS (State RState) a -> ADM a
ADM forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m (Stm (Rep m))
mkLetNamesM [VName]
pat Exp (Rep ADM)
e

  addStms :: Stms (Rep ADM) -> ADM ()
addStms = forall a. BuilderT SOACS (State RState) a -> ADM a
ADM forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms
  collectStms :: forall a. ADM a -> ADM (a, Stms (Rep ADM))
collectStms (ADM BuilderT SOACS (State RState) a
m) = forall a. BuilderT SOACS (State RState) a -> ADM a
ADM forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a.
MonadBuilder m =>
m a -> m (a, Stms (Rep m))
collectStms BuilderT SOACS (State RState) a
m

instance MonadFreshNames (State RState) where
  getNameSource :: State RState VNameSource
getNameSource = forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets RState -> VNameSource
stateNameSource
  putNameSource :: VNameSource -> State RState ()
putNameSource VNameSource
src = forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (\RState
env -> RState
env {stateNameSource :: VNameSource
stateNameSource = VNameSource
src})

runADM :: MonadFreshNames m => ADM a -> m a
runADM :: forall (m :: * -> *) a. MonadFreshNames m => ADM a -> m a
runADM (ADM BuilderT SOACS (State RState) a
m) =
  forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource forall a b. (a -> b) -> a -> b
$ \VNameSource
vn ->
    forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second RState -> VNameSource
stateNameSource forall a b. (a -> b) -> a -> b
$
      forall s a. State s a -> s -> (a, s)
runState
        (forall a b. (a, b) -> a
fst forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) rep a.
MonadFreshNames m =>
BuilderT rep m a -> Scope rep -> m (a, Stms rep)
runBuilderT BuilderT SOACS (State RState) a
m forall a. Monoid a => a
mempty)
        (Map VName Adj
-> Substitutions -> Substitutions -> VNameSource -> RState
RState forall a. Monoid a => a
mempty forall a. Monoid a => a
mempty forall a. Monoid a => a
mempty VNameSource
vn)

adjVal :: Adj -> ADM VName
adjVal :: Adj -> ADM VName
adjVal (AdjVal SubExp
se) = forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"const_adj" forall a b. (a -> b) -> a -> b
$ forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se
adjVal (AdjSparse Sparse
sparse) = forall (m :: * -> *).
(MonadBuilder m, Rep m ~ SOACS) =>
Sparse -> m VName
sparseArray Sparse
sparse
adjVal (AdjZero Shape
shape PrimType
t) = forall (m :: * -> *).
MonadBuilder m =>
Shape -> TypeBase Shape NoUniqueness -> m VName
zeroArray Shape
shape forall a b. (a -> b) -> a -> b
$ forall shape u. PrimType -> TypeBase shape u
Prim PrimType
t

-- | Set a specific adjoint.
setAdj :: VName -> Adj -> ADM ()
setAdj :: VName -> Adj -> ADM ()
setAdj VName
v Adj
v_adj = forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify forall a b. (a -> b) -> a -> b
$ \RState
env ->
  RState
env {stateAdjs :: Map VName Adj
stateAdjs = forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
v Adj
v_adj forall a b. (a -> b) -> a -> b
$ RState -> Map VName Adj
stateAdjs RState
env}

-- | Set an 'AdjVal' adjoint.  Simple wrapper around 'setAdj'.
insAdj :: VName -> VName -> ADM ()
insAdj :: VName -> VName -> ADM ()
insAdj VName
v = VName -> Adj -> ADM ()
setAdj VName
v forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> Adj
AdjVal forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var

adjVName :: VName -> ADM VName
adjVName :: VName -> ADM VName
adjVName VName
v = forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName (VName -> [Char]
baseString VName
v forall a. Semigroup a => a -> a -> a
<> [Char]
"_adj")

-- | Create copies of all arrays consumed in the given statement, and
-- return statements which include copies of the consumed arrays.
--
-- See Note [Consumption].
copyConsumedArrsInStm :: Stm SOACS -> ADM (Substitutions, Stms SOACS)
copyConsumedArrsInStm :: Stm SOACS -> ADM (Substitutions, Stms SOACS)
copyConsumedArrsInStm Stm SOACS
s = forall rep a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf Stm SOACS
s forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a.
MonadBuilder m =>
m a -> m (a, Stms (Rep m))
collectStms forall a b. (a -> b) -> a -> b
$ Stm SOACS -> ADM Substitutions
copyConsumedArrsInStm' Stm SOACS
s
  where
    copyConsumedArrsInStm' :: Stm SOACS -> ADM Substitutions
copyConsumedArrsInStm' Stm SOACS
stm =
      let onConsumed :: VName -> ADM [(VName, VName)]
onConsumed VName
v = forall rep a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf Stm SOACS
s forall a b. (a -> b) -> a -> b
$ do
            TypeBase Shape NoUniqueness
v_t <- forall rep (m :: * -> *).
HasScope rep m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType VName
v
            case TypeBase Shape NoUniqueness
v_t of
              Array {} -> do
                VName
v' <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp (VName -> [Char]
baseString VName
v forall a. Semigroup a => a -> a -> a
<> [Char]
"_ad_copy") (forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy VName
v)
                VName -> VName -> ADM ()
addSubstitution VName
v' VName
v
                forall (f :: * -> *) a. Applicative f => a -> f a
pure [(VName
v, VName
v')]
              TypeBase Shape NoUniqueness
_ -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Monoid a => a
mempty
       in forall k a. Ord k => [(k, a)] -> Map k a
M.fromList forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Monoid a => [a] -> a
mconcat
            forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> ADM [(VName, VName)]
onConsumed (Names -> [VName]
namesToList forall a b. (a -> b) -> a -> b
$ forall rep. Aliased rep => Stms rep -> Names
consumedInStms forall a b. (a -> b) -> a -> b
$ forall a b. (a, b) -> a
fst (forall rep.
AliasableRep rep =>
AliasTable -> Stms rep -> (Stms (Aliases rep), AliasesAndConsumed)
Alias.analyseStms forall a. Monoid a => a
mempty (forall rep. Stm rep -> Stms rep
oneStm Stm SOACS
stm)))

copyConsumedArrsInBody :: [VName] -> Body SOACS -> ADM Substitutions
copyConsumedArrsInBody :: [VName] -> Body SOACS -> ADM Substitutions
copyConsumedArrsInBody [VName]
dontCopy Body SOACS
b =
  forall a. Monoid a => [a] -> a
mconcat forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall {m :: * -> *}. MonadBuilder m => VName -> m Substitutions
onConsumed (forall a. (a -> Bool) -> [a] -> [a]
filter (forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` [VName]
dontCopy) forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList forall a b. (a -> b) -> a -> b
$ forall rep. Aliased rep => Body rep -> Names
consumedInBody (forall rep.
AliasableRep rep =>
AliasTable -> Body rep -> Body (Aliases rep)
Alias.analyseBody forall a. Monoid a => a
mempty Body SOACS
b))
  where
    onConsumed :: VName -> m Substitutions
onConsumed VName
v = do
      TypeBase Shape NoUniqueness
v_t <- forall rep (m :: * -> *).
HasScope rep m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType VName
v
      case TypeBase Shape NoUniqueness
v_t of
        Acc {} -> forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$ [Char]
"copyConsumedArrsInBody: Acc " forall a. Semigroup a => a -> a -> a
<> forall a. Pretty a => a -> [Char]
prettyString VName
v
        Array {} -> forall k a. k -> a -> Map k a
M.singleton VName
v forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp (VName -> [Char]
baseString VName
v forall a. Semigroup a => a -> a -> a
<> [Char]
"_ad_copy") (forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy VName
v)
        TypeBase Shape NoUniqueness
_ -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Monoid a => a
mempty

returnSweepCode :: ADM a -> ADM a
returnSweepCode :: forall a. ADM a -> ADM a
returnSweepCode ADM a
m = do
  (a
a, Stms SOACS
stms) <- forall (m :: * -> *) a.
MonadBuilder m =>
m a -> m (a, Stms (Rep m))
collectStms ADM a
m
  Substitutions
substs <- forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets RState -> Substitutions
stateSubsts
  forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms forall a b. (a -> b) -> a -> b
$ forall a. Substitute a => Substitutions -> a -> a
substituteNames Substitutions
substs Stms SOACS
stms
  forall (f :: * -> *) a. Applicative f => a -> f a
pure a
a

addSubstitution :: VName -> VName -> ADM ()
addSubstitution :: VName -> VName -> ADM ()
addSubstitution VName
v VName
v' = forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify forall a b. (a -> b) -> a -> b
$ \RState
env ->
  RState
env {stateSubsts :: Substitutions
stateSubsts = forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
v VName
v' forall a b. (a -> b) -> a -> b
$ RState -> Substitutions
stateSubsts RState
env}

-- While evaluating this action, pretend these variables have no
-- adjoints.  Restore current adjoints afterwards.  This is used for
-- handling certain nested operations. XXX: feels like this should
-- really be part of subAD, somehow.  Main challenge is that we don't
-- want to blank out Accumulator adjoints.  Also, might be inefficient
-- to blank out array adjoints.
noAdjsFor :: Names -> ADM a -> ADM a
noAdjsFor :: forall a. Names -> ADM a -> ADM a
noAdjsFor Names
names ADM a
m = do
  [Adj]
old <- forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets forall a b. (a -> b) -> a -> b
$ \RState
env -> forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (forall k a. Ord k => k -> Map k a -> Maybe a
`M.lookup` RState -> Map VName Adj
stateAdjs RState
env) [VName]
names'
  forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify forall a b. (a -> b) -> a -> b
$ \RState
env -> RState
env {stateAdjs :: Map VName Adj
stateAdjs = forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' (forall a b c. (a -> b -> c) -> b -> a -> c
flip forall k a. Ord k => k -> Map k a -> Map k a
M.delete) (RState -> Map VName Adj
stateAdjs RState
env) [VName]
names'}
  a
x <- ADM a
m
  forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify forall a b. (a -> b) -> a -> b
$ \RState
env -> RState
env {stateAdjs :: Map VName Adj
stateAdjs = forall k a. Ord k => [(k, a)] -> Map k a
M.fromList (forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
names' [Adj]
old) forall a. Semigroup a => a -> a -> a
<> RState -> Map VName Adj
stateAdjs RState
env}
  forall (f :: * -> *) a. Applicative f => a -> f a
pure a
x
  where
    names' :: [VName]
names' = Names -> [VName]
namesToList Names
names

addBinOp :: PrimType -> BinOp
addBinOp :: PrimType -> BinOp
addBinOp (IntType IntType
it) = IntType -> Overflow -> BinOp
Add IntType
it Overflow
OverflowWrap
addBinOp (FloatType FloatType
ft) = FloatType -> BinOp
FAdd FloatType
ft
addBinOp PrimType
Bool = BinOp
LogAnd
addBinOp PrimType
Unit = BinOp
LogAnd

tabNest :: Int -> [VName] -> ([VName] -> [VName] -> ADM [VName]) -> ADM [VName]
tabNest :: Int
-> [VName] -> ([VName] -> [VName] -> ADM [VName]) -> ADM [VName]
tabNest = forall {m :: * -> *} {t}.
(LParamInfo (Rep m) ~ TypeBase Shape NoUniqueness,
 BodyDec (Rep m) ~ (), OpC (Rep m) ~ SOAC, Num t, MonadBuilder m,
 Eq t) =>
[VName]
-> t -> [VName] -> ([VName] -> [VName] -> m [VName]) -> m [VName]
tabNest' []
  where
    tabNest' :: [VName]
-> t -> [VName] -> ([VName] -> [VName] -> m [VName]) -> m [VName]
tabNest' [VName]
is t
0 [VName]
vs [VName] -> [VName] -> m [VName]
f = [VName] -> [VName] -> m [VName]
f (forall a. [a] -> [a]
reverse [VName]
is) [VName]
vs
    tabNest' [VName]
is t
n [VName]
vs [VName] -> [VName] -> m [VName]
f = do
      [TypeBase Shape NoUniqueness]
vs_ts <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall rep (m :: * -> *).
HasScope rep m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType [VName]
vs
      let w :: SubExp
w = forall u. Int -> [TypeBase Shape u] -> SubExp
arraysSize Int
0 [TypeBase Shape NoUniqueness]
vs_ts
      VName
iota <-
        forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"tab_iota" forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
          SubExp -> SubExp -> SubExp -> IntType -> BasicOp
Iota SubExp
w (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) IntType
Int64
      Param (TypeBase Shape NoUniqueness)
iparam <- forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam [Char]
"i" forall a b. (a -> b) -> a -> b
$ forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
      [Param (TypeBase Shape NoUniqueness)]
params <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [VName]
vs forall a b. (a -> b) -> a -> b
$ \VName
v ->
        forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam (VName -> [Char]
baseString VName
v forall a. Semigroup a => a -> a -> a
<> [Char]
"_p") forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall u. TypeBase Shape u -> TypeBase Shape u
rowType forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall rep (m :: * -> *).
HasScope rep m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType VName
v
      (([TypeBase Shape NoUniqueness]
ret, Result
res), Stms (Rep m)
stms) <- forall (m :: * -> *) a.
MonadBuilder m =>
m a -> m (a, Stms (Rep m))
collectStms forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (forall rep dec. (LParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfLParams (Param (TypeBase Shape NoUniqueness)
iparam forall a. a -> [a] -> [a]
: [Param (TypeBase Shape NoUniqueness)]
params)) forall a b. (a -> b) -> a -> b
$ do
        [VName]
res <- [VName]
-> t -> [VName] -> ([VName] -> [VName] -> m [VName]) -> m [VName]
tabNest' (forall dec. Param dec -> VName
paramName Param (TypeBase Shape NoUniqueness)
iparam forall a. a -> [a] -> [a]
: [VName]
is) (t
n forall a. Num a => a -> a -> a
- t
1) (forall a b. (a -> b) -> [a] -> [b]
map forall dec. Param dec -> VName
paramName [Param (TypeBase Shape NoUniqueness)]
params) [VName] -> [VName] -> m [VName]
f
        [TypeBase Shape NoUniqueness]
ret <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall rep (m :: * -> *).
HasScope rep m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType [VName]
res
        forall (f :: * -> *) a. Applicative f => a -> f a
pure ([TypeBase Shape NoUniqueness]
ret, [VName] -> Result
varsRes [VName]
res)
      let lam :: Lambda (Rep m)
lam = forall rep.
[LParam rep]
-> Body rep -> [TypeBase Shape NoUniqueness] -> Lambda rep
Lambda (Param (TypeBase Shape NoUniqueness)
iparam forall a. a -> [a] -> [a]
: [Param (TypeBase Shape NoUniqueness)]
params) (forall rep. BodyDec rep -> Stms rep -> Result -> Body rep
Body () Stms (Rep m)
stms Result
res) [TypeBase Shape NoUniqueness]
ret
      forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m [VName]
letTupExp [Char]
"tab" forall a b. (a -> b) -> a -> b
$ forall rep. Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w (VName
iota forall a. a -> [a] -> [a]
: [VName]
vs) (forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda (Rep m)
lam)

-- | Construct a lambda for adding two values of the given type.
addLambda :: Type -> ADM (Lambda SOACS)
addLambda :: TypeBase Shape NoUniqueness -> ADM (Lambda SOACS)
addLambda (Prim PrimType
pt) = forall (m :: * -> *).
(MonadBuilder m, Buildable (Rep m)) =>
BinOp -> PrimType -> m (Lambda (Rep m))
binOpLambda (PrimType -> BinOp
addBinOp PrimType
pt) PrimType
pt
addLambda t :: TypeBase Shape NoUniqueness
t@Array {} = do
  Param (TypeBase Shape NoUniqueness)
xs_p <- forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam [Char]
"xs" TypeBase Shape NoUniqueness
t
  Param (TypeBase Shape NoUniqueness)
ys_p <- forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam [Char]
"ys" TypeBase Shape NoUniqueness
t
  Lambda SOACS
lam <- TypeBase Shape NoUniqueness -> ADM (Lambda SOACS)
addLambda forall a b. (a -> b) -> a -> b
$ forall u. TypeBase Shape u -> TypeBase Shape u
rowType TypeBase Shape NoUniqueness
t
  Body SOACS
body <- forall (m :: * -> *).
MonadBuilder m =>
m (Body (Rep m)) -> m (Body (Rep m))
insertStmsM forall a b. (a -> b) -> a -> b
$ do
    SubExp
res <-
      forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"lam_map" forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$
        forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma (forall u. Int -> TypeBase Shape u -> SubExp
arraySize Int
0 TypeBase Shape NoUniqueness
t) [forall dec. Param dec -> VName
paramName Param (TypeBase Shape NoUniqueness)
xs_p, forall dec. Param dec -> VName
paramName Param (TypeBase Shape NoUniqueness)
ys_p] (forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
lam)
    forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall rep. Buildable rep => [SubExp] -> Body rep
resultBody [SubExp
res]
  forall (f :: * -> *) a. Applicative f => a -> f a
pure
    Lambda
      { lambdaParams :: [LParam SOACS]
lambdaParams = [Param (TypeBase Shape NoUniqueness)
xs_p, Param (TypeBase Shape NoUniqueness)
ys_p],
        lambdaReturnType :: [TypeBase Shape NoUniqueness]
lambdaReturnType = [TypeBase Shape NoUniqueness
t],
        lambdaBody :: Body SOACS
lambdaBody = Body SOACS
body
      }
addLambda TypeBase Shape NoUniqueness
t =
  forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$ [Char]
"addLambda: " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> [Char]
show TypeBase Shape NoUniqueness
t

-- Construct an expression for adding the two variables.
addExp :: VName -> VName -> ADM (Exp SOACS)
addExp :: VName -> VName -> ADM (Exp SOACS)
addExp VName
x VName
y = do
  TypeBase Shape NoUniqueness
x_t <- forall rep (m :: * -> *).
HasScope rep m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType VName
x
  case TypeBase Shape NoUniqueness
x_t of
    Prim PrimType
pt ->
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (PrimType -> BinOp
addBinOp PrimType
pt) (VName -> SubExp
Var VName
x) (VName -> SubExp
Var VName
y)
    Array {} -> do
      Lambda SOACS
lam <- TypeBase Shape NoUniqueness -> ADM (Lambda SOACS)
addLambda forall a b. (a -> b) -> a -> b
$ forall u. TypeBase Shape u -> TypeBase Shape u
rowType TypeBase Shape NoUniqueness
x_t
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall rep. Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma (forall u. Int -> TypeBase Shape u -> SubExp
arraySize Int
0 TypeBase Shape NoUniqueness
x_t) [VName
x, VName
y] (forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
lam)
    TypeBase Shape NoUniqueness
_ ->
      forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$ [Char]
"addExp: unexpected type: " forall a. [a] -> [a] -> [a]
++ forall a. Pretty a => a -> [Char]
prettyString TypeBase Shape NoUniqueness
x_t

lookupAdj :: VName -> ADM Adj
lookupAdj :: VName -> ADM Adj
lookupAdj VName
v = do
  Maybe Adj
maybeAdj <- forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets forall a b. (a -> b) -> a -> b
$ forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v forall b c a. (b -> c) -> (a -> b) -> a -> c
. RState -> Map VName Adj
stateAdjs
  case Maybe Adj
maybeAdj of
    Maybe Adj
Nothing -> do
      TypeBase Shape NoUniqueness
v_t <- forall rep (m :: * -> *).
HasScope rep m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType VName
v
      case TypeBase Shape NoUniqueness
v_t of
        Acc VName
_ Shape
shape [Prim PrimType
t] NoUniqueness
_ -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Shape -> PrimType -> Adj
AdjZero Shape
shape PrimType
t
        TypeBase Shape NoUniqueness
_ -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Shape -> PrimType -> Adj
AdjZero (forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape TypeBase Shape NoUniqueness
v_t) (forall shape u. TypeBase shape u -> PrimType
elemType TypeBase Shape NoUniqueness
v_t)
    Just Adj
v_adj -> forall (f :: * -> *) a. Applicative f => a -> f a
pure Adj
v_adj

lookupAdjVal :: VName -> ADM VName
lookupAdjVal :: VName -> ADM VName
lookupAdjVal VName
v = Adj -> ADM VName
adjVal forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName -> ADM Adj
lookupAdj VName
v

updateAdj :: VName -> VName -> ADM ()
updateAdj :: VName -> VName -> ADM ()
updateAdj VName
v VName
d = do
  Maybe Adj
maybeAdj <- forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets forall a b. (a -> b) -> a -> b
$ forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v forall b c a. (b -> c) -> (a -> b) -> a -> c
. RState -> Map VName Adj
stateAdjs
  case Maybe Adj
maybeAdj of
    Maybe Adj
Nothing ->
      VName -> VName -> ADM ()
insAdj VName
v VName
d
    Just Adj
adj -> do
      VName
v_adj <- Adj -> ADM VName
adjVal Adj
adj
      TypeBase Shape NoUniqueness
v_adj_t <- forall rep (m :: * -> *).
HasScope rep m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType VName
v_adj
      case TypeBase Shape NoUniqueness
v_adj_t of
        Acc {} -> do
          [SubExp]
dims <- forall u. TypeBase Shape u -> [SubExp]
arrayDims forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall rep (m :: * -> *).
HasScope rep m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType VName
d
          ~[VName
v_adj'] <-
            Int
-> [VName] -> ([VName] -> [VName] -> ADM [VName]) -> ADM [VName]
tabNest (forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
dims) [VName
d, VName
v_adj] forall a b. (a -> b) -> a -> b
$ \[VName]
is [VName
d', VName
v_adj'] ->
              forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m [VName]
letTupExp [Char]
"acc" forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
                VName -> [SubExp] -> [SubExp] -> BasicOp
UpdateAcc VName
v_adj' (forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
is) [VName -> SubExp
Var VName
d']
          VName -> VName -> ADM ()
insAdj VName
v VName
v_adj'
        TypeBase Shape NoUniqueness
_ -> do
          VName
v_adj' <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp (VName -> [Char]
baseString VName
v forall a. Semigroup a => a -> a -> a
<> [Char]
"_adj") forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName -> VName -> ADM (Exp SOACS)
addExp VName
v_adj VName
d
          VName -> VName -> ADM ()
insAdj VName
v VName
v_adj'

updateAdjSlice :: Slice SubExp -> VName -> VName -> ADM ()
updateAdjSlice :: Slice SubExp -> VName -> VName -> ADM ()
updateAdjSlice (Slice [DimFix SubExp
i]) VName
v VName
d =
  VName -> (InBounds, SubExp) -> SubExp -> ADM ()
updateAdjIndex VName
v (InBounds
AssumeBounds, SubExp
i) (VName -> SubExp
Var VName
d)
updateAdjSlice Slice SubExp
slice VName
v VName
d = do
  TypeBase Shape NoUniqueness
t <- forall rep (m :: * -> *).
HasScope rep m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType VName
v
  VName
v_adj <- VName -> ADM VName
lookupAdjVal VName
v
  TypeBase Shape NoUniqueness
v_adj_t <- forall rep (m :: * -> *).
HasScope rep m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType VName
v_adj
  VName
v_adj' <- case TypeBase Shape NoUniqueness
v_adj_t of
    Acc {} -> do
      let dims :: [SubExp]
dims = forall d. Slice d -> [d]
sliceDims Slice SubExp
slice
      ~[VName
v_adj'] <-
        Int
-> [VName] -> ([VName] -> [VName] -> ADM [VName]) -> ADM [VName]
tabNest (forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
dims) [VName
d, VName
v_adj] forall a b. (a -> b) -> a -> b
$ \[VName]
is [VName
d', VName
v_adj'] -> do
          [SubExp]
slice' <-
            forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (forall (m :: * -> *) a.
(MonadBuilder m, ToExp a) =>
[Char] -> a -> m SubExp
toSubExp [Char]
"index") forall a b. (a -> b) -> a -> b
$
              forall d. Num d => Slice d -> [d] -> [d]
fixSlice (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TPrimExp Int64 VName
pe64 Slice SubExp
slice) forall a b. (a -> b) -> a -> b
$
                forall a b. (a -> b) -> [a] -> [b]
map forall a. a -> TPrimExp Int64 a
le64 [VName]
is
          forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m [VName]
letTupExp (VName -> [Char]
baseString VName
v_adj') forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
            VName -> [SubExp] -> [SubExp] -> BasicOp
UpdateAcc VName
v_adj' [SubExp]
slice' [VName -> SubExp
Var VName
d']
      forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
v_adj'
    TypeBase Shape NoUniqueness
_ -> do
      VName
v_adjslice <-
        if forall shape u. TypeBase shape u -> Bool
primType TypeBase Shape NoUniqueness
t
          then forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
v_adj
          else forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp (VName -> [Char]
baseString VName
v forall a. [a] -> [a] -> [a]
++ [Char]
"_slice") forall a b. (a -> b) -> a -> b
$ forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
v_adj Slice SubExp
slice
      forall (m :: * -> *).
MonadBuilder m =>
[Char] -> VName -> Slice SubExp -> Exp (Rep m) -> m VName
letInPlace [Char]
"updated_adj" VName
v_adj Slice SubExp
slice forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName -> VName -> ADM (Exp SOACS)
addExp VName
v_adjslice VName
d
  VName -> VName -> ADM ()
insAdj VName
v VName
v_adj'

updateSubExpAdj :: SubExp -> VName -> ADM ()
updateSubExpAdj :: SubExp -> VName -> ADM ()
updateSubExpAdj Constant {} VName
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
updateSubExpAdj (Var VName
v) VName
d = forall (f :: * -> *) a. Functor f => f a -> f ()
void forall a b. (a -> b) -> a -> b
$ VName -> VName -> ADM ()
updateAdj VName
v VName
d

-- The index may be negative, in which case the update has no effect.
updateAdjIndex :: VName -> (InBounds, SubExp) -> SubExp -> ADM ()
updateAdjIndex :: VName -> (InBounds, SubExp) -> SubExp -> ADM ()
updateAdjIndex VName
v (InBounds
check, SubExp
i) SubExp
se = do
  Maybe Adj
maybeAdj <- forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets forall a b. (a -> b) -> a -> b
$ forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v forall b c a. (b -> c) -> (a -> b) -> a -> c
. RState -> Map VName Adj
stateAdjs
  TypeBase Shape NoUniqueness
t <- forall rep (m :: * -> *).
HasScope rep m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType VName
v
  let iv :: (InBounds, SubExp, SubExp)
iv = (InBounds
check, SubExp
i, SubExp
se)
  case Maybe Adj
maybeAdj of
    Maybe Adj
Nothing -> do
      VName -> Adj -> ADM ()
setAdj VName
v forall a b. (a -> b) -> a -> b
$ Sparse -> Adj
AdjSparse forall a b. (a -> b) -> a -> b
$ Shape -> PrimType -> [(InBounds, SubExp, SubExp)] -> Sparse
Sparse (forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape TypeBase Shape NoUniqueness
t) (forall shape u. TypeBase shape u -> PrimType
elemType TypeBase Shape NoUniqueness
t) [(InBounds, SubExp, SubExp)
iv]
    Just AdjZero {} ->
      VName -> Adj -> ADM ()
setAdj VName
v forall a b. (a -> b) -> a -> b
$ Sparse -> Adj
AdjSparse forall a b. (a -> b) -> a -> b
$ Shape -> PrimType -> [(InBounds, SubExp, SubExp)] -> Sparse
Sparse (forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape TypeBase Shape NoUniqueness
t) (forall shape u. TypeBase shape u -> PrimType
elemType TypeBase Shape NoUniqueness
t) [(InBounds, SubExp, SubExp)
iv]
    Just (AdjSparse (Sparse Shape
shape PrimType
pt [(InBounds, SubExp, SubExp)]
ivs)) ->
      VName -> Adj -> ADM ()
setAdj VName
v forall a b. (a -> b) -> a -> b
$ Sparse -> Adj
AdjSparse forall a b. (a -> b) -> a -> b
$ Shape -> PrimType -> [(InBounds, SubExp, SubExp)] -> Sparse
Sparse Shape
shape PrimType
pt forall a b. (a -> b) -> a -> b
$ (InBounds, SubExp, SubExp)
iv forall a. a -> [a] -> [a]
: [(InBounds, SubExp, SubExp)]
ivs
    Just adj :: Adj
adj@AdjVal {} -> do
      VName
v_adj <- Adj -> ADM VName
adjVal Adj
adj
      TypeBase Shape NoUniqueness
v_adj_t <- forall rep (m :: * -> *).
HasScope rep m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType VName
v_adj
      VName
se_v <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"se_v" forall a b. (a -> b) -> a -> b
$ forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se
      VName -> VName -> ADM ()
insAdj VName
v
        forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< case TypeBase Shape NoUniqueness
v_adj_t of
          Acc {}
            | InBounds
check forall a. Eq a => a -> a -> Bool
== InBounds
OutOfBounds ->
                forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
v_adj
            | Bool
otherwise -> do
                [SubExp]
dims <- forall u. TypeBase Shape u -> [SubExp]
arrayDims forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall rep (m :: * -> *).
HasScope rep m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType VName
se_v
                ~[VName
v_adj'] <-
                  Int
-> [VName] -> ([VName] -> [VName] -> ADM [VName]) -> ADM [VName]
tabNest (forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
dims) [VName
se_v, VName
v_adj] forall a b. (a -> b) -> a -> b
$ \[VName]
is [VName
se_v', VName
v_adj'] ->
                    forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m [VName]
letTupExp [Char]
"acc" forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
                      VName -> [SubExp] -> [SubExp] -> BasicOp
UpdateAcc VName
v_adj' (SubExp
i forall a. a -> [a] -> [a]
: forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
is) [VName -> SubExp
Var VName
se_v']
                forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
v_adj'
          TypeBase Shape NoUniqueness
_ -> do
            let stms :: Safety -> ADM VName
stms Safety
s = do
                  VName
v_adj_i <-
                    forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp (VName -> [Char]
baseString VName
v_adj forall a. Semigroup a => a -> a -> a
<> [Char]
"_i") forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
                      VName -> Slice SubExp -> BasicOp
Index VName
v_adj forall a b. (a -> b) -> a -> b
$
                        TypeBase Shape NoUniqueness -> [DimIndex SubExp] -> Slice SubExp
fullSlice TypeBase Shape NoUniqueness
v_adj_t [forall d. d -> DimIndex d
DimFix SubExp
i]
                  SubExp
se_update <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"updated_adj_i" forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName -> VName -> ADM (Exp SOACS)
addExp VName
se_v VName
v_adj_i
                  forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp (VName -> [Char]
baseString VName
v_adj) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
                    Safety -> VName -> Slice SubExp -> SubExp -> BasicOp
Update Safety
s VName
v_adj (TypeBase Shape NoUniqueness -> [DimIndex SubExp] -> Slice SubExp
fullSlice TypeBase Shape NoUniqueness
v_adj_t [forall d. d -> DimIndex d
DimFix SubExp
i]) SubExp
se_update
            case InBounds
check of
              CheckBounds Maybe SubExp
_ -> Safety -> ADM VName
stms Safety
Safe
              InBounds
AssumeBounds -> Safety -> ADM VName
stms Safety
Unsafe
              InBounds
OutOfBounds -> forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
v_adj

-- | Is this primal variable active in the AD sense?  FIXME: this is
-- (obviously) much too conservative.
isActive :: VName -> ADM Bool
isActive :: VName -> ADM Bool
isActive = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall a. Eq a => a -> a -> Bool
/= forall shape u. PrimType -> TypeBase shape u
Prim PrimType
Unit) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep (m :: * -> *).
HasScope rep m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType

-- | Ignore any changes to adjoints made while evaluating this action.
subAD :: ADM a -> ADM a
subAD :: forall a. ADM a -> ADM a
subAD ADM a
m = do
  Map VName Adj
old_state_adjs <- forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets RState -> Map VName Adj
stateAdjs
  a
x <- ADM a
m
  forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify forall a b. (a -> b) -> a -> b
$ \RState
s -> RState
s {stateAdjs :: Map VName Adj
stateAdjs = Map VName Adj
old_state_adjs}
  forall (f :: * -> *) a. Applicative f => a -> f a
pure a
x

subSubsts :: ADM a -> ADM a
subSubsts :: forall a. ADM a -> ADM a
subSubsts ADM a
m = do
  Substitutions
old_state_substs <- forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets RState -> Substitutions
stateSubsts
  a
x <- ADM a
m
  forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify forall a b. (a -> b) -> a -> b
$ \RState
s -> RState
s {stateSubsts :: Substitutions
stateSubsts = Substitutions
old_state_substs}
  forall (f :: * -> *) a. Applicative f => a -> f a
pure a
x

data VjpOps = VjpOps
  { VjpOps -> [Adj] -> [VName] -> Lambda SOACS -> ADM (Lambda SOACS)
vjpLambda :: [Adj] -> [VName] -> Lambda SOACS -> ADM (Lambda SOACS),
    VjpOps -> Stm SOACS -> ADM () -> ADM ()
vjpStm :: Stm SOACS -> ADM () -> ADM ()
  }

-- | @setLoopTape v vs@ establishes @vs@ as the name of the array
-- where values of loop parameter @v@ from the forward pass are
-- stored.
setLoopTape :: VName -> VName -> ADM ()
setLoopTape :: VName -> VName -> ADM ()
setLoopTape VName
v VName
vs = forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify forall a b. (a -> b) -> a -> b
$ \RState
env ->
  RState
env {stateLoopTape :: Substitutions
stateLoopTape = forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
v VName
vs forall a b. (a -> b) -> a -> b
$ RState -> Substitutions
stateLoopTape RState
env}

-- | Look-up the name of the array where @v@ is stored.
lookupLoopTape :: VName -> ADM (Maybe VName)
lookupLoopTape :: VName -> ADM (Maybe VName)
lookupLoopTape VName
v = forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets forall a b. (a -> b) -> a -> b
$ forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v forall b c a. (b -> c) -> (a -> b) -> a -> c
. RState -> Substitutions
stateLoopTape

-- | @substLoopTape v v'@ substitutes the key @v@ for @v'@. That is,
-- if @v |-> vs@ then after the substitution @v' |-> vs@ (and @v@
-- points to nothing).
substLoopTape :: VName -> VName -> ADM ()
substLoopTape :: VName -> VName -> ADM ()
substLoopTape VName
v VName
v' = forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (VName -> VName -> ADM ()
setLoopTape VName
v') forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName -> ADM (Maybe VName)
lookupLoopTape VName
v

-- | Renames the keys of the loop tape. Useful for fixing the
-- the names in the loop tape after a loop rename.
renameLoopTape :: Substitutions -> ADM ()
renameLoopTape :: Substitutions -> ADM ()
renameLoopTape = forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry VName -> VName -> ADM ()
substLoopTape) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k a. Map k a -> [(k, a)]
M.toList

-- Note [Consumption]
--
-- Parts of this transformation depends on duplicating computation.
-- This is a problem when a primal expression consumes arrays (via
-- e.g. Update).  For example, consider how we handle this conditional:
--
--   if b then ys with [0] = 0 else ys
--
-- This consumes the array 'ys', which means that when we later
-- generate code for the return sweep, we can no longer use 'ys'.
-- This is a problem, because when we call 'diffBody' on the branch
-- bodies, we'll keep the primal code (maybe it'll be removed by
-- simplification later - we cannot know).  A similar issue occurs for
-- SOACs.  Our solution is to make copies of all consumes arrays:
--
--  let ys_copy = copy ys
--
-- Then we generate code for the return sweep as normal, but replace
-- _every instance_ of 'ys' in the generated code with 'ys_copy'.
-- This works because Futhark does not have *semantic* in-place
-- updates - any uniqueness violation can be replaced with copies (on
-- arrays, anyway).
--
-- If we are lucky, the uses of 'ys_copy' will be removed by
-- simplification, and there will be no overhead.  But even if not,
-- this is still (asymptotically) efficient because the array that is
-- being consumed must in any case have been produced within the code
-- that we are differentiating, so a copy is at most a scalar
-- overhead.  This is _not_ the case when loops are involved.
--
-- Also, the above only works for arrays, not accumulator variables.
-- Those will need some other mechanism.