module Language.Haskell.FreeTheorems.Parser.Haskell98 (parse) where
import Control.Monad (foldM, liftM, liftM2)
import Control.Monad.Error (throwError)
import Control.Monad.Writer (Writer, tell)
import Data.Generics (everywhere, mkT)
import Data.List (nub)
import Language.Haskell.Parser (parseModule, ParseResult(..))
import Language.Haskell.Syntax
import Text.PrettyPrint
import qualified Language.Haskell.FreeTheorems.Syntax as S
import Language.Haskell.FreeTheorems.Frontend.Error
parse :: String -> Parsed [S.Declaration]
parse text = case parseModule text of
ParseOk hsModule -> let decls = transform . filterDeclarations $ hsModule
in foldM collectDeclarations [] decls
ParseFailed l _ -> do tell [pp ("Parse error at (" ++ show (srcLine l)
++ ":" ++ show (srcColumn l) ++ ").")]
return []
where
collectDeclarations :: [S.Declaration] -> HsDecl -> Parsed [S.Declaration]
collectDeclarations ds d =
case mkDeclaration d of
Left e -> tell [e] >> return ds
Right d' -> return (ds ++ [d'])
filterDeclarations :: HsModule -> [HsDecl]
filterDeclarations (HsModule _ _ _ _ ds) = filter isAcceptedDeclaration ds
where
isAcceptedDeclaration decl = case decl of
HsTypeDecl _ _ _ _ -> True
HsDataDecl _ _ _ _ _ _ -> True
HsNewTypeDecl _ _ _ _ _ _ -> True
HsClassDecl _ _ _ _ _ -> True
HsTypeSig _ _ _ -> True
otherwise -> False
transform :: [HsDecl] -> [HsDecl]
transform = everywhere (mkT extendTypeSignature)
where
extendTypeSignature :: [HsDecl] -> [HsDecl]
extendTypeSignature ds = case ds of
((HsTypeSig l ns t):ds') -> (map (\n -> HsTypeSig l [n] t) ns) ++ ds'
otherwise -> ds
mkDeclaration :: HsDecl -> ErrorOr S.Declaration
mkDeclaration decl = case decl of
HsTypeDecl l n vs t -> addErr l n (mkType n vs t)
HsDataDecl l _ n vs cs _ -> addErr l n (mkData n vs cs)
HsNewTypeDecl l _ n vs c _ -> addErr l n (mkNewtype n vs c)
HsClassDecl l scs n [v] ds -> addErr l n (mkClass scs n v ds)
HsTypeSig l [n] (HsQualType cx t) -> addErr l n (mkSignature cx n t)
HsClassDecl l _ n [] _ -> addErr l n (throwError missingVar)
HsClassDecl l _ n (_:_:_) _ -> addErr l n (throwError noMultiParam)
missingVar = pp "Missing type variable to be constrained by type class."
noMultiParam = pp "Multi-parameter type classes are not allowed."
addErr :: SrcLoc -> HsName -> ErrorOr S.Declaration-> ErrorOr S.Declaration
addErr loc name e = case getError e of
Nothing -> e
Just doc -> throwError $
pp ("In the declaration of `" ++ hsNameToString name
++ "' at (" ++ show (srcLine loc) ++ ":"
++ show (srcColumn loc) ++ "):")
$$ nest 2 doc
mkType :: HsName -> [HsName] -> HsType -> ErrorOr S.Declaration
mkType name vars ty = do
ident <- mkIdentifier name
tvs <- mapM mkTypeVariable vars
t <- mkTypeExpression ty
return (S.TypeDecl (S.Type ident tvs t))
mkData :: HsName -> [HsName] -> [HsConDecl] -> ErrorOr S.Declaration
mkData name vars cons = do
ident <- mkIdentifier name
tvs <- mapM mkTypeVariable vars
ds <- mapM mkDataConstructorDeclaration cons
return (S.DataDecl (S.Data ident tvs ds))
mkDataConstructorDeclaration ::
HsConDecl -> ErrorOr S.DataConstructorDeclaration
mkDataConstructorDeclaration (HsConDecl _ name btys) = mkDataConDecl name btys
mkDataConstructorDeclaration (HsRecDecl _ name rbtys) =
let btys = concatMap (\(l,ty) -> replicate (length l) ty) rbtys
in mkDataConDecl name btys
mkDataConDecl ::
HsName -> [HsBangType] -> ErrorOr S.DataConstructorDeclaration
mkDataConDecl name btys = do
ident <- mkIdentifier name
bts <- mapM mkBangTyEx btys
return (S.DataCon ident bts)
where
mkBangTyEx (HsBangedTy ty) = liftM S.Banged (mkTypeExpression ty)
mkBangTyEx (HsUnBangedTy ty) = liftM S.Unbanged (mkTypeExpression ty)
mkNewtype :: HsName -> [HsName] -> HsConDecl -> ErrorOr S.Declaration
mkNewtype name vars con = do
ident <- mkIdentifier name
tvs <- mapM mkTypeVariable vars
(con,t) <- mkNewtypeConDecl con
return (S.NewtypeDecl (S.Newtype ident tvs con t))
where
mkNewtypeConDecl (HsConDecl _ c bts) = mkNCD c bts
mkNewtypeConDecl (HsRecDecl _ c bts) = mkNCD c (snd $ unzip bts)
mkNCD c [bty] = liftM2 (,) (mkIdentifier c) (bang bty)
mkNCD c [] = throwError errNewtype
mkNCD c (_:_:_) = throwError errNewtype
errNewtype =
pp "A `newtype' declaration must have exactly one type expression."
bang (HsUnBangedTy ty) = mkTypeExpression ty
bang (HsBangedTy ty) =
throwError (pp "A `newtype' declaration must not use a strictness flag.")
mkClass :: HsContext -> HsName -> HsName -> [HsDecl] -> ErrorOr S.Declaration
mkClass ctx name var decls = do
ident <- mkIdentifier name
tv <- mkTypeVariable var
superCs <- mkContext ctx >>= check tv
sigs <- liftM (map toSig) (mapM mkDeclaration (filter isSig decls))
return (S.ClassDecl (S.Class superCs ident tv sigs))
where
isSig :: HsDecl -> Bool
isSig decl = case decl of
HsTypeSig _ _ _ -> True
otherwise -> False
toSig :: S.Declaration -> S.Signature
toSig (S.TypeSig s) = s
check ::
S.TypeVariable
-> [(S.TypeClass, S.TypeVariable)]
-> ErrorOr [S.TypeClass]
check tv@(S.TV (S.Ident v)) ctx =
let (tcs, tvs) = unzip ctx
in if null (filter (/= tv) tvs)
then return tcs
else throwError (errClass v)
errClass v =
pp $ "Only `" ++ v ++ "' can be constrained by the superclasses."
mkSignature :: HsContext -> HsName -> HsType -> ErrorOr S.Declaration
mkSignature ctx var ty = do
context <- mkContext ctx
ident <- mkIdentifier var
t <- mkTypeExpression ty
return $ S.TypeSig (S.Signature ident (merge context t))
where
merge ::
[(S.TypeClass, S.TypeVariable)]
-> S.TypeExpression
-> S.TypeExpression
merge ctx t =
let
vars = (nub . snd . unzip) ctx
classes v = map fst (filter ((==) v . snd) ctx)
in foldr (\v -> S.TypeAbs v (classes v)) t vars
mkContext :: HsContext -> ErrorOr [(S.TypeClass, S.TypeVariable)]
mkContext = mapM trans
where
trans (qname, tys) = case tys of
[HsTyVar var] -> do ident <- liftM S.TC (mkIdentifierQ qname)
tv <- mkTypeVariable var
return $ (ident, tv)
otherwise -> throwError errContext
errContext =
pp "Only a type variable may be constrained by a class in a context."
mkTypeExpression :: HsType -> ErrorOr S.TypeExpression
mkTypeExpression (HsTyVar var) = liftM S.TypeVar (mkTypeVariable var)
mkTypeExpression (HsTyApp ty1 ty2) = mkAppTyEx ty1 [ty2]
mkTypeExpression (HsTyCon qname) = mkTypeConstructorApp qname []
mkTypeExpression (HsTyFun ty1 ty2) = do
t1 <- mkTypeExpression ty1
t2 <- mkTypeExpression ty2
return (S.TypeFun t1 t2)
mkTypeExpression (HsTyTuple tys) = do
ts <- mapM mkTypeExpression tys
return (S.TypeCon (S.ConTuple (length ts)) ts)
mkAppTyEx :: HsType -> [HsType] -> ErrorOr S.TypeExpression
mkAppTyEx ty tys = case ty of
HsTyFun _ _ -> throwError $ pp ("A function type must not be applied to a "
++ "type.")
HsTyTuple _ -> throwError (pp "A tuple type must not be applied to a type.")
HsTyVar _ -> throwError (pp "A variable must not be applied to a type.")
HsTyApp t1 t2 -> mkAppTyEx t1 (t2 : tys)
HsTyCon qname -> mapM mkTypeExpression tys >>= mkTypeConstructorApp qname
mkTypeConstructorApp ::
HsQName
-> [S.TypeExpression]
-> ErrorOr S.TypeExpression
mkTypeConstructorApp (Special HsFunCon) [t1,t2] = return $ S.TypeFun t1 t2
mkTypeConstructorApp (Special HsFunCon) _ = throwError errorTypeConstructorApp
mkTypeConstructorApp qname ts =
liftM (\con -> S.TypeCon con ts) (mkTypeConstructor qname)
errorTypeConstructorApp =
pp "The function type constructor `->' must be applied to exactly two types."
mkTypeConstructor :: HsQName -> ErrorOr S.TypeConstructor
mkTypeConstructor (Qual (Module mod) hsName) =
if mod == "Prelude"
then return (asCon hsName)
else return (S.Con $ hsNameToIdentifier hsName)
mkTypeConstructor (UnQual hsName) = return $ asCon hsName
mkTypeConstructor (Special HsUnitCon) = return $ S.ConUnit
mkTypeConstructor (Special HsListCon) = return $ S.ConList
mkTypeConstructor (Special (HsTupleCon n)) = return $ S.ConTuple n
asCon :: HsName -> S.TypeConstructor
asCon name = case name of
HsIdent "Int" -> S.ConInt
HsIdent "Integer" -> S.ConInteger
HsIdent "Float" -> S.ConFloat
HsIdent "Double" -> S.ConDouble
HsIdent "Char" -> S.ConChar
otherwise -> S.Con $ hsNameToIdentifier name
mkTypeVariable :: HsName -> ErrorOr S.TypeVariable
mkTypeVariable = return . S.TV . hsNameToIdentifier
mkIdentifierQ :: HsQName -> ErrorOr S.Identifier
mkIdentifierQ (UnQual hsName) = return (hsNameToIdentifier hsName)
mkIdentifierQ (Qual (Module _) hsName) = return (hsNameToIdentifier hsName)
mkIdentifierQ (Special HsUnitCon) = throwErrorIdentifierQ "`()'"
mkIdentifierQ (Special HsListCon) = throwErrorIdentifierQ "`[]'"
mkIdentifierQ (Special HsFunCon) = throwErrorIdentifierQ "`->'"
mkIdentifierQ (Special HsCons) = throwErrorIdentifierQ "`:'"
mkIdentifierQ (Special (HsTupleCon _)) = throwErrorIdentifierQ "for tuples"
throwErrorIdentifierQ s = throwError $ pp $
"The constructor " ++ s ++ " must not be used as an identifier."
mkIdentifier :: HsName -> ErrorOr S.Identifier
mkIdentifier = return . hsNameToIdentifier
hsNameToIdentifier :: HsName -> S.Identifier
hsNameToIdentifier = S.Ident . hsNameToString
hsNameToString :: HsName -> String
hsNameToString (HsIdent s) = s
hsNameToString (HsSymbol s) = "(" ++ s ++ ")"