module Record.TH where
import BasePrelude hiding (Proxy)
import GHC.TypeLits
import Language.Haskell.TH
classPCompat :: Name -> [Type] -> Pred
#if MIN_VERSION_template_haskell(2,10,0)
classPCompat n tl = foldl AppT (ConT n) tl
#else
classPCompat = ClassP
#endif
instanceDCompat :: Cxt -> Type -> [Dec] -> Dec
#if MIN_VERSION_template_haskell(2,11,0)
instanceDCompat =
InstanceD Nothing
#else
instanceDCompat =
InstanceD
#endif
recordTypeDec :: Bool -> Int -> Dec
recordTypeDec strict arity =
#if MIN_VERSION_template_haskell(2,11,0)
DataD [] name varBndrs Nothing [NormalC name conTypes] derivingNames
#else
DataD [] name varBndrs [NormalC name conTypes] derivingNames
#endif
where
name =
recordName strict arity
varBndrs =
do
i <- [1 .. arity]
let
n = KindedTV (mkName ("n" <> show i)) (ConT ''Symbol)
v = PlainTV (mkName ("v" <> show i))
in [n, v]
conTypes =
do
i <- [1 .. arity]
return $ (,) strictness (VarT (mkName ("v" <> show i)))
where
strictness =
#if MIN_VERSION_template_haskell(2,11,0)
if strict then Bang NoSourceUnpackedness SourceStrict else Bang NoSourceUnpackedness NoSourceStrictness
#else
if strict then IsStrict else NotStrict
#endif
derivingNames =
#if MIN_VERSION_template_haskell(2,11,0)
map ConT
#endif
#if MIN_VERSION_base(4,7,0)
[''Show, ''Eq, ''Ord, ''Typeable, ''Generic]
#else
[''Show, ''Eq, ''Ord, ''Generic]
#endif
recordName :: Bool -> Int -> Name
recordName strict arity =
mkName $ recordNameString strict arity
recordNameString :: Bool -> Int -> String
recordNameString strict arity =
prefix <> "Record" <> show arity
where
prefix =
if strict then "Strict" else "Lazy"
recordFieldInstanceDec :: Bool -> Int -> Int -> Dec
recordFieldInstanceDec strict =
fieldInstanceDec (FieldInstanceDecMode_Record strict)
tupleFieldInstanceDec :: Int -> Int -> Dec
tupleFieldInstanceDec arity fieldIndex =
fieldInstanceDec FieldInstanceDecMode_Tuple arity fieldIndex
data FieldInstanceDecMode =
FieldInstanceDecMode_Tuple |
FieldInstanceDecMode_Record Bool
fieldInstanceDec :: FieldInstanceDecMode -> Int -> Int -> Dec
fieldInstanceDec mode arity fieldIndex =
instanceDCompat [] headType decs
where
headType =
foldl1 AppT
[classType, VarT selectedNVarName, recordType, recordPrimeType, VarT selectedVVarName, VarT selectedVPrimeVarName]
where
classType =
ConT (mkName "Field")
typeName =
case mode of
FieldInstanceDecMode_Tuple -> tupleTypeName arity
FieldInstanceDecMode_Record strict -> recordName strict arity
conName =
case mode of
FieldInstanceDecMode_Tuple -> tupleDataName arity
FieldInstanceDecMode_Record strict -> recordName strict arity
selectedNVarName =
mkName $ "n" <> show fieldIndex
selectedVVarName =
mkName $ "v" <> show fieldIndex
selectedVPrimeVarName =
mkName $ "v" <> show fieldIndex <> "'"
recordType =
foldl (\a i -> AppT (addNVar a i)
(VarT (mkName ("v" <> show i))))
(ConT typeName)
[1 .. arity]
recordPrimeType =
foldl (\a i -> AppT (addNVar a i)
(VarT (if i == fieldIndex then selectedVPrimeVarName
else mkName ("v" <> show i))))
(ConT typeName)
[1 .. arity]
addNVar =
case mode of
FieldInstanceDecMode_Tuple -> \a i -> a
FieldInstanceDecMode_Record _ -> \a i -> AppT a (VarT (mkName ("n" <> show i)))
decs =
[fieldLensDec]
where
fieldLensDec =
FunD (mkName "fieldLens") [Clause patterns (NormalB exp) []]
where
patterns =
[WildP, VarP fVarName, ConP conName (fmap VarP indexedVVarNames)]
fVarName =
mkName "f"
indexedVVarNames =
fmap (\i -> mkName ("v" <> show i)) [1..arity]
exp =
AppE (AppE (VarE 'fmap) (consLambda))
(AppE (VarE fVarName) (VarE selectedVVarName))
where
consLambda =
LamE [VarP selectedVPrimeVarName] exp
where
exp =
foldl AppE (ConE conName) $
map VarE $
map (\(i, n) -> if i == fieldIndex then selectedVPrimeVarName
else mkName ("v" <> show i)) $
zip [1 .. arity] indexedVVarNames
recordStorableInstanceDec :: Bool -> Int -> Dec
recordStorableInstanceDec strict arity =
instanceDCompat
context
(AppT (ConT (mkName "Storable")) recordType)
[sizeOfD, inlineP "sizeOf", alignmentD, inlineP "alignment", peekD, inlineP "peek", pokeD, inlineP "poke"]
where
name = recordName strict arity
recordType =
foldl (\a i -> AppT (AppT a (VarT (mkName ("n" <> show i))))
(VarT (mkName ("v" <> show i))))
(ConT name)
[1 .. arity]
context = map (\i -> classPCompat (mkName "Storable") [VarT (mkName ("v" <> show i))])
[1 .. arity]
inlineP name = PragmaD $ InlineP (mkName name) Inline FunLike AllPhases
calcDs = map alignmentkD [1 .. arity] ++
map sizekD [1 .. arity] ++
begin0D : map beginkD [1 .. arity] ++
end0D : map endkD [1 .. arity]
where
alignmentkD k = ValD (VarP (mkName ("a" <> show k)))
(NormalB (AppE (VarE (mkName "alignment"))
(SigE (AppE (VarE (mkName "error"))
(LitE (StringL "non-constant alignment not supported")))
(VarT (mkName ("v" <> show k))))))
[]
sizekD k = ValD (VarP (mkName ("s" <> show k)))
(NormalB (AppE (VarE (mkName "sizeOf"))
(SigE (AppE (VarE (mkName "error"))
(LitE (StringL "non-constant sizeOf not supported")))
(VarT (mkName ("v" <> show k))))))
[]
begin0D = ValD (VarP (mkName "b0")) (NormalB (LitE (IntegerL 0))) []
beginkD k = ValD (VarP (mkName ("b" <> show k)))
(NormalB (UInfixE (VarE (mkName ("e" <> show (k1))))
(VarE (mkName "+"))
(UInfixE (AppE (VarE (mkName "negate"))
(VarE (mkName ("e" <> show (k1)))))
(VarE (mkName "mod"))
(VarE (mkName ("a" <> show k))))))
[]
end0D = ValD (VarP (mkName "e0")) (NormalB (LitE (IntegerL 0))) []
endkD k = ValD (VarP (mkName ("e" <> show k)))
(NormalB (UInfixE (VarE (mkName ("b" <> show k)))
(VarE (mkName "+"))
(VarE (mkName ("s" <> show k)))))
[]
sizeOfD = FunD (mkName "sizeOf")
[Clause [WildP]
(NormalB (VarE (mkName ("e" <> show arity))))
calcDs]
alignmentD = FunD (mkName "alignment")
[Clause [WildP]
(NormalB (AppE (VarE (mkName "maximum"))
(ListE (map (\k -> (VarE (mkName ("a" <> show k))))
[1 .. arity]))))
calcDs]
peekD = FunD (mkName "peek")
[Clause [VarP (mkName "p")]
(NormalB (foldl (\e k -> UInfixE e (VarE (mkName "<*>"))
(AppE (VarE (mkName "peek"))
(AppE (AppE (VarE (mkName "plusPtr"))
(VarE (mkName "p")))
(VarE (mkName ("b" <> show k))))))
(AppE (VarE (mkName "pure")) (ConE name))
[1 .. arity]))
calcDs]
pokeD = FunD (mkName "poke")
[Clause [VarP (mkName "p"),
ConP name (map (\k -> VarP (mkName ("v" <> show k))) [1 .. arity])]
(NormalB (foldl (\e k -> UInfixE e (VarE (mkName "<*"))
(AppE (AppE (VarE (mkName "poke"))
(AppE (AppE (VarE (mkName "plusPtr"))
(VarE (mkName "p")))
(VarE (mkName ("b" <> show k)))))
(VarE (mkName ("v" <> show k)))))
(AppE (VarE (mkName "pure")) (TupE []))
[1 .. arity]))
calcDs]
recordConFunDecs :: Bool -> Int -> [Dec]
recordConFunDecs strict arity =
[inline, signature, fun]
where
inline =
PragmaD (InlineP name Inline FunLike AllPhases)
signature =
SigD name type_
where
type_ =
ForallT varBndrs [] $
foldr AppT recordType $
map (AppT ArrowT) $
interleave nameProxyTypes valueVariableTypes
where
varBndrs =
map PlainTV $
interleave nameVariableNames valueVariableNames
recordType =
foldl' AppT (ConT (recordName strict arity)) $
interleave nameVariableTypes valueVariableTypes
valueVariableTypes =
map VarT valueVariableNames
valueVariableNames =
map (\i -> mkName ("v" <> show i)) [1 .. arity]
nameVariableTypes =
map VarT nameVariableNames
nameVariableNames =
map (\i -> mkName ("n" <> show i)) [1 .. arity]
nameProxyTypes =
map (AppT (ConT (mkName "FieldName"))) nameVariableTypes
interleave a b =
join $ zipWith (\a b -> [a, b]) a b
fun =
FunD name [Clause [] (NormalB (recordConLambdaExp strict arity)) []]
name =
mkName string
where
string =
onHead toLower (recordNameString strict arity)
where
onHead f =
\case
a : b -> f a : b
[] -> []
recordConLambdaExp :: Bool -> Int -> Exp
recordConLambdaExp strict arity =
SigE exp t
where
name =
recordName strict arity
exp =
LamE pats exp
where
pats =
concat $ flip map [1 .. arity] $ \i -> [WildP, VarP (mkName ("v" <> show i))]
exp =
foldl AppE (ConE name) (map (\i -> VarE (mkName ("v" <> show i))) [1 .. arity])
t =
fnType name
where
fnType conName =
ForallT varBndrs [] $
foldr1 (\l r -> AppT (AppT ArrowT l) r)
(argTypes <> pure (resultType conName))
varBndrs =
concat $ flip map [1 .. arity] $ \i ->
PlainTV (mkName ("n" <> show i)) :
PlainTV (mkName ("v" <> show i)) :
[]
argTypes =
concat $ flip map [1 .. arity] $ \i ->
AppT (ConT (mkName "FieldName")) (VarT (mkName ("n" <> show i))) :
VarT (mkName ("v" <> show i)) :
[]
resultType conName =
foldl AppT (ConT conName) $ concat $ flip map [1 .. arity] $ \i ->
VarT (mkName ("n" <> show i)) :
VarT (mkName ("v" <> show i)) :
[]