module Control.Monad.TestFixture.TH.Internal.TypesQuasi (ts) where
import Control.Monad ((<=<))
import Language.Haskell.Exts.Lexer
import Language.Haskell.Exts.Parser
import Language.Haskell.Exts.SrcLoc
import Language.Haskell.Meta.Syntax.Translate (toType)
import Language.Haskell.TH.Instances ()
import Language.Haskell.TH.Syntax hiding (Loc)
import Language.Haskell.TH.Quote
ts :: QuasiQuoter
ts = QuasiQuoter
{ quoteExp = \str -> case parseTypesSplitOnCommas str of
ParseOk tys -> lift =<< mapM resolveTypeNames tys
ParseFailed _ msg -> fail msg
, quotePat = error "ts can only be used in an expression context"
, quoteType = error "ts can only be used in an expression context"
, quoteDec = error "ts can only be used in an expression context"
}
parseTypesSplitOnCommas :: String -> ParseResult [Type]
parseTypesSplitOnCommas = fmap (map toType) . mapM parseType <=< lexSplitOnCommas
lexSplitOnCommas :: String -> ParseResult [String]
lexSplitOnCommas str = splitOnSrcSpans str <$> lexSplittingCommas str
splitOnSrcSpans :: String -> [SrcSpan] -> [String]
splitOnSrcSpans str [] = [str]
splitOnSrcSpans str spans@(x:xs) = case x of
SrcSpan { srcSpanStartLine = line, srcSpanStartColumn = col }
| line > 1 ->
let (l, _:ls) = break (== '\n') str
(r:rs) = splitOnSrcSpans ls (map advanceLine spans)
in (l ++ "\n" ++ r) : rs
| col > 1 ->
let (currentLs, nextLs) = span ((== line) . srcSpanStartLine) spans
(c:cs) = str
(r:rs) = splitOnSrcSpans cs (map advanceColumn currentLs ++ nextLs)
in (c : r) : rs
| otherwise ->
let (currentLs, nextLs) = span ((== line) . srcSpanStartLine) xs
(_:cs) = str
in "" : splitOnSrcSpans cs (map advanceColumn currentLs ++ nextLs)
advanceLine :: SrcSpan -> SrcSpan
advanceLine s@SrcSpan { srcSpanStartLine = line } = s { srcSpanStartLine = line 1 }
advanceColumn :: SrcSpan -> SrcSpan
advanceColumn s@SrcSpan { srcSpanStartColumn = col } = s { srcSpanStartColumn = col 1 }
lexSplittingCommas :: String -> ParseResult [SrcSpan]
lexSplittingCommas = fmap splittingCommas . lexTokenStream
splittingCommas :: [Loc Token] -> [SrcSpan]
splittingCommas = map loc . go
where go [] = []
go (x@Loc{ unLoc = Comma }:xs) = x : go xs
go (Loc{ unLoc = LeftParen }:xs) = go $ skipUntil RightParen xs
go (Loc{ unLoc = LeftSquare }:xs) = go $ skipUntil RightSquare xs
go (Loc{ unLoc = LeftCurly }:xs) = go $ skipUntil RightCurly xs
go (_:xs) = go xs
skipUntil _ [] = []
skipUntil d (Loc{ unLoc = LeftParen }:xs) = skipUntil d $ skipUntil RightParen xs
skipUntil d (Loc{ unLoc = LeftSquare }:xs) = skipUntil d $ skipUntil RightSquare xs
skipUntil d (Loc{ unLoc = LeftCurly }:xs) = skipUntil d $ skipUntil RightCurly xs
skipUntil d (Loc{ unLoc = t }:xs)
| t == d = xs
| otherwise = skipUntil d xs
resolveTypeNames :: Type -> Q Type
resolveTypeNames (ConT nm) = ConT <$> resolveTypeName nm
resolveTypeNames (ForallT tyVars ctx t) = ForallT tyVars <$> mapM resolveTypeNames ctx <*> resolveTypeNames t
resolveTypeNames (AppT a b) = AppT <$> resolveTypeNames a <*> resolveTypeNames b
resolveTypeNames (SigT t k) = SigT <$> resolveTypeNames t <*> resolveTypeNames k
resolveTypeNames t@VarT{} = return t
resolveTypeNames t@PromotedT{} = return t
resolveTypeNames t@TupleT{} = return t
resolveTypeNames t@UnboxedTupleT{} = return t
resolveTypeNames t@ArrowT{} = return t
resolveTypeNames t@EqualityT = return t
resolveTypeNames t@ListT = return t
resolveTypeNames t@PromotedTupleT{} = return t
resolveTypeNames t@PromotedNilT = return t
resolveTypeNames t@PromotedConsT = return t
resolveTypeNames t@StarT = return t
resolveTypeNames t@ConstraintT = return t
resolveTypeNames t@LitT{} = return t
#if MIN_VERSION_template_haskell(2,11,0)
resolveTypeNames (InfixT a n b) = InfixT <$> resolveTypeNames a <*> resolveTypeName n <*> resolveTypeNames b
resolveTypeNames (UInfixT a n b) = UInfixT <$> resolveTypeNames a <*> resolveTypeName n <*> resolveTypeNames b
resolveTypeNames (ParensT t) = ParensT <$> resolveTypeNames t
resolveTypeNames t@WildCardT = return t
#endif
#if MIN_VERSION_template_haskell(2,12,0)
resolveTypeNames t@UnboxedSumT{} = return t
#endif
resolveTypeName :: Name -> Q Name
resolveTypeName (Name (OccName str) NameS) = lookupTypeName str >>= \case
Just nm -> return nm
Nothing -> fail $ "unbound type name ‘" ++ str ++ "’"
resolveTypeName nm = return nm