module Control.Supermonad.Plugin.Environment
(
SupermonadPluginM
, runSupermonadPlugin
, runTcPlugin
, getBindClass, getReturnClass
, getSupermonadModule
, getGivenConstraints, getWantedConstraints
, getInstEnvs
, getBindInstances
, getSupermonadFor
, addTypeEqualities, addTypeEquality
, addTyVarEqualities, addTyVarEquality
, getTypeEqualities, getTyVarEqualities
, whenNoResults
, addWarning, displayWarnings
, throwPluginError, throwPluginErrorSDoc, catchPluginError
, assert, assertM
, printErr, printMsg, printObj, printWarn
, printConstraints
) where
import Data.List ( groupBy )
import Data.Map ( Map )
import qualified Data.Map as M
import Control.Monad ( unless, forM_ )
import Control.Monad.Reader ( ReaderT, runReaderT, asks )
import Control.Monad.State ( StateT , runStateT , gets, modify )
import Control.Monad.Except ( ExceptT, runExceptT, throwError, catchError )
import Control.Monad.Trans.Class ( lift )
import Class ( Class )
import Module ( Module )
import InstEnv ( InstEnvs, ClsInst )
import Type ( TyVar, Type )
import TyCon ( TyCon )
import TcRnTypes ( Ct )
import TcPluginM ( TcPluginM, tcPluginIO )
import qualified TcPluginM
import Outputable ( Outputable )
import SrcLoc ( srcSpanFileName_maybe )
import FastString ( unpackFS )
import qualified Outputable as O
import qualified Control.Supermonad.Plugin.Log as L
import Control.Supermonad.Plugin.Constraint
( GivenCt, WantedCt
, constraintSourceLocation )
import Control.Supermonad.Plugin.Detect
( findSupermonadModule
, findBindClass, findReturnClass
, findInstancesInScope
, bindClassName, returnClassName
, findSupermonads
, checkSupermonadInstances )
type SupermonadError = O.SDoc
type SupermonadPluginM = ReaderT SupermonadPluginEnv
( StateT SupermonadPluginState
( ExceptT SupermonadError TcPluginM
) )
data SupermonadPluginEnv = SupermonadPluginEnv
{ smEnvSupermonadModule :: Module
, smEnvBindClass :: Class
, smEnvReturnClass :: Class
, smEnvBindInstances :: [ClsInst]
, smEnvGivenConstraints :: [GivenCt]
, smEnvWantedConstraints :: [WantedCt]
, smEnvSupermonads :: Map TyCon (ClsInst, ClsInst)
}
data SupermonadPluginState = SupermonadPluginState
{ smStateTyVarEqualities :: [(Ct, TyVar, Type)]
, smStateTypeEqualities :: [(Ct, Type, Type)]
, smStateWarningQueue :: [(String, O.SDoc)]
}
runSupermonadPlugin
:: [GivenCt]
-> [WantedCt]
-> SupermonadPluginM a
-> TcPluginM (Either SupermonadError a)
runSupermonadPlugin givenCts wantedCts smM = do
mSupermonadMdl <- findSupermonadModule
mBindCls <- findBindClass
mReturnCls <- findReturnClass
(smInsts, smErrors) <- case (mBindCls, mReturnCls) of
(Just bindCls, Just returnCls) -> do
(smInsts, smErrors) <- findSupermonads bindCls returnCls
smCheckErrors <- checkSupermonadInstances bindCls returnCls
return $ (smInsts, fmap snd smErrors ++ fmap snd smCheckErrors)
(_, _) -> return mempty
case (mSupermonadMdl, mBindCls, mReturnCls, smErrors) of
(Right supermonadMdl, Just bindCls, Just returnCls, []) -> do
let initState = SupermonadPluginState
{ smStateTyVarEqualities = []
, smStateTypeEqualities = []
, smStateWarningQueue = []
}
bindInsts <- findInstancesInScope bindCls
eResult <- runExceptT $ flip runStateT initState $ runReaderT smM $ SupermonadPluginEnv
{ smEnvSupermonadModule = supermonadMdl
, smEnvBindClass = bindCls
, smEnvReturnClass = returnCls
, smEnvBindInstances = bindInsts
, smEnvGivenConstraints = givenCts
, smEnvWantedConstraints = wantedCts
, smEnvSupermonads = smInsts
}
return $ case eResult of
Left err -> Left err
Right (a, _res) -> Right a
(Left mdlErrMsg, _, _, _) -> do
let msg = "Could not find supermonad module:"
L.printErr msg
L.printErr $ L.sDocToStr mdlErrMsg
return $ Left $ stringToSupermonadError msg O.$$ mdlErrMsg
(_, Nothing, _, _) -> do
let msg = "Could not find " ++ bindClassName ++ " class!"
L.printErr msg
return $ Left $ stringToSupermonadError msg
(_, _, Nothing, _) -> do
let msg = "Could not find " ++ returnClassName ++ " class!"
L.printErr msg
return $ Left $ stringToSupermonadError msg
(_, _, _, _) -> do
let msg = "Problems when finding supermonad instances:"
let sdocErr = O.vcat smErrors
L.printErr msg
L.printErr $ L.sDocToStr sdocErr
return $ Left $ stringToSupermonadError msg O.$$ sdocErr
runTcPlugin :: TcPluginM a -> SupermonadPluginM a
runTcPlugin = lift . lift . lift
getBindClass :: SupermonadPluginM Class
getBindClass = asks smEnvBindClass
getReturnClass :: SupermonadPluginM Class
getReturnClass = asks smEnvReturnClass
getSupermonadModule :: SupermonadPluginM Module
getSupermonadModule = asks smEnvSupermonadModule
getGivenConstraints :: SupermonadPluginM [GivenCt]
getGivenConstraints = asks smEnvGivenConstraints
getWantedConstraints :: SupermonadPluginM [WantedCt]
getWantedConstraints = asks smEnvWantedConstraints
getBindInstances :: SupermonadPluginM [ClsInst]
getBindInstances = asks smEnvBindInstances
getInstEnvs :: SupermonadPluginM InstEnvs
getInstEnvs = runTcPlugin TcPluginM.getInstEnvs
getSupermonadFor :: TyCon -> SupermonadPluginM (Maybe (ClsInst, ClsInst))
getSupermonadFor tc = (return . M.lookup tc) =<< asks smEnvSupermonads
addTyVarEquality :: Ct -> TyVar -> Type -> SupermonadPluginM ()
addTyVarEquality ct tv ty = modify $ \s -> s { smStateTyVarEqualities = (ct, tv, ty) : smStateTyVarEqualities s }
addTyVarEqualities :: [(Ct, TyVar, Type)] -> SupermonadPluginM ()
addTyVarEqualities = mapM_ (\(ct, tv, ty) -> addTyVarEquality ct tv ty)
addTypeEquality :: Ct -> Type -> Type -> SupermonadPluginM ()
addTypeEquality ct ta tb = modify $ \s -> s { smStateTypeEqualities = (ct, ta, tb) : smStateTypeEqualities s }
addTypeEqualities :: [(Ct, Type, Type)] -> SupermonadPluginM ()
addTypeEqualities = mapM_ (\(ct, ta, tb) -> addTypeEquality ct ta tb)
getTyVarEqualities :: SupermonadPluginM [(Ct, TyVar, Type)]
getTyVarEqualities = gets $ smStateTyVarEqualities
getTypeEqualities :: SupermonadPluginM [(Ct, Type, Type)]
getTypeEqualities = gets $ smStateTypeEqualities
addWarning :: String -> O.SDoc -> SupermonadPluginM ()
addWarning msg details = modify $ \s -> s { smStateWarningQueue = (msg, details) : smStateWarningQueue s }
whenNoResults :: SupermonadPluginM () -> SupermonadPluginM ()
whenNoResults m = do
tyVarEqs <- getTyVarEqualities
tyEqs <- getTypeEqualities
if null tyVarEqs && null tyEqs
then m
else return ()
displayWarnings :: SupermonadPluginM ()
displayWarnings = whenNoResults $ do
warns <- gets smStateWarningQueue
forM_ warns $ \(msg, details) -> do
printWarn msg
internalPrint $ L.smObjMsg $ L.sDocToStr details
stringToSupermonadError :: String -> SupermonadError
stringToSupermonadError = O.text
assert :: Bool -> String -> SupermonadPluginM ()
assert cond msg = unless cond $ throwPluginError msg
assertM :: SupermonadPluginM Bool -> String -> SupermonadPluginM ()
assertM condM msg = do
cond <- condM
assert cond msg
throwPluginError :: String -> SupermonadPluginM a
throwPluginError = throwError . stringToSupermonadError
throwPluginErrorSDoc :: O.SDoc -> SupermonadPluginM a
throwPluginErrorSDoc = throwError
catchPluginError :: SupermonadPluginM a -> (SupermonadError -> SupermonadPluginM a) -> SupermonadPluginM a
catchPluginError = catchError
printObj :: Outputable o => o -> SupermonadPluginM ()
printObj = internalPrint . L.smObjMsg . L.pprToStr
printMsg :: String -> SupermonadPluginM ()
printMsg = internalPrint . L.smDebugMsg
printErr :: String -> SupermonadPluginM ()
printErr = internalPrint . L.smErrMsg
printWarn :: String -> SupermonadPluginM ()
printWarn = internalPrint . L.smWarnMsg
internalPrint :: String -> SupermonadPluginM ()
internalPrint = runTcPlugin . tcPluginIO . putStr
printFormattedObj :: String -> SupermonadPluginM ()
printFormattedObj = internalPrint . L.smObjMsg
printConstraints :: [Ct] -> SupermonadPluginM ()
printConstraints cts =
forM_ groupedCts $ \(file, ctGroup) -> do
printFormattedObj $ maybe "From unknown file:" (("From " ++) . (++":") . unpackFS) file
mapM_ (printFormattedObj . L.formatConstraint) ctGroup
where
groupedCts = (\ctGroup -> (getCtFile $ head ctGroup, ctGroup)) <$> groupBy eqFileName cts
eqFileName ct1 ct2 = getCtFile ct1 == getCtFile ct2
getCtFile = srcSpanFileName_maybe . constraintSourceLocation