{-# LANGUAGE ScopedTypeVariables #-}
{-# OPTIONS_GHC -Wno-x-partial -Wno-unrecognised-warning-flags #-}

module Data.Packed.TH.WriteCon (genConWrite, conWriteFName) where

import Data.List (group, sort)
import Data.Packed.FieldSize
import Data.Packed.Needs (NeedsWriter)
import qualified Data.Packed.Needs as N
import Data.Packed.Packable
import Data.Packed.TH.Flag (PackingFlag (..))
import Data.Packed.TH.Start (genStart, startFName)
import Data.Packed.TH.Utils
import Language.Haskell.TH

-- For a constructor 'Leaf', will generate the function name 'writeConLeaf'
conWriteFName :: Name -> Name
conWriteFName :: Name -> Name
conWriteFName Name
conName = String -> Name
mkName (String -> Name) -> String -> Name
forall a b. (a -> b) -> a -> b
$ String
"writeCon" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Name -> String
sanitizeConName Name
conName

-- | Generates a function that serialises and write a value to a 'Needs'.
-- The generated function is specific to a single data constructor.
--
-- __Example:__
--
-- For the 'Tree' data type, it generates the following function for the 'Leaf' constructor
--
-- @
-- writeConLeaf :: ('Packable' a) => a -> 'NeedsWriter (Tree a) r t'
-- writeConLeaf n  = startLeaf 'Data.Packed.Needs.>>' 'write' n
-- @
genConWrite ::
    [PackingFlag] ->
    -- | The name of the data constructor to generate the function for
    Name ->
    -- | A unique (to the data type) 'Tag' to identify the packed data constructor.
    --
    -- For example, for a 'Tree' data type,
    -- we would typically use '0' for the 'Leaf' constructor and '1' for the 'Node' constructor
    Tag ->
    [BangType] ->
    Q [Dec]
genConWrite :: [PackingFlag] -> Name -> Tag -> [BangType] -> Q [Dec]
genConWrite [PackingFlag]
flags Name
conName Tag
conIndex [BangType]
bangTypes = do
    (DataConI _ conType _) <- Name -> Q Info
reify Name
conName
    let r = Name -> Type
VarT (Name -> Type) -> Name -> Type
forall a b. (a -> b) -> a -> b
$ String -> Name
mkName String
"r"
        t = Name -> Type
VarT (Name -> Type) -> Name -> Type
forall a b. (a -> b) -> a -> b
$ String -> Name
mkName String
"t"
        fName = Name -> Name
conWriteFName Name
conName
        paramTypeList = BangType -> Type
forall a b. (a, b) -> b
snd (BangType -> Type) -> [BangType] -> [Type]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [BangType]
bangTypes
        parentType = Type -> Type
getParentTypeFromConstructorType Type
conType
    signature <- genConWriteSignature conName paramTypeList parentType r t
    -- for each parameter type, we create a name
    varNameAndType <- mapM (\Type
ty -> (,Type
ty) (Name -> (Name, Type)) -> Q Name -> Q (Name, Type)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> Q Name
forall (m :: * -> *). Quote m => String -> m Name
newName String
"t") paramTypeList
    -- we either call `encode` for every type parameter, and fold
    body <-
        foldl
            ( \Q Exp
rest (Name
paramName, Bool
needsSizeTag) ->
                -- We insert the size before
                if Bool
needsSizeTag
                    then [|$Q Exp
rest N.>> writeWithFieldSize $(Name -> Q Exp
forall (m :: * -> *). Quote m => Name -> m Exp
varE Name
paramName)|]
                    else [|$Q Exp
rest N.>> write $(Name -> Q Exp
forall (m :: * -> *). Quote m => Name -> m Exp
varE Name
paramName)|]
            )
            [|$(varE $ startFName conName)|]
            ( if InsertFieldSize `elem` flags
                then case reverse varNameAndType of
                    -- Here, 'a' is the last field. We insert a FieldSize iff SkipLastFieldSize is not set
                    ((Name, Type)
a : [(Name, Type)]
b) -> [(Name, Bool)] -> [(Name, Bool)]
forall a. [a] -> [a]
reverse ([(Name, Bool)] -> [(Name, Bool)])
-> [(Name, Bool)] -> [(Name, Bool)]
forall a b. (a -> b) -> a -> b
$ ((Name, Type) -> Name
forall a b. (a, b) -> a
fst (Name, Type)
a, PackingFlag
SkipLastFieldSize PackingFlag -> [PackingFlag] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` [PackingFlag]
flags) (Name, Bool) -> [(Name, Bool)] -> [(Name, Bool)]
forall a. a -> [a] -> [a]
: ((,Bool
True) (Name -> (Name, Bool))
-> ((Name, Type) -> Name) -> (Name, Type) -> (Name, Bool)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Name, Type) -> Name
forall a b. (a, b) -> a
fst ((Name, Type) -> (Name, Bool)) -> [(Name, Type)] -> [(Name, Bool)]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [(Name, Type)]
b)
                    [(Name, Type)]
x -> [(Name, Bool)] -> [(Name, Bool)]
forall a. [a] -> [a]
reverse ([(Name, Bool)] -> [(Name, Bool)])
-> [(Name, Bool)] -> [(Name, Bool)]
forall a b. (a -> b) -> a -> b
$ (,Bool
True) (Name -> (Name, Bool))
-> ((Name, Type) -> Name) -> (Name, Type) -> (Name, Bool)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Name, Type) -> Name
forall a b. (a, b) -> a
fst ((Name, Type) -> (Name, Bool)) -> [(Name, Type)] -> [(Name, Bool)]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [(Name, Type)]
x
                else (,False) . fst <$> varNameAndType
            )
    -- The pattern (lhs of '=' in a function implementation) will be something like '\a needs' for constructor 'Leaf a'
    let patt = Name -> Pat
VarP (Name -> Pat) -> ((Name, Type) -> Name) -> (Name, Type) -> Pat
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Name, Type) -> Name
forall a b. (a, b) -> a
fst ((Name, Type) -> Pat) -> [(Name, Type)] -> [Pat]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [(Name, Type)]
varNameAndType
    start <- genStart flags conName conIndex (snd <$> bangTypes)
    return $
        start
            ++ [ signature
               , FunD fName [Clause [] (NormalB $ LamE patt body) []]
               ]

-- Generates the function signature for functions like 'writeConLeaf'
-- writeConLeaf :: ('Packable' a) => a -> 'NeedsWriter (Tree a) r t'
genConWriteSignature :: Name -> [Type] -> Type -> Type -> Type -> Q Dec
genConWriteSignature :: Name -> [Type] -> Type -> Type -> Type -> Q Dec
genConWriteSignature Name
constructorName [Type]
constructorArgumentsTypes Type
parentType Type
restType Type
resultType = do
    let funName :: Name
funName = Name -> Name
conWriteFName Name
constructorName
        typeVariables :: [Type]
typeVariables = [Type] -> [Type]
filterDuplicates ([Type] -> [Type]) -> [Type] -> [Type]
forall a b. (a -> b) -> a -> b
$ (Type -> [Type]) -> [Type] -> [Type]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Type -> [Type]
getAllVarInType [Type]
constructorArgumentsTypes
        -- The signature without the constructor's parameters
        needsWriterType :: Q Type
needsWriterType = [t|NeedsWriter $(Type -> Q Type
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return Type
parentType) $(Type -> Q Type
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return Type
restType) $(Type -> Q Type
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return Type
resultType)|]
        constraints :: Q [Type]
constraints = (Type -> Q Type) -> [Type] -> Q [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (\Type
tyVar -> [t|(Packable $(Type -> Q Type
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return Type
tyVar))|]) [Type]
typeVariables
        funSignature :: Q Type
funSignature = (Type -> Q Type -> Q Type) -> Q Type -> [Type] -> Q Type
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (\Type
p Q Type
rest -> [t|$(Type -> Q Type
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return Type
p) -> $Q Type
rest|]) Q Type
needsWriterType [Type]
constructorArgumentsTypes
    Name -> Q Type -> Q Dec
forall (m :: * -> *). Quote m => Name -> m Type -> m Dec
sigD Name
funName (Q Type -> Q Dec) -> Q Type -> Q Dec
forall a b. (a -> b) -> a -> b
$ [TyVarBndr Specificity] -> Q [Type] -> Q Type -> Q Type
forall (m :: * -> *).
Quote m =>
[TyVarBndr Specificity] -> m [Type] -> m Type -> m Type
forallT [] Q [Type]
constraints Q Type
funSignature
  where
    getAllVarInType :: Type -> [Type]
getAllVarInType (AppT Type
a Type
b) = Type -> [Type]
getAllVarInType Type
a [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ Type -> [Type]
getAllVarInType Type
b
    getAllVarInType v :: Type
v@(VarT Name
_) = [Type
v]
    getAllVarInType Type
_ = []
    filterDuplicates :: [Type] -> [Type]
filterDuplicates = ([Type] -> Type) -> [[Type]] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map [Type] -> Type
forall a. HasCallStack => [a] -> a
head ([[Type]] -> [Type]) -> ([Type] -> [[Type]]) -> [Type] -> [Type]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [[Type]] -> [[Type]]
forall a. Ord a => [a] -> [a]
sort ([[Type]] -> [[Type]])
-> ([Type] -> [[Type]]) -> [Type] -> [[Type]]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Type] -> [[Type]]
forall a. Eq a => [a] -> [[a]]
group