module DDC.Core.Transform.Rewrite.Env
( RewriteEnv
, empty
, extend
, extendLets
, containsRegion
, containsWitness
, getWitnesses
, insertDef
, getDef
, hasDef
, lift
, liftValue)
where
import DDC.Core.Exp
import qualified DDC.Type.Exp as T
import qualified DDC.Type.Compounds as T
import qualified DDC.Type.Predicates as T
import qualified DDC.Type.Transform.LiftT as L
import qualified DDC.Core.Transform.LiftX as L
import Data.Maybe (fromMaybe, listToMaybe, isJust)
data RewriteEnv a n
= RewriteEnv
{
witnesses :: [[T.Type n]]
, letregions :: [[Bind n]]
, defs :: [[RewriteDef a n]] }
deriving (Show,Eq)
type RewriteDef a n
= (Bind n, Maybe (Exp a n))
empty :: Ord n => RewriteEnv a n
empty = RewriteEnv [] [] []
extend :: Ord n => Bind n -> RewriteEnv a n -> RewriteEnv a n
extend b env
| T.isWitnessType (T.typeOfBind b)
= let ty = T.typeOfBind b
extend' (w:ws') = (ty:w) : ws'
extend' [] = [[ty]]
in liftValue b $ env { witnesses = extend' (witnesses env) }
| otherwise
= insertDef b Nothing (liftValue b env)
extendLets :: Ord n => Lets a n -> RewriteEnv a n -> RewriteEnv a n
extendLets (LPrivate bs _mt cs) renv
= foldl (flip extend) (foldl extendB renv bs) cs
where
extendB (env@RewriteEnv{witnesses = ws, letregions = rs}) b
= case b of
BAnon{}
-> env { witnesses = [] : ws
, letregions = [b] : rs }
BName{}
-> env { letregions = extend' b rs }
BNone{}
-> env
extend' b (r:rs') = (b:r) : rs'
extend' b [] = [[b]]
extendLets (LLet b def) env
= insertDef b (Just def') (liftValue b env)
where def' = case b of
BAnon{} -> L.liftX 1 def
_ -> def
extendLets (LRec bs) env
= foldl lift' env (map fst bs)
where lift' e b = insertDef b Nothing (liftValue b e)
extendLets _ env = env
containsWitness :: Ord n => Type n -> RewriteEnv a n -> Bool
containsWitness c env
= go c (witnesses env)
where go _ [] = False
go c' (w:ws) = c' `elem` w || go (L.liftT (1) c') ws
getWitnesses :: Ord n => RewriteEnv a n -> [Type n]
getWitnesses env
= go (witnesses env) 0
where go [] _ = []
go (w:ws) i = map (L.liftT i) w ++ go ws (i+1)
containsRegion :: Ord n => Bound n -> RewriteEnv a n -> Bool
containsRegion r env
= go r (letregions env)
where
go _ []
= False
go (UIx 0) (w:_)
= any (T.boundMatchesBind (UIx 0)) w
go (UIx n) (_:ws)
= go (UIx (n1)) ws
go (UName n) (w:ws)
= any (T.boundMatchesBind (UName n)) w || go r ws
go (UPrim _ _) _
= False
insertDef :: Bind n -> Maybe (Exp a n) -> RewriteEnv a n -> RewriteEnv a n
insertDef b def env
= env { defs = extend' $ defs env }
where
extend' (r:rs') = ((b,def):r) : rs'
extend' [] = [[(b,def)]]
hasDef :: (Ord n, L.MapBoundX (Exp a) n)
=> Bound n -> RewriteEnv a n -> Bool
hasDef b env
= isJust $ getDef' b env
getDef :: (Ord n, L.MapBoundX (Exp a) n)
=> Bound n
-> RewriteEnv a n
-> Maybe (Exp a n)
getDef b env
= fromMaybe Nothing $ getDef' b env
getDef' :: (Ord n, L.MapBoundX (Exp a) n)
=> Bound n
-> RewriteEnv a n
-> Maybe (Maybe (Exp a n))
getDef' b env
= go b 0 (defs env)
where
go _ _ [] = Nothing
go b' i (w:ws) = match b' i w `orM` go (L.liftX (1) b') (i+1) ws
match b' i ds
= fmap (fmap $ L.liftX i)
$ listToMaybe
$ map snd
$ filter (T.boundMatchesBind b' . fst) ds
orM (Just x) _ = Just x
orM Nothing y = y
lift :: Bind n -> RewriteEnv a n -> RewriteEnv a n
lift b env@(RewriteEnv ws rs is)
= case b of
BAnon{} -> RewriteEnv ([]:ws) ([]:rs) is
_ -> env
liftValue :: Bind n -> RewriteEnv a n -> RewriteEnv a n
liftValue b env@(RewriteEnv ws rs is)
= case b of
BAnon{} -> RewriteEnv ws rs ([]:is)
_ -> env