{-# language ScopedTypeVariables #-}
{-# language RecordWildCards #-}
{-# language TupleSections #-}
{-# language FlexibleInstances #-}
{-# language MultiParamTypeClasses #-}
{-# language FlexibleContexts #-}
{-# language GADTs #-}
{-# language LambdaCase #-}
{-# language ViewPatterns #-}
module Language.Haskell.TH.Instances.Internal
(module Language.Haskell.TH.Instances.Internal
,module X) where
import Language.Haskell.TH as X
import Language.Haskell.TH.Instances.Internal.Utils as X
import Language.Haskell.TH.Instances.Defaults as X
import Language.Haskell.TH.Syntax as X hiding (lift)
import Language.Haskell.Meta.Parse as X (parseDecs)
import Language.Haskell.TH.Quote as X (QuasiQuoter(..))
import Data.Map as X (Map)
import qualified Data.Map as M
import Data.Maybe as X (mapMaybe)
import Control.Monad.Writer as X
import Data.Foldable as X
instances :: QuasiQuoter
instances = QuasiQuoter
{quoteExp = err "Exp"
,quotePat = err "Pat"
,quoteType = err "Type"
,quoteDec = \s -> case parseDecs ("instance " ++ s) of
Left e -> error e
Right d -> fmap concat $ mapM splitInstances d}
where err s = const $ error $ "quasiquoter `instances` expected Dec, instead used as " ++ s
splitInstances :: Dec -> DecsQ
splitInstances = \case
InstanceD Nothing ctx (AppT (ConT className) instancesFor) declaredMethods -> do
declaredMethods' <- M.fromList <$> traverse globalizeDef declaredMethods
superclasses <- getTransitiveSuperclassNames className
requiredMethods <- fold <$> M.traverseWithKey (\k _ -> getClassMethods k) superclasses
let badMethods = filter (\x -> not $ M.member x requiredMethods) $ M.keys declaredMethods'
unless (null badMethods) $
error $ "splitInstances: Trying to declare methods not in the superclass heirarchy\n"
++ unlines (map show badMethods)
defaultMethods <- (`M.intersection` requiredMethods) <$> transitiveProvidedDefaults className
let declaredMethods'' = declaredMethods' `M.union` defaultMethods
superclassHasInstance <- M.traverseWithKey (\k _ -> isInstance k [instancesFor]) superclasses
superclasses' <- fmap (fromKeys M.empty) $ traverse globalizeClass $ filter (\k -> not $ superclassHasInstance M.! k) $ M.keys superclasses
classOps <- getClassOps (M.elems declaredMethods'') superclasses'
let classDefs = M.map (\names -> (declaredMethods'' M.!) `M.mapKeys` names) classOps
let instanceDecls = M.foldrWithKey (\c ms -> (declInstance ctx c instancesFor ms :)) [] classDefs
pure instanceDecls
d -> error $ "splitInstances: Not an instance declaration\n" ++ pprint d
where
declInstance ctx className targetType ms = InstanceD Nothing ctx (AppT (ConT className) targetType) (M.keys ms)
globalizeDef d = (lookupValueName . occName . defName) d >>= \case
Nothing -> error $ "globalizeDef: instance method " ++ show (occName (defName d)) ++ " not in scope"
Just n -> pure (n,d)
globalizeClass :: Name -> Q Name
globalizeClass c = (lookupTypeName . occName) c >>= \case
Nothing -> error $ "globalizeClass: class " ++ show (occName c) ++ " not in scope"
Just n -> pure n
getClassOps :: Traversable t => t Dec -> Map ParentName (Set Name) -> Q (Map ParentName (Set Name))
getClassOps decs superclasses = adjustMany (`M.insert` ()) superclasses <$> mapM (\d -> opClass <$> reify (defName d)) decs
where
opClass :: Info -> (ParentName, Name)
opClass (ClassOpI n _t p) = (p,n)
opClass x = error $ "opClass: not a class operation\n" ++ pprint x
defName :: Dec -> Name
defName x = case x of
FunD n _ -> n
ValD (VarP n) _ _ -> n
d -> error $ "defName: Declaration is not a Function or Value definition\n" ++ pprint d
sigName :: Dec -> Name
sigName = \case
SigD n _ -> n
d -> error $ "sigName: Declaration is not a type signature\n" ++ pprint d
getSuperclassNames :: Name -> Q [Name]
getSuperclassNames className = do
ClassI (ClassD ctx _ (fromKeys () . map _TyVarBndr_name -> classVars) _ _) _ <- reify className
let
superclass :: Type -> Maybe Name
superclass = \case
AppT t (VarT v) | M.member v classVars -> Just $ headAppT t
AppT ConT{} _ -> Nothing
AppT t _ -> superclass t
x -> error $ show x
pure $ mapMaybe superclass ctx
where
_TyVarBndr_name = \case {PlainTV n -> n; KindedTV n _ -> n}
headAppT :: Type -> Name
headAppT = \case
ConT n -> n
AppT t _ -> headAppT t
x -> error $ "headAppT: Malformed type\n" ++ show x
getClassMethods :: Name -> Q (Set Name)
getClassMethods className = reify className <&> (\(ClassI (ClassD _ _ _ _ (map sigName -> methods)) _) -> fromKeys () methods)
getTransitiveSuperclassNames :: Name -> Q (Map Name (Set a))
getTransitiveSuperclassNames = execWriterT . go where
go n = do
tell $ M.singleton n M.empty
traverse_ go =<< lift (getSuperclassNames n)
occName :: Name -> String
occName (Name (OccName s) _) = s
getTransitiveSuperclassNames' :: Name -> Q (Map Name Int)
getTransitiveSuperclassNames' = execWriterT . go 0 where
go i n = do
tell $ M.singleton n i
traverse_ (go (i+1)) =<< lift (getSuperclassNames n)
transitiveProvidedDefaults :: Name -> Q (Map Name Dec)
transitiveProvidedDefaults n = do
sc <- M.toList <$> getTransitiveSuperclassNames' n
defaultsMap <- M.map fst . M.unionsWith lowest
<$> mapM (\(n',i) -> M.map (,i) <$> providedDefaults n') sc
return $ M.mapWithKey name_dec defaultsMap
where name_dec n' def = ValD (VarP n') (NormalB (VarE def)) []
lowest a@(_,i) b@(_,i') = if i <= i' then a else b
providedDefaults :: Name -> Q (Map Name Name)
providedDefaults (AnnLookupName -> n) = do
defaults <- reifyAnnotations n
return $ M.fromList [(defining,definition) | Defaults{..} <- defaults]