module MProver.Eval where import MProver.Syntax import MProver.Monad import Control.Monad.Reader import Control.Monad.Identity import Control.Monad.Error (throwError) import qualified Data.Map as Map import qualified Data.Set as Set import Data.Set hiding (map) import Data.Maybe import Unbound.LocallyNameless hiding (union,singleton,toList) --import Debug.Trace --trace _ = id -- FIXME: trying to decide if some kind of let-generalization is in order here; -- I think maybe not, because the let-bindings should be garbage collected if -- possible. tryJoin :: (Monad m) => Expr -> Expr -> MPT m Bool tryJoin e1 e2 = do e1' <- evalCBN e1 e2' <- evalCBN e2 if e1' `aeq` e2' then return True else case (e1',e2') of (Lambda b1,Lambda b2) -> do mr <- unbind2 b1 b2 case mr of (Just (x,e1'',_,e2'')) -> do localEE (unbindEnv x) (tryJoin e1'' e2'') Nothing -> return False (Var v1,Var v2) -> return (v1==v2) (Ctor c1,Ctor c2) -> return (c1==c2) (Literal l1,Literal l2) -> return (l1==l2) (App _ _,App _ _) -> do let (f1,args1) = unnestApp e1' (f2,args2) = unnestApp e2' case (f1,f2) of (Ctor c,Ctor c') -> if c==c' && length args1==length args2 then do rs <- mapM (uncurry tryJoin) (zip args1 args2) return (and rs) else return False (Var x,Var y) -> if x==y && length args1==length args2 then do rs <- mapM (uncurry tryJoin) (zip args1 args2) return (and rs) else return False _ -> return False _ -> return False evalCBN :: (Monad m) => Expr -> MPT m Expr evalCBN (Var x) = do ee <- askEE case Map.lookup x ee of (Just (_,Just e)) -> evalCBN e _ -> return (Var x) evalCBN (App e1 e2) = do e1' <- evalCBN e1 case e1' of (Lambda b) -> do (x,e) <- unbind b evalCBN (subst x e2 e) Bottom -> return Bottom _ -> return (App e1' e2) evalCBN (Case e alts) = do r <- doAlts e alts case r of (Just e') -> evalCBN e' Nothing -> return (Case e alts) evalCBN (Let b) = do (r,e) <- unbind b let bs = unrec r e' <- localEE (\ ee -> foldr (\ (x,e_) ee -> bindEnv x (Nothing,Just (unembed e_)) ee) ee bs) (evalCBN e) if any (\ v -> elem v (fv e')) (map fst bs) then return (Let (bind r e')) else return e' evalCBN e = return e data MR = Yes Expr | No | Poss deriving Show doAlts :: (Monad m) => Expr -> [Alt] -> MPT m (Maybe Expr) doAlts e (a:as) = do r <- doAlt e a case r of Yes e' -> return (Just e') No -> doAlts e as Poss -> return Nothing doAlts e [] = return (Just Bottom) doPat :: (Monad m) => Expr -> Pat -> MPT m MatchResult doPat e (PatVar x) = return (Match [(x,e)]) doPat e (PatCtor c) = do e' <- evalCBN e case e' of Lambda _ -> return NoMatch Var _ -> return Possible Ctor c' -> if c==c' then return (Match []) else return NoMatch Literal _ -> return NoMatch Let _ -> return Possible Case _ _ -> return Possible App _ _ -> do let (f,args) = unnestApp e' case f of Ctor c' -> return NoMatch _ -> return Possible Bottom -> return Diverge doPat e (PatLiteral l) = do e' <- evalCBN e case e' of Lambda _ -> return NoMatch Var _ -> return Possible Ctor _ -> return NoMatch Literal l' -> if l==l' then return (Match []) else return NoMatch Let _ -> return Possible Case _ _ -> return Possible App _ _ -> do let (f,args) = unnestApp e' case f of Ctor c' -> return NoMatch _ -> return Possible Bottom -> return Diverge doPat e PatWildcard = return (Match []) doPat e (PatApp c ps) = do e' <- evalCBN e case e' of Lambda _ -> return NoMatch Var _ -> return Possible Ctor _ -> return NoMatch Literal _ -> return NoMatch Let _ -> return Possible Case _ _ -> return Possible App _ _ -> do let (f,args) = unnestApp e' case f of Ctor c' -> if c==c' && length ps==length args then doPats args ps else return NoMatch _ -> return Possible Bottom -> return Diverge doPat e PatBottom = throwError "bottom pattern occurs in an expression" doPats :: (Monad m) => [Expr] -> [Pat] -> MPT m MatchResult doPats (e:es) (p:ps) = do r <- doPat e p case r of Match bs -> do r' <- doPats es ps case r' of Match bs' -> return (Match (bs++bs')) NoMatch -> return NoMatch Possible -> return Possible Diverge -> return Diverge NoMatch -> return NoMatch Possible -> return Possible Diverge -> return Diverge doPats [] [] = return (Match []) doAlt :: (Monad m) => Expr -> Alt -> MPT m MR doAlt e alt = do (p,body) <- unbind alt r <- doPat e p case r of Match bs -> return (Yes $ substs bs body) NoMatch -> return No Possible -> return Poss Diverge -> return (Yes Bottom) data MatchResult = Match [(Name Expr,Expr)] | NoMatch | Possible | Diverge deriving Show altsMatch :: (Monad m) => [Alt] -> Expr -> MPT m (Maybe Expr) altsMatch (a:as) e = do (p,b) <- unbind a case patMatch p e of (Match bs) -> return (Just $ substs bs b) NoMatch -> altsMatch as e Possible -> return Nothing Diverge -> return (Just Bottom) altsMatch [] _ = return (Just Bottom) patMatch :: Pat -> Expr -> MatchResult patMatch PatWildcard _ = Match [] patMatch (PatVar x) e = Match [(x,e)] patMatch (PatApp ctor ps) e = case e of (App _ _) -> let (ef,es) = unnestApp e in case ef of (Ctor c) | ctor == c -> zipPatMatch ps es | otherwise -> NoMatch _ -> Possible (Lambda _) -> NoMatch (Ctor _) -> NoMatch (Literal _) -> NoMatch Bottom -> Diverge _ -> Possible patMatch (PatCtor ctor) e = case e of (Ctor c) | ctor == c -> Match [] | otherwise -> NoMatch (App _ _) -> let (ef,es) = unnestApp e in case ef of (Ctor _) -> NoMatch _ -> Possible (Lambda _) -> NoMatch (Literal _) -> NoMatch Bottom -> Diverge _ -> Possible patMatch (PatLiteral l) e = case e of (Literal l') | l == l' -> Match [] | otherwise -> NoMatch (App _ _) -> let (ef,es) = unnestApp e in case ef of (Ctor _) -> NoMatch _ -> Possible (Ctor _) -> NoMatch (Lambda _) -> NoMatch Bottom -> Diverge _ -> Possible zipPatMatch :: [Pat] -> [Expr] -> MatchResult zipPatMatch ps es | length ps /= length es = NoMatch zipPatMatch (p:ps) (e:es) = case patMatch p e of (Match bs) -> case zipPatMatch ps es of (Match bs') -> Match (bs++bs') r -> r r -> r zipPatMatch [] [] = Match [] unnestApp :: Expr -> (Expr,[Expr]) unnestApp (App e1 e2) = let (ef,es) = unnestApp e1 in (ef,es++[e2]) unnestApp e = (e,[])