module E.Subst(
doSubst,
doSubst',
eAp,
litSMapM,
subst,
subst',
substMap,
substMap',
substMap'',
typeSubst,
typeSubst'
) where
import Control.Monad.Reader
import qualified Data.Set as Set
import qualified Data.Traversable as T
import E.E
import E.FreeVars()
import GenUtil
import Name.Id
import Name.Names
import Support.FreeVars
import Util.HasSize
import Util.SetLike as S
subst ::
TVr
-> E
-> E
-> E
subst (TVr { tvrIdent = eid }) _ e | eid == emptyId = e
subst (TVr { tvrIdent = i }) w e = doSubst' False False (msingleton i w) (\n -> n `member` (freeVars w `union` freeVars e :: IdSet)) e
subst' :: TVr -> E -> E -> E
subst' (TVr { tvrIdent = eid }) _ e | eid == emptyId = e
subst' (TVr { tvrIdent = (i) }) w e = doSubst' True False (msingleton i w) (\n -> n `member` (freeVars w `union` freeVars e :: IdSet)) e
litSMapM f LitCons { litName = s, litArgs = es, litType = t, litAliasFor = af } = do
t' <- f t
es' <- mapM f es
return $ LitCons s es' t' af
litSMapM f (LitInt n t) = do
t' <- f t
return $ LitInt n t'
substMap :: IdMap E -> E -> E
substMap im e = doSubst' False False im (\n -> n `member` (unions $ (freeVars e :: IdSet):map freeVars (values im))) e
substMap' :: IdMap E -> E -> E
substMap' im = doSubst' False False im (`member` im)
substMap'' :: IdMap (Maybe E) -> E -> E
substMap'' im = doSubst' False False (mapMaybeIdMap id im) (`member` im)
doSubst :: Bool -> Bool -> IdMap (Maybe E) -> E -> E
doSubst substInVars allShadow bm e
= doSubst' substInVars allShadow (mapMaybeIdMap id bm) (`member` bm) e
doSubst' :: Bool -> Bool -> IdMap E -> (Id -> Bool) -> E -> E
doSubst' substInVars allShadow bm check e = f e (Set.empty, bm) where
f :: E -> (Set.Set Id, IdMap E) -> E
f eo@(EVar tvr@(TVr { tvrIdent = i, tvrType = t })) = do
(_,mp) <- ask
case mlookup i mp of
Just v -> return v
_
| substInVars -> f t >>= \t' -> return $ EVar (tvr { tvrType = t'})
| otherwise -> return eo
f (ELam tvr e) = lp ELam tvr e
f (EPi tvr e) = lp EPi tvr e
f (EAp a b) = liftM2 eAp (f a) (f b)
f (EError x e) = liftM (EError x) (f e)
f (EPrim x es e) = liftM2 (EPrim x) (mapM f es) (f e)
f ELetRec { eDefs = dl, eBody = e } = do
(as,rs) <- mapMntvr (fsts dl)
local (foldr (.) id rs) $ do
ds <- mapM f (snds dl)
e' <- f e
return $ ELetRec (zip as ds) e'
f (ELit l) = liftM ELit $ litSMapM f l
f Unknown = return Unknown
f e@(ESort {}) = return e
f ec@(ECase {}) = do
e' <- f $ eCaseScrutinee ec
(b',r) <- ntvr Set.empty $ eCaseBind ec
d <- local r $ T.mapM f $ eCaseDefault ec
let da (Alt lc@LitCons { litName = s, litArgs = vs, litType = t } e) = do
t' <- f t
(as,rs) <- mapMntvr vs
e' <- local (foldr (.) id rs) $ f e
return $ Alt lc { litArgs = as, litType = t' } e'
da (Alt l e) = do
l' <- T.mapM f l
e' <- f e
return $ Alt l' e'
alts <- local r (mapM da $ eCaseAlts ec)
nty <- f (eCaseType ec)
return $ caseUpdate ec { eCaseScrutinee = e', eCaseDefault = d, eCaseBind = b', eCaseAlts = alts, eCaseType = nty }
lp lam tvr@(TVr { tvrIdent = n, tvrType = t}) e | n == emptyId || (allShadow && n `notElem` freeVars e) = do
t' <- f t
e' <- local (\(s,m) -> (Set.insert n s, delete n m)) $ f e
return $ lam (tvr { tvrIdent = emptyId, tvrType = t'}) e'
lp lam tvr e = do
(tv,r) <- ntvr Set.empty tvr
e' <- local r $ f e
return $ lam tv e'
mapMntvr ts = f ts [] where
f [] xs = return $ unzip $ reverse xs
f (t:ts) rs = do
(t',r) <- ntvr vs t
local r $ f ts ((t',r):rs)
vs = Set.fromList [ tvrIdent x | x <- ts ]
ntvr xs tvr@(TVr { tvrIdent = eid, tvrType = t}) | eid == emptyId = do
t' <- f t
let nvr = (tvr { tvrType = t'})
return (nvr,id)
ntvr xs tvr@(TVr {tvrIdent = i, tvrType = t}) = do
t' <- f t
(s,ss) <- ask
let i' = mnv allShadow xs i check s ss
let nvr = (tvr { tvrIdent = i', tvrType = t'})
return (nvr,\(s,m) -> (Set.insert i' . Set.insert i $ s, minsert i (EVar nvr) . delete i' $ m))
mnv :: Bool -> Set.Set Id -> Id -> (Id -> Bool) -> Set.Set Id -> IdMap a -> Id
mnv allShadow xs i checkTaken s ss
| allShadow = newId (Set.size xs `mixInt` Set.size s `mixInt` size ss) (not . scheck)
| isInvalidId i || scheck i = newId (Set.size xs `mixInt` Set.size s `mixInt` size ss) (not . check)
| otherwise = i
where scheck n = n `member` ss || n `member` s || checkTaken n
check n = scheck n || n `member` xs
eAp (EPi t b) e = if tvrIdent t == emptyId then b else subst t e b
eAp (ELam t b) e = if tvrIdent t == emptyId then b else subst t e b
eAp (ELit LitCons { litName = arr, litArgs = [a1], litType = (EPi _ r) }) a2 | arr == tc_Arrow = EPi tvr { tvrType = a1 } a2
eAp (ELit lc@LitCons { litArgs = es, litType = (EPi t r) }) b = ELit lc { litArgs = es ++ [b], litType = subst t b r }
eAp (ELit LitCons { litArgs = es, litAliasFor = Just af }) b = foldl eAp af (es ++ [b])
eAp (EError s t) b = EError s (eAp t b)
eAp a b = EAp a b
typeSubst' :: IdMap E -> IdMap E -> E -> E
typeSubst' termSub typeSub e | isEmpty termSub && isEmpty typeSub = e
typeSubst' termSub typeSub e = typeSubst (fmap Just termSub `union` fmap ((`mlookup` termSub) . tvrIdent) fvs) typeSub e where
fvs :: IdMap TVr
fvs = (freeVars e `union` fvmap termSub `union` fvmap typeSub)
fvmap m = unions (map freeVars (values m))
typeSubst ::
IdMap (Maybe E)
-> IdMap E
-> (E -> E)
typeSubst termSubst typeSubst e | isEmpty termSubst && isEmpty typeSubst = e
typeSubst termSubst typeSubst e = f e (False,termSubst',typeSubst) where
termSubst' = termSubst `union` fmap (const Nothing) typeSubst
f :: E -> (Bool,IdMap (Maybe E),IdMap E) -> E
f eo@(EVar tvr@(TVr { tvrIdent = i, tvrType = t })) = do
(wh,trm,tp) <- ask
case (wh,mlookup i trm, mlookup i tp) of
(False,(Just (Just v)),_) -> return v
(True,_,(Just v)) -> return v
_ -> return eo
f (ELam tvr e) = lp ELam tvr e
f (EPi tvr e) = lp EPi tvr e
f (EAp a b) = liftM2 eAp (f a) (f b)
f (EError x e) = liftM (EError x) (inType $ f e)
f (EPrim x es e) = liftM2 (EPrim x) (mapM f es) (inType $ f e)
f ELetRec { eDefs = dl, eBody = e } = do
(as,rs) <- liftM unzip $ mapMntvr (fsts dl)
local (foldr (.) id rs) $ do
ds <- mapM f (snds dl)
e' <- f e
return $ ELetRec (zip as ds) e'
f (ELit l) = liftM ELit $ litSMapM l
f Unknown = return Unknown
f e@(ESort {}) = return e
f ec@(ECase {}) = do
e' <- f $ eCaseScrutinee ec
(b',r) <- ntvr Set.empty $ eCaseBind ec
d <- local r $ T.mapM f $ eCaseDefault ec
let da (Alt lc@LitCons { litName = s, litArgs = vs, litType = t } e) = do
t' <- inType $ f t
(as,rs) <- liftM unzip $ mapMntvr vs
e' <- local (foldr (.) id rs) $ f e
return $ Alt lc { litArgs = as, litType = t' } e'
da (Alt (LitInt n t) e) = do
t' <- inType (f t)
e' <- f e
return $ Alt (LitInt n t') e'
alts <- (mapM da $ eCaseAlts ec)
nty <- inType (f $ eCaseType ec)
return $ caseUpdate ec { eCaseScrutinee = e', eCaseDefault = d, eCaseBind = b', eCaseAlts = alts, eCaseType = nty }
lp lam tvr@(TVr { tvrIdent = eid, tvrType = t}) e | eid == emptyId = do
t' <- inType (f t)
e' <- f e
return $ lam (tvr { tvrIdent = emptyId, tvrType = t'}) e'
lp lam tvr e = do
(tv,r) <- ntvr Set.empty tvr
e' <- local r $ f e
return $ lam tv e'
mapMntvr ts = f ts [] where
f [] xs = return $ reverse xs
f (t:ts) rs = do
(t',r) <- ntvr vs t
local r $ f ts ((t',r):rs)
vs = Set.fromList [ tvrIdent x | x <- ts ]
inType = local (\ (_,trm,typ) -> (True,trm,typ) )
addMap i (Just e) (b,trm,typ) = (b,minsert i (Just e) trm, minsert i e typ)
addMap i Nothing (b,trm,typ) = (b,minsert i Nothing trm, typ)
litSMapM lc@LitCons { litName = s, litArgs = es, litType = t } = do
t' <- inType $ f t
es' <- mapM f es
return $ lc { litArgs = es', litType = t' }
litSMapM (LitInt n t) = do
t' <- inType $ f t
return $ LitInt n t'
ntvr xs tvr@(TVr { tvrIdent = eid, tvrType = t}) | eid == emptyId = do
t' <- inType (f t)
let nvr = (tvr { tvrType = t'})
return (nvr,id)
ntvr xs tvr@(TVr {tvrIdent = i, tvrType = t}) = do
t' <- inType (f t)
(_,map,_) <- ask
let i' = mnv False xs i (\_ -> False) Set.empty map
let nvr = (tvr { tvrIdent = i', tvrType = t'})
case i == i' of
True -> return (nvr,addMap i (Just $ EVar nvr))
False -> return (nvr,addMap i (Just $ EVar nvr) . addMap i' Nothing)