{-|
Copyright  :  (C) 2016     , University of Twente,
                  2017-2018, QBayLogic B.V.,
                  2017     , Google Inc.
License    :  BSD2 (see the file LICENSE)
Maintainer :  Christiaan Baaij <christiaan.baaij@gmail.com>

A type checker plugin for GHC that can derive \"complex\" @KnownNat@
constraints from other simple/variable @KnownNat@ constraints. i.e. without
this plugin, you must have both a @KnownNat n@ and a @KnownNat (n+2)@
constraint in the type signature of the following function:

@
f :: forall n . (KnownNat n, KnownNat (n+2)) => Proxy n -> Integer
f _ = natVal (Proxy :: Proxy n) + natVal (Proxy :: Proxy (n+2))
@

Using the plugin you can omit the @KnownNat (n+2)@ constraint:

@
f :: forall n . KnownNat n => Proxy n -> Integer
f _ = natVal (Proxy :: Proxy n) + natVal (Proxy :: Proxy (n+2))
@

The plugin can derive @KnownNat@ constraints for types consisting of:

* Type variables, when there is a corresponding @KnownNat@ constraint
* Type-level naturals
* Applications of the arithmetic expression: @{+,-,*,^}@
* Type functions, when there is either:
  * a matching given @KnownNat@ constraint; or
  * a corresponding @KnownNat\<N\>@ instance for the type function

To elaborate the latter points, given the type family @Min@:

@
type family Min (a :: Nat) (b :: Nat) :: Nat where
  Min 0 b = 0
  Min a b = If (a <=? b) a b
@

the plugin can derive a @KnownNat (Min x y + 1)@ constraint given only a
@KnownNat (Min x y)@ constraint:

@
g :: forall x y . (KnownNat (Min x y)) => Proxy x -> Proxy y -> Integer
g _ _ = natVal (Proxy :: Proxy (Min x y + 1))
@

And, given the type family @Max@:

@
type family Max (a :: Nat) (b :: Nat) :: Nat where
  Max 0 b = b
  Max a b = If (a <=? b) b a
@

and corresponding @KnownNat2@ instance:

@
instance (KnownNat a, KnownNat b) => KnownNat2 \"TestFunctions.Max\" a b where
  natSing2 = let x = natVal (Proxy @ a)
                 y = natVal (Proxy @ b)
                 z = max x y
             in  SNatKn z
  \{\-# INLINE natSing2 \#-\}
@

the plugin can derive a @KnownNat (Max x y + 1)@ constraint given only a
@KnownNat x@ and @KnownNat y@ constraint:

@
h :: forall x y . (KnownNat x, KnownNat y) => Proxy x -> Proxy y -> Integer
h _ _ = natVal (Proxy :: Proxy (Max x y + 1))
@

To use the plugin, add the

@
OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver
@

Pragma to the header of your file.

-}

{-# LANGUAGE CPP           #-}
{-# LANGUAGE LambdaCase    #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE ViewPatterns  #-}

{-# LANGUAGE Trustworthy   #-}

{-# OPTIONS_HADDOCK show-extensions #-}

module GHC.TypeLits.KnownNat.Solver
  ( plugin )
where

-- external
import Control.Arrow                ((&&&), first)
import Control.Monad.Trans.Maybe    (MaybeT (..))
import Control.Monad.Trans.Writer.Strict
import Data.Maybe                   (catMaybes,mapMaybe)
import GHC.TcPluginM.Extra          (lookupModule, lookupName, newWanted,
                                     tracePlugin)
#if MIN_VERSION_ghc(8,4,0)
import GHC.TcPluginM.Extra          (flattenGivens, mkSubst', substType)
#endif
import GHC.TypeLits.Normalise.SOP   (SOP (..), Product (..), Symbol (..))
import GHC.TypeLits.Normalise.Unify (CType (..),normaliseNat,reifySOP)

-- GHC API
import Class      (Class, classMethods, className, classTyCon)
#if MIN_VERSION_ghc(8,6,0)
import Coercion   (Role (Representational), mkUnivCo)
#endif
import FamInst    (tcInstNewTyCon_maybe)
import FastString (fsLit)
import Id         (idType)
import InstEnv    (instanceDFunId,lookupUniqueInstEnv)
#if MIN_VERSION_ghc(8,5,0)
import MkCore     (mkNaturalExpr)
#endif
import Module     (mkModuleName, moduleName, moduleNameString)
import Name       (nameModule_maybe, nameOccName)
import OccName    (mkTcOcc, occNameString)
import Plugins    (Plugin (..), defaultPlugin)
#if MIN_VERSION_ghc(8,6,0)
import Plugins    (purePlugin)
#endif
import PrelNames  (knownNatClassName)
#if MIN_VERSION_ghc(8,5,0)
import TcEvidence (EvTerm (..), EvExpr, evDFunApp, mkEvCast, mkTcSymCo, mkTcTransCo)
#else
import TcEvidence (EvTerm (..), EvLit (EvNum), mkEvCast, mkTcSymCo, mkTcTransCo)
#endif
#if MIN_VERSION_ghc(8,5,0)
import TcPluginM  (unsafeTcPluginTcM)
#endif
#if !MIN_VERSION_ghc(8,4,0)
import TcPluginM  (zonkCt)
#endif
import TcPluginM  (TcPluginM, tcLookupClass, getInstEnvs)
import TcRnTypes  (Ct, TcPlugin(..), TcPluginResult (..), ctEvidence, ctEvLoc,
#if MIN_VERSION_ghc(8,5,0)
                   ctEvPred, ctEvExpr, ctLoc, ctLocSpan, isWanted,
#else
                   ctEvPred, ctEvTerm, ctLoc, ctLocSpan, isWanted,
#endif
                   mkNonCanonical, setCtLoc, setCtLocSpan)
import TcTypeNats (typeNatAddTyCon, typeNatSubTyCon)
import Type
  (EqRel (NomEq), PredTree (ClassPred,EqPred), PredType, classifyPredType,
   dropForAlls, eqType, funResultTy, mkNumLitTy, mkStrLitTy, mkTyConApp,
   piResultTys, splitFunTys, splitTyConApp_maybe, tyConAppTyCon_maybe, typeKind)
import TyCon      (tyConName)
import TyCoRep    (Type (..), TyLit (..))
#if MIN_VERSION_ghc(8,6,0)
import TyCoRep    (UnivCoProvenance (PluginProv))
import TysWiredIn (boolTy)
#endif
import Var        (DFunId)

-- | Classes and instances from "GHC.TypeLits.KnownNat"
data KnownNatDefs
  = KnownNatDefs
  { KnownNatDefs -> Class
knownBool     :: Class
  , KnownNatDefs -> Class
knownBoolNat2 :: Class
  , KnownNatDefs -> Class
knownNat2Bool :: Class
  , KnownNatDefs -> Int -> Maybe Class
knownNatN     :: Int -> Maybe Class -- ^ KnownNat{N}
  }

-- | KnownNat constraints
type KnConstraint = (Ct    -- The constraint
                    ,Class -- KnownNat class
                    ,Type  -- The argument to KnownNat
                    )

{-|
A type checker plugin for GHC that can derive \"complex\" @KnownNat@
constraints from other simple/variable @KnownNat@ constraints. i.e. without
this plugin, you must have both a @KnownNat n@ and a @KnownNat (n+2)@
constraint in the type signature of the following function:

@
f :: forall n . (KnownNat n, KnownNat (n+2)) => Proxy n -> Integer
f _ = natVal (Proxy :: Proxy n) + natVal (Proxy :: Proxy (n+2))
@

Using the plugin you can omit the @KnownNat (n+2)@ constraint:

@
f :: forall n . KnownNat n => Proxy n -> Integer
f _ = natVal (Proxy :: Proxy n) + natVal (Proxy :: Proxy (n+2))
@

The plugin can derive @KnownNat@ constraints for types consisting of:

* Type variables, when there is a corresponding @KnownNat@ constraint
* Type-level naturals
* Applications of the arithmetic expression: @{+,-,*,^}@
* Type functions, when there is either:
  * a matching given @KnownNat@ constraint; or
  * a corresponding @KnownNat\<N\>@ instance for the type function

To elaborate the latter points, given the type family @Min@:

@
type family Min (a :: Nat) (b :: Nat) :: Nat where
  Min 0 b = 0
  Min a b = If (a <=? b) a b
@

the plugin can derive a @KnownNat (Min x y + 1)@ constraint given only a
@KnownNat (Min x y)@ constraint:

@
g :: forall x y . (KnownNat (Min x y)) => Proxy x -> Proxy y -> Integer
g _ _ = natVal (Proxy :: Proxy (Min x y + 1))
@

And, given the type family @Max@:

@
type family Max (a :: Nat) (b :: Nat) :: Nat where
  Max 0 b = b
  Max a b = If (a <=? b) b a

$(genDefunSymbols [''Max]) -- creates the 'MaxSym0' symbol
@

and corresponding @KnownNat2@ instance:

@
instance (KnownNat a, KnownNat b) => KnownNat2 \"TestFunctions.Max\" a b where
  type KnownNatF2 \"TestFunctions.Max\" = MaxSym0
  natSing2 = let x = natVal (Proxy @ a)
                 y = natVal (Proxy @ b)
                 z = max x y
             in  SNatKn z
  \{\-# INLINE natSing2 \#-\}
@

the plugin can derive a @KnownNat (Max x y + 1)@ constraint given only a
@KnownNat x@ and @KnownNat y@ constraint:

@
h :: forall x y . (KnownNat x, KnownNat y) => Proxy x -> Proxy y -> Integer
h _ _ = natVal (Proxy :: Proxy (Max x y + 1))
@

To use the plugin, add the

@
OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver
@

Pragma to the header of your file.

-}
plugin :: Plugin
plugin :: Plugin
plugin
  = Plugin
defaultPlugin
  { tcPlugin :: TcPlugin
tcPlugin = Maybe TcPlugin -> TcPlugin
forall a b. a -> b -> a
const (Maybe TcPlugin -> TcPlugin) -> Maybe TcPlugin -> TcPlugin
forall a b. (a -> b) -> a -> b
$ TcPlugin -> Maybe TcPlugin
forall a. a -> Maybe a
Just TcPlugin
normalisePlugin
#if MIN_VERSION_ghc(8,6,0)
  , pluginRecompile :: [CommandLineOption] -> IO PluginRecompile
pluginRecompile = [CommandLineOption] -> IO PluginRecompile
purePlugin
#endif
  }

normalisePlugin :: TcPlugin
normalisePlugin :: TcPlugin
normalisePlugin = CommandLineOption -> TcPlugin -> TcPlugin
tracePlugin "ghc-typelits-knownnat"
  TcPlugin :: forall s.
TcPluginM s
-> (s -> TcPluginSolver) -> (s -> TcPluginM ()) -> TcPlugin
TcPlugin { tcPluginInit :: TcPluginM KnownNatDefs
tcPluginInit  = TcPluginM KnownNatDefs
lookupKnownNatDefs
           , tcPluginSolve :: KnownNatDefs -> TcPluginSolver
tcPluginSolve = KnownNatDefs -> TcPluginSolver
solveKnownNat
           , tcPluginStop :: KnownNatDefs -> TcPluginM ()
tcPluginStop  = TcPluginM () -> KnownNatDefs -> TcPluginM ()
forall a b. a -> b -> a
const (() -> TcPluginM ()
forall (m :: * -> *) a. Monad m => a -> m a
return ())
           }

solveKnownNat :: KnownNatDefs -> [Ct] -> [Ct] -> [Ct]
              -> TcPluginM TcPluginResult
solveKnownNat :: KnownNatDefs -> TcPluginSolver
solveKnownNat _defs :: KnownNatDefs
_defs _givens :: [Ct]
_givens _deriveds :: [Ct]
_deriveds []      = TcPluginResult -> TcPluginM TcPluginResult
forall (m :: * -> *) a. Monad m => a -> m a
return ([(EvTerm, Ct)] -> [Ct] -> TcPluginResult
TcPluginOk [] [])
solveKnownNat defs :: KnownNatDefs
defs  givens :: [Ct]
givens  _deriveds :: [Ct]
_deriveds wanteds :: [Ct]
wanteds = do
  -- GHC 7.10 puts deriveds with the wanteds, so filter them out
  let wanteds' :: [Ct]
wanteds'   = (Ct -> Bool) -> [Ct] -> [Ct]
forall a. (a -> Bool) -> [a] -> [a]
filter (CtEvidence -> Bool
isWanted (CtEvidence -> Bool) -> (Ct -> CtEvidence) -> Ct -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Ct -> CtEvidence
ctEvidence) [Ct]
wanteds
#if MIN_VERSION_ghc(8,4,0)
      subst :: [(TcTyVar, TcType)]
subst      = (((TcTyVar, TcType), Ct) -> (TcTyVar, TcType))
-> [((TcTyVar, TcType), Ct)] -> [(TcTyVar, TcType)]
forall a b. (a -> b) -> [a] -> [b]
map ((TcTyVar, TcType), Ct) -> (TcTyVar, TcType)
forall a b. (a, b) -> a
fst
                 ([((TcTyVar, TcType), Ct)] -> [(TcTyVar, TcType)])
-> [((TcTyVar, TcType), Ct)] -> [(TcTyVar, TcType)]
forall a b. (a -> b) -> a -> b
$ [Ct] -> [((TcTyVar, TcType), Ct)]
mkSubst' [Ct]
givens
      kn_wanteds :: [(Ct, Class, TcType)]
kn_wanteds = ((Ct, Class, TcType) -> (Ct, Class, TcType))
-> [(Ct, Class, TcType)] -> [(Ct, Class, TcType)]
forall a b. (a -> b) -> [a] -> [b]
map (\(x :: Ct
x,y :: Class
y,z :: TcType
z) -> (Ct
x,Class
y,[(TcTyVar, TcType)] -> TcType -> TcType
substType [(TcTyVar, TcType)]
subst TcType
z))
                 ([(Ct, Class, TcType)] -> [(Ct, Class, TcType)])
-> [(Ct, Class, TcType)] -> [(Ct, Class, TcType)]
forall a b. (a -> b) -> a -> b
$ (Ct -> Maybe (Ct, Class, TcType)) -> [Ct] -> [(Ct, Class, TcType)]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (KnownNatDefs -> Ct -> Maybe (Ct, Class, TcType)
toKnConstraint KnownNatDefs
defs) [Ct]
wanteds'
#else
      kn_wanteds = mapMaybe (toKnConstraint defs) wanteds'
#endif
  case [(Ct, Class, TcType)]
kn_wanteds of
    [] -> TcPluginResult -> TcPluginM TcPluginResult
forall (m :: * -> *) a. Monad m => a -> m a
return ([(EvTerm, Ct)] -> [Ct] -> TcPluginResult
TcPluginOk [] [])
    _  -> do
      -- Make a lookup table for all the [G]iven constraints
#if MIN_VERSION_ghc(8,4,0)
      let given_map :: [(CType, EvExpr)]
given_map = (Ct -> (CType, EvExpr)) -> [Ct] -> [(CType, EvExpr)]
forall a b. (a -> b) -> [a] -> [b]
map Ct -> (CType, EvExpr)
toGivenEntry ([Ct] -> [Ct]
flattenGivens [Ct]
givens)
#else
      given_map <- mapM (fmap toGivenEntry . zonkCt) givens
#endif
      -- Try to solve the wanted KnownNat constraints given the [G]iven
      -- KnownNat constraints
      (solved :: [(EvTerm, Ct)]
solved,new :: [[Ct]]
new) <- ([((EvTerm, Ct), [Ct])] -> ([(EvTerm, Ct)], [[Ct]])
forall a b. [(a, b)] -> ([a], [b])
unzip ([((EvTerm, Ct), [Ct])] -> ([(EvTerm, Ct)], [[Ct]]))
-> ([Maybe ((EvTerm, Ct), [Ct])] -> [((EvTerm, Ct), [Ct])])
-> [Maybe ((EvTerm, Ct), [Ct])]
-> ([(EvTerm, Ct)], [[Ct]])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Maybe ((EvTerm, Ct), [Ct])] -> [((EvTerm, Ct), [Ct])]
forall a. [Maybe a] -> [a]
catMaybes) ([Maybe ((EvTerm, Ct), [Ct])] -> ([(EvTerm, Ct)], [[Ct]]))
-> TcPluginM [Maybe ((EvTerm, Ct), [Ct])]
-> TcPluginM ([(EvTerm, Ct)], [[Ct]])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (((Ct, Class, TcType) -> TcPluginM (Maybe ((EvTerm, Ct), [Ct])))
-> [(Ct, Class, TcType)] -> TcPluginM [Maybe ((EvTerm, Ct), [Ct])]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (KnownNatDefs
-> [(CType, EvExpr)]
-> (Ct, Class, TcType)
-> TcPluginM (Maybe ((EvTerm, Ct), [Ct]))
constraintToEvTerm KnownNatDefs
defs [(CType, EvExpr)]
given_map) [(Ct, Class, TcType)]
kn_wanteds)
      TcPluginResult -> TcPluginM TcPluginResult
forall (m :: * -> *) a. Monad m => a -> m a
return ([(EvTerm, Ct)] -> [Ct] -> TcPluginResult
TcPluginOk [(EvTerm, Ct)]
solved ([[Ct]] -> [Ct]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[Ct]]
new))

-- | Get the KnownNat constraints
toKnConstraint :: KnownNatDefs -> Ct -> Maybe KnConstraint
toKnConstraint :: KnownNatDefs -> Ct -> Maybe (Ct, Class, TcType)
toKnConstraint defs :: KnownNatDefs
defs ct :: Ct
ct = case TcType -> PredTree
classifyPredType (TcType -> PredTree) -> TcType -> PredTree
forall a b. (a -> b) -> a -> b
$ CtEvidence -> TcType
ctEvPred (CtEvidence -> TcType) -> CtEvidence -> TcType
forall a b. (a -> b) -> a -> b
$ Ct -> CtEvidence
ctEvidence Ct
ct of
  ClassPred cls :: Class
cls [ty :: TcType
ty]
    |  Class -> Name
className Class
cls Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
== Name
knownNatClassName Bool -> Bool -> Bool
||
       Class -> Name
className Class
cls Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
== Class -> Name
className (KnownNatDefs -> Class
knownBool KnownNatDefs
defs)
    -> (Ct, Class, TcType) -> Maybe (Ct, Class, TcType)
forall a. a -> Maybe a
Just (Ct
ct,Class
cls,TcType
ty)
  _ -> Maybe (Ct, Class, TcType)
forall a. Maybe a
Nothing

-- | Create a look-up entry for a [G]iven constraint.
#if MIN_VERSION_ghc(8,5,0)
toGivenEntry :: Ct -> (CType,EvExpr)
#else
toGivenEntry :: Ct -> (CType,EvTerm)
#endif
toGivenEntry :: Ct -> (CType, EvExpr)
toGivenEntry ct :: Ct
ct = let ct_ev :: CtEvidence
ct_ev = Ct -> CtEvidence
ctEvidence Ct
ct
                      c_ty :: TcType
c_ty  = CtEvidence -> TcType
ctEvPred   CtEvidence
ct_ev
#if MIN_VERSION_ghc(8,5,0)
                      ev :: EvExpr
ev    = CtEvidence -> EvExpr
ctEvExpr   CtEvidence
ct_ev
#else
                      ev    = ctEvTerm   ct_ev
#endif
                  in  (TcType -> CType
CType TcType
c_ty,EvExpr
ev)

-- | Find the \"magic\" classes and instances in "GHC.TypeLits.KnownNat"
lookupKnownNatDefs :: TcPluginM KnownNatDefs
lookupKnownNatDefs :: TcPluginM KnownNatDefs
lookupKnownNatDefs = do
    Module
md     <- ModuleName -> FastString -> TcPluginM Module
lookupModule ModuleName
myModule FastString
myPackage
    Class
kbC    <- Module -> CommandLineOption -> TcPluginM Class
look Module
md "KnownBool"
    Class
kbn2C  <- Module -> CommandLineOption -> TcPluginM Class
look Module
md "KnownBoolNat2"
    Class
kn2bC  <- Module -> CommandLineOption -> TcPluginM Class
look Module
md "KnownNat2Bool"
    Class
kn1C   <- Module -> CommandLineOption -> TcPluginM Class
look Module
md "KnownNat1"
    Class
kn2C   <- Module -> CommandLineOption -> TcPluginM Class
look Module
md "KnownNat2"
    Class
kn3C   <- Module -> CommandLineOption -> TcPluginM Class
look Module
md "KnownNat3"
    KnownNatDefs -> TcPluginM KnownNatDefs
forall (m :: * -> *) a. Monad m => a -> m a
return KnownNatDefs :: Class -> Class -> Class -> (Int -> Maybe Class) -> KnownNatDefs
KnownNatDefs
           { knownBool :: Class
knownBool     = Class
kbC
           , knownBoolNat2 :: Class
knownBoolNat2 = Class
kbn2C
           , knownNat2Bool :: Class
knownNat2Bool = Class
kn2bC
           , knownNatN :: Int -> Maybe Class
knownNatN     = \case { 1 -> Class -> Maybe Class
forall a. a -> Maybe a
Just Class
kn1C
                                   ; 2 -> Class -> Maybe Class
forall a. a -> Maybe a
Just Class
kn2C
                                   ; 3 -> Class -> Maybe Class
forall a. a -> Maybe a
Just Class
kn3C
                                   ; _ -> Maybe Class
forall a. Maybe a
Nothing
                                   }
           }
  where
    look :: Module -> CommandLineOption -> TcPluginM Class
look md :: Module
md s :: CommandLineOption
s = do
      Name
nm   <- Module -> OccName -> TcPluginM Name
lookupName Module
md (CommandLineOption -> OccName
mkTcOcc CommandLineOption
s)
      Name -> TcPluginM Class
tcLookupClass Name
nm

    myModule :: ModuleName
myModule  = CommandLineOption -> ModuleName
mkModuleName "GHC.TypeLits.KnownNat"
    myPackage :: FastString
myPackage = CommandLineOption -> FastString
fsLit "ghc-typelits-knownnat"

-- | Try to create evidence for a wanted constraint
constraintToEvTerm
  :: KnownNatDefs     -- ^ The "magic" KnownNatN classes
#if MIN_VERSION_ghc(8,5,0)
  -> [(CType,EvExpr)]
#else
  -> [(CType,EvTerm)]
#endif
  -- All the [G]iven constraints

  -> KnConstraint
  -> TcPluginM (Maybe ((EvTerm,Ct),[Ct]))
constraintToEvTerm :: KnownNatDefs
-> [(CType, EvExpr)]
-> (Ct, Class, TcType)
-> TcPluginM (Maybe ((EvTerm, Ct), [Ct]))
constraintToEvTerm defs :: KnownNatDefs
defs givens :: [(CType, EvExpr)]
givens (ct :: Ct
ct,cls :: Class
cls,op :: TcType
op) = do
    -- 1. Determine if we are an offset apart from a [G]iven constraint
    Maybe (EvTerm, [Ct])
offsetM <- TcType -> TcPluginM (Maybe (EvTerm, [Ct]))
offset TcType
op
    Maybe (EvTerm, [Ct])
evM     <- case Maybe (EvTerm, [Ct])
offsetM of
                 -- 3.a If so, we are done
                 found :: Maybe (EvTerm, [Ct])
found@Just {} -> Maybe (EvTerm, [Ct]) -> TcPluginM (Maybe (EvTerm, [Ct]))
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe (EvTerm, [Ct])
found
                 -- 3.b If not, we check if the outer type-level operation
                 -- has a corresponding KnownNat<N> instance.
                 _ -> TcType -> TcPluginM (Maybe (EvTerm, [Ct]))
go TcType
op
    Maybe ((EvTerm, Ct), [Ct])
-> TcPluginM (Maybe ((EvTerm, Ct), [Ct]))
forall (m :: * -> *) a. Monad m => a -> m a
return (((EvTerm -> (EvTerm, Ct)) -> (EvTerm, [Ct]) -> ((EvTerm, Ct), [Ct])
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (b, d) (c, d)
first (,Ct
ct)) ((EvTerm, [Ct]) -> ((EvTerm, Ct), [Ct]))
-> Maybe (EvTerm, [Ct]) -> Maybe ((EvTerm, Ct), [Ct])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Maybe (EvTerm, [Ct])
evM)
  where
    -- Determine whether the outer type-level operation has a corresponding
    -- KnownNat<N> instance, where /N/ corresponds to the arity of the
    -- type-level operation
    go :: Type -> TcPluginM (Maybe (EvTerm,[Ct]))
    go :: TcType -> TcPluginM (Maybe (EvTerm, [Ct]))
go (TcType -> Maybe EvTerm
go_other -> Just ev :: EvTerm
ev) = Maybe (EvTerm, [Ct]) -> TcPluginM (Maybe (EvTerm, [Ct]))
forall (m :: * -> *) a. Monad m => a -> m a
return ((EvTerm, [Ct]) -> Maybe (EvTerm, [Ct])
forall a. a -> Maybe a
Just (EvTerm
ev,[]))
    go ty :: TcType
ty@(TyConApp tc :: TyCon
tc args0 :: [TcType]
args0)
      | let tcNm :: Name
tcNm = TyCon -> Name
tyConName TyCon
tc
      , Just m :: Module
m <- Name -> Maybe Module
nameModule_maybe Name
tcNm
      = do
        InstEnvs
ienv <- TcPluginM InstEnvs
getInstEnvs
        let mS :: CommandLineOption
mS  = ModuleName -> CommandLineOption
moduleNameString (Module -> ModuleName
moduleName Module
m)
            tcS :: CommandLineOption
tcS = OccName -> CommandLineOption
occNameString (Name -> OccName
nameOccName Name
tcNm)
            fn0 :: CommandLineOption
fn0 = CommandLineOption
mS CommandLineOption -> CommandLineOption -> CommandLineOption
forall a. [a] -> [a] -> [a]
++ "." CommandLineOption -> CommandLineOption -> CommandLineOption
forall a. [a] -> [a] -> [a]
++ CommandLineOption
tcS
            fn1 :: TcType
fn1 = FastString -> TcType
mkStrLitTy (CommandLineOption -> FastString
fsLit CommandLineOption
fn0)
            args1 :: [TcType]
args1 = TcType
fn1TcType -> [TcType] -> [TcType]
forall a. a -> [a] -> [a]
:[TcType]
args0
            instM :: Maybe (ClsInst, Class, [TcType], [TcType])
instM = case () of
              () | Just knN_cls :: Class
knN_cls    <- KnownNatDefs -> Int -> Maybe Class
knownNatN KnownNatDefs
defs ([TcType] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [TcType]
args0)
                 , Right (inst :: ClsInst
inst, _) <- InstEnvs -> Class -> [TcType] -> Either MsgDoc (ClsInst, [TcType])
lookupUniqueInstEnv InstEnvs
ienv Class
knN_cls [TcType]
args1
                 -> (ClsInst, Class, [TcType], [TcType])
-> Maybe (ClsInst, Class, [TcType], [TcType])
forall a. a -> Maybe a
Just (ClsInst
inst,Class
knN_cls,[TcType]
args0,[TcType]
args1)
                 | [TcType] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [TcType]
args0 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== 2
                 , let knN_cls :: Class
knN_cls = KnownNatDefs -> Class
knownBoolNat2 KnownNatDefs
defs
                       ki :: TcType
ki      = HasDebugCallStack => TcType -> TcType
TcType -> TcType
typeKind ([TcType] -> TcType
forall a. [a] -> a
head [TcType]
args0)
                       args1N :: [TcType]
args1N  = TcType
kiTcType -> [TcType] -> [TcType]
forall a. a -> [a] -> [a]
:[TcType]
args1
                 , Right (inst :: ClsInst
inst, _) <- InstEnvs -> Class -> [TcType] -> Either MsgDoc (ClsInst, [TcType])
lookupUniqueInstEnv InstEnvs
ienv Class
knN_cls [TcType]
args1N
                 -> (ClsInst, Class, [TcType], [TcType])
-> Maybe (ClsInst, Class, [TcType], [TcType])
forall a. a -> Maybe a
Just (ClsInst
inst,Class
knN_cls,[TcType]
args0,[TcType]
args1N)
                 | [TcType] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [TcType]
args0 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== 4
                 , CommandLineOption
fn0 CommandLineOption -> CommandLineOption -> Bool
forall a. Eq a => a -> a -> Bool
== "Data.Type.Bool.If"
                 , let args0N :: [TcType]
args0N = [TcType] -> [TcType]
forall a. [a] -> [a]
tail [TcType]
args0
                       args1N :: [TcType]
args1N = [TcType] -> TcType
forall a. [a] -> a
head [TcType]
args0TcType -> [TcType] -> [TcType]
forall a. a -> [a] -> [a]
:TcType
fn1TcType -> [TcType] -> [TcType]
forall a. a -> [a] -> [a]
:[TcType] -> [TcType]
forall a. [a] -> [a]
tail [TcType]
args0
                       knN_cls :: Class
knN_cls = KnownNatDefs -> Class
knownNat2Bool KnownNatDefs
defs
                 , Right (inst :: ClsInst
inst, _) <- InstEnvs -> Class -> [TcType] -> Either MsgDoc (ClsInst, [TcType])
lookupUniqueInstEnv InstEnvs
ienv Class
knN_cls [TcType]
args1N
                 -> (ClsInst, Class, [TcType], [TcType])
-> Maybe (ClsInst, Class, [TcType], [TcType])
forall a. a -> Maybe a
Just (ClsInst
inst,Class
knN_cls,[TcType]
args0N,[TcType]
args1N)
                 | Bool
otherwise
                 -> Maybe (ClsInst, Class, [TcType], [TcType])
forall a. Maybe a
Nothing
        case Maybe (ClsInst, Class, [TcType], [TcType])
instM of
          Just (inst :: ClsInst
inst,knN_cls :: Class
knN_cls,args0N :: [TcType]
args0N,args1N :: [TcType]
args1N) -> do
            let df_id :: TcTyVar
df_id   = ClsInst -> TcTyVar
instanceDFunId ClsInst
inst
                df :: (Class, TcTyVar)
df      = (Class
knN_cls,TcTyVar
df_id)
                df_args :: [TcType]
df_args = ([TcType], TcType) -> [TcType]
forall a b. (a, b) -> a
fst                  -- [KnownNat x, KnownNat y]
                        (([TcType], TcType) -> [TcType])
-> (TcType -> ([TcType], TcType)) -> TcType -> [TcType]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TcType -> ([TcType], TcType)
splitFunTys          -- ([KnownNat x, KnowNat y], DKnownNat2 "+" x y)
                        (TcType -> ([TcType], TcType))
-> (TcType -> TcType) -> TcType -> ([TcType], TcType)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (HasDebugCallStack => TcType -> [TcType] -> TcType
TcType -> [TcType] -> TcType
`piResultTys` [TcType]
args0N) -- (KnowNat x, KnownNat y) => DKnownNat2 "+" x y
                        (TcType -> [TcType]) -> TcType -> [TcType]
forall a b. (a -> b) -> a -> b
$ TcTyVar -> TcType
idType TcTyVar
df_id         -- forall a b . (KnownNat a, KnownNat b) => DKnownNat2 "+" a b
            (evs :: [EvExpr]
evs,new :: [[Ct]]
new) <- [(EvExpr, [Ct])] -> ([EvExpr], [[Ct]])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(EvExpr, [Ct])] -> ([EvExpr], [[Ct]]))
-> TcPluginM [(EvExpr, [Ct])] -> TcPluginM ([EvExpr], [[Ct]])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (TcType -> TcPluginM (EvExpr, [Ct]))
-> [TcType] -> TcPluginM [(EvExpr, [Ct])]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM TcType -> TcPluginM (EvExpr, [Ct])
go_arg [TcType]
df_args
            if Class -> Name
className Class
cls Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
== Class -> Name
className (KnownNatDefs -> Class
knownBool KnownNatDefs
defs)
               then Maybe (EvTerm, [Ct]) -> TcPluginM (Maybe (EvTerm, [Ct]))
forall (m :: * -> *) a. Monad m => a -> m a
return ((,[[Ct]] -> [Ct]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[Ct]]
new) (EvTerm -> (EvTerm, [Ct])) -> Maybe EvTerm -> Maybe (EvTerm, [Ct])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Class, TcTyVar)
-> Class
-> [TcType]
-> [TcType]
-> TcType
-> [EvExpr]
-> Maybe EvTerm
makeOpDictByFiat (Class, TcTyVar)
df Class
cls [TcType]
args1N [TcType]
args0N TcType
op [EvExpr]
evs)
               else Maybe (EvTerm, [Ct]) -> TcPluginM (Maybe (EvTerm, [Ct]))
forall (m :: * -> *) a. Monad m => a -> m a
return ((,[[Ct]] -> [Ct]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[Ct]]
new) (EvTerm -> (EvTerm, [Ct])) -> Maybe EvTerm -> Maybe (EvTerm, [Ct])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Class, TcTyVar)
-> Class
-> [TcType]
-> [TcType]
-> TcType
-> [EvExpr]
-> Maybe EvTerm
makeOpDict (Class, TcTyVar)
df Class
cls [TcType]
args1N [TcType]
args0N TcType
op [EvExpr]
evs)
          _ -> Maybe (EvTerm, [Ct]) -> TcPluginM (Maybe (EvTerm, [Ct]))
forall (m :: * -> *) a. Monad m => a -> m a
return ((,[]) (EvTerm -> (EvTerm, [Ct])) -> Maybe EvTerm -> Maybe (EvTerm, [Ct])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> TcType -> Maybe EvTerm
go_other TcType
ty)

    go (LitTy (NumTyLit i :: Integer
i))
      -- Let GHC solve simple Literal constraints
      | LitTy _ <- TcType
op
      = Maybe (EvTerm, [Ct]) -> TcPluginM (Maybe (EvTerm, [Ct]))
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe (EvTerm, [Ct])
forall a. Maybe a
Nothing
      -- This plugin only solves Literal KnownNat's that needed to be normalised
      -- first
      | Bool
otherwise
#if MIN_VERSION_ghc(8,5,0)
      = ((EvTerm -> (EvTerm, [Ct])) -> Maybe EvTerm -> Maybe (EvTerm, [Ct])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (,[])) (Maybe EvTerm -> Maybe (EvTerm, [Ct]))
-> TcPluginM (Maybe EvTerm) -> TcPluginM (Maybe (EvTerm, [Ct]))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Class -> TcType -> Integer -> TcPluginM (Maybe EvTerm)
makeLitDict Class
cls TcType
op Integer
i
#else
      = return ((,[]) <$> makeLitDict cls op i)
#endif
    go _ = Maybe (EvTerm, [Ct]) -> TcPluginM (Maybe (EvTerm, [Ct]))
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe (EvTerm, [Ct])
forall a. Maybe a
Nothing

    -- Get EvTerm arguments for type-level operations. If they do not exist
    -- as [G]iven constraints, then generate new [W]anted constraints
#if MIN_VERSION_ghc(8,5,0)
    go_arg :: PredType -> TcPluginM (EvExpr,[Ct])
#else
    go_arg :: PredType -> TcPluginM (EvTerm,[Ct])
#endif
    go_arg :: TcType -> TcPluginM (EvExpr, [Ct])
go_arg ty :: TcType
ty = case CType -> [(CType, EvExpr)] -> Maybe EvExpr
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup (TcType -> CType
CType TcType
ty) [(CType, EvExpr)]
givens of
      Just ev :: EvExpr
ev -> (EvExpr, [Ct]) -> TcPluginM (EvExpr, [Ct])
forall (m :: * -> *) a. Monad m => a -> m a
return (EvExpr
ev,[])
      _ -> do
        (ev :: EvExpr
ev,wanted :: Ct
wanted) <- Ct -> TcType -> TcPluginM (EvExpr, Ct)
makeWantedEv Ct
ct TcType
ty
        (EvExpr, [Ct]) -> TcPluginM (EvExpr, [Ct])
forall (m :: * -> *) a. Monad m => a -> m a
return (EvExpr
ev,[Ct
wanted])

    -- Fall through case: look up the normalised [W]anted constraint in the list
    -- of [G]iven constraints.
    go_other :: Type -> Maybe EvTerm
    go_other :: TcType -> Maybe EvTerm
go_other ty :: TcType
ty =
      let knClsTc :: TyCon
knClsTc = Class -> TyCon
classTyCon Class
cls
          kn :: TcType
kn      = TyCon -> [TcType] -> TcType
mkTyConApp TyCon
knClsTc [TcType
ty]
          cast :: EvExpr -> Maybe EvTerm
cast    = if TcType -> CType
CType TcType
ty CType -> CType -> Bool
forall a. Eq a => a -> a -> Bool
== TcType -> CType
CType TcType
op
#if MIN_VERSION_ghc(8,6,0)
                       then EvTerm -> Maybe EvTerm
forall a. a -> Maybe a
Just (EvTerm -> Maybe EvTerm)
-> (EvExpr -> EvTerm) -> EvExpr -> Maybe EvTerm
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EvExpr -> EvTerm
EvExpr
#else
                       then Just
#endif
                       else Class -> TcType -> TcType -> EvExpr -> Maybe EvTerm
makeKnCoercion Class
cls TcType
ty TcType
op
      in  EvExpr -> Maybe EvTerm
cast (EvExpr -> Maybe EvTerm) -> Maybe EvExpr -> Maybe EvTerm
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< CType -> [(CType, EvExpr)] -> Maybe EvExpr
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup (TcType -> CType
CType TcType
kn) [(CType, EvExpr)]
givens

    -- Find a known constraint for a wanted, so that (modulo normalization)
    -- the two are a constant offset apart.
    offset :: Type -> TcPluginM (Maybe (EvTerm,[Ct]))
    offset :: TcType -> TcPluginM (Maybe (EvTerm, [Ct]))
offset want :: TcType
want = MaybeT TcPluginM (EvTerm, [Ct]) -> TcPluginM (Maybe (EvTerm, [Ct]))
forall (m :: * -> *) a. MaybeT m a -> m (Maybe a)
runMaybeT (MaybeT TcPluginM (EvTerm, [Ct])
 -> TcPluginM (Maybe (EvTerm, [Ct])))
-> MaybeT TcPluginM (EvTerm, [Ct])
-> TcPluginM (Maybe (EvTerm, [Ct]))
forall a b. (a -> b) -> a -> b
$ do
      let -- Get the knownnat contraints
          unKn :: TcType -> Maybe TcType
unKn ty' :: TcType
ty' = case TcType -> PredTree
classifyPredType TcType
ty' of
                       ClassPred cls' :: Class
cls' [ty'' :: TcType
ty'']
                         | Class -> Name
className Class
cls' Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
== Name
knownNatClassName
                         -> TcType -> Maybe TcType
forall a. a -> Maybe a
Just TcType
ty''
                       _ -> Maybe TcType
forall a. Maybe a
Nothing
          -- Get the rewrites
          unEq :: TcType -> Maybe (TcType, TcType)
unEq ty' :: TcType
ty' = case TcType -> PredTree
classifyPredType TcType
ty' of
                       EqPred NomEq ty1 :: TcType
ty1 ty2 :: TcType
ty2 -> (TcType, TcType) -> Maybe (TcType, TcType)
forall a. a -> Maybe a
Just (TcType
ty1,TcType
ty2)
                       _ -> Maybe (TcType, TcType)
forall a. Maybe a
Nothing
          rewrites :: [(TcType, TcType)]
rewrites = ((CType, EvExpr) -> Maybe (TcType, TcType))
-> [(CType, EvExpr)] -> [(TcType, TcType)]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (TcType -> Maybe (TcType, TcType)
unEq (TcType -> Maybe (TcType, TcType))
-> ((CType, EvExpr) -> TcType)
-> (CType, EvExpr)
-> Maybe (TcType, TcType)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CType -> TcType
unCType (CType -> TcType)
-> ((CType, EvExpr) -> CType) -> (CType, EvExpr) -> TcType
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (CType, EvExpr) -> CType
forall a b. (a, b) -> a
fst) [(CType, EvExpr)]
givens
          -- Rewrite
          rewriteTy :: TcType -> (TcType, TcType) -> Maybe TcType
rewriteTy tyK :: TcType
tyK (ty1 :: TcType
ty1,ty2 :: TcType
ty2) | TcType
ty1 TcType -> TcType -> Bool
`eqType` TcType
tyK = TcType -> Maybe TcType
forall a. a -> Maybe a
Just TcType
ty2
                                  | TcType
ty2 TcType -> TcType -> Bool
`eqType` TcType
tyK = TcType -> Maybe TcType
forall a. a -> Maybe a
Just TcType
ty1
                                  | Bool
otherwise        = Maybe TcType
forall a. Maybe a
Nothing
          -- Get only the [G]iven KnownNat constraints
          knowns :: [TcType]
knowns   = ((CType, EvExpr) -> Maybe TcType) -> [(CType, EvExpr)] -> [TcType]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (TcType -> Maybe TcType
unKn (TcType -> Maybe TcType)
-> ((CType, EvExpr) -> TcType) -> (CType, EvExpr) -> Maybe TcType
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CType -> TcType
unCType (CType -> TcType)
-> ((CType, EvExpr) -> CType) -> (CType, EvExpr) -> TcType
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (CType, EvExpr) -> CType
forall a b. (a, b) -> a
fst) [(CType, EvExpr)]
givens
          -- Get all the rewritten KNs
          knownsR :: [TcType]
knownsR  = [Maybe TcType] -> [TcType]
forall a. [Maybe a] -> [a]
catMaybes ([Maybe TcType] -> [TcType]) -> [Maybe TcType] -> [TcType]
forall a b. (a -> b) -> a -> b
$ (TcType -> [Maybe TcType]) -> [TcType] -> [Maybe TcType]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (\t :: TcType
t -> ((TcType, TcType) -> Maybe TcType)
-> [(TcType, TcType)] -> [Maybe TcType]
forall a b. (a -> b) -> [a] -> [b]
map (TcType -> (TcType, TcType) -> Maybe TcType
rewriteTy TcType
t) [(TcType, TcType)]
rewrites) [TcType]
knowns
          -- pair up the sum-of-products KnownNat constraints
          -- with the original Nat operation
          subWant :: TcType -> TcType
subWant  = TyCon -> [TcType] -> TcType
mkTyConApp TyCon
typeNatSubTyCon ([TcType] -> TcType) -> (TcType -> [TcType]) -> TcType -> TcType
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (TcType -> [TcType] -> [TcType]
forall a. a -> [a] -> [a]
:[TcType
want])
          exploded :: [(CoreSOP, TcType)]
exploded = (TcType -> (CoreSOP, TcType)) -> [TcType] -> [(CoreSOP, TcType)]
forall a b. (a -> b) -> [a] -> [b]
map ((CoreSOP, [(TcType, TcType)]) -> CoreSOP
forall a b. (a, b) -> a
fst ((CoreSOP, [(TcType, TcType)]) -> CoreSOP)
-> (TcType -> (CoreSOP, [(TcType, TcType)])) -> TcType -> CoreSOP
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Writer [(TcType, TcType)] CoreSOP -> (CoreSOP, [(TcType, TcType)])
forall w a. Writer w a -> (a, w)
runWriter (Writer [(TcType, TcType)] CoreSOP
 -> (CoreSOP, [(TcType, TcType)]))
-> (TcType -> Writer [(TcType, TcType)] CoreSOP)
-> TcType
-> (CoreSOP, [(TcType, TcType)])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TcType -> Writer [(TcType, TcType)] CoreSOP
normaliseNat (TcType -> Writer [(TcType, TcType)] CoreSOP)
-> (TcType -> TcType)
-> TcType
-> Writer [(TcType, TcType)] CoreSOP
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TcType -> TcType
subWant (TcType -> CoreSOP)
-> (TcType -> TcType) -> TcType -> (CoreSOP, TcType)
forall (a :: * -> * -> *) b c c'.
Arrow a =>
a b c -> a b c' -> a b (c, c')
&&& TcType -> TcType
forall a. a -> a
id)
                         ([TcType]
knowns [TcType] -> [TcType] -> [TcType]
forall a. [a] -> [a] -> [a]
++ [TcType]
knownsR)
          -- interesting cases for us are those where
          -- wanted and given only differ by a constant
          examineDiff :: SOP v c -> a -> Maybe (a, Symbol v c)
examineDiff (S [P [I n :: Integer
n]]) entire :: a
entire = (a, Symbol v c) -> Maybe (a, Symbol v c)
forall a. a -> Maybe a
Just (a
entire,Integer -> Symbol v c
forall v c. Integer -> Symbol v c
I Integer
n)
          examineDiff (S [P [V v :: v
v]]) entire :: a
entire = (a, Symbol v c) -> Maybe (a, Symbol v c)
forall a. a -> Maybe a
Just (a
entire,v -> Symbol v c
forall v c. v -> Symbol v c
V v
v)
          examineDiff _ _ = Maybe (a, Symbol v c)
forall a. Maybe a
Nothing
          interesting :: [(TcType, Symbol TcTyVar c)]
interesting = ((CoreSOP, TcType) -> Maybe (TcType, Symbol TcTyVar c))
-> [(CoreSOP, TcType)] -> [(TcType, Symbol TcTyVar c)]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe ((CoreSOP -> TcType -> Maybe (TcType, Symbol TcTyVar c))
-> (CoreSOP, TcType) -> Maybe (TcType, Symbol TcTyVar c)
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry CoreSOP -> TcType -> Maybe (TcType, Symbol TcTyVar c)
forall v c a c. SOP v c -> a -> Maybe (a, Symbol v c)
examineDiff) [(CoreSOP, TcType)]
exploded
      -- convert the first suitable evidence
      ((h :: TcType
h,corr :: Symbol TcTyVar CType
corr):_) <- [(TcType, Symbol TcTyVar CType)]
-> MaybeT TcPluginM [(TcType, Symbol TcTyVar CType)]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [(TcType, Symbol TcTyVar CType)]
forall c. [(TcType, Symbol TcTyVar c)]
interesting
      let x :: TcType
x = case Symbol TcTyVar CType
corr of
                I 0 -> TcType
h
                I i :: Integer
i | Integer
i Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
< 0     -> TyCon -> [TcType] -> TcType
mkTyConApp TyCon
typeNatAddTyCon [TcType
h,Integer -> TcType
mkNumLitTy (Integer -> Integer
forall a. Num a => a -> a
negate Integer
i)]
                    | Bool
otherwise -> TyCon -> [TcType] -> TcType
mkTyConApp TyCon
typeNatSubTyCon [TcType
h,Integer -> TcType
mkNumLitTy Integer
i]
                _ -> TyCon -> [TcType] -> TcType
mkTyConApp TyCon
typeNatSubTyCon [TcType
h,CoreSOP -> TcType
reifySOP ([Product TcTyVar CType] -> CoreSOP
forall v c. [Product v c] -> SOP v c
S [[Symbol TcTyVar CType] -> Product TcTyVar CType
forall v c. [Symbol v c] -> Product v c
P [Symbol TcTyVar CType
corr]])]
      TcPluginM (Maybe (EvTerm, [Ct])) -> MaybeT TcPluginM (EvTerm, [Ct])
forall (m :: * -> *) a. m (Maybe a) -> MaybeT m a
MaybeT (TcType -> TcPluginM (Maybe (EvTerm, [Ct]))
go TcType
x)

makeWantedEv
  :: Ct
  -> Type
#if MIN_VERSION_ghc(8,5,0)
  -> TcPluginM (EvExpr,Ct)
#else
  -> TcPluginM (EvTerm,Ct)
#endif
makeWantedEv :: Ct -> TcType -> TcPluginM (EvExpr, Ct)
makeWantedEv ct :: Ct
ct ty :: TcType
ty = do
  -- Create a new wanted constraint
  CtEvidence
wantedCtEv <- CtLoc -> TcType -> TcPluginM CtEvidence
newWanted (Ct -> CtLoc
ctLoc Ct
ct) TcType
ty
#if MIN_VERSION_ghc(8,5,0)
  let ev :: EvExpr
ev      = CtEvidence -> EvExpr
ctEvExpr CtEvidence
wantedCtEv
#else
  let ev      = ctEvTerm wantedCtEv
#endif
      wanted :: Ct
wanted  = CtEvidence -> Ct
mkNonCanonical CtEvidence
wantedCtEv
      -- Set the source-location of the new wanted constraint to the source
      -- location of the [W]anted constraint we are currently trying to solve
      ct_ls :: RealSrcSpan
ct_ls   = CtLoc -> RealSrcSpan
ctLocSpan (Ct -> CtLoc
ctLoc Ct
ct)
      ctl :: CtLoc
ctl     = CtEvidence -> CtLoc
ctEvLoc  CtEvidence
wantedCtEv
      wanted' :: Ct
wanted' = Ct -> CtLoc -> Ct
setCtLoc Ct
wanted (CtLoc -> RealSrcSpan -> CtLoc
setCtLocSpan CtLoc
ctl RealSrcSpan
ct_ls)
  (EvExpr, Ct) -> TcPluginM (EvExpr, Ct)
forall (m :: * -> *) a. Monad m => a -> m a
return (EvExpr
ev,Ct
wanted')

{- |
Given:

* A "magic" class, and corresponding instance dictionary function, for a
  type-level arithmetic operation
* Two KnownNat dictionaries

makeOpDict instantiates the dictionary function with the KnownNat dictionaries,
and coerces it to a KnownNat dictionary. i.e. for KnownNat2, the "magic"
dictionary for binary functions, the coercion happens in the following steps:

1. KnownNat2 "+" a b           -> SNatKn (KnownNatF2 "+" a b)
2. SNatKn (KnownNatF2 "+" a b) -> Integer
3. Integer                     -> SNat (a + b)
4. SNat (a + b)                -> KnownNat (a + b)

this process is mirrored for the dictionary functions of a higher arity
-}
makeOpDict
  :: (Class,DFunId)
  -- ^ "magic" class function and dictionary function id
  -> Class
  -- ^ KnownNat class
  -> [Type]
  -- ^ Argument types for the Class
  -> [Type]
  -- ^ Argument types for the Instance
  -> Type           -- ^ Type of the result
#if MIN_VERSION_ghc(8,5,0)
  -> [EvExpr]
#else
  -> [EvTerm]
#endif
  -- ^ Evidence arguments
  -> Maybe EvTerm
makeOpDict :: (Class, TcTyVar)
-> Class
-> [TcType]
-> [TcType]
-> TcType
-> [EvExpr]
-> Maybe EvTerm
makeOpDict (opCls :: Class
opCls,dfid :: TcTyVar
dfid) knCls :: Class
knCls tyArgsC :: [TcType]
tyArgsC tyArgsI :: [TcType]
tyArgsI z :: TcType
z evArgs :: [EvExpr]
evArgs
  | Just (_, kn_co_dict :: TcCoercion
kn_co_dict) <- TyCon -> [TcType] -> Maybe (TcType, TcCoercion)
tcInstNewTyCon_maybe (Class -> TyCon
classTyCon Class
knCls) [TcType
z]
    -- KnownNat n ~ SNat n
  , [ kn_meth :: TcTyVar
kn_meth ] <- Class -> [TcTyVar]
classMethods Class
knCls
  , Just kn_tcRep :: TyCon
kn_tcRep <- TcType -> Maybe TyCon
tyConAppTyCon_maybe -- SNat
                      (TcType -> Maybe TyCon) -> TcType -> Maybe TyCon
forall a b. (a -> b) -> a -> b
$ TcType -> TcType
funResultTy      -- SNat n
                      (TcType -> TcType) -> TcType -> TcType
forall a b. (a -> b) -> a -> b
$ TcType -> TcType
dropForAlls      -- KnownNat n => SNat n
                      (TcType -> TcType) -> TcType -> TcType
forall a b. (a -> b) -> a -> b
$ TcTyVar -> TcType
idType TcTyVar
kn_meth   -- forall n. KnownNat n => SNat n
  , Just (_, kn_co_rep :: TcCoercion
kn_co_rep) <- TyCon -> [TcType] -> Maybe (TcType, TcCoercion)
tcInstNewTyCon_maybe TyCon
kn_tcRep [TcType
z]
    -- SNat n ~ Integer
  , Just (_, op_co_dict :: TcCoercion
op_co_dict) <- TyCon -> [TcType] -> Maybe (TcType, TcCoercion)
tcInstNewTyCon_maybe (Class -> TyCon
classTyCon Class
opCls) [TcType]
tyArgsC
    -- KnownNatAdd a b ~ SNatKn (a+b)
  , [ op_meth :: TcTyVar
op_meth ] <- Class -> [TcTyVar]
classMethods Class
opCls
  , Just (op_tcRep :: TyCon
op_tcRep,op_args :: [TcType]
op_args) <- HasDebugCallStack => TcType -> Maybe (TyCon, [TcType])
TcType -> Maybe (TyCon, [TcType])
splitTyConApp_maybe        -- (SNatKn, [KnownNatF2 f x y])
                                 (TcType -> Maybe (TyCon, [TcType]))
-> TcType -> Maybe (TyCon, [TcType])
forall a b. (a -> b) -> a -> b
$ TcType -> TcType
funResultTy            -- SNatKn (KnownNatF2 f x y)
                                 (TcType -> TcType) -> TcType -> TcType
forall a b. (a -> b) -> a -> b
$ (HasDebugCallStack => TcType -> [TcType] -> TcType
TcType -> [TcType] -> TcType
`piResultTys` [TcType]
tyArgsC) -- KnownNatAdd f x y => SNatKn (KnownNatF2 f x y)
                                 (TcType -> TcType) -> TcType -> TcType
forall a b. (a -> b) -> a -> b
$ TcTyVar -> TcType
idType TcTyVar
op_meth         -- forall f a b . KnownNat2 f a b => SNatKn (KnownNatF2 f a b)
  , Just (_, op_co_rep :: TcCoercion
op_co_rep) <- TyCon -> [TcType] -> Maybe (TcType, TcCoercion)
tcInstNewTyCon_maybe TyCon
op_tcRep [TcType]
op_args
    -- SNatKn (a+b) ~ Integer
#if MIN_VERSION_ghc(8,5,0)
  , let EvExpr dfun_inst :: EvExpr
dfun_inst = TcTyVar -> [TcType] -> [EvExpr] -> EvTerm
evDFunApp TcTyVar
dfid [TcType]
tyArgsI [EvExpr]
evArgs
#else
  , let dfun_inst = EvDFunApp dfid tyArgsI evArgs
#endif
        -- KnownNatAdd a b
        op_to_kn :: TcCoercion
op_to_kn  = TcCoercion -> TcCoercion -> TcCoercion
mkTcTransCo (TcCoercion -> TcCoercion -> TcCoercion
mkTcTransCo TcCoercion
op_co_dict TcCoercion
op_co_rep)
                                (TcCoercion -> TcCoercion
mkTcSymCo (TcCoercion -> TcCoercion -> TcCoercion
mkTcTransCo TcCoercion
kn_co_dict TcCoercion
kn_co_rep))
        -- KnownNatAdd a b ~ KnownNat (a+b)
        ev_tm :: EvTerm
ev_tm     = EvExpr -> TcCoercion -> EvTerm
mkEvCast EvExpr
dfun_inst TcCoercion
op_to_kn
  = EvTerm -> Maybe EvTerm
forall a. a -> Maybe a
Just EvTerm
ev_tm
  | Bool
otherwise
  = Maybe EvTerm
forall a. Maybe a
Nothing

{-
Given:
* A KnownNat dictionary evidence over a type x
* a desired type z
makeKnCoercion assembles a coercion from a KnownNat x
dictionary to a KnownNat z dictionary and applies it
to the passed-in evidence.
The coercion happens in the following steps:
1. KnownNat x -> SNat x
2. SNat x     -> Integer
3. Integer    -> SNat z
4. SNat z     -> KnownNat z
-}
makeKnCoercion :: Class          -- ^ KnownNat class
               -> Type           -- ^ Type of the argument
               -> Type           -- ^ Type of the result
#if MIN_VERSION_ghc(8,5,0)
               -> EvExpr
#else
               -> EvTerm
#endif
               -- ^ KnownNat dictionary for the argument
               -> Maybe EvTerm
makeKnCoercion :: Class -> TcType -> TcType -> EvExpr -> Maybe EvTerm
makeKnCoercion knCls :: Class
knCls x :: TcType
x z :: TcType
z xEv :: EvExpr
xEv
  | Just (_, kn_co_dict_z :: TcCoercion
kn_co_dict_z) <- TyCon -> [TcType] -> Maybe (TcType, TcCoercion)
tcInstNewTyCon_maybe (Class -> TyCon
classTyCon Class
knCls) [TcType
z]
    -- KnownNat z ~ SNat z
  , [ kn_meth :: TcTyVar
kn_meth ] <- Class -> [TcTyVar]
classMethods Class
knCls
  , Just kn_tcRep :: TyCon
kn_tcRep <- TcType -> Maybe TyCon
tyConAppTyCon_maybe -- SNat
                      (TcType -> Maybe TyCon) -> TcType -> Maybe TyCon
forall a b. (a -> b) -> a -> b
$ TcType -> TcType
funResultTy      -- SNat n
                      (TcType -> TcType) -> TcType -> TcType
forall a b. (a -> b) -> a -> b
$ TcType -> TcType
dropForAlls      -- KnownNat n => SNat n
                      (TcType -> TcType) -> TcType -> TcType
forall a b. (a -> b) -> a -> b
$ TcTyVar -> TcType
idType TcTyVar
kn_meth   -- forall n. KnownNat n => SNat n
  , Just (_, kn_co_rep_z :: TcCoercion
kn_co_rep_z) <- TyCon -> [TcType] -> Maybe (TcType, TcCoercion)
tcInstNewTyCon_maybe TyCon
kn_tcRep [TcType
z]
    -- SNat z ~ Integer
  , Just (_, kn_co_rep_x :: TcCoercion
kn_co_rep_x) <- TyCon -> [TcType] -> Maybe (TcType, TcCoercion)
tcInstNewTyCon_maybe TyCon
kn_tcRep [TcType
x]
    -- Integer ~ SNat x
  , Just (_, kn_co_dict_x :: TcCoercion
kn_co_dict_x) <- TyCon -> [TcType] -> Maybe (TcType, TcCoercion)
tcInstNewTyCon_maybe (Class -> TyCon
classTyCon Class
knCls) [TcType
x]
    -- SNat x ~ KnownNat x
  = EvTerm -> Maybe EvTerm
forall a. a -> Maybe a
Just (EvTerm -> Maybe EvTerm)
-> (TcCoercion -> EvTerm) -> TcCoercion -> Maybe EvTerm
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EvExpr -> TcCoercion -> EvTerm
mkEvCast EvExpr
xEv (TcCoercion -> Maybe EvTerm) -> TcCoercion -> Maybe EvTerm
forall a b. (a -> b) -> a -> b
$ (TcCoercion
kn_co_dict_x TcCoercion -> TcCoercion -> TcCoercion
`mkTcTransCo` TcCoercion
kn_co_rep_x) TcCoercion -> TcCoercion -> TcCoercion
`mkTcTransCo` TcCoercion -> TcCoercion
mkTcSymCo (TcCoercion
kn_co_dict_z TcCoercion -> TcCoercion -> TcCoercion
`mkTcTransCo` TcCoercion
kn_co_rep_z)
  | Bool
otherwise = Maybe EvTerm
forall a. Maybe a
Nothing

-- | THIS CODE IS COPIED FROM:
-- https://github.com/ghc/ghc/blob/8035d1a5dc7290e8d3d61446ee4861e0b460214e/compiler/typecheck/TcInteract.hs#L1973
--
-- makeLitDict adds a coercion that will convert the literal into a dictionary
-- of the appropriate type.  See Note [KnownNat & KnownSymbol and EvLit]
-- in TcEvidence.  The coercion happens in 2 steps:
--
--     Integer -> SNat n     -- representation of literal to singleton
--     SNat n  -> KnownNat n -- singleton to dictionary
#if MIN_VERSION_ghc(8,5,0)
makeLitDict :: Class -> Type -> Integer -> TcPluginM (Maybe EvTerm)
#else
makeLitDict :: Class -> Type -> Integer -> Maybe EvTerm
#endif
makeLitDict :: Class -> TcType -> Integer -> TcPluginM (Maybe EvTerm)
makeLitDict clas :: Class
clas ty :: TcType
ty i :: Integer
i
  | Just (_, co_dict :: TcCoercion
co_dict) <- TyCon -> [TcType] -> Maybe (TcType, TcCoercion)
tcInstNewTyCon_maybe (Class -> TyCon
classTyCon Class
clas) [TcType
ty]
    -- co_dict :: KnownNat n ~ SNat n
  , [ meth :: TcTyVar
meth ]   <- Class -> [TcTyVar]
classMethods Class
clas
  , Just tcRep :: TyCon
tcRep <- TcType -> Maybe TyCon
tyConAppTyCon_maybe -- SNat
                    (TcType -> Maybe TyCon) -> TcType -> Maybe TyCon
forall a b. (a -> b) -> a -> b
$ TcType -> TcType
funResultTy     -- SNat n
                    (TcType -> TcType) -> TcType -> TcType
forall a b. (a -> b) -> a -> b
$ TcType -> TcType
dropForAlls     -- KnownNat n => SNat n
                    (TcType -> TcType) -> TcType -> TcType
forall a b. (a -> b) -> a -> b
$ TcTyVar -> TcType
idType TcTyVar
meth     -- forall n. KnownNat n => SNat n
  , Just (_, co_rep :: TcCoercion
co_rep) <- TyCon -> [TcType] -> Maybe (TcType, TcCoercion)
tcInstNewTyCon_maybe TyCon
tcRep [TcType
ty]
        -- SNat n ~ Integer
#if MIN_VERSION_ghc(8,5,0)
  = do
    EvExpr
et <- TcM EvExpr -> TcPluginM EvExpr
forall a. TcM a -> TcPluginM a
unsafeTcPluginTcM (Integer -> TcM EvExpr
forall (m :: * -> *). MonadThings m => Integer -> m EvExpr
mkNaturalExpr Integer
i)
    let ev_tm :: EvTerm
ev_tm = EvExpr -> TcCoercion -> EvTerm
mkEvCast EvExpr
et (TcCoercion -> TcCoercion
mkTcSymCo (TcCoercion -> TcCoercion -> TcCoercion
mkTcTransCo TcCoercion
co_dict TcCoercion
co_rep))
    Maybe EvTerm -> TcPluginM (Maybe EvTerm)
forall (m :: * -> *) a. Monad m => a -> m a
return (EvTerm -> Maybe EvTerm
forall a. a -> Maybe a
Just EvTerm
ev_tm)
  | Bool
otherwise
  = Maybe EvTerm -> TcPluginM (Maybe EvTerm)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe EvTerm
forall a. Maybe a
Nothing
#else
  , let ev_tm = mkEvCast (EvLit (EvNum i)) (mkTcSymCo (mkTcTransCo co_dict co_rep))
  = Just ev_tm
  | otherwise
  = Nothing
#endif

{- |
Given:

* A "magic" class, and corresponding instance dictionary function, for a
  type-level boolean operation
* Two KnownBool dictionaries

makeOpDictByFiat instantiates the dictionary function with the KnownBool
dictionaries, and coerces it to a KnownBool dictionary. i.e. for KnownBoolNat2,
the "magic" dictionary for binary functions, the coercion happens in the
following steps:

1. KnownBoolNat2 "<=?" x y     -> SBoolF "<=?"
2. SBoolF "<=?"                -> Bool
3. Bool                        -> SNat (x <=? y)  THE BY FIAT PART!
4. SBool (x <=? y)             -> KnownBool (x <=? y)

this process is mirrored for the dictionary functions of a higher arity
-}
makeOpDictByFiat
  :: (Class,DFunId)
  -- ^ "magic" class function and dictionary function id
  -> Class
   -- ^ KnownNat class
  -> [Type]
  -- ^ Argument types for the Class
  -> [Type]
  -- ^ Argument types for the Instance
  -> Type
  -- ^ Type of the result
#if MIN_VERSION_ghc(8,6,0)
  -> [EvExpr]
#else
  -> [EvTerm]
#endif
  -- ^ Evidence arguments
  -> Maybe EvTerm
#if MIN_VERSION_ghc(8,6,0)
makeOpDictByFiat :: (Class, TcTyVar)
-> Class
-> [TcType]
-> [TcType]
-> TcType
-> [EvExpr]
-> Maybe EvTerm
makeOpDictByFiat (opCls :: Class
opCls,dfid :: TcTyVar
dfid) knCls :: Class
knCls tyArgsC :: [TcType]
tyArgsC tyArgsI :: [TcType]
tyArgsI z :: TcType
z evArgs :: [EvExpr]
evArgs
    -- KnownBool b ~ SBool b
  | Just (_, kn_co_dict :: TcCoercion
kn_co_dict) <- TyCon -> [TcType] -> Maybe (TcType, TcCoercion)
tcInstNewTyCon_maybe (Class -> TyCon
classTyCon Class
knCls) [TcType
z]
  , [ kn_meth :: TcTyVar
kn_meth ] <- Class -> [TcTyVar]
classMethods Class
knCls
  , Just kn_tcRep :: TyCon
kn_tcRep <- TcType -> Maybe TyCon
tyConAppTyCon_maybe -- SBool
                       (TcType -> Maybe TyCon) -> TcType -> Maybe TyCon
forall a b. (a -> b) -> a -> b
$ TcType -> TcType
funResultTy     -- SBool b
                       (TcType -> TcType) -> TcType -> TcType
forall a b. (a -> b) -> a -> b
$ TcType -> TcType
dropForAlls     -- KnownBool b => SBool b
                       (TcType -> TcType) -> TcType -> TcType
forall a b. (a -> b) -> a -> b
$ TcTyVar -> TcType
idType TcTyVar
kn_meth  -- forall b. KnownBool b => SBool b
    -- SBool b R~ Bool (The "Lie")
  , let kn_co_rep :: TcCoercion
kn_co_rep = UnivCoProvenance -> Role -> TcType -> TcType -> TcCoercion
mkUnivCo (CommandLineOption -> UnivCoProvenance
PluginProv "ghc-typelits-knownnat")
                             Role
Representational
                             (TyCon -> [TcType] -> TcType
mkTyConApp TyCon
kn_tcRep [TcType
z]) TcType
boolTy
    -- KnownBoolNat2 f a b ~ SBool f
  , Just (_, op_co_dict :: TcCoercion
op_co_dict) <- TyCon -> [TcType] -> Maybe (TcType, TcCoercion)
tcInstNewTyCon_maybe (Class -> TyCon
classTyCon Class
opCls) [TcType]
tyArgsC
  , [ op_meth :: TcTyVar
op_meth ] <- Class -> [TcTyVar]
classMethods Class
opCls
  , Just (op_tcRep :: TyCon
op_tcRep,op_args :: [TcType]
op_args) <- HasDebugCallStack => TcType -> Maybe (TyCon, [TcType])
TcType -> Maybe (TyCon, [TcType])
splitTyConApp_maybe        -- (SBool, [f])
                                 (TcType -> Maybe (TyCon, [TcType]))
-> TcType -> Maybe (TyCon, [TcType])
forall a b. (a -> b) -> a -> b
$ TcType -> TcType
funResultTy            -- SBool f
                                 (TcType -> TcType) -> TcType -> TcType
forall a b. (a -> b) -> a -> b
$ (HasDebugCallStack => TcType -> [TcType] -> TcType
TcType -> [TcType] -> TcType
`piResultTys` [TcType]
tyArgsC) -- KnownBoolNat2 f x y => SBool f
                                 (TcType -> TcType) -> TcType -> TcType
forall a b. (a -> b) -> a -> b
$ TcTyVar -> TcType
idType TcTyVar
op_meth         -- forall f x y . KnownBoolNat2 f a b => SBoolf f
    -- SBoolF f ~ Bool
  , Just (_, op_co_rep :: TcCoercion
op_co_rep) <- TyCon -> [TcType] -> Maybe (TcType, TcCoercion)
tcInstNewTyCon_maybe TyCon
op_tcRep [TcType]
op_args
  , EvExpr dfun_inst :: EvExpr
dfun_inst <- TcTyVar -> [TcType] -> [EvExpr] -> EvTerm
evDFunApp TcTyVar
dfid [TcType]
tyArgsI [EvExpr]
evArgs
    -- KnownBoolNat2 f x y ~ KnownBool b
  , let op_to_kn :: TcCoercion
op_to_kn  = TcCoercion -> TcCoercion -> TcCoercion
mkTcTransCo (TcCoercion -> TcCoercion -> TcCoercion
mkTcTransCo TcCoercion
op_co_dict TcCoercion
op_co_rep)
                                (TcCoercion -> TcCoercion
mkTcSymCo (TcCoercion -> TcCoercion -> TcCoercion
mkTcTransCo TcCoercion
kn_co_dict TcCoercion
kn_co_rep))
        ev_tm :: EvTerm
ev_tm     = EvExpr -> TcCoercion -> EvTerm
mkEvCast EvExpr
dfun_inst TcCoercion
op_to_kn
  = EvTerm -> Maybe EvTerm
forall a. a -> Maybe a
Just EvTerm
ev_tm
  | Bool
otherwise
  = Maybe EvTerm
forall a. Maybe a
Nothing
#else
makeOpDictByFiat _ _ _ _ _ _ = Nothing
#endif