module Flite.Matching (desugarEqn, desugarCase) where
import Flite.Syntax
import Flite.Traversals
import Flite.Descend
import Flite.Fresh
import Data.List
import Data.Maybe
import Control.Monad
desugarEqn :: Prog -> Fresh Prog
desugarEqn p = mapM (\(f, arity, qs) ->
do us <- mapM (\_ -> fresh) [1..arity]
rhs <- match us qs
return (Func f (map Var us) rhs)
) (groupEqn p)
groupEqn :: Prog -> [(String, Int, [Equation])]
groupEqn p
| all (rect . map funcArgs) dss = map gr dss
| otherwise = error "Function equations cannot have different arities!"
where
dss = groupBy (\a b -> funcName a == funcName b) p
gr ds = ( funcName (head ds)
, length (funcArgs (head ds))
, zip (map funcArgs ds) (map funcRhs ds)
)
rect :: [[a]] -> Bool
rect = (== 1) . length . groupBy (==) . map length
desugarCase :: Prog -> Fresh Prog
desugarCase = onExpM (\e -> caseVar e >>= desugar)
where
desugar (Case (Var v) as) =
do as' <- mapM (\(p, e) -> do e' <- desugar e; return (p, e')) as
match [v] [([p], e) | (p, e) <- as']
desugar e = descendM desugar e
caseVar :: Exp -> Fresh Exp
caseVar (Case e as) =
case getVar e of
Nothing -> do v <- fresh
caseVar (Let [(v, e)] (Case (Var v) as))
Just v -> descendM caseVar (Case (Var v) as)
where v = getVar e
caseVar e = descendM caseVar e
getVar :: Exp -> Maybe Id
getVar (Var v) = Just v
getVar (App e []) = getVar e
getVar e = Nothing
type Equation = ([Pat], Exp)
isVar :: Equation -> Bool
isVar (Var v:ps, e) = True
isVar (App (Con c) args:ps, e) = False
isCon :: Equation -> Bool
isCon e = not (isVar e)
getCon :: Equation -> (Id, [Pat])
getCon (App (Con c) args:ps, e) = (c, args)
match :: [Id] -> [Equation] -> Fresh Exp
match [] [q] = return (snd q)
match (u:us) qs
| all isVar qs = match us [(ps, subst (Var u) v e) | (Var v:ps, e) <- qs]
| all isCon qs = do alts <- mapM (matchClause us) (groupEqns qs)
return (Case (Var u) alts)
match _ _ = error "Non-uniform pattern matching is disallowed!"
groupEqns :: [Equation] -> [(Id, Int, [Equation])]
groupEqns [] = []
groupEqns (q:qs)
| all ((== arity) . length . snd . getCon) qs0 =
(name, arity, qs0) : groupEqns qs1
| otherwise = error ("Constructor `" ++ name ++ "` has different arities!")
where (qs0, qs1) = partition ((== name) . fst . getCon) (q:qs)
name = fst (getCon q)
arity = length (snd (getCon q))
matchClause :: [Id] -> (Id, Int, [Equation]) -> Fresh Alt
matchClause us (c, arity, qs) =
do us' <- mapM (\_ -> fresh) [1..arity]
alts <- match (us' ++ us) [(ps' ++ ps, e) | (App (Con c) ps':ps, e) <- qs]
return (App (Con c) (map Var us'), alts)