{-# LANGUAGE CPP #-}
{-# LANGUAGE NoMonomorphismRestriction #-}
module Polysemy.Plugin
( plugin
) where
import Class
import CoAxiom
import Control.Monad
import CoreMonad
import Data.Maybe
import DynFlags
import FastString (fsLit)
import GHC (ModuleName, moduleName)
import GHC.TcPluginM.Extra (lookupModule, lookupName)
import Module (mkModuleName, moduleSetElts)
import OccName (mkTcOcc)
import Outputable
import TcPluginM (TcPluginM, tcLookupClass)
import TcRnTypes
import TcSMonad hiding (tcLookupClass)
import TyCoRep (Type (..))
import Type
import Plugins (Plugin (..), defaultPlugin
#if __GLASGOW_HASKELL__ >= 806
, PluginRecompile(..)
#endif
)
plugin :: Plugin
plugin = defaultPlugin
{ tcPlugin = const $ Just fundepPlugin
#if __GLASGOW_HASKELL__ >= 810
, installCoreToDos = const installTodos
#endif
#if __GLASGOW_HASKELL__ >= 806
, pluginRecompile = const $ pure NoForceRecompile
#endif
}
polysemyInternal :: ModuleName
polysemyInternal = mkModuleName "Polysemy.Internal"
polysemyInternalUnion :: ModuleName
polysemyInternalUnion = mkModuleName "Polysemy.Internal.Union"
installTodos :: [CoreToDo] -> CoreM [CoreToDo]
installTodos todos = do
dynFlags <- getDynFlags
case optLevel dynFlags of
2 -> do
mods <- moduleSetElts <$> getVisibleOrphanMods
case any ((== polysemyInternal) . moduleName) mods of
-- TODO(sandy): install extra passes
True -> pure todos
False -> pure todos
_ -> pure todos
fundepPlugin :: TcPlugin
fundepPlugin = TcPlugin
{ tcPluginInit = do
md <- lookupModule polysemyInternalUnion (fsLit "polysemy")
monadEffectTcNm <- lookupName md (mkTcOcc "Find")
tcLookupClass monadEffectTcNm
, tcPluginSolve = solveFundep
, tcPluginStop = const (return ()) }
allMonadEffectConstraints :: Class -> [Ct] -> [(CtLoc, (Type, Type, Type))]
allMonadEffectConstraints cls cts =
[ (ctLoc cd, (effName, eff, r))
| cd@CDictCan{cc_class = cls', cc_tyargs = [_, r, eff]} <- cts
, cls == cls'
, let effName = getEffName eff
]
singleListToJust :: [a] -> Maybe a
singleListToJust [a] = Just a
singleListToJust _ = Nothing
findMatchingEffectIfSingular :: (Type, Type, Type) -> [(Type, Type, Type)] -> Maybe Type
findMatchingEffectIfSingular (effName, _, mon) ts = singleListToJust
[ eff'
| (effName', eff', mon') <- ts
, eqType effName effName'
, eqType mon mon' ]
getEffName :: Type -> Type
getEffName t = fst $ splitAppTys t
-- isTyVar :: Type -> Bool
-- isTyVar = isJust . getTyVar_maybe
canUnify :: Type -> Type -> Bool
canUnify wanted given =
let (w, ws) = splitAppTys wanted
(g, gs) = splitAppTys given
in (&& eqType w g) . flip all (zip ws gs) $ \(wt, gt) ->
if isTyVarTy gt
then isTyVarTy wt
else True
mkWanted :: Bool -> CtLoc -> Type -> Type -> TcPluginM (Maybe Ct)
mkWanted mustUnify loc wanted given = do
if (not mustUnify || canUnify wanted given)
then do
(ev, _) <- unsafeTcPluginTcM $ runTcSDeriveds $ newWantedEq loc Nominal wanted given
pure $ Just (CNonCanonical ev)
else
pure Nothing
solveFundep :: Class -> [Ct] -> [Ct] -> [Ct] -> TcPluginM TcPluginResult
solveFundep effCls giv _ want = do
let wantedEffs = allMonadEffectConstraints effCls want
let givenEffs = snd <$> allMonadEffectConstraints effCls giv
eqs <- forM wantedEffs $ \(loc, e@(_, eff, r)) ->
case findMatchingEffectIfSingular e givenEffs of
Nothing -> do
case splitAppTys r of
(_, [_, eff', _]) -> mkWanted False loc eff eff'
_ -> pure Nothing
Just eff' -> mkWanted True loc eff eff'
return (TcPluginOk [] (catMaybes eqs))