{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE ViewPatterns #-}
{-# LANGUAGE MultiParamTypeClasses #-}
module RSolve.HM
where
import RSolve.Logic
import RSolve.Solver
import RSolve.MultiState
import Control.Lens (Lens', view, over, makeLenses)
import Control.Applicative
import Control.Monad
import Debug.Trace
import qualified Data.List as L
import qualified Data.Map as M
import qualified Data.Set as S
type Fix a = a -> a
infixl 6 :->, :*
data T
= TVar Int
| TFresh String
| T :-> T
| T :* T
| TForall (S.Set String) T
| TApp T T
| TNom Int
deriving (Eq, Ord)
deConsTOp :: T -> Maybe (T -> T -> T, T, T)
deConsTOp = \case
a :-> b -> Just ((:->), a, b)
a :* b -> Just ((:*), a, b)
TApp a b -> Just (TApp, a, b)
_ -> Nothing
instance Show T where
show = \case
TVar idx -> "@" ++ show idx
TFresh s -> s
a :-> b -> showNest a ++ " -> " ++ show b
a :* b -> showNest a ++ " * " ++ show b
TForall l t -> "forall " ++ (unwords $ S.toList l) ++ ". " ++ show t
TApp t1 t2 -> show t1 ++ " " ++ showNest t2
TNom i -> "@t" ++ show i
where
showNest s
| isNest s = "(" ++ show s ++ ")"
| otherwise = show s
isNest s = case s of
TApp _ _ -> True
TForall _ s -> isNest s
_ :-> _ -> True
_ :* _ -> True
_ -> False
data Unif
= Unif {
lhs :: T
, rhs :: T
, neq :: Bool
}
deriving (Eq, Ord)
instance Show Unif where
show Unif {lhs, rhs, neq} =
let op = if neq then " /= " else " == "
in show lhs ++ op ++ show rhs
instance AtomF Unif where
notA a@Unif {neq} = [a {neq = not neq}]
data TCEnv = TCEnv {
_noms :: M.Map Int T
, _tvars :: M.Map Int T
, _neqs :: S.Set (T, T)
}
deriving (Show)
emptyTCEnv = TCEnv M.empty M.empty S.empty
makeLenses ''TCEnv
newTVar :: MS TCEnv Int
newTVar = do
i <- getsMS $ M.size . view tvars
modifyMS $ over tvars $ M.insert i (TVar i)
return i
newTNom :: MS TCEnv Int
newTNom = do
i <- getsMS $ M.size . view noms
modifyMS $ over noms $ M.insert i (TNom i)
return i
loadTVar :: Int -> MS TCEnv T
loadTVar i = getsMS $ (M.! i) . view tvars
occurIn :: Int -> T -> MS TCEnv Bool
occurIn l = contains
where
contains (deConsTOp -> Just (_, a, b)) = (||) <$> contains a <*> contains b
contains (TNom _) = return False
contains (TForall _ a) = contains a
contains (TFresh _) = return False
contains (TVar a)
| a == l = return True
| otherwise = do
tvar <- loadTVar a
case tvar of
TVar a' | a' == a -> return False
_ -> contains tvar
free :: M.Map String T -> T -> T
free m = mkFree
where
mkFree (deConsTOp -> Just (op, a, b)) = op (mkFree a) (mkFree b)
mkFree a@(TNom i) = a
mkFree (TForall n t) = TForall n $ flip free t $ M.withoutKeys m n
mkFree a@(TVar _) = a
mkFree a@(TFresh id) = M.findWithDefault a id m
prune :: T -> MS TCEnv T
prune = \case
(deConsTOp -> Just (op, a, b)) -> op <$> prune a <*> prune b
a@(TNom i) -> return a
TVar i ->
loadTVar i >>= \case
a@(TVar i') | i' == i -> return a
a -> do
t <- prune a
update i t
return t
a@(TFresh _) -> return a
TForall a b -> TForall a <$> prune b
update :: Int -> T -> MS TCEnv ()
update i t = modifyMS $ over tvars $ M.insert i t
addNEq :: (T, T) -> MS TCEnv ()
addNEq t = modifyMS $ over neqs (S.insert t)
unify :: Fix (Unif -> MS TCEnv ())
unify self Unif {lhs, rhs, neq=True} = addNEq (lhs, rhs)
unify self Unif {lhs=TNom a, rhs=TNom b}
| a == b = return ()
| otherwise = empty
unify self Unif {lhs=TVar a, rhs = TVar b} = do
recursive <- occurIn a (TVar b)
if recursive
then error "ill formed definition like a = a -> b"
else update a (TVar b)
unify self Unif {lhs=TVar id, rhs, neq} = update id rhs
unify self a@Unif {lhs, rhs=rhs@(TVar _)} = self a {lhs=rhs, rhs=lhs}
unify self Unif {lhs=l1 :-> l2, rhs= r1 :-> r2} =
self Unif {lhs=l1, rhs=r1, neq=False} >>
self Unif {lhs=l2, rhs=r2, neq=False}
unify self Unif {lhs=l1 :* l2, rhs= r1 :* r2} =
self Unif {lhs=l1, rhs=r1, neq=False} >>
self Unif {lhs=l2, rhs=r2, neq=False}
unify self Unif {lhs=TApp l1 l2, rhs= TApp r1 r2} =
self Unif {lhs=l1, rhs=r1, neq=False} >>
self Unif {lhs=l2, rhs=r2, neq=False}
unify self Unif {lhs=TForall freevars poly, rhs} = do
pairs <- mapM freepair $ S.toList freevars
let freemap = M.fromList pairs
let l = free freemap poly
self Unif {lhs=l, rhs=rhs, neq=False}
where freepair freevar = (freevar,) . TVar <$> newTVar
unify self a@Unif {lhs, rhs=rhs@(TForall _ _)} =
self a {lhs=rhs, rhs=lhs}
instance CtxSolver TCEnv Unif where
solve =
let frec = unify (pruneUnif >=> frec)
in pruneUnif >=> frec
where
pruneUnif a@Unif {neq=True} = return a
pruneUnif a@Unif {lhs, rhs} = do
lhs <- prune lhs
rhs <- prune rhs
return $ a {lhs=lhs , rhs=rhs}