{-# LANGUAGE CPP                  #-}
{-# LANGUAGE DataKinds            #-}
{-# LANGUAGE FlexibleContexts     #-}
{-# LANGUAGE LambdaCase           #-}
{-# LANGUAGE RecordWildCards      #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE ViewPatterns         #-}

{-# OPTIONS_GHC -Wno-orphans #-}

-- | Thin layer around ghc-tcplugin-api
module Data.Record.Anon.Internal.Plugin.TC.GhcTcPluginAPI (
    -- * Standard exports
    module GHC.TcPlugin.API
  , module GHC.Builtin.Names
  , module GHC.Builtin.Types
  , module GHC.Builtin.Types.Prim
  , module GHC.Core.Make
  , module GHC.Utils.Outputable

    -- * New functonality
  , isCanonicalVarEq
  , getModule
  ) where

import GHC.Stack

#if __GLASGOW_HASKELL__ < 900
import Data.List.NonEmpty (NonEmpty, toList)
#endif

import GHC.TcPlugin.API
import GHC.Builtin.Names
import GHC.Builtin.Types
import GHC.Builtin.Types.Prim
import GHC.Core.Make
import GHC.Utils.Outputable

#if __GLASGOW_HASKELL__ >= 808 &&  __GLASGOW_HASKELL__ < 810
import TcRnTypes (Ct(..))
#endif

#if __GLASGOW_HASKELL__ >= 810 &&  __GLASGOW_HASKELL__ < 900
import Constraint (Ct(..))
#endif

#if __GLASGOW_HASKELL__ >= 900 &&  __GLASGOW_HASKELL__ < 902
import GHC.Tc.Types.Constraint (Ct(..))
#endif

#if __GLASGOW_HASKELL__ >= 902
import GHC.Tc.Types.Constraint (Ct(..), CanEqLHS(..))
#endif

isCanonicalVarEq :: Ct -> Maybe (TcTyVar, Type)
#if __GLASGOW_HASKELL__ >= 808 &&  __GLASGOW_HASKELL__ < 902
isCanonicalVarEq = \case
    CTyEqCan{..}  -> Just (cc_tyvar, cc_rhs)
    CFunEqCan{..} -> Just (cc_fsk, mkTyConApp cc_fun cc_tyargs)
    _otherwise    -> Nothing
#endif
#if __GLASGOW_HASKELL__ >= 902
isCanonicalVarEq :: Ct -> Maybe (TcTyVar, Type)
isCanonicalVarEq = \case
    CEqCan{CtEvidence
CanEqLHS
EqRel
Type
cc_eq_rel :: Ct -> EqRel
cc_ev :: Ct -> CtEvidence
cc_lhs :: Ct -> CanEqLHS
cc_rhs :: Ct -> Type
cc_eq_rel :: EqRel
cc_rhs :: Type
cc_lhs :: CanEqLHS
cc_ev :: CtEvidence
..}
      | TyVarLHS TcTyVar
var <- CanEqLHS
cc_lhs
      -> forall a. a -> Maybe a
Just (TcTyVar
var, Type
cc_rhs)
      | TyFamLHS TyCon
tyCon [Type]
args <- CanEqLHS
cc_lhs
      , Just TcTyVar
var            <- Type -> Maybe TcTyVar
getTyVar_maybe Type
cc_rhs
      -> forall a. a -> Maybe a
Just (TcTyVar
var, TyCon -> [Type] -> Type
mkTyConApp TyCon
tyCon [Type]
args)
    Ct
_otherwise
      -> forall a. Maybe a
Nothing
#endif

-- TODO: Ideally we would actually show the location information obviously
instance Outputable CtLoc where
  ppr :: CtLoc -> SDoc
ppr CtLoc
_ = String -> SDoc
text String
"<CtLoc>"

#if __GLASGOW_HASKELL__ < 900
instance Outputable a => Outputable (NonEmpty a) where
  ppr = ppr . toList
#endif

#if __GLASGOW_HASKELL__ >= 902
instance (Outputable l, Outputable e) => Outputable (GenLocated l e) where
  ppr :: GenLocated l e -> SDoc
ppr (L l
l e
e) = SDoc -> SDoc
parens forall a b. (a -> b) -> a -> b
$ String -> SDoc
text String
"L" SDoc -> SDoc -> SDoc
<+> forall a. Outputable a => a -> SDoc
ppr l
l SDoc -> SDoc -> SDoc
<+> forall a. Outputable a => a -> SDoc
ppr e
e
#endif

getModule :: (HasCallStack, MonadTcPlugin m) => String -> String -> m Module
getModule :: forall (m :: * -> *).
(HasCallStack, MonadTcPlugin m) =>
String -> String -> m Module
getModule String
pkg String
modl = do
    let modl' :: ModuleName
modl' = String -> ModuleName
mkModuleName String
modl
    PkgQual
pkg' <- forall (m :: * -> *).
MonadTcPlugin m =>
ModuleName -> Maybe FastString -> m PkgQual
resolveImport ModuleName
modl' (forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ String -> FastString
fsLit String
pkg)
    FindResult
res  <- forall (m :: * -> *).
MonadTcPlugin m =>
ModuleName -> PkgQual -> m FindResult
findImportedModule ModuleName
modl' PkgQual
pkg'
    case FindResult
res of
      Found ModLocation
_ Module
m  -> forall (m :: * -> *) a. Monad m => a -> m a
return Module
m
      FindResult
_otherwise -> forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [
          String
"getModule: could not find "
        , String
modl
        , String
" in package "
        , String
pkg
        ]