{-# LANGUAGE TypeFamilies #-}

module Futhark.AD.Rev.Loop (diffLoop, stripmineStms) where

import Control.Monad
import Data.Foldable (toList)
import Data.List ((\\))
import Data.Map qualified as M
import Data.Maybe
import Futhark.AD.Rev.Monad
import Futhark.Analysis.Alias qualified as Alias
import Futhark.Analysis.PrimExp.Convert
import Futhark.Builder
import Futhark.IR.Aliases (consumedInStms)
import Futhark.IR.SOACS
import Futhark.Tools
import Futhark.Transform.Rename
import Futhark.Transform.Substitute
import Futhark.Util (nubOrd, traverseFold)

-- | A convenience function to bring the components of a for-loop into
-- scope and throw an error if the passed 'Exp' is not a for-loop.
bindForLoop ::
  (PrettyRep rep) =>
  Exp rep ->
  ( [(Param (FParamInfo rep), SubExp)] ->
    LoopForm ->
    VName ->
    IntType ->
    SubExp ->
    Body rep ->
    a
  ) ->
  a
bindForLoop :: forall rep a.
PrettyRep rep =>
Exp rep
-> ([(Param (FParamInfo rep), SubExp)]
    -> LoopForm -> VName -> IntType -> SubExp -> Body rep -> a)
-> a
bindForLoop (Loop [(FParam rep, SubExp)]
val_pats form :: LoopForm
form@(ForLoop VName
i IntType
it SubExp
bound) Body rep
body) [(FParam rep, SubExp)]
-> LoopForm -> VName -> IntType -> SubExp -> Body rep -> a
f =
  [(FParam rep, SubExp)]
-> LoopForm -> VName -> IntType -> SubExp -> Body rep -> a
f [(FParam rep, SubExp)]
val_pats LoopForm
form VName
i IntType
it SubExp
bound Body rep
body
bindForLoop Exp rep
e [(FParam rep, SubExp)]
-> LoopForm -> VName -> IntType -> SubExp -> Body rep -> a
_ = [Char] -> a
forall a. HasCallStack => [Char] -> a
error ([Char] -> a) -> [Char] -> a
forall a b. (a -> b) -> a -> b
$ [Char]
"bindForLoop: not a for-loop:\n" [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> Exp rep -> [Char]
forall a. Pretty a => a -> [Char]
prettyString Exp rep
e

-- | A convenience function to rename a for-loop and then bind the
-- renamed components.
renameForLoop ::
  (MonadFreshNames m, Renameable rep, PrettyRep rep) =>
  Exp rep ->
  ( Exp rep ->
    [(Param (FParamInfo rep), SubExp)] ->
    LoopForm ->
    VName ->
    IntType ->
    SubExp ->
    Body rep ->
    m a
  ) ->
  m a
renameForLoop :: forall (m :: * -> *) rep a.
(MonadFreshNames m, Renameable rep, PrettyRep rep) =>
Exp rep
-> (Exp rep
    -> [(Param (FParamInfo rep), SubExp)]
    -> LoopForm
    -> VName
    -> IntType
    -> SubExp
    -> Body rep
    -> m a)
-> m a
renameForLoop Exp rep
loop Exp rep
-> [(Param (FParamInfo rep), SubExp)]
-> LoopForm
-> VName
-> IntType
-> SubExp
-> Body rep
-> m a
f = Exp rep -> m (Exp rep)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Exp rep -> m (Exp rep)
renameExp Exp rep
loop m (Exp rep) -> (Exp rep -> m a) -> m a
forall a b. m a -> (a -> m b) -> m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \Exp rep
loop' -> Exp rep
-> ([(Param (FParamInfo rep), SubExp)]
    -> LoopForm -> VName -> IntType -> SubExp -> Body rep -> m a)
-> m a
forall rep a.
PrettyRep rep =>
Exp rep
-> ([(Param (FParamInfo rep), SubExp)]
    -> LoopForm -> VName -> IntType -> SubExp -> Body rep -> a)
-> a
bindForLoop Exp rep
loop' (Exp rep
-> [(Param (FParamInfo rep), SubExp)]
-> LoopForm
-> VName
-> IntType
-> SubExp
-> Body rep
-> m a
f Exp rep
loop')

-- | Is the loop a while-loop?
isWhileLoop :: Exp rep -> Bool
isWhileLoop :: forall rep. Exp rep -> Bool
isWhileLoop (Loop [(FParam rep, SubExp)]
_ WhileLoop {} Body rep
_) = Bool
True
isWhileLoop Exp rep
_ = Bool
False

-- | Augments a while-loop to also compute the number of iterations.
computeWhileIters :: Exp SOACS -> ADM SubExp
computeWhileIters :: Exp SOACS -> ADM SubExp
computeWhileIters (Loop [(FParam SOACS, SubExp)]
val_pats (WhileLoop VName
b) Body SOACS
body) = do
  VName
bound_v <- [Char] -> ADM VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"bound"
  let t :: TypeBase shape u
t = PrimType -> TypeBase shape u
forall shape u. PrimType -> TypeBase shape u
Prim (PrimType -> TypeBase shape u) -> PrimType -> TypeBase shape u
forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
IntType IntType
Int64
      bound_param :: Param (TypeBase Shape Uniqueness)
bound_param = Attrs
-> VName
-> TypeBase Shape Uniqueness
-> Param (TypeBase Shape Uniqueness)
forall dec. Attrs -> VName -> dec -> Param dec
Param Attrs
forall a. Monoid a => a
mempty VName
bound_v TypeBase Shape Uniqueness
forall {shape} {u}. TypeBase shape u
t
  SubExp
bound_init <- [Char] -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"bound_init" (Exp (Rep ADM) -> ADM SubExp) -> Exp (Rep ADM) -> ADM SubExp
forall a b. (a -> b) -> a -> b
$ Type -> Exp SOACS
forall rep. Type -> Exp rep
zeroExp Type
forall {shape} {u}. TypeBase shape u
t
  Body SOACS
body' <- Scope SOACS -> ADM (Body SOACS) -> ADM (Body SOACS)
forall a. Scope SOACS -> ADM a -> ADM a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope ([Param (TypeBase Shape Uniqueness)] -> Scope SOACS
forall rep dec. (FParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfFParams [Param (TypeBase Shape Uniqueness)
bound_param]) (ADM (Body SOACS) -> ADM (Body SOACS))
-> ADM (Body SOACS) -> ADM (Body SOACS)
forall a b. (a -> b) -> a -> b
$
    ADM Result -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
m Result -> m (Body (Rep m))
buildBody_ (ADM Result -> ADM (Body (Rep ADM)))
-> ADM Result -> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$ do
      SubExp
bound_plus_one <-
        let one :: SubExp
one = PrimValue -> SubExp
Constant (PrimValue -> SubExp) -> PrimValue -> SubExp
forall a b. (a -> b) -> a -> b
$ IntValue -> PrimValue
IntValue (IntValue -> PrimValue) -> IntValue -> PrimValue
forall a b. (a -> b) -> a -> b
$ IntType -> Int -> IntValue
forall int. Integral int => IntType -> int -> IntValue
intValue IntType
Int64 (Int
1 :: Int)
         in [Char] -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"bound+1" (Exp (Rep ADM) -> ADM SubExp) -> Exp (Rep ADM) -> ADM SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep ADM)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep ADM)) -> BasicOp -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Overflow -> BinOp
Add IntType
Int64 Overflow
OverflowUndef) (VName -> SubExp
Var VName
bound_v) SubExp
one
      Stms (Rep ADM) -> ADM ()
forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms (Stms (Rep ADM) -> ADM ()) -> Stms (Rep ADM) -> ADM ()
forall a b. (a -> b) -> a -> b
$ Body SOACS -> Stms SOACS
forall rep. Body rep -> Stms rep
bodyStms Body SOACS
body
      Result -> ADM Result
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SubExpRes -> Result
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SubExp -> SubExpRes
subExpRes SubExp
bound_plus_one) Result -> Result -> Result
forall a. Semigroup a => a -> a -> a
<> Body SOACS -> Result
forall rep. Body rep -> Result
bodyResult Body SOACS
body)
  [SubExp]
res <- [Char] -> Exp (Rep ADM) -> ADM [SubExp]
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m [SubExp]
letTupExp' [Char]
"loop" (Exp (Rep ADM) -> ADM [SubExp]) -> Exp (Rep ADM) -> ADM [SubExp]
forall a b. (a -> b) -> a -> b
$ [(FParam SOACS, SubExp)] -> LoopForm -> Body SOACS -> Exp SOACS
forall rep.
[(FParam rep, SubExp)] -> LoopForm -> Body rep -> Exp rep
Loop ((Param (TypeBase Shape Uniqueness)
bound_param, SubExp
bound_init) (Param (TypeBase Shape Uniqueness), SubExp)
-> [(Param (TypeBase Shape Uniqueness), SubExp)]
-> [(Param (TypeBase Shape Uniqueness), SubExp)]
forall a. a -> [a] -> [a]
: [(Param (TypeBase Shape Uniqueness), SubExp)]
[(FParam SOACS, SubExp)]
val_pats) (VName -> LoopForm
WhileLoop VName
b) Body SOACS
body'
  SubExp -> ADM SubExp
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SubExp -> ADM SubExp) -> SubExp -> ADM SubExp
forall a b. (a -> b) -> a -> b
$ [SubExp] -> SubExp
forall a. HasCallStack => [a] -> a
head [SubExp]
res
computeWhileIters Exp SOACS
e = [Char] -> ADM SubExp
forall a. HasCallStack => [Char] -> a
error ([Char] -> ADM SubExp) -> [Char] -> ADM SubExp
forall a b. (a -> b) -> a -> b
$ [Char]
"convertWhileIters: not a while-loop:\n" [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> Exp SOACS -> [Char]
forall a. Pretty a => a -> [Char]
prettyString Exp SOACS
e

-- | Converts a 'WhileLoop' into a 'ForLoop'. Requires that the
-- surrounding 'Loop' is annotated with a @#[bound(n)]@ attribute,
-- where @n@ is an upper bound on the number of iterations of the
-- while-loop. The resulting for-loop will execute for @n@ iterations on
-- all inputs, so the tighter the bound the better.
convertWhileLoop :: SubExp -> Exp SOACS -> ADM (Exp SOACS)
convertWhileLoop :: SubExp -> Exp SOACS -> ADM (Exp SOACS)
convertWhileLoop SubExp
bound_se (Loop [(FParam SOACS, SubExp)]
val_pats (WhileLoop VName
cond) Body SOACS
body) =
  Scope SOACS -> ADM (Exp SOACS) -> ADM (Exp SOACS)
forall a. Scope SOACS -> ADM a -> ADM a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope ([FParam SOACS] -> Scope SOACS
forall rep dec. (FParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfFParams ([FParam SOACS] -> Scope SOACS) -> [FParam SOACS] -> Scope SOACS
forall a b. (a -> b) -> a -> b
$ ((FParam SOACS, SubExp) -> FParam SOACS)
-> [(FParam SOACS, SubExp)] -> [FParam SOACS]
forall a b. (a -> b) -> [a] -> [b]
map (FParam SOACS, SubExp) -> FParam SOACS
forall a b. (a, b) -> a
fst [(FParam SOACS, SubExp)]
val_pats) (ADM (Exp SOACS) -> ADM (Exp SOACS))
-> ADM (Exp SOACS) -> ADM (Exp SOACS)
forall a b. (a -> b) -> a -> b
$ do
    VName
i <- [Char] -> ADM VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"i"
    Body SOACS
body' <-
      [ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody
        [ ADM (Exp (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
            (Exp (Rep ADM) -> ADM (Exp (Rep ADM))
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp (Rep ADM) -> ADM (Exp (Rep ADM)))
-> Exp (Rep ADM) -> ADM (Exp (Rep ADM))
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep ADM)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep ADM)) -> BasicOp -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
cond)
            (Body SOACS -> ADM (Body SOACS)
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Body SOACS
body)
            ([SubExp] -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[SubExp] -> m (Body (Rep m))
resultBodyM ([SubExp] -> ADM (Body (Rep ADM)))
-> [SubExp] -> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$ ((Param (TypeBase Shape Uniqueness), SubExp) -> SubExp)
-> [(Param (TypeBase Shape Uniqueness), SubExp)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> SubExp
Var (VName -> SubExp)
-> ((Param (TypeBase Shape Uniqueness), SubExp) -> VName)
-> (Param (TypeBase Shape Uniqueness), SubExp)
-> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param (TypeBase Shape Uniqueness) -> VName
forall dec. Param dec -> VName
paramName (Param (TypeBase Shape Uniqueness) -> VName)
-> ((Param (TypeBase Shape Uniqueness), SubExp)
    -> Param (TypeBase Shape Uniqueness))
-> (Param (TypeBase Shape Uniqueness), SubExp)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Param (TypeBase Shape Uniqueness), SubExp)
-> Param (TypeBase Shape Uniqueness)
forall a b. (a, b) -> a
fst) [(Param (TypeBase Shape Uniqueness), SubExp)]
[(FParam SOACS, SubExp)]
val_pats)
        ]
    Exp SOACS -> ADM (Exp SOACS)
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp SOACS -> ADM (Exp SOACS)) -> Exp SOACS -> ADM (Exp SOACS)
forall a b. (a -> b) -> a -> b
$ [(FParam SOACS, SubExp)] -> LoopForm -> Body SOACS -> Exp SOACS
forall rep.
[(FParam rep, SubExp)] -> LoopForm -> Body rep -> Exp rep
Loop [(FParam SOACS, SubExp)]
val_pats (VName -> IntType -> SubExp -> LoopForm
ForLoop VName
i IntType
Int64 SubExp
bound_se) Body SOACS
body'
convertWhileLoop SubExp
_ Exp SOACS
e = [Char] -> ADM (Exp SOACS)
forall a. HasCallStack => [Char] -> a
error ([Char] -> ADM (Exp SOACS)) -> [Char] -> ADM (Exp SOACS)
forall a b. (a -> b) -> a -> b
$ [Char]
"convertWhileLoopBound: not a while-loop:\n" [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> Exp SOACS -> [Char]
forall a. Pretty a => a -> [Char]
prettyString Exp SOACS
e

-- | @nestifyLoop n bound loop@ transforms a loop into a depth-@n@ loop nest
-- of @bound@-iteration loops. This transformation does not preserve
-- the original semantics of the loop: @n@ and @bound@ may be arbitrary and have
-- no relation to the number of iterations of @loop@.
nestifyLoop ::
  SubExp ->
  Integer ->
  Exp SOACS ->
  ADM (Exp SOACS)
nestifyLoop :: SubExp -> Integer -> Exp SOACS -> ADM (Exp SOACS)
nestifyLoop SubExp
bound_se = SubExp -> Integer -> Exp SOACS -> ADM (Exp SOACS)
nestifyLoop' SubExp
bound_se
  where
    nestifyLoop' :: SubExp -> Integer -> Exp SOACS -> ADM (Exp SOACS)
nestifyLoop' SubExp
offset Integer
n Exp SOACS
loop = Exp SOACS
-> ([(FParam SOACS, SubExp)]
    -> LoopForm
    -> VName
    -> IntType
    -> SubExp
    -> Body SOACS
    -> ADM (Exp SOACS))
-> ADM (Exp SOACS)
forall rep a.
PrettyRep rep =>
Exp rep
-> ([(Param (FParamInfo rep), SubExp)]
    -> LoopForm -> VName -> IntType -> SubExp -> Body rep -> a)
-> a
bindForLoop Exp SOACS
loop [(Param (TypeBase Shape Uniqueness), SubExp)]
-> LoopForm
-> VName
-> IntType
-> SubExp
-> Body SOACS
-> ADM (Exp SOACS)
[(FParam SOACS, SubExp)]
-> LoopForm
-> VName
-> IntType
-> SubExp
-> Body SOACS
-> ADM (Exp SOACS)
nestify
      where
        nestify :: [(Param (TypeBase Shape Uniqueness), SubExp)]
-> LoopForm
-> VName
-> IntType
-> SubExp
-> Body SOACS
-> ADM (Exp SOACS)
nestify [(Param (TypeBase Shape Uniqueness), SubExp)]
val_pats LoopForm
_form VName
i IntType
it SubExp
_bound Body SOACS
body
          | Integer
n Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
> Integer
1 = do
              Exp SOACS
-> (Exp SOACS
    -> [(FParam SOACS, SubExp)]
    -> LoopForm
    -> VName
    -> IntType
    -> SubExp
    -> Body SOACS
    -> ADM (Exp SOACS))
-> ADM (Exp SOACS)
forall (m :: * -> *) rep a.
(MonadFreshNames m, Renameable rep, PrettyRep rep) =>
Exp rep
-> (Exp rep
    -> [(Param (FParamInfo rep), SubExp)]
    -> LoopForm
    -> VName
    -> IntType
    -> SubExp
    -> Body rep
    -> m a)
-> m a
renameForLoop Exp SOACS
loop ((Exp SOACS
  -> [(FParam SOACS, SubExp)]
  -> LoopForm
  -> VName
  -> IntType
  -> SubExp
  -> Body SOACS
  -> ADM (Exp SOACS))
 -> ADM (Exp SOACS))
-> (Exp SOACS
    -> [(FParam SOACS, SubExp)]
    -> LoopForm
    -> VName
    -> IntType
    -> SubExp
    -> Body SOACS
    -> ADM (Exp SOACS))
-> ADM (Exp SOACS)
forall a b. (a -> b) -> a -> b
$ \Exp SOACS
_loop' [(FParam SOACS, SubExp)]
val_pats' LoopForm
_form' VName
i' IntType
it' SubExp
_bound' Body SOACS
body' -> do
                let loop_params :: [Param (TypeBase Shape Uniqueness)]
loop_params = ((Param (TypeBase Shape Uniqueness), SubExp)
 -> Param (TypeBase Shape Uniqueness))
-> [(Param (TypeBase Shape Uniqueness), SubExp)]
-> [Param (TypeBase Shape Uniqueness)]
forall a b. (a -> b) -> [a] -> [b]
map (Param (TypeBase Shape Uniqueness), SubExp)
-> Param (TypeBase Shape Uniqueness)
forall a b. (a, b) -> a
fst [(Param (TypeBase Shape Uniqueness), SubExp)]
val_pats
                    loop_params' :: [Param (TypeBase Shape Uniqueness)]
loop_params' = ((Param (TypeBase Shape Uniqueness), SubExp)
 -> Param (TypeBase Shape Uniqueness))
-> [(Param (TypeBase Shape Uniqueness), SubExp)]
-> [Param (TypeBase Shape Uniqueness)]
forall a b. (a -> b) -> [a] -> [b]
map (Param (TypeBase Shape Uniqueness), SubExp)
-> Param (TypeBase Shape Uniqueness)
forall a b. (a, b) -> a
fst [(Param (TypeBase Shape Uniqueness), SubExp)]
[(FParam SOACS, SubExp)]
val_pats'
                    loop_inits' :: [SubExp]
loop_inits' = (Param (TypeBase Shape Uniqueness) -> SubExp)
-> [Param (TypeBase Shape Uniqueness)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> SubExp
Var (VName -> SubExp)
-> (Param (TypeBase Shape Uniqueness) -> VName)
-> Param (TypeBase Shape Uniqueness)
-> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param (TypeBase Shape Uniqueness) -> VName
forall dec. Param dec -> VName
paramName) [Param (TypeBase Shape Uniqueness)]
loop_params
                    val_pats'' :: [(Param (TypeBase Shape Uniqueness), SubExp)]
val_pats'' = [Param (TypeBase Shape Uniqueness)]
-> [SubExp] -> [(Param (TypeBase Shape Uniqueness), SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param (TypeBase Shape Uniqueness)]
loop_params' [SubExp]
loop_inits'
                Body SOACS
outer_body <-
                  ADM Result -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
m Result -> m (Body (Rep m))
buildBody_ (ADM Result -> ADM (Body (Rep ADM)))
-> ADM Result -> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$ do
                    SubExp
offset' <-
                      [Char] -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"offset" (Exp SOACS -> ADM SubExp)
-> (BasicOp -> Exp SOACS) -> BasicOp -> ADM SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> ADM SubExp) -> BasicOp -> ADM SubExp
forall a b. (a -> b) -> a -> b
$
                        BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Overflow -> BinOp
Mul IntType
it Overflow
OverflowUndef) SubExp
offset (VName -> SubExp
Var VName
i)

                    Body SOACS
inner_body <- ADM (Body (Rep ADM)) -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
m (Body (Rep m)) -> m (Body (Rep m))
insertStmsM (ADM (Body (Rep ADM)) -> ADM (Body (Rep ADM)))
-> ADM (Body (Rep ADM)) -> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$ do
                      VName
i_inner <-
                        [Char] -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"i_inner" (Exp SOACS -> ADM VName)
-> (BasicOp -> Exp SOACS) -> BasicOp -> ADM VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> ADM VName) -> BasicOp -> ADM VName
forall a b. (a -> b) -> a -> b
$
                          BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Overflow -> BinOp
Add IntType
it Overflow
OverflowUndef) SubExp
offset' (VName -> SubExp
Var VName
i')
                      Body SOACS -> ADM (Body SOACS)
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Body SOACS -> ADM (Body SOACS)) -> Body SOACS -> ADM (Body SOACS)
forall a b. (a -> b) -> a -> b
$ Map VName VName -> Body SOACS -> Body SOACS
forall a. Substitute a => Map VName VName -> a -> a
substituteNames (VName -> VName -> Map VName VName
forall k a. k -> a -> Map k a
M.singleton VName
i' VName
i_inner) Body SOACS
body'

                    [VName]
inner_loop <-
                      [Char] -> Exp (Rep ADM) -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m [VName]
letTupExp [Char]
"inner_loop"
                        (Exp SOACS -> ADM [VName]) -> ADM (Exp SOACS) -> ADM [VName]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< SubExp -> Integer -> Exp SOACS -> ADM (Exp SOACS)
nestifyLoop'
                          SubExp
offset'
                          (Integer
n Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- Integer
1)
                          ([(FParam SOACS, SubExp)] -> LoopForm -> Body SOACS -> Exp SOACS
forall rep.
[(FParam rep, SubExp)] -> LoopForm -> Body rep -> Exp rep
Loop [(Param (TypeBase Shape Uniqueness), SubExp)]
[(FParam SOACS, SubExp)]
val_pats'' (VName -> IntType -> SubExp -> LoopForm
ForLoop VName
i' IntType
it' SubExp
bound_se) Body SOACS
inner_body)
                    Result -> ADM Result
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Result -> ADM Result) -> Result -> ADM Result
forall a b. (a -> b) -> a -> b
$ [VName] -> Result
varsRes [VName]
inner_loop
                Exp SOACS -> ADM (Exp SOACS)
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp SOACS -> ADM (Exp SOACS)) -> Exp SOACS -> ADM (Exp SOACS)
forall a b. (a -> b) -> a -> b
$ [(FParam SOACS, SubExp)] -> LoopForm -> Body SOACS -> Exp SOACS
forall rep.
[(FParam rep, SubExp)] -> LoopForm -> Body rep -> Exp rep
Loop [(Param (TypeBase Shape Uniqueness), SubExp)]
[(FParam SOACS, SubExp)]
val_pats (VName -> IntType -> SubExp -> LoopForm
ForLoop VName
i IntType
it SubExp
bound_se) Body SOACS
outer_body
          | Integer
n Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
== Integer
1 =
              Exp SOACS -> ADM (Exp SOACS)
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp SOACS -> ADM (Exp SOACS)) -> Exp SOACS -> ADM (Exp SOACS)
forall a b. (a -> b) -> a -> b
$ [(FParam SOACS, SubExp)] -> LoopForm -> Body SOACS -> Exp SOACS
forall rep.
[(FParam rep, SubExp)] -> LoopForm -> Body rep -> Exp rep
Loop [(Param (TypeBase Shape Uniqueness), SubExp)]
[(FParam SOACS, SubExp)]
val_pats (VName -> IntType -> SubExp -> LoopForm
ForLoop VName
i IntType
it SubExp
bound_se) Body SOACS
body
          | Bool
otherwise = Exp SOACS -> ADM (Exp SOACS)
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Exp SOACS
loop

-- | @stripmine n pat loop@ stripmines a loop into a depth-@n@ loop nest.
-- An additional @bound - (floor(bound^(1/n)))^n@-iteration remainder loop is
-- inserted after the stripmined loop which executes the remaining iterations
-- so that the stripmined loop is semantically equivalent to the original loop.
stripmine :: Integer -> Pat Type -> Exp SOACS -> ADM (Stms SOACS)
stripmine :: Integer -> Pat Type -> Exp SOACS -> ADM (Stms SOACS)
stripmine Integer
n Pat Type
pat Exp SOACS
loop = do
  Exp SOACS
-> ([(FParam SOACS, SubExp)]
    -> LoopForm
    -> VName
    -> IntType
    -> SubExp
    -> Body SOACS
    -> ADM (Stms SOACS))
-> ADM (Stms SOACS)
forall rep a.
PrettyRep rep =>
Exp rep
-> ([(Param (FParamInfo rep), SubExp)]
    -> LoopForm -> VName -> IntType -> SubExp -> Body rep -> a)
-> a
bindForLoop Exp SOACS
loop (([(FParam SOACS, SubExp)]
  -> LoopForm
  -> VName
  -> IntType
  -> SubExp
  -> Body SOACS
  -> ADM (Stms SOACS))
 -> ADM (Stms SOACS))
-> ([(FParam SOACS, SubExp)]
    -> LoopForm
    -> VName
    -> IntType
    -> SubExp
    -> Body SOACS
    -> ADM (Stms SOACS))
-> ADM (Stms SOACS)
forall a b. (a -> b) -> a -> b
$ \[(FParam SOACS, SubExp)]
_val_pats LoopForm
_form VName
_i IntType
it SubExp
bound Body SOACS
_body -> do
    let n_root :: SubExp
n_root = PrimValue -> SubExp
Constant (PrimValue -> SubExp) -> PrimValue -> SubExp
forall a b. (a -> b) -> a -> b
$ FloatValue -> PrimValue
FloatValue (FloatValue -> PrimValue) -> FloatValue -> PrimValue
forall a b. (a -> b) -> a -> b
$ FloatType -> Double -> FloatValue
forall num. Real num => FloatType -> num -> FloatValue
floatValue FloatType
Float64 (Double
1 Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Integer -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral Integer
n :: Double)
    SubExp
bound_float <- [Char] -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"bound_f64" (Exp (Rep ADM) -> ADM SubExp) -> Exp (Rep ADM) -> ADM SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep ADM)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep ADM)) -> BasicOp -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ ConvOp -> SubExp -> BasicOp
ConvOp (IntType -> FloatType -> ConvOp
UIToFP IntType
it FloatType
Float64) SubExp
bound
    SubExp
bound' <- [Char] -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"bound" (Exp (Rep ADM) -> ADM SubExp) -> Exp (Rep ADM) -> ADM SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep ADM)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep ADM)) -> BasicOp -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (FloatType -> BinOp
FPow FloatType
Float64) SubExp
bound_float SubExp
n_root
    SubExp
bound_int <- [Char] -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"bound_int" (Exp (Rep ADM) -> ADM SubExp) -> Exp (Rep ADM) -> ADM SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep ADM)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep ADM)) -> BasicOp -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ ConvOp -> SubExp -> BasicOp
ConvOp (FloatType -> IntType -> ConvOp
FPToUI FloatType
Float64 IntType
it) SubExp
bound'
    SubExp
total_iters <-
      [Char] -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"total_iters" (Exp SOACS -> ADM SubExp)
-> (BasicOp -> Exp SOACS) -> BasicOp -> ADM SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> ADM SubExp) -> BasicOp -> ADM SubExp
forall a b. (a -> b) -> a -> b
$
        BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> BinOp
Pow IntType
it) SubExp
bound_int (PrimValue -> SubExp
Constant (PrimValue -> SubExp) -> PrimValue -> SubExp
forall a b. (a -> b) -> a -> b
$ IntValue -> PrimValue
IntValue (IntValue -> PrimValue) -> IntValue -> PrimValue
forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> IntValue
forall int. Integral int => IntType -> int -> IntValue
intValue IntType
it Integer
n)
    SubExp
remain_iters <-
      [Char] -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"remain_iters" (Exp (Rep ADM) -> ADM SubExp) -> Exp (Rep ADM) -> ADM SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep ADM)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep ADM)) -> BasicOp -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Overflow -> BinOp
Sub IntType
it Overflow
OverflowUndef) SubExp
bound SubExp
total_iters
    Exp SOACS
mined_loop <- SubExp -> Integer -> Exp SOACS -> ADM (Exp SOACS)
nestifyLoop SubExp
bound_int Integer
n Exp SOACS
loop
    Pat Type
pat' <- Pat Type -> ADM (Pat Type)
forall dec (m :: * -> *).
(Rename dec, MonadFreshNames m) =>
Pat dec -> m (Pat dec)
renamePat Pat Type
pat
    Exp SOACS
-> (Exp SOACS
    -> [(FParam SOACS, SubExp)]
    -> LoopForm
    -> VName
    -> IntType
    -> SubExp
    -> Body SOACS
    -> ADM (Stms SOACS))
-> ADM (Stms SOACS)
forall (m :: * -> *) rep a.
(MonadFreshNames m, Renameable rep, PrettyRep rep) =>
Exp rep
-> (Exp rep
    -> [(Param (FParamInfo rep), SubExp)]
    -> LoopForm
    -> VName
    -> IntType
    -> SubExp
    -> Body rep
    -> m a)
-> m a
renameForLoop Exp SOACS
loop ((Exp SOACS
  -> [(FParam SOACS, SubExp)]
  -> LoopForm
  -> VName
  -> IntType
  -> SubExp
  -> Body SOACS
  -> ADM (Stms SOACS))
 -> ADM (Stms SOACS))
-> (Exp SOACS
    -> [(FParam SOACS, SubExp)]
    -> LoopForm
    -> VName
    -> IntType
    -> SubExp
    -> Body SOACS
    -> ADM (Stms SOACS))
-> ADM (Stms SOACS)
forall a b. (a -> b) -> a -> b
$ \Exp SOACS
_loop [(FParam SOACS, SubExp)]
val_pats' LoopForm
_form' VName
i' IntType
it' SubExp
_bound' Body SOACS
body' -> do
      Body SOACS
remain_body <- ADM (Body (Rep ADM)) -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
m (Body (Rep m)) -> m (Body (Rep m))
insertStmsM (ADM (Body (Rep ADM)) -> ADM (Body (Rep ADM)))
-> ADM (Body (Rep ADM)) -> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$ do
        VName
i_remain <-
          [Char] -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"i_remain" (Exp SOACS -> ADM VName)
-> (BasicOp -> Exp SOACS) -> BasicOp -> ADM VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> ADM VName) -> BasicOp -> ADM VName
forall a b. (a -> b) -> a -> b
$
            BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Overflow -> BinOp
Add IntType
it Overflow
OverflowUndef) SubExp
total_iters (VName -> SubExp
Var VName
i')
        Body SOACS -> ADM (Body SOACS)
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Body SOACS -> ADM (Body SOACS)) -> Body SOACS -> ADM (Body SOACS)
forall a b. (a -> b) -> a -> b
$ Map VName VName -> Body SOACS -> Body SOACS
forall a. Substitute a => Map VName VName -> a -> a
substituteNames (VName -> VName -> Map VName VName
forall k a. k -> a -> Map k a
M.singleton VName
i' VName
i_remain) Body SOACS
body'
      let loop_params_rem :: [Param (TypeBase Shape Uniqueness)]
loop_params_rem = ((Param (TypeBase Shape Uniqueness), SubExp)
 -> Param (TypeBase Shape Uniqueness))
-> [(Param (TypeBase Shape Uniqueness), SubExp)]
-> [Param (TypeBase Shape Uniqueness)]
forall a b. (a -> b) -> [a] -> [b]
map (Param (TypeBase Shape Uniqueness), SubExp)
-> Param (TypeBase Shape Uniqueness)
forall a b. (a, b) -> a
fst [(Param (TypeBase Shape Uniqueness), SubExp)]
[(FParam SOACS, SubExp)]
val_pats'
          loop_inits_rem :: [SubExp]
loop_inits_rem = (PatElem Type -> SubExp) -> [PatElem Type] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> SubExp
Var (VName -> SubExp)
-> (PatElem Type -> VName) -> PatElem Type -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatElem Type -> VName
forall dec. PatElem dec -> VName
patElemName) ([PatElem Type] -> [SubExp]) -> [PatElem Type] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Pat Type -> [PatElem Type]
forall dec. Pat dec -> [PatElem dec]
patElems Pat Type
pat'
          val_pats_rem :: [(Param (TypeBase Shape Uniqueness), SubExp)]
val_pats_rem = [Param (TypeBase Shape Uniqueness)]
-> [SubExp] -> [(Param (TypeBase Shape Uniqueness), SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param (TypeBase Shape Uniqueness)]
loop_params_rem [SubExp]
loop_inits_rem
          remain_loop :: Exp SOACS
remain_loop = [(FParam SOACS, SubExp)] -> LoopForm -> Body SOACS -> Exp SOACS
forall rep.
[(FParam rep, SubExp)] -> LoopForm -> Body rep -> Exp rep
Loop [(Param (TypeBase Shape Uniqueness), SubExp)]
[(FParam SOACS, SubExp)]
val_pats_rem (VName -> IntType -> SubExp -> LoopForm
ForLoop VName
i' IntType
it' SubExp
remain_iters) Body SOACS
remain_body
      ADM () -> ADM (Stms (Rep ADM))
forall (m :: * -> *) a. MonadBuilder m => m a -> m (Stms (Rep m))
collectStms_ (ADM () -> ADM (Stms (Rep ADM))) -> ADM () -> ADM (Stms (Rep ADM))
forall a b. (a -> b) -> a -> b
$ do
        Pat (LetDec (Rep ADM)) -> Exp (Rep ADM) -> ADM ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat Type
Pat (LetDec (Rep ADM))
pat' Exp (Rep ADM)
Exp SOACS
mined_loop
        Pat (LetDec (Rep ADM)) -> Exp (Rep ADM) -> ADM ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat Type
Pat (LetDec (Rep ADM))
pat Exp (Rep ADM)
Exp SOACS
remain_loop

-- | Stripmines a statement. Only has an effect when the statement's
-- expression is a for-loop with a @#[stripmine(n)]@ attribute, where
-- @n@ is the nesting depth.
stripmineStm :: Stm SOACS -> ADM (Stms SOACS)
stripmineStm :: Stm SOACS -> ADM (Stms SOACS)
stripmineStm stm :: Stm SOACS
stm@(Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux loop :: Exp SOACS
loop@(Loop [(FParam SOACS, SubExp)]
_ ForLoop {} Body SOACS
_)) =
  case [Integer]
nums of
    (Integer
n : [Integer]
_) -> Integer -> Pat Type -> Exp SOACS -> ADM (Stms SOACS)
stripmine Integer
n Pat Type
Pat (LetDec SOACS)
pat Exp SOACS
loop
    [Integer]
_ -> Stms SOACS -> ADM (Stms SOACS)
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms SOACS -> ADM (Stms SOACS)) -> Stms SOACS -> ADM (Stms SOACS)
forall a b. (a -> b) -> a -> b
$ Stm SOACS -> Stms SOACS
forall rep. Stm rep -> Stms rep
oneStm Stm SOACS
stm
  where
    extractNum :: Attr -> Maybe Integer
extractNum (AttrComp Name
"stripmine" [AttrInt Integer
n]) = Integer -> Maybe Integer
forall a. a -> Maybe a
Just Integer
n
    extractNum Attr
_ = Maybe Integer
forall a. Maybe a
Nothing
    nums :: [Integer]
nums = [Maybe Integer] -> [Integer]
forall a. [Maybe a] -> [a]
catMaybes ([Maybe Integer] -> [Integer]) -> [Maybe Integer] -> [Integer]
forall a b. (a -> b) -> a -> b
$ (Attr -> Maybe Integer) -> Attrs -> [Maybe Integer]
forall a. (Attr -> a) -> Attrs -> [a]
mapAttrs Attr -> Maybe Integer
extractNum (Attrs -> [Maybe Integer]) -> Attrs -> [Maybe Integer]
forall a b. (a -> b) -> a -> b
$ StmAux () -> Attrs
forall dec. StmAux dec -> Attrs
stmAuxAttrs StmAux ()
StmAux (ExpDec SOACS)
aux
stripmineStm Stm SOACS
stm = Stms SOACS -> ADM (Stms SOACS)
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms SOACS -> ADM (Stms SOACS)) -> Stms SOACS -> ADM (Stms SOACS)
forall a b. (a -> b) -> a -> b
$ Stm SOACS -> Stms SOACS
forall rep. Stm rep -> Stms rep
oneStm Stm SOACS
stm

stripmineStms :: Stms SOACS -> ADM (Stms SOACS)
stripmineStms :: Stms SOACS -> ADM (Stms SOACS)
stripmineStms = (Stm SOACS -> ADM (Stms SOACS)) -> Stms SOACS -> ADM (Stms SOACS)
forall m (t :: * -> *) (f :: * -> *) a.
(Monoid m, Traversable t, Applicative f) =>
(a -> f m) -> t a -> f m
traverseFold Stm SOACS -> ADM (Stms SOACS)
stripmineStm

-- | Forward pass transformation of a loop. This includes modifying the loop
-- to save the loop values at each iteration onto a tape as well as copying
-- any consumed arrays in the loop's body and consuming said copies in lieu of
-- the originals (which will be consumed later in the reverse pass).
fwdLoop :: Pat Type -> StmAux () -> Exp SOACS -> ADM ()
fwdLoop :: Pat Type -> StmAux () -> Exp SOACS -> ADM ()
fwdLoop Pat Type
pat StmAux ()
aux Exp SOACS
loop =
  Exp SOACS
-> ([(FParam SOACS, SubExp)]
    -> LoopForm -> VName -> IntType -> SubExp -> Body SOACS -> ADM ())
-> ADM ()
forall rep a.
PrettyRep rep =>
Exp rep
-> ([(Param (FParamInfo rep), SubExp)]
    -> LoopForm -> VName -> IntType -> SubExp -> Body rep -> a)
-> a
bindForLoop Exp SOACS
loop (([(FParam SOACS, SubExp)]
  -> LoopForm -> VName -> IntType -> SubExp -> Body SOACS -> ADM ())
 -> ADM ())
-> ([(FParam SOACS, SubExp)]
    -> LoopForm -> VName -> IntType -> SubExp -> Body SOACS -> ADM ())
-> ADM ()
forall a b. (a -> b) -> a -> b
$ \[(FParam SOACS, SubExp)]
val_pats LoopForm
form VName
i IntType
_it SubExp
bound Body SOACS
body -> do
    SubExp
bound64 <- IntType -> SubExp -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
IntType -> SubExp -> m SubExp
asIntS IntType
Int64 SubExp
bound
    let loop_params :: [Param (TypeBase Shape Uniqueness)]
loop_params = ((Param (TypeBase Shape Uniqueness), SubExp)
 -> Param (TypeBase Shape Uniqueness))
-> [(Param (TypeBase Shape Uniqueness), SubExp)]
-> [Param (TypeBase Shape Uniqueness)]
forall a b. (a -> b) -> [a] -> [b]
map (Param (TypeBase Shape Uniqueness), SubExp)
-> Param (TypeBase Shape Uniqueness)
forall a b. (a, b) -> a
fst [(Param (TypeBase Shape Uniqueness), SubExp)]
[(FParam SOACS, SubExp)]
val_pats
        is_true_dep :: Param dec -> Bool
is_true_dep = Attr -> Attrs -> Bool
inAttrs (Name -> Attr
AttrName Name
"true_dep") (Attrs -> Bool) -> (Param dec -> Attrs) -> Param dec -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param dec -> Attrs
forall dec. Param dec -> Attrs
paramAttrs
        dont_copy_params :: [Param (TypeBase Shape Uniqueness)]
dont_copy_params = (Param (TypeBase Shape Uniqueness) -> Bool)
-> [Param (TypeBase Shape Uniqueness)]
-> [Param (TypeBase Shape Uniqueness)]
forall a. (a -> Bool) -> [a] -> [a]
filter Param (TypeBase Shape Uniqueness) -> Bool
forall {dec}. Param dec -> Bool
is_true_dep [Param (TypeBase Shape Uniqueness)]
loop_params
        dont_copy :: [VName]
dont_copy = (Param (TypeBase Shape Uniqueness) -> VName)
-> [Param (TypeBase Shape Uniqueness)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param (TypeBase Shape Uniqueness) -> VName
forall dec. Param dec -> VName
paramName [Param (TypeBase Shape Uniqueness)]
dont_copy_params
        loop_params_to_copy :: [Param (TypeBase Shape Uniqueness)]
loop_params_to_copy = [Param (TypeBase Shape Uniqueness)]
loop_params [Param (TypeBase Shape Uniqueness)]
-> [Param (TypeBase Shape Uniqueness)]
-> [Param (TypeBase Shape Uniqueness)]
forall a. Eq a => [a] -> [a] -> [a]
\\ [Param (TypeBase Shape Uniqueness)]
dont_copy_params

    [SubExp]
empty_saved_array <-
      [Param (TypeBase Shape Uniqueness)]
-> (Param (TypeBase Shape Uniqueness) -> ADM SubExp)
-> ADM [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Param (TypeBase Shape Uniqueness)]
loop_params_to_copy ((Param (TypeBase Shape Uniqueness) -> ADM SubExp) -> ADM [SubExp])
-> (Param (TypeBase Shape Uniqueness) -> ADM SubExp)
-> ADM [SubExp]
forall a b. (a -> b) -> a -> b
$ \Param (TypeBase Shape Uniqueness)
p ->
        [Char] -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp (VName -> [Char]
baseString (Param (TypeBase Shape Uniqueness) -> VName
forall dec. Param dec -> VName
paramName Param (TypeBase Shape Uniqueness)
p) [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> [Char]
"_empty_saved")
          (Exp SOACS -> ADM SubExp) -> ADM (Exp SOACS) -> ADM SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Type -> ADM (Exp (Rep ADM))
forall (m :: * -> *). MonadBuilder m => Type -> m (Exp (Rep m))
eBlank (TypeBase Shape Uniqueness -> Shape -> NoUniqueness -> Type
forall shape u_unused u.
ArrayShape shape =>
TypeBase shape u_unused -> shape -> u -> TypeBase shape u
arrayOf (Param (TypeBase Shape Uniqueness) -> TypeBase Shape Uniqueness
forall dec. Param dec -> dec
paramDec Param (TypeBase Shape Uniqueness)
p) ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp
bound64]) NoUniqueness
NoUniqueness)

    (Body SOACS
body', ([PatElem Type]
saved_pats, [Param (TypeBase Shape Uniqueness)]
saved_params)) <- ADM (Result, ([PatElem Type], [Param (TypeBase Shape Uniqueness)]))
-> ADM
     (Body (Rep ADM),
      ([PatElem Type], [Param (TypeBase Shape Uniqueness)]))
forall (m :: * -> *) a.
MonadBuilder m =>
m (Result, a) -> m (Body (Rep m), a)
buildBody (ADM
   (Result, ([PatElem Type], [Param (TypeBase Shape Uniqueness)]))
 -> ADM
      (Body (Rep ADM),
       ([PatElem Type], [Param (TypeBase Shape Uniqueness)])))
-> ADM
     (Result, ([PatElem Type], [Param (TypeBase Shape Uniqueness)]))
-> ADM
     (Body (Rep ADM),
      ([PatElem Type], [Param (TypeBase Shape Uniqueness)]))
forall a b. (a -> b) -> a -> b
$
      Scope SOACS
-> ADM
     (Result, ([PatElem Type], [Param (TypeBase Shape Uniqueness)]))
-> ADM
     (Result, ([PatElem Type], [Param (TypeBase Shape Uniqueness)]))
forall a. Scope SOACS -> ADM a -> ADM a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope ([Param (TypeBase Shape Uniqueness)] -> Scope SOACS
forall rep dec. (FParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfFParams [Param (TypeBase Shape Uniqueness)]
loop_params) (ADM
   (Result, ([PatElem Type], [Param (TypeBase Shape Uniqueness)]))
 -> ADM
      (Result, ([PatElem Type], [Param (TypeBase Shape Uniqueness)])))
-> ADM
     (Result, ([PatElem Type], [Param (TypeBase Shape Uniqueness)]))
-> ADM
     (Result, ([PatElem Type], [Param (TypeBase Shape Uniqueness)]))
forall a b. (a -> b) -> a -> b
$
        Scope SOACS
-> ADM
     (Result, ([PatElem Type], [Param (TypeBase Shape Uniqueness)]))
-> ADM
     (Result, ([PatElem Type], [Param (TypeBase Shape Uniqueness)]))
forall a. Scope SOACS -> ADM a -> ADM a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (LoopForm -> Scope SOACS
forall rep. LoopForm -> Scope rep
scopeOfLoopForm LoopForm
form) (ADM
   (Result, ([PatElem Type], [Param (TypeBase Shape Uniqueness)]))
 -> ADM
      (Result, ([PatElem Type], [Param (TypeBase Shape Uniqueness)])))
-> ADM
     (Result, ([PatElem Type], [Param (TypeBase Shape Uniqueness)]))
-> ADM
     (Result, ([PatElem Type], [Param (TypeBase Shape Uniqueness)]))
forall a b. (a -> b) -> a -> b
$ do
          Map VName VName
copy_substs <- [VName] -> Body SOACS -> ADM (Map VName VName)
copyConsumedArrsInBody [VName]
dont_copy Body SOACS
body
          Stms (Rep ADM) -> ADM ()
forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms (Stms (Rep ADM) -> ADM ()) -> Stms (Rep ADM) -> ADM ()
forall a b. (a -> b) -> a -> b
$ Body SOACS -> Stms SOACS
forall rep. Body rep -> Stms rep
bodyStms Body SOACS
body
          SubExp
i_i64 <- IntType -> SubExp -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
IntType -> SubExp -> m SubExp
asIntS IntType
Int64 (SubExp -> ADM SubExp) -> SubExp -> ADM SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
i
          ([VName]
saved_updates, [(PatElem Type, Param (TypeBase Shape Uniqueness))]
saved_pats_params) <- ([(VName, (PatElem Type, Param (TypeBase Shape Uniqueness)))]
 -> ([VName], [(PatElem Type, Param (TypeBase Shape Uniqueness))]))
-> ADM [(VName, (PatElem Type, Param (TypeBase Shape Uniqueness)))]
-> ADM
     ([VName], [(PatElem Type, Param (TypeBase Shape Uniqueness))])
forall a b. (a -> b) -> ADM a -> ADM b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [(VName, (PatElem Type, Param (TypeBase Shape Uniqueness)))]
-> ([VName], [(PatElem Type, Param (TypeBase Shape Uniqueness))])
forall a b. [(a, b)] -> ([a], [b])
unzip (ADM [(VName, (PatElem Type, Param (TypeBase Shape Uniqueness)))]
 -> ADM
      ([VName], [(PatElem Type, Param (TypeBase Shape Uniqueness))]))
-> ADM [(VName, (PatElem Type, Param (TypeBase Shape Uniqueness)))]
-> ADM
     ([VName], [(PatElem Type, Param (TypeBase Shape Uniqueness))])
forall a b. (a -> b) -> a -> b
$
            [Param (TypeBase Shape Uniqueness)]
-> (Param (TypeBase Shape Uniqueness)
    -> ADM (VName, (PatElem Type, Param (TypeBase Shape Uniqueness))))
-> ADM [(VName, (PatElem Type, Param (TypeBase Shape Uniqueness)))]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Param (TypeBase Shape Uniqueness)]
loop_params_to_copy ((Param (TypeBase Shape Uniqueness)
  -> ADM (VName, (PatElem Type, Param (TypeBase Shape Uniqueness))))
 -> ADM
      [(VName, (PatElem Type, Param (TypeBase Shape Uniqueness)))])
-> (Param (TypeBase Shape Uniqueness)
    -> ADM (VName, (PatElem Type, Param (TypeBase Shape Uniqueness))))
-> ADM [(VName, (PatElem Type, Param (TypeBase Shape Uniqueness)))]
forall a b. (a -> b) -> a -> b
$ \Param (TypeBase Shape Uniqueness)
p -> do
              let v :: VName
v = Param (TypeBase Shape Uniqueness) -> VName
forall dec. Param dec -> VName
paramName Param (TypeBase Shape Uniqueness)
p
                  t :: TypeBase Shape Uniqueness
t = Param (TypeBase Shape Uniqueness) -> TypeBase Shape Uniqueness
forall dec. Param dec -> dec
paramDec Param (TypeBase Shape Uniqueness)
p
              VName
saved_param_v <- [Char] -> ADM VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName ([Char] -> ADM VName) -> [Char] -> ADM VName
forall a b. (a -> b) -> a -> b
$ VName -> [Char]
baseString VName
v [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> [Char]
"_saved"
              VName
saved_pat_v <- [Char] -> ADM VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName ([Char] -> ADM VName) -> [Char] -> ADM VName
forall a b. (a -> b) -> a -> b
$ VName -> [Char]
baseString VName
v [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> [Char]
"_saved"
              VName -> VName -> ADM ()
setLoopTape VName
v VName
saved_pat_v
              let saved_param :: Param (TypeBase Shape Uniqueness)
saved_param = Attrs
-> VName
-> TypeBase Shape Uniqueness
-> Param (TypeBase Shape Uniqueness)
forall dec. Attrs -> VName -> dec -> Param dec
Param Attrs
forall a. Monoid a => a
mempty VName
saved_param_v (TypeBase Shape Uniqueness -> Param (TypeBase Shape Uniqueness))
-> TypeBase Shape Uniqueness -> Param (TypeBase Shape Uniqueness)
forall a b. (a -> b) -> a -> b
$ TypeBase Shape Uniqueness
-> Shape -> Uniqueness -> TypeBase Shape Uniqueness
forall shape u_unused u.
ArrayShape shape =>
TypeBase shape u_unused -> shape -> u -> TypeBase shape u
arrayOf TypeBase Shape Uniqueness
t ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp
bound64]) Uniqueness
Unique
                  saved_pat :: PatElem Type
saved_pat = VName -> Type -> PatElem Type
forall dec. VName -> dec -> PatElem dec
PatElem VName
saved_pat_v (Type -> PatElem Type) -> Type -> PatElem Type
forall a b. (a -> b) -> a -> b
$ TypeBase Shape Uniqueness -> Shape -> NoUniqueness -> Type
forall shape u_unused u.
ArrayShape shape =>
TypeBase shape u_unused -> shape -> u -> TypeBase shape u
arrayOf TypeBase Shape Uniqueness
t ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp
bound64]) NoUniqueness
NoUniqueness
              VName
saved_update <-
                Scope SOACS -> ADM VName -> ADM VName
forall a. Scope SOACS -> ADM a -> ADM a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope ([Param (TypeBase Shape Uniqueness)] -> Scope SOACS
forall rep dec. (FParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfFParams [Param (TypeBase Shape Uniqueness)
saved_param])
                  (ADM VName -> ADM VName) -> ADM VName -> ADM VName
forall a b. (a -> b) -> a -> b
$ [Char] -> VName -> Slice SubExp -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> VName -> Slice SubExp -> Exp (Rep m) -> m VName
letInPlace
                    (VName -> [Char]
baseString VName
v [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> [Char]
"_saved_update")
                    VName
saved_param_v
                    (Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice (TypeBase Shape Uniqueness -> Type
forall shape.
TypeBase shape Uniqueness -> TypeBase shape NoUniqueness
fromDecl (TypeBase Shape Uniqueness -> Type)
-> TypeBase Shape Uniqueness -> Type
forall a b. (a -> b) -> a -> b
$ Param (TypeBase Shape Uniqueness) -> TypeBase Shape Uniqueness
forall dec. Param dec -> dec
paramDec Param (TypeBase Shape Uniqueness)
saved_param) [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix SubExp
i_i64])
                  (Exp (Rep ADM) -> ADM VName) -> Exp (Rep ADM) -> ADM VName
forall a b. (a -> b) -> a -> b
$ Map VName VName -> Exp (Rep ADM) -> Exp (Rep ADM)
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
copy_substs
                  (Exp (Rep ADM) -> Exp (Rep ADM)) -> Exp (Rep ADM) -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep ADM)
forall rep. BasicOp -> Exp rep
BasicOp
                  (BasicOp -> Exp (Rep ADM)) -> BasicOp -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp
                  (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
v
              (VName, (PatElem Type, Param (TypeBase Shape Uniqueness)))
-> ADM (VName, (PatElem Type, Param (TypeBase Shape Uniqueness)))
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName
saved_update, (PatElem Type
saved_pat, Param (TypeBase Shape Uniqueness)
saved_param))
          (Result, ([PatElem Type], [Param (TypeBase Shape Uniqueness)]))
-> ADM
     (Result, ([PatElem Type], [Param (TypeBase Shape Uniqueness)]))
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Body SOACS -> Result
forall rep. Body rep -> Result
bodyResult Body SOACS
body Result -> Result -> Result
forall a. Semigroup a => a -> a -> a
<> [VName] -> Result
varsRes [VName]
saved_updates, [(PatElem Type, Param (TypeBase Shape Uniqueness))]
-> ([PatElem Type], [Param (TypeBase Shape Uniqueness)])
forall a b. [(a, b)] -> ([a], [b])
unzip [(PatElem Type, Param (TypeBase Shape Uniqueness))]
saved_pats_params)

    let pat' :: Pat Type
pat' = Pat Type
pat Pat Type -> Pat Type -> Pat Type
forall a. Semigroup a => a -> a -> a
<> [PatElem Type] -> Pat Type
forall dec. [PatElem dec] -> Pat dec
Pat [PatElem Type]
saved_pats
        val_pats' :: [(Param (TypeBase Shape Uniqueness), SubExp)]
val_pats' = [(Param (TypeBase Shape Uniqueness), SubExp)]
[(FParam SOACS, SubExp)]
val_pats [(Param (TypeBase Shape Uniqueness), SubExp)]
-> [(Param (TypeBase Shape Uniqueness), SubExp)]
-> [(Param (TypeBase Shape Uniqueness), SubExp)]
forall a. Semigroup a => a -> a -> a
<> [Param (TypeBase Shape Uniqueness)]
-> [SubExp] -> [(Param (TypeBase Shape Uniqueness), SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param (TypeBase Shape Uniqueness)]
saved_params [SubExp]
empty_saved_array
    Stm (Rep ADM) -> ADM ()
forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm (Stm (Rep ADM) -> ADM ()) -> Stm (Rep ADM) -> ADM ()
forall a b. (a -> b) -> a -> b
$ Pat (LetDec (Rep ADM))
-> StmAux (ExpDec (Rep ADM)) -> Exp (Rep ADM) -> Stm (Rep ADM)
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
Pat (LetDec (Rep ADM))
pat' StmAux ()
StmAux (ExpDec (Rep ADM))
aux (Exp (Rep ADM) -> Stm (Rep ADM)) -> Exp (Rep ADM) -> Stm (Rep ADM)
forall a b. (a -> b) -> a -> b
$ [(FParam SOACS, SubExp)] -> LoopForm -> Body SOACS -> Exp SOACS
forall rep.
[(FParam rep, SubExp)] -> LoopForm -> Body rep -> Exp rep
Loop [(Param (TypeBase Shape Uniqueness), SubExp)]
[(FParam SOACS, SubExp)]
val_pats' LoopForm
form Body SOACS
body'

-- | Construct a loop value-pattern for the adjoint of the
-- given variable.
valPatAdj :: VName -> ADM (Param DeclType, SubExp)
valPatAdj :: VName -> ADM (Param (TypeBase Shape Uniqueness), SubExp)
valPatAdj VName
v = do
  VName
v_adj <- VName -> ADM VName
adjVName VName
v
  VName
init_adj <- VName -> ADM VName
lookupAdjVal VName
v
  Type
t <- VName -> ADM Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
init_adj
  (Param (TypeBase Shape Uniqueness), SubExp)
-> ADM (Param (TypeBase Shape Uniqueness), SubExp)
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Attrs
-> VName
-> TypeBase Shape Uniqueness
-> Param (TypeBase Shape Uniqueness)
forall dec. Attrs -> VName -> dec -> Param dec
Param Attrs
forall a. Monoid a => a
mempty VName
v_adj (Type -> Uniqueness -> TypeBase Shape Uniqueness
forall shape.
TypeBase shape NoUniqueness
-> Uniqueness -> TypeBase shape Uniqueness
toDecl Type
t Uniqueness
Unique), VName -> SubExp
Var VName
init_adj)

valPatAdjs :: LoopInfo [VName] -> ADM (LoopInfo [(Param DeclType, SubExp)])
valPatAdjs :: LoopInfo [VName]
-> ADM (LoopInfo [(Param (TypeBase Shape Uniqueness), SubExp)])
valPatAdjs = (([VName] -> ADM [(Param (TypeBase Shape Uniqueness), SubExp)])
-> LoopInfo [VName]
-> ADM (LoopInfo [(Param (TypeBase Shape Uniqueness), SubExp)])
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b.
Monad m =>
(a -> m b) -> LoopInfo a -> m (LoopInfo b)
mapM (([VName] -> ADM [(Param (TypeBase Shape Uniqueness), SubExp)])
 -> LoopInfo [VName]
 -> ADM (LoopInfo [(Param (TypeBase Shape Uniqueness), SubExp)]))
-> ((VName -> ADM (Param (TypeBase Shape Uniqueness), SubExp))
    -> [VName] -> ADM [(Param (TypeBase Shape Uniqueness), SubExp)])
-> (VName -> ADM (Param (TypeBase Shape Uniqueness), SubExp))
-> LoopInfo [VName]
-> ADM (LoopInfo [(Param (TypeBase Shape Uniqueness), SubExp)])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName -> ADM (Param (TypeBase Shape Uniqueness), SubExp))
-> [VName] -> ADM [(Param (TypeBase Shape Uniqueness), SubExp)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM) VName -> ADM (Param (TypeBase Shape Uniqueness), SubExp)
valPatAdj

-- | Reverses a loop by substituting the loop index.
reverseIndices :: Exp SOACS -> ADM (Substitutions, Stms SOACS)
reverseIndices :: Exp SOACS -> ADM (Map VName VName, Stms SOACS)
reverseIndices Exp SOACS
loop = do
  Exp SOACS
-> ([(FParam SOACS, SubExp)]
    -> LoopForm
    -> VName
    -> IntType
    -> SubExp
    -> Body SOACS
    -> ADM (Map VName VName, Stms SOACS))
-> ADM (Map VName VName, Stms SOACS)
forall rep a.
PrettyRep rep =>
Exp rep
-> ([(Param (FParamInfo rep), SubExp)]
    -> LoopForm -> VName -> IntType -> SubExp -> Body rep -> a)
-> a
bindForLoop Exp SOACS
loop (([(FParam SOACS, SubExp)]
  -> LoopForm
  -> VName
  -> IntType
  -> SubExp
  -> Body SOACS
  -> ADM (Map VName VName, Stms SOACS))
 -> ADM (Map VName VName, Stms SOACS))
-> ([(FParam SOACS, SubExp)]
    -> LoopForm
    -> VName
    -> IntType
    -> SubExp
    -> Body SOACS
    -> ADM (Map VName VName, Stms SOACS))
-> ADM (Map VName VName, Stms SOACS)
forall a b. (a -> b) -> a -> b
$ \[(FParam SOACS, SubExp)]
_val_pats LoopForm
form VName
i IntType
it SubExp
bound Body SOACS
_body -> do
    SubExp
bound_minus_one <-
      Scope SOACS -> ADM SubExp -> ADM SubExp
forall a. Scope SOACS -> ADM a -> ADM a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (LoopForm -> Scope SOACS
forall rep. LoopForm -> Scope rep
scopeOfLoopForm LoopForm
form) (ADM SubExp -> ADM SubExp) -> ADM SubExp -> ADM SubExp
forall a b. (a -> b) -> a -> b
$
        let one :: SubExp
one = PrimValue -> SubExp
Constant (PrimValue -> SubExp) -> PrimValue -> SubExp
forall a b. (a -> b) -> a -> b
$ IntValue -> PrimValue
IntValue (IntValue -> PrimValue) -> IntValue -> PrimValue
forall a b. (a -> b) -> a -> b
$ IntType -> Int -> IntValue
forall int. Integral int => IntType -> int -> IntValue
intValue IntType
it (Int
1 :: Int)
         in [Char] -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"bound-1" (Exp (Rep ADM) -> ADM SubExp) -> Exp (Rep ADM) -> ADM SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep ADM)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep ADM)) -> BasicOp -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Overflow -> BinOp
Sub IntType
it Overflow
OverflowUndef) SubExp
bound SubExp
one

    (VName
i_rev, Stms SOACS
i_stms) <- ADM VName -> ADM (VName, Stms (Rep ADM))
forall a. ADM a -> ADM (a, Stms (Rep ADM))
forall (m :: * -> *) a.
MonadBuilder m =>
m a -> m (a, Stms (Rep m))
collectStms (ADM VName -> ADM (VName, Stms (Rep ADM)))
-> ADM VName -> ADM (VName, Stms (Rep ADM))
forall a b. (a -> b) -> a -> b
$
      Scope SOACS -> ADM VName -> ADM VName
forall a. Scope SOACS -> ADM a -> ADM a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (LoopForm -> Scope SOACS
forall rep. LoopForm -> Scope rep
scopeOfLoopForm LoopForm
form) (ADM VName -> ADM VName) -> ADM VName -> ADM VName
forall a b. (a -> b) -> a -> b
$ do
        [Char] -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp (VName -> [Char]
baseString VName
i [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> [Char]
"_rev") (Exp (Rep ADM) -> ADM VName) -> Exp (Rep ADM) -> ADM VName
forall a b. (a -> b) -> a -> b
$
          BasicOp -> Exp (Rep ADM)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep ADM)) -> BasicOp -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$
            BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Overflow -> BinOp
Sub IntType
it Overflow
OverflowWrap) SubExp
bound_minus_one (VName -> SubExp
Var VName
i)

    (Map VName VName, Stms SOACS) -> ADM (Map VName VName, Stms SOACS)
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName -> VName -> Map VName VName
forall k a. k -> a -> Map k a
M.singleton VName
i VName
i_rev, Stms SOACS
i_stms)

-- | Pures a substitution which substitutes values in the reverse
-- loop body with values from the tape.
restore :: Stms SOACS -> [Param DeclType] -> VName -> ADM Substitutions
restore :: Stms SOACS
-> [Param (TypeBase Shape Uniqueness)]
-> VName
-> ADM (Map VName VName)
restore Stms SOACS
stms_adj [Param (TypeBase Shape Uniqueness)]
loop_params' VName
i' =
  [(VName, VName)] -> Map VName VName
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, VName)] -> Map VName VName)
-> ([Maybe (VName, VName)] -> [(VName, VName)])
-> [Maybe (VName, VName)]
-> Map VName VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Maybe (VName, VName)] -> [(VName, VName)]
forall a. [Maybe a] -> [a]
catMaybes ([Maybe (VName, VName)] -> Map VName VName)
-> ADM [Maybe (VName, VName)] -> ADM (Map VName VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Param (TypeBase Shape Uniqueness) -> ADM (Maybe (VName, VName)))
-> [Param (TypeBase Shape Uniqueness)]
-> ADM [Maybe (VName, VName)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM Param (TypeBase Shape Uniqueness) -> ADM (Maybe (VName, VName))
f [Param (TypeBase Shape Uniqueness)]
loop_params'
  where
    dont_copy :: [VName]
dont_copy =
      (Param (TypeBase Shape Uniqueness) -> VName)
-> [Param (TypeBase Shape Uniqueness)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param (TypeBase Shape Uniqueness) -> VName
forall dec. Param dec -> VName
paramName ([Param (TypeBase Shape Uniqueness)] -> [VName])
-> [Param (TypeBase Shape Uniqueness)] -> [VName]
forall a b. (a -> b) -> a -> b
$ (Param (TypeBase Shape Uniqueness) -> Bool)
-> [Param (TypeBase Shape Uniqueness)]
-> [Param (TypeBase Shape Uniqueness)]
forall a. (a -> Bool) -> [a] -> [a]
filter (Attr -> Attrs -> Bool
inAttrs (Name -> Attr
AttrName Name
"true_dep") (Attrs -> Bool)
-> (Param (TypeBase Shape Uniqueness) -> Attrs)
-> Param (TypeBase Shape Uniqueness)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param (TypeBase Shape Uniqueness) -> Attrs
forall dec. Param dec -> Attrs
paramAttrs) [Param (TypeBase Shape Uniqueness)]
loop_params'
    f :: Param (TypeBase Shape Uniqueness) -> ADM (Maybe (VName, VName))
f Param (TypeBase Shape Uniqueness)
p
      | VName
v VName -> [VName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` [VName]
dont_copy = do
          Maybe VName
m_vs <- VName -> ADM (Maybe VName)
lookupLoopTape VName
v
          case Maybe VName
m_vs of
            Maybe VName
Nothing -> Maybe (VName, VName) -> ADM (Maybe (VName, VName))
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe (VName, VName)
forall a. Maybe a
Nothing
            Just VName
vs -> do
              Type
vs_t <- VName -> ADM Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
vs
              SubExp
i_i64' <- IntType -> SubExp -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
IntType -> SubExp -> m SubExp
asIntS IntType
Int64 (SubExp -> ADM SubExp) -> SubExp -> ADM SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
i'
              VName
v' <- [Char] -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"restore" (Exp (Rep ADM) -> ADM VName) -> Exp (Rep ADM) -> ADM VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep ADM)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep ADM)) -> BasicOp -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
vs (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice Type
vs_t [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix SubExp
i_i64']
              Type
t <- VName -> ADM Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
v
              VName
v'' <- case (Type
t, VName
v VName -> [VName] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [VName]
consumed) of
                (Array {}, Bool
True) ->
                  [Char] -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"restore_copy" (Exp (Rep ADM) -> ADM VName) -> Exp (Rep ADM) -> ADM VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep ADM)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep ADM)) -> BasicOp -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ Shape -> SubExp -> BasicOp
Replicate Shape
forall a. Monoid a => a
mempty (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
v'
                (Type, Bool)
_ -> VName -> ADM VName
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
v'
              Maybe (VName, VName) -> ADM (Maybe (VName, VName))
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe (VName, VName) -> ADM (Maybe (VName, VName)))
-> Maybe (VName, VName) -> ADM (Maybe (VName, VName))
forall a b. (a -> b) -> a -> b
$ (VName, VName) -> Maybe (VName, VName)
forall a. a -> Maybe a
Just (VName
v, VName
v'')
      | Bool
otherwise = Maybe (VName, VName) -> ADM (Maybe (VName, VName))
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe (VName, VName)
forall a. Maybe a
Nothing
      where
        v :: VName
v = Param (TypeBase Shape Uniqueness) -> VName
forall dec. Param dec -> VName
paramName Param (TypeBase Shape Uniqueness)
p
        consumed :: [VName]
consumed = Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$ Stms (Aliases SOACS) -> Names
forall rep. Aliased rep => Stms rep -> Names
consumedInStms (Stms (Aliases SOACS) -> Names) -> Stms (Aliases SOACS) -> Names
forall a b. (a -> b) -> a -> b
$ (Stms (Aliases SOACS), AliasesAndConsumed) -> Stms (Aliases SOACS)
forall a b. (a, b) -> a
fst ((Stms (Aliases SOACS), AliasesAndConsumed)
 -> Stms (Aliases SOACS))
-> (Stms (Aliases SOACS), AliasesAndConsumed)
-> Stms (Aliases SOACS)
forall a b. (a -> b) -> a -> b
$ AliasTable
-> Stms SOACS -> (Stms (Aliases SOACS), AliasesAndConsumed)
forall rep.
AliasableRep rep =>
AliasTable -> Stms rep -> (Stms (Aliases rep), AliasesAndConsumed)
Alias.analyseStms AliasTable
forall a. Monoid a => a
mempty Stms SOACS
stms_adj

-- | A type to keep track of and seperate values corresponding to different
-- parts of the loop.
data LoopInfo a = LoopInfo
  { forall a. LoopInfo a -> a
loopRes :: a,
    forall a. LoopInfo a -> a
loopFree :: a,
    forall a. LoopInfo a -> a
loopVals :: a
  }
  deriving ((forall a b. (a -> b) -> LoopInfo a -> LoopInfo b)
-> (forall a b. a -> LoopInfo b -> LoopInfo a) -> Functor LoopInfo
forall a b. a -> LoopInfo b -> LoopInfo a
forall a b. (a -> b) -> LoopInfo a -> LoopInfo b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
$cfmap :: forall a b. (a -> b) -> LoopInfo a -> LoopInfo b
fmap :: forall a b. (a -> b) -> LoopInfo a -> LoopInfo b
$c<$ :: forall a b. a -> LoopInfo b -> LoopInfo a
<$ :: forall a b. a -> LoopInfo b -> LoopInfo a
Functor, (forall m. Monoid m => LoopInfo m -> m)
-> (forall m a. Monoid m => (a -> m) -> LoopInfo a -> m)
-> (forall m a. Monoid m => (a -> m) -> LoopInfo a -> m)
-> (forall a b. (a -> b -> b) -> b -> LoopInfo a -> b)
-> (forall a b. (a -> b -> b) -> b -> LoopInfo a -> b)
-> (forall b a. (b -> a -> b) -> b -> LoopInfo a -> b)
-> (forall b a. (b -> a -> b) -> b -> LoopInfo a -> b)
-> (forall a. (a -> a -> a) -> LoopInfo a -> a)
-> (forall a. (a -> a -> a) -> LoopInfo a -> a)
-> (forall a. LoopInfo a -> [a])
-> (forall a. LoopInfo a -> Bool)
-> (forall a. LoopInfo a -> Int)
-> (forall a. Eq a => a -> LoopInfo a -> Bool)
-> (forall a. Ord a => LoopInfo a -> a)
-> (forall a. Ord a => LoopInfo a -> a)
-> (forall a. Num a => LoopInfo a -> a)
-> (forall a. Num a => LoopInfo a -> a)
-> Foldable LoopInfo
forall a. Eq a => a -> LoopInfo a -> Bool
forall a. Num a => LoopInfo a -> a
forall a. Ord a => LoopInfo a -> a
forall m. Monoid m => LoopInfo m -> m
forall a. LoopInfo a -> Bool
forall a. LoopInfo a -> Int
forall a. LoopInfo a -> [a]
forall a. (a -> a -> a) -> LoopInfo a -> a
forall m a. Monoid m => (a -> m) -> LoopInfo a -> m
forall b a. (b -> a -> b) -> b -> LoopInfo a -> b
forall a b. (a -> b -> b) -> b -> LoopInfo a -> b
forall (t :: * -> *).
(forall m. Monoid m => t m -> m)
-> (forall m a. Monoid m => (a -> m) -> t a -> m)
-> (forall m a. Monoid m => (a -> m) -> t a -> m)
-> (forall a b. (a -> b -> b) -> b -> t a -> b)
-> (forall a b. (a -> b -> b) -> b -> t a -> b)
-> (forall b a. (b -> a -> b) -> b -> t a -> b)
-> (forall b a. (b -> a -> b) -> b -> t a -> b)
-> (forall a. (a -> a -> a) -> t a -> a)
-> (forall a. (a -> a -> a) -> t a -> a)
-> (forall a. t a -> [a])
-> (forall a. t a -> Bool)
-> (forall a. t a -> Int)
-> (forall a. Eq a => a -> t a -> Bool)
-> (forall a. Ord a => t a -> a)
-> (forall a. Ord a => t a -> a)
-> (forall a. Num a => t a -> a)
-> (forall a. Num a => t a -> a)
-> Foldable t
$cfold :: forall m. Monoid m => LoopInfo m -> m
fold :: forall m. Monoid m => LoopInfo m -> m
$cfoldMap :: forall m a. Monoid m => (a -> m) -> LoopInfo a -> m
foldMap :: forall m a. Monoid m => (a -> m) -> LoopInfo a -> m
$cfoldMap' :: forall m a. Monoid m => (a -> m) -> LoopInfo a -> m
foldMap' :: forall m a. Monoid m => (a -> m) -> LoopInfo a -> m
$cfoldr :: forall a b. (a -> b -> b) -> b -> LoopInfo a -> b
foldr :: forall a b. (a -> b -> b) -> b -> LoopInfo a -> b
$cfoldr' :: forall a b. (a -> b -> b) -> b -> LoopInfo a -> b
foldr' :: forall a b. (a -> b -> b) -> b -> LoopInfo a -> b
$cfoldl :: forall b a. (b -> a -> b) -> b -> LoopInfo a -> b
foldl :: forall b a. (b -> a -> b) -> b -> LoopInfo a -> b
$cfoldl' :: forall b a. (b -> a -> b) -> b -> LoopInfo a -> b
foldl' :: forall b a. (b -> a -> b) -> b -> LoopInfo a -> b
$cfoldr1 :: forall a. (a -> a -> a) -> LoopInfo a -> a
foldr1 :: forall a. (a -> a -> a) -> LoopInfo a -> a
$cfoldl1 :: forall a. (a -> a -> a) -> LoopInfo a -> a
foldl1 :: forall a. (a -> a -> a) -> LoopInfo a -> a
$ctoList :: forall a. LoopInfo a -> [a]
toList :: forall a. LoopInfo a -> [a]
$cnull :: forall a. LoopInfo a -> Bool
null :: forall a. LoopInfo a -> Bool
$clength :: forall a. LoopInfo a -> Int
length :: forall a. LoopInfo a -> Int
$celem :: forall a. Eq a => a -> LoopInfo a -> Bool
elem :: forall a. Eq a => a -> LoopInfo a -> Bool
$cmaximum :: forall a. Ord a => LoopInfo a -> a
maximum :: forall a. Ord a => LoopInfo a -> a
$cminimum :: forall a. Ord a => LoopInfo a -> a
minimum :: forall a. Ord a => LoopInfo a -> a
$csum :: forall a. Num a => LoopInfo a -> a
sum :: forall a. Num a => LoopInfo a -> a
$cproduct :: forall a. Num a => LoopInfo a -> a
product :: forall a. Num a => LoopInfo a -> a
Foldable, Functor LoopInfo
Foldable LoopInfo
(Functor LoopInfo, Foldable LoopInfo) =>
(forall (f :: * -> *) a b.
 Applicative f =>
 (a -> f b) -> LoopInfo a -> f (LoopInfo b))
-> (forall (f :: * -> *) a.
    Applicative f =>
    LoopInfo (f a) -> f (LoopInfo a))
-> (forall (m :: * -> *) a b.
    Monad m =>
    (a -> m b) -> LoopInfo a -> m (LoopInfo b))
-> (forall (m :: * -> *) a.
    Monad m =>
    LoopInfo (m a) -> m (LoopInfo a))
-> Traversable LoopInfo
forall (t :: * -> *).
(Functor t, Foldable t) =>
(forall (f :: * -> *) a b.
 Applicative f =>
 (a -> f b) -> t a -> f (t b))
-> (forall (f :: * -> *) a. Applicative f => t (f a) -> f (t a))
-> (forall (m :: * -> *) a b.
    Monad m =>
    (a -> m b) -> t a -> m (t b))
-> (forall (m :: * -> *) a. Monad m => t (m a) -> m (t a))
-> Traversable t
forall (m :: * -> *) a. Monad m => LoopInfo (m a) -> m (LoopInfo a)
forall (f :: * -> *) a.
Applicative f =>
LoopInfo (f a) -> f (LoopInfo a)
forall (m :: * -> *) a b.
Monad m =>
(a -> m b) -> LoopInfo a -> m (LoopInfo b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> LoopInfo a -> f (LoopInfo b)
$ctraverse :: forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> LoopInfo a -> f (LoopInfo b)
traverse :: forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> LoopInfo a -> f (LoopInfo b)
$csequenceA :: forall (f :: * -> *) a.
Applicative f =>
LoopInfo (f a) -> f (LoopInfo a)
sequenceA :: forall (f :: * -> *) a.
Applicative f =>
LoopInfo (f a) -> f (LoopInfo a)
$cmapM :: forall (m :: * -> *) a b.
Monad m =>
(a -> m b) -> LoopInfo a -> m (LoopInfo b)
mapM :: forall (m :: * -> *) a b.
Monad m =>
(a -> m b) -> LoopInfo a -> m (LoopInfo b)
$csequence :: forall (m :: * -> *) a. Monad m => LoopInfo (m a) -> m (LoopInfo a)
sequence :: forall (m :: * -> *) a. Monad m => LoopInfo (m a) -> m (LoopInfo a)
Traversable, Int -> LoopInfo a -> [Char] -> [Char]
[LoopInfo a] -> [Char] -> [Char]
LoopInfo a -> [Char]
(Int -> LoopInfo a -> [Char] -> [Char])
-> (LoopInfo a -> [Char])
-> ([LoopInfo a] -> [Char] -> [Char])
-> Show (LoopInfo a)
forall a. Show a => Int -> LoopInfo a -> [Char] -> [Char]
forall a. Show a => [LoopInfo a] -> [Char] -> [Char]
forall a. Show a => LoopInfo a -> [Char]
forall a.
(Int -> a -> [Char] -> [Char])
-> (a -> [Char]) -> ([a] -> [Char] -> [Char]) -> Show a
$cshowsPrec :: forall a. Show a => Int -> LoopInfo a -> [Char] -> [Char]
showsPrec :: Int -> LoopInfo a -> [Char] -> [Char]
$cshow :: forall a. Show a => LoopInfo a -> [Char]
show :: LoopInfo a -> [Char]
$cshowList :: forall a. Show a => [LoopInfo a] -> [Char] -> [Char]
showList :: [LoopInfo a] -> [Char] -> [Char]
Show)

-- | Transforms a for-loop into its reverse-mode derivative.
revLoop :: (Stms SOACS -> ADM ()) -> Pat Type -> Exp SOACS -> ADM ()
revLoop :: (Stms SOACS -> ADM ()) -> Pat Type -> Exp SOACS -> ADM ()
revLoop Stms SOACS -> ADM ()
diffStms Pat Type
pat Exp SOACS
loop =
  Exp SOACS
-> ([(FParam SOACS, SubExp)]
    -> LoopForm -> VName -> IntType -> SubExp -> Body SOACS -> ADM ())
-> ADM ()
forall rep a.
PrettyRep rep =>
Exp rep
-> ([(Param (FParamInfo rep), SubExp)]
    -> LoopForm -> VName -> IntType -> SubExp -> Body rep -> a)
-> a
bindForLoop Exp SOACS
loop (([(FParam SOACS, SubExp)]
  -> LoopForm -> VName -> IntType -> SubExp -> Body SOACS -> ADM ())
 -> ADM ())
-> ([(FParam SOACS, SubExp)]
    -> LoopForm -> VName -> IntType -> SubExp -> Body SOACS -> ADM ())
-> ADM ()
forall a b. (a -> b) -> a -> b
$ \[(FParam SOACS, SubExp)]
val_pats LoopForm
_form VName
_i IntType
_it SubExp
_bound Body SOACS
_body ->
    Exp SOACS
-> (Exp SOACS
    -> [(FParam SOACS, SubExp)]
    -> LoopForm
    -> VName
    -> IntType
    -> SubExp
    -> Body SOACS
    -> ADM ())
-> ADM ()
forall (m :: * -> *) rep a.
(MonadFreshNames m, Renameable rep, PrettyRep rep) =>
Exp rep
-> (Exp rep
    -> [(Param (FParamInfo rep), SubExp)]
    -> LoopForm
    -> VName
    -> IntType
    -> SubExp
    -> Body rep
    -> m a)
-> m a
renameForLoop Exp SOACS
loop ((Exp SOACS
  -> [(FParam SOACS, SubExp)]
  -> LoopForm
  -> VName
  -> IntType
  -> SubExp
  -> Body SOACS
  -> ADM ())
 -> ADM ())
-> (Exp SOACS
    -> [(FParam SOACS, SubExp)]
    -> LoopForm
    -> VName
    -> IntType
    -> SubExp
    -> Body SOACS
    -> ADM ())
-> ADM ()
forall a b. (a -> b) -> a -> b
$
      \Exp SOACS
loop' [(FParam SOACS, SubExp)]
val_pats' LoopForm
form' VName
i' IntType
_it' SubExp
_bound' Body SOACS
body' -> do
        let loop_params :: [Param (TypeBase Shape Uniqueness)]
loop_params = ((Param (TypeBase Shape Uniqueness), SubExp)
 -> Param (TypeBase Shape Uniqueness))
-> [(Param (TypeBase Shape Uniqueness), SubExp)]
-> [Param (TypeBase Shape Uniqueness)]
forall a b. (a -> b) -> [a] -> [b]
map (Param (TypeBase Shape Uniqueness), SubExp)
-> Param (TypeBase Shape Uniqueness)
forall a b. (a, b) -> a
fst [(Param (TypeBase Shape Uniqueness), SubExp)]
[(FParam SOACS, SubExp)]
val_pats
            ([Param (TypeBase Shape Uniqueness)]
loop_params', [SubExp]
loop_vals') = [(Param (TypeBase Shape Uniqueness), SubExp)]
-> ([Param (TypeBase Shape Uniqueness)], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip [(Param (TypeBase Shape Uniqueness), SubExp)]
[(FParam SOACS, SubExp)]
val_pats'
            getVName :: SubExp -> Maybe VName
getVName Constant {} = Maybe VName
forall a. Maybe a
Nothing
            getVName (Var VName
v) = VName -> Maybe VName
forall a. a -> Maybe a
Just VName
v
            loop_vnames :: LoopInfo [VName]
loop_vnames =
              LoopInfo
                { loopRes :: [VName]
loopRes = (SubExpRes -> Maybe VName) -> Result -> [VName]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe SubExpRes -> Maybe VName
subExpResVName (Result -> [VName]) -> Result -> [VName]
forall a b. (a -> b) -> a -> b
$ Body SOACS -> Result
forall rep. Body rep -> Result
bodyResult Body SOACS
body',
                  loopFree :: [VName]
loopFree =
                    Names -> [VName]
namesToList (Exp SOACS -> Names
forall a. FreeIn a => a -> Names
freeIn Exp SOACS
loop') [VName] -> [VName] -> [VName]
forall a. Eq a => [a] -> [a] -> [a]
\\ (SubExp -> Maybe VName) -> [SubExp] -> [VName]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe SubExp -> Maybe VName
getVName [SubExp]
loop_vals',
                  loopVals :: [VName]
loopVals = [VName] -> [VName]
forall a. Ord a => [a] -> [a]
nubOrd ([VName] -> [VName]) -> [VName] -> [VName]
forall a b. (a -> b) -> a -> b
$ (SubExp -> Maybe VName) -> [SubExp] -> [VName]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe SubExp -> Maybe VName
getVName [SubExp]
loop_vals'
                }

        Map VName VName -> ADM ()
renameLoopTape (Map VName VName -> ADM ()) -> Map VName VName -> ADM ()
forall a b. (a -> b) -> a -> b
$ [(VName, VName)] -> Map VName VName
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, VName)] -> Map VName VName)
-> [(VName, VName)] -> Map VName VName
forall a b. (a -> b) -> a -> b
$ [VName] -> [VName] -> [(VName, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip ((Param (TypeBase Shape Uniqueness) -> VName)
-> [Param (TypeBase Shape Uniqueness)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param (TypeBase Shape Uniqueness) -> VName
forall dec. Param dec -> VName
paramName [Param (TypeBase Shape Uniqueness)]
loop_params) ((Param (TypeBase Shape Uniqueness) -> VName)
-> [Param (TypeBase Shape Uniqueness)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param (TypeBase Shape Uniqueness) -> VName
forall dec. Param dec -> VName
paramName [Param (TypeBase Shape Uniqueness)]
loop_params')

        [(SubExpRes, PatElem Type)]
-> ((SubExpRes, PatElem Type) -> ADM ()) -> ADM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (Result -> [PatElem Type] -> [(SubExpRes, PatElem Type)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Body SOACS -> Result
forall rep. Body rep -> Result
bodyResult Body SOACS
body') ([PatElem Type] -> [(SubExpRes, PatElem Type)])
-> [PatElem Type] -> [(SubExpRes, PatElem Type)]
forall a b. (a -> b) -> a -> b
$ Pat Type -> [PatElem Type]
forall dec. Pat dec -> [PatElem dec]
patElems Pat Type
pat) (((SubExpRes, PatElem Type) -> ADM ()) -> ADM ())
-> ((SubExpRes, PatElem Type) -> ADM ()) -> ADM ()
forall a b. (a -> b) -> a -> b
$ \(SubExpRes
se_res, PatElem Type
pe) ->
          case SubExpRes -> Maybe VName
subExpResVName SubExpRes
se_res of
            Just VName
v -> VName -> Adj -> ADM ()
setAdj VName
v (Adj -> ADM ()) -> ADM Adj -> ADM ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName -> ADM Adj
lookupAdj (PatElem Type -> VName
forall dec. PatElem dec -> VName
patElemName PatElem Type
pe)
            Maybe VName
Nothing -> () -> ADM ()
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

        (Map VName VName
i_subst, Stms SOACS
i_stms) <- Exp SOACS -> ADM (Map VName VName, Stms SOACS)
reverseIndices Exp SOACS
loop'

        LoopInfo [(Param (TypeBase Shape Uniqueness), SubExp)]
val_pat_adjs <- LoopInfo [VName]
-> ADM (LoopInfo [(Param (TypeBase Shape Uniqueness), SubExp)])
valPatAdjs LoopInfo [VName]
loop_vnames
        let val_pat_adjs_list :: [(Param (TypeBase Shape Uniqueness), SubExp)]
val_pat_adjs_list = [[(Param (TypeBase Shape Uniqueness), SubExp)]]
-> [(Param (TypeBase Shape Uniqueness), SubExp)]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[(Param (TypeBase Shape Uniqueness), SubExp)]]
 -> [(Param (TypeBase Shape Uniqueness), SubExp)])
-> [[(Param (TypeBase Shape Uniqueness), SubExp)]]
-> [(Param (TypeBase Shape Uniqueness), SubExp)]
forall a b. (a -> b) -> a -> b
$ LoopInfo [(Param (TypeBase Shape Uniqueness), SubExp)]
-> [[(Param (TypeBase Shape Uniqueness), SubExp)]]
forall a. LoopInfo a -> [a]
forall (t :: * -> *) a. Foldable t => t a -> [a]
toList LoopInfo [(Param (TypeBase Shape Uniqueness), SubExp)]
val_pat_adjs

        (LoopInfo [VName]
loop_adjs, Stms SOACS
stms_adj) <- ADM (LoopInfo [VName]) -> ADM (LoopInfo [VName], Stms (Rep ADM))
forall a. ADM a -> ADM (a, Stms (Rep ADM))
forall (m :: * -> *) a.
MonadBuilder m =>
m a -> m (a, Stms (Rep m))
collectStms (ADM (LoopInfo [VName]) -> ADM (LoopInfo [VName], Stms (Rep ADM)))
-> ADM (LoopInfo [VName]) -> ADM (LoopInfo [VName], Stms (Rep ADM))
forall a b. (a -> b) -> a -> b
$
          Scope SOACS -> ADM (LoopInfo [VName]) -> ADM (LoopInfo [VName])
forall a. Scope SOACS -> ADM a -> ADM a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (LoopForm -> Scope SOACS
forall rep. LoopForm -> Scope rep
scopeOfLoopForm LoopForm
form' Scope SOACS -> Scope SOACS -> Scope SOACS
forall a. Semigroup a => a -> a -> a
<> [Param (TypeBase Shape Uniqueness)] -> Scope SOACS
forall rep dec. (FParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfFParams (((Param (TypeBase Shape Uniqueness), SubExp)
 -> Param (TypeBase Shape Uniqueness))
-> [(Param (TypeBase Shape Uniqueness), SubExp)]
-> [Param (TypeBase Shape Uniqueness)]
forall a b. (a -> b) -> [a] -> [b]
map (Param (TypeBase Shape Uniqueness), SubExp)
-> Param (TypeBase Shape Uniqueness)
forall a b. (a, b) -> a
fst [(Param (TypeBase Shape Uniqueness), SubExp)]
val_pat_adjs_list [Param (TypeBase Shape Uniqueness)]
-> [Param (TypeBase Shape Uniqueness)]
-> [Param (TypeBase Shape Uniqueness)]
forall a. Semigroup a => a -> a -> a
<> [Param (TypeBase Shape Uniqueness)]
loop_params')) (ADM (LoopInfo [VName]) -> ADM (LoopInfo [VName]))
-> ADM (LoopInfo [VName]) -> ADM (LoopInfo [VName])
forall a b. (a -> b) -> a -> b
$ do
            Stms (Rep ADM) -> ADM ()
forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms Stms (Rep ADM)
Stms SOACS
i_stms
            (LoopInfo [VName]
loop_adjs, Stms SOACS
stms_adj) <- ADM (LoopInfo [VName]) -> ADM (LoopInfo [VName], Stms (Rep ADM))
forall a. ADM a -> ADM (a, Stms (Rep ADM))
forall (m :: * -> *) a.
MonadBuilder m =>
m a -> m (a, Stms (Rep m))
collectStms (ADM (LoopInfo [VName]) -> ADM (LoopInfo [VName], Stms (Rep ADM)))
-> ADM (LoopInfo [VName]) -> ADM (LoopInfo [VName], Stms (Rep ADM))
forall a b. (a -> b) -> a -> b
$
              ADM (LoopInfo [VName]) -> ADM (LoopInfo [VName])
forall a. ADM a -> ADM a
subAD (ADM (LoopInfo [VName]) -> ADM (LoopInfo [VName]))
-> ADM (LoopInfo [VName]) -> ADM (LoopInfo [VName])
forall a b. (a -> b) -> a -> b
$ do
                ((Param (TypeBase Shape Uniqueness), SubExp) -> VName -> ADM ())
-> [(Param (TypeBase Shape Uniqueness), SubExp)]
-> [VName]
-> ADM ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_
                  (\(Param (TypeBase Shape Uniqueness), SubExp)
val_pat VName
v -> VName -> VName -> ADM ()
insAdj VName
v (Param (TypeBase Shape Uniqueness) -> VName
forall dec. Param dec -> VName
paramName (Param (TypeBase Shape Uniqueness) -> VName)
-> Param (TypeBase Shape Uniqueness) -> VName
forall a b. (a -> b) -> a -> b
$ (Param (TypeBase Shape Uniqueness), SubExp)
-> Param (TypeBase Shape Uniqueness)
forall a b. (a, b) -> a
fst (Param (TypeBase Shape Uniqueness), SubExp)
val_pat))
                  [(Param (TypeBase Shape Uniqueness), SubExp)]
val_pat_adjs_list
                  ([[VName]] -> [VName]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[VName]] -> [VName]) -> [[VName]] -> [VName]
forall a b. (a -> b) -> a -> b
$ LoopInfo [VName] -> [[VName]]
forall a. LoopInfo a -> [a]
forall (t :: * -> *) a. Foldable t => t a -> [a]
toList LoopInfo [VName]
loop_vnames)
                Stms SOACS -> ADM ()
diffStms (Stms SOACS -> ADM ()) -> Stms SOACS -> ADM ()
forall a b. (a -> b) -> a -> b
$ Body SOACS -> Stms SOACS
forall rep. Body rep -> Stms rep
bodyStms Body SOACS
body'

                [VName]
loop_res_adjs <- (Param (TypeBase Shape Uniqueness) -> ADM VName)
-> [Param (TypeBase Shape Uniqueness)] -> ADM [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (VName -> ADM VName
lookupAdjVal (VName -> ADM VName)
-> (Param (TypeBase Shape Uniqueness) -> VName)
-> Param (TypeBase Shape Uniqueness)
-> ADM VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param (TypeBase Shape Uniqueness) -> VName
forall dec. Param dec -> VName
paramName) [Param (TypeBase Shape Uniqueness)]
loop_params'
                [VName]
loop_free_adjs <- (VName -> ADM VName) -> [VName] -> ADM [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM VName -> ADM VName
lookupAdjVal ([VName] -> ADM [VName]) -> [VName] -> ADM [VName]
forall a b. (a -> b) -> a -> b
$ LoopInfo [VName] -> [VName]
forall a. LoopInfo a -> a
loopFree LoopInfo [VName]
loop_vnames
                [VName]
loop_vals_adjs <- (VName -> ADM VName) -> [VName] -> ADM [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM VName -> ADM VName
lookupAdjVal ([VName] -> ADM [VName]) -> [VName] -> ADM [VName]
forall a b. (a -> b) -> a -> b
$ LoopInfo [VName] -> [VName]
forall a. LoopInfo a -> a
loopVals LoopInfo [VName]
loop_vnames

                LoopInfo [VName] -> ADM (LoopInfo [VName])
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (LoopInfo [VName] -> ADM (LoopInfo [VName]))
-> LoopInfo [VName] -> ADM (LoopInfo [VName])
forall a b. (a -> b) -> a -> b
$
                  LoopInfo
                    { loopRes :: [VName]
loopRes = [VName]
loop_res_adjs,
                      loopFree :: [VName]
loopFree = [VName]
loop_free_adjs,
                      loopVals :: [VName]
loopVals = [VName]
loop_vals_adjs
                    }
            (Map VName VName
substs, Stms SOACS
restore_stms) <-
              ADM (Map VName VName) -> ADM (Map VName VName, Stms (Rep ADM))
forall a. ADM a -> ADM (a, Stms (Rep ADM))
forall (m :: * -> *) a.
MonadBuilder m =>
m a -> m (a, Stms (Rep m))
collectStms (ADM (Map VName VName) -> ADM (Map VName VName, Stms (Rep ADM)))
-> ADM (Map VName VName) -> ADM (Map VName VName, Stms (Rep ADM))
forall a b. (a -> b) -> a -> b
$ Stms SOACS
-> [Param (TypeBase Shape Uniqueness)]
-> VName
-> ADM (Map VName VName)
restore Stms SOACS
stms_adj [Param (TypeBase Shape Uniqueness)]
loop_params' VName
i'
            Stms (Rep ADM) -> ADM ()
forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms (Stms (Rep ADM) -> ADM ()) -> Stms (Rep ADM) -> ADM ()
forall a b. (a -> b) -> a -> b
$ Map VName VName -> Stms SOACS -> Stms SOACS
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
i_subst Stms SOACS
restore_stms
            Stms (Rep ADM) -> ADM ()
forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms (Stms (Rep ADM) -> ADM ()) -> Stms (Rep ADM) -> ADM ()
forall a b. (a -> b) -> a -> b
$ Map VName VName -> Stms (Rep ADM) -> Stms (Rep ADM)
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
i_subst (Stms (Rep ADM) -> Stms (Rep ADM))
-> Stms (Rep ADM) -> Stms (Rep ADM)
forall a b. (a -> b) -> a -> b
$ Map VName VName -> Stms SOACS -> Stms SOACS
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs Stms SOACS
stms_adj
            LoopInfo [VName] -> ADM (LoopInfo [VName])
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure LoopInfo [VName]
loop_adjs

        Stms SOACS -> ADM () -> ADM ()
forall rep a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf Stms SOACS
stms_adj (ADM () -> ADM ()) -> ADM () -> ADM ()
forall a b. (a -> b) -> a -> b
$
          Scope SOACS -> ADM () -> ADM ()
forall a. Scope SOACS -> ADM a -> ADM a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope ([Param (TypeBase Shape Uniqueness)] -> Scope SOACS
forall rep dec. (FParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfFParams ([Param (TypeBase Shape Uniqueness)] -> Scope SOACS)
-> [Param (TypeBase Shape Uniqueness)] -> Scope SOACS
forall a b. (a -> b) -> a -> b
$ ((Param (TypeBase Shape Uniqueness), SubExp)
 -> Param (TypeBase Shape Uniqueness))
-> [(Param (TypeBase Shape Uniqueness), SubExp)]
-> [Param (TypeBase Shape Uniqueness)]
forall a b. (a -> b) -> [a] -> [b]
map (Param (TypeBase Shape Uniqueness), SubExp)
-> Param (TypeBase Shape Uniqueness)
forall a b. (a, b) -> a
fst [(Param (TypeBase Shape Uniqueness), SubExp)]
val_pat_adjs_list) (ADM () -> ADM ()) -> ADM () -> ADM ()
forall a b. (a -> b) -> a -> b
$ do
            let body_adj :: Body SOACS
body_adj = Stms SOACS -> Result -> Body SOACS
forall rep. Buildable rep => Stms rep -> Result -> Body rep
mkBody Stms SOACS
stms_adj (Result -> Body SOACS) -> Result -> Body SOACS
forall a b. (a -> b) -> a -> b
$ [VName] -> Result
varsRes ([VName] -> Result) -> [VName] -> Result
forall a b. (a -> b) -> a -> b
$ [[VName]] -> [VName]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[VName]] -> [VName]) -> [[VName]] -> [VName]
forall a b. (a -> b) -> a -> b
$ LoopInfo [VName] -> [[VName]]
forall a. LoopInfo a -> [a]
forall (t :: * -> *) a. Foldable t => t a -> [a]
toList LoopInfo [VName]
loop_adjs
                restore_true_deps :: Map VName VName
restore_true_deps = [(VName, VName)] -> Map VName VName
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, VName)] -> Map VName VName)
-> [(VName, VName)] -> Map VName VName
forall a b. (a -> b) -> a -> b
$
                  (((Param (TypeBase Shape Uniqueness), PatElem Type)
  -> Maybe (VName, VName))
 -> [(Param (TypeBase Shape Uniqueness), PatElem Type)]
 -> [(VName, VName)])
-> [(Param (TypeBase Shape Uniqueness), PatElem Type)]
-> ((Param (TypeBase Shape Uniqueness), PatElem Type)
    -> Maybe (VName, VName))
-> [(VName, VName)]
forall a b c. (a -> b -> c) -> b -> a -> c
flip ((Param (TypeBase Shape Uniqueness), PatElem Type)
 -> Maybe (VName, VName))
-> [(Param (TypeBase Shape Uniqueness), PatElem Type)]
-> [(VName, VName)]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe ([Param (TypeBase Shape Uniqueness)]
-> [PatElem Type]
-> [(Param (TypeBase Shape Uniqueness), PatElem Type)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param (TypeBase Shape Uniqueness)]
loop_params' ([PatElem Type]
 -> [(Param (TypeBase Shape Uniqueness), PatElem Type)])
-> [PatElem Type]
-> [(Param (TypeBase Shape Uniqueness), PatElem Type)]
forall a b. (a -> b) -> a -> b
$ Pat Type -> [PatElem Type]
forall dec. Pat dec -> [PatElem dec]
patElems Pat Type
pat) (((Param (TypeBase Shape Uniqueness), PatElem Type)
  -> Maybe (VName, VName))
 -> [(VName, VName)])
-> ((Param (TypeBase Shape Uniqueness), PatElem Type)
    -> Maybe (VName, VName))
-> [(VName, VName)]
forall a b. (a -> b) -> a -> b
$ \(Param (TypeBase Shape Uniqueness)
p, PatElem Type
pe) ->
                    if Param (TypeBase Shape Uniqueness)
p Param (TypeBase Shape Uniqueness)
-> [Param (TypeBase Shape Uniqueness)] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` (Param (TypeBase Shape Uniqueness) -> Bool)
-> [Param (TypeBase Shape Uniqueness)]
-> [Param (TypeBase Shape Uniqueness)]
forall a. (a -> Bool) -> [a] -> [a]
filter (Attr -> Attrs -> Bool
inAttrs (Name -> Attr
AttrName Name
"true_dep") (Attrs -> Bool)
-> (Param (TypeBase Shape Uniqueness) -> Attrs)
-> Param (TypeBase Shape Uniqueness)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param (TypeBase Shape Uniqueness) -> Attrs
forall dec. Param dec -> Attrs
paramAttrs) [Param (TypeBase Shape Uniqueness)]
loop_params'
                      then (VName, VName) -> Maybe (VName, VName)
forall a. a -> Maybe a
Just (Param (TypeBase Shape Uniqueness) -> VName
forall dec. Param dec -> VName
paramName Param (TypeBase Shape Uniqueness)
p, PatElem Type -> VName
forall dec. PatElem dec -> VName
patElemName PatElem Type
pe)
                      else Maybe (VName, VName)
forall a. Maybe a
Nothing
            [VName]
adjs' <-
              [Char] -> Exp (Rep ADM) -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m [VName]
letTupExp [Char]
"loop_adj" (Exp (Rep ADM) -> ADM [VName]) -> Exp (Rep ADM) -> ADM [VName]
forall a b. (a -> b) -> a -> b
$
                Map VName VName -> Exp (Rep ADM) -> Exp (Rep ADM)
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
restore_true_deps (Exp (Rep ADM) -> Exp (Rep ADM)) -> Exp (Rep ADM) -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$
                  [(FParam SOACS, SubExp)] -> LoopForm -> Body SOACS -> Exp SOACS
forall rep.
[(FParam rep, SubExp)] -> LoopForm -> Body rep -> Exp rep
Loop [(Param (TypeBase Shape Uniqueness), SubExp)]
[(FParam SOACS, SubExp)]
val_pat_adjs_list LoopForm
form' Body SOACS
body_adj
            let ([VName]
loop_res_adjs, [VName]
loop_free_var_val_adjs) =
                  Int -> [VName] -> ([VName], [VName])
forall a. Int -> [a] -> ([a], [a])
splitAt ([VName] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([VName] -> Int) -> [VName] -> Int
forall a b. (a -> b) -> a -> b
$ LoopInfo [VName] -> [VName]
forall a. LoopInfo a -> a
loopRes LoopInfo [VName]
loop_adjs) [VName]
adjs'
                ([VName]
loop_free_adjs, [VName]
loop_val_adjs) =
                  Int -> [VName] -> ([VName], [VName])
forall a. Int -> [a] -> ([a], [a])
splitAt ([VName] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([VName] -> Int) -> [VName] -> Int
forall a b. (a -> b) -> a -> b
$ LoopInfo [VName] -> [VName]
forall a. LoopInfo a -> a
loopFree LoopInfo [VName]
loop_adjs) [VName]
loop_free_var_val_adjs
            ADM () -> ADM ()
forall a. ADM a -> ADM a
returnSweepCode (ADM () -> ADM ()) -> ADM () -> ADM ()
forall a b. (a -> b) -> a -> b
$ do
              (SubExp -> VName -> ADM ()) -> [SubExp] -> [VName] -> ADM ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ SubExp -> VName -> ADM ()
updateSubExpAdj [SubExp]
loop_vals' [VName]
loop_res_adjs
              (VName -> VName -> ADM ()) -> [VName] -> [VName] -> ADM ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> VName -> ADM ()
insAdj (LoopInfo [VName] -> [VName]
forall a. LoopInfo a -> a
loopFree LoopInfo [VName]
loop_vnames) [VName]
loop_free_adjs
              (VName -> VName -> ADM ()) -> [VName] -> [VName] -> ADM ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> VName -> ADM ()
updateAdj (LoopInfo [VName] -> [VName]
forall a. LoopInfo a -> a
loopVals LoopInfo [VName]
loop_vnames) [VName]
loop_val_adjs

-- | Transforms a loop into its reverse-mode derivative.
diffLoop :: (Stms SOACS -> ADM ()) -> Pat Type -> StmAux () -> Exp SOACS -> ADM () -> ADM ()
diffLoop :: (Stms SOACS -> ADM ())
-> Pat Type -> StmAux () -> Exp SOACS -> ADM () -> ADM ()
diffLoop Stms SOACS -> ADM ()
diffStms Pat Type
pat StmAux ()
aux Exp SOACS
loop ADM ()
m
  | Exp SOACS -> Bool
forall rep. Exp rep -> Bool
isWhileLoop Exp SOACS
loop =
      let getBound :: Attr -> Maybe Integer
getBound (AttrComp Name
"bound" [AttrInt Integer
b]) = Integer -> Maybe Integer
forall a. a -> Maybe a
Just Integer
b
          getBound Attr
_ = Maybe Integer
forall a. Maybe a
Nothing
          bounds :: [Integer]
bounds = [Maybe Integer] -> [Integer]
forall a. [Maybe a] -> [a]
catMaybes ([Maybe Integer] -> [Integer]) -> [Maybe Integer] -> [Integer]
forall a b. (a -> b) -> a -> b
$ (Attr -> Maybe Integer) -> Attrs -> [Maybe Integer]
forall a. (Attr -> a) -> Attrs -> [a]
mapAttrs Attr -> Maybe Integer
getBound (Attrs -> [Maybe Integer]) -> Attrs -> [Maybe Integer]
forall a b. (a -> b) -> a -> b
$ StmAux () -> Attrs
forall dec. StmAux dec -> Attrs
stmAuxAttrs StmAux ()
aux
       in case [Integer]
bounds of
            (Integer
bound : [Integer]
_) -> do
              let bound_se :: SubExp
bound_se = PrimValue -> SubExp
Constant (PrimValue -> SubExp) -> PrimValue -> SubExp
forall a b. (a -> b) -> a -> b
$ IntValue -> PrimValue
IntValue (IntValue -> PrimValue) -> IntValue -> PrimValue
forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> IntValue
forall int. Integral int => IntType -> int -> IntValue
intValue IntType
Int64 Integer
bound
              Exp SOACS
for_loop <- SubExp -> Exp SOACS -> ADM (Exp SOACS)
convertWhileLoop SubExp
bound_se Exp SOACS
loop
              (Stms SOACS -> ADM ())
-> Pat Type -> StmAux () -> Exp SOACS -> ADM () -> ADM ()
diffLoop Stms SOACS -> ADM ()
diffStms Pat Type
pat StmAux ()
aux Exp SOACS
for_loop ADM ()
m
            [Integer]
_ -> do
              SubExp
bound <- Exp SOACS -> ADM SubExp
computeWhileIters Exp SOACS
loop
              Exp SOACS
for_loop <- SubExp -> Exp SOACS -> ADM (Exp SOACS)
convertWhileLoop SubExp
bound (Exp SOACS -> ADM (Exp SOACS))
-> ADM (Exp SOACS) -> ADM (Exp SOACS)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Exp SOACS -> ADM (Exp SOACS)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Exp rep -> m (Exp rep)
renameExp Exp SOACS
loop
              (Stms SOACS -> ADM ())
-> Pat Type -> StmAux () -> Exp SOACS -> ADM () -> ADM ()
diffLoop Stms SOACS -> ADM ()
diffStms Pat Type
pat StmAux ()
aux Exp SOACS
for_loop ADM ()
m
  | Bool
otherwise = do
      Pat Type -> StmAux () -> Exp SOACS -> ADM ()
fwdLoop Pat Type
pat StmAux ()
aux Exp SOACS
loop
      ADM ()
m
      (Stms SOACS -> ADM ()) -> Pat Type -> Exp SOACS -> ADM ()
revLoop Stms SOACS -> ADM ()
diffStms Pat Type
pat Exp SOACS
loop