--
-- (c) Susumu Katayama
--
\begin{code}
module MagicHaskeller.Types(Type(..), Kind, TyCon, TyVar, TypeName, Typed(..), tyvars, Subst, plusSubst,
emptySubst, apply, mgu, varBind, match, maxVarID, normalizeVarIDs, normalize,
Decoder(..), typer, typee, negateTVIDs, limitType, saferQuantify, quantify, quantify', unquantify, lookupSubst, unifyFunAp,
alltyvars, mapTV, size, unitSubst, applyCheck, assertsubst, substOK, eqType, getRet, getArity, getArities, getAritiesRet, splitArgs, getArgs, pushArgs, popArgs, mguFunAp, revSplitArgs, revGetArgs, module Data.Int
) where
import Data.List
import Control.Monad
import Data.Char(ord)
#ifdef QUICKCHECK
import Test.QuickCheck
#endif
import Data.Int
infixr :->, :>
trace _ = id
data Type = TV !TyVar | TC !TyCon | TA Type Type | Type :> Type | Type :-> Type | Type :=> Type
deriving (Eq, Ord, Read)
size :: Type -> Int
size (TC _) = 1
size (TV _) = 1
size (TA t0 t1) = size t0 + size t1
size (t0 :> t1) = size t0 + size t1
size (t0 :-> t1) = size t0 + size t1
#ifdef QUICKCHECK
instance Arbitrary Type where
arbitrary = sized arbType
arbType 0 = oneof [liftM TV arbitrary, liftM TC arbitrary]
arbType n = frequency [ (8, arbType 0),
(2, liftM2 TA (arbType (n `div` 2)) (arbType (n `div` 2))),
(2, liftM2 (:->) (arbType (n `div` 2)) (arbType (n `div` 2))) ]
#endif
mapTV :: (TyVar -> TyVar) -> Type -> Type
mapTV f t =
mtv t
where mtv (TA t0 t1) = TA (mtv t0) (mtv t1)
mtv (t1 :=> t0) = (mtv t1) :=> (mtv t0)
mtv (t1 :-> t0) = (mtv t1) :-> (mtv t0)
mtv (t1 :> t0) = (mtv t1) :> (mtv t0)
mtv (TV tv) = TV (f tv)
mtv tc@(TC _) = tc
negateTVIDs :: Type -> Type
negateTVIDs = mapTV (\tvid -> 1 tvid)
limitType n (TC _) = n1
limitType n (TV _) = n1
limitType n (u :-> t) = lt n t u
limitType n (u :> t) = lt n t u
limitType n (TA t u) = lt n t u
lt n t u = case limitType n t of m | m > 0 -> limitType m u
| otherwise -> 1
alltyvars :: Type -> Subst -> [TyVar]
alltyvars ty s = alltyvars' ty s []
alltyvars' :: Type -> Subst -> [TyVar] -> [TyVar]
alltyvars' (TV tv) s = case lookupSubst s tv of Just t -> alltyvars' t s
Nothing -> (tv:)
alltyvars' (TC tc) s = id
alltyvars' (TA t u) s = alltyvars' t s . alltyvars' u s
alltyvars' (u :> t) s = alltyvars' t s . alltyvars' u s
alltyvars' (u :-> t) s = alltyvars' t s . alltyvars' u s
alltyvars' (u :=> t) s = alltyvars' t s
tyvars :: Type -> [TyVar]
tyvars ty = tyvars' ty []
tyvars' :: Type -> [TyVar] -> [TyVar]
tyvars' (TV tv) = (tv:)
tyvars' (TC tc) = id
tyvars' (TA t u) = tyvars' t . tyvars' u
tyvars' (u :> t) = tyvars' t . tyvars' u
tyvars' (u :-> t) = tyvars' t . tyvars' u
tyvars' (u :=> t) = tyvars' t
maxVarID :: Type -> TyVar
maxVarID (TV tv) = tv
maxVarID (TC _) = 1
maxVarID (TA t u) = maxVarID t `max` maxVarID u
maxVarID (t :> u) = maxVarID t `max` maxVarID u
maxVarID (t :-> u) = maxVarID t `max` maxVarID u
maxVarID (_ :=> u) = maxVarID u
type Kind = Int
type TyVar = Int8
type TyCon = TyVar
type TypeName = String
data Decoder = Dec [TyVar] TyVar
deriving Show
normalizeVarIDs :: Type -> TyVar -> (Type, Decoder)
normalizeVarIDs ty mx = let decoList = nub $ tyvars ty
tup = zip decoList [0..]
encoType = mapTV (\tv -> case lookup tv tup of Just n -> n) ty
len = genericLength decoList
margin =
mx + 1 len
in
(encoType, Dec decoList margin)
normalize ty = fst $ normalizeVarIDs ty (error "undef of normalize")
eqType :: Type -> Type -> Bool
eqType t0 t1 = normalize t0 == normalize t1
saferQuantify, quantify, quantify', unquantify :: Type -> Type
saferQuantify = quantify . negUnquantify
negUnquantify (TC i) | i < 0 = TV $ fromIntegral i
negUnquantify (TA t u) = TA (negUnquantify t) (negUnquantify u)
negUnquantify (u :-> t) = negUnquantify u :-> negUnquantify t
negUnquantify (u :> t) = negUnquantify u :> negUnquantify t
negUnquantify (u :=> t) = error "negUnquantify: applied to types with contexts"
negUnquantify t = t
quantify ty = quantify' (normalize ty)
quantify' (TV iD) = TC $ fromIntegral (iD1)
quantify' tc@(TC _) = tc
quantify' (TA t u) = TA (quantify' t) (quantify' u)
quantify' (u :-> t) = quantify' u :-> quantify' t
quantify' (u :> t) = quantify' u :> quantify' t
unquantify (TC tc) | tc < 0 = TV $ fromIntegral (1tc)
unquantify (TA t u) = TA (unquantify t) (unquantify u)
unquantify (u :-> t) = unquantify u :-> unquantify t
unquantify (u :> t) = unquantify u :> unquantify t
unquantify (u :=> t) = error "unquantify: applied to types with contexts"
unquantify t = t
unifyFunAp str t u = case uniFunAp t u of Just v -> trace (str ++ ". unify "++show t ++" and "++show u) v
Nothing -> error (str ++ ". unifyFunAp: t = "++show t++", and u = "++show u)
uniFunAp :: MonadPlus m => Type -> Type -> m Type
uniFunAp (a:->r) t = do subst <- mgu (getRet a) (getRet t)
return (apply subst r)
uniFunAp (a:=>r) t = uniFunAp (a:->r) t
uniFunAp f t = mzero
mguFunAp :: MonadPlus m => Type -> Type -> m Type
mguFunAp t0 t1 = trace ("mguFunAp t0 = "++ show t0++", and t1 = "++show t1) $
case maxVarID t1 + 1 of mx -> mfa (mapTV (mx+) t0) t1
mfa (a:->r) t = do subst <- mgu a t
let retv = (apply subst r)
trace ("retv = "++show retv) $ return retv
mfa (a:>r) t = mfa (a:->r) t
mfa (a:=>r) t = mfa (a:->r) t
mfa t@(TV _) _ = return t
mfa f t = mzero
pushArgsCPS :: Integral i => (i -> i -> [Type] -> Type -> a) -> [Type] -> Type -> a
pushArgsCPS f = pa 0 0
where
pa c n args (t0:->t1) = pa c (n+1) (t0:args) t1
pa c n args (t0:>t1) = pa c (n+1) (t0:args) t1
pa c n args (t0:=>t1) = pa (c+1) n (t0:args) t1
pa c n args retty = f c n args retty
pushArgs :: [Type] -> Type -> ([Type],Type)
pushArgs = pushArgsCPS (\_ _ a r -> (a,r))
getRet = pushArgsCPS (\_ _ _ r -> r) undefined
getArgs = pushArgsCPS (\_ _ a _ -> a) []
getNumCxts, getArity :: Integral i => Type -> i
getNumCxts = pushArgsCPS (\c _ _ _ -> c) undefined
getArity = pushArgsCPS (\_ i _ _ -> i) undefined
getArities :: Integral i => Type -> (i,i)
getArities = pushArgsCPS (\c i _ _ -> (c,i)) undefined
getAritiesRet :: Integral i => Type -> (i,i,Type)
getAritiesRet = pushArgsCPS (\c i _ r -> (c,i,r)) undefined
splitArgs :: Type -> ([Type],Type)
splitArgs = pushArgs []
revSplitArgs :: Integral i => Type -> (i,[Type],Type)
revSplitArgs (t0:->t1) = case revSplitArgs t1 of (n,args,ret) -> (n+1, t0:args, ret)
revSplitArgs t = (0, [], t)
revGetArgs :: Type -> [Type]
revGetArgs ty = case revSplitArgs ty of (_,ts,_) -> ts
popArgs :: [Type] -> Type -> Type
popArgs = flip (foldl (flip (:->)))
\end{code}
data "Typed", taken from obsolete/Binding.hs
\begin{code}
data Typed a = a ::: !Type deriving (Show, Eq, Ord)
typee (a ::: _) = a
typer (_ ::: t) = t
instance Functor Typed where
fmap f (a ::: t) = f a ::: t
\end{code}
\section{Type inference tools}
\begin{code}
type Subst = [(TyVar,Type)]
showsAssoc [] = id
showsAssoc ((k,v):assocs) = (' ':) . shows k . ("\t|-> "++) . shows v . ('\n':) . showsAssoc assocs
emptySubst = []
unitSubst k e = [(k, e)]
match, mgu :: MonadPlus m => Type -> Type -> m Subst
match (l :-> r) (l' :-> r') = match2Ap l r l' r'
match (TA l r) (TA l' r') = match2Ap l r l' r'
match (TV u) t = varBind u t
match (TC tc1) (TC tc2) | tc1==tc2 = return emptySubst
match _ _ = mzero
match2Ap l r l' r' = do s1 <- match l l'
s2 <- match (apply s1 r) r'
return (s2 `plusSubst` s1)
mgu (l :-> r) (l' :-> r') = mgu2Ap l r l' r'
#ifdef REALDYNAMIC
mgu (l :-> r) (l' :> r') = mgu2Ap l r l' r'
mgu (l :> r) (l' :-> r') = mgu2Ap l r l' r'
mgu (l :> r) (l' :> r') = mgu2Ap l r l' r'
#endif
mgu (TA l r) (TA l' r') = mgu2Ap l r l' r'
mgu (TV u) t = varBind u t
mgu t (TV u) = varBind u t
mgu (TC tc1) (TC tc2) | tc1==tc2 = return emptySubst
mgu _ _ = mzero
mgu2Ap l r l' r' = do s1 <- mgu l l'
s2 <- mgu (apply s1 r) (apply s1 r')
return (s2 `plusSubst` s1)
varBind :: MonadPlus m => TyVar -> Type -> m Subst
varBind _ (_:=>_) = mzero
varBind u t | t == TV u = return emptySubst
| u `elem` (tyvars t) = mzero
| otherwise = return (unitSubst u t)
substOK :: Subst -> Bool
substOK = all (\ (i,ty) -> not (i `elem` (tyvars ty)))
assertsubst :: String -> Subst -> Subst
assertsubst str = \s -> if substOK s then s else error (str ++ ": assertsubst failed. substitution = " ++ show s)
instance Show Type where
showsPrec _ ty = toString' 0 ty
where toString' k (TV i) = ('a':) . shows i
toString' k (TC i) = ('K':) . shows k . ('I':) . shows i
toString' k (TA t0 t1) = showParen True (toString' (k+1) t0 . (' ':) . toString' 0 t1)
toString' k (t0 :=> t1) = showParen True (toString' 0 t0 . ("=>"++) . toString' 0 t1)
toString' k (t0 :-> t1) = showParen True (toString' 0 t0 . ("->"++) . toString' 0 t1)
toString' k (t0 :> t1) = showParen True (("(->) "++) . toString' 0 t0 . (' ':) . toString' 0 t1)
plusSubst :: Subst -> Subst -> Subst
s0 `plusSubst` s1 = [(u,
applyCheck s0 t) | (u,t) <- s1] ++ s0
lookupSubst :: MonadPlus m => Subst -> TyVar -> m Type
lookupSubst subst i = case lookup i subst of Nothing -> mzero
Just x -> return x
apply :: Subst -> Type -> Type
apply subst ty = apply' ty
where apply' tc@(TC _) = tc
apply' tg@(TV tv)
= case lookupSubst subst tv of Just tt -> tt
Nothing -> tg
apply' (TA t0 t1) = TA (apply' t0) (apply' t1)
apply' (t0:->t1) = apply' t0 :-> apply' t1
apply' (t0:>t1) = apply' t0 :> apply' t1
apply' (t0:=>t1) = apply' t0 :=> apply' t1
applyCheck subst t =
apply subst t
\end{code}