module Control.Supermonad.Plugin
( plugin ) where
import Plugins ( Plugin(tcPlugin), defaultPlugin )
import TcRnTypes
( Ct(..)
, TcPlugin(..), TcPluginResult(..) )
import TcPluginM ( TcPluginM )
import Control.Supermonad.Plugin.Log ( sDocToStr )
import qualified Control.Supermonad.Plugin.Log as L
import Control.Supermonad.Plugin.Solving
( solveConstraints )
import Control.Supermonad.Plugin.Environment
( SupermonadPluginM, runSupermonadPlugin
, getWantedConstraints
, getTypeEqualities, getTyVarEqualities
, printMsg
)
import Control.Supermonad.Plugin.Constraint
( mkDerivedTypeEqCt, mkDerivedTypeEqCtOfTypes )
plugin :: Plugin
plugin = defaultPlugin { tcPlugin = \_clOpts -> Just supermonadPlugin }
type SupermonadState = ()
supermonadPlugin :: TcPlugin
supermonadPlugin = TcPlugin
{ tcPluginInit = supermonadInit
, tcPluginSolve = supermonadSolve
, tcPluginStop = supermonadStop
}
supermonadInit :: TcPluginM SupermonadState
supermonadInit = return ()
supermonadStop :: SupermonadState -> TcPluginM ()
supermonadStop _s = return ()
supermonadSolve :: SupermonadState -> [Ct] -> [Ct] -> [Ct] -> TcPluginM TcPluginResult
supermonadSolve s given derived wanted = do
res <- runSupermonadPlugin (given ++ derived) wanted $
if not $ null wanted then do
printMsg "Invoke supermonad plugin..."
supermonadSolve' s
tyVarEqs <- getTyVarEqualities
let tyVarEqCts = fmap (\(baseCt, tv, ty) -> mkDerivedTypeEqCt baseCt tv ty) tyVarEqs
tyEqs <- getTypeEqualities
let tyEqCts = fmap (\(baseCt, ta, tb) -> mkDerivedTypeEqCtOfTypes baseCt ta tb) tyEqs
return $ TcPluginOk [] $ tyVarEqCts ++ tyEqCts
else
return noResult
case res of
Left err -> do
L.printErr $ sDocToStr err
return noResult
Right solution -> return solution
supermonadSolve' :: SupermonadState -> SupermonadPluginM ()
supermonadSolve' _s = do
solveConstraints =<< getWantedConstraints
noResult :: TcPluginResult
noResult = TcPluginOk [] []