{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}

-- | A representation where all bindings are annotated with aliasing
-- information.
module Futhark.IR.Aliases
  ( -- * The Lore definition
    Aliases,
    AliasDec (..),
    VarAliases,
    ConsumedInExp,
    BodyAliasing,
    module Futhark.IR.Prop.Aliases,

    -- * Module re-exports
    module Futhark.IR.Prop,
    module Futhark.IR.Traversals,
    module Futhark.IR.Pretty,
    module Futhark.IR.Syntax,

    -- * Adding aliases
    addAliasesToPattern,
    mkAliasedLetStm,
    mkAliasedBody,
    mkPatternAliases,
    mkBodyAliases,

    -- * Removing aliases
    removeProgAliases,
    removeFunDefAliases,
    removeExpAliases,
    removeStmAliases,
    removeLambdaAliases,
    removePatternAliases,
    removeScopeAliases,

    -- * Tracking aliases
    AliasesAndConsumed,
    trackAliases,
    mkStmsAliases,
  )
where

import Control.Monad.Identity
import Control.Monad.Reader
import qualified Data.Map.Strict as M
import Data.Maybe
import Futhark.Analysis.Rephrase
import Futhark.Binder
import Futhark.IR.Pretty
import Futhark.IR.Prop
import Futhark.IR.Prop.Aliases
import Futhark.IR.Syntax
import Futhark.IR.Traversals
import Futhark.Transform.Rename
import Futhark.Transform.Substitute
import qualified Futhark.Util.Pretty as PP

-- | The lore for the basic representation.
data Aliases lore

-- | A wrapper around 'AliasDec' to get around the fact that we need an
-- 'Ord' instance, which 'AliasDec does not have.
newtype AliasDec = AliasDec {AliasDec -> Names
unAliases :: Names}
  deriving (Int -> AliasDec -> ShowS
[AliasDec] -> ShowS
AliasDec -> String
(Int -> AliasDec -> ShowS)
-> (AliasDec -> String) -> ([AliasDec] -> ShowS) -> Show AliasDec
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [AliasDec] -> ShowS
$cshowList :: [AliasDec] -> ShowS
show :: AliasDec -> String
$cshow :: AliasDec -> String
showsPrec :: Int -> AliasDec -> ShowS
$cshowsPrec :: Int -> AliasDec -> ShowS
Show)

instance Semigroup AliasDec where
  AliasDec
x <> :: AliasDec -> AliasDec -> AliasDec
<> AliasDec
y = Names -> AliasDec
AliasDec (Names -> AliasDec) -> Names -> AliasDec
forall a b. (a -> b) -> a -> b
$ AliasDec -> Names
unAliases AliasDec
x Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> AliasDec -> Names
unAliases AliasDec
y

instance Monoid AliasDec where
  mempty :: AliasDec
mempty = Names -> AliasDec
AliasDec Names
forall a. Monoid a => a
mempty

instance Eq AliasDec where
  AliasDec
_ == :: AliasDec -> AliasDec -> Bool
== AliasDec
_ = Bool
True

instance Ord AliasDec where
  AliasDec
_ compare :: AliasDec -> AliasDec -> Ordering
`compare` AliasDec
_ = Ordering
EQ

instance Rename AliasDec where
  rename :: AliasDec -> RenameM AliasDec
rename (AliasDec Names
names) = Names -> AliasDec
AliasDec (Names -> AliasDec) -> RenameM Names -> RenameM AliasDec
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Names -> RenameM Names
forall a. Rename a => a -> RenameM a
rename Names
names

instance Substitute AliasDec where
  substituteNames :: Map VName VName -> AliasDec -> AliasDec
substituteNames Map VName VName
substs (AliasDec Names
names) = Names -> AliasDec
AliasDec (Names -> AliasDec) -> Names -> AliasDec
forall a b. (a -> b) -> a -> b
$ Map VName VName -> Names -> Names
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs Names
names

instance FreeIn AliasDec where
  freeIn' :: AliasDec -> FV
freeIn' = FV -> AliasDec -> FV
forall a b. a -> b -> a
const FV
forall a. Monoid a => a
mempty

instance PP.Pretty AliasDec where
  ppr :: AliasDec -> Doc
ppr = Doc -> Doc
PP.braces (Doc -> Doc) -> (AliasDec -> Doc) -> AliasDec -> Doc
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Doc] -> Doc
PP.commasep ([Doc] -> Doc) -> (AliasDec -> [Doc]) -> AliasDec -> Doc
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName -> Doc) -> [VName] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map VName -> Doc
forall a. Pretty a => a -> Doc
PP.ppr ([VName] -> [Doc]) -> (AliasDec -> [VName]) -> AliasDec -> [Doc]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Names -> [VName]
namesToList (Names -> [VName]) -> (AliasDec -> Names) -> AliasDec -> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. AliasDec -> Names
unAliases

-- | The aliases of the let-bound variable.
type VarAliases = AliasDec

-- | Everything consumed in the expression.
type ConsumedInExp = AliasDec

-- | The aliases of what is returned by the t'Body', and what is
-- consumed inside of it.
type BodyAliasing = ([VarAliases], ConsumedInExp)

instance
  (Decorations lore, CanBeAliased (Op lore)) =>
  Decorations (Aliases lore)
  where
  type LetDec (Aliases lore) = (VarAliases, LetDec lore)
  type ExpDec (Aliases lore) = (ConsumedInExp, ExpDec lore)
  type BodyDec (Aliases lore) = (BodyAliasing, BodyDec lore)
  type FParamInfo (Aliases lore) = FParamInfo lore
  type LParamInfo (Aliases lore) = LParamInfo lore
  type RetType (Aliases lore) = RetType lore
  type BranchType (Aliases lore) = BranchType lore
  type Op (Aliases lore) = OpWithAliases (Op lore)

instance AliasesOf (VarAliases, dec) where
  aliasesOf :: (AliasDec, dec) -> Names
aliasesOf = AliasDec -> Names
unAliases (AliasDec -> Names)
-> ((AliasDec, dec) -> AliasDec) -> (AliasDec, dec) -> Names
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (AliasDec, dec) -> AliasDec
forall a b. (a, b) -> a
fst

instance FreeDec AliasDec

withoutAliases ::
  (HasScope (Aliases lore) m, Monad m) =>
  ReaderT (Scope lore) m a ->
  m a
withoutAliases :: forall lore (m :: * -> *) a.
(HasScope (Aliases lore) m, Monad m) =>
ReaderT (Scope lore) m a -> m a
withoutAliases ReaderT (Scope lore) m a
m = do
  Scope lore
scope <- (Scope (Aliases lore) -> Scope lore) -> m (Scope lore)
forall lore (m :: * -> *) a.
HasScope lore m =>
(Scope lore -> a) -> m a
asksScope Scope (Aliases lore) -> Scope lore
forall lore. Scope (Aliases lore) -> Scope lore
removeScopeAliases
  ReaderT (Scope lore) m a -> Scope lore -> m a
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ReaderT (Scope lore) m a
m Scope lore
scope

instance (ASTLore lore, CanBeAliased (Op lore)) => ASTLore (Aliases lore) where
  expTypesFromPattern :: forall (m :: * -> *).
(HasScope (Aliases lore) m, Monad m) =>
Pattern (Aliases lore) -> m [BranchType (Aliases lore)]
expTypesFromPattern =
    ReaderT (Scope lore) m [BranchType lore] -> m [BranchType lore]
forall lore (m :: * -> *) a.
(HasScope (Aliases lore) m, Monad m) =>
ReaderT (Scope lore) m a -> m a
withoutAliases (ReaderT (Scope lore) m [BranchType lore] -> m [BranchType lore])
-> (PatternT (AliasDec, LetDec lore)
    -> ReaderT (Scope lore) m [BranchType lore])
-> PatternT (AliasDec, LetDec lore)
-> m [BranchType lore]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatternT (LetDec lore) -> ReaderT (Scope lore) m [BranchType lore]
forall lore (m :: * -> *).
(ASTLore lore, HasScope lore m, Monad m) =>
Pattern lore -> m [BranchType lore]
expTypesFromPattern (PatternT (LetDec lore)
 -> ReaderT (Scope lore) m [BranchType lore])
-> (PatternT (AliasDec, LetDec lore) -> PatternT (LetDec lore))
-> PatternT (AliasDec, LetDec lore)
-> ReaderT (Scope lore) m [BranchType lore]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatternT (AliasDec, LetDec lore) -> PatternT (LetDec lore)
forall a. PatternT (AliasDec, a) -> PatternT a
removePatternAliases

instance (ASTLore lore, CanBeAliased (Op lore)) => Aliased (Aliases lore) where
  bodyAliases :: Body (Aliases lore) -> [Names]
bodyAliases = (AliasDec -> Names) -> [AliasDec] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map AliasDec -> Names
unAliases ([AliasDec] -> [Names])
-> (Body (Aliases lore) -> [AliasDec])
-> Body (Aliases lore)
-> [Names]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BodyAliasing -> [AliasDec]
forall a b. (a, b) -> a
fst (BodyAliasing -> [AliasDec])
-> (Body (Aliases lore) -> BodyAliasing)
-> Body (Aliases lore)
-> [AliasDec]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (BodyAliasing, BodyDec lore) -> BodyAliasing
forall a b. (a, b) -> a
fst ((BodyAliasing, BodyDec lore) -> BodyAliasing)
-> (Body (Aliases lore) -> (BodyAliasing, BodyDec lore))
-> Body (Aliases lore)
-> BodyAliasing
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Body (Aliases lore) -> (BodyAliasing, BodyDec lore)
forall lore. BodyT lore -> BodyDec lore
bodyDec
  consumedInBody :: Body (Aliases lore) -> Names
consumedInBody = AliasDec -> Names
unAliases (AliasDec -> Names)
-> (Body (Aliases lore) -> AliasDec)
-> Body (Aliases lore)
-> Names
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BodyAliasing -> AliasDec
forall a b. (a, b) -> b
snd (BodyAliasing -> AliasDec)
-> (Body (Aliases lore) -> BodyAliasing)
-> Body (Aliases lore)
-> AliasDec
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (BodyAliasing, BodyDec lore) -> BodyAliasing
forall a b. (a, b) -> a
fst ((BodyAliasing, BodyDec lore) -> BodyAliasing)
-> (Body (Aliases lore) -> (BodyAliasing, BodyDec lore))
-> Body (Aliases lore)
-> BodyAliasing
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Body (Aliases lore) -> (BodyAliasing, BodyDec lore)
forall lore. BodyT lore -> BodyDec lore
bodyDec

instance (ASTLore lore, CanBeAliased (Op lore)) => PrettyLore (Aliases lore) where
  ppExpLore :: ExpDec (Aliases lore) -> Exp (Aliases lore) -> Maybe Doc
ppExpLore (AliasDec
consumed, ExpDec lore
inner) Exp (Aliases lore)
e =
    [Doc] -> Maybe Doc
maybeComment ([Doc] -> Maybe Doc) -> [Doc] -> Maybe Doc
forall a b. (a -> b) -> a -> b
$
      [Maybe Doc] -> [Doc]
forall a. [Maybe a] -> [a]
catMaybes
        [ Maybe Doc
exp_dec,
          Maybe Doc
merge_dec,
          ExpDec lore -> Exp lore -> Maybe Doc
forall lore.
PrettyLore lore =>
ExpDec lore -> Exp lore -> Maybe Doc
ppExpLore ExpDec lore
inner (Exp lore -> Maybe Doc) -> Exp lore -> Maybe Doc
forall a b. (a -> b) -> a -> b
$ Exp (Aliases lore) -> Exp lore
forall lore.
CanBeAliased (Op lore) =>
Exp (Aliases lore) -> Exp lore
removeExpAliases Exp (Aliases lore)
e
        ]
    where
      merge_dec :: Maybe Doc
merge_dec =
        case Exp (Aliases lore)
e of
          DoLoop [(FParam (Aliases lore), SubExp)]
_ [(FParam (Aliases lore), SubExp)]
merge LoopForm (Aliases lore)
_ BodyT (Aliases lore)
body ->
            let mergeParamAliases :: Param dec -> Names -> Maybe Doc
mergeParamAliases Param dec
fparam Names
als
                  | TypeBase Shape NoUniqueness -> Bool
forall shape u. TypeBase shape u -> Bool
primType (Param dec -> TypeBase Shape NoUniqueness
forall dec. Typed dec => Param dec -> TypeBase Shape NoUniqueness
paramType Param dec
fparam) =
                    Maybe Doc
forall a. Maybe a
Nothing
                  | Bool
otherwise =
                    VName -> Names -> Maybe Doc
forall a. Pretty a => a -> Names -> Maybe Doc
resultAliasComment (Param dec -> VName
forall dec. Param dec -> VName
paramName Param dec
fparam) Names
als
             in [Doc] -> Maybe Doc
maybeComment ([Doc] -> Maybe Doc) -> [Doc] -> Maybe Doc
forall a b. (a -> b) -> a -> b
$
                  [Maybe Doc] -> [Doc]
forall a. [Maybe a] -> [a]
catMaybes ([Maybe Doc] -> [Doc]) -> [Maybe Doc] -> [Doc]
forall a b. (a -> b) -> a -> b
$
                    (Param (FParamInfo lore) -> Names -> Maybe Doc)
-> [Param (FParamInfo lore)] -> [Names] -> [Maybe Doc]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Param (FParamInfo lore) -> Names -> Maybe Doc
forall {dec}. Typed dec => Param dec -> Names -> Maybe Doc
mergeParamAliases (((Param (FParamInfo lore), SubExp) -> Param (FParamInfo lore))
-> [(Param (FParamInfo lore), SubExp)] -> [Param (FParamInfo lore)]
forall a b. (a -> b) -> [a] -> [b]
map (Param (FParamInfo lore), SubExp) -> Param (FParamInfo lore)
forall a b. (a, b) -> a
fst [(Param (FParamInfo lore), SubExp)]
[(FParam (Aliases lore), SubExp)]
merge) ([Names] -> [Maybe Doc]) -> [Names] -> [Maybe Doc]
forall a b. (a -> b) -> a -> b
$
                      BodyT (Aliases lore) -> [Names]
forall lore. Aliased lore => Body lore -> [Names]
bodyAliases BodyT (Aliases lore)
body
          Exp (Aliases lore)
_ -> Maybe Doc
forall a. Maybe a
Nothing

      exp_dec :: Maybe Doc
exp_dec = case Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$ AliasDec -> Names
unAliases AliasDec
consumed of
        [] -> Maybe Doc
forall a. Maybe a
Nothing
        [VName]
als ->
          Doc -> Maybe Doc
forall a. a -> Maybe a
Just (Doc -> Maybe Doc) -> Doc -> Maybe Doc
forall a b. (a -> b) -> a -> b
$
            Doc -> Doc
PP.oneLine (Doc -> Doc) -> Doc -> Doc
forall a b. (a -> b) -> a -> b
$
              String -> Doc
PP.text String
"-- Consumes " Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> [Doc] -> Doc
PP.commasep ((VName -> Doc) -> [VName] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map VName -> Doc
forall a. Pretty a => a -> Doc
PP.ppr [VName]
als)

maybeComment :: [PP.Doc] -> Maybe PP.Doc
maybeComment :: [Doc] -> Maybe Doc
maybeComment [] = Maybe Doc
forall a. Maybe a
Nothing
maybeComment [Doc]
cs = Doc -> Maybe Doc
forall a. a -> Maybe a
Just (Doc -> Maybe Doc) -> Doc -> Maybe Doc
forall a b. (a -> b) -> a -> b
$ (Doc -> Doc -> Doc) -> [Doc] -> Doc
PP.folddoc Doc -> Doc -> Doc
(PP.</>) [Doc]
cs

resultAliasComment :: PP.Pretty a => a -> Names -> Maybe PP.Doc
resultAliasComment :: forall a. Pretty a => a -> Names -> Maybe Doc
resultAliasComment a
name Names
als =
  case Names -> [VName]
namesToList Names
als of
    [] -> Maybe Doc
forall a. Maybe a
Nothing
    [VName]
als' ->
      Doc -> Maybe Doc
forall a. a -> Maybe a
Just (Doc -> Maybe Doc) -> Doc -> Maybe Doc
forall a b. (a -> b) -> a -> b
$
        Doc -> Doc
PP.oneLine (Doc -> Doc) -> Doc -> Doc
forall a b. (a -> b) -> a -> b
$
          String -> Doc
PP.text String
"-- Result of " Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> a -> Doc
forall a. Pretty a => a -> Doc
PP.ppr a
name Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> String -> Doc
PP.text String
" aliases "
            Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> [Doc] -> Doc
PP.commasep ((VName -> Doc) -> [VName] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map VName -> Doc
forall a. Pretty a => a -> Doc
PP.ppr [VName]
als')

removeAliases :: CanBeAliased (Op lore) => Rephraser Identity (Aliases lore) lore
removeAliases :: forall lore.
CanBeAliased (Op lore) =>
Rephraser Identity (Aliases lore) lore
removeAliases =
  Rephraser :: forall (m :: * -> *) from to.
(ExpDec from -> m (ExpDec to))
-> (LetDec from -> m (LetDec to))
-> (FParamInfo from -> m (FParamInfo to))
-> (LParamInfo from -> m (LParamInfo to))
-> (BodyDec from -> m (BodyDec to))
-> (RetType from -> m (RetType to))
-> (BranchType from -> m (BranchType to))
-> (Op from -> m (Op to))
-> Rephraser m from to
Rephraser
    { rephraseExpLore :: ExpDec (Aliases lore) -> Identity (ExpDec lore)
rephraseExpLore = ExpDec lore -> Identity (ExpDec lore)
forall (m :: * -> *) a. Monad m => a -> m a
return (ExpDec lore -> Identity (ExpDec lore))
-> ((AliasDec, ExpDec lore) -> ExpDec lore)
-> (AliasDec, ExpDec lore)
-> Identity (ExpDec lore)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (AliasDec, ExpDec lore) -> ExpDec lore
forall a b. (a, b) -> b
snd,
      rephraseLetBoundLore :: LetDec (Aliases lore) -> Identity (LetDec lore)
rephraseLetBoundLore = LetDec lore -> Identity (LetDec lore)
forall (m :: * -> *) a. Monad m => a -> m a
return (LetDec lore -> Identity (LetDec lore))
-> ((AliasDec, LetDec lore) -> LetDec lore)
-> (AliasDec, LetDec lore)
-> Identity (LetDec lore)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (AliasDec, LetDec lore) -> LetDec lore
forall a b. (a, b) -> b
snd,
      rephraseBodyLore :: BodyDec (Aliases lore) -> Identity (BodyDec lore)
rephraseBodyLore = BodyDec lore -> Identity (BodyDec lore)
forall (m :: * -> *) a. Monad m => a -> m a
return (BodyDec lore -> Identity (BodyDec lore))
-> ((BodyAliasing, BodyDec lore) -> BodyDec lore)
-> (BodyAliasing, BodyDec lore)
-> Identity (BodyDec lore)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (BodyAliasing, BodyDec lore) -> BodyDec lore
forall a b. (a, b) -> b
snd,
      rephraseFParamLore :: FParamInfo (Aliases lore) -> Identity (FParamInfo lore)
rephraseFParamLore = FParamInfo (Aliases lore) -> Identity (FParamInfo lore)
forall (m :: * -> *) a. Monad m => a -> m a
return,
      rephraseLParamLore :: LParamInfo (Aliases lore) -> Identity (LParamInfo lore)
rephraseLParamLore = LParamInfo (Aliases lore) -> Identity (LParamInfo lore)
forall (m :: * -> *) a. Monad m => a -> m a
return,
      rephraseRetType :: RetType (Aliases lore) -> Identity (RetType lore)
rephraseRetType = RetType (Aliases lore) -> Identity (RetType lore)
forall (m :: * -> *) a. Monad m => a -> m a
return,
      rephraseBranchType :: BranchType (Aliases lore) -> Identity (BranchType lore)
rephraseBranchType = BranchType (Aliases lore) -> Identity (BranchType lore)
forall (m :: * -> *) a. Monad m => a -> m a
return,
      rephraseOp :: Op (Aliases lore) -> Identity (Op lore)
rephraseOp = Op lore -> Identity (Op lore)
forall (m :: * -> *) a. Monad m => a -> m a
return (Op lore -> Identity (Op lore))
-> (OpWithAliases (Op lore) -> Op lore)
-> OpWithAliases (Op lore)
-> Identity (Op lore)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. OpWithAliases (Op lore) -> Op lore
forall op. CanBeAliased op => OpWithAliases op -> op
removeOpAliases
    }

removeScopeAliases :: Scope (Aliases lore) -> Scope lore
removeScopeAliases :: forall lore. Scope (Aliases lore) -> Scope lore
removeScopeAliases = (NameInfo (Aliases lore) -> NameInfo lore)
-> Map VName (NameInfo (Aliases lore)) -> Map VName (NameInfo lore)
forall a b k. (a -> b) -> Map k a -> Map k b
M.map NameInfo (Aliases lore) -> NameInfo lore
forall {lore} {a} {lore}.
(LetDec lore ~ (a, LetDec lore), FParamInfo lore ~ FParamInfo lore,
 LParamInfo lore ~ LParamInfo lore) =>
NameInfo lore -> NameInfo lore
unAlias
  where
    unAlias :: NameInfo lore -> NameInfo lore
unAlias (LetName (a
_, LetDec lore
dec)) = LetDec lore -> NameInfo lore
forall lore. LetDec lore -> NameInfo lore
LetName LetDec lore
dec
    unAlias (FParamName FParamInfo lore
dec) = FParamInfo lore -> NameInfo lore
forall lore. FParamInfo lore -> NameInfo lore
FParamName FParamInfo lore
FParamInfo lore
dec
    unAlias (LParamName LParamInfo lore
dec) = LParamInfo lore -> NameInfo lore
forall lore. LParamInfo lore -> NameInfo lore
LParamName LParamInfo lore
LParamInfo lore
dec
    unAlias (IndexName IntType
it) = IntType -> NameInfo lore
forall lore. IntType -> NameInfo lore
IndexName IntType
it

removeProgAliases ::
  CanBeAliased (Op lore) =>
  Prog (Aliases lore) ->
  Prog lore
removeProgAliases :: forall lore.
CanBeAliased (Op lore) =>
Prog (Aliases lore) -> Prog lore
removeProgAliases = Identity (Prog lore) -> Prog lore
forall a. Identity a -> a
runIdentity (Identity (Prog lore) -> Prog lore)
-> (Prog (Aliases lore) -> Identity (Prog lore))
-> Prog (Aliases lore)
-> Prog lore
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Rephraser Identity (Aliases lore) lore
-> Prog (Aliases lore) -> Identity (Prog lore)
forall (m :: * -> *) from to.
Monad m =>
Rephraser m from to -> Prog from -> m (Prog to)
rephraseProg Rephraser Identity (Aliases lore) lore
forall lore.
CanBeAliased (Op lore) =>
Rephraser Identity (Aliases lore) lore
removeAliases

removeFunDefAliases ::
  CanBeAliased (Op lore) =>
  FunDef (Aliases lore) ->
  FunDef lore
removeFunDefAliases :: forall lore.
CanBeAliased (Op lore) =>
FunDef (Aliases lore) -> FunDef lore
removeFunDefAliases = Identity (FunDef lore) -> FunDef lore
forall a. Identity a -> a
runIdentity (Identity (FunDef lore) -> FunDef lore)
-> (FunDef (Aliases lore) -> Identity (FunDef lore))
-> FunDef (Aliases lore)
-> FunDef lore
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Rephraser Identity (Aliases lore) lore
-> FunDef (Aliases lore) -> Identity (FunDef lore)
forall (m :: * -> *) from to.
Monad m =>
Rephraser m from to -> FunDef from -> m (FunDef to)
rephraseFunDef Rephraser Identity (Aliases lore) lore
forall lore.
CanBeAliased (Op lore) =>
Rephraser Identity (Aliases lore) lore
removeAliases

removeExpAliases ::
  CanBeAliased (Op lore) =>
  Exp (Aliases lore) ->
  Exp lore
removeExpAliases :: forall lore.
CanBeAliased (Op lore) =>
Exp (Aliases lore) -> Exp lore
removeExpAliases = Identity (Exp lore) -> Exp lore
forall a. Identity a -> a
runIdentity (Identity (Exp lore) -> Exp lore)
-> (Exp (Aliases lore) -> Identity (Exp lore))
-> Exp (Aliases lore)
-> Exp lore
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Rephraser Identity (Aliases lore) lore
-> Exp (Aliases lore) -> Identity (Exp lore)
forall (m :: * -> *) from to.
Monad m =>
Rephraser m from to -> Exp from -> m (Exp to)
rephraseExp Rephraser Identity (Aliases lore) lore
forall lore.
CanBeAliased (Op lore) =>
Rephraser Identity (Aliases lore) lore
removeAliases

removeStmAliases ::
  CanBeAliased (Op lore) =>
  Stm (Aliases lore) ->
  Stm lore
removeStmAliases :: forall lore.
CanBeAliased (Op lore) =>
Stm (Aliases lore) -> Stm lore
removeStmAliases = Identity (Stm lore) -> Stm lore
forall a. Identity a -> a
runIdentity (Identity (Stm lore) -> Stm lore)
-> (Stm (Aliases lore) -> Identity (Stm lore))
-> Stm (Aliases lore)
-> Stm lore
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Rephraser Identity (Aliases lore) lore
-> Stm (Aliases lore) -> Identity (Stm lore)
forall (m :: * -> *) from to.
Monad m =>
Rephraser m from to -> Stm from -> m (Stm to)
rephraseStm Rephraser Identity (Aliases lore) lore
forall lore.
CanBeAliased (Op lore) =>
Rephraser Identity (Aliases lore) lore
removeAliases

removeLambdaAliases ::
  CanBeAliased (Op lore) =>
  Lambda (Aliases lore) ->
  Lambda lore
removeLambdaAliases :: forall lore.
CanBeAliased (Op lore) =>
Lambda (Aliases lore) -> Lambda lore
removeLambdaAliases = Identity (Lambda lore) -> Lambda lore
forall a. Identity a -> a
runIdentity (Identity (Lambda lore) -> Lambda lore)
-> (Lambda (Aliases lore) -> Identity (Lambda lore))
-> Lambda (Aliases lore)
-> Lambda lore
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Rephraser Identity (Aliases lore) lore
-> Lambda (Aliases lore) -> Identity (Lambda lore)
forall (m :: * -> *) from to.
Monad m =>
Rephraser m from to -> Lambda from -> m (Lambda to)
rephraseLambda Rephraser Identity (Aliases lore) lore
forall lore.
CanBeAliased (Op lore) =>
Rephraser Identity (Aliases lore) lore
removeAliases

removePatternAliases ::
  PatternT (AliasDec, a) ->
  PatternT a
removePatternAliases :: forall a. PatternT (AliasDec, a) -> PatternT a
removePatternAliases = Identity (PatternT a) -> PatternT a
forall a. Identity a -> a
runIdentity (Identity (PatternT a) -> PatternT a)
-> (PatternT (AliasDec, a) -> Identity (PatternT a))
-> PatternT (AliasDec, a)
-> PatternT a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((AliasDec, a) -> Identity a)
-> PatternT (AliasDec, a) -> Identity (PatternT a)
forall (m :: * -> *) from to.
Monad m =>
(from -> m to) -> PatternT from -> m (PatternT to)
rephrasePattern (a -> Identity a
forall (m :: * -> *) a. Monad m => a -> m a
return (a -> Identity a)
-> ((AliasDec, a) -> a) -> (AliasDec, a) -> Identity a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (AliasDec, a) -> a
forall a b. (a, b) -> b
snd)

addAliasesToPattern ::
  (ASTLore lore, CanBeAliased (Op lore), Typed dec) =>
  PatternT dec ->
  Exp (Aliases lore) ->
  PatternT (VarAliases, dec)
addAliasesToPattern :: forall lore dec.
(ASTLore lore, CanBeAliased (Op lore), Typed dec) =>
PatternT dec -> Exp (Aliases lore) -> PatternT (AliasDec, dec)
addAliasesToPattern PatternT dec
pat Exp (Aliases lore)
e =
  ([PatElemT (AliasDec, dec)]
 -> [PatElemT (AliasDec, dec)] -> PatternT (AliasDec, dec))
-> ([PatElemT (AliasDec, dec)], [PatElemT (AliasDec, dec)])
-> PatternT (AliasDec, dec)
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry [PatElemT (AliasDec, dec)]
-> [PatElemT (AliasDec, dec)] -> PatternT (AliasDec, dec)
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern (([PatElemT (AliasDec, dec)], [PatElemT (AliasDec, dec)])
 -> PatternT (AliasDec, dec))
-> ([PatElemT (AliasDec, dec)], [PatElemT (AliasDec, dec)])
-> PatternT (AliasDec, dec)
forall a b. (a -> b) -> a -> b
$ PatternT dec
-> Exp (Aliases lore)
-> ([PatElemT (AliasDec, dec)], [PatElemT (AliasDec, dec)])
forall lore dec.
(Aliased lore, Typed dec) =>
PatternT dec
-> Exp lore
-> ([PatElemT (AliasDec, dec)], [PatElemT (AliasDec, dec)])
mkPatternAliases PatternT dec
pat Exp (Aliases lore)
e

mkAliasedBody ::
  (ASTLore lore, CanBeAliased (Op lore)) =>
  BodyDec lore ->
  Stms (Aliases lore) ->
  Result ->
  Body (Aliases lore)
mkAliasedBody :: forall lore.
(ASTLore lore, CanBeAliased (Op lore)) =>
BodyDec lore
-> Stms (Aliases lore) -> Result -> Body (Aliases lore)
mkAliasedBody BodyDec lore
innerlore Stms (Aliases lore)
bnds Result
res =
  BodyDec (Aliases lore)
-> Stms (Aliases lore) -> Result -> BodyT (Aliases lore)
forall lore. BodyDec lore -> Stms lore -> Result -> BodyT lore
Body (Stms (Aliases lore) -> Result -> BodyAliasing
forall lore. Aliased lore => Stms lore -> Result -> BodyAliasing
mkBodyAliases Stms (Aliases lore)
bnds Result
res, BodyDec lore
innerlore) Stms (Aliases lore)
bnds Result
res

mkPatternAliases ::
  (Aliased lore, Typed dec) =>
  PatternT dec ->
  Exp lore ->
  ( [PatElemT (VarAliases, dec)],
    [PatElemT (VarAliases, dec)]
  )
mkPatternAliases :: forall lore dec.
(Aliased lore, Typed dec) =>
PatternT dec
-> Exp lore
-> ([PatElemT (AliasDec, dec)], [PatElemT (AliasDec, dec)])
mkPatternAliases PatternT dec
pat Exp lore
e =
  -- Some part of the pattern may be the context.  This does not have
  -- aliases from expAliases, so we use a hack to compute aliases of
  -- the context.
  let als :: [Names]
als = Exp lore -> [Names]
forall lore. Aliased lore => Exp lore -> [Names]
expAliases Exp lore
e [Names] -> [Names] -> [Names]
forall a. [a] -> [a] -> [a]
++ Names -> [Names]
forall a. a -> [a]
repeat Names
forall a. Monoid a => a
mempty -- In case the pattern has
  -- more elements (this
  -- implies a type error).
      context_als :: [Names]
context_als = PatternT dec -> Exp lore -> [Names]
forall lore dec.
Aliased lore =>
PatternT dec -> Exp lore -> [Names]
mkContextAliases PatternT dec
pat Exp lore
e
   in ( (PatElemT dec -> Names -> PatElemT (AliasDec, dec))
-> [PatElemT dec] -> [Names] -> [PatElemT (AliasDec, dec)]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith PatElemT dec -> Names -> PatElemT (AliasDec, dec)
forall {b}.
Typed b =>
PatElemT b -> Names -> PatElemT (AliasDec, b)
annotateBindee (PatternT dec -> [PatElemT dec]
forall dec. PatternT dec -> [PatElemT dec]
patternContextElements PatternT dec
pat) [Names]
context_als,
        (PatElemT dec -> Names -> PatElemT (AliasDec, dec))
-> [PatElemT dec] -> [Names] -> [PatElemT (AliasDec, dec)]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith PatElemT dec -> Names -> PatElemT (AliasDec, dec)
forall {b}.
Typed b =>
PatElemT b -> Names -> PatElemT (AliasDec, b)
annotateBindee (PatternT dec -> [PatElemT dec]
forall dec. PatternT dec -> [PatElemT dec]
patternValueElements PatternT dec
pat) [Names]
als
      )
  where
    annotateBindee :: PatElemT b -> Names -> PatElemT (AliasDec, b)
annotateBindee PatElemT b
bindee Names
names =
      PatElemT b
bindee PatElemT b -> (AliasDec, b) -> PatElemT (AliasDec, b)
forall oldattr newattr.
PatElemT oldattr -> newattr -> PatElemT newattr
`setPatElemLore` (Names -> AliasDec
AliasDec Names
names', PatElemT b -> b
forall dec. PatElemT dec -> dec
patElemDec PatElemT b
bindee)
      where
        names' :: Names
names' =
          case PatElemT b -> TypeBase Shape NoUniqueness
forall dec.
Typed dec =>
PatElemT dec -> TypeBase Shape NoUniqueness
patElemType PatElemT b
bindee of
            Array {} -> Names
names
            Mem Space
_ -> Names
names
            TypeBase Shape NoUniqueness
_ -> Names
forall a. Monoid a => a
mempty

mkContextAliases ::
  Aliased lore =>
  PatternT dec ->
  Exp lore ->
  [Names]
mkContextAliases :: forall lore dec.
Aliased lore =>
PatternT dec -> Exp lore -> [Names]
mkContextAliases PatternT dec
pat (DoLoop [(FParam lore, SubExp)]
ctxmerge [(FParam lore, SubExp)]
valmerge LoopForm lore
_ BodyT lore
body) =
  let ctx :: [FParam lore]
ctx = ((FParam lore, SubExp) -> FParam lore)
-> [(FParam lore, SubExp)] -> [FParam lore]
forall a b. (a -> b) -> [a] -> [b]
map (FParam lore, SubExp) -> FParam lore
forall a b. (a, b) -> a
fst [(FParam lore, SubExp)]
ctxmerge
      init_als :: [(VName, Names)]
init_als = [VName] -> [Names] -> [(VName, Names)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
mergenames ([Names] -> [(VName, Names)]) -> [Names] -> [(VName, Names)]
forall a b. (a -> b) -> a -> b
$ ((FParam lore, SubExp) -> Names)
-> [(FParam lore, SubExp)] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> Names
subExpAliases (SubExp -> Names)
-> ((FParam lore, SubExp) -> SubExp)
-> (FParam lore, SubExp)
-> Names
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (FParam lore, SubExp) -> SubExp
forall a b. (a, b) -> b
snd) ([(FParam lore, SubExp)] -> [Names])
-> [(FParam lore, SubExp)] -> [Names]
forall a b. (a -> b) -> a -> b
$ [(FParam lore, SubExp)]
ctxmerge [(FParam lore, SubExp)]
-> [(FParam lore, SubExp)] -> [(FParam lore, SubExp)]
forall a. [a] -> [a] -> [a]
++ [(FParam lore, SubExp)]
valmerge
      expand :: Names -> Names
expand Names
als = Names
als Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> [Names] -> Names
forall a. Monoid a => [a] -> a
mconcat ((VName -> Maybe Names) -> [VName] -> [Names]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (VName -> [(VName, Names)] -> Maybe Names
forall a b. Eq a => a -> [(a, b)] -> Maybe b
`lookup` [(VName, Names)]
init_als) (Names -> [VName]
namesToList Names
als))
      merge_als :: [(VName, Names)]
merge_als =
        [VName] -> [Names] -> [(VName, Names)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
mergenames ([Names] -> [(VName, Names)]) -> [Names] -> [(VName, Names)]
forall a b. (a -> b) -> a -> b
$
          (Names -> Names) -> [Names] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map ((Names -> Names -> Names
`namesSubtract` Names
mergenames_set) (Names -> Names) -> (Names -> Names) -> Names -> Names
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Names -> Names
expand) ([Names] -> [Names]) -> [Names] -> [Names]
forall a b. (a -> b) -> a -> b
$
            BodyT lore -> [Names]
forall lore. Aliased lore => Body lore -> [Names]
bodyAliases BodyT lore
body
   in if [FParam lore] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [FParam lore]
ctx Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [PatElemT dec] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (PatternT dec -> [PatElemT dec]
forall dec. PatternT dec -> [PatElemT dec]
patternContextElements PatternT dec
pat)
        then (FParam lore -> Names) -> [FParam lore] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map (Names -> Maybe Names -> Names
forall a. a -> Maybe a -> a
fromMaybe Names
forall a. Monoid a => a
mempty (Maybe Names -> Names)
-> (FParam lore -> Maybe Names) -> FParam lore -> Names
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName -> [(VName, Names)] -> Maybe Names)
-> [(VName, Names)] -> VName -> Maybe Names
forall a b c. (a -> b -> c) -> b -> a -> c
flip VName -> [(VName, Names)] -> Maybe Names
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup [(VName, Names)]
merge_als (VName -> Maybe Names)
-> (FParam lore -> VName) -> FParam lore -> Maybe Names
forall b c a. (b -> c) -> (a -> b) -> a -> c
. FParam lore -> VName
forall dec. Param dec -> VName
paramName) [FParam lore]
ctx
        else (PatElemT dec -> Names) -> [PatElemT dec] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map (Names -> PatElemT dec -> Names
forall a b. a -> b -> a
const Names
forall a. Monoid a => a
mempty) ([PatElemT dec] -> [Names]) -> [PatElemT dec] -> [Names]
forall a b. (a -> b) -> a -> b
$ PatternT dec -> [PatElemT dec]
forall dec. PatternT dec -> [PatElemT dec]
patternContextElements PatternT dec
pat
  where
    mergenames :: [VName]
mergenames = ((FParam lore, SubExp) -> VName)
-> [(FParam lore, SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (FParam lore -> VName
forall dec. Param dec -> VName
paramName (FParam lore -> VName)
-> ((FParam lore, SubExp) -> FParam lore)
-> (FParam lore, SubExp)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (FParam lore, SubExp) -> FParam lore
forall a b. (a, b) -> a
fst) ([(FParam lore, SubExp)] -> [VName])
-> [(FParam lore, SubExp)] -> [VName]
forall a b. (a -> b) -> a -> b
$ [(FParam lore, SubExp)]
ctxmerge [(FParam lore, SubExp)]
-> [(FParam lore, SubExp)] -> [(FParam lore, SubExp)]
forall a. [a] -> [a] -> [a]
++ [(FParam lore, SubExp)]
valmerge
    mergenames_set :: Names
mergenames_set = [VName] -> Names
namesFromList [VName]
mergenames
mkContextAliases PatternT dec
pat (If SubExp
_ BodyT lore
tbranch BodyT lore
fbranch IfDec (BranchType lore)
_) =
  Int -> [Names] -> [Names]
forall a. Int -> [a] -> [a]
take ([VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([VName] -> Int) -> [VName] -> Int
forall a b. (a -> b) -> a -> b
$ PatternT dec -> [VName]
forall dec. PatternT dec -> [VName]
patternContextNames PatternT dec
pat) ([Names] -> [Names]) -> [Names] -> [Names]
forall a b. (a -> b) -> a -> b
$
    (Names -> Names -> Names) -> [Names] -> [Names] -> [Names]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
(<>) (BodyT lore -> [Names]
forall lore. Aliased lore => Body lore -> [Names]
bodyAliases BodyT lore
tbranch) (BodyT lore -> [Names]
forall lore. Aliased lore => Body lore -> [Names]
bodyAliases BodyT lore
fbranch)
mkContextAliases PatternT dec
pat ExpT lore
_ =
  Int -> Names -> [Names]
forall a. Int -> a -> [a]
replicate ([PatElemT dec] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([PatElemT dec] -> Int) -> [PatElemT dec] -> Int
forall a b. (a -> b) -> a -> b
$ PatternT dec -> [PatElemT dec]
forall dec. PatternT dec -> [PatElemT dec]
patternContextElements PatternT dec
pat) Names
forall a. Monoid a => a
mempty

mkBodyAliases ::
  Aliased lore =>
  Stms lore ->
  Result ->
  BodyAliasing
mkBodyAliases :: forall lore. Aliased lore => Stms lore -> Result -> BodyAliasing
mkBodyAliases Stms lore
bnds Result
res =
  -- We need to remove the names that are bound in bnds from the alias
  -- and consumption sets.  We do this by computing the transitive
  -- closure of the alias map (within bnds), then removing anything
  -- bound in bnds.
  let ([Names]
aliases, Names
consumed) = Stms lore -> Result -> ([Names], Names)
forall lore.
Aliased lore =>
Stms lore -> Result -> ([Names], Names)
mkStmsAliases Stms lore
bnds Result
res
      boundNames :: Names
boundNames =
        (Stm lore -> Names) -> Stms lore -> Names
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap ([VName] -> Names
namesFromList ([VName] -> Names) -> (Stm lore -> [VName]) -> Stm lore -> Names
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatternT (LetDec lore) -> [VName]
forall dec. PatternT dec -> [VName]
patternNames (PatternT (LetDec lore) -> [VName])
-> (Stm lore -> PatternT (LetDec lore)) -> Stm lore -> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm lore -> PatternT (LetDec lore)
forall lore. Stm lore -> Pattern lore
stmPattern) Stms lore
bnds
      aliases' :: [Names]
aliases' = (Names -> Names) -> [Names] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map (Names -> Names -> Names
`namesSubtract` Names
boundNames) [Names]
aliases
      consumed' :: Names
consumed' = Names
consumed Names -> Names -> Names
`namesSubtract` Names
boundNames
   in ((Names -> AliasDec) -> [Names] -> [AliasDec]
forall a b. (a -> b) -> [a] -> [b]
map Names -> AliasDec
AliasDec [Names]
aliases', Names -> AliasDec
AliasDec Names
consumed')

-- | The aliases of the result and everything consumed in the given
-- statements.
mkStmsAliases ::
  Aliased lore =>
  Stms lore ->
  [SubExp] ->
  ([Names], Names)
mkStmsAliases :: forall lore.
Aliased lore =>
Stms lore -> Result -> ([Names], Names)
mkStmsAliases Stms lore
bnds Result
res = AliasesAndConsumed -> [Stm lore] -> ([Names], Names)
delve AliasesAndConsumed
forall a. Monoid a => a
mempty ([Stm lore] -> ([Names], Names)) -> [Stm lore] -> ([Names], Names)
forall a b. (a -> b) -> a -> b
$ Stms lore -> [Stm lore]
forall lore. Stms lore -> [Stm lore]
stmsToList Stms lore
bnds
  where
    delve :: AliasesAndConsumed -> [Stm lore] -> ([Names], Names)
delve (Map VName Names
aliasmap, Names
consumed) [] =
      ( (SubExp -> Names) -> Result -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map (Map VName Names -> Names -> Names
aliasClosure Map VName Names
aliasmap (Names -> Names) -> (SubExp -> Names) -> SubExp -> Names
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> Names
subExpAliases) Result
res,
        Names
consumed
      )
    delve (Map VName Names
aliasmap, Names
consumed) (Stm lore
bnd : [Stm lore]
bnds') =
      AliasesAndConsumed -> [Stm lore] -> ([Names], Names)
delve (AliasesAndConsumed -> Stm lore -> AliasesAndConsumed
forall lore.
Aliased lore =>
AliasesAndConsumed -> Stm lore -> AliasesAndConsumed
trackAliases (Map VName Names
aliasmap, Names
consumed) Stm lore
bnd) [Stm lore]
bnds'
    aliasClosure :: Map VName Names -> Names -> Names
aliasClosure Map VName Names
aliasmap Names
names =
      Names
names Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> [Names] -> Names
forall a. Monoid a => [a] -> a
mconcat ((VName -> Names) -> [VName] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map VName -> Names
look ([VName] -> [Names]) -> [VName] -> [Names]
forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList Names
names)
      where
        look :: VName -> Names
look VName
k = Names -> VName -> Map VName Names -> Names
forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault Names
forall a. Monoid a => a
mempty VName
k Map VName Names
aliasmap

type AliasesAndConsumed =
  ( M.Map VName Names,
    Names
  )

trackAliases ::
  Aliased lore =>
  AliasesAndConsumed ->
  Stm lore ->
  AliasesAndConsumed
trackAliases :: forall lore.
Aliased lore =>
AliasesAndConsumed -> Stm lore -> AliasesAndConsumed
trackAliases (Map VName Names
aliasmap, Names
consumed) Stm lore
stm =
  let pat :: Pattern lore
pat = Stm lore -> Pattern lore
forall lore. Stm lore -> Pattern lore
stmPattern Stm lore
stm
      pe_als :: [(VName, Names)]
pe_als =
        [VName] -> [Names] -> [(VName, Names)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Pattern lore -> [VName]
forall dec. PatternT dec -> [VName]
patternNames Pattern lore
pat) ([Names] -> [(VName, Names)]) -> [Names] -> [(VName, Names)]
forall a b. (a -> b) -> a -> b
$ (Names -> Names) -> [Names] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map Names -> Names
addAliasesOfAliases ([Names] -> [Names]) -> [Names] -> [Names]
forall a b. (a -> b) -> a -> b
$ Pattern lore -> [Names]
forall dec. AliasesOf dec => PatternT dec -> [Names]
patternAliases Pattern lore
pat
      als :: Map VName Names
als = [(VName, Names)] -> Map VName Names
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [(VName, Names)]
pe_als
      rev_als :: Map VName Names
rev_als = ((VName, Names) -> Map VName Names)
-> [(VName, Names)] -> Map VName Names
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (VName, Names) -> Map VName Names
revAls [(VName, Names)]
pe_als
      revAls :: (VName, Names) -> Map VName Names
revAls (VName
v, Names
v_als) =
        [(VName, Names)] -> Map VName Names
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, Names)] -> Map VName Names)
-> [(VName, Names)] -> Map VName Names
forall a b. (a -> b) -> a -> b
$ (VName -> (VName, Names)) -> [VName] -> [(VName, Names)]
forall a b. (a -> b) -> [a] -> [b]
map (,VName -> Names
oneName VName
v) ([VName] -> [(VName, Names)]) -> [VName] -> [(VName, Names)]
forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList Names
v_als
      comb :: Map VName Names -> Map VName Names -> Map VName Names
comb = (Names -> Names -> Names)
-> Map VName Names -> Map VName Names -> Map VName Names
forall k a. Ord k => (a -> a -> a) -> Map k a -> Map k a -> Map k a
M.unionWith Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
(<>)
      aliasmap' :: Map VName Names
aliasmap' = Map VName Names
rev_als Map VName Names -> Map VName Names -> Map VName Names
`comb` Map VName Names
als Map VName Names -> Map VName Names -> Map VName Names
`comb` Map VName Names
aliasmap
      consumed' :: Names
consumed' = Names
consumed Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names -> Names
addAliasesOfAliases (Stm lore -> Names
forall lore. Aliased lore => Stm lore -> Names
consumedInStm Stm lore
stm)
   in (Map VName Names
aliasmap', Names
consumed')
  where
    addAliasesOfAliases :: Names -> Names
addAliasesOfAliases Names
names = Names
names Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names -> Names
aliasesOfAliases Names
names
    aliasesOfAliases :: Names -> Names
aliasesOfAliases = [Names] -> Names
forall a. Monoid a => [a] -> a
mconcat ([Names] -> Names) -> (Names -> [Names]) -> Names -> Names
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName -> Names) -> [VName] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map VName -> Names
look ([VName] -> [Names]) -> (Names -> [VName]) -> Names -> [Names]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Names -> [VName]
namesToList
    look :: VName -> Names
look VName
k = Names -> VName -> Map VName Names -> Names
forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault Names
forall a. Monoid a => a
mempty VName
k Map VName Names
aliasmap

mkAliasedLetStm ::
  (ASTLore lore, CanBeAliased (Op lore)) =>
  Pattern lore ->
  StmAux (ExpDec lore) ->
  Exp (Aliases lore) ->
  Stm (Aliases lore)
mkAliasedLetStm :: forall lore.
(ASTLore lore, CanBeAliased (Op lore)) =>
Pattern lore
-> StmAux (ExpDec lore) -> Exp (Aliases lore) -> Stm (Aliases lore)
mkAliasedLetStm Pattern lore
pat (StmAux Certificates
cs Attrs
attrs ExpDec lore
dec) Exp (Aliases lore)
e =
  Pattern (Aliases lore)
-> StmAux (ExpDec (Aliases lore))
-> Exp (Aliases lore)
-> Stm (Aliases lore)
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let
    (Pattern lore
-> Exp (Aliases lore) -> PatternT (AliasDec, LetDec lore)
forall lore dec.
(ASTLore lore, CanBeAliased (Op lore), Typed dec) =>
PatternT dec -> Exp (Aliases lore) -> PatternT (AliasDec, dec)
addAliasesToPattern Pattern lore
pat Exp (Aliases lore)
e)
    (Certificates
-> Attrs
-> (AliasDec, ExpDec lore)
-> StmAux (AliasDec, ExpDec lore)
forall dec. Certificates -> Attrs -> dec -> StmAux dec
StmAux Certificates
cs Attrs
attrs (Names -> AliasDec
AliasDec (Names -> AliasDec) -> Names -> AliasDec
forall a b. (a -> b) -> a -> b
$ Exp (Aliases lore) -> Names
forall lore. Aliased lore => Exp lore -> Names
consumedInExp Exp (Aliases lore)
e, ExpDec lore
dec))
    Exp (Aliases lore)
e

instance (Bindable lore, CanBeAliased (Op lore)) => Bindable (Aliases lore) where
  mkExpDec :: Pattern (Aliases lore)
-> Exp (Aliases lore) -> ExpDec (Aliases lore)
mkExpDec Pattern (Aliases lore)
pat Exp (Aliases lore)
e =
    let dec :: ExpDec lore
dec = Pattern lore -> Exp lore -> ExpDec lore
forall lore.
Bindable lore =>
Pattern lore -> Exp lore -> ExpDec lore
mkExpDec (PatternT (AliasDec, LetDec lore) -> Pattern lore
forall a. PatternT (AliasDec, a) -> PatternT a
removePatternAliases PatternT (AliasDec, LetDec lore)
Pattern (Aliases lore)
pat) (Exp lore -> ExpDec lore) -> Exp lore -> ExpDec lore
forall a b. (a -> b) -> a -> b
$ Exp (Aliases lore) -> Exp lore
forall lore.
CanBeAliased (Op lore) =>
Exp (Aliases lore) -> Exp lore
removeExpAliases Exp (Aliases lore)
e
     in (Names -> AliasDec
AliasDec (Names -> AliasDec) -> Names -> AliasDec
forall a b. (a -> b) -> a -> b
$ Exp (Aliases lore) -> Names
forall lore. Aliased lore => Exp lore -> Names
consumedInExp Exp (Aliases lore)
e, ExpDec lore
dec)

  mkExpPat :: [Ident] -> [Ident] -> Exp (Aliases lore) -> Pattern (Aliases lore)
mkExpPat [Ident]
ctx [Ident]
val Exp (Aliases lore)
e =
    Pattern lore
-> Exp (Aliases lore) -> PatternT (AliasDec, LetDec lore)
forall lore dec.
(ASTLore lore, CanBeAliased (Op lore), Typed dec) =>
PatternT dec -> Exp (Aliases lore) -> PatternT (AliasDec, dec)
addAliasesToPattern ([Ident] -> [Ident] -> Exp lore -> Pattern lore
forall lore.
Bindable lore =>
[Ident] -> [Ident] -> Exp lore -> Pattern lore
mkExpPat [Ident]
ctx [Ident]
val (Exp lore -> Pattern lore) -> Exp lore -> Pattern lore
forall a b. (a -> b) -> a -> b
$ Exp (Aliases lore) -> Exp lore
forall lore.
CanBeAliased (Op lore) =>
Exp (Aliases lore) -> Exp lore
removeExpAliases Exp (Aliases lore)
e) Exp (Aliases lore)
e

  mkLetNames :: forall (m :: * -> *).
(MonadFreshNames m, HasScope (Aliases lore) m) =>
[VName] -> Exp (Aliases lore) -> m (Stm (Aliases lore))
mkLetNames [VName]
names Exp (Aliases lore)
e = do
    Scope lore
env <- (Scope (Aliases lore) -> Scope lore) -> m (Scope lore)
forall lore (m :: * -> *) a.
HasScope lore m =>
(Scope lore -> a) -> m a
asksScope Scope (Aliases lore) -> Scope lore
forall lore. Scope (Aliases lore) -> Scope lore
removeScopeAliases
    (ReaderT (Scope lore) m (Stm (Aliases lore))
 -> Scope lore -> m (Stm (Aliases lore)))
-> Scope lore
-> ReaderT (Scope lore) m (Stm (Aliases lore))
-> m (Stm (Aliases lore))
forall a b c. (a -> b -> c) -> b -> a -> c
flip ReaderT (Scope lore) m (Stm (Aliases lore))
-> Scope lore -> m (Stm (Aliases lore))
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT Scope lore
env (ReaderT (Scope lore) m (Stm (Aliases lore))
 -> m (Stm (Aliases lore)))
-> ReaderT (Scope lore) m (Stm (Aliases lore))
-> m (Stm (Aliases lore))
forall a b. (a -> b) -> a -> b
$ do
      Let Pattern lore
pat StmAux (ExpDec lore)
dec Exp lore
_ <- [VName] -> Exp lore -> ReaderT (Scope lore) m (Stm lore)
forall lore (m :: * -> *).
(Bindable lore, MonadFreshNames m, HasScope lore m) =>
[VName] -> Exp lore -> m (Stm lore)
mkLetNames [VName]
names (Exp lore -> ReaderT (Scope lore) m (Stm lore))
-> Exp lore -> ReaderT (Scope lore) m (Stm lore)
forall a b. (a -> b) -> a -> b
$ Exp (Aliases lore) -> Exp lore
forall lore.
CanBeAliased (Op lore) =>
Exp (Aliases lore) -> Exp lore
removeExpAliases Exp (Aliases lore)
e
      Stm (Aliases lore) -> ReaderT (Scope lore) m (Stm (Aliases lore))
forall (m :: * -> *) a. Monad m => a -> m a
return (Stm (Aliases lore) -> ReaderT (Scope lore) m (Stm (Aliases lore)))
-> Stm (Aliases lore)
-> ReaderT (Scope lore) m (Stm (Aliases lore))
forall a b. (a -> b) -> a -> b
$ Pattern lore
-> StmAux (ExpDec lore) -> Exp (Aliases lore) -> Stm (Aliases lore)
forall lore.
(ASTLore lore, CanBeAliased (Op lore)) =>
Pattern lore
-> StmAux (ExpDec lore) -> Exp (Aliases lore) -> Stm (Aliases lore)
mkAliasedLetStm Pattern lore
pat StmAux (ExpDec lore)
dec Exp (Aliases lore)
e

  mkBody :: Stms (Aliases lore) -> Result -> Body (Aliases lore)
mkBody Stms (Aliases lore)
bnds Result
res =
    let Body BodyDec lore
bodylore Stms lore
_ Result
_ = Stms lore -> Result -> BodyT lore
forall lore. Bindable lore => Stms lore -> Result -> Body lore
mkBody ((Stm (Aliases lore) -> Stm lore)
-> Stms (Aliases lore) -> Stms lore
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Stm (Aliases lore) -> Stm lore
forall lore.
CanBeAliased (Op lore) =>
Stm (Aliases lore) -> Stm lore
removeStmAliases Stms (Aliases lore)
bnds) Result
res
     in BodyDec lore
-> Stms (Aliases lore) -> Result -> Body (Aliases lore)
forall lore.
(ASTLore lore, CanBeAliased (Op lore)) =>
BodyDec lore
-> Stms (Aliases lore) -> Result -> Body (Aliases lore)
mkAliasedBody BodyDec lore
bodylore Stms (Aliases lore)
bnds Result
res

instance (ASTLore (Aliases lore), Bindable (Aliases lore)) => BinderOps (Aliases lore)