{-# LANGUAGE CPP                        #-}
{-# LANGUAGE TemplateHaskell            #-}

module Data.API.Tools.CBOR
    ( cborTool
    ) where

import           Data.API.TH
import           Data.API.Tools.Combinators
import           Data.API.Tools.Datatypes
import           Data.API.Tools.Enum
import           Data.API.Types

import           Control.Applicative
import qualified Control.Monad.Fail as Fail
import           Codec.Serialise.Class
import           Codec.Serialise.Decoding
import           Codec.Serialise.Encoding
import           Data.Binary.Serialise.CBOR.Extra
import           Data.List (foldl', sortBy)
import qualified Data.Map                       as Map
import           Data.Monoid
import           Data.Ord (comparing)
import qualified Data.Text                      as T
import           Language.Haskell.TH
import           Prelude

-- | Tool to generate 'Serialise' instances for types generated by
-- 'datatypesTool'. This depends on 'enumTool'.
cborTool :: APITool
cborTool :: APITool
cborTool = Tool APINode -> APITool
apiNodeTool (Tool APINode -> APITool) -> Tool APINode -> APITool
forall a b. (a -> b) -> a -> b
$
             Tool (APINode, SpecNewtype)
-> Tool (APINode, SpecRecord)
-> Tool (APINode, SpecUnion)
-> Tool (APINode, SpecEnum)
-> Tool (APINode, APIType)
-> Tool APINode
apiSpecTool Tool (APINode, SpecNewtype)
gen_sn_to Tool (APINode, SpecRecord)
gen_sr_to Tool (APINode, SpecUnion)
gen_su_to Tool (APINode, SpecEnum)
gen_se_to Tool (APINode, APIType)
forall a. Monoid a => a
mempty
             Tool APINode -> Tool APINode -> Tool APINode
forall a. Semigroup a => a -> a -> a
<> Tool APINode
gen_pr

{-
instance Serialise JobId where
    encode = encode . _JobId
    decode = JobId <$> decode

In this version we don't check the @snFilter@, for simplicity and speed.
This is safe, since the CBOR code is used only internally as a data
representation format, not as a communication format with clients
that could potentially send faulty data.
-}

gen_sn_to :: Tool (APINode, SpecNewtype)
gen_sn_to :: Tool (APINode, SpecNewtype)
gen_sn_to = (ToolSettings -> (APINode, SpecNewtype) -> Q [Dec])
-> Tool (APINode, SpecNewtype)
forall a. (ToolSettings -> a -> Q [Dec]) -> Tool a
mkTool ((ToolSettings -> (APINode, SpecNewtype) -> Q [Dec])
 -> Tool (APINode, SpecNewtype))
-> (ToolSettings -> (APINode, SpecNewtype) -> Q [Dec])
-> Tool (APINode, SpecNewtype)
forall a b. (a -> b) -> a -> b
$ \ ToolSettings
ts (APINode
an, SpecNewtype
sn) -> ToolSettings -> Name -> [TypeQ] -> [DecQ] -> Q [Dec]
optionalInstanceD ToolSettings
ts ''Serialise [APINode -> TypeQ
nodeRepT APINode
an]
                                          [ Name -> ExpQ -> DecQ
simpleD 'encode (APINode -> SpecNewtype -> ExpQ
bdy_in APINode
an SpecNewtype
sn)
                                          , Name -> ExpQ -> DecQ
simpleD 'decode (ToolSettings -> APINode -> SpecNewtype -> ExpQ
bdy_out ToolSettings
ts APINode
an SpecNewtype
sn)]
  where
    bdy_in :: APINode -> SpecNewtype -> ExpQ
bdy_in APINode
an SpecNewtype
sn = [e| $(ine sn) . $(newtypeProjectionE an) |]
    bdy_out :: ToolSettings -> APINode -> SpecNewtype -> ExpQ
bdy_out ToolSettings
ts APINode
an SpecNewtype
sn = [e| $(nodeNewtypeConE ts an sn) <$> $(oute sn) |]

    ine :: SpecNewtype -> ExpQ
ine SpecNewtype
sn = case SpecNewtype -> BasicType
snType SpecNewtype
sn of
            BasicType
BTstring -> [e| encodeString |]
            BasicType
BTbinary -> [e| encode |]
            BasicType
BTbool   -> [e| encodeBool |]
            BasicType
BTint    -> [e| encodeInt |]
            BasicType
BTutc    -> [e| encode |]


    oute :: SpecNewtype -> ExpQ
oute SpecNewtype
sn =
        case SpecNewtype -> BasicType
snType SpecNewtype
sn of
            BasicType
BTstring -> [e| decodeString |]
            BasicType
BTbinary -> [e| decode |]
            BasicType
BTbool   -> [e| decodeBool |]
            BasicType
BTint    -> [e| decodeInt |]
            BasicType
BTutc    -> [e| decode |]



{-
instance Serialise JobSpecId where
     encode = \ x ->
        encodeMapLen 4 >>
        encodeRecordFields
            [ encodeString "Id"         <> encode (jsiId         x)
            , encodeString "Input"      <> encode (jsiInput      x)
            , encodeString "Output"     <> encode (jsiOutput     x)
            , encodeString "PipelineId" <> encode (jsiPipelineId x)
            ]
     decode (Record v) =
        decodeMapLen >>
        JobSpecId <$> (decodeString >> decode)
                  <*> (decodeString >> decode)
                  <*> (decodeString >> decode)
                  <*> (decodeString >> decode)

Note that fields are stored alphabetically ordered by field name, so
that we are insensitive to changes in field order in the schema.
-}

gen_sr_to :: Tool (APINode, SpecRecord)
gen_sr_to :: Tool (APINode, SpecRecord)
gen_sr_to = (ToolSettings -> (APINode, SpecRecord) -> Q [Dec])
-> Tool (APINode, SpecRecord)
forall a. (ToolSettings -> a -> Q [Dec]) -> Tool a
mkTool ((ToolSettings -> (APINode, SpecRecord) -> Q [Dec])
 -> Tool (APINode, SpecRecord))
-> (ToolSettings -> (APINode, SpecRecord) -> Q [Dec])
-> Tool (APINode, SpecRecord)
forall a b. (a -> b) -> a -> b
$ \ ToolSettings
ts (APINode
an, SpecRecord
sr) -> do
    Name
x <- String -> Q Name
newName String
"x"
    ToolSettings -> Name -> [TypeQ] -> [DecQ] -> Q [Dec]
optionalInstanceD ToolSettings
ts ''Serialise [APINode -> TypeQ
nodeRepT APINode
an] [ Name -> ExpQ -> DecQ
simpleD 'encode (APINode -> SpecRecord -> Name -> ExpQ
bdy_in APINode
an SpecRecord
sr Name
x)
                                                   , Name -> ExpQ -> DecQ
simpleD 'decode (APINode -> SpecRecord -> ExpQ
cl APINode
an SpecRecord
sr)
                                                   ]
  where
    bdy_in :: APINode -> SpecRecord -> Name -> ExpQ
bdy_in APINode
an SpecRecord
sr Name
x =
        let fields :: [(FieldName, FieldType)]
fields = SpecRecord -> [(FieldName, FieldType)]
sortFields SpecRecord
sr
            len :: Integer
len = Int -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral ([(FieldName, FieldType)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [(FieldName, FieldType)]
fields)  -- to Integer
            lenE :: ExpQ
lenE = Name -> ExpQ
varE 'fromIntegral  -- to Word
                     ExpQ -> ExpQ -> ExpQ
`appE` (ExpQ -> TypeQ -> ExpQ
sigE (Lit -> ExpQ
litE (Integer -> Lit
integerL Integer
len))
                                  (Name -> TypeQ
conT ''Integer))
            -- Micro-optimization: we use the statically known @len@ value
            -- instead of creating a list of thunks from the argument of
            -- @encodeRecordFields@ and dynamically calculating
            -- it's length, long before the list is fully forced.
            writeRecordHeader :: ExpQ
writeRecordHeader = Name -> ExpQ
varE 'encodeMapLen ExpQ -> ExpQ -> ExpQ
`appE` ExpQ
lenE
            encFields :: ExpQ
encFields =
                Name -> ExpQ
varE 'encodeRecordFields ExpQ -> ExpQ -> ExpQ
`appE`
                    [ExpQ] -> ExpQ
listE [ [e| encodeString $(fieldNameE fn)
                                <> encode ($(nodeFieldE an fn) $(varE x)) |]
                            | (FieldName
fn, FieldType
_fty) <- [(FieldName, FieldType)]
fields ]
        in [PatQ] -> ExpQ -> ExpQ
lamE [Name -> PatQ
varP Name
x] (ExpQ -> ExpQ) -> ExpQ -> ExpQ
forall a b. (a -> b) -> a -> b
$
               Name -> ExpQ
varE '(<>)
                 ExpQ -> ExpQ -> ExpQ
`appE` ExpQ
writeRecordHeader
                 ExpQ -> ExpQ -> ExpQ
`appE` ExpQ
encFields

    cl :: APINode -> SpecRecord -> ExpQ
cl APINode
an SpecRecord
sr    = Name -> ExpQ
varE '(>>)
                    ExpQ -> ExpQ -> ExpQ
`appE` (Name -> ExpQ
varE 'decodeMapLen)  -- TODO (extra check): check len with srFields
                    ExpQ -> ExpQ -> ExpQ
`appE` ExpQ
bdy
      where
        sorted_fields :: [FieldName]
sorted_fields   = ((FieldName, FieldType) -> FieldName)
-> [(FieldName, FieldType)] -> [FieldName]
forall a b. (a -> b) -> [a] -> [b]
map (FieldName, FieldType) -> FieldName
forall a b. (a, b) -> a
fst ([(FieldName, FieldType)] -> [FieldName])
-> [(FieldName, FieldType)] -> [FieldName]
forall a b. (a -> b) -> a -> b
$ SpecRecord -> [(FieldName, FieldType)]
sortFields SpecRecord
sr
        original_fields :: [FieldName]
original_fields = ((FieldName, FieldType) -> FieldName)
-> [(FieldName, FieldType)] -> [FieldName]
forall a b. (a -> b) -> [a] -> [b]
map (FieldName, FieldType) -> FieldName
forall a b. (a, b) -> a
fst ([(FieldName, FieldType)] -> [FieldName])
-> [(FieldName, FieldType)] -> [FieldName]
forall a b. (a -> b) -> a -> b
$ SpecRecord -> [(FieldName, FieldType)]
srFields SpecRecord
sr
        bdy :: ExpQ
bdy = ExpQ -> [ExpQ] -> ExpQ
applicativeE ExpQ
dataCon ([ExpQ] -> ExpQ) -> [ExpQ] -> ExpQ
forall a b. (a -> b) -> a -> b
$ (FieldName -> ExpQ) -> [FieldName] -> [ExpQ]
forall a b. (a -> b) -> [a] -> [b]
map FieldName -> ExpQ
forall p. p -> ExpQ
project [FieldName]
sorted_fields
        project :: p -> ExpQ
project p
_fn = [e| decodeString >> decode |]
          -- TODO (correctness): check that $(fieldNameE fn) matches the decoded name
          -- and if not, use the default value, etc.

        -- If the fields are sorted, just use the data constructor,
        -- but if not, generate a reordering function like
        --   \ _foo_a _foo_b -> Con _foo_b _foo_a
        dataCon :: ExpQ
dataCon | [FieldName]
sorted_fields [FieldName] -> [FieldName] -> Bool
forall a. Eq a => a -> a -> Bool
== [FieldName]
original_fields = APINode -> ExpQ
nodeConE APINode
an
                | Bool
otherwise = [PatQ] -> ExpQ -> ExpQ
lamE ((FieldName -> PatQ) -> [FieldName] -> [PatQ]
forall a b. (a -> b) -> [a] -> [b]
map (APINode -> FieldName -> PatQ
nodeFieldP APINode
an) [FieldName]
sorted_fields)
                                   ((ExpQ -> ExpQ -> ExpQ) -> ExpQ -> [ExpQ] -> ExpQ
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' ExpQ -> ExpQ -> ExpQ
appE (APINode -> ExpQ
nodeConE APINode
an) ((FieldName -> ExpQ) -> [FieldName] -> [ExpQ]
forall a b. (a -> b) -> [a] -> [b]
map (APINode -> FieldName -> ExpQ
nodeFieldE APINode
an) [FieldName]
original_fields))

    sortFields :: SpecRecord -> [(FieldName, FieldType)]
sortFields SpecRecord
sr = ((FieldName, FieldType) -> (FieldName, FieldType) -> Ordering)
-> [(FieldName, FieldType)] -> [(FieldName, FieldType)]
forall a. (a -> a -> Ordering) -> [a] -> [a]
sortBy (((FieldName, FieldType) -> FieldName)
-> (FieldName, FieldType) -> (FieldName, FieldType) -> Ordering
forall a b. Ord a => (b -> a) -> b -> b -> Ordering
comparing (FieldName, FieldType) -> FieldName
forall a b. (a, b) -> a
fst) ([(FieldName, FieldType)] -> [(FieldName, FieldType)])
-> [(FieldName, FieldType)] -> [(FieldName, FieldType)]
forall a b. (a -> b) -> a -> b
$ SpecRecord -> [(FieldName, FieldType)]
srFields SpecRecord
sr

{-
instance Serialise Foo where
    encode (Bar x) = encodeUnion "x" x
    encode (Baz x) = object [ "y" .= x ]
    decode = decodeUnion [ ("x", fmap Bar . decode)
                         , ("y", fmap Baz . decode) ]

-}

gen_su_to :: Tool (APINode, SpecUnion)
gen_su_to :: Tool (APINode, SpecUnion)
gen_su_to = (ToolSettings -> (APINode, SpecUnion) -> Q [Dec])
-> Tool (APINode, SpecUnion)
forall a. (ToolSettings -> a -> Q [Dec]) -> Tool a
mkTool ((ToolSettings -> (APINode, SpecUnion) -> Q [Dec])
 -> Tool (APINode, SpecUnion))
-> (ToolSettings -> (APINode, SpecUnion) -> Q [Dec])
-> Tool (APINode, SpecUnion)
forall a b. (a -> b) -> a -> b
$ \ ToolSettings
ts (APINode
an, SpecUnion
su) -> ToolSettings -> Name -> [TypeQ] -> [DecQ] -> Q [Dec]
optionalInstanceD ToolSettings
ts ''Serialise [APINode -> TypeQ
nodeRepT APINode
an]
                                        [ Name -> [ClauseQ] -> DecQ
funD    'encode (APINode -> SpecUnion -> [ClauseQ]
cls APINode
an SpecUnion
su)
                                        , Name -> ExpQ -> DecQ
simpleD 'decode (APINode -> SpecUnion -> ExpQ
bdy_out APINode
an SpecUnion
su)
                                        ]
  where
    cls :: APINode -> SpecUnion -> [ClauseQ]
cls APINode
an SpecUnion
su = ((FieldName, (APIType, String)) -> ClauseQ)
-> [(FieldName, (APIType, String))] -> [ClauseQ]
forall a b. (a -> b) -> [a] -> [b]
map (APINode -> (FieldName, (APIType, String)) -> ClauseQ
forall a b. APINode -> (FieldName, (a, b)) -> ClauseQ
cl APINode
an) (SpecUnion -> [(FieldName, (APIType, String))]
suFields SpecUnion
su)

    cl :: APINode -> (FieldName, (a, b)) -> ClauseQ
cl APINode
an (FieldName
fn, (a
_ty, b
_)) = do
      Name
x <- String -> Q Name
newName String
"x"
      [PatQ] -> BodyQ -> [DecQ] -> ClauseQ
clause [APINode -> FieldName -> [PatQ] -> PatQ
nodeAltConP APINode
an FieldName
fn [Name -> PatQ
varP Name
x]] (FieldName -> Name -> BodyQ
bdy FieldName
fn Name
x) []

    bdy :: FieldName -> Name -> BodyQ
bdy FieldName
fn Name
x = ExpQ -> BodyQ
normalB [e| encodeUnion $(fieldNameE fn) (encode $(varE x)) |]


    bdy_out :: APINode -> SpecUnion -> ExpQ
bdy_out APINode
an SpecUnion
su = Name -> ExpQ
varE 'decodeUnion ExpQ -> ExpQ -> ExpQ
`appE` [ExpQ] -> ExpQ
listE (((FieldName, (APIType, String)) -> ExpQ)
-> [(FieldName, (APIType, String))] -> [ExpQ]
forall a b. (a -> b) -> [a] -> [b]
map (APINode -> (FieldName, (APIType, String)) -> ExpQ
forall b. APINode -> (FieldName, b) -> ExpQ
alt APINode
an) (SpecUnion -> [(FieldName, (APIType, String))]
suFields SpecUnion
su))

    alt :: APINode -> (FieldName, b) -> ExpQ
alt APINode
an (FieldName
fn, b
_) = [e| ( $(fieldNameE fn) , fmap $(nodeAltConE an fn) decode ) |]


{-
instance Serialise FrameRate where
    encode = encodeString . _text_FrameRate
    decode = decodeString >>= cborStrMap_p _map_FrameRate
-}

gen_se_to :: Tool (APINode, SpecEnum)
gen_se_to :: Tool (APINode, SpecEnum)
gen_se_to = (ToolSettings -> (APINode, SpecEnum) -> Q [Dec])
-> Tool (APINode, SpecEnum)
forall a. (ToolSettings -> a -> Q [Dec]) -> Tool a
mkTool ((ToolSettings -> (APINode, SpecEnum) -> Q [Dec])
 -> Tool (APINode, SpecEnum))
-> (ToolSettings -> (APINode, SpecEnum) -> Q [Dec])
-> Tool (APINode, SpecEnum)
forall a b. (a -> b) -> a -> b
$ \ ToolSettings
ts (APINode
an, SpecEnum
_se) -> ToolSettings -> Name -> [TypeQ] -> [DecQ] -> Q [Dec]
optionalInstanceD ToolSettings
ts ''Serialise [APINode -> TypeQ
nodeRepT APINode
an]
                                         [ Name -> ExpQ -> DecQ
simpleD 'encode (APINode -> ExpQ
bdy_in APINode
an)
                                         , Name -> ExpQ -> DecQ
simpleD 'decode (APINode -> ExpQ
bdy_out APINode
an)
                                         ]
  where
    bdy_in :: APINode -> ExpQ
bdy_in APINode
an = [e| encodeString . $(varE (text_enum_nm an)) |]

    bdy_out :: APINode -> ExpQ
bdy_out APINode
an = [e| decodeString >>= cborStrMap_p $(varE (map_enum_nm an)) |]

-- In a monad, to @fail@ instead of crashing with @error@.
cborStrMap_p :: (Fail.MonadFail m, Ord a) => Map.Map T.Text a -> T.Text -> m a
cborStrMap_p :: Map Text a -> Text -> m a
cborStrMap_p Map Text a
mp Text
t = case Text -> Map Text a -> Maybe a
forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup Text
t Map Text a
mp of
  Maybe a
Nothing -> String -> m a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Unexpected enumeration key in CBOR"
  Just a
r -> a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return a
r


gen_pr :: Tool APINode
gen_pr :: Tool APINode
gen_pr = (ToolSettings -> APINode -> Q [Dec]) -> Tool APINode
forall a. (ToolSettings -> a -> Q [Dec]) -> Tool a
mkTool ((ToolSettings -> APINode -> Q [Dec]) -> Tool APINode)
-> (ToolSettings -> APINode -> Q [Dec]) -> Tool APINode
forall a b. (a -> b) -> a -> b
$ \ ToolSettings
ts APINode
an -> case APINode -> Conversion
anConvert APINode
an of
  Conversion
Nothing               -> [Dec] -> Q [Dec]
forall (m :: * -> *) a. Monad m => a -> m a
return []
  Just (FieldName
inj_fn, FieldName
prj_fn) -> ToolSettings -> Name -> [TypeQ] -> [DecQ] -> Q [Dec]
optionalInstanceD ToolSettings
ts ''Serialise [APINode -> TypeQ
nodeT APINode
an] [ Name -> ExpQ -> DecQ
simpleD 'encode ExpQ
bdy_in
                                                                       , Name -> ExpQ -> DecQ
simpleD 'decode ExpQ
bdy_out
                                                                       ]
   where
    bdy_in :: ExpQ
bdy_in  = [e| encode . $(fieldNameVarE prj_fn) |]
    bdy_out :: ExpQ
bdy_out = [e| decode >>= $(fieldNameVarE inj_fn) |]