{-# LANGUAGE GeneralizedNewtypeDeriving #-}
module Polysemy.Plugin.Fundep (fundepPlugin) where
import Class
import CoAxiom
import Control.Monad
import Data.Bifunctor
import Data.Function (on)
import Data.IORef
import Data.List
import Data.Maybe
import qualified Data.Set as S
import FastString (fsLit)
import GHC (ModuleName)
import GHC.TcPluginM.Extra (lookupModule, lookupName)
import Module (mkModuleName)
import OccName (mkTcOcc)
import TcPluginM (TcPluginM, tcLookupClass, tcPluginIO)
import TcRnTypes
import TcSMonad hiding (tcLookupClass)
import TyCoRep (Type (..))
import Type
polysemyInternalUnion :: ModuleName
polysemyInternalUnion = mkModuleName "Polysemy.Internal.Union"
fundepPlugin :: TcPlugin
fundepPlugin = TcPlugin
{ tcPluginInit = do
md <- lookupModule polysemyInternalUnion (fsLit "polysemy")
monadEffectTcNm <- lookupName md (mkTcOcc "Find")
(,) <$> tcPluginIO (newIORef S.empty)
<*> tcLookupClass monadEffectTcNm
, tcPluginSolve = solveFundep
, tcPluginStop = const (return ()) }
allMonadEffectConstraints :: Class -> [Ct] -> [(CtLoc, (Type, Type, Type))]
allMonadEffectConstraints cls cts =
[ (ctLoc cd, (effName, eff, r))
| cd@CDictCan{cc_class = cls', cc_tyargs = [_, r, eff]} <- cts
, cls == cls'
, let effName = getEffName eff
]
singleListToJust :: [a] -> Maybe a
singleListToJust [a] = Just a
singleListToJust _ = Nothing
findMatchingEffectIfSingular :: (Type, Type, Type) -> [(Type, Type, Type)] -> Maybe Type
findMatchingEffectIfSingular (effName, _, mon) ts = singleListToJust
[ eff'
| (effName', eff', mon') <- ts
, eqType effName effName'
, eqType mon mon' ]
getEffName :: Type -> Type
getEffName t = fst $ splitAppTys t
canUnify :: Type -> Type -> Bool
canUnify wanted given =
let (w, ws) = splitAppTys wanted
(g, gs) = splitAppTys given
in (&& eqType w g) . flip all (zip ws gs) $ \(wt, gt) ->
or [ isTyVarTy wt
, eqType wt gt
, canUnify wt gt
]
mkWanted
:: Bool
-> CtLoc
-> Type
-> Type
-> TcPluginM (Maybe ( (OrdType, OrdType)
, Ct
))
mkWanted must_unify loc wanted given =
if (not must_unify || canUnify wanted given)
then do
(ev, _) <- unsafeTcPluginTcM $ runTcSDeriveds $ newWantedEq loc Nominal wanted given
pure $ Just ( (OrdType wanted, OrdType given)
, CNonCanonical ev
)
else
pure Nothing
thd :: (a, b, c) -> c
thd (_, _, c) = c
countLength :: (a -> a -> Bool) -> [a] -> [(a, Int)]
countLength eq as =
let grouped = groupBy eq as
in zipWith (curry $ bimap head length) grouped grouped
newtype OrdType = OrdType
{ getOrdType :: Type
}
instance Eq OrdType where
(==) = eqType `on` getOrdType
instance Ord OrdType where
compare = nonDetCmpType `on` getOrdType
solveFundep
:: (IORef (S.Set (OrdType, OrdType)), Class)
-> [Ct]
-> [Ct]
-> [Ct]
-> TcPluginM TcPluginResult
solveFundep _ _ _ [] = pure $ TcPluginOk [] []
solveFundep (ref, effCls) giv _ want = do
let wantedEffs = allMonadEffectConstraints effCls want
givenEffs = snd <$> allMonadEffectConstraints effCls giv
num_wanteds_by_r = countLength eqType $ fmap (thd . snd) wantedEffs
must_unify r =
let Just num_wanted = find (eqType r . fst) num_wanteds_by_r
in snd num_wanted /= 1
eqs <- forM wantedEffs $ \(loc, e@(_, eff, r)) -> do
case findMatchingEffectIfSingular e givenEffs of
Nothing -> do
case splitAppTys r of
(_, [_, eff', _]) -> mkWanted (must_unify r) loc eff eff'
_ -> pure Nothing
Just eff' -> mkWanted True loc eff eff'
already_emitted <- tcPluginIO $ readIORef ref
let new_wanteds = filter (not . flip S.member already_emitted . fst)
$ catMaybes eqs
tcPluginIO $ modifyIORef ref $ S.union $ S.fromList $ fmap fst new_wanteds
pure . TcPluginOk [] $ fmap snd new_wanteds