{-# LANGUAGE CPP                        #-}
{-# LANGUAGE TemplateHaskell            #-}

module Data.API.Tools.CBOR
    ( cborTool
    ) where

import           Data.API.TH
import           Data.API.Tools.Combinators
import           Data.API.Tools.Datatypes
import           Data.API.Tools.Enum
import           Data.API.Types

import           Control.Applicative
import qualified Control.Monad.Fail as Fail
import           Codec.Serialise.Class
import           Codec.Serialise.Decoding
import           Codec.Serialise.Encoding
import           Data.Binary.Serialise.CBOR.Extra
import           Data.List (foldl', sortBy)
import qualified Data.Map                       as Map
import           Data.Monoid
import           Data.Ord (comparing)
import qualified Data.Text                      as T
import           Language.Haskell.TH
import           Prelude

-- | Tool to generate 'Serialise' instances for types generated by
-- 'datatypesTool'. This depends on 'enumTool'.
cborTool :: APITool
cborTool :: APITool
cborTool = Tool APINode -> APITool
apiNodeTool forall a b. (a -> b) -> a -> b
$
             Tool (APINode, SpecNewtype)
-> Tool (APINode, SpecRecord)
-> Tool (APINode, SpecUnion)
-> Tool (APINode, SpecEnum)
-> Tool (APINode, APIType)
-> Tool APINode
apiSpecTool Tool (APINode, SpecNewtype)
gen_sn_to Tool (APINode, SpecRecord)
gen_sr_to Tool (APINode, SpecUnion)
gen_su_to Tool (APINode, SpecEnum)
gen_se_to forall a. Monoid a => a
mempty
             forall a. Semigroup a => a -> a -> a
<> Tool APINode
gen_pr

{-
instance Serialise JobId where
    encode = encode . _JobId
    decode = JobId <$> decode

In this version we don't check the @snFilter@, for simplicity and speed.
This is safe, since the CBOR code is used only internally as a data
representation format, not as a communication format with clients
that could potentially send faulty data.
-}

gen_sn_to :: Tool (APINode, SpecNewtype)
gen_sn_to :: Tool (APINode, SpecNewtype)
gen_sn_to = forall a. (ToolSettings -> a -> Q [Dec]) -> Tool a
mkTool forall a b. (a -> b) -> a -> b
$ \ ToolSettings
ts (APINode
an, SpecNewtype
sn) -> ToolSettings -> Name -> [TypeQ] -> [DecQ] -> Q [Dec]
optionalInstanceD ToolSettings
ts ''Serialise [APINode -> TypeQ
nodeRepT APINode
an]
                                          [ Name -> ExpQ -> DecQ
simpleD 'encode (APINode -> SpecNewtype -> ExpQ
bdy_in APINode
an SpecNewtype
sn)
                                          , Name -> ExpQ -> DecQ
simpleD 'decode (ToolSettings -> APINode -> SpecNewtype -> ExpQ
bdy_out ToolSettings
ts APINode
an SpecNewtype
sn)]
  where
    bdy_in :: APINode -> SpecNewtype -> ExpQ
bdy_in APINode
an SpecNewtype
sn = [e| $(ine sn) . $(newtypeProjectionE an) |]
    bdy_out :: ToolSettings -> APINode -> SpecNewtype -> ExpQ
bdy_out ToolSettings
ts APINode
an SpecNewtype
sn = [e| $(nodeNewtypeConE ts an sn) <$> $(oute sn) |]

    ine :: SpecNewtype -> m Exp
ine SpecNewtype
sn = case SpecNewtype -> BasicType
snType SpecNewtype
sn of
            BasicType
BTstring -> [e| encodeString |]
            BasicType
BTbinary -> [e| encode |]
            BasicType
BTbool   -> [e| encodeBool |]
            BasicType
BTint    -> [e| encodeInt |]
            BasicType
BTutc    -> [e| encode |]


    oute :: SpecNewtype -> m Exp
oute SpecNewtype
sn =
        case SpecNewtype -> BasicType
snType SpecNewtype
sn of
            BasicType
BTstring -> [e| decodeString |]
            BasicType
BTbinary -> [e| decode |]
            BasicType
BTbool   -> [e| decodeBool |]
            BasicType
BTint    -> [e| decodeInt |]
            BasicType
BTutc    -> [e| decode |]



{-
instance Serialise JobSpecId where
     encode = \ (JobSpecId _jsi_id _jsi_input _jsi_output _jsi_pipelineId) ->
        encodeMapLen 4 <>
        encodeRecordFields
            [ encodeString "Id"         <> encode _jsi_id
            , encodeString "Input"      <> encode _jsi_input
            , encodeString "Output"     <> encode _jsi_output
            , encodeString "PipelineId" <> encode _jsi_pipelineId
            ]
     decode (Record v) =
        decodeMapLen >>
        JobSpecId <$> (decodeString >> decode)
                  <*> (decodeString >> decode)
                  <*> (decodeString >> decode)
                  <*> (decodeString >> decode)

Note that fields are stored alphabetically ordered by field name, so
that we are insensitive to changes in field order in the schema.


Previously we generated code like this:

     encode = \ x ->
         encodeMapLen 4 <>
         encodeRecordFields
            [ encodeString "Id"         <> encode (_jsi_id         x)
            , encodeString "Input"      <> encode (_jsi_input      x)
            , encodeString "Output"     <> encode (_jsi_output     x)
            , encodeString "PipelineId" <> encode (_jsi_pipelineId x)
            ]

This binds the record to the variable `x` and uses the record selectors to
project out the components. As a consequence, we can end up retaining the entire
record until the very end of encoding it. This is a problem if the record is
constructed lazily and each component would otherwise have been freed once it
was encoded, because we end up realising the whole thing in memory rather than
being incremental.

The fix is to pattern-match once on the value to be serialised and bind its
components separately. Now the record constructor is garbage once we evaluate
the outer pattern-match, and we can free individual fields once they are
encoded.

One might hope that the selector thunk optimisation would squash this
automatically, but that is somewhat fragile and may not apply at all to large
records (see https://gitlab.haskell.org/ghc/ghc/-/issues/20139).

-}

gen_sr_to :: Tool (APINode, SpecRecord)
gen_sr_to :: Tool (APINode, SpecRecord)
gen_sr_to = forall a. (ToolSettings -> a -> Q [Dec]) -> Tool a
mkTool forall a b. (a -> b) -> a -> b
$ \ ToolSettings
ts (APINode
an, SpecRecord
sr) ->
    ToolSettings -> Name -> [TypeQ] -> [DecQ] -> Q [Dec]
optionalInstanceD ToolSettings
ts ''Serialise [APINode -> TypeQ
nodeRepT APINode
an] [ Name -> ExpQ -> DecQ
simpleD 'encode (APINode -> SpecRecord -> ExpQ
bdy_in APINode
an SpecRecord
sr)
                                                   , Name -> ExpQ -> DecQ
simpleD 'decode (APINode -> SpecRecord -> ExpQ
cl APINode
an SpecRecord
sr)
                                                   ]
  where
    bdy_in :: APINode -> SpecRecord -> ExpQ
bdy_in APINode
an SpecRecord
sr =
        let fields :: [(FieldName, FieldType)]
fields = SpecRecord -> [(FieldName, FieldType)]
sortFields SpecRecord
sr
            len :: Integer
len = forall a b. (Integral a, Num b) => a -> b
fromIntegral (forall (t :: * -> *) a. Foldable t => t a -> Int
length [(FieldName, FieldType)]
fields)  -- to Integer
            lenE :: ExpQ
lenE = forall (m :: * -> *). Quote m => Name -> m Exp
varE 'fromIntegral  -- to Word
                     forall (m :: * -> *). Quote m => m Exp -> m Exp -> m Exp
`appE` (forall (m :: * -> *). Quote m => m Exp -> m Type -> m Exp
sigE (forall (m :: * -> *). Quote m => Lit -> m Exp
litE (Integer -> Lit
integerL Integer
len))
                                  (forall (m :: * -> *). Quote m => Name -> m Type
conT ''Integer))
            -- Micro-optimization: we use the statically known @len@ value
            -- instead of creating a list of thunks from the argument of
            -- @encodeRecordFields@ and dynamically calculating
            -- it's length, long before the list is fully forced.
            writeRecordHeader :: ExpQ
writeRecordHeader = forall (m :: * -> *). Quote m => Name -> m Exp
varE 'encodeMapLen forall (m :: * -> *). Quote m => m Exp -> m Exp -> m Exp
`appE` ExpQ
lenE
            encFields :: ExpQ
encFields =
                forall (m :: * -> *). Quote m => Name -> m Exp
varE 'encodeRecordFields forall (m :: * -> *). Quote m => m Exp -> m Exp -> m Exp
`appE`
                    forall (m :: * -> *). Quote m => [m Exp] -> m Exp
listE [ [e| encodeString $(fieldNameE fn)
                                <> encode $(nodeFieldE an fn) |]
                            | (FieldName
fn, FieldType
_fty) <- [(FieldName, FieldType)]
fields ]
        in forall (m :: * -> *). Quote m => [m Pat] -> m Exp -> m Exp
lamE [APINode -> [Q Pat] -> Q Pat
nodeConP APINode
an [APINode -> FieldName -> Q Pat
nodeFieldP APINode
an FieldName
fn | (FieldName
fn, FieldType
_) <- SpecRecord -> [(FieldName, FieldType)]
srFields SpecRecord
sr ]] forall a b. (a -> b) -> a -> b
$
               forall (m :: * -> *). Quote m => Name -> m Exp
varE '(<>)
                 forall (m :: * -> *). Quote m => m Exp -> m Exp -> m Exp
`appE` ExpQ
writeRecordHeader
                 forall (m :: * -> *). Quote m => m Exp -> m Exp -> m Exp
`appE` ExpQ
encFields

    cl :: APINode -> SpecRecord -> ExpQ
cl APINode
an SpecRecord
sr    = forall (m :: * -> *). Quote m => Name -> m Exp
varE '(>>)
                    forall (m :: * -> *). Quote m => m Exp -> m Exp -> m Exp
`appE` (forall (m :: * -> *). Quote m => Name -> m Exp
varE 'decodeMapLen)  -- TODO (extra check): check len with srFields
                    forall (m :: * -> *). Quote m => m Exp -> m Exp -> m Exp
`appE` ExpQ
bdy
      where
        sorted_fields :: [FieldName]
sorted_fields   = forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst forall a b. (a -> b) -> a -> b
$ SpecRecord -> [(FieldName, FieldType)]
sortFields SpecRecord
sr
        original_fields :: [FieldName]
original_fields = forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst forall a b. (a -> b) -> a -> b
$ SpecRecord -> [(FieldName, FieldType)]
srFields SpecRecord
sr
        bdy :: ExpQ
bdy = ExpQ -> [ExpQ] -> ExpQ
applicativeE ExpQ
dataCon forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall {m :: * -> *} {p}. Quote m => p -> m Exp
project [FieldName]
sorted_fields
        project :: p -> m Exp
project p
_fn = [e| decodeString >> decode |]
          -- TODO (correctness): check that $(fieldNameE fn) matches the decoded name
          -- and if not, use the default value, etc.

        -- If the fields are sorted, just use the data constructor,
        -- but if not, generate a reordering function like
        --   \ _foo_a _foo_b -> Con _foo_b _foo_a
        dataCon :: ExpQ
dataCon | [FieldName]
sorted_fields forall a. Eq a => a -> a -> Bool
== [FieldName]
original_fields = APINode -> ExpQ
nodeConE APINode
an
                | Bool
otherwise = forall (m :: * -> *). Quote m => [m Pat] -> m Exp -> m Exp
lamE (forall a b. (a -> b) -> [a] -> [b]
map (APINode -> FieldName -> Q Pat
nodeFieldP APINode
an) [FieldName]
sorted_fields)
                                   (forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' forall (m :: * -> *). Quote m => m Exp -> m Exp -> m Exp
appE (APINode -> ExpQ
nodeConE APINode
an) (forall a b. (a -> b) -> [a] -> [b]
map (APINode -> FieldName -> ExpQ
nodeFieldE APINode
an) [FieldName]
original_fields))

    sortFields :: SpecRecord -> [(FieldName, FieldType)]
sortFields SpecRecord
sr = forall a. (a -> a -> Ordering) -> [a] -> [a]
sortBy (forall a b. Ord a => (b -> a) -> b -> b -> Ordering
comparing forall a b. (a, b) -> a
fst) forall a b. (a -> b) -> a -> b
$ SpecRecord -> [(FieldName, FieldType)]
srFields SpecRecord
sr

{-
instance Serialise Foo where
    encode (Bar x) = encodeUnion "x" x
    encode (Baz x) = object [ "y" .= x ]
    decode = decodeUnion [ ("x", fmap Bar . decode)
                         , ("y", fmap Baz . decode) ]

-}

gen_su_to :: Tool (APINode, SpecUnion)
gen_su_to :: Tool (APINode, SpecUnion)
gen_su_to = forall a. (ToolSettings -> a -> Q [Dec]) -> Tool a
mkTool forall a b. (a -> b) -> a -> b
$ \ ToolSettings
ts (APINode
an, SpecUnion
su) -> ToolSettings -> Name -> [TypeQ] -> [DecQ] -> Q [Dec]
optionalInstanceD ToolSettings
ts ''Serialise [APINode -> TypeQ
nodeRepT APINode
an]
                                        [ forall (m :: * -> *). Quote m => Name -> [m Clause] -> m Dec
funD    'encode (APINode -> SpecUnion -> [Q Clause]
cls APINode
an SpecUnion
su)
                                        , Name -> ExpQ -> DecQ
simpleD 'decode (APINode -> SpecUnion -> ExpQ
bdy_out APINode
an SpecUnion
su)
                                        ]
  where
    cls :: APINode -> SpecUnion -> [Q Clause]
cls APINode
an SpecUnion
su = forall a b. (a -> b) -> [a] -> [b]
map (forall {a} {b}. APINode -> (FieldName, (a, b)) -> Q Clause
cl APINode
an) (SpecUnion -> [(FieldName, (APIType, MDComment))]
suFields SpecUnion
su)

    cl :: APINode -> (FieldName, (a, b)) -> Q Clause
cl APINode
an (FieldName
fn, (a
_ty, b
_)) = do
      Name
x <- forall (m :: * -> *). Quote m => MDComment -> m Name
newName MDComment
"x"
      forall (m :: * -> *).
Quote m =>
[m Pat] -> m Body -> [m Dec] -> m Clause
clause [APINode -> FieldName -> [Q Pat] -> Q Pat
nodeAltConP APINode
an FieldName
fn [forall (m :: * -> *). Quote m => Name -> m Pat
varP Name
x]] (FieldName -> Name -> Q Body
bdy FieldName
fn Name
x) []

    bdy :: FieldName -> Name -> Q Body
bdy FieldName
fn Name
x = forall (m :: * -> *). Quote m => m Exp -> m Body
normalB [e| encodeUnion $(fieldNameE fn) (encode $(varE x)) |]


    bdy_out :: APINode -> SpecUnion -> ExpQ
bdy_out APINode
an SpecUnion
su = forall (m :: * -> *). Quote m => Name -> m Exp
varE 'decodeUnion forall (m :: * -> *). Quote m => m Exp -> m Exp -> m Exp
`appE` forall (m :: * -> *). Quote m => [m Exp] -> m Exp
listE (forall a b. (a -> b) -> [a] -> [b]
map (forall {b}. APINode -> (FieldName, b) -> ExpQ
alt APINode
an) (SpecUnion -> [(FieldName, (APIType, MDComment))]
suFields SpecUnion
su))

    alt :: APINode -> (FieldName, b) -> ExpQ
alt APINode
an (FieldName
fn, b
_) = [e| ( $(fieldNameE fn) , fmap $(nodeAltConE an fn) decode ) |]


{-
instance Serialise FrameRate where
    encode = encodeString . _text_FrameRate
    decode = decodeString >>= cborStrMap_p _map_FrameRate
-}

gen_se_to :: Tool (APINode, SpecEnum)
gen_se_to :: Tool (APINode, SpecEnum)
gen_se_to = forall a. (ToolSettings -> a -> Q [Dec]) -> Tool a
mkTool forall a b. (a -> b) -> a -> b
$ \ ToolSettings
ts (APINode
an, SpecEnum
_se) -> ToolSettings -> Name -> [TypeQ] -> [DecQ] -> Q [Dec]
optionalInstanceD ToolSettings
ts ''Serialise [APINode -> TypeQ
nodeRepT APINode
an]
                                         [ Name -> ExpQ -> DecQ
simpleD 'encode (forall {m :: * -> *}. Quote m => APINode -> m Exp
bdy_in APINode
an)
                                         , Name -> ExpQ -> DecQ
simpleD 'decode (forall {m :: * -> *}. Quote m => APINode -> m Exp
bdy_out APINode
an)
                                         ]
  where
    bdy_in :: APINode -> m Exp
bdy_in APINode
an = [e| encodeString . $(varE (text_enum_nm an)) |]

    bdy_out :: APINode -> m Exp
bdy_out APINode
an = [e| decodeString >>= cborStrMap_p $(varE (map_enum_nm an)) |]

-- In a monad, to @fail@ instead of crashing with @error@.
cborStrMap_p :: (Fail.MonadFail m, Ord a) => Map.Map T.Text a -> T.Text -> m a
cborStrMap_p :: forall (m :: * -> *) a.
(MonadFail m, Ord a) =>
Map Text a -> Text -> m a
cborStrMap_p Map Text a
mp Text
t = case forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup Text
t Map Text a
mp of
  Maybe a
Nothing -> forall (m :: * -> *) a. MonadFail m => MDComment -> m a
fail MDComment
"Unexpected enumeration key in CBOR"
  Just a
r -> forall (m :: * -> *) a. Monad m => a -> m a
return a
r


gen_pr :: Tool APINode
gen_pr :: Tool APINode
gen_pr = forall a. (ToolSettings -> a -> Q [Dec]) -> Tool a
mkTool forall a b. (a -> b) -> a -> b
$ \ ToolSettings
ts APINode
an -> case APINode -> Conversion
anConvert APINode
an of
  Conversion
Nothing               -> forall (m :: * -> *) a. Monad m => a -> m a
return []
  Just (FieldName
inj_fn, FieldName
prj_fn) -> ToolSettings -> Name -> [TypeQ] -> [DecQ] -> Q [Dec]
optionalInstanceD ToolSettings
ts ''Serialise [APINode -> TypeQ
nodeT APINode
an] [ Name -> ExpQ -> DecQ
simpleD 'encode ExpQ
bdy_in
                                                                       , Name -> ExpQ -> DecQ
simpleD 'decode ExpQ
bdy_out
                                                                       ]
   where
    bdy_in :: ExpQ
bdy_in  = [e| encode . $(fieldNameVarE prj_fn) |]
    bdy_out :: ExpQ
bdy_out = [e| decode >>= $(fieldNameVarE inj_fn) |]