module Flite.Case (caseElim, caseElimWithCaseStack) where import Flite.Syntax import Flite.Traversals import Flite.Descend import Flite.State import Control.Monad import Data.List as List import Data.Set as Set import Data.Map as Map -- Assumes that pattern matching has been desugared. caseElim :: Prog -> Prog caseElim = caseElim' False caseElimWithCaseStack :: Prog -> Prog caseElimWithCaseStack = caseElim' True caseElim' :: Bool -> Prog -> Prog caseElim' cstk p = elim cstk fs (expandCase ft p) where fs = families p ft = familyTable fs type Family = Set (Id, Int) families :: Prog -> [Family] families p | check = fams | otherwise = error "A constructor cannot have different arities!" where check = let ids = [id | (id, _) <- Set.toList (Set.unions fams)] in length ids == length (nub ids) fams = fixMerge (List.map Set.fromList ctrs) merge [] = [] merge (f:fs) = Set.unions (f:same) : merge different where (same, different) = List.partition (overlap f) fs fixMerge fs = if length fs == length fs' then fs' else fixMerge fs' where fs' = merge fs overlap f0 f1 = not (Set.null (Set.intersection f0 f1)) ctrs = fromExp fam p fam e = List.map (concatMap getCtr) (caseAlts e) getCtr (App (Con c) ps, e) = [(c, length ps)] getCtr (p, e) = [] familyTable :: [Family] -> Map Id Family familyTable fams = Map.fromList [(id, fam) | fam <- fams, (id, arity) <- Set.toList fam] expandCase :: Map Id Family -> Prog -> Prog expandCase table p = onExp expand p where expand (Case e ((Var v, rhs):as)) = expand (Let [(v, e)] rhs) expand (Case e alts@((App (Con c) ps, rhs):as)) = Case (expand e) alts' where alts' = [getAlt f n | (f, n) <- Set.toAscList (table Map.! c)] getAlt f n = head ([ (App (Con c) args, expand rhs) | (App (Con c) args, rhs) <- alts , c == f ] ++ [bottom f n]) bottom f n = (App (Con f) (replicate n (Var "_")), Bottom) expand e = descend expand e elim :: Bool -> [Family] -> Prog -> Prog elim cstk fams p = concatMap comp p where ctrInfo = [ (f, (arity, i)) | fs <- List.map Set.toAscList fams , ((f, arity), i) <- zip fs [0..] ] comp d = let ((_, ds), e) = runState (compFun (funcName d) (funcRhs d)) (1, []) in (d { funcRhs = e } : ds) compFun fun (Con c) | Prelude.null cinfo = return Bottom | otherwise = return (Ctr c (fst $ head cinfo) (snd $ head cinfo)) where cinfo = [ci | (d, ci) <- ctrInfo, c == d] compFun fun (Case e as) = return App `ap` compFun fun e `ap` calts fun as compFun fun e = descendM (compFun fun) e calts fun as = do es' <- mapM (compFun fun) es let fvs = nub $ concat $ zipWith (freeVarsExcept) vss es' fs <- zipWithM (calt fun fvs) vss es' let alts = Alts fs (length fvs) return ([alts] ++ [Int 0 | cstk && List.null fvs] ++ List.map Var fvs) where (ps, es) = unzip as vss = List.map (\(App _ args) -> [v | Var v <- args]) ps calt fun fvs vs e = do n <- newAlt let name = fun ++ "#" ++ show n let args = vs ++ ["$ct" | not cstk || (cstk && List.null fvs)] ++ fvs addDecl (Func name (List.map Var args) e) return name newAlt = S (\(i, ds) -> ((i+1, ds), i)) addDecl d = S (\(i, ds) -> ((i, ds ++ [d]), ()))