{-# language NoMonoLocalBinds, TemplateHaskell #-}
module FCI.Internal.TH (
mkInst
, unsafeMkInst
, getClassDictInfo
, dictInst
) where
import Language.Haskell.TH.Syntax
import Control.Monad (when, unless)
import Control.Monad.ST (runST)
import Data.Char (isAlpha)
import qualified Data.Kind as K
import Data.List (foldl1')
import qualified Data.Map.Strict as M
import Data.Maybe (mapMaybe)
import Data.STRef (newSTRef, readSTRef, modifySTRef)
import Language.Haskell.TH (thisModule)
import FCI.Internal.Types (Inst, Dict)
mkInst :: Name -> Q [Dec]
mkInst name = checkSafeInst name *> unsafeMkInst name
checkSafeInst :: Name -> Q ()
checkSafeInst name = do
isExtEnabled QuantifiedConstraints >>= flip when do
fail "'QuantifiedConstraints' are not supported yet"
Module _ (ModName this_module) <- thisModule
unless (nameModule name == Just this_module) $ fail
$ '\'' : nameBase name ++ "' is not declared in current module '"
++ this_module ++ "'"
unsafeMkInst :: Name -> Q [Dec]
unsafeMkInst = fmap dictInst . getClassDictInfo
getClassDictInfo :: Name -> Q ClassDictInfo
getClassDictInfo className = reify className >>= \case
ClassI (ClassD constraints _ args _ methods) _ -> do
dictConName <- dictConFromClassName className
pure CDI{
className
, dictTyArg = foldl1' AppT $ ConT className : map bndrToType args
, dictConName
, dictFields = superFieldsFromCxt constraints
++ mapMaybe methodFieldFromDec methods
}
_ -> fail $ '\'' : nameBase className ++ "' is not a class"
dictConFromClassName :: Name -> Q Name
dictConFromClassName (nameBase -> name@(c : _)) = mkName <$> if
| isAlpha_ c -> pure name
| c == '(' -> fail $ "Attempt to use restricted class '" ++ name ++ "'"
| otherwise -> pure $ ':':name
dictConFromClassName _ = error "dictConFromClassName: empty 'Name'"
superFieldsFromCxt :: [Pred] -> [ClassDictField]
superFieldsFromCxt constraints = runST do
counts <- newSTRef M.empty
sequence $ mapMaybe (fmap . mkSuperField counts <*> appHeadName) constraints
where
mkSuperField counts c n = do
count <- maybe 0 id . M.lookup n <$> readSTRef counts
modifySTRef counts $ M.alter (maybe (Just 1) $ Just . (+1)) n
pure CDF{
fieldName = fieldFromClassName n count
, fieldSource = Superclass
, origName = n
, origType = c
}
fieldFromClassName :: Name -> Int -> Name
fieldFromClassName (nameBase -> name@(c:_)) count = mkName if
| isAlpha_ c -> "_" ++ name ++ index
| c == '(' -> "_Tuple" ++ index
| otherwise -> "||" ++ name ++ replicate count '|'
where
index = if count == 0 then "" else show count
fieldFromClassName _ _ = error "fieldFromClassName: empty 'Name'"
bndrToType :: TyVarBndr -> Type
bndrToType = \case
PlainTV n -> VarT n
KindedTV n k -> VarT n `SigT` k
appHeadName :: Type -> Maybe Name
appHeadName = \case
ForallT _ _ t -> appHeadName t
AppT t _ -> appHeadName t
SigT t _ -> appHeadName t
VarT n -> Just n
ConT n -> Just n
PromotedT n -> Just n
InfixT _ n _ -> Just n
UInfixT _ n _ -> Just n
ParensT t -> appHeadName t
TupleT i -> prod "(" ',' (i - 1) ")"
UnboxedTupleT i -> prod "(#" ',' (i - 1) "#)"
UnboxedSumT i -> prod "(#" '|' (i + 1) "#)"
ArrowT -> Just ''(->)
EqualityT -> Just ''(~)
ListT -> Just ''[]
PromotedTupleT i -> prod "(" ',' (i - 1) ")"
PromotedNilT -> Just '[]
PromotedConsT -> Just '(:)
StarT -> Just ''K.Type
ConstraintT -> Just ''K.Constraint
LitT{} -> Nothing
WildCardT -> Nothing
where
prod l d i r = Just $ mkName if
| i <= 0 -> l ++ r
| otherwise -> l ++ replicate i d ++ r
methodFieldFromDec :: Dec -> Maybe ClassDictField
methodFieldFromDec = \case
SigD n (ForallT _ _ t) -> Just CDF{
fieldName = fieldFromMethodName n
, fieldSource = Method
, origName = n
, origType = t
}
_ -> Nothing
fieldFromMethodName :: Name -> Name
fieldFromMethodName (nameBase -> name@(c:_)) = mkName if
| isAlpha_ c -> '_':name
| otherwise -> '|':name
fieldFromMethodName _ = error "fieldFromMethodName: empty 'Name'"
dictInst :: ClassDictInfo -> [Dec]
dictInst cdi = [
TySynInstD ''Inst $
TySynEqn [dictTyArg cdi] $ ConT ''Dict `AppT` dictTyArg cdi
, case classDictToRecField <$> dictFields cdi of
[] -> mk DataInstD [NormalC (dictConName cdi) [] ]
[field] -> mk NewtypeInstD (RecC (dictConName cdi) [field])
fields -> mk DataInstD [RecC (dictConName cdi) fields ]
]
where
mk con fields = con [] ''Dict [dictTyArg cdi] Nothing fields []
classDictToRecField :: ClassDictField -> VarBangType
classDictToRecField cdf = (
fieldName cdf
, Bang NoSourceUnpackedness NoSourceStrictness
, (case fieldSource cdf of
Superclass -> AppT $ ConT ''Inst
Method -> id
) $ origType cdf
)
data ClassDictInfo = CDI{
className :: Name
, dictTyArg :: Pred
, dictConName :: Name
, dictFields :: [ClassDictField]
} deriving Show
data ClassDictField = CDF{
fieldName :: Name
, fieldSource :: ClassDictFieldSource
, origName :: Name
, origType :: Type
} deriving Show
data ClassDictFieldSource = Superclass | Method deriving Show
isAlpha_ :: Char -> Bool
isAlpha_ c = isAlpha c || c == '_'