module Language.Haskell.FreeTheorems.Frontend.CheckLocal (
checkLocal
, checkDataAndNewtypeDeclarations
) where
import Data.Generics (Data, everything, mkQ)
import Data.List (group, sort)
import Data.Maybe (mapMaybe, fromJust, isJust)
import qualified Data.Set as Set
( Set, union, empty, difference, fromList, null, elems, isSubsetOf
, singleton)
import Language.Haskell.FreeTheorems.Syntax
import Language.Haskell.FreeTheorems.Frontend.Error
import Language.Haskell.FreeTheorems.Frontend.TypeExpressions
checkLocal :: [Declaration] -> Checked [Declaration]
checkLocal = foldChecks checkDecl
where
checkDecl :: Declaration -> ErrorOr ()
checkDecl (DataDecl d) = checkDataDecl d
checkDecl (NewtypeDecl d) = checkNewtypeDecl d
checkDecl (TypeDecl d) = checkTypeDecl d
checkDecl (ClassDecl d) = checkClassDecl d
checkDecl (TypeSig sig) = checkSignature sig
checkDataDecl :: DataDeclaration -> ErrorOr ()
checkDataDecl d =
inDecl (DataDecl d) $ do
checkNotPrimitive (dataName d)
checkVariables (dataVars d)
(everything Set.union
(Set.empty `mkQ` (freeTypeVariables . withoutBang))
(dataCons d))
checkNotEmpty (dataCons d)
mapM_ (checkNotNested (dataName d) (map TypeVar (dataVars d)))
(conNamesAndTypes d)
mapM_ (checkNoFixedTEsNamed "data constructor") (conNamesAndTypes d)
where
conNamesAndTypes =
map (makePair dataConName (map withoutBang . dataConTypes)) . dataCons
checkNewtypeDecl :: NewtypeDeclaration -> ErrorOr ()
checkNewtypeDecl d =
inDecl (NewtypeDecl d) $ do
checkNotPrimitive (newtypeName d)
checkVariables (newtypeVars d) (freeTypeVariables $ newtypeRhs d)
checkNotNested (newtypeName d) (map TypeVar (newtypeVars d)) (conAndType d)
checkNoFixedTEsNamed "data constructor" (conAndType d)
where
conAndType = makePair newtypeCon (singletonList . newtypeRhs)
checkTypeDecl :: TypeDeclaration -> ErrorOr ()
checkTypeDecl d =
inDecl (TypeDecl d) $ do
checkNotPrimitive (typeName d)
checkVariables (typeVars d) (freeTypeVariables $ typeRhs d)
checkTypeDeclNotRecursive (typeName d) (typeRhs d)
checkNoFixedTEs (typeRhs d)
checkClassDecl :: ClassDeclaration -> ErrorOr ()
checkClassDecl d =
inDecl (ClassDecl d) $ do
checkNotPrimitive (className d)
checkClassMethodsDistinct (map signatureName . classFuns $ d)
checkClassVarInMethods (classVar d) (classFuns d)
checkClassDeclNotRecursive (className d) (classFuns d)
mapM_ (checkNoFixedTEsNamed "class method")
(map (makePair signatureName (singletonList . signatureType))
(classFuns d))
checkSignature :: Signature -> ErrorOr ()
checkSignature s =
inDecl (TypeSig s) $ do
checkNoFixedTEs (signatureType s)
checkDataAndNewtypeDeclarations :: [Declaration] -> Checked [Declaration]
checkDataAndNewtypeDeclarations = foldChecks checkDN
where
checkDN :: Declaration -> ErrorOr ()
checkDN d = case d of
DataDecl d' -> inDecl d (mapM_ checkAbsFun (dataConsAndTypes d'))
NewtypeDecl d' -> inDecl d (checkAbsFun (newtypeConAndType d'))
otherwise -> return ()
dataConsAndTypes =
map (makePair dataConName (map withoutBang . dataConTypes)) . dataCons
newtypeConAndType = makePair newtypeCon (singletonList . newtypeRhs)
checkNotPrimitive :: Identifier -> ErrorOr ()
checkNotPrimitive (Ident name) =
errorIf (name `elem` ["Int", "Integer", "Float", "Double", "Char"]) $
pp ("A primitive type must not have a declaration.")
checkVariables :: [TypeVariable] -> Set.Set TypeVariable -> ErrorOr ()
checkVariables vs rvs = do
let es = extractRepeatingElements vs
errorIf (not $ null es) $
pp ("Type variables must not be given more than once on the left-hand "
++ "side of a declaration. "
++ violating "variable" (map varName $ es))
let set = rvs `Set.difference` Set.fromList vs
errorIf (not (Set.null set)) $
pp ("Type variables occurring on the right-hand side of a declaration must "
++ "be declared on the left-hand side. "
++ violating "variable" (map varName . Set.elems $ set))
where
varName (TV v) = unpackIdent v
checkNotEmpty :: [DataConstructorDeclaration] -> ErrorOr ()
checkNotEmpty cons =
errorIf (null cons) $
pp ("The declaration of an algebraic data type must have at least one "
++ "data constructor.")
checkNotNested ::
Identifier -> [TypeExpression] -> (Identifier, [TypeExpression])
-> ErrorOr ()
checkNotNested con vs (dcon, ts) =
errorIf (any (satisfiesSomewhere isNested) ts) $
pp ("Declarations must not be nested."
++ violating "data constructor" [unpackIdent dcon])
where
isNested t = case t of
TypeCon (Con c) ts -> c == con && ts /= vs
otherwise -> False
checkTypeDeclNotRecursive :: Identifier -> TypeExpression -> ErrorOr ()
checkTypeDeclNotRecursive ident t =
errorIf (satisfiesSomewhere (isCon ident) t) $
pp ("A type synonym must not be declared recursively.")
where
isCon ident t = case t of
TypeCon (Con c) _ -> c == ident
otherwise -> False
checkClassMethodsDistinct :: [Identifier] -> ErrorOr ()
checkClassMethodsDistinct is =
let es = extractRepeatingElements is
in errorIf (not $ null es) $
pp ("Class methods must not be declared more than once. "
++ violating "class method" (map unpackIdent es))
checkClassVarInMethods :: TypeVariable -> [Signature] -> ErrorOr ()
checkClassVarInMethods v@(TV vName) ss =
let setOfv = Set.singleton v
vIsFreeIn t = setOfv `Set.isSubsetOf` freeTypeVariables t
ms = filter (not . vIsFreeIn . signatureType) ss
in errorIf (not $ null ms) $
pp ("The type variable `" ++ unpackIdent vName ++ "' must occur free "
++ "in the type expressions of every class method. "
++ violating "class method" (map (unpackIdent . signatureName) ms))
checkClassDeclNotRecursive :: Identifier -> [Signature] -> ErrorOr ()
checkClassDeclNotRecursive ident sigs =
let hasThisClass = satisfiesSomewhere (\c -> c == TC ident)
ms = filter (hasThisClass . signatureType) sigs
in errorIf (not $ null ms) $
pp ("The type class `" ++ unpackIdent ident ++ "' must not occur in a "
++ "type expression of any class method of this class. "
++ violating "class method" (map (unpackIdent . signatureName) ms))
checkNoFixedTEsNamed :: String -> (Identifier, [TypeExpression]) -> ErrorOr ()
checkNoFixedTEsNamed tag (con, ts) =
let es = mapMaybe checkNoFixedTEsPlain ts
in errorIf (not . null $ es) $
pp (head es ++ violating tag [unpackIdent con])
checkNoFixedTEs :: TypeExpression -> ErrorOr ()
checkNoFixedTEs t =
let e = checkNoFixedTEsPlain t
in errorIf (isJust e) (pp . fromJust $ e)
checkNoFixedTEsPlain :: TypeExpression -> Maybe String
checkNoFixedTEsPlain t =
if (satisfiesSomewhere isFixedTE t)
then Just "A fixed type expression must not occur in a type expression."
else Nothing
where
isFixedTE t = case t of
TypeExp _ -> True
otherwise -> False
checkAbsFun :: (Identifier, [TypeExpression]) -> ErrorOr ()
checkAbsFun (con, ts) =
errorIf (satisfiesSomewhere isAbsOrFun ts) $
pp ("Algebraic data types and type renamings must be declared without type "
++ "abstractions and function type constructors occurring on the "
++ "right-hand side."
++ violating "data constructor" [unpackIdent con])
where
isAbsOrFun t = case t of
TypeFun _ _ -> True
TypeAbs _ _ _ -> True
otherwise -> False
makePair :: (a -> b) -> (a -> c) -> a -> (b, c)
makePair f g x = (f x, g x)
singletonList :: a -> [a]
singletonList x = [x]
extractRepeatingElements :: Ord a => [a] -> [a]
extractRepeatingElements =
map head . filter (\vs -> length vs > 1) . group . sort
satisfiesSomewhere :: (Data a, Data b) => (a -> Bool) -> b -> Bool
satisfiesSomewhere predicate x = everything (||) (False `mkQ` predicate) x