{-# LANGUAGE LambdaCase #-} {-# LANGUAGE QuasiQuotes #-} {-# LANGUAGE TemplateHaskell #-} module Data.Constraint.Extras.TH (deriveArgDict, deriveArgDictV, gadtIndices) where import Data.Constraint.Extras import Control.Monad import Data.Constraint import Data.Maybe import Language.Haskell.TH import Data.Either deriveArgDict :: Name -> Q [Dec] deriveArgDict n = do ts <- gadtIndices n c <- newName "c" g <- newName "g" let xs = flip map ts $ \case Left t -> AppT (AppT (ConT ''ConstraintsFor) t) (VarT c) Right t -> (AppT (VarT c) t) xs' = flip map ts $ \case Left t -> AppT (AppT (AppT (ConT ''ConstraintsFor') t) (VarT c)) (VarT g) Right t -> AppT (VarT c) (AppT (VarT g) t) l = length xs constraints = foldl AppT (TupleT l) xs constraints' = foldl AppT (TupleT l) xs' {- runIO $ putStrLn "Constraints:" runIO . putStrLn . pprint $ constraints' -} [d| instance ArgDict $(pure $ ConT n) where type ConstraintsFor $(conT n) $(varT c) = $(pure constraints) type ConstraintsFor' $(conT n) $(varT c) $(varT g) = $(pure constraints') argDict = $(LamCaseE <$> matches n 'argDict) argDict' = $(LamCaseE <$> matches n 'argDict') |] deriveArgDictV :: Name -> Q [Dec] deriveArgDictV n = do vs <- gadtIndices n c <- newName "c" g <- newName "g" let xs = flip map vs $ \case Left t -> AppT (AppT (AppT (ConT ''ConstraintsForV) t) (VarT c)) (VarT g) Right v -> AppT (VarT c) $ AppT v (VarT g) l = length xs constraints = foldl AppT (TupleT l) xs {- runIO $ putStrLn "Constraints:" runIO . putStrLn . pprint $ constraints' -} ds <- deriveArgDict n d <- [d| instance ArgDictV $(pure $ ConT n) where type ConstraintsForV $(conT n) $(varT c) $(varT g) = $(pure constraints) argDictV = $(LamCaseE <$> matches n 'argDictV) |] return (d ++ ds) matches :: Name -> Name -> Q [Match] matches n argDictName = do x <- newName "x" reify n >>= \case TyConI (DataD _ _ _ _ cons _) -> fmap concat $ forM cons $ \case GadtC [name] _ (AppT (ConT _) (VarT _)) -> return $ [Match (ConP name [VarP x]) (NormalB $ AppE (VarE argDictName) (VarE x)) []] GadtC [name] _ _ -> return $ [Match (RecP name []) (NormalB $ ConE 'Dict) []] ForallC _ _ (GadtC [name] bts (AppT (ConT _) (VarT b))) -> do ps <- forM bts $ \case (_, AppT (ConT a) (VarT b')) | b == b' -> do hasArgDictInstance <- not . null <$> reifyInstances ''ArgDict [(ConT a)] return $ if hasArgDictInstance then Just x else Nothing _ -> return Nothing return $ case catMaybes ps of [] -> [Match (RecP name []) (NormalB $ ConE 'Dict) []] (v:_) -> let patf = \v' rest done -> if done then WildP : rest done else case v' of Nothing -> WildP : rest done Just _ -> VarP v : rest True pat = foldr patf (const []) ps False in [Match (ConP name pat) (NormalB $ AppE (VarE argDictName) (VarE v)) []] ForallC _ _ (GadtC [name] _ _) -> return $ [Match (RecP name []) (NormalB $ ConE 'Dict) []] a -> error $ "deriveArgDict matches: Unmatched 'Dec': " <> show a a -> error $ "deriveArgDict matches: Unmatched 'Info': " <> show a gadtIndices :: Name -> Q [Either Type Type] gadtIndices n = do reify n >>= \case TyConI (DataD _ _ _ _ cons _) -> fmap concat $ forM cons $ \x -> case x of GadtC _ _ (AppT (ConT _) (VarT _)) -> return [] GadtC _ _ (AppT _ typ) -> return [Right typ] ForallC _ _ (GadtC _ bts (AppT (ConT _) (VarT _))) -> fmap concat $ forM bts $ \case (_, AppT (ConT a) (VarT _)) -> do hasArgDictInstance <- fmap (not . null) $ reifyInstances ''ArgDict [(ConT a)] return $ if hasArgDictInstance then [Left (ConT a)] else [] _ -> return [] ForallC _ _ (GadtC _ _ (AppT _ typ)) -> return [Right typ] _ -> return [] a -> error $ "gadtResults: Unmatched 'Info': " <> show a