{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE PatternSynonyms #-}

{-# OPTIONS_HADDOCK show-extensions #-}

module Internal
  ( -- * Create new constraints

    newWanted
  , newGiven
  , newDerived
    -- * Creating evidence

  , evByFiat
    -- * Lookup

  , lookupModule
  , lookupName
    -- * Trace state of the plugin

  , tracePlugin
    -- * Substitutions

  , flattenGivens
  , mkSubst
  , mkSubst'
  , substType
  , substCt
  )
where

import GHC.Tc.Plugin (TcPluginM, lookupOrig, tcPluginTrace)
import qualified GHC.Tc.Plugin as TcPlugin
    (newDerived, newWanted, getTopEnv, tcPluginIO, findImportedModule)
import GHC.Tc.Types (TcPlugin(..), TcPluginResult(..))
import Control.Arrow (first, second)
import Data.Function (on)
import Data.List (groupBy, partition, sortOn)
import GHC.Tc.Utils.TcType (TcType)
import Data.Maybe (mapMaybe)

import GhcApi.Constraint (Ct(..), CtEvidence(..), CtLoc)
import GhcApi.GhcPlugins

import Internal.Type (substType)
import Internal.Constraint (newGiven, flatToCt, mkSubst, overEvidencePredType)
import Internal.Evidence (evByFiat)

{-# ANN fr_mod "HLint: ignore Use camelCase" #-}

pattern FoundModule :: Module -> FindResult
pattern $mFoundModule :: forall {r}. FindResult -> (Module -> r) -> ((# #) -> r) -> r
FoundModule a <- Found _ a

fr_mod :: a -> a
fr_mod :: forall a. a -> a
fr_mod = forall a. a -> a
id

-- | Create a new [W]anted constraint.

newWanted  :: CtLoc -> PredType -> TcPluginM CtEvidence
newWanted :: CtLoc -> PredType -> TcPluginM CtEvidence
newWanted = CtLoc -> PredType -> TcPluginM CtEvidence
TcPlugin.newWanted

-- | Create a new [D]erived constraint.

newDerived :: CtLoc -> PredType -> TcPluginM CtEvidence
newDerived :: CtLoc -> PredType -> TcPluginM CtEvidence
newDerived = CtLoc -> PredType -> TcPluginM CtEvidence
TcPlugin.newDerived

-- | Find a module

lookupModule :: ModuleName -- ^ Name of the module

             -> FastString -- ^ Name of the package containing the module.

                           -- NOTE: This value is ignored on ghc>=8.0.

             -> TcPluginM Module
lookupModule :: ModuleName -> FastString -> TcPluginM Module
lookupModule ModuleName
mod_nm FastString
_pkg = do
  HscEnv
hsc_env <- TcPluginM HscEnv
TcPlugin.getTopEnv
  FindResult
found_module <- forall a. IO a -> TcPluginM a
TcPlugin.tcPluginIO forall a b. (a -> b) -> a -> b
$ HscEnv -> ModuleName -> IO FindResult
findPluginModule HscEnv
hsc_env ModuleName
mod_nm
  case FindResult
found_module of
    FoundModule Module
h -> forall (m :: * -> *) a. Monad m => a -> m a
return (forall a. a -> a
fr_mod Module
h)
    FindResult
_          -> do
      FindResult
found_module' <- ModuleName -> Maybe FastString -> TcPluginM FindResult
TcPlugin.findImportedModule ModuleName
mod_nm forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ String -> FastString
fsLit String
"this"
      case FindResult
found_module' of
        FoundModule Module
h -> forall (m :: * -> *) a. Monad m => a -> m a
return (forall a. a -> a
fr_mod Module
h)
        FindResult
_ -> forall a. String -> SDoc -> a
panicDoc String
"Couldn't find module" (forall a. Outputable a => a -> SDoc
ppr ModuleName
mod_nm)

-- | Find a 'Name' in a 'Module' given an 'OccName'

lookupName :: Module -> OccName -> TcPluginM Name
lookupName :: Module -> OccName -> TcPluginM Name
lookupName = Module -> OccName -> TcPluginM Name
lookupOrig

-- | Print out extra information about the initialisation, stop, and every run

-- of the plugin when @-ddump-tc-trace@ is enabled.

tracePlugin :: String -> TcPlugin -> TcPlugin
tracePlugin :: String -> TcPlugin -> TcPlugin
tracePlugin String
s TcPlugin{TcPluginM s
s -> TcPluginM ()
s -> TcPluginSolver
tcPluginInit :: ()
tcPluginSolve :: ()
tcPluginStop :: ()
tcPluginStop :: s -> TcPluginM ()
tcPluginSolve :: s -> TcPluginSolver
tcPluginInit :: TcPluginM s
..} = TcPlugin { tcPluginInit :: TcPluginM s
tcPluginInit  = TcPluginM s
traceInit
                                      , tcPluginSolve :: s -> TcPluginSolver
tcPluginSolve = s -> TcPluginSolver
traceSolve
                                      , tcPluginStop :: s -> TcPluginM ()
tcPluginStop  = s -> TcPluginM ()
traceStop
                                      }
  where
    traceInit :: TcPluginM s
traceInit = do
      -- workaround for https://ghc.haskell.org/trac/ghc/ticket/10301

      TcPluginM ()
initializeStaticFlags
      String -> SDoc -> TcPluginM ()
tcPluginTrace (String
"tcPluginInit " forall a. [a] -> [a] -> [a]
++ String
s) SDoc
empty forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> TcPluginM s
tcPluginInit

    traceStop :: s -> TcPluginM ()
traceStop  s
z = String -> SDoc -> TcPluginM ()
tcPluginTrace (String
"tcPluginStop " forall a. [a] -> [a] -> [a]
++ String
s) SDoc
empty forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> s -> TcPluginM ()
tcPluginStop s
z

    traceSolve :: s -> TcPluginSolver
traceSolve s
z [Ct]
given [Ct]
derived [Ct]
wanted = do
      String -> SDoc -> TcPluginM ()
tcPluginTrace (String
"tcPluginSolve start " forall a. [a] -> [a] -> [a]
++ String
s)
                        (String -> SDoc
text String
"given   =" SDoc -> SDoc -> SDoc
<+> forall a. Outputable a => a -> SDoc
ppr [Ct]
given
                      SDoc -> SDoc -> SDoc
$$ String -> SDoc
text String
"derived =" SDoc -> SDoc -> SDoc
<+> forall a. Outputable a => a -> SDoc
ppr [Ct]
derived
                      SDoc -> SDoc -> SDoc
$$ String -> SDoc
text String
"wanted  =" SDoc -> SDoc -> SDoc
<+> forall a. Outputable a => a -> SDoc
ppr [Ct]
wanted)
      TcPluginResult
r <- s -> TcPluginSolver
tcPluginSolve s
z [Ct]
given [Ct]
derived [Ct]
wanted
      case TcPluginResult
r of
        TcPluginOk [(EvTerm, Ct)]
solved [Ct]
new     -> String -> SDoc -> TcPluginM ()
tcPluginTrace (String
"tcPluginSolve ok " forall a. [a] -> [a] -> [a]
++ String
s)
                                         (String -> SDoc
text String
"solved =" SDoc -> SDoc -> SDoc
<+> forall a. Outputable a => a -> SDoc
ppr [(EvTerm, Ct)]
solved
                                       SDoc -> SDoc -> SDoc
$$ String -> SDoc
text String
"new    =" SDoc -> SDoc -> SDoc
<+> forall a. Outputable a => a -> SDoc
ppr [Ct]
new)
        TcPluginContradiction [Ct]
bad -> String -> SDoc -> TcPluginM ()
tcPluginTrace
                                         (String
"tcPluginSolve contradiction " forall a. [a] -> [a] -> [a]
++ String
s)
                                         (String -> SDoc
text String
"bad =" SDoc -> SDoc -> SDoc
<+> forall a. Outputable a => a -> SDoc
ppr [Ct]
bad)
      forall (m :: * -> *) a. Monad m => a -> m a
return TcPluginResult
r

-- workaround for https://ghc.haskell.org/trac/ghc/ticket/10301

initializeStaticFlags :: TcPluginM ()
initializeStaticFlags :: TcPluginM ()
initializeStaticFlags = forall (m :: * -> *) a. Monad m => a -> m a
return ()

-- | Flattens evidence of constraints by substituting each others equalities.

--

-- __NB:__ Should only be used on /[G]iven/ constraints!

--

-- __NB:__ Doesn't flatten under binders

flattenGivens :: [Ct] -> [Ct]
flattenGivens :: [Ct] -> [Ct]
flattenGivens [Ct]
givens =
  forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe [((TcTyVar, PredType), Ct)] -> Maybe Ct
flatToCt [[((TcTyVar, PredType), Ct)]]
flat forall a. [a] -> [a] -> [a]
++ forall a b. (a -> b) -> [a] -> [b]
map ([(TcTyVar, PredType)] -> Ct -> Ct
substCt [(TcTyVar, PredType)]
subst') [Ct]
givens
 where
  subst :: [((TcTyVar, PredType), Ct)]
subst = [Ct] -> [((TcTyVar, PredType), Ct)]
mkSubst' [Ct]
givens
  ([[((TcTyVar, PredType), Ct)]]
flat,[(TcTyVar, PredType)]
subst')
    = forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (d, b) (d, c)
second (forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat)
    forall a b. (a -> b) -> a -> b
$ forall a. (a -> Bool) -> [a] -> ([a], [a])
partition ((forall a. Ord a => a -> a -> Bool
>= Int
2) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) a. Foldable t => t a -> Int
length)
    forall a b. (a -> b) -> a -> b
$ forall a. (a -> a -> Bool) -> [a] -> [[a]]
groupBy (forall a. Eq a => a -> a -> Bool
(==) forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` (forall a b. (a, b) -> a
fstforall b c a. (b -> c) -> (a -> b) -> a -> c
.forall a b. (a, b) -> a
fst))
    forall a b. (a -> b) -> a -> b
$ forall b a. Ord b => (a -> b) -> [a] -> [a]
sortOn (forall a b. (a, b) -> a
fstforall b c a. (b -> c) -> (a -> b) -> a -> c
.forall a b. (a, b) -> a
fst) [((TcTyVar, PredType), Ct)]
subst

-- | Create flattened substitutions from type equalities, i.e. the substitutions

-- have been applied to each others right hand sides.

mkSubst' :: [Ct] -> [((TcTyVar,TcType),Ct)]
mkSubst' :: [Ct] -> [((TcTyVar, PredType), Ct)]
mkSubst' = forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr ((TcTyVar, PredType), Ct)
-> [((TcTyVar, PredType), Ct)] -> [((TcTyVar, PredType), Ct)]
substSubst [] forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe Ct -> Maybe ((TcTyVar, PredType), Ct)
mkSubst
 where
  substSubst :: ((TcTyVar,TcType),Ct)
             -> [((TcTyVar,TcType),Ct)]
             -> [((TcTyVar,TcType),Ct)]
  substSubst :: ((TcTyVar, PredType), Ct)
-> [((TcTyVar, PredType), Ct)] -> [((TcTyVar, PredType), Ct)]
substSubst ((TcTyVar
tv,PredType
t),Ct
ct) [((TcTyVar, PredType), Ct)]
s = ((TcTyVar
tv,[(TcTyVar, PredType)] -> PredType -> PredType
substType (forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst [((TcTyVar, PredType), Ct)]
s) PredType
t),Ct
ct)
                           forall a. a -> [a] -> [a]
: forall a b. (a -> b) -> [a] -> [b]
map (forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (b, d) (c, d)
first (forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (d, b) (d, c)
second ([(TcTyVar, PredType)] -> PredType -> PredType
substType [(TcTyVar
tv,PredType
t)]))) [((TcTyVar, PredType), Ct)]
s

-- | Apply substitution in the evidence of Cts

substCt :: [(TcTyVar, TcType)] -> Ct -> Ct
substCt :: [(TcTyVar, PredType)] -> Ct -> Ct
substCt [(TcTyVar, PredType)]
subst = (PredType -> PredType) -> Ct -> Ct
overEvidencePredType ([(TcTyVar, PredType)] -> PredType -> PredType
substType [(TcTyVar, PredType)]
subst)