{-# language CPP
, TemplateHaskell #-}
module Yaya.Retrofit
( module Yaya.Fold
, PatternFunctorRules (..)
, defaultRules
, extractPatternFunctor
) where
import Control.Exception (Exception (..), throw)
import Control.Monad ((<=<))
import Data.Bifunctor (bimap)
import Data.Either.Validation (Validation (..), validationToEither)
import Data.Functor.Identity (Identity (..))
import Data.List.NonEmpty (NonEmpty)
import Language.Haskell.TH as TH
import Language.Haskell.TH.Datatype as TH.Abs
import Language.Haskell.TH.Syntax (mkNameG_tc)
import Text.Read.Lex (isSymbolChar)
import Yaya.Fold
( Corecursive (..)
, Projectable (..)
, Recursive (..)
, Steppable (..)
, recursiveEq
, recursiveShowsPrec
)
extractPatternFunctor :: PatternFunctorRules -> Name -> Q [Dec]
extractPatternFunctor rules =
either throw id . makePrimForDI rules <=< reifyDatatype
data PatternFunctorRules = PatternFunctorRules
{ patternType :: Name -> Name
, patternCon :: Name -> Name
, patternField :: Name -> Name
}
defaultRules :: PatternFunctorRules
defaultRules = PatternFunctorRules
{ patternType = toFName
, patternCon = toFName
, patternField = toFName
}
toFName :: Name -> Name
toFName = mkName . f . nameBase
where
f name | isInfixName name = name ++ "$"
| otherwise = name ++ "F"
isInfixName :: String -> Bool
isInfixName = all isSymbolChar
data UnsupportedDatatype
= UnsupportedInstTypes (NonEmpty Type)
| UnsupportedVariant DatatypeVariant
instance Show UnsupportedDatatype where
show = \case
UnsupportedInstTypes tys ->
"extractPatternFunctor: Couldn't process the following types " <> show tys
UnsupportedVariant _variant ->
"extractPatternFunctor: Data families are currently not supported."
instance Exception UnsupportedDatatype
makePrimForDI
:: PatternFunctorRules -> DatatypeInfo -> Either UnsupportedDatatype (Q [Dec])
makePrimForDI
rules
(DatatypeInfo { datatypeName = tyName
, datatypeInstTypes = instTys
, datatypeCons = cons
, datatypeVariant = variant }) =
if isDataFamInstance
then Left $ UnsupportedVariant variant
else
bimap
UnsupportedInstTypes
(flip (makePrimForDI' rules (variant == Newtype) tyName) cons)
. validationToEither
$ traverse (\ty -> maybe (Failure $ pure ty) Success $ toTyVarBndr ty) instTys
where
isDataFamInstance = case variant of
DataInstance -> True
NewtypeInstance -> True
Datatype -> False
Newtype -> False
toTyVarBndr :: Type -> Maybe TyVarBndr
toTyVarBndr (VarT n) = pure $ PlainTV n
toTyVarBndr (SigT (VarT n) k) = pure $ KindedTV n k
toTyVarBndr _ = Nothing
makePrimForDI'
:: PatternFunctorRules -> Bool -> Name -> [TyVarBndr] -> [ConstructorInfo] -> Q [Dec]
makePrimForDI' rules isNewtype tyName vars cons = do
let vars' = map VarT (typeVars vars)
let tyNameF = patternType rules tyName
let s = conAppsT tyName vars'
rName <- newName "r"
let r = VarT rName
let varsF = vars ++ [PlainTV rName]
cons' <- traverse (conTypeTraversal resolveTypeSynonyms) cons
let consF
= toCon
. conNameMap (patternCon rules)
. conFieldNameMap (patternField rules)
. conTypeMap (substType s r)
<$> cons'
let dataDec = case consF of
[conF] | isNewtype ->
NewtypeD [] tyNameF varsF Nothing conF deriveds
_ -> DataD [] tyNameF varsF Nothing consF deriveds
where
deriveds =
#if MIN_VERSION_template_haskell(2,12,0)
pure $ DerivClause Nothing
#endif
[ ConT functorTypeName
, ConT foldableTypeName
, ConT traversableTypeName ]
recursiveDec <-
[d|
instance Projectable (->) $(pure s) $(pure $ conAppsT tyNameF vars') where
project = $(LamCaseE <$> mkMorphism id (patternCon rules) cons')
instance Steppable (->) $(pure s) $(pure $ conAppsT tyNameF vars') where
embed = $(LamCaseE <$> mkMorphism (patternCon rules) id cons')
instance Recursive (->) $(pure s) $(pure $ conAppsT tyNameF vars') where
cata φ = φ . fmap (cata φ) . project
instance Corecursive (->) $(pure s) $(pure $ conAppsT tyNameF vars') where
ana ψ = embed . fmap (ana ψ) . ψ
|]
pure ([dataDec] <> recursiveDec)
mkMorphism
:: (Name -> Name)
-> (Name -> Name)
-> [ConstructorInfo]
-> Q [Match]
mkMorphism nFrom nTo =
traverse
(\ci -> do
let n = constructorName ci
fs <- traverse (const $ newName "x") $ constructorFields ci
pure
$ Match
(ConP (nFrom n) (map VarP fs))
(NormalB $ foldl AppE (ConE $ nTo n) (map VarE fs))
[]
)
conNameTraversal :: Traversal' ConstructorInfo Name
conNameTraversal = lens constructorName (\s v -> s { constructorName = v })
conFieldNameTraversal :: Traversal' ConstructorInfo Name
conFieldNameTraversal = lens constructorVariant (\s v -> s { constructorVariant = v })
. conVariantTraversal
where
conVariantTraversal :: Traversal' ConstructorVariant Name
conVariantTraversal _ NormalConstructor = pure NormalConstructor
conVariantTraversal _ InfixConstructor = pure InfixConstructor
conVariantTraversal f (RecordConstructor fs) = RecordConstructor <$> traverse f fs
conTypeTraversal :: Traversal' ConstructorInfo Type
conTypeTraversal = lens constructorFields (\s v -> s { constructorFields = v })
. traverse
conNameMap :: (Name -> Name) -> ConstructorInfo -> ConstructorInfo
conNameMap = over conNameTraversal
conFieldNameMap :: (Name -> Name) -> ConstructorInfo -> ConstructorInfo
conFieldNameMap = over conFieldNameTraversal
conTypeMap :: (Type -> Type) -> ConstructorInfo -> ConstructorInfo
conTypeMap = over conTypeTraversal
type Lens' s a = forall f. Functor f => (a -> f a) -> s -> f s
type Traversal' s a = forall f. Applicative f => (a -> f a) -> s -> f s
lens :: (s -> a) -> (s -> a -> s) -> Lens' s a
lens sa sas afa s = sas s <$> afa (sa s)
{-# INLINE lens #-}
over :: Traversal' s a -> (a -> a) -> s -> s
over l f = runIdentity . l (Identity . f)
{-# INLINE over #-}
typeVars :: [TyVarBndr] -> [Name]
typeVars = map tvName
conAppsT :: Name -> [Type] -> Type
conAppsT conName = foldl AppT (ConT conName)
substType
:: Type
-> Type
-> Type
-> Type
substType a b = go
where
go x | x == a = b
go (VarT n) = VarT n
go (AppT l r) = AppT (go l) (go r)
go (ForallT xs ctx t) = ForallT xs ctx (go t)
go (SigT t k) = SigT (go t) k
go (InfixT l n r) = InfixT (go l) n (go r)
go (UInfixT l n r) = UInfixT (go l) n (go r)
go (ParensT t) = ParensT (go t)
go x = x
toCon :: ConstructorInfo -> Con
toCon (ConstructorInfo { constructorName = name
, constructorVars = vars
, constructorContext = ctxt
, constructorFields = ftys
, constructorStrictness = fstricts
, constructorVariant = variant })
| not (null vars && null ctxt)
= error "makeBaseFunctor: GADTs are not currently supported."
| otherwise
= let bangs = map toBang fstricts
in case variant of
NormalConstructor -> NormalC name $ zip bangs ftys
RecordConstructor fnames -> RecC name $ zip3 fnames bangs ftys
InfixConstructor -> let [bang1, bang2] = bangs
[fty1, fty2] = ftys
in InfixC (bang1, fty1) name (bang2, fty2)
where
toBang (FieldStrictness upkd strct) = Bang (toSourceUnpackedness upkd)
(toSourceStrictness strct)
where
toSourceUnpackedness :: Unpackedness -> SourceUnpackedness
toSourceUnpackedness UnspecifiedUnpackedness = NoSourceUnpackedness
toSourceUnpackedness NoUnpack = SourceNoUnpack
toSourceUnpackedness Unpack = SourceUnpack
toSourceStrictness :: Strictness -> SourceStrictness
toSourceStrictness UnspecifiedStrictness = NoSourceStrictness
toSourceStrictness Lazy = SourceLazy
toSourceStrictness TH.Abs.Strict = SourceStrict
functorTypeName :: Name
functorTypeName = mkNameG_tc "base" "GHC.Base" "Functor"
foldableTypeName :: Name
foldableTypeName = mkNameG_tc "base" "Data.Foldable" "Foldable"
traversableTypeName :: Name
traversableTypeName = mkNameG_tc "base" "Data.Traversable" "Traversable"