{-# LANGUAGE UndecidableInstances #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE FlexibleInstances #-} -- | A representation where all bindings are annotated with aliasing -- information. module Futhark.Representation.Aliases ( -- * The Lore definition Aliases , Names' (..) , VarAliases , ConsumedInExp , BodyAliasing , module Futhark.Representation.AST.Attributes.Aliases -- * Module re-exports , module Futhark.Representation.AST.Attributes , module Futhark.Representation.AST.Traversals , module Futhark.Representation.AST.Pretty , module Futhark.Representation.AST.Syntax -- * Adding aliases , addAliasesToPattern , mkAliasedLetStm , mkAliasedBody , mkPatternAliases , mkBodyAliases -- * Removing aliases , removeProgAliases , removeFunDefAliases , removeExpAliases , removeBodyAliases , removeStmAliases , removeLambdaAliases , removePatternAliases , removeScopeAliases -- * Tracking aliases , AliasesAndConsumed , trackAliases , consumedInStms ) where import Control.Monad.Identity import Control.Monad.Reader import Data.Foldable import Data.Maybe import Data.Monoid ((<>)) import qualified Data.Map.Strict as M import qualified Data.Set as S import Futhark.Representation.AST.Syntax import Futhark.Representation.AST.Attributes import Futhark.Representation.AST.Attributes.Aliases import Futhark.Representation.AST.Traversals import Futhark.Representation.AST.Pretty import Futhark.Transform.Rename import Futhark.Binder import Futhark.Transform.Substitute import Futhark.Analysis.Rephrase import qualified Futhark.Util.Pretty as PP -- | The lore for the basic representation. data Aliases lore -- | A wrapper around 'Names' to get around the fact that we need an -- 'Ord' instance, which 'Names' does not have. newtype Names' = Names' { unNames :: Names } deriving (Show) instance Semigroup Names' where x <> y = Names' $ unNames x <> unNames y instance Monoid Names' where mempty = Names' mempty instance Eq Names' where _ == _ = True instance Ord Names' where _ `compare` _ = EQ instance Rename Names' where rename (Names' names) = Names' <$> rename names instance Substitute Names' where substituteNames substs (Names' names) = Names' $ substituteNames substs names instance FreeIn Names' where freeIn = const mempty instance PP.Pretty Names' where ppr = PP.commasep . map PP.ppr . S.toList . unNames -- | The aliases of the let-bound variable. type VarAliases = Names' -- | Everything consumed in the expression. type ConsumedInExp = Names' -- | The aliases of what is returned by the 'Body', and what is -- consumed inside of it. type BodyAliasing = ([VarAliases], ConsumedInExp) instance (Annotations lore, CanBeAliased (Op lore)) => Annotations (Aliases lore) where type LetAttr (Aliases lore) = (VarAliases, LetAttr lore) type ExpAttr (Aliases lore) = (ConsumedInExp, ExpAttr lore) type BodyAttr (Aliases lore) = (BodyAliasing, BodyAttr lore) type FParamAttr (Aliases lore) = FParamAttr lore type LParamAttr (Aliases lore) = LParamAttr lore type RetType (Aliases lore) = RetType lore type BranchType (Aliases lore) = BranchType lore type Op (Aliases lore) = OpWithAliases (Op lore) instance AliasesOf (VarAliases, attr) where aliasesOf = unNames . fst instance FreeAttr Names' where withoutAliases :: (HasScope (Aliases lore) m, Monad m) => ReaderT (Scope lore) m a -> m a withoutAliases m = do scope <- asksScope removeScopeAliases runReaderT m scope instance (Attributes lore, CanBeAliased (Op lore)) => Attributes (Aliases lore) where expTypesFromPattern = withoutAliases . expTypesFromPattern . removePatternAliases instance (Attributes lore, CanBeAliased (Op lore)) => Aliased (Aliases lore) where bodyAliases = map unNames . fst . fst . bodyAttr consumedInBody = unNames . snd . fst . bodyAttr instance PrettyAnnot (PatElemT attr) => PrettyAnnot (PatElemT (VarAliases, attr)) where ppAnnot (PatElem name (Names' als, attr)) = let alias_comment = PP.oneLine <$> aliasComment name als in case (alias_comment, ppAnnot (PatElem name attr)) of (_, Nothing) -> alias_comment (Just alias_comment', Just inner_comment) -> Just $ alias_comment' PP. inner_comment (Nothing, Just inner_comment) -> Just inner_comment instance (Attributes lore, CanBeAliased (Op lore)) => PrettyLore (Aliases lore) where ppExpLore (consumed, inner) e = maybeComment $ catMaybes [expAttr, mergeAttr, ppExpLore inner $ removeExpAliases e] where mergeAttr = case e of DoLoop _ merge _ body -> let mergeParamAliases fparam als | primType (paramType fparam) = Nothing | otherwise = resultAliasComment (paramName fparam) als in maybeComment $ catMaybes $ zipWith mergeParamAliases (map fst merge) $ bodyAliases body _ -> Nothing expAttr = case S.toList $ unNames consumed of [] -> Nothing als -> Just $ PP.oneLine $ PP.text "-- Consumes " <> PP.commasep (map PP.ppr als) maybeComment :: [PP.Doc] -> Maybe PP.Doc maybeComment [] = Nothing maybeComment cs = Just $ PP.folddoc (PP.) cs aliasComment :: (PP.Pretty a, PP.Pretty b) => a -> S.Set b -> Maybe PP.Doc aliasComment name als = case S.toList als of [] -> Nothing als' -> Just $ PP.oneLine $ PP.text "-- " <> PP.ppr name <> PP.text " aliases " <> PP.commasep (map PP.ppr als') resultAliasComment :: (PP.Pretty a, PP.Pretty b) => a -> S.Set b -> Maybe PP.Doc resultAliasComment name als = case S.toList als of [] -> Nothing als' -> Just $ PP.oneLine $ PP.text "-- Result of " <> PP.ppr name <> PP.text " aliases " <> PP.commasep (map PP.ppr als') removeAliases :: CanBeAliased (Op lore) => Rephraser Identity (Aliases lore) lore removeAliases = Rephraser { rephraseExpLore = return . snd , rephraseLetBoundLore = return . snd , rephraseBodyLore = return . snd , rephraseFParamLore = return , rephraseLParamLore = return , rephraseRetType = return , rephraseBranchType = return , rephraseOp = return . removeOpAliases } removeScopeAliases :: Scope (Aliases lore) -> Scope lore removeScopeAliases = M.map unAlias where unAlias (LetInfo (_, attr)) = LetInfo attr unAlias (FParamInfo attr) = FParamInfo attr unAlias (LParamInfo attr) = LParamInfo attr unAlias (IndexInfo it) = IndexInfo it removeProgAliases :: CanBeAliased (Op lore) => Prog (Aliases lore) -> Prog lore removeProgAliases = runIdentity . rephraseProg removeAliases removeFunDefAliases :: CanBeAliased (Op lore) => FunDef (Aliases lore) -> FunDef lore removeFunDefAliases = runIdentity . rephraseFunDef removeAliases removeExpAliases :: CanBeAliased (Op lore) => Exp (Aliases lore) -> Exp lore removeExpAliases = runIdentity . rephraseExp removeAliases removeBodyAliases :: CanBeAliased (Op lore) => Body (Aliases lore) -> Body lore removeBodyAliases = runIdentity . rephraseBody removeAliases removeStmAliases :: CanBeAliased (Op lore) => Stm (Aliases lore) -> Stm lore removeStmAliases = runIdentity . rephraseStm removeAliases removeLambdaAliases :: CanBeAliased (Op lore) => Lambda (Aliases lore) -> Lambda lore removeLambdaAliases = runIdentity . rephraseLambda removeAliases removePatternAliases :: PatternT (Names', a) -> PatternT a removePatternAliases = runIdentity . rephrasePattern (return . snd) addAliasesToPattern :: (Attributes lore, CanBeAliased (Op lore), Typed attr) => PatternT attr -> Exp (Aliases lore) -> PatternT (VarAliases, attr) addAliasesToPattern pat e = uncurry Pattern $ mkPatternAliases pat e mkAliasedBody :: (Attributes lore, CanBeAliased (Op lore)) => BodyAttr lore -> Stms (Aliases lore) -> Result -> Body (Aliases lore) mkAliasedBody innerlore bnds res = Body (mkBodyAliases bnds res, innerlore) bnds res mkPatternAliases :: (Attributes lore, Aliased lore, Typed attr) => PatternT attr -> Exp lore -> ([PatElemT (VarAliases, attr)], [PatElemT (VarAliases, attr)]) mkPatternAliases pat e = -- Some part of the pattern may be the context. This does not have -- aliases from expAliases, so we use a hack to compute aliases of -- the context. let als = expAliases e ++ repeat mempty -- In case the pattern has -- more elements (this -- implies a type error). context_als = mkContextAliases pat e in (zipWith annotateBindee (patternContextElements pat) context_als, zipWith annotateBindee (patternValueElements pat) als) where annotateBindee bindee names = bindee `setPatElemLore` (Names' names', patElemAttr bindee) where names' = case patElemType bindee of Array {} -> names Mem _ _ -> names _ -> mempty mkContextAliases :: (Attributes lore, Aliased lore) => PatternT attr -> Exp lore -> [Names] mkContextAliases pat (DoLoop ctxmerge valmerge _ body) = let ctx = loopResultContext (map fst ctxmerge) (map fst valmerge) init_als = zip mergenames $ map (subExpAliases . snd) $ ctxmerge ++ valmerge expand als = als <> S.unions (mapMaybe (`lookup` init_als) (S.toList als)) merge_als = zip mergenames $ map ((`S.difference` mergenames_set) . expand) $ bodyAliases body in if length ctx == length (patternContextElements pat) then map (fromMaybe mempty . flip lookup merge_als . paramName) ctx else map (const mempty) $ patternContextElements pat where mergenames = map (paramName . fst) $ ctxmerge ++ valmerge mergenames_set = S.fromList mergenames mkContextAliases pat (If _ tbranch fbranch _) = take (length $ patternContextNames pat) $ zipWith (<>) (bodyAliases tbranch) (bodyAliases fbranch) mkContextAliases pat _ = replicate (length $ patternContextElements pat) mempty mkBodyAliases :: Aliased lore => Stms lore -> Result -> BodyAliasing mkBodyAliases bnds res = -- We need to remove the names that are bound in bnds from the alias -- and consumption sets. We do this by computing the transitive -- closure of the alias map (within bnds), then removing anything -- bound in bnds. let (aliases, consumed) = mkStmsAliases bnds res boundNames = fold $ fmap (S.fromList . patternNames . stmPattern) bnds bound = (`S.member` boundNames) aliases' = map (S.filter (not . bound)) aliases consumed' = S.filter (not . bound) consumed in (map Names' aliases', Names' consumed') mkStmsAliases :: Aliased lore => Stms lore -> [SubExp] -> ([Names], Names) mkStmsAliases bnds res = delve mempty $ stmsToList bnds where delve (aliasmap, consumed) [] = (map (aliasClosure aliasmap . subExpAliases) res, consumed) delve (aliasmap, consumed) (bnd:bnds') = delve (trackAliases (aliasmap, consumed) bnd) bnds' aliasClosure aliasmap names = names `S.union` mconcat (map look $ S.toList names) where look k = M.findWithDefault mempty k aliasmap -- | Everything consumed in the given bindings and result (even transitively). consumedInStms :: Aliased lore => Stms lore -> [SubExp] -> Names consumedInStms bnds res = snd $ mkStmsAliases bnds res type AliasesAndConsumed = (M.Map VName Names, Names) trackAliases :: Aliased lore => AliasesAndConsumed -> Stm lore -> AliasesAndConsumed trackAliases (aliasmap, consumed) bnd = let pat = stmPattern bnd als = M.fromList $ zip (patternNames pat) (map addAliasesOfAliases $ patternAliases pat) aliasmap' = als <> aliasmap consumed' = consumed <> addAliasesOfAliases (consumedInStm bnd) in (aliasmap', consumed') where addAliasesOfAliases names = names <> aliasesOfAliases names aliasesOfAliases = mconcat . map look . S.toList look k = M.findWithDefault mempty k aliasmap mkAliasedLetStm :: (Attributes lore, CanBeAliased (Op lore)) => Pattern lore -> StmAux (ExpAttr lore) -> Exp (Aliases lore) -> Stm (Aliases lore) mkAliasedLetStm pat (StmAux cs attr) e = Let (addAliasesToPattern pat e) (StmAux cs (Names' $ consumedInExp e, attr)) e instance (Bindable lore, CanBeAliased (Op lore)) => Bindable (Aliases lore) where mkExpAttr pat e = let attr = mkExpAttr (removePatternAliases pat) $ removeExpAliases e in (Names' $ consumedInExp e, attr) mkExpPat ctx val e = addAliasesToPattern (mkExpPat ctx val $ removeExpAliases e) e mkLetNames names e = do env <- asksScope removeScopeAliases flip runReaderT env $ do Let pat attr _ <- mkLetNames names $ removeExpAliases e return $ mkAliasedLetStm pat attr e mkBody bnds res = let Body bodylore _ _ = mkBody (fmap removeStmAliases bnds) res in mkAliasedBody bodylore bnds res instance (Attributes (Aliases lore), Bindable (Aliases lore)) => BinderOps (Aliases lore) where mkBodyB = bindableMkBodyB mkExpAttrB = bindableMkExpAttrB mkLetNamesB = bindableMkLetNamesB