{-# LANGUAGE DeriveDataTypeable #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE OverloadedStrings #-} module Data.Constraint.Deriving.ClassDict ( ClassDict (..) , classDictPass , CorePluginEnvRef, initCorePluginEnv ) where import Control.Monad (join, unless, when) import Data.Data (Data) import Data.Maybe (fromMaybe, isJust) import GhcPlugins hiding (OverlapMode (..), overlapMode, mkFunTy) import qualified Unify import Data.Constraint.Deriving.CorePluginM {- | A marker to tell the core plugin to replace the implementation of a top-level function by a corresponding class data constructor (wrapped into `Data.Constraint.Dict`). Example: @ class BarClass a => FooClass a where fooFun1 :: a -> a -> Int fooFun2 :: a -> Bool {\-\# ANN deriveFooClass ClassDict \#-\} deriveFooClass :: forall a . BarClass a => (a -> a -> Int) -> (a -> Bool) -> Dict (FooClass a) deriveFooClass = deriveFooClass @ That is, the plugin replaces the RHS of @deriveFooClass@ function with `DataCon.classDataCon` wrapped by `bareToDict`. Note: * The plugin requires you to create a dummy function `deriveFooClass` and annotate it with `ClassDict` instead of automatically creating this function for you; this way, the function is visible during type checking: you can use it in the same module (avoiding orphans) and you see its type signature. * You have to provide a correct signature for `deriveFooClass` function; the plugin compares this signature against visible classes and their constructors. An incorrect signature will result in a compile-time error. * The dummy implementation @deriveFooClass = deriveFooClass@ is used here to prevent GHC from inlining the function before the plugin can replace it. But you can implement in any way you like at your own risk. -} data ClassDict = ClassDict deriving (Eq, Show, Read, Data) -- | Run `ClassDict` plugin pass classDictPass :: CorePluginEnvRef -> CoreToDo classDictPass eref = CoreDoPluginPass "Data.Constraint.Deriving.ClassDict" -- if a plugin pass totally fails to do anything useful, -- copy original ModGuts as its output, so that next passes can do their jobs. (\x -> fromMaybe x <$> runCorePluginM (classDictPass' x) eref) classDictPass' :: ModGuts -> CorePluginM ModGuts classDictPass' guts = do (remAnns, processedBinds) <- runWithAnns (traverse go (mg_binds guts)) annotateds unless (isNullUFM remAnns) $ pluginWarning $ "One or more ClassDict annotations are ignored:" $+$ vcat (map pprBulletNameLoc . join $ eltsUFM remAnns) $$ "Note possible issues:" $$ pprNotes [ "ClassDict is meant to be used only on bindings of type Ctx => Dict (Class t1 .. tn)." , "GHC may remove the annotated definition if it is not reachable from module exports." ] return guts { mg_binds = processedBinds} where annotateds :: UniqFM [Name] annotateds = map fst <$> (getModuleAnns guts :: UniqFM [(Name, ClassDict)]) go :: CoreBind -> WithAnns CoreBind go (NonRec b e) = NonRec b <$> classDict' b e go (Rec bes) = Rec <$> traverse (\(b, e) -> (,) b <$> classDict' b e) bes pprBulletNameLoc n = hsep [" " , bullet, ppr $ occName n, ppr $ nameSrcSpan n] pprNotes = vcat . map (\x -> hsep [" ", bullet, x]) classDict' x origBind = WithAnns $ \anns -> case lookupUFM anns x of Just (xn:xns) -> do unless (null xns) $ pluginLocatedWarning (nameSrcSpan xn) $ "Ignoring redundant ClassDict annotations" $$ hcat [ "(the plugin needs only one annotation per binding, but got " , speakN (length xns + 1) , ")" ] -- add new definitions and continue (,) (delFromUFM anns x) . fromMaybe origBind <$> try (classDict x) _ -> return (anns, origBind) -- a small state transformer for tracking remaining annotations newtype WithAnns a = WithAnns { runWithAnns :: UniqFM [Name] -> CorePluginM (UniqFM [Name], a) } instance Functor WithAnns where fmap f m = WithAnns $ \anns -> fmap f <$> runWithAnns m anns instance Applicative WithAnns where pure x = WithAnns $ \anns -> pure (anns, x) mf <*> mx = WithAnns $ \anns0 -> do (anns1, f) <- runWithAnns mf anns0 (anns2, x) <- runWithAnns mx anns1 pure (anns2, f x) -- | Replace a given CoreBind with a corresponding class DataCon fun implementation. -- -- The core bind must have type `Ctx => Dict (Class t1 .. tn)`; -- it does not change. classDict :: CoreBndr -> CorePluginM CoreExpr classDict bindVar = do -- get necessary definitions tcDict <- ask tyConDict let conDict = tyConSingleDataCon tcDict -- check that the outermost constructor of the result type is Dict -- and unwrap it. dictContentTy <- case splitTyConApp_maybe dictTy of Just (tcDict', [resTy]) | tcDict' == tcDict -> pure resTy err -> pluginLocatedError loc $ vcat [ hsep [ "Expected `Dict (Cls t1..tn)', but found", ppr dictTy] , if isJust err then "(constructor or number of arguments do not match)." else "(I could not split apart a constructor application)." , notGoodMsg ] -- check if the content of the result Dict is indeed a class constraint -- and get the class and its arguments. (klass, instanceArgs) <- case splitTyConApp_maybe dictContentTy of Just (klassTyCon, iArgs) | Just klas <- tyConClass_maybe klassTyCon -> pure (klas, iArgs) | otherwise -> pluginLocatedError loc $ vcat [ hsep [ "Expected a class constructor, but found", ppr klassTyCon] , "(not a class data constructor)." , notGoodMsg ] Nothing -> pluginLocatedError loc $ vcat [ hsep [ "Expected a class constructor, but found", ppr dictContentTy] , "(I could not split apart a constructor application)." , notGoodMsg ] -- the core of the plugin: use a class data constructor let klassDataCon = classDataCon klass -- check if the types agree let expectedType = mapResultType (mkTyConApp tcDict . (:[])) . idType $ dataConWorkId klassDataCon when (Unify.typesCantMatch [(origBindTy, expectedType)]) $ pluginLocatedError loc $ vcat [ hsep [ "Cannot match the expected type (the type of the data constructor of the given class)" , "and the found type (the user-supplied binding)."] , hsep ["Expected type:", ppr expectedType] , hsep ["Found type: ", ppr origBindTy] ] argVars <- traverse (flip newLocalVar "t") argTys return . mkCoreLams (bndrs ++ argVars) $ mkCoreConApps conDict [ mkTyArg dictContentTy , klassDataCon `mkCoreConApps` (map mkTyArg instanceArgs ++ varsToCoreExprs argVars) ] where origBindTy = idType bindVar (bndrs, bindTy) = splitForAllTys origBindTy (argTys, dictTy) = splitFunTys bindTy loc = nameSrcSpan $ getName bindVar notGoodMsg = "ClassDict plugin pass failed to process a Dict declaraion." $$ "The declaration must have form `forall a1..an . Ctx => Dict (Cls t1..tn)'" $$ "Declaration:" $$ hcat [ " " , ppr bindVar, " :: " , ppr origBindTy ] -- | Transform the result type in a more complex fun type. mapResultType :: (Type -> Type) -> Type -> Type mapResultType f t | (bndrs@(_:_), t') <- splitForAllTys t = mkSpecForAllTys bndrs $ mapResultType f t' | Just (vis, at, rt) <- splitFunTyArg_maybe t = mkFunTy vis at (mapResultType f rt) | otherwise = f t