{-# LANGUAGE CPP #-}

module Wingman.Context where

import           Control.Arrow
import           Control.Monad.Reader
import           Data.Coerce (coerce)
import           Data.Foldable.Extra (allM)
import           Data.Maybe (fromMaybe, isJust, mapMaybe)
import qualified Data.Set as S
import           Development.IDE.GHC.Compat
import           Development.IDE.GHC.Compat.Util
import           Wingman.GHC (normalizeType)
import           Wingman.Judgements.Theta
import           Wingman.Types

#if __GLASGOW_HASKELL__ >= 900
import GHC.Tc.Utils.TcType
#endif


mkContext
    :: Config
    -> [(OccName, CType)]
    -> TcGblEnv
    -> HscEnv
    -> ExternalPackageState
    -> [Evidence]
    -> Context
mkContext :: Config
-> [(OccName, CType)]
-> TcGblEnv
-> HscEnv
-> ExternalPackageState
-> [Evidence]
-> Context
mkContext Config
cfg [(OccName, CType)]
locals TcGblEnv
tcg HscEnv
hscenv ExternalPackageState
eps [Evidence]
ev = (Context -> Context) -> Context
forall a. (a -> a) -> a
fix ((Context -> Context) -> Context)
-> (Context -> Context) -> Context
forall a b. (a -> b) -> a -> b
$ \Context
ctx ->
  Context :: [(OccName, CType)]
-> [(OccName, CType)]
-> Config
-> InstEnvs
-> FamInstEnvs
-> Set CType
-> HscEnv
-> OccEnv [GlobalRdrElt]
-> Module
-> Context
Context
    { ctxDefiningFuncs :: [(OccName, CType)]
ctxDefiningFuncs
        = ((OccName, CType) -> (OccName, CType))
-> [(OccName, CType)] -> [(OccName, CType)]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((CType -> CType) -> (OccName, CType) -> (OccName, CType)
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (d, b) (d, c)
second ((CType -> CType) -> (OccName, CType) -> (OccName, CType))
-> (CType -> CType) -> (OccName, CType) -> (OccName, CType)
forall a b. (a -> b) -> a -> b
$ (Type -> Type) -> CType -> CType
coerce ((Type -> Type) -> CType -> CType)
-> (Type -> Type) -> CType -> CType
forall a b. (a -> b) -> a -> b
$ Context -> Type -> Type
normalizeType Context
ctx) [(OccName, CType)]
locals
    , ctxModuleFuncs :: [(OccName, CType)]
ctxModuleFuncs
        = (Id -> (OccName, CType)) -> [Id] -> [(OccName, CType)]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((CType -> CType) -> (OccName, CType) -> (OccName, CType)
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (d, b) (d, c)
second ((Type -> Type) -> CType -> CType
coerce ((Type -> Type) -> CType -> CType)
-> (Type -> Type) -> CType -> CType
forall a b. (a -> b) -> a -> b
$ Context -> Type -> Type
normalizeType Context
ctx) ((OccName, CType) -> (OccName, CType))
-> (Id -> (OccName, CType)) -> Id -> (OccName, CType)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Id -> (OccName, CType)
splitId)
        ([Id] -> [(OccName, CType)])
-> (Bag (LHsBindLR GhcTc GhcTc) -> [Id])
-> Bag (LHsBindLR GhcTc GhcTc)
-> [(OccName, CType)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Id] -> [Id] -> [Id]
forall a. Monoid a => a -> a -> a
mappend (TcGblEnv -> [Id]
locallyDefinedMethods TcGblEnv
tcg)
        ([Id] -> [Id])
-> (Bag (LHsBindLR GhcTc GhcTc) -> [Id])
-> Bag (LHsBindLR GhcTc GhcTc)
-> [Id]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (HsBindLR GhcTc GhcTc -> [Id]
getFunBindId (HsBindLR GhcTc GhcTc -> [Id]) -> [HsBindLR GhcTc GhcTc] -> [Id]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<<)
        ([HsBindLR GhcTc GhcTc] -> [Id])
-> (Bag (LHsBindLR GhcTc GhcTc) -> [HsBindLR GhcTc GhcTc])
-> Bag (LHsBindLR GhcTc GhcTc)
-> [Id]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (LHsBindLR GhcTc GhcTc -> HsBindLR GhcTc GhcTc)
-> [LHsBindLR GhcTc GhcTc] -> [HsBindLR GhcTc GhcTc]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap LHsBindLR GhcTc GhcTc -> HsBindLR GhcTc GhcTc
forall a. HasSrcSpan a => a -> SrcSpanLess a
unLoc
        ([LHsBindLR GhcTc GhcTc] -> [HsBindLR GhcTc GhcTc])
-> (Bag (LHsBindLR GhcTc GhcTc) -> [LHsBindLR GhcTc GhcTc])
-> Bag (LHsBindLR GhcTc GhcTc)
-> [HsBindLR GhcTc GhcTc]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Bag (LHsBindLR GhcTc GhcTc) -> [LHsBindLR GhcTc GhcTc]
forall a. Bag a -> [a]
bagToList
        (Bag (LHsBindLR GhcTc GhcTc) -> [(OccName, CType)])
-> Bag (LHsBindLR GhcTc GhcTc) -> [(OccName, CType)]
forall a b. (a -> b) -> a -> b
$ TcGblEnv -> Bag (LHsBindLR GhcTc GhcTc)
tcg_binds TcGblEnv
tcg
    , ctxConfig :: Config
ctxConfig = Config
cfg
    , ctxFamInstEnvs :: FamInstEnvs
ctxFamInstEnvs =
        (ExternalPackageState -> PackageFamInstEnv
eps_fam_inst_env ExternalPackageState
eps, TcGblEnv -> PackageFamInstEnv
tcg_fam_inst_env TcGblEnv
tcg)
    , ctxInstEnvs :: InstEnvs
ctxInstEnvs =
        InstEnv -> InstEnv -> VisibleOrphanModules -> InstEnvs
InstEnvs
          (ExternalPackageState -> InstEnv
eps_inst_env ExternalPackageState
eps)
          (TcGblEnv -> InstEnv
tcg_inst_env TcGblEnv
tcg)
          (TcGblEnv -> VisibleOrphanModules
tcVisibleOrphanMods TcGblEnv
tcg)
    , ctxTheta :: Set CType
ctxTheta = [Evidence] -> Set CType
evidenceToThetaType [Evidence]
ev
    , ctx_hscEnv :: HscEnv
ctx_hscEnv = HscEnv
hscenv
    , ctx_occEnv :: OccEnv [GlobalRdrElt]
ctx_occEnv = TcGblEnv -> OccEnv [GlobalRdrElt]
tcg_rdr_env TcGblEnv
tcg
    , ctx_module :: Module
ctx_module = TcGblEnv -> Module
forall t. ContainsModule t => t -> Module
extractModule TcGblEnv
tcg
    }


locallyDefinedMethods :: TcGblEnv -> [Id]
locallyDefinedMethods :: TcGblEnv -> [Id]
locallyDefinedMethods
  = (Class -> [Id]) -> [Class] -> [Id]
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Class -> [Id]
classMethods
  ([Class] -> [Id]) -> (TcGblEnv -> [Class]) -> TcGblEnv -> [Id]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (TyCon -> Maybe Class) -> [TyCon] -> [Class]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe TyCon -> Maybe Class
tyConClass_maybe
  ([TyCon] -> [Class])
-> (TcGblEnv -> [TyCon]) -> TcGblEnv -> [Class]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TcGblEnv -> [TyCon]
tcg_tcs



splitId :: Id -> (OccName, CType)
splitId :: Id -> (OccName, CType)
splitId = Id -> OccName
forall name. HasOccName name => name -> OccName
occName (Id -> OccName) -> (Id -> CType) -> Id -> (OccName, CType)
forall (a :: * -> * -> *) b c c'.
Arrow a =>
a b c -> a b c' -> a b (c, c')
&&& Type -> CType
CType (Type -> CType) -> (Id -> Type) -> Id -> CType
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Id -> Type
idType


getFunBindId :: HsBindLR GhcTc GhcTc -> [Id]
getFunBindId :: HsBindLR GhcTc GhcTc -> [Id]
getFunBindId (AbsBinds XAbsBinds GhcTc GhcTc
_ [Id]
_ [Id]
_ [ABExport GhcTc]
abes [TcEvBinds]
_ Bag (LHsBindLR GhcTc GhcTc)
_ Bool
_)
  = [ABExport GhcTc]
abes [ABExport GhcTc] -> (ABExport GhcTc -> [Id]) -> [Id]
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
      ABE XABE GhcTc
_ IdP GhcTc
poly IdP GhcTc
_ HsWrapper
_ TcSpecPrags
_ -> Id -> [Id]
forall (f :: * -> *) a. Applicative f => a -> f a
pure IdP GhcTc
Id
poly
      ABExport GhcTc
_                -> []
getFunBindId HsBindLR GhcTc GhcTc
_ = []


------------------------------------------------------------------------------
-- | Determine if there is an instance that exists for the given 'Class' at the
-- specified types. Deeply checks contexts to ensure the instance is actually
-- real.
--
-- If so, this returns a 'PredType' that corresponds to the type of the
-- dictionary.
getInstance :: MonadReader Context m => Class -> [Type] -> m (Maybe (Class, PredType))
getInstance :: Class -> [Type] -> m (Maybe (Class, Type))
getInstance Class
cls [Type]
tys = do
  InstEnvs
env <- (Context -> InstEnvs) -> m InstEnvs
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Context -> InstEnvs
ctxInstEnvs
  let ([InstMatch]
mres, [ClsInst]
_, [InstMatch]
_) = Bool
-> InstEnvs
-> Class
-> [Type]
-> ([InstMatch], [ClsInst], [InstMatch])
lookupInstEnv Bool
False InstEnvs
env Class
cls [Type]
tys
  case [InstMatch]
mres of
    ((ClsInst
inst, [DFunInstType]
mapps) : [InstMatch]
_) -> do
      -- Get the instantiated type of the dictionary
      let df :: Type
df = HasDebugCallStack => Type -> [Type] -> Type
Type -> [Type] -> Type
piResultTys (Id -> Type
idType (Id -> Type) -> Id -> Type
forall a b. (a -> b) -> a -> b
$ ClsInst -> Id
is_dfun ClsInst
inst) ([Type] -> Type) -> [Type] -> Type
forall a b. (a -> b) -> a -> b
$ (Type -> DFunInstType -> Type)
-> [Type] -> [DFunInstType] -> [Type]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Type -> DFunInstType -> Type
forall a. a -> Maybe a -> a
fromMaybe [Type]
alphaTys [DFunInstType]
mapps
      -- pull off its resulting arguments
      let ([Type]
theta, Type
df') = Type -> ([Type], Type)
tcSplitPhiTy Type
df
      (Type -> m Bool) -> [Type] -> m Bool
forall (f :: * -> *) (m :: * -> *) a.
(Foldable f, Monad m) =>
(a -> m Bool) -> f a -> m Bool
allM Type -> m Bool
forall (m :: * -> *). MonadReader Context m => Type -> m Bool
hasClassInstance [Type]
theta m Bool
-> (Bool -> m (Maybe (Class, Type))) -> m (Maybe (Class, Type))
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
        Bool
True -> Maybe (Class, Type) -> m (Maybe (Class, Type))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe (Class, Type) -> m (Maybe (Class, Type)))
-> Maybe (Class, Type) -> m (Maybe (Class, Type))
forall a b. (a -> b) -> a -> b
$ (Class, Type) -> Maybe (Class, Type)
forall a. a -> Maybe a
Just (Class
cls, Type
df')
        Bool
False -> Maybe (Class, Type) -> m (Maybe (Class, Type))
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe (Class, Type)
forall a. Maybe a
Nothing
    [InstMatch]
_ -> Maybe (Class, Type) -> m (Maybe (Class, Type))
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe (Class, Type)
forall a. Maybe a
Nothing


------------------------------------------------------------------------------
-- | Like 'getInstance', but only returns whether or not it succeeded. Can fail
-- fast, and uses a cached Theta from the context.
hasClassInstance :: MonadReader Context m => PredType -> m Bool
hasClassInstance :: Type -> m Bool
hasClassInstance Type
predty = do
  Set CType
theta <- (Context -> Set CType) -> m (Set CType)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Context -> Set CType
ctxTheta
  case CType -> Set CType -> Bool
forall a. Ord a => a -> Set a -> Bool
S.member (Type -> CType
CType Type
predty) Set CType
theta of
    Bool
True -> Bool -> m Bool
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
True
    Bool
False -> do
      let (TyCon
con, [Type]
apps) = Type -> (TyCon, [Type])
tcSplitTyConApp Type
predty
      case TyCon -> Maybe Class
tyConClass_maybe TyCon
con of
        Maybe Class
Nothing -> Bool -> m Bool
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
False
        Just Class
cls -> (Maybe (Class, Type) -> Bool) -> m (Maybe (Class, Type)) -> m Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Maybe (Class, Type) -> Bool
forall a. Maybe a -> Bool
isJust (m (Maybe (Class, Type)) -> m Bool)
-> m (Maybe (Class, Type)) -> m Bool
forall a b. (a -> b) -> a -> b
$ Class -> [Type] -> m (Maybe (Class, Type))
forall (m :: * -> *).
MonadReader Context m =>
Class -> [Type] -> m (Maybe (Class, Type))
getInstance Class
cls [Type]
apps