module FrontEnd.TypeSynonyms (
removeSynonymsFromType,
declsToTypeSynonyms,
TypeSynonyms,
restrictTypeSynonyms,
showSynonyms,
showSynonym
) where
import Control.Monad.Writer
import Data.Binary
import Data.List
import qualified Data.Map as Map
import qualified Data.Set as Set
import Doc.DocLike
import FrontEnd.HsSyn
import FrontEnd.SrcLoc
import FrontEnd.Syn.Traverse
import FrontEnd.Warning
import GenUtil
import Name.Name
import Support.FreeVars
import Support.MapBinaryInstance
import Util.HasSize
import Util.UniqueMonad
import qualified Util.Graph as G
newtype TypeSynonyms = TypeSynonyms (Map.Map Name ([HsName], HsType, SrcLoc))
deriving(Monoid,HasSize)
instance Binary TypeSynonyms where
put (TypeSynonyms ts) = putMap ts
get = fmap TypeSynonyms getMap
restrictTypeSynonyms :: (Name -> Bool) -> TypeSynonyms -> TypeSynonyms
restrictTypeSynonyms f (TypeSynonyms fm) = TypeSynonyms (Map.filterWithKey (\k _ -> f k) fm)
showSynonym :: (DocLike d,Monad m) => (HsType -> d) -> Name -> TypeSynonyms -> m d
showSynonym pprint n (TypeSynonyms m) =
case Map.lookup n m of
Just (ns, t, _) -> return $ hsep (tshow n:map tshow ns) <+> text "=" <+> pprint t
Nothing -> fail "key not found"
showSynonyms :: DocLike d => (HsType -> d) -> TypeSynonyms -> d
showSynonyms pprint (TypeSynonyms m) = vcat (map f (Map.toList m)) where
f (n,(ns,t,_)) = hsep (tshow n:map tshow ns) <+> text "=" <+> pprint t
declsToTypeSynonyms :: MonadWarn m => TypeSynonyms -> [HsDecl] -> m TypeSynonyms
declsToTypeSynonyms tsin ds = f tsin gr [] where
gr = G.scc $ G.newGraph [ (toName TypeConstructor name,( args , quantifyHsType args (HsQualType [] t) , sl)) | (HsTypeDecl sl name args' t) <- ds, let args = [ n | ~(HsTyVar n) <- args'] ] fst (Set.toList . freeVars . (\ (_,(_,t,_)) -> t))
f tsin (Right ns:xs) rs = do
warn (head [ sl | (_,(_,_,sl)) <- ns]) TypeSynonymRecursive ("Recursive type synonyms:" <+> show (fsts ns))
f tsin xs rs
f tsin (Left (n,(as,body,sl)):xs) rs = do
body' <- removeSynonymsFromType tsin body
f (tsInsert n (as,body',sl) tsin) xs ((n,(as,body',sl)):rs)
f _ [] rs = return $ TypeSynonyms (Map.fromList rs)
tsInsert x y (TypeSynonyms xs) = TypeSynonyms (Map.insert x y xs)
removeSynonymsFromType :: MonadWarn m => TypeSynonyms -> HsType -> m HsType
removeSynonymsFromType syns t = evalTypeSyms syns t
quantifyHsType :: [HsName] -> HsQualType -> HsType
quantifyHsType inscope t
| null vs, null (hsQualTypeContext t) = hsQualTypeType t
| otherwise = HsTyForall vs t where
vs = map g $ snub (execWriter (fv (hsQualTypeType t))) \\ inscope
g n = hsTyVarBind { hsTyVarBindName = n }
fv (HsTyVar v) = tell [v]
fv (HsTyForall vs qt) = tell $ snub (execWriter (fv $ hsQualTypeType qt)) \\ map hsTyVarBindName vs
fv (HsTyExists vs qt) = tell $ snub (execWriter (fv $ hsQualTypeType qt)) \\ map hsTyVarBindName vs
fv x = traverseHsType (\x -> fv x >> return x) x >> return ()
evalTypeSyms :: MonadWarn m => TypeSynonyms -> HsType -> m HsType
evalTypeSyms (TypeSynonyms tmap) t = execUniqT 1 (eval [] t) where
eval stack x@(HsTyCon n) | Just (args, t, sl) <- Map.lookup (toName TypeConstructor n) tmap = do
let excess = length stack length args
if (excess < 0) then do
lift $ warn sl TypeSynonymPartialAp ("Partially applied typesym:" <+> show n <+> "need" <+> show ( excess) <+> "more arguments.")
unwind x stack
else case t of
HsTyAssoc -> unwind x stack
_ -> do
st <- subst (Map.fromList [(a,s) | a <- args | s <- stack]) t
eval (drop (length args) stack) st
eval stack (HsTyApp t1 t2) = eval (t2:stack) t1
eval stack x = do
t <- traverseHsType (eval []) x
unwind t stack
unwind t [] = return t
unwind t (t1:rest) = do
t1' <- eval [] t1
unwind (HsTyApp t t1') rest
subst sm (HsTyForall vs t) = do
ns <- mapM (const newUniq) vs
let nvs = [ (hsTyVarBindName v,v { hsTyVarBindName = hsNameIdent_u ((show n ++ "00") ++) (hsTyVarBindName v)})| (n,v) <- zip ns vs ]
nsm = Map.fromList [ (v,HsTyVar $ hsTyVarBindName t)| (v,t) <- nvs] `Map.union` sm
t' <- substqt nsm t
return $ HsTyForall (snds nvs) t'
subst sm (HsTyExists vs t) = do
ns <- mapM (const newUniq) vs
let nvs = [ (hsTyVarBindName v,v { hsTyVarBindName = hsNameIdent_u (hsIdentString_u ((show n ++ "00") ++)) (hsTyVarBindName v)})| (n,v) <- zip ns vs ]
nsm = Map.fromList [ (v,HsTyVar $ hsTyVarBindName t)| (v,t) <- nvs] `Map.union` sm
t' <- substqt nsm t
return $ HsTyExists (snds nvs) t'
subst (sm::(Map.Map HsName HsType)) (HsTyVar n) | Just v <- Map.lookup n sm = return v
subst sm t = traverseHsType (subst sm) t
substqt sm qt@HsQualType { hsQualTypeContext = ps, hsQualTypeType = t } = do
t' <- subst sm t
let f (HsAsst c xs) = return (HsAsst c (map g xs))
f (HsAsstEq a b) = do
a' <- subst sm a
b' <- subst sm b
return (HsAsstEq a' b')
g n = case Map.lookup n sm of Just (HsTyVar n') -> n' ; _ -> n
ps' <- mapM f ps
return qt { hsQualTypeType = t', hsQualTypeContext = ps' }