module DDC.Core.Transform.Rewrite.Env ( RewriteEnv , empty , extend , extendLets , containsRegion , containsWitness , getWitnesses , insertDef , getDef , hasDef , lift , liftValue) where import DDC.Core.Exp import qualified DDC.Type.Exp.Simple as T import qualified DDC.Type.Transform.BoundT as L import qualified DDC.Core.Transform.BoundX as L import Data.Maybe (fromMaybe, listToMaybe, isJust) -- | A summary of the environment that we perform a rewrite in. -- -- As we decend into the program looking for expressions to rewrite, -- we keep track of what information as been defined in the environment -- in a `RewriteEnv`. -- -- When we go under an anonymous binder then we push a new outermost -- list instead of lifting every element on the environment eagerly. -- data RewriteEnv a n = RewriteEnv { -- | Types of all witnesses in scope. -- We use these to satisfy constraints on rewrite rules like Const r. witnesses :: [[T.Type n]] -- | Names of letregion-bound regions: -- this is interesting because they must be distinct. , letregions :: [[Bind n]] -- | Assoc of known values -- If going to inline them, they must only reference de bruijn binds -- these are value-level bindings, so be careful lifting. , defs :: [[RewriteDef a n]] } deriving (Show,Eq) type RewriteDef a n = (Bind n, Maybe (Exp a n)) -- | An empty environment. empty :: RewriteEnv a n empty = RewriteEnv [] [] [] -- | Extend an environment with some lambda-bound binder (XLam) -- Might be a witness. Don't count if it's a region. extend :: Ord n => Bind n -> RewriteEnv a n -> RewriteEnv a n extend b env | T.isWitnessType (T.typeOfBind b) = let ty = T.typeOfBind b extend' (w:ws') = (ty:w) : ws' extend' [] = [[ty]] in liftValue b $ env { witnesses = extend' (witnesses env) } | otherwise = insertDef b Nothing (liftValue b env) -- | Extend an environment with the variables bount by these let-bindings. -- -- If it's a letregion, remember the region's name and any witnesses. -- extendLets :: Ord n => Lets a n -> RewriteEnv a n -> RewriteEnv a n extendLets (LPrivate bs _mt cs) renv = foldl (flip extend) (foldl extendB renv bs) cs where extendB (env@RewriteEnv{witnesses = ws, letregions = rs}) b = case b of BAnon{} -> env { witnesses = [] : ws , letregions = [b] : rs } BName{} -> env { letregions = extend' b rs } BNone{} -> env extend' b (r:rs') = (b:r) : rs' extend' b [] = [[b]] extendLets (LLet b def) env = insertDef b (Just def') (liftValue b env) where def' = case b of BAnon{} -> L.liftX 1 def _ -> def extendLets (LRec bs) env = foldl lift' env (map fst bs) where lift' e b = insertDef b Nothing (liftValue b e) -- Witnesses ------------------------------------------------------------------ -- | Check if the witness map in the given environment. --- -- This tries each set in turn, lowering the indices in c by 1 after each -- unsuccessful match. If nothing matches then 'c' may end up with negative -- indices, which will definiately not match anything else. -- containsWitness :: Ord n => Type n -> RewriteEnv a n -> Bool containsWitness c env = go c (witnesses env) where go _ [] = False go c' (w:ws) = c' `elem` w || go (L.liftT (-1) c') ws -- | Get a list of all the witness types in an environment, -- normalising their indices. getWitnesses :: Ord n => RewriteEnv a n -> [Type n] getWitnesses env = go (witnesses env) 0 where go [] _ = [] go (w:ws) i = map (L.liftT i) w ++ go ws (i+1) -- Regions -------------------------------------------------------------------- -- | Check whether an environment contains the given region, -- bound by a letregion. containsRegion :: Ord n => Bound n -> RewriteEnv a n -> Bool containsRegion r env = go r (letregions env) where go _ [] = False go (UIx 0) (w:_) = any (T.boundMatchesBind (UIx 0)) w go (UIx n) (_:ws) = go (UIx (n-1)) ws go (UName n) (w:ws) = any (T.boundMatchesBind (UName n)) w || go r ws go (UPrim _ _) _ = False -- Defs ----------------------------------------------------------------------- -- | Insert a rewrite definition into the environment. insertDef :: Bind n -> Maybe (Exp a n) -> RewriteEnv a n -> RewriteEnv a n insertDef b def env = env { defs = extend' $ defs env } where extend' (r:rs') = ((b,def):r) : rs' extend' [] = [[(b,def)]] hasDef :: (Ord n, L.MapBoundX (Exp a) n) => Bound n -> RewriteEnv a n -> Bool hasDef b env = isJust $ getDef' b env -- | Lookup the definition of some let-bound variable from the environment. getDef :: (Ord n, L.MapBoundX (Exp a) n) => Bound n -> RewriteEnv a n -> Maybe (Exp a n) getDef b env = fromMaybe Nothing $ getDef' b env getDef' :: (Ord n, L.MapBoundX (Exp a) n) => Bound n -> RewriteEnv a n -> Maybe (Maybe (Exp a n)) getDef' b env = go b 0 (defs env) where go _ _ [] = Nothing go b' i (w:ws) = match b' i w `orM` go (L.liftX (-1) b') (i+1) ws match b' i ds = fmap (fmap $ L.liftX i) $ listToMaybe $ map snd $ filter (T.boundMatchesBind b' . fst) ds orM (Just x) _ = Just x orM Nothing y = y -- | Raise all elements in witness map if binder is anonymous. -- Only call with type binders: ie XLAM, not XLam lift :: Bind n -> RewriteEnv a n -> RewriteEnv a n lift b env@(RewriteEnv ws rs is) = case b of BAnon{} -> RewriteEnv ([]:ws) ([]:rs) is _ -> env -- | Raise all elements in definitions map if binder is anonymous -- Use for *value* binders, not type binders. liftValue :: Bind n -> RewriteEnv a n -> RewriteEnv a n liftValue b env@(RewriteEnv ws rs is) = case b of BAnon{} -> RewriteEnv ws rs ([]:is) _ -> env