module Control.Super.Plugin.Solving 
  ( solveConstraints
  ) where
import Data.Maybe 
  ( catMaybes
  , isJust, isNothing
  , fromJust, fromMaybe )
import Data.List ( partition, nubBy )
import qualified Data.Set as Set
import Control.Monad ( forM, forM_, filterM )
import TcRnTypes ( Ct(..) )
import TyCon ( TyCon )
import Class ( Class, classTyCon )
import Type 
  ( Type, TyVar
  , substTyVar, substTys
  , eqType )
import TcType ( isAmbiguousTyVar )
import InstEnv ( ClsInst, instanceHead )
import Unify ( tcUnifyTy )
import qualified Outputable as O
import qualified Control.Super.Plugin.Collection.Set as S
import Control.Super.Plugin.Debug ( sDocToStr )
import Control.Super.Plugin.InstanceDict ( InstanceDict )
import Control.Super.Plugin.Wrapper 
  ( TypeVarSubst, mkTypeVarSubst )
import Control.Super.Plugin.Environment 
  ( SupermonadPluginM
  , getGivenConstraints, getWantedConstraints
  , getInstanceFor
  , addTyVarEquality, addTyVarEqualities
  , addTypeEqualities
  , getTyVarEqualities
  , printMsg, printObj, printErr 
  , addWarning, displayWarnings
  , throwPluginError, throwPluginErrorSDoc )
import Control.Super.Plugin.Environment.Lift
  ( isPotentiallyInstantiatedCt
  , partiallyApplyTyCons )
import Control.Super.Plugin.Constraint 
  ( WantedCt
  , isClassConstraint
  , isAnyClassConstraint
  , constraintClassType
  , constraintClassTyArgs )
import Control.Super.Plugin.Separation 
  ( ConstraintGroup
  , separateContraints
  , componentTopTyCons, componentTopTcVars
  , componentMonoTyCon )
import Control.Super.Plugin.Utils 
  ( collectTopTcVars
  , collectTopTyCons
  , collectTyVars
  , associations
  , allM )
import Control.Super.Plugin.Log 
  ( formatConstraint )
  
solveConstraints :: [Class] -> ConstraintGroup -> SupermonadPluginM InstanceDict ()
solveConstraints relevantClss wantedCts = do
  
  
  let ctGroups = separateContraints wantedCts
  
  
  let markedCtGroups = fmap (\g -> (componentMonoTyCon relevantClss g, g)) ctGroups
  
  
  
  
  
  
  let (monoGroups, polyGroups) = partition (isJust . fst) markedCtGroups
  forM_ (fmap (\(tc, g) -> (fromJust tc, g)) monoGroups) $ solveMonoConstraintGroup relevantClss
  forM_ (fmap snd polyGroups) $ solvePolyConstraintGroup relevantClss
  
  
  
  
  
  solveSolvedTyConIndices relevantClss
  
  
  displayWarnings
solveMonoConstraintGroup :: [Class] -> (TyCon, ConstraintGroup) -> SupermonadPluginM s ()
solveMonoConstraintGroup _relevantClss (_, []) = return ()
solveMonoConstraintGroup relevantClss (tyCon, ctGroup) = do
  
  let smCtGroup = filter (isAnyClassConstraint relevantClss) ctGroup
  forM_ smCtGroup $ \ct -> do
    let ctAmbVars = Set.filter isAmbiguousTyVar 
                  $ collectTopTcVars 
                  $ fromMaybe []
                  $ constraintClassTyArgs ct
    forM_ ctAmbVars $ \tyVar -> do
      appliedTyCon <- either throwPluginErrorSDoc return =<< partiallyApplyTyCons [(tyVar, Left tyCon)]
      case nubBy tyConAssocEq appliedTyCon of
        [] -> do
          throwPluginError "How did this become an empty list?"
        [(tv, ty, _)] -> do
          addTyVarEquality ct tv ty
        _ -> do
          throwPluginError "How did this become a list with more then one element?"
  where
    
    tyConAssocEq :: (TyVar, Type, [TyVar]) -> (TyVar, Type, [TyVar]) -> Bool
    tyConAssocEq (tv, t, tvs) (tv', t', tvs') = tv == tv' && tvs == tvs' && eqType t t'
solvePolyConstraintGroup :: [Class] -> ConstraintGroup -> SupermonadPluginM s ()
solvePolyConstraintGroup relevantClss ctGroup = do
  
  
  (_, assocs) <- determineValidConstraintGroupAssocs relevantClss ctGroup
  
  appliedAssocs <- forM assocs $ \assoc -> either throwPluginErrorSDoc return =<< partiallyApplyTyCons assoc
  
  
  case (ctGroup, appliedAssocs) of
    
    
    ([], _) -> return ()
    
    
    
    
    
    
    
    
    
    
    
    
    (_, []) -> do
      let topTcVars = concat $ fmap collectRelevantTopTcVars ctGroup
      
      if null topTcVars then do
        
        
        return ()
      
      else do
        addWarning 
          "There are no possible associations for the current constraint group!" 
          ( O.hang (O.text "There are two possible reasons for this warning:") 2 
          $ O.vcat $ 
            [ O.text "1. Either the group can not be solved or"
            , O.text "2. further iterations between the plugin and type checker "
            , O.text "   have to resolve for sufficient information to arise."
            ] ++ fmap (O.text . formatConstraint) ctGroup)
    
    (_, [appliedAssoc]) -> do
      forM_ appliedAssoc $ \(tv, ty, _flexVars) -> do
        addTyVarEquality (head ctGroup) tv ty
    
    
    (_, _) -> do
      printMsg "Possible associations:"
      forM_ appliedAssocs printObj
      throwPluginError "There is more then one possible association for the current constraint group!"
  where
    collectRelevantTopTcVars :: Ct -> [TyVar]
    collectRelevantTopTcVars ct = do
      let isRelevantCt = isAnyClassConstraint relevantClss ct
      case (isRelevantCt, constraintClassType ct) of
        (True, Just (_cls, tyArgs)) -> Set.toList $ collectTopTcVars tyArgs
        _ -> []
type TcTvSet = (S.Set TyCon, Set.Set TyVar)
tctvIntersection :: TcTvSet -> TcTvSet -> TcTvSet
tctvIntersection (tca, tva) (tcb, tvb) = (S.intersection tca tcb, Set.intersection tva tvb)
tctvNull :: TcTvSet -> Bool
tctvNull (tcs, tvs) = S.null tcs && Set.null tvs
tctvUnion :: TcTvSet -> TcTvSet -> TcTvSet
tctvUnion (tca, tva) (tcb, tvb) = (S.union tca tcb, Set.union tva tvb)
tctvToList :: TcTvSet -> [Either TyCon TyVar]
tctvToList (tcs, tvs) = (fmap Left $ S.toList tcs) ++ (fmap Right $ Set.toList tvs)
determineValidConstraintGroupAssocs :: [Class] -> ConstraintGroup -> SupermonadPluginM s ([WantedCt], [[(TyVar, Either TyCon TyVar)]])
determineValidConstraintGroupAssocs _relevantClss [] = throwPluginError "Solving received an empty constraint group!"
determineValidConstraintGroupAssocs relevantClss ctGroup = do
  
  givenCts <- getGivenConstraints
  
  
  
  let smCtGroup = filter (isAnyClassConstraint relevantClss) ctGroup
  
  
  
  
  tyConVars <- Set.toList <$> getAmbTyConVarsFrom smCtGroup
  
  
  
  
  
  let wantedTyConBase = getTyConBaseFrom smCtGroup
  
  let givenTyConBase = getTyConBaseFrom $ filterRelevantCtsWith givenCts wantedTyConBase
  
  let tyConBase :: [Either TyCon TyVar]
      tyConBase = tctvToList $ tctvUnion wantedTyConBase givenTyConBase
  
  
  
  
  let assocs :: [[(TyVar, Either TyCon TyVar)]]
      assocs = filter (not . null) $ associations $ fmap (\tv -> (tv, tyConBase)) tyConVars
  
  
  
  
  
  checkedAssocs <- forM assocs $ \assoc -> do
    
    validAssoc <- allM (\ct -> isPotentiallyInstantiatedCt ct assoc) ctGroup
    return (assoc, validAssoc)
  
  
  let validAssocs = fmap fst $ filter snd checkedAssocs
  
  return (ctGroup, validAssocs)
  where
    
    
    filterRelevantCtsWith :: [Ct] -> TcTvSet -> [Ct]
    filterRelevantCtsWith allCts baseTyCons =
      let cts = filter (isAnyClassConstraint relevantClss) allCts
      in filter predicate cts
      where 
        predicate :: Ct -> Bool
        predicate ct =
          let ctBase = getTyConBaseFrom [ct]
          in not $ tctvNull $ tctvIntersection ctBase baseTyCons
    
    
    getTyConBaseFrom :: [Ct] -> TcTvSet
    getTyConBaseFrom cts =
      let checkedCts = filter (isAnyClassConstraint relevantClss)  cts
          baseTvs :: Set.Set TyVar
          baseTvs = Set.filter (not . isAmbiguousTyVar) $ componentTopTcVars checkedCts
          baseTcs :: S.Set TyCon
          baseTcs = componentTopTyCons checkedCts
      in (baseTcs, baseTvs)
    
    
    
    getAmbTyConVarsFrom :: [Ct] -> SupermonadPluginM s (Set.Set TyVar)
    getAmbTyConVarsFrom cts = do
      let checkedCts = filter (isAnyClassConstraint relevantClss) cts
      return $ Set.filter isAmbiguousTyVar $ componentTopTcVars checkedCts
solveSolvedTyConIndices :: [Class] -> SupermonadPluginM InstanceDict ()
solveSolvedTyConIndices relevantClss = do
  
  
  tyVarEqs <- getTyVarEqualities
  let tvSubst = mkTypeVarSubst $ fmap (\(_ct, tv, ty) -> (tv, ty)) tyVarEqs
  
  
  
  
  wantedCts <- getWantedConstraints
  let prepWantedCts = catMaybes $ fmap (prepCt tvSubst) wantedCts
  
  
  printMsg "Unification solve constraints..."
  forM_ relevantClss $ \cls -> unificationSolve prepWantedCts (return . isClassConstraint cls) (\tc -> getInstanceFor tc cls)
  
  
  
  
  
  
  where
    unificationSolve :: [(Ct, TyCon, [Type])] 
                     
                     -> (Ct -> SupermonadPluginM s Bool) 
                     
                     -> (TyCon -> SupermonadPluginM s (Maybe ClsInst)) 
                     
                     -> SupermonadPluginM s ()
    unificationSolve prepWantedCts isRequiredConstraint getTyConInst = do
      cts <- filterTopTyConSolvedConstraints prepWantedCts isRequiredConstraint
      forM_ cts $ \ct -> do
        eResult <- withTopTyCon ct getTyConInst $ \_topTyCon _ctArgs inst -> do
          case deriveUnificationConstraints ct inst of
            Left err -> do
              printErr $ sDocToStr err
            Right (tvTyEqs, tyEqs) -> do
              addTyVarEqualities tvTyEqs
              addTypeEqualities tyEqs
        case eResult of
          Left err -> printErr $ sDocToStr err
          Right () -> return ()
    
    
    filterTopTyConSolvedConstraints :: [(WantedCt, TyCon, [Type])] 
                                    -> (WantedCt -> SupermonadPluginM s Bool) 
                                    -> SupermonadPluginM s [(WantedCt, TyCon, [Type])]
    filterTopTyConSolvedConstraints cts p = do
      
      predFilteredCts <- filterM (\(ct, _tc, _args) -> p ct) cts
      
      
      let filterNoVarCts = filter (\(_ct, _tc, args) -> not $ Set.null 
                                                            $ Set.filter isAmbiguousTyVar 
                                                            $ Set.unions 
                                                            $ fmap collectTyVars args) 
                                  predFilteredCts
      
      
      return $ filter (Set.null . collectTopTcVars . (\(_ct, _tc, args) -> args)) filterNoVarCts
    
    withTopTyCon :: (Ct, TyCon, [Type]) 
                 -> (TyCon -> SupermonadPluginM s (Maybe ClsInst)) 
                 -> (TyCon -> [Type] -> ClsInst -> SupermonadPluginM s a) 
                 -> SupermonadPluginM s (Either O.SDoc a)
    withTopTyCon (ct, _ctClsTyCon, ctArgs) getTyConInst process = do
      
      let mTopTyCon = S.toList $ collectTopTyCons ctArgs
      
      case mTopTyCon of
        [topTyCon] -> do
          
          mInst <- getTyConInst topTyCon
          case mInst of
            Just inst -> Right <$> process topTyCon ctArgs inst
            
            
            Nothing -> do
              return $ Left
                     $ O.text "Constraints top type constructor does not have an associated instance:"
                       O.$$ O.ppr topTyCon
        _ -> do
          return $ Left
                 $ O.text "Constraint misses a unqiue top-level type constructor:"
                   O.$$ O.ppr ct
    
    deriveUnificationConstraints :: (Ct, TyCon, [Type]) -> ClsInst -> Either O.SDoc ([(Ct, TyVar, Type)], [(Ct, Type, Type)])
    deriveUnificationConstraints (ct, _ctClsTyCon, ctArgs) inst = do
      let (instVars, _instCls, instArgs) = instanceHead inst
      
      let ctVars = Set.toList $ Set.unions $ fmap collectTyVars ctArgs
      let mSubsts = zipWith tcUnifyTy instArgs ctArgs
      if any isNothing mSubsts then do
        Left $ O.hang (O.text "Unification constraint solving not possible, because instance and constraint are not unifiable!") 2 
             $ (O.hang (O.text "Instance:") 2 $ O.ppr inst) O.$$
               (O.hang (O.text "Constraint:") 2 $ O.ppr ct) O.$$
               (O.hang (O.text "Constraint arguments:") 2 $ O.ppr ctArgs)
      else do
        let substs = catMaybes mSubsts
        let instVarEqGroups = collectEqualityGroup substs instVars
        instVarEqGroupsCt <- fmap concat $ forM instVarEqGroups $ \(_, eqGroup) -> do
          return $ mkEqGroup ct eqGroup
        
        
        let ctVarEqGroups = collectEqualityGroup substs $ filter isAmbiguousTyVar ctVars
        let ctVarEqCts = mkEqStarGroup ct ctVarEqGroups
        return (ctVarEqCts, instVarEqGroupsCt)
    
    mkEqGroup ::  Ct -> [Type] -> [(Ct, Type, Type)]
    mkEqGroup _ [] = []
    mkEqGroup baseCt (ty : tys) = fmap (\ty' -> (baseCt, ty, ty')) tys
    
    mkEqStarGroup :: Ct -> [(TyVar, [Type])] -> [(Ct, TyVar, Type)]
    mkEqStarGroup baseCt eqGroups = concatMap (\(tv, tys) -> fmap (\ty -> (baseCt, tv, ty)) tys) eqGroups
    
    collectEqualityGroup :: [TypeVarSubst] -> [TyVar] -> [(TyVar, [Type])]
    collectEqualityGroup substs tvs = [ (tv, nubBy eqType $ filter (notElem tv . collectTyVars)
                                                          $ [ substTyVar subst tv | subst <- substs]
                                        ) | tv <- tvs]
                                        
    prepCt :: TypeVarSubst -> Ct -> Maybe (Ct, TyCon, [Type])
    prepCt subst ct = fmap (\(cls, args) -> (ct, classTyCon cls, substTys subst args)) $ constraintClassType ct