{-# LANGUAGE TypeFamilies #-} {-# Language FlexibleInstances, FlexibleContexts #-} module Futhark.Representation.AST.Attributes.Aliases ( vnameAliases , subExpAliases , primOpAliases , expAliases , patternAliases , Aliased (..) , AliasesOf (..) -- * Consumption , consumedInStm , consumedInExp , consumedByLambda -- * Extensibility , AliasedOp (..) , CanBeAliased (..) ) where import Control.Arrow (first) import Data.Monoid ((<>)) import qualified Data.Set as S import Futhark.Representation.AST.Attributes (IsOp) import Futhark.Representation.AST.Syntax import Futhark.Representation.AST.Attributes.Patterns import Futhark.Representation.AST.Attributes.Types class (Annotations lore, AliasedOp (Op lore), AliasesOf (LetAttr lore)) => Aliased lore where bodyAliases :: Body lore -> [Names] consumedInBody :: Body lore -> Names vnameAliases :: VName -> Names vnameAliases = S.singleton subExpAliases :: SubExp -> Names subExpAliases Constant{} = mempty subExpAliases (Var v) = vnameAliases v primOpAliases :: BasicOp lore -> [Names] primOpAliases (SubExp se) = [subExpAliases se] primOpAliases (Opaque se) = [subExpAliases se] primOpAliases (ArrayLit _ _) = [mempty] primOpAliases BinOp{} = [mempty] primOpAliases ConvOp{} = [mempty] primOpAliases CmpOp{} = [mempty] primOpAliases UnOp{} = [mempty] primOpAliases (Index ident _) = [vnameAliases ident] primOpAliases Update{} = [mempty] primOpAliases Iota{} = [mempty] primOpAliases Replicate{} = [mempty] primOpAliases (Repeat _ _ v) = [vnameAliases v] primOpAliases Scratch{} = [mempty] primOpAliases (Reshape _ e) = [vnameAliases e] primOpAliases (Rearrange _ e) = [vnameAliases e] primOpAliases (Rotate _ e) = [vnameAliases e] primOpAliases Concat{} = [mempty] primOpAliases Copy{} = [mempty] primOpAliases Manifest{} = [mempty] primOpAliases Assert{} = [mempty] ifAliases :: ([Names], Names) -> ([Names], Names) -> [Names] ifAliases (als1,cons1) (als2,cons2) = map (S.filter notConsumed) $ zipWith mappend als1 als2 where notConsumed = not . (`S.member` cons) cons = cons1 <> cons2 funcallAliases :: [(SubExp, Diet)] -> [TypeBase shape Uniqueness] -> [Names] funcallAliases args t = returnAliases t [(subExpAliases se, d) | (se,d) <- args ] expAliases :: (Aliased lore) => Exp lore -> [Names] expAliases (If _ tb fb attr) = drop (length all_aliases - length ts) all_aliases where ts = ifReturns attr all_aliases = ifAliases (bodyAliases tb, consumedInBody tb) (bodyAliases fb, consumedInBody fb) expAliases (BasicOp op) = primOpAliases op expAliases (DoLoop ctxmerge valmerge _ loopbody) = map (`S.difference` merge_names) val_aliases where (_ctx_aliases, val_aliases) = splitAt (length ctxmerge) $ bodyAliases loopbody merge_names = S.fromList $ map (paramName . fst) $ ctxmerge ++ valmerge expAliases (Apply _ args t _) = funcallAliases args $ retTypeValues t expAliases (Op op) = opAliases op returnAliases :: [TypeBase shaper Uniqueness] -> [(Names, Diet)] -> [Names] returnAliases rts args = map returnType' rts where returnType' (Array _ _ Nonunique) = mconcat $ map (uncurry maskAliases) args returnType' (Array _ _ Unique) = mempty returnType' (Prim _) = mempty returnType' Mem{} = error "returnAliases Mem" maskAliases :: Names -> Diet -> Names maskAliases _ Consume = mempty maskAliases als Observe = als consumedInStm :: Aliased lore => Stm lore -> Names consumedInStm = consumedInExp . stmExp consumedInExp :: (Aliased lore) => Exp lore -> Names consumedInExp (Apply _ args _ _) = mconcat (map (consumeArg . first subExpAliases) args) where consumeArg (als, Consume) = als consumeArg (_, Observe) = mempty consumedInExp (If _ tb fb _) = consumedInBody tb <> consumedInBody fb consumedInExp (DoLoop _ merge _ _) = mconcat (map (subExpAliases . snd) $ filter (unique . paramDeclType . fst) merge) consumedInExp (BasicOp (Update src _ _)) = S.singleton src consumedInExp (Op op) = consumedInOp op consumedInExp _ = mempty consumedByLambda :: Aliased lore => Lambda lore -> Names consumedByLambda = consumedInBody . lambdaBody patternAliases :: AliasesOf attr => PatternT attr -> [Names] patternAliases = map (aliasesOf . patElemAttr) . patternElements -- | Something that contains alias information. class AliasesOf a where -- | The alias of the argument element. aliasesOf :: a -> Names instance AliasesOf Names where aliasesOf = id instance AliasesOf attr => AliasesOf (PatElemT attr) where aliasesOf = aliasesOf . patElemAttr class IsOp op => AliasedOp op where opAliases :: op -> [Names] consumedInOp :: op -> Names instance AliasedOp () where opAliases () = [] consumedInOp () = mempty class AliasedOp (OpWithAliases op) => CanBeAliased op where type OpWithAliases op :: * removeOpAliases :: OpWithAliases op -> op addOpAliases :: op -> OpWithAliases op instance CanBeAliased () where type OpWithAliases () = () removeOpAliases = id addOpAliases = id