{-# LANGUAGE CPP #-}
{-# LANGUAGE TypeFamilies #-}
module TcClassDcl ( tcClassSigs, tcClassDecl2,
                    findMethodBind, instantiateMethod,
                    tcClassMinimalDef,
                    HsSigFun, mkHsSigFun,
                    tcMkDeclCtxt, tcAddDeclCtxt, badMethodErr,
                    instDeclCtxt1, instDeclCtxt2, instDeclCtxt3,
                    tcATDefault
                  ) where
#include "HsVersions.h"
import GhcPrelude
import HsSyn
import TcEnv
import TcSigs
import TcEvidence ( idHsWrapper )
import TcBinds
import TcUnify
import TcHsType
import TcMType
import Type     ( getClassPredTys_maybe, piResultTys )
import TcType
import TcRnMonad
import DriverPhases (HscSource(..))
import BuildTyCl( TcMethInfo )
import Class
import Coercion ( pprCoAxiom )
import DynFlags
import FamInst
import FamInstEnv
import Id
import Name
import NameEnv
import NameSet
import Var
import VarEnv
import Outputable
import SrcLoc
import TyCon
import Maybes
import BasicTypes
import Bag
import FastString
import BooleanFormula
import Util
import Control.Monad
import Data.List ( mapAccumL, partition )
illegalHsigDefaultMethod :: Name -> SDoc
illegalHsigDefaultMethod n =
    text "Illegal default method(s) in class definition of" <+> ppr n <+> text "in hsig file"
tcClassSigs :: Name                
            -> [LSig GhcRn]
            -> LHsBinds GhcRn
            -> TcM [TcMethInfo]    
tcClassSigs clas sigs def_methods
  = do { traceTc "tcClassSigs 1" (ppr clas)
       ; gen_dm_prs <- concat <$> mapM (addLocM tc_gen_sig) gen_sigs
       ; let gen_dm_env :: NameEnv (SrcSpan, Type)
             gen_dm_env = mkNameEnv gen_dm_prs
       ; op_info <- concat <$> mapM (addLocM (tc_sig gen_dm_env)) vanilla_sigs
       ; let op_names = mkNameSet [ n | (n,_,_) <- op_info ]
       ; sequence_ [ failWithTc (badMethodErr clas n)
                   | n <- dm_bind_names, not (n `elemNameSet` op_names) ]
                   
       ; tcg_env <- getGblEnv
       ; if tcg_src tcg_env == HsigFile
            then
               
               
               
               
               when (not (null def_methods)) $
                failWithTc (illegalHsigDefaultMethod clas)
            else
               
               sequence_ [ failWithTc (badGenericMethod clas n)
                         | (n,_) <- gen_dm_prs, not (n `elem` dm_bind_names) ]
       ; traceTc "tcClassSigs 2" (ppr clas)
       ; return op_info }
  where
    vanilla_sigs = [L loc (nm,ty) | L loc (ClassOpSig _ False nm ty) <- sigs]
    gen_sigs     = [L loc (nm,ty) | L loc (ClassOpSig _ True  nm ty) <- sigs]
    dm_bind_names :: [Name] 
    dm_bind_names = [op | L _ (FunBind {fun_id = L _ op}) <- bagToList def_methods]
    skol_info = TyConSkol ClassFlavour clas
    tc_sig :: NameEnv (SrcSpan, Type) -> ([Located Name], LHsSigType GhcRn)
           -> TcM [TcMethInfo]
    tc_sig gen_dm_env (op_names, op_hs_ty)
      = do { traceTc "ClsSig 1" (ppr op_names)
           ; op_ty <- tcClassSigType skol_info op_names op_hs_ty
                   
           ; traceTc "ClsSig 2" (ppr op_names)
           ; return [ (op_name, op_ty, f op_name) | L _ op_name <- op_names ] }
           where
             f nm | Just lty <- lookupNameEnv gen_dm_env nm = Just (GenericDM lty)
                  | nm `elem` dm_bind_names                 = Just VanillaDM
                  | otherwise                               = Nothing
    tc_gen_sig (op_names, gen_hs_ty)
      = do { gen_op_ty <- tcClassSigType skol_info op_names gen_hs_ty
           ; return [ (op_name, (loc, gen_op_ty)) | L loc op_name <- op_names ] }
tcClassDecl2 :: LTyClDecl GhcRn          
             -> TcM (LHsBinds GhcTcId)
tcClassDecl2 (L _ (ClassDecl {tcdLName = class_name, tcdSigs = sigs,
                                tcdMeths = default_binds}))
  = recoverM (return emptyLHsBinds)     $
    setSrcSpan (getLoc class_name)      $
    do  { clas <- tcLookupLocatedClass class_name
        
        
        
        
        
        
        
        
        ; let (tyvars, _, _, op_items) = classBigSig clas
              prag_fn     = mkPragEnv sigs default_binds
              sig_fn      = mkHsSigFun sigs
              clas_tyvars = snd (tcSuperSkolTyVars tyvars)
              pred        = mkClassPred clas (mkTyVarTys clas_tyvars)
        ; this_dict <- newEvVar pred
        ; let tc_item = tcDefMeth clas clas_tyvars this_dict
                                  default_binds sig_fn prag_fn
                   
                   
                   
        ; dm_binds <- tcExtendTyVarEnv clas_tyvars $
                      mapM tc_item op_items
        ; return (unionManyBags dm_binds) }
tcClassDecl2 d = pprPanic "tcClassDecl2" (ppr d)
tcDefMeth :: Class -> [TyVar] -> EvVar -> LHsBinds GhcRn
          -> HsSigFun -> TcPragEnv -> ClassOpItem
          -> TcM (LHsBinds GhcTcId)
tcDefMeth _ _ _ _ _ prag_fn (sel_id, Nothing)
  = do { 
         mapM_ (addLocM (badDmPrag sel_id))
               (lookupPragEnv prag_fn (idName sel_id))
       ; return emptyBag }
tcDefMeth clas tyvars this_dict binds_in hs_sig_fn prag_fn
          (sel_id, Just (dm_name, dm_spec))
  | Just (L bind_loc dm_bind, bndr_loc, prags) <- findMethodBind sel_name binds_in prag_fn
  = do { 
         
         
         
         
         
         
         
         
         
         global_dm_id  <- tcLookupId dm_name
       ; global_dm_id  <- addInlinePrags global_dm_id prags
       ; local_dm_name <- newNameAt (getOccName sel_name) bndr_loc
            
            
       ; spec_prags <- discardConstraints $
                       tcSpecPrags global_dm_id prags
       ; warnTc NoReason
                (not (null spec_prags))
                (text "Ignoring SPECIALISE pragmas on default method"
                 <+> quotes (ppr sel_name))
       ; let hs_ty = hs_sig_fn sel_name
                     `orElse` pprPanic "tc_dm" (ppr sel_name)
             
             
             
             
             
             
             
             
             
             local_dm_ty = instantiateMethod clas global_dm_id (mkTyVarTys tyvars)
             lm_bind     = dm_bind { fun_id = L bind_loc local_dm_name }
                             
                             
             warn_redundant = case dm_spec of
                                GenericDM {} -> True
                                VanillaDM    -> False
                
                
                
             ctxt = FunSigCtxt sel_name warn_redundant
       ; let local_dm_id = mkLocalId local_dm_name local_dm_ty
             local_dm_sig = CompleteSig { sig_bndr = local_dm_id
                                        , sig_ctxt  = ctxt
                                        , sig_loc   = getLoc (hsSigType hs_ty) }
       ; (ev_binds, (tc_bind, _))
               <- checkConstraints (TyConSkol ClassFlavour (getName clas)) tyvars [this_dict] $
                  tcPolyCheck no_prag_fn local_dm_sig
                              (L bind_loc lm_bind)
       ; let export = ABE { abe_ext   = noExt
                          , abe_poly  = global_dm_id
                          , abe_mono  = local_dm_id
                          , abe_wrap  = idHsWrapper
                          , abe_prags = IsDefaultMethod }
             full_bind = AbsBinds { abs_ext      = noExt
                                  , abs_tvs      = tyvars
                                  , abs_ev_vars  = [this_dict]
                                  , abs_exports  = [export]
                                  , abs_ev_binds = [ev_binds]
                                  , abs_binds    = tc_bind
                                  , abs_sig      = True }
       ; return (unitBag (L bind_loc full_bind)) }
  | otherwise = pprPanic "tcDefMeth" (ppr sel_id)
  where
    sel_name = idName sel_id
    no_prag_fn = emptyPragEnv   
                                
tcClassMinimalDef :: Name -> [LSig GhcRn] -> [TcMethInfo] -> TcM ClassMinimalDef
tcClassMinimalDef _clas sigs op_info
  = case findMinimalDef sigs of
      Nothing -> return defMindef
      Just mindef -> do
        
        
        
        
        tcg_env <- getGblEnv
        
        
        when (tcg_src tcg_env /= HsigFile) $
            whenIsJust (isUnsatisfied (mindef `impliesAtom`) defMindef) $
                       (\bf -> addWarnTc NoReason (warningMinimalDefIncomplete bf))
        return mindef
  where
    
    defMindef :: ClassMinimalDef
    defMindef = mkAnd [ noLoc (mkVar name)
                      | (name, _, Nothing) <- op_info ]
instantiateMethod :: Class -> TcId -> [TcType] -> TcType
instantiateMethod clas sel_id inst_tys
  = ASSERT( ok_first_pred ) local_meth_ty
  where
    rho_ty = piResultTys (idType sel_id) inst_tys
    (first_pred, local_meth_ty) = tcSplitPredFunTy_maybe rho_ty
                `orElse` pprPanic "tcInstanceMethod" (ppr sel_id)
    ok_first_pred = case getClassPredTys_maybe first_pred of
                      Just (clas1, _tys) -> clas == clas1
                      Nothing -> False
              
              
type HsSigFun = Name -> Maybe (LHsSigType GhcRn)
mkHsSigFun :: [LSig GhcRn] -> HsSigFun
mkHsSigFun sigs = lookupNameEnv env
  where
    env = mkHsSigEnv get_classop_sig sigs
    get_classop_sig :: LSig GhcRn -> Maybe ([Located Name], LHsSigType GhcRn)
    get_classop_sig  (L _ (ClassOpSig _ _ ns hs_ty)) = Just (ns, hs_ty)
    get_classop_sig  _                               = Nothing
findMethodBind  :: Name                 
                -> LHsBinds GhcRn       
                -> TcPragEnv
                -> Maybe (LHsBind GhcRn, SrcSpan, [LSig GhcRn])
                
                
                
findMethodBind sel_name binds prag_fn
  = foldlBag mplus Nothing (mapBag f binds)
  where
    prags    = lookupPragEnv prag_fn sel_name
    f bind@(L _ (FunBind { fun_id = L bndr_loc op_name }))
      | op_name == sel_name
             = Just (bind, bndr_loc, prags)
    f _other = Nothing
findMinimalDef :: [LSig GhcRn] -> Maybe ClassMinimalDef
findMinimalDef = firstJusts . map toMinimalDef
  where
    toMinimalDef :: LSig GhcRn -> Maybe ClassMinimalDef
    toMinimalDef (L _ (MinimalSig _ _ (L _ bf))) = Just (fmap unLoc bf)
    toMinimalDef _                               = Nothing
tcMkDeclCtxt :: TyClDecl GhcRn -> SDoc
tcMkDeclCtxt decl = hsep [text "In the", pprTyClDeclFlavour decl,
                      text "declaration for", quotes (ppr (tcdName decl))]
tcAddDeclCtxt :: TyClDecl GhcRn -> TcM a -> TcM a
tcAddDeclCtxt decl thing_inside
  = addErrCtxt (tcMkDeclCtxt decl) thing_inside
badMethodErr :: Outputable a => a -> Name -> SDoc
badMethodErr clas op
  = hsep [text "Class", quotes (ppr clas),
          text "does not have a method", quotes (ppr op)]
badGenericMethod :: Outputable a => a -> Name -> SDoc
badGenericMethod clas op
  = hsep [text "Class", quotes (ppr clas),
          text "has a generic-default signature without a binding", quotes (ppr op)]
badDmPrag :: TcId -> Sig GhcRn -> TcM ()
badDmPrag sel_id prag
  = addErrTc (text "The" <+> hsSigDoc prag <+> ptext (sLit "for default method")
              <+> quotes (ppr sel_id)
              <+> text "lacks an accompanying binding")
warningMinimalDefIncomplete :: ClassMinimalDef -> SDoc
warningMinimalDefIncomplete mindef
  = vcat [ text "The MINIMAL pragma does not require:"
         , nest 2 (pprBooleanFormulaNice mindef)
         , text "but there is no default implementation." ]
instDeclCtxt1 :: LHsSigType GhcRn -> SDoc
instDeclCtxt1 hs_inst_ty
  = inst_decl_ctxt (ppr (getLHsInstDeclHead hs_inst_ty))
instDeclCtxt2 :: Type -> SDoc
instDeclCtxt2 dfun_ty
  = instDeclCtxt3 cls tys
  where
    (_,_,cls,tys) = tcSplitDFunTy dfun_ty
instDeclCtxt3 :: Class -> [Type] -> SDoc
instDeclCtxt3 cls cls_tys
  = inst_decl_ctxt (ppr (mkClassPred cls cls_tys))
inst_decl_ctxt :: SDoc -> SDoc
inst_decl_ctxt doc = hang (text "In the instance declaration for")
                        2 (quotes doc)
tcATDefault :: SrcSpan
            -> TCvSubst
            -> NameSet
            -> ClassATItem
            -> TcM [FamInst]
tcATDefault loc inst_subst defined_ats (ATI fam_tc defs)
  
  | tyConName fam_tc `elemNameSet` defined_ats
  = return []
  
   
   
   
  | Just (rhs_ty, _loc) <- defs
  = do { let (subst', pat_tys') = mapAccumL subst_tv inst_subst
                                            (tyConTyVars fam_tc)
             rhs'     = substTyUnchecked subst' rhs_ty
             tcv' = tyCoVarsOfTypesList pat_tys'
             (tv', cv') = partition isTyVar tcv'
             tvs'     = toposortTyVars tv'
             cvs'     = toposortTyVars cv'
       ; rep_tc_name <- newFamInstTyConName (L loc (tyConName fam_tc)) pat_tys'
       ; let axiom = mkSingleCoAxiom Nominal rep_tc_name tvs' cvs'
                                     fam_tc pat_tys' rhs'
           
           
           
           
       ; traceTc "mk_deflt_at_instance" (vcat [ ppr fam_tc, ppr rhs_ty
                                              , pprCoAxiom axiom ])
       ; fam_inst <- newFamInst SynFamilyInst axiom
       ; return [fam_inst] }
   
  | otherwise  
  = do { warnMissingAT (tyConName fam_tc)
       ; return [] }
  where
    subst_tv subst tc_tv
      | Just ty <- lookupVarEnv (getTvSubstEnv subst) tc_tv
      = (subst, ty)
      | otherwise
      = (extendTvSubst subst tc_tv ty', ty')
      where
        ty' = mkTyVarTy (updateTyVarKind (substTyUnchecked subst) tc_tv)
warnMissingAT :: Name -> TcM ()
warnMissingAT name
  = do { warn <- woptM Opt_WarnMissingMethods
       ; traceTc "warn" (ppr name <+> ppr warn)
       ; hsc_src <- fmap tcg_src getGblEnv
       
       ; warnTc (Reason Opt_WarnMissingMethods) (warn && hsc_src /= HsigFile)
                (text "No explicit" <+> text "associated type"
                    <+> text "or default declaration for"
                    <+> quotes (ppr name)) }