module Generate.Cases (toMatch, Match (..), Clause (..), matchSubst, newVar) where import Control.Applicative ((<$>)) import Control.Arrow (first) import Control.Monad.State (State) import qualified Control.Monad.State as State import Data.List (groupBy,sortBy) import Data.Maybe (fromMaybe) import qualified AST.Annotation as A import qualified AST.Expression.General as Expr import qualified AST.Expression.Canonical as Canonical import qualified AST.Literal as Literal import qualified AST.Pattern as P import qualified AST.Variable as Var import Transform.Substitute (subst) toMatch :: [(P.CanonicalPattern, Canonical.Expr)] -> State Int (String, Match) toMatch patterns = do v <- newVar (,) v <$> match [v] (map (first (:[])) patterns) Fail newVar :: State Int String newVar = do n <- State.get State.modify (+1) return $ "_v" ++ show n data Match = Match String [Clause] Match | Break | Fail | Other Canonical.Expr | Seq [Match] deriving Show data Clause = Clause (Either Var.Canonical Literal.Literal) [String] Match deriving Show type Case = ([P.CanonicalPattern], Canonical.Expr) matchSubst :: [(String,String)] -> Match -> Match matchSubst pairs match = case match of Break -> Break Fail -> Fail Seq ms -> Seq (map (matchSubst pairs) ms) Other (A.A a e) -> Other . A.A a $ foldr ($) e $ map (\(x,y) -> subst x (Expr.localVar y)) pairs Match n cs m -> Match (varSubst n) (map clauseSubst cs) (matchSubst pairs m) where varSubst v = fromMaybe v (lookup v pairs) clauseSubst (Clause c vs m) = Clause c (map varSubst vs) (matchSubst pairs m) isCon :: ([P.Pattern var], expr) -> Bool isCon ([] , _) = noMatch "isCon" isCon (p:_, _) = case p of P.Data _ _ -> True P.Literal _ -> True _ -> False isVar :: ([P.Pattern var], expr) -> Bool isVar p = not (isCon p) match :: [String] -> [Case] -> Match -> State Int Match match variables cases match = case (variables, cases, match) of ([], [], _) -> return match ([], [([], expr)], Fail) -> return $ Other expr ([], [([], expr)], Break) -> return $ Other expr ([], _, _) -> return $ Seq (map (Other . snd) cases ++ [match]) (var:_, _, _) | all isVar cases' -> matchVar variables cases' match | all isCon cases' -> matchCon variables cases' match | otherwise -> matchMix variables cases' match where cases' = map (dealias var) cases dealias :: String -> Case -> Case dealias _ ([], _) = noMatch "dealias" dealias v c@(p:ps, A.A a e) = case p of P.Alias x pattern -> (pattern:ps, A.A a $ subst x (Expr.localVar v) e) _ -> c matchVar :: [String] -> [Case] -> Match -> State Int Match matchVar [] _ _ = noMatch "matchVar" matchVar (v:vs) cs def = match vs (map subVar cs) def where subVar ([], _) = noMatch "matchVar.subVar" subVar (p:ps, (A.A a expr)) = (ps, A.A a $ subOnePattern p expr) where subOnePattern pattern expr = case pattern of P.Anything -> expr P.Var x -> subst x (Expr.localVar v) expr P.Record fs -> foldr (\x -> subst x (Expr.Access (A.A a (Expr.localVar v)) x)) expr fs _ -> noMatch "matchVar.subVar" matchCon :: [String] -> [Case] -> Match -> State Int Match matchCon variables cases match = case variables of [] -> noMatch "matchCon" var : _vars -> flip (Match var) match <$> mapM toClause caseGroups where caseGroups :: [[Case]] caseGroups = groupBy eq (sortBy cmp cases) eq :: Case -> Case -> Bool eq (patterns1, _) (patterns2, _) = case (patterns1, patterns2) of (pattern1:_, pattern2:_) -> case (pattern1, pattern2) of (P.Data name1 _, P.Data name2 _) -> name1 == name2 (_, _) -> pattern1 == pattern2 (_, _) -> noMatch "matchCon.eq" cmp :: Case -> Case -> Ordering cmp (patterns1, _) (patterns2, _) = case (patterns1, patterns2) of (pattern1:_, pattern2:_) -> case (pattern1, pattern2) of (P.Data name1 _, P.Data name2 _) -> compare name1 name2 (_, _) -> compare pattern1 pattern2 (_, _) -> noMatch "matchCon.cmp" toClause :: [Case] -> State Int Clause toClause cases = case head cases of (P.Data name _ : _, _) -> matchClause (Left name) variables cases Break (P.Literal lit : _, _) -> matchClause (Right lit) variables cases Break _ -> noMatch "matchCon" matchClause :: Either Var.Canonical Literal.Literal -> [String] -> [Case] -> Match -> State Int Clause matchClause _ [] _ _ = noMatch "matchClause" matchClause c (_:vs) cs mtch = do vs' <- getVars Clause c vs' <$> match (vs' ++ vs) (map flatten cs) mtch where flatten ([], _) = noMatch "matchClause.flatten" flatten (p:ps, e) = case p of P.Data _ ps' -> (ps' ++ ps, e) P.Literal _ -> (ps, e) _ -> noMatch "matchClause.flatten" getVars :: State Int [String] getVars = case head cs of (P.Data _ ps : _, _) -> State.forM ps (const newVar) (P.Literal _ : _, _) -> return [] _ -> noMatch "matchClause.getVars" matchMix :: [String] -> [Case] -> Match -> State Int Match matchMix vs cs def = State.foldM (flip $ match vs) def (reverse css) where css = groupBy (\p1 p2 -> isCon p1 == isCon p2) cs noMatch :: String -> a noMatch name = error $ "unexpected pattern in '" ++ name ++ "' function. Report a compiler issue to ."