module Type.State where
import Control.Applicative ( (<$>), (<*>), Applicative, (<|>) )
import Control.Monad.State
import qualified Data.Map as Map
import qualified Data.Traversable as Traversable
import qualified Data.UnionFind.IO as UF
import qualified AST.Annotation as A
import AST.PrettyPrint
import Text.PrettyPrint as P
import Type.Type
data Pool = Pool
{ maxRank :: Int
, inhabitants :: [Variable]
} deriving Show
emptyPool = Pool { maxRank = outermostRank, inhabitants = [] }
type Env = Map.Map String Variable
data SolverState = SS {
sEnv :: Env,
sSavedEnv :: Env,
sPool :: Pool,
sMark :: Int,
sErrors :: [P.Doc]
}
initialState = SS {
sEnv = Map.empty,
sSavedEnv = Map.empty,
sPool = emptyPool,
sMark = noMark + 1,
sErrors = []
}
modifyEnv f = modify $ \state -> state { sEnv = f (sEnv state) }
modifyPool f = modify $ \state -> state { sPool = f (sPool state) }
addError :: A.Region -> Maybe String -> UF.Point Descriptor -> UF.Point Descriptor
-> StateT SolverState IO ()
addError region hint t1 t2 =
do err <- liftIO makeError
modify $ \state -> state { sErrors = err : sErrors state }
where
makeError = do
t1' <- pretty <$> toSrcType t1
t2' <- pretty <$> toSrcType t2
return . P.vcat $
[ P.text "Type error" <+> pretty region <> P.colon
, maybe P.empty P.text hint
, P.text ""
, P.nest 8 $ A.getRegionDocs region
, P.text ""
, P.text " Expected Type:" <+> t1'
, P.text " Actual Type:" <+> t2'
]
switchToPool pool = modifyPool (\_ -> pool)
getPool :: StateT SolverState IO Pool
getPool = sPool <$> get
getEnv :: StateT SolverState IO Env
getEnv = sEnv <$> get
saveLocalEnv :: StateT SolverState IO ()
saveLocalEnv = do
env <- sEnv <$> get
modify $ \state -> state { sSavedEnv = env }
uniqueMark :: StateT SolverState IO Int
uniqueMark = do
state <- get
let mark = sMark state
put $ state { sMark = mark + 1 }
return mark
nextRankPool :: StateT SolverState IO Pool
nextRankPool = do
pool <- getPool
return $ Pool { maxRank = maxRank pool + 1, inhabitants = [] }
register :: Variable -> StateT SolverState IO Variable
register variable = do
modifyPool $ \pool -> pool { inhabitants = variable : inhabitants pool }
return variable
introduce :: Variable -> StateT SolverState IO Variable
introduce variable = do
pool <- getPool
liftIO $ UF.modifyDescriptor variable (\desc -> desc { rank = maxRank pool })
register variable
flatten :: Type -> StateT SolverState IO Variable
flatten term =
case term of
VarN maybeAlias v -> do
liftIO $ UF.modifyDescriptor v $ \desc -> desc { alias = maybeAlias <|> alias desc }
return v
TermN maybeAlias t -> do
flatStructure <- traverseTerm flatten t
pool <- getPool
var <- liftIO . UF.fresh $ Descriptor
{ structure = Just flatStructure
, rank = maxRank pool
, flex = Flexible
, name = Nothing
, copy = Nothing
, mark = noMark
, alias = maybeAlias
}
register var
makeInstance :: Variable -> StateT SolverState IO Variable
makeInstance var = do
alreadyCopied <- uniqueMark
freshVar <- makeCopy alreadyCopied var
restore alreadyCopied var
return freshVar
makeCopy :: Int -> Variable -> StateT SolverState IO Variable
makeCopy alreadyCopied variable = do
desc <- liftIO $ UF.descriptor variable
case () of
() | mark desc == alreadyCopied ->
case copy desc of
Just v -> return v
Nothing -> error $ "Error copying type variable. This should be impossible." ++
" Please report an error to the github repo!"
| rank desc /= noRank || flex desc == Constant ->
return variable
| otherwise -> do
pool <- getPool
newVar <- liftIO $ UF.fresh $ Descriptor
{ structure = Nothing
, rank = maxRank pool
, mark = noMark
, flex = case flex desc of
Is s -> Is s
_ -> Flexible
, copy = Nothing
, name = case flex desc of
Rigid -> Nothing
_ -> name desc
, alias = Nothing
}
register newVar
liftIO $ UF.modifyDescriptor variable $ \desc ->
desc { mark = alreadyCopied, copy = Just newVar }
case structure desc of
Nothing -> return newVar
Just term -> do
newTerm <- traverseTerm (makeCopy alreadyCopied) term
liftIO $ UF.modifyDescriptor newVar $ \desc ->
desc { structure = Just newTerm }
return newVar
restore :: Int -> Variable -> StateT SolverState IO Variable
restore alreadyCopied variable = do
desc <- liftIO $ UF.descriptor variable
if mark desc /= alreadyCopied
then return variable
else do
restoredStructure <-
Traversable.traverse (traverseTerm (restore alreadyCopied)) (structure desc)
liftIO $ UF.modifyDescriptor variable $ \desc ->
desc { mark = noMark, rank = noRank, structure = restoredStructure }
return variable
traverseTerm :: (Monad f, Applicative f) => (a -> f b) -> Term1 a -> f (Term1 b)
traverseTerm f term =
case term of
App1 a b -> App1 <$> f a <*> f b
Fun1 a b -> Fun1 <$> f a <*> f b
Var1 x -> Var1 <$> f x
EmptyRecord1 -> return EmptyRecord1
Record1 fields ext ->
Record1 <$> Traversable.traverse (mapM f) fields <*> f ext