-- SPDX-FileCopyrightText: 2020 Tocqueville Group -- -- SPDX-License-Identifier: LicenseRef-MIT-TQ -- -- | Generic deriving with unbalanced trees. module Util.CustomGeneric ( -- * Custom Generic strategies withDepths , rightBalanced , leftBalanced , rightComb , leftComb -- * Depth usage helpers , cstr , fld -- * Instance derivation , customGeneric ) where import qualified GHC.Generics as G import Generics.Deriving.TH (makeRep0Inline) import Language.Haskell.TH ---------------------------------------------------------------------------- -- Simple type synonyms ---------------------------------------------------------------------------- -- | Simple tuple specifying the depth of a constuctor and a list of depths -- for its fields. -- -- This is used as a way to specify the tree topology of the Generic instance -- to derive. type CstrDepth = (Natural, [Natural]) -- | Simple tuple that defines the "shape" of a constructor: it's name and number -- of fields. Used only in this module. type CstrShape = (Name, Int) -- | Type of a strategy to derive 'G.Generic' instances, it will be given the actual -- 'CstrShape's for a data-type and needs to return the 'CstrDepth's for it. -- It should when possible make checks and 'fail', using the constructors' 'Name' -- provided by the 'CstrShape'. type GenericStrategy = [CstrShape] -> Q [CstrDepth] -- | Simple type synonym used (internally) between functions, basically extending -- 'CstrDepth' with the 'Name's of the constructor and its fields. type NamedCstrDepths = (Natural, Name, [(Natural, Name)]) ---------------------------------------------------------------------------- -- Custom Generic strategies ---------------------------------------------------------------------------- -- | In this strategy the desired depths of contructors (in the type tree) and -- fields (in each constructor's tree) are provided manually and simply checked -- against the number of actual constructors and fields. withDepths :: [CstrDepth] -> GenericStrategy withDepths treeDepths cstrShape = do when (length treeDepths /= length cstrShape) $ fail "Number of contructors' depths does not match number of data contructors." forM_ (zip (map snd treeDepths) cstrShape) $ \(fDepths, (constrName, fldNum)) -> when (length fDepths /= fldNum) . fail $ "Number of fields' depths does not match number of field for data " <> "constructor: " <> show constrName return treeDepths -- | Strategy to make right-balanced instances (both in constructors and fields). rightBalanced :: GenericStrategy rightBalanced = fromDepthsStrategy makeRightBalDepths -- | Strategy to make left-balanced instances (both in constructors and fields). leftBalanced :: GenericStrategy leftBalanced = fromDepthsStrategy (reverse . makeRightBalDepths) -- | Strategy to make fully right-leaning instances (both in constructors and fields). rightComb :: GenericStrategy rightComb = fromDepthsStrategy (reverse . makeLeftCombDepths) -- | Strategy to make fully left-leaning instances (both in constructors and fields). leftComb :: GenericStrategy leftComb = fromDepthsStrategy makeLeftCombDepths ---------------------------------------------------------------------------- -- Generic strategies' builders ---------------------------------------------------------------------------- -- | Helper to make a strategy that created depths for constructor and fields -- in the same way, just from their number. fromDepthsStrategy :: (Int -> [Natural]) -> GenericStrategy fromDepthsStrategy dStrategy cShapes = return $ zip (dStrategy $ length cShapes) $ map (dStrategy . snd) cShapes makeRightBalDepths :: Int -> [Natural] makeRightBalDepths n = foldr (const addRightBalDepth) [] [1..n] where addRightBalDepth :: [Natural] -> [Natural] addRightBalDepth = \case [] -> [0] [x] -> [x + 1, x + 1] (x : y : xs) | x == y -> x : addRightBalDepth (x : xs) (_ : y : xs) -> y : y : y : xs makeLeftCombDepths :: Int -> [Natural] makeLeftCombDepths 0 = [] makeLeftCombDepths n = map fromIntegral $ (n - 1) : [n - 1, n - 2..1] ---------------------------------------------------------------------------- -- Depth usage helpers ---------------------------------------------------------------------------- -- | Helper for making a constructor depth. -- -- Note that this is only intended to be more readable than directly using a -- tuple with 'withDepths' and for the ability to be used in places where -- @RebindableSyntax@ overrides the number literal resolution. cstr :: forall n. KnownNat n => [Natural] -> CstrDepth cstr flds = (natVal (Proxy @n), flds) -- | Helper for making a field depth. -- -- Note that this is only intended to be more readable than directly using a -- tuple with 'withDepths' and for the ability to be used in places where -- @RebindableSyntax@ overrides the number literal resolution. fld :: forall n. KnownNat n => Natural fld = natVal $ Proxy @n ---------------------------------------------------------------------------- -- Instance derivation ---------------------------------------------------------------------------- -- | Derives the 'G.Generic' instance for a type given its name and a -- 'GenericStrategy' to use. -- -- The strategy is used to calculate the depths of the data-type constructors -- and each constructors' fields. -- -- The depths are used to generate the tree of the 'G.Generic' representation, -- allowing for a custom one, in contrast with the one derived automatically. -- -- This only supports "plain" @data@ types (no GADTs, no @newtype@s, etc.) and -- requires the depths to describe a fully and well-defined tree (see 'unbalancedFold'). -- -- For example, this is valid (and uses the 'withDepths' strategy with the 'cstr' -- and 'fld' helpers) and results in a balanced instance, equivalent to the -- auto-derived one: -- -- @@@ -- data CustomType a -- = CustomUp Integer Integer -- | CustomMid {unMid :: Natural} -- | CustomDown a -- | CustomNone -- -- $(customGeneric "CustomType" $ withDepths -- [ cstr @2 [fld @1, fld @1] -- , cstr @2 [fld @0] -- , cstr @2 [fld @0] -- , cstr @2 [] -- ]) -- @@@ -- -- and this is a valid, but fully left-leaning one: -- -- @@@ -- $(customGeneric "CustomType" $ withDepths -- [ cstr @3 [fld @1, fld @1] -- , cstr @3 [fld @0] -- , cstr @2 [fld @0] -- , cstr @1 [] -- ]) -- @@@ -- -- and, just as a demonstration, this is the same fully left-leaning one, but -- made using the simpler 'leftComb' strategy: -- -- @@@ -- $(customGeneric "CustomType" leftComb) -- @@@ -- customGeneric :: String -> GenericStrategy -> Q [Dec] customGeneric typeStr genStrategy = do -- reify the data type (typeName, mKind, vars, constructors) <- reifyDataType typeStr -- obtain info about its constructor and desired tree let derivedType = deriveFullType typeName mKind vars cShapes <- cstrShapes constructors treeDepths <- genStrategy cShapes weightedConstrs <- makeWeightedConstrs treeDepths cShapes -- produce the Generic instance res <- instanceD (pure []) (appT (conT ''G.Generic) derivedType) [ tySynInstD . tySynEqn Nothing (appT (conT ''G.Rep) derivedType) $ makeUnbalancedRep typeName treeDepths derivedType , makeUnbalancedFrom weightedConstrs , makeUnbalancedTo weightedConstrs ] return [res] -- | Reifies info from a type name (given as a 'String'). -- The lookup happens from the current splice's scope (see 'lookupTypeName') and -- the only accepted result is a "plain" data type (no GADTs). reifyDataType :: String -> Q (Name, Maybe Kind, [TyVarBndr], [Con]) reifyDataType typeStr = do typeInfo <- lookupTypeName typeStr >>= \case Nothing -> fail $ "Failed type name lookup for: '" <> typeStr <> "'." Just tn -> reify tn case typeInfo of TyConI (DataD _ typeName vars mKind constrs _) -> return (typeName, mKind, vars, constrs) _ -> fail $ "Only plain datatypes are supported for derivation, but '" <> typeStr <> "' instead reifies to:\n" <> show typeInfo -- | Derives, as well as possible, a type definition from its name, its kind -- (where known) and its variables. deriveFullType :: Name -> Maybe Kind -> [TyVarBndr] -> TypeQ deriveFullType tName mKind = addTypeSig . foldl appT (conT tName) . makeVarsType where addTypeSig :: TypeQ -> TypeQ addTypeSig = flip sigT $ fromMaybe StarT mKind makeVarsType :: [TyVarBndr] -> [TypeQ] makeVarsType = map $ \case PlainTV vName -> varT vName KindedTV vName kind -> sigT (varT vName) kind -- | Calculate the "shape" for each of the given constructors. -- The shape is simply the 'Name' of the constructor and the number of its args. cstrShapes :: [Con] -> Q [CstrShape] cstrShapes constructors = forM constructors $ \case NormalC name lst -> return (name, length lst) RecC name lst -> return (name, length lst) InfixC _ name _ -> return (name, 2) constr -> fail $ "Unsupported constructor: " <> show constr -- | Combines depths with constructors, 'fail'ing in case of mismatches, and -- generates 'Name's for the constructors' arguments. makeWeightedConstrs :: [CstrDepth] -> [CstrShape] -> Q [NamedCstrDepths] makeWeightedConstrs treeDepths cSizes = do forM (zip treeDepths cSizes) $ \((cDepth, fDepths), (cName, fNum)) -> do constrVarsNames <- replicateM fNum $ newName "v" return (cDepth, cName, zip fDepths constrVarsNames) -- | Creates the 'G.Rep' type for an unbalanced 'G.Generic' instance, for a type -- given its name, constructors' depths and derived full type. -- -- Note: given that these types definition can be very complex to generate, -- especially in the metadata, here we let @generic-deriving@ make a balanced -- value first (see 'makeRep0Inline') and then de-balance the result. makeUnbalancedRep :: Name -> [CstrDepth] -> TypeQ -> TypeQ makeUnbalancedRep typeName treeDepths derivedType = do -- let generic-deriving create the balanced type first balRep <- makeRep0Inline typeName derivedType -- separate the top-most type metadata from the constructors' trees (typeMd, constrTypes) <- dismantleGenericTree [t| (G.:+:) |] [t| G.C1 |] balRep -- for each of the constructor's trees unbalConstrs <- forM (zip constrTypes treeDepths) $ \(constrType, treeDepth) -> case treeDepth of (n, []) -> -- when there are no fields there is no tree to unbalance return (n, constrType) (n, fields) -> do -- separate the top-most constructor metadata from the fields' trees (constrMd, fieldTypes) <- dismantleGenericTree [t| (G.:*:) |] [t| G.S1 |] constrType -- build the unbalanced tree of fields unbalConstRes <- unbalancedFold (zip fields fieldTypes) (appT . appT (conT ''(G.:*:))) -- return the new unbalanced constructor return (n, AppT constrMd unbalConstRes) -- build the unbalanced tree of constructors and rebuild the type appT (pure typeMd) $ unbalancedFold unbalConstrs (appT . appT (conT ''(G.:+:))) -- | Breaks down a tree of @Generic@ types from the contructor of "nodes" and -- the constructor for "leaves" metadata. -- -- This expects (and should always be the case) the "root" to be a @Generic@ -- metadata contructor, which is returned in the result alongside the list of -- leaves (in order). dismantleGenericTree :: TypeQ -> TypeQ -> Type -> Q (Type, [Type]) dismantleGenericTree nodeConstrQ leafMetaQ (AppT meta nodes) = do nodeConstr <- nodeConstrQ leafMeta <- leafMetaQ let collectLeafsTypes :: Type -> [Type] collectLeafsTypes tp@(AppT a b) = case a of AppT md _ | md == leafMeta -> [tp] nd | nd == nodeConstr -> collectLeafsTypes b _ -> collectLeafsTypes a <> collectLeafsTypes b collectLeafsTypes x = error $ "Unexpected lack of Generic constructor application: " <> show x return (meta, collectLeafsTypes nodes) dismantleGenericTree _ _ x = error $ "Unexpected lack of Generic Metadata: " <> show x -- | Create the unbalanced 'G.from' fuction declaration for a type starting from -- its list of weighted constructors. makeUnbalancedFrom :: [NamedCstrDepths] -> DecQ makeUnbalancedFrom wConstrs = do (cPatts, cDepthExp) <- fmap unzip . forM wConstrs $ \(cDepth, cName, wFields) -> do (fPatts, fDepthExp) <- fmap unzip . forM wFields $ \(fDepth, fName) -> do -- make pattern for field variable fPat <- varP fName -- make expression to asseble a Generic Field from its variable fExpr <- appE [| G.M1 |] . appE [| G.K1 |] $ varE fName return (fPat, (fDepth, fExpr)) -- make pattern for this constructor let cPatt = ConP cName fPatts -- make expression to assemble its fields as an isolated Generic Constructor cExp <- appE [| G.M1 |] $ case fDepthExp of [] -> conE 'G.U1 _ -> unbalancedFold fDepthExp (appE . appE [| (G.:*:) |]) return (cPatt, (cDepth, [cExp])) -- make expressions to assemble all Generic Constructors cExps <- mapQ (appE [| G.M1 |]) $ unbalancedFold cDepthExp $ \xs ys -> (<>) <$> mapQ (appE [| G.L1 |]) xs <*> mapQ (appE [| G.R1 |]) ys -- make function definition funD 'G.from $ zipWith (\p e -> clause [pure p] (normalB $ pure e) []) cPatts cExps -- | Create the unbalanced 'G.to' fuction declaration for a type starting from -- its list of weighted constructors. makeUnbalancedTo :: [NamedCstrDepths] -> DecQ makeUnbalancedTo wConstrs = do (cExps, cDepthPat) <- fmap unzip . forM wConstrs $ \(cDepth, cName, wFields) -> do (fExps, fDepthPat) <- fmap unzip . forM wFields $ \(fDepth, fName) -> do -- make expression for field variable fExp <- varE fName -- make pattern for a Generic Field from its variable fPatt <- conP1 'G.M1 . conP1 'G.K1 $ varP fName return (fExp, (fDepth, fPatt)) -- make pattern for this isolated Generic Constructor cPatt <- conP1 'G.M1 $ case fDepthPat of [] -> conP 'G.U1 [] _ -> unbalancedFold fDepthPat (conP2 '(G.:*:)) -- make expression to assemble this constructor let cExp = foldl AppE (ConE cName) fExps return (cExp, (cDepth, [cPatt])) -- make patterns for all Generic Constructors cPatts <- mapQ (conP1 'G.M1) $ unbalancedFold cDepthPat $ \xs ys -> (<>) <$> mapQ (conP1 'G.L1) xs <*> mapQ (conP1 'G.R1) ys -- make function definition funD 'G.to $ zipWith (\p e -> clause [pure p] (normalB $ pure e) []) cPatts cExps -- | Recursively aggregates the values in the given list by merging (with the -- given function) the ones that are adjacent and with the same depth. -- -- This will fail for every case in which the list cannot be folded into a single -- 0-depth value. unbalancedFold :: forall a. Eq a => [(Natural, a)] -> (Q a -> Q a -> Q a) -> Q a unbalancedFold lst f = unbalancedFoldRec lst >>= \case [(0, result)] -> return result [(n, _)] -> fail $ "Resulting unbalanced tree has a single root, but of depth " <> show n <> " instead of 0. Check your depths definitions." _ -> fail $ "Cannot create a tree from nodes of depths: " <> show (map fst lst) <> ". Check your depths definitions." where unbalancedFoldRec :: [(Natural, a)] -> Q [(Natural, a)] unbalancedFoldRec xs = do ys <- unbalancedFoldSingle xs if xs == ys then return xs else unbalancedFoldRec ys unbalancedFoldSingle :: [(Natural, a)] -> Q [(Natural, a)] unbalancedFoldSingle = \case [] -> return [] (dx, x) : (dy, y) : xs | dx == dy -> do dxy <- f (pure x) (pure y) return $ (dx - 1, dxy) : xs x : xs -> do ys <- unbalancedFoldSingle xs return (x : ys) ---------------------------------------------------------------------------- -- Utility functions ---------------------------------------------------------------------------- conP1 :: Name -> PatQ -> PatQ conP1 name pat = conP name [pat] conP2 :: Name -> PatQ -> PatQ -> PatQ conP2 name pat1 pat2 = conP name [pat1, pat2] mapQ :: (Q a -> Q a) -> Q [a] -> Q [a] mapQ f qlst = qlst >>= mapM (f . pure)