{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TupleSections #-}

-- |
-- Module      :   Grisette.Internal.TH.GADT.DeriveSerial
-- Copyright   :   (c) Sirui Lu 2024
-- License     :   BSD-3-Clause (see the LICENSE file)
--
-- Maintainer  :   siruilu@cs.washington.edu
-- Stability   :   Experimental
-- Portability :   GHC only
module Grisette.Internal.TH.GADT.DeriveSerial
  ( deriveGADTSerial,
    deriveGADTSerial1,
    deriveGADTSerial2,
  )
where

import Control.Monad (zipWithM)
import Data.Bytes.Serial
  ( Serial (deserialize, serialize),
    Serial1 (deserializeWith, serializeWith),
    Serial2 (deserializeWith2, serializeWith2),
  )
import qualified Data.Map as M
import Data.Maybe (mapMaybe)
import qualified Data.Set as S
import Grisette.Internal.TH.GADT.Common (DeriveConfig)
import Grisette.Internal.TH.GADT.UnaryOpCommon
  ( FieldFunExp,
    UnaryOpClassConfig
      ( UnaryOpClassConfig,
        unaryOpAllowExistential,
        unaryOpConfigs,
        unaryOpExtraVars,
        unaryOpInstanceNames,
        unaryOpInstanceTypeFromConfig
      ),
    UnaryOpConfig (UnaryOpConfig),
    UnaryOpFieldConfig
      ( UnaryOpFieldConfig,
        extraLiftedPatNames,
        extraPatNames,
        fieldCombineFun,
        fieldFunExp,
        fieldResFun
      ),
    UnaryOpFunConfig (genUnaryOpFun),
    defaultFieldFunExp,
    defaultUnaryOpInstanceTypeFromConfig,
    genUnaryOpClass,
  )
import Grisette.Internal.TH.Util (integerE)
import Language.Haskell.TH
  ( Body (NormalB),
    Clause (Clause),
    Dec (FunD),
    Lit (IntegerL),
    Match (Match),
    Name,
    Pat (LitP, VarP, WildP),
    Q,
    Type (VarT),
    bindS,
    caseE,
    conE,
    conT,
    doE,
    match,
    mkName,
    newName,
    noBindS,
    normalB,
    sigP,
    varE,
    varP,
    wildP,
  )
import Language.Haskell.TH.Datatype
  ( ConstructorInfo (constructorFields, constructorName),
    TypeSubstitution (freeVariables),
    resolveTypeSynonyms,
  )

newtype UnaryOpDeserializeConfig = UnaryOpDeserializeConfig
  {UnaryOpDeserializeConfig -> FieldFunExp
fieldDeserializeFun :: FieldFunExp}

instance UnaryOpFunConfig UnaryOpDeserializeConfig where
  genUnaryOpFun :: DeriveConfig
-> UnaryOpDeserializeConfig
-> [Name]
-> Int
-> [(Type, Type)]
-> [(Type, Type)]
-> [(Type, Type)]
-> (Name -> Bool)
-> [ConstructorInfo]
-> Q Dec
genUnaryOpFun
    DeriveConfig
_
    UnaryOpDeserializeConfig {FieldFunExp
fieldDeserializeFun :: UnaryOpDeserializeConfig -> FieldFunExp
fieldDeserializeFun :: FieldFunExp
..}
    [Name]
funNames
    Int
n
    [(Type, Type)]
_
    [(Type, Type)]
_
    [(Type, Type)]
argTypes
    Name -> Bool
_
    [ConstructorInfo]
constructors = do
      [Type]
allFields <-
        (Type -> Q Type) -> [Type] -> Q [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM Type -> Q Type
resolveTypeSynonyms ([Type] -> Q [Type]) -> [Type] -> Q [Type]
forall a b. (a -> b) -> a -> b
$
          (ConstructorInfo -> [Type]) -> [ConstructorInfo] -> [Type]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap ConstructorInfo -> [Type]
constructorFields [ConstructorInfo]
constructors
      let usedArgs :: Set Name
usedArgs = [Name] -> Set Name
forall a. Ord a => [a] -> Set a
S.fromList ([Name] -> Set Name) -> [Name] -> Set Name
forall a b. (a -> b) -> a -> b
$ [Type] -> [Name]
forall a. TypeSubstitution a => a -> [Name]
freeVariables [Type]
allFields
      [(Name, Maybe Name)]
args <-
        ((Type, Type) -> Q (Name, Maybe Name))
-> [(Type, Type)] -> Q [(Name, Maybe Name)]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse
          ( \(Type
ty, Type
_) -> do
              case Type
ty of
                VarT Name
nm ->
                  if Name -> Set Name -> Bool
forall a. Ord a => a -> Set a -> Bool
S.member Name
nm Set Name
usedArgs
                    then do
                      Name
pname <- String -> Q Name
forall (m :: * -> *). Quote m => String -> m Name
newName String
"p"
                      (Name, Maybe Name) -> Q (Name, Maybe Name)
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return (Name
nm, Name -> Maybe Name
forall a. a -> Maybe a
Just Name
pname)
                    else (Name, Maybe Name) -> Q (Name, Maybe Name)
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return ('undefined, Maybe Name
forall a. Maybe a
Nothing)
                Type
_ -> (Name, Maybe Name) -> Q (Name, Maybe Name)
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return ('undefined, Maybe Name
forall a. Maybe a
Nothing)
          )
          [(Type, Type)]
argTypes
      let argToFunPat :: Map Name Name
argToFunPat =
            [(Name, Name)] -> Map Name Name
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(Name, Name)] -> Map Name Name)
-> [(Name, Name)] -> Map Name Name
forall a b. (a -> b) -> a -> b
$ ((Name, Maybe Name) -> Maybe (Name, Name))
-> [(Name, Maybe Name)] -> [(Name, Name)]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (\(Name
nm, Maybe Name
mpat) -> (Name -> (Name, Name)) -> Maybe Name -> Maybe (Name, Name)
forall a b. (a -> b) -> Maybe a -> Maybe b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Name
nm,) Maybe Name
mpat) [(Name, Maybe Name)]
args
      let funPats :: [Pat]
funPats = ((Name, Maybe Name) -> Pat) -> [(Name, Maybe Name)] -> [Pat]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Pat -> (Name -> Pat) -> Maybe Name -> Pat
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Pat
WildP Name -> Pat
VarP (Maybe Name -> Pat)
-> ((Name, Maybe Name) -> Maybe Name) -> (Name, Maybe Name) -> Pat
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Name, Maybe Name) -> Maybe Name
forall a b. (a, b) -> b
snd) [(Name, Maybe Name)]
args
      let genAuxFunMatch :: Integer -> ConstructorInfo -> Q Match
genAuxFunMatch Integer
conIdx ConstructorInfo
conInfo = do
            [Type]
fields <- (Type -> Q Type) -> [Type] -> Q [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM Type -> Q Type
resolveTypeSynonyms ([Type] -> Q [Type]) -> [Type] -> Q [Type]
forall a b. (a -> b) -> a -> b
$ ConstructorInfo -> [Type]
constructorFields ConstructorInfo
conInfo
            [Exp]
defaultFieldFunExps <-
              (Type -> Q Exp) -> [Type] -> Q [Exp]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse
                (FieldFunExp
fieldDeserializeFun Map Name Name
argToFunPat Map Name [Name]
forall k a. Map k a
M.empty)
                [Type]
fields
            let conName :: Name
conName = ConstructorInfo -> Name
constructorName ConstructorInfo
conInfo
            Exp
exp <-
              (Q Exp -> Exp -> Q Exp) -> Q Exp -> [Exp] -> Q Exp
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl
                (\Q Exp
exp Exp
fieldFun -> [|$Q Exp
exp <*> $(Exp -> Q Exp
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return Exp
fieldFun)|])
                [|return $(Name -> Q Exp
forall (m :: * -> *). Quote m => Name -> m Exp
conE Name
conName)|]
                [Exp]
defaultFieldFunExps
            Match -> Q Match
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return (Match -> Q Match) -> Match -> Q Match
forall a b. (a -> b) -> a -> b
$ Pat -> Body -> [Dec] -> Match
Match (Lit -> Pat
LitP (Integer -> Lit
IntegerL Integer
conIdx)) (Exp -> Body
NormalB Exp
exp) []
      [Match]
auxMatches <- (Integer -> ConstructorInfo -> Q Match)
-> [Integer] -> [ConstructorInfo] -> Q [Match]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM Integer -> ConstructorInfo -> Q Match
genAuxFunMatch [Integer
0 ..] [ConstructorInfo]
constructors
      Match
auxFallbackMatch <- Q Pat -> Q Body -> [Q Dec] -> Q Match
forall (m :: * -> *).
Quote m =>
m Pat -> m Body -> [m Dec] -> m Match
match Q Pat
forall (m :: * -> *). Quote m => m Pat
wildP (Q Exp -> Q Body
forall (m :: * -> *). Quote m => m Exp -> m Body
normalB [|undefined|]) []
      let instanceFunName :: Name
instanceFunName = [Name]
funNames [Name] -> Int -> Name
forall a. HasCallStack => [a] -> Int -> a
!! Int
n
      -- let auxFunName = mkName "go"
      let selName :: Name
selName = String -> Name
mkName String
"sel"
      Exp
exp <-
        [Q Stmt] -> Q Exp
forall (m :: * -> *). Quote m => [m Stmt] -> m Exp
doE
          [ Q Pat -> Q Exp -> Q Stmt
forall (m :: * -> *). Quote m => m Pat -> m Exp -> m Stmt
bindS
              (Q Pat -> Q Type -> Q Pat
forall (m :: * -> *). Quote m => m Pat -> m Type -> m Pat
sigP (Name -> Q Pat
forall (m :: * -> *). Quote m => Name -> m Pat
varP Name
selName) (Name -> Q Type
forall (m :: * -> *). Quote m => Name -> m Type
conT ''Int))
              [|deserialize|],
            Q Exp -> Q Stmt
forall (m :: * -> *). Quote m => m Exp -> m Stmt
noBindS (Q Exp -> Q Stmt) -> Q Exp -> Q Stmt
forall a b. (a -> b) -> a -> b
$
              Q Exp -> [Q Match] -> Q Exp
forall (m :: * -> *). Quote m => m Exp -> [m Match] -> m Exp
caseE (Name -> Q Exp
forall (m :: * -> *). Quote m => Name -> m Exp
varE Name
selName) ([Q Match] -> Q Exp) -> [Q Match] -> Q Exp
forall a b. (a -> b) -> a -> b
$
                Match -> Q Match
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return (Match -> Q Match) -> [Match] -> [Q Match]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Match]
auxMatches [Match] -> [Match] -> [Match]
forall a. [a] -> [a] -> [a]
++ [Match
auxFallbackMatch]
          ]
      Dec -> Q Dec
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return (Dec -> Q Dec) -> Dec -> Q Dec
forall a b. (a -> b) -> a -> b
$
        Name -> [Clause] -> Dec
FunD
          Name
instanceFunName
          [ [Pat] -> Body -> [Dec] -> Clause
Clause
              [Pat]
funPats
              (Exp -> Body
NormalB Exp
exp)
              []
          ]

serialConfig :: UnaryOpClassConfig
serialConfig :: UnaryOpClassConfig
serialConfig =
  UnaryOpClassConfig
    { unaryOpConfigs :: [UnaryOpConfig]
unaryOpConfigs =
        [ UnaryOpFieldConfig -> [Name] -> UnaryOpConfig
forall config.
UnaryOpFunConfig config =>
config -> [Name] -> UnaryOpConfig
UnaryOpConfig
            UnaryOpFieldConfig
              { extraPatNames :: [String]
extraPatNames = [],
                extraLiftedPatNames :: Int -> [String]
extraLiftedPatNames = [String] -> Int -> [String]
forall a b. a -> b -> a
const [],
                fieldCombineFun :: Int
-> ConstructorVariant -> Name -> [Exp] -> [Exp] -> Q (Exp, [Bool])
fieldCombineFun = \Int
conIdx ConstructorVariant
_ Name
_ [] [Exp]
exp -> do
                  Exp
r <-
                    (Q Exp -> Exp -> Q Exp) -> Q Exp -> [Exp] -> Q Exp
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl
                      (\Q Exp
r Exp
exp -> [|$Q Exp
r >> $(Exp -> Q Exp
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return Exp
exp)|])
                      ([|serialize ($(Int -> Q Exp
forall a. Integral a => a -> Q Exp
integerE Int
conIdx) :: Int)|])
                      [Exp]
exp
                  (Exp, [Bool]) -> Q (Exp, [Bool])
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp
r, [Bool
True]),
                fieldResFun :: ConstructorVariant
-> Name -> [Exp] -> Int -> Exp -> Exp -> Q (Exp, [Bool])
fieldResFun = \ConstructorVariant
_ Name
_ [Exp]
_ Int
_ Exp
fieldPat Exp
fieldFun -> do
                  Exp
r <- [|$(Exp -> Q Exp
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return Exp
fieldFun) $(Exp -> Q Exp
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return Exp
fieldPat)|]
                  (Exp, [Bool]) -> Q (Exp, [Bool])
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp
r, [Bool
True]),
                fieldFunExp :: FieldFunExp
fieldFunExp =
                  [Name] -> FieldFunExp
defaultFieldFunExp
                    ['serialize, 'serializeWith, 'serializeWith2]
              }
            ['serialize, 'serializeWith, 'serializeWith2],
          UnaryOpDeserializeConfig -> [Name] -> UnaryOpConfig
forall config.
UnaryOpFunConfig config =>
config -> [Name] -> UnaryOpConfig
UnaryOpConfig
            UnaryOpDeserializeConfig
              { fieldDeserializeFun :: FieldFunExp
fieldDeserializeFun =
                  [Name] -> FieldFunExp
defaultFieldFunExp
                    ['deserialize, 'deserializeWith, 'deserializeWith2]
              }
            ['deserialize, 'deserializeWith, 'deserializeWith2]
        ],
      unaryOpInstanceNames :: [Name]
unaryOpInstanceNames = [''Serial, ''Serial1, ''Serial2],
      unaryOpExtraVars :: DeriveConfig -> Q [(Type, Type)]
unaryOpExtraVars = Q [(Type, Type)] -> DeriveConfig -> Q [(Type, Type)]
forall a b. a -> b -> a
const (Q [(Type, Type)] -> DeriveConfig -> Q [(Type, Type)])
-> Q [(Type, Type)] -> DeriveConfig -> Q [(Type, Type)]
forall a b. (a -> b) -> a -> b
$ [(Type, Type)] -> Q [(Type, Type)]
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return [],
      unaryOpInstanceTypeFromConfig :: DeriveConfig -> [(Type, Type)] -> [(Type, Type)] -> Name -> Q Type
unaryOpInstanceTypeFromConfig = DeriveConfig -> [(Type, Type)] -> [(Type, Type)] -> Name -> Q Type
defaultUnaryOpInstanceTypeFromConfig,
      unaryOpAllowExistential :: Bool
unaryOpAllowExistential = Bool
False
    }

-- | Derive 'Serial' instance for a GADT.
deriveGADTSerial :: DeriveConfig -> Name -> Q [Dec]
deriveGADTSerial :: DeriveConfig -> Name -> Q [Dec]
deriveGADTSerial DeriveConfig
deriveConfig = DeriveConfig -> UnaryOpClassConfig -> Int -> Name -> Q [Dec]
genUnaryOpClass DeriveConfig
deriveConfig UnaryOpClassConfig
serialConfig Int
0

-- | Derive 'Serial1' instance for a GADT.
deriveGADTSerial1 :: DeriveConfig -> Name -> Q [Dec]
deriveGADTSerial1 :: DeriveConfig -> Name -> Q [Dec]
deriveGADTSerial1 DeriveConfig
deriveConfig = DeriveConfig -> UnaryOpClassConfig -> Int -> Name -> Q [Dec]
genUnaryOpClass DeriveConfig
deriveConfig UnaryOpClassConfig
serialConfig Int
1

-- | Derive 'Serial2' instance for a GADT.
deriveGADTSerial2 :: DeriveConfig -> Name -> Q [Dec]
deriveGADTSerial2 :: DeriveConfig -> Name -> Q [Dec]
deriveGADTSerial2 DeriveConfig
deriveConfig = DeriveConfig -> UnaryOpClassConfig -> Int -> Name -> Q [Dec]
genUnaryOpClass DeriveConfig
deriveConfig UnaryOpClassConfig
serialConfig Int
2