module CSPM.TypeChecker.Decl (typeCheckDecls) where
import Control.Monad
import Control.Monad.Trans
import Data.Graph.Wrapper
import qualified Data.Map as M
import qualified Data.Set as S
import Data.List (nub, intersect, (\\), sortBy)
import CSPM.DataStructures.Names
import CSPM.DataStructures.Syntax hiding (getType)
import CSPM.DataStructures.Types
import CSPM.PrettyPrinter
import CSPM.TypeChecker.BuiltInFunctions
import CSPM.TypeChecker.Common
import CSPM.TypeChecker.Dependencies
import CSPM.TypeChecker.Exceptions
import CSPM.TypeChecker.Expr
import CSPM.TypeChecker.Monad
import CSPM.TypeChecker.Pat
import CSPM.TypeChecker.Unification
import Util.Annotated
import Util.List
import Util.Monad
import Util.PartialFunctions
import Util.PrettyPrint
typeCheckDecls :: [PDecl] -> TypeCheckMonad ()
typeCheckDecls decls = do
namesBoundByDecls <- mapM (\ decl -> do
namesBound <- namesBoundByDecl decl
return (decl, namesBound)) decls
let
declMap = zip decls [0..]
invDeclMap = invert declMap
varToDeclIdMap =
[(n, apply declMap d) | (d, ns) <- namesBoundByDecls, n <- ns]
boundVars = map fst varToDeclIdMap
namesToLocations = [(n, loc d) | (d, ns) <- namesBoundByDecls, n <- ns]
manyErrorsIfFalse (noDups boundVars)
(duplicatedDefinitionsMessage namesToLocations)
mapM_ registerChannelsAndDataTypes (map unAnnotate decls)
declDeps <- mapM (\ decl -> do
deps <- dependencies decl
let depsInThisGroup = intersect deps boundVars
return (apply declMap decl, mapPF varToDeclIdMap depsInThisGroup)
) decls
let
declGraph :: Graph Int Int
declGraph = fromListSimple [(id, deps) | (id, deps) <- declDeps]
sccgraph :: Graph (S.Set Int) (M.Map Int Int)
sccgraph = transpose (sccGraph declGraph)
sccs :: [S.Set Int]
sccs = topologicalSort sccgraph
typeInferenceGroup = mapPF invDeclMap
typeCheckGroups [] b = if b then failM else return ()
typeCheckGroups (g:gs) b = do
err <- tryAndRecover (do
typeCheckMutualyRecursiveGroup (typeInferenceGroup (S.toList g))
return False
) (return True)
if not err then typeCheckGroups gs b
else typeCheckGroups (gs \\ (reachableVertices sccgraph g)) True
typeCheckGroups sccs False
registerChannelsAndDataTypes :: Decl -> TypeCheckMonad ()
registerChannelsAndDataTypes (DataType n cs) = do
mapM_ (\ c -> case unAnnotate c of
DataTypeClause n' _ -> addDataTypeOrChannel n'
) cs
registerChannelsAndDataTypes (Channel ns _) =
mapM_ addDataTypeOrChannel ns
registerChannelsAndDataTypes _ = return ()
typeCheckMutualyRecursiveGroup :: [PDecl] -> TypeCheckMonad ()
typeCheckMutualyRecursiveGroup ds' = do
let
cmp x y = case (unAnnotate x, unAnnotate y) of
(DataType _ _, DataType _ _) -> EQ
(DataType _ _, _) -> LT
(_, DataType _ _) -> GT
(_, _) -> EQ
ds = sortBy cmp ds'
fvs <- liftM nub (concatMapM namesBoundByDecl ds)
ftvs <- replicateM (length fvs) freshTypeVar
zipWithM setType fvs (map (ForAll []) ftvs)
fvs <- liftM nub (concatMapM namesBoundByDecl ds)
nts <- generaliseGroup fvs (map typeCheck ds)
zipWithM annotate nts ds
mapM_ (\ n -> do
t <- getType n
t' <- compressTypeScheme t
setType n t') fvs
where
annotate nts (An _ psymbtable _) = setPSymbolTable (snd psymbtable) nts
evalTypeExpression :: Type -> TypeCheckMonad Type
evalTypeExpression (TTuple ts) = do
ts' <- mapM evalTypeExpression ts
return $ TTuple ts'
evalTypeExpression (TDot t1 t2) = do
t1' <- evalTypeExpression t1
t2' <- evalTypeExpression t2
return $ TDot t1' t2'
evalTypeExpression t = do
fv <- freshTypeVar
unify t (TSet fv)
return fv
instance TypeCheckable PDecl [(Name, Type)] where
errorContext an = Nothing
typeCheck' an = setSrcSpan (loc an) $ typeCheck (inner an)
instance TypeCheckable Decl [(Name, Type)] where
errorContext (FunBind n ms) = Just $
text "In the declaration of:" <+> prettyPrint n
errorContext (p@(PatBind pat exp)) = Just $
hang (text "In a pattern binding:") tabWidth (prettyPrint p)
errorContext (DataType n cs) = Just $
text "In the declaration of:" <+> prettyPrint n
errorContext (NameType n e) = Just $
text "In the declaration of:" <+> prettyPrint n
errorContext (Channel ns es) = Just $
text "In the declaration of:" <+> list (map prettyPrint ns)
errorContext (Assert a) = Just $
text "In the assertion:" <+> prettyPrint a
errorContext (Transparent ns) = Nothing
errorContext (External ns) = Nothing
typeCheck' (FunBind n ms) = do
ts <- mapM (\ m -> addErrorContext (matchCtxt m) $ typeCheck m) ms
ForAll [] t <- getType n
(t' @ (TFunction tsargs _)) <- unifyAll (t:ts)
return [(n, t')]
where
matchCtxt an =
hang (text "In an equation for" <+> prettyPrint n <> colon)
tabWidth (prettyPrintMatch n an)
typeCheck' (PatBind pat exp) = do
tpat <- typeCheck pat
texp <- typeCheck exp
tpat <- evaluateDots tpat
texp <- evaluateDots texp
disallowSymmetricUnification (unify texp tpat)
ns <- namesBoundByDecl' (PatBind pat exp)
ts <- mapM getType ns
return $ zip ns [t | ForAll _ t <- ts]
typeCheck' (Channel ns Nothing) = do
mapM (\ n -> do
ForAll [] t <- getType n
unify TEvent t) ns
return [(n, TEvent) | n <- ns]
typeCheck' (Channel ns (Just e)) = do
t <- typeCheck e
valueType <- evalTypeExpression t
dotList <- typeToDotList valueType
let t = foldr TDotable TEvent dotList
mapM (\ n -> do
ForAll [] t' <- getType n
unify t' t) ns
return $ [(n, t) | n <- ns]
typeCheck' (DataType n clauses) = do
nts <- mapM (\ clause -> do
let
n' = case unAnnotate clause of
DataTypeClause x _ -> x
ForAll [] t <- getType n'
(n', ts) <- typeCheck clause
let texp = foldr TDotable (TDatatype n) ts
t <- unify texp t
return (n', t)
) clauses
ForAll [] t <- getType n
t' <- unify t (TSet (TDatatype n))
return $ (n, t'):nts
typeCheck' (NameType n e) = do
t <- typeCheck e
valueType <- evalTypeExpression t
return [(n, TSet valueType)]
typeCheck' (Transparent ns) = do
mapM_ (\ (n@(Name s)) -> do
texp <- applyPFOrError (transparentFunctionNotRecognised n)
transparentFunctions s
ForAll [] t <- getType n
unify texp t) ns
return []
typeCheck' (External ns) = do
mapM_ (\ (n@(Name s)) -> do
texp <- applyPFOrError (externalFunctionNotRecognised n)
externalFunctions s
ForAll [] t <- getType n
unify texp t) ns
return []
typeCheck' (Assert a) = typeCheck a >> return []
instance TypeCheckable Assertion () where
errorContext a = Just $
hang (text "In the assertion" <> colon) tabWidth (prettyPrint a)
typeCheck' (PropertyCheck e1 p m) = do
ensureIsProc e1
return ()
typeCheck' (Refinement e1 m e2 opts) = do
ensureIsProc e1
ensureIsProc e2
mapM_ typeCheck opts
instance TypeCheckable ModelOption () where
errorContext a = Nothing
typeCheck' (TauPriority e) = do
typeCheckExpect e (TSet TEvent)
return ()
instance TypeCheckable PDataTypeClause (Name, [Type]) where
errorContext an = Nothing
typeCheck' an = setSrcSpan (loc an) $ typeCheck (inner an)
instance TypeCheckable DataTypeClause (Name, [Type]) where
errorContext c = Just $
hang (text "In the data type clause" <> colon) tabWidth
(prettyPrint c)
typeCheck' (DataTypeClause n' Nothing) = do
return (n', [])
typeCheck' (DataTypeClause n' (Just e)) = do
t <- typeCheck e
valueType <- evalTypeExpression t
dotList <- typeToDotList valueType
return (n', dotList)
instance TypeCheckable PMatch Type where
errorContext an = Nothing
typeCheck' an = setSrcSpan (loc an) $ typeCheck (inner an)
instance TypeCheckable Match Type where
errorContext (Match groups exp) = Nothing
typeCheck' (Match groups exp) = do
fvs <- liftM concat (mapM freeVars groups)
local fvs $ do
tgroups <- mapM (\ pats -> mapM (\ pat ->
typeCheck pat >>= evaluateDots
) pats) groups
tr <- typeCheck exp >>= evaluateDots
tgroups <- mapM (\ pats -> mapM (\ pat ->
typeCheck pat >>= evaluateDots
) pats) groups
return $ foldr (\ targs tr -> TFunction targs tr) tr tgroups
applyPFOrError
:: Eq a => Error -> PartialFunction a b -> a -> TypeCheckMonad b
applyPFOrError err pf a =
case safeApply pf a of
Just a -> return a
Nothing -> raiseMessageAsError err