module Matching (casificate) where import Language.Haskell.Exts.Syntax import Control.Monad.State {- Convert arguments in functions to "case of"'s, as in the example: f 0 y = e1 \ ---> / f v = case v of (1,y) -> e1 f x y = e2 / \ (x,y) -> e2 -} casificate :: Module -> Module casificate (Module a b c d i decls) = let seed = "v" newDecls = evalState (mapM cas_decl decls) seed in Module a b c d i newDecls type ST a = State String a cas_decl :: Decl -> ST Decl cas_decl (FunBind ms) = do newMs <- cas_matches ms return $ PatBind mkLoc (PVar $ getName $ head ms) (UnGuardedRhs newMs) (BDecls []) where getName (Match _ name _ _ _) = name cas_decl x = return x cas_matches :: [Match] -> ST Exp cas_matches ms@((Match _ _ pats _ _):_) = do alts <- mapM cas_alt ms seed <- gets id let npats = length pats return $ buildPVars npats seed $ Case (buildVars npats seed) alts where buildPVars 1 v = Lambda mkLoc [mkPVar $ v ++ "1"] buildPVars n v = Lambda mkLoc [mkPVar $ v ++ show n] . (buildPVars (n-1) v) buildVars 1 v = mkVar $ v++"1" buildVars n v = Tuple $ [mkVar $ v++show n,buildVars (n-1) v] cas_alt :: Match -> ST Alt cas_alt (Match l _ pats expRhs (BDecls ds)) = do ds' <- mapM cas_decl ds return $ Alt l (pat pats) altRhs (BDecls ds') where pat [x] = x pat (x:xs) = PTuple [x,pat xs] altRhs = case expRhs of UnGuardedRhs exp -> UnGuardedAlt exp GuardedRhss x -> GuardedAlts (map aux x) aux (GuardedRhs l x y) = (GuardedAlt l x y) -- auxiliary functions mkVar = Var . UnQual . Ident mkPVar = PVar . Ident mkLoc = SrcLoc "" 0 0 {- getSeed :: Data a => a -> String getSeed = flip replicate 'x' . maximum . (1:) . everything (++) (mkQ [] aux) where aux = (:[]) . (+1) . length . takeWhile (=='x') -}