module TypeLet.Plugin (plugin) where
import Prelude hiding (cycle)
import Data.Traversable (forM)
import GHC.Plugins (Plugin(..), defaultPlugin, purePlugin)
import TypeLet.Plugin.Constraints
import TypeLet.Plugin.GhcTcPluginAPI
import TypeLet.Plugin.NameResolution
import TypeLet.Plugin.Substitution
plugin :: Plugin
plugin :: Plugin
plugin = Plugin
defaultPlugin {
pluginRecompile :: [CommandLineOption] -> IO PluginRecompile
pluginRecompile = [CommandLineOption] -> IO PluginRecompile
purePlugin
, tcPlugin :: TcPlugin
tcPlugin = \[CommandLineOption]
_cmdline -> forall a. a -> Maybe a
Just forall b c a. (b -> c) -> (a -> b) -> a -> c
. TcPlugin -> TcPlugin
mkTcPlugin forall a b. (a -> b) -> a -> b
$ TcPlugin {
tcPluginInit :: TcPluginM 'Init ResolvedNames
tcPluginInit = TcPluginM 'Init ResolvedNames
resolveNames
, tcPluginSolve :: ResolvedNames -> TcPluginSolver
tcPluginSolve = ResolvedNames -> TcPluginSolver
solve
, tcPluginRewrite :: ResolvedNames -> UniqFM TyCon TcPluginRewriter
tcPluginRewrite = \ResolvedNames
_st -> forall key elt. UniqFM key elt
emptyUFM
, tcPluginStop :: ResolvedNames -> TcPluginM 'Stop ()
tcPluginStop = \ResolvedNames
_st -> forall (m :: * -> *) a. Monad m => a -> m a
return ()
}
}
solve :: ResolvedNames -> TcPluginSolver
solve :: ResolvedNames -> TcPluginSolver
solve ResolvedNames
rn [Ct]
given [Ct]
wanted
| forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Ct]
wanted = ResolvedNames -> [Ct] -> TcPluginM 'Solve TcPluginSolveResult
simplifyGivens ResolvedNames
rn [Ct]
given
| Bool
otherwise = ResolvedNames -> TcPluginSolver
simplifyWanteds ResolvedNames
rn [Ct]
given [Ct]
wanted
simplifyGivens ::
ResolvedNames
-> [Ct]
-> TcPluginM 'Solve TcPluginSolveResult
simplifyGivens :: ResolvedNames -> [Ct] -> TcPluginM 'Solve TcPluginSolveResult
simplifyGivens ResolvedNames
_st [Ct]
_given = forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ [(EvTerm, Ct)] -> [Ct] -> TcPluginSolveResult
TcPluginOk [] []
simplifyWanteds ::
ResolvedNames
-> [Ct]
-> [Ct]
-> TcPluginM 'Solve TcPluginSolveResult
simplifyWanteds :: ResolvedNames -> TcPluginSolver
simplifyWanteds ResolvedNames
rn [Ct]
given [Ct]
wanted = do
case forall e a b. (a -> ParseResult e b) -> [a] -> Either e [b]
parseAll (ResolvedNames
-> Ct
-> ParseResult
(GenLocated CtLoc InvalidLet) (GenLocated CtLoc CLet)
parseLet ResolvedNames
rn) [Ct]
given of
Left GenLocated CtLoc InvalidLet
err ->
GenLocated CtLoc TcPluginErrorMessage
-> TcPluginM 'Solve TcPluginSolveResult
errWith forall a b. (a -> b) -> a -> b
$ InvalidLet -> TcPluginErrorMessage
formatInvalidLet forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> GenLocated CtLoc InvalidLet
err
Right [GenLocated CtLoc CLet]
lets -> do
case [GenLocated CtLoc CLet]
-> Either (Cycle (GenLocated CtLoc CLet)) TCvSubst
letsToSubst [GenLocated CtLoc CLet]
lets of
Left Cycle (GenLocated CtLoc CLet)
cycle ->
GenLocated CtLoc TcPluginErrorMessage
-> TcPluginM 'Solve TcPluginSolveResult
errWith forall a b. (a -> b) -> a -> b
$ Cycle (GenLocated CtLoc CLet)
-> GenLocated CtLoc TcPluginErrorMessage
formatLetCycle Cycle (GenLocated CtLoc CLet)
cycle
Right TCvSubst
subst -> do
([(EvTerm, Ct)]
solved, [Ct]
new) <- forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (forall a b. (a -> ParseResult Void b) -> [a] -> [b]
parseAll' (forall a e b. (a -> ParseResult e b) -> a -> ParseResult e (a, b)
withOrig (ResolvedNames -> Ct -> ParseResult Void (GenLocated CtLoc CEqual)
parseEqual ResolvedNames
rn)) [Ct]
wanted) forall a b. (a -> b) -> a -> b
$
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry (TCvSubst
-> Ct
-> GenLocated CtLoc CEqual
-> TcPluginM 'Solve ((EvTerm, Ct), Ct)
solveEqual TCvSubst
subst)
forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ [(EvTerm, Ct)] -> [Ct] -> TcPluginSolveResult
TcPluginOk [(EvTerm, Ct)]
solved [Ct]
new
where
newWanted' :: CtLoc -> PredType -> TcPluginM 'Solve CtEvidence
newWanted' :: CtLoc -> PredType -> TcPluginM 'Solve CtEvidence
newWanted' CtLoc
l PredType
w = forall (m :: * -> *) a. MonadTcPluginWork m => CtLoc -> m a -> m a
setCtLocM CtLoc
l forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *).
MonadTcPluginWork m =>
CtLoc -> PredType -> m CtEvidence
newWanted CtLoc
l PredType
w
errWith ::
GenLocated CtLoc TcPluginErrorMessage
-> TcPluginM 'Solve TcPluginSolveResult
errWith :: GenLocated CtLoc TcPluginErrorMessage
-> TcPluginM 'Solve TcPluginSolveResult
errWith (L CtLoc
l TcPluginErrorMessage
err) = do
PredType
errAsTyp <- forall (m :: * -> *).
MonadTcPluginWork m =>
TcPluginErrorMessage -> m PredType
mkTcPluginErrorTy TcPluginErrorMessage
err
CtEvidence -> TcPluginSolveResult
mkErr forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> CtLoc -> PredType -> TcPluginM 'Solve CtEvidence
newWanted' CtLoc
l PredType
errAsTyp
where
mkErr :: CtEvidence -> TcPluginSolveResult
mkErr :: CtEvidence -> TcPluginSolveResult
mkErr = [Ct] -> TcPluginSolveResult
TcPluginContradiction forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall a. a -> [a] -> [a]
:[]) forall b c a. (b -> c) -> (a -> b) -> a -> c
. CtEvidence -> Ct
mkNonCanonical
solveEqual ::
TCvSubst
-> Ct
-> GenLocated CtLoc CEqual
-> TcPluginM 'Solve ((EvTerm, Ct), Ct)
solveEqual :: TCvSubst
-> Ct
-> GenLocated CtLoc CEqual
-> TcPluginM 'Solve ((EvTerm, Ct), Ct)
solveEqual TCvSubst
subst Ct
orig (L CtLoc
l CEqual
parsed) = do
CtEvidence
ev <- CtLoc -> PredType -> TcPluginM 'Solve CtEvidence
newWanted' CtLoc
l forall a b. (a -> b) -> a -> b
$
Role -> PredType -> PredType -> PredType
mkPrimEqPredRole
Role
Nominal
(HasCallStack => TCvSubst -> PredType -> PredType
substTy TCvSubst
subst (CEqual -> PredType
equalLHS CEqual
parsed))
(HasCallStack => TCvSubst -> PredType -> PredType
substTy TCvSubst
subst (CEqual -> PredType
equalRHS CEqual
parsed))
forall (m :: * -> *) a. Monad m => a -> m a
return (
(ResolvedNames -> CEqual -> EvTerm
evidenceEqual ResolvedNames
rn CEqual
parsed, Ct
orig)
, CtEvidence -> Ct
mkNonCanonical CtEvidence
ev
)