{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}
module Data.Constraint.Deriving.ClassDict
( ClassDict (..)
, classDictPass
, CorePluginEnvRef, initCorePluginEnv
) where
import Control.Monad (join, unless, when)
import Data.Data (Data)
import Data.Maybe (fromMaybe, isJust)
import GhcPlugins hiding (OverlapMode (..), overlapMode, mkFunTy)
import qualified Unify
import Data.Constraint.Deriving.CorePluginM
data ClassDict = ClassDict
deriving (Eq, Show, Read, Data)
classDictPass :: CorePluginEnvRef -> CoreToDo
classDictPass eref = CoreDoPluginPass "Data.Constraint.Deriving.ClassDict"
(\x -> fromMaybe x <$> runCorePluginM (classDictPass' x) eref)
classDictPass' :: ModGuts -> CorePluginM ModGuts
classDictPass' guts = do
(remAnns, processedBinds) <- runWithAnns (traverse go (mg_binds guts)) annotateds
unless (isNullUFM remAnns) $
pluginWarning $ "One or more ClassDict annotations are ignored:"
$+$ vcat
(map pprBulletNameLoc . join $ eltsUFM remAnns)
$$ "Note possible issues:"
$$ pprNotes
[ "ClassDict is meant to be used only on bindings of type Ctx => Dict (Class t1 .. tn)."
, "GHC may remove the annotated definition if it is not reachable from module exports."
]
return guts { mg_binds = processedBinds}
where
annotateds :: UniqFM [Name]
annotateds = map fst <$> (getModuleAnns guts :: UniqFM [(Name, ClassDict)])
go :: CoreBind -> WithAnns CoreBind
go (NonRec b e) = NonRec b <$> classDict' b e
go (Rec bes) = Rec <$> traverse (\(b, e) -> (,) b <$> classDict' b e) bes
pprBulletNameLoc n = hsep
[" " , bullet, ppr $ occName n, ppr $ nameSrcSpan n]
pprNotes = vcat . map (\x -> hsep [" ", bullet, x])
classDict' x origBind = WithAnns $ \anns -> case lookupUFM anns x of
Just (xn:xns) -> do
unless (null xns) $
pluginLocatedWarning (nameSrcSpan xn) $
"Ignoring redundant ClassDict annotations" $$
hcat
[ "(the plugin needs only one annotation per binding, but got "
, speakN (length xns + 1)
, ")"
]
(,) (delFromUFM anns x) . fromMaybe origBind <$> try (classDict x)
_ -> return (anns, origBind)
newtype WithAnns a = WithAnns
{ runWithAnns :: UniqFM [Name] -> CorePluginM (UniqFM [Name], a) }
instance Functor WithAnns where
fmap f m = WithAnns $ \anns -> fmap f <$> runWithAnns m anns
instance Applicative WithAnns where
pure x = WithAnns $ \anns -> pure (anns, x)
mf <*> mx = WithAnns $ \anns0 -> do
(anns1, f) <- runWithAnns mf anns0
(anns2, x) <- runWithAnns mx anns1
pure (anns2, f x)
classDict :: CoreBndr -> CorePluginM CoreExpr
classDict bindVar = do
tcDict <- ask tyConDict
let conDict = tyConSingleDataCon tcDict
dictContentTy <- case splitTyConApp_maybe dictTy of
Just (tcDict', [resTy])
| tcDict' == tcDict -> pure resTy
err -> pluginLocatedError loc $ vcat
[ hsep
[ "Expected `Dict (Cls t1..tn)', but found", ppr dictTy]
, if isJust err
then "(constructor or number of arguments do not match)."
else "(I could not split apart a constructor application)."
, notGoodMsg
]
(klass, instanceArgs) <- case splitTyConApp_maybe dictContentTy of
Just (klassTyCon, iArgs)
| Just klas <- tyConClass_maybe klassTyCon
-> pure (klas, iArgs)
| otherwise
-> pluginLocatedError loc $ vcat
[ hsep
[ "Expected a class constructor, but found", ppr klassTyCon]
, "(not a class data constructor)."
, notGoodMsg
]
Nothing -> pluginLocatedError loc $ vcat
[ hsep
[ "Expected a class constructor, but found", ppr dictContentTy]
, "(I could not split apart a constructor application)."
, notGoodMsg
]
let klassDataCon = classDataCon klass
let expectedType = mapResultType (mkTyConApp tcDict . (:[]))
. idType $ dataConWorkId klassDataCon
when (Unify.typesCantMatch [(origBindTy, expectedType)]) $
pluginLocatedError loc $ vcat
[ hsep
[ "Cannot match the expected type (the type of the data constructor of the given class)"
, "and the found type (the user-supplied binding)."]
, hsep ["Expected type:", ppr expectedType]
, hsep ["Found type: ", ppr origBindTy]
]
argVars <- traverse (flip newLocalVar "t") argTys
return
. mkCoreLams (bndrs ++ argVars)
$ mkCoreConApps conDict
[ mkTyArg dictContentTy
, klassDataCon `mkCoreConApps`
(map mkTyArg instanceArgs ++ varsToCoreExprs argVars)
]
where
origBindTy = idType bindVar
(bndrs, bindTy) = splitForAllTys origBindTy
(argTys, dictTy) = splitFunTys bindTy
loc = nameSrcSpan $ getName bindVar
notGoodMsg =
"ClassDict plugin pass failed to process a Dict declaraion."
$$ "The declaration must have form `forall a1..an . Ctx => Dict (Cls t1..tn)'"
$$ "Declaration:"
$$ hcat
[ " "
, ppr bindVar, " :: "
, ppr origBindTy
]
mapResultType :: (Type -> Type) -> Type -> Type
mapResultType f t
| (bndrs@(_:_), t') <- splitForAllTys t
= mkSpecForAllTys bndrs $ mapResultType f t'
| Just (vis, at, rt) <- splitFunTyArg_maybe t
= mkFunTy vis at (mapResultType f rt)
| otherwise
= f t