module Network.ONCRPC.Transport
  ( sendTransport
  , recvTransport
  , TransportState
  , transportStart
  , recvGetFirst
  , recvGetNext
  ) where

import qualified Data.ByteString as BS
import qualified Data.ByteString.Lazy as BSL
import qualified Data.Serialize.Get as S
import qualified Network.Socket as Net

import           Network.ONCRPC.RecordMarking

sendTransport :: Net.Socket -> BSL.ByteString -> IO ()
sendTransport :: Socket -> ByteString -> IO ()
sendTransport Socket
sock ByteString
b = do
  SocketType
t <- Socket -> IO SocketType
Net.getSocketType Socket
sock
  if SocketType
t forall a. Eq a => a -> a -> Bool
== SocketType
Net.Stream
    then Socket -> ByteString -> IO ()
sendRecord Socket
sock ByteString
b
    else forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"ONCRPC: Unsupported socket type"

recvTransport :: Net.Socket -> RecordState -> IO (BS.ByteString, RecordState)
recvTransport :: Socket -> RecordState -> IO (ByteString, RecordState)
recvTransport Socket
sock RecordState
r = do
  SocketType
t <- Socket -> IO SocketType
Net.getSocketType Socket
sock
  if SocketType
t forall a. Eq a => a -> a -> Bool
== SocketType
Net.Stream
    then Socket -> RecordState -> IO (ByteString, RecordState)
recvRecord Socket
sock RecordState
r
    else forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"ONCRPC: Unsupported socket type"

data TransportState = TransportState
  { TransportState -> ByteString
_bufferState :: BS.ByteString
  , TransportState -> RecordState
recordState :: RecordState
  }
  deriving (TransportState -> TransportState -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: TransportState -> TransportState -> Bool
$c/= :: TransportState -> TransportState -> Bool
== :: TransportState -> TransportState -> Bool
$c== :: TransportState -> TransportState -> Bool
Eq, Int -> TransportState -> ShowS
[TransportState] -> ShowS
TransportState -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [TransportState] -> ShowS
$cshowList :: [TransportState] -> ShowS
show :: TransportState -> String
$cshow :: TransportState -> String
showsPrec :: Int -> TransportState -> ShowS
$cshowsPrec :: Int -> TransportState -> ShowS
Show)

transportNext :: RecordState -> TransportState
transportNext :: RecordState -> TransportState
transportNext = ByteString -> RecordState -> TransportState
TransportState ByteString
BS.empty

transportStart :: TransportState
transportStart :: TransportState
transportStart = RecordState -> TransportState
transportNext RecordState
RecordStart

recvTransportWith :: Net.Socket -> RecordState -> (BS.ByteString -> RecordState -> IO (Maybe a)) -> IO (Maybe a)
recvTransportWith :: forall a.
Socket
-> RecordState
-> (ByteString -> RecordState -> IO (Maybe a))
-> IO (Maybe a)
recvTransportWith Socket
sock RecordState
rs ByteString -> RecordState -> IO (Maybe a)
f = do
  (ByteString
b, RecordState
rs') <- Socket -> RecordState -> IO (ByteString, RecordState)
recvTransport Socket
sock RecordState
rs
  if ByteString -> Bool
BS.null ByteString
b
    then forall (m :: * -> *) a. Monad m => a -> m a
return forall a. Maybe a
Nothing
    else ByteString -> RecordState -> IO (Maybe a)
f ByteString
b RecordState
rs'

-- |Get the next part of the current record, after calling 'recvGetFirst' to start.
recvGetNext :: Net.Socket -> S.Get a -> TransportState -> IO (Maybe (Either String a, TransportState))
recvGetNext :: forall a.
Socket
-> Get a
-> TransportState
-> IO (Maybe (Either String a, TransportState))
recvGetNext Socket
sock Get a
getter = TransportState -> IO (Maybe (Either String a, TransportState))
start where
  start :: TransportState -> IO (Maybe (Either String a, TransportState))
start (TransportState ByteString
b RecordState
rs) -- continue record
    | ByteString -> Bool
BS.null ByteString
b = Maybe (ByteString -> Result a)
-> RecordState -> IO (Maybe (Either String a, TransportState))
get forall a. Maybe a
Nothing RecordState
rs -- check for more
    | Bool
otherwise = Maybe (ByteString -> Result a)
-> ByteString
-> RecordState
-> IO (Maybe (Either String a, TransportState))
got forall a. Maybe a
Nothing ByteString
b RecordState
rs -- buffered data
  get :: Maybe (ByteString -> Result a)
-> RecordState -> IO (Maybe (Either String a, TransportState))
get Maybe (ByteString -> Result a)
f RecordState
RecordStart = Maybe (ByteString -> Result a)
-> ByteString
-> RecordState
-> IO (Maybe (Either String a, TransportState))
got Maybe (ByteString -> Result a)
f ByteString
BS.empty RecordState
RecordStart -- end of record
  get Maybe (ByteString -> Result a)
f RecordState
rs = forall a.
Socket
-> RecordState
-> (ByteString -> RecordState -> IO (Maybe a))
-> IO (Maybe a)
recvTransportWith Socket
sock RecordState
rs forall a b. (a -> b) -> a -> b
$ Maybe (ByteString -> Result a)
-> ByteString
-> RecordState
-> IO (Maybe (Either String a, TransportState))
got Maybe (ByteString -> Result a)
f -- read next block
  got :: Maybe (ByteString -> Result a)
-> ByteString
-> RecordState
-> IO (Maybe (Either String a, TransportState))
got Maybe (ByteString -> Result a)
Nothing ByteString
b RecordState
rs = RecordState
-> Result a -> IO (Maybe (Either String a, TransportState))
fed RecordState
rs forall a b. (a -> b) -> a -> b
$ forall a. Get a -> Maybe Int -> ByteString -> Result a
S.runGetChunk Get a
getter (RecordState -> Maybe Int
recordRemaining RecordState
rs) ByteString
b -- start parsing
  got (Just ByteString -> Result a
f) ByteString
b RecordState
rs = RecordState
-> Result a -> IO (Maybe (Either String a, TransportState))
fed RecordState
rs forall a b. (a -> b) -> a -> b
$ ByteString -> Result a
f ByteString
b -- parse block
  fed :: RecordState
-> Result a -> IO (Maybe (Either String a, TransportState))
fed RecordState
rs (S.Partial ByteString -> Result a
f) = Maybe (ByteString -> Result a)
-> RecordState -> IO (Maybe (Either String a, TransportState))
get (forall a. a -> Maybe a
Just ByteString -> Result a
f) RecordState
rs
  fed RecordState
rs (S.Done a
r ByteString
b) = forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just (forall a b. b -> Either a b
Right a
r, ByteString -> RecordState -> TransportState
TransportState ByteString
b RecordState
rs)
  fed RecordState
rs (S.Fail String
e ByteString
b) = forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just (forall a b. a -> Either a b
Left String
e, ByteString -> RecordState -> TransportState
TransportState ByteString
b RecordState
rs)

-- |Get the first part of the next record, possibly skipping over the rest of the current record.
recvGetFirst :: Net.Socket -> S.Get a -> TransportState -> IO (Maybe (Either String a, TransportState))
recvGetFirst :: forall a.
Socket
-> Get a
-> TransportState
-> IO (Maybe (Either String a, TransportState))
recvGetFirst Socket
sock Get a
getter = RecordState -> IO (Maybe (Either String a, TransportState))
get forall b c a. (b -> c) -> (a -> b) -> a -> c
. TransportState -> RecordState
recordState where
  get :: RecordState -> IO (Maybe (Either String a, TransportState))
get RecordState
rs = forall a.
Socket
-> RecordState
-> (ByteString -> RecordState -> IO (Maybe a))
-> IO (Maybe a)
recvTransportWith Socket
sock RecordState
rs forall a b. (a -> b) -> a -> b
$ RecordState
-> ByteString
-> RecordState
-> IO (Maybe (Either String a, TransportState))
got RecordState
rs -- read next block
  got :: RecordState
-> ByteString
-> RecordState
-> IO (Maybe (Either String a, TransportState))
got RecordState
RecordStart ByteString
b RecordState
rs = forall a.
Socket
-> Get a
-> TransportState
-> IO (Maybe (Either String a, TransportState))
recvGetNext Socket
sock Get a
getter forall a b. (a -> b) -> a -> b
$ ByteString -> RecordState -> TransportState
TransportState ByteString
b RecordState
rs -- start next record
  got RecordState
_ ByteString
_ RecordState
rs = RecordState -> IO (Maybe (Either String a, TransportState))
get RecordState
rs -- ignore remaining record