{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}

-- | A representation where all bindings are annotated with aliasing
-- information.
module Futhark.IR.Aliases
  ( -- * The Lore definition
    Aliases,
    AliasDec (..),
    VarAliases,
    ConsumedInExp,
    BodyAliasing,
    module Futhark.IR.Prop.Aliases,

    -- * Module re-exports
    module Futhark.IR.Prop,
    module Futhark.IR.Traversals,
    module Futhark.IR.Pretty,
    module Futhark.IR.Syntax,

    -- * Adding aliases
    addAliasesToPattern,
    mkAliasedLetStm,
    mkAliasedBody,
    mkPatternAliases,
    mkBodyAliases,

    -- * Removing aliases
    removeProgAliases,
    removeFunDefAliases,
    removeExpAliases,
    removeStmAliases,
    removeLambdaAliases,
    removePatternAliases,
    removeScopeAliases,

    -- * Tracking aliases
    AliasesAndConsumed,
    trackAliases,
    consumedInStms,
  )
where

import Control.Monad.Identity
import Control.Monad.Reader
import qualified Data.Map.Strict as M
import Data.Maybe
import Futhark.Analysis.Rephrase
import Futhark.Binder
import Futhark.IR.Pretty
import Futhark.IR.Prop
import Futhark.IR.Prop.Aliases
import Futhark.IR.Syntax
import Futhark.IR.Traversals
import Futhark.Transform.Rename
import Futhark.Transform.Substitute
import qualified Futhark.Util.Pretty as PP
import GHC.Generics
import Language.SexpGrammar as Sexp
import Language.SexpGrammar.Generic

-- | The lore for the basic representation.
data Aliases lore

-- | A wrapper around 'AliasDec to get around the fact that we need an
-- 'Ord' instance, which 'AliasDec does not have.
newtype AliasDec = AliasDec {AliasDec -> Names
unAliases :: Names}
  deriving (Int -> AliasDec -> ShowS
[AliasDec] -> ShowS
AliasDec -> String
(Int -> AliasDec -> ShowS)
-> (AliasDec -> String) -> ([AliasDec] -> ShowS) -> Show AliasDec
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [AliasDec] -> ShowS
$cshowList :: [AliasDec] -> ShowS
show :: AliasDec -> String
$cshow :: AliasDec -> String
showsPrec :: Int -> AliasDec -> ShowS
$cshowsPrec :: Int -> AliasDec -> ShowS
Show, (forall x. AliasDec -> Rep AliasDec x)
-> (forall x. Rep AliasDec x -> AliasDec) -> Generic AliasDec
forall x. Rep AliasDec x -> AliasDec
forall x. AliasDec -> Rep AliasDec x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cto :: forall x. Rep AliasDec x -> AliasDec
$cfrom :: forall x. AliasDec -> Rep AliasDec x
Generic)

instance SexpIso AliasDec where
  sexpIso :: Grammar Position (Sexp :- t) (AliasDec :- t)
sexpIso = (Grammar Position (Names :- t) (AliasDec :- t)
 -> Grammar Position (Sexp :- t) (AliasDec :- t))
-> Grammar Position (Sexp :- t) (AliasDec :- t)
forall a b s t (c :: Meta) (d :: Meta) (f :: * -> *) p.
(Generic a, MkPrismList (Rep a), MkStackPrism f,
 Rep a ~ M1 D d (M1 C c f), StackPrismLhs f t ~ b, Constructor c) =>
(Grammar p b (a :- t) -> Grammar p s (a :- t))
-> Grammar p s (a :- t)
with ((Grammar Position (Names :- t) (AliasDec :- t)
  -> Grammar Position (Sexp :- t) (AliasDec :- t))
 -> Grammar Position (Sexp :- t) (AliasDec :- t))
-> (Grammar Position (Names :- t) (AliasDec :- t)
    -> Grammar Position (Sexp :- t) (AliasDec :- t))
-> Grammar Position (Sexp :- t) (AliasDec :- t)
forall a b. (a -> b) -> a -> b
$ \Grammar Position (Names :- t) (AliasDec :- t)
vname -> Grammar Position (Sexp :- t) (Names :- t)
forall a. SexpIso a => SexpGrammar a
sexpIso Grammar Position (Sexp :- t) (Names :- t)
-> Grammar Position (Names :- t) (AliasDec :- t)
-> Grammar Position (Sexp :- t) (AliasDec :- t)
forall k (cat :: k -> k -> *) (a :: k) (b :: k) (c :: k).
Category cat =>
cat a b -> cat b c -> cat a c
>>> Grammar Position (Names :- t) (AliasDec :- t)
vname

instance Semigroup AliasDec where
  AliasDec
x <> :: AliasDec -> AliasDec -> AliasDec
<> AliasDec
y = Names -> AliasDec
AliasDec (Names -> AliasDec) -> Names -> AliasDec
forall a b. (a -> b) -> a -> b
$ AliasDec -> Names
unAliases AliasDec
x Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> AliasDec -> Names
unAliases AliasDec
y

instance Monoid AliasDec where
  mempty :: AliasDec
mempty = Names -> AliasDec
AliasDec Names
forall a. Monoid a => a
mempty

instance Eq AliasDec where
  AliasDec
_ == :: AliasDec -> AliasDec -> Bool
== AliasDec
_ = Bool
True

instance Ord AliasDec where
  AliasDec
_ compare :: AliasDec -> AliasDec -> Ordering
`compare` AliasDec
_ = Ordering
EQ

instance Rename AliasDec where
  rename :: AliasDec -> RenameM AliasDec
rename (AliasDec Names
names) = Names -> AliasDec
AliasDec (Names -> AliasDec) -> RenameM Names -> RenameM AliasDec
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Names -> RenameM Names
forall a. Rename a => a -> RenameM a
rename Names
names

instance Substitute AliasDec where
  substituteNames :: Map VName VName -> AliasDec -> AliasDec
substituteNames Map VName VName
substs (AliasDec Names
names) = Names -> AliasDec
AliasDec (Names -> AliasDec) -> Names -> AliasDec
forall a b. (a -> b) -> a -> b
$ Map VName VName -> Names -> Names
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs Names
names

instance FreeIn AliasDec where
  freeIn' :: AliasDec -> FV
freeIn' = FV -> AliasDec -> FV
forall a b. a -> b -> a
const FV
forall a. Monoid a => a
mempty

instance PP.Pretty AliasDec where
  ppr :: AliasDec -> Doc
ppr = [Doc] -> Doc
PP.commasep ([Doc] -> Doc) -> (AliasDec -> [Doc]) -> AliasDec -> Doc
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName -> Doc) -> [VName] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map VName -> Doc
forall a. Pretty a => a -> Doc
PP.ppr ([VName] -> [Doc]) -> (AliasDec -> [VName]) -> AliasDec -> [Doc]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Names -> [VName]
namesToList (Names -> [VName]) -> (AliasDec -> Names) -> AliasDec -> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. AliasDec -> Names
unAliases

-- | The aliases of the let-bound variable.
type VarAliases = AliasDec

-- | Everything consumed in the expression.
type ConsumedInExp = AliasDec

-- | The aliases of what is returned by the t'Body', and what is
-- consumed inside of it.
type BodyAliasing = ([VarAliases], ConsumedInExp)

instance
  (Decorations lore, CanBeAliased (Op lore)) =>
  Decorations (Aliases lore)
  where
  type LetDec (Aliases lore) = (VarAliases, LetDec lore)
  type ExpDec (Aliases lore) = (ConsumedInExp, ExpDec lore)
  type BodyDec (Aliases lore) = (BodyAliasing, BodyDec lore)
  type FParamInfo (Aliases lore) = FParamInfo lore
  type LParamInfo (Aliases lore) = LParamInfo lore
  type RetType (Aliases lore) = RetType lore
  type BranchType (Aliases lore) = BranchType lore
  type Op (Aliases lore) = OpWithAliases (Op lore)

instance AliasesOf (VarAliases, dec) where
  aliasesOf :: (AliasDec, dec) -> Names
aliasesOf = AliasDec -> Names
unAliases (AliasDec -> Names)
-> ((AliasDec, dec) -> AliasDec) -> (AliasDec, dec) -> Names
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (AliasDec, dec) -> AliasDec
forall a b. (a, b) -> a
fst

instance FreeDec AliasDec

withoutAliases ::
  (HasScope (Aliases lore) m, Monad m) =>
  ReaderT (Scope lore) m a ->
  m a
withoutAliases :: ReaderT (Scope lore) m a -> m a
withoutAliases ReaderT (Scope lore) m a
m = do
  Scope lore
scope <- (Scope (Aliases lore) -> Scope lore) -> m (Scope lore)
forall lore (m :: * -> *) a.
HasScope lore m =>
(Scope lore -> a) -> m a
asksScope Scope (Aliases lore) -> Scope lore
forall lore. Scope (Aliases lore) -> Scope lore
removeScopeAliases
  ReaderT (Scope lore) m a -> Scope lore -> m a
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ReaderT (Scope lore) m a
m Scope lore
scope

instance (ASTLore lore, CanBeAliased (Op lore)) => ASTLore (Aliases lore) where
  expTypesFromPattern :: Pattern (Aliases lore) -> m [BranchType (Aliases lore)]
expTypesFromPattern =
    ReaderT (Scope lore) m [BranchType lore] -> m [BranchType lore]
forall lore (m :: * -> *) a.
(HasScope (Aliases lore) m, Monad m) =>
ReaderT (Scope lore) m a -> m a
withoutAliases (ReaderT (Scope lore) m [BranchType lore] -> m [BranchType lore])
-> (PatternT (AliasDec, LetDec lore)
    -> ReaderT (Scope lore) m [BranchType lore])
-> PatternT (AliasDec, LetDec lore)
-> m [BranchType lore]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatternT (LetDec lore) -> ReaderT (Scope lore) m [BranchType lore]
forall lore (m :: * -> *).
(ASTLore lore, HasScope lore m, Monad m) =>
Pattern lore -> m [BranchType lore]
expTypesFromPattern (PatternT (LetDec lore)
 -> ReaderT (Scope lore) m [BranchType lore])
-> (PatternT (AliasDec, LetDec lore) -> PatternT (LetDec lore))
-> PatternT (AliasDec, LetDec lore)
-> ReaderT (Scope lore) m [BranchType lore]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatternT (AliasDec, LetDec lore) -> PatternT (LetDec lore)
forall a. PatternT (AliasDec, a) -> PatternT a
removePatternAliases

instance (ASTLore lore, CanBeAliased (Op lore)) => Aliased (Aliases lore) where
  bodyAliases :: Body (Aliases lore) -> [Names]
bodyAliases = (AliasDec -> Names) -> [AliasDec] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map AliasDec -> Names
unAliases ([AliasDec] -> [Names])
-> (Body (Aliases lore) -> [AliasDec])
-> Body (Aliases lore)
-> [Names]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([AliasDec], AliasDec) -> [AliasDec]
forall a b. (a, b) -> a
fst (([AliasDec], AliasDec) -> [AliasDec])
-> (Body (Aliases lore) -> ([AliasDec], AliasDec))
-> Body (Aliases lore)
-> [AliasDec]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (([AliasDec], AliasDec), BodyDec lore) -> ([AliasDec], AliasDec)
forall a b. (a, b) -> a
fst ((([AliasDec], AliasDec), BodyDec lore) -> ([AliasDec], AliasDec))
-> (Body (Aliases lore) -> (([AliasDec], AliasDec), BodyDec lore))
-> Body (Aliases lore)
-> ([AliasDec], AliasDec)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Body (Aliases lore) -> (([AliasDec], AliasDec), BodyDec lore)
forall lore. BodyT lore -> BodyDec lore
bodyDec
  consumedInBody :: Body (Aliases lore) -> Names
consumedInBody = AliasDec -> Names
unAliases (AliasDec -> Names)
-> (Body (Aliases lore) -> AliasDec)
-> Body (Aliases lore)
-> Names
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([AliasDec], AliasDec) -> AliasDec
forall a b. (a, b) -> b
snd (([AliasDec], AliasDec) -> AliasDec)
-> (Body (Aliases lore) -> ([AliasDec], AliasDec))
-> Body (Aliases lore)
-> AliasDec
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (([AliasDec], AliasDec), BodyDec lore) -> ([AliasDec], AliasDec)
forall a b. (a, b) -> a
fst ((([AliasDec], AliasDec), BodyDec lore) -> ([AliasDec], AliasDec))
-> (Body (Aliases lore) -> (([AliasDec], AliasDec), BodyDec lore))
-> Body (Aliases lore)
-> ([AliasDec], AliasDec)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Body (Aliases lore) -> (([AliasDec], AliasDec), BodyDec lore)
forall lore. BodyT lore -> BodyDec lore
bodyDec

instance
  PrettyAnnot (PatElemT dec) =>
  PrettyAnnot (PatElemT (VarAliases, dec))
  where
  ppAnnot :: PatElemT (AliasDec, dec) -> Maybe Doc
ppAnnot (PatElem VName
name (AliasDec Names
als, dec
dec)) =
    let alias_comment :: Maybe Doc
alias_comment = Doc -> Doc
PP.oneLine (Doc -> Doc) -> Maybe Doc -> Maybe Doc
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> Names -> Maybe Doc
forall a. Pretty a => a -> Names -> Maybe Doc
aliasComment VName
name Names
als
     in case (Maybe Doc
alias_comment, PatElemT dec -> Maybe Doc
forall a. PrettyAnnot a => a -> Maybe Doc
ppAnnot (VName -> dec -> PatElemT dec
forall dec. VName -> dec -> PatElemT dec
PatElem VName
name dec
dec)) of
          (Maybe Doc
_, Maybe Doc
Nothing) ->
            Maybe Doc
alias_comment
          (Just Doc
alias_comment', Just Doc
inner_comment) ->
            Doc -> Maybe Doc
forall a. a -> Maybe a
Just (Doc -> Maybe Doc) -> Doc -> Maybe Doc
forall a b. (a -> b) -> a -> b
$ Doc
alias_comment' Doc -> Doc -> Doc
PP.</> Doc
inner_comment
          (Maybe Doc
Nothing, Just Doc
inner_comment) ->
            Doc -> Maybe Doc
forall a. a -> Maybe a
Just Doc
inner_comment

instance (ASTLore lore, CanBeAliased (Op lore)) => PrettyLore (Aliases lore) where
  ppExpLore :: ExpDec (Aliases lore) -> Exp (Aliases lore) -> Maybe Doc
ppExpLore (consumed, inner) Exp (Aliases lore)
e =
    [Doc] -> Maybe Doc
maybeComment ([Doc] -> Maybe Doc) -> [Doc] -> Maybe Doc
forall a b. (a -> b) -> a -> b
$
      [Maybe Doc] -> [Doc]
forall a. [Maybe a] -> [a]
catMaybes
        [ Maybe Doc
exp_dec,
          Maybe Doc
merge_dec,
          ExpDec lore -> Exp lore -> Maybe Doc
forall lore.
PrettyLore lore =>
ExpDec lore -> Exp lore -> Maybe Doc
ppExpLore ExpDec lore
inner (Exp lore -> Maybe Doc) -> Exp lore -> Maybe Doc
forall a b. (a -> b) -> a -> b
$ Exp (Aliases lore) -> Exp lore
forall lore.
CanBeAliased (Op lore) =>
Exp (Aliases lore) -> Exp lore
removeExpAliases Exp (Aliases lore)
e
        ]
    where
      merge_dec :: Maybe Doc
merge_dec =
        case Exp (Aliases lore)
e of
          DoLoop [(Param (FParamInfo (Aliases lore)), SubExp)]
_ [(Param (FParamInfo (Aliases lore)), SubExp)]
merge LoopForm (Aliases lore)
_ BodyT (Aliases lore)
body ->
            let mergeParamAliases :: Param dec -> Names -> Maybe Doc
mergeParamAliases Param dec
fparam Names
als
                  | TypeBase Shape NoUniqueness -> Bool
forall shape u. TypeBase shape u -> Bool
primType (Param dec -> TypeBase Shape NoUniqueness
forall dec. Typed dec => Param dec -> TypeBase Shape NoUniqueness
paramType Param dec
fparam) =
                    Maybe Doc
forall a. Maybe a
Nothing
                  | Bool
otherwise =
                    VName -> Names -> Maybe Doc
forall a. Pretty a => a -> Names -> Maybe Doc
resultAliasComment (Param dec -> VName
forall dec. Param dec -> VName
paramName Param dec
fparam) Names
als
             in [Doc] -> Maybe Doc
maybeComment ([Doc] -> Maybe Doc) -> [Doc] -> Maybe Doc
forall a b. (a -> b) -> a -> b
$
                  [Maybe Doc] -> [Doc]
forall a. [Maybe a] -> [a]
catMaybes ([Maybe Doc] -> [Doc]) -> [Maybe Doc] -> [Doc]
forall a b. (a -> b) -> a -> b
$
                    (Param (FParamInfo lore) -> Names -> Maybe Doc)
-> [Param (FParamInfo lore)] -> [Names] -> [Maybe Doc]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Param (FParamInfo lore) -> Names -> Maybe Doc
forall dec. Typed dec => Param dec -> Names -> Maybe Doc
mergeParamAliases (((Param (FParamInfo lore), SubExp) -> Param (FParamInfo lore))
-> [(Param (FParamInfo lore), SubExp)] -> [Param (FParamInfo lore)]
forall a b. (a -> b) -> [a] -> [b]
map (Param (FParamInfo lore), SubExp) -> Param (FParamInfo lore)
forall a b. (a, b) -> a
fst [(Param (FParamInfo lore), SubExp)]
[(Param (FParamInfo (Aliases lore)), SubExp)]
merge) ([Names] -> [Maybe Doc]) -> [Names] -> [Maybe Doc]
forall a b. (a -> b) -> a -> b
$
                      BodyT (Aliases lore) -> [Names]
forall lore. Aliased lore => Body lore -> [Names]
bodyAliases BodyT (Aliases lore)
body
          Exp (Aliases lore)
_ -> Maybe Doc
forall a. Maybe a
Nothing

      exp_dec :: Maybe Doc
exp_dec = case Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$ AliasDec -> Names
unAliases AliasDec
consumed of
        [] -> Maybe Doc
forall a. Maybe a
Nothing
        [VName]
als ->
          Doc -> Maybe Doc
forall a. a -> Maybe a
Just (Doc -> Maybe Doc) -> Doc -> Maybe Doc
forall a b. (a -> b) -> a -> b
$
            Doc -> Doc
PP.oneLine (Doc -> Doc) -> Doc -> Doc
forall a b. (a -> b) -> a -> b
$
              String -> Doc
PP.text String
"-- Consumes " Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> [Doc] -> Doc
PP.commasep ((VName -> Doc) -> [VName] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map VName -> Doc
forall a. Pretty a => a -> Doc
PP.ppr [VName]
als)

maybeComment :: [PP.Doc] -> Maybe PP.Doc
maybeComment :: [Doc] -> Maybe Doc
maybeComment [] = Maybe Doc
forall a. Maybe a
Nothing
maybeComment [Doc]
cs = Doc -> Maybe Doc
forall a. a -> Maybe a
Just (Doc -> Maybe Doc) -> Doc -> Maybe Doc
forall a b. (a -> b) -> a -> b
$ (Doc -> Doc -> Doc) -> [Doc] -> Doc
PP.folddoc Doc -> Doc -> Doc
(PP.</>) [Doc]
cs

aliasComment :: PP.Pretty a => a -> Names -> Maybe PP.Doc
aliasComment :: a -> Names -> Maybe Doc
aliasComment a
name Names
als =
  case Names -> [VName]
namesToList Names
als of
    [] -> Maybe Doc
forall a. Maybe a
Nothing
    [VName]
als' ->
      Doc -> Maybe Doc
forall a. a -> Maybe a
Just (Doc -> Maybe Doc) -> Doc -> Maybe Doc
forall a b. (a -> b) -> a -> b
$
        Doc -> Doc
PP.oneLine (Doc -> Doc) -> Doc -> Doc
forall a b. (a -> b) -> a -> b
$
          String -> Doc
PP.text String
"-- " Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> a -> Doc
forall a. Pretty a => a -> Doc
PP.ppr a
name Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> String -> Doc
PP.text String
" aliases "
            Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> [Doc] -> Doc
PP.commasep ((VName -> Doc) -> [VName] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map VName -> Doc
forall a. Pretty a => a -> Doc
PP.ppr [VName]
als')

resultAliasComment :: PP.Pretty a => a -> Names -> Maybe PP.Doc
resultAliasComment :: a -> Names -> Maybe Doc
resultAliasComment a
name Names
als =
  case Names -> [VName]
namesToList Names
als of
    [] -> Maybe Doc
forall a. Maybe a
Nothing
    [VName]
als' ->
      Doc -> Maybe Doc
forall a. a -> Maybe a
Just (Doc -> Maybe Doc) -> Doc -> Maybe Doc
forall a b. (a -> b) -> a -> b
$
        Doc -> Doc
PP.oneLine (Doc -> Doc) -> Doc -> Doc
forall a b. (a -> b) -> a -> b
$
          String -> Doc
PP.text String
"-- Result of " Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> a -> Doc
forall a. Pretty a => a -> Doc
PP.ppr a
name Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> String -> Doc
PP.text String
" aliases "
            Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> [Doc] -> Doc
PP.commasep ((VName -> Doc) -> [VName] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map VName -> Doc
forall a. Pretty a => a -> Doc
PP.ppr [VName]
als')

removeAliases :: CanBeAliased (Op lore) => Rephraser Identity (Aliases lore) lore
removeAliases :: Rephraser Identity (Aliases lore) lore
removeAliases =
  Rephraser :: forall (m :: * -> *) from to.
(ExpDec from -> m (ExpDec to))
-> (LetDec from -> m (LetDec to))
-> (FParamInfo from -> m (FParamInfo to))
-> (LParamInfo from -> m (LParamInfo to))
-> (BodyDec from -> m (BodyDec to))
-> (RetType from -> m (RetType to))
-> (BranchType from -> m (BranchType to))
-> (Op from -> m (Op to))
-> Rephraser m from to
Rephraser
    { rephraseExpLore :: ExpDec (Aliases lore) -> Identity (ExpDec lore)
rephraseExpLore = ExpDec lore -> Identity (ExpDec lore)
forall (m :: * -> *) a. Monad m => a -> m a
return (ExpDec lore -> Identity (ExpDec lore))
-> ((AliasDec, ExpDec lore) -> ExpDec lore)
-> (AliasDec, ExpDec lore)
-> Identity (ExpDec lore)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (AliasDec, ExpDec lore) -> ExpDec lore
forall a b. (a, b) -> b
snd,
      rephraseLetBoundLore :: LetDec (Aliases lore) -> Identity (LetDec lore)
rephraseLetBoundLore = LetDec lore -> Identity (LetDec lore)
forall (m :: * -> *) a. Monad m => a -> m a
return (LetDec lore -> Identity (LetDec lore))
-> ((AliasDec, LetDec lore) -> LetDec lore)
-> (AliasDec, LetDec lore)
-> Identity (LetDec lore)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (AliasDec, LetDec lore) -> LetDec lore
forall a b. (a, b) -> b
snd,
      rephraseBodyLore :: BodyDec (Aliases lore) -> Identity (BodyDec lore)
rephraseBodyLore = BodyDec lore -> Identity (BodyDec lore)
forall (m :: * -> *) a. Monad m => a -> m a
return (BodyDec lore -> Identity (BodyDec lore))
-> ((([AliasDec], AliasDec), BodyDec lore) -> BodyDec lore)
-> (([AliasDec], AliasDec), BodyDec lore)
-> Identity (BodyDec lore)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (([AliasDec], AliasDec), BodyDec lore) -> BodyDec lore
forall a b. (a, b) -> b
snd,
      rephraseFParamLore :: FParamInfo (Aliases lore) -> Identity (FParamInfo lore)
rephraseFParamLore = FParamInfo (Aliases lore) -> Identity (FParamInfo lore)
forall (m :: * -> *) a. Monad m => a -> m a
return,
      rephraseLParamLore :: LParamInfo (Aliases lore) -> Identity (LParamInfo lore)
rephraseLParamLore = LParamInfo (Aliases lore) -> Identity (LParamInfo lore)
forall (m :: * -> *) a. Monad m => a -> m a
return,
      rephraseRetType :: RetType (Aliases lore) -> Identity (RetType lore)
rephraseRetType = RetType (Aliases lore) -> Identity (RetType lore)
forall (m :: * -> *) a. Monad m => a -> m a
return,
      rephraseBranchType :: BranchType (Aliases lore) -> Identity (BranchType lore)
rephraseBranchType = BranchType (Aliases lore) -> Identity (BranchType lore)
forall (m :: * -> *) a. Monad m => a -> m a
return,
      rephraseOp :: Op (Aliases lore) -> Identity (Op lore)
rephraseOp = Op lore -> Identity (Op lore)
forall (m :: * -> *) a. Monad m => a -> m a
return (Op lore -> Identity (Op lore))
-> (OpWithAliases (Op lore) -> Op lore)
-> OpWithAliases (Op lore)
-> Identity (Op lore)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. OpWithAliases (Op lore) -> Op lore
forall op. CanBeAliased op => OpWithAliases op -> op
removeOpAliases
    }

removeScopeAliases :: Scope (Aliases lore) -> Scope lore
removeScopeAliases :: Scope (Aliases lore) -> Scope lore
removeScopeAliases = (NameInfo (Aliases lore) -> NameInfo lore)
-> Scope (Aliases lore) -> Scope lore
forall a b k. (a -> b) -> Map k a -> Map k b
M.map NameInfo (Aliases lore) -> NameInfo lore
forall lore a lore.
(LetDec lore ~ (a, LetDec lore), FParamInfo lore ~ FParamInfo lore,
 LParamInfo lore ~ LParamInfo lore) =>
NameInfo lore -> NameInfo lore
unAlias
  where
    unAlias :: NameInfo lore -> NameInfo lore
unAlias (LetName (_, dec)) = LetDec lore -> NameInfo lore
forall lore. LetDec lore -> NameInfo lore
LetName LetDec lore
dec
    unAlias (FParamName FParamInfo lore
dec) = FParamInfo lore -> NameInfo lore
forall lore. FParamInfo lore -> NameInfo lore
FParamName FParamInfo lore
FParamInfo lore
dec
    unAlias (LParamName LParamInfo lore
dec) = LParamInfo lore -> NameInfo lore
forall lore. LParamInfo lore -> NameInfo lore
LParamName LParamInfo lore
LParamInfo lore
dec
    unAlias (IndexName IntType
it) = IntType -> NameInfo lore
forall lore. IntType -> NameInfo lore
IndexName IntType
it

removeProgAliases ::
  CanBeAliased (Op lore) =>
  Prog (Aliases lore) ->
  Prog lore
removeProgAliases :: Prog (Aliases lore) -> Prog lore
removeProgAliases = Identity (Prog lore) -> Prog lore
forall a. Identity a -> a
runIdentity (Identity (Prog lore) -> Prog lore)
-> (Prog (Aliases lore) -> Identity (Prog lore))
-> Prog (Aliases lore)
-> Prog lore
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Rephraser Identity (Aliases lore) lore
-> Prog (Aliases lore) -> Identity (Prog lore)
forall (m :: * -> *) from to.
Monad m =>
Rephraser m from to -> Prog from -> m (Prog to)
rephraseProg Rephraser Identity (Aliases lore) lore
forall lore.
CanBeAliased (Op lore) =>
Rephraser Identity (Aliases lore) lore
removeAliases

removeFunDefAliases ::
  CanBeAliased (Op lore) =>
  FunDef (Aliases lore) ->
  FunDef lore
removeFunDefAliases :: FunDef (Aliases lore) -> FunDef lore
removeFunDefAliases = Identity (FunDef lore) -> FunDef lore
forall a. Identity a -> a
runIdentity (Identity (FunDef lore) -> FunDef lore)
-> (FunDef (Aliases lore) -> Identity (FunDef lore))
-> FunDef (Aliases lore)
-> FunDef lore
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Rephraser Identity (Aliases lore) lore
-> FunDef (Aliases lore) -> Identity (FunDef lore)
forall (m :: * -> *) from to.
Monad m =>
Rephraser m from to -> FunDef from -> m (FunDef to)
rephraseFunDef Rephraser Identity (Aliases lore) lore
forall lore.
CanBeAliased (Op lore) =>
Rephraser Identity (Aliases lore) lore
removeAliases

removeExpAliases ::
  CanBeAliased (Op lore) =>
  Exp (Aliases lore) ->
  Exp lore
removeExpAliases :: Exp (Aliases lore) -> Exp lore
removeExpAliases = Identity (Exp lore) -> Exp lore
forall a. Identity a -> a
runIdentity (Identity (Exp lore) -> Exp lore)
-> (Exp (Aliases lore) -> Identity (Exp lore))
-> Exp (Aliases lore)
-> Exp lore
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Rephraser Identity (Aliases lore) lore
-> Exp (Aliases lore) -> Identity (Exp lore)
forall (m :: * -> *) from to.
Monad m =>
Rephraser m from to -> Exp from -> m (Exp to)
rephraseExp Rephraser Identity (Aliases lore) lore
forall lore.
CanBeAliased (Op lore) =>
Rephraser Identity (Aliases lore) lore
removeAliases

removeStmAliases ::
  CanBeAliased (Op lore) =>
  Stm (Aliases lore) ->
  Stm lore
removeStmAliases :: Stm (Aliases lore) -> Stm lore
removeStmAliases = Identity (Stm lore) -> Stm lore
forall a. Identity a -> a
runIdentity (Identity (Stm lore) -> Stm lore)
-> (Stm (Aliases lore) -> Identity (Stm lore))
-> Stm (Aliases lore)
-> Stm lore
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Rephraser Identity (Aliases lore) lore
-> Stm (Aliases lore) -> Identity (Stm lore)
forall (m :: * -> *) from to.
Monad m =>
Rephraser m from to -> Stm from -> m (Stm to)
rephraseStm Rephraser Identity (Aliases lore) lore
forall lore.
CanBeAliased (Op lore) =>
Rephraser Identity (Aliases lore) lore
removeAliases

removeLambdaAliases ::
  CanBeAliased (Op lore) =>
  Lambda (Aliases lore) ->
  Lambda lore
removeLambdaAliases :: Lambda (Aliases lore) -> Lambda lore
removeLambdaAliases = Identity (Lambda lore) -> Lambda lore
forall a. Identity a -> a
runIdentity (Identity (Lambda lore) -> Lambda lore)
-> (Lambda (Aliases lore) -> Identity (Lambda lore))
-> Lambda (Aliases lore)
-> Lambda lore
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Rephraser Identity (Aliases lore) lore
-> Lambda (Aliases lore) -> Identity (Lambda lore)
forall (m :: * -> *) from to.
Monad m =>
Rephraser m from to -> Lambda from -> m (Lambda to)
rephraseLambda Rephraser Identity (Aliases lore) lore
forall lore.
CanBeAliased (Op lore) =>
Rephraser Identity (Aliases lore) lore
removeAliases

removePatternAliases ::
  PatternT (AliasDec, a) ->
  PatternT a
removePatternAliases :: PatternT (AliasDec, a) -> PatternT a
removePatternAliases = Identity (PatternT a) -> PatternT a
forall a. Identity a -> a
runIdentity (Identity (PatternT a) -> PatternT a)
-> (PatternT (AliasDec, a) -> Identity (PatternT a))
-> PatternT (AliasDec, a)
-> PatternT a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((AliasDec, a) -> Identity a)
-> PatternT (AliasDec, a) -> Identity (PatternT a)
forall (m :: * -> *) from to.
Monad m =>
(from -> m to) -> PatternT from -> m (PatternT to)
rephrasePattern (a -> Identity a
forall (m :: * -> *) a. Monad m => a -> m a
return (a -> Identity a)
-> ((AliasDec, a) -> a) -> (AliasDec, a) -> Identity a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (AliasDec, a) -> a
forall a b. (a, b) -> b
snd)

addAliasesToPattern ::
  (ASTLore lore, CanBeAliased (Op lore), Typed dec) =>
  PatternT dec ->
  Exp (Aliases lore) ->
  PatternT (VarAliases, dec)
addAliasesToPattern :: PatternT dec -> Exp (Aliases lore) -> PatternT (AliasDec, dec)
addAliasesToPattern PatternT dec
pat Exp (Aliases lore)
e =
  ([PatElemT (AliasDec, dec)]
 -> [PatElemT (AliasDec, dec)] -> PatternT (AliasDec, dec))
-> ([PatElemT (AliasDec, dec)], [PatElemT (AliasDec, dec)])
-> PatternT (AliasDec, dec)
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry [PatElemT (AliasDec, dec)]
-> [PatElemT (AliasDec, dec)] -> PatternT (AliasDec, dec)
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern (([PatElemT (AliasDec, dec)], [PatElemT (AliasDec, dec)])
 -> PatternT (AliasDec, dec))
-> ([PatElemT (AliasDec, dec)], [PatElemT (AliasDec, dec)])
-> PatternT (AliasDec, dec)
forall a b. (a -> b) -> a -> b
$ PatternT dec
-> Exp (Aliases lore)
-> ([PatElemT (AliasDec, dec)], [PatElemT (AliasDec, dec)])
forall lore dec.
(Aliased lore, Typed dec) =>
PatternT dec
-> Exp lore
-> ([PatElemT (AliasDec, dec)], [PatElemT (AliasDec, dec)])
mkPatternAliases PatternT dec
pat Exp (Aliases lore)
e

mkAliasedBody ::
  (ASTLore lore, CanBeAliased (Op lore)) =>
  BodyDec lore ->
  Stms (Aliases lore) ->
  Result ->
  Body (Aliases lore)
mkAliasedBody :: BodyDec lore
-> Stms (Aliases lore) -> Result -> Body (Aliases lore)
mkAliasedBody BodyDec lore
innerlore Stms (Aliases lore)
bnds Result
res =
  BodyDec (Aliases lore)
-> Stms (Aliases lore) -> Result -> Body (Aliases lore)
forall lore. BodyDec lore -> Stms lore -> Result -> BodyT lore
Body (Stms (Aliases lore) -> Result -> ([AliasDec], AliasDec)
forall lore.
Aliased lore =>
Stms lore -> Result -> ([AliasDec], AliasDec)
mkBodyAliases Stms (Aliases lore)
bnds Result
res, BodyDec lore
innerlore) Stms (Aliases lore)
bnds Result
res

mkPatternAliases ::
  (Aliased lore, Typed dec) =>
  PatternT dec ->
  Exp lore ->
  ( [PatElemT (VarAliases, dec)],
    [PatElemT (VarAliases, dec)]
  )
mkPatternAliases :: PatternT dec
-> Exp lore
-> ([PatElemT (AliasDec, dec)], [PatElemT (AliasDec, dec)])
mkPatternAliases PatternT dec
pat Exp lore
e =
  -- Some part of the pattern may be the context.  This does not have
  -- aliases from expAliases, so we use a hack to compute aliases of
  -- the context.
  let als :: [Names]
als = Exp lore -> [Names]
forall lore. Aliased lore => Exp lore -> [Names]
expAliases Exp lore
e [Names] -> [Names] -> [Names]
forall a. [a] -> [a] -> [a]
++ Names -> [Names]
forall a. a -> [a]
repeat Names
forall a. Monoid a => a
mempty -- In case the pattern has
  -- more elements (this
  -- implies a type error).
      context_als :: [Names]
context_als = PatternT dec -> Exp lore -> [Names]
forall lore dec.
Aliased lore =>
PatternT dec -> Exp lore -> [Names]
mkContextAliases PatternT dec
pat Exp lore
e
   in ( (PatElemT dec -> Names -> PatElemT (AliasDec, dec))
-> [PatElemT dec] -> [Names] -> [PatElemT (AliasDec, dec)]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith PatElemT dec -> Names -> PatElemT (AliasDec, dec)
forall b. Typed b => PatElemT b -> Names -> PatElemT (AliasDec, b)
annotateBindee (PatternT dec -> [PatElemT dec]
forall dec. PatternT dec -> [PatElemT dec]
patternContextElements PatternT dec
pat) [Names]
context_als,
        (PatElemT dec -> Names -> PatElemT (AliasDec, dec))
-> [PatElemT dec] -> [Names] -> [PatElemT (AliasDec, dec)]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith PatElemT dec -> Names -> PatElemT (AliasDec, dec)
forall b. Typed b => PatElemT b -> Names -> PatElemT (AliasDec, b)
annotateBindee (PatternT dec -> [PatElemT dec]
forall dec. PatternT dec -> [PatElemT dec]
patternValueElements PatternT dec
pat) [Names]
als
      )
  where
    annotateBindee :: PatElemT b -> Names -> PatElemT (AliasDec, b)
annotateBindee PatElemT b
bindee Names
names =
      PatElemT b
bindee PatElemT b -> (AliasDec, b) -> PatElemT (AliasDec, b)
forall oldattr newattr.
PatElemT oldattr -> newattr -> PatElemT newattr
`setPatElemLore` (Names -> AliasDec
AliasDec Names
names', PatElemT b -> b
forall dec. PatElemT dec -> dec
patElemDec PatElemT b
bindee)
      where
        names' :: Names
names' =
          case PatElemT b -> TypeBase Shape NoUniqueness
forall dec.
Typed dec =>
PatElemT dec -> TypeBase Shape NoUniqueness
patElemType PatElemT b
bindee of
            Array {} -> Names
names
            Mem Space
_ -> Names
names
            TypeBase Shape NoUniqueness
_ -> Names
forall a. Monoid a => a
mempty

mkContextAliases ::
  Aliased lore =>
  PatternT dec ->
  Exp lore ->
  [Names]
mkContextAliases :: PatternT dec -> Exp lore -> [Names]
mkContextAliases PatternT dec
pat (DoLoop [(FParam lore, SubExp)]
ctxmerge [(FParam lore, SubExp)]
valmerge LoopForm lore
_ BodyT lore
body) =
  let ctx :: [FParam lore]
ctx = ((FParam lore, SubExp) -> FParam lore)
-> [(FParam lore, SubExp)] -> [FParam lore]
forall a b. (a -> b) -> [a] -> [b]
map (FParam lore, SubExp) -> FParam lore
forall a b. (a, b) -> a
fst [(FParam lore, SubExp)]
ctxmerge
      init_als :: [(VName, Names)]
init_als = [VName] -> [Names] -> [(VName, Names)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
mergenames ([Names] -> [(VName, Names)]) -> [Names] -> [(VName, Names)]
forall a b. (a -> b) -> a -> b
$ ((FParam lore, SubExp) -> Names)
-> [(FParam lore, SubExp)] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> Names
subExpAliases (SubExp -> Names)
-> ((FParam lore, SubExp) -> SubExp)
-> (FParam lore, SubExp)
-> Names
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (FParam lore, SubExp) -> SubExp
forall a b. (a, b) -> b
snd) ([(FParam lore, SubExp)] -> [Names])
-> [(FParam lore, SubExp)] -> [Names]
forall a b. (a -> b) -> a -> b
$ [(FParam lore, SubExp)]
ctxmerge [(FParam lore, SubExp)]
-> [(FParam lore, SubExp)] -> [(FParam lore, SubExp)]
forall a. [a] -> [a] -> [a]
++ [(FParam lore, SubExp)]
valmerge
      expand :: Names -> Names
expand Names
als = Names
als Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> [Names] -> Names
forall a. Monoid a => [a] -> a
mconcat ((VName -> Maybe Names) -> [VName] -> [Names]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (VName -> [(VName, Names)] -> Maybe Names
forall a b. Eq a => a -> [(a, b)] -> Maybe b
`lookup` [(VName, Names)]
init_als) (Names -> [VName]
namesToList Names
als))
      merge_als :: [(VName, Names)]
merge_als =
        [VName] -> [Names] -> [(VName, Names)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
mergenames ([Names] -> [(VName, Names)]) -> [Names] -> [(VName, Names)]
forall a b. (a -> b) -> a -> b
$
          (Names -> Names) -> [Names] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map ((Names -> Names -> Names
`namesSubtract` Names
mergenames_set) (Names -> Names) -> (Names -> Names) -> Names -> Names
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Names -> Names
expand) ([Names] -> [Names]) -> [Names] -> [Names]
forall a b. (a -> b) -> a -> b
$
            BodyT lore -> [Names]
forall lore. Aliased lore => Body lore -> [Names]
bodyAliases BodyT lore
body
   in if [FParam lore] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [FParam lore]
ctx Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [PatElemT dec] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (PatternT dec -> [PatElemT dec]
forall dec. PatternT dec -> [PatElemT dec]
patternContextElements PatternT dec
pat)
        then (FParam lore -> Names) -> [FParam lore] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map (Names -> Maybe Names -> Names
forall a. a -> Maybe a -> a
fromMaybe Names
forall a. Monoid a => a
mempty (Maybe Names -> Names)
-> (FParam lore -> Maybe Names) -> FParam lore -> Names
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName -> [(VName, Names)] -> Maybe Names)
-> [(VName, Names)] -> VName -> Maybe Names
forall a b c. (a -> b -> c) -> b -> a -> c
flip VName -> [(VName, Names)] -> Maybe Names
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup [(VName, Names)]
merge_als (VName -> Maybe Names)
-> (FParam lore -> VName) -> FParam lore -> Maybe Names
forall b c a. (b -> c) -> (a -> b) -> a -> c
. FParam lore -> VName
forall dec. Param dec -> VName
paramName) [FParam lore]
ctx
        else (PatElemT dec -> Names) -> [PatElemT dec] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map (Names -> PatElemT dec -> Names
forall a b. a -> b -> a
const Names
forall a. Monoid a => a
mempty) ([PatElemT dec] -> [Names]) -> [PatElemT dec] -> [Names]
forall a b. (a -> b) -> a -> b
$ PatternT dec -> [PatElemT dec]
forall dec. PatternT dec -> [PatElemT dec]
patternContextElements PatternT dec
pat
  where
    mergenames :: [VName]
mergenames = ((FParam lore, SubExp) -> VName)
-> [(FParam lore, SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (FParam lore -> VName
forall dec. Param dec -> VName
paramName (FParam lore -> VName)
-> ((FParam lore, SubExp) -> FParam lore)
-> (FParam lore, SubExp)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (FParam lore, SubExp) -> FParam lore
forall a b. (a, b) -> a
fst) ([(FParam lore, SubExp)] -> [VName])
-> [(FParam lore, SubExp)] -> [VName]
forall a b. (a -> b) -> a -> b
$ [(FParam lore, SubExp)]
ctxmerge [(FParam lore, SubExp)]
-> [(FParam lore, SubExp)] -> [(FParam lore, SubExp)]
forall a. [a] -> [a] -> [a]
++ [(FParam lore, SubExp)]
valmerge
    mergenames_set :: Names
mergenames_set = [VName] -> Names
namesFromList [VName]
mergenames
mkContextAliases PatternT dec
pat (If SubExp
_ BodyT lore
tbranch BodyT lore
fbranch IfDec (BranchType lore)
_) =
  Int -> [Names] -> [Names]
forall a. Int -> [a] -> [a]
take ([VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([VName] -> Int) -> [VName] -> Int
forall a b. (a -> b) -> a -> b
$ PatternT dec -> [VName]
forall dec. PatternT dec -> [VName]
patternContextNames PatternT dec
pat) ([Names] -> [Names]) -> [Names] -> [Names]
forall a b. (a -> b) -> a -> b
$
    (Names -> Names -> Names) -> [Names] -> [Names] -> [Names]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
(<>) (BodyT lore -> [Names]
forall lore. Aliased lore => Body lore -> [Names]
bodyAliases BodyT lore
tbranch) (BodyT lore -> [Names]
forall lore. Aliased lore => Body lore -> [Names]
bodyAliases BodyT lore
fbranch)
mkContextAliases PatternT dec
pat Exp lore
_ =
  Int -> Names -> [Names]
forall a. Int -> a -> [a]
replicate ([PatElemT dec] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([PatElemT dec] -> Int) -> [PatElemT dec] -> Int
forall a b. (a -> b) -> a -> b
$ PatternT dec -> [PatElemT dec]
forall dec. PatternT dec -> [PatElemT dec]
patternContextElements PatternT dec
pat) Names
forall a. Monoid a => a
mempty

mkBodyAliases ::
  Aliased lore =>
  Stms lore ->
  Result ->
  BodyAliasing
mkBodyAliases :: Stms lore -> Result -> ([AliasDec], AliasDec)
mkBodyAliases Stms lore
bnds Result
res =
  -- We need to remove the names that are bound in bnds from the alias
  -- and consumption sets.  We do this by computing the transitive
  -- closure of the alias map (within bnds), then removing anything
  -- bound in bnds.
  let ([Names]
aliases, Names
consumed) = Stms lore -> Result -> ([Names], Names)
forall lore.
Aliased lore =>
Stms lore -> Result -> ([Names], Names)
mkStmsAliases Stms lore
bnds Result
res
      boundNames :: Names
boundNames =
        (Stm lore -> Names) -> Stms lore -> Names
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap ([VName] -> Names
namesFromList ([VName] -> Names) -> (Stm lore -> [VName]) -> Stm lore -> Names
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatternT (LetDec lore) -> [VName]
forall dec. PatternT dec -> [VName]
patternNames (PatternT (LetDec lore) -> [VName])
-> (Stm lore -> PatternT (LetDec lore)) -> Stm lore -> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm lore -> PatternT (LetDec lore)
forall lore. Stm lore -> Pattern lore
stmPattern) Stms lore
bnds
      aliases' :: [Names]
aliases' = (Names -> Names) -> [Names] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map (Names -> Names -> Names
`namesSubtract` Names
boundNames) [Names]
aliases
      consumed' :: Names
consumed' = Names
consumed Names -> Names -> Names
`namesSubtract` Names
boundNames
   in ((Names -> AliasDec) -> [Names] -> [AliasDec]
forall a b. (a -> b) -> [a] -> [b]
map Names -> AliasDec
AliasDec [Names]
aliases', Names -> AliasDec
AliasDec Names
consumed')

mkStmsAliases ::
  Aliased lore =>
  Stms lore ->
  [SubExp] ->
  ([Names], Names)
mkStmsAliases :: Stms lore -> Result -> ([Names], Names)
mkStmsAliases Stms lore
bnds Result
res = AliasesAndConsumed -> [Stm lore] -> ([Names], Names)
delve AliasesAndConsumed
forall a. Monoid a => a
mempty ([Stm lore] -> ([Names], Names)) -> [Stm lore] -> ([Names], Names)
forall a b. (a -> b) -> a -> b
$ Stms lore -> [Stm lore]
forall lore. Stms lore -> [Stm lore]
stmsToList Stms lore
bnds
  where
    delve :: AliasesAndConsumed -> [Stm lore] -> ([Names], Names)
delve (Map VName Names
aliasmap, Names
consumed) [] =
      ( (SubExp -> Names) -> Result -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map (Map VName Names -> Names -> Names
aliasClosure Map VName Names
aliasmap (Names -> Names) -> (SubExp -> Names) -> SubExp -> Names
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> Names
subExpAliases) Result
res,
        Names
consumed
      )
    delve (Map VName Names
aliasmap, Names
consumed) (Stm lore
bnd : [Stm lore]
bnds') =
      AliasesAndConsumed -> [Stm lore] -> ([Names], Names)
delve (AliasesAndConsumed -> Stm lore -> AliasesAndConsumed
forall lore.
Aliased lore =>
AliasesAndConsumed -> Stm lore -> AliasesAndConsumed
trackAliases (Map VName Names
aliasmap, Names
consumed) Stm lore
bnd) [Stm lore]
bnds'
    aliasClosure :: Map VName Names -> Names -> Names
aliasClosure Map VName Names
aliasmap Names
names =
      Names
names Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> [Names] -> Names
forall a. Monoid a => [a] -> a
mconcat ((VName -> Names) -> [VName] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map VName -> Names
look ([VName] -> [Names]) -> [VName] -> [Names]
forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList Names
names)
      where
        look :: VName -> Names
look VName
k = Names -> VName -> Map VName Names -> Names
forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault Names
forall a. Monoid a => a
mempty VName
k Map VName Names
aliasmap

-- | Everything consumed in the given statements and result (even
-- transitively).
consumedInStms :: Aliased lore => Stms lore -> Names
consumedInStms :: Stms lore -> Names
consumedInStms = ([Names], Names) -> Names
forall a b. (a, b) -> b
snd (([Names], Names) -> Names)
-> (Stms lore -> ([Names], Names)) -> Stms lore -> Names
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Stms lore -> Result -> ([Names], Names))
-> Result -> Stms lore -> ([Names], Names)
forall a b c. (a -> b -> c) -> b -> a -> c
flip Stms lore -> Result -> ([Names], Names)
forall lore.
Aliased lore =>
Stms lore -> Result -> ([Names], Names)
mkStmsAliases []

type AliasesAndConsumed =
  ( M.Map VName Names,
    Names
  )

trackAliases ::
  Aliased lore =>
  AliasesAndConsumed ->
  Stm lore ->
  AliasesAndConsumed
trackAliases :: AliasesAndConsumed -> Stm lore -> AliasesAndConsumed
trackAliases (Map VName Names
aliasmap, Names
consumed) Stm lore
bnd =
  let pat :: Pattern lore
pat = Stm lore -> Pattern lore
forall lore. Stm lore -> Pattern lore
stmPattern Stm lore
bnd
      als :: Map VName Names
als =
        [(VName, Names)] -> Map VName Names
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, Names)] -> Map VName Names)
-> [(VName, Names)] -> Map VName Names
forall a b. (a -> b) -> a -> b
$
          [VName] -> [Names] -> [(VName, Names)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Pattern lore -> [VName]
forall dec. PatternT dec -> [VName]
patternNames Pattern lore
pat) ((Names -> Names) -> [Names] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map Names -> Names
addAliasesOfAliases ([Names] -> [Names]) -> [Names] -> [Names]
forall a b. (a -> b) -> a -> b
$ Pattern lore -> [Names]
forall dec. AliasesOf dec => PatternT dec -> [Names]
patternAliases Pattern lore
pat)
      aliasmap' :: Map VName Names
aliasmap' = Map VName Names
als Map VName Names -> Map VName Names -> Map VName Names
forall a. Semigroup a => a -> a -> a
<> Map VName Names
aliasmap
      consumed' :: Names
consumed' = Names
consumed Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names -> Names
addAliasesOfAliases (Stm lore -> Names
forall lore. Aliased lore => Stm lore -> Names
consumedInStm Stm lore
bnd)
   in (Map VName Names
aliasmap', Names
consumed')
  where
    addAliasesOfAliases :: Names -> Names
addAliasesOfAliases Names
names = Names
names Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names -> Names
aliasesOfAliases Names
names
    aliasesOfAliases :: Names -> Names
aliasesOfAliases = [Names] -> Names
forall a. Monoid a => [a] -> a
mconcat ([Names] -> Names) -> (Names -> [Names]) -> Names -> Names
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName -> Names) -> [VName] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map VName -> Names
look ([VName] -> [Names]) -> (Names -> [VName]) -> Names -> [Names]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Names -> [VName]
namesToList
    look :: VName -> Names
look VName
k = Names -> VName -> Map VName Names -> Names
forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault Names
forall a. Monoid a => a
mempty VName
k Map VName Names
aliasmap

mkAliasedLetStm ::
  (ASTLore lore, CanBeAliased (Op lore)) =>
  Pattern lore ->
  StmAux (ExpDec lore) ->
  Exp (Aliases lore) ->
  Stm (Aliases lore)
mkAliasedLetStm :: Pattern lore
-> StmAux (ExpDec lore) -> Exp (Aliases lore) -> Stm (Aliases lore)
mkAliasedLetStm Pattern lore
pat (StmAux Certificates
cs Attrs
attrs ExpDec lore
dec) Exp (Aliases lore)
e =
  Pattern (Aliases lore)
-> StmAux (ExpDec (Aliases lore))
-> Exp (Aliases lore)
-> Stm (Aliases lore)
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let
    (Pattern lore
-> Exp (Aliases lore) -> PatternT (AliasDec, LetDec lore)
forall lore dec.
(ASTLore lore, CanBeAliased (Op lore), Typed dec) =>
PatternT dec -> Exp (Aliases lore) -> PatternT (AliasDec, dec)
addAliasesToPattern Pattern lore
pat Exp (Aliases lore)
e)
    (Certificates
-> Attrs
-> (AliasDec, ExpDec lore)
-> StmAux (AliasDec, ExpDec lore)
forall dec. Certificates -> Attrs -> dec -> StmAux dec
StmAux Certificates
cs Attrs
attrs (Names -> AliasDec
AliasDec (Names -> AliasDec) -> Names -> AliasDec
forall a b. (a -> b) -> a -> b
$ Exp (Aliases lore) -> Names
forall lore. Aliased lore => Exp lore -> Names
consumedInExp Exp (Aliases lore)
e, ExpDec lore
dec))
    Exp (Aliases lore)
e

instance (Bindable lore, CanBeAliased (Op lore)) => Bindable (Aliases lore) where
  mkExpDec :: Pattern (Aliases lore)
-> Exp (Aliases lore) -> ExpDec (Aliases lore)
mkExpDec Pattern (Aliases lore)
pat Exp (Aliases lore)
e =
    let dec :: ExpDec lore
dec = Pattern lore -> Exp lore -> ExpDec lore
forall lore.
Bindable lore =>
Pattern lore -> Exp lore -> ExpDec lore
mkExpDec (PatternT (AliasDec, LetDec lore) -> Pattern lore
forall a. PatternT (AliasDec, a) -> PatternT a
removePatternAliases PatternT (AliasDec, LetDec lore)
Pattern (Aliases lore)
pat) (Exp lore -> ExpDec lore) -> Exp lore -> ExpDec lore
forall a b. (a -> b) -> a -> b
$ Exp (Aliases lore) -> Exp lore
forall lore.
CanBeAliased (Op lore) =>
Exp (Aliases lore) -> Exp lore
removeExpAliases Exp (Aliases lore)
e
     in (Names -> AliasDec
AliasDec (Names -> AliasDec) -> Names -> AliasDec
forall a b. (a -> b) -> a -> b
$ Exp (Aliases lore) -> Names
forall lore. Aliased lore => Exp lore -> Names
consumedInExp Exp (Aliases lore)
e, ExpDec lore
dec)

  mkExpPat :: [Ident] -> [Ident] -> Exp (Aliases lore) -> Pattern (Aliases lore)
mkExpPat [Ident]
ctx [Ident]
val Exp (Aliases lore)
e =
    Pattern lore
-> Exp (Aliases lore) -> PatternT (AliasDec, LetDec lore)
forall lore dec.
(ASTLore lore, CanBeAliased (Op lore), Typed dec) =>
PatternT dec -> Exp (Aliases lore) -> PatternT (AliasDec, dec)
addAliasesToPattern ([Ident] -> [Ident] -> Exp lore -> Pattern lore
forall lore.
Bindable lore =>
[Ident] -> [Ident] -> Exp lore -> Pattern lore
mkExpPat [Ident]
ctx [Ident]
val (Exp lore -> Pattern lore) -> Exp lore -> Pattern lore
forall a b. (a -> b) -> a -> b
$ Exp (Aliases lore) -> Exp lore
forall lore.
CanBeAliased (Op lore) =>
Exp (Aliases lore) -> Exp lore
removeExpAliases Exp (Aliases lore)
e) Exp (Aliases lore)
e

  mkLetNames :: [VName] -> Exp (Aliases lore) -> m (Stm (Aliases lore))
mkLetNames [VName]
names Exp (Aliases lore)
e = do
    Scope lore
env <- (Scope (Aliases lore) -> Scope lore) -> m (Scope lore)
forall lore (m :: * -> *) a.
HasScope lore m =>
(Scope lore -> a) -> m a
asksScope Scope (Aliases lore) -> Scope lore
forall lore. Scope (Aliases lore) -> Scope lore
removeScopeAliases
    (ReaderT (Scope lore) m (Stm (Aliases lore))
 -> Scope lore -> m (Stm (Aliases lore)))
-> Scope lore
-> ReaderT (Scope lore) m (Stm (Aliases lore))
-> m (Stm (Aliases lore))
forall a b c. (a -> b -> c) -> b -> a -> c
flip ReaderT (Scope lore) m (Stm (Aliases lore))
-> Scope lore -> m (Stm (Aliases lore))
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT Scope lore
env (ReaderT (Scope lore) m (Stm (Aliases lore))
 -> m (Stm (Aliases lore)))
-> ReaderT (Scope lore) m (Stm (Aliases lore))
-> m (Stm (Aliases lore))
forall a b. (a -> b) -> a -> b
$ do
      Let Pattern lore
pat StmAux (ExpDec lore)
dec Exp lore
_ <- [VName] -> Exp lore -> ReaderT (Scope lore) m (Stm lore)
forall lore (m :: * -> *).
(Bindable lore, MonadFreshNames m, HasScope lore m) =>
[VName] -> Exp lore -> m (Stm lore)
mkLetNames [VName]
names (Exp lore -> ReaderT (Scope lore) m (Stm lore))
-> Exp lore -> ReaderT (Scope lore) m (Stm lore)
forall a b. (a -> b) -> a -> b
$ Exp (Aliases lore) -> Exp lore
forall lore.
CanBeAliased (Op lore) =>
Exp (Aliases lore) -> Exp lore
removeExpAliases Exp (Aliases lore)
e
      Stm (Aliases lore) -> ReaderT (Scope lore) m (Stm (Aliases lore))
forall (m :: * -> *) a. Monad m => a -> m a
return (Stm (Aliases lore) -> ReaderT (Scope lore) m (Stm (Aliases lore)))
-> Stm (Aliases lore)
-> ReaderT (Scope lore) m (Stm (Aliases lore))
forall a b. (a -> b) -> a -> b
$ Pattern lore
-> StmAux (ExpDec lore) -> Exp (Aliases lore) -> Stm (Aliases lore)
forall lore.
(ASTLore lore, CanBeAliased (Op lore)) =>
Pattern lore
-> StmAux (ExpDec lore) -> Exp (Aliases lore) -> Stm (Aliases lore)
mkAliasedLetStm Pattern lore
pat StmAux (ExpDec lore)
dec Exp (Aliases lore)
e

  mkBody :: Stms (Aliases lore) -> Result -> Body (Aliases lore)
mkBody Stms (Aliases lore)
bnds Result
res =
    let Body BodyDec lore
bodylore Stms lore
_ Result
_ = Stms lore -> Result -> BodyT lore
forall lore. Bindable lore => Stms lore -> Result -> Body lore
mkBody ((Stm (Aliases lore) -> Stm lore)
-> Stms (Aliases lore) -> Stms lore
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Stm (Aliases lore) -> Stm lore
forall lore.
CanBeAliased (Op lore) =>
Stm (Aliases lore) -> Stm lore
removeStmAliases Stms (Aliases lore)
bnds) Result
res
     in BodyDec lore
-> Stms (Aliases lore) -> Result -> Body (Aliases lore)
forall lore.
(ASTLore lore, CanBeAliased (Op lore)) =>
BodyDec lore
-> Stms (Aliases lore) -> Result -> Body (Aliases lore)
mkAliasedBody BodyDec lore
bodylore Stms (Aliases lore)
bnds Result
res

instance (ASTLore (Aliases lore), Bindable (Aliases lore)) => BinderOps (Aliases lore)