{-# LANGUAGE FlexibleInstances, FlexibleContexts, UndecidableInstances #-}
module Futhark.Representation.AST.Attributes.Names
(
Names
, nameIn
, oneName
, namesFromList
, namesToList
, namesIntersection
, namesIntersect
, namesSubtract
, mapNames
, FreeIn (..)
, freeIn
, freeInStmsAndRes
, boundInBody
, boundByStm
, boundByStms
, boundByLambda
, FreeAttr(..)
, FV
, fvBind
, fvName
, fvNames
)
where
import Control.Monad.State.Strict
import qualified Data.IntMap.Strict as IM
import qualified Data.Map.Strict as M
import Data.Foldable
import Futhark.Representation.AST.Syntax
import Futhark.Representation.AST.Traversals
import Futhark.Representation.AST.Attributes.Patterns
import Futhark.Representation.AST.Attributes.Scope
import Futhark.Util.Pretty
newtype Names = Names { unNames :: IM.IntMap VName }
deriving (Eq, Show)
instance Semigroup Names where
vs1 <> vs2 = Names $ unNames vs1 <> unNames vs2
instance Monoid Names where
mempty = Names mempty
instance Pretty Names where
ppr = ppr . namesToList
nameIn :: VName -> Names -> Bool
nameIn v (Names vs) = baseTag v `IM.member` vs
namesFromList :: [VName] -> Names
namesFromList vs = Names $ IM.fromList $ zip (map baseTag vs) vs
namesToList :: Names -> [VName]
namesToList = IM.elems . unNames
oneName :: VName -> Names
oneName v = Names $ IM.singleton (baseTag v) v
namesIntersection :: Names -> Names -> Names
namesIntersection (Names vs1) (Names vs2) = Names $ IM.intersection vs1 vs2
namesIntersect :: Names -> Names -> Bool
namesIntersect vs1 vs2 = not $ IM.disjoint (unNames vs1) (unNames vs2)
namesSubtract :: Names -> Names -> Names
namesSubtract (Names vs1) (Names vs2) = Names $ IM.difference vs1 vs2
mapNames :: (VName -> VName) -> Names -> Names
mapNames f vs = namesFromList $ map f $ namesToList vs
newtype FV = FV { unFV :: Names }
instance Monoid FV where
mempty = FV mempty
instance Semigroup FV where
FV fv1 <> FV fv2 = FV $ fv1 <> fv2
fvBind :: Names -> FV -> FV
fvBind vs (FV fv) = FV $ fv `namesSubtract` vs
fvName :: VName -> FV
fvName v = FV $ oneName v
fvNames :: Names -> FV
fvNames = FV
freeWalker :: (FreeAttr (ExpAttr lore),
FreeAttr (BodyAttr lore),
FreeIn (FParamAttr lore),
FreeIn (LParamAttr lore),
FreeIn (LetAttr lore),
FreeIn (Op lore)) =>
Walker lore (State FV)
freeWalker = identityWalker {
walkOnSubExp = modify . (<>) . freeIn'
, walkOnBody = modify . (<>) . freeIn'
, walkOnVName = modify . (<>) . fvName
, walkOnOp = modify . (<>) . freeIn'
}
freeInStmsAndRes :: (FreeIn (Op lore),
FreeIn (LetAttr lore),
FreeIn (LParamAttr lore),
FreeIn (FParamAttr lore),
FreeAttr (BodyAttr lore),
FreeAttr (ExpAttr lore)) =>
Stms lore -> Result -> FV
freeInStmsAndRes stms res =
fvBind (boundByStms stms) $ fold (fmap freeIn' stms) <> freeIn' res
class FreeIn a where
freeIn' :: a -> FV
freeIn' = fvNames . freeIn
freeIn :: FreeIn a => a -> Names
freeIn = unFV . freeIn'
instance FreeIn FV where
freeIn' = id
instance FreeIn () where
freeIn' () = mempty
instance FreeIn Int where
freeIn' = const mempty
instance (FreeIn a, FreeIn b) => FreeIn (a,b) where
freeIn' (a,b) = freeIn' a <> freeIn' b
instance (FreeIn a, FreeIn b, FreeIn c) => FreeIn (a,b,c) where
freeIn' (a,b,c) = freeIn' a <> freeIn' b <> freeIn' c
instance FreeIn a => FreeIn [a] where
freeIn' = fold . fmap freeIn'
instance (FreeAttr (ExpAttr lore),
FreeAttr (BodyAttr lore),
FreeIn (FParamAttr lore),
FreeIn (LParamAttr lore),
FreeIn (LetAttr lore),
FreeIn (Op lore)) => FreeIn (Lambda lore) where
freeIn' (Lambda params body rettype) =
fvBind (namesFromList $ map paramName params) $
freeIn' rettype <> freeIn' params <> freeIn' body
instance (FreeAttr (ExpAttr lore),
FreeAttr (BodyAttr lore),
FreeIn (FParamAttr lore),
FreeIn (LParamAttr lore),
FreeIn (LetAttr lore),
FreeIn (Op lore)) => FreeIn (Body lore) where
freeIn' (Body attr stms res) =
precomputed attr $ freeIn' attr <> freeInStmsAndRes stms res
instance (FreeAttr (ExpAttr lore),
FreeAttr (BodyAttr lore),
FreeIn (FParamAttr lore),
FreeIn (LParamAttr lore),
FreeIn (LetAttr lore),
FreeIn (Op lore)) => FreeIn (Exp lore) where
freeIn' (DoLoop ctxmerge valmerge form loopbody) =
let (ctxparams, ctxinits) = unzip ctxmerge
(valparams, valinits) = unzip valmerge
bound_here = namesFromList $ M.keys $
scopeOf form <>
scopeOfFParams (ctxparams ++ valparams)
in fvBind bound_here $
freeIn' (ctxinits ++ valinits) <> freeIn' form <>
freeIn' (ctxparams ++ valparams) <> freeIn' loopbody
freeIn' e = execState (walkExpM freeWalker e) mempty
instance (FreeAttr (ExpAttr lore),
FreeAttr (BodyAttr lore),
FreeIn (FParamAttr lore),
FreeIn (LParamAttr lore),
FreeIn (LetAttr lore),
FreeIn (Op lore)) => FreeIn (Stm lore) where
freeIn' (Let pat (StmAux cs attr) e) =
freeIn' cs <> precomputed attr (freeIn' attr <> freeIn' e <> freeIn' pat)
instance FreeIn (Stm lore) => FreeIn (Stms lore) where
freeIn' = fold . fmap freeIn'
instance FreeIn Names where
freeIn' = fvNames
instance FreeIn Bool where
freeIn' _ = mempty
instance FreeIn a => FreeIn (Maybe a) where
freeIn' = maybe mempty freeIn'
instance FreeIn VName where
freeIn' = fvName
instance FreeIn Ident where
freeIn' = freeIn' . identType
instance FreeIn SubExp where
freeIn' (Var v) = freeIn' v
freeIn' Constant{} = mempty
instance FreeIn d => FreeIn (ShapeBase d) where
freeIn' = freeIn' . shapeDims
instance FreeIn d => FreeIn (Ext d) where
freeIn' (Free x) = freeIn' x
freeIn' (Ext _) = mempty
instance FreeIn shape => FreeIn (TypeBase shape u) where
freeIn' (Array _ shape _) = freeIn' shape
freeIn' (Mem _) = mempty
freeIn' (Prim _) = mempty
instance FreeIn attr => FreeIn (Param attr) where
freeIn' (Param _ attr) = freeIn' attr
instance FreeIn attr => FreeIn (PatElemT attr) where
freeIn' (PatElem _ attr) = freeIn' attr
instance FreeIn (LParamAttr lore) => FreeIn (LoopForm lore) where
freeIn' (ForLoop _ _ bound loop_vars) = freeIn' bound <> freeIn' loop_vars
freeIn' (WhileLoop cond) = freeIn' cond
instance FreeIn d => FreeIn (DimChange d) where
freeIn' = Data.Foldable.foldMap freeIn'
instance FreeIn d => FreeIn (DimIndex d) where
freeIn' = Data.Foldable.foldMap freeIn'
instance FreeIn attr => FreeIn (PatternT attr) where
freeIn' (Pattern context values) =
fvBind bound_here $ freeIn' $ context ++ values
where bound_here = namesFromList $ map patElemName $ context ++ values
instance FreeIn Certificates where
freeIn' (Certificates cs) = freeIn' cs
instance FreeIn attr => FreeIn (StmAux attr) where
freeIn' (StmAux cs attr) = freeIn' cs <> freeIn' attr
instance FreeIn a => FreeIn (IfAttr a) where
freeIn' (IfAttr r _) = freeIn' r
class FreeIn attr => FreeAttr attr where
precomputed :: attr -> FV -> FV
precomputed _ = id
instance FreeAttr () where
instance (FreeAttr a, FreeIn b) => FreeAttr (a,b) where
precomputed (a,_) = precomputed a
instance FreeAttr a => FreeAttr [a] where
precomputed [] = id
precomputed (a:_) = precomputed a
instance FreeAttr a => FreeAttr (Maybe a) where
precomputed Nothing = id
precomputed (Just a) = precomputed a
boundInBody :: Body lore -> Names
boundInBody = boundByStms . bodyStms
boundByStm :: Stm lore -> Names
boundByStm = namesFromList . patternNames . stmPattern
boundByStms :: Stms lore -> Names
boundByStms = fold . fmap boundByStm
boundByLambda :: Lambda lore -> [VName]
boundByLambda lam = map paramName (lambdaParams lam)