{-# LANGUAGE CPP, DataKinds, TemplateHaskell, QuasiQuotes, TypeApplications, TypeOperators #-}
{-# OPTIONS_GHC -Wno-incomplete-uni-patterns #-}
{-# OPTIONS_HADDOCK hide #-}
module Data.Sum.Templates
( mkElemIndexTypeFamily
, mkApplyInstance
) where

import Control.Monad
import Data.Kind
import Data.Traversable
import Language.Haskell.TH hiding (Type)
import qualified Language.Haskell.TH as TH (Type)
import Language.Haskell.TH.Quote
import Unsafe.Coerce (unsafeCoerce)
import GHC.TypeLits

{- This generates a type family of the form

type family ElemIndex (t :: GHC.Types.Type
                            -> GHC.Types.Type) (ts :: [GHC.Types.Type
                                                       -> GHC.Types.Type]) :: Nat where
  ElemIndex t0 ('(:) t0 _) = 0
  ElemIndex t1 ('(:) t0 ('(:) t1 _)) = 1
  ElemIndex t2 ('(:) t0 ('(:) t1 ('(:) t2 _))) = 2
  ElemIndex t3 ('(:) t0 ('(:) t1 ('(:) t2 ('(:) t3 _)))) = 3
  ElemIndex t4 ('(:) t0 ('(:) t1 ('(:) t2 ('(:) t3 ('(:) t4 _))))) = 4
  etc...
  ElemIndex t ts = TypeError ('(:$$:) ('(:<>:) ('(:<>:) ('Text "'") ('ShowType t)) ('Text "' is not a member of the type-level list")) ('ShowType ts))

-}
mkElemIndexTypeFamily :: Integer -> DecsQ
mkElemIndexTypeFamily :: Integer -> DecsQ
mkElemIndexTypeFamily Integer
paramN = do
  -- Start by declaring some names.
  let [Name
elemIndex, Name
t, Name
ts] = String -> Name
mkName (String -> Name) -> [String] -> [Name]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [String
"ElemIndex", String
"t", String
"ts"]
      -- Helper for building more readable type names rather than verbose gensyms
      mkT :: Integer -> Q Type
mkT = Type -> Q Type
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Type -> Q Type) -> (Integer -> Type) -> Integer -> Q Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Name -> Type
VarT (Name -> Type) -> (Integer -> Name) -> Integer -> Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Name
mkName (String -> Name) -> (Integer -> String) -> Integer -> Name
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Char
't' Char -> String -> String
forall a. a -> [a] -> [a]
:) (String -> String) -> (Integer -> String) -> Integer -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Integer -> String
forall a. Show a => a -> String
show
      -- We want to make the kind signatures explicit here.
      binders :: [Q TyVarBndr]
binders = [Name -> Type -> TyVarBndr
kindedTV Name
t  (Type -> TyVarBndr) -> Q Type -> Q TyVarBndr
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [t| Type -> Type |] , Name -> Type -> TyVarBndr
kindedTV Name
ts (Type -> TyVarBndr) -> Q Type -> Q TyVarBndr
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [t| [Type -> Type] |] ]
      -- This family ends up returning a Nat.
      resultKind :: Q FamilyResultSig
resultKind = Type -> FamilyResultSig
kindSig (Type -> FamilyResultSig) -> Q Type -> Q FamilyResultSig
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [t| Nat |]
      -- We have to build n ElemIndex entries.
      equations :: [TySynEqnQ]
equations = (Integer -> TySynEqnQ) -> [Integer] -> [TySynEqnQ]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Integer -> TySynEqnQ
buildEquation [Integer
0..Integer -> Integer
forall a. Enum a => a -> a
pred Integer
paramN] [TySynEqnQ] -> [TySynEqnQ] -> [TySynEqnQ]
forall a. [a] -> [a] -> [a]
++ [TySynEqnQ
errorCase]
      errorBody :: Q Type
errorBody = [t|
        TypeError ('Text "'" ':<>: ('ShowType $(varT t)) ':<>:
                   'Text "' is not a member of the type-level list" ':$$:
                   'ShowType $(varT ts))
        |]
      -- The tySynEqn API changed in 2.15 so we need a guard here.
      -- buildEquation a single family instance equation; it uses lhsMatch
      -- to do so, making a type of the form 'ElemIndex n (n ': n0 : _)
      -- errorCase is invoked above to provide a readable error
#if MIN_VERSION_template_haskell(2,15,0)
      buildEquation :: Integer -> TySynEqnQ
buildEquation Integer
n = Maybe [TyVarBndr] -> Q Type -> Q Type -> TySynEqnQ
tySynEqn Maybe [TyVarBndr]
forall a. Maybe a
Nothing (Integer -> Q Type
lhsMatch Integer
n) (Q Type -> TySynEqnQ)
-> (Integer -> Q Type) -> Integer -> TySynEqnQ
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TyLitQ -> Q Type
litT (TyLitQ -> Q Type) -> (Integer -> TyLitQ) -> Integer -> Q Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Integer -> TyLitQ
numTyLit (Integer -> TySynEqnQ) -> Integer -> TySynEqnQ
forall a b. (a -> b) -> a -> b
$ Integer
n
      lhsMatch :: Integer -> Q Type
lhsMatch Integer
n = [t| $(conT elemIndex) $(mkT n) $(typeListT WildCardT <$> traverse mkT [0..n]) |]
      errorCase :: TySynEqnQ
errorCase = Maybe [TyVarBndr] -> Q Type -> Q Type -> TySynEqnQ
tySynEqn Maybe [TyVarBndr]
forall a. Maybe a
Nothing [t| $(conT elemIndex) $(varT t) $(varT ts) |] Q Type
errorBody
#else
      buildEquation n = tySynEqn (lhsMatch n) (litT . numTyLit $ n)
      lhsMatch n = [mkT n, typeListT WildCardT <$> traverse mkT [0..n] ]
      errorCase = tySynEqn [varT t, varT ts] errorBody
#endif

  (Dec -> [Dec]) -> Q Dec -> DecsQ
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Dec -> [Dec]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Q Dec -> DecsQ) -> Q (Q Dec) -> DecsQ
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Name
-> [TyVarBndr]
-> FamilyResultSig
-> Maybe InjectivityAnn
-> [TySynEqnQ]
-> Q Dec
closedTypeFamilyD Name
elemIndex
    ([TyVarBndr]
 -> FamilyResultSig -> Maybe InjectivityAnn -> [TySynEqnQ] -> Q Dec)
-> Q [TyVarBndr]
-> Q (FamilyResultSig
      -> Maybe InjectivityAnn -> [TySynEqnQ] -> Q Dec)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Q TyVarBndr] -> Q [TyVarBndr]
forall (t :: * -> *) (f :: * -> *) a.
(Traversable t, Applicative f) =>
t (f a) -> f (t a)
sequenceA [Q TyVarBndr]
binders
    Q (FamilyResultSig -> Maybe InjectivityAnn -> [TySynEqnQ] -> Q Dec)
-> Q FamilyResultSig
-> Q (Maybe InjectivityAnn -> [TySynEqnQ] -> Q Dec)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Q FamilyResultSig
resultKind
    Q (Maybe InjectivityAnn -> [TySynEqnQ] -> Q Dec)
-> Q (Maybe InjectivityAnn) -> Q ([TySynEqnQ] -> Q Dec)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Maybe InjectivityAnn -> Q (Maybe InjectivityAnn)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe InjectivityAnn
forall a. Maybe a
Nothing
    Q ([TySynEqnQ] -> Q Dec) -> Q [TySynEqnQ] -> Q (Q Dec)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [TySynEqnQ] -> Q [TySynEqnQ]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [TySynEqnQ]
equations


mkApplyInstance :: Integer -> Dec
mkApplyInstance :: Integer -> Dec
mkApplyInstance Integer
paramN =
  Maybe Overlap -> Cxt -> Type -> [Dec] -> Dec
InstanceD Maybe Overlap
forall a. Maybe a
Nothing (Type -> Type -> Type
AppT Type
constraint (Type -> Type) -> Cxt -> Cxt
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Cxt
typeParams) (Type -> Type -> Type
AppT (Type -> Type -> Type
AppT (Name -> Type
ConT Name
applyC) Type
constraint) (Type -> Cxt -> Type
typeListT Type
PromotedNilT Cxt
typeParams))
    [ Name -> [Clause] -> Dec
FunD Name
apply ((Integer -> Type -> Clause) -> [Integer] -> Cxt -> [Clause]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Integer -> Type -> Clause
mkClause [Integer
0..] Cxt
typeParams)
    , Pragma -> Dec
PragmaD (Name -> Inline -> RuleMatch -> Phases -> Pragma
InlineP Name
apply Inline
Inlinable RuleMatch
FunLike Phases
AllPhases)
    ]
  where typeParams :: Cxt
typeParams = Name -> Type
VarT (Name -> Type) -> (Integer -> Name) -> Integer -> Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Name
mkName (String -> Name) -> (Integer -> String) -> Integer -> Name
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Char
'f' Char -> String -> String
forall a. a -> [a] -> [a]
:) (String -> String) -> (Integer -> String) -> Integer -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Integer -> String
forall a. Show a => a -> String
show (Integer -> Type) -> [Integer] -> Cxt
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Integer
0..Integer -> Integer
forall a. Enum a => a -> a
pred Integer
paramN]
        [Name
applyC, Name
apply, Name
f, Name
r, Name
union] = String -> Name
mkName (String -> Name) -> [String] -> [Name]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [String
"Apply", String
"apply", String
"f", String
"r", String
"Sum"]
        [Type
constraint, Type
a] = Name -> Type
VarT (Name -> Type) -> (String -> Name) -> String -> Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Name
mkName (String -> Type) -> [String] -> Cxt
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [String
"constraint", String
"a"]
        mkClause :: Integer -> Type -> Clause
mkClause Integer
i Type
nthType = [Pat] -> Body -> [Dec] -> Clause
Clause
#if MIN_VERSION_template_haskell(2,18,0)
          [ VarP f, ConP union [] [ LitP (IntegerL i), VarP r ] ]
#else
          [ Name -> Pat
VarP Name
f, Name -> [Pat] -> Pat
ConP Name
union [ Lit -> Pat
LitP (Integer -> Lit
IntegerL Integer
i), Name -> Pat
VarP Name
r ] ]
#endif
          (Exp -> Body
NormalB (Exp -> Exp -> Exp
AppE (Name -> Exp
VarE Name
f) (Exp -> Type -> Exp
SigE (Exp -> Exp -> Exp
AppE (Name -> Exp
VarE 'unsafeCoerce) (Name -> Exp
VarE Name
r)) (Type -> Type -> Type
AppT Type
nthType Type
a))))
          []

typeListT :: TH.Type -> [TH.Type] -> TH.Type
typeListT :: Type -> Cxt -> Type
typeListT = (Type -> Type -> Type) -> Type -> Cxt -> Type
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (Type -> Type -> Type
AppT (Type -> Type -> Type) -> (Type -> Type) -> Type -> Type -> Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Type -> Type -> Type
AppT Type
PromotedConsT)