{-# LANGUAGE TemplateHaskell #-} module Data.Record.StateFields.Core (Field(..), record) where import Data.Char import Language.Haskell.TH import Language.Haskell.TH.Syntax -- | A primitive field descriptor. data Field c a = Field { getField :: c -> a , putField :: a -> c -> c } -- | Modify the given 'data' or 'newtype' declaration so that all field names -- are prefixed with an underscore followed by the given string, and -- generate declarations of field descriptors for all fields, each bound to -- the corresponding field name prefixed with the given string (but no -- underscore). -- -- Example usage (this goes at the top level of a module): -- -- > record "foo" [d| data Foo = Foo { bar :: Int, baz :: Int } |] -- -- Note: the second parameter is Q [Dec] because this is what the [d| |] -- form returns, which is the most convenient way to use this function. -- However, the list must contain exactly one declaration, and it must be -- a 'data' or 'newtype' declaration. -- -- Note: in addition to adding the given prefix to each name, the first -- character of the original name is capitalized. record :: String -> Q [Dec] -> Q [Dec] record pre ds = ds >>= \ds -> case ds of [DataD cxt name tvs cons dvs] -> sequence $ return (DataD cxt name tvs (map mkCon cons) dvs) : concatMap mkFields cons [NewtypeD cxt name tvs con dvs] -> sequence $ return (NewtypeD cxt name tvs (mkCon con) dvs) : mkFields con _ -> fail $ "A `record' declaration must be given exactly one " ++ "`data' or `newtype' declaration." where ucFirst (x : xs) = toUpper x : xs rawName name = mkName $ '_' : pre ++ ucFirst (showName name) fieldName name = mkName $ pre ++ ucFirst (showName name) mkCon (RecC name vs) = RecC name $ map mkVar vs mkCon x = x mkVar (name, str, ty) = (rawName name, str, ty) mkFields (RecC name vs) = map mkField vs mkFields _ = [] mkField (name, str, ty) = do r <- newName "r" v <- newName "v" valD (varP fName) (normalB [| Field { getField = $(varE rName) , putField = $(lamE [varP v, varP r] $ recUpdE (varE r) [return (rName, VarE v)]) } |]) [] where fName = fieldName name rName = rawName name record' :: IO () record' = runQ (record "bar" [d| data Bar = Bar { bar :: Int } |]) >>= print . ppr