module CSPM.DataStructures.Types where
import Control.Monad.Trans
import Data.IORef
import Data.List
import CSPM.DataStructures.Names
import Util.PartialFunctions
import Util.PrettyPrint
newtype TypeVar = TypeVar Int deriving (Eq, Show)
data TypeScheme =
ForAll [(TypeVar, [Constraint])] Type
deriving (Eq, Show)
data Constraint =
Eq | Ord
deriving (Eq, Ord, Show)
data Type =
TVar TypeVarRef
| TProc
| TInt
| TBool
| TEvent
| TEventable
| TSet Type
| TSeq Type
| TDot Type Type
| TTuple [Type]
| TFunction [Type] Type
| TDotable Type Type
| TDatatype Name
deriving (Eq, Show)
data TypeVarRef =
TypeVarRef TypeVar [Constraint] PType
instance Eq TypeVarRef where
(TypeVarRef tv1 cs1 pt1) == (TypeVarRef tv2 cs2 pt2) = tv1 == tv2
instance Show TypeVarRef where
show (TypeVarRef tv cs _) = "TypeVarRef "++show tv ++ show cs
newtype IORefMaybe a = IORefMaybe (Maybe a)
type SymbolTable = PartialFunction Name TypeScheme
type PType = IORef (Maybe Type)
type PSymbolTable = IORef SymbolTable
readPType :: (MonadIO m) => PType -> m (Maybe Type)
readPType ioref = liftIO $ readIORef ioref
setPType :: (MonadIO m) => PType -> Type -> m ()
setPType ioref t = liftIO $ writeIORef ioref (Just t)
freshPType :: (MonadIO m) => m PType
freshPType = liftIO $ newIORef Nothing
readPSymbolTable :: (MonadIO m) => PSymbolTable -> m SymbolTable
readPSymbolTable ioref = liftIO $ readIORef ioref
setPSymbolTable :: (MonadIO m) => PSymbolTable -> SymbolTable -> m ()
setPSymbolTable ioref t = liftIO $ writeIORef ioref t
freshPSymbolTable :: (MonadIO m) => m PSymbolTable
freshPSymbolTable = liftIO $ newIORef []
instance PrettyPrintable Constraint where
prettyPrint Eq = text "Eq"
prettyPrint Ord = text "Ord"
prettyPrintTypes :: [Type] -> [Doc]
prettyPrintTypes ts = map (prettyPrintType vmap) ts
where
vs = (nub . map fst . concatMap collectConstraints) ts
vmap = zip (map (\ (TypeVar n) -> n) vs) ['a'..'z']
instance PrettyPrintable Type where
prettyPrint t = prettyPrint (ForAll (collectConstraints t) t)
instance PrettyPrintable TypeScheme where
prettyPrint (ForAll ts t) =
(if length varsWithCs > 0 then
(if length varsWithCs > 1 then parens constraintsText
else constraintsText)
<+> text "=> "
else empty)
<> prettyPrintType vmap t
where
vmap = zip (map (\ (TypeVar n, _) -> n) ts) ['a'..'z']
varsWithCs = [(v, c) | (v, cs) <- ts, c <- cs, cs /= []]
constraintsText =
hsep (
punctuate comma [
prettyPrint c <+> char (apply vmap n)
| (TypeVar n, c) <- varsWithCs]
)
prettyPrintType :: PartialFunction Int Char -> Type -> Doc
prettyPrintType vmap (TVar (TypeVarRef (TypeVar n) cs ioref)) =
case safeApply vmap n of
Just c -> char c
Nothing -> int n
prettyPrintType vmap (TFunction targs tr) =
parens (hsep (punctuate comma (map (prettyPrintType vmap) targs)))
<+> text "->" <+> prettyPrintType vmap tr
prettyPrintType vmap (TSeq t) =
char '<' <> prettyPrintType vmap t <> char '>'
prettyPrintType vmap (TSet t) =
char '{' <> prettyPrintType vmap t <> char '}'
prettyPrintType vmap (TTuple ts) =
parens (hsep (punctuate comma (map (prettyPrintType vmap) ts)))
prettyPrintType vmap (TDot t1 t2) =
(case t1 of
TDotable _ _ -> parens (prettyPrintType vmap t1)
_ -> prettyPrintType vmap t1
) <> text "." <> prettyPrintType vmap t2
prettyPrintType vmap (TDotable t1 t2) =
prettyPrintType vmap t1 <> text "=>" <> prettyPrintType vmap t2
prettyPrintType vmap (TDatatype (Name n)) = text n
prettyPrintType vmap (TBool) = text "Bool"
prettyPrintType vmap (TInt) = text "Int"
prettyPrintType vmap (TProc) = text "Proc"
prettyPrintType vmap (TEvent) = text "Event"
prettyPrintType vmap (TEventable) = text "Event or Channel"
collectConstraints :: Type -> [(TypeVar, [Constraint])]
collectConstraints = combine . collect
where
combine :: [(TypeVar, [Constraint])] -> [(TypeVar, [Constraint])]
combine xs =
map (\ ys -> (head (map fst ys), nub (concat (map snd ys))))
(groupBy (\ (v1, _) (v2, _) -> v1 == v2) xs)
collect :: Type -> [(TypeVar, [Constraint])]
collect (TVar (TypeVarRef v cs _)) = [(v, cs)]
collect (TFunction targs tr) =
concatMap collect targs ++ collect tr
collect (TSeq t) = collect t
collect (TSet t) = collect t
collect (TTuple ts) = concatMap collect ts
collect (TDot t1 t2) = collect t1 ++ collect t2
collect (TDotable t1 t2) = collect t1 ++ collect t2
collect (TDatatype _) = []
collect TBool = []
collect TInt = []
collect TProc = []
collect TEvent = []
collect TEventable = []