{-# LANGUAGE TypeFamilies #-}

module Futhark.AD.Rev.Loop (diffLoop, stripmineStms) where

import Control.Monad
import Data.Foldable (toList)
import Data.List ((\\))
import Data.Map qualified as M
import Data.Maybe
import Futhark.AD.Rev.Monad
import Futhark.Analysis.Alias qualified as Alias
import Futhark.Analysis.PrimExp.Convert
import Futhark.Builder
import Futhark.IR.Aliases (consumedInStms)
import Futhark.IR.SOACS
import Futhark.Tools
import Futhark.Transform.Rename
import Futhark.Transform.Substitute
import Futhark.Util (nubOrd, traverseFold)

-- | A convenience function to bring the components of a for-loop into
-- scope and throw an error if the passed 'Exp' is not a for-loop.
bindForLoop ::
  PrettyRep rep =>
  Exp rep ->
  ( [(Param (FParamInfo rep), SubExp)] ->
    LoopForm rep ->
    VName ->
    IntType ->
    SubExp ->
    [(Param (LParamInfo rep), VName)] ->
    Body rep ->
    a
  ) ->
  a
bindForLoop :: forall rep a.
PrettyRep rep =>
Exp rep
-> ([(Param (FParamInfo rep), SubExp)]
    -> LoopForm rep
    -> VName
    -> IntType
    -> SubExp
    -> [(Param (LParamInfo rep), VName)]
    -> Body rep
    -> a)
-> a
bindForLoop (DoLoop [(FParam rep, SubExp)]
val_pats form :: LoopForm rep
form@(ForLoop VName
i IntType
it SubExp
bound [(LParam rep, VName)]
loop_vars) Body rep
body) [(FParam rep, SubExp)]
-> LoopForm rep
-> VName
-> IntType
-> SubExp
-> [(LParam rep, VName)]
-> Body rep
-> a
f =
  [(FParam rep, SubExp)]
-> LoopForm rep
-> VName
-> IntType
-> SubExp
-> [(LParam rep, VName)]
-> Body rep
-> a
f [(FParam rep, SubExp)]
val_pats LoopForm rep
form VName
i IntType
it SubExp
bound [(LParam rep, VName)]
loop_vars Body rep
body
bindForLoop Exp rep
e [(FParam rep, SubExp)]
-> LoopForm rep
-> VName
-> IntType
-> SubExp
-> [(LParam rep, VName)]
-> Body rep
-> a
_ = forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$ [Char]
"bindForLoop: not a for-loop:\n" forall a. Semigroup a => a -> a -> a
<> forall a. Pretty a => a -> [Char]
prettyString Exp rep
e

-- | A convenience function to rename a for-loop and then bind the
-- renamed components.
renameForLoop ::
  (MonadFreshNames m, Renameable rep, PrettyRep rep) =>
  Exp rep ->
  ( Exp rep ->
    [(Param (FParamInfo rep), SubExp)] ->
    LoopForm rep ->
    VName ->
    IntType ->
    SubExp ->
    [(Param (LParamInfo rep), VName)] ->
    Body rep ->
    m a
  ) ->
  m a
renameForLoop :: forall (m :: * -> *) rep a.
(MonadFreshNames m, Renameable rep, PrettyRep rep) =>
Exp rep
-> (Exp rep
    -> [(Param (FParamInfo rep), SubExp)]
    -> LoopForm rep
    -> VName
    -> IntType
    -> SubExp
    -> [(Param (LParamInfo rep), VName)]
    -> Body rep
    -> m a)
-> m a
renameForLoop Exp rep
loop Exp rep
-> [(Param (FParamInfo rep), SubExp)]
-> LoopForm rep
-> VName
-> IntType
-> SubExp
-> [(Param (LParamInfo rep), VName)]
-> Body rep
-> m a
f = forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Exp rep -> m (Exp rep)
renameExp Exp rep
loop forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \Exp rep
loop' -> forall rep a.
PrettyRep rep =>
Exp rep
-> ([(Param (FParamInfo rep), SubExp)]
    -> LoopForm rep
    -> VName
    -> IntType
    -> SubExp
    -> [(Param (LParamInfo rep), VName)]
    -> Body rep
    -> a)
-> a
bindForLoop Exp rep
loop' (Exp rep
-> [(Param (FParamInfo rep), SubExp)]
-> LoopForm rep
-> VName
-> IntType
-> SubExp
-> [(Param (LParamInfo rep), VName)]
-> Body rep
-> m a
f Exp rep
loop')

-- | Is the loop a while-loop?
isWhileLoop :: Exp rep -> Bool
isWhileLoop :: forall rep. Exp rep -> Bool
isWhileLoop (DoLoop [(FParam rep, SubExp)]
_ WhileLoop {} Body rep
_) = Bool
True
isWhileLoop Exp rep
_ = Bool
False

-- | Transforms a 'ForLoop' into a 'ForLoop' with an empty list of
-- loop variables.
removeLoopVars :: MonadBuilder m => Exp (Rep m) -> m (Exp (Rep m))
removeLoopVars :: forall (m :: * -> *).
MonadBuilder m =>
Exp (Rep m) -> m (Exp (Rep m))
removeLoopVars Exp (Rep m)
loop =
  forall rep a.
PrettyRep rep =>
Exp rep
-> ([(Param (FParamInfo rep), SubExp)]
    -> LoopForm rep
    -> VName
    -> IntType
    -> SubExp
    -> [(Param (LParamInfo rep), VName)]
    -> Body rep
    -> a)
-> a
bindForLoop Exp (Rep m)
loop forall a b. (a -> b) -> a -> b
$ \[(Param (FParamInfo (Rep m)), SubExp)]
val_pats LoopForm (Rep m)
form VName
i IntType
_it SubExp
_bound [(Param (LParamInfo (Rep m)), VName)]
loop_vars Body (Rep m)
body -> do
    let indexify :: (Param (LParamInfo (Rep m)), VName) -> m (VName, VName)
indexify (Param (LParamInfo (Rep m))
x_param, VName
xs) = do
          Type
xs_t <- forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
xs
          VName
x' <-
            forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp (VName -> [Char]
baseString forall a b. (a -> b) -> a -> b
$ forall dec. Param dec -> VName
paramName Param (LParamInfo (Rep m))
x_param) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. BasicOp -> Exp rep
BasicOp forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> Slice SubExp -> BasicOp
Index VName
xs forall a b. (a -> b) -> a -> b
$
              Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice Type
xs_t [forall d. d -> DimIndex d
DimFix (VName -> SubExp
Var VName
i)]
          forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall dec. Param dec -> VName
paramName Param (LParamInfo (Rep m))
x_param, VName
x')
    ([(VName, VName)]
substs_list, Stms (Rep m)
subst_stms) <- forall (m :: * -> *) a.
MonadBuilder m =>
m a -> m (a, Stms (Rep m))
collectStms forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Param (LParamInfo (Rep m)), VName) -> m (VName, VName)
indexify [(Param (LParamInfo (Rep m)), VName)]
loop_vars
    let Body BodyDec (Rep m)
aux' Stms (Rep m)
stms' Result
res' = forall a. Substitute a => Map VName VName -> a -> a
substituteNames (forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [(VName, VName)]
substs_list) Body (Rep m)
body
    forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall rep.
[(FParam rep, SubExp)] -> LoopForm rep -> Body rep -> Exp rep
DoLoop [(Param (FParamInfo (Rep m)), SubExp)]
val_pats LoopForm (Rep m)
form forall a b. (a -> b) -> a -> b
$ forall rep. BodyDec rep -> Stms rep -> Result -> Body rep
Body BodyDec (Rep m)
aux' (Stms (Rep m)
subst_stms forall a. Semigroup a => a -> a -> a
<> Stms (Rep m)
stms') Result
res'

-- | Augments a while-loop to also compute the number of iterations.
computeWhileIters :: Exp SOACS -> ADM SubExp
computeWhileIters :: Exp SOACS -> ADM SubExp
computeWhileIters (DoLoop [(FParam SOACS, SubExp)]
val_pats (WhileLoop VName
b) Body SOACS
body) = do
  VName
bound_v <- forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"bound"
  let t :: TypeBase shape u
t = forall shape u. PrimType -> TypeBase shape u
Prim forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
IntType IntType
Int64
      bound_param :: Param (TypeBase Shape Uniqueness)
bound_param = forall dec. Attrs -> VName -> dec -> Param dec
Param forall a. Monoid a => a
mempty VName
bound_v forall {shape} {u}. TypeBase shape u
t
  SubExp
bound_init <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"bound_init" forall a b. (a -> b) -> a -> b
$ forall rep. Type -> Exp rep
zeroExp forall {shape} {u}. TypeBase shape u
t
  Body SOACS
body' <- forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (forall rep dec. (FParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfFParams [Param (TypeBase Shape Uniqueness)
bound_param]) forall a b. (a -> b) -> a -> b
$
    forall (m :: * -> *).
MonadBuilder m =>
m Result -> m (Body (Rep m))
buildBody_ forall a b. (a -> b) -> a -> b
$ do
      SubExp
bound_plus_one <-
        let one :: SubExp
one = PrimValue -> SubExp
Constant forall a b. (a -> b) -> a -> b
$ IntValue -> PrimValue
IntValue forall a b. (a -> b) -> a -> b
$ forall int. Integral int => IntType -> int -> IntValue
intValue IntType
Int64 (Int
1 :: Int)
         in forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"bound+1" forall a b. (a -> b) -> a -> b
$ forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Overflow -> BinOp
Add IntType
Int64 Overflow
OverflowUndef) (VName -> SubExp
Var VName
bound_v) SubExp
one
      forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms forall a b. (a -> b) -> a -> b
$ forall rep. Body rep -> Stms rep
bodyStms Body SOACS
body
      forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall (f :: * -> *) a. Applicative f => a -> f a
pure (SubExp -> SubExpRes
subExpRes SubExp
bound_plus_one) forall a. Semigroup a => a -> a -> a
<> forall rep. Body rep -> Result
bodyResult Body SOACS
body)
  [SubExp]
res <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m [SubExp]
letTupExp' [Char]
"loop" forall a b. (a -> b) -> a -> b
$ forall rep.
[(FParam rep, SubExp)] -> LoopForm rep -> Body rep -> Exp rep
DoLoop ((Param (TypeBase Shape Uniqueness)
bound_param, SubExp
bound_init) forall a. a -> [a] -> [a]
: [(FParam SOACS, SubExp)]
val_pats) (forall rep. VName -> LoopForm rep
WhileLoop VName
b) Body SOACS
body'
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a. [a] -> a
head [SubExp]
res
computeWhileIters Exp SOACS
e = forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$ [Char]
"convertWhileIters: not a while-loop:\n" forall a. Semigroup a => a -> a -> a
<> forall a. Pretty a => a -> [Char]
prettyString Exp SOACS
e

-- | Converts a 'WhileLoop' into a 'ForLoop'. Requires that the
-- surrounding 'DoLoop' is annotated with a @#[bound(n)]@ attribute,
-- where @n@ is an upper bound on the number of iterations of the
-- while-loop. The resulting for-loop will execute for @n@ iterations on
-- all inputs, so the tighter the bound the better.
convertWhileLoop :: SubExp -> Exp SOACS -> ADM (Exp SOACS)
convertWhileLoop :: SubExp -> Exp SOACS -> ADM (Exp SOACS)
convertWhileLoop SubExp
bound_se (DoLoop [(FParam SOACS, SubExp)]
val_pats (WhileLoop VName
cond) Body SOACS
body) =
  forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (forall rep dec. (FParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfFParams forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst [(FParam SOACS, SubExp)]
val_pats) forall a b. (a -> b) -> a -> b
$ do
    VName
i <- forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"i"
    Body SOACS
body' <-
      forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody
        [ forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
            (forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
cond)
            (forall (f :: * -> *) a. Applicative f => a -> f a
pure Body SOACS
body)
            (forall (m :: * -> *).
MonadBuilder m =>
[SubExp] -> m (Body (Rep m))
resultBodyM forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (VName -> SubExp
Var forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dec. Param dec -> VName
paramName forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst) [(FParam SOACS, SubExp)]
val_pats)
        ]
    forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall rep.
[(FParam rep, SubExp)] -> LoopForm rep -> Body rep -> Exp rep
DoLoop [(FParam SOACS, SubExp)]
val_pats (forall rep.
VName -> IntType -> SubExp -> [(LParam rep, VName)] -> LoopForm rep
ForLoop VName
i IntType
Int64 SubExp
bound_se forall a. Monoid a => a
mempty) Body SOACS
body'
convertWhileLoop SubExp
_ Exp SOACS
e = forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$ [Char]
"convertWhileLoopBound: not a while-loop:\n" forall a. Semigroup a => a -> a -> a
<> forall a. Pretty a => a -> [Char]
prettyString Exp SOACS
e

-- | @nestifyLoop n bound loop@ transforms a loop into a depth-@n@ loop nest
-- of @bound@-iteration loops. This transformation does not preserve
-- the original semantics of the loop: @n@ and @bound@ may be arbitrary and have
-- no relation to the number of iterations of @loop@.
nestifyLoop ::
  SubExp ->
  Integer ->
  Exp SOACS ->
  ADM (Exp SOACS)
nestifyLoop :: SubExp -> Integer -> Exp SOACS -> ADM (Exp SOACS)
nestifyLoop SubExp
bound_se = SubExp -> Integer -> Exp SOACS -> ADM (Exp SOACS)
nestifyLoop' SubExp
bound_se
  where
    nestifyLoop' :: SubExp -> Integer -> Exp SOACS -> ADM (Exp SOACS)
nestifyLoop' SubExp
offset Integer
n Exp SOACS
loop = forall rep a.
PrettyRep rep =>
Exp rep
-> ([(Param (FParamInfo rep), SubExp)]
    -> LoopForm rep
    -> VName
    -> IntType
    -> SubExp
    -> [(Param (LParamInfo rep), VName)]
    -> Body rep
    -> a)
-> a
bindForLoop Exp SOACS
loop [(Param (TypeBase Shape Uniqueness), SubExp)]
-> LoopForm SOACS
-> VName
-> IntType
-> SubExp
-> [(Param Type, VName)]
-> Body SOACS
-> ADM (Exp SOACS)
nestify
      where
        nestify :: [(Param (TypeBase Shape Uniqueness), SubExp)]
-> LoopForm SOACS
-> VName
-> IntType
-> SubExp
-> [(Param Type, VName)]
-> Body SOACS
-> ADM (Exp SOACS)
nestify [(Param (TypeBase Shape Uniqueness), SubExp)]
val_pats LoopForm SOACS
_form VName
i IntType
it SubExp
_bound [(Param Type, VName)]
loop_vars Body SOACS
body
          | Integer
n forall a. Ord a => a -> a -> Bool
> Integer
1 = do
              forall (m :: * -> *) rep a.
(MonadFreshNames m, Renameable rep, PrettyRep rep) =>
Exp rep
-> (Exp rep
    -> [(Param (FParamInfo rep), SubExp)]
    -> LoopForm rep
    -> VName
    -> IntType
    -> SubExp
    -> [(Param (LParamInfo rep), VName)]
    -> Body rep
    -> m a)
-> m a
renameForLoop Exp SOACS
loop forall a b. (a -> b) -> a -> b
$ \Exp SOACS
_loop' [(FParam SOACS, SubExp)]
val_pats' LoopForm SOACS
_form' VName
i' IntType
it' SubExp
_bound' [(Param (LParamInfo SOACS), VName)]
loop_vars' Body SOACS
body' -> do
                let loop_params :: [Param (TypeBase Shape Uniqueness)]
loop_params = forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst [(Param (TypeBase Shape Uniqueness), SubExp)]
val_pats
                    loop_params' :: [Param (TypeBase Shape Uniqueness)]
loop_params' = forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst [(FParam SOACS, SubExp)]
val_pats'
                    loop_inits' :: [SubExp]
loop_inits' = forall a b. (a -> b) -> [a] -> [b]
map (VName -> SubExp
Var forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dec. Param dec -> VName
paramName) [Param (TypeBase Shape Uniqueness)]
loop_params
                    val_pats'' :: [(Param (TypeBase Shape Uniqueness), SubExp)]
val_pats'' = forall a b. [a] -> [b] -> [(a, b)]
zip [Param (TypeBase Shape Uniqueness)]
loop_params' [SubExp]
loop_inits'
                Body SOACS
outer_body <-
                  forall (m :: * -> *).
MonadBuilder m =>
m Result -> m (Body (Rep m))
buildBody_ forall a b. (a -> b) -> a -> b
$ do
                    SubExp
offset' <-
                      forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"offset" forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
                        BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Overflow -> BinOp
Mul IntType
it Overflow
OverflowUndef) SubExp
offset (VName -> SubExp
Var VName
i)

                    Body SOACS
inner_body <- forall (m :: * -> *).
MonadBuilder m =>
m (Body (Rep m)) -> m (Body (Rep m))
insertStmsM forall a b. (a -> b) -> a -> b
$ do
                      VName
i_inner <-
                        forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"i_inner" forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
                          BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Overflow -> BinOp
Add IntType
it Overflow
OverflowUndef) SubExp
offset' (VName -> SubExp
Var VName
i')
                      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a. Substitute a => Map VName VName -> a -> a
substituteNames (forall k a. k -> a -> Map k a
M.singleton VName
i' VName
i_inner) Body SOACS
body'

                    [VName]
inner_loop <-
                      forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m [VName]
letTupExp [Char]
"inner_loop"
                        forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< SubExp -> Integer -> Exp SOACS -> ADM (Exp SOACS)
nestifyLoop'
                          SubExp
offset'
                          (Integer
n forall a. Num a => a -> a -> a
- Integer
1)
                          (forall rep.
[(FParam rep, SubExp)] -> LoopForm rep -> Body rep -> Exp rep
DoLoop [(Param (TypeBase Shape Uniqueness), SubExp)]
val_pats'' (forall rep.
VName -> IntType -> SubExp -> [(LParam rep, VName)] -> LoopForm rep
ForLoop VName
i' IntType
it' SubExp
bound_se [(Param (LParamInfo SOACS), VName)]
loop_vars') Body SOACS
inner_body)
                    forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ [VName] -> Result
varsRes [VName]
inner_loop
                forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$
                  forall rep.
[(FParam rep, SubExp)] -> LoopForm rep -> Body rep -> Exp rep
DoLoop [(Param (TypeBase Shape Uniqueness), SubExp)]
val_pats (forall rep.
VName -> IntType -> SubExp -> [(LParam rep, VName)] -> LoopForm rep
ForLoop VName
i IntType
it SubExp
bound_se [(Param Type, VName)]
loop_vars) Body SOACS
outer_body
          | Integer
n forall a. Eq a => a -> a -> Bool
== Integer
1 =
              forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall rep.
[(FParam rep, SubExp)] -> LoopForm rep -> Body rep -> Exp rep
DoLoop [(Param (TypeBase Shape Uniqueness), SubExp)]
val_pats (forall rep.
VName -> IntType -> SubExp -> [(LParam rep, VName)] -> LoopForm rep
ForLoop VName
i IntType
it SubExp
bound_se [(Param Type, VName)]
loop_vars) Body SOACS
body
          | Bool
otherwise = forall (f :: * -> *) a. Applicative f => a -> f a
pure Exp SOACS
loop

-- | @stripmine n pat loop@ stripmines a loop into a depth-@n@ loop nest.
-- An additional @bound - (floor(bound^(1/n)))^n@-iteration remainder loop is
-- inserted after the stripmined loop which executes the remaining iterations
-- so that the stripmined loop is semantically equivalent to the original loop.
stripmine :: Integer -> Pat Type -> Exp SOACS -> ADM (Stms SOACS)
stripmine :: Integer -> Pat Type -> Exp SOACS -> ADM (Stms SOACS)
stripmine Integer
n Pat Type
pat Exp SOACS
loop = do
  Exp SOACS
loop' <- forall (m :: * -> *).
MonadBuilder m =>
Exp (Rep m) -> m (Exp (Rep m))
removeLoopVars Exp SOACS
loop
  forall rep a.
PrettyRep rep =>
Exp rep
-> ([(Param (FParamInfo rep), SubExp)]
    -> LoopForm rep
    -> VName
    -> IntType
    -> SubExp
    -> [(Param (LParamInfo rep), VName)]
    -> Body rep
    -> a)
-> a
bindForLoop Exp SOACS
loop' forall a b. (a -> b) -> a -> b
$ \[(FParam SOACS, SubExp)]
_val_pats LoopForm SOACS
_form VName
_i IntType
it SubExp
bound [(Param (LParamInfo SOACS), VName)]
_loop_vars Body SOACS
_body -> do
    let n_root :: SubExp
n_root = PrimValue -> SubExp
Constant forall a b. (a -> b) -> a -> b
$ FloatValue -> PrimValue
FloatValue forall a b. (a -> b) -> a -> b
$ forall num. Real num => FloatType -> num -> FloatValue
floatValue FloatType
Float64 (Double
1 forall a. Fractional a => a -> a -> a
/ forall a b. (Integral a, Num b) => a -> b
fromIntegral Integer
n :: Double)
    SubExp
bound_float <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"bound_f64" forall a b. (a -> b) -> a -> b
$ forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ ConvOp -> SubExp -> BasicOp
ConvOp (IntType -> FloatType -> ConvOp
UIToFP IntType
it FloatType
Float64) SubExp
bound
    SubExp
bound' <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"bound" forall a b. (a -> b) -> a -> b
$ forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (FloatType -> BinOp
FPow FloatType
Float64) SubExp
bound_float SubExp
n_root
    SubExp
bound_int <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"bound_int" forall a b. (a -> b) -> a -> b
$ forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ ConvOp -> SubExp -> BasicOp
ConvOp (FloatType -> IntType -> ConvOp
FPToUI FloatType
Float64 IntType
it) SubExp
bound'
    SubExp
total_iters <-
      forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"total_iters" forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
        BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> BinOp
Pow IntType
it) SubExp
bound_int (PrimValue -> SubExp
Constant forall a b. (a -> b) -> a -> b
$ IntValue -> PrimValue
IntValue forall a b. (a -> b) -> a -> b
$ forall int. Integral int => IntType -> int -> IntValue
intValue IntType
it Integer
n)
    SubExp
remain_iters <-
      forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"remain_iters" forall a b. (a -> b) -> a -> b
$ forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Overflow -> BinOp
Sub IntType
it Overflow
OverflowUndef) SubExp
bound SubExp
total_iters
    Exp SOACS
mined_loop <- SubExp -> Integer -> Exp SOACS -> ADM (Exp SOACS)
nestifyLoop SubExp
bound_int Integer
n Exp SOACS
loop
    Pat Type
pat' <- forall dec (m :: * -> *).
(Rename dec, MonadFreshNames m) =>
Pat dec -> m (Pat dec)
renamePat Pat Type
pat
    forall (m :: * -> *) rep a.
(MonadFreshNames m, Renameable rep, PrettyRep rep) =>
Exp rep
-> (Exp rep
    -> [(Param (FParamInfo rep), SubExp)]
    -> LoopForm rep
    -> VName
    -> IntType
    -> SubExp
    -> [(Param (LParamInfo rep), VName)]
    -> Body rep
    -> m a)
-> m a
renameForLoop Exp SOACS
loop forall a b. (a -> b) -> a -> b
$ \Exp SOACS
_loop' [(FParam SOACS, SubExp)]
val_pats' LoopForm SOACS
_form' VName
i' IntType
it' SubExp
_bound' [(Param (LParamInfo SOACS), VName)]
loop_vars' Body SOACS
body' -> do
      Body SOACS
remain_body <- forall (m :: * -> *).
MonadBuilder m =>
m (Body (Rep m)) -> m (Body (Rep m))
insertStmsM forall a b. (a -> b) -> a -> b
$ do
        VName
i_remain <-
          forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"i_remain" forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
            BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Overflow -> BinOp
Add IntType
it Overflow
OverflowUndef) SubExp
total_iters (VName -> SubExp
Var VName
i')
        forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a. Substitute a => Map VName VName -> a -> a
substituteNames (forall k a. k -> a -> Map k a
M.singleton VName
i' VName
i_remain) Body SOACS
body'
      let loop_params_rem :: [Param (TypeBase Shape Uniqueness)]
loop_params_rem = forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst [(FParam SOACS, SubExp)]
val_pats'
          loop_inits_rem :: [SubExp]
loop_inits_rem = forall a b. (a -> b) -> [a] -> [b]
map (VName -> SubExp
Var forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dec. PatElem dec -> VName
patElemName) forall a b. (a -> b) -> a -> b
$ forall dec. Pat dec -> [PatElem dec]
patElems Pat Type
pat'
          val_pats_rem :: [(Param (TypeBase Shape Uniqueness), SubExp)]
val_pats_rem = forall a b. [a] -> [b] -> [(a, b)]
zip [Param (TypeBase Shape Uniqueness)]
loop_params_rem [SubExp]
loop_inits_rem
          remain_loop :: Exp SOACS
remain_loop = forall rep.
[(FParam rep, SubExp)] -> LoopForm rep -> Body rep -> Exp rep
DoLoop [(Param (TypeBase Shape Uniqueness), SubExp)]
val_pats_rem (forall rep.
VName -> IntType -> SubExp -> [(LParam rep, VName)] -> LoopForm rep
ForLoop VName
i' IntType
it' SubExp
remain_iters [(Param (LParamInfo SOACS), VName)]
loop_vars') Body SOACS
remain_body
      forall (m :: * -> *) a. MonadBuilder m => m a -> m (Stms (Rep m))
collectStms_ forall a b. (a -> b) -> a -> b
$ do
        forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat Type
pat' Exp SOACS
mined_loop
        forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat Type
pat Exp SOACS
remain_loop

-- | Stripmines a statement. Only has an effect when the statement's
-- expression is a for-loop with a @#[stripmine(n)]@ attribute, where
-- @n@ is the nesting depth.
stripmineStm :: Stm SOACS -> ADM (Stms SOACS)
stripmineStm :: Stm SOACS -> ADM (Stms SOACS)
stripmineStm stm :: Stm SOACS
stm@(Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux loop :: Exp SOACS
loop@(DoLoop [(FParam SOACS, SubExp)]
_ ForLoop {} Body SOACS
_)) =
  case [Integer]
nums of
    (Integer
n : [Integer]
_) -> Integer -> Pat Type -> Exp SOACS -> ADM (Stms SOACS)
stripmine Integer
n Pat (LetDec SOACS)
pat Exp SOACS
loop
    [Integer]
_ -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall rep. Stm rep -> Stms rep
oneStm Stm SOACS
stm
  where
    extractNum :: Attr -> Maybe Integer
extractNum (AttrComp Name
"stripmine" [AttrInt Integer
n]) = forall a. a -> Maybe a
Just Integer
n
    extractNum Attr
_ = forall a. Maybe a
Nothing
    nums :: [Integer]
nums = forall a. [Maybe a] -> [a]
catMaybes forall a b. (a -> b) -> a -> b
$ forall a. (Attr -> a) -> Attrs -> [a]
mapAttrs Attr -> Maybe Integer
extractNum forall a b. (a -> b) -> a -> b
$ forall dec. StmAux dec -> Attrs
stmAuxAttrs StmAux (ExpDec SOACS)
aux
stripmineStm Stm SOACS
stm = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall rep. Stm rep -> Stms rep
oneStm Stm SOACS
stm

stripmineStms :: Stms SOACS -> ADM (Stms SOACS)
stripmineStms :: Stms SOACS -> ADM (Stms SOACS)
stripmineStms = forall m (t :: * -> *) (f :: * -> *) a.
(Monoid m, Traversable t, Applicative f) =>
(a -> f m) -> t a -> f m
traverseFold Stm SOACS -> ADM (Stms SOACS)
stripmineStm

-- | Forward pass transformation of a loop. This includes modifying the loop
-- to save the loop values at each iteration onto a tape as well as copying
-- any consumed arrays in the loop's body and consuming said copies in lieu of
-- the originals (which will be consumed later in the reverse pass).
fwdLoop :: Pat Type -> StmAux () -> Exp SOACS -> ADM ()
fwdLoop :: Pat Type -> StmAux () -> Exp SOACS -> ADM ()
fwdLoop Pat Type
pat StmAux ()
aux Exp SOACS
loop =
  forall rep a.
PrettyRep rep =>
Exp rep
-> ([(Param (FParamInfo rep), SubExp)]
    -> LoopForm rep
    -> VName
    -> IntType
    -> SubExp
    -> [(Param (LParamInfo rep), VName)]
    -> Body rep
    -> a)
-> a
bindForLoop Exp SOACS
loop forall a b. (a -> b) -> a -> b
$ \[(FParam SOACS, SubExp)]
val_pats LoopForm SOACS
form VName
i IntType
_it SubExp
bound [(Param (LParamInfo SOACS), VName)]
_loop_vars Body SOACS
body -> do
    SubExp
bound64 <- forall (m :: * -> *).
MonadBuilder m =>
IntType -> SubExp -> m SubExp
asIntS IntType
Int64 SubExp
bound
    let loop_params :: [Param (TypeBase Shape Uniqueness)]
loop_params = forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst [(FParam SOACS, SubExp)]
val_pats
        is_true_dep :: Param dec -> Bool
is_true_dep = Attr -> Attrs -> Bool
inAttrs (Name -> Attr
AttrName Name
"true_dep") forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dec. Param dec -> Attrs
paramAttrs
        dont_copy_params :: [Param (TypeBase Shape Uniqueness)]
dont_copy_params = forall a. (a -> Bool) -> [a] -> [a]
filter forall {dec}. Param dec -> Bool
is_true_dep [Param (TypeBase Shape Uniqueness)]
loop_params
        dont_copy :: [VName]
dont_copy = forall a b. (a -> b) -> [a] -> [b]
map forall dec. Param dec -> VName
paramName [Param (TypeBase Shape Uniqueness)]
dont_copy_params
        loop_params_to_copy :: [Param (TypeBase Shape Uniqueness)]
loop_params_to_copy = [Param (TypeBase Shape Uniqueness)]
loop_params forall a. Eq a => [a] -> [a] -> [a]
\\ [Param (TypeBase Shape Uniqueness)]
dont_copy_params

    [SubExp]
empty_saved_array <-
      forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Param (TypeBase Shape Uniqueness)]
loop_params_to_copy forall a b. (a -> b) -> a -> b
$ \Param (TypeBase Shape Uniqueness)
p ->
        forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp (VName -> [Char]
baseString (forall dec. Param dec -> VName
paramName Param (TypeBase Shape Uniqueness)
p) forall a. Semigroup a => a -> a -> a
<> [Char]
"_empty_saved")
          forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *). MonadBuilder m => Type -> m (Exp (Rep m))
eBlank (forall shape u_unused u.
ArrayShape shape =>
TypeBase shape u_unused -> shape -> u -> TypeBase shape u
arrayOf (forall dec. Param dec -> dec
paramDec Param (TypeBase Shape Uniqueness)
p) (forall d. [d] -> ShapeBase d
Shape [SubExp
bound64]) NoUniqueness
NoUniqueness)

    (Body SOACS
body', ([PatElem Type]
saved_pats, [Param (TypeBase Shape Uniqueness)]
saved_params)) <- forall (m :: * -> *) a.
MonadBuilder m =>
m (Result, a) -> m (Body (Rep m), a)
buildBody forall a b. (a -> b) -> a -> b
$
      forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (forall rep dec. (FParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfFParams [Param (TypeBase Shape Uniqueness)]
loop_params) forall a b. (a -> b) -> a -> b
$
        forall rep a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf LoopForm SOACS
form forall a b. (a -> b) -> a -> b
$ do
          Map VName VName
copy_substs <- [VName] -> Body SOACS -> ADM (Map VName VName)
copyConsumedArrsInBody [VName]
dont_copy Body SOACS
body
          forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms forall a b. (a -> b) -> a -> b
$ forall rep. Body rep -> Stms rep
bodyStms Body SOACS
body
          SubExp
i_i64 <- forall (m :: * -> *).
MonadBuilder m =>
IntType -> SubExp -> m SubExp
asIntS IntType
Int64 forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
i
          ([VName]
saved_updates, [(PatElem Type, Param (TypeBase Shape Uniqueness))]
saved_pats_params) <- forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$
            forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Param (TypeBase Shape Uniqueness)]
loop_params_to_copy forall a b. (a -> b) -> a -> b
$ \Param (TypeBase Shape Uniqueness)
p -> do
              let v :: VName
v = forall dec. Param dec -> VName
paramName Param (TypeBase Shape Uniqueness)
p
                  t :: TypeBase Shape Uniqueness
t = forall dec. Param dec -> dec
paramDec Param (TypeBase Shape Uniqueness)
p
              VName
saved_param_v <- forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName forall a b. (a -> b) -> a -> b
$ VName -> [Char]
baseString VName
v forall a. Semigroup a => a -> a -> a
<> [Char]
"_saved"
              VName
saved_pat_v <- forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName forall a b. (a -> b) -> a -> b
$ VName -> [Char]
baseString VName
v forall a. Semigroup a => a -> a -> a
<> [Char]
"_saved"
              VName -> VName -> ADM ()
setLoopTape VName
v VName
saved_pat_v
              let saved_param :: Param (TypeBase Shape Uniqueness)
saved_param = forall dec. Attrs -> VName -> dec -> Param dec
Param forall a. Monoid a => a
mempty VName
saved_param_v forall a b. (a -> b) -> a -> b
$ forall shape u_unused u.
ArrayShape shape =>
TypeBase shape u_unused -> shape -> u -> TypeBase shape u
arrayOf TypeBase Shape Uniqueness
t (forall d. [d] -> ShapeBase d
Shape [SubExp
bound64]) Uniqueness
Unique
                  saved_pat :: PatElem Type
saved_pat = forall dec. VName -> dec -> PatElem dec
PatElem VName
saved_pat_v forall a b. (a -> b) -> a -> b
$ forall shape u_unused u.
ArrayShape shape =>
TypeBase shape u_unused -> shape -> u -> TypeBase shape u
arrayOf TypeBase Shape Uniqueness
t (forall d. [d] -> ShapeBase d
Shape [SubExp
bound64]) NoUniqueness
NoUniqueness
              VName
saved_update <-
                forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (forall rep dec. (FParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfFParams [Param (TypeBase Shape Uniqueness)
saved_param])
                  forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *).
MonadBuilder m =>
[Char] -> VName -> Slice SubExp -> Exp (Rep m) -> m VName
letInPlace
                    (VName -> [Char]
baseString VName
v forall a. Semigroup a => a -> a -> a
<> [Char]
"_saved_update")
                    VName
saved_param_v
                    (Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice (forall shape.
TypeBase shape Uniqueness -> TypeBase shape NoUniqueness
fromDecl forall a b. (a -> b) -> a -> b
$ forall dec. Param dec -> dec
paramDec Param (TypeBase Shape Uniqueness)
saved_param) [forall d. d -> DimIndex d
DimFix SubExp
i_i64])
                  forall a b. (a -> b) -> a -> b
$ forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
copy_substs
                  forall a b. (a -> b) -> a -> b
$ forall rep. BasicOp -> Exp rep
BasicOp
                  forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp
                  forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
v
              forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName
saved_update, (PatElem Type
saved_pat, Param (TypeBase Shape Uniqueness)
saved_param))
          forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall rep. Body rep -> Result
bodyResult Body SOACS
body forall a. Semigroup a => a -> a -> a
<> [VName] -> Result
varsRes [VName]
saved_updates, forall a b. [(a, b)] -> ([a], [b])
unzip [(PatElem Type, Param (TypeBase Shape Uniqueness))]
saved_pats_params)

    let pat' :: Pat Type
pat' = Pat Type
pat forall a. Semigroup a => a -> a -> a
<> forall dec. [PatElem dec] -> Pat dec
Pat [PatElem Type]
saved_pats
        val_pats' :: [(Param (TypeBase Shape Uniqueness), SubExp)]
val_pats' = [(FParam SOACS, SubExp)]
val_pats forall a. Semigroup a => a -> a -> a
<> forall a b. [a] -> [b] -> [(a, b)]
zip [Param (TypeBase Shape Uniqueness)]
saved_params [SubExp]
empty_saved_array
    forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm forall a b. (a -> b) -> a -> b
$ forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
pat' StmAux ()
aux forall a b. (a -> b) -> a -> b
$ forall rep.
[(FParam rep, SubExp)] -> LoopForm rep -> Body rep -> Exp rep
DoLoop [(Param (TypeBase Shape Uniqueness), SubExp)]
val_pats' LoopForm SOACS
form Body SOACS
body'

-- | Construct a loop value-pattern for the adjoint of the
-- given variable.
valPatAdj :: VName -> ADM (Param DeclType, SubExp)
valPatAdj :: VName -> ADM (Param (TypeBase Shape Uniqueness), SubExp)
valPatAdj VName
v = do
  VName
v_adj <- VName -> ADM VName
adjVName VName
v
  VName
init_adj <- VName -> ADM VName
lookupAdjVal VName
v
  Type
t <- forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
init_adj
  forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall dec. Attrs -> VName -> dec -> Param dec
Param forall a. Monoid a => a
mempty VName
v_adj (forall shape.
TypeBase shape NoUniqueness
-> Uniqueness -> TypeBase shape Uniqueness
toDecl Type
t Uniqueness
Unique), VName -> SubExp
Var VName
init_adj)

valPatAdjs :: LoopInfo [VName] -> ADM (LoopInfo [(Param DeclType, SubExp)])
valPatAdjs :: LoopInfo [VName]
-> ADM (LoopInfo [(Param (TypeBase Shape Uniqueness), SubExp)])
valPatAdjs = (forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM) VName -> ADM (Param (TypeBase Shape Uniqueness), SubExp)
valPatAdj

-- | Reverses a loop by substituting the loop index as well as reversing
-- the arrays that loop variables are bound to.
reverseIndices :: Exp SOACS -> ADM (Substitutions, Substitutions, Stms SOACS)
reverseIndices :: Exp SOACS -> ADM (Map VName VName, Map VName VName, Stms SOACS)
reverseIndices Exp SOACS
loop = do
  forall rep a.
PrettyRep rep =>
Exp rep
-> ([(Param (FParamInfo rep), SubExp)]
    -> LoopForm rep
    -> VName
    -> IntType
    -> SubExp
    -> [(Param (LParamInfo rep), VName)]
    -> Body rep
    -> a)
-> a
bindForLoop Exp SOACS
loop forall a b. (a -> b) -> a -> b
$ \[(FParam SOACS, SubExp)]
_val_pats LoopForm SOACS
form VName
i IntType
it SubExp
bound [(Param (LParamInfo SOACS), VName)]
loop_vars Body SOACS
_body -> do
    SubExp
bound_minus_one <-
      forall rep a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf LoopForm SOACS
form forall a b. (a -> b) -> a -> b
$
        let one :: SubExp
one = PrimValue -> SubExp
Constant forall a b. (a -> b) -> a -> b
$ IntValue -> PrimValue
IntValue forall a b. (a -> b) -> a -> b
$ forall int. Integral int => IntType -> int -> IntValue
intValue IntType
it (Int
1 :: Int)
         in forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"bound-1" forall a b. (a -> b) -> a -> b
$ forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Overflow -> BinOp
Sub IntType
it Overflow
OverflowUndef) SubExp
bound SubExp
one

    Map VName VName
var_arrays_substs <- forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall k a. Ord k => [(k, a)] -> Map k a
M.fromList forall a b. (a -> b) -> a -> b
$
      forall rep a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf LoopForm SOACS
form forall a b. (a -> b) -> a -> b
$ do
        forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> b
snd [(Param (LParamInfo SOACS), VName)]
loop_vars) forall a b. (a -> b) -> a -> b
$ \VName
xs -> do
          Type
xs_t <- forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
xs
          VName
xs_rev <-
            forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"reverse" forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. BasicOp -> Exp rep
BasicOp forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> Slice SubExp -> BasicOp
Index VName
xs forall a b. (a -> b) -> a -> b
$
              Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice
                Type
xs_t
                [forall d. d -> d -> d -> DimIndex d
DimSlice SubExp
bound_minus_one SubExp
bound (PrimValue -> SubExp
Constant (IntValue -> PrimValue
IntValue (Int64 -> IntValue
Int64Value (-Int64
1))))]
          forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName
xs, VName
xs_rev)

    (VName
i_rev, Stms SOACS
i_stms) <- forall (m :: * -> *) a.
MonadBuilder m =>
m a -> m (a, Stms (Rep m))
collectStms forall a b. (a -> b) -> a -> b
$
      forall rep a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf LoopForm SOACS
form forall a b. (a -> b) -> a -> b
$ do
        forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp (VName -> [Char]
baseString VName
i forall a. Semigroup a => a -> a -> a
<> [Char]
"_rev") forall a b. (a -> b) -> a -> b
$
          forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
            BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Overflow -> BinOp
Sub IntType
it Overflow
OverflowWrap) SubExp
bound_minus_one (VName -> SubExp
Var VName
i)

    forall (f :: * -> *) a. Applicative f => a -> f a
pure (Map VName VName
var_arrays_substs, forall k a. k -> a -> Map k a
M.singleton VName
i VName
i_rev, Stms SOACS
i_stms)

-- | Pures a substitution which substitutes values in the reverse
-- loop body with values from the tape.
restore :: Stms SOACS -> [Param DeclType] -> VName -> ADM Substitutions
restore :: Stms SOACS
-> [Param (TypeBase Shape Uniqueness)]
-> VName
-> ADM (Map VName VName)
restore Stms SOACS
stms_adj [Param (TypeBase Shape Uniqueness)]
loop_params' VName
i' =
  forall k a. Ord k => [(k, a)] -> Map k a
M.fromList forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. [Maybe a] -> [a]
catMaybes forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Param (TypeBase Shape Uniqueness) -> ADM (Maybe (VName, VName))
f [Param (TypeBase Shape Uniqueness)]
loop_params'
  where
    dont_copy :: [VName]
dont_copy =
      forall a b. (a -> b) -> [a] -> [b]
map forall dec. Param dec -> VName
paramName forall a b. (a -> b) -> a -> b
$ forall a. (a -> Bool) -> [a] -> [a]
filter (Attr -> Attrs -> Bool
inAttrs (Name -> Attr
AttrName Name
"true_dep") forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dec. Param dec -> Attrs
paramAttrs) [Param (TypeBase Shape Uniqueness)]
loop_params'
    f :: Param (TypeBase Shape Uniqueness) -> ADM (Maybe (VName, VName))
f Param (TypeBase Shape Uniqueness)
p
      | VName
v forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` [VName]
dont_copy = do
          Maybe VName
m_vs <- VName -> ADM (Maybe VName)
lookupLoopTape VName
v
          case Maybe VName
m_vs of
            Maybe VName
Nothing -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Maybe a
Nothing
            Just VName
vs -> do
              Type
vs_t <- forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
vs
              SubExp
i_i64' <- forall (m :: * -> *).
MonadBuilder m =>
IntType -> SubExp -> m SubExp
asIntS IntType
Int64 forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
i'
              VName
v' <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"restore" forall a b. (a -> b) -> a -> b
$ forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
vs forall a b. (a -> b) -> a -> b
$ Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice Type
vs_t [forall d. d -> DimIndex d
DimFix SubExp
i_i64']
              Type
t <- forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
v
              VName
v'' <- case (Type
t, VName
v forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [VName]
consumed) of
                (Array {}, Bool
True) -> forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"restore_copy" forall a b. (a -> b) -> a -> b
$ forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy VName
v'
                (Type, Bool)
_ -> forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
v'
              forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just (VName
v, VName
v'')
      | Bool
otherwise = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Maybe a
Nothing
      where
        v :: VName
v = forall dec. Param dec -> VName
paramName Param (TypeBase Shape Uniqueness)
p
        consumed :: [VName]
consumed = Names -> [VName]
namesToList forall a b. (a -> b) -> a -> b
$ forall rep. Aliased rep => Stms rep -> Names
consumedInStms forall a b. (a -> b) -> a -> b
$ forall a b. (a, b) -> a
fst forall a b. (a -> b) -> a -> b
$ forall rep.
AliasableRep rep =>
AliasTable -> Stms rep -> (Stms (Aliases rep), AliasesAndConsumed)
Alias.analyseStms forall a. Monoid a => a
mempty Stms SOACS
stms_adj

-- | A type to keep track of and seperate values corresponding to different
-- parts of the loop.
data LoopInfo a = LoopInfo
  { forall a. LoopInfo a -> a
loopRes :: a,
    forall a. LoopInfo a -> a
loopFree :: a,
    forall a. LoopInfo a -> a
loopVars :: a,
    forall a. LoopInfo a -> a
loopVals :: a
  }
  deriving (forall a b. a -> LoopInfo b -> LoopInfo a
forall a b. (a -> b) -> LoopInfo a -> LoopInfo b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: forall a b. a -> LoopInfo b -> LoopInfo a
$c<$ :: forall a b. a -> LoopInfo b -> LoopInfo a
fmap :: forall a b. (a -> b) -> LoopInfo a -> LoopInfo b
$cfmap :: forall a b. (a -> b) -> LoopInfo a -> LoopInfo b
Functor, forall a. Eq a => a -> LoopInfo a -> Bool
forall a. Num a => LoopInfo a -> a
forall a. Ord a => LoopInfo a -> a
forall m. Monoid m => LoopInfo m -> m
forall a. LoopInfo a -> Bool
forall a. LoopInfo a -> Int
forall a. LoopInfo a -> [a]
forall a. (a -> a -> a) -> LoopInfo a -> a
forall m a. Monoid m => (a -> m) -> LoopInfo a -> m
forall b a. (b -> a -> b) -> b -> LoopInfo a -> b
forall a b. (a -> b -> b) -> b -> LoopInfo a -> b
forall (t :: * -> *).
(forall m. Monoid m => t m -> m)
-> (forall m a. Monoid m => (a -> m) -> t a -> m)
-> (forall m a. Monoid m => (a -> m) -> t a -> m)
-> (forall a b. (a -> b -> b) -> b -> t a -> b)
-> (forall a b. (a -> b -> b) -> b -> t a -> b)
-> (forall b a. (b -> a -> b) -> b -> t a -> b)
-> (forall b a. (b -> a -> b) -> b -> t a -> b)
-> (forall a. (a -> a -> a) -> t a -> a)
-> (forall a. (a -> a -> a) -> t a -> a)
-> (forall a. t a -> [a])
-> (forall a. t a -> Bool)
-> (forall a. t a -> Int)
-> (forall a. Eq a => a -> t a -> Bool)
-> (forall a. Ord a => t a -> a)
-> (forall a. Ord a => t a -> a)
-> (forall a. Num a => t a -> a)
-> (forall a. Num a => t a -> a)
-> Foldable t
product :: forall a. Num a => LoopInfo a -> a
$cproduct :: forall a. Num a => LoopInfo a -> a
sum :: forall a. Num a => LoopInfo a -> a
$csum :: forall a. Num a => LoopInfo a -> a
minimum :: forall a. Ord a => LoopInfo a -> a
$cminimum :: forall a. Ord a => LoopInfo a -> a
maximum :: forall a. Ord a => LoopInfo a -> a
$cmaximum :: forall a. Ord a => LoopInfo a -> a
elem :: forall a. Eq a => a -> LoopInfo a -> Bool
$celem :: forall a. Eq a => a -> LoopInfo a -> Bool
length :: forall a. LoopInfo a -> Int
$clength :: forall a. LoopInfo a -> Int
null :: forall a. LoopInfo a -> Bool
$cnull :: forall a. LoopInfo a -> Bool
toList :: forall a. LoopInfo a -> [a]
$ctoList :: forall a. LoopInfo a -> [a]
foldl1 :: forall a. (a -> a -> a) -> LoopInfo a -> a
$cfoldl1 :: forall a. (a -> a -> a) -> LoopInfo a -> a
foldr1 :: forall a. (a -> a -> a) -> LoopInfo a -> a
$cfoldr1 :: forall a. (a -> a -> a) -> LoopInfo a -> a
foldl' :: forall b a. (b -> a -> b) -> b -> LoopInfo a -> b
$cfoldl' :: forall b a. (b -> a -> b) -> b -> LoopInfo a -> b
foldl :: forall b a. (b -> a -> b) -> b -> LoopInfo a -> b
$cfoldl :: forall b a. (b -> a -> b) -> b -> LoopInfo a -> b
foldr' :: forall a b. (a -> b -> b) -> b -> LoopInfo a -> b
$cfoldr' :: forall a b. (a -> b -> b) -> b -> LoopInfo a -> b
foldr :: forall a b. (a -> b -> b) -> b -> LoopInfo a -> b
$cfoldr :: forall a b. (a -> b -> b) -> b -> LoopInfo a -> b
foldMap' :: forall m a. Monoid m => (a -> m) -> LoopInfo a -> m
$cfoldMap' :: forall m a. Monoid m => (a -> m) -> LoopInfo a -> m
foldMap :: forall m a. Monoid m => (a -> m) -> LoopInfo a -> m
$cfoldMap :: forall m a. Monoid m => (a -> m) -> LoopInfo a -> m
fold :: forall m. Monoid m => LoopInfo m -> m
$cfold :: forall m. Monoid m => LoopInfo m -> m
Foldable, Functor LoopInfo
Foldable LoopInfo
forall (t :: * -> *).
Functor t
-> Foldable t
-> (forall (f :: * -> *) a b.
    Applicative f =>
    (a -> f b) -> t a -> f (t b))
-> (forall (f :: * -> *) a. Applicative f => t (f a) -> f (t a))
-> (forall (m :: * -> *) a b.
    Monad m =>
    (a -> m b) -> t a -> m (t b))
-> (forall (m :: * -> *) a. Monad m => t (m a) -> m (t a))
-> Traversable t
forall (m :: * -> *) a. Monad m => LoopInfo (m a) -> m (LoopInfo a)
forall (f :: * -> *) a.
Applicative f =>
LoopInfo (f a) -> f (LoopInfo a)
forall (m :: * -> *) a b.
Monad m =>
(a -> m b) -> LoopInfo a -> m (LoopInfo b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> LoopInfo a -> f (LoopInfo b)
sequence :: forall (m :: * -> *) a. Monad m => LoopInfo (m a) -> m (LoopInfo a)
$csequence :: forall (m :: * -> *) a. Monad m => LoopInfo (m a) -> m (LoopInfo a)
mapM :: forall (m :: * -> *) a b.
Monad m =>
(a -> m b) -> LoopInfo a -> m (LoopInfo b)
$cmapM :: forall (m :: * -> *) a b.
Monad m =>
(a -> m b) -> LoopInfo a -> m (LoopInfo b)
sequenceA :: forall (f :: * -> *) a.
Applicative f =>
LoopInfo (f a) -> f (LoopInfo a)
$csequenceA :: forall (f :: * -> *) a.
Applicative f =>
LoopInfo (f a) -> f (LoopInfo a)
traverse :: forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> LoopInfo a -> f (LoopInfo b)
$ctraverse :: forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> LoopInfo a -> f (LoopInfo b)
Traversable, Int -> LoopInfo a -> ShowS
forall a. Show a => Int -> LoopInfo a -> ShowS
forall a. Show a => [LoopInfo a] -> ShowS
forall a. Show a => LoopInfo a -> [Char]
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
showList :: [LoopInfo a] -> ShowS
$cshowList :: forall a. Show a => [LoopInfo a] -> ShowS
show :: LoopInfo a -> [Char]
$cshow :: forall a. Show a => LoopInfo a -> [Char]
showsPrec :: Int -> LoopInfo a -> ShowS
$cshowsPrec :: forall a. Show a => Int -> LoopInfo a -> ShowS
Show)

-- | Transforms a for-loop into its reverse-mode derivative.
revLoop :: (Stms SOACS -> ADM ()) -> Pat Type -> Exp SOACS -> ADM ()
revLoop :: (Stms SOACS -> ADM ()) -> Pat Type -> Exp SOACS -> ADM ()
revLoop Stms SOACS -> ADM ()
diffStms Pat Type
pat Exp SOACS
loop =
  forall rep a.
PrettyRep rep =>
Exp rep
-> ([(Param (FParamInfo rep), SubExp)]
    -> LoopForm rep
    -> VName
    -> IntType
    -> SubExp
    -> [(Param (LParamInfo rep), VName)]
    -> Body rep
    -> a)
-> a
bindForLoop Exp SOACS
loop forall a b. (a -> b) -> a -> b
$ \[(FParam SOACS, SubExp)]
val_pats LoopForm SOACS
_form VName
_i IntType
_it SubExp
_bound [(Param (LParamInfo SOACS), VName)]
_loop_vars Body SOACS
_body ->
    forall (m :: * -> *) rep a.
(MonadFreshNames m, Renameable rep, PrettyRep rep) =>
Exp rep
-> (Exp rep
    -> [(Param (FParamInfo rep), SubExp)]
    -> LoopForm rep
    -> VName
    -> IntType
    -> SubExp
    -> [(Param (LParamInfo rep), VName)]
    -> Body rep
    -> m a)
-> m a
renameForLoop Exp SOACS
loop forall a b. (a -> b) -> a -> b
$
      \Exp SOACS
loop' [(FParam SOACS, SubExp)]
val_pats' LoopForm SOACS
form' VName
i' IntType
_it' SubExp
_bound' [(Param (LParamInfo SOACS), VName)]
loop_vars' Body SOACS
body' -> do
        let loop_params :: [Param (TypeBase Shape Uniqueness)]
loop_params = forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst [(FParam SOACS, SubExp)]
val_pats
            ([Param (TypeBase Shape Uniqueness)]
loop_params', [SubExp]
loop_vals') = forall a b. [(a, b)] -> ([a], [b])
unzip [(FParam SOACS, SubExp)]
val_pats'
            loop_var_arrays' :: [VName]
loop_var_arrays' = forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> b
snd [(Param (LParamInfo SOACS), VName)]
loop_vars'
            getVName :: SubExp -> Maybe VName
getVName Constant {} = forall a. Maybe a
Nothing
            getVName (Var VName
v) = forall a. a -> Maybe a
Just VName
v
            loop_vnames :: LoopInfo [VName]
loop_vnames =
              LoopInfo
                { loopRes :: [VName]
loopRes = forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe SubExpRes -> Maybe VName
subExpResVName forall a b. (a -> b) -> a -> b
$ forall rep. Body rep -> Result
bodyResult Body SOACS
body',
                  loopFree :: [VName]
loopFree =
                    (Names -> [VName]
namesToList (forall a. FreeIn a => a -> Names
freeIn Exp SOACS
loop') forall a. Eq a => [a] -> [a] -> [a]
\\ [VName]
loop_var_arrays') forall a. Eq a => [a] -> [a] -> [a]
\\ forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe SubExp -> Maybe VName
getVName [SubExp]
loop_vals',
                  loopVars :: [VName]
loopVars = [VName]
loop_var_arrays',
                  loopVals :: [VName]
loopVals = forall a. Ord a => [a] -> [a]
nubOrd forall a b. (a -> b) -> a -> b
$ forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe SubExp -> Maybe VName
getVName [SubExp]
loop_vals'
                }

        Map VName VName -> ADM ()
renameLoopTape forall a b. (a -> b) -> a -> b
$ forall k a. Ord k => [(k, a)] -> Map k a
M.fromList forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip (forall a b. (a -> b) -> [a] -> [b]
map forall dec. Param dec -> VName
paramName [Param (TypeBase Shape Uniqueness)]
loop_params) (forall a b. (a -> b) -> [a] -> [b]
map forall dec. Param dec -> VName
paramName [Param (TypeBase Shape Uniqueness)]
loop_params')

        forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip (forall rep. Body rep -> Result
bodyResult Body SOACS
body') forall a b. (a -> b) -> a -> b
$ forall dec. Pat dec -> [PatElem dec]
patElems Pat Type
pat) forall a b. (a -> b) -> a -> b
$ \(SubExpRes
se_res, PatElem Type
pe) ->
          case SubExpRes -> Maybe VName
subExpResVName SubExpRes
se_res of
            Just VName
v -> VName -> Adj -> ADM ()
setAdj VName
v forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName -> ADM Adj
lookupAdj (forall dec. PatElem dec -> VName
patElemName PatElem Type
pe)
            Maybe VName
Nothing -> forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

        (Map VName VName
var_array_substs, Map VName VName
i_subst, Stms SOACS
i_stms) <-
          Exp SOACS -> ADM (Map VName VName, Map VName VName, Stms SOACS)
reverseIndices Exp SOACS
loop'

        LoopInfo [(Param (TypeBase Shape Uniqueness), SubExp)]
val_pat_adjs <- LoopInfo [VName]
-> ADM (LoopInfo [(Param (TypeBase Shape Uniqueness), SubExp)])
valPatAdjs LoopInfo [VName]
loop_vnames
        let val_pat_adjs_list :: [(Param (TypeBase Shape Uniqueness), SubExp)]
val_pat_adjs_list = forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => t a -> [a]
toList LoopInfo [(Param (TypeBase Shape Uniqueness), SubExp)]
val_pat_adjs

        (LoopInfo [VName]
loop_adjs, Stms SOACS
stms_adj) <- forall (m :: * -> *) a.
MonadBuilder m =>
m a -> m (a, Stms (Rep m))
collectStms forall a b. (a -> b) -> a -> b
$
          forall rep a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf LoopForm SOACS
form' forall a b. (a -> b) -> a -> b
$
            forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (forall rep dec. (FParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfFParams (forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst [(Param (TypeBase Shape Uniqueness), SubExp)]
val_pat_adjs_list forall a. Semigroup a => a -> a -> a
<> [Param (TypeBase Shape Uniqueness)]
loop_params')) forall a b. (a -> b) -> a -> b
$ do
              forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms Stms SOACS
i_stms
              (LoopInfo [VName]
loop_adjs, Stms SOACS
stms_adj) <- forall (m :: * -> *) a.
MonadBuilder m =>
m a -> m (a, Stms (Rep m))
collectStms forall a b. (a -> b) -> a -> b
$
                forall a. ADM a -> ADM a
subAD forall a b. (a -> b) -> a -> b
$ do
                  forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_
                    (\(Param (TypeBase Shape Uniqueness), SubExp)
val_pat VName
v -> VName -> VName -> ADM ()
insAdj VName
v (forall dec. Param dec -> VName
paramName forall a b. (a -> b) -> a -> b
$ forall a b. (a, b) -> a
fst (Param (TypeBase Shape Uniqueness), SubExp)
val_pat))
                    [(Param (TypeBase Shape Uniqueness), SubExp)]
val_pat_adjs_list
                    (forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => t a -> [a]
toList LoopInfo [VName]
loop_vnames)
                  Stms SOACS -> ADM ()
diffStms forall a b. (a -> b) -> a -> b
$ forall rep. Body rep -> Stms rep
bodyStms Body SOACS
body'

                  let update_var_arrays :: VName -> VName -> ADM ()
update_var_arrays VName
v VName
vs = do
                        Type
vs_t <- forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
vs
                        VName
v_adj <- VName -> ADM VName
lookupAdjVal VName
v
                        Slice SubExp -> VName -> VName -> ADM ()
updateAdjSlice (Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice Type
vs_t [forall d. d -> DimIndex d
DimFix forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
i']) VName
vs VName
v_adj
                  forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_
                    VName -> VName -> ADM ()
update_var_arrays
                    (forall a b. (a -> b) -> [a] -> [b]
map (forall dec. Param dec -> VName
paramName forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst) [(Param (LParamInfo SOACS), VName)]
loop_vars')
                    (forall a. LoopInfo a -> a
loopVars LoopInfo [VName]
loop_vnames)

                  [VName]
loop_res_adjs <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (VName -> ADM VName
lookupAdjVal forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dec. Param dec -> VName
paramName) [Param (TypeBase Shape Uniqueness)]
loop_params'
                  [VName]
loop_free_adjs <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> ADM VName
lookupAdjVal forall a b. (a -> b) -> a -> b
$ forall a. LoopInfo a -> a
loopFree LoopInfo [VName]
loop_vnames
                  [VName]
loop_vars_adjs <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> ADM VName
lookupAdjVal forall a b. (a -> b) -> a -> b
$ forall a. LoopInfo a -> a
loopVars LoopInfo [VName]
loop_vnames
                  [VName]
loop_vals_adjs <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> ADM VName
lookupAdjVal forall a b. (a -> b) -> a -> b
$ forall a. LoopInfo a -> a
loopVals LoopInfo [VName]
loop_vnames

                  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$
                    LoopInfo
                      { loopRes :: [VName]
loopRes = [VName]
loop_res_adjs,
                        loopFree :: [VName]
loopFree = [VName]
loop_free_adjs,
                        loopVars :: [VName]
loopVars = [VName]
loop_vars_adjs,
                        loopVals :: [VName]
loopVals = [VName]
loop_vals_adjs
                      }
              (Map VName VName
substs, Stms SOACS
restore_stms) <-
                forall (m :: * -> *) a.
MonadBuilder m =>
m a -> m (a, Stms (Rep m))
collectStms forall a b. (a -> b) -> a -> b
$ Stms SOACS
-> [Param (TypeBase Shape Uniqueness)]
-> VName
-> ADM (Map VName VName)
restore Stms SOACS
stms_adj [Param (TypeBase Shape Uniqueness)]
loop_params' VName
i'
              forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms forall a b. (a -> b) -> a -> b
$ forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
i_subst Stms SOACS
restore_stms
              forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms forall a b. (a -> b) -> a -> b
$ forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
i_subst forall a b. (a -> b) -> a -> b
$ forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs Stms SOACS
stms_adj
              forall (f :: * -> *) a. Applicative f => a -> f a
pure LoopInfo [VName]
loop_adjs

        forall rep a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf Stms SOACS
stms_adj forall a b. (a -> b) -> a -> b
$
          forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (forall rep dec. (FParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfFParams forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst [(Param (TypeBase Shape Uniqueness), SubExp)]
val_pat_adjs_list) forall a b. (a -> b) -> a -> b
$ do
            let body_adj :: Body SOACS
body_adj = forall rep. Buildable rep => Stms rep -> Result -> Body rep
mkBody Stms SOACS
stms_adj forall a b. (a -> b) -> a -> b
$ [VName] -> Result
varsRes forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => t a -> [a]
toList LoopInfo [VName]
loop_adjs
                restore_true_deps :: Map VName VName
restore_true_deps = forall k a. Ord k => [(k, a)] -> Map k a
M.fromList forall a b. (a -> b) -> a -> b
$
                  forall a b c. (a -> b -> c) -> b -> a -> c
flip forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (forall a b. [a] -> [b] -> [(a, b)]
zip [Param (TypeBase Shape Uniqueness)]
loop_params' forall a b. (a -> b) -> a -> b
$ forall dec. Pat dec -> [PatElem dec]
patElems Pat Type
pat) forall a b. (a -> b) -> a -> b
$ \(Param (TypeBase Shape Uniqueness)
p, PatElem Type
pe) ->
                    if Param (TypeBase Shape Uniqueness)
p forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` forall a. (a -> Bool) -> [a] -> [a]
filter (Attr -> Attrs -> Bool
inAttrs (Name -> Attr
AttrName Name
"true_dep") forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dec. Param dec -> Attrs
paramAttrs) [Param (TypeBase Shape Uniqueness)]
loop_params'
                      then forall a. a -> Maybe a
Just (forall dec. Param dec -> VName
paramName Param (TypeBase Shape Uniqueness)
p, forall dec. PatElem dec -> VName
patElemName PatElem Type
pe)
                      else forall a. Maybe a
Nothing
            [VName]
adjs' <-
              forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m [VName]
letTupExp [Char]
"loop_adj" forall a b. (a -> b) -> a -> b
$
                forall a. Substitute a => Map VName VName -> a -> a
substituteNames (Map VName VName
restore_true_deps forall a. Semigroup a => a -> a -> a
<> Map VName VName
var_array_substs) forall a b. (a -> b) -> a -> b
$
                  forall rep.
[(FParam rep, SubExp)] -> LoopForm rep -> Body rep -> Exp rep
DoLoop [(Param (TypeBase Shape Uniqueness), SubExp)]
val_pat_adjs_list LoopForm SOACS
form' Body SOACS
body_adj
            let ([VName]
loop_res_adjs, [VName]
loop_free_var_val_adjs) =
                  forall a. Int -> [a] -> ([a], [a])
splitAt (forall (t :: * -> *) a. Foldable t => t a -> Int
length forall a b. (a -> b) -> a -> b
$ forall a. LoopInfo a -> a
loopRes LoopInfo [VName]
loop_adjs) [VName]
adjs'
                ([VName]
loop_free_adjs, [VName]
loop_var_val_adjs) =
                  forall a. Int -> [a] -> ([a], [a])
splitAt (forall (t :: * -> *) a. Foldable t => t a -> Int
length forall a b. (a -> b) -> a -> b
$ forall a. LoopInfo a -> a
loopFree LoopInfo [VName]
loop_adjs) [VName]
loop_free_var_val_adjs
                ([VName]
loop_var_adjs, [VName]
loop_val_adjs) =
                  forall a. Int -> [a] -> ([a], [a])
splitAt (forall (t :: * -> *) a. Foldable t => t a -> Int
length forall a b. (a -> b) -> a -> b
$ forall a. LoopInfo a -> a
loopVars LoopInfo [VName]
loop_adjs) [VName]
loop_var_val_adjs
            forall a. ADM a -> ADM a
returnSweepCode forall a b. (a -> b) -> a -> b
$ do
              forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ SubExp -> VName -> ADM ()
updateSubExpAdj [SubExp]
loop_vals' [VName]
loop_res_adjs
              forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> VName -> ADM ()
insAdj (forall a. LoopInfo a -> a
loopFree LoopInfo [VName]
loop_vnames) [VName]
loop_free_adjs
              forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> VName -> ADM ()
insAdj (forall a. LoopInfo a -> a
loopVars LoopInfo [VName]
loop_vnames) [VName]
loop_var_adjs
              forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> VName -> ADM ()
updateAdj (forall a. LoopInfo a -> a
loopVals LoopInfo [VName]
loop_vnames) [VName]
loop_val_adjs

-- | Transforms a loop into its reverse-mode derivative.
diffLoop :: (Stms SOACS -> ADM ()) -> Pat Type -> StmAux () -> Exp SOACS -> ADM () -> ADM ()
diffLoop :: (Stms SOACS -> ADM ())
-> Pat Type -> StmAux () -> Exp SOACS -> ADM () -> ADM ()
diffLoop Stms SOACS -> ADM ()
diffStms Pat Type
pat StmAux ()
aux Exp SOACS
loop ADM ()
m
  | forall rep. Exp rep -> Bool
isWhileLoop Exp SOACS
loop =
      let getBound :: Attr -> Maybe Integer
getBound (AttrComp Name
"bound" [AttrInt Integer
b]) = forall a. a -> Maybe a
Just Integer
b
          getBound Attr
_ = forall a. Maybe a
Nothing
          bounds :: [Integer]
bounds = forall a. [Maybe a] -> [a]
catMaybes forall a b. (a -> b) -> a -> b
$ forall a. (Attr -> a) -> Attrs -> [a]
mapAttrs Attr -> Maybe Integer
getBound forall a b. (a -> b) -> a -> b
$ forall dec. StmAux dec -> Attrs
stmAuxAttrs StmAux ()
aux
       in case [Integer]
bounds of
            (Integer
bound : [Integer]
_) -> do
              let bound_se :: SubExp
bound_se = PrimValue -> SubExp
Constant forall a b. (a -> b) -> a -> b
$ IntValue -> PrimValue
IntValue forall a b. (a -> b) -> a -> b
$ forall int. Integral int => IntType -> int -> IntValue
intValue IntType
Int64 Integer
bound
              Exp SOACS
for_loop <- SubExp -> Exp SOACS -> ADM (Exp SOACS)
convertWhileLoop SubExp
bound_se Exp SOACS
loop
              (Stms SOACS -> ADM ())
-> Pat Type -> StmAux () -> Exp SOACS -> ADM () -> ADM ()
diffLoop Stms SOACS -> ADM ()
diffStms Pat Type
pat StmAux ()
aux Exp SOACS
for_loop ADM ()
m
            [Integer]
_ -> do
              SubExp
bound <- Exp SOACS -> ADM SubExp
computeWhileIters Exp SOACS
loop
              Exp SOACS
for_loop <- SubExp -> Exp SOACS -> ADM (Exp SOACS)
convertWhileLoop SubExp
bound forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Exp rep -> m (Exp rep)
renameExp Exp SOACS
loop
              (Stms SOACS -> ADM ())
-> Pat Type -> StmAux () -> Exp SOACS -> ADM () -> ADM ()
diffLoop Stms SOACS -> ADM ()
diffStms Pat Type
pat StmAux ()
aux Exp SOACS
for_loop ADM ()
m
  | Bool
otherwise = do
      Pat Type -> StmAux () -> Exp SOACS -> ADM ()
fwdLoop Pat Type
pat StmAux ()
aux Exp SOACS
loop
      ADM ()
m
      (Stms SOACS -> ADM ()) -> Pat Type -> Exp SOACS -> ADM ()
revLoop Stms SOACS -> ADM ()
diffStms Pat Type
pat Exp SOACS
loop