-- | Defines local checks, i.e. checks which only look at one declaration at a -- time. module Language.Haskell.FreeTheorems.Frontend.CheckLocal ( checkLocal , checkDataAndNewtypeDeclarations ) where import Data.Generics (Data, everything, mkQ) import Data.List (group, sort) import Data.Maybe (mapMaybe, fromJust, isJust) import qualified Data.Set as Set ( Set, union, empty, difference, fromList, null, elems, isSubsetOf , singleton) import Language.Haskell.FreeTheorems.Syntax import Language.Haskell.FreeTheorems.Frontend.Error import Language.Haskell.FreeTheorems.Frontend.TypeExpressions ------- Local checks ---------------------------------------------------------- -- | Check validity of every declaration. -- This includes ensuring that fixed type expressions occur nowhere, that only -- declared type variables occur in right-hand sides and that no primitive -- type is declared, among other restrictions. -- -- Local checks comprise all those which can be down by just looking at a -- single declaration. checkLocal :: [Declaration] -> Checked [Declaration] checkLocal = foldChecks checkDecl where checkDecl :: Declaration -> ErrorOr () checkDecl (DataDecl d) = checkDataDecl d checkDecl (NewtypeDecl d) = checkNewtypeDecl d checkDecl (TypeDecl d) = checkTypeDecl d checkDecl (ClassDecl d) = checkClassDecl d checkDecl (TypeSig sig) = checkSignature sig -- | Checks a @data@ declaration. The following restrictions must hold: -- -- * The declared type constructor is not a primitive type. -- * The variables occurring on the right-hand side have to be mentioned on -- the left-hand side, and the left-hand side variables are pairwise -- distinct. -- * The declaration is not nested, i.e. if the declared type constructor -- occurs on the right-hand side, it has only type variables as arguments. -- * No fixed type expression occurs in any type expression. checkDataDecl :: DataDeclaration -> ErrorOr () checkDataDecl d = inDecl (DataDecl d) $ do checkNotPrimitive (dataName d) checkVariables (dataVars d) (everything Set.union (Set.empty `mkQ` (freeTypeVariables . withoutBang)) (dataCons d)) checkNotEmpty (dataCons d) mapM_ (checkNotNested (dataName d) (map TypeVar (dataVars d))) (conNamesAndTypes d) mapM_ (checkNoFixedTEsNamed "data constructor") (conNamesAndTypes d) where conNamesAndTypes = map (makePair dataConName (map withoutBang . dataConTypes)) . dataCons -- | Checks a @newtype@ declaration. The following restrictions must hold: -- -- * The declared type constructor is not a primitive type. -- * The variables occurring on the right-hand side have to be mentioned on -- the left-hand side, and the left-hand side variables are pairwise -- distinct. -- * The declaration is not nested, i.e. if the declared type constructor -- occurs on the right-hand side, it has only type variables as arguments. -- * No fixed type expression occurs in the right-hand side type expression. checkNewtypeDecl :: NewtypeDeclaration -> ErrorOr () checkNewtypeDecl d = inDecl (NewtypeDecl d) $ do checkNotPrimitive (newtypeName d) checkVariables (newtypeVars d) (freeTypeVariables $ newtypeRhs d) checkNotNested (newtypeName d) (map TypeVar (newtypeVars d)) (conAndType d) checkNoFixedTEsNamed "data constructor" (conAndType d) where conAndType = makePair newtypeCon (singletonList . newtypeRhs) -- | Checks a @type@ declaration. The following restrictions must hold: -- -- * The declared type constructor is not a primitive type. -- * The variables occurring on the right-hand side have to be mentioned on -- the left-hand side, and the left-hand side variables are pairwise -- distinct. -- * The declaration must not be recursive, i.e. the type constructor declared -- by this declaration must not occur on th right-hand side. -- * No fixed type expression occurs in the right-hand side type expression. checkTypeDecl :: TypeDeclaration -> ErrorOr () checkTypeDecl d = inDecl (TypeDecl d) $ do checkNotPrimitive (typeName d) checkVariables (typeVars d) (freeTypeVariables $ typeRhs d) checkTypeDeclNotRecursive (typeName d) (typeRhs d) checkNoFixedTEs (typeRhs d) -- | Checks a @class@ declaration. The following restrictions must hold: -- -- * The declared type class does not equal a primitive type. -- * The names of the class methods are pairwise distinct. -- * The class variable occurs in the type expression of every class method. -- * The name of the class does not occur in a type expression of any class -- method. -- * No fixed type expression occurs in a type expression of any class method. checkClassDecl :: ClassDeclaration -> ErrorOr () checkClassDecl d = inDecl (ClassDecl d) $ do checkNotPrimitive (className d) checkClassMethodsDistinct (map signatureName . classFuns $ d) checkClassVarInMethods (classVar d) (classFuns d) checkClassDeclNotRecursive (className d) (classFuns d) mapM_ (checkNoFixedTEsNamed "class method") (map (makePair signatureName (singletonList . signatureType)) (classFuns d)) -- | Checks a type signature. The following restrictions must hold: -- -- * No fixed type expressions occurs in the type expression of this type -- signature. checkSignature :: Signature -> ErrorOr () checkSignature s = inDecl (TypeSig s) $ do checkNoFixedTEs (signatureType s) ------- Special checks for data and newtype declarations ---------------------- -- | Check data and newtype declarations for occurring function type -- constructors or type abstraction constructors. If any declaration contains -- one of these, an error message is created. All other declarations are -- passed. checkDataAndNewtypeDeclarations :: [Declaration] -> Checked [Declaration] checkDataAndNewtypeDeclarations = foldChecks checkDN where checkDN :: Declaration -> ErrorOr () checkDN d = case d of DataDecl d' -> inDecl d (mapM_ checkAbsFun (dataConsAndTypes d')) NewtypeDecl d' -> inDecl d (checkAbsFun (newtypeConAndType d')) otherwise -> return () dataConsAndTypes = map (makePair dataConName (map withoutBang . dataConTypes)) . dataCons newtypeConAndType = makePair newtypeCon (singletonList . newtypeRhs) ------- Checking restrictions ------------------------------------------------- -- | Checks if the given identifier is not a name of a primitive type. -- Otherwise, an error message is created. checkNotPrimitive :: Identifier -> ErrorOr () checkNotPrimitive (Ident name) = errorIf (name `elem` ["Int", "Integer", "Float", "Double", "Char"]) $ pp ("A primitive type must not have a declaration.") -- | Checks if the second argument set is contained in the first argument list. -- If not, an error message is returned. -- -- Checks also if first argument contains pairwise distinct variables. -- If not, an error message is returned. checkVariables :: [TypeVariable] -> Set.Set TypeVariable -> ErrorOr () checkVariables vs rvs = do let es = extractRepeatingElements vs errorIf (not $ null es) $ pp ("Type variables must not be given more than once on the left-hand " ++ "side of a declaration. " ++ violating "variable" (map varName $ es)) let set = rvs `Set.difference` Set.fromList vs errorIf (not (Set.null set)) $ pp ("Type variables occurring on the right-hand side of a declaration must " ++ "be declared on the left-hand side. " ++ violating "variable" (map varName . Set.elems $ set)) where varName (TV v) = unpackIdent v -- | Checks that there is at least one data constructor declaration in the the -- declaration of an algebraic data type. checkNotEmpty :: [DataConstructorDeclaration] -> ErrorOr () checkNotEmpty cons = errorIf (null cons) $ pp ("The declaration of an algebraic data type must have at least one " ++ "data constructor.") -- | Checks if the identifiers occurs in any of the given type expressions as -- a type constructor. If so, and if the identifier is applied not only to -- type variables, it is called nested and an error message is created. checkNotNested :: Identifier -> [TypeExpression] -> (Identifier, [TypeExpression]) -> ErrorOr () checkNotNested con vs (dcon, ts) = errorIf (any (satisfiesSomewhere isNested) ts) $ pp ("Declarations must not be nested." ++ violating "data constructor" [unpackIdent dcon]) where isNested t = case t of TypeCon (Con c) ts -> c == con && ts /= vs otherwise -> False -- | Checks if a type declaration is recursive, i.e. the identifier occurs in -- the given type expression as a type constructor. -- If so, an error message is created. checkTypeDeclNotRecursive :: Identifier -> TypeExpression -> ErrorOr () checkTypeDeclNotRecursive ident t = errorIf (satisfiesSomewhere (isCon ident) t) $ pp ("A type synonym must not be declared recursively.") where isCon ident t = case t of TypeCon (Con c) _ -> c == ident otherwise -> False -- | Checks that the names of class methods are pairwise distinct. -- If not, an error message is created. checkClassMethodsDistinct :: [Identifier] -> ErrorOr () checkClassMethodsDistinct is = let es = extractRepeatingElements is in errorIf (not $ null es) $ pp ("Class methods must not be declared more than once. " ++ violating "class method" (map unpackIdent es)) -- | Checks if the given identifier occurs as free type variable in every -- signature. If not, an error message is created. checkClassVarInMethods :: TypeVariable -> [Signature] -> ErrorOr () checkClassVarInMethods v@(TV vName) ss = let setOfv = Set.singleton v vIsFreeIn t = setOfv `Set.isSubsetOf` freeTypeVariables t ms = filter (not . vIsFreeIn . signatureType) ss in errorIf (not $ null ms) $ pp ("The type variable `" ++ unpackIdent vName ++ "' must occur free " ++ "in the type expressions of every class method. " ++ violating "class method" (map (unpackIdent . signatureName) ms)) -- | Checks that the name of a type class does not occur in any of the class -- methods. Otherwise, an error message is created. checkClassDeclNotRecursive :: Identifier -> [Signature] -> ErrorOr () checkClassDeclNotRecursive ident sigs = let hasThisClass = satisfiesSomewhere (\c -> c == TC ident) ms = filter (hasThisClass . signatureType) sigs in errorIf (not $ null ms) $ pp ("The type class `" ++ unpackIdent ident ++ "' must not occur in a " ++ "type expression of any class method of this class. " ++ violating "class method" (map (unpackIdent . signatureName) ms)) -- | Checks that no FixedTypeExpression occurs in the given list of named -- type expressions. The first argument is used in generating a helpful error -- message and describes what kind of items the second argument contains. checkNoFixedTEsNamed :: String -> (Identifier, [TypeExpression]) -> ErrorOr () checkNoFixedTEsNamed tag (con, ts) = let es = mapMaybe checkNoFixedTEsPlain ts in errorIf (not . null $ es) $ pp (head es ++ violating tag [unpackIdent con]) -- | Checks that no FixedTypeExpression occurs in a type expression. -- If it does, an error message is created. checkNoFixedTEs :: TypeExpression -> ErrorOr () checkNoFixedTEs t = let e = checkNoFixedTEsPlain t in errorIf (isJust e) (pp . fromJust $ e) -- | Returns an error if a FixedTypeExpression occurs in the argument, otherwise -- returns @Nothing@. checkNoFixedTEsPlain :: TypeExpression -> Maybe String checkNoFixedTEsPlain t = if (satisfiesSomewhere isFixedTE t) then Just "A fixed type expression must not occur in a type expression." else Nothing where isFixedTE t = case t of TypeExp _ -> True otherwise -> False -- | Checks that no function type constructor and no type abstraction -- constructor occur in the given named list of type expressions. checkAbsFun :: (Identifier, [TypeExpression]) -> ErrorOr () checkAbsFun (con, ts) = errorIf (satisfiesSomewhere isAbsOrFun ts) $ pp ("Algebraic data types and type renamings must be declared without type " ++ "abstractions and function type constructors occurring on the " ++ "right-hand side." ++ violating "data constructor" [unpackIdent con]) where isAbsOrFun t = case t of TypeFun _ _ -> True TypeAbs _ _ _ -> True otherwise -> False ------- Helper functions ------------------------------------------------------ -- | Applies two functions to a value and creates a pair of the results. makePair :: (a -> b) -> (a -> c) -> a -> (b, c) makePair f g x = (f x, g x) -- | Creates a list containing just one element. singletonList :: a -> [a] singletonList x = [x] -- | Filters all elements which occur more than once in the given list. -- Only one representative is returned for every group of equal items. extractRepeatingElements :: Ord a => [a] -> [a] extractRepeatingElements = map head . filter (\vs -> length vs > 1) . group . sort -- | Tests if a predicate holds somewhere in an arbitrary tree. satisfiesSomewhere :: (Data a, Data b) => (a -> Bool) -> b -> Bool satisfiesSomewhere predicate x = everything (||) (False `mkQ` predicate) x