{-# language ScopedTypeVariables #-} {-# language RecordWildCards #-} {-# language TupleSections #-} {-# language FlexibleInstances #-} {-# language MultiParamTypeClasses #-} {-# language FlexibleContexts #-} {-# language GADTs #-} {-# language LambdaCase #-} {-# language ViewPatterns #-} module Language.Haskell.TH.Instances.Internal (module Language.Haskell.TH.Instances.Internal ,module X) where import Language.Haskell.TH as X import Language.Haskell.TH.Instances.Internal.Utils as X import Language.Haskell.TH.Instances.Defaults as X import Language.Haskell.TH.Syntax as X hiding (lift) import Language.Haskell.Meta.Parse as X (parseDecs) import Language.Haskell.TH.Quote as X (QuasiQuoter(..)) import Data.Map as X (Map) import qualified Data.Map as M import Data.Maybe as X (mapMaybe) import Control.Monad.Writer as X import Data.Foldable as X -- | @QuasiQuoter@ for providing . -- -- Example: -- -- > {-# language TemplateHaskell,QuasiQuotes,FlexibleInstances,UndecidableInstances #-} -- > import Prelude hiding (Monoid(..)) -- > import Language.Haskell.TH.Instances -- > -- > class Semigroup a where mappend :: a -> a -> a -- > class Semigroup a => Commutative a -- > class Semigroup a => Monoid a where mempty :: a -- > class Monoid a => Group a where inverse :: a -> a -- > class (Commutative a, Group a) => CommutativeGroup a -- > $(return []) -- Only needed if classes are defined in the same module, to make sure they're in scope below -- > [instances| Num a => CommutativeGroup a where -- > mempty = fromInteger 0 -- > mappend a b = a + b -- > inverse = negate -- > |] -- -- will generate the appropriate instances for @Semigroup@, @Monoid@, and @Group@: -- -- > instance Num a => Semigroup a where mappend a b = a + b -- > instance Num a => Commutative a -- > instance Num a => Monoid a where mempty = fromInteger 0 -- > instance Num a => Group a where inverse = negate -- > instance Num a => CommutativeGroup a instances :: QuasiQuoter instances = QuasiQuoter {quoteExp = err "Exp" ,quotePat = err "Pat" ,quoteType = err "Type" ,quoteDec = \s -> case parseDecs ("instance " ++ s) of Left e -> error e Right d -> fmap concat $ mapM splitInstances d} where err s = const $ error $ "quasiquoter `instances` expected Dec, instead used as " ++ s -- | Implements the @instances@ quasiquoter ast transform splitInstances :: Dec -> DecsQ splitInstances = \case InstanceD Nothing ctx (AppT (ConT className) instancesFor) declaredMethods -> do declaredMethods' <- M.fromList <$> traverse globalizeDef declaredMethods superclasses <- getTransitiveSuperclassNames className requiredMethods <- fold <$> M.traverseWithKey (\k _ -> getClassMethods k) superclasses -- Methods that will eventually need to be filled in let badMethods = filter (\x -> not $ M.member x requiredMethods) $ M.keys declaredMethods' unless (null badMethods) $ error $ "splitInstances: Trying to declare methods not in the superclass heirarchy\n" ++ unlines (map show badMethods) defaultMethods <- (`M.intersection` requiredMethods) <$> transitiveProvidedDefaults className {-defaultMethods <- M.traverseMaybeWithKey (\k _ -> getDefault k) requiredMethods-} let declaredMethods'' = declaredMethods' `M.union` defaultMethods superclassHasInstance <- M.traverseWithKey (\k _ -> isInstance k [instancesFor]) superclasses superclasses' <- fmap (fromKeys M.empty) $ traverse globalizeClass $ filter (\k -> not $ superclassHasInstance M.! k) $ M.keys superclasses classOps <- getClassOps (M.elems declaredMethods'') superclasses' let classDefs = M.map (\names -> (declaredMethods'' M.!) `M.mapKeys` names) classOps let instanceDecls = M.foldrWithKey (\c ms -> (declInstance ctx c instancesFor ms :)) [] classDefs pure instanceDecls d -> error $ "splitInstances: Not an instance declaration\n" ++ pprint d where declInstance ctx className targetType ms = InstanceD Nothing ctx (AppT (ConT className) targetType) (M.keys ms) -- Associate a definition with its toplevel qualified identifier globalizeDef d = (lookupValueName . occName . defName) d >>= \case Nothing -> error $ "globalizeDef: instance method " ++ show (occName (defName d)) ++ " not in scope" Just n -> pure (n,d) -- | Get the fully qualified name of a class globalizeClass :: Name -> Q Name globalizeClass c = (lookupTypeName . occName) c >>= \case Nothing -> error $ "globalizeClass: class " ++ show (occName c) ++ " not in scope" Just n -> pure n -- | Create a Map of className to method declaration from a list of instance method definitions getClassOps :: Traversable t => t Dec -> Map ParentName (Set Name) -> Q (Map ParentName (Set Name)) getClassOps decs superclasses = adjustMany (`M.insert` ()) superclasses <$> mapM (\d -> opClass <$> reify (defName d)) decs where opClass :: Info -> (ParentName, Name) opClass (ClassOpI n _t p) = (p,n) opClass x = error $ "opClass: not a class operation\n" ++ pprint x -- | Get the name of a function or value declaration defName :: Dec -> Name defName x = case x of FunD n _ -> n ValD (VarP n) _ _ -> n d -> error $ "defName: Declaration is not a Function or Value definition\n" ++ pprint d sigName :: Dec -> Name sigName = \case SigD n _ -> n d -> error $ "sigName: Declaration is not a type signature\n" ++ pprint d -- | reify the names of the direct superclasses for a class name getSuperclassNames :: Name -> Q [Name] getSuperclassNames className = do ClassI (ClassD ctx _ (fromKeys () . map _TyVarBndr_name -> classVars) _ _) _ <- reify className let -- if t represents a supeclass of n then `superclass t` is Just the superclass name, and Nothing otherwise superclass :: Type -> Maybe Name superclass = \case AppT t (VarT v) | M.member v classVars -> Just $ headAppT t AppT ConT{} _ -> Nothing AppT t _ -> superclass t x -> error $ show x pure $ mapMaybe superclass ctx where _TyVarBndr_name = \case {PlainTV n -> n; KindedTV n _ -> n} headAppT :: Type -> Name -- project the innermost @ConT@ in a chain of @AppT@ headAppT = \case ConT n -> n AppT t _ -> headAppT t x -> error $ "headAppT: Malformed type\n" ++ show x getClassMethods :: Name -> Q (Set Name) getClassMethods className = reify className <&> (\(ClassI (ClassD _ _ _ _ (map sigName -> methods)) _) -> fromKeys () methods) -- | reify the names of all transitive superclasses for a class name, including itself getTransitiveSuperclassNames :: Name -> Q (Map Name (Set a)) getTransitiveSuperclassNames = execWriterT . go where go n = do tell $ M.singleton n M.empty traverse_ go =<< lift (getSuperclassNames n) -- | Extract the unqualified part from a @Name@. For example: -- -- > show ''Show === "GHC.Show.Show" -- > occName ''Show === "Show" occName :: Name -> String occName (Name (OccName s) _) = s -- | reify the names of all transitive superclasses for a class name, including itself getTransitiveSuperclassNames' :: Name -> Q (Map Name Int) getTransitiveSuperclassNames' = execWriterT . go 0 where go i n = do tell $ M.singleton n i traverse_ (go (i+1)) =<< lift (getSuperclassNames n) transitiveProvidedDefaults :: Name -> Q (Map Name Dec) transitiveProvidedDefaults n = do sc <- M.toList <$> getTransitiveSuperclassNames' n defaultsMap <- M.map fst . M.unionsWith lowest -- Defaults from deeper subclasses take precidence <$> mapM (\(n',i) -> M.map (,i) <$> providedDefaults n') sc return $ M.mapWithKey name_dec defaultsMap where name_dec n' def = ValD (VarP n') (NormalB (VarE def)) [] lowest a@(_,i) b@(_,i') = if i <= i' then a else b -- | Get the default superclass method implementations provided by a subclass providedDefaults :: Name -> Q (Map Name Name) providedDefaults (AnnLookupName -> n) = do defaults <- reifyAnnotations n return $ M.fromList [(defining,definition) | Defaults{..} <- defaults]