{-# LANGUAGE CPP, TemplateHaskell #-}
-- | For the generated instances you'll typically need the following
-- extensions:
--
-- >{-# LANGUAGE TemplateHaskell, MultiParamTypeClasses, FlexibleInstances, ConstraintKinds, UndecidableInstances #-}
module Data.Generics.Traversable.TH
  ( deriveGTraversable
  , gtraverseExpr
  ) where

import Language.Haskell.TH
import Control.Monad
import Data.Generics.Traversable.Core
import Data.List

err :: String -> a
err :: String -> a
err String
s = String -> a
forall a. HasCallStack => String -> a
error (String -> a) -> String -> a
forall a b. (a -> b) -> a -> b
$ String
"Data.Generics.Traversable.TH: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
s

getDataInfo :: Name -> Q (Name, [Name], [(Name, Int, [Type])])
getDataInfo :: Name -> Q (Name, [Name], [(Name, Int, [Type])])
getDataInfo Name
name = do
  Info
info <- Name -> Q Info
reify Name
name
  let
    decl :: Dec
decl =
      case Info
info of
        TyConI Dec
d -> Dec
d
        Info
_ -> String -> Dec
forall a. String -> a
err String
"can't be used on anything but a type constructor of an algebraic data type"

  (Name, [Name], [(Name, Int, [Type])])
-> Q (Name, [Name], [(Name, Int, [Type])])
forall (m :: * -> *) a. Monad m => a -> m a
return ((Name, [Name], [(Name, Int, [Type])])
 -> Q (Name, [Name], [(Name, Int, [Type])]))
-> (Name, [Name], [(Name, Int, [Type])])
-> Q (Name, [Name], [(Name, Int, [Type])])
forall a b. (a -> b) -> a -> b
$
    case Dec
decl of
#if MIN_VERSION_template_haskell(2,11,0)
      DataD    [Type]
_ Name
n [TyVarBndr]
ps Maybe Type
_ [Con]
cs [DerivClause]
_ -> (Name
n, (TyVarBndr -> Name) -> [TyVarBndr] -> [Name]
forall a b. (a -> b) -> [a] -> [b]
map TyVarBndr -> Name
varName [TyVarBndr]
ps, (Con -> (Name, Int, [Type])) -> [Con] -> [(Name, Int, [Type])]
forall a b. (a -> b) -> [a] -> [b]
map Con -> (Name, Int, [Type])
conA [Con]
cs)
      NewtypeD [Type]
_ Name
n [TyVarBndr]
ps Maybe Type
_ Con
c  [DerivClause]
_ -> (Name
n, (TyVarBndr -> Name) -> [TyVarBndr] -> [Name]
forall a b. (a -> b) -> [a] -> [b]
map TyVarBndr -> Name
varName [TyVarBndr]
ps, [Con -> (Name, Int, [Type])
conA Con
c])
#else
      DataD    _ n ps   cs _ -> (n, map varName ps, map conA cs)
      NewtypeD _ n ps   c  _ -> (n, map varName ps, [conA c])
#endif
      Dec
_ -> String -> (Name, [Name], [(Name, Int, [Type])])
forall a. String -> a
err (String
"not a data type declaration: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Dec -> String
forall a. Show a => a -> String
show Dec
decl)

-- | Return a lambda expression which implements 'gtraverse' for the given
-- data type.
gtraverseExpr :: Name -> Q Exp
gtraverseExpr :: Name -> Q Exp
gtraverseExpr Name
typeName = do
  (Name
_name, [Name]
_params, [(Name, Int, [Type])]
constructors) <- Name -> Q (Name, [Name], [(Name, Int, [Type])])
getDataInfo Name
typeName
  Name
f <- String -> Q Name
newName String
"f"
  Name
x <- String -> Q Name
newName String
"x"

  let
    lam :: Q Exp
lam = [PatQ] -> Q Exp -> Q Exp
lamE [Name -> PatQ
varP Name
f, Name -> PatQ
varP Name
x] (Q Exp -> Q Exp) -> Q Exp -> Q Exp
forall a b. (a -> b) -> a -> b
$ Q Exp -> [MatchQ] -> Q Exp
caseE (Name -> Q Exp
varE Name
x) [MatchQ]
matches

    -- Con a1 ... -> pure Con <*> f a1 <*> ...
    mkMatch :: (Name, Int, c) -> MatchQ
mkMatch (Name
c, Int
n, c
_)
     = do [Name]
args <- Int -> Q Name -> Q [Name]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
n (String -> Q Name
newName String
"arg")
          let
            applyF :: Q Exp -> Name -> Q Exp
applyF Q Exp
e Name
arg =
              Name -> Q Exp
varE '(<*>) Q Exp -> Q Exp -> Q Exp
`appE` Q Exp
e Q Exp -> Q Exp -> Q Exp
`appE`
                (Name -> Q Exp
varE Name
f Q Exp -> Q Exp -> Q Exp
`appE` Name -> Q Exp
varE Name
arg)
            body :: Q Exp
body = (Q Exp -> Name -> Q Exp) -> Q Exp -> [Name] -> Q Exp
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Q Exp -> Name -> Q Exp
applyF [| $(varE 'pure) $(conE c) |] [Name]
args
          PatQ -> BodyQ -> [DecQ] -> MatchQ
match (Name -> [PatQ] -> PatQ
conP Name
c ([PatQ] -> PatQ) -> [PatQ] -> PatQ
forall a b. (a -> b) -> a -> b
$ (Name -> PatQ) -> [Name] -> [PatQ]
forall a b. (a -> b) -> [a] -> [b]
map Name -> PatQ
varP [Name]
args) (Q Exp -> BodyQ
normalB Q Exp
body) []
    matches :: [MatchQ]
matches = ((Name, Int, [Type]) -> MatchQ)
-> [(Name, Int, [Type])] -> [MatchQ]
forall a b. (a -> b) -> [a] -> [b]
map (Name, Int, [Type]) -> MatchQ
forall c. (Name, Int, c) -> MatchQ
mkMatch [(Name, Int, [Type])]
constructors

  Q Exp
lam

-- | Example usage:
--
-- >data MyType = MyType
-- >
-- >deriveGTraversable ''MyType
--
-- It tries to create the necessary instance constraints, but is not very
-- smart about it For tricky types, it may fail or produce an
-- overconstrained instance. In that case, write the instance declaration
-- yourself and use 'gtraverseExpr' to derive the implementation:
--
-- >data MyType a = MyType
-- >
-- >instance GTraversable (MyType a) where
-- >  gtraverse = $(gtraverseExpr ''MyType)
deriveGTraversable :: Name -> Q [Dec]
deriveGTraversable :: Name -> Q [Dec]
deriveGTraversable Name
name = do
  Name
ctx <- String -> Q Name
newName String
"c"

  (Name
typeName, [Name]
typeParams, [(Name, Int, [Type])]
constructors) <- Name -> Q (Name, [Name], [(Name, Int, [Type])])
getDataInfo Name
name

  let
    appliedType :: Type
appliedType = (Type -> Type -> Type) -> Type -> [Type] -> Type
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Type -> Type -> Type
AppT (Name -> Type
ConT Name
typeName) ([Type] -> Type) -> [Type] -> Type
forall a b. (a -> b) -> a -> b
$ (Name -> Type) -> [Name] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Name -> Type
VarT [Name]
typeParams

    -- instance (...) => GTraversable ctx MyType where { ... }
    inst :: DecQ
inst =
      CxtQ -> TypeQ -> [DecQ] -> DecQ
instanceD CxtQ
context (Name -> TypeQ
conT ''GTraversable TypeQ -> TypeQ -> TypeQ
`appT` Name -> TypeQ
varT Name
ctx TypeQ -> TypeQ -> TypeQ
`appT` Type -> TypeQ
forall (f :: * -> *) a. Applicative f => a -> f a
pure Type
appliedType) [ do
          -- gtraverse = ...
          Name -> [ClauseQ] -> DecQ
funD 'gtraverse [ [PatQ] -> BodyQ -> [DecQ] -> ClauseQ
clause [] (Q Exp -> BodyQ
normalB (Q Exp -> BodyQ) -> Q Exp -> BodyQ
forall a b. (a -> b) -> a -> b
$ Name -> Q Exp
gtraverseExpr Name
typeName) [] ]
        ]

    context :: CxtQ
context = [TypeQ] -> CxtQ
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
sequence [TypeQ]
userContext

    types :: [Type]
types = [Type] -> [Type]
forall a. Eq a => [a] -> [a]
nub [ Type
t | (Name
_,Int
_,[Type]
ts) <- [(Name, Int, [Type])]
constructors, Type
t <- [Type]
ts ]

#if MIN_VERSION_template_haskell(2,10,0)
-- see https://ghc.haskell.org/trac/ghc/ticket/9270
    userContext :: [TypeQ]
userContext = [ Name -> TypeQ
varT Name
ctx TypeQ -> TypeQ -> TypeQ
`appT` Type -> TypeQ
forall (f :: * -> *) a. Applicative f => a -> f a
pure Type
t | Type
t <- [Type]
types ]
#else
    userContext = [ classP ctx [pure t] | t <- types ]
#endif

  [DecQ] -> Q [Dec]
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
sequence [DecQ
inst]

conA :: Con -> (Name, Int, [Type])
conA :: Con -> (Name, Int, [Type])
conA (NormalC Name
c [BangType]
xs)   = (Name
c, [BangType] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [BangType]
xs, (BangType -> Type) -> [BangType] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map BangType -> Type
forall a b. (a, b) -> b
snd [BangType]
xs)
conA (InfixC BangType
x1 Name
c BangType
x2) = Con -> (Name, Int, [Type])
conA (Name -> [BangType] -> Con
NormalC Name
c [BangType
x1, BangType
x2])
conA (ForallC [TyVarBndr]
_ [Type]
_ Con
c)  = Con -> (Name, Int, [Type])
conA Con
c
conA (RecC Name
c [VarBangType]
xs)      = (Name
c, [VarBangType] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [VarBangType]
xs, (VarBangType -> Type) -> [VarBangType] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (\(Name
_,Bang
_,Type
t)->Type
t) [VarBangType]
xs)
conA Con
_ = String -> (Name, Int, [Type])
forall a. String -> a
err String
"GADTs are not supported yet"

#if MIN_VERSION_template_haskell(2,17,0)
varName :: TyVarBndr flag -> Name
varName (PlainTV n _) = n
varName (KindedTV n _ _) = n
#else
varName :: TyVarBndr -> Name
varName :: TyVarBndr -> Name
varName (PlainTV Name
n) = Name
n
varName (KindedTV Name
n Type
_) = Name
n
#endif