{-# LANGUAGE FlexibleContexts #-} module Type.State where import Control.Applicative ( Applicative, (<$>), (<*>), (<|>) ) import Control.Monad.State import Data.Map ((!)) import qualified Data.Map as Map import qualified Data.Traversable as Traversable import qualified Data.UnionFind.IO as UF import qualified Type.Hint as Hint import Type.Type -- Pool -- Holds a bunch of variables -- The rank of each variable is less than or equal to the pool's "maxRank" -- The young pool exists to make it possible to identify these vars in constant time. data Pool = Pool { maxRank :: Int , inhabitants :: [Variable] } emptyPool :: Pool emptyPool = Pool { maxRank = outermostRank , inhabitants = [] } type Env = Map.Map String Variable -- Keeps track of the environment, type variable pool, and a list of errors data SolverState = SS { sEnv :: Env , sSavedEnv :: Env , sPool :: Pool , sMark :: Int , sHint :: [Hint.Hint] } initialState :: SolverState initialState = SS { sEnv = Map.empty , sSavedEnv = Map.empty , sPool = emptyPool , sMark = noMark + 1 -- The mark must never be equal to noMark! , sHint = [] } modifyEnv :: (MonadState SolverState m) => (Env -> Env) -> m () modifyEnv f = modify $ \state -> state { sEnv = f (sEnv state) } modifyPool :: (MonadState SolverState m) => (Pool -> Pool) -> m () modifyPool f = modify $ \state -> state { sPool = f (sPool state) } addHint :: Hint.Hint -> StateT SolverState IO () addHint hint = modify $ \state -> state { sHint = hint : sHint state } switchToPool :: (MonadState SolverState m) => Pool -> m () switchToPool pool = modifyPool (\_ -> pool) getPool :: StateT SolverState IO Pool getPool = gets sPool getEnv :: StateT SolverState IO Env getEnv = gets sEnv saveLocalEnv :: StateT SolverState IO () saveLocalEnv = do currentEnv <- getEnv modify $ \state -> state { sSavedEnv = currentEnv } uniqueMark :: StateT SolverState IO Int uniqueMark = do state <- get let mark = sMark state put $ state { sMark = mark + 1 } return mark nextRankPool :: StateT SolverState IO Pool nextRankPool = do pool <- getPool return $ Pool { maxRank = maxRank pool + 1, inhabitants = [] } register :: Variable -> StateT SolverState IO Variable register variable = do modifyPool $ \pool -> pool { inhabitants = variable : inhabitants pool } return variable introduce :: Variable -> StateT SolverState IO Variable introduce variable = do pool <- getPool liftIO $ UF.modifyDescriptor variable (\desc -> desc { rank = maxRank pool }) register variable flatten :: Type -> StateT SolverState IO Variable flatten term = flattenHelp Map.empty term flattenHelp :: Map.Map String Variable -> Type -> StateT SolverState IO Variable flattenHelp aliasDict term = case term of PlaceHolder name -> return (aliasDict ! name) VarN maybeAlias v -> do maybeAlias' <- Traversable.traverse flattenAlias maybeAlias liftIO $ UF.modifyDescriptor v $ \desc -> desc { alias = maybeAlias' <|> alias desc } return v TermN maybeAlias subTerm -> do maybeAlias' <- Traversable.traverse flattenAlias maybeAlias let localDict = maybe aliasDict (Map.fromList . snd) maybeAlias' flatStructure <- traverseTerm (flattenHelp localDict) subTerm pool <- getPool var <- liftIO . UF.fresh $ Descriptor { structure = Just flatStructure , rank = maxRank pool , flex = Flexible , name = Nothing , copy = Nothing , mark = noMark , alias = maybeAlias' } register var where flattenAlias (name, args) = let flattenPair (arg, subTerm) = (,) arg <$> flattenHelp aliasDict subTerm in (,) name <$> mapM flattenPair args makeInstance :: Variable -> StateT SolverState IO Variable makeInstance var = do alreadyCopied <- uniqueMark freshVar <- makeCopy alreadyCopied var restore alreadyCopied var return freshVar makeCopy :: Int -> Variable -> StateT SolverState IO Variable makeCopy alreadyCopied variable = do desc <- liftIO $ UF.descriptor variable case () of () | mark desc == alreadyCopied -> case copy desc of Just v -> return v Nothing -> error $ "Error copying type variable. This should be impossible." ++ " Please report an error to the github repo!" | rank desc /= noRank || flex desc == Constant -> return variable | otherwise -> do pool <- getPool newVar <- liftIO $ UF.fresh $ Descriptor { structure = Nothing , rank = maxRank pool , mark = noMark , flex = case flex desc of Is s -> Is s _ -> Flexible , copy = Nothing , name = case flex desc of Rigid -> Nothing _ -> name desc , alias = Nothing } register newVar -- Link the original variable to the new variable. This lets us -- avoid making multiple copies of the variable we are instantiating. -- -- Need to do this before recursively copying the structure of -- the variable to avoid looping on cyclic terms. liftIO $ UF.modifyDescriptor variable $ \desc -> desc { mark = alreadyCopied, copy = Just newVar } -- Now we recursively copy the structure of the variable. -- We have already marked the variable as copied, so we -- will not repeat this work or crawl this variable again. case structure desc of Nothing -> return newVar Just term -> do newTerm <- traverseTerm (makeCopy alreadyCopied) term liftIO $ UF.modifyDescriptor newVar $ \desc -> desc { structure = Just newTerm } return newVar restore :: Int -> Variable -> StateT SolverState IO Variable restore alreadyCopied variable = do desc <- liftIO $ UF.descriptor variable case mark desc /= alreadyCopied of True -> return variable False -> do restoredStructure <- Traversable.traverse (traverseTerm (restore alreadyCopied)) (structure desc) liftIO $ UF.modifyDescriptor variable $ \desc -> desc { mark = noMark, rank = noRank, structure = restoredStructure } return variable traverseTerm :: (Monad f, Applicative f) => (a -> f b) -> Term1 a -> f (Term1 b) traverseTerm f term = case term of App1 a b -> App1 <$> f a <*> f b Fun1 a b -> Fun1 <$> f a <*> f b Var1 x -> Var1 <$> f x EmptyRecord1 -> return EmptyRecord1 Record1 fields ext -> Record1 <$> Traversable.traverse (mapM f) fields <*> f ext