{-# LANGUAGE CPP #-}
{-# LANGUAGE TupleSections #-}

#if __GLASGOW_HASKELL__ >= 800
{-# LANGUAGE TemplateHaskellQuotes #-}
#endif

{-|
Module:      Data.Deriving.Via.Internal
Copyright:   (C) 2015-2017 Ryan Scott
License:     BSD-style (see the file LICENSE)
Maintainer:  Ryan Scott
Portability: Template Haskell

On @template-haskell-2.12@ or later (i.e., GHC 8.2 or later), this module
exports functionality which emulates the @GeneralizedNewtypeDeriving@ and
@DerivingVia@ GHC extensions (the latter of which was introduced in GHC 8.6).

On older versions of @template-haskell@/GHC, this module does not export
anything.

Note: this is an internal module, and as such, the API presented here is not
guaranteed to be stable, even between minor releases of this library.
-}
module Data.Deriving.Via.Internal where

#if MIN_VERSION_template_haskell(2,12,0)
import           Control.Monad ((<=<), unless)

import           Data.Deriving.Internal
import qualified Data.List as L
import qualified Data.Map as M
import           Data.Map (Map)
import           Data.Maybe (catMaybes)

import           GHC.Exts (Any)

import           Language.Haskell.TH
import           Language.Haskell.TH.Datatype
import           Language.Haskell.TH.Datatype.TyVarBndr

-------------------------------------------------------------------------------
-- Code generation
-------------------------------------------------------------------------------

{- | Generates an instance for a type class at a newtype by emulating the
behavior of the @GeneralizedNewtypeDeriving@ extension. For example:

@
newtype Foo a = MkFoo a
$('deriveGND' [t| forall a. 'Eq' a => 'Eq' (Foo a) |])
@
-}
deriveGND :: Q Type -> Q [Dec]
deriveGND :: Q Type -> Q [Dec]
deriveGND Q Type
qty = do
  Type
ty <- Q Type
qty
  let ([TyVarBndrSpec]
_instanceTvbs, Cxt
instanceCxt, Type
instanceTy) = Type -> ([TyVarBndrSpec], Cxt, Type)
decomposeType Type
ty
  Type
instanceTy' <- (Type -> Q Type
resolveTypeSynonyms forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< Type -> Q Type
resolveInfixT) Type
instanceTy
  [Dec]
decs <- Type -> Maybe Type -> Q [Dec]
deriveViaDecs Type
instanceTy' forall a. Maybe a
Nothing
  (forall a. a -> [a] -> [a]
:[]) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
`fmap` forall (m :: * -> *).
Quote m =>
m Cxt -> m Type -> [m Dec] -> m Dec
instanceD (forall (m :: * -> *) a. Monad m => a -> m a
return Cxt
instanceCxt)
                         (forall (m :: * -> *) a. Monad m => a -> m a
return Type
instanceTy)
                         (forall a b. (a -> b) -> [a] -> [b]
map forall (m :: * -> *) a. Monad m => a -> m a
return [Dec]
decs)

{- | Generates an instance for a type class by emulating the behavior of the
@DerivingVia@ extension. For example:

@
newtype Foo a = MkFoo a
$('deriveVia' [t| forall a. 'Ord' a => 'Ord' (Foo a) ``Via`` Down a |])
@

As shown in the example above, the syntax is a tad strange. One must specify
the type by which to derive the instance using the 'Via' type. This
requirement is in place to ensure that the type variables are scoped
correctly across all the types being used (e.g., to make sure that the same
@a@ is used in @'Ord' a@, @'Ord' (Foo a)@, and @Down a@).
-}
deriveVia :: Q Type -> Q [Dec]
deriveVia :: Q Type -> Q [Dec]
deriveVia Q Type
qty = do
  Type
ty <- Q Type
qty
  let ([TyVarBndrSpec]
_instanceTvbs, Cxt
instanceCxt, Type
viaApp) = Type -> ([TyVarBndrSpec], Cxt, Type)
decomposeType Type
ty
  Type
viaApp' <- (Type -> Q Type
resolveTypeSynonyms forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< Type -> Q Type
resolveInfixT) Type
viaApp
  (Type
instanceTy, Type
viaTy)
    <- case Type -> (Type, Cxt)
unapplyTy Type
viaApp' of
         (Type
via, [Type
instanceTy,Type
viaTy])
           | Type
via forall a. Eq a => a -> a -> Bool
== Name -> Type
ConT Name
viaTypeName
          -> forall (m :: * -> *) a. Monad m => a -> m a
return (Type
instanceTy, Type
viaTy)
         (Type, Cxt)
_ -> forall (m :: * -> *) a. MonadFail m => String -> m a
fail forall a b. (a -> b) -> a -> b
$ [String] -> String
unlines
                [ String
"Failure to meet ‘deriveVia‘ specification"
                , String
"\tThe ‘Via‘ type must be used, e.g."
                , String
"\t[t| forall a. C (T a) `Via` V a |]"
                ]
  -- This is a stronger requirement than what GHC's implementation of
  -- DerivingVia imposes, but due to Template Haskell restrictions, we
  -- currently can't do better. See #27.
  let viaTyFVs :: [Name]
viaTyFVs           = forall a. TypeSubstitution a => a -> [Name]
freeVariables Type
viaTy
      otherFVs :: [Name]
otherFVs           = forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [forall a. TypeSubstitution a => a -> [Name]
freeVariables Cxt
instanceCxt, forall a. TypeSubstitution a => a -> [Name]
freeVariables Type
instanceTy]
      floatingViaTyFVs :: [Name]
floatingViaTyFVs   = [Name]
viaTyFVs forall a. Eq a => [a] -> [a] -> [a]
L.\\ [Name]
otherFVs
      floatingViaTySubst :: Map Name Type
floatingViaTySubst = forall k a. Ord k => [(k, a)] -> Map k a
M.fromList forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (, Name -> Type
ConT ''Any) [Name]
floatingViaTyFVs
      viaTy' :: Type
viaTy'             = forall a. TypeSubstitution a => Map Name Type -> a -> a
applySubstitution Map Name Type
floatingViaTySubst Type
viaTy
  [Dec]
decs <- Type -> Maybe Type -> Q [Dec]
deriveViaDecs Type
instanceTy (forall a. a -> Maybe a
Just Type
viaTy')
  (forall a. a -> [a] -> [a]
:[]) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
`fmap` forall (m :: * -> *).
Quote m =>
m Cxt -> m Type -> [m Dec] -> m Dec
instanceD (forall (m :: * -> *) a. Monad m => a -> m a
return Cxt
instanceCxt)
                         (forall (m :: * -> *) a. Monad m => a -> m a
return Type
instanceTy)
                         (forall a b. (a -> b) -> [a] -> [b]
map forall (m :: * -> *) a. Monad m => a -> m a
return [Dec]
decs)

deriveViaDecs :: Type       -- ^ The instance head (e.g., @Eq (Foo a)@)
              -> Maybe Type -- ^ If using 'deriveGND', this is 'Nothing.
                            --   If using 'deriveVia', this is 'Just' the @via@ type.
              -> Q [Dec]
deriveViaDecs :: Type -> Maybe Type -> Q [Dec]
deriveViaDecs Type
instanceTy Maybe Type
mbViaTy = do
  let (Type
clsTy, Cxt
clsArgs) = Type -> (Type, Cxt)
unapplyTy Type
instanceTy
  case Type
clsTy of
    ConT Name
clsName -> do
      Info
clsInfo <- Name -> Q Info
reify Name
clsName
      case Info
clsInfo of
        ClassI (ClassD Cxt
_ Name
_ [TyVarBndr ()]
clsTvbs [FunDep]
_ [Dec]
clsDecs) [Dec]
_ ->
          case (forall a. [a] -> Maybe ([a], a)
unsnoc Cxt
clsArgs, forall a. [a] -> Maybe ([a], a)
unsnoc [TyVarBndr ()]
clsTvbs) of
            (Just (Cxt
_, Type
dataApp), Just ([TyVarBndr ()]
_, TyVarBndr ()
clsLastTvb)) -> do
              let (Type
dataTy, Cxt
dataArgs) = Type -> (Type, Cxt)
unapplyTy Type
dataApp
                  clsLastTvbKind :: Type
clsLastTvbKind     = forall flag. TyVarBndr_ flag -> Type
tvbKind TyVarBndr ()
clsLastTvb
                  (Cxt
_, Cxt
kindList)      = Type -> (Cxt, Cxt)
uncurryTy Type
clsLastTvbKind
                  numArgsToEtaReduce :: Int
numArgsToEtaReduce = forall (t :: * -> *) a. Foldable t => t a -> Int
length Cxt
kindList forall a. Num a => a -> a -> a
- Int
1
              Type
repTy <-
                case Maybe Type
mbViaTy of
                  Just Type
viaTy -> forall (m :: * -> *) a. Monad m => a -> m a
return Type
viaTy
                  Maybe Type
Nothing ->
                    case Type
dataTy of
                      ConT Name
dataName -> do
                        DatatypeInfo {
                                       datatypeInstTypes :: DatatypeInfo -> Cxt
datatypeInstTypes = Cxt
dataInstTypes
                                     , datatypeVariant :: DatatypeInfo -> DatatypeVariant
datatypeVariant   = DatatypeVariant
dv
                                     , datatypeCons :: DatatypeInfo -> [ConstructorInfo]
datatypeCons      = [ConstructorInfo]
cons
                                     } <- Name -> Q DatatypeInfo
reifyDatatype Name
dataName
                        case DatatypeVariant -> [ConstructorInfo] -> Maybe Type
newtypeRepType DatatypeVariant
dv [ConstructorInfo]
cons of
                          Just Type
newtypeRepTy ->
                            case Int -> Type -> Maybe Type
etaReduce Int
numArgsToEtaReduce Type
newtypeRepTy of
                              Just Type
etaRepTy ->
                                let repTySubst :: Map Name Type
repTySubst =
                                      forall k a. Ord k => [(k, a)] -> Map k a
M.fromList forall a b. (a -> b) -> a -> b
$
                                      forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (\Type
var Type
arg -> (Type -> Name
varTToName Type
var, Type
arg))
                                              Cxt
dataInstTypes Cxt
dataArgs
                                in forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall a. TypeSubstitution a => Map Name Type -> a -> a
applySubstitution Map Name Type
repTySubst Type
etaRepTy
                              Maybe Type
Nothing -> forall a. Type -> Q a
etaReductionError Type
instanceTy
                          Maybe Type
Nothing -> forall (m :: * -> *) a. MonadFail m => String -> m a
fail forall a b. (a -> b) -> a -> b
$ String
"Not a newtype: " forall a. [a] -> [a] -> [a]
++ Name -> String
nameBase Name
dataName
                      Type
_ -> forall (m :: * -> *) a. MonadFail m => String -> m a
fail forall a b. (a -> b) -> a -> b
$ String
"Not a data type: " forall a. [a] -> [a] -> [a]
++ forall a. Ppr a => a -> String
pprint Type
dataTy
              forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. [Maybe a] -> [a]
catMaybes forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (forall flag.
Name -> [TyVarBndr_ flag] -> Cxt -> Type -> Dec -> Q (Maybe [Dec])
deriveViaDecs' Name
clsName [TyVarBndr ()]
clsTvbs Cxt
clsArgs Type
repTy) [Dec]
clsDecs
            (Maybe (Cxt, Type)
_, Maybe ([TyVarBndr ()], TyVarBndr ())
_) -> forall (m :: * -> *) a. MonadFail m => String -> m a
fail forall a b. (a -> b) -> a -> b
$ String
"Cannot derive instance for nullary class " forall a. [a] -> [a] -> [a]
++ forall a. Ppr a => a -> String
pprint Type
clsTy
        Info
_ -> forall (m :: * -> *) a. MonadFail m => String -> m a
fail forall a b. (a -> b) -> a -> b
$ String
"Not a type class: " forall a. [a] -> [a] -> [a]
++ forall a. Ppr a => a -> String
pprint Type
clsTy
    Type
_ -> forall (m :: * -> *) a. MonadFail m => String -> m a
fail forall a b. (a -> b) -> a -> b
$ String
"Malformed instance: " forall a. [a] -> [a] -> [a]
++ forall a. Ppr a => a -> String
pprint Type
instanceTy

deriveViaDecs' :: Name -> [TyVarBndr_ flag] -> [Type] -> Type -> Dec -> Q (Maybe [Dec])
deriveViaDecs' :: forall flag.
Name -> [TyVarBndr_ flag] -> Cxt -> Type -> Dec -> Q (Maybe [Dec])
deriveViaDecs' Name
clsName [TyVarBndr_ flag]
clsTvbs Cxt
clsArgs Type
repTy Dec
dec = do
    let numExpectedArgs :: Int
numExpectedArgs = forall (t :: * -> *) a. Foldable t => t a -> Int
length [TyVarBndr_ flag]
clsTvbs
        numActualArgs :: Int
numActualArgs   = forall (t :: * -> *) a. Foldable t => t a -> Int
length Cxt
clsArgs
    forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Int
numExpectedArgs forall a. Eq a => a -> a -> Bool
== Int
numActualArgs) forall a b. (a -> b) -> a -> b
$
      forall (m :: * -> *) a. MonadFail m => String -> m a
fail forall a b. (a -> b) -> a -> b
$ String
"Mismatched number of class arguments"
          forall a. [a] -> [a] -> [a]
++ String
"\n\tThe class " forall a. [a] -> [a] -> [a]
++ Name -> String
nameBase Name
clsName forall a. [a] -> [a] -> [a]
++ String
" expects " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show Int
numExpectedArgs forall a. [a] -> [a] -> [a]
++ String
" argument(s),"
          forall a. [a] -> [a] -> [a]
++ String
"\n\tbut was provided " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show Int
numActualArgs forall a. [a] -> [a] -> [a]
++ String
" argument(s)."
    Dec -> Q (Maybe [Dec])
go Dec
dec
  where
    go :: Dec -> Q (Maybe [Dec])

    go :: Dec -> Q (Maybe [Dec])
go (OpenTypeFamilyD (TypeFamilyHead Name
tfName [TyVarBndr ()]
tfTvbs FamilyResultSig
_ Maybe InjectivityAnn
_)) = do
      let lhsSubst :: Map Name Type
lhsSubst = forall flag. [TyVarBndr_ flag] -> Cxt -> Map Name Type
zipTvbSubst [TyVarBndr_ flag]
clsTvbs Cxt
clsArgs
          rhsSubst :: Map Name Type
rhsSubst = forall flag. [TyVarBndr_ flag] -> Cxt -> Map Name Type
zipTvbSubst [TyVarBndr_ flag]
clsTvbs forall a b. (a -> b) -> a -> b
$ forall a. [a] -> a -> [a]
changeLast Cxt
clsArgs Type
repTy
          tfTvbTys :: Cxt
tfTvbTys = forall a b. (a -> b) -> [a] -> [b]
map forall flag. TyVarBndr_ flag -> Type
tvbToType [TyVarBndr ()]
tfTvbs
          tfLHSTys :: Cxt
tfLHSTys = forall a b. (a -> b) -> [a] -> [b]
map (forall a. TypeSubstitution a => Map Name Type -> a -> a
applySubstitution Map Name Type
lhsSubst) Cxt
tfTvbTys
          tfRHSTys :: Cxt
tfRHSTys = forall a b. (a -> b) -> [a] -> [b]
map (forall a. TypeSubstitution a => Map Name Type -> a -> a
applySubstitution Map Name Type
rhsSubst) Cxt
tfTvbTys
          tfRHSTy :: Type
tfRHSTy  = Type -> Cxt -> Type
applyTy (Name -> Type
ConT Name
tfName) Cxt
tfRHSTys
      Dec
tfInst <- Name -> Maybe [Q (TyVarBndr ())] -> [Q Type] -> Q Type -> DecQ
tySynInstDCompat Name
tfName forall a. Maybe a
Nothing
                                 (forall a b. (a -> b) -> [a] -> [b]
map forall (f :: * -> *) a. Applicative f => a -> f a
pure Cxt
tfLHSTys) (forall (f :: * -> *) a. Applicative f => a -> f a
pure Type
tfRHSTy)
      forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall a. a -> Maybe a
Just [Dec
tfInst])

    go (SigD Name
methName Type
methTy) =
      let (Type
fromTy, Type
toTy) = forall flag.
[TyVarBndr_ flag] -> Cxt -> Type -> Type -> (Type, Type)
mkCoerceClassMethEqn [TyVarBndr_ flag]
clsTvbs Cxt
clsArgs Type
repTy forall a b. (a -> b) -> a -> b
$
                           Type -> Type
stripOuterForallT Type
methTy
          fromTau :: Type
fromTau = Type -> Type
stripOuterForallT Type
fromTy
          toTau :: Type
toTau   = Type -> Type
stripOuterForallT Type
toTy
          rhsExpr :: Exp
rhsExpr = Name -> Exp
VarE Name
coerceValName Exp -> Type -> Exp
`AppTypeE` Type
fromTau
                                       Exp -> Type -> Exp
`AppTypeE` Type
toTau
                                       Exp -> Exp -> Exp
`AppE`     Name -> Exp
VarE Name
methName
          sig :: Dec
sig  = Name -> Type -> Dec
SigD Name
methName Type
toTy
          meth :: Dec
meth = Pat -> Body -> [Dec] -> Dec
ValD (Name -> Pat
VarP Name
methName)
                      (Exp -> Body
NormalB Exp
rhsExpr)
                      []
      in forall (m :: * -> *) a. Monad m => a -> m a
return (forall a. a -> Maybe a
Just [Dec
sig, Dec
meth])

    go Dec
_ = forall (m :: * -> *) a. Monad m => a -> m a
return forall a. Maybe a
Nothing

mkCoerceClassMethEqn :: [TyVarBndr_ flag] -> [Type] -> Type -> Type -> (Type, Type)
mkCoerceClassMethEqn :: forall flag.
[TyVarBndr_ flag] -> Cxt -> Type -> Type -> (Type, Type)
mkCoerceClassMethEqn [TyVarBndr_ flag]
clsTvbs Cxt
clsArgs Type
repTy Type
methTy
  = ( forall a. TypeSubstitution a => Map Name Type -> a -> a
applySubstitution Map Name Type
rhsSubst Type
methTy
    , forall a. TypeSubstitution a => Map Name Type -> a -> a
applySubstitution Map Name Type
lhsSubst Type
methTy
    )
  where
    lhsSubst :: Map Name Type
lhsSubst = forall flag. [TyVarBndr_ flag] -> Cxt -> Map Name Type
zipTvbSubst [TyVarBndr_ flag]
clsTvbs Cxt
clsArgs
    rhsSubst :: Map Name Type
rhsSubst = forall flag. [TyVarBndr_ flag] -> Cxt -> Map Name Type
zipTvbSubst [TyVarBndr_ flag]
clsTvbs forall a b. (a -> b) -> a -> b
$ forall a. [a] -> a -> [a]
changeLast Cxt
clsArgs Type
repTy

zipTvbSubst :: [TyVarBndr_ flag] -> [Type] -> Map Name Type
zipTvbSubst :: forall flag. [TyVarBndr_ flag] -> Cxt -> Map Name Type
zipTvbSubst [TyVarBndr_ flag]
tvbs = forall k a. Ord k => [(k, a)] -> Map k a
M.fromList forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (\TyVarBndr_ flag
tvb Type
ty -> (forall flag. TyVarBndr_ flag -> Name
tvName TyVarBndr_ flag
tvb, Type
ty)) [TyVarBndr_ flag]
tvbs

-- | Replace the last element of a list with another element.
changeLast :: [a] -> a -> [a]
changeLast :: forall a. [a] -> a -> [a]
changeLast []     a
_  = forall a. HasCallStack => String -> a
error String
"changeLast"
changeLast [a
_]    a
x  = [a
x]
changeLast (a
x:[a]
xs) a
x' = a
x forall a. a -> [a] -> [a]
: forall a. [a] -> a -> [a]
changeLast [a]
xs a
x'

stripOuterForallT :: Type -> Type
#if __GLASGOW_HASKELL__ < 807
-- Before GHC 8.7, TH-reified classes would put a redundant forall/class
-- context in front of each method's type signature, so we have to strip them
-- off here.
stripOuterForallT (ForallT _ _ ty) = ty
#endif
stripOuterForallT :: Type -> Type
stripOuterForallT Type
ty               = Type
ty

decomposeType :: Type -> ([TyVarBndrSpec], Cxt, Type)
decomposeType :: Type -> ([TyVarBndrSpec], Cxt, Type)
decomposeType (ForallT [TyVarBndrSpec]
tvbs Cxt
ctxt Type
ty) = ([TyVarBndrSpec]
tvbs, Cxt
ctxt, Type
ty)
decomposeType Type
ty                     = ([],   [],   Type
ty)

newtypeRepType :: DatatypeVariant -> [ConstructorInfo] -> Maybe Type
newtypeRepType :: DatatypeVariant -> [ConstructorInfo] -> Maybe Type
newtypeRepType DatatypeVariant
dv [ConstructorInfo]
cons = do
    Maybe ()
checkIfNewtype
    case [ConstructorInfo]
cons of
      [ConstructorInfo { constructorVars :: ConstructorInfo -> [TyVarBndr ()]
constructorVars    = []
                       , constructorContext :: ConstructorInfo -> Cxt
constructorContext = []
                       , constructorFields :: ConstructorInfo -> Cxt
constructorFields  = [Type
repTy]
                       }] -> forall a. a -> Maybe a
Just Type
repTy
      [ConstructorInfo]
_ -> forall a. Maybe a
Nothing
  where
    checkIfNewtype :: Maybe ()
    checkIfNewtype :: Maybe ()
checkIfNewtype
      | DatatypeVariant
Newtype         <- DatatypeVariant
dv = forall a. a -> Maybe a
Just ()
      | DatatypeVariant
NewtypeInstance <- DatatypeVariant
dv = forall a. a -> Maybe a
Just ()
      | Bool
otherwise             = forall a. Maybe a
Nothing

etaReduce :: Int -> Type -> Maybe Type
etaReduce :: Int -> Type -> Maybe Type
etaReduce Int
num Type
ty =
  let (Type
tyHead, Cxt
tyArgs) = Type -> (Type, Cxt)
unapplyTy Type
ty
      (Cxt
tyArgsRemaining, Cxt
tyArgsDropped) = forall a. Int -> [a] -> ([a], [a])
splitAt (forall (t :: * -> *) a. Foldable t => t a -> Int
length Cxt
tyArgs forall a. Num a => a -> a -> a
- Int
num) Cxt
tyArgs
  in if Cxt -> Cxt -> Bool
canEtaReduce Cxt
tyArgsRemaining Cxt
tyArgsDropped
        then forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ Type -> Cxt -> Type
applyTy Type
tyHead Cxt
tyArgsRemaining
        else forall a. Maybe a
Nothing
#endif