module Control.Supermonad.Plugin.Separation
( separateContraints
, componentTopTyCons
, componentTopTcVars
, componentMonoTyCon
) where
import Data.Maybe ( fromJust )
import qualified Data.Set as S
import Control.Monad ( filterM, liftM2 )
import Data.Graph.Inductive.Graph
( LNode, Edge
, mkGraph, toLEdge )
import Data.Graph.Inductive.PatriciaTree ( Gr )
import Data.Graph.Inductive.Query.DFS ( components )
import TcRnTypes ( Ct )
import TyCon ( TyCon )
import Type ( Type, TyVar )
import TcType ( isAmbiguousTyVar )
import Control.Supermonad.Plugin.Constraint
( WantedCt, constraintClassTyArgs )
import Control.Supermonad.Plugin.Utils
( collectTopTyCons, collectTopTcVars, anyM )
import Control.Supermonad.Plugin.Environment
( SupermonadPluginM )
import Control.Supermonad.Plugin.Environment.Lift
( isReturnConstraint, isBindConstraint )
type SCNode = LNode WantedCt
componentMonoTyCon :: [WantedCt] -> SupermonadPluginM (Maybe TyCon)
componentMonoTyCon cts = do
smCts <- filterM (\ct -> liftM2 (||) (isReturnConstraint ct) (isBindConstraint ct)) cts
let tyVars = S.filter (not . isAmbiguousTyVar) $ componentTopTcVars smCts
let tyCons = componentTopTyCons smCts
return $ case (S.toList tyCons, S.size tyVars) of
([tc], 0) -> Just tc
_ -> Nothing
componentTopTyCons :: [WantedCt] -> S.Set TyCon
componentTopTyCons = collect collectTopTyCons
componentTopTcVars :: [WantedCt] -> S.Set TyVar
componentTopTcVars = collect collectTopTcVars
collect :: (Ord a) => ([Type] -> S.Set a) -> [Ct] -> S.Set a
collect f cts = S.unions $ fmap collectLocal cts
where
collectLocal ct = maybe S.empty id
$ fmap f
$ constraintClassTyArgs ct
separateContraints :: [WantedCt] -> SupermonadPluginM [[WantedCt]]
separateContraints wantedCts = filterM containsBindOrReturn comps
where
containsBindOrReturn :: [WantedCt] -> SupermonadPluginM Bool
containsBindOrReturn = anyM $ \ct -> liftM2 (||) (isBindConstraint ct) (isReturnConstraint ct)
comps :: [[WantedCt]]
comps = fmap (\n -> fromJust $ lookup n nodes) <$> components g
g :: Gr WantedCt ()
g = mkGraph nodes (fmap (\e -> toLEdge e ()) edges)
nodes :: [SCNode]
nodes = zip [0..] wantedCts
edges :: [Edge]
edges = [ e | e <- allEdgesFor nodes, isEdge e ]
isEdge :: Edge -> Bool
isEdge (na, nb) = maybe False id $ do
caArgs <- lookup na nodes >>= constraintClassTyArgs
cbArgs <- lookup nb nodes >>= constraintClassTyArgs
let ta = S.filter isAmbiguousTyVar $ collectTopTcVars caArgs
let tb = S.filter isAmbiguousTyVar $ collectTopTcVars cbArgs
return $ not $ S.null $ S.intersection ta tb
allEdgesFor :: [SCNode] -> [Edge]
allEdgesFor [] = []
allEdgesFor (n : ns) = concat [ fmap (\m -> (m, fst n)) (fmap fst ns)
, fmap (\m -> (fst n, m)) (fmap fst ns)
, allEdgesFor ns ]