{-# LANGUAGE TypeFamilies #-}

module Futhark.Optimise.InPlaceLowering.LowerIntoStm
  ( lowerUpdateGPU,
    lowerUpdate,
    LowerUpdate,
    DesiredUpdate (..),
  )
where

import Control.Monad
import Control.Monad.Writer
import Data.Either
import Data.List (find, unzip4)
import Data.Maybe (isNothing, mapMaybe)
import Futhark.Analysis.PrimExp.Convert
import Futhark.Construct
import Futhark.IR.Aliases
import Futhark.IR.GPU
import Futhark.Optimise.InPlaceLowering.SubstituteIndices

data DesiredUpdate dec = DesiredUpdate
  { -- | Name of result.
    forall dec. DesiredUpdate dec -> VName
updateName :: VName,
    -- | Type of result.
    forall dec. DesiredUpdate dec -> dec
updateType :: dec,
    forall dec. DesiredUpdate dec -> Certs
updateCerts :: Certs,
    forall dec. DesiredUpdate dec -> VName
updateSource :: VName,
    forall dec. DesiredUpdate dec -> Slice SubExp
updateIndices :: Slice SubExp,
    forall dec. DesiredUpdate dec -> VName
updateValue :: VName
  }
  deriving (Int -> DesiredUpdate dec -> ShowS
[DesiredUpdate dec] -> ShowS
DesiredUpdate dec -> String
(Int -> DesiredUpdate dec -> ShowS)
-> (DesiredUpdate dec -> String)
-> ([DesiredUpdate dec] -> ShowS)
-> Show (DesiredUpdate dec)
forall dec. Show dec => Int -> DesiredUpdate dec -> ShowS
forall dec. Show dec => [DesiredUpdate dec] -> ShowS
forall dec. Show dec => DesiredUpdate dec -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: forall dec. Show dec => Int -> DesiredUpdate dec -> ShowS
showsPrec :: Int -> DesiredUpdate dec -> ShowS
$cshow :: forall dec. Show dec => DesiredUpdate dec -> String
show :: DesiredUpdate dec -> String
$cshowList :: forall dec. Show dec => [DesiredUpdate dec] -> ShowS
showList :: [DesiredUpdate dec] -> ShowS
Show)

instance Functor DesiredUpdate where
  a -> b
f fmap :: forall a b. (a -> b) -> DesiredUpdate a -> DesiredUpdate b
`fmap` DesiredUpdate a
u = DesiredUpdate a
u {updateType :: b
updateType = a -> b
f (a -> b) -> a -> b
forall a b. (a -> b) -> a -> b
$ DesiredUpdate a -> a
forall dec. DesiredUpdate dec -> dec
updateType DesiredUpdate a
u}

updateHasValue :: VName -> DesiredUpdate dec -> Bool
updateHasValue :: forall dec. VName -> DesiredUpdate dec -> Bool
updateHasValue VName
name = (VName
name ==) (VName -> Bool)
-> (DesiredUpdate dec -> VName) -> DesiredUpdate dec -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DesiredUpdate dec -> VName
forall dec. DesiredUpdate dec -> VName
updateValue

type LowerUpdate rep m =
  Scope (Aliases rep) ->
  Stm (Aliases rep) ->
  [DesiredUpdate (LetDec (Aliases rep))] ->
  Maybe (m [Stm (Aliases rep)])

lowerUpdate ::
  ( MonadFreshNames m,
    Buildable rep,
    LetDec rep ~ Type,
    AliasableRep rep
  ) =>
  LowerUpdate rep m
lowerUpdate :: forall (m :: * -> *) rep.
(MonadFreshNames m, Buildable rep, LetDec rep ~ Type,
 AliasableRep rep) =>
LowerUpdate rep m
lowerUpdate Scope (Aliases rep)
scope (Let Pat (LetDec (Aliases rep))
pat StmAux (ExpDec (Aliases rep))
aux (DoLoop [(FParam (Aliases rep), SubExp)]
merge LoopForm (Aliases rep)
form Body (Aliases rep)
body)) [DesiredUpdate (LetDec (Aliases rep))]
updates = do
  m ([Stm (Aliases rep)], [Stm (Aliases rep)], [Ident],
   [(Param DeclType, SubExp)], Body (Aliases rep))
canDo <- Scope (Aliases rep)
-> [DesiredUpdate (LetDec (Aliases rep))]
-> Pat (LetDec (Aliases rep))
-> [(FParam (Aliases rep), SubExp)]
-> LoopForm (Aliases rep)
-> Body (Aliases rep)
-> Maybe
     (m ([Stm (Aliases rep)], [Stm (Aliases rep)], [Ident],
         [(FParam (Aliases rep), SubExp)], Body (Aliases rep)))
forall rep als (m :: * -> *).
(Buildable rep, BuilderOps rep, Aliased rep,
 LetDec rep ~ (als, Type), MonadFreshNames m) =>
Scope rep
-> [DesiredUpdate (LetDec rep)]
-> Pat (LetDec rep)
-> [(FParam rep, SubExp)]
-> LoopForm rep
-> Body rep
-> Maybe
     (m ([Stm rep], [Stm rep], [Ident], [(FParam rep, SubExp)],
         Body rep))
lowerUpdateIntoLoop Scope (Aliases rep)
scope [DesiredUpdate (LetDec (Aliases rep))]
updates Pat (LetDec (Aliases rep))
pat [(FParam (Aliases rep), SubExp)]
merge LoopForm (Aliases rep)
form Body (Aliases rep)
body
  m [Stm (Aliases rep)] -> Maybe (m [Stm (Aliases rep)])
forall a. a -> Maybe a
Just (m [Stm (Aliases rep)] -> Maybe (m [Stm (Aliases rep)]))
-> m [Stm (Aliases rep)] -> Maybe (m [Stm (Aliases rep)])
forall a b. (a -> b) -> a -> b
$ do
    ([Stm (Aliases rep)]
prestms, [Stm (Aliases rep)]
poststms, [Ident]
pat', [(Param DeclType, SubExp)]
merge', Body (Aliases rep)
body') <- m ([Stm (Aliases rep)], [Stm (Aliases rep)], [Ident],
   [(Param DeclType, SubExp)], Body (Aliases rep))
canDo
    [Stm (Aliases rep)] -> m [Stm (Aliases rep)]
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([Stm (Aliases rep)] -> m [Stm (Aliases rep)])
-> [Stm (Aliases rep)] -> m [Stm (Aliases rep)]
forall a b. (a -> b) -> a -> b
$
      [Stm (Aliases rep)]
prestms
        [Stm (Aliases rep)] -> [Stm (Aliases rep)] -> [Stm (Aliases rep)]
forall a. [a] -> [a] -> [a]
++ [ Certs -> Stm (Aliases rep) -> Stm (Aliases rep)
forall rep. Certs -> Stm rep -> Stm rep
certify (StmAux (VarAliases, ExpDec rep) -> Certs
forall dec. StmAux dec -> Certs
stmAuxCerts StmAux (VarAliases, ExpDec rep)
StmAux (ExpDec (Aliases rep))
aux) (Stm (Aliases rep) -> Stm (Aliases rep))
-> Stm (Aliases rep) -> Stm (Aliases rep)
forall a b. (a -> b) -> a -> b
$
               [Ident] -> Exp (Aliases rep) -> Stm (Aliases rep)
forall rep. Buildable rep => [Ident] -> Exp rep -> Stm rep
mkLet [Ident]
pat' (Exp (Aliases rep) -> Stm (Aliases rep))
-> Exp (Aliases rep) -> Stm (Aliases rep)
forall a b. (a -> b) -> a -> b
$
                 [(FParam (Aliases rep), SubExp)]
-> LoopForm (Aliases rep)
-> Body (Aliases rep)
-> Exp (Aliases rep)
forall rep.
[(FParam rep, SubExp)] -> LoopForm rep -> Body rep -> Exp rep
DoLoop [(Param DeclType, SubExp)]
[(FParam (Aliases rep), SubExp)]
merge' LoopForm (Aliases rep)
form Body (Aliases rep)
body'
           ]
        [Stm (Aliases rep)] -> [Stm (Aliases rep)] -> [Stm (Aliases rep)]
forall a. [a] -> [a] -> [a]
++ [Stm (Aliases rep)]
poststms
lowerUpdate
  Scope (Aliases rep)
_
  (Let Pat (LetDec (Aliases rep))
pat StmAux (ExpDec (Aliases rep))
aux (BasicOp (SubExp (Var VName
v))))
  [DesiredUpdate VName
bindee_nm LetDec (Aliases rep)
bindee_dec Certs
cs VName
src (Slice [DimIndex SubExp]
is) VName
val]
    | Pat (VarAliases, Type) -> [VName]
forall dec. Pat dec -> [VName]
patNames Pat (VarAliases, Type)
Pat (LetDec (Aliases rep))
pat [VName] -> [VName] -> Bool
forall a. Eq a => a -> a -> Bool
== [VName
src] =
        let is' :: Slice SubExp
is' = Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice ((VarAliases, Type) -> Type
forall t. Typed t => t -> Type
typeOf (VarAliases, Type)
LetDec (Aliases rep)
bindee_dec) [DimIndex SubExp]
is
         in m [Stm (Aliases rep)] -> Maybe (m [Stm (Aliases rep)])
forall a. a -> Maybe a
Just (m [Stm (Aliases rep)] -> Maybe (m [Stm (Aliases rep)]))
-> ([Stm (Aliases rep)] -> m [Stm (Aliases rep)])
-> [Stm (Aliases rep)]
-> Maybe (m [Stm (Aliases rep)])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Stm (Aliases rep)] -> m [Stm (Aliases rep)]
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([Stm (Aliases rep)] -> Maybe (m [Stm (Aliases rep)]))
-> [Stm (Aliases rep)] -> Maybe (m [Stm (Aliases rep)])
forall a b. (a -> b) -> a -> b
$
              [ Certs -> Stm (Aliases rep) -> Stm (Aliases rep)
forall rep. Certs -> Stm rep -> Stm rep
certify (StmAux (VarAliases, ExpDec rep) -> Certs
forall dec. StmAux dec -> Certs
stmAuxCerts StmAux (VarAliases, ExpDec rep)
StmAux (ExpDec (Aliases rep))
aux Certs -> Certs -> Certs
forall a. Semigroup a => a -> a -> a
<> Certs
cs) (Stm (Aliases rep) -> Stm (Aliases rep))
-> Stm (Aliases rep) -> Stm (Aliases rep)
forall a b. (a -> b) -> a -> b
$
                  [Ident] -> Exp (Aliases rep) -> Stm (Aliases rep)
forall rep. Buildable rep => [Ident] -> Exp rep -> Stm rep
mkLet [VName -> Type -> Ident
Ident VName
bindee_nm (Type -> Ident) -> Type -> Ident
forall a b. (a -> b) -> a -> b
$ (VarAliases, Type) -> Type
forall t. Typed t => t -> Type
typeOf (VarAliases, Type)
LetDec (Aliases rep)
bindee_dec] (Exp (Aliases rep) -> Stm (Aliases rep))
-> Exp (Aliases rep) -> Stm (Aliases rep)
forall a b. (a -> b) -> a -> b
$
                    BasicOp -> Exp (Aliases rep)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Aliases rep)) -> BasicOp -> Exp (Aliases rep)
forall a b. (a -> b) -> a -> b
$
                      Safety -> VName -> Slice SubExp -> SubExp -> BasicOp
Update Safety
Unsafe VName
v Slice SubExp
is' (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$
                        VName -> SubExp
Var VName
val
              ]
lowerUpdate Scope (Aliases rep)
_ Stm (Aliases rep)
_ [DesiredUpdate (LetDec (Aliases rep))]
_ =
  Maybe (m [Stm (Aliases rep)])
forall a. Maybe a
Nothing

lowerUpdateGPU :: MonadFreshNames m => LowerUpdate GPU m
lowerUpdateGPU :: forall (m :: * -> *). MonadFreshNames m => LowerUpdate GPU m
lowerUpdateGPU
  Scope (Aliases GPU)
scope
  (Let Pat (LetDec (Aliases GPU))
pat StmAux (ExpDec (Aliases GPU))
aux (Op (SegOp (SegMap SegLevel
lvl SegSpace
space [Type]
ts KernelBody (Aliases GPU)
kbody))))
  [DesiredUpdate (LetDec (Aliases GPU))]
updates
    | (DesiredUpdate (VarAliases, Type) -> Bool)
-> [DesiredUpdate (VarAliases, Type)] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all ((VName -> [VName] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` Pat (VarAliases, Type) -> [VName]
forall dec. Pat dec -> [VName]
patNames Pat (VarAliases, Type)
Pat (LetDec (Aliases GPU))
pat) (VName -> Bool)
-> (DesiredUpdate (VarAliases, Type) -> VName)
-> DesiredUpdate (VarAliases, Type)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DesiredUpdate (VarAliases, Type) -> VName
forall dec. DesiredUpdate dec -> VName
updateValue) [DesiredUpdate (VarAliases, Type)]
[DesiredUpdate (LetDec (Aliases GPU))]
updates,
      Bool -> Bool
not Bool
source_used_in_kbody = do
        m (Pat (VarAliases, Type), KernelBody (Aliases GPU),
   Stms (Aliases GPU))
mk <- Scope (Aliases GPU)
-> Pat (LetDec (Aliases GPU))
-> [DesiredUpdate (LetDec (Aliases GPU))]
-> SegSpace
-> KernelBody (Aliases GPU)
-> Maybe
     (m (Pat (LetDec (Aliases GPU)), KernelBody (Aliases GPU),
         Stms (Aliases GPU)))
forall (m :: * -> *).
MonadFreshNames m =>
Scope (Aliases GPU)
-> Pat (LetDec (Aliases GPU))
-> [DesiredUpdate (LetDec (Aliases GPU))]
-> SegSpace
-> KernelBody (Aliases GPU)
-> Maybe
     (m (Pat (LetDec (Aliases GPU)), KernelBody (Aliases GPU),
         Stms (Aliases GPU)))
lowerUpdatesIntoSegMap Scope (Aliases GPU)
scope Pat (LetDec (Aliases GPU))
pat [DesiredUpdate (LetDec (Aliases GPU))]
updates SegSpace
space KernelBody (Aliases GPU)
kbody
        m [Stm (Aliases GPU)] -> Maybe (m [Stm (Aliases GPU)])
forall a. a -> Maybe a
Just (m [Stm (Aliases GPU)] -> Maybe (m [Stm (Aliases GPU)]))
-> m [Stm (Aliases GPU)] -> Maybe (m [Stm (Aliases GPU)])
forall a b. (a -> b) -> a -> b
$ do
          (Pat (VarAliases, Type)
pat', KernelBody (Aliases GPU)
kbody', Stms (Aliases GPU)
poststms) <- m (Pat (VarAliases, Type), KernelBody (Aliases GPU),
   Stms (Aliases GPU))
mk
          let cs :: Certs
cs = StmAux (VarAliases, ()) -> Certs
forall dec. StmAux dec -> Certs
stmAuxCerts StmAux (VarAliases, ())
StmAux (ExpDec (Aliases GPU))
aux Certs -> Certs -> Certs
forall a. Semigroup a => a -> a -> a
<> (DesiredUpdate (VarAliases, Type) -> Certs)
-> [DesiredUpdate (VarAliases, Type)] -> Certs
forall m a. Monoid m => (a -> m) -> [a] -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap DesiredUpdate (VarAliases, Type) -> Certs
forall dec. DesiredUpdate dec -> Certs
updateCerts [DesiredUpdate (VarAliases, Type)]
[DesiredUpdate (LetDec (Aliases GPU))]
updates
          [Stm (Aliases GPU)] -> m [Stm (Aliases GPU)]
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([Stm (Aliases GPU)] -> m [Stm (Aliases GPU)])
-> [Stm (Aliases GPU)] -> m [Stm (Aliases GPU)]
forall a b. (a -> b) -> a -> b
$
            Certs -> Stm (Aliases GPU) -> Stm (Aliases GPU)
forall rep. Certs -> Stm rep -> Stm rep
certify Certs
cs (Pat (LetDec (Aliases GPU))
-> StmAux (ExpDec (Aliases GPU))
-> Exp (Aliases GPU)
-> Stm (Aliases GPU)
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (VarAliases, Type)
Pat (LetDec (Aliases GPU))
pat' StmAux (ExpDec (Aliases GPU))
aux (Exp (Aliases GPU) -> Stm (Aliases GPU))
-> Exp (Aliases GPU) -> Stm (Aliases GPU)
forall a b. (a -> b) -> a -> b
$ Op (Aliases GPU) -> Exp (Aliases GPU)
forall rep. Op rep -> Exp rep
Op (Op (Aliases GPU) -> Exp (Aliases GPU))
-> Op (Aliases GPU) -> Exp (Aliases GPU)
forall a b. (a -> b) -> a -> b
$ SegOp SegLevel (Aliases GPU) -> HostOp SOAC (Aliases GPU)
forall (op :: * -> *) rep. SegOp SegLevel rep -> HostOp op rep
SegOp (SegOp SegLevel (Aliases GPU) -> HostOp SOAC (Aliases GPU))
-> SegOp SegLevel (Aliases GPU) -> HostOp SOAC (Aliases GPU)
forall a b. (a -> b) -> a -> b
$ SegLevel
-> SegSpace
-> [Type]
-> KernelBody (Aliases GPU)
-> SegOp SegLevel (Aliases GPU)
forall lvl rep.
lvl -> SegSpace -> [Type] -> KernelBody rep -> SegOp lvl rep
SegMap SegLevel
lvl SegSpace
space [Type]
ts KernelBody (Aliases GPU)
kbody')
              Stm (Aliases GPU) -> [Stm (Aliases GPU)] -> [Stm (Aliases GPU)]
forall a. a -> [a] -> [a]
: Stms (Aliases GPU) -> [Stm (Aliases GPU)]
forall rep. Stms rep -> [Stm rep]
stmsToList Stms (Aliases GPU)
poststms
    where
      -- This check is a bit more conservative than ideal.  In a perfect
      -- world, we would allow indexing a[i,j] if the update is also
      -- to exactly a[i,j], as that will not create cross-iteration
      -- dependencies.  (Although the type checker wouldn't be able to
      -- permit this anyway.)
      source_used_in_kbody :: Bool
source_used_in_kbody =
        [Names] -> Names
forall a. Monoid a => [a] -> a
mconcat ((VName -> Names) -> [VName] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> Scope (Aliases GPU) -> Names
forall rep. AliasesOf (LetDec rep) => VName -> Scope rep -> Names
`lookupAliases` Scope (Aliases GPU)
scope) (Names -> [VName]
namesToList (KernelBody (Aliases GPU) -> Names
forall a. FreeIn a => a -> Names
freeIn KernelBody (Aliases GPU)
kbody)))
          Names -> Names -> Bool
`namesIntersect` [Names] -> Names
forall a. Monoid a => [a] -> a
mconcat ((DesiredUpdate (VarAliases, Type) -> Names)
-> [DesiredUpdate (VarAliases, Type)] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map ((VName -> Scope (Aliases GPU) -> Names
forall rep. AliasesOf (LetDec rep) => VName -> Scope rep -> Names
`lookupAliases` Scope (Aliases GPU)
scope) (VName -> Names)
-> (DesiredUpdate (VarAliases, Type) -> VName)
-> DesiredUpdate (VarAliases, Type)
-> Names
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DesiredUpdate (VarAliases, Type) -> VName
forall dec. DesiredUpdate dec -> VName
updateSource) [DesiredUpdate (VarAliases, Type)]
[DesiredUpdate (LetDec (Aliases GPU))]
updates)
lowerUpdateGPU Scope (Aliases GPU)
scope Stm (Aliases GPU)
stm [DesiredUpdate (LetDec (Aliases GPU))]
updates = LowerUpdate GPU m
forall (m :: * -> *) rep.
(MonadFreshNames m, Buildable rep, LetDec rep ~ Type,
 AliasableRep rep) =>
LowerUpdate rep m
lowerUpdate Scope (Aliases GPU)
scope Stm (Aliases GPU)
stm [DesiredUpdate (LetDec (Aliases GPU))]
updates

lowerUpdatesIntoSegMap ::
  MonadFreshNames m =>
  Scope (Aliases GPU) ->
  Pat (LetDec (Aliases GPU)) ->
  [DesiredUpdate (LetDec (Aliases GPU))] ->
  SegSpace ->
  KernelBody (Aliases GPU) ->
  Maybe
    ( m
        ( Pat (LetDec (Aliases GPU)),
          KernelBody (Aliases GPU),
          Stms (Aliases GPU)
        )
    )
lowerUpdatesIntoSegMap :: forall (m :: * -> *).
MonadFreshNames m =>
Scope (Aliases GPU)
-> Pat (LetDec (Aliases GPU))
-> [DesiredUpdate (LetDec (Aliases GPU))]
-> SegSpace
-> KernelBody (Aliases GPU)
-> Maybe
     (m (Pat (LetDec (Aliases GPU)), KernelBody (Aliases GPU),
         Stms (Aliases GPU)))
lowerUpdatesIntoSegMap Scope (Aliases GPU)
scope Pat (LetDec (Aliases GPU))
pat [DesiredUpdate (LetDec (Aliases GPU))]
updates SegSpace
kspace KernelBody (Aliases GPU)
kbody = do
  -- The updates are all-or-nothing.  Being more liberal would require
  -- changes to the in-place-lowering pass itself.
  [m (PatElem (VarAliases, Type), Stms (Aliases GPU), KernelResult,
    Stms (Aliases GPU))]
mk <- (PatElem (VarAliases, Type)
 -> KernelResult
 -> Maybe
      (m (PatElem (VarAliases, Type), Stms (Aliases GPU), KernelResult,
          Stms (Aliases GPU))))
-> [PatElem (VarAliases, Type)]
-> [KernelResult]
-> Maybe
     [m (PatElem (VarAliases, Type), Stms (Aliases GPU), KernelResult,
         Stms (Aliases GPU))]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM PatElem (VarAliases, Type)
-> KernelResult
-> Maybe
     (m (PatElem (VarAliases, Type), Stms (Aliases GPU), KernelResult,
         Stms (Aliases GPU)))
onRet (Pat (VarAliases, Type) -> [PatElem (VarAliases, Type)]
forall dec. Pat dec -> [PatElem dec]
patElems Pat (VarAliases, Type)
Pat (LetDec (Aliases GPU))
pat) (KernelBody (Aliases GPU) -> [KernelResult]
forall rep. KernelBody rep -> [KernelResult]
kernelBodyResult KernelBody (Aliases GPU)
kbody)
  m (Pat (VarAliases, Type), KernelBody (Aliases GPU),
   Stms (Aliases GPU))
-> Maybe
     (m (Pat (VarAliases, Type), KernelBody (Aliases GPU),
         Stms (Aliases GPU)))
forall a. a -> Maybe a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (m (Pat (VarAliases, Type), KernelBody (Aliases GPU),
    Stms (Aliases GPU))
 -> Maybe
      (m (Pat (VarAliases, Type), KernelBody (Aliases GPU),
          Stms (Aliases GPU))))
-> m (Pat (VarAliases, Type), KernelBody (Aliases GPU),
      Stms (Aliases GPU))
-> Maybe
     (m (Pat (VarAliases, Type), KernelBody (Aliases GPU),
         Stms (Aliases GPU)))
forall a b. (a -> b) -> a -> b
$ do
    ([PatElem (VarAliases, Type)]
pes, [Stms (Aliases GPU)]
bodystms, [KernelResult]
krets, [Stms (Aliases GPU)]
poststms) <- [(PatElem (VarAliases, Type), Stms (Aliases GPU), KernelResult,
  Stms (Aliases GPU))]
-> ([PatElem (VarAliases, Type)], [Stms (Aliases GPU)],
    [KernelResult], [Stms (Aliases GPU)])
forall a b c d. [(a, b, c, d)] -> ([a], [b], [c], [d])
unzip4 ([(PatElem (VarAliases, Type), Stms (Aliases GPU), KernelResult,
   Stms (Aliases GPU))]
 -> ([PatElem (VarAliases, Type)], [Stms (Aliases GPU)],
     [KernelResult], [Stms (Aliases GPU)]))
-> m [(PatElem (VarAliases, Type), Stms (Aliases GPU),
       KernelResult, Stms (Aliases GPU))]
-> m ([PatElem (VarAliases, Type)], [Stms (Aliases GPU)],
      [KernelResult], [Stms (Aliases GPU)])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [m (PatElem (VarAliases, Type), Stms (Aliases GPU), KernelResult,
    Stms (Aliases GPU))]
-> m [(PatElem (VarAliases, Type), Stms (Aliases GPU),
       KernelResult, Stms (Aliases GPU))]
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
forall (m :: * -> *) a. Monad m => [m a] -> m [a]
sequence [m (PatElem (VarAliases, Type), Stms (Aliases GPU), KernelResult,
    Stms (Aliases GPU))]
mk
    (Pat (VarAliases, Type), KernelBody (Aliases GPU),
 Stms (Aliases GPU))
-> m (Pat (VarAliases, Type), KernelBody (Aliases GPU),
      Stms (Aliases GPU))
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
      ( [PatElem (VarAliases, Type)] -> Pat (VarAliases, Type)
forall dec. [PatElem dec] -> Pat dec
Pat [PatElem (VarAliases, Type)]
pes,
        KernelBody (Aliases GPU)
kbody
          { kernelBodyStms :: Stms (Aliases GPU)
kernelBodyStms = KernelBody (Aliases GPU) -> Stms (Aliases GPU)
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody (Aliases GPU)
kbody Stms (Aliases GPU) -> Stms (Aliases GPU) -> Stms (Aliases GPU)
forall a. Semigroup a => a -> a -> a
<> [Stms (Aliases GPU)] -> Stms (Aliases GPU)
forall a. Monoid a => [a] -> a
mconcat [Stms (Aliases GPU)]
bodystms,
            kernelBodyResult :: [KernelResult]
kernelBodyResult = [KernelResult]
krets
          },
        [Stms (Aliases GPU)] -> Stms (Aliases GPU)
forall a. Monoid a => [a] -> a
mconcat [Stms (Aliases GPU)]
poststms
      )
  where
    ([VName]
gtids, [SubExp]
_dims) = [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], [SubExp]))
-> [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
kspace

    onRet :: PatElem (VarAliases, Type)
-> KernelResult
-> Maybe
     (m (PatElem (VarAliases, Type), Stms (Aliases GPU), KernelResult,
         Stms (Aliases GPU)))
onRet (PatElem VName
v (VarAliases, Type)
v_dec) KernelResult
ret
      | Just (DesiredUpdate VName
bindee_nm (VarAliases, Type)
bindee_dec Certs
_cs VName
src Slice SubExp
slice VName
_val) <-
          (DesiredUpdate (VarAliases, Type) -> Bool)
-> [DesiredUpdate (VarAliases, Type)]
-> Maybe (DesiredUpdate (VarAliases, Type))
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
v) (VName -> Bool)
-> (DesiredUpdate (VarAliases, Type) -> VName)
-> DesiredUpdate (VarAliases, Type)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DesiredUpdate (VarAliases, Type) -> VName
forall dec. DesiredUpdate dec -> VName
updateValue) [DesiredUpdate (VarAliases, Type)]
[DesiredUpdate (LetDec (Aliases GPU))]
updates = do
          Returns ResultManifest
_ Certs
cs SubExp
se <- KernelResult -> Maybe KernelResult
forall a. a -> Maybe a
Just KernelResult
ret

          -- The slice we're writing per thread must fully cover the
          -- underlying dimensions.
          Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$
            let ([SubExp]
dims', [DimIndex SubExp]
slice') =
                  [(SubExp, DimIndex SubExp)] -> ([SubExp], [DimIndex SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(SubExp, DimIndex SubExp)] -> ([SubExp], [DimIndex SubExp]))
-> ([(SubExp, DimIndex SubExp)] -> [(SubExp, DimIndex SubExp)])
-> [(SubExp, DimIndex SubExp)]
-> ([SubExp], [DimIndex SubExp])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> [(SubExp, DimIndex SubExp)] -> [(SubExp, DimIndex SubExp)]
forall a. Int -> [a] -> [a]
drop ([VName] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
gtids) ([(SubExp, DimIndex SubExp)] -> [(SubExp, DimIndex SubExp)])
-> ([(SubExp, DimIndex SubExp)] -> [(SubExp, DimIndex SubExp)])
-> [(SubExp, DimIndex SubExp)]
-> [(SubExp, DimIndex SubExp)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((SubExp, DimIndex SubExp) -> Bool)
-> [(SubExp, DimIndex SubExp)] -> [(SubExp, DimIndex SubExp)]
forall a. (a -> Bool) -> [a] -> [a]
filter (Maybe SubExp -> Bool
forall a. Maybe a -> Bool
isNothing (Maybe SubExp -> Bool)
-> ((SubExp, DimIndex SubExp) -> Maybe SubExp)
-> (SubExp, DimIndex SubExp)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DimIndex SubExp -> Maybe SubExp
forall d. DimIndex d -> Maybe d
dimFix (DimIndex SubExp -> Maybe SubExp)
-> ((SubExp, DimIndex SubExp) -> DimIndex SubExp)
-> (SubExp, DimIndex SubExp)
-> Maybe SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (SubExp, DimIndex SubExp) -> DimIndex SubExp
forall a b. (a, b) -> b
snd) ([(SubExp, DimIndex SubExp)] -> ([SubExp], [DimIndex SubExp]))
-> [(SubExp, DimIndex SubExp)] -> ([SubExp], [DimIndex SubExp])
forall a b. (a -> b) -> a -> b
$
                    [SubExp] -> [DimIndex SubExp] -> [(SubExp, DimIndex SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims ((VarAliases, Type) -> Type
forall t. Typed t => t -> Type
typeOf (VarAliases, Type)
bindee_dec)) (Slice SubExp -> [DimIndex SubExp]
forall d. Slice d -> [DimIndex d]
unSlice Slice SubExp
slice)
             in Shape -> Slice SubExp -> Bool
isFullSlice ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp]
dims') ([DimIndex SubExp] -> Slice SubExp
forall d. [DimIndex d] -> Slice d
Slice [DimIndex SubExp]
slice')

          m (PatElem (VarAliases, Type), Stms (Aliases GPU), KernelResult,
   Stms (Aliases GPU))
-> Maybe
     (m (PatElem (VarAliases, Type), Stms (Aliases GPU), KernelResult,
         Stms (Aliases GPU)))
forall a. a -> Maybe a
Just (m (PatElem (VarAliases, Type), Stms (Aliases GPU), KernelResult,
    Stms (Aliases GPU))
 -> Maybe
      (m (PatElem (VarAliases, Type), Stms (Aliases GPU), KernelResult,
          Stms (Aliases GPU))))
-> m (PatElem (VarAliases, Type), Stms (Aliases GPU), KernelResult,
      Stms (Aliases GPU))
-> Maybe
     (m (PatElem (VarAliases, Type), Stms (Aliases GPU), KernelResult,
         Stms (Aliases GPU)))
forall a b. (a -> b) -> a -> b
$ do
            ([SubExp]
slice', Stms (Aliases GPU)
bodystms) <-
              (BuilderT (Aliases GPU) m [SubExp]
 -> Scope (Aliases GPU) -> m ([SubExp], Stms (Aliases GPU)))
-> Scope (Aliases GPU)
-> BuilderT (Aliases GPU) m [SubExp]
-> m ([SubExp], Stms (Aliases GPU))
forall a b c. (a -> b -> c) -> b -> a -> c
flip BuilderT (Aliases GPU) m [SubExp]
-> Scope (Aliases GPU) -> m ([SubExp], Stms (Aliases GPU))
forall (m :: * -> *) rep a.
MonadFreshNames m =>
BuilderT rep m a -> Scope rep -> m (a, Stms rep)
runBuilderT Scope (Aliases GPU)
scope (BuilderT (Aliases GPU) m [SubExp]
 -> m ([SubExp], Stms (Aliases GPU)))
-> BuilderT (Aliases GPU) m [SubExp]
-> m ([SubExp], Stms (Aliases GPU))
forall a b. (a -> b) -> a -> b
$
                (TPrimExp Int64 VName -> BuilderT (Aliases GPU) m SubExp)
-> [TPrimExp Int64 VName] -> BuilderT (Aliases GPU) m [SubExp]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse (String -> TPrimExp Int64 VName -> BuilderT (Aliases GPU) m SubExp
forall (m :: * -> *) a.
(MonadBuilder m, ToExp a) =>
String -> a -> m SubExp
toSubExp String
"index") ([TPrimExp Int64 VName] -> BuilderT (Aliases GPU) m [SubExp])
-> [TPrimExp Int64 VName] -> BuilderT (Aliases GPU) m [SubExp]
forall a b. (a -> b) -> a -> b
$
                  Slice (TPrimExp Int64 VName)
-> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall d. Num d => Slice d -> [d] -> [d]
fixSlice ((SubExp -> TPrimExp Int64 VName)
-> Slice SubExp -> Slice (TPrimExp Int64 VName)
forall a b. (a -> b) -> Slice a -> Slice b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TPrimExp Int64 VName
pe64 Slice SubExp
slice) ([TPrimExp Int64 VName] -> [TPrimExp Int64 VName])
-> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> a -> b
$
                    (VName -> TPrimExp Int64 VName)
-> [VName] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> TPrimExp Int64 VName
pe64 (SubExp -> TPrimExp Int64 VName)
-> (VName -> SubExp) -> VName -> TPrimExp Int64 VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var) [VName]
gtids

            let res_dims :: [SubExp]
res_dims = Int -> [SubExp] -> [SubExp]
forall a. Int -> [a] -> [a]
take ([SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
slice') ([SubExp] -> [SubExp]) -> [SubExp] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims (Type -> [SubExp]) -> Type -> [SubExp]
forall a b. (a -> b) -> a -> b
$ (VarAliases, Type) -> Type
forall a b. (a, b) -> b
snd (VarAliases, Type)
bindee_dec
                ret' :: KernelResult
ret' = Certs -> Shape -> VName -> [(Slice SubExp, SubExp)] -> KernelResult
WriteReturns Certs
cs ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp]
res_dims) VName
src [([DimIndex SubExp] -> Slice SubExp
forall d. [DimIndex d] -> Slice d
Slice ([DimIndex SubExp] -> Slice SubExp)
-> [DimIndex SubExp] -> Slice SubExp
forall a b. (a -> b) -> a -> b
$ (SubExp -> DimIndex SubExp) -> [SubExp] -> [DimIndex SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix [SubExp]
slice', SubExp
se)]

            VName
v_aliased <- VName -> m VName
forall (m :: * -> *). MonadFreshNames m => VName -> m VName
newName VName
v

            (PatElem (VarAliases, Type), Stms (Aliases GPU), KernelResult,
 Stms (Aliases GPU))
-> m (PatElem (VarAliases, Type), Stms (Aliases GPU), KernelResult,
      Stms (Aliases GPU))
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
              ( VName -> (VarAliases, Type) -> PatElem (VarAliases, Type)
forall dec. VName -> dec -> PatElem dec
PatElem VName
bindee_nm (VarAliases, Type)
bindee_dec,
                Stms (Aliases GPU)
bodystms,
                KernelResult
ret',
                [Stm (Aliases GPU)] -> Stms (Aliases GPU)
forall rep. [Stm rep] -> Stms rep
stmsFromList
                  [ [Ident] -> Exp (Aliases GPU) -> Stm (Aliases GPU)
forall rep. Buildable rep => [Ident] -> Exp rep -> Stm rep
mkLet [VName -> Type -> Ident
Ident VName
v_aliased (Type -> Ident) -> Type -> Ident
forall a b. (a -> b) -> a -> b
$ (VarAliases, Type) -> Type
forall t. Typed t => t -> Type
typeOf (VarAliases, Type)
v_dec] (Exp (Aliases GPU) -> Stm (Aliases GPU))
-> Exp (Aliases GPU) -> Stm (Aliases GPU)
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Aliases GPU)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Aliases GPU)) -> BasicOp -> Exp (Aliases GPU)
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
bindee_nm Slice SubExp
slice,
                    [Ident] -> Exp (Aliases GPU) -> Stm (Aliases GPU)
forall rep. Buildable rep => [Ident] -> Exp rep -> Stm rep
mkLet [VName -> Type -> Ident
Ident VName
v (Type -> Ident) -> Type -> Ident
forall a b. (a -> b) -> a -> b
$ (VarAliases, Type) -> Type
forall t. Typed t => t -> Type
typeOf (VarAliases, Type)
v_dec] (Exp (Aliases GPU) -> Stm (Aliases GPU))
-> Exp (Aliases GPU) -> Stm (Aliases GPU)
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Aliases GPU)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Aliases GPU)) -> BasicOp -> Exp (Aliases GPU)
forall a b. (a -> b) -> a -> b
$ Shape -> SubExp -> BasicOp
Replicate Shape
forall a. Monoid a => a
mempty (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
v_aliased
                  ]
              )
    onRet PatElem (VarAliases, Type)
pe KernelResult
ret =
      m (PatElem (VarAliases, Type), Stms (Aliases GPU), KernelResult,
   Stms (Aliases GPU))
-> Maybe
     (m (PatElem (VarAliases, Type), Stms (Aliases GPU), KernelResult,
         Stms (Aliases GPU)))
forall a. a -> Maybe a
Just (m (PatElem (VarAliases, Type), Stms (Aliases GPU), KernelResult,
    Stms (Aliases GPU))
 -> Maybe
      (m (PatElem (VarAliases, Type), Stms (Aliases GPU), KernelResult,
          Stms (Aliases GPU))))
-> m (PatElem (VarAliases, Type), Stms (Aliases GPU), KernelResult,
      Stms (Aliases GPU))
-> Maybe
     (m (PatElem (VarAliases, Type), Stms (Aliases GPU), KernelResult,
         Stms (Aliases GPU)))
forall a b. (a -> b) -> a -> b
$ (PatElem (VarAliases, Type), Stms (Aliases GPU), KernelResult,
 Stms (Aliases GPU))
-> m (PatElem (VarAliases, Type), Stms (Aliases GPU), KernelResult,
      Stms (Aliases GPU))
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (PatElem (VarAliases, Type)
pe, Stms (Aliases GPU)
forall a. Monoid a => a
mempty, KernelResult
ret, Stms (Aliases GPU)
forall a. Monoid a => a
mempty)

lowerUpdateIntoLoop ::
  ( Buildable rep,
    BuilderOps rep,
    Aliased rep,
    LetDec rep ~ (als, Type),
    MonadFreshNames m
  ) =>
  Scope rep ->
  [DesiredUpdate (LetDec rep)] ->
  Pat (LetDec rep) ->
  [(FParam rep, SubExp)] ->
  LoopForm rep ->
  Body rep ->
  Maybe
    ( m
        ( [Stm rep],
          [Stm rep],
          [Ident],
          [(FParam rep, SubExp)],
          Body rep
        )
    )
lowerUpdateIntoLoop :: forall rep als (m :: * -> *).
(Buildable rep, BuilderOps rep, Aliased rep,
 LetDec rep ~ (als, Type), MonadFreshNames m) =>
Scope rep
-> [DesiredUpdate (LetDec rep)]
-> Pat (LetDec rep)
-> [(FParam rep, SubExp)]
-> LoopForm rep
-> Body rep
-> Maybe
     (m ([Stm rep], [Stm rep], [Ident], [(FParam rep, SubExp)],
         Body rep))
lowerUpdateIntoLoop Scope rep
scope [DesiredUpdate (LetDec rep)]
updates Pat (LetDec rep)
pat [(FParam rep, SubExp)]
val LoopForm rep
form Body rep
body = do
  -- Algorithm:
  --
  --   0) Map each result of the loop body to a corresponding in-place
  --      update, if one exists.
  --
  --   1) Create new merge variables corresponding to the arrays being
  --      updated; extend the pattern and the @res@ list with these,
  --      and remove the parts of the result list that have a
  --      corresponding in-place update.
  --
  --      (The creation of the new merge variable identifiers is
  --      actually done at the same time as step (0)).
  --
  --   2) Create in-place updates at the end of the loop body.
  --
  --   3) Create index expressions that read back the values written
  --      in (2).  If the merge parameter corresponding to this value
  --      is unique, also @copy@ this value.
  --
  --   4) Update the result of the loop body to properly pass the new
  --      arrays and indexed elements to the next iteration of the
  --      loop.
  --
  -- We also check that the merge parameters we work with have
  -- loop-invariant shapes.

  -- Safety condition (8).
  [((Param DeclType, SubExp), Names)]
-> (((Param DeclType, SubExp), Names) -> Maybe ()) -> Maybe ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([(Param DeclType, SubExp)]
-> [Names] -> [((Param DeclType, SubExp), Names)]
forall a b. [a] -> [b] -> [(a, b)]
zip [(Param DeclType, SubExp)]
[(FParam rep, SubExp)]
val ([Names] -> [((Param DeclType, SubExp), Names)])
-> [Names] -> [((Param DeclType, SubExp), Names)]
forall a b. (a -> b) -> a -> b
$ Body rep -> [Names]
forall rep. Aliased rep => Body rep -> [Names]
bodyAliases Body rep
body) ((((Param DeclType, SubExp), Names) -> Maybe ()) -> Maybe ())
-> (((Param DeclType, SubExp), Names) -> Maybe ()) -> Maybe ()
forall a b. (a -> b) -> a -> b
$ \((Param DeclType
p, SubExp
_), Names
als) ->
    Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$ Param DeclType -> VName
forall dec. Param dec -> VName
paramName Param DeclType
p VName -> Names -> Bool
`notNameIn` Names
als

  m [LoopResultSummary (als, Type)]
mk_in_place_map <- Scope rep
-> [DesiredUpdate (als, Type)]
-> Names
-> [(SubExpRes, Ident)]
-> [(Param DeclType, SubExp)]
-> Maybe (m [LoopResultSummary (als, Type)])
forall rep (m :: * -> *) als.
(Aliased rep, MonadFreshNames m) =>
Scope rep
-> [DesiredUpdate (als, Type)]
-> Names
-> [(SubExpRes, Ident)]
-> [(Param DeclType, SubExp)]
-> Maybe (m [LoopResultSummary (als, Type)])
summariseLoop Scope rep
scope [DesiredUpdate (als, Type)]
[DesiredUpdate (LetDec rep)]
updates Names
usedInBody [(SubExpRes, Ident)]
resmap [(Param DeclType, SubExp)]
[(FParam rep, SubExp)]
val

  m ([Stm rep], [Stm rep], [Ident], [(Param DeclType, SubExp)],
   Body rep)
-> Maybe
     (m ([Stm rep], [Stm rep], [Ident], [(Param DeclType, SubExp)],
         Body rep))
forall a. a -> Maybe a
Just (m ([Stm rep], [Stm rep], [Ident], [(Param DeclType, SubExp)],
    Body rep)
 -> Maybe
      (m ([Stm rep], [Stm rep], [Ident], [(Param DeclType, SubExp)],
          Body rep)))
-> m ([Stm rep], [Stm rep], [Ident], [(Param DeclType, SubExp)],
      Body rep)
-> Maybe
     (m ([Stm rep], [Stm rep], [Ident], [(Param DeclType, SubExp)],
         Body rep))
forall a b. (a -> b) -> a -> b
$ do
    [LoopResultSummary (als, Type)]
in_place_map <- m [LoopResultSummary (als, Type)]
mk_in_place_map
    ([(Param DeclType, SubExp)]
val', [Stm rep]
prestms, [Stm rep]
poststms) <- [LoopResultSummary (als, Type)]
-> m ([(Param DeclType, SubExp)], [Stm rep], [Stm rep])
forall (m :: * -> *) rep als.
(MonadFreshNames m, Buildable rep) =>
[LoopResultSummary (als, Type)]
-> m ([(Param DeclType, SubExp)], [Stm rep], [Stm rep])
mkMerges [LoopResultSummary (als, Type)]
in_place_map
    let valpat :: [Ident]
valpat = [LoopResultSummary (als, Type)] -> [Ident]
forall {a}. [LoopResultSummary (a, Type)] -> [Ident]
mkResAndPat [LoopResultSummary (als, Type)]
in_place_map
        idxsubsts :: IndexSubstitutions
idxsubsts = [LoopResultSummary (als, Type)] -> IndexSubstitutions
forall dec.
Typed dec =>
[LoopResultSummary dec] -> IndexSubstitutions
indexSubstitutions [LoopResultSummary (als, Type)]
in_place_map
    (IndexSubstitutions
idxsubsts', Stms rep
newstms) <- IndexSubstitutions -> Stms rep -> m (IndexSubstitutions, Stms rep)
forall (m :: * -> *) rep.
(MonadFreshNames m, BuilderOps rep, Buildable rep, Aliased rep) =>
IndexSubstitutions -> Stms rep -> m (IndexSubstitutions, Stms rep)
substituteIndices IndexSubstitutions
idxsubsts (Stms rep -> m (IndexSubstitutions, Stms rep))
-> Stms rep -> m (IndexSubstitutions, Stms rep)
forall a b. (a -> b) -> a -> b
$ Body rep -> Stms rep
forall rep. Body rep -> Stms rep
bodyStms Body rep
body
    (Result
body_res, Stms rep
res_stms) <- [LoopResultSummary (LetDec rep)]
-> IndexSubstitutions -> m (Result, Stms rep)
forall rep (m :: * -> *).
(Buildable rep, MonadFreshNames m) =>
[LoopResultSummary (LetDec rep)]
-> IndexSubstitutions -> m (Result, Stms rep)
manipulateResult [LoopResultSummary (als, Type)]
[LoopResultSummary (LetDec rep)]
in_place_map IndexSubstitutions
idxsubsts'
    let body' :: Body rep
body' = Stms rep -> Result -> Body rep
forall rep. Buildable rep => Stms rep -> Result -> Body rep
mkBody (Stms rep
newstms Stms rep -> Stms rep -> Stms rep
forall a. Semigroup a => a -> a -> a
<> Stms rep
res_stms) Result
body_res
    ([Stm rep], [Stm rep], [Ident], [(Param DeclType, SubExp)],
 Body rep)
-> m ([Stm rep], [Stm rep], [Ident], [(Param DeclType, SubExp)],
      Body rep)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([Stm rep]
prestms, [Stm rep]
poststms, [Ident]
valpat, [(Param DeclType, SubExp)]
val', Body rep
body')
  where
    usedInBody :: Names
usedInBody =
      [Names] -> Names
forall a. Monoid a => [a] -> a
mconcat ([Names] -> Names) -> [Names] -> Names
forall a b. (a -> b) -> a -> b
$ (VName -> Names) -> [VName] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> Scope rep -> Names
forall rep. AliasesOf (LetDec rep) => VName -> Scope rep -> Names
`lookupAliases` Scope rep
scope) ([VName] -> [Names]) -> [VName] -> [Names]
forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$ Body rep -> Names
forall a. FreeIn a => a -> Names
freeIn Body rep
body Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> LoopForm rep -> Names
forall a. FreeIn a => a -> Names
freeIn LoopForm rep
form
    resmap :: [(SubExpRes, Ident)]
resmap = Result -> [Ident] -> [(SubExpRes, Ident)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Body rep -> Result
forall rep. Body rep -> Result
bodyResult Body rep
body) ([Ident] -> [(SubExpRes, Ident)])
-> [Ident] -> [(SubExpRes, Ident)]
forall a b. (a -> b) -> a -> b
$ Pat (als, Type) -> [Ident]
forall dec. Typed dec => Pat dec -> [Ident]
patIdents Pat (als, Type)
Pat (LetDec rep)
pat

    mkMerges ::
      (MonadFreshNames m, Buildable rep) =>
      [LoopResultSummary (als, Type)] ->
      m ([(Param DeclType, SubExp)], [Stm rep], [Stm rep])
    mkMerges :: forall (m :: * -> *) rep als.
(MonadFreshNames m, Buildable rep) =>
[LoopResultSummary (als, Type)]
-> m ([(Param DeclType, SubExp)], [Stm rep], [Stm rep])
mkMerges [LoopResultSummary (als, Type)]
summaries = do
      (([(Param DeclType, SubExp)]
origmerge, [(Param DeclType, SubExp)]
extramerge), ([Stm rep]
prestms, [Stm rep]
poststms)) <-
        WriterT
  ([Stm rep], [Stm rep])
  m
  ([(Param DeclType, SubExp)], [(Param DeclType, SubExp)])
-> m (([(Param DeclType, SubExp)], [(Param DeclType, SubExp)]),
      ([Stm rep], [Stm rep]))
forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT (WriterT
   ([Stm rep], [Stm rep])
   m
   ([(Param DeclType, SubExp)], [(Param DeclType, SubExp)])
 -> m (([(Param DeclType, SubExp)], [(Param DeclType, SubExp)]),
       ([Stm rep], [Stm rep])))
-> WriterT
     ([Stm rep], [Stm rep])
     m
     ([(Param DeclType, SubExp)], [(Param DeclType, SubExp)])
-> m (([(Param DeclType, SubExp)], [(Param DeclType, SubExp)]),
      ([Stm rep], [Stm rep]))
forall a b. (a -> b) -> a -> b
$ [Either (Param DeclType, SubExp) (Param DeclType, SubExp)]
-> ([(Param DeclType, SubExp)], [(Param DeclType, SubExp)])
forall a b. [Either a b] -> ([a], [b])
partitionEithers ([Either (Param DeclType, SubExp) (Param DeclType, SubExp)]
 -> ([(Param DeclType, SubExp)], [(Param DeclType, SubExp)]))
-> WriterT
     ([Stm rep], [Stm rep])
     m
     [Either (Param DeclType, SubExp) (Param DeclType, SubExp)]
-> WriterT
     ([Stm rep], [Stm rep])
     m
     ([(Param DeclType, SubExp)], [(Param DeclType, SubExp)])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (LoopResultSummary (als, Type)
 -> WriterT
      ([Stm rep], [Stm rep])
      m
      (Either (Param DeclType, SubExp) (Param DeclType, SubExp)))
-> [LoopResultSummary (als, Type)]
-> WriterT
     ([Stm rep], [Stm rep])
     m
     [Either (Param DeclType, SubExp) (Param DeclType, SubExp)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM LoopResultSummary (als, Type)
-> WriterT
     ([Stm rep], [Stm rep])
     m
     (Either (Param DeclType, SubExp) (Param DeclType, SubExp))
forall {rep} {rep} {m :: * -> *} {a}.
(MonadFreshNames m, MonadWriter ([Stm rep], [Stm rep]) m,
 Buildable rep, Buildable rep) =>
LoopResultSummary (a, Type)
-> m (Either (Param DeclType, SubExp) (Param DeclType, SubExp))
mkMerge [LoopResultSummary (als, Type)]
summaries
      ([(Param DeclType, SubExp)], [Stm rep], [Stm rep])
-> m ([(Param DeclType, SubExp)], [Stm rep], [Stm rep])
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([(Param DeclType, SubExp)]
origmerge [(Param DeclType, SubExp)]
-> [(Param DeclType, SubExp)] -> [(Param DeclType, SubExp)]
forall a. [a] -> [a] -> [a]
++ [(Param DeclType, SubExp)]
extramerge, [Stm rep]
prestms, [Stm rep]
poststms)

    mkMerge :: LoopResultSummary (a, Type)
-> m (Either (Param DeclType, SubExp) (Param DeclType, SubExp))
mkMerge LoopResultSummary (a, Type)
summary
      | Just (DesiredUpdate (a, Type)
update, VName
mergename, (a, Type)
mergedec) <- LoopResultSummary (a, Type)
-> Maybe (DesiredUpdate (a, Type), VName, (a, Type))
forall dec.
LoopResultSummary dec -> Maybe (DesiredUpdate dec, VName, dec)
relatedUpdate LoopResultSummary (a, Type)
summary = do
          VName
source <- String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"modified_source"
          VName
precopy <- String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (String -> m VName) -> String -> m VName
forall a b. (a -> b) -> a -> b
$ VName -> String
baseString (DesiredUpdate (a, Type) -> VName
forall dec. DesiredUpdate dec -> VName
updateValue DesiredUpdate (a, Type)
update) String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
"_precopy"
          let source_t :: Type
source_t = (a, Type) -> Type
forall a b. (a, b) -> b
snd ((a, Type) -> Type) -> (a, Type) -> Type
forall a b. (a -> b) -> a -> b
$ DesiredUpdate (a, Type) -> (a, Type)
forall dec. DesiredUpdate dec -> dec
updateType DesiredUpdate (a, Type)
update
              elm_t :: Type
elm_t = Type
source_t Type -> [SubExp] -> Type
forall oldshape u.
TypeBase oldshape u -> [SubExp] -> TypeBase Shape u
`setArrayDims` Slice SubExp -> [SubExp]
forall d. Slice d -> [d]
sliceDims (DesiredUpdate (a, Type) -> Slice SubExp
forall dec. DesiredUpdate dec -> Slice SubExp
updateIndices DesiredUpdate (a, Type)
update)
          ([Stm rep], [Stm rep]) -> m ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell
            ( [ [Ident] -> Exp rep -> Stm rep
forall rep. Buildable rep => [Ident] -> Exp rep -> Stm rep
mkLet [VName -> Type -> Ident
Ident VName
source Type
source_t] (Exp rep -> Stm rep) -> (BasicOp -> Exp rep) -> BasicOp -> Stm rep
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp
                  (BasicOp -> Stm rep) -> BasicOp -> Stm rep
forall a b. (a -> b) -> a -> b
$ Safety -> VName -> Slice SubExp -> SubExp -> BasicOp
Update
                    Safety
Unsafe
                    (DesiredUpdate (a, Type) -> VName
forall dec. DesiredUpdate dec -> VName
updateSource DesiredUpdate (a, Type)
update)
                    (Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice Type
source_t ([DimIndex SubExp] -> Slice SubExp)
-> [DimIndex SubExp] -> Slice SubExp
forall a b. (a -> b) -> a -> b
$ Slice SubExp -> [DimIndex SubExp]
forall d. Slice d -> [DimIndex d]
unSlice (Slice SubExp -> [DimIndex SubExp])
-> Slice SubExp -> [DimIndex SubExp]
forall a b. (a -> b) -> a -> b
$ DesiredUpdate (a, Type) -> Slice SubExp
forall dec. DesiredUpdate dec -> Slice SubExp
updateIndices DesiredUpdate (a, Type)
update)
                  (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ (Param DeclType, SubExp) -> SubExp
forall a b. (a, b) -> b
snd
                  ((Param DeclType, SubExp) -> SubExp)
-> (Param DeclType, SubExp) -> SubExp
forall a b. (a -> b) -> a -> b
$ LoopResultSummary (a, Type) -> (Param DeclType, SubExp)
forall dec. LoopResultSummary dec -> (Param DeclType, SubExp)
mergeParam LoopResultSummary (a, Type)
summary
              ],
              [ [Ident] -> Exp rep -> Stm rep
forall rep. Buildable rep => [Ident] -> Exp rep -> Stm rep
mkLet [VName -> Type -> Ident
Ident VName
precopy Type
elm_t] (Exp rep -> Stm rep) -> (BasicOp -> Exp rep) -> BasicOp -> Stm rep
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Stm rep) -> BasicOp -> Stm rep
forall a b. (a -> b) -> a -> b
$
                  VName -> Slice SubExp -> BasicOp
Index
                    (DesiredUpdate (a, Type) -> VName
forall dec. DesiredUpdate dec -> VName
updateName DesiredUpdate (a, Type)
update)
                    (Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice Type
source_t ([DimIndex SubExp] -> Slice SubExp)
-> [DimIndex SubExp] -> Slice SubExp
forall a b. (a -> b) -> a -> b
$ Slice SubExp -> [DimIndex SubExp]
forall d. Slice d -> [DimIndex d]
unSlice (Slice SubExp -> [DimIndex SubExp])
-> Slice SubExp -> [DimIndex SubExp]
forall a b. (a -> b) -> a -> b
$ DesiredUpdate (a, Type) -> Slice SubExp
forall dec. DesiredUpdate dec -> Slice SubExp
updateIndices DesiredUpdate (a, Type)
update),
                [Ident] -> Exp rep -> Stm rep
forall rep. Buildable rep => [Ident] -> Exp rep -> Stm rep
mkLet [VName -> Type -> Ident
Ident (DesiredUpdate (a, Type) -> VName
forall dec. DesiredUpdate dec -> VName
updateValue DesiredUpdate (a, Type)
update) Type
elm_t] (Exp rep -> Stm rep) -> Exp rep -> Stm rep
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp rep) -> BasicOp -> Exp rep
forall a b. (a -> b) -> a -> b
$ Shape -> SubExp -> BasicOp
Replicate Shape
forall a. Monoid a => a
mempty (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
precopy
              ]
            )
          Either (Param DeclType, SubExp) (Param DeclType, SubExp)
-> m (Either (Param DeclType, SubExp) (Param DeclType, SubExp))
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Either (Param DeclType, SubExp) (Param DeclType, SubExp)
 -> m (Either (Param DeclType, SubExp) (Param DeclType, SubExp)))
-> Either (Param DeclType, SubExp) (Param DeclType, SubExp)
-> m (Either (Param DeclType, SubExp) (Param DeclType, SubExp))
forall a b. (a -> b) -> a -> b
$
            (Param DeclType, SubExp)
-> Either (Param DeclType, SubExp) (Param DeclType, SubExp)
forall a b. b -> Either a b
Right
              ( Attrs -> VName -> DeclType -> Param DeclType
forall dec. Attrs -> VName -> dec -> Param dec
Param Attrs
forall a. Monoid a => a
mempty VName
mergename (Type -> Uniqueness -> DeclType
forall shape.
TypeBase shape NoUniqueness
-> Uniqueness -> TypeBase shape Uniqueness
toDecl ((a, Type) -> Type
forall t. Typed t => t -> Type
typeOf (a, Type)
mergedec) Uniqueness
Unique),
                VName -> SubExp
Var VName
source
              )
      | Bool
otherwise = Either (Param DeclType, SubExp) (Param DeclType, SubExp)
-> m (Either (Param DeclType, SubExp) (Param DeclType, SubExp))
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Either (Param DeclType, SubExp) (Param DeclType, SubExp)
 -> m (Either (Param DeclType, SubExp) (Param DeclType, SubExp)))
-> Either (Param DeclType, SubExp) (Param DeclType, SubExp)
-> m (Either (Param DeclType, SubExp) (Param DeclType, SubExp))
forall a b. (a -> b) -> a -> b
$ (Param DeclType, SubExp)
-> Either (Param DeclType, SubExp) (Param DeclType, SubExp)
forall a b. a -> Either a b
Left ((Param DeclType, SubExp)
 -> Either (Param DeclType, SubExp) (Param DeclType, SubExp))
-> (Param DeclType, SubExp)
-> Either (Param DeclType, SubExp) (Param DeclType, SubExp)
forall a b. (a -> b) -> a -> b
$ LoopResultSummary (a, Type) -> (Param DeclType, SubExp)
forall dec. LoopResultSummary dec -> (Param DeclType, SubExp)
mergeParam LoopResultSummary (a, Type)
summary

    mkResAndPat :: [LoopResultSummary (a, Type)] -> [Ident]
mkResAndPat [LoopResultSummary (a, Type)]
summaries =
      let ([Ident]
origpat, [Ident]
extrapat) = [Either Ident Ident] -> ([Ident], [Ident])
forall a b. [Either a b] -> ([a], [b])
partitionEithers ([Either Ident Ident] -> ([Ident], [Ident]))
-> [Either Ident Ident] -> ([Ident], [Ident])
forall a b. (a -> b) -> a -> b
$ (LoopResultSummary (a, Type) -> Either Ident Ident)
-> [LoopResultSummary (a, Type)] -> [Either Ident Ident]
forall a b. (a -> b) -> [a] -> [b]
map LoopResultSummary (a, Type) -> Either Ident Ident
forall {a}. LoopResultSummary (a, Type) -> Either Ident Ident
mkResAndPat' [LoopResultSummary (a, Type)]
summaries
       in [Ident]
origpat [Ident] -> [Ident] -> [Ident]
forall a. [a] -> [a] -> [a]
++ [Ident]
extrapat

    mkResAndPat' :: LoopResultSummary (a, Type) -> Either Ident Ident
mkResAndPat' LoopResultSummary (a, Type)
summary
      | Just (DesiredUpdate (a, Type)
update, VName
_, (a, Type)
_) <- LoopResultSummary (a, Type)
-> Maybe (DesiredUpdate (a, Type), VName, (a, Type))
forall dec.
LoopResultSummary dec -> Maybe (DesiredUpdate dec, VName, dec)
relatedUpdate LoopResultSummary (a, Type)
summary =
          Ident -> Either Ident Ident
forall a b. b -> Either a b
Right (VName -> Type -> Ident
Ident (DesiredUpdate (a, Type) -> VName
forall dec. DesiredUpdate dec -> VName
updateName DesiredUpdate (a, Type)
update) ((a, Type) -> Type
forall a b. (a, b) -> b
snd ((a, Type) -> Type) -> (a, Type) -> Type
forall a b. (a -> b) -> a -> b
$ DesiredUpdate (a, Type) -> (a, Type)
forall dec. DesiredUpdate dec -> dec
updateType DesiredUpdate (a, Type)
update))
      | Bool
otherwise =
          Ident -> Either Ident Ident
forall a b. a -> Either a b
Left (LoopResultSummary (a, Type) -> Ident
forall dec. LoopResultSummary dec -> Ident
inPatAs LoopResultSummary (a, Type)
summary)

summariseLoop ::
  ( Aliased rep,
    MonadFreshNames m
  ) =>
  Scope rep ->
  [DesiredUpdate (als, Type)] ->
  Names ->
  [(SubExpRes, Ident)] ->
  [(Param DeclType, SubExp)] ->
  Maybe (m [LoopResultSummary (als, Type)])
summariseLoop :: forall rep (m :: * -> *) als.
(Aliased rep, MonadFreshNames m) =>
Scope rep
-> [DesiredUpdate (als, Type)]
-> Names
-> [(SubExpRes, Ident)]
-> [(Param DeclType, SubExp)]
-> Maybe (m [LoopResultSummary (als, Type)])
summariseLoop Scope rep
scope [DesiredUpdate (als, Type)]
updates Names
usedInBody [(SubExpRes, Ident)]
resmap [(Param DeclType, SubExp)]
merge =
  [m (LoopResultSummary (als, Type))]
-> m [LoopResultSummary (als, Type)]
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
forall (m :: * -> *) a. Monad m => [m a] -> m [a]
sequence ([m (LoopResultSummary (als, Type))]
 -> m [LoopResultSummary (als, Type)])
-> Maybe [m (LoopResultSummary (als, Type))]
-> Maybe (m [LoopResultSummary (als, Type)])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((SubExpRes, Ident)
 -> (Param DeclType, SubExp)
 -> Maybe (m (LoopResultSummary (als, Type))))
-> [(SubExpRes, Ident)]
-> [(Param DeclType, SubExp)]
-> Maybe [m (LoopResultSummary (als, Type))]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM (SubExpRes, Ident)
-> (Param DeclType, SubExp)
-> Maybe (m (LoopResultSummary (als, Type)))
summariseLoopResult [(SubExpRes, Ident)]
resmap [(Param DeclType, SubExp)]
merge
  where
    summariseLoopResult :: (SubExpRes, Ident)
-> (Param DeclType, SubExp)
-> Maybe (m (LoopResultSummary (als, Type)))
summariseLoopResult (SubExpRes
se, Ident
v) (Param DeclType
fparam, SubExp
mergeinit)
      | Just DesiredUpdate (als, Type)
update <- (DesiredUpdate (als, Type) -> Bool)
-> [DesiredUpdate (als, Type)] -> Maybe (DesiredUpdate (als, Type))
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find (VName -> DesiredUpdate (als, Type) -> Bool
forall dec. VName -> DesiredUpdate dec -> Bool
updateHasValue (VName -> DesiredUpdate (als, Type) -> Bool)
-> VName -> DesiredUpdate (als, Type) -> Bool
forall a b. (a -> b) -> a -> b
$ Ident -> VName
identName Ident
v) [DesiredUpdate (als, Type)]
updates =
          -- Safety condition (7)
          if Names
usedInBody Names -> Names -> Bool
`namesIntersect` VName -> Scope rep -> Names
forall rep. AliasesOf (LetDec rep) => VName -> Scope rep -> Names
lookupAliases (DesiredUpdate (als, Type) -> VName
forall dec. DesiredUpdate dec -> VName
updateSource DesiredUpdate (als, Type)
update) Scope rep
scope
            then Maybe (m (LoopResultSummary (als, Type)))
forall a. Maybe a
Nothing
            else
              if Param DeclType -> Bool
hasLoopInvariantShape Param DeclType
fparam
                then m (LoopResultSummary (als, Type))
-> Maybe (m (LoopResultSummary (als, Type)))
forall a. a -> Maybe a
Just (m (LoopResultSummary (als, Type))
 -> Maybe (m (LoopResultSummary (als, Type))))
-> m (LoopResultSummary (als, Type))
-> Maybe (m (LoopResultSummary (als, Type)))
forall a b. (a -> b) -> a -> b
$ do
                  VName
lowered_array <- String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"lowered_array"
                  LoopResultSummary (als, Type) -> m (LoopResultSummary (als, Type))
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
                    LoopResultSummary
                      { resultSubExp :: SubExpRes
resultSubExp = SubExpRes
se,
                        inPatAs :: Ident
inPatAs = Ident
v,
                        mergeParam :: (Param DeclType, SubExp)
mergeParam = (Param DeclType
fparam, SubExp
mergeinit),
                        relatedUpdate :: Maybe (DesiredUpdate (als, Type), VName, (als, Type))
relatedUpdate =
                          (DesiredUpdate (als, Type), VName, (als, Type))
-> Maybe (DesiredUpdate (als, Type), VName, (als, Type))
forall a. a -> Maybe a
Just
                            ( DesiredUpdate (als, Type)
update,
                              VName
lowered_array,
                              DesiredUpdate (als, Type) -> (als, Type)
forall dec. DesiredUpdate dec -> dec
updateType DesiredUpdate (als, Type)
update
                            )
                      }
                else Maybe (m (LoopResultSummary (als, Type)))
forall a. Maybe a
Nothing
    summariseLoopResult (SubExpRes, Ident)
_ (Param DeclType, SubExp)
_ =
      Maybe (m (LoopResultSummary (als, Type)))
forall a. Maybe a
Nothing -- XXX: conservative; but this entire pass is going away.
    hasLoopInvariantShape :: Param DeclType -> Bool
hasLoopInvariantShape = (SubExp -> Bool) -> [SubExp] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all SubExp -> Bool
loopInvariant ([SubExp] -> Bool)
-> (Param DeclType -> [SubExp]) -> Param DeclType -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims (Type -> [SubExp])
-> (Param DeclType -> Type) -> Param DeclType -> [SubExp]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param DeclType -> Type
forall dec. Typed dec => Param dec -> Type
paramType

    merge_param_names :: [VName]
merge_param_names = ((Param DeclType, SubExp) -> VName)
-> [(Param DeclType, SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (Param DeclType -> VName
forall dec. Param dec -> VName
paramName (Param DeclType -> VName)
-> ((Param DeclType, SubExp) -> Param DeclType)
-> (Param DeclType, SubExp)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Param DeclType, SubExp) -> Param DeclType
forall a b. (a, b) -> a
fst) [(Param DeclType, SubExp)]
merge

    loopInvariant :: SubExp -> Bool
loopInvariant (Var VName
v) = VName
v VName -> [VName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` [VName]
merge_param_names
    loopInvariant Constant {} = Bool
True

data LoopResultSummary dec = LoopResultSummary
  { forall dec. LoopResultSummary dec -> SubExpRes
resultSubExp :: SubExpRes,
    forall dec. LoopResultSummary dec -> Ident
inPatAs :: Ident,
    forall dec. LoopResultSummary dec -> (Param DeclType, SubExp)
mergeParam :: (Param DeclType, SubExp),
    forall dec.
LoopResultSummary dec -> Maybe (DesiredUpdate dec, VName, dec)
relatedUpdate :: Maybe (DesiredUpdate dec, VName, dec)
  }
  deriving (Int -> LoopResultSummary dec -> ShowS
[LoopResultSummary dec] -> ShowS
LoopResultSummary dec -> String
(Int -> LoopResultSummary dec -> ShowS)
-> (LoopResultSummary dec -> String)
-> ([LoopResultSummary dec] -> ShowS)
-> Show (LoopResultSummary dec)
forall dec. Show dec => Int -> LoopResultSummary dec -> ShowS
forall dec. Show dec => [LoopResultSummary dec] -> ShowS
forall dec. Show dec => LoopResultSummary dec -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: forall dec. Show dec => Int -> LoopResultSummary dec -> ShowS
showsPrec :: Int -> LoopResultSummary dec -> ShowS
$cshow :: forall dec. Show dec => LoopResultSummary dec -> String
show :: LoopResultSummary dec -> String
$cshowList :: forall dec. Show dec => [LoopResultSummary dec] -> ShowS
showList :: [LoopResultSummary dec] -> ShowS
Show)

indexSubstitutions :: Typed dec => [LoopResultSummary dec] -> IndexSubstitutions
indexSubstitutions :: forall dec.
Typed dec =>
[LoopResultSummary dec] -> IndexSubstitutions
indexSubstitutions = (LoopResultSummary dec -> Maybe (VName, IndexSubstitution))
-> [LoopResultSummary dec] -> IndexSubstitutions
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe LoopResultSummary dec -> Maybe (VName, IndexSubstitution)
forall {t}.
Typed t =>
LoopResultSummary t -> Maybe (VName, IndexSubstitution)
getSubstitution
  where
    getSubstitution :: LoopResultSummary t -> Maybe (VName, IndexSubstitution)
getSubstitution LoopResultSummary t
res = do
      (DesiredUpdate VName
_ t
_ Certs
cs VName
_ Slice SubExp
is VName
_, VName
nm, t
dec) <- LoopResultSummary t -> Maybe (DesiredUpdate t, VName, t)
forall dec.
LoopResultSummary dec -> Maybe (DesiredUpdate dec, VName, dec)
relatedUpdate LoopResultSummary t
res
      let name :: VName
name = Param DeclType -> VName
forall dec. Param dec -> VName
paramName (Param DeclType -> VName) -> Param DeclType -> VName
forall a b. (a -> b) -> a -> b
$ (Param DeclType, SubExp) -> Param DeclType
forall a b. (a, b) -> a
fst ((Param DeclType, SubExp) -> Param DeclType)
-> (Param DeclType, SubExp) -> Param DeclType
forall a b. (a -> b) -> a -> b
$ LoopResultSummary t -> (Param DeclType, SubExp)
forall dec. LoopResultSummary dec -> (Param DeclType, SubExp)
mergeParam LoopResultSummary t
res
      (VName, IndexSubstitution) -> Maybe (VName, IndexSubstitution)
forall a. a -> Maybe a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName
name, (Certs
cs, VName
nm, t -> Type
forall t. Typed t => t -> Type
typeOf t
dec, Slice SubExp
is))

manipulateResult ::
  (Buildable rep, MonadFreshNames m) =>
  [LoopResultSummary (LetDec rep)] ->
  IndexSubstitutions ->
  m (Result, Stms rep)
manipulateResult :: forall rep (m :: * -> *).
(Buildable rep, MonadFreshNames m) =>
[LoopResultSummary (LetDec rep)]
-> IndexSubstitutions -> m (Result, Stms rep)
manipulateResult [LoopResultSummary (LetDec rep)]
summaries IndexSubstitutions
substs = do
  let (Result
orig_ses, Result
updated_ses) = [Either SubExpRes SubExpRes] -> (Result, Result)
forall a b. [Either a b] -> ([a], [b])
partitionEithers ([Either SubExpRes SubExpRes] -> (Result, Result))
-> [Either SubExpRes SubExpRes] -> (Result, Result)
forall a b. (a -> b) -> a -> b
$ (LoopResultSummary (LetDec rep) -> Either SubExpRes SubExpRes)
-> [LoopResultSummary (LetDec rep)] -> [Either SubExpRes SubExpRes]
forall a b. (a -> b) -> [a] -> [b]
map LoopResultSummary (LetDec rep) -> Either SubExpRes SubExpRes
forall {dec}. LoopResultSummary dec -> Either SubExpRes SubExpRes
unchangedRes [LoopResultSummary (LetDec rep)]
summaries
  (Result
subst_ses, [Stm rep]
res_stms) <- WriterT [Stm rep] m Result -> m (Result, [Stm rep])
forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT (WriterT [Stm rep] m Result -> m (Result, [Stm rep]))
-> WriterT [Stm rep] m Result -> m (Result, [Stm rep])
forall a b. (a -> b) -> a -> b
$ (SubExpRes
 -> (VName, IndexSubstitution) -> WriterT [Stm rep] m SubExpRes)
-> Result -> IndexSubstitutions -> WriterT [Stm rep] m Result
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM SubExpRes
-> (VName, IndexSubstitution) -> WriterT [Stm rep] m SubExpRes
forall {rep} {f :: * -> *} {t}.
(MonadFreshNames f, MonadWriter [Stm rep] f, Buildable rep,
 Typed t) =>
SubExpRes
-> (VName, (Certs, VName, t, Slice SubExp)) -> f SubExpRes
substRes Result
updated_ses IndexSubstitutions
substs
  (Result, Stms rep) -> m (Result, Stms rep)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Result
orig_ses Result -> Result -> Result
forall a. [a] -> [a] -> [a]
++ Result
subst_ses, [Stm rep] -> Stms rep
forall rep. [Stm rep] -> Stms rep
stmsFromList [Stm rep]
res_stms)
  where
    unchangedRes :: LoopResultSummary dec -> Either SubExpRes SubExpRes
unchangedRes LoopResultSummary dec
summary =
      case LoopResultSummary dec -> Maybe (DesiredUpdate dec, VName, dec)
forall dec.
LoopResultSummary dec -> Maybe (DesiredUpdate dec, VName, dec)
relatedUpdate LoopResultSummary dec
summary of
        Maybe (DesiredUpdate dec, VName, dec)
Nothing -> SubExpRes -> Either SubExpRes SubExpRes
forall a b. a -> Either a b
Left (SubExpRes -> Either SubExpRes SubExpRes)
-> SubExpRes -> Either SubExpRes SubExpRes
forall a b. (a -> b) -> a -> b
$ LoopResultSummary dec -> SubExpRes
forall dec. LoopResultSummary dec -> SubExpRes
resultSubExp LoopResultSummary dec
summary
        Just (DesiredUpdate dec, VName, dec)
_ -> SubExpRes -> Either SubExpRes SubExpRes
forall a b. b -> Either a b
Right (SubExpRes -> Either SubExpRes SubExpRes)
-> SubExpRes -> Either SubExpRes SubExpRes
forall a b. (a -> b) -> a -> b
$ LoopResultSummary dec -> SubExpRes
forall dec. LoopResultSummary dec -> SubExpRes
resultSubExp LoopResultSummary dec
summary
    substRes :: SubExpRes
-> (VName, (Certs, VName, t, Slice SubExp)) -> f SubExpRes
substRes (SubExpRes Certs
res_cs (Var VName
res_v)) (VName
subst_v, (Certs
_, VName
nm, t
_, Slice SubExp
_))
      | VName
res_v VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
subst_v =
          SubExpRes -> f SubExpRes
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SubExpRes -> f SubExpRes) -> SubExpRes -> f SubExpRes
forall a b. (a -> b) -> a -> b
$ Certs -> SubExp -> SubExpRes
SubExpRes Certs
res_cs (SubExp -> SubExpRes) -> SubExp -> SubExpRes
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
nm
    substRes (SubExpRes Certs
res_cs SubExp
res_se) (VName
_, (Certs
cs, VName
nm, t
dec, Slice [DimIndex SubExp]
is)) = do
      Ident
v' <- ShowS -> Ident -> f Ident
forall (m :: * -> *).
MonadFreshNames m =>
ShowS -> Ident -> m Ident
newIdent' (String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"_updated") (Ident -> f Ident) -> Ident -> f Ident
forall a b. (a -> b) -> a -> b
$ VName -> Type -> Ident
Ident VName
nm (Type -> Ident) -> Type -> Ident
forall a b. (a -> b) -> a -> b
$ t -> Type
forall t. Typed t => t -> Type
typeOf t
dec
      [Stm rep] -> f ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell
        [ Certs -> Stm rep -> Stm rep
forall rep. Certs -> Stm rep -> Stm rep
certify (Certs
res_cs Certs -> Certs -> Certs
forall a. Semigroup a => a -> a -> a
<> Certs
cs) (Stm rep -> Stm rep) -> (BasicOp -> Stm rep) -> BasicOp -> Stm rep
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Ident] -> Exp rep -> Stm rep
forall rep. Buildable rep => [Ident] -> Exp rep -> Stm rep
mkLet [Ident
v'] (Exp rep -> Stm rep) -> (BasicOp -> Exp rep) -> BasicOp -> Stm rep
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Stm rep) -> BasicOp -> Stm rep
forall a b. (a -> b) -> a -> b
$
            Safety -> VName -> Slice SubExp -> SubExp -> BasicOp
Update Safety
Unsafe VName
nm (Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice (t -> Type
forall t. Typed t => t -> Type
typeOf t
dec) [DimIndex SubExp]
is) SubExp
res_se
        ]
      SubExpRes -> f SubExpRes
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SubExpRes -> f SubExpRes) -> SubExpRes -> f SubExpRes
forall a b. (a -> b) -> a -> b
$ VName -> SubExpRes
varRes (VName -> SubExpRes) -> VName -> SubExpRes
forall a b. (a -> b) -> a -> b
$ Ident -> VName
identName Ident
v'