{- Language/Haskell/TH/Desugar/Match.hs (c) Richard Eisenberg 2013 eir@cis.upenn.edu Simplifies case statements in desugared TH. After this pass, there are no more nested patterns. This code is directly based on the analogous operation as written in GHC. -} {-# LANGUAGE CPP, TemplateHaskell #-} #if __GLASGOW_HASKELL__ <= 708 {-# LANGUAGE StandaloneDeriving #-} {-# OPTIONS_GHC -fno-warn-orphans #-} -- we need Ord Lit. argh. #endif module Language.Haskell.TH.Desugar.Match (scExp, scLetDec) where import Prelude hiding ( fail, exp ) #if __GLASGOW_HASKELL__ < 709 import Control.Applicative #endif import Control.Monad hiding ( fail ) import qualified Data.Set as S import qualified Data.Map as Map import Language.Haskell.TH.Instances () import Language.Haskell.TH.Syntax import Language.Haskell.TH.Desugar.Core import Language.Haskell.TH.Desugar.Util import Language.Haskell.TH.Desugar.Reify -- | Remove all nested pattern-matches within this expression. This also -- removes all 'DTildePa's and 'DBangPa's. After this is run, every pattern -- is guaranteed to be either a 'DConPa' with bare variables as arguments, -- a 'DLitPa', or a 'DWildPa'. scExp :: DsMonad q => DExp -> q DExp scExp (DAppE e1 e2) = DAppE <$> scExp e1 <*> scExp e2 scExp (DLamE names exp) = DLamE names <$> scExp exp scExp (DCaseE scrut matches) | DVarE name <- scrut = simplCaseExp [name] clauses | otherwise = do scrut_name <- newUniqueName "scrut" case_exp <- simplCaseExp [scrut_name] clauses return $ DLetE [DValD (DVarPa scrut_name) scrut] case_exp where clauses = map match_to_clause matches match_to_clause (DMatch pat exp) = DClause [pat] exp scExp (DLetE decs body) = DLetE <$> mapM scLetDec decs <*> scExp body scExp (DSigE exp ty) = DSigE <$> scExp exp <*> pure ty scExp e = return e -- | Like 'scExp', but for a 'DLetDec'. scLetDec :: DsMonad q => DLetDec -> q DLetDec scLetDec (DFunD name clauses@(DClause pats1 _ : _)) = do arg_names <- mapM (const (newUniqueName "_arg")) pats1 clauses' <- mapM sc_clause_rhs clauses case_exp <- simplCaseExp arg_names clauses' return $ DFunD name [DClause (map DVarPa arg_names) case_exp] where sc_clause_rhs (DClause pats exp) = DClause pats <$> scExp exp scLetDec (DValD pat exp) = DValD pat <$> scExp exp scLetDec dec = return dec type MatchResult = DExp -> DExp matchResultToDExp :: MatchResult -> DExp matchResultToDExp mr = mr failed_pattern_match where failed_pattern_match = DAppE (DVarE 'error) (DLitE $ StringL "Pattern-match failure") simplCaseExp :: DsMonad q => [Name] -> [DClause] -> q DExp simplCaseExp vars clauses = do let eis = [ EquationInfo pats (\_ -> rhs) | DClause pats rhs <- clauses ] matchResultToDExp `liftM` simplCase vars eis data EquationInfo = EquationInfo [DPat] MatchResult -- like DClause, but with a hole -- analogous to GHC's match (in deSugar/Match.lhs) simplCase :: DsMonad q => [Name] -- the names of the scrutinees -> [EquationInfo] -- the matches (where the # of pats == length (1st arg)) -> q MatchResult simplCase [] clauses = return (foldr1 (.) match_results) where match_results = [ mr | EquationInfo _ mr <- clauses ] simplCase vars@(v:_) clauses = do (aux_binds, tidy_clauses) <- mapAndUnzipM (tidyClause v) clauses let grouped = groupClauses tidy_clauses match_results <- match_groups grouped return (adjustMatchResult (foldr (.) id aux_binds) $ foldr1 (.) match_results) where match_groups :: DsMonad q => [[(PatGroup, EquationInfo)]] -> q [MatchResult] match_groups [] = matchEmpty v match_groups gs = mapM match_group gs match_group :: DsMonad q => [(PatGroup, EquationInfo)] -> q MatchResult match_group [] = error "Internal error in th-desugar (match_group)" match_group eqns@((group,_) : _) = case group of PgCon _ -> matchConFamily vars (subGroup [(c,e) | (PgCon c, e) <- eqns]) PgLit _ -> matchLiterals vars (subGroup [(l,e) | (PgLit l, e) <- eqns]) PgBang -> matchBangs vars (drop_group eqns) PgAny -> matchVariables vars (drop_group eqns) drop_group = map snd -- analogous to GHC's tidyEqnInfo tidyClause :: DsMonad q => Name -> EquationInfo -> q (DExp -> DExp, EquationInfo) tidyClause _ (EquationInfo [] _) = error "Internal error in th-desugar: no patterns in tidyClause." tidyClause v (EquationInfo (pat : pats) body) = do (wrap, pat') <- tidy1 v pat return (wrap, EquationInfo (pat' : pats) body) tidy1 :: DsMonad q => Name -- the name of the variable that ... -> DPat -- ... this pattern is matching against -> q (DExp -> DExp, DPat) -- a wrapper and tidied pattern tidy1 _ p@(DLitPa {}) = return (id, p) tidy1 v (DVarPa var) = return (wrapBind var v, DWildPa) tidy1 _ p@(DConPa {}) = return (id, p) tidy1 v (DTildePa pat) = do sel_decs <- mkSelectorDecs pat v return (maybeDLetE sel_decs, DWildPa) tidy1 v (DBangPa pat) = case pat of DLitPa _ -> tidy1 v pat -- already strict DVarPa _ -> return (id, DBangPa pat) -- no change DConPa _ _ -> tidy1 v pat -- already strict DTildePa p -> tidy1 v (DBangPa p) -- discard ~ under ! DBangPa p -> tidy1 v (DBangPa p) -- discard ! under ! DWildPa -> return (id, DBangPa pat) -- no change tidy1 _ DWildPa = return (id, DWildPa) wrapBind :: Name -> Name -> DExp -> DExp wrapBind new old | new == old = id | otherwise = DLetE [DValD (DVarPa new) (DVarE old)] -- like GHC's mkSelectorBinds mkSelectorDecs :: DsMonad q => DPat -- pattern to deconstruct -> Name -- variable being matched against -> q [DLetDec] mkSelectorDecs (DVarPa v) name = return [DValD (DVarPa v) (DVarE name)] mkSelectorDecs pat name | S.null binders = return [] | S.size binders == 1 = do val_var <- newUniqueName "var" err_var <- newUniqueName "err" bind <- mk_bind val_var err_var (head $ S.elems binders) return [DValD (DVarPa val_var) (DVarE name), DValD (DVarPa err_var) (DVarE 'error `DAppE` (DLitE $ StringL "Irrefutable match failed")), bind] | otherwise = do tuple_expr <- simplCaseExp [name] [DClause [pat] local_tuple] tuple_var <- newUniqueName "tuple" projections <- mapM (mk_projection tuple_var) [0 .. tuple_size-1] return (DValD (DVarPa tuple_var) tuple_expr : zipWith DValD (map DVarPa binders_list) projections) where binders = extractBoundNamesDPat pat binders_list = S.toAscList binders tuple_size = length binders_list local_tuple = mkTupleDExp (map DVarE binders_list) mk_projection :: DsMonad q => Name -- of the tuple -> Int -- which element to get (0-indexed) -> q DExp mk_projection tup_name i = do var_name <- newUniqueName "proj" return $ DCaseE (DVarE tup_name) [DMatch (DConPa (tupleDataName tuple_size) (mk_tuple_pats var_name i)) (DVarE var_name)] mk_tuple_pats :: Name -- of the projected element -> Int -- which element to get (0-indexed) -> [DPat] mk_tuple_pats elt_name i = replicate i DWildPa ++ DVarPa elt_name : replicate (tuple_size - i - 1) DWildPa mk_bind scrut_var err_var bndr_var = do rhs_mr <- simplCase [scrut_var] [EquationInfo [pat] (\_ -> DVarE bndr_var)] return (DValD (DVarPa bndr_var) (rhs_mr (DVarE err_var))) extractBoundNamesDPat :: DPat -> S.Set Name extractBoundNamesDPat (DLitPa _) = S.empty extractBoundNamesDPat (DVarPa n) = S.singleton n extractBoundNamesDPat (DConPa _ pats) = S.unions (map extractBoundNamesDPat pats) extractBoundNamesDPat (DTildePa p) = extractBoundNamesDPat p extractBoundNamesDPat (DBangPa p) = extractBoundNamesDPat p extractBoundNamesDPat DWildPa = S.empty data PatGroup = PgAny -- immediate match (wilds, vars, lazies) | PgCon Name | PgLit Lit | PgBang -- like GHC's groupEquations groupClauses :: [EquationInfo] -> [[(PatGroup, EquationInfo)]] groupClauses clauses = runs same_gp [(patGroup (firstPat clause), clause) | clause <- clauses] where same_gp :: (PatGroup, EquationInfo) -> (PatGroup, EquationInfo) -> Bool (pg1,_) `same_gp` (pg2,_) = pg1 `sameGroup` pg2 patGroup :: DPat -> PatGroup patGroup (DLitPa l) = PgLit l patGroup (DVarPa {}) = error "Internal error in th-desugar (patGroup DVarP)" patGroup (DConPa con _) = PgCon con patGroup (DTildePa {}) = error "Internal error in th-desugar (patGroup DTildeP)" patGroup (DBangPa {}) = PgBang patGroup DWildPa = PgAny sameGroup :: PatGroup -> PatGroup -> Bool sameGroup PgAny PgAny = True sameGroup PgBang PgBang = True sameGroup (PgCon _) (PgCon _) = True sameGroup (PgLit _) (PgLit _) = True sameGroup _ _ = False subGroup :: Ord a => [(a, EquationInfo)] -> [[EquationInfo]] subGroup group = map reverse $ Map.elems $ foldl accumulate Map.empty group where accumulate pg_map (pg, eqn) = case Map.lookup pg pg_map of Just eqns -> Map.insert pg (eqn:eqns) pg_map Nothing -> Map.insert pg [eqn] pg_map firstPat :: EquationInfo -> DPat firstPat (EquationInfo (pat : _) _) = pat firstPat _ = error "Clause encountered with no patterns -- should never happen" data CaseAlt = CaseAlt { alt_con :: Name -- con name , _alt_args :: [Name] -- bound var names , _alt_rhs :: MatchResult -- RHS } -- from GHC's MatchCon.lhs matchConFamily :: DsMonad q => [Name] -> [[EquationInfo]] -> q MatchResult matchConFamily (var:vars) groups = do alts <- mapM (matchOneCon vars) groups mkDataConCase var alts matchConFamily [] _ = error "Internal error in th-desugar (matchConFamily)" -- like matchOneConLike from MatchCon matchOneCon :: DsMonad q => [Name] -> [EquationInfo] -> q CaseAlt matchOneCon vars eqns@(eqn1 : _) = do arg_vars <- selectMatchVars (pat_args pat1) match_result <- match_group arg_vars return $ CaseAlt (pat_con pat1) arg_vars match_result where pat1 = firstPat eqn1 pat_args (DConPa _ pats) = pats pat_args _ = error "Internal error in th-desugar (pat_args)" pat_con (DConPa con _) = con pat_con _ = error "Internal error in th-desugar (pat_con)" match_group :: DsMonad q => [Name] -> q MatchResult match_group arg_vars = simplCase (arg_vars ++ vars) (map shift eqns) shift (EquationInfo (DConPa _ args : pats) exp) = EquationInfo (args ++ pats) exp shift _ = error "Internal error in th-desugar (shift)" matchOneCon _ _ = error "Internal error in th-desugar (matchOneCon)" mkDataConCase :: DsMonad q => Name -> [CaseAlt] -> q MatchResult mkDataConCase var case_alts = do all_ctors <- get_all_ctors (alt_con $ head case_alts) return $ \fail -> let matches = map (mk_alt fail) case_alts in DCaseE (DVarE var) (matches ++ mk_default all_ctors fail) where mk_alt fail (CaseAlt con args body_fn) = let body = body_fn fail in DMatch (DConPa con (map DVarPa args)) body mk_default all_ctors fail | exhaustive_case all_ctors = [] | otherwise = [DMatch DWildPa fail] mentioned_ctors = S.fromList $ map alt_con case_alts exhaustive_case all_ctors = all_ctors `S.isSubsetOf` mentioned_ctors get_all_ctors :: DsMonad q => Name -> q (S.Set Name) get_all_ctors con_name = do ty_name <- dataConNameToDataName con_name Just (DTyConI tycon_dec _) <- dsReify ty_name return $ S.fromList $ map get_con_name $ get_cons tycon_dec get_cons (DDataD _ _ _ _ cons _) = cons get_cons (DDataInstD _ _ _ _ cons _) = cons get_cons _ = [] get_con_name (DCon _ _ n _ _) = n matchEmpty :: DsMonad q => Name -> q [MatchResult] matchEmpty var = return [mk_seq] where mk_seq fail = DCaseE (DVarE var) [DMatch DWildPa fail] matchLiterals :: DsMonad q => [Name] -> [[EquationInfo]] -> q MatchResult matchLiterals (var:vars) sub_groups = do alts <- mapM match_group sub_groups return (mkCoPrimCaseMatchResult var alts) where match_group :: DsMonad q => [EquationInfo] -> q (Lit, MatchResult) match_group eqns = do let DLitPa lit = firstPat (head eqns) match_result <- simplCase vars (shiftEqns eqns) return (lit, match_result) matchLiterals [] _ = error "Internal error in th-desugar (matchLiterals)" mkCoPrimCaseMatchResult :: Name -- Scrutinee -> [(Lit, MatchResult)] -> MatchResult mkCoPrimCaseMatchResult var match_alts = mk_case where mk_case fail = let alts = map (mk_alt fail) match_alts in DCaseE (DVarE var) (alts ++ [DMatch DWildPa fail]) mk_alt fail (lit, body_fn) = DMatch (DLitPa lit) (body_fn fail) matchBangs :: DsMonad q => [Name] -> [EquationInfo] -> q MatchResult matchBangs (var:vars) eqns = do match_result <- simplCase (var:vars) $ map (decomposeFirstPat getBangPat) eqns return (mkEvalMatchResult var match_result) matchBangs [] _ = error "Internal error in th-desugar (matchBangs)" decomposeFirstPat :: (DPat -> DPat) -> EquationInfo -> EquationInfo decomposeFirstPat extractpat (EquationInfo (pat:pats) body) = EquationInfo (extractpat pat : pats) body decomposeFirstPat _ _ = error "Internal error in th-desugar (decomposeFirstPat)" getBangPat :: DPat -> DPat getBangPat (DBangPa p) = p getBangPat _ = error "Internal error in th-desugar (getBangPat)" mkEvalMatchResult :: Name -> MatchResult -> MatchResult mkEvalMatchResult var body_fn fail = foldl DAppE (DVarE 'seq) [DVarE var, body_fn fail] matchVariables :: DsMonad q => [Name] -> [EquationInfo] -> q MatchResult matchVariables (_:vars) eqns = simplCase vars (shiftEqns eqns) matchVariables _ _ = error "Internal error in th-desugar (matchVariables)" shiftEqns :: [EquationInfo] -> [EquationInfo] shiftEqns = map shift where shift (EquationInfo pats rhs) = EquationInfo (tail pats) rhs adjustMatchResult :: (DExp -> DExp) -> MatchResult -> MatchResult adjustMatchResult wrap mr fail = wrap $ mr fail -- from DsUtils selectMatchVars :: DsMonad q => [DPat] -> q [Name] selectMatchVars = mapM selectMatchVar -- from DsUtils selectMatchVar :: DsMonad q => DPat -> q Name selectMatchVar (DBangPa pat) = selectMatchVar pat selectMatchVar (DTildePa pat) = selectMatchVar pat selectMatchVar (DVarPa var) = newUniqueName ('_' : nameBase var) selectMatchVar _ = newUniqueName "_pat" -- like GHC's runs runs :: (a -> a -> Bool) -> [a] -> [[a]] runs _ [] = [] runs p (x:xs) = case span (p x) xs of (first, rest) -> (x:first) : (runs p rest)