module Test.Hspec.Expectations.Match
( shouldMatch
, shouldReturnAndMatch
, assertDo
) where
import Control.Monad.Base (MonadBase, liftBase)
import Data.Maybe (fromMaybe)
import GHC.Stack (HasCallStack)
import Test.Hspec.Expectations (expectationFailure)
import Language.Haskell.TH.Ppr
import Language.Haskell.TH.Syntax
assertPatternMatchFailure :: (HasCallStack, MonadBase IO m, Show a) => String -> a -> m b
assertPatternMatchFailure pat val =
liftBase (expectationFailure (showsPrec 11 val "" ++ " failed to match pattern " ++ pat))
*> pure (error "assertPatternMatchFailure: internal error")
shouldMatch :: Q Exp -> Q Pat -> Q Exp
shouldMatch qExpr qPat = do
expr <- qExpr
pat <- qPat
patStr <- showsPat 11 pat
valName <- newName "val"
let successExpr = VarE 'pure `AppE` patBindingsToTupleExp pat
let failureExpr = VarE 'assertPatternMatchFailure `AppE` LitE (StringL (patStr "")) `AppE` VarE valName
pure $ CaseE expr
[ Match pat (NormalB successExpr) []
, Match (VarP valName) (NormalB failureExpr) []
]
shouldReturnAndMatch :: Q Exp -> Q Pat -> Q Exp
shouldReturnAndMatch qExpr qPat = do
expr <- qExpr
pat <- qPat
patStr <- showsPat 11 pat
valName <- newName "val"
let successExpr = VarE 'pure `AppE` patBindingsToTupleExp pat
let failureExpr = VarE 'assertPatternMatchFailure `AppE` LitE (StringL (patStr "")) `AppE` VarE valName
pure $ VarE '(>>=) `AppE` expr `AppE` LamE [VarP valName]
(CaseE (VarE valName)
[ Match pat (NormalB successExpr) []
, Match WildP (NormalB failureExpr) []
])
patBindingsToTupleExp :: Pat -> Exp
patBindingsToTupleExp = TupE . map VarE . patBindings
patBindingsToTuplePat :: Pat -> Pat
patBindingsToTuplePat = TupP . map VarP . patBindings
patBindings :: Pat -> [Name]
patBindings (LitP _) = []
patBindings (VarP nm) = [nm]
patBindings (TupP pats) = concatMap patBindings pats
patBindings (UnboxedTupP pats) = concatMap patBindings pats
patBindings (ConP _ pats) = concatMap patBindings pats
patBindings (InfixP patA _ patB) = patBindings patA ++ patBindings patB
patBindings (UInfixP patA _ patB) = patBindings patA ++ patBindings patB
patBindings (ParensP pat) = patBindings pat
patBindings (TildeP pat) = patBindings pat
patBindings (BangP pat) = patBindings pat
patBindings (AsP nm pat) = nm : patBindings pat
patBindings WildP = []
patBindings (RecP _ fieldPats) = concatMap (patBindings . snd) fieldPats
patBindings (ListP pats) = concatMap patBindings pats
patBindings (SigP pat _) = patBindings pat
patBindings (ViewP _ pat) = patBindings pat
#if MIN_VERSION_GLASGOW_HASKELL(8,2,1,0)
patBindings (UnboxedSumP pat _ _) = patBindings pat
#endif
showsPat :: Int -> Pat -> Q ShowS
showsPat prec p = case p of
LitP lit -> pure $ showString (showLit lit)
VarP nm -> pure $ showString (nameBase nm)
TupP [] -> pure $ showString "()"
TupP pats -> do
pats' <- traverse (showsPat 0) pats
pure $ showChar '(' . foldr1 (\s r -> s . showString ", " . r) pats' . showChar ')'
UnboxedTupP [] -> pure $ showString "(# #)"
UnboxedTupP pats -> do
pats' <- traverse (showsPat 0) pats
pure $ showString "(# " . foldr1 (\s r -> s . showString ", " . r) pats' . showString " #)"
ConP nm [] -> pure $ showString (nameBase nm)
ConP nm pats -> do
pats' <- traverse (showsPat 11) pats
pure . showParen (prec > 10) $
showString (nameBase nm) . showChar ' ' . foldr1 (\s r -> s . showChar ' ' . r) pats'
InfixP patA nm patB -> showInfix patA nm patB
UInfixP patA nm patB -> showInfix patA nm patB
ParensP pat -> showParen True <$> showsPat 0 pat
TildeP pat -> (showChar '~' .) <$> showsPat 11 pat
BangP pat -> (showChar '!' .) <$> showsPat 11 pat
AsP nm pat -> ((showString (nameBase nm) . showChar '@') .) <$> showsPat 11 pat
WildP -> pure $ showChar '_'
RecP nm [] -> pure $ showString (nameBase nm) . showString " {}"
RecP nm fieldPats -> do
fieldPats' <- showFieldPats fieldPats
pure $ showString (nameBase nm) . showString " { " . fieldPats' . showString " }"
ListP [] -> pure $ showString "[]"
ListP pats -> do
pats' <- traverse (showsPat 0) pats
pure $ showChar '[' . foldr1 (\s r -> s . showString ", " . r) pats' . showChar ']'
SigP pat ty -> do
pat' <- showsPat 10 pat
pure . showParen (prec > 0) $ pat' . showString " :: " . showsPrec 10 (ppr ty)
ViewP expr pat -> do
pat' <- showsPat 10 pat
pure . showParen (prec > 0) $ showsPrec 10 (ppr expr) . showString " -> " . pat'
#if MIN_VERSION_GLASGOW_HASKELL(8,2,1,0)
UnboxedSumP pat alt arity -> do
pat' <- showsPat 0 pat
pure $
showString "(#" . showString (replicate (alt 1) '|') . pat'
. showString (replicate (arity alt) '|') . showString "#)"
#endif
where
showInfix patA nm patB = do
Fixity nmPrec _ <- fromMaybe defaultFixity <$> reifyFixity nm
patA' <- showsPat (nmPrec + 1) patA
patB' <- showsPat (nmPrec + 1) patB
pure . showParen (prec > nmPrec) $
patA' . showChar ' ' . showString (nameBase nm) . showChar ' ' . patB'
showFieldPats fieldPats = do
fieldPats' <- traverse showFieldPat fieldPats
pure $ foldr1 (\s r -> s . showString ", " . r) fieldPats'
showFieldPat (nm, pat) = ((showString (nameBase nm) . showString " = ") .) <$> showsPat 0 pat
showLit :: Lit -> String
showLit (CharL c) = show c
showLit (StringL s) = show s
showLit (IntegerL i) = show i
showLit (RationalL r) = show (fromRational r :: Double)
showLit (IntPrimL i) = show i ++ "#"
showLit (WordPrimL i) = show i ++ "##"
showLit (FloatPrimL r) = show (fromRational r :: Float) ++ "#"
showLit (DoublePrimL r) = show (fromRational r :: Double) ++ "##"
showLit (StringPrimL s) = show s ++ "#"
showLit (CharPrimL c) = show c ++ "#"
assertDo :: Q Exp -> Q Exp
assertDo qDoExp = qDoExp >>= \case
DoE stmts -> DoE <$> traverse annotateStatement stmts
_ -> fail "assertDo: expected a do block"
where
annotateStatement stmt@(BindS pat expr) = case pat of
VarP _ -> pure stmt
WildP -> pure stmt
_ -> do
hasOtherCases <- patternHasOtherCases pat
if hasOtherCases
then BindS (patBindingsToTuplePat pat) <$> shouldReturnAndMatch (pure expr) (pure pat)
else pure stmt
annotateStatement stmt = pure stmt
patternHasOtherCases (LitP _) = pure True
patternHasOtherCases (VarP _) = pure False
patternHasOtherCases (TupP pats) = or <$> traverse patternHasOtherCases pats
patternHasOtherCases (UnboxedTupP pats) = or <$> traverse patternHasOtherCases pats
patternHasOtherCases (ConP nm pats) = do
DataConI _ _ tyNm <- reify nm
tyInfo <- reify tyNm
conHasOtherCases <- case tyInfo of
TyConI dec -> case dec of
DataD _ _ _ _ cons _ -> pure (length cons > 1)
NewtypeD{} -> pure False
_ -> fail ("patternHasOtherCases: internal error; unexpected declaration in TyConI: " ++ show dec)
_ -> fail ("patternHasOtherCases: internal error; unexpected Info in DataConI: " ++ show tyInfo)
if conHasOtherCases
then pure True
else or <$> traverse patternHasOtherCases pats
patternHasOtherCases (InfixP patA nm patB) = patternHasOtherCases (ConP nm [patA, patB])
patternHasOtherCases (UInfixP patA nm patB) = patternHasOtherCases (ConP nm [patA, patB])
patternHasOtherCases (ParensP pat) = patternHasOtherCases pat
patternHasOtherCases (TildeP pat) = patternHasOtherCases pat
patternHasOtherCases (BangP pat) = patternHasOtherCases pat
patternHasOtherCases (AsP _ pat) = patternHasOtherCases pat
patternHasOtherCases WildP = pure False
patternHasOtherCases (RecP _ fieldPats) = or <$> traverse (patternHasOtherCases . snd) fieldPats
patternHasOtherCases (ListP _) = pure True
patternHasOtherCases (SigP pat _) = patternHasOtherCases pat
patternHasOtherCases (ViewP _ _) = pure True