module Control.Supermonad.Plugin.Detect
(
supermonadModuleName
, bindClassName, returnClassName
, findSupermonadModule
, isBindClass, isReturnClass
, isSupermonadModule
, findBindClass, findReturnClass
, findSupermonads
, checkSupermonadInstances
, identityModuleName
, identityTyConName
, findIdentityModule
, findIdentityTyCon
, functorClassName, functorModuleName
, findInstancesInScope
) where
import Data.List ( find )
import Data.Maybe ( catMaybes, listToMaybe )
import qualified Data.Set as S
import qualified Data.Map as M
import Control.Monad ( forM, liftM )
import BasicTypes ( Arity )
import TcRnTypes
( TcGblEnv(..)
, TcTyThing(..)
, ImportAvails( imp_mods ) )
import Type ( TyThing(..) )
import TyCon ( TyCon )
import TcPluginM
( TcPluginM
, getEnvs, getInstEnvs
, tcLookup )
import Name
( nameModule
, getOccName )
import OccName
( OccName
, occNameString, mkTcOcc )
import RdrName
( GlobalRdrElt(..)
, Parent( NoParent )
, lookupGlobalRdrEnv )
import Module
( Module, ModuleName
, moduleName
, moduleEnvKeys
, mkModuleName )
import Class
( Class(..)
, className, classArity )
import InstEnv
( ClsInst(..)
, instEnvElts
, ie_global
, classInstances )
import PrelNames ( mAIN_NAME )
import Outputable ( SDoc, ($$), (<>), text, vcat, ppr )
import Control.Supermonad.Plugin.Wrapper
( UnitId, baseUnitId, moduleUnitId, isImportedFrom )
import Control.Supermonad.Plugin.Instance
( instanceTyArgs
, isMonoTyConInstance
, isPolyTyConInstance )
import Control.Supermonad.Plugin.Utils
( collectTopTyCons )
supermonadModuleName :: String
supermonadModuleName = "Control.Supermonad"
supermonadCtModuleName :: String
supermonadCtModuleName = "Control.Supermonad.Constrained"
bindClassName :: String
bindClassName = "Bind"
returnClassName :: String
returnClassName = "Return"
identityModuleName :: String
identityModuleName = "Data.Functor.Identity"
identityTyConName :: String
identityTyConName = "Identity"
supermonadPreludeModuleName :: String
supermonadPreludeModuleName = "Control.Supermonad.Prelude"
supermonadCtPreludeModuleName :: String
supermonadCtPreludeModuleName = "Control.Supermonad.Constrained.Prelude"
functorClassName :: String
functorClassName = "Functor"
functorModuleName :: String
functorModuleName = "Data.Functor"
findSupermonadModule :: TcPluginM (Either SDoc Module)
findSupermonadModule = do
eSmUnCtMdl <- findSupermonadUnCtModule
eSmCtMdl <- findSupermonadCtModule
case (eSmUnCtMdl, eSmCtMdl) of
(Right _ , Left _errCt) -> return eSmUnCtMdl
(Left _err, Right _ ) -> return eSmCtMdl
(Left err, Left errCt) -> return $ Left
$ text "Could not find supermonad or constrained supermonad modules!" $$ err $$ errCt
(Right _, Right _) -> return $ Left
$ text "Found unconstrained and constrained supermonad modules!"
findSupermonadUnCtModule :: TcPluginM (Either SDoc Module)
findSupermonadUnCtModule = do
eMdl <- getModule Nothing supermonadModuleName
case eMdl of
Left _err -> getModule Nothing supermonadPreludeModuleName
Right _ -> return eMdl
findSupermonadCtModule :: TcPluginM (Either SDoc Module)
findSupermonadCtModule = do
eCtMdl <- getModule Nothing supermonadCtModuleName
case eCtMdl of
Left _err -> getModule Nothing supermonadCtPreludeModuleName
Right _ -> return eCtMdl
isSupermonadModule :: Module -> Bool
isSupermonadModule mdl = mdlName `elem` [smMdlName, smPrelName, smCtMdlName, smCtPrelName, mAIN_NAME]
where mdlName = moduleName mdl
smMdlName = mkModuleName supermonadModuleName
smPrelName = mkModuleName supermonadPreludeModuleName
smCtMdlName = mkModuleName supermonadCtModuleName
smCtPrelName = mkModuleName supermonadCtPreludeModuleName
isBindClass :: Class -> Bool
isBindClass cls = isClass cls isSupermonadModule bindClassName 3
isReturnClass :: Class -> Bool
isReturnClass cls = isClass cls isSupermonadModule returnClassName 1
findBindClass :: TcPluginM (Maybe Class)
findBindClass = findClass isBindClass
findReturnClass :: TcPluginM (Maybe Class)
findReturnClass = findClass isReturnClass
findIdentityModule :: TcPluginM (Either SDoc Module)
findIdentityModule = do
mdls <- findModules [getModule (Just baseUnitId) identityModuleName, findSupermonadModule]
case mdls of
[] -> return $ Left $ text "Could not find module 'Identity' module."
(mdl:_) -> return $ Right mdl
findIdentityTyCon :: TcPluginM (Maybe TyCon)
findIdentityTyCon = do
mdls <- findModules [findIdentityModule, findSupermonadModule]
case mdls of
[] -> return Nothing
_ -> findTyConByNameAndModule (mkTcOcc identityTyConName) mdls
findModules :: [TcPluginM (Either SDoc Module)] -> TcPluginM [Module]
findModules findMdls = do
eitherMdls <- sequence findMdls
return $ catMaybes $ fmap (either (const Nothing) Just) eitherMdls
getModule :: Maybe UnitId -> String -> TcPluginM (Either SDoc Module)
getModule pkgKeyToFind mdlNameToFind = do
(gblEnv, _lclEnv) <- getEnvs
let mdls = moduleEnvKeys $ imp_mods $ tcg_imports $ gblEnv
case find (isModule . splitModule) mdls of
Just mdl -> return $ Right mdl
Nothing -> return $ Left $ text $ "Could not find module '" ++ mdlNameToFind ++ "'"
where
isModule :: (UnitId, ModuleName) -> Bool
isModule (pkgKey, mdlName)
= maybe True (pkgKey ==) pkgKeyToFind
&& mdlName == mkModuleName mdlNameToFind
splitModule :: Module -> (UnitId, ModuleName)
splitModule mdl = (moduleUnitId mdl, moduleName mdl)
findClass :: (Class -> Bool) -> TcPluginM (Maybe Class)
findClass isClass' = do
let isCls = isClass' . is_cls
envs <- fst <$> getEnvs
let foundInstsLcl = (filter isCls . instEnvElts . tcg_inst_env $ envs)
++ (filter isCls . tcg_insts $ envs)
foundInstsGbl <- filter isCls . instEnvElts . ie_global <$> getInstEnvs
return $ case foundInstsLcl ++ foundInstsGbl of
(inst : _) -> Just $ is_cls inst
[] -> Nothing
isClass :: Class -> (Module -> Bool) -> String -> Arity -> Bool
isClass cls isModule targetClassName targetArity =
let clsName = className cls
clsMdl = nameModule clsName
clsNameStr = occNameString $ getOccName clsName
clsArity = classArity cls
in isModule clsMdl
&& clsNameStr == targetClassName
&& clsArity == targetArity
findTyConByNameAndModule :: OccName -> [Module] -> TcPluginM (Maybe TyCon)
findTyConByNameAndModule occName mdls = do
rdrEnv <- tcg_rdr_env . fst <$> getEnvs
let envResultElem = lookupGlobalRdrEnv rdrEnv occName
let relResults = filter
(\e -> any (e `isImportedFrom`) mdls && hasNoParent e)
envResultElem
mTyCons <- forM relResults $ liftM tcTyThingToTyCon . tcLookup . gre_name
let tyCons = catMaybes mTyCons
return $ listToMaybe tyCons
tcTyThingToTyCon :: TcTyThing -> Maybe TyCon
tcTyThingToTyCon (AGlobal (ATyCon tc)) = Just tc
tcTyThingToTyCon _ = Nothing
hasNoParent :: GlobalRdrElt -> Bool
hasNoParent rdrElt = case gre_par rdrElt of
NoParent -> True
_ -> False
findInstancesInScope :: Class -> TcPluginM [ClsInst]
findInstancesInScope cls = do
instEnvs <- TcPluginM.getInstEnvs
return $ classInstances instEnvs cls
checkSupermonadInstances
:: Class
-> Class
-> TcPluginM [(ClsInst, SDoc)]
checkSupermonadInstances bindCls returnCls = do
bindInsts <- findInstancesInScope bindCls
returnInsts <- findInstancesInScope returnCls
let polyBindInsts = filter (isPolyTyConInstance bindCls ) bindInsts
let polyReturnInsts = filter (isPolyTyConInstance returnCls) returnInsts
return $ fmap (\inst -> (inst, text "Not a valid supermonad instance: " $$ ppr inst)) polyBindInsts
++ fmap (\inst -> (inst, text "Not a valid supermonad instance: " $$ ppr inst)) polyReturnInsts
findSupermonads
:: Class
-> Class
-> TcPluginM (M.Map TyCon (ClsInst, ClsInst), [(TyCon, SDoc)])
findSupermonads bindCls returnCls = do
bindInsts <- findInstancesInScope bindCls
returnInsts <- findInstancesInScope returnCls
let supermonadTyCons = S.unions $ fmap instTopTyCons $ bindInsts ++ returnInsts
return $ mconcat
$ fmap (findSupermonad bindInsts returnInsts)
$ S.toList supermonadTyCons
where
findSupermonad :: [ClsInst] -> [ClsInst] -> TyCon -> (M.Map TyCon (ClsInst, ClsInst), [(TyCon, SDoc)])
findSupermonad bindInsts returnInsts tc =
case ( filter (isMonoTyConInstance tc bindCls) bindInsts
, filter (isMonoTyConInstance tc returnCls) returnInsts ) of
([bindInst], [returnInst]) -> (M.singleton tc (bindInst, returnInst), [])
([], _) -> findError tc
$ text "Missing 'Bind' instance for supermonad '" <> ppr tc <> text "'."
(_, []) -> findError tc
$ text "Missing 'Return' instance for supermonad '" <> ppr tc <> text "'."
(bindInsts', returnInsts') -> findError tc
$ text "Multiple 'Bind' instances for supermonad '" <> ppr tc <> text "':" $$ vcat (fmap ppr bindInsts')
$$ text "Multiple 'Return' instances for supermonad '" <> ppr tc <> text "':" $$ vcat (fmap ppr returnInsts')
findError :: TyCon -> SDoc -> (M.Map TyCon (ClsInst, ClsInst), [(TyCon, SDoc)])
findError tc msg = (M.empty, [(tc, msg)])
instTopTyCons :: ClsInst -> S.Set TyCon
instTopTyCons = collectTopTyCons . instanceTyArgs