module Language.Haskell.TH.Desugar.Util where
import Language.Haskell.TH
import Language.Haskell.TH.Syntax ( Quasi(..) )
import qualified Data.Set as S
import Data.Foldable
import Control.Applicative
reifyWithWarning :: Quasi q => Name -> q Info
reifyWithWarning name = qRecover
(fail $ "Looking up " ++ (show name) ++ " in the list of available " ++
"declarations failed.\nThis lookup fails if the declaration " ++
"referenced was made in the same Template\nHaskell splice as the use " ++
"of the declaration. If this is the case, put\nthe reference to " ++
"the declaration in a new splice.")
(qReify name)
impossible :: Quasi q => String -> q a
impossible err = fail (err ++ "\nThis should not happen in Haskell.\nPlease email eir@cis.upenn.edu with your code if you see this.")
getDataD :: Quasi q
=> String
-> Name
-> q ([TyVarBndr], [Con])
getDataD err name = do
info <- reifyWithWarning name
dec <- case info of
TyConI dec -> return dec
_ -> badDeclaration
case dec of
DataD _cxt _name tvbs cons _derivings -> return (tvbs, cons)
NewtypeD _cxt _name tvbs con _derivings -> return (tvbs, [con])
_ -> badDeclaration
where badDeclaration =
fail $ "The name (" ++ (show name) ++ ") refers to something " ++
"other than a datatype. " ++ err
dataConNameToCon :: Quasi q => Name -> q Con
dataConNameToCon con_name = do
info <- reifyWithWarning con_name
type_name <- case info of
DataConI _name _type parent_name _fixity -> return parent_name
_ -> impossible "Non-data-con used to construct a record."
(_, cons) <- getDataD "This seems to be an error in GHC." type_name
let m_con = find ((con_name ==) . get_con_name) cons
case m_con of
Just con -> return con
Nothing -> impossible "Datatype does not contain one of its own constructors."
where
get_con_name (NormalC name _) = name
get_con_name (RecC name _) = name
get_con_name (InfixC _ name _) = name
get_con_name (ForallC _ _ con) = get_con_name con
stripVarP_maybe :: Pat -> Maybe Name
stripVarP_maybe (VarP name) = Just name
stripVarP_maybe _ = Nothing
stripPlainTV_maybe :: TyVarBndr -> Maybe Name
stripPlainTV_maybe (PlainTV n) = Just n
stripPlainTV_maybe _ = Nothing
extractBoundNamesStmt :: Stmt -> S.Set Name
extractBoundNamesStmt (BindS pat _) = extractBoundNamesPat pat
extractBoundNamesStmt (LetS decs) = foldMap extractBoundNamesDec decs
extractBoundNamesStmt (NoBindS _) = S.empty
extractBoundNamesStmt (ParS stmtss) = foldMap (foldMap extractBoundNamesStmt) stmtss
extractBoundNamesDec :: Dec -> S.Set Name
extractBoundNamesDec (FunD name _) = S.singleton name
extractBoundNamesDec (ValD pat _ _) = extractBoundNamesPat pat
extractBoundNamesDec _ = S.empty
extractBoundNamesPat :: Pat -> S.Set Name
extractBoundNamesPat (LitP _) = S.empty
extractBoundNamesPat (VarP name) = S.singleton name
extractBoundNamesPat (TupP pats) = foldMap extractBoundNamesPat pats
extractBoundNamesPat (UnboxedTupP pats) = foldMap extractBoundNamesPat pats
extractBoundNamesPat (ConP _ pats) = foldMap extractBoundNamesPat pats
extractBoundNamesPat (InfixP p1 _ p2) = extractBoundNamesPat p1 `S.union`
extractBoundNamesPat p2
extractBoundNamesPat (UInfixP p1 _ p2) = extractBoundNamesPat p1 `S.union`
extractBoundNamesPat p2
extractBoundNamesPat (ParensP pat) = extractBoundNamesPat pat
extractBoundNamesPat (TildeP pat) = extractBoundNamesPat pat
extractBoundNamesPat (BangP pat) = extractBoundNamesPat pat
extractBoundNamesPat (AsP name pat) = S.singleton name `S.union` extractBoundNamesPat pat
extractBoundNamesPat WildP = S.empty
extractBoundNamesPat (RecP _ field_pats) = let (_, pats) = unzip field_pats in
foldMap extractBoundNamesPat pats
extractBoundNamesPat (ListP pats) = foldMap extractBoundNamesPat pats
extractBoundNamesPat (SigP pat _) = extractBoundNamesPat pat
extractBoundNamesPat (ViewP _ pat) = extractBoundNamesPat pat
concatMapM :: Applicative m => (a -> m [b]) -> [a] -> m [b]
concatMapM _ [] = pure []
concatMapM f (a : as) = (++) <$> f a <*> concatMapM f as
splitAtList :: [a] -> [b] -> ([b], [b])
splitAtList [] x = ([], x)
splitAtList (_ : t) (x : xs) =
let (as, bs) = splitAtList t xs in
(x : as, bs)
splitAtList (_ : _) [] = ([], [])