{-# LANGUAGE GeneralizedNewtypeDeriving #-}
module Polysemy.Plugin.Fundep (fundepPlugin) where
import Class
import CoAxiom
import Control.Applicative
import Control.Monad
import Data.Bifunctor
import Data.Bool
import Data.Function (on)
import Data.IORef
import Data.List
import Data.Maybe
import qualified Data.Set as S
import FastString (fsLit)
import GHC (ModuleName)
import GHC.TcPluginM.Extra (lookupModule, lookupName)
import Module (mkModuleName)
import OccName (mkTcOcc)
import TcPluginM (TcPluginM, tcLookupClass, tcPluginIO)
import TcRnTypes
import TcSMonad hiding (tcLookupClass)
import TyCoRep (Type (..))
import Type
polysemyInternalUnion :: ModuleName
polysemyInternalUnion = mkModuleName "Polysemy.Internal.Union"
fundepPlugin :: TcPlugin
fundepPlugin = TcPlugin
{ tcPluginInit = do
md <- lookupModule polysemyInternalUnion (fsLit "polysemy")
monadEffectTcNm <- lookupName md (mkTcOcc "Find")
(,) <$> tcPluginIO (newIORef S.empty)
<*> tcLookupClass monadEffectTcNm
, tcPluginSolve = solveFundep
, tcPluginStop = const (return ()) }
allMonadEffectConstraints :: Class -> [Ct] -> [(CtLoc, (Type, Type, Type))]
allMonadEffectConstraints cls cts =
[ (ctLoc cd, (effName, eff, r))
| cd@CDictCan{cc_class = cls', cc_tyargs = [_, r, eff]} <- cts
, cls == cls'
, let effName = getEffName eff
]
singleListToJust :: [a] -> Maybe a
singleListToJust [a] = Just a
singleListToJust _ = Nothing
findMatchingEffectIfSingular :: (Type, Type, Type) -> [(Type, Type, Type)] -> Maybe Type
findMatchingEffectIfSingular (effName, _, mon) ts = singleListToJust
[ eff'
| (effName', eff', mon') <- ts
, eqType effName effName'
, eqType mon mon' ]
getEffName :: Type -> Type
getEffName t = fst $ splitAppTys t
canUnifyRecursive :: SolveContext -> Type -> Type -> Bool
canUnifyRecursive solve_ctx = go True
where
poly_given_ok :: Bool
poly_given_ok =
case solve_ctx of
InterpreterUse _ -> True
FunctionDef -> False
go :: Bool -> Type -> Type -> Bool
go is_first wanted given =
let (w, ws) = splitAppTys wanted
(g, gs) = splitAppTys given
in (&& bool (canUnify poly_given_ok) eqType is_first w g)
. flip all (zip ws gs)
$ \(wt, gt) -> canUnify poly_given_ok wt gt || go False wt gt
canUnify :: Bool -> Type -> Type -> Bool
canUnify poly_given_ok wt gt =
or [ isTyVarTy wt
, isTyVarTy gt && poly_given_ok
, eqType wt gt
]
whenA
:: (Monad m, Alternative z)
=> Bool
-> m a
-> m (z a)
whenA False _ = pure empty
whenA True ma = fmap pure ma
mkWanted
:: SolveContext
-> CtLoc
-> Type
-> Type
-> TcPluginM (Maybe ( (OrdType, OrdType)
, Ct
))
mkWanted solve_ctx loc wanted given =
whenA (not (mustUnify solve_ctx) || canUnifyRecursive solve_ctx wanted given) $ do
(ev, _) <- unsafeTcPluginTcM
. runTcSDeriveds
$ newWantedEq loc Nominal wanted given
pure ( (OrdType wanted, OrdType given)
, CNonCanonical ev
)
thd :: (a, b, c) -> c
thd (_, _, c) = c
countLength :: (a -> a -> Bool) -> [a] -> [(a, Int)]
countLength eq as =
let grouped = groupBy eq as
in zipWith (curry $ bimap head length) grouped grouped
newtype OrdType = OrdType
{ getOrdType :: Type
}
instance Eq OrdType where
(==) = eqType `on` getOrdType
instance Ord OrdType where
compare = nonDetCmpType `on` getOrdType
data SolveContext
=
FunctionDef
| InterpreterUse Bool
deriving (Eq, Ord, Show)
mustUnify :: SolveContext -> Bool
mustUnify FunctionDef = True
mustUnify (InterpreterUse b) = b
solveFundep
:: (IORef (S.Set (OrdType, OrdType)), Class)
-> [Ct]
-> [Ct]
-> [Ct]
-> TcPluginM TcPluginResult
solveFundep _ _ _ [] = pure $ TcPluginOk [] []
solveFundep (ref, effCls) giv _ want = do
let wantedEffs = allMonadEffectConstraints effCls want
givenEffs = snd <$> allMonadEffectConstraints effCls giv
num_wanteds_by_r = countLength eqType $ fmap (thd . snd) wantedEffs
must_unify r =
let Just num_wanted = find (eqType r . fst) num_wanteds_by_r
in snd num_wanted /= 1
eqs <- forM wantedEffs $ \(loc, e@(_, eff, r)) -> do
case findMatchingEffectIfSingular e givenEffs of
Nothing -> do
case splitAppTys r of
(_, [_, eff', _]) -> mkWanted (InterpreterUse $ must_unify r) loc eff eff'
_ -> pure Nothing
Just eff' -> mkWanted FunctionDef loc eff eff'
already_emitted <- tcPluginIO $ readIORef ref
let new_wanteds = filter (not . flip S.member already_emitted . fst)
$ catMaybes eqs
tcPluginIO $ modifyIORef ref $ S.union $ S.fromList $ fmap fst new_wanteds
pure . TcPluginOk [] $ fmap snd new_wanteds