{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecordWildCards #-}

-- zgossip protocol https://github.com/zeromq/czmq/blob/master/src/zgossip_msg.bnf
-- client sends HELLO, recieves all stored tuples, forwards to other clients

module Data.ZGossip (
    newZGS
  , parseZGS
  , encodeZGS
  , Key
  , Value
  , TTL
  , Peer
  , ZGSCmd(..)
  , ZGSMsg(..)
  ) where

import Prelude hiding (putStrLn, take)
import Data.ByteString (ByteString)

import GHC.Word

import Data.ZMQParse

-- | Version of the ZGossip protocol
zgsVer :: Int
zgsVer :: Int
zgsVer = Int
1

-- | Signature of the ZGossip protocol
zgsSig :: Word16
zgsSig :: Word16
zgsSig = Word16
0xAAA0

type Peer  = ByteString
type Key   = ByteString
type Value = ByteString
type TTL   = Int

data ZGSMsg = ZGSMsg {
    ZGSMsg -> Maybe ByteString
zgsFrom :: Maybe ByteString
  , ZGSMsg -> ZGSCmd
zgsCmd :: ZGSCmd
  } deriving (Int -> ZGSMsg -> ShowS
[ZGSMsg] -> ShowS
ZGSMsg -> String
(Int -> ZGSMsg -> ShowS)
-> (ZGSMsg -> String) -> ([ZGSMsg] -> ShowS) -> Show ZGSMsg
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [ZGSMsg] -> ShowS
$cshowList :: [ZGSMsg] -> ShowS
show :: ZGSMsg -> String
$cshow :: ZGSMsg -> String
showsPrec :: Int -> ZGSMsg -> ShowS
$cshowsPrec :: Int -> ZGSMsg -> ShowS
Show, ZGSMsg -> ZGSMsg -> Bool
(ZGSMsg -> ZGSMsg -> Bool)
-> (ZGSMsg -> ZGSMsg -> Bool) -> Eq ZGSMsg
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: ZGSMsg -> ZGSMsg -> Bool
$c/= :: ZGSMsg -> ZGSMsg -> Bool
== :: ZGSMsg -> ZGSMsg -> Bool
$c== :: ZGSMsg -> ZGSMsg -> Bool
Eq, Eq ZGSMsg
Eq ZGSMsg
-> (ZGSMsg -> ZGSMsg -> Ordering)
-> (ZGSMsg -> ZGSMsg -> Bool)
-> (ZGSMsg -> ZGSMsg -> Bool)
-> (ZGSMsg -> ZGSMsg -> Bool)
-> (ZGSMsg -> ZGSMsg -> Bool)
-> (ZGSMsg -> ZGSMsg -> ZGSMsg)
-> (ZGSMsg -> ZGSMsg -> ZGSMsg)
-> Ord ZGSMsg
ZGSMsg -> ZGSMsg -> Bool
ZGSMsg -> ZGSMsg -> Ordering
ZGSMsg -> ZGSMsg -> ZGSMsg
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: ZGSMsg -> ZGSMsg -> ZGSMsg
$cmin :: ZGSMsg -> ZGSMsg -> ZGSMsg
max :: ZGSMsg -> ZGSMsg -> ZGSMsg
$cmax :: ZGSMsg -> ZGSMsg -> ZGSMsg
>= :: ZGSMsg -> ZGSMsg -> Bool
$c>= :: ZGSMsg -> ZGSMsg -> Bool
> :: ZGSMsg -> ZGSMsg -> Bool
$c> :: ZGSMsg -> ZGSMsg -> Bool
<= :: ZGSMsg -> ZGSMsg -> Bool
$c<= :: ZGSMsg -> ZGSMsg -> Bool
< :: ZGSMsg -> ZGSMsg -> Bool
$c< :: ZGSMsg -> ZGSMsg -> Bool
compare :: ZGSMsg -> ZGSMsg -> Ordering
$ccompare :: ZGSMsg -> ZGSMsg -> Ordering
$cp1Ord :: Eq ZGSMsg
Ord)

data ZGSCmd =
    Hello
  | Publish Key Value TTL
  | Ping
  | PingOk
  | Invalid
  deriving (Int -> ZGSCmd -> ShowS
[ZGSCmd] -> ShowS
ZGSCmd -> String
(Int -> ZGSCmd -> ShowS)
-> (ZGSCmd -> String) -> ([ZGSCmd] -> ShowS) -> Show ZGSCmd
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [ZGSCmd] -> ShowS
$cshowList :: [ZGSCmd] -> ShowS
show :: ZGSCmd -> String
$cshow :: ZGSCmd -> String
showsPrec :: Int -> ZGSCmd -> ShowS
$cshowsPrec :: Int -> ZGSCmd -> ShowS
Show, ZGSCmd -> ZGSCmd -> Bool
(ZGSCmd -> ZGSCmd -> Bool)
-> (ZGSCmd -> ZGSCmd -> Bool) -> Eq ZGSCmd
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: ZGSCmd -> ZGSCmd -> Bool
$c/= :: ZGSCmd -> ZGSCmd -> Bool
== :: ZGSCmd -> ZGSCmd -> Bool
$c== :: ZGSCmd -> ZGSCmd -> Bool
Eq, Eq ZGSCmd
Eq ZGSCmd
-> (ZGSCmd -> ZGSCmd -> Ordering)
-> (ZGSCmd -> ZGSCmd -> Bool)
-> (ZGSCmd -> ZGSCmd -> Bool)
-> (ZGSCmd -> ZGSCmd -> Bool)
-> (ZGSCmd -> ZGSCmd -> Bool)
-> (ZGSCmd -> ZGSCmd -> ZGSCmd)
-> (ZGSCmd -> ZGSCmd -> ZGSCmd)
-> Ord ZGSCmd
ZGSCmd -> ZGSCmd -> Bool
ZGSCmd -> ZGSCmd -> Ordering
ZGSCmd -> ZGSCmd -> ZGSCmd
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: ZGSCmd -> ZGSCmd -> ZGSCmd
$cmin :: ZGSCmd -> ZGSCmd -> ZGSCmd
max :: ZGSCmd -> ZGSCmd -> ZGSCmd
$cmax :: ZGSCmd -> ZGSCmd -> ZGSCmd
>= :: ZGSCmd -> ZGSCmd -> Bool
$c>= :: ZGSCmd -> ZGSCmd -> Bool
> :: ZGSCmd -> ZGSCmd -> Bool
$c> :: ZGSCmd -> ZGSCmd -> Bool
<= :: ZGSCmd -> ZGSCmd -> Bool
$c<= :: ZGSCmd -> ZGSCmd -> Bool
< :: ZGSCmd -> ZGSCmd -> Bool
$c< :: ZGSCmd -> ZGSCmd -> Bool
compare :: ZGSCmd -> ZGSCmd -> Ordering
$ccompare :: ZGSCmd -> ZGSCmd -> Ordering
$cp1Ord :: Eq ZGSCmd
Ord)

cmdCode :: ZGSCmd -> Word8
cmdCode :: ZGSCmd -> Word8
cmdCode ZGSCmd
Hello           = Word8
0x01
cmdCode (Publish ByteString
_ ByteString
_ Int
_) = Word8
0x02
cmdCode ZGSCmd
Ping            = Word8
0x03
cmdCode ZGSCmd
PingOk          = Word8
0x04
cmdCode ZGSCmd
Invalid         = Word8
0x05

newZGS :: ZGSCmd -> ZGSMsg
newZGS :: ZGSCmd -> ZGSMsg
newZGS ZGSCmd
cmd = Maybe ByteString -> ZGSCmd -> ZGSMsg
ZGSMsg Maybe ByteString
forall a. Maybe a
Nothing ZGSCmd
cmd

encodeZGS :: ZGSMsg -> ByteString
encodeZGS :: ZGSMsg -> ByteString
encodeZGS ZGSMsg{Maybe ByteString
ZGSCmd
zgsCmd :: ZGSCmd
zgsFrom :: Maybe ByteString
zgsCmd :: ZGSMsg -> ZGSCmd
zgsFrom :: ZGSMsg -> Maybe ByteString
..} = ByteString
msg
  where
    msg :: ByteString
msg = Put -> ByteString
runPut (Put -> ByteString) -> Put -> ByteString
forall a b. (a -> b) -> a -> b
$ do
      Word16 -> Put
putWord16be Word16
zgsSig
      Word8 -> Put
putWord8 (Word8 -> Put) -> Word8 -> Put
forall a b. (a -> b) -> a -> b
$ ZGSCmd -> Word8
cmdCode ZGSCmd
zgsCmd
      Int8 -> Put
putInt8 (Int8 -> Put) -> Int8 -> Put
forall a b. (a -> b) -> a -> b
$ Int -> Int8
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
zgsVer
      ZGSCmd -> Put
encodeCmd ZGSCmd
zgsCmd

encodeCmd :: ZGSCmd -> PutM ()
encodeCmd :: ZGSCmd -> Put
encodeCmd (Publish ByteString
k ByteString
v Int
ttl) = do
  ByteString -> Put
putByteStringLen ByteString
k
  ByteString -> Put
putLongByteStringLen ByteString
v
  Int32 -> Put
putInt32be (Int32 -> Put) -> Int32 -> Put
forall a b. (a -> b) -> a -> b
$ Int -> Int32
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
ttl
encodeCmd ZGSCmd
_ = () -> Put
forall (m :: * -> *) a. Monad m => a -> m a
return ()

parsePublish :: Get ZGSCmd
parsePublish :: Get ZGSCmd
parsePublish = ByteString -> ByteString -> Int -> ZGSCmd
Publish
  (ByteString -> ByteString -> Int -> ZGSCmd)
-> Get ByteString -> Get (ByteString -> Int -> ZGSCmd)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Get ByteString
parseString
  Get (ByteString -> Int -> ZGSCmd)
-> Get ByteString -> Get (Int -> ZGSCmd)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Get ByteString
parseLongString
  Get (Int -> ZGSCmd) -> Get Int -> Get ZGSCmd
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Get Int
forall a. Integral a => Get a
getInt32

parseCmd :: ByteString -> Get ZGSMsg
parseCmd :: ByteString -> Get ZGSMsg
parseCmd ByteString
from = do
    Int
cmd <- (Get Int
forall a. Integral a => Get a
getInt8 :: Get Int)
    Int
ver <- Get Int
forall a. Integral a => Get a
getInt8

    if Int
ver Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
zgsVer
      then String -> Get ZGSMsg
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Protocol version mismatch"
      else do

        ZGSCmd
zcmd <- case Int
cmd of
          Int
0x01 -> ZGSCmd -> Get ZGSCmd
forall (f :: * -> *) a. Applicative f => a -> f a
pure ZGSCmd
Hello
          Int
0x02 -> Get ZGSCmd
parsePublish
          Int
0x03 -> ZGSCmd -> Get ZGSCmd
forall (f :: * -> *) a. Applicative f => a -> f a
pure ZGSCmd
Ping
          Int
0x04 -> ZGSCmd -> Get ZGSCmd
forall (f :: * -> *) a. Applicative f => a -> f a
pure ZGSCmd
PingOk
          Int
0x05 -> ZGSCmd -> Get ZGSCmd
forall (f :: * -> *) a. Applicative f => a -> f a
pure ZGSCmd
Invalid
          Int
_    -> String -> Get ZGSCmd
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Unknown command"

        ZGSMsg -> Get ZGSMsg
forall (m :: * -> *) a. Monad m => a -> m a
return (ZGSMsg -> Get ZGSMsg) -> ZGSMsg -> Get ZGSMsg
forall a b. (a -> b) -> a -> b
$ Maybe ByteString -> ZGSCmd -> ZGSMsg
ZGSMsg (ByteString -> Maybe ByteString
forall a. a -> Maybe a
Just ByteString
from) ZGSCmd
zcmd

parseZGS :: [ByteString] -> Either String ZGSMsg
parseZGS :: [ByteString] -> Either String ZGSMsg
parseZGS [ByteString
from, ByteString
msg] = ByteString -> ByteString -> Either String ZGSMsg
parseZgs ByteString
from ByteString
msg
parseZGS [ByteString]
_ = String -> Either String ZGSMsg
forall a b. a -> Either a b
Left String
"empty message"

parseZgs :: ByteString -> ByteString -> Either String ZGSMsg
parseZgs :: ByteString -> ByteString -> Either String ZGSMsg
parseZgs ByteString
from ByteString
msg = (Get ZGSMsg -> ByteString -> Either String ZGSMsg)
-> ByteString -> Get ZGSMsg -> Either String ZGSMsg
forall a b c. (a -> b -> c) -> b -> a -> c
flip Get ZGSMsg -> ByteString -> Either String ZGSMsg
forall a. Get a -> ByteString -> Either String a
runGet ByteString
msg (Get ZGSMsg -> Either String ZGSMsg)
-> Get ZGSMsg -> Either String ZGSMsg
forall a b. (a -> b) -> a -> b
$ do
  Word16
sig <- Get Word16
forall a. Integral a => Get a
getInt16
  if Word16
sig Word16 -> Word16 -> Bool
forall a. Eq a => a -> a -> Bool
/= Word16
zgsSig
    then String -> Get ZGSMsg
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Signature mismatch"
    else do
      ZGSMsg
res <- ByteString -> Get ZGSMsg
parseCmd ByteString
from
      ZGSMsg -> Get ZGSMsg
forall (m :: * -> *) a. Monad m => a -> m a
return ZGSMsg
res