{-# LANGUAGE ExistentialQuantification #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE Safe #-}
module Types.TypeInstance (
AnyTypeResolver(..),
CategoryName(..),
DefinesInstance(..),
FilterDirection(..),
GeneralInstance,
InstanceFilters,
InstanceParams,
InstanceVariances,
ParamFilters,
ParamVariances,
ParamName(..),
StorageType(..),
TypeFilter(..),
TypeInstance(..),
TypeInstanceOrParam(..),
TypeResolver(..),
ValueType(..),
checkDefinesMatch,
checkGeneralMatch,
checkValueTypeMatch,
uncheckedSubFilter,
uncheckedSubFilters,
uncheckedSubInstance,
uncheckedSubValueType,
getValueForParam,
isBuiltinCategory,
isDefinesFilter,
isRequiresFilter,
isWeakValue,
requiredParam,
requiredSingleton,
validateAssignment,
validateDefinesInstance,
validateDefinesVariance,
validateGeneralInstance,
validateInstanceVariance,
validateTypeFilter,
validateTypeInstance,
) where
import Control.Monad (when)
import Data.List (intercalate)
import qualified Data.Map as Map
import Base.CompileError
import Base.Mergeable
import Types.GeneralType
import Types.Positional
import Types.Variance
type GeneralInstance = GeneralType TypeInstanceOrParam
instance Show GeneralInstance where
show (SingleType t) = show t
show (TypeMerge MergeUnion []) = "all"
show (TypeMerge MergeUnion ts) = "[" ++ intercalate "|" (map show ts) ++ "]"
show (TypeMerge MergeIntersect []) = "any"
show (TypeMerge MergeIntersect ts) = "[" ++ intercalate "&" (map show ts) ++ "]"
data StorageType =
WeakValue |
OptionalValue |
RequiredValue
deriving (Eq,Ord)
data ValueType =
ValueType {
vtRequired :: StorageType,
vtType :: GeneralInstance
}
deriving (Eq,Ord)
instance Show ValueType where
show (ValueType WeakValue t) = "weak " ++ show t
show (ValueType OptionalValue t) = "optional " ++ show t
show (ValueType RequiredValue t) = show t
isWeakValue :: ValueType -> Bool
isWeakValue = (== WeakValue) . vtRequired
requiredSingleton :: CategoryName -> ValueType
requiredSingleton n = ValueType RequiredValue $ SingleType $ JustTypeInstance $ TypeInstance n (Positional [])
requiredParam :: ParamName -> ValueType
requiredParam n = ValueType RequiredValue $ SingleType $ JustParamName n
data CategoryName =
CategoryName {
tnName :: String
} |
BuiltinBool |
BuiltinChar |
BuiltinInt |
BuiltinFloat |
BuiltinString |
BuiltinFormatted |
CategoryNone
instance Show CategoryName where
show (CategoryName n) = n
show BuiltinBool = "Bool"
show BuiltinChar = "Char"
show BuiltinInt = "Int"
show BuiltinFloat = "Float"
show BuiltinString = "String"
show BuiltinFormatted = "Formatted"
show CategoryNone = "(none)"
instance Eq CategoryName where
c1 == c2 = show c1 == show c2
instance Ord CategoryName where
c1 <= c2 = show c1 <= show c2
isBuiltinCategory :: CategoryName -> Bool
isBuiltinCategory _ = False
newtype ParamName =
ParamName {
pnName :: String
}
deriving (Eq,Ord)
instance Show ParamName where
show (ParamName n) = n
data TypeInstance =
TypeInstance {
tiName :: CategoryName,
tiParams :: InstanceParams
}
deriving (Eq,Ord)
instance Show TypeInstance where
show (TypeInstance n (Positional [])) = show n
show (TypeInstance n (Positional ts)) =
show n ++ "<" ++ intercalate "," (map show ts) ++ ">"
data DefinesInstance =
DefinesInstance {
diName :: CategoryName,
diParams :: InstanceParams
}
deriving (Eq,Ord)
instance Show DefinesInstance where
show (DefinesInstance n (Positional [])) = show n
show (DefinesInstance n (Positional ts)) =
show n ++ "<" ++ intercalate "," (map show ts) ++ ">"
data TypeInstanceOrParam =
JustTypeInstance {
jtiType :: TypeInstance
} |
JustParamName {
jpnName :: ParamName
}
deriving (Eq,Ord)
instance Show TypeInstanceOrParam where
show (JustTypeInstance t) = show t
show (JustParamName n) = show n
data FilterDirection =
FilterRequires |
FilterAllows
deriving (Eq,Ord)
data TypeFilter =
TypeFilter {
tfDirection :: FilterDirection,
tfType :: TypeInstanceOrParam
} |
DefinesFilter {
dfType :: DefinesInstance
}
deriving (Eq,Ord)
instance Show TypeFilter where
show (TypeFilter FilterRequires t) = "requires " ++ show t
show (TypeFilter FilterAllows t) = "allows " ++ show t
show (DefinesFilter t) = "defines " ++ show t
isTypeFilter :: TypeFilter -> Bool
isTypeFilter (TypeFilter _ _) = True
isTypeFilter _ = False
isRequiresFilter :: TypeFilter -> Bool
isRequiresFilter (TypeFilter FilterRequires _) = True
isRequiresFilter _ = False
isDefinesFilter :: TypeFilter -> Bool
isDefinesFilter (DefinesFilter _) = True
isDefinesFilter _ = False
viewTypeFilter :: ParamName -> TypeFilter -> String
viewTypeFilter n f = show n ++ " " ++ show f
type InstanceParams = Positional GeneralInstance
type InstanceVariances = Positional Variance
type InstanceFilters = Positional [TypeFilter]
type ParamFilters = Map.Map ParamName [TypeFilter]
type ParamVariances = Map.Map ParamName Variance
class TypeResolver r where
trRefines :: (MergeableM m, CompileErrorM m) =>
r -> TypeInstance -> CategoryName -> m InstanceParams
trDefines :: (MergeableM m, CompileErrorM m) =>
r -> TypeInstance -> CategoryName -> m InstanceParams
trVariance :: (MergeableM m, CompileErrorM m) =>
r -> CategoryName -> m InstanceVariances
trTypeFilters :: (MergeableM m, CompileErrorM m) =>
r -> TypeInstance -> m InstanceFilters
trDefinesFilters :: (MergeableM m, CompileErrorM m) =>
r -> DefinesInstance -> m InstanceFilters
trConcrete :: (MergeableM m, CompileErrorM m) =>
r -> CategoryName -> m Bool
data AnyTypeResolver = forall r. TypeResolver r => AnyTypeResolver r
instance TypeResolver AnyTypeResolver where
trRefines (AnyTypeResolver r) = trRefines r
trDefines (AnyTypeResolver r) = trDefines r
trVariance (AnyTypeResolver r) = trVariance r
trTypeFilters (AnyTypeResolver r) = trTypeFilters r
trDefinesFilters (AnyTypeResolver r) = trDefinesFilters r
trConcrete (AnyTypeResolver r) = trConcrete r
filterLookup :: (CompileErrorM m) =>
ParamFilters -> ParamName -> m [TypeFilter]
filterLookup ps n = resolve $ n `Map.lookup` ps where
resolve (Just x) = return x
resolve _ = compileError $ "Param " ++ show n ++ " not found"
getValueForParam :: (CompileErrorM m) =>
Map.Map ParamName GeneralInstance -> ParamName -> m GeneralInstance
getValueForParam pa n =
case n `Map.lookup` pa of
(Just x) -> return x
_ -> compileError $ "Param " ++ show n ++ " does not exist"
checkValueTypeMatch :: (MergeableM m, CompileErrorM m, TypeResolver r) =>
r -> ParamFilters -> ValueType -> ValueType -> m ()
checkValueTypeMatch r f ts1@(ValueType r1 t1) ts2@(ValueType r2 t2)
| r1 < r2 =
compileError $ "Cannot convert " ++ show ts1 ++ " to " ++ show ts2
| otherwise = checkGeneralMatch r f Covariant t1 t2
checkGeneralMatch :: (MergeableM m, CompileErrorM m, TypeResolver r) =>
r -> ParamFilters -> Variance ->
GeneralInstance -> GeneralInstance -> m ()
checkGeneralMatch r f Invariant ts1 ts2 = do
checkGeneralType (checkSingleMatch r f Covariant) ts1 ts2
checkGeneralType (checkSingleMatch r f Covariant) ts2 ts1
checkGeneralMatch r f Contravariant ts1 ts2 =
checkGeneralType (checkSingleMatch r f Covariant) ts2 ts1
checkGeneralMatch r f Covariant ts1 ts2 =
checkGeneralType (checkSingleMatch r f Covariant) ts1 ts2
checkSingleMatch :: (MergeableM m, CompileErrorM m, TypeResolver r) =>
r -> ParamFilters -> Variance ->
TypeInstanceOrParam -> TypeInstanceOrParam -> m ()
checkSingleMatch r f v (JustTypeInstance t1) (JustTypeInstance t2) =
checkInstanceToInstance r f v t1 t2
checkSingleMatch r f v (JustParamName p1) (JustTypeInstance t2) =
checkParamToInstance r f v p1 t2
checkSingleMatch r f v (JustTypeInstance t1) (JustParamName p2) =
checkInstanceToParam r f v t1 p2
checkSingleMatch r f v (JustParamName p1) (JustParamName p2) =
checkParamToParam r f v p1 p2
checkInstanceToInstance :: (MergeableM m, CompileErrorM m, TypeResolver r) =>
r -> ParamFilters -> Variance -> TypeInstance -> TypeInstance -> m ()
checkInstanceToInstance r f Invariant t1 t2
| t1 == t2 = mergeDefaultM
| otherwise =
mergeAllM [checkInstanceToInstance r f Covariant t1 t2,
checkInstanceToInstance r f Contravariant t1 t2]
checkInstanceToInstance r f Contravariant t1 t2 =
checkInstanceToInstance r f Covariant t2 t1
checkInstanceToInstance r f Covariant t1@(TypeInstance n1 ps1) t2@(TypeInstance n2 ps2)
| n1 == n2 = do
paired <- processPairs alwaysPair ps1 ps2
let zipped = Positional paired
variance <- trVariance r n1
processPairs_ (\v2 (p1,p2) -> checkGeneralMatch r f v2 p1 p2) variance zipped >> mergeDefaultM
| otherwise = do
ps1' <- trRefines r t1 n2
checkInstanceToInstance r f Covariant (TypeInstance n2 ps1') t2
checkParamToInstance :: (MergeableM m, CompileErrorM m, TypeResolver r) =>
r -> ParamFilters -> Variance -> ParamName -> TypeInstance -> m ()
checkParamToInstance r f Invariant n1 t2 =
mergeAllM [checkParamToInstance r f Covariant n1 t2,
checkParamToInstance r f Contravariant n1 t2]
checkParamToInstance r f Contravariant p1 t2 =
checkInstanceToParam r f Covariant t2 p1
checkParamToInstance r f Covariant n1 t2@(TypeInstance _ _) = do
cs1 <- fmap (filter isTypeFilter) $ f `filterLookup` n1
mergeAnyM (map checkConstraintToInstance cs1) `reviseError`
("No filters imply " ++ show n1 ++ " -> " ++ show t2)
where
checkConstraintToInstance (TypeFilter FilterRequires t) =
checkSingleMatch r f Covariant t (JustTypeInstance t2)
checkConstraintToInstance f2 =
compileError $ "Constraint " ++ viewTypeFilter n1 f2 ++
" does not imply " ++ show n1 ++ " -> " ++ show t2
checkInstanceToParam :: (MergeableM m, CompileErrorM m, TypeResolver r) =>
r -> ParamFilters -> Variance -> TypeInstance -> ParamName -> m ()
checkInstanceToParam r f Invariant t1 n2 =
mergeAllM [checkInstanceToParam r f Covariant t1 n2,
checkInstanceToParam r f Contravariant t1 n2]
checkInstanceToParam r f Contravariant t1 p2 =
checkParamToInstance r f Covariant p2 t1
checkInstanceToParam r f Covariant t1@(TypeInstance _ _) n2 = do
cs2 <- fmap (filter isTypeFilter) $ f `filterLookup` n2
mergeAnyM (map checkInstanceToConstraint cs2) `reviseError`
("No filters imply " ++ show t1 ++ " -> " ++ show n2)
where
checkInstanceToConstraint (TypeFilter FilterAllows t) =
checkSingleMatch r f Covariant (JustTypeInstance t1) t
checkInstanceToConstraint f2 =
compileError $ "Constraint " ++ viewTypeFilter n2 f2 ++
" does not imply " ++ show t1 ++ " -> " ++ show n2
checkParamToParam :: (MergeableM m, CompileErrorM m, TypeResolver r) =>
r -> ParamFilters -> Variance -> ParamName -> ParamName -> m ()
checkParamToParam r f Invariant n1 n2
| n1 == n2 = mergeDefaultM
| otherwise =
mergeAllM [checkParamToParam r f Covariant n1 n2,
checkParamToParam r f Contravariant n1 n2]
checkParamToParam r f Contravariant n1 n2 =
checkParamToParam r f Covariant n2 n1
checkParamToParam r f Covariant n1 n2
| n1 == n2 = mergeDefaultM
| otherwise = do
cs1 <- fmap (filter isTypeFilter) $ f `filterLookup` n1
cs2 <- fmap (filter isTypeFilter) $ f `filterLookup` n2
let typeFilters = [(c1,c2) | c1 <- cs1, c2 <- cs2] ++
[(self1,c2) | c2 <- cs2] ++
[(c1,self2) | c1 <- cs1]
mergeAnyM (map (\(c1,c2) -> checkConstraintToConstraint c1 c2) typeFilters) `reviseError`
("No filters imply " ++ show n1 ++ " -> " ++ show n2)
where
self1 = TypeFilter FilterRequires (JustParamName n1)
self2 = TypeFilter FilterAllows (JustParamName n2)
checkConstraintToConstraint (TypeFilter FilterRequires t1) (TypeFilter FilterAllows t2)
| t1 == (JustParamName n1) && t2 == (JustParamName n2) =
compileError $ "Infinite recursion in " ++ show n1 ++ " -> " ++ show n2
| otherwise = checkSingleMatch r f Covariant t1 t2
checkConstraintToConstraint f1 f2 =
compileError $ "Constraints " ++ viewTypeFilter n1 f1 ++ " and " ++
viewTypeFilter n2 f2 ++ " do not imply " ++
show n1 ++ " -> " ++ show n2
validateGeneralInstance :: (MergeableM m, CompileErrorM m, TypeResolver r) =>
r -> ParamFilters -> GeneralInstance -> m ()
validateGeneralInstance _ _ (TypeMerge _ ts)
| length ts == 1 = compileError $ "Unions and intersections must have at least 2 types to avoid ambiguity"
validateGeneralInstance r f (TypeMerge MergeIntersect ts) =
mergeAllM (map (validateGeneralInstance r f) ts)
validateGeneralInstance r f (TypeMerge _ ts) =
mergeAllM (map (validateGeneralInstance r f) ts)
validateGeneralInstance r f (SingleType (JustTypeInstance t)) =
validateTypeInstance r f t
validateGeneralInstance _ f (SingleType (JustParamName n)) =
when (not $ n `Map.member` f) $
compileError $ "Param " ++ show n ++ " does not exist"
validateTypeInstance :: (MergeableM m, CompileErrorM m, TypeResolver r) =>
r -> ParamFilters -> TypeInstance -> m ()
validateTypeInstance r f t@(TypeInstance _ ps) = do
fa <- trTypeFilters r t
processPairs_ (validateAssignment r f) ps fa
mergeAllM (map (validateGeneralInstance r f) (pValues ps)) `reviseError`
("Recursive error in " ++ show t)
validateDefinesInstance :: (MergeableM m, CompileErrorM m, TypeResolver r) =>
r -> ParamFilters -> DefinesInstance -> m ()
validateDefinesInstance r f t@(DefinesInstance _ ps) = do
fa <- trDefinesFilters r t
processPairs_ (validateAssignment r f) ps fa
mergeAllM (map (validateGeneralInstance r f) (pValues ps)) `reviseError`
("Recursive error in " ++ show t)
validateTypeFilter :: (MergeableM m, CompileErrorM m, TypeResolver r) =>
r -> ParamFilters -> TypeFilter -> m ()
validateTypeFilter r f (TypeFilter _ t) =
validateGeneralInstance r f (SingleType t)
validateTypeFilter r f (DefinesFilter t) =
validateDefinesInstance r f t
validateAssignment :: (MergeableM m, CompileErrorM m, TypeResolver r) =>
r -> ParamFilters -> GeneralInstance -> [TypeFilter] -> m ()
validateAssignment r f t fs = mergeAllM (map (checkFilter t) fs) where
checkFilter t1 (TypeFilter FilterRequires t2) = do
checkGeneralMatch r f Covariant t1 (SingleType t2)
checkFilter t1 (TypeFilter FilterAllows t2) = do
checkGeneralMatch r f Contravariant t1 (SingleType t2)
checkFilter t1@(TypeMerge _ _) (DefinesFilter t2) =
compileError $ "Merged type " ++ show t1 ++ " cannot satisfy defines constraint " ++ show t2
checkFilter (SingleType t1) (DefinesFilter f2) = checkDefinesFilter f2 t1
checkDefinesFilter f2@(DefinesInstance n2 _) (JustTypeInstance t1) = do
ps1' <- trDefines r t1 n2
checkDefinesMatch r f f2 (DefinesInstance n2 ps1')
checkDefinesFilter f2 (JustParamName n1) = do
fs1 <- fmap (map dfType . filter isDefinesFilter) $ f `filterLookup` n1
mergeAnyM (map (checkDefinesMatch r f f2) fs1) `reviseError`
("No filters imply " ++ show n1 ++ " defines " ++ show f2)
checkDefinesMatch :: (MergeableM m, CompileErrorM m, TypeResolver r) =>
r -> ParamFilters -> DefinesInstance -> DefinesInstance -> m ()
checkDefinesMatch r f f2@(DefinesInstance n2 ps2) f1@(DefinesInstance n1 ps1)
| n1 == n2 = do
paired <- processPairs alwaysPair ps1 ps2
variance <- trVariance r n2
processPairs_ (\v2 (p1,p2) -> checkGeneralMatch r f v2 p1 p2) variance (Positional paired)
mergeDefaultM
| otherwise = compileError $ "Constraint " ++ show f1 ++ " does not imply " ++ show f2
validateInstanceVariance :: (MergeableM m, CompileErrorM m, TypeResolver r) =>
r -> ParamVariances -> Variance -> GeneralInstance -> m ()
validateInstanceVariance r vm v (SingleType (JustTypeInstance (TypeInstance n ps))) = do
vs <- trVariance r n
paired <- processPairs alwaysPair vs ps
mergeAllM (map (\(v2,p) -> validateInstanceVariance r vm (v `composeVariance` v2) p) paired)
validateInstanceVariance r vm v (TypeMerge MergeUnion ts) =
mergeAllM (map (validateInstanceVariance r vm v) ts)
validateInstanceVariance r vm v (TypeMerge MergeIntersect ts) =
mergeAllM (map (validateInstanceVariance r vm v) ts)
validateInstanceVariance _ vm v (SingleType (JustParamName n)) =
case n `Map.lookup` vm of
Nothing -> compileError $ "Param " ++ show n ++ " is undefined"
(Just v0) -> when (not $ v0 `paramAllowsVariance` v) $
compileError $ "Param " ++ show n ++ " cannot be " ++ show v
validateDefinesVariance :: (MergeableM m, CompileErrorM m, TypeResolver r) =>
r -> ParamVariances -> Variance -> DefinesInstance -> m ()
validateDefinesVariance r vm v (DefinesInstance n ps) = do
vs <- trVariance r n
paired <- processPairs alwaysPair vs ps
mergeAllM (map (\(v2,p) -> validateInstanceVariance r vm (v `composeVariance` v2) p) paired)
uncheckedSubValueType :: (MergeableM m, CompileErrorM m) =>
(ParamName -> m GeneralInstance) -> ValueType -> m ValueType
uncheckedSubValueType replace (ValueType s t) = do
t' <- uncheckedSubInstance replace t
return $ ValueType s t'
uncheckedSubInstance :: (MergeableM m, CompileErrorM m) =>
(ParamName -> m GeneralInstance) -> GeneralInstance -> m GeneralInstance
uncheckedSubInstance replace = subAll where
subAll (TypeMerge MergeUnion ts) = do
gs <- collectAllOrErrorM $ map subAll ts
return (TypeMerge MergeUnion gs)
subAll (TypeMerge MergeIntersect ts) = do
gs <- collectAllOrErrorM $ map subAll ts
return (TypeMerge MergeIntersect gs)
subAll (SingleType t) = subInstance t
subInstance (JustTypeInstance (TypeInstance n (Positional ts))) = do
gs <- collectAllOrErrorM $ map subAll ts
let t2 = SingleType $ JustTypeInstance $ TypeInstance n (Positional gs)
return (t2)
subInstance (JustParamName n) = replace n
uncheckedSubFilter :: (MergeableM m, CompileErrorM m) =>
(ParamName -> m GeneralInstance) -> TypeFilter -> m TypeFilter
uncheckedSubFilter replace (TypeFilter d t) = do
t' <- uncheckedSubInstance replace (SingleType t)
return (TypeFilter d (stType t'))
uncheckedSubFilter replace (DefinesFilter (DefinesInstance n ts)) = do
ts' <- collectAllOrErrorM $ map (uncheckedSubInstance replace) (pValues ts)
return (DefinesFilter (DefinesInstance n (Positional ts')))
uncheckedSubFilters :: (MergeableM m, CompileErrorM m) =>
(ParamName -> m GeneralInstance) -> ParamFilters -> m ParamFilters
uncheckedSubFilters replace fa = do
fa' <- collectAllOrErrorM $ map subParam $ Map.toList fa
return $ Map.fromList fa'
where
subParam (n,fs) = do
fs' <- collectAllOrErrorM $ map (uncheckedSubFilter replace) fs
return (n,fs')