{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE TemplateHaskell #-}

module SumTypesX.TH
  ( -- * Constructing sum types
    constructSumType,
    SumTypeOptions,
    defaultSumTypeOptions,
    sumTypeOptionsTagOptions,
    SumTypeTagOptions (..),
    sumTypeOptionsConstructorStrictness,
    SumTypeConstructorStrictness (..),

    -- * Converting between sum types
    sumTypeConverter,
    partialSumTypeConverter,
  )
where

import Language.Haskell.TH

-- | This is a template haskell function that creates a sum type from a list of
-- types. Here is an example:
--
-- > data TypeA = TypeA
-- > data TypeB = TypeB
-- > data TypeC = TypeC
-- >
-- > constructSumType "MySum" defaultSumTypeOptions [''TypeA, ''TypeB, ''TypeC]
--
-- This will produce the following sum type:
--
-- > data MySum
-- >   = MySumTypeA TypeA
-- >   | MySumTypeB TypeB
-- >   | MySumTypeC TypeC
--
-- Note that you can use standalone deriving to derive any instances you want:
--
-- > deriving instance Show MySum
-- > deriving instance Eq MySum
constructSumType :: String -> SumTypeOptions -> [Name] -> Q [Dec]
constructSumType :: String -> SumTypeOptions -> [Name] -> Q [Dec]
constructSumType String
typeName SumTypeOptions {SumTypeConstructorStrictness
SumTypeTagOptions
sumTypeOptionsTagOptions :: SumTypeOptions -> SumTypeTagOptions
sumTypeOptionsConstructorStrictness :: SumTypeOptions -> SumTypeConstructorStrictness
sumTypeOptionsTagOptions :: SumTypeTagOptions
sumTypeOptionsConstructorStrictness :: SumTypeConstructorStrictness
..} [Name]
types = do
  let strictness :: SourceStrictness
strictness = SumTypeConstructorStrictness -> SourceStrictness
constructorStrictness SumTypeConstructorStrictness
sumTypeOptionsConstructorStrictness
      mkConstructor :: Name -> Con
mkConstructor Name
name =
        Name -> [BangType] -> Con
NormalC
          (SumTypeTagOptions -> String -> Name -> Name
constructorName SumTypeTagOptions
sumTypeOptionsTagOptions String
typeName Name
name)
          [(SourceUnpackedness -> SourceStrictness -> Bang
Bang SourceUnpackedness
NoSourceUnpackedness SourceStrictness
strictness, Name -> Type
ConT Name
name)]
      constructors :: [Con]
constructors = (Name -> Con) -> [Name] -> [Con]
forall a b. (a -> b) -> [a] -> [b]
map Name -> Con
mkConstructor [Name]
types
  [Dec] -> Q [Dec]
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return [Cxt
-> Name
-> [TyVarBndr ()]
-> Maybe Type
-> [Con]
-> [DerivClause]
-> Dec
DataD [] (String -> Name
mkName String
typeName) [] Maybe Type
forall a. Maybe a
Nothing [Con]
constructors []]

-- | Options for 'constructSumType'. Note that the constructor for this type is
-- not exported, please use 'defaultSumTypeOptions'. (This is done for
-- the sake of backwards compatibility in case we add options.)
data SumTypeOptions
  = SumTypeOptions
  { SumTypeOptions -> SumTypeTagOptions
sumTypeOptionsTagOptions :: SumTypeTagOptions,
    SumTypeOptions -> SumTypeConstructorStrictness
sumTypeOptionsConstructorStrictness :: SumTypeConstructorStrictness
  }

-- | Default options for 'SumTypeOptions'
--
-- @
-- 'SumTypeOptions'
-- { 'sumTypeOptionsTagOptions' = 'PrefixTagsWithTypeName'
-- , 'sumTypeOptionsConstructorStrictness' = 'LazySumTypeConstructors'
-- }
-- @
defaultSumTypeOptions :: SumTypeOptions
defaultSumTypeOptions :: SumTypeOptions
defaultSumTypeOptions =
  SumTypeOptions
    { sumTypeOptionsTagOptions :: SumTypeTagOptions
sumTypeOptionsTagOptions = SumTypeTagOptions
PrefixTagsWithTypeName,
      sumTypeOptionsConstructorStrictness :: SumTypeConstructorStrictness
sumTypeOptionsConstructorStrictness = SumTypeConstructorStrictness
LazySumTypeConstructors
    }

-- | This type specifies how 'constructSumType' will generate the tags for each
-- type.
data SumTypeTagOptions
  = -- | This option generates tags with the sum type name prefixed to each
    -- tag.
    PrefixTagsWithTypeName
  | -- | This option generates tags with the sum type name appended to each
    -- tag.
    AppendTypeNameToTags
  | -- | Uses the given function to construct an arbitrary tag name. The
    -- argument to this function is the name of the tagged type.
    ConstructTagName (String -> String)

constructorName :: SumTypeTagOptions -> String -> Name -> Name
constructorName :: SumTypeTagOptions -> String -> Name -> Name
constructorName SumTypeTagOptions
PrefixTagsWithTypeName String
typeName = String -> Name
mkName (String -> Name) -> (Name -> String) -> Name -> Name
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (String
typeName String -> String -> String
forall a. [a] -> [a] -> [a]
++) (String -> String) -> (Name -> String) -> Name -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Name -> String
nameBase
constructorName SumTypeTagOptions
AppendTypeNameToTags String
typeName = String -> Name
mkName (String -> Name) -> (Name -> String) -> Name -> Name
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
typeName) (String -> String) -> (Name -> String) -> Name -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Name -> String
nameBase
constructorName (ConstructTagName String -> String
mkConstructor) String
_ = String -> Name
mkName (String -> Name) -> (Name -> String) -> Name -> Name
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> String
mkConstructor (String -> String) -> (Name -> String) -> Name -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Name -> String
nameBase

-- | Defines if the constructors for the sum type should be lazy or strict.
data SumTypeConstructorStrictness
  = -- | Constructors will be lazy
    LazySumTypeConstructors
  | -- | Constructors will be strict
    StrictSumTypeConstructors
  deriving (Int -> SumTypeConstructorStrictness -> String -> String
[SumTypeConstructorStrictness] -> String -> String
SumTypeConstructorStrictness -> String
(Int -> SumTypeConstructorStrictness -> String -> String)
-> (SumTypeConstructorStrictness -> String)
-> ([SumTypeConstructorStrictness] -> String -> String)
-> Show SumTypeConstructorStrictness
forall a.
(Int -> a -> String -> String)
-> (a -> String) -> ([a] -> String -> String) -> Show a
$cshowsPrec :: Int -> SumTypeConstructorStrictness -> String -> String
showsPrec :: Int -> SumTypeConstructorStrictness -> String -> String
$cshow :: SumTypeConstructorStrictness -> String
show :: SumTypeConstructorStrictness -> String
$cshowList :: [SumTypeConstructorStrictness] -> String -> String
showList :: [SumTypeConstructorStrictness] -> String -> String
Show, SumTypeConstructorStrictness
-> SumTypeConstructorStrictness -> Bool
(SumTypeConstructorStrictness
 -> SumTypeConstructorStrictness -> Bool)
-> (SumTypeConstructorStrictness
    -> SumTypeConstructorStrictness -> Bool)
-> Eq SumTypeConstructorStrictness
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: SumTypeConstructorStrictness
-> SumTypeConstructorStrictness -> Bool
== :: SumTypeConstructorStrictness
-> SumTypeConstructorStrictness -> Bool
$c/= :: SumTypeConstructorStrictness
-> SumTypeConstructorStrictness -> Bool
/= :: SumTypeConstructorStrictness
-> SumTypeConstructorStrictness -> Bool
Eq)

constructorStrictness :: SumTypeConstructorStrictness -> SourceStrictness
constructorStrictness :: SumTypeConstructorStrictness -> SourceStrictness
constructorStrictness SumTypeConstructorStrictness
LazySumTypeConstructors = SourceStrictness
NoSourceStrictness
constructorStrictness SumTypeConstructorStrictness
StrictSumTypeConstructors = SourceStrictness
SourceStrict

-- | This template haskell function creates a conversion function between two
-- sum types. It works by matching up constructors that share the same inner
-- type. Note that all types in the source sum type must be present in the
-- target sum type, or you will get an error.
--
-- > data MySum
-- >   = MySumTypeA TypeA
-- >   | MySumTypeB TypeB
-- >   | MySumTypeC TypeC
-- >
-- > data OtherSum
-- >   = OtherSumTypeA TypeA
-- >   | OtherSumTypeB TypeB
-- >
-- > sumTypeConverter "otherSumToMySum" ''OtherSum ''MySum
--
-- This will producing the following code:
--
-- > otherSumToMySum :: OtherSum -> MySum
-- > otherSumToMySum (OtherSumTypeA typeA) = MySumTypeA typeA
-- > otherSumToMySum (OtherSumTypeB typeB) = MySumTypeB typeB
sumTypeConverter :: String -> Name -> Name -> Q [Dec]
sumTypeConverter :: String -> Name -> Name -> Q [Dec]
sumTypeConverter String
functionName Name
sourceType Name
targetType = do
  [BothConstructors]
bothConstructors <- Name -> Name -> Q [BothConstructors]
matchTypeConstructors Name
sourceType Name
targetType
  let funcName :: Name
funcName = String -> Name
mkName String
functionName
  [Clause]
funcClauses <- (BothConstructors -> Q Clause) -> [BothConstructors] -> Q [Clause]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM BothConstructors -> Q Clause
mkSerializeFunc [BothConstructors]
bothConstructors
  Type
typeDecl <- [t|$(Name -> Q Type
forall (m :: * -> *). Quote m => Name -> m Type
conT Name
sourceType) -> $(Name -> Q Type
forall (m :: * -> *). Quote m => Name -> m Type
conT Name
targetType)|]
  [Dec] -> Q [Dec]
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return
    [ Name -> Type -> Dec
SigD Name
funcName Type
typeDecl,
      Name -> [Clause] -> Dec
FunD Name
funcName [Clause]
funcClauses
    ]

-- | Similar to 'sumTypeConverter', except not all types in the source sum type
-- need to be present in the target sum type.
--
-- Note that this doesn't produce a partial function in the Haskell sense; you
-- won't get an 'error' with the generated function on any arguments. The word
-- partial is used mathematically to denote that not all types from the source
-- sum type are present in the target sum type.
--
-- > data MySum
-- >   = MySumTypeA TypeA
-- >   | MySumTypeB TypeB
-- >   | MySumTypeC TypeC
-- >
-- > data OtherSum
-- >   = OtherSumTypeA TypeA
-- >   | OtherSumTypeB TypeB
-- >
-- > partialSumTypeConverter "mySumToOtherSum" ''MySum ''OtherSum
--
-- This will producing the following code:
--
-- > mySumToOtherSum :: MySum -> Maybe OtherSum
-- > mySumToOtherSum (MySumTypeA typeA) = Just $ OtherSumTypeA typeA
-- > mySumToOtherSum (MySumTypeB typeB) = Just $ OtherSumTypeB typeB
-- > mySumToOtherSum other = Nothing
partialSumTypeConverter :: String -> Name -> Name -> Q [Dec]
partialSumTypeConverter :: String -> Name -> Name -> Q [Dec]
partialSumTypeConverter String
functionName Name
sourceType Name
targetType = do
  [BothConstructors]
bothConstructors <- Name -> Name -> Q [BothConstructors]
matchTypeConstructors Name
targetType Name
sourceType
  let funcName :: Name
funcName = String -> Name
mkName String
functionName
      wildcardClause :: Clause
wildcardClause = [Pat] -> Body -> [Dec] -> Clause
Clause [Pat
WildP] (Exp -> Body
NormalB (Name -> Exp
ConE 'Nothing)) []
  [Clause]
funcClauses <- (BothConstructors -> Q Clause) -> [BothConstructors] -> Q [Clause]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM BothConstructors -> Q Clause
mkDeserializeFunc [BothConstructors]
bothConstructors
  Type
typeDecl <- [t|$(Name -> Q Type
forall (m :: * -> *). Quote m => Name -> m Type
conT Name
sourceType) -> Maybe $(Name -> Q Type
forall (m :: * -> *). Quote m => Name -> m Type
conT Name
targetType)|]

  [Dec] -> Q [Dec]
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return
    [ Name -> Type -> Dec
SigD Name
funcName Type
typeDecl,
      Name -> [Clause] -> Dec
FunD Name
funcName ([Clause]
funcClauses [Clause] -> [Clause] -> [Clause]
forall a. [a] -> [a] -> [a]
++ [Clause
wildcardClause])
    ]

matchTypeConstructors :: Name -> Name -> Q [BothConstructors]
matchTypeConstructors :: Name -> Name -> Q [BothConstructors]
matchTypeConstructors Name
sourceType Name
targetType = do
  [(Type, Name)]
sourceConstructors <- Name -> Q [(Type, Name)]
typeConstructors Name
sourceType
  [(Type, Name)]
targetConstructors <- Name -> Q [(Type, Name)]
typeConstructors Name
targetType
  ((Type, Name) -> Q BothConstructors)
-> [(Type, Name)] -> Q [BothConstructors]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM ([(Type, Name)] -> (Type, Name) -> Q BothConstructors
matchConstructor [(Type, Name)]
targetConstructors) [(Type, Name)]
sourceConstructors

-- | Extract the constructors and types for the given sum type.
typeConstructors :: Name -> Q [(Type, Name)]
typeConstructors :: Name -> Q [(Type, Name)]
typeConstructors Name
typeName = do
  Info
info <- Name -> Q Info
reify Name
typeName
  case Info
info of
    (TyConI (DataD Cxt
_ Name
_ [TyVarBndr ()]
_ Maybe Type
_ [Con]
constructors [DerivClause]
_)) -> (Con -> Q (Type, Name)) -> [Con] -> Q [(Type, Name)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM Con -> Q (Type, Name)
forall {m :: * -> *}. MonadFail m => Con -> m (Type, Name)
go [Con]
constructors
      where
        go :: Con -> m (Type, Name)
go (NormalC Name
name []) = String -> m (Type, Name)
forall a. String -> m a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String -> m (Type, Name)) -> String -> m (Type, Name)
forall a b. (a -> b) -> a -> b
$ String
"Constructor " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Name -> String
nameBase Name
name String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" doesn't have any arguments"
        go (NormalC Name
name [(Bang
_, Type
type')]) = (Type, Name) -> m (Type, Name)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (Type
type', Name
name)
        go (NormalC Name
name [BangType]
_) = String -> m (Type, Name)
forall a. String -> m a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String -> m (Type, Name)) -> String -> m (Type, Name)
forall a b. (a -> b) -> a -> b
$ String
"Constructor " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Name -> String
nameBase Name
name String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" has more than one argument"
        go Con
_ = String -> m (Type, Name)
forall a. String -> m a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String -> m (Type, Name)) -> String -> m (Type, Name)
forall a b. (a -> b) -> a -> b
$ String
"Invalid constructor in " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Name -> String
nameBase Name
typeName
    Info
_ -> String -> Q [(Type, Name)]
forall a. String -> Q a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String -> Q [(Type, Name)]) -> String -> Q [(Type, Name)]
forall a b. (a -> b) -> a -> b
$ Name -> String
nameBase Name
typeName String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" must be a sum type"

-- | Find the corresponding target constructor for a given source constructor.
matchConstructor :: [(Type, Name)] -> (Type, Name) -> Q BothConstructors
matchConstructor :: [(Type, Name)] -> (Type, Name) -> Q BothConstructors
matchConstructor [(Type, Name)]
targetConstructors (Type
type', Name
sourceConstructor) = do
  Name
targetConstructor <-
    Q Name -> (Name -> Q Name) -> Maybe Name -> Q Name
forall b a. b -> (a -> b) -> Maybe a -> b
maybe
      (String -> Q Name
forall a. String -> Q a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String -> Q Name) -> String -> Q Name
forall a b. (a -> b) -> a -> b
$ String
"Can't find constructor in target type corresponding to " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Name -> String
nameBase Name
sourceConstructor)
      Name -> Q Name
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return
      (Type -> [(Type, Name)] -> Maybe Name
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup Type
type' [(Type, Name)]
targetConstructors)
  BothConstructors -> Q BothConstructors
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return (BothConstructors -> Q BothConstructors)
-> BothConstructors -> Q BothConstructors
forall a b. (a -> b) -> a -> b
$ Type -> Name -> Name -> BothConstructors
BothConstructors Type
type' Name
sourceConstructor Name
targetConstructor

-- | Utility type to hold the source and target constructors for a given type.
data BothConstructors
  = BothConstructors
  { BothConstructors -> Type
innerType :: Type,
    BothConstructors -> Name
sourceConstructor :: Name,
    BothConstructors -> Name
targetConstructor :: Name
  }

-- | Construct the TH function 'Clause' for the serialization function for a
-- given type.
mkSerializeFunc :: BothConstructors -> Q Clause
mkSerializeFunc :: BothConstructors -> Q Clause
mkSerializeFunc BothConstructors {Type
Name
innerType :: BothConstructors -> Type
sourceConstructor :: BothConstructors -> Name
targetConstructor :: BothConstructors -> Name
innerType :: Type
sourceConstructor :: Name
targetConstructor :: Name
..} = do
  Name
varName <- String -> Q Name
forall (m :: * -> *). Quote m => String -> m Name
newName String
"value"
  let tmp :: [Pat]
tmp = [Name -> Pat
VarP Name
varName]
      patternMatch :: Pat
patternMatch = Name -> Cxt -> [Pat] -> Pat
ConP Name
sourceConstructor [] [Pat]
tmp
      constructor :: Exp
constructor = Exp -> Exp -> Exp
AppE (Name -> Exp
ConE Name
targetConstructor) (Name -> Exp
VarE Name
varName)
  Clause -> Q Clause
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return (Clause -> Q Clause) -> Clause -> Q Clause
forall a b. (a -> b) -> a -> b
$ [Pat] -> Body -> [Dec] -> Clause
Clause [Pat
patternMatch] (Exp -> Body
NormalB Exp
constructor) []

-- | Construct the TH function 'Clause' for the deserialization function for a
-- given type.
mkDeserializeFunc :: BothConstructors -> Q Clause
mkDeserializeFunc :: BothConstructors -> Q Clause
mkDeserializeFunc BothConstructors {Type
Name
innerType :: BothConstructors -> Type
sourceConstructor :: BothConstructors -> Name
targetConstructor :: BothConstructors -> Name
innerType :: Type
sourceConstructor :: Name
targetConstructor :: Name
..} = do
  Name
varName <- String -> Q Name
forall (m :: * -> *). Quote m => String -> m Name
newName String
"value"
  let patternMatch :: Pat
patternMatch = Name -> Cxt -> [Pat] -> Pat
ConP Name
targetConstructor [] [Name -> Pat
VarP Name
varName]
      constructor :: Exp
constructor = Exp -> Exp -> Exp
AppE (Name -> Exp
ConE 'Just) (Exp -> Exp -> Exp
AppE (Name -> Exp
ConE Name
sourceConstructor) (Name -> Exp
VarE Name
varName))
  Clause -> Q Clause
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return (Clause -> Q Clause) -> Clause -> Q Clause
forall a b. (a -> b) -> a -> b
$ [Pat] -> Body -> [Dec] -> Clause
Clause [Pat
patternMatch] (Exp -> Body
NormalB Exp
constructor) []