{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE QuasiQuotes #-}
{-# LANGUAGE TemplateHaskell #-}

module Data.Constraint.Extras.TH (deriveArgDict, deriveArgDictV, gadtIndices) where

import Data.Constraint
import Data.Constraint.Extras
import Data.Maybe
import Control.Monad
import Language.Haskell.TH

deriveArgDict :: Name -> Q [Dec]
deriveArgDict :: Name -> Q [InstanceDec]
deriveArgDict Name
n = do
  (Type
typeHead, [Con]
constrs) <- Name -> Q (Type, [Con])
getDeclInfo Name
n
  Name
c <- String -> Q Name
forall (m :: * -> *). Quote m => String -> m Name
newName String
"c"
  [Either Type Type]
ts <- Name -> [Con] -> Q [Either Type Type]
gadtIndices Name
c [Con]
constrs
  let xs :: [Type]
xs = ((Either Type Type -> Type) -> [Either Type Type] -> [Type])
-> [Either Type Type] -> (Either Type Type -> Type) -> [Type]
forall a b c. (a -> b -> c) -> b -> a -> c
flip (Either Type Type -> Type) -> [Either Type Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map [Either Type Type]
ts ((Either Type Type -> Type) -> [Type])
-> (Either Type Type -> Type) -> [Type]
forall a b. (a -> b) -> a -> b
$ \case
        Left Type
t -> Type -> Type -> Type
AppT (Type -> Type -> Type
AppT (Name -> Type
ConT ''ConstraintsFor) Type
t) (Name -> Type
VarT Name
c)
        Right Type
t -> (Type -> Type -> Type
AppT (Name -> Type
VarT Name
c) Type
t)
      l :: Int
l = [Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
xs
      constraints :: Type
constraints = (Type -> Type -> Type) -> Type -> [Type] -> Type
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Type -> Type -> Type
AppT (Int -> Type
TupleT Int
l) [Type]
xs
  [d| instance ArgDict $(varT c) $(pure typeHead) where
        type ConstraintsFor  $(pure typeHead) $(varT c) = $(pure constraints)
        argDict = $(LamCaseE <$> matches c constrs 'argDict)
    |]

{-# DEPRECATED deriveArgDictV "Just use 'deriveArgDict'" #-}
deriveArgDictV :: Name -> Q [Dec]
deriveArgDictV :: Name -> Q [InstanceDec]
deriveArgDictV = Name -> Q [InstanceDec]
deriveArgDict

matches :: Name -> [Con] -> Name -> Q [Match]
matches :: Name -> [Con] -> Name -> Q [Match]
matches Name
c [Con]
constrs Name
argDictName = do
  Name
x <- String -> Q Name
forall (m :: * -> *). Quote m => String -> m Name
newName String
"x"
  ([[Match]] -> [Match]) -> Q [[Match]] -> Q [Match]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [[Match]] -> [Match]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat (Q [[Match]] -> Q [Match]) -> Q [[Match]] -> Q [Match]
forall a b. (a -> b) -> a -> b
$ [Con] -> (Con -> Q [Match]) -> Q [[Match]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Con]
constrs ((Con -> Q [Match]) -> Q [[Match]])
-> (Con -> Q [Match]) -> Q [[Match]]
forall a b. (a -> b) -> a -> b
$ \case
    GadtC [Name
name] [BangType]
_ Type
_ -> [Match] -> Q [Match]
forall (m :: * -> *) a. Monad m => a -> m a
return ([Match] -> Q [Match]) -> [Match] -> Q [Match]
forall a b. (a -> b) -> a -> b
$
      [Pat -> Body -> [InstanceDec] -> Match
Match (Name -> [FieldPat] -> Pat
RecP Name
name []) (Exp -> Body
NormalB (Exp -> Body) -> Exp -> Body
forall a b. (a -> b) -> a -> b
$ Name -> Exp
ConE 'Dict) []]
    ForallC [TyVarBndr Specificity]
_ [Type]
_ (GadtC [Name
name] [BangType]
bts (AppT Type
_ (VarT Name
b))) -> do
      [Maybe Name]
ps <- [BangType] -> (BangType -> Q (Maybe Name)) -> Q [Maybe Name]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [BangType]
bts ((BangType -> Q (Maybe Name)) -> Q [Maybe Name])
-> (BangType -> Q (Maybe Name)) -> Q [Maybe Name]
forall a b. (a -> b) -> a -> b
$ \case
        (Bang
_, AppT Type
t (VarT Name
b')) | Name
b Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
== Name
b' -> do
          Bool
hasArgDictInstance <- Bool -> Bool
not (Bool -> Bool) -> ([InstanceDec] -> Bool) -> [InstanceDec] -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [InstanceDec] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null ([InstanceDec] -> Bool) -> Q [InstanceDec] -> Q Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Name -> [Type] -> Q [InstanceDec]
reifyInstances ''ArgDict [Name -> Type
VarT Name
c, Type
t]
          Maybe Name -> Q (Maybe Name)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe Name -> Q (Maybe Name)) -> Maybe Name -> Q (Maybe Name)
forall a b. (a -> b) -> a -> b
$ if Bool
hasArgDictInstance
            then Name -> Maybe Name
forall a. a -> Maybe a
Just Name
x
            else Maybe Name
forall a. Maybe a
Nothing
        BangType
_ -> Maybe Name -> Q (Maybe Name)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe Name
forall a. Maybe a
Nothing
      [Match] -> Q [Match]
forall (m :: * -> *) a. Monad m => a -> m a
return ([Match] -> Q [Match]) -> [Match] -> Q [Match]
forall a b. (a -> b) -> a -> b
$ case [Maybe Name] -> [Name]
forall a. [Maybe a] -> [a]
catMaybes [Maybe Name]
ps of
        [] -> [Pat -> Body -> [InstanceDec] -> Match
Match (Name -> [FieldPat] -> Pat
RecP Name
name []) (Exp -> Body
NormalB (Exp -> Body) -> Exp -> Body
forall a b. (a -> b) -> a -> b
$ Name -> Exp
ConE 'Dict) []]
        (Name
v:[Name]
_) ->
          let patf :: Maybe a -> (Bool -> [Pat]) -> Bool -> [Pat]
patf = \Maybe a
v' Bool -> [Pat]
rest Bool
done -> if Bool
done
                then Pat
WildP Pat -> [Pat] -> [Pat]
forall a. a -> [a] -> [a]
: Bool -> [Pat]
rest Bool
done
                else case Maybe a
v' of
                  Maybe a
Nothing -> Pat
WildP Pat -> [Pat] -> [Pat]
forall a. a -> [a] -> [a]
: Bool -> [Pat]
rest Bool
done
                  Just a
_ -> Name -> Pat
VarP Name
v Pat -> [Pat] -> [Pat]
forall a. a -> [a] -> [a]
: Bool -> [Pat]
rest Bool
True
              pat :: [Pat]
pat = (Maybe Name -> (Bool -> [Pat]) -> Bool -> [Pat])
-> (Bool -> [Pat]) -> [Maybe Name] -> Bool -> [Pat]
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr Maybe Name -> (Bool -> [Pat]) -> Bool -> [Pat]
forall {a}. Maybe a -> (Bool -> [Pat]) -> Bool -> [Pat]
patf ([Pat] -> Bool -> [Pat]
forall a b. a -> b -> a
const []) [Maybe Name]
ps Bool
False
          in [Pat -> Body -> [InstanceDec] -> Match
Match (Name -> [Pat] -> Pat
ConP Name
name [Pat]
pat) (Exp -> Body
NormalB (Exp -> Body) -> Exp -> Body
forall a b. (a -> b) -> a -> b
$ Exp -> Exp -> Exp
AppE (Name -> Exp
VarE Name
argDictName) (Name -> Exp
VarE Name
v)) []]
    ForallC [TyVarBndr Specificity]
_ [Type]
_ (GadtC [Name
name] [BangType]
_ Type
_) -> [Match] -> Q [Match]
forall (m :: * -> *) a. Monad m => a -> m a
return ([Match] -> Q [Match]) -> [Match] -> Q [Match]
forall a b. (a -> b) -> a -> b
$
      [Pat -> Body -> [InstanceDec] -> Match
Match (Name -> [FieldPat] -> Pat
RecP Name
name []) (Exp -> Body
NormalB (Exp -> Body) -> Exp -> Body
forall a b. (a -> b) -> a -> b
$ Name -> Exp
ConE 'Dict) []]
    Con
a -> String -> Q [Match]
forall a. HasCallStack => String -> a
error (String -> Q [Match]) -> String -> Q [Match]
forall a b. (a -> b) -> a -> b
$ String
"deriveArgDict matches: Unmatched 'Dec': " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Con -> String
forall a. Show a => a -> String
show Con
a

kindArity :: Kind -> Int
kindArity :: Type -> Int
kindArity = \case
  ForallT [TyVarBndr Specificity]
_ [Type]
_ Type
t -> Type -> Int
kindArity Type
t
  AppT (AppT Type
ArrowT Type
_) Type
t -> Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Type -> Int
kindArity Type
t
  SigT Type
t Type
_ -> Type -> Int
kindArity Type
t
  ParensT Type
t -> Type -> Int
kindArity Type
t
  Type
_ -> Int
0

getDeclInfo :: Name -> Q (Type, [Con])
getDeclInfo :: Name -> Q (Type, [Con])
getDeclInfo Name
n = Name -> Q Info
reify Name
n Q Info -> (Info -> Q (Type, [Con])) -> Q (Type, [Con])
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
  TyConI (DataD [Type]
_ Name
_ [TyVarBndr ()]
ts Maybe Type
mk [Con]
constrs [DerivClause]
_) -> do
    let arity :: Int
arity = Int -> Maybe Int -> Int
forall a. a -> Maybe a -> a
fromMaybe Int
0 ((Type -> Int) -> Maybe Type -> Maybe Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Type -> Int
kindArity Maybe Type
mk) Int -> Int -> Int
forall a. Num a => a -> a -> a
+ [TyVarBndr ()] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [TyVarBndr ()]
ts
    [Name]
tyVars <- Int -> Q Name -> Q [Name]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM (Int
arity Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) (String -> Q Name
forall (m :: * -> *). Quote m => String -> m Name
newName String
"a")
    let typeHead :: Type
typeHead = (Name -> Type -> Type) -> Type -> [Name] -> Type
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (\Name
v Type
x -> Type -> Type -> Type
AppT Type
x (Name -> Type
VarT Name
v)) (Name -> Type
ConT Name
n) [Name]
tyVars
    (Type, [Con]) -> Q (Type, [Con])
forall (m :: * -> *) a. Monad m => a -> m a
return (Type
typeHead, [Con]
constrs)
  DataConI Name
_ (ForallT [TyVarBndr Specificity]
_ [Type]
_ (AppT Type
typeHead Type
_)) Name
parent -> do
    Name -> Q Info
reify Name
parent Q Info -> (Info -> Q (Type, [Con])) -> Q (Type, [Con])
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
      FamilyI InstanceDec
_ [InstanceDec]
instances -> do
        let instCons :: InstanceDec -> [Con]
            instCons :: InstanceDec -> [Con]
instCons = \case
              DataInstD [Type]
_ Maybe [TyVarBndr ()]
_ Type
_ Maybe Type
_ [Con]
cons [DerivClause]
_ -> [Con]
cons
              NewtypeInstD [Type]
_ Maybe [TyVarBndr ()]
_ Type
_ Maybe Type
_ Con
con [DerivClause]
_ -> [Con
con]
              InstanceDec
_ -> String -> [Con]
forall a. HasCallStack => String -> a
error (String -> [Con]) -> String -> [Con]
forall a b. (a -> b) -> a -> b
$ String
"getDeclInfo: Expected a data or newtype family instance"
            conNames :: Con -> [Name]
            conNames :: Con -> [Name]
conNames = \case
              NormalC Name
other [BangType]
_ -> [Name
other]
              RecC Name
other [VarBangType]
_ -> [Name
other]
              InfixC BangType
_ Name
other BangType
_ -> [Name
other]
              ForallC [TyVarBndr Specificity]
_ [Type]
_ Con
con -> Con -> [Name]
conNames Con
con
              GadtC [Name]
others [BangType]
_ Type
_ -> [Name]
others
              RecGadtC [Name]
others [VarBangType]
_ Type
_ -> [Name]
others
            instHasThisConstructor :: InstanceDec -> Bool
instHasThisConstructor InstanceDec
i = (Name -> Bool) -> [Name] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
== Name
n) ([Name] -> Bool) -> [Name] -> Bool
forall a b. (a -> b) -> a -> b
$ Con -> [Name]
conNames (Con -> [Name]) -> [Con] -> [Name]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< InstanceDec -> [Con]
instCons InstanceDec
i
        case (InstanceDec -> Bool) -> [InstanceDec] -> [InstanceDec]
forall a. (a -> Bool) -> [a] -> [a]
filter InstanceDec -> Bool
instHasThisConstructor [InstanceDec]
instances of
          [] -> String -> Q (Type, [Con])
forall a. HasCallStack => String -> a
error (String -> Q (Type, [Con])) -> String -> Q (Type, [Con])
forall a b. (a -> b) -> a -> b
$ String
"getDeclInfo: Couldn't find data family instance for constructor " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Name -> String
forall a. Show a => a -> String
show Name
n
          l :: [InstanceDec]
l@(InstanceDec
_:InstanceDec
_:[InstanceDec]
_) -> String -> Q (Type, [Con])
forall a. HasCallStack => String -> a
error (String -> Q (Type, [Con])) -> String -> Q (Type, [Con])
forall a b. (a -> b) -> a -> b
$ String
"getDeclInfo: Expected one data family instance for constructor " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Name -> String
forall a. Show a => a -> String
show Name
n String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" but found multiple: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ [InstanceDec] -> String
forall a. Show a => a -> String
show [InstanceDec]
l
          [InstanceDec
i] -> (Type, [Con]) -> Q (Type, [Con])
forall (m :: * -> *) a. Monad m => a -> m a
return (Type
typeHead, InstanceDec -> [Con]
instCons InstanceDec
i)
      Info
a -> String -> Q (Type, [Con])
forall a. HasCallStack => String -> a
error (String -> Q (Type, [Con])) -> String -> Q (Type, [Con])
forall a b. (a -> b) -> a -> b
$ String
"getDeclInfo: Unmatched parent of data family instance: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Info -> String
forall a. Show a => a -> String
show Info
a
  Info
a -> String -> Q (Type, [Con])
forall a. HasCallStack => String -> a
error (String -> Q (Type, [Con])) -> String -> Q (Type, [Con])
forall a b. (a -> b) -> a -> b
$ String
"getDeclInfo: Unmatched 'Info': " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Info -> String
forall a. Show a => a -> String
show Info
a

gadtIndices :: Name -> [Con] -> Q [Either Type Type]
gadtIndices :: Name -> [Con] -> Q [Either Type Type]
gadtIndices Name
c [Con]
constrs = ([[Either Type Type]] -> [Either Type Type])
-> Q [[Either Type Type]] -> Q [Either Type Type]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [[Either Type Type]] -> [Either Type Type]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat (Q [[Either Type Type]] -> Q [Either Type Type])
-> Q [[Either Type Type]] -> Q [Either Type Type]
forall a b. (a -> b) -> a -> b
$ [Con] -> (Con -> Q [Either Type Type]) -> Q [[Either Type Type]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Con]
constrs ((Con -> Q [Either Type Type]) -> Q [[Either Type Type]])
-> (Con -> Q [Either Type Type]) -> Q [[Either Type Type]]
forall a b. (a -> b) -> a -> b
$ \case
  GadtC [Name]
_ [BangType]
_ (AppT Type
_ Type
typ) -> [Either Type Type] -> Q [Either Type Type]
forall (m :: * -> *) a. Monad m => a -> m a
return [Type -> Either Type Type
forall a b. b -> Either a b
Right Type
typ]
  ForallC [TyVarBndr Specificity]
_ [Type]
_ (GadtC [Name]
_ [BangType]
bts (AppT Type
_ (VarT Name
_))) -> ([[Either Type Type]] -> [Either Type Type])
-> Q [[Either Type Type]] -> Q [Either Type Type]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [[Either Type Type]] -> [Either Type Type]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat (Q [[Either Type Type]] -> Q [Either Type Type])
-> Q [[Either Type Type]] -> Q [Either Type Type]
forall a b. (a -> b) -> a -> b
$ [BangType]
-> (BangType -> Q [Either Type Type]) -> Q [[Either Type Type]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [BangType]
bts ((BangType -> Q [Either Type Type]) -> Q [[Either Type Type]])
-> (BangType -> Q [Either Type Type]) -> Q [[Either Type Type]]
forall a b. (a -> b) -> a -> b
$ \case
    (Bang
_, AppT Type
t (VarT Name
_)) -> do
      Bool
hasArgDictInstance <- ([InstanceDec] -> Bool) -> Q [InstanceDec] -> Q Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Bool -> Bool
not (Bool -> Bool) -> ([InstanceDec] -> Bool) -> [InstanceDec] -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [InstanceDec] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null) (Q [InstanceDec] -> Q Bool) -> Q [InstanceDec] -> Q Bool
forall a b. (a -> b) -> a -> b
$ Name -> [Type] -> Q [InstanceDec]
reifyInstances ''ArgDict [Name -> Type
VarT Name
c, Type
t]
      [Either Type Type] -> Q [Either Type Type]
forall (m :: * -> *) a. Monad m => a -> m a
return ([Either Type Type] -> Q [Either Type Type])
-> [Either Type Type] -> Q [Either Type Type]
forall a b. (a -> b) -> a -> b
$ if Bool
hasArgDictInstance then [Type -> Either Type Type
forall a b. a -> Either a b
Left Type
t] else []
    BangType
_ -> [Either Type Type] -> Q [Either Type Type]
forall (m :: * -> *) a. Monad m => a -> m a
return []
  ForallC [TyVarBndr Specificity]
_ [Type]
_ (GadtC [Name]
_ [BangType]
_ (AppT Type
_ Type
typ)) -> [Either Type Type] -> Q [Either Type Type]
forall (m :: * -> *) a. Monad m => a -> m a
return [Type -> Either Type Type
forall a b. b -> Either a b
Right Type
typ]
  Con
_ -> [Either Type Type] -> Q [Either Type Type]
forall (m :: * -> *) a. Monad m => a -> m a
return []