{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TypeFamilies #-}

-- Naming scheme:
--
-- An adjoint-related object for "x" is named "x_adj".  This means
-- both actual adjoints and statements.
--
-- Do not assume "x'" means anything related to derivatives.
module Futhark.AD.Rev.Monad
  ( ADM,
    RState (..),
    runADM,
    Adj (..),
    InBounds (..),
    Sparse (..),
    adjFromParam,
    adjFromVar,
    lookupAdj,
    lookupAdjVal,
    adjVal,
    updateAdj,
    updateSubExpAdj,
    updateAdjSlice,
    updateAdjIndex,
    setAdj,
    insAdj,
    insSubExpAdj,
    adjsReps,
    --
    copyConsumedArrsInStm,
    copyConsumedArrsInBody,
    addSubstitution,
    returnSweepCode,
    --
    adjVName,
    subAD,
    noAdjsFor,
    subSubsts,
    isActive,
    --
    tabNest,
    oneExp,
    zeroExp,
    unitAdjOfType,
    addLambda,
    --
    VjpOps (..),
    --
    setLoopTape,
    lookupLoopTape,
    substLoopTape,
    renameLoopTape,
  )
where

import Control.Monad
import Control.Monad.State.Strict
import Data.Bifunctor (second)
import Data.List (foldl')
import qualified Data.Map as M
import Data.Maybe
import qualified Futhark.Analysis.Alias as Alias
import Futhark.Analysis.PrimExp.Convert
import Futhark.Builder
import Futhark.IR.Aliases (consumedInStms)
import Futhark.IR.Prop.Aliases
import Futhark.IR.SOACS
import Futhark.Tools
import Futhark.Transform.Substitute
import Futhark.Util (chunks)

zeroExp :: Type -> Exp rep
zeroExp :: Type -> Exp rep
zeroExp (Prim PrimType
pt) =
  BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp rep) -> BasicOp -> Exp rep
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant (PrimValue -> SubExp) -> PrimValue -> SubExp
forall a b. (a -> b) -> a -> b
$ PrimType -> PrimValue
blankPrimValue PrimType
pt
zeroExp (Array PrimType
pt Shape
shape NoUniqueness
_) =
  BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp rep) -> BasicOp -> Exp rep
forall a b. (a -> b) -> a -> b
$ Shape -> SubExp -> BasicOp
Replicate Shape
shape (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant (PrimValue -> SubExp) -> PrimValue -> SubExp
forall a b. (a -> b) -> a -> b
$ PrimType -> PrimValue
blankPrimValue PrimType
pt
zeroExp Type
t = [Char] -> Exp rep
forall a. HasCallStack => [Char] -> a
error ([Char] -> Exp rep) -> [Char] -> Exp rep
forall a b. (a -> b) -> a -> b
$ [Char]
"zeroExp: " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ Type -> [Char]
forall a. Pretty a => a -> [Char]
pretty Type
t

onePrim :: PrimType -> PrimValue
onePrim :: PrimType -> PrimValue
onePrim (IntType IntType
it) = IntValue -> PrimValue
IntValue (IntValue -> PrimValue) -> IntValue -> PrimValue
forall a b. (a -> b) -> a -> b
$ IntType -> Int -> IntValue
forall int. Integral int => IntType -> int -> IntValue
intValue IntType
it (Int
1 :: Int)
onePrim (FloatType FloatType
ft) = FloatValue -> PrimValue
FloatValue (FloatValue -> PrimValue) -> FloatValue -> PrimValue
forall a b. (a -> b) -> a -> b
$ FloatType -> Double -> FloatValue
forall num. Real num => FloatType -> num -> FloatValue
floatValue FloatType
ft (Double
1 :: Double)
onePrim PrimType
Bool = Bool -> PrimValue
BoolValue Bool
True
onePrim PrimType
Unit = PrimValue
UnitValue

oneExp :: Type -> Exp rep
oneExp :: Type -> Exp rep
oneExp (Prim PrimType
t) = BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp rep) -> BasicOp -> Exp rep
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
forall v. IsValue v => v -> SubExp
constant (PrimValue -> SubExp) -> PrimValue -> SubExp
forall a b. (a -> b) -> a -> b
$ PrimType -> PrimValue
onePrim PrimType
t
oneExp (Array PrimType
pt Shape
shape NoUniqueness
_) =
  BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp rep) -> BasicOp -> Exp rep
forall a b. (a -> b) -> a -> b
$ Shape -> SubExp -> BasicOp
Replicate Shape
shape (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant (PrimValue -> SubExp) -> PrimValue -> SubExp
forall a b. (a -> b) -> a -> b
$ PrimType -> PrimValue
onePrim PrimType
pt
oneExp Type
t = [Char] -> Exp rep
forall a. HasCallStack => [Char] -> a
error ([Char] -> Exp rep) -> [Char] -> Exp rep
forall a b. (a -> b) -> a -> b
$ [Char]
"oneExp: " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ Type -> [Char]
forall a. Pretty a => a -> [Char]
pretty Type
t

-- | Whether 'Sparse' should check bounds or assume they are correct.
-- The latter results in simpler code.
data InBounds
  = -- | If a SubExp is provided, it references a boolean that is true
    -- when in-bounds.
    CheckBounds (Maybe SubExp)
  | AssumeBounds
  | -- | Dynamically these will always fail, so don't bother
    -- generating code for the update.  This is only needed to ensure
    -- a consistent representation of sparse Jacobians.
    OutOfBounds
  deriving (InBounds -> InBounds -> Bool
(InBounds -> InBounds -> Bool)
-> (InBounds -> InBounds -> Bool) -> Eq InBounds
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: InBounds -> InBounds -> Bool
$c/= :: InBounds -> InBounds -> Bool
== :: InBounds -> InBounds -> Bool
$c== :: InBounds -> InBounds -> Bool
Eq, Eq InBounds
Eq InBounds
-> (InBounds -> InBounds -> Ordering)
-> (InBounds -> InBounds -> Bool)
-> (InBounds -> InBounds -> Bool)
-> (InBounds -> InBounds -> Bool)
-> (InBounds -> InBounds -> Bool)
-> (InBounds -> InBounds -> InBounds)
-> (InBounds -> InBounds -> InBounds)
-> Ord InBounds
InBounds -> InBounds -> Bool
InBounds -> InBounds -> Ordering
InBounds -> InBounds -> InBounds
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: InBounds -> InBounds -> InBounds
$cmin :: InBounds -> InBounds -> InBounds
max :: InBounds -> InBounds -> InBounds
$cmax :: InBounds -> InBounds -> InBounds
>= :: InBounds -> InBounds -> Bool
$c>= :: InBounds -> InBounds -> Bool
> :: InBounds -> InBounds -> Bool
$c> :: InBounds -> InBounds -> Bool
<= :: InBounds -> InBounds -> Bool
$c<= :: InBounds -> InBounds -> Bool
< :: InBounds -> InBounds -> Bool
$c< :: InBounds -> InBounds -> Bool
compare :: InBounds -> InBounds -> Ordering
$ccompare :: InBounds -> InBounds -> Ordering
$cp1Ord :: Eq InBounds
Ord, Int -> InBounds -> [Char] -> [Char]
[InBounds] -> [Char] -> [Char]
InBounds -> [Char]
(Int -> InBounds -> [Char] -> [Char])
-> (InBounds -> [Char])
-> ([InBounds] -> [Char] -> [Char])
-> Show InBounds
forall a.
(Int -> a -> [Char] -> [Char])
-> (a -> [Char]) -> ([a] -> [Char] -> [Char]) -> Show a
showList :: [InBounds] -> [Char] -> [Char]
$cshowList :: [InBounds] -> [Char] -> [Char]
show :: InBounds -> [Char]
$cshow :: InBounds -> [Char]
showsPrec :: Int -> InBounds -> [Char] -> [Char]
$cshowsPrec :: Int -> InBounds -> [Char] -> [Char]
Show)

-- | A symbolic representation of an array that is all zeroes, except at one
-- index.
data Sparse = Sparse
  { -- | The shape of the array.
    Sparse -> Shape
sparseShape :: Shape,
    -- | Element type of the array.
    Sparse -> PrimType
sparseType :: PrimType,
    -- | Locations and values of nonzero values.  Indexes may be
    -- negative, in which case the value is ignored (unless
    -- 'AssumeBounds' is used).
    Sparse -> [(InBounds, SubExp, SubExp)]
sparseIdxVals :: [(InBounds, SubExp, SubExp)]
  }
  deriving (Sparse -> Sparse -> Bool
(Sparse -> Sparse -> Bool)
-> (Sparse -> Sparse -> Bool) -> Eq Sparse
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Sparse -> Sparse -> Bool
$c/= :: Sparse -> Sparse -> Bool
== :: Sparse -> Sparse -> Bool
$c== :: Sparse -> Sparse -> Bool
Eq, Eq Sparse
Eq Sparse
-> (Sparse -> Sparse -> Ordering)
-> (Sparse -> Sparse -> Bool)
-> (Sparse -> Sparse -> Bool)
-> (Sparse -> Sparse -> Bool)
-> (Sparse -> Sparse -> Bool)
-> (Sparse -> Sparse -> Sparse)
-> (Sparse -> Sparse -> Sparse)
-> Ord Sparse
Sparse -> Sparse -> Bool
Sparse -> Sparse -> Ordering
Sparse -> Sparse -> Sparse
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: Sparse -> Sparse -> Sparse
$cmin :: Sparse -> Sparse -> Sparse
max :: Sparse -> Sparse -> Sparse
$cmax :: Sparse -> Sparse -> Sparse
>= :: Sparse -> Sparse -> Bool
$c>= :: Sparse -> Sparse -> Bool
> :: Sparse -> Sparse -> Bool
$c> :: Sparse -> Sparse -> Bool
<= :: Sparse -> Sparse -> Bool
$c<= :: Sparse -> Sparse -> Bool
< :: Sparse -> Sparse -> Bool
$c< :: Sparse -> Sparse -> Bool
compare :: Sparse -> Sparse -> Ordering
$ccompare :: Sparse -> Sparse -> Ordering
$cp1Ord :: Eq Sparse
Ord, Int -> Sparse -> [Char] -> [Char]
[Sparse] -> [Char] -> [Char]
Sparse -> [Char]
(Int -> Sparse -> [Char] -> [Char])
-> (Sparse -> [Char])
-> ([Sparse] -> [Char] -> [Char])
-> Show Sparse
forall a.
(Int -> a -> [Char] -> [Char])
-> (a -> [Char]) -> ([a] -> [Char] -> [Char]) -> Show a
showList :: [Sparse] -> [Char] -> [Char]
$cshowList :: [Sparse] -> [Char] -> [Char]
show :: Sparse -> [Char]
$cshow :: Sparse -> [Char]
showsPrec :: Int -> Sparse -> [Char] -> [Char]
$cshowsPrec :: Int -> Sparse -> [Char] -> [Char]
Show)

-- | The adjoint of a variable.
data Adj
  = AdjSparse Sparse
  | AdjVal SubExp
  | AdjZero Shape PrimType
  deriving (Adj -> Adj -> Bool
(Adj -> Adj -> Bool) -> (Adj -> Adj -> Bool) -> Eq Adj
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Adj -> Adj -> Bool
$c/= :: Adj -> Adj -> Bool
== :: Adj -> Adj -> Bool
$c== :: Adj -> Adj -> Bool
Eq, Eq Adj
Eq Adj
-> (Adj -> Adj -> Ordering)
-> (Adj -> Adj -> Bool)
-> (Adj -> Adj -> Bool)
-> (Adj -> Adj -> Bool)
-> (Adj -> Adj -> Bool)
-> (Adj -> Adj -> Adj)
-> (Adj -> Adj -> Adj)
-> Ord Adj
Adj -> Adj -> Bool
Adj -> Adj -> Ordering
Adj -> Adj -> Adj
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: Adj -> Adj -> Adj
$cmin :: Adj -> Adj -> Adj
max :: Adj -> Adj -> Adj
$cmax :: Adj -> Adj -> Adj
>= :: Adj -> Adj -> Bool
$c>= :: Adj -> Adj -> Bool
> :: Adj -> Adj -> Bool
$c> :: Adj -> Adj -> Bool
<= :: Adj -> Adj -> Bool
$c<= :: Adj -> Adj -> Bool
< :: Adj -> Adj -> Bool
$c< :: Adj -> Adj -> Bool
compare :: Adj -> Adj -> Ordering
$ccompare :: Adj -> Adj -> Ordering
$cp1Ord :: Eq Adj
Ord, Int -> Adj -> [Char] -> [Char]
[Adj] -> [Char] -> [Char]
Adj -> [Char]
(Int -> Adj -> [Char] -> [Char])
-> (Adj -> [Char]) -> ([Adj] -> [Char] -> [Char]) -> Show Adj
forall a.
(Int -> a -> [Char] -> [Char])
-> (a -> [Char]) -> ([a] -> [Char] -> [Char]) -> Show a
showList :: [Adj] -> [Char] -> [Char]
$cshowList :: [Adj] -> [Char] -> [Char]
show :: Adj -> [Char]
$cshow :: Adj -> [Char]
showsPrec :: Int -> Adj -> [Char] -> [Char]
$cshowsPrec :: Int -> Adj -> [Char] -> [Char]
Show)

instance Substitute Adj where
  substituteNames :: Map VName VName -> Adj -> Adj
substituteNames Map VName VName
m (AdjVal (Var VName
v)) = SubExp -> Adj
AdjVal (SubExp -> Adj) -> SubExp -> Adj
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Map VName VName -> VName -> VName
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
m VName
v
  substituteNames Map VName VName
_ Adj
adj = Adj
adj

zeroArray :: MonadBuilder m => Shape -> Type -> m VName
zeroArray :: Shape -> Type -> m VName
zeroArray Shape
shape Type
t
  | Shape -> Int
forall a. ArrayShape a => a -> Int
shapeRank Shape
shape Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 =
      [Char] -> Exp (Rep m) -> m VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"zero" (Exp (Rep m) -> m VName) -> Exp (Rep m) -> m VName
forall a b. (a -> b) -> a -> b
$ Type -> Exp (Rep m)
forall rep. Type -> Exp rep
zeroExp Type
t
  | Bool
otherwise = do
      SubExp
zero <- [Char] -> Exp (Rep m) -> m SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"zero" (Exp (Rep m) -> m SubExp) -> Exp (Rep m) -> m SubExp
forall a b. (a -> b) -> a -> b
$ Type -> Exp (Rep m)
forall rep. Type -> Exp rep
zeroExp Type
t
      Attrs -> m VName -> m VName
forall (m :: * -> *) a. MonadBuilder m => Attrs -> m a -> m a
attributing (Attr -> Attrs
oneAttr Attr
"sequential") (m VName -> m VName) -> m VName -> m VName
forall a b. (a -> b) -> a -> b
$
        [Char] -> Exp (Rep m) -> m VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"zeroes_" (Exp (Rep m) -> m VName) -> Exp (Rep m) -> m VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep m)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep m)) -> BasicOp -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$ Shape -> SubExp -> BasicOp
Replicate Shape
shape SubExp
zero

sparseArray :: (MonadBuilder m, Rep m ~ SOACS) => Sparse -> m VName
sparseArray :: Sparse -> m VName
sparseArray (Sparse Shape
shape PrimType
t [(InBounds, SubExp, SubExp)]
ivs) = do
  (VName -> [(InBounds, SubExp, SubExp)] -> m VName)
-> [(InBounds, SubExp, SubExp)] -> VName -> m VName
forall a b c. (a -> b -> c) -> b -> a -> c
flip ((VName -> (InBounds, SubExp, SubExp) -> m VName)
-> VName -> [(InBounds, SubExp, SubExp)] -> m VName
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM VName -> (InBounds, SubExp, SubExp) -> m VName
f) [(InBounds, SubExp, SubExp)]
ivs (VName -> m VName) -> m VName -> m VName
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Shape -> Type -> m VName
forall (m :: * -> *). MonadBuilder m => Shape -> Type -> m VName
zeroArray Shape
shape (PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
t)
  where
    arr_t :: Type
arr_t = PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
t Type -> Shape -> Type
`arrayOfShape` Shape
shape
    f :: VName -> (InBounds, SubExp, SubExp) -> m VName
f VName
arr (InBounds
check, SubExp
i, SubExp
se) = do
      let stm :: Safety -> m VName
stm Safety
s =
            [Char] -> Exp (Rep m) -> m VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"sparse" (Exp SOACS -> m VName)
-> (BasicOp -> Exp SOACS) -> BasicOp -> m VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> m VName) -> BasicOp -> m VName
forall a b. (a -> b) -> a -> b
$
              Safety -> VName -> Slice SubExp -> SubExp -> BasicOp
Update Safety
s VName
arr (Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice Type
arr_t [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix SubExp
i]) SubExp
se
      case InBounds
check of
        InBounds
AssumeBounds -> Safety -> m VName
stm Safety
Unsafe
        CheckBounds Maybe SubExp
_ -> Safety -> m VName
stm Safety
Safe
        InBounds
OutOfBounds -> VName -> m VName
forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
arr

adjFromVar :: VName -> Adj
adjFromVar :: VName -> Adj
adjFromVar = SubExp -> Adj
AdjVal (SubExp -> Adj) -> (VName -> SubExp) -> VName -> Adj
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var

adjFromParam :: Param t -> Adj
adjFromParam :: Param t -> Adj
adjFromParam = VName -> Adj
adjFromVar (VName -> Adj) -> (Param t -> VName) -> Param t -> Adj
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param t -> VName
forall dec. Param dec -> VName
paramName

unitAdjOfType :: Type -> ADM Adj
unitAdjOfType :: Type -> ADM Adj
unitAdjOfType Type
t = SubExp -> Adj
AdjVal (SubExp -> Adj) -> ADM SubExp -> ADM Adj
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Char] -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"adj_unit" (Type -> Exp SOACS
forall rep. Type -> Exp rep
oneExp Type
t)

-- | The values representing an adjoint in symbolic form.  This is
-- used for when we wish to return an Adj from a Body or similar
-- without forcing manifestation.  Also returns a function for
-- reassembling the Adj from a new representation (the list must have
-- the same length).
adjRep :: Adj -> ([SubExp], [SubExp] -> Adj)
adjRep :: Adj -> ([SubExp], [SubExp] -> Adj)
adjRep (AdjVal SubExp
se) = ([SubExp
se], \[SubExp
se'] -> SubExp -> Adj
AdjVal SubExp
se')
adjRep (AdjZero Shape
shape PrimType
pt) = ([], \[] -> Shape -> PrimType -> Adj
AdjZero Shape
shape PrimType
pt)
adjRep (AdjSparse (Sparse Shape
shape PrimType
pt [(InBounds, SubExp, SubExp)]
ivs)) =
  (((InBounds, SubExp, SubExp) -> [SubExp])
-> [(InBounds, SubExp, SubExp)] -> [SubExp]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (InBounds, SubExp, SubExp) -> [SubExp]
forall a a. (a, a, a) -> [a]
ivRep [(InBounds, SubExp, SubExp)]
ivs, Sparse -> Adj
AdjSparse (Sparse -> Adj) -> ([SubExp] -> Sparse) -> [SubExp] -> Adj
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Shape -> PrimType -> [(InBounds, SubExp, SubExp)] -> Sparse
Sparse Shape
shape PrimType
pt ([(InBounds, SubExp, SubExp)] -> Sparse)
-> ([SubExp] -> [(InBounds, SubExp, SubExp)]) -> [SubExp] -> Sparse
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [(InBounds, SubExp, SubExp)]
-> [SubExp] -> [(InBounds, SubExp, SubExp)]
forall b c c. [(InBounds, b, c)] -> [c] -> [(InBounds, c, c)]
repIvs [(InBounds, SubExp, SubExp)]
ivs)
  where
    ivRep :: (a, a, a) -> [a]
ivRep (a
_, a
i, a
v) = [a
i, a
v]
    repIvs :: [(InBounds, b, c)] -> [c] -> [(InBounds, c, c)]
repIvs ((InBounds
check, b
_, c
_) : [(InBounds, b, c)]
ivs') (c
i : c
v : [c]
ses) =
      (InBounds
check', c
i, c
v) (InBounds, c, c) -> [(InBounds, c, c)] -> [(InBounds, c, c)]
forall a. a -> [a] -> [a]
: [(InBounds, b, c)] -> [c] -> [(InBounds, c, c)]
repIvs [(InBounds, b, c)]
ivs' [c]
ses
      where
        check' :: InBounds
check' = case InBounds
check of
          InBounds
AssumeBounds -> InBounds
AssumeBounds
          CheckBounds Maybe SubExp
b -> Maybe SubExp -> InBounds
CheckBounds Maybe SubExp
b
          InBounds
OutOfBounds -> Maybe SubExp -> InBounds
CheckBounds (SubExp -> Maybe SubExp
forall a. a -> Maybe a
Just (Bool -> SubExp
forall v. IsValue v => v -> SubExp
constant Bool
False)) -- sic!
    repIvs [(InBounds, b, c)]
_ [c]
_ = []

-- | Conveniently convert a list of Adjs to their representation, as
-- well as produce a function for converting back.
adjsReps :: [Adj] -> ([SubExp], [SubExp] -> [Adj])
adjsReps :: [Adj] -> ([SubExp], [SubExp] -> [Adj])
adjsReps [Adj]
adjs =
  let ([[SubExp]]
reps, [[SubExp] -> Adj]
fs) = [([SubExp], [SubExp] -> Adj)] -> ([[SubExp]], [[SubExp] -> Adj])
forall a b. [(a, b)] -> ([a], [b])
unzip ([([SubExp], [SubExp] -> Adj)] -> ([[SubExp]], [[SubExp] -> Adj]))
-> [([SubExp], [SubExp] -> Adj)] -> ([[SubExp]], [[SubExp] -> Adj])
forall a b. (a -> b) -> a -> b
$ (Adj -> ([SubExp], [SubExp] -> Adj))
-> [Adj] -> [([SubExp], [SubExp] -> Adj)]
forall a b. (a -> b) -> [a] -> [b]
map Adj -> ([SubExp], [SubExp] -> Adj)
adjRep [Adj]
adjs
   in ([[SubExp]] -> [SubExp]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[SubExp]]
reps, (([SubExp] -> Adj) -> [SubExp] -> Adj)
-> [[SubExp] -> Adj] -> [[SubExp]] -> [Adj]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith ([SubExp] -> Adj) -> [SubExp] -> Adj
forall a b. (a -> b) -> a -> b
($) [[SubExp] -> Adj]
fs ([[SubExp]] -> [Adj])
-> ([SubExp] -> [[SubExp]]) -> [SubExp] -> [Adj]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Int] -> [SubExp] -> [[SubExp]]
forall a. [Int] -> [a] -> [[a]]
chunks (([SubExp] -> Int) -> [[SubExp]] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map [SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [[SubExp]]
reps))

data RState = RState
  { RState -> Map VName Adj
stateAdjs :: M.Map VName Adj,
    RState -> Map VName VName
stateLoopTape :: Substitutions,
    RState -> Map VName VName
stateSubsts :: Substitutions,
    RState -> VNameSource
stateNameSource :: VNameSource
  }

newtype ADM a = ADM (BuilderT SOACS (State RState) a)
  deriving
    ( a -> ADM b -> ADM a
(a -> b) -> ADM a -> ADM b
(forall a b. (a -> b) -> ADM a -> ADM b)
-> (forall a b. a -> ADM b -> ADM a) -> Functor ADM
forall a b. a -> ADM b -> ADM a
forall a b. (a -> b) -> ADM a -> ADM b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: a -> ADM b -> ADM a
$c<$ :: forall a b. a -> ADM b -> ADM a
fmap :: (a -> b) -> ADM a -> ADM b
$cfmap :: forall a b. (a -> b) -> ADM a -> ADM b
Functor,
      Functor ADM
a -> ADM a
Functor ADM
-> (forall a. a -> ADM a)
-> (forall a b. ADM (a -> b) -> ADM a -> ADM b)
-> (forall a b c. (a -> b -> c) -> ADM a -> ADM b -> ADM c)
-> (forall a b. ADM a -> ADM b -> ADM b)
-> (forall a b. ADM a -> ADM b -> ADM a)
-> Applicative ADM
ADM a -> ADM b -> ADM b
ADM a -> ADM b -> ADM a
ADM (a -> b) -> ADM a -> ADM b
(a -> b -> c) -> ADM a -> ADM b -> ADM c
forall a. a -> ADM a
forall a b. ADM a -> ADM b -> ADM a
forall a b. ADM a -> ADM b -> ADM b
forall a b. ADM (a -> b) -> ADM a -> ADM b
forall a b c. (a -> b -> c) -> ADM a -> ADM b -> ADM c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: ADM a -> ADM b -> ADM a
$c<* :: forall a b. ADM a -> ADM b -> ADM a
*> :: ADM a -> ADM b -> ADM b
$c*> :: forall a b. ADM a -> ADM b -> ADM b
liftA2 :: (a -> b -> c) -> ADM a -> ADM b -> ADM c
$cliftA2 :: forall a b c. (a -> b -> c) -> ADM a -> ADM b -> ADM c
<*> :: ADM (a -> b) -> ADM a -> ADM b
$c<*> :: forall a b. ADM (a -> b) -> ADM a -> ADM b
pure :: a -> ADM a
$cpure :: forall a. a -> ADM a
$cp1Applicative :: Functor ADM
Applicative,
      Applicative ADM
a -> ADM a
Applicative ADM
-> (forall a b. ADM a -> (a -> ADM b) -> ADM b)
-> (forall a b. ADM a -> ADM b -> ADM b)
-> (forall a. a -> ADM a)
-> Monad ADM
ADM a -> (a -> ADM b) -> ADM b
ADM a -> ADM b -> ADM b
forall a. a -> ADM a
forall a b. ADM a -> ADM b -> ADM b
forall a b. ADM a -> (a -> ADM b) -> ADM b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: a -> ADM a
$creturn :: forall a. a -> ADM a
>> :: ADM a -> ADM b -> ADM b
$c>> :: forall a b. ADM a -> ADM b -> ADM b
>>= :: ADM a -> (a -> ADM b) -> ADM b
$c>>= :: forall a b. ADM a -> (a -> ADM b) -> ADM b
$cp1Monad :: Applicative ADM
Monad,
      MonadState RState,
      Monad ADM
Applicative ADM
ADM VNameSource
Applicative ADM
-> Monad ADM
-> ADM VNameSource
-> (VNameSource -> ADM ())
-> MonadFreshNames ADM
VNameSource -> ADM ()
forall (m :: * -> *).
Applicative m
-> Monad m
-> m VNameSource
-> (VNameSource -> m ())
-> MonadFreshNames m
putNameSource :: VNameSource -> ADM ()
$cputNameSource :: VNameSource -> ADM ()
getNameSource :: ADM VNameSource
$cgetNameSource :: ADM VNameSource
$cp2MonadFreshNames :: Monad ADM
$cp1MonadFreshNames :: Applicative ADM
MonadFreshNames,
      HasScope SOACS,
      LocalScope SOACS
    )

instance MonadBuilder ADM where
  type Rep ADM = SOACS
  mkExpDecM :: Pat (LetDec (Rep ADM)) -> Exp (Rep ADM) -> ADM (ExpDec (Rep ADM))
mkExpDecM Pat (LetDec (Rep ADM))
pat Exp (Rep ADM)
e = BuilderT SOACS (State RState) () -> ADM ()
forall a. BuilderT SOACS (State RState) a -> ADM a
ADM (BuilderT SOACS (State RState) () -> ADM ())
-> BuilderT SOACS (State RState) () -> ADM ()
forall a b. (a -> b) -> a -> b
$ Pat (LetDec (Rep (BuilderT SOACS (State RState))))
-> Exp (Rep (BuilderT SOACS (State RState)))
-> BuilderT
     SOACS (State RState) (ExpDec (Rep (BuilderT SOACS (State RState))))
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m (ExpDec (Rep m))
mkExpDecM Pat (LetDec (Rep (BuilderT SOACS (State RState))))
Pat (LetDec (Rep ADM))
pat Exp (Rep (BuilderT SOACS (State RState)))
Exp (Rep ADM)
e
  mkBodyM :: Stms (Rep ADM) -> Result -> ADM (Body (Rep ADM))
mkBodyM Stms (Rep ADM)
bnds Result
res = BuilderT SOACS (State RState) (Body SOACS) -> ADM (Body SOACS)
forall a. BuilderT SOACS (State RState) a -> ADM a
ADM (BuilderT SOACS (State RState) (Body SOACS) -> ADM (Body SOACS))
-> BuilderT SOACS (State RState) (Body SOACS) -> ADM (Body SOACS)
forall a b. (a -> b) -> a -> b
$ Stms (Rep (BuilderT SOACS (State RState)))
-> Result
-> BuilderT
     SOACS (State RState) (Body (Rep (BuilderT SOACS (State RState))))
forall (m :: * -> *).
MonadBuilder m =>
Stms (Rep m) -> Result -> m (Body (Rep m))
mkBodyM Stms (Rep (BuilderT SOACS (State RState)))
Stms (Rep ADM)
bnds Result
res
  mkLetNamesM :: [VName] -> Exp (Rep ADM) -> ADM (Stm (Rep ADM))
mkLetNamesM [VName]
pat Exp (Rep ADM)
e = BuilderT SOACS (State RState) (Stm SOACS) -> ADM (Stm SOACS)
forall a. BuilderT SOACS (State RState) a -> ADM a
ADM (BuilderT SOACS (State RState) (Stm SOACS) -> ADM (Stm SOACS))
-> BuilderT SOACS (State RState) (Stm SOACS) -> ADM (Stm SOACS)
forall a b. (a -> b) -> a -> b
$ [VName]
-> Exp (Rep (BuilderT SOACS (State RState)))
-> BuilderT
     SOACS (State RState) (Stm (Rep (BuilderT SOACS (State RState))))
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m (Stm (Rep m))
mkLetNamesM [VName]
pat Exp (Rep (BuilderT SOACS (State RState)))
Exp (Rep ADM)
e

  addStms :: Stms (Rep ADM) -> ADM ()
addStms = BuilderT SOACS (State RState) () -> ADM ()
forall a. BuilderT SOACS (State RState) a -> ADM a
ADM (BuilderT SOACS (State RState) () -> ADM ())
-> (Stms SOACS -> BuilderT SOACS (State RState) ())
-> Stms SOACS
-> ADM ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stms SOACS -> BuilderT SOACS (State RState) ()
forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms
  collectStms :: ADM a -> ADM (a, Stms (Rep ADM))
collectStms (ADM BuilderT SOACS (State RState) a
m) = BuilderT SOACS (State RState) (a, Stms SOACS)
-> ADM (a, Stms SOACS)
forall a. BuilderT SOACS (State RState) a -> ADM a
ADM (BuilderT SOACS (State RState) (a, Stms SOACS)
 -> ADM (a, Stms SOACS))
-> BuilderT SOACS (State RState) (a, Stms SOACS)
-> ADM (a, Stms SOACS)
forall a b. (a -> b) -> a -> b
$ BuilderT SOACS (State RState) a
-> BuilderT
     SOACS
     (State RState)
     (a, Stms (Rep (BuilderT SOACS (State RState))))
forall (m :: * -> *) a.
MonadBuilder m =>
m a -> m (a, Stms (Rep m))
collectStms BuilderT SOACS (State RState) a
m

instance MonadFreshNames (State RState) where
  getNameSource :: State RState VNameSource
getNameSource = (RState -> VNameSource) -> State RState VNameSource
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets RState -> VNameSource
stateNameSource
  putNameSource :: VNameSource -> State RState ()
putNameSource VNameSource
src = (RState -> RState) -> State RState ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (\RState
env -> RState
env {stateNameSource :: VNameSource
stateNameSource = VNameSource
src})

runADM :: MonadFreshNames m => ADM a -> m a
runADM :: ADM a -> m a
runADM (ADM BuilderT SOACS (State RState) a
m) =
  (VNameSource -> (a, VNameSource)) -> m a
forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource ((VNameSource -> (a, VNameSource)) -> m a)
-> (VNameSource -> (a, VNameSource)) -> m a
forall a b. (a -> b) -> a -> b
$ \VNameSource
vn ->
    (RState -> VNameSource) -> (a, RState) -> (a, VNameSource)
forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second RState -> VNameSource
stateNameSource ((a, RState) -> (a, VNameSource))
-> (a, RState) -> (a, VNameSource)
forall a b. (a -> b) -> a -> b
$
      State RState a -> RState -> (a, RState)
forall s a. State s a -> s -> (a, s)
runState
        ((a, Stms SOACS) -> a
forall a b. (a, b) -> a
fst ((a, Stms SOACS) -> a)
-> State RState (a, Stms SOACS) -> State RState a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> BuilderT SOACS (State RState) a
-> Scope SOACS -> State RState (a, Stms SOACS)
forall (m :: * -> *) rep a.
MonadFreshNames m =>
BuilderT rep m a -> Scope rep -> m (a, Stms rep)
runBuilderT BuilderT SOACS (State RState) a
m Scope SOACS
forall a. Monoid a => a
mempty)
        (Map VName Adj
-> Map VName VName -> Map VName VName -> VNameSource -> RState
RState Map VName Adj
forall a. Monoid a => a
mempty Map VName VName
forall a. Monoid a => a
mempty Map VName VName
forall a. Monoid a => a
mempty VNameSource
vn)

adjVal :: Adj -> ADM VName
adjVal :: Adj -> ADM VName
adjVal (AdjVal SubExp
se) = [Char] -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"const_adj" (Exp (Rep ADM) -> ADM VName) -> Exp (Rep ADM) -> ADM VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp SOACS) -> BasicOp -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se
adjVal (AdjSparse Sparse
sparse) = Sparse -> ADM VName
forall (m :: * -> *).
(MonadBuilder m, Rep m ~ SOACS) =>
Sparse -> m VName
sparseArray Sparse
sparse
adjVal (AdjZero Shape
shape PrimType
t) = Shape -> Type -> ADM VName
forall (m :: * -> *). MonadBuilder m => Shape -> Type -> m VName
zeroArray Shape
shape (Type -> ADM VName) -> Type -> ADM VName
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
t

setAdj :: VName -> Adj -> ADM ()
setAdj :: VName -> Adj -> ADM ()
setAdj VName
v Adj
v_adj = (RState -> RState) -> ADM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((RState -> RState) -> ADM ()) -> (RState -> RState) -> ADM ()
forall a b. (a -> b) -> a -> b
$ \RState
env ->
  RState
env {stateAdjs :: Map VName Adj
stateAdjs = VName -> Adj -> Map VName Adj -> Map VName Adj
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
v Adj
v_adj (Map VName Adj -> Map VName Adj) -> Map VName Adj -> Map VName Adj
forall a b. (a -> b) -> a -> b
$ RState -> Map VName Adj
stateAdjs RState
env}

insAdj :: VName -> VName -> ADM ()
insAdj :: VName -> VName -> ADM ()
insAdj VName
v = VName -> Adj -> ADM ()
setAdj VName
v (Adj -> ADM ()) -> (VName -> Adj) -> VName -> ADM ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> Adj
AdjVal (SubExp -> Adj) -> (VName -> SubExp) -> VName -> Adj
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var

adjVName :: VName -> ADM VName
adjVName :: VName -> ADM VName
adjVName VName
v = [Char] -> ADM VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName (VName -> [Char]
baseString VName
v [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> [Char]
"_adj")

-- | Create copies of all arrays consumed in the given statement, and
-- return statements which include copies of the consumed arrays.
--
-- See Note [Consumption].
copyConsumedArrsInStm :: Stm SOACS -> ADM (Substitutions, Stms SOACS)
copyConsumedArrsInStm :: Stm SOACS -> ADM (Map VName VName, Stms SOACS)
copyConsumedArrsInStm Stm SOACS
s = Stm SOACS
-> ADM (Map VName VName, Stms SOACS)
-> ADM (Map VName VName, Stms SOACS)
forall rep a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf Stm SOACS
s (ADM (Map VName VName, Stms SOACS)
 -> ADM (Map VName VName, Stms SOACS))
-> ADM (Map VName VName, Stms SOACS)
-> ADM (Map VName VName, Stms SOACS)
forall a b. (a -> b) -> a -> b
$ ADM (Map VName VName) -> ADM (Map VName VName, Stms (Rep ADM))
forall (m :: * -> *) a.
MonadBuilder m =>
m a -> m (a, Stms (Rep m))
collectStms (ADM (Map VName VName) -> ADM (Map VName VName, Stms (Rep ADM)))
-> ADM (Map VName VName) -> ADM (Map VName VName, Stms (Rep ADM))
forall a b. (a -> b) -> a -> b
$ Stm SOACS -> ADM (Map VName VName)
copyConsumedArrsInStm' Stm SOACS
s
  where
    copyConsumedArrsInStm' :: Stm SOACS -> ADM (Map VName VName)
copyConsumedArrsInStm' Stm SOACS
stm =
      let onConsumed :: VName -> ADM [(VName, VName)]
onConsumed VName
v = Stm SOACS -> ADM [(VName, VName)] -> ADM [(VName, VName)]
forall rep a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf Stm SOACS
s (ADM [(VName, VName)] -> ADM [(VName, VName)])
-> ADM [(VName, VName)] -> ADM [(VName, VName)]
forall a b. (a -> b) -> a -> b
$ do
            Type
v_t <- VName -> ADM Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
v
            case Type
v_t of
              Array {} -> do
                VName
v' <- [Char] -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp (VName -> [Char]
baseString VName
v [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> [Char]
"_ad_copy") (BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp SOACS) -> BasicOp -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy VName
v)
                VName -> VName -> ADM ()
addSubstitution VName
v' VName
v
                [(VName, VName)] -> ADM [(VName, VName)]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [(VName
v, VName
v')]
              Type
_ -> [(VName, VName)] -> ADM [(VName, VName)]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [(VName, VName)]
forall a. Monoid a => a
mempty
       in [(VName, VName)] -> Map VName VName
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, VName)] -> Map VName VName)
-> ([[(VName, VName)]] -> [(VName, VName)])
-> [[(VName, VName)]]
-> Map VName VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [[(VName, VName)]] -> [(VName, VName)]
forall a. Monoid a => [a] -> a
mconcat
            ([[(VName, VName)]] -> Map VName VName)
-> ADM [[(VName, VName)]] -> ADM (Map VName VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (VName -> ADM [(VName, VName)])
-> [VName] -> ADM [[(VName, VName)]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> ADM [(VName, VName)]
onConsumed (Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$ Stms (Aliases SOACS) -> Names
forall rep. Aliased rep => Stms rep -> Names
consumedInStms (Stms (Aliases SOACS) -> Names) -> Stms (Aliases SOACS) -> Names
forall a b. (a -> b) -> a -> b
$ (Stms (Aliases SOACS), AliasesAndConsumed) -> Stms (Aliases SOACS)
forall a b. (a, b) -> a
fst (AliasTable
-> Stms SOACS -> (Stms (Aliases SOACS), AliasesAndConsumed)
forall rep.
(ASTRep rep, CanBeAliased (Op rep)) =>
AliasTable -> Stms rep -> (Stms (Aliases rep), AliasesAndConsumed)
Alias.analyseStms AliasTable
forall a. Monoid a => a
mempty (Stm SOACS -> Stms SOACS
forall rep. Stm rep -> Stms rep
oneStm Stm SOACS
stm)))

copyConsumedArrsInBody :: [VName] -> Body SOACS -> ADM Substitutions
copyConsumedArrsInBody :: [VName] -> Body SOACS -> ADM (Map VName VName)
copyConsumedArrsInBody [VName]
dontCopy Body SOACS
b =
  [Map VName VName] -> Map VName VName
forall a. Monoid a => [a] -> a
mconcat ([Map VName VName] -> Map VName VName)
-> ADM [Map VName VName] -> ADM (Map VName VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (VName -> ADM (Map VName VName))
-> [VName] -> ADM [Map VName VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> ADM (Map VName VName)
forall (m :: * -> *).
MonadBuilder m =>
VName -> m (Map VName VName)
onConsumed ((VName -> Bool) -> [VName] -> [VName]
forall a. (a -> Bool) -> [a] -> [a]
filter (VName -> [VName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` [VName]
dontCopy) ([VName] -> [VName]) -> [VName] -> [VName]
forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$ Body (Aliases SOACS) -> Names
forall rep. Aliased rep => Body rep -> Names
consumedInBody (AliasTable -> Body SOACS -> Body (Aliases SOACS)
forall rep.
(ASTRep rep, CanBeAliased (Op rep)) =>
AliasTable -> Body rep -> Body (Aliases rep)
Alias.analyseBody AliasTable
forall a. Monoid a => a
mempty Body SOACS
b))
  where
    onConsumed :: VName -> m (Map VName VName)
onConsumed VName
v = do
      Type
v_t <- VName -> m Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
v
      case Type
v_t of
        Acc {} -> [Char] -> m (Map VName VName)
forall a. HasCallStack => [Char] -> a
error ([Char] -> m (Map VName VName)) -> [Char] -> m (Map VName VName)
forall a b. (a -> b) -> a -> b
$ [Char]
"copyConsumedArrsInBody: Acc " [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> VName -> [Char]
forall a. Pretty a => a -> [Char]
pretty VName
v
        Array {} -> VName -> VName -> Map VName VName
forall k a. k -> a -> Map k a
M.singleton VName
v (VName -> Map VName VName) -> m VName -> m (Map VName VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Char] -> Exp (Rep m) -> m VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp (VName -> [Char]
baseString VName
v [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> [Char]
"_ad_copy") (BasicOp -> Exp (Rep m)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep m)) -> BasicOp -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy VName
v)
        Type
_ -> Map VName VName -> m (Map VName VName)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Map VName VName
forall a. Monoid a => a
mempty

returnSweepCode :: ADM a -> ADM a
returnSweepCode :: ADM a -> ADM a
returnSweepCode ADM a
m = do
  (a
a, Stms SOACS
stms) <- ADM a -> ADM (a, Stms (Rep ADM))
forall (m :: * -> *) a.
MonadBuilder m =>
m a -> m (a, Stms (Rep m))
collectStms ADM a
m
  Map VName VName
substs <- (RState -> Map VName VName) -> ADM (Map VName VName)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets RState -> Map VName VName
stateSubsts
  Stms (Rep ADM) -> ADM ()
forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms (Stms (Rep ADM) -> ADM ()) -> Stms (Rep ADM) -> ADM ()
forall a b. (a -> b) -> a -> b
$ Map VName VName -> Stms SOACS -> Stms SOACS
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs Stms SOACS
stms
  a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure a
a

addSubstitution :: VName -> VName -> ADM ()
addSubstitution :: VName -> VName -> ADM ()
addSubstitution VName
v VName
v' = (RState -> RState) -> ADM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((RState -> RState) -> ADM ()) -> (RState -> RState) -> ADM ()
forall a b. (a -> b) -> a -> b
$ \RState
env ->
  RState
env {stateSubsts :: Map VName VName
stateSubsts = VName -> VName -> Map VName VName -> Map VName VName
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
v VName
v' (Map VName VName -> Map VName VName)
-> Map VName VName -> Map VName VName
forall a b. (a -> b) -> a -> b
$ RState -> Map VName VName
stateSubsts RState
env}

-- While evaluating this action, pretend these variables have no
-- adjoints.  Restore current adjoints afterwards.  This is used for
-- handling certain nested operations. XXX: feels like this should
-- really be part of subAD, somehow.  Main challenge is that we don't
-- want to blank out Accumulator adjoints.  Also, might be inefficient
-- to blank out array adjoints.
noAdjsFor :: Names -> ADM a -> ADM a
noAdjsFor :: Names -> ADM a -> ADM a
noAdjsFor Names
names ADM a
m = do
  [Adj]
old <- (RState -> [Adj]) -> ADM [Adj]
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((RState -> [Adj]) -> ADM [Adj]) -> (RState -> [Adj]) -> ADM [Adj]
forall a b. (a -> b) -> a -> b
$ \RState
env -> (VName -> Maybe Adj) -> [VName] -> [Adj]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (VName -> Map VName Adj -> Maybe Adj
forall k a. Ord k => k -> Map k a -> Maybe a
`M.lookup` RState -> Map VName Adj
stateAdjs RState
env) [VName]
names'
  (RState -> RState) -> ADM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((RState -> RState) -> ADM ()) -> (RState -> RState) -> ADM ()
forall a b. (a -> b) -> a -> b
$ \RState
env -> RState
env {stateAdjs :: Map VName Adj
stateAdjs = (Map VName Adj -> VName -> Map VName Adj)
-> Map VName Adj -> [VName] -> Map VName Adj
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' ((VName -> Map VName Adj -> Map VName Adj)
-> Map VName Adj -> VName -> Map VName Adj
forall a b c. (a -> b -> c) -> b -> a -> c
flip VName -> Map VName Adj -> Map VName Adj
forall k a. Ord k => k -> Map k a -> Map k a
M.delete) (RState -> Map VName Adj
stateAdjs RState
env) [VName]
names'}
  a
x <- ADM a
m
  (RState -> RState) -> ADM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((RState -> RState) -> ADM ()) -> (RState -> RState) -> ADM ()
forall a b. (a -> b) -> a -> b
$ \RState
env -> RState
env {stateAdjs :: Map VName Adj
stateAdjs = [(VName, Adj)] -> Map VName Adj
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([VName] -> [Adj] -> [(VName, Adj)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
names' [Adj]
old) Map VName Adj -> Map VName Adj -> Map VName Adj
forall a. Semigroup a => a -> a -> a
<> RState -> Map VName Adj
stateAdjs RState
env}
  a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure a
x
  where
    names' :: [VName]
names' = Names -> [VName]
namesToList Names
names

addBinOp :: PrimType -> BinOp
addBinOp :: PrimType -> BinOp
addBinOp (IntType IntType
it) = IntType -> Overflow -> BinOp
Add IntType
it Overflow
OverflowWrap
addBinOp (FloatType FloatType
ft) = FloatType -> BinOp
FAdd FloatType
ft
addBinOp PrimType
Bool = BinOp
LogAnd
addBinOp PrimType
Unit = BinOp
LogAnd

tabNest :: Int -> [VName] -> ([VName] -> [VName] -> ADM [VName]) -> ADM [VName]
tabNest :: Int
-> [VName] -> ([VName] -> [VName] -> ADM [VName]) -> ADM [VName]
tabNest = [VName]
-> Int
-> [VName]
-> ([VName] -> [VName] -> ADM [VName])
-> ADM [VName]
forall t (m :: * -> *).
(Eq t, MonadBuilder m, Num t, LParamInfo (Rep m) ~ Type,
 Op (Rep m) ~ SOAC (Rep m), BodyDec (Rep m) ~ ()) =>
[VName]
-> t -> [VName] -> ([VName] -> [VName] -> m [VName]) -> m [VName]
tabNest' []
  where
    tabNest' :: [VName]
-> t -> [VName] -> ([VName] -> [VName] -> m [VName]) -> m [VName]
tabNest' [VName]
is t
0 [VName]
vs [VName] -> [VName] -> m [VName]
f = [VName] -> [VName] -> m [VName]
f ([VName] -> [VName]
forall a. [a] -> [a]
reverse [VName]
is) [VName]
vs
    tabNest' [VName]
is t
n [VName]
vs [VName] -> [VName] -> m [VName]
f = do
      [Type]
vs_ts <- (VName -> m Type) -> [VName] -> m [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> m Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType [VName]
vs
      let w :: SubExp
w = Int -> [Type] -> SubExp
forall u. Int -> [TypeBase Shape u] -> SubExp
arraysSize Int
0 [Type]
vs_ts
      VName
iota <-
        [Char] -> Exp (Rep m) -> m VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"tab_iota" (Exp (Rep m) -> m VName)
-> (BasicOp -> Exp (Rep m)) -> BasicOp -> m VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp (Rep m)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> m VName) -> BasicOp -> m VName
forall a b. (a -> b) -> a -> b
$
          SubExp -> SubExp -> SubExp -> IntType -> BasicOp
Iota SubExp
w (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) IntType
Int64
      Param Type
iparam <- [Char] -> Type -> m (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam [Char]
"i" (Type -> m (Param Type)) -> Type -> m (Param Type)
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
      [Param Type]
params <- [VName] -> (VName -> m (Param Type)) -> m [Param Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [VName]
vs ((VName -> m (Param Type)) -> m [Param Type])
-> (VName -> m (Param Type)) -> m [Param Type]
forall a b. (a -> b) -> a -> b
$ \VName
v ->
        [Char] -> Type -> m (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam (VName -> [Char]
baseString VName
v [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> [Char]
"_p") (Type -> m (Param Type))
-> (Type -> Type) -> Type -> m (Param Type)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Type -> Type
forall u. TypeBase Shape u -> TypeBase Shape u
rowType (Type -> m (Param Type)) -> m Type -> m (Param Type)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName -> m Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
v
      (([Type]
ret, Result
res), Stms (Rep m)
stms) <- m ([Type], Result) -> m (([Type], Result), Stms (Rep m))
forall (m :: * -> *) a.
MonadBuilder m =>
m a -> m (a, Stms (Rep m))
collectStms (m ([Type], Result) -> m (([Type], Result), Stms (Rep m)))
-> (m ([Type], Result) -> m ([Type], Result))
-> m ([Type], Result)
-> m (([Type], Result), Stms (Rep m))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Scope (Rep m) -> m ([Type], Result) -> m ([Type], Result)
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope ([Param Type] -> Scope (Rep m)
forall rep dec. (LParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfLParams (Param Type
iparam Param Type -> [Param Type] -> [Param Type]
forall a. a -> [a] -> [a]
: [Param Type]
params)) (m ([Type], Result) -> m (([Type], Result), Stms (Rep m)))
-> m ([Type], Result) -> m (([Type], Result), Stms (Rep m))
forall a b. (a -> b) -> a -> b
$ do
        [VName]
res <- [VName]
-> t -> [VName] -> ([VName] -> [VName] -> m [VName]) -> m [VName]
tabNest' (Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
iparam VName -> [VName] -> [VName]
forall a. a -> [a] -> [a]
: [VName]
is) (t
n t -> t -> t
forall a. Num a => a -> a -> a
- t
1) ((Param Type -> VName) -> [Param Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> VName
forall dec. Param dec -> VName
paramName [Param Type]
params) [VName] -> [VName] -> m [VName]
f
        [Type]
ret <- (VName -> m Type) -> [VName] -> m [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> m Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType [VName]
res
        ([Type], Result) -> m ([Type], Result)
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([Type]
ret, [VName] -> Result
varsRes [VName]
res)
      let lam :: Lambda (Rep m)
lam = [LParam (Rep m)] -> Body (Rep m) -> [Type] -> Lambda (Rep m)
forall rep. [LParam rep] -> Body rep -> [Type] -> Lambda rep
Lambda (Param Type
iparam Param Type -> [Param Type] -> [Param Type]
forall a. a -> [a] -> [a]
: [Param Type]
params) (BodyDec (Rep m) -> Stms (Rep m) -> Result -> Body (Rep m)
forall rep. BodyDec rep -> Stms rep -> Result -> Body rep
Body () Stms (Rep m)
stms Result
res) [Type]
ret
      [Char] -> Exp (Rep m) -> m [VName]
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m [VName]
letTupExp [Char]
"tab" (Exp (Rep m) -> m [VName]) -> Exp (Rep m) -> m [VName]
forall a b. (a -> b) -> a -> b
$ Op (Rep m) -> Exp (Rep m)
forall rep. Op rep -> Exp rep
Op (Op (Rep m) -> Exp (Rep m)) -> Op (Rep m) -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$ SubExp -> [VName] -> ScremaForm (Rep m) -> SOAC (Rep m)
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w (VName
iota VName -> [VName] -> [VName]
forall a. a -> [a] -> [a]
: [VName]
vs) (Lambda (Rep m) -> ScremaForm (Rep m)
forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda (Rep m)
lam)

-- | Construct a lambda for adding two values of the given type.
addLambda :: Type -> ADM (Lambda SOACS)
addLambda :: Type -> ADM (Lambda SOACS)
addLambda (Prim PrimType
pt) = BinOp -> PrimType -> ADM (Lambda (Rep ADM))
forall (m :: * -> *).
(MonadBuilder m, Buildable (Rep m)) =>
BinOp -> PrimType -> m (Lambda (Rep m))
binOpLambda (PrimType -> BinOp
addBinOp PrimType
pt) PrimType
pt
addLambda t :: Type
t@Array {} = do
  Param Type
xs_p <- [Char] -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam [Char]
"xs" Type
t
  Param Type
ys_p <- [Char] -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam [Char]
"ys" Type
t
  Lambda SOACS
lam <- Type -> ADM (Lambda SOACS)
addLambda (Type -> ADM (Lambda SOACS)) -> Type -> ADM (Lambda SOACS)
forall a b. (a -> b) -> a -> b
$ Type -> Type
forall u. TypeBase Shape u -> TypeBase Shape u
rowType Type
t
  Body SOACS
body <- ADM (Body (Rep ADM)) -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
m (Body (Rep m)) -> m (Body (Rep m))
insertStmsM (ADM (Body (Rep ADM)) -> ADM (Body (Rep ADM)))
-> ADM (Body (Rep ADM)) -> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$ do
    SubExp
res <-
      [Char] -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"lam_map" (Exp SOACS -> ADM SubExp)
-> (SOAC SOACS -> Exp SOACS) -> SOAC SOACS -> ADM SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SOAC SOACS -> Exp SOACS
forall rep. Op rep -> Exp rep
Op (SOAC SOACS -> ADM SubExp) -> SOAC SOACS -> ADM SubExp
forall a b. (a -> b) -> a -> b
$
        SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma (Int -> Type -> SubExp
forall u. Int -> TypeBase Shape u -> SubExp
arraySize Int
0 Type
t) [Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
xs_p, Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
ys_p] (Lambda SOACS -> ScremaForm SOACS
forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
lam)
    Body SOACS -> ADM (Body SOACS)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Body SOACS -> ADM (Body SOACS)) -> Body SOACS -> ADM (Body SOACS)
forall a b. (a -> b) -> a -> b
$ [SubExp] -> Body SOACS
forall rep. Buildable rep => [SubExp] -> Body rep
resultBody [SubExp
res]
  Lambda SOACS -> ADM (Lambda SOACS)
forall (f :: * -> *) a. Applicative f => a -> f a
pure
    Lambda :: forall rep. [LParam rep] -> Body rep -> [Type] -> Lambda rep
Lambda
      { lambdaParams :: [LParam SOACS]
lambdaParams = [Param Type
LParam SOACS
xs_p, Param Type
LParam SOACS
ys_p],
        lambdaReturnType :: [Type]
lambdaReturnType = [Type
t],
        lambdaBody :: Body SOACS
lambdaBody = Body SOACS
body
      }
addLambda Type
t =
  [Char] -> ADM (Lambda SOACS)
forall a. HasCallStack => [Char] -> a
error ([Char] -> ADM (Lambda SOACS)) -> [Char] -> ADM (Lambda SOACS)
forall a b. (a -> b) -> a -> b
$ [Char]
"addLambda: " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ Type -> [Char]
forall a. Show a => a -> [Char]
show Type
t

-- Construct an expression for adding the two variables.
addExp :: VName -> VName -> ADM (Exp SOACS)
addExp :: VName -> VName -> ADM (Exp SOACS)
addExp VName
x VName
y = do
  Type
x_t <- VName -> ADM Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
x
  case Type
x_t of
    Prim PrimType
pt ->
      Exp SOACS -> ADM (Exp SOACS)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp SOACS -> ADM (Exp SOACS)) -> Exp SOACS -> ADM (Exp SOACS)
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp SOACS) -> BasicOp -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (PrimType -> BinOp
addBinOp PrimType
pt) (VName -> SubExp
Var VName
x) (VName -> SubExp
Var VName
y)
    Array {} -> do
      Lambda SOACS
lam <- Type -> ADM (Lambda SOACS)
addLambda (Type -> ADM (Lambda SOACS)) -> Type -> ADM (Lambda SOACS)
forall a b. (a -> b) -> a -> b
$ Type -> Type
forall u. TypeBase Shape u -> TypeBase Shape u
rowType Type
x_t
      Exp SOACS -> ADM (Exp SOACS)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp SOACS -> ADM (Exp SOACS)) -> Exp SOACS -> ADM (Exp SOACS)
forall a b. (a -> b) -> a -> b
$ Op SOACS -> Exp SOACS
forall rep. Op rep -> Exp rep
Op (Op SOACS -> Exp SOACS) -> Op SOACS -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma (Int -> Type -> SubExp
forall u. Int -> TypeBase Shape u -> SubExp
arraySize Int
0 Type
x_t) [VName
x, VName
y] (Lambda SOACS -> ScremaForm SOACS
forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
lam)
    Type
_ ->
      [Char] -> ADM (Exp SOACS)
forall a. HasCallStack => [Char] -> a
error ([Char] -> ADM (Exp SOACS)) -> [Char] -> ADM (Exp SOACS)
forall a b. (a -> b) -> a -> b
$ [Char]
"addExp: unexpected type: " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ Type -> [Char]
forall a. Pretty a => a -> [Char]
pretty Type
x_t

lookupAdj :: VName -> ADM Adj
lookupAdj :: VName -> ADM Adj
lookupAdj VName
v = do
  Maybe Adj
maybeAdj <- (RState -> Maybe Adj) -> ADM (Maybe Adj)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((RState -> Maybe Adj) -> ADM (Maybe Adj))
-> (RState -> Maybe Adj) -> ADM (Maybe Adj)
forall a b. (a -> b) -> a -> b
$ VName -> Map VName Adj -> Maybe Adj
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v (Map VName Adj -> Maybe Adj)
-> (RState -> Map VName Adj) -> RState -> Maybe Adj
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RState -> Map VName Adj
stateAdjs
  case Maybe Adj
maybeAdj of
    Maybe Adj
Nothing -> do
      Type
v_t <- VName -> ADM Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
v
      case Type
v_t of
        Acc VName
_ Shape
shape [Prim PrimType
t] NoUniqueness
_ -> Adj -> ADM Adj
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Adj -> ADM Adj) -> Adj -> ADM Adj
forall a b. (a -> b) -> a -> b
$ Shape -> PrimType -> Adj
AdjZero Shape
shape PrimType
t
        Type
_ -> Adj -> ADM Adj
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Adj -> ADM Adj) -> Adj -> ADM Adj
forall a b. (a -> b) -> a -> b
$ Shape -> PrimType -> Adj
AdjZero (Type -> Shape
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape Type
v_t) (Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType Type
v_t)
    Just Adj
v_adj -> Adj -> ADM Adj
forall (f :: * -> *) a. Applicative f => a -> f a
pure Adj
v_adj

lookupAdjVal :: VName -> ADM VName
lookupAdjVal :: VName -> ADM VName
lookupAdjVal VName
v = Adj -> ADM VName
adjVal (Adj -> ADM VName) -> ADM Adj -> ADM VName
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName -> ADM Adj
lookupAdj VName
v

updateAdj :: VName -> VName -> ADM ()
updateAdj :: VName -> VName -> ADM ()
updateAdj VName
v VName
d = do
  Maybe Adj
maybeAdj <- (RState -> Maybe Adj) -> ADM (Maybe Adj)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((RState -> Maybe Adj) -> ADM (Maybe Adj))
-> (RState -> Maybe Adj) -> ADM (Maybe Adj)
forall a b. (a -> b) -> a -> b
$ VName -> Map VName Adj -> Maybe Adj
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v (Map VName Adj -> Maybe Adj)
-> (RState -> Map VName Adj) -> RState -> Maybe Adj
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RState -> Map VName Adj
stateAdjs
  case Maybe Adj
maybeAdj of
    Maybe Adj
Nothing ->
      VName -> VName -> ADM ()
insAdj VName
v VName
d
    Just Adj
adj -> do
      VName
v_adj <- Adj -> ADM VName
adjVal Adj
adj
      Type
v_adj_t <- VName -> ADM Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
v_adj
      case Type
v_adj_t of
        Acc {} -> do
          [SubExp]
dims <- Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims (Type -> [SubExp]) -> ADM Type -> ADM [SubExp]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ADM Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
d
          ~[VName
v_adj'] <-
            Int
-> [VName] -> ([VName] -> [VName] -> ADM [VName]) -> ADM [VName]
tabNest ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
dims) [VName
d, VName
v_adj] (([VName] -> [VName] -> ADM [VName]) -> ADM [VName])
-> ([VName] -> [VName] -> ADM [VName]) -> ADM [VName]
forall a b. (a -> b) -> a -> b
$ \[VName]
is [VName
d', VName
v_adj'] ->
              [Char] -> Exp (Rep ADM) -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m [VName]
letTupExp [Char]
"acc" (Exp (Rep ADM) -> ADM [VName]) -> Exp (Rep ADM) -> ADM [VName]
forall a b. (a -> b) -> a -> b
$
                BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp SOACS) -> BasicOp -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ VName -> [SubExp] -> [SubExp] -> BasicOp
UpdateAcc VName
v_adj' ((VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
is) [VName -> SubExp
Var VName
d']
          VName -> VName -> ADM ()
insAdj VName
v VName
v_adj'
        Type
_ -> do
          VName
v_adj' <- [Char] -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp (VName -> [Char]
baseString VName
v [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> [Char]
"_adj") (Exp SOACS -> ADM VName) -> ADM (Exp SOACS) -> ADM VName
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName -> VName -> ADM (Exp SOACS)
addExp VName
v_adj VName
d
          VName -> VName -> ADM ()
insAdj VName
v VName
v_adj'

updateAdjSlice :: Slice SubExp -> VName -> VName -> ADM ()
updateAdjSlice :: Slice SubExp -> VName -> VName -> ADM ()
updateAdjSlice (Slice [DimFix SubExp
i]) VName
v VName
d =
  VName -> (InBounds, SubExp) -> SubExp -> ADM ()
updateAdjIndex VName
v (InBounds
AssumeBounds, SubExp
i) (VName -> SubExp
Var VName
d)
updateAdjSlice Slice SubExp
slice VName
v VName
d = do
  Type
t <- VName -> ADM Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
v
  VName
v_adj <- VName -> ADM VName
lookupAdjVal VName
v
  Type
v_adj_t <- VName -> ADM Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
v_adj
  VName
v_adj' <- case Type
v_adj_t of
    Acc {} -> do
      let dims :: [SubExp]
dims = Slice SubExp -> [SubExp]
forall d. Slice d -> [d]
sliceDims Slice SubExp
slice
      ~[VName
v_adj'] <-
        Int
-> [VName] -> ([VName] -> [VName] -> ADM [VName]) -> ADM [VName]
tabNest ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
dims) [VName
d, VName
v_adj] (([VName] -> [VName] -> ADM [VName]) -> ADM [VName])
-> ([VName] -> [VName] -> ADM [VName]) -> ADM [VName]
forall a b. (a -> b) -> a -> b
$ \[VName]
is [VName
d', VName
v_adj'] -> do
          [SubExp]
slice' <-
            (TPrimExp Int64 VName -> ADM SubExp)
-> [TPrimExp Int64 VName] -> ADM [SubExp]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse ([Char] -> TPrimExp Int64 VName -> ADM SubExp
forall (m :: * -> *) a.
(MonadBuilder m, ToExp a) =>
[Char] -> a -> m SubExp
toSubExp [Char]
"index") ([TPrimExp Int64 VName] -> ADM [SubExp])
-> [TPrimExp Int64 VName] -> ADM [SubExp]
forall a b. (a -> b) -> a -> b
$
              Slice (TPrimExp Int64 VName)
-> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall d. Num d => Slice d -> [d] -> [d]
fixSlice ((SubExp -> TPrimExp Int64 VName)
-> Slice SubExp -> Slice (TPrimExp Int64 VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TPrimExp Int64 VName
pe64 Slice SubExp
slice) ([TPrimExp Int64 VName] -> [TPrimExp Int64 VName])
-> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> a -> b
$ (VName -> TPrimExp Int64 VName)
-> [VName] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 [VName]
is
          [Char] -> Exp (Rep ADM) -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m [VName]
letTupExp (VName -> [Char]
baseString VName
v_adj') (Exp SOACS -> ADM [VName])
-> (BasicOp -> Exp SOACS) -> BasicOp -> ADM [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> ADM [VName]) -> BasicOp -> ADM [VName]
forall a b. (a -> b) -> a -> b
$
            VName -> [SubExp] -> [SubExp] -> BasicOp
UpdateAcc VName
v_adj' [SubExp]
slice' [VName -> SubExp
Var VName
d']
      VName -> ADM VName
forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
v_adj'
    Type
_ -> do
      VName
v_adjslice <-
        if Type -> Bool
forall shape u. TypeBase shape u -> Bool
primType Type
t
          then VName -> ADM VName
forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
v_adj
          else [Char] -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp (VName -> [Char]
baseString VName
v [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"_slice") (Exp (Rep ADM) -> ADM VName) -> Exp (Rep ADM) -> ADM VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp SOACS) -> BasicOp -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
v_adj Slice SubExp
slice
      [Char] -> VName -> Slice SubExp -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> VName -> Slice SubExp -> Exp (Rep m) -> m VName
letInPlace [Char]
"updated_adj" VName
v_adj Slice SubExp
slice (Exp SOACS -> ADM VName) -> ADM (Exp SOACS) -> ADM VName
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName -> VName -> ADM (Exp SOACS)
addExp VName
v_adjslice VName
d
  VName -> VName -> ADM ()
insAdj VName
v VName
v_adj'

updateSubExpAdj :: SubExp -> VName -> ADM ()
updateSubExpAdj :: SubExp -> VName -> ADM ()
updateSubExpAdj Constant {} VName
_ = () -> ADM ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
updateSubExpAdj (Var VName
v) VName
d = ADM () -> ADM ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (ADM () -> ADM ()) -> ADM () -> ADM ()
forall a b. (a -> b) -> a -> b
$ VName -> VName -> ADM ()
updateAdj VName
v VName
d

insSubExpAdj :: SubExp -> VName -> ADM ()
insSubExpAdj :: SubExp -> VName -> ADM ()
insSubExpAdj Constant {} VName
_ = () -> ADM ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
insSubExpAdj (Var VName
v) VName
d = ADM () -> ADM ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (ADM () -> ADM ()) -> ADM () -> ADM ()
forall a b. (a -> b) -> a -> b
$ VName -> VName -> ADM ()
insAdj VName
v VName
d

-- The index may be negative, in which case the update has no effect.
updateAdjIndex :: VName -> (InBounds, SubExp) -> SubExp -> ADM ()
updateAdjIndex :: VName -> (InBounds, SubExp) -> SubExp -> ADM ()
updateAdjIndex VName
v (InBounds
check, SubExp
i) SubExp
se = do
  Maybe Adj
maybeAdj <- (RState -> Maybe Adj) -> ADM (Maybe Adj)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((RState -> Maybe Adj) -> ADM (Maybe Adj))
-> (RState -> Maybe Adj) -> ADM (Maybe Adj)
forall a b. (a -> b) -> a -> b
$ VName -> Map VName Adj -> Maybe Adj
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v (Map VName Adj -> Maybe Adj)
-> (RState -> Map VName Adj) -> RState -> Maybe Adj
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RState -> Map VName Adj
stateAdjs
  Type
t <- VName -> ADM Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
v
  let iv :: (InBounds, SubExp, SubExp)
iv = (InBounds
check, SubExp
i, SubExp
se)
  case Maybe Adj
maybeAdj of
    Maybe Adj
Nothing -> do
      VName -> Adj -> ADM ()
setAdj VName
v (Adj -> ADM ()) -> Adj -> ADM ()
forall a b. (a -> b) -> a -> b
$ Sparse -> Adj
AdjSparse (Sparse -> Adj) -> Sparse -> Adj
forall a b. (a -> b) -> a -> b
$ Shape -> PrimType -> [(InBounds, SubExp, SubExp)] -> Sparse
Sparse (Type -> Shape
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape Type
t) (Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType Type
t) [(InBounds, SubExp, SubExp)
iv]
    Just AdjZero {} ->
      VName -> Adj -> ADM ()
setAdj VName
v (Adj -> ADM ()) -> Adj -> ADM ()
forall a b. (a -> b) -> a -> b
$ Sparse -> Adj
AdjSparse (Sparse -> Adj) -> Sparse -> Adj
forall a b. (a -> b) -> a -> b
$ Shape -> PrimType -> [(InBounds, SubExp, SubExp)] -> Sparse
Sparse (Type -> Shape
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape Type
t) (Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType Type
t) [(InBounds, SubExp, SubExp)
iv]
    Just (AdjSparse (Sparse Shape
shape PrimType
pt [(InBounds, SubExp, SubExp)]
ivs)) ->
      VName -> Adj -> ADM ()
setAdj VName
v (Adj -> ADM ()) -> Adj -> ADM ()
forall a b. (a -> b) -> a -> b
$ Sparse -> Adj
AdjSparse (Sparse -> Adj) -> Sparse -> Adj
forall a b. (a -> b) -> a -> b
$ Shape -> PrimType -> [(InBounds, SubExp, SubExp)] -> Sparse
Sparse Shape
shape PrimType
pt ([(InBounds, SubExp, SubExp)] -> Sparse)
-> [(InBounds, SubExp, SubExp)] -> Sparse
forall a b. (a -> b) -> a -> b
$ (InBounds, SubExp, SubExp)
iv (InBounds, SubExp, SubExp)
-> [(InBounds, SubExp, SubExp)] -> [(InBounds, SubExp, SubExp)]
forall a. a -> [a] -> [a]
: [(InBounds, SubExp, SubExp)]
ivs
    Just adj :: Adj
adj@AdjVal {} -> do
      VName
v_adj <- Adj -> ADM VName
adjVal Adj
adj
      Type
v_adj_t <- VName -> ADM Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
v_adj
      VName
se_v <- [Char] -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"se_v" (Exp (Rep ADM) -> ADM VName) -> Exp (Rep ADM) -> ADM VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp SOACS) -> BasicOp -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se
      VName -> VName -> ADM ()
insAdj VName
v
        (VName -> ADM ()) -> ADM VName -> ADM ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< case Type
v_adj_t of
          Acc {}
            | InBounds
check InBounds -> InBounds -> Bool
forall a. Eq a => a -> a -> Bool
== InBounds
OutOfBounds ->
                VName -> ADM VName
forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
v_adj
            | Bool
otherwise -> do
                [SubExp]
dims <- Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims (Type -> [SubExp]) -> ADM Type -> ADM [SubExp]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ADM Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
se_v
                ~[VName
v_adj'] <-
                  Int
-> [VName] -> ([VName] -> [VName] -> ADM [VName]) -> ADM [VName]
tabNest ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
dims) [VName
se_v, VName
v_adj] (([VName] -> [VName] -> ADM [VName]) -> ADM [VName])
-> ([VName] -> [VName] -> ADM [VName]) -> ADM [VName]
forall a b. (a -> b) -> a -> b
$ \[VName]
is [VName
se_v', VName
v_adj'] ->
                    [Char] -> Exp (Rep ADM) -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m [VName]
letTupExp [Char]
"acc" (Exp (Rep ADM) -> ADM [VName]) -> Exp (Rep ADM) -> ADM [VName]
forall a b. (a -> b) -> a -> b
$
                      BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp SOACS) -> BasicOp -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ VName -> [SubExp] -> [SubExp] -> BasicOp
UpdateAcc VName
v_adj' (SubExp
i SubExp -> [SubExp] -> [SubExp]
forall a. a -> [a] -> [a]
: (VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
is) [VName -> SubExp
Var VName
se_v']
                VName -> ADM VName
forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
v_adj'
          Type
_ -> do
            let stms :: Safety -> ADM VName
stms Safety
s = do
                  VName
v_adj_i <- [Char] -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp (VName -> [Char]
baseString VName
v_adj [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> [Char]
"_i") (Exp (Rep ADM) -> ADM VName) -> Exp (Rep ADM) -> ADM VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp SOACS) -> BasicOp -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
v_adj (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice Type
v_adj_t [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix SubExp
i]
                  SubExp
se_update <- [Char] -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"updated_adj_i" (Exp SOACS -> ADM SubExp) -> ADM (Exp SOACS) -> ADM SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName -> VName -> ADM (Exp SOACS)
addExp VName
se_v VName
v_adj_i
                  [Char] -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp (VName -> [Char]
baseString VName
v_adj) (Exp (Rep ADM) -> ADM VName) -> Exp (Rep ADM) -> ADM VName
forall a b. (a -> b) -> a -> b
$
                    BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp SOACS) -> BasicOp -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ Safety -> VName -> Slice SubExp -> SubExp -> BasicOp
Update Safety
s VName
v_adj (Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice Type
v_adj_t [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix SubExp
i]) SubExp
se_update
            case InBounds
check of
              CheckBounds Maybe SubExp
_ -> Safety -> ADM VName
stms Safety
Safe
              InBounds
AssumeBounds -> Safety -> ADM VName
stms Safety
Unsafe
              InBounds
OutOfBounds -> VName -> ADM VName
forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
v_adj

-- | Is this primal variable active in the AD sense?  FIXME: this is
-- (obviously) much too conservative.
isActive :: VName -> ADM Bool
isActive :: VName -> ADM Bool
isActive = (Type -> Bool) -> ADM Type -> ADM Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Type -> Type -> Bool
forall a. Eq a => a -> a -> Bool
/= PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
Unit) (ADM Type -> ADM Bool) -> (VName -> ADM Type) -> VName -> ADM Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> ADM Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType

subAD :: ADM a -> ADM a
subAD :: ADM a -> ADM a
subAD ADM a
m = do
  Map VName Adj
old_state_adjs <- (RState -> Map VName Adj) -> ADM (Map VName Adj)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets RState -> Map VName Adj
stateAdjs
  a
x <- ADM a
m
  (RState -> RState) -> ADM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((RState -> RState) -> ADM ()) -> (RState -> RState) -> ADM ()
forall a b. (a -> b) -> a -> b
$ \RState
s -> RState
s {stateAdjs :: Map VName Adj
stateAdjs = Map VName Adj
old_state_adjs}
  a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure a
x

subSubsts :: ADM a -> ADM a
subSubsts :: ADM a -> ADM a
subSubsts ADM a
m = do
  Map VName VName
old_state_substs <- (RState -> Map VName VName) -> ADM (Map VName VName)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets RState -> Map VName VName
stateSubsts
  a
x <- ADM a
m
  (RState -> RState) -> ADM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((RState -> RState) -> ADM ()) -> (RState -> RState) -> ADM ()
forall a b. (a -> b) -> a -> b
$ \RState
s -> RState
s {stateSubsts :: Map VName VName
stateSubsts = Map VName VName
old_state_substs}
  a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure a
x

data VjpOps = VjpOps
  { VjpOps -> [Adj] -> [VName] -> Lambda SOACS -> ADM (Lambda SOACS)
vjpLambda :: [Adj] -> [VName] -> Lambda SOACS -> ADM (Lambda SOACS),
    VjpOps -> Stm SOACS -> ADM () -> ADM ()
vjpStm :: Stm SOACS -> ADM () -> ADM ()
  }

-- | @setLoopTape v vs@ establishes @vs@ as the name of the array
-- where values of loop parameter @v@ from the forward pass are
-- stored.
setLoopTape :: VName -> VName -> ADM ()
setLoopTape :: VName -> VName -> ADM ()
setLoopTape VName
v VName
vs = (RState -> RState) -> ADM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((RState -> RState) -> ADM ()) -> (RState -> RState) -> ADM ()
forall a b. (a -> b) -> a -> b
$ \RState
env ->
  RState
env {stateLoopTape :: Map VName VName
stateLoopTape = VName -> VName -> Map VName VName -> Map VName VName
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
v VName
vs (Map VName VName -> Map VName VName)
-> Map VName VName -> Map VName VName
forall a b. (a -> b) -> a -> b
$ RState -> Map VName VName
stateLoopTape RState
env}

-- | Look-up the name of the array where @v@ is stored.
lookupLoopTape :: VName -> ADM (Maybe VName)
lookupLoopTape :: VName -> ADM (Maybe VName)
lookupLoopTape VName
v = (RState -> Maybe VName) -> ADM (Maybe VName)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((RState -> Maybe VName) -> ADM (Maybe VName))
-> (RState -> Maybe VName) -> ADM (Maybe VName)
forall a b. (a -> b) -> a -> b
$ VName -> Map VName VName -> Maybe VName
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v (Map VName VName -> Maybe VName)
-> (RState -> Map VName VName) -> RState -> Maybe VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RState -> Map VName VName
stateLoopTape

-- | @substLoopTape v v'@ substitutes the key @v@ for @v'@. That is,
-- if @v |-> vs@ then after the substitution @v' |-> vs@ (and @v@
-- points to nothing).
substLoopTape :: VName -> VName -> ADM ()
substLoopTape :: VName -> VName -> ADM ()
substLoopTape VName
v VName
v' = (VName -> ADM ()) -> Maybe VName -> ADM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (VName -> VName -> ADM ()
setLoopTape VName
v') (Maybe VName -> ADM ()) -> ADM (Maybe VName) -> ADM ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName -> ADM (Maybe VName)
lookupLoopTape VName
v

-- | Renames the keys of the loop tape. Useful for fixing the
-- the names in the loop tape after a loop rename.
renameLoopTape :: Substitutions -> ADM ()
renameLoopTape :: Map VName VName -> ADM ()
renameLoopTape = ((VName, VName) -> ADM ()) -> [(VName, VName)] -> ADM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ((VName -> VName -> ADM ()) -> (VName, VName) -> ADM ()
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry VName -> VName -> ADM ()
substLoopTape) ([(VName, VName)] -> ADM ())
-> (Map VName VName -> [(VName, VName)])
-> Map VName VName
-> ADM ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Map VName VName -> [(VName, VName)]
forall k a. Map k a -> [(k, a)]
M.toList

-- Note [Consumption]
--
-- Parts of this transformation depends on duplicating computation.
-- This is a problem when a primal expression consumes arrays (via
-- e.g. Update).  For example, consider how we handle this conditional:
--
--   if b then ys with [0] = 0 else ys
--
-- This consumes the array 'ys', which means that when we later
-- generate code for the return sweep, we can no longer use 'ys'.
-- This is a problem, because when we call 'diffBody' on the branch
-- bodies, we'll keep the primal code (maybe it'll be removed by
-- simplification later - we cannot know).  A similar issue occurs for
-- SOACs.  Our solution is to make copies of all consumes arrays:
--
--  let ys_copy = copy ys
--
-- Then we generate code for the return sweep as normal, but replace
-- _every instance_ of 'ys' in the generated code with 'ys_copy'.
-- This works because Futhark does not have *semantic* in-place
-- updates - any uniqueness violation can be replaced with copies (on
-- arrays, anyway).
--
-- If we are lucky, the uses of 'ys_copy' will be removed by
-- simplification, and there will be no overhead.  But even if not,
-- this is still (asymptotically) efficient because the array that is
-- being consumed must in any case have been produced within the code
-- that we are differentiating, so a copy is at most a scalar
-- overhead.  This is _not_ the case when loops are involved.
--
-- Also, the above only works for arrays, not accumulator variables.
-- Those will need some other mechanism.