{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE QuasiQuotes #-}
{-# LANGUAGE TemplateHaskell #-}
module Data.Constraint.Extras.TH (deriveArgDict, deriveArgDictV, gadtIndices) where
import Data.Constraint.Extras
import Control.Monad
import Data.Constraint
import Data.Maybe
import Language.Haskell.TH
import Data.Either
deriveArgDict :: Name -> Q [Dec]
deriveArgDict n = do
ts <- gadtIndices n
c <- newName "c"
g <- newName "g"
let xs = flip map ts $ \case
Left t -> AppT (AppT (ConT ''ConstraintsFor) t) (VarT c)
Right t -> (AppT (VarT c) t)
xs' = flip map ts $ \case
Left t -> AppT (AppT (AppT (ConT ''ConstraintsFor') t) (VarT c)) (VarT g)
Right t -> AppT (VarT c) (AppT (VarT g) t)
l = length xs
constraints = foldl AppT (TupleT l) xs
constraints' = foldl AppT (TupleT l) xs'
[d| instance ArgDict $(pure $ ConT n) where
type ConstraintsFor $(conT n) $(varT c) = $(pure constraints)
type ConstraintsFor' $(conT n) $(varT c) $(varT g) = $(pure constraints')
argDict = $(LamCaseE <$> matches n 'argDict)
argDict' = $(LamCaseE <$> matches n 'argDict')
|]
deriveArgDictV :: Name -> Q [Dec]
deriveArgDictV n = do
vs <- gadtIndices n
c <- newName "c"
g <- newName "g"
let xs = flip map vs $ \case
Left t -> AppT (AppT (AppT (ConT ''ConstraintsForV) t) (VarT c)) (VarT g)
Right v -> AppT (VarT c) $ AppT v (VarT g)
l = length xs
constraints = foldl AppT (TupleT l) xs
ds <- deriveArgDict n
d <- [d| instance ArgDictV $(pure $ ConT n) where
type ConstraintsForV $(conT n) $(varT c) $(varT g) = $(pure constraints)
argDictV = $(LamCaseE <$> matches n 'argDictV)
|]
return (d ++ ds)
matches :: Name -> Name -> Q [Match]
matches n argDictName = do
x <- newName "x"
reify n >>= \case
TyConI (DataD _ _ _ _ cons _) -> fmap concat $ forM cons $ \case
GadtC [name] _ (AppT (ConT _) (VarT _)) -> return $
[Match (ConP name [VarP x]) (NormalB $ AppE (VarE argDictName) (VarE x)) []]
GadtC [name] _ _ -> return $
[Match (RecP name []) (NormalB $ ConE 'Dict) []]
ForallC _ _ (GadtC [name] bts (AppT (ConT _) (VarT b))) -> do
ps <- forM bts $ \case
(_, AppT (ConT a) (VarT b')) | b == b' -> do
hasArgDictInstance <- not . null <$> reifyInstances ''ArgDict [(ConT a)]
return $ if hasArgDictInstance
then Just x
else Nothing
_ -> return Nothing
return $ case catMaybes ps of
[] -> [Match (RecP name []) (NormalB $ ConE 'Dict) []]
(v:_) ->
let patf = \v' rest done -> if done
then WildP : rest done
else case v' of
Nothing -> WildP : rest done
Just _ -> VarP v : rest True
pat = foldr patf (const []) ps False
in [Match (ConP name pat) (NormalB $ AppE (VarE argDictName) (VarE v)) []]
ForallC _ _ (GadtC [name] _ _) -> return $
[Match (RecP name []) (NormalB $ ConE 'Dict) []]
a -> error $ "deriveArgDict matches: Unmatched 'Dec': " <> show a
a -> error $ "deriveArgDict matches: Unmatched 'Info': " <> show a
gadtIndices :: Name -> Q [Either Type Type]
gadtIndices n = do
reify n >>= \case
TyConI (DataD _ _ _ _ cons _) -> fmap concat $ forM cons $ \x -> case x of
GadtC _ _ (AppT (ConT _) (VarT _)) -> return []
GadtC _ _ (AppT _ typ) -> return [Right typ]
ForallC _ _ (GadtC _ bts (AppT (ConT _) (VarT _))) -> fmap concat $ forM bts $ \case
(_, AppT (ConT a) (VarT _)) -> do
hasArgDictInstance <- fmap (not . null) $ reifyInstances ''ArgDict [(ConT a)]
return $ if hasArgDictInstance then [Left (ConT a)] else []
_ -> return []
ForallC _ _ (GadtC _ _ (AppT _ typ)) -> return [Right typ]
_ -> return []
a -> error $ "gadtResults: Unmatched 'Info': " <> show a