{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE TemplateHaskell #-}

module SumTypes.TH
  ( -- * Constructing sum types
    constructSumType
  , SumTypeOptions
  , defaultSumTypeOptions
  , sumTypeOptionsTagOptions
  , SumTypeTagOptions (..)
  , sumTypeOptionsConstructorStrictness
  , SumTypeConstructorStrictness (..)
    -- * Converting between sum types
  , sumTypeConverter
  , partialSumTypeConverter
  ) where

import Language.Haskell.TH

-- | This is a template haskell function that creates a sum type from a list of
-- types. Here is an example:
--
-- > data TypeA = TypeA
-- > data TypeB = TypeB
-- > data TypeC = TypeC
-- >
-- > constructSumType "MySum" defaultSumTypeOptions [''TypeA, ''TypeB, ''TypeC]
--
-- This will produce the following sum type:
--
-- > data MySum
-- >   = MySumTypeA TypeA
-- >   | MySumTypeB TypeB
-- >   | MySumTypeC TypeC
--
-- Note that you can use standalone deriving to derive any instances you want:
--
-- > deriving instance Show MySum
-- > deriving instance Eq MySum
constructSumType :: String -> SumTypeOptions -> [Name] -> Q [Dec]
constructSumType typeName SumTypeOptions{..} types = do
  let
    strictness = constructorStrictness sumTypeOptionsConstructorStrictness
    mkConstructor name =
      NormalC
      (constructorName sumTypeOptionsTagOptions typeName name)
      [(Bang NoSourceUnpackedness strictness, ConT name)]
    constructors = map mkConstructor types
  return [DataD [] (mkName typeName) [] Nothing constructors []]

-- | Options for 'constructSumType'. Note that the constructor for this type is
-- not exported, please use 'defaultSumTypeOptions'. (This is done for
-- the sake of backwards compatibility in case we add options.)
data SumTypeOptions
  = SumTypeOptions
  { sumTypeOptionsTagOptions :: SumTypeTagOptions
  , sumTypeOptionsConstructorStrictness :: SumTypeConstructorStrictness
  }

-- | Default options for 'SumTypeOptions'
--
-- @
-- 'SumTypeOptions'
-- { 'sumTypeOptionsTagOptions' = 'PrefixTagsWithTypeName'
-- , 'sumTypeOptionsConstructorStrictness' = 'LazySumTypeConstructors'
-- }
-- @
defaultSumTypeOptions :: SumTypeOptions
defaultSumTypeOptions =
  SumTypeOptions
  { sumTypeOptionsTagOptions = PrefixTagsWithTypeName
  , sumTypeOptionsConstructorStrictness = LazySumTypeConstructors
  }

-- | This type specifies how 'constructSumType' will generate the tags for each
-- type.
data SumTypeTagOptions
  = PrefixTagsWithTypeName
    -- ^ This option generates tags with the sum type name prefixed to each
    -- tag.
  | AppendTypeNameToTags
    -- ^ This option generates tags with the sum type name appended to each
    -- tag.
  | ConstructTagName (String -> String)
    -- ^ Uses the given function to construct an arbitrary tag name. The
    -- argument to this function is the name of the tagged type.

constructorName :: SumTypeTagOptions -> String -> Name -> Name
constructorName PrefixTagsWithTypeName typeName = mkName . (typeName ++) . nameBase
constructorName AppendTypeNameToTags typeName = mkName . (++ typeName) . nameBase
constructorName (ConstructTagName mkConstructor) _ = mkName . mkConstructor . nameBase

-- | Defines if the constructors for the sum type should be lazy or strict.
data SumTypeConstructorStrictness
  = LazySumTypeConstructors
    -- ^ Constructors will be lazy
  | StrictSumTypeConstructors
    -- ^ Constructors will be strict
  deriving (Show, Eq)

constructorStrictness :: SumTypeConstructorStrictness -> SourceStrictness
constructorStrictness LazySumTypeConstructors = NoSourceStrictness
constructorStrictness StrictSumTypeConstructors = SourceStrict

-- | This template haskell function creates a conversion function between two
-- sum types. It works by matching up constructors that share the same inner
-- type. Note that all types in the source sum type must be present in the
-- target sum type, or you will get an error.
--
-- > data MySum
-- >   = MySumTypeA TypeA
-- >   | MySumTypeB TypeB
-- >   | MySumTypeC TypeC
-- >
-- > data OtherSum
-- >   = OtherSumTypeA TypeA
-- >   | OtherSumTypeB TypeB
-- >
-- > sumTypeConverter "otherSumToMySum" ''OtherSum ''MySum
--
-- This will producing the following code:
--
-- > otherSumToMySum :: OtherSum -> MySum
-- > otherSumToMySum (OtherSumTypeA typeA) = MySumTypeA typeA
-- > otherSumToMySum (OtherSumTypeB typeB) = MySumTypeB typeB
sumTypeConverter :: String -> Name -> Name -> Q [Dec]
sumTypeConverter functionName sourceType targetType = do
  bothConstructors <- matchTypeConstructors sourceType targetType
  let
    funcName = mkName functionName
  funcClauses <- mapM mkSerializeFunc bothConstructors
  typeDecl <- [t| $(conT sourceType) -> $(conT targetType) |]
  return
    [ SigD funcName typeDecl
    , FunD funcName funcClauses
    ]

-- | Similar to 'sumTypeConverter', except not all types in the source sum type
-- need to be present in the target sum type.
--
-- Note that this doesn't produce a partial function in the Haskell sense; you
-- won't get an 'error' with the generated function on any arguments. The word
-- partial is used mathematically to denote that not all types from the source
-- sum type are present in the target sum type.
--
-- > data MySum
-- >   = MySumTypeA TypeA
-- >   | MySumTypeB TypeB
-- >   | MySumTypeC TypeC
-- >
-- > data OtherSum
-- >   = OtherSumTypeA TypeA
-- >   | OtherSumTypeB TypeB
-- >
-- > partialSumTypeConverter "mySumToOtherSum" ''MySum ''OtherSum
--
-- This will producing the following code:
--
-- > mySumToOtherSum :: MySum -> Maybe OtherSum
-- > mySumToOtherSum (MySumTypeA typeA) = Just $ OtherSumTypeA typeA
-- > mySumToOtherSum (MySumTypeB typeB) = Just $ OtherSumTypeB typeB
-- > mySumToOtherSum other = Nothing
partialSumTypeConverter :: String -> Name -> Name -> Q [Dec]
partialSumTypeConverter functionName sourceType targetType = do
  bothConstructors <- matchTypeConstructors targetType sourceType
  let
    funcName = mkName functionName
    wildcardClause = Clause [WildP] (NormalB (ConE 'Nothing)) []
  funcClauses <- mapM mkDeserializeFunc bothConstructors
  typeDecl <- [t| $(conT sourceType) -> Maybe $(conT targetType) |]

  return
    [ SigD funcName typeDecl
    , FunD funcName (funcClauses ++ [wildcardClause])
    ]

matchTypeConstructors :: Name -> Name -> Q [BothConstructors]
matchTypeConstructors sourceType targetType = do
  sourceConstructors <- typeConstructors sourceType
  targetConstructors <- typeConstructors targetType
  mapM (matchConstructor targetConstructors) sourceConstructors

-- | Extract the constructors and types for the given sum type.
typeConstructors :: Name -> Q [(Type, Name)]
typeConstructors typeName = do
  info <- reify typeName
  case info of
    (TyConI (DataD _ _ _ _ constructors _)) -> mapM go constructors
      where
        go (NormalC name []) = fail $ "Constructor " ++ nameBase name ++ " doesn't have any arguments"
        go (NormalC name [(_, type')]) = return (type', name)
        go (NormalC name _) = fail $ "Constructor " ++ nameBase name ++ " has more than one argument"
        go _ = fail $ "Invalid constructor in " ++ nameBase typeName
    _ -> fail $ nameBase typeName ++ " must be a sum type"

-- | Find the corresponding target constructor for a given source constructor.
matchConstructor :: [(Type, Name)] -> (Type, Name) -> Q BothConstructors
matchConstructor targetConstructors (type', sourceConstructor) = do
  targetConstructor <-
    maybe
    (fail $ "Can't find constructor in target type corresponding to " ++ nameBase sourceConstructor)
    return
    (lookup type' targetConstructors)
  return $ BothConstructors type' sourceConstructor targetConstructor

-- | Utility type to hold the source and target constructors for a given type.
data BothConstructors =
  BothConstructors
  { innerType :: Type
  , sourceConstructor :: Name
  , targetConstructor :: Name
  }

-- | Construct the TH function 'Clause' for the serialization function for a
-- given type.
mkSerializeFunc :: BothConstructors -> Q Clause
mkSerializeFunc BothConstructors{..} = do
  varName <- newName "value"
  let
    patternMatch = ConP sourceConstructor [VarP varName]
    constructor = AppE (ConE targetConstructor) (VarE varName)
  return $ Clause [patternMatch] (NormalB constructor) []

-- | Construct the TH function 'Clause' for the deserialization function for a
-- given type.
mkDeserializeFunc :: BothConstructors -> Q Clause
mkDeserializeFunc BothConstructors{..} = do
  varName <- newName "value"
  let
    patternMatch = ConP targetConstructor [VarP varName]
    constructor = AppE (ConE 'Just) (AppE (ConE sourceConstructor) (VarE varName))
  return $ Clause [patternMatch] (NormalB constructor) []