module Control.Supermonad.Plugin.Solving
( solveConstraints
) where
import Data.Maybe
( catMaybes
, isJust, isNothing
, fromJust )
import Data.List ( partition, nubBy )
import qualified Data.Set as S
import Control.Monad ( forM, forM_, filterM, liftM2 )
import TcRnTypes ( Ct(..) )
import TyCon ( TyCon )
import Class ( classTyCon )
import Type
( Type, TyVar
, substTyVar, substTys
, eqType )
import TcType ( isAmbiguousTyVar )
import InstEnv ( ClsInst, instanceHead )
import Unify ( tcUnifyTy )
import qualified Outputable as O
import Control.Supermonad.Plugin.Debug ( sDocToStr )
import Control.Supermonad.Plugin.Wrapper
( TypeVarSubst, mkTypeVarSubst )
import Control.Supermonad.Plugin.Environment
( SupermonadPluginM
, getGivenConstraints, getWantedConstraints
, getReturnClass, getBindClass
, getSupermonadFor
, addTyVarEquality, addTyVarEqualities
, addTypeEqualities
, getTyVarEqualities
, printMsg, printObj, printErr
, addWarning, displayWarnings
, throwPluginError, throwPluginErrorSDoc )
import Control.Supermonad.Plugin.Environment.Lift
( isPotentiallyInstantiatedCt
, isBindConstraint, isReturnConstraint
, partiallyApplyTyCons )
import Control.Supermonad.Plugin.Constraint
( WantedCt
, isClassConstraint
, constraintClassType
, constraintClassTyArgs )
import Control.Supermonad.Plugin.Separation
( separateContraints
, componentTopTyCons, componentTopTcVars
, componentMonoTyCon )
import Control.Supermonad.Plugin.Utils
( collectTopTcVars
, collectTopTyCons
, collectTyVars
, associations
, allM )
import Control.Supermonad.Plugin.Log
( formatConstraint )
solveConstraints :: [WantedCt] -> SupermonadPluginM ()
solveConstraints wantedCts = do
ctGroups <- separateContraints wantedCts
markedCtGroups <- forM ctGroups $ \g -> do
mMonoTyCon <- componentMonoTyCon g
return (mMonoTyCon, g)
let (monoGroups, polyGroups) = partition (isJust . fst) markedCtGroups
forM_ (fmap (\(tc, g) -> (fromJust tc, g)) monoGroups) solveMonoConstraintGroup
forM_ (fmap snd polyGroups) solvePolyConstraintGroup
solveSolvedTyConIndices
displayWarnings
solveMonoConstraintGroup :: (TyCon, [WantedCt]) -> SupermonadPluginM ()
solveMonoConstraintGroup (_, []) = return ()
solveMonoConstraintGroup (tyCon, ctGroup) = do
printMsg "Solve mono group..."
smCtGroup <- filterM (\ct -> liftM2 (||) (isReturnConstraint ct) (isBindConstraint ct)) ctGroup
forM_ smCtGroup $ \ct -> do
let ctAmbVars = S.filter isAmbiguousTyVar
$ collectTopTcVars
$ maybe [] id
$ constraintClassTyArgs ct
forM_ ctAmbVars $ \tyVar -> do
appliedTyCon <- either throwPluginErrorSDoc return =<< partiallyApplyTyCons [(tyVar, Left tyCon)]
case nubBy tyConAssocEq appliedTyCon of
[] -> do
throwPluginError "How did this become an empty list?"
[(tv, ty, _)] -> do
addTyVarEquality ct tv ty
_ -> do
throwPluginError "How did this become a list with more then one element?"
where
tyConAssocEq :: (TyVar, Type, [TyVar]) -> (TyVar, Type, [TyVar]) -> Bool
tyConAssocEq (tv, t, tvs) (tv', t', tvs') = tv == tv' && tvs == tvs' && eqType t t'
solvePolyConstraintGroup :: [WantedCt] -> SupermonadPluginM ()
solvePolyConstraintGroup ctGroup = do
printMsg "Solve poly group..."
(_, assocs) <- determineValidConstraintGroupAssocs ctGroup
appliedAssocs <- forM assocs $ \assoc -> either throwPluginErrorSDoc return =<< partiallyApplyTyCons assoc
case (ctGroup, appliedAssocs) of
([], _) -> return ()
(_, []) -> do
topTcVars <- concat <$> mapM collectBindReturnTopTcVars ctGroup
if null topTcVars then do
return ()
else do
addWarning
"There are no possible associations for the current constraint group!"
( O.hang (O.text "There are two possible reasons for this warning:") 2
$ O.vcat $
[ O.text "1. Either the group can not be solved or"
, O.text "2. further iterations between the plugin and type checker "
, O.text " have to resolve for sufficient information to arise."
] ++ fmap (O.text . formatConstraint) ctGroup)
(_, [appliedAssoc]) -> do
forM_ appliedAssoc $ \(tv, ty, _flexVars) -> do
addTyVarEquality (head ctGroup) tv ty
(_, _) -> do
printMsg "Possible associations:"
forM_ appliedAssocs printObj
throwPluginError "There is more then one possible association for the current constraint group!"
where
collectBindReturnTopTcVars :: Ct -> SupermonadPluginM [TyVar]
collectBindReturnTopTcVars ct = do
isBindOrReturn <- liftM2 (||) (isBindConstraint ct) (isReturnConstraint ct)
case (isBindOrReturn, constraintClassType ct) of
(True, Just (_cls, tyArgs)) -> return $ S.toList $ collectTopTcVars tyArgs
_ -> return []
determineValidConstraintGroupAssocs :: [WantedCt] -> SupermonadPluginM ([WantedCt], [[(TyVar, Either TyCon TyVar)]])
determineValidConstraintGroupAssocs [] = throwPluginError "Solving received an empty constraint group!"
determineValidConstraintGroupAssocs ctGroup = do
givenCts <- getGivenConstraints
smCtGroup <- filterSupermonadCts ctGroup
tyConVars <- S.toList <$> getAmbTyConVarsFrom smCtGroup
wantedTyConBase <- getTyConBaseFrom smCtGroup
givenTyConBase <- getTyConBaseFrom =<< filterSupermonadCtsWith givenCts wantedTyConBase
let tyConBase :: [Either TyCon TyVar]
tyConBase = S.toList $ S.union wantedTyConBase givenTyConBase
let assocs :: [[(TyVar, Either TyCon TyVar)]]
assocs = filter (not . null) $ associations $ fmap (\tv -> (tv, tyConBase)) tyConVars
checkedAssocs <- forM assocs $ \assoc -> do
validAssoc <- allM (\ct -> isPotentiallyInstantiatedCt ct assoc) ctGroup
return (assoc, validAssoc)
let validAssocs = fmap fst $ filter snd checkedAssocs
return (ctGroup, validAssocs)
where
filterSupermonadCts :: [Ct] -> SupermonadPluginM [Ct]
filterSupermonadCts cts = do
returnCls <- getReturnClass
bindCls <- getBindClass
return $ filter (\ct -> isClassConstraint returnCls ct || isClassConstraint bindCls ct) cts
filterSupermonadCtsWith :: [Ct] -> S.Set (Either TyCon TyVar) -> SupermonadPluginM [Ct]
filterSupermonadCtsWith allCts baseTyCons = do
cts <- filterSupermonadCts allCts
filterM predicate cts
where
predicate :: Ct -> SupermonadPluginM Bool
predicate ct = do
ctBase <- getTyConBaseFrom [ct]
return $ not $ S.null $ S.intersection ctBase baseTyCons
getTyConBaseFrom :: [Ct] -> SupermonadPluginM (S.Set (Either TyCon TyVar))
getTyConBaseFrom cts = do
checkedCts <- filterSupermonadCts cts
let baseTvs :: S.Set (Either TyCon TyVar)
baseTvs = S.map Right $ S.filter (not . isAmbiguousTyVar) $ componentTopTcVars checkedCts
let baseTcs :: S.Set (Either TyCon TyVar)
baseTcs = S.map Left $ componentTopTyCons checkedCts
return $ S.union baseTvs baseTcs
getAmbTyConVarsFrom :: [Ct] -> SupermonadPluginM (S.Set TyVar)
getAmbTyConVarsFrom cts = do
checkedCts <- filterSupermonadCts cts
return $ S.filter isAmbiguousTyVar $ componentTopTcVars checkedCts
solveSolvedTyConIndices :: SupermonadPluginM ()
solveSolvedTyConIndices = do
tyVarEqs <- getTyVarEqualities
let tvSubst = mkTypeVarSubst $ fmap (\(_ct, tv, ty) -> (tv, ty)) tyVarEqs
wantedCts <- getWantedConstraints
let prepWantedCts = catMaybes $ fmap (prepCt tvSubst) wantedCts
printMsg "Unification solve return constraints..."
unificationSolve prepWantedCts isReturnConstraint (\tc -> fmap snd <$> getSupermonadFor tc)
printMsg "Unification solve bind constraints..."
unificationSolve prepWantedCts isBindConstraint (\tc -> fmap fst <$> getSupermonadFor tc)
where
unificationSolve :: [(Ct, TyCon, [Type])]
-> (Ct -> SupermonadPluginM Bool)
-> (TyCon -> SupermonadPluginM (Maybe ClsInst))
-> SupermonadPluginM ()
unificationSolve prepWantedCts isRequiredConstraint getTyConInst = do
cts <- filterTopTyConSolvedConstraints prepWantedCts isRequiredConstraint
forM_ cts $ \ct -> do
eResult <- withTopTyCon ct getTyConInst $ \_topTyCon _ctArgs inst -> do
case deriveUnificationConstraints ct inst of
Left err -> do
printErr $ sDocToStr err
Right (tvTyEqs, tyEqs) -> do
addTyVarEqualities tvTyEqs
addTypeEqualities tyEqs
case eResult of
Left err -> printErr $ sDocToStr err
Right () -> return ()
filterTopTyConSolvedConstraints :: [(WantedCt, TyCon, [Type])]
-> (WantedCt -> SupermonadPluginM Bool)
-> SupermonadPluginM [(WantedCt, TyCon, [Type])]
filterTopTyConSolvedConstraints cts p = do
predFilteredCts <- filterM (\(ct, _tc, _args) -> p ct) cts
let filterNoVarCts = filter (\(_ct, _tc, args) -> not $ S.null
$ S.filter isAmbiguousTyVar
$ S.unions
$ fmap collectTyVars args)
predFilteredCts
return $ filter (S.null . collectTopTcVars . (\(_ct, _tc, args) -> args)) filterNoVarCts
withTopTyCon :: (Ct, TyCon, [Type])
-> (TyCon -> SupermonadPluginM (Maybe ClsInst))
-> (TyCon -> [Type] -> ClsInst -> SupermonadPluginM a)
-> SupermonadPluginM (Either O.SDoc a)
withTopTyCon (ct, _ctClsTyCon, ctArgs) getTyConInst process = do
let mTopTyCon = S.toList $ collectTopTyCons ctArgs
case mTopTyCon of
[topTyCon] -> do
mInst <- getTyConInst topTyCon
case mInst of
Just inst -> Right <$> process topTyCon ctArgs inst
Nothing -> do
return $ Left
$ O.text "Constraints top type constructor does not have an associated instance:"
O.$$ O.ppr topTyCon
_ -> do
return $ Left
$ O.text "Constraint misses a unqiue top-level type constructor:"
O.$$ O.ppr ct
deriveUnificationConstraints :: (Ct, TyCon, [Type]) -> ClsInst -> Either O.SDoc ([(Ct, TyVar, Type)], [(Ct, Type, Type)])
deriveUnificationConstraints (ct, _ctClsTyCon, ctArgs) inst = do
let (instVars, _instCls, instArgs) = instanceHead inst
let ctVars = S.toList $ S.unions $ fmap collectTyVars ctArgs
let mSubsts = zipWith tcUnifyTy instArgs ctArgs
if any isNothing mSubsts then do
Left $ O.hang (O.text "Unification constraint solving not possible, because instance and constraint are not unifiable!") 2
$ (O.hang (O.text "Instance:") 2 $ O.ppr inst) O.$$
(O.hang (O.text "Constraint:") 2 $ O.ppr ct) O.$$
(O.hang (O.text "Constraint arguments:") 2 $ O.ppr ctArgs)
else do
let substs = catMaybes mSubsts
let instVarEqGroups = collectEqualityGroup substs instVars
instVarEqGroupsCt <- fmap concat $ forM instVarEqGroups $ \(_, eqGroup) -> do
return $ mkEqGroup ct eqGroup
let ctVarEqGroups = collectEqualityGroup substs $ filter isAmbiguousTyVar ctVars
let ctVarEqCts = mkEqStarGroup ct ctVarEqGroups
return (ctVarEqCts, instVarEqGroupsCt)
mkEqGroup :: Ct -> [Type] -> [(Ct, Type, Type)]
mkEqGroup _ [] = []
mkEqGroup baseCt (ty : tys) = fmap (\ty' -> (baseCt, ty, ty')) tys
mkEqStarGroup :: Ct -> [(TyVar, [Type])] -> [(Ct, TyVar, Type)]
mkEqStarGroup baseCt eqGroups = concatMap (\(tv, tys) -> fmap (\ty -> (baseCt, tv, ty)) tys) eqGroups
collectEqualityGroup :: [TypeVarSubst] -> [TyVar] -> [(TyVar, [Type])]
collectEqualityGroup substs tvs = [ (tv, nubBy eqType $ filter (all (tv /=) . collectTyVars)
$ [ substTyVar subst tv | subst <- substs]
) | tv <- tvs]
prepCt :: TypeVarSubst -> Ct -> Maybe (Ct, TyCon, [Type])
prepCt subst ct = fmap (\(cls, args) -> (ct, classTyCon cls, substTys subst args)) $ constraintClassType ct