{-|
Copyright  :  (C) 2015-2016, University of Twente
License    :  BSD2 (see the file LICENSE)
Maintainer :  Christiaan Baaij <christiaan.baaij@gmail.com>

To use the plugin, add the

@
{\-\# OPTIONS_GHC -fplugin GHC.TypeLits.Extra.Solver \#-\}
@

pragma to the header of your file

-}

{-# LANGUAGE CPP           #-}
{-# LANGUAGE TupleSections #-}

{-# OPTIONS_HADDOCK show-extensions #-}

module GHC.TypeLits.Extra.Solver
  ( plugin )
where

-- external
import Control.Monad.Trans.Maybe (MaybeT (..))
import Data.Maybe                (catMaybes)
import GHC.TcPluginM.Extra       (evByFiat, lookupModule, lookupName
                                 ,tracePlugin, newWanted)
#if MIN_VERSION_ghc(8,4,0)
import GHC.TcPluginM.Extra (flattenGivens)
#else
import Control.Monad ((<=<))
#endif

-- GHC API
#if MIN_VERSION_ghc(9,0,0)
import GHC.Builtin.Names (eqPrimTyConKey, hasKey)
import GHC.Builtin.Types (promotedTrueDataCon, promotedFalseDataCon)
#if MIN_VERSION_ghc(9,2,0)
import GHC.Builtin.Types (boolTy, naturalTy)
#else
import GHC.Builtin.Types (typeNatKind)
#endif
import GHC.Builtin.Types.Literals (typeNatTyCons)
#if MIN_VERSION_ghc(9,2,0)
import GHC.Builtin.Types.Literals (typeNatCmpTyCon)
#else
import GHC.Builtin.Types.Literals (typeNatLeqTyCon)
#endif
import GHC.Core.Predicate (EqRel (NomEq), Pred (EqPred), classifyPredType)
import GHC.Core.TyCo.Rep (Type (..))
import GHC.Core.Type (Kind, eqType, mkTyConApp, splitTyConApp_maybe, typeKind)
import GHC.Data.FastString (fsLit)
import GHC.Driver.Plugins (Plugin (..), defaultPlugin, purePlugin)
import GHC.Tc.Plugin (TcPluginM, tcLookupTyCon, tcPluginTrace)
import GHC.Tc.Types (TcPlugin(..), TcPluginResult (..))
import GHC.Tc.Types.Constraint
  (Ct, ctEvidence, ctEvPred, ctLoc, isWantedCt, cc_ev)
#if MIN_VERSION_ghc(9,2,0)
import GHC.Tc.Types.Constraint (Ct (CQuantCan), qci_ev)
#endif
import GHC.Tc.Types.Evidence (EvTerm)
import GHC.Types.Name.Occurrence (mkTcOcc)
import GHC.Unit.Module (mkModuleName)
import GHC.Utils.Outputable (Outputable (..), (<+>), ($$), text)
#else
import FastString (fsLit)
import Module     (mkModuleName)
import OccName    (mkTcOcc)
import Outputable (Outputable (..), (<+>), ($$), text)
import Plugins    (Plugin (..), defaultPlugin)
#if MIN_VERSION_ghc(8,6,0)
import Plugins    (purePlugin)
#endif
import PrelNames  (eqPrimTyConKey, hasKey)
import TcEvidence (EvTerm)
import TcPluginM  (TcPluginM, tcLookupTyCon, tcPluginTrace)
import TcRnTypes  (TcPlugin(..), TcPluginResult (..))
import Type       (Kind, eqType, mkTyConApp, splitTyConApp_maybe)
import TyCoRep    (Type (..))
import TysWiredIn (typeNatKind, promotedTrueDataCon, promotedFalseDataCon)
import TcTypeNats (typeNatLeqTyCon)
#if MIN_VERSION_ghc(8,4,0)
import TcTypeNats (typeNatTyCons)
#else
import TcPluginM  (zonkCt)
#endif

#if MIN_VERSION_ghc(8,10,0)
import Constraint (Ct, ctEvidence, ctEvPred, ctLoc, isWantedCt, cc_ev)
import Predicate  (EqRel (NomEq), Pred (EqPred), classifyPredType)
import Type       (typeKind)
#else
import TcRnTypes  (Ct, ctEvidence, ctEvPred, ctLoc, isWantedCt, cc_ev)
import TcType     (typeKind)
import Type       (EqRel (NomEq), PredTree (EqPred), classifyPredType)
#endif
#endif

-- internal
import GHC.TypeLits.Extra.Solver.Operations
import GHC.TypeLits.Extra.Solver.Unify

#if MIN_VERSION_ghc(9,2,0)
typeNatKind :: Type
typeNatKind :: Type
typeNatKind = Type
naturalTy
#endif

-- | A solver implement as a type-checker plugin for:
--
--     * 'Div': type-level 'div'
--
--     * 'Mod': type-level 'mod'
--
--     * 'FLog': type-level equivalent of <https://hackage.haskell.org/package/base-4.17.0.0/docs/GHC-Integer-Logarithms.html#v:integerLogBase-35- integerLogBase#>
--       .i.e. the exact integer equivalent to "@'floor' ('logBase' x y)@"
--
--     * 'CLog': type-level equivalent of /the ceiling of/ <https://hackage.haskell.org/package/base-4.17.0.0/docs/GHC-Integer-Logarithms.html#v:integerLogBase-35- integerLogBase#>
--       .i.e. the exact integer equivalent to "@'ceiling' ('logBase' x y)@"
--
--     * 'Log': type-level equivalent of <https://hackage.haskell.org/package/base-4.17.0.0/docs/GHC-Integer-Logarithms.html#v:integerLogBase-35- integerLogBase#>
--        where the operation only reduces when "@'floor' ('logBase' b x) ~ 'ceiling' ('logBase' b x)@"
--
--     * 'GCD': a type-level 'gcd'
--
--     * 'LCM': a type-level 'lcm'
--
-- To use the plugin, add
--
-- @
-- {\-\# OPTIONS_GHC -fplugin GHC.TypeLits.Extra.Solver \#-\}
-- @
--
-- To the header of your file.
plugin :: Plugin
plugin :: Plugin
plugin
  = Plugin
defaultPlugin
  { tcPlugin :: TcPlugin
tcPlugin = forall a b. a -> b -> a
const forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just TcPlugin
normalisePlugin
#if MIN_VERSION_ghc(8,6,0)
  , pluginRecompile :: [CommandLineOption] -> IO PluginRecompile
pluginRecompile = [CommandLineOption] -> IO PluginRecompile
purePlugin
#endif
  }

normalisePlugin :: TcPlugin
normalisePlugin :: TcPlugin
normalisePlugin = CommandLineOption -> TcPlugin -> TcPlugin
tracePlugin CommandLineOption
"ghc-typelits-extra"
  TcPlugin { tcPluginInit :: TcPluginM ExtraDefs
tcPluginInit  = TcPluginM ExtraDefs
lookupExtraDefs
           , tcPluginSolve :: ExtraDefs -> TcPluginSolver
tcPluginSolve = ExtraDefs -> TcPluginSolver
decideEqualSOP
           , tcPluginStop :: ExtraDefs -> TcPluginM ()
tcPluginStop  = forall a b. a -> b -> a
const (forall (m :: * -> *) a. Monad m => a -> m a
return ())
           }

decideEqualSOP :: ExtraDefs -> [Ct] -> [Ct] -> [Ct] -> TcPluginM TcPluginResult
decideEqualSOP :: ExtraDefs -> TcPluginSolver
decideEqualSOP ExtraDefs
_    [Ct]
_givens [Ct]
_deriveds []      = forall (m :: * -> *) a. Monad m => a -> m a
return ([(EvTerm, Ct)] -> [Ct] -> TcPluginResult
TcPluginOk [] [])
decideEqualSOP ExtraDefs
defs [Ct]
givens  [Ct]
_deriveds [Ct]
wanteds = do
  -- GHC 7.10.1 puts deriveds with the wanteds, so filter them out
  let wanteds' :: [Ct]
wanteds' = forall a. (a -> Bool) -> [a] -> [a]
filter Ct -> Bool
isWantedCt [Ct]
wanteds
  [SolverConstraint]
unit_wanteds <- forall a. [Maybe a] -> [a]
catMaybes forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (m :: * -> *) a. MaybeT m a -> m (Maybe a)
runMaybeT forall b c a. (b -> c) -> (a -> b) -> a -> c
. ExtraDefs -> Ct -> MaybeT TcPluginM SolverConstraint
toSolverConstraint ExtraDefs
defs) [Ct]
wanteds'
  case [SolverConstraint]
unit_wanteds of
    [] -> forall (m :: * -> *) a. Monad m => a -> m a
return ([(EvTerm, Ct)] -> [Ct] -> TcPluginResult
TcPluginOk [] [])
    [SolverConstraint]
_  -> do
#if MIN_VERSION_ghc(8,4,0)
      [SolverConstraint]
unit_givens <- forall a. [Maybe a] -> [a]
catMaybes forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (m :: * -> *) a. MaybeT m a -> m (Maybe a)
runMaybeT forall b c a. (b -> c) -> (a -> b) -> a -> c
. ExtraDefs -> Ct -> MaybeT TcPluginM SolverConstraint
toSolverConstraint ExtraDefs
defs) ([Ct]
givens forall a. [a] -> [a] -> [a]
++ [Ct] -> [Ct]
flattenGivens [Ct]
givens)
#else
      unit_givens <- catMaybes <$> mapM ((runMaybeT . toSolverConstraint defs) <=< zonkCt) givens
#endif
      SimplifyResult
sr <- ExtraDefs -> [SolverConstraint] -> TcPluginM SimplifyResult
simplifyExtra ExtraDefs
defs ([SolverConstraint]
unit_givens forall a. [a] -> [a] -> [a]
++ [SolverConstraint]
unit_wanteds)
      CommandLineOption -> SDoc -> TcPluginM ()
tcPluginTrace CommandLineOption
"normalised" (forall a. Outputable a => a -> SDoc
ppr SimplifyResult
sr)
      case SimplifyResult
sr of
        Simplified [(EvTerm, Ct)]
evs [Ct]
new -> forall (m :: * -> *) a. Monad m => a -> m a
return ([(EvTerm, Ct)] -> [Ct] -> TcPluginResult
TcPluginOk (forall a. (a -> Bool) -> [a] -> [a]
filter (Ct -> Bool
isWantedCt forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> b
snd) [(EvTerm, Ct)]
evs) [Ct]
new)
        Impossible SolverConstraint
eq  -> forall (m :: * -> *) a. Monad m => a -> m a
return ([Ct] -> TcPluginResult
TcPluginContradiction [SolverConstraint -> Ct
fromSolverConstraint SolverConstraint
eq])

data SolverConstraint
   = NatEquality Ct ExtraOp ExtraOp Normalised
   | NatInequality Ct ExtraOp ExtraOp Bool Normalised

instance Outputable SolverConstraint where
  ppr :: SolverConstraint -> SDoc
ppr (NatEquality Ct
ct ExtraOp
op1 ExtraOp
op2 Normalised
norm) = CommandLineOption -> SDoc
text CommandLineOption
"NatEquality" SDoc -> SDoc -> SDoc
$$ forall a. Outputable a => a -> SDoc
ppr Ct
ct SDoc -> SDoc -> SDoc
$$ forall a. Outputable a => a -> SDoc
ppr ExtraOp
op1 SDoc -> SDoc -> SDoc
$$ forall a. Outputable a => a -> SDoc
ppr ExtraOp
op2 SDoc -> SDoc -> SDoc
$$ forall a. Outputable a => a -> SDoc
ppr Normalised
norm
  ppr (NatInequality Ct
_ ExtraOp
op1 ExtraOp
op2 Bool
b Normalised
norm) = CommandLineOption -> SDoc
text CommandLineOption
"NatInequality" SDoc -> SDoc -> SDoc
$$ forall a. Outputable a => a -> SDoc
ppr ExtraOp
op1 SDoc -> SDoc -> SDoc
$$ forall a. Outputable a => a -> SDoc
ppr ExtraOp
op2 SDoc -> SDoc -> SDoc
$$ forall a. Outputable a => a -> SDoc
ppr Bool
b SDoc -> SDoc -> SDoc
$$ forall a. Outputable a => a -> SDoc
ppr Normalised
norm

data SimplifyResult
  = Simplified [(EvTerm,Ct)] [Ct]
  | Impossible SolverConstraint

instance Outputable SimplifyResult where
  ppr :: SimplifyResult -> SDoc
ppr (Simplified [(EvTerm, Ct)]
evs [Ct]
new) = CommandLineOption -> SDoc
text CommandLineOption
"Simplified" SDoc -> SDoc -> SDoc
$$ CommandLineOption -> SDoc
text CommandLineOption
"Solved:" SDoc -> SDoc -> SDoc
$$ forall a. Outputable a => a -> SDoc
ppr [(EvTerm, Ct)]
evs SDoc -> SDoc -> SDoc
$$ CommandLineOption -> SDoc
text CommandLineOption
"New:" SDoc -> SDoc -> SDoc
$$ forall a. Outputable a => a -> SDoc
ppr [Ct]
new
  ppr (Impossible SolverConstraint
sct)  = CommandLineOption -> SDoc
text CommandLineOption
"Impossible" SDoc -> SDoc -> SDoc
<+> forall a. Outputable a => a -> SDoc
ppr SolverConstraint
sct

simplifyExtra :: ExtraDefs -> [SolverConstraint] -> TcPluginM SimplifyResult
simplifyExtra :: ExtraDefs -> [SolverConstraint] -> TcPluginM SimplifyResult
simplifyExtra ExtraDefs
defs [SolverConstraint]
eqs = CommandLineOption -> SDoc -> TcPluginM ()
tcPluginTrace CommandLineOption
"simplifyExtra" (forall a. Outputable a => a -> SDoc
ppr [SolverConstraint]
eqs) forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> [Maybe (EvTerm, Ct)]
-> [Ct] -> [SolverConstraint] -> TcPluginM SimplifyResult
simples [] [] [SolverConstraint]
eqs
  where
    simples :: [Maybe (EvTerm, Ct)] -> [Ct] -> [SolverConstraint] -> TcPluginM SimplifyResult
    simples :: [Maybe (EvTerm, Ct)]
-> [Ct] -> [SolverConstraint] -> TcPluginM SimplifyResult
simples [Maybe (EvTerm, Ct)]
evs [Ct]
news [] = forall (m :: * -> *) a. Monad m => a -> m a
return ([(EvTerm, Ct)] -> [Ct] -> SimplifyResult
Simplified (forall a. [Maybe a] -> [a]
catMaybes [Maybe (EvTerm, Ct)]
evs) [Ct]
news)
    simples [Maybe (EvTerm, Ct)]
evs [Ct]
news (eq :: SolverConstraint
eq@(NatEquality Ct
ct ExtraOp
u ExtraOp
v Normalised
norm):[SolverConstraint]
eqs') = do
      UnifyResult
ur <- Ct -> ExtraOp -> ExtraOp -> TcPluginM UnifyResult
unifyExtra Ct
ct ExtraOp
u ExtraOp
v
      CommandLineOption -> SDoc -> TcPluginM ()
tcPluginTrace CommandLineOption
"unifyExtra result" (forall a. Outputable a => a -> SDoc
ppr UnifyResult
ur)
      case UnifyResult
ur of
        UnifyResult
Win                          -> [Maybe (EvTerm, Ct)]
-> [Ct] -> [SolverConstraint] -> TcPluginM SimplifyResult
simples (((,) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Ct -> Maybe EvTerm
evMagic Ct
ct forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure Ct
ct)forall a. a -> [a] -> [a]
:[Maybe (EvTerm, Ct)]
evs) [Ct]
news [SolverConstraint]
eqs'
        UnifyResult
Lose | forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Maybe (EvTerm, Ct)]
evs Bool -> Bool -> Bool
&& forall (t :: * -> *) a. Foldable t => t a -> Bool
null [SolverConstraint]
eqs' -> forall (m :: * -> *) a. Monad m => a -> m a
return (SolverConstraint -> SimplifyResult
Impossible SolverConstraint
eq)
        UnifyResult
_ | Normalised
norm forall a. Eq a => a -> a -> Bool
== Normalised
Normalised Bool -> Bool -> Bool
&& Ct -> Bool
isWantedCt Ct
ct -> do
          Ct
newCt <- ExtraDefs -> SolverConstraint -> TcPluginM Ct
createWantedFromNormalised ExtraDefs
defs SolverConstraint
eq
          [Maybe (EvTerm, Ct)]
-> [Ct] -> [SolverConstraint] -> TcPluginM SimplifyResult
simples (((,) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Ct -> Maybe EvTerm
evMagic Ct
ct forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure Ct
ct)forall a. a -> [a] -> [a]
:[Maybe (EvTerm, Ct)]
evs) (Ct
newCtforall a. a -> [a] -> [a]
:[Ct]
news) [SolverConstraint]
eqs'
        UnifyResult
Lose -> [Maybe (EvTerm, Ct)]
-> [Ct] -> [SolverConstraint] -> TcPluginM SimplifyResult
simples [Maybe (EvTerm, Ct)]
evs [Ct]
news [SolverConstraint]
eqs'
        UnifyResult
Draw -> [Maybe (EvTerm, Ct)]
-> [Ct] -> [SolverConstraint] -> TcPluginM SimplifyResult
simples [Maybe (EvTerm, Ct)]
evs [Ct]
news [SolverConstraint]
eqs'
    simples [Maybe (EvTerm, Ct)]
evs [Ct]
news (eq :: SolverConstraint
eq@(NatInequality Ct
ct ExtraOp
u ExtraOp
v Bool
b Normalised
norm):[SolverConstraint]
eqs') = do
      CommandLineOption -> SDoc -> TcPluginM ()
tcPluginTrace CommandLineOption
"unifyExtra leq result" (forall a. Outputable a => a -> SDoc
ppr (ExtraOp
u,ExtraOp
v,Bool
b))
      case (ExtraOp
u,ExtraOp
v) of
        (I Integer
i,I Integer
j)
          | (Integer
i forall a. Ord a => a -> a -> Bool
<= Integer
j) forall a. Eq a => a -> a -> Bool
== Bool
b -> [Maybe (EvTerm, Ct)]
-> [Ct] -> [SolverConstraint] -> TcPluginM SimplifyResult
simples (((,) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Ct -> Maybe EvTerm
evMagic Ct
ct forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure Ct
ct)forall a. a -> [a] -> [a]
:[Maybe (EvTerm, Ct)]
evs) [Ct]
news [SolverConstraint]
eqs'
          | Bool
otherwise     -> forall (m :: * -> *) a. Monad m => a -> m a
return  (SolverConstraint -> SimplifyResult
Impossible SolverConstraint
eq)
        (ExtraOp
p, Max ExtraOp
x ExtraOp
y)
          | Bool
b Bool -> Bool -> Bool
&& (ExtraOp
p forall a. Eq a => a -> a -> Bool
== ExtraOp
x Bool -> Bool -> Bool
|| ExtraOp
p forall a. Eq a => a -> a -> Bool
== ExtraOp
y) -> [Maybe (EvTerm, Ct)]
-> [Ct] -> [SolverConstraint] -> TcPluginM SimplifyResult
simples (((,) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Ct -> Maybe EvTerm
evMagic Ct
ct forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure Ct
ct)forall a. a -> [a] -> [a]
:[Maybe (EvTerm, Ct)]
evs) [Ct]
news [SolverConstraint]
eqs'

        -- transform:  q ~ Max x y => (p <=? q ~ True)
        -- to:         (p <=? Max x y) ~ True
        -- and try to solve that along with the rest of the eqs'
        (ExtraOp
p, q :: ExtraOp
q@(V TyVar
_))
          | Bool
b -> case ExtraOp -> [SolverConstraint] -> Maybe ExtraOp
findMax ExtraOp
q [SolverConstraint]
eqs of
                   Just ExtraOp
m  -> [Maybe (EvTerm, Ct)]
-> [Ct] -> [SolverConstraint] -> TcPluginM SimplifyResult
simples [Maybe (EvTerm, Ct)]
evs [Ct]
news (Ct -> ExtraOp -> ExtraOp -> Bool -> Normalised -> SolverConstraint
NatInequality Ct
ct ExtraOp
p ExtraOp
m Bool
b Normalised
normforall a. a -> [a] -> [a]
:[SolverConstraint]
eqs')
                   Maybe ExtraOp
Nothing -> [Maybe (EvTerm, Ct)]
-> [Ct] -> [SolverConstraint] -> TcPluginM SimplifyResult
simples [Maybe (EvTerm, Ct)]
evs [Ct]
news [SolverConstraint]
eqs'
        (ExtraOp, ExtraOp)
_ | Normalised
norm forall a. Eq a => a -> a -> Bool
== Normalised
Normalised Bool -> Bool -> Bool
&& Ct -> Bool
isWantedCt Ct
ct -> do
          Ct
newCt <- ExtraDefs -> SolverConstraint -> TcPluginM Ct
createWantedFromNormalised ExtraDefs
defs SolverConstraint
eq
          [Maybe (EvTerm, Ct)]
-> [Ct] -> [SolverConstraint] -> TcPluginM SimplifyResult
simples (((,) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Ct -> Maybe EvTerm
evMagic Ct
ct forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure Ct
ct)forall a. a -> [a] -> [a]
:[Maybe (EvTerm, Ct)]
evs) (Ct
newCtforall a. a -> [a] -> [a]
:[Ct]
news) [SolverConstraint]
eqs'
        (ExtraOp, ExtraOp)
_ -> [Maybe (EvTerm, Ct)]
-> [Ct] -> [SolverConstraint] -> TcPluginM SimplifyResult
simples [Maybe (EvTerm, Ct)]
evs [Ct]
news [SolverConstraint]
eqs'

    -- look for given constraint with the form: c ~ Max x y
    findMax :: ExtraOp -> [SolverConstraint] -> Maybe ExtraOp
    findMax :: ExtraOp -> [SolverConstraint] -> Maybe ExtraOp
findMax ExtraOp
c = [SolverConstraint] -> Maybe ExtraOp
go
      where
        go :: [SolverConstraint] -> Maybe ExtraOp
go [] = forall a. Maybe a
Nothing
        go ((NatEquality Ct
ct ExtraOp
a b :: ExtraOp
b@(Max ExtraOp
_ ExtraOp
_) Normalised
_) :[SolverConstraint]
_)
          | ExtraOp
c forall a. Eq a => a -> a -> Bool
== ExtraOp
a Bool -> Bool -> Bool
&& Bool -> Bool
not (Ct -> Bool
isWantedCt Ct
ct)
            = forall a. a -> Maybe a
Just ExtraOp
b
        go ((NatEquality Ct
ct a :: ExtraOp
a@(Max ExtraOp
_ ExtraOp
_) ExtraOp
b Normalised
_) :[SolverConstraint]
_)
          | ExtraOp
c forall a. Eq a => a -> a -> Bool
== ExtraOp
b Bool -> Bool -> Bool
&& Bool -> Bool
not (Ct -> Bool
isWantedCt Ct
ct)
            = forall a. a -> Maybe a
Just ExtraOp
a
        go (SolverConstraint
_:[SolverConstraint]
rest) = [SolverConstraint] -> Maybe ExtraOp
go [SolverConstraint]
rest


-- Extract the Nat equality constraints
toSolverConstraint :: ExtraDefs -> Ct -> MaybeT TcPluginM SolverConstraint
toSolverConstraint :: ExtraDefs -> Ct -> MaybeT TcPluginM SolverConstraint
toSolverConstraint ExtraDefs
defs Ct
ct = case Type -> Pred
classifyPredType forall a b. (a -> b) -> a -> b
$ CtEvidence -> Type
ctEvPred forall a b. (a -> b) -> a -> b
$ Ct -> CtEvidence
ctEvidence Ct
ct of
    EqPred EqRel
NomEq Type
t1 Type
t2
      | Type -> Bool
isNatKind (HasDebugCallStack => Type -> Type
typeKind Type
t1) Bool -> Bool -> Bool
|| Type -> Bool
isNatKind (HasDebugCallStack => Type -> Type
typeKind Type
t2)
      -> do
         (ExtraOp
t1', Normalised
n1) <- ExtraDefs -> Type -> MaybeT TcPluginM (ExtraOp, Normalised)
normaliseNat ExtraDefs
defs Type
t1
         (ExtraOp
t2', Normalised
n2) <- ExtraDefs -> Type -> MaybeT TcPluginM (ExtraOp, Normalised)
normaliseNat ExtraDefs
defs Type
t2
         forall (f :: * -> *) a. Applicative f => a -> f a
pure (Ct -> ExtraOp -> ExtraOp -> Normalised -> SolverConstraint
NatEquality Ct
ct ExtraOp
t1' ExtraOp
t2' (Normalised -> Normalised -> Normalised
mergeNormalised Normalised
n1 Normalised
n2))
#if MIN_VERSION_ghc(9,2,0)
      | TyConApp TyCon
tc [Type
_,Type
cmpNat,TyConApp TyCon
tt1 [],TyConApp TyCon
tt2 [],TyConApp TyCon
ff1 []] <- Type
t1
      , TyCon
tc forall a. Eq a => a -> a -> Bool
== ExtraDefs -> TyCon
ordTyCon ExtraDefs
defs
      , TyConApp TyCon
cmpNatTc [Type
x,Type
y] <- Type
cmpNat
      , TyCon
cmpNatTc forall a. Eq a => a -> a -> Bool
== TyCon
typeNatCmpTyCon
      , TyCon
tt1 forall a. Eq a => a -> a -> Bool
== TyCon
promotedTrueDataCon
      , TyCon
tt2 forall a. Eq a => a -> a -> Bool
== TyCon
promotedTrueDataCon
      , TyCon
ff1 forall a. Eq a => a -> a -> Bool
== TyCon
promotedFalseDataCon
#else
      | TyConApp tc [x,y] <- t1
      , tc == typeNatLeqTyCon
#endif
      , TyConApp TyCon
tc' [] <- Type
t2
      -> do
          (ExtraOp
x', Normalised
n1) <- ExtraDefs -> Type -> MaybeT TcPluginM (ExtraOp, Normalised)
normaliseNat ExtraDefs
defs Type
x
          (ExtraOp
y', Normalised
n2) <- ExtraDefs -> Type -> MaybeT TcPluginM (ExtraOp, Normalised)
normaliseNat ExtraDefs
defs Type
y
          let res :: MaybeT TcPluginM SolverConstraint
res | TyCon
tc' forall a. Eq a => a -> a -> Bool
== TyCon
promotedTrueDataCon  = forall (f :: * -> *) a. Applicative f => a -> f a
pure (Ct -> ExtraOp -> ExtraOp -> Bool -> Normalised -> SolverConstraint
NatInequality Ct
ct ExtraOp
x' ExtraOp
y' Bool
True  (Normalised -> Normalised -> Normalised
mergeNormalised Normalised
n1 Normalised
n2))
                  | TyCon
tc' forall a. Eq a => a -> a -> Bool
== TyCon
promotedFalseDataCon = forall (f :: * -> *) a. Applicative f => a -> f a
pure (Ct -> ExtraOp -> ExtraOp -> Bool -> Normalised -> SolverConstraint
NatInequality Ct
ct ExtraOp
x' ExtraOp
y' Bool
False (Normalised -> Normalised -> Normalised
mergeNormalised Normalised
n1 Normalised
n2))
                  | Bool
otherwise                   = forall (m :: * -> *) a. MonadFail m => CommandLineOption -> m a
fail CommandLineOption
"Nothing"
          MaybeT TcPluginM SolverConstraint
res
    Pred
_ -> forall (m :: * -> *) a. MonadFail m => CommandLineOption -> m a
fail CommandLineOption
"Nothing"
  where
    isNatKind :: Kind -> Bool
    isNatKind :: Type -> Bool
isNatKind = (Type -> Type -> Bool
`eqType` Type
typeNatKind)

createWantedFromNormalised :: ExtraDefs -> SolverConstraint -> TcPluginM Ct
createWantedFromNormalised :: ExtraDefs -> SolverConstraint -> TcPluginM Ct
createWantedFromNormalised ExtraDefs
defs SolverConstraint
sct = do
  let extractCtSides :: SolverConstraint -> (Ct, Type, Type)
extractCtSides (NatEquality Ct
ct ExtraOp
t1 ExtraOp
t2 Normalised
_)   = (Ct
ct, ExtraDefs -> ExtraOp -> Type
reifyEOP ExtraDefs
defs ExtraOp
t1, ExtraDefs -> ExtraOp -> Type
reifyEOP ExtraDefs
defs ExtraOp
t2)
      extractCtSides (NatInequality Ct
ct ExtraOp
x ExtraOp
y Bool
b Normalised
_) =
        let tc :: TyCon
tc = if Bool
b then TyCon
promotedTrueDataCon else TyCon
promotedFalseDataCon
#if MIN_VERSION_ghc(9,2,0)
            t1 :: Type
t1 = TyCon -> [Type] -> Type
TyConApp (ExtraDefs -> TyCon
ordTyCon ExtraDefs
defs)
                    [ Type
boolTy
                    , TyCon -> [Type] -> Type
TyConApp TyCon
typeNatCmpTyCon [ExtraDefs -> ExtraOp -> Type
reifyEOP ExtraDefs
defs ExtraOp
x, ExtraDefs -> ExtraOp -> Type
reifyEOP ExtraDefs
defs ExtraOp
y]
                    , TyCon -> [Type] -> Type
TyConApp TyCon
promotedTrueDataCon []
                    , TyCon -> [Type] -> Type
TyConApp TyCon
promotedTrueDataCon []
                    , TyCon -> [Type] -> Type
TyConApp TyCon
promotedFalseDataCon []
                    ]
#else
            t1 = TyConApp typeNatLeqTyCon [reifyEOP defs x, reifyEOP defs y]
#endif
            t2 :: Type
t2 = TyCon -> [Type] -> Type
TyConApp TyCon
tc []
          in (Ct
ct, Type
t1, Type
t2)
  let (Ct
ct, Type
t1, Type
t2) = SolverConstraint -> (Ct, Type, Type)
extractCtSides SolverConstraint
sct
  Type
newPredTy <- case HasDebugCallStack => Type -> Maybe (TyCon, [Type])
splitTyConApp_maybe forall a b. (a -> b) -> a -> b
$ CtEvidence -> Type
ctEvPred forall a b. (a -> b) -> a -> b
$ Ct -> CtEvidence
ctEvidence Ct
ct of
    Just (TyCon
tc, [Type
a, Type
b, Type
_, Type
_]) | TyCon
tc forall a. Uniquable a => a -> Unique -> Bool
`hasKey` Unique
eqPrimTyConKey -> forall (f :: * -> *) a. Applicative f => a -> f a
pure (TyCon -> [Type] -> Type
mkTyConApp TyCon
tc [Type
a, Type
b, Type
t1, Type
t2])
    Maybe (TyCon, [Type])
_ -> forall (m :: * -> *) a. MonadFail m => CommandLineOption -> m a
fail CommandLineOption
"Nothing"
  CtEvidence
ev <- CtLoc -> Type -> TcPluginM CtEvidence
newWanted (Ct -> CtLoc
ctLoc Ct
ct) Type
newPredTy
  let ctN :: Ct
ctN = case Ct
ct of
#if MIN_VERSION_ghc(9,2,0)
              CQuantCan QCInst
qc -> QCInst -> Ct
CQuantCan (QCInst
qc { qci_ev :: CtEvidence
qci_ev = CtEvidence
ev})
#endif
              Ct
ctX -> Ct
ctX { cc_ev :: CtEvidence
cc_ev = CtEvidence
ev }
  forall (m :: * -> *) a. Monad m => a -> m a
return Ct
ctN

fromSolverConstraint :: SolverConstraint -> Ct
fromSolverConstraint :: SolverConstraint -> Ct
fromSolverConstraint (NatEquality Ct
ct ExtraOp
_ ExtraOp
_ Normalised
_)  = Ct
ct
fromSolverConstraint (NatInequality Ct
ct ExtraOp
_ ExtraOp
_ Bool
_ Normalised
_) = Ct
ct

lookupExtraDefs :: TcPluginM ExtraDefs
lookupExtraDefs :: TcPluginM ExtraDefs
lookupExtraDefs = do
    Module
md <- ModuleName -> FastString -> TcPluginM Module
lookupModule ModuleName
myModule FastString
myPackage
#if MIN_VERSION_ghc(9,2,0)
    Module
md2 <- ModuleName -> FastString -> TcPluginM Module
lookupModule ModuleName
ordModule FastString
basePackage
#endif
    TyCon
-> TyCon
-> TyCon
-> TyCon
-> TyCon
-> TyCon
-> TyCon
-> TyCon
-> TyCon
-> TyCon
-> TyCon
-> ExtraDefs
ExtraDefs forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Module -> CommandLineOption -> TcPluginM TyCon
look Module
md CommandLineOption
"Max"
              forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Module -> CommandLineOption -> TcPluginM TyCon
look Module
md CommandLineOption
"Min"
#if MIN_VERSION_ghc(8,4,0)
              forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure ([TyCon]
typeNatTyCons forall a. [a] -> Int -> a
!! Int
5)
              forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure ([TyCon]
typeNatTyCons forall a. [a] -> Int -> a
!! Int
6)
#else
              <*> look md "Div"
              <*> look md "Mod"
#endif
              forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Module -> CommandLineOption -> TcPluginM TyCon
look Module
md CommandLineOption
"FLog"
              forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Module -> CommandLineOption -> TcPluginM TyCon
look Module
md CommandLineOption
"CLog"
              forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Module -> CommandLineOption -> TcPluginM TyCon
look Module
md CommandLineOption
"Log"
              forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Module -> CommandLineOption -> TcPluginM TyCon
look Module
md CommandLineOption
"GCD"
              forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Module -> CommandLineOption -> TcPluginM TyCon
look Module
md CommandLineOption
"LCM"
#if MIN_VERSION_ghc(9,2,0)
              forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Module -> CommandLineOption -> TcPluginM TyCon
look Module
md2 CommandLineOption
"OrdCond"
              forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Module -> CommandLineOption -> TcPluginM TyCon
look Module
md2 CommandLineOption
"OrdCond"
#else
              <*> pure typeNatLeqTyCon
              <*> pure typeNatLeqTyCon
#endif
  where
    look :: Module -> CommandLineOption -> TcPluginM TyCon
look Module
md CommandLineOption
s = Name -> TcPluginM TyCon
tcLookupTyCon forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Module -> OccName -> TcPluginM Name
lookupName Module
md (CommandLineOption -> OccName
mkTcOcc CommandLineOption
s)
    myModule :: ModuleName
myModule  = CommandLineOption -> ModuleName
mkModuleName CommandLineOption
"GHC.TypeLits.Extra"
    myPackage :: FastString
myPackage = CommandLineOption -> FastString
fsLit CommandLineOption
"ghc-typelits-extra"
#if MIN_VERSION_ghc(9,2,0)
    ordModule :: ModuleName
ordModule   = CommandLineOption -> ModuleName
mkModuleName CommandLineOption
"Data.Type.Ord"
    basePackage :: FastString
basePackage = CommandLineOption -> FastString
fsLit CommandLineOption
"base"
#endif

-- Utils
evMagic :: Ct -> Maybe EvTerm
evMagic :: Ct -> Maybe EvTerm
evMagic Ct
ct = case Type -> Pred
classifyPredType forall a b. (a -> b) -> a -> b
$ CtEvidence -> Type
ctEvPred forall a b. (a -> b) -> a -> b
$ Ct -> CtEvidence
ctEvidence Ct
ct of
    EqPred EqRel
NomEq Type
t1 Type
t2 -> forall a. a -> Maybe a
Just (CommandLineOption -> Type -> Type -> EvTerm
evByFiat CommandLineOption
"ghc-typelits-extra" Type
t1 Type
t2)
    Pred
_                  -> forall a. Maybe a
Nothing