module Language.Haskell.TH.KindInference (inferKind) where
import Control.Monad.Trans
import Data.Ord
import Debug.Trace
import Data.Map((!))
import Data.Set
import Control.Monad.State.Strict
import Text.ParserCombinators.ReadP hiding (get)
import Language.Haskell.TH hiding (AppE)
import Language.Haskell.TH.Unification
import Language.Haskell.TH.PprLib hiding (empty, char)
import qualified Language.Haskell.TH.PprLib as Ppr
type KindUTerm = Term KindFunc Type KindAtom
type KindUT = UnifT KindFunc Type KindAtom
instance Ord Type where
compare = comparing show
type LoopKillerT = StateT (Set Name)
data KindFunc = KindArrow deriving (Eq, Show)
data KindAtom = Star deriving (Eq, Show)
inferKind :: Name -> Q (Either String Kind)
inferKind name = do
ans <- solveUnification defaultKind (evalStateT (infer (ConT name)) empty)
either (return . Left) (\ (x, sol) -> return (Right $ termToK (subTerm defaultKind sol x))) ans
defaultKind :: Explicit KindFunc KindAtom
defaultKind = AtomE Star
termToK :: Explicit KindFunc KindAtom -> Kind
termToK (AppE ~KindArrow t1 t2) = termToK t1 `ArrowK` termToK t2
termToK (AtomE ~Star) = StarK
infer :: Type -> LoopKillerT (KindUT Q) KindUTerm
infer (TupleT n) = return (tupleKind n star)
infer ArrowT = return (tupleKind 2 star)
infer ListT = return (tupleKind 1 star)
infer (AppT f x) = do
fK <- infer f
xK <- infer x
let var = Var (AppT f x)
unify fK (xK ->- var)
return var
infer (SigT t k) = do
tK <- infer t
unify tK (kToTerm k)
return tK
infer (ForallT bdrs cxt t) = do
mapM_ handleBdr bdrs
mapM_ handleCxt cxt
infer t
infer t@VarT{} = return $ Var t
infer (ConT t) = do
examine (Just t) t
return (tyCon t)
matchUnboxedTuple :: ReadP Int
matchUnboxedTuple = do
string "(#"
munchComma 1
where munchComma k = k `seq` ((do
char ','
munchComma (k+1)) <++ (do
string "#)"
return k))
examine :: Maybe Name -> Name -> LoopKillerT (KindUT Q) ()
examine name0 name = do
mUnify name0 (tyVar name)
case [n | (n, "") <- readP_to_S matchUnboxedTuple (nameBase name)] of
(n:_) -> unify (tyVar name) (tupleKind n star)
_ -> do
inf <- lift $ lift $ reify name
case inf of
ClassI dec -> examineDec name0 dec
TyConI dec -> examineDec name0 dec
PrimTyConI name n _ -> unify (tyVar name) (tupleKind n star)
TyVarI name typ -> do
kind <- infer typ
unify (tyVar name) kind
_ -> return ()
mUnify :: Maybe Name -> KindUTerm -> LoopKillerT (KindUT Q) ()
mUnify name0 k = case name0 of
Just name0 -> unify (tyCon name0) k
_ -> return ()
examineDec :: Maybe Name -> Dec -> LoopKillerT (KindUT Q) ()
examineDec name0 (DataD cxt name bdrs cons _) = do
visited <- get
unless (name `member` visited) $ do
modify (insert name)
mapM_ handleCxt cxt
args <- mapM handleBdr bdrs
unify (tyCon name) (foldr (->-) star args)
mUnify name0 (tyCon name)
mapM_ handleCon cons
examineDec name0 (NewtypeD cxt name bdrs con _) = do
visited <- get
unless (name `member` visited) $ do
modify (insert name)
mapM_ handleCxt cxt
args <- mapM handleBdr bdrs
unify (tyCon name) (foldr (->-) star args)
mUnify name0 (tyCon name)
handleCon con
examineDec name0 (ClassD cxt name bdrs _ _) = do
visited <- get
unless (name `member` visited) $ do
modify (insert name)
mapM_ handleCxt cxt
args <- mapM handleBdr bdrs
unify (tyCon name) (foldr (->-) star args)
mUnify name0 (tyCon name)
examineDec name0 (FamilyD _ name bdrs mK) = do
visited <- get
unless (name `member` visited) $ do
modify (insert name)
mapM_ handleBdr bdrs
unify (tyCon name) (maybe star kToTerm mK)
mUnify name0 (tyCon name)
examineDec name0 (TySynD name bdrs typ) = do
visited <- get
unless (name `member` visited) $ do
modify (insert name)
args <- mapM handleBdr bdrs
kind <- infer typ
unify (tyCon name) (foldr (->-) kind args)
mUnify name0 (tyCon name)
examineDec _ _ = return ()
handleCon :: Con -> LoopKillerT (KindUT Q) ()
handleCon (NormalC _ ts) = mapM_ (\ (_, t) -> infer t >>= unify star) ts
handleCon (RecC _ ts) = mapM_ (\ (_, _, t) -> infer t >>= unify star) ts
handleCon (InfixC (_, t1) _ (_, t2)) = do
infer t1 >>= unify star
infer t2 >>= unify star
handleCon (ForallC bdrs cxt con) = do
mapM_ handleBdr bdrs
mapM_ handleCxt cxt
handleCon con
tyCon :: Name -> KindUTerm
tyCon = Var . ConT
tyVar :: Name -> KindUTerm
tyVar = Var . VarT
handleBdr :: TyVarBndr -> LoopKillerT (KindUT Q) KindUTerm
handleBdr (PlainTV n) = return (tyVar n)
handleBdr (KindedTV n k) = do
unify (tyVar n) (kToTerm k)
return (tyVar n)
handleCxt :: Pred -> LoopKillerT (KindUT Q) ()
handleCxt (ClassP name args) = do
kinds <- mapM infer args
unify (Var (ConT name)) (foldr (->-) star kinds)
examine (Just name) name
handleCxt (EqualP t1 t2) = do
k1 <- infer t1
k2 <- infer t2
unify k1 k2
kToTerm :: Kind -> KindUTerm
kToTerm (ArrowK a b) = kToTerm a ->- kToTerm b
kToTerm StarK = star
(->-) :: KindUTerm -> KindUTerm -> KindUTerm
(->-) = App KindArrow
star :: KindUTerm
star = Atom Star
tupleKind :: Int -> KindUTerm -> KindUTerm
tupleKind n k = foldr (->-) k (replicate n star)
instance (Ppr a, Ppr b) => Ppr (Either a b) where
ppr (Left x) = text "Left" <+> parens (ppr x)
ppr (Right x) = text "Right" <+> parens (ppr x)
instance Ppr Char where
ppr = Ppr.char