module CSPM.TypeChecker.Decl (typeCheckDecls) where
import Control.Monad
import Data.Graph.Wrapper
import qualified Data.Map as M
import qualified Data.Set as S
import Data.List (intersect, (\\), sortBy)
import CSPM.DataStructures.FreeVars
import CSPM.DataStructures.Names
import CSPM.DataStructures.Syntax hiding (getType)
import CSPM.DataStructures.Types
import CSPM.PrettyPrinter
import CSPM.TypeChecker.Common
import CSPM.TypeChecker.Exceptions
import CSPM.TypeChecker.Expr()
import CSPM.TypeChecker.Monad
import CSPM.TypeChecker.Pat()
import CSPM.TypeChecker.Unification
import Util.Annotated
import Util.List
import Util.PartialFunctions
import Util.PrettyPrint
typeCheckDecls :: [TCDecl] -> TypeCheckMonad ()
typeCheckDecls decls = do
let flattenDecl :: TCDecl -> [TCDecl]
flattenDecl (An a b (Module mn args ds1 ds2)) =
[An a b (Module mn args
(concatMap flattenDecl ds1)
(concatMap flattenDecl ds2))]
flattenDecl (d@(An _ _ (TimedSection _ _ ds))) =
d : concatMap flattenDecl ds
flattenDecl x = [x]
flatDecls = concatMap flattenDecl decls
isInstance (An _ _ (ModuleInstance _ _ _ _ _)) = True
isInstance _ = False
instanceDecls = filter isInstance flatDecls
declMap = zip flatDecls [0..]
invDeclMap = invert declMap
let
namesBoundByDecls = concatMap (\ (decl, declId) ->
case decl of
An _ _ (TimedSection _ _ _) -> []
_ -> [(declId, boundNames decl)]) declMap
varToDeclIdMap =
[(n, declId) | (declId, ns) <- namesBoundByDecls, n <- ns]
boundVars = map fst varToDeclIdMap
when (not (noDups boundVars)) $ panic "Duplicates found after renaming."
declDeps <- mapM (\ (decl, declId) -> do
let deps = freeVars decl
let depsInThisGroup = intersect deps boundVars
return (declId, mapPF varToDeclIdMap depsInThisGroup)
) declMap
let
declGraph :: Graph Int Int
declGraph = fromListSimple [(id, deps) | (id, deps) <- declDeps]
sccgraph :: Graph (S.Set Int) (M.Map Int Int)
sccgraph = transpose (sccGraph declGraph)
sccs :: [S.Set Int]
sccs = topologicalSort sccgraph
typeInferenceGroup = mapPF invDeclMap
checkSCCForModuleCycles :: [TCDecl] -> TypeCheckMonad ()
checkSCCForModuleCycles decls =
let
instances = [i | An _ _ (i@(ModuleInstance _ _ _ _ _)) <- decls]
mods = [m | An _ _ (m@(Module _ _ _ _)) <- decls]
instancesOfMod n =
[i | i@(ModuleInstance _ nt _ _ _) <- instances, nt == n]
checkMod (Module nm _ _ _) =
case instancesOfMod nm of
[] -> return ()
(ModuleInstance n _ _ instanceMap _ : is) -> do
mapM_ (\ n -> do
raiseMessageAsError $
illegalModuleInstanceCycleErrorMessage nm n
(mapPF (invert varToDeclIdMap)
(pathBetweenVerticies declGraph
(apply varToDeclIdMap nm)
(apply varToDeclIdMap n)))
) (map fst instanceMap)
pathBetweenVerticies :: Ord i => Graph i v -> i -> i -> [i]
pathBetweenVerticies g i i' =
let
findPath visited i =
let scs = successors g i
in if i' `elem` scs then ([i'], [])
else check visited scs
check visited [] = ([], visited)
check visited (x:xs) =
let (path, visited') = findPath visited x
in case path of
[] -> check visited' xs
_ -> (x : path, visited')
in i : fst (findPath [] i)
in mapM_ checkMod mods
typeCheckGroups [] b = if b then failM else return ()
typeCheckGroups (g:gs) b = do
err <- tryAndRecover True (do
let ds = typeInferenceGroup $ S.toList g
checkSCCForModuleCycles ds
typeCheckMutualyRecursiveGroup ds
return False
) (return True)
if not err then typeCheckGroups gs b
else typeCheckGroups (gs \\ (reachableVertices sccgraph g)) True
typeCheckGroups sccs False
let
annotate (decl@(An _ psymbtable (Module _ _ ds1 ds2))) = do
let ns = boundNames decl
ts <- mapM getType ns
setPSymbolTable (snd psymbtable) (zip ns ts)
mapM_ annotate (ds1++ds2)
annotate (decl@(An _ psymbtable (TimedSection _ _ ds))) = do
let ns = boundNames decl
ts <- mapM getType ns
setPSymbolTable (snd psymbtable) (zip ns ts)
mapM_ annotate ds
annotate (decl@(An _ psymbtable _)) = do
let ns = boundNames decl
ts <- mapM getType ns
setPSymbolTable (snd psymbtable) (zip ns ts)
mapM_ annotate decls
typeCheckMutualyRecursiveGroup :: [TCDecl] -> TypeCheckMonad ()
typeCheckMutualyRecursiveGroup ds' = do
let
cmp x y = case (unAnnotate x, unAnnotate y) of
(DataType _ _, DataType _ _) -> EQ
(DataType _ _, _) -> LT
(_, DataType _ _) -> GT
(_, _) -> EQ
ds = sortBy cmp ds'
fvs = boundNames ds
ftvs <- replicateM (length fvs) freshTypeVar
zipWithM setType fvs (map (ForAll []) ftvs)
nts <- generaliseGroup fvs $ map (\ d ->
case boundNames d of
[] -> typeCheck d
(n:_) -> addDefinitionName n (typeCheck d)) ds
mapM_ (\ n -> do
t <- getType n
t' <- compressTypeScheme t
setType n t') fvs
evalTypeExpression :: Type -> TypeCheckMonad Type
evalTypeExpression (TTuple ts) = do
ts' <- mapM evalTypeExpression ts
return $ TTuple ts'
evalTypeExpression (TDot t1 t2) = do
t1' <- evalTypeExpression t1
t2' <- evalTypeExpression t2
return $ TDot t1' t2'
evalTypeExpression t = do
fv <- freshTypeVar
unify t (TSet fv)
return fv
instance TypeCheckable TCDecl [(Name, Type)] where
errorContext an = Nothing
typeCheck' an = setSrcSpan (loc an) $ typeCheck (inner an)
instance TypeCheckable (Decl Name) [(Name, Type)] where
errorContext (FunBind n ms _) = Just $
text "In the declaration of:" <+> prettyPrint n
errorContext (p@(PatBind pat exp _)) = Just $
hang (text "In a pattern binding:") tabWidth (prettyPrint p)
errorContext (DataType n cs) = Just $
text "In the declaration of:" <+> prettyPrint n
errorContext (SubType n cs) = Just $
text "In the declaration of:" <+> prettyPrint n
errorContext (NameType n e) = Just $
text "In the declaration of:" <+> prettyPrint n
errorContext (Channel ns es) = Just $
text "In the declaration of:" <+> list (map prettyPrint ns)
errorContext (Assert a) = Just $
text "In the assertion:" <+> prettyPrint a
errorContext (TimedSection _ _ _) = Nothing
errorContext (Transparent ns) = Nothing
errorContext (External ns) = Nothing
errorContext (ModuleInstance n _ _ _ _) = Just $
text "In the declaration of the module instance:" <+> prettyPrint n
errorContext (Module _ _ _ _) = Nothing
errorContext (PrintStatement _) = Nothing
typeCheck' (FunBind n ms mta) = do
let boundTypeVars =
case mta of
Just (An _ _ (STypeScheme boundNs _ _)) -> boundNs
_ -> []
ts <- local boundTypeVars $ do
mta <- case mta of
Just ta -> do
ForAll _ t <- typeCheck ta
return $ Just t
Nothing -> return Nothing
mapM (\ m -> addErrorContext (matchCtxt m) $
case mta of
Nothing -> typeCheck m
Just ta -> typeCheckExpect m ta) ms
ForAll [] t <- getType n
(t' @ (TFunction tsargs _)) <- unifyAll (t:ts)
return [(n, t')]
where
matchCtxt an =
hang (text "In an equation for" <+> prettyPrint n <> colon)
tabWidth (prettyPrintMatch n an)
typeCheck' (p@(PatBind pat exp mta)) = do
(texp, tpat) <-
case mta of
Nothing -> do
tpat <- typeCheck pat
texp <- typeCheck exp
return (tpat, texp)
Just ta -> do
ForAll _ typ <- typeCheck ta
tpat <- typeCheckExpect pat typ
texp <- typeCheckExpect exp typ
return (tpat, texp)
tpat <- evaluateDots tpat
texp <- evaluateDots texp
disallowSymmetricUnification (unify texp tpat)
let ns = boundNames p
ts <- mapM getType ns
return $ zip ns [t | ForAll _ t <- ts]
typeCheck' (Channel ns Nothing) = do
mapM (\ n -> do
ForAll [] t <- getType n
unify TEvent t) ns
return [(n, TEvent) | n <- ns]
typeCheck' (Channel ns (Just e)) = do
t <- typeCheck e
ensureHasConstraint CEq t
valueType <- evalTypeExpression t
dotList <- typeToDotList valueType
let t = foldr TDotable TEvent dotList
mapM (\ n -> do
ForAll [] t' <- getType n
unify t' t) ns
return $ [(n, t) | n <- ns]
typeCheck' (SubType n clauses) = do
parentType <- freshTypeVar
mapM_ (\ clause -> do
let nclause = case unAnnotate clause of
DataTypeClause x _ -> x
(_, tsFields) <- typeCheck clause
ForAll [] typeCon <- getType nclause
(actFields, dataType) <- dotableToDotList typeCon
tvref' <- freshTypeVarRef []
unify (TExtendable parentType tvref') dataType
zipWithM unify actFields tsFields
) clauses
ForAll [] t <- getType n
t' <- unify t (TSet parentType)
return [(n, TSet parentType)]
typeCheck' (DataType n clauses) = do
ForAll [] t <- getType n
unify t (TSet (TDatatype n))
ntss <- mapM (\ clause -> do
let
n' = case unAnnotate clause of
DataTypeClause x _ -> x
ForAll [] t <- getType n'
(n', ts) <- typeCheck clause
let texp = foldr TDotable (TDatatype n) ts
t <- unify texp t
return ((n', t), ts)
) clauses
let (nts, tcss) = unzip ntss
tclauses = concat tcss
tclauses <- mapM (\t -> compress t) tclauses
markDatatypeAsComparableForEquality n
b <- tryAndRecover False
(mapM_ (ensureHasConstraint CEq) tclauses >> return True)
(return False)
when (not b) $ unmarkDatatypeAsComparableForEquality n
ForAll [] t <- getType n
t' <- unify t (TSet (TDatatype n))
return $ (n, t'):nts
typeCheck' (NameType n e) = do
t <- typeCheck e
valueType <- evalTypeExpression t
return [(n, TSet valueType)]
typeCheck' (Transparent ns) = return []
typeCheck' (External ns) = return []
typeCheck' (Assert a) = typeCheck a >> return []
typeCheck' (TimedSection (Just tn) f _) = do
typeCheckExpect (Var tn) TEvent
case f of
Just f -> typeCheckExpect f (TFunction [TEvent] TInt) >> return ()
Nothing -> return ()
return []
typeCheck' (Module n args pubDs privDs) = do
let fvs = boundNames args
local fvs $ do
tpats <- mapM (\ pat -> typeCheck pat >>= evaluateDots) args
typeCheckDecls (pubDs ++ privDs)
tpats <- mapM (\ pat -> typeCheck pat >>= evaluateDots) args
return [(n, TTuple tpats)]
typeCheck' (ModuleInstance n nt args nm (Just mod)) = do
stk <- getDefinitionStack
when (nt `elem` stk) $ raiseMessageAsError $
illegalModuleInstanceCycleErrorMessage nt n $ nt :
reverse (takeWhile (\n -> not (n == nt)) stk)
ts <- getType nt
(TTuple ts, sub) <- instantiate' ts
targs <- zipWithM typeCheckExpect args ts
let subName n = case safeApply (invert nm) n of
Just n' -> n'
Nothing -> n
nts <- mapM (\ (ourName, theirName) -> do
ForAll xs t <- getType theirName
t' <- substituteTypes sub t
let sub (TVar tvref) = do
res <- readTypeRef tvref
case res of
Left _ -> return $ TVar tvref
Right t -> sub t
sub (TSet t) = sub t >>= return . TSet
sub (TSeq t) = sub t >>= return . TSeq
sub (TDot t1 t2) = do
t1 <- sub t1
t2 <- sub t2
return $! TDot t1 t2
sub (TTuple ts) = mapM sub ts >>= return . TTuple
sub (TFunction ts t) = do
ts <- mapM sub ts
t <- sub t
return $! TFunction ts t
sub (TDatatype n) = return $ TDatatype $! subName n
sub (TDotable t1 t2) = do
t1 <- sub t1
t2 <- sub t2
return $! TDotable t1 t2
sub (TExtendable t tvref) = do
t' <- sub t
res <- readTypeRef tvref
case res of
Left _ -> return $ TExtendable t tvref
Right t -> do
tsub' <- sub t
writeTypeRef tvref tsub'
return $ TExtendable t' tvref
sub TInt = return TInt
sub TBool = return TBool
sub TProc = return TProc
sub TEvent = return TEvent
sub TChar = return TChar
sub TExtendableEmptyDotList = return TExtendableEmptyDotList
t' <- sub t'
setType ourName $ ForAll xs t'
return (ourName, t')
) nm
let An _ _ (Module _ _ privDs pubDs) = mod
mapM_ (\ d -> case unAnnotate d of
DataType n clauses -> do
let n' = subName n
tclauses <- mapM (\ (An _ _ (DataTypeClause n _)) -> do
ForAll _ t <- getType (subName n)
return t) clauses
markDatatypeAsComparableForEquality n'
b <- tryAndRecover False
(mapM_ (ensureHasConstraint CEq) tclauses >> return True)
(return False)
when (not b) $ unmarkDatatypeAsComparableForEquality n'
_ -> return ()
) (privDs ++ pubDs)
return nts
typeCheck' (PrintStatement _) = return []
instance TypeCheckable TCAssertion () where
errorContext an = Nothing
typeCheck' an = setSrcSpan (loc an) $ typeCheck (inner an)
instance TypeCheckable (Assertion Name) () where
errorContext a = Just $
hang (text "In the assertion" <> colon) tabWidth (prettyPrint a)
typeCheck' (PropertyCheck e1 p m) = do
ensureIsProc e1
return ()
typeCheck' (Refinement e1 m e2 opts) = do
ensureIsProc e1
ensureIsProc e2
mapM_ typeCheck opts
typeCheck' (ASNot a) = typeCheck a
instance TypeCheckable (ModelOption Name) () where
errorContext a = Nothing
typeCheck' (TauPriority e) = do
typeCheckExpect e (TSet TEvent)
return ()
instance TypeCheckable TCDataTypeClause (Name, [Type]) where
errorContext an = Nothing
typeCheck' an = setSrcSpan (loc an) $ typeCheck (inner an)
instance TypeCheckable (DataTypeClause Name) (Name, [Type]) where
errorContext c = Just $
hang (text "In the data type clause" <> colon) tabWidth
(prettyPrint c)
typeCheck' (DataTypeClause n' Nothing) = do
return (n', [])
typeCheck' (DataTypeClause n' (Just e)) = do
t <- typeCheck e
valueType <- evalTypeExpression t
dotList <- typeToDotList valueType
return (n', dotList)
instance TypeCheckable TCMatch Type where
errorContext an = Nothing
typeCheck' an = setSrcSpan (loc an) $ typeCheck (inner an)
typeCheckExpect an t = setSrcSpan (loc an) $ typeCheckExpect (inner an) t
instance TypeCheckable (Match Name) Type where
errorContext (Match groups exp) = Nothing
typeCheck' (Match groups exp) = do
let fvs = boundNames groups
local fvs $ do
tgroups <- mapM (\ pats -> mapM (\ pat ->
typeCheck pat >>= evaluateDots
) pats) groups
tr <- typeCheck exp >>= evaluateDots
tgroups <- mapM (\ pats -> mapM (\ pat ->
typeCheck pat >>= evaluateDots
) pats) groups
return $ foldr (\ targs tr -> TFunction targs tr) tr tgroups
typeCheckExpect (Match groups exp) tsig = do
let fvs = boundNames groups
local fvs $ do
rt <- freshTypeVar
argts <- mapM (flip replicateM freshTypeVar) (map length groups)
unify tsig $ foldr (\ targs tr -> TFunction targs tr) rt argts
tgroups <- zipWithM (\ pats argts -> zipWithM (\ pat argt ->
typeCheckExpect pat argt >>= evaluateDots
) pats argts) groups argts
tr <- typeCheckExpect exp rt >>= evaluateDots
tgroups <- mapM (\ pats -> mapM (\ pat ->
typeCheck pat >>= evaluateDots
) pats) groups
return $ foldr (\ targs tr -> TFunction targs tr) tr tgroups
instance TypeCheckable TCSTypeScheme TypeScheme where
errorContext an = Nothing
typeCheck' an = setSrcSpan (loc an) $ typeCheck (inner an)
instance TypeCheckable (STypeScheme Name) TypeScheme where
errorContext _ = Nothing
typeCheck' (STypeScheme boundNs cs t) = do
tvs <- mapM (\ n -> do
let ncs = map (\ (STypeConstraint c _) -> c) $
filter (\ (STypeConstraint _ n') -> n == n') $
(map unAnnotate cs)
t@(TVar tvref) <- freshRigidTypeVarWithConstraints n ncs
setType n (ForAll [] t)
return (typeVar tvref, constraints tvref)
) boundNs
t' <- typeCheck t
return $ ForAll tvs t'
instance TypeCheckable TCSType Type where
errorContext _ = Nothing
typeCheck' an = setSrcSpan (loc an) $ typeCheck (inner an)
instance TypeCheckable (SType Name) Type where
errorContext _ = Nothing
typeCheck' (STVar var) = getType var >>= \ (ForAll [] t) -> return t
typeCheck' (STExtendable t var) = do
t <- typeCheck t
TVar tvref <- getType var >>= \ (ForAll [] t) -> return t
return $ TExtendable t tvref
typeCheck' (STSet t) = typeCheck t >>= return . TSet
typeCheck' (STSeq t) = typeCheck t >>= return . TSeq
typeCheck' (STDot t1 t2) = do
t1' <- typeCheck t1
t2' <- typeCheck t2
return $ TDot t1' t2'
typeCheck' (STMap t1 t2) = do
t1' <- typeCheck t1
t2' <- typeCheck t2
return $ TMap t1' t2'
typeCheck' (STTuple ts) = mapM typeCheck ts >>= return . TTuple
typeCheck' (STFunction args rt) = do
targs <- mapM typeCheck args
trt <- typeCheck rt
return $ TFunction targs trt
typeCheck' (STDotable t1 t2) = do
t1' <- typeCheck t1
t2' <- typeCheck t2
return $ TDotable t1' t2'
typeCheck' (STParen t) = typeCheck t
typeCheck' (STDatatype n) = return $ TDatatype n
typeCheck' STProc = return TProc
typeCheck' STInt = return TInt
typeCheck' STBool = return TBool
typeCheck' STChar = return TChar
typeCheck' STEvent = return TEvent