{-# LANGUAGE CPP #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE DeriveDataTypeable #-} {-# LANGUAGE DeriveGeneric #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE NoImplicitPrelude #-} {-# LANGUAGE RecordWildCards #-} {-# LANGUAGE TemplateHaskell #-} {-# LANGUAGE TupleSections #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeSynonymInstances #-} #ifndef HAVE_OVERLOADED_LABELS {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE MagicHash #-} #endif #if 0 #if HAVE_MONAD_FAIL && MIN_VERSION_template_haskell(2,11,0) #define _FAIL_IN_MONAD #else #define _FAIL_IN_MONAD , fail #endif #endif -- Package template-haskell 2.11 is missing MonadFail instance for Q. #define _FAIL_IN_MONAD , fail -- | -- Module: $HEADER$ -- Description: Derive magic instances for OverloadedRecordFields. -- Copyright: (c) 2016, Peter Trško -- License: BSD3 -- -- Maintainer: peter.trsko@gmail.com -- Stability: experimental -- Portability: CPP, DataKinds, DeriveDataTypeable, DeriveGeneric, -- FlexibleContexts (GHC <8), FlexibleInstances, LambdaCase, -- MagicHash (GHC <8), MultiParamTypeClasses, NoImplicitPrelude, -- RecordWildCards, TemplateHaskell, TupleSections, TypeFamilies, -- TypeSynonymInstances -- -- Derive magic instances for OverloadedRecordFields. module Data.OverloadedRecords.TH ( -- * Derive OverloadedRecordFields instances overloadedRecord , overloadedRecords , overloadedRecordFor , overloadedRecordsFor -- ** Customize Derivation Process , DeriveOverloadedRecordsParams #ifndef HAVE_OVERLOADED_LABELS , fieldDerivation #endif , FieldDerivation , OverloadedField(..) , defaultFieldDerivation , defaultMakeFieldName -- * Low-level Deriving Mechanism , field , simpleField , fieldGetter , fieldSetter , simpleFieldSetter ) where import Prelude (Num((-)), fromIntegral) import Control.Applicative (Applicative((<*>))) import Control.Monad (Monad((>>=) _FAIL_IN_MONAD, return), replicateM) #if 0 #if HAVE_MONAD_FAIL && MIN_VERSION_template_haskell(2,11,0) import Control.Monad.Fail (MonadFail(fail)) #endif #endif import Data.Bool (Bool(False), otherwise) import qualified Data.Char as Char (toLower) import Data.Foldable (concat, foldl) import Data.Function ((.), ($)) import Data.Functor (Functor(fmap), (<$>)) import qualified Data.List as List ( drop , isPrefixOf , length , map , replicate , zip ) import Data.Maybe (Maybe(Just, Nothing), fromMaybe) import Data.Monoid ((<>)) import Data.String (String) import Data.Traversable (forM, mapM) import Data.Typeable (Typeable) import Data.Word (Word) import GHC.Generics (Generic) #ifndef HAVE_OVERLOADED_LABELS import GHC.Exts (Proxy#, proxy#) #endif import Text.Show (Show(show)) import Language.Haskell.TH ( Con(ForallC, InfixC, NormalC, RecC) , Dec(DataD, NewtypeD) , DecsQ , ExpQ , Info(TyConI) , Name , Pat(ConP, VarP, WildP) , PatQ , Q , Strict , Type , TypeQ , TyVarBndr(KindedTV, PlainTV) , appE , appT , conE , conP , conT , lamE , litT , nameBase , newName , recUpdE , reify , strTyLit , varE , varP , varT , wildP ) import Data.Default.Class (Default(def)) #ifndef HAVE_OVERLOADED_LABELS import Data.OverloadedLabels (IsLabel(fromLabel)) #endif import Data.OverloadedRecords ( FieldType , HasField(getField) , SetField(setField) , UpdateType ) #ifndef HAVE_OVERLOADED_LABELS -- | Overloaded label that can be used for accessing function of type -- 'FieldDerivation' from 'DeriveOverloadedRecordsParams'. -- -- This definition is available only if compiled with GHC <8. fieldDerivation :: IsLabel "fieldDerivation" a => a fieldDerivation = fromLabel (proxy# :: Proxy# "fieldDerivation") #endif -- | Parameters for customization of deriving process. Use 'def' to get -- default behaviour. data DeriveOverloadedRecordsParams = DeriveOverloadedRecordsParams { _strictFields :: Bool -- ^ Make setter and getter strict. **Currently unused.** , _fieldDerivation :: FieldDerivation -- ^ See 'FieldDerivation' for description. } deriving (Generic, Typeable) type instance FieldType "fieldDerivation" DeriveOverloadedRecordsParams = FieldDerivation instance HasField "fieldDerivation" DeriveOverloadedRecordsParams FieldDerivation where getField _proxy = _fieldDerivation type instance UpdateType "fieldDerivation" DeriveOverloadedRecordsParams FieldDerivation = DeriveOverloadedRecordsParams instance SetField "fieldDerivation" DeriveOverloadedRecordsParams FieldDerivation where setField _proxy s b = s{_fieldDerivation = b} -- | Describes what should be the name of overloaded record field, and can also -- provide custom implementation of getter and setter. data OverloadedField = GetterOnlyField String (Maybe ExpQ) -- ^ Derive only getter instances. If second argument is 'Just', then it -- contains custom definition of getter function. | GetterAndSetterField String (Maybe (ExpQ, ExpQ)) -- ^ Derive only getter instances. If second argument is 'Just', then it -- contains custom definitions of getter and setter functions, -- respectively. deriving (Generic, Typeable) -- | Type signature of a function that can customize the derivation of each -- individual overloaded record field. -- -- If field has an selector then the function will get its name or 'Nothing' -- otherwise. Function has to return 'Nothing' in case when generating -- overloaded record field instances is not desired. type FieldDerivation = String -- ^ Name of the type, of which this field is part of. -> String -- ^ Name of the constructor, of which this field is part of. -> Word -- ^ Field position as an argument of the constructor it is part of. -- Indexing starts from zero. -> Maybe String -- ^ Name of the field (record) accessor; 'Nothing' means that there is no -- record accessor defined for it. -> Maybe OverloadedField -- ^ Describes how overloaded record field should be generated for this -- specific constructor field. 'Nothing' means that no overloaded record -- field should be derived. See also 'OverloadedField' for details. -- | Suppose we have a weird type definition as this: -- -- @ -- data SomeType a b c = SomeConstructor -- { _fieldX :: a -- , someTypeFieldY :: b -- , someConstructorFieldZ :: c -- , anythingElse :: (a, b, c) -- } -- @ -- -- Then for each of those fields, 'defaultMakeFieldName' will produce -- expected OverloadedLabel name: -- -- * @_fieldX --> fieldX@ -- -- * @someTypeFieldY --> fieldY@ -- -- * @someConstructorFieldZ --> fieldZ@ -- -- * @anythingElse@ is ignored defaultMakeFieldName :: String -- ^ Name of the type, of which this field is part of. -> String -- ^ Name of the constructor, of which this field is part of. -> Word -- ^ Field position as an argument of the constructor it is part of. -- Indexing starts from zero. -> Maybe String -- ^ Name of the field (record) accessor; 'Nothing' means that there is no -- record accessor defined for it. -> Maybe String -- ^ Overloaded record field name to be used for this specific constructor -- field; 'Nothing' means that there shouldn't be a label associated with -- it. defaultMakeFieldName typeName constructorName _fieldPosition = \case Nothing -> Nothing Just fieldName | startsWith "_" -> Just $ dropPrefix "_" fieldName | startsWith typePrefix -> Just $ dropPrefix typePrefix fieldName | startsWith conPrefix -> Just $ dropPrefix conPrefix fieldName | otherwise -> Nothing where startsWith :: String -> Bool startsWith = (`List.isPrefixOf` fieldName) dropPrefix :: String -> String -> String dropPrefix s = headToLower . List.drop (List.length s) headToLower :: String -> String headToLower "" = "" headToLower (x : xs) = Char.toLower x : xs typePrefix, conPrefix :: String typePrefix = headToLower typeName conPrefix = headToLower constructorName -- | Function used by default value of 'DeriveOverloadedRecordsParams'. defaultFieldDerivation :: FieldDerivation defaultFieldDerivation = (((fmap (`GetterAndSetterField` Nothing) .) .) .) . defaultMakeFieldName -- | -- @ -- 'def' = 'DeriveOverloadedRecordsParams' -- { strictFields = 'False' -- , 'fieldDerivation' = 'defaultFieldDerivation' -- } -- @ instance Default DeriveOverloadedRecordsParams where def = DeriveOverloadedRecordsParams { _strictFields = False , _fieldDerivation = defaultFieldDerivation } -- | Derive magic OverloadedRecordFields instances for specified type. overloadedRecord :: DeriveOverloadedRecordsParams -- ^ Parameters for customization of deriving process. Use 'def' to get -- default behaviour. -> Name -- ^ Name of the type for which magic instances should be derived. -> DecsQ overloadedRecord params = withReified $ \name -> \case TyConI dec -> case dec of -- Not supporting DatatypeContexts, hence the [] required as the first -- argument to NewtypeD and DataD. #if MIN_VERSION_template_haskell(2,11,0) NewtypeD [] typeName typeVars _kindSignature constructor _deriving -> #else NewtypeD [] typeName typeVars constructor _deriving -> #endif deriveForConstructor params typeName typeVars constructor #if MIN_VERSION_template_haskell(2,11,0) DataD [] typeName typeVars _kindSignature constructors _deriving -> #else DataD [] typeName typeVars constructors _deriving -> #endif fmap concat . forM constructors $ deriveForConstructor params typeName typeVars x -> canNotDeriveError name x x -> canNotDeriveError name x where withReified :: (Name -> Info -> Q a) -> Name -> Q a withReified f t = (reify t >>= f t) canNotDeriveError :: Show a => Name -> a -> Q b canNotDeriveError = (fail .) . errMessage errMessage :: Show a => Name -> a -> String errMessage n x = "`" <> show n <> "' is neither newtype nor data type: " <> show x -- | Derive magic OverloadedRecordFields instances for specified types. overloadedRecords :: DeriveOverloadedRecordsParams -- ^ Parameters for customization of deriving process. Use 'def' to get -- default behaviour. -> [Name] -- ^ Names of the types for which magic instances should be derived. -> DecsQ overloadedRecords params = fmap concat . mapM (overloadedRecord params) -- | Derive magic OverloadedRecordFields instances for specified type. -- -- Similar to 'overloadedRecords', but instead of -- 'DeriveOverloadedRecordsParams' value it takes function which can modify its -- default value. -- -- @ -- data Coordinates2D a -- { coordinateX :: a -- , coordinateY :: a -- } -- -- 'overloadedRecordsFor' ''Coordinates2D -- $ \#fieldDerivation .~ \\_ _ _ -> \\case -- Nothing -> Nothing -- Just field -> lookup field -- [ (\"coordinateX\", 'GetterOnlyField' \"x\" Nothing) -- , (\"coordinateY\", 'GetterOnlyField' \"y\" Nothing) -- ] -- @ overloadedRecordFor :: Name -- ^ Name of the type for which magic instances should be derived. -> (DeriveOverloadedRecordsParams -> DeriveOverloadedRecordsParams) -- ^ Function that modifies parameters for customization of deriving -- process. -> DecsQ overloadedRecordFor typeName f = overloadedRecord (f def) typeName -- | Derive magic OverloadedRecordFields instances for specified types. overloadedRecordsFor :: [Name] -- ^ Names of the types for which magic instances should be derived. -> (DeriveOverloadedRecordsParams -> DeriveOverloadedRecordsParams) -- ^ Function that modifies parameters for customization of deriving -- process. -> DecsQ overloadedRecordsFor typeNames f = overloadedRecords (f def) typeNames -- | Derive magic instances for all fields of a specific data constructor of a -- specific type. deriveForConstructor :: DeriveOverloadedRecordsParams -- ^ Parameters for customization of deriving process. Use 'def' to get -- default behaviour. -> Name -> [TyVarBndr] -> Con -> DecsQ deriveForConstructor params name typeVars = \case NormalC constructorName args -> deriveFor constructorName args $ \(strict, argType) f -> f Nothing strict argType RecC constructorName args -> deriveFor constructorName args $ \(accessor, strict, argType) f -> f (Just accessor) strict argType InfixC arg0 constructorName arg1 -> deriveFor constructorName [arg0, arg1] $ \(strict, argType) f -> f Nothing strict argType #if 0 #if MIN_VERSION_template_haskell(2,11,0) GadtC _ _ _ -> RecGadtC _ _ _ -> #endif #endif -- Existentials aren't supported. ForallC _typeVariables _context _constructor -> return [] where deriveFor :: Name -> [a] -> (a -> (Maybe Name -> Strict -> Type -> DecsQ) -> DecsQ) -> DecsQ deriveFor constrName args f = fmap concat . forM (withIndexes args) $ \(idx, arg) -> f arg $ \accessor strict fieldType' -> deriveForField params DeriveFieldParams { typeName = name , typeVariables = List.map getTypeName typeVars , constructorName = constrName , numberOfArgs = fromIntegral $ List.length args , currentIndex = idx , accessorName = accessor , strictness = strict , fieldType = fieldType' } where getTypeName :: TyVarBndr -> Name getTypeName = \case PlainTV n -> n KindedTV n _kind -> n withIndexes = List.zip [(0 :: Word) ..] -- | Parameters for 'deriveForField' function. data DeriveFieldParams = DeriveFieldParams { typeName :: Name -- ^ Record name, i.e. type constructor name. , typeVariables :: [Name] -- ^ Free type variables of a type constructor. , constructorName :: Name -- ^ Data constructor name. , numberOfArgs :: Word -- ^ Number of arguments that data constructor takes. , currentIndex :: Word -- ^ Index of the current argument of a data constructor for which we are -- deriving overloaded record field instances. Indexing starts from zero. -- In other words 'currentIndex' is between zero (including) and -- 'numberOfArgs' (excluding). , accessorName :: Maybe Name -- ^ Record field accessor, if available, otherwise 'Nothing'. , strictness :: Strict -- ^ Strictness annotation of the current data constructor argument. , fieldType :: Type -- ^ Type of the current data constructor argument. } -- | Derive magic instances for a specific field of a specific type. deriveForField :: DeriveOverloadedRecordsParams -- ^ Parameters for customization of deriving process. Use 'def' to get -- default behaviour. -> DeriveFieldParams -- ^ All the necessary information for derivation procedure. -> DecsQ deriveForField params DeriveFieldParams{..} = case possiblyLabel of Nothing -> return [] Just (GetterOnlyField label customGetterExpr) -> deriveGetter' (strTyLitT label) $ fromMaybe derivedGetterExpr customGetterExpr Just (GetterAndSetterField label customGetterAndSetterExpr) -> (<>) <$> deriveGetter' labelType getterExpr <*> deriveSetter' labelType setterExpr where labelType = strTyLitT label (getterExpr, setterExpr) = fromMaybe (derivedGetterExpr, derivedSetterExpr) customGetterAndSetterExpr where possiblyLabel = _fieldDerivation params (nameBase typeName) (nameBase constructorName) currentIndex (fmap nameBase accessorName) deriveGetter' labelType = deriveGetter labelType recordType (return fieldType) deriveSetter' labelType = deriveSetter labelType recordType (return fieldType) newRecordType newFieldType recordType = foldl appT (conT typeName) $ List.map varT typeVariables -- TODO: When field type is polymorphic, then we should allow to change it. newFieldType = return fieldType newRecordType = recordType -- Number of variables, i.e. arguments of a constructor, to the right of -- the currently processed field. numVarsOnRight = numberOfArgs - currentIndex - 1 inbetween :: (a -> [b]) -> a -> a -> b -> [b] inbetween f a1 a2 b = f a1 <> (b : f a2) derivedGetterExpr = case accessorName of Just name -> varE name Nothing -> do a <- newName "a" -- \(C _ _ ... _ a _ _ ... _) -> a lamE [return . ConP constructorName $ nthArg (VarP a)] (varE a) where nthArg :: Pat -> [Pat] nthArg = inbetween wildPs currentIndex numVarsOnRight derivedSetterExpr = case accessorName of Just name -> do s <- newName "s" b <- newName "b" lamE [varP s, varP b] $ recUpdE (varE s) [(name, ) <$> varE b] Nothing -> do varsBefore <- newNames currentIndex "a" b <- newName "b" varsAfter <- newNames numVarsOnRight "a" -- \(C a_0 a_1 ... a_(i - 1) _ a_(i + 1) a_(i + 2) ... a_(n)) b -> -- C a_0 a_1 ... a_(i - 1) b a_(i + 1) a_(i + 2) ... a_(n) lamE [constrPattern varsBefore varsAfter, varP b] $ constrExpression varsBefore (varE b) varsAfter where constrPattern before after = conP constructorName $ inbetween varPs before after wildP constrExpression before b after = foldl appE (conE constructorName) $ varEs before <> (b : varEs after) -- | Derive instances for overloaded record field, both getter and setter. field :: String -- ^ Overloaded label name. -> TypeQ -- ^ Record type. -> TypeQ -- ^ Field type. -> TypeQ -- ^ Record type after update. -> TypeQ -- ^ Setter will set field to a value of this type. -> ExpQ -- ^ Getter function. -> ExpQ -- ^ Setter function. -> DecsQ field label recType fldType newRecType newFldType getterExpr setterExpr = (<>) <$> deriveGetter labelType recType fldType getterExpr <*> deriveSetter labelType recType fldType newRecType newFldType setterExpr where labelType = strTyLitT label -- | Derive instances for overloaded record field, both getter and setter. Same -- as 'field', but record type is the same before and after update and so is -- the field type. simpleField :: String -- ^ Overloaded label name. -> TypeQ -- ^ Record type. -> TypeQ -- ^ Field type. -> ExpQ -- ^ Getter function. -> ExpQ -- ^ Setter function. -> DecsQ simpleField label recType fldType = field label recType fldType recType fldType -- | Derive instances for overloaded record field getter. fieldGetter :: String -- ^ Overloaded label name. -> TypeQ -- ^ Record type. -> TypeQ -- ^ Field type -> ExpQ -- ^ Getter function. -> DecsQ fieldGetter = deriveGetter . strTyLitT -- | Derive only getter related instances. deriveGetter :: TypeQ -> TypeQ -> TypeQ -> ExpQ -> DecsQ deriveGetter labelType recordType fieldType getter = [d| type instance FieldType $(labelType) $(recordType) = $(fieldType) instance HasField $(labelType) $(recordType) $(fieldType) where getField _proxy = $(getter) |] -- | Derive instances for overloaded record field setter. Same as -- 'fieldSetter', but record type is the same before and after update and so is -- the field type. simpleFieldSetter :: String -- ^ Overloaded label name. -> TypeQ -- ^ Record type. -> TypeQ -- ^ Field type. -> ExpQ -- ^ Setter function. -> DecsQ simpleFieldSetter label recordType fieldType = fieldSetter label recordType fieldType recordType fieldType -- | Derive instances for overloaded record field setter. fieldSetter :: String -- ^ Overloaded label name. -> TypeQ -- ^ Record type. -> TypeQ -- ^ Field type. -> TypeQ -- ^ Record type after update. -> TypeQ -- ^ Setter will set field to a value of this type. -> ExpQ -- ^ Setter function. -> DecsQ fieldSetter = deriveSetter . strTyLitT -- | Derive only setter related instances. deriveSetter :: TypeQ -> TypeQ -> TypeQ -> TypeQ -> TypeQ -> ExpQ -> DecsQ deriveSetter labelType recordType fieldType newRecordType newFieldType setter = [d| type instance UpdateType $(labelType) $(recordType) $(newFieldType) = $(newRecordType) instance SetField $(labelType) $(recordType) $(fieldType) where setField _proxy = $(setter) |] -- | Construct list of wildcard patterns ('WildP'). wildPs :: Word -> [Pat] wildPs n = List.replicate (fromIntegral n) WildP -- | Construct list of new names usin 'newName'. newNames :: Word -> String -> Q [Name] newNames n s = fromIntegral n `replicateM` newName s varPs :: [Name] -> [PatQ] varPs = List.map varP varEs :: [Name] -> [ExpQ] varEs = List.map varE strTyLitT :: String -> TypeQ strTyLitT = litT . strTyLit