module Flite.Matching (desugarEqn, desugarCase) where import Flite.Syntax import Flite.Traversals import Flite.Descend import Flite.Fresh import Data.List import Data.Maybe import Control.Monad desugarEqn :: Prog -> Fresh Prog desugarEqn p = mapM (\(f, arity, qs) -> do us <- mapM (\_ -> fresh) [1..arity] rhs <- match us qs return (Func f (map Var us) rhs) ) (groupEqn p) groupEqn :: Prog -> [(String, Int, [Equation])] groupEqn p | all (rect . map funcArgs) dss = map gr dss | otherwise = error "Function equations cannot have different arities!" where dss = groupBy (\a b -> funcName a == funcName b) p gr ds = ( funcName (head ds) , length (funcArgs (head ds)) , zip (map funcArgs ds) (map funcRhs ds) ) rect :: [[a]] -> Bool rect = (== 1) . length . groupBy (==) . map length desugarCase :: Prog -> Fresh Prog desugarCase = onExpM (\e -> caseVar e >>= desugar) where desugar (Case (Var v) as) = do as' <- mapM (\(p, e) -> do e' <- desugar e; return (p, e')) as match [v] [([p], e) | (p, e) <- as'] desugar e = descendM desugar e caseVar :: Exp -> Fresh Exp caseVar (Case e as) = case getVar e of Nothing -> do v <- fresh caseVar (Let [(v, e)] (Case (Var v) as)) Just v -> descendM caseVar (Case (Var v) as) where v = getVar e caseVar e = descendM caseVar e getVar :: Exp -> Maybe Id getVar (Var v) = Just v getVar (App e []) = getVar e getVar e = Nothing -- Wadler's algorithm for compilation of *uniform* pattern matching, -- from "The Implementation of Functional Programming Languages". type Equation = ([Pat], Exp) isVar :: Equation -> Bool isVar (Var v:ps, e) = True isVar (App (Con c) args:ps, e) = False isCon :: Equation -> Bool isCon e = not (isVar e) getCon :: Equation -> (Id, [Pat]) getCon (App (Con c) args:ps, e) = (c, args) match :: [Id] -> [Equation] -> Fresh Exp match [] [q] = return (snd q) match (u:us) qs | all isVar qs = match us [(ps, subst (Var u) v e) | (Var v:ps, e) <- qs] | all isCon qs = do alts <- mapM (matchClause us) (groupEqns qs) return (Case (Var u) alts) match _ _ = error "Non-uniform pattern matching is disallowed!" groupEqns :: [Equation] -> [(Id, Int, [Equation])] groupEqns [] = [] groupEqns (q:qs) | all ((== arity) . length . snd . getCon) qs0 = (name, arity, qs0) : groupEqns qs1 | otherwise = error ("Constructor `" ++ name ++ "` has different arities!") where (qs0, qs1) = partition ((== name) . fst . getCon) (q:qs) name = fst (getCon q) arity = length (snd (getCon q)) matchClause :: [Id] -> (Id, Int, [Equation]) -> Fresh Alt matchClause us (c, arity, qs) = do us' <- mapM (\_ -> fresh) [1..arity] alts <- match (us' ++ us) [(ps' ++ ps, e) | (App (Con c) ps':ps, e) <- qs] return (App (Con c) (map Var us'), alts)