-- | -- This module implements the desugaring pass which replaces top-level binders with -- case expressions. -- module Language.PureScript.Sugar.CaseDeclarations ( desugarCases , desugarCasesModule , desugarCaseGuards ) where import Prelude import Protolude (ordNub) import Data.List (groupBy, foldl1') import Data.Maybe (catMaybes, mapMaybe) import Control.Monad ((<=<), forM, replicateM, join, unless) import Control.Monad.Error.Class (MonadError(..)) import Control.Monad.Supply.Class (MonadSupply) import Language.PureScript.AST import Language.PureScript.Crash (internalError) import Language.PureScript.Environment (NameKind(..)) import Language.PureScript.Errors (ErrorMessage(..), MultipleErrors(..), SimpleErrorMessage(..), addHint, errorMessage', parU, rethrow, withPosition) import Language.PureScript.Names (pattern ByNullSourcePos, Ident, Qualified(..), freshIdent') import Language.PureScript.TypeChecker.Monad (guardWith) -- | -- Replace all top-level binders in a module with case expressions. -- desugarCasesModule :: (MonadSupply m, MonadError MultipleErrors m) => Module -> m Module desugarCasesModule (Module ss coms name ds exps) = rethrow (addHint (ErrorInModule name)) $ Module ss coms name <$> (desugarCases <=< desugarAbs <=< validateCases $ ds) <*> pure exps desugarCaseGuards :: forall m. (MonadSupply m, MonadError MultipleErrors m) => [Declaration] -> m [Declaration] desugarCaseGuards declarations = parU declarations go where go d = let (f, _, _) = everywhereOnValuesM return (desugarGuardedExprs (declSourceSpan d)) return in f d -- | -- Desugar case with pattern guards and pattern clauses to a -- series of nested case expressions. -- desugarGuardedExprs :: forall m. (MonadSupply m) => SourceSpan -> Expr -> m Expr desugarGuardedExprs ss (Case scrut alternatives) | not $ all isTrivialExpr scrut = do -- in case the scrutinee is non trivial (e.g. not a Var or Literal) -- we may evaluate the scrutinee more than once when a guard occurs. -- We bind the scrutinee to Vars here to mitigate this case. (scrut', scrut_decls) <- unzip <$> forM scrut (\e -> do scrut_id <- freshIdent' pure ( Var ss (Qualified ByNullSourcePos scrut_id) , ValueDecl (ss, []) scrut_id Private [] [MkUnguarded e] ) ) Let FromLet scrut_decls <$> desugarGuardedExprs ss (Case scrut' alternatives) where isTrivialExpr (Var _ _) = True isTrivialExpr (Literal _ _) = True isTrivialExpr (Accessor _ e) = isTrivialExpr e isTrivialExpr (Parens e) = isTrivialExpr e isTrivialExpr (PositionedValue _ _ e) = isTrivialExpr e isTrivialExpr (TypedValue _ e _) = isTrivialExpr e isTrivialExpr _ = False desugarGuardedExprs ss (Case scrut alternatives) = let -- Alternatives which do not have guards are -- left as-is. Alternatives which -- -- 1) have multiple clauses of the form -- binder | g_1 -- , g_2 -- , ... -- , g_n -- -> expr -- -- 2) and/or contain pattern guards of the form -- binder | pat_bind <- e -- , ... -- -- are desugared to a sequence of nested case expressions. -- -- Consider an example case expression: -- -- case e of -- (T s) | Just info <- Map.lookup s names -- , is_used info -- -> f info -- -- We desugar this to -- -- case e of -- (T s) -> case Map.lookup s names of -- Just info -> case is_used info of -- True -> f info -- (_ -> ) -- (_ -> ) -- -- Note that if the original case is partial the desugared -- case is also partial. -- -- Consider an exhaustive case expression: -- -- case e of -- (T s) | Just info <- Map.lookup s names -- , is_used info -- -> f info -- _ -> Nothing -- -- desugars to: -- -- case e of -- _ -> let -- v _ = Nothing -- in -- case e of -- (T s) -> case Map.lookup s names of -- Just info -> f info -- _ -> v true -- _ -> v true -- -- This might look strange but simplifies the algorithm a lot. -- desugarAlternatives :: [CaseAlternative] -> m [CaseAlternative] desugarAlternatives [] = pure [] -- the trivial case: no guards desugarAlternatives (a@(CaseAlternative _ [MkUnguarded _]) : as) = (a :) <$> desugarAlternatives as -- Special case: CoreFn understands single condition guards on -- binders right hand side. desugarAlternatives (CaseAlternative ab ge : as) | not (null cond_guards) = (CaseAlternative ab cond_guards :) <$> desugarGuardedAlternative ab rest as | otherwise = desugarGuardedAlternative ab ge as where (cond_guards, rest) = span isSingleCondGuard ge isSingleCondGuard (GuardedExpr [ConditionGuard _] _) = True isSingleCondGuard _ = False desugarGuardedAlternative :: [Binder] -> [GuardedExpr] -> [CaseAlternative] -> m [CaseAlternative] desugarGuardedAlternative _vb [] rem_alts = desugarAlternatives rem_alts desugarGuardedAlternative vb (GuardedExpr gs e : ge) rem_alts = do rhs <- desugarAltOutOfLine vb ge rem_alts $ \alt_fail -> let -- if the binder is a var binder we must not add -- the fail case as it results in unreachable -- alternative alt_fail' n | all isIrrefutable vb = [] | otherwise = alt_fail n -- we are here: -- -- case scrut of -- ... -- _ -> let -- v _ = -- in case scrut of -- we are here -- ... -- in Case scrut (CaseAlternative vb [MkUnguarded (desugarGuard gs e alt_fail)] : alt_fail' (length scrut)) return [ CaseAlternative scrut_nullbinder [MkUnguarded rhs]] desugarGuard :: [Guard] -> Expr -> (Int ->[CaseAlternative]) -> Expr desugarGuard [] e _ = e desugarGuard (ConditionGuard c : gs) e match_failed | isTrueExpr c = desugarGuard gs e match_failed | otherwise = Case [c] (CaseAlternative [LiteralBinder ss (BooleanLiteral True)] [MkUnguarded (desugarGuard gs e match_failed)] : match_failed 1) desugarGuard (PatternGuard vb g : gs) e match_failed = Case [g] (CaseAlternative [vb] [MkUnguarded (desugarGuard gs e match_failed)] : match_failed') where -- don't consider match_failed case if the binder is irrefutable match_failed' | isIrrefutable vb = [] | otherwise = match_failed 1 -- we generate a let-binding for the remaining guards -- and alternatives. A CaseAlternative is passed (or in -- fact the original case is partial non is passed) to -- mk_body which branches to the generated let-binding. desugarAltOutOfLine :: [Binder] -> [GuardedExpr] -> [CaseAlternative] -> ((Int -> [CaseAlternative]) -> Expr) -> m Expr desugarAltOutOfLine alt_binder rem_guarded rem_alts mk_body | Just rem_case <- mkCaseOfRemainingGuardsAndAlts = do desugared <- desugarGuardedExprs ss rem_case rem_case_id <- freshIdent' unused_binder <- freshIdent' let goto_rem_case :: Expr goto_rem_case = Var ss (Qualified ByNullSourcePos rem_case_id) `App` Literal ss (BooleanLiteral True) alt_fail :: Int -> [CaseAlternative] alt_fail n = [CaseAlternative (replicate n NullBinder) [MkUnguarded goto_rem_case]] pure $ Let FromLet [ ValueDecl (ss, []) rem_case_id Private [] [MkUnguarded (Abs (VarBinder ss unused_binder) desugared)] ] (mk_body alt_fail) | otherwise = pure $ mk_body (const []) where mkCaseOfRemainingGuardsAndAlts | not (null rem_guarded) = Just $ Case scrut (CaseAlternative alt_binder rem_guarded : rem_alts) | not (null rem_alts) = Just $ Case scrut rem_alts | otherwise = Nothing scrut_nullbinder :: [Binder] scrut_nullbinder = replicate (length scrut) NullBinder -- case expressions with a single alternative which have -- a NullBinder occur frequently after desugaring -- complex guards. This function removes these superfluous -- cases. optimize :: Expr -> Expr optimize (Case _ [CaseAlternative vb [MkUnguarded v]]) | all isNullBinder vb = v where isNullBinder NullBinder = True isNullBinder (PositionedBinder _ _ b) = isNullBinder b isNullBinder (TypedBinder _ b) = isNullBinder b isNullBinder _ = False optimize e = e in do alts' <- desugarAlternatives alternatives return $ optimize (Case scrut alts') desugarGuardedExprs ss (TypedValue inferred e ty) = TypedValue inferred <$> desugarGuardedExprs ss e <*> pure ty desugarGuardedExprs _ (PositionedValue ss comms e) = PositionedValue ss comms <$> desugarGuardedExprs ss e desugarGuardedExprs _ v = pure v -- | -- Validates that case head and binder lengths match. -- validateCases :: forall m. (MonadSupply m, MonadError MultipleErrors m) => [Declaration] -> m [Declaration] validateCases = flip parU f where (f, _, _) = everywhereOnValuesM return validate return validate :: Expr -> m Expr validate c@(Case vs alts) = do let l = length vs alts' = filter ((l /=) . length . caseAlternativeBinders) alts unless (null alts') $ throwError . MultipleErrors $ fmap (altError l) (caseAlternativeBinders <$> alts') return c validate other = return other altError :: Int -> [Binder] -> ErrorMessage altError l bs = withPosition pos $ ErrorMessage [] $ CaseBinderLengthDiffers l bs where pos = foldl1' widenSpan (mapMaybe positionedBinder bs) widenSpan (SourceSpan n start end) (SourceSpan _ start' end') = SourceSpan n (min start start') (max end end') positionedBinder (PositionedBinder p _ _) = Just p positionedBinder _ = Nothing desugarAbs :: forall m. (MonadSupply m, MonadError MultipleErrors m) => [Declaration] -> m [Declaration] desugarAbs = flip parU f where (f, _, _) = everywhereOnValuesM return replace return replace :: Expr -> m Expr replace (Abs (stripPositioned -> (VarBinder ss i)) val) = pure (Abs (VarBinder ss i) val) replace (Abs binder val) = do ident <- freshIdent' return $ Abs (VarBinder nullSourceSpan ident) $ Case [Var nullSourceSpan (Qualified ByNullSourcePos ident)] [CaseAlternative [binder] [MkUnguarded val]] replace other = return other stripPositioned :: Binder -> Binder stripPositioned (PositionedBinder _ _ binder) = stripPositioned binder stripPositioned binder = binder -- | -- Replace all top-level binders with case expressions. -- desugarCases :: forall m. (MonadSupply m, MonadError MultipleErrors m) => [Declaration] -> m [Declaration] desugarCases = desugarRest <=< fmap join . flip parU toDecls . groupBy inSameGroup where desugarRest :: [Declaration] -> m [Declaration] desugarRest (TypeInstanceDeclaration sa na cd idx name constraints className tys ds : rest) = (:) <$> (TypeInstanceDeclaration sa na cd idx name constraints className tys <$> traverseTypeInstanceBody desugarCases ds) <*> desugarRest rest desugarRest (ValueDecl sa name nameKind bs result : rest) = let (_, f, _) = everywhereOnValuesTopDownM return go return f' = mapM (\(GuardedExpr gs e) -> GuardedExpr gs <$> f e) in (:) <$> (ValueDecl sa name nameKind bs <$> f' result) <*> desugarRest rest where go (Let w ds val') = Let w <$> desugarCases ds <*> pure val' go other = return other desugarRest (d : ds) = (:) d <$> desugarRest ds desugarRest [] = pure [] inSameGroup :: Declaration -> Declaration -> Bool inSameGroup (ValueDeclaration vd1) (ValueDeclaration vd2) = valdeclIdent vd1 == valdeclIdent vd2 inSameGroup _ _ = False toDecls :: forall m. (MonadSupply m, MonadError MultipleErrors m) => [Declaration] -> m [Declaration] toDecls [ValueDecl sa@(ss, _) ident nameKind bs [MkUnguarded val]] | all isIrrefutable bs = do args <- mapM fromVarBinder bs let body = foldr (Abs . VarBinder ss) val args guardWith (errorMessage' ss (OverlappingArgNames (Just ident))) $ length (ordNub args) == length args return [ValueDecl sa ident nameKind [] [MkUnguarded body]] where fromVarBinder :: Binder -> m Ident fromVarBinder NullBinder = freshIdent' fromVarBinder (VarBinder _ name) = return name fromVarBinder (PositionedBinder _ _ b) = fromVarBinder b fromVarBinder (TypedBinder _ b) = fromVarBinder b fromVarBinder _ = internalError "fromVarBinder: Invalid argument" toDecls ds@(ValueDecl (ss, _) ident _ bs (result : _) : _) = do let tuples = map toTuple ds isGuarded (MkUnguarded _) = False isGuarded _ = True unless (all ((== length bs) . length . fst) tuples) . throwError . errorMessage' ss $ ArgListLengthsDiffer ident unless (not (null bs) || isGuarded result) . throwError . errorMessage' ss $ DuplicateValueDeclaration ident caseDecl <- makeCaseDeclaration ss ident tuples return [caseDecl] toDecls ds = return ds toTuple :: Declaration -> ([Binder], [GuardedExpr]) toTuple (ValueDecl _ _ _ bs result) = (bs, result) toTuple _ = internalError "Not a value declaration" makeCaseDeclaration :: forall m. (MonadSupply m) => SourceSpan -> Ident -> [([Binder], [GuardedExpr])] -> m Declaration makeCaseDeclaration ss ident alternatives = do let namedArgs = map findName . fst <$> alternatives argNames = foldl1 resolveNames namedArgs args <- if allUnique (catMaybes argNames) then mapM argName argNames else replicateM (length argNames) ((nullSourceSpan, ) <$> freshIdent') let vars = map (Var ss . Qualified ByNullSourcePos . snd) args binders = [ CaseAlternative bs result | (bs, result) <- alternatives ] let value = foldr (Abs . uncurry VarBinder) (Case vars binders) args return $ ValueDecl (ss, []) ident Public [] [MkUnguarded value] where -- We will construct a table of potential names. -- VarBinders will become Just _ which is a potential name. -- Everything else becomes Nothing, which indicates that we -- have to generate a name. findName :: Binder -> Maybe (SourceSpan, Ident) findName (VarBinder ss' name) = Just (ss', name) findName (PositionedBinder _ _ binder) = findName binder findName _ = Nothing -- We still have to make sure the generated names are unique, or else -- we will end up constructing an invalid function. allUnique :: (Ord a) => [a] -> Bool allUnique xs = length xs == length (ordNub xs) argName :: Maybe (SourceSpan, Ident) -> m (SourceSpan, Ident) argName (Just (ss', name)) = return (ss', name) argName _ = (nullSourceSpan, ) <$> freshIdent' -- Combine two lists of potential names from two case alternatives -- by zipping corresponding columns. resolveNames :: [Maybe (SourceSpan, Ident)] -> [Maybe (SourceSpan, Ident)] -> [Maybe (SourceSpan, Ident)] resolveNames = zipWith resolveName -- Resolve a pair of names. VarBinder beats NullBinder, and everything -- else results in Nothing. resolveName :: Maybe (SourceSpan, Ident) -> Maybe (SourceSpan, Ident) -> Maybe (SourceSpan, Ident) resolveName (Just a) (Just b) | a == b = Just a | otherwise = Nothing resolveName _ _ = Nothing