module Optimus.Generalise where import Flite.Fresh import Flite.Syntax import Flite.Traversals import Optimus.Homeo import Optimus.Trace import Optimus.Uniplate import Control.Monad import Data.Generics.Uniplate import Data.List import Data.Maybe type GenBinding = (GenFlag, Id, Exp) data GenFlag = Safe | Unsafe deriving Show generalise :: Exp -> Exp -> Fresh Exp generalise s t | couple t s = do v <- fresh (x', bind, _) <- msg (Var v, [(v, s)], [(v, t)]) return $ Let bind x' | otherwise = return s where msg :: (Exp, [Binding], [Binding]) -> Fresh (Exp, [Binding], [Binding]) msg m | null vsts = return m | otherwise = do ys <- mapM (const fresh) ss msg (subst (y_ $ map Var ys) v tg, th0' ++ zip ys ss, th1' ++ zip ys ts) where vsts = [ (v, s, t) | (v, s) <- th0, (w, t) <- th1, v == w, s =~ t ] th0' = [ b | b@(v', _) <- th0, v /= v' ] th1' = [ b | b@(w', _) <- th1, v /= w' ] (v, s, t) = head vsts (ss, y_) = uniplate s ts = children t (tg, th0, th1) = t_G (show m) m generalise1 :: Exp -> Exp -> Fresh Exp generalise1 s t | couple t s = do v <- fresh (x', bind, _) <- msg (Var v, [(Safe, v, s)], [(v, t)]) return $ case null [ undefined | (Safe, _, _) <- bind ] of True -> s False -> Let [ (v, y) | (Safe, v, y) <- bind ] (substMany x' [ (y, v) | (Unsafe, v, y) <- bind ]) | otherwise = return s where msg :: (Exp, [GenBinding], [Binding]) -> Fresh (Exp, [GenBinding], [Binding]) msg m | null vsts = return m | otherwise = do ys <- mapM (const fresh) ss msg (subst (y_ $ map Var ys) v tg, th0' ++ zip3 sFlags ys ss, th1' ++ zip ys ts) where vsts = [ (v, s, t) | (_, v, s) <- th0, (w, t) <- th1, v == w, s =~ t ] th0' = [ gb | gb@(_, v', _) <- th0, v /= v' ] th1' = [ b | b@(w', _) <- th1, v /= w' ] (v, s, t) = head vsts (ss, y_) = uniplate s ts = children t sFlags = map flag (zipWith inaccessibleVars ss (map (\ms -> subst (y_ ms) v tg) (markers ss))) (tg, th0, th1) = t_G (show m) m flag xs | null xs = Safe | otherwise = Unsafe markers :: [a] -> [[Exp]] markers [] = [] markers xs = nub . permutations $ Var "%" : replicate (length xs - 1) (Var "") inaccessibleVars :: Exp -> Exp -> [Id] inaccessibleVars x y = freeVars x `intersect` bindingsAtMarker y bindingsAtContext :: (Exp -> Exp) -> [Id] bindingsAtContext ctx = bindingsAtMarker (ctx $ Var "%") bindingsAtMarker :: Exp -> [Id] bindingsAtMarker x | isJust vs = nub . fromJust $ vs | otherwise = error "Never reached marker." where vs = f x f e@(Var "%") = Just [] f e@(Let bs _) = (listToMaybe . catMaybes . map f) (children e) >>= return . (++) (map fst bs) f e@(Case _ as) = (listToMaybe . catMaybes . zipWith (\p x -> x >>= (return . (++) p)) ([] : map (patVars . fst) as) . map f) (children e) f e@(App _ _) = (listToMaybe . catMaybes . map f) (children e) f _ = Nothing