{-# LANGUAGE UndecidableInstances #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE ConstraintKinds #-} {-# LANGUAGE TypeFamilies #-} module Futhark.Analysis.SymbolTable ( SymbolTable (bindings, loopDepth, availableAtClosestLoop, simplifyMemory) , empty , fromScope , toScope -- * Entries , Entry , deepen , entryDepth , entryLetBoundDec , entryIsSize -- * Lookup , elem , lookup , lookupStm , lookupExp , lookupBasicOp , lookupType , lookupSubExp , lookupAliases , lookupLoopVar , available , consume , index , index' , Indexed(..) , indexedAddCerts , IndexOp(..) -- * Insertion , insertStm , insertStms , insertFParams , insertLParam , insertLoopVar -- * Misc , hideCertified ) where import Control.Arrow ((&&&)) import Control.Monad import Data.Ord import Data.Maybe import Data.List (foldl', elemIndex) import qualified Data.Map.Strict as M import Prelude hiding (elem, lookup) import Futhark.Analysis.PrimExp.Convert import Futhark.IR hiding (FParam, lookupType) import qualified Futhark.IR as AST import qualified Futhark.IR.Prop.Aliases as Aliases data SymbolTable lore = SymbolTable { loopDepth :: Int , bindings :: M.Map VName (Entry lore) , availableAtClosestLoop :: Names -- ^ Which names are available just before the most enclosing -- loop? , simplifyMemory :: Bool -- ^ We are in a situation where we should -- simplify/hoist/un-existentialise memory as much as possible - -- typically, inside a kernel. } instance Semigroup (SymbolTable lore) where table1 <> table2 = SymbolTable { loopDepth = max (loopDepth table1) (loopDepth table2) , bindings = bindings table1 <> bindings table2 , availableAtClosestLoop = availableAtClosestLoop table1 <> availableAtClosestLoop table2 , simplifyMemory = simplifyMemory table1 || simplifyMemory table2 } instance Monoid (SymbolTable lore) where mempty = empty empty :: SymbolTable lore empty = SymbolTable 0 M.empty mempty False fromScope :: ASTLore lore => Scope lore -> SymbolTable lore fromScope = M.foldlWithKey' insertFreeVar' empty where insertFreeVar' m k dec = insertFreeVar k dec m toScope :: SymbolTable lore -> Scope lore toScope = M.map entryInfo . bindings deepen :: SymbolTable lore -> SymbolTable lore deepen vtable = vtable { loopDepth = loopDepth vtable + 1, availableAtClosestLoop = namesFromList $ M.keys $ bindings vtable } -- | The result of indexing a delayed array. data Indexed = Indexed Certificates (PrimExp VName) -- ^ A PrimExp based on the indexes (that is, without -- accessing any actual array). | IndexedArray Certificates VName [PrimExp VName] -- ^ The indexing corresponds to another (perhaps more -- advantageous) array. indexedAddCerts :: Certificates -> Indexed -> Indexed indexedAddCerts cs1 (Indexed cs2 v) = Indexed (cs1<>cs2) v indexedAddCerts cs1 (IndexedArray cs2 arr v) = IndexedArray (cs1<>cs2) arr v instance FreeIn Indexed where freeIn' (Indexed cs v) = freeIn' cs <> freeIn' v freeIn' (IndexedArray cs arr v) = freeIn' cs <> freeIn' arr <> freeIn' v -- | Indexing a delayed array if possible. type IndexArray = [PrimExp VName] -> Maybe Indexed data Entry lore = Entry { entryConsumed :: Bool -- ^ True if consumed. , entryDepth :: Int , entryIsSize :: Bool -- ^ True if this name has been used as an array size, -- implying that it is non-negative. , entryType :: EntryType lore } data EntryType lore = LoopVar (LoopVarEntry lore) | LetBound (LetBoundEntry lore) | FParam (FParamEntry lore) | LParam (LParamEntry lore) | FreeVar (FreeVarEntry lore) data LoopVarEntry lore = LoopVarEntry { loopVarType :: IntType , loopVarBound :: SubExp } data LetBoundEntry lore = LetBoundEntry { letBoundDec :: LetDec lore , letBoundAliases :: Names , letBoundStm :: Stm lore , letBoundIndex :: Int -> IndexArray -- ^ Index a delayed array, if possible. } data FParamEntry lore = FParamEntry { fparamDec :: FParamInfo lore , fparamAliases :: Names } data LParamEntry lore = LParamEntry { lparamDec :: LParamInfo lore , lparamIndex :: IndexArray } data FreeVarEntry lore = FreeVarEntry { freeVarDec :: NameInfo lore , freeVarIndex :: VName -> IndexArray -- ^ Index a delayed array, if possible. } instance ASTLore lore => Typed (Entry lore) where typeOf = typeOf . entryInfo entryInfo :: Entry lore -> NameInfo lore entryInfo e = case entryType e of LetBound entry -> LetName $ letBoundDec entry LoopVar entry -> IndexName $ loopVarType entry FParam entry -> FParamName $ fparamDec entry LParam entry -> LParamName $ lparamDec entry FreeVar entry -> freeVarDec entry isLetBound :: Entry lore -> Maybe (LetBoundEntry lore) isLetBound e = case entryType e of LetBound entry -> Just entry _ -> Nothing entryStm :: Entry lore -> Maybe (Stm lore) entryStm = fmap letBoundStm . isLetBound entryLetBoundDec :: Entry lore -> Maybe (LetDec lore) entryLetBoundDec = fmap letBoundDec . isLetBound elem :: VName -> SymbolTable lore -> Bool elem name = isJust . lookup name lookup :: VName -> SymbolTable lore -> Maybe (Entry lore) lookup name = M.lookup name . bindings lookupStm :: VName -> SymbolTable lore -> Maybe (Stm lore) lookupStm name vtable = entryStm =<< lookup name vtable lookupExp :: VName -> SymbolTable lore -> Maybe (Exp lore, Certificates) lookupExp name vtable = (stmExp &&& stmCerts) <$> lookupStm name vtable lookupBasicOp :: VName -> SymbolTable lore -> Maybe (BasicOp, Certificates) lookupBasicOp name vtable = case lookupExp name vtable of Just (BasicOp e, cs) -> Just (e, cs) _ -> Nothing lookupType :: ASTLore lore => VName -> SymbolTable lore -> Maybe Type lookupType name vtable = typeOf <$> lookup name vtable lookupSubExpType :: ASTLore lore => SubExp -> SymbolTable lore -> Maybe Type lookupSubExpType (Var v) = lookupType v lookupSubExpType (Constant v) = const $ Just $ Prim $ primValueType v lookupSubExp :: VName -> SymbolTable lore -> Maybe (SubExp, Certificates) lookupSubExp name vtable = do (e,cs) <- lookupExp name vtable case e of BasicOp (SubExp se) -> Just (se,cs) _ -> Nothing lookupAliases :: VName -> SymbolTable lore -> Names lookupAliases name vtable = case entryType <$> M.lookup name (bindings vtable) of Just (LetBound e) -> letBoundAliases e Just (FParam e) -> fparamAliases e _ -> mempty -- | If the given variable name is the name of a 'ForLoop' parameter, -- then return the bound of that loop. lookupLoopVar :: VName -> SymbolTable lore -> Maybe SubExp lookupLoopVar name vtable = do LoopVar e <- entryType <$> M.lookup name (bindings vtable) return $ loopVarBound e -- | In symbol table and not consumed. available :: VName -> SymbolTable lore -> Bool available name = maybe False (not . entryConsumed) . M.lookup name . bindings index :: ASTLore lore => VName -> [SubExp] -> SymbolTable lore -> Maybe Indexed index name is table = do is' <- mapM asPrimExp is index' name is' table where asPrimExp i = do Prim t <- lookupSubExpType i table return $ primExpFromSubExp t i index' :: VName -> [PrimExp VName] -> SymbolTable lore -> Maybe Indexed index' name is vtable = do entry <- lookup name vtable case entryType entry of LetBound entry' | Just k <- elemIndex name $ patternValueNames $ stmPattern $ letBoundStm entry' -> letBoundIndex entry' k is FreeVar entry' -> freeVarIndex entry' name is LParam entry' -> lparamIndex entry' is _ -> Nothing class IndexOp op where indexOp :: (ASTLore lore, IndexOp (Op lore)) => SymbolTable lore -> Int -> op -> [PrimExp VName] -> Maybe Indexed indexOp _ _ _ _ = Nothing instance IndexOp () where indexExp :: (IndexOp (Op lore), ASTLore lore) => SymbolTable lore -> Exp lore -> Int -> IndexArray indexExp vtable (Op op) k is = indexOp vtable k op is indexExp _ (BasicOp (Iota _ x s to_it)) _ [i] = Just $ Indexed mempty $ sExt to_it i * primExpFromSubExp (IntType to_it) s + primExpFromSubExp (IntType to_it) x indexExp table (BasicOp (Replicate (Shape ds) v)) _ is | length ds == length is, Just (Prim t) <- lookupSubExpType v table = Just $ Indexed mempty $ primExpFromSubExp t v indexExp table (BasicOp (Replicate (Shape [_]) (Var v))) _ (_:is) = index' v is table indexExp table (BasicOp (Reshape newshape v)) _ is | Just oldshape <- arrayDims <$> lookupType v table = let is' = reshapeIndex (map (primExpFromSubExp int32) oldshape) (map (primExpFromSubExp int32) $ newDims newshape) is in index' v is' table indexExp table (BasicOp (Index v slice)) _ is = index' v (adjust slice is) table where adjust (DimFix j:js') is' = pe j : adjust js' is' adjust (DimSlice j _ s:js') (i:is') = let i_t_s = i * pe s j_p_i_t_s = pe j + i_t_s in j_p_i_t_s : adjust js' is' adjust _ _ = [] pe = primExpFromSubExp (IntType Int32) indexExp _ _ _ _ = Nothing defBndEntry :: (ASTLore lore, IndexOp (Op lore)) => SymbolTable lore -> PatElem lore -> Names -> Stm lore -> LetBoundEntry lore defBndEntry vtable patElem als bnd = LetBoundEntry { letBoundDec = patElemDec patElem , letBoundAliases = als , letBoundStm = bnd , letBoundIndex = \k -> fmap (indexedAddCerts (stmAuxCerts $ stmAux bnd)) . indexExp vtable (stmExp bnd) k } bindingEntries :: (ASTLore lore, Aliases.Aliased lore, IndexOp (Op lore)) => Stm lore -> SymbolTable lore -> [LetBoundEntry lore] bindingEntries bnd@(Let pat _ _) vtable = do pat_elem <- patternElements pat return $ defBndEntry vtable pat_elem (Aliases.aliasesOf pat_elem) bnd adjustSeveral :: Ord k => (v -> v) -> [k] -> M.Map k v -> M.Map k v adjustSeveral f = flip $ foldl' $ flip $ M.adjust f insertEntry :: ASTLore lore => VName -> EntryType lore -> SymbolTable lore -> SymbolTable lore insertEntry name entry vtable = let entry' = Entry { entryConsumed = False , entryDepth = loopDepth vtable , entryIsSize = False , entryType = entry } dims = mapMaybe subExpVar $ arrayDims $ typeOf entry' isSize e = e { entryIsSize = True } in vtable { bindings = adjustSeveral isSize dims $ M.insert name entry' $ bindings vtable } insertEntries :: ASTLore lore => [(VName, EntryType lore)] -> SymbolTable lore -> SymbolTable lore insertEntries entries vtable = foldl' add vtable entries where add vtable' (name, entry) = insertEntry name entry vtable' insertStm :: (ASTLore lore, IndexOp (Op lore), Aliases.Aliased lore) => Stm lore -> SymbolTable lore -> SymbolTable lore insertStm stm vtable = flip (foldl' $ flip consume) (namesToList stm_consumed) $ flip (foldl' addRevAliases) (patternElements $ stmPattern stm) $ insertEntries (zip names $ map LetBound $ bindingEntries stm vtable) vtable where names = patternNames $ stmPattern stm stm_consumed = expandAliases (Aliases.consumedInStm stm) vtable addRevAliases vtable' pe = vtable' { bindings = adjustSeveral update inedges $ bindings vtable' } where inedges = namesToList $ expandAliases (Aliases.aliasesOf pe) vtable' update e = e { entryType = update' $ entryType e } update' (LetBound entry) = LetBound entry { letBoundAliases = oneName (patElemName pe) <> letBoundAliases entry } update' (FParam entry) = FParam entry { fparamAliases = oneName (patElemName pe) <> fparamAliases entry } update' e = e insertStms :: (ASTLore lore, IndexOp (Op lore), Aliases.Aliased lore) => Stms lore -> SymbolTable lore -> SymbolTable lore insertStms stms vtable = foldl' (flip insertStm) vtable $ stmsToList stms expandAliases :: Names -> SymbolTable lore -> Names expandAliases names vtable = names <> aliasesOfAliases where aliasesOfAliases = mconcat . map (`lookupAliases` vtable) . namesToList $ names insertFParam :: ASTLore lore => AST.FParam lore -> SymbolTable lore -> SymbolTable lore insertFParam fparam = insertEntry name entry where name = AST.paramName fparam entry = FParam FParamEntry { fparamDec = AST.paramDec fparam , fparamAliases = mempty } insertFParams :: ASTLore lore => [AST.FParam lore] -> SymbolTable lore -> SymbolTable lore insertFParams fparams symtable = foldl' (flip insertFParam) symtable fparams insertLParam :: ASTLore lore => LParam lore -> SymbolTable lore -> SymbolTable lore insertLParam param = insertEntry name bind where bind = LParam LParamEntry { lparamDec = AST.paramDec param , lparamIndex = const Nothing } name = AST.paramName param insertLoopVar :: ASTLore lore => VName -> IntType -> SubExp -> SymbolTable lore -> SymbolTable lore insertLoopVar name it bound = insertEntry name bind where bind = LoopVar LoopVarEntry { loopVarType = it , loopVarBound = bound } insertFreeVar :: ASTLore lore => VName -> NameInfo lore -> SymbolTable lore -> SymbolTable lore insertFreeVar name dec = insertEntry name entry where entry = FreeVar FreeVarEntry { freeVarDec = dec , freeVarIndex = \_ _ -> Nothing } consume :: VName -> SymbolTable lore -> SymbolTable lore consume consumee vtable = foldl' consume' vtable $ namesToList $ expandAliases (oneName consumee) vtable where consume' vtable' v = vtable' { bindings = M.adjust consume'' v $ bindings vtable' } consume'' e = e { entryConsumed = True } -- | Hide definitions of those entries that satisfy some predicate. hideIf :: (Entry lore -> Bool) -> SymbolTable lore -> SymbolTable lore hideIf hide vtable = vtable { bindings = M.map maybeHide $ bindings vtable } where maybeHide entry | hide entry = entry { entryType = FreeVar FreeVarEntry { freeVarDec = entryInfo entry , freeVarIndex = \_ _ -> Nothing } } | otherwise = entry -- | Hide these definitions, if they are protected by certificates in -- the set of names. hideCertified :: Names -> SymbolTable lore -> SymbolTable lore hideCertified to_hide = hideIf $ maybe False hide . entryStm where hide = any (`nameIn` to_hide) . unCertificates . stmCerts