module Network.ONCRPC.Transport
  ( sendTransport
  , recvTransport
  , TransportState
  , transportStart
  , recvGetFirst
  , recvGetNext
  ) where

import qualified Data.ByteString as BS
import qualified Data.ByteString.Lazy as BSL
import qualified Data.Serialize.Get as S
import qualified Network.Socket as Net

import           Network.ONCRPC.RecordMarking

sendTransport :: Net.Socket -> BSL.ByteString -> IO ()
sendTransport sock@(Net.MkSocket _ _ Net.Stream _ _) = sendRecord sock
sendTransport _ = const $ fail "ONCRPC: Unsupported socket type"

recvTransport :: Net.Socket -> RecordState -> IO (BS.ByteString, RecordState)
recvTransport sock@(Net.MkSocket _ _ Net.Stream _ _) = recvRecord sock
recvTransport _ = const $ fail "ONCRPC: Unsupported socket type"

data TransportState = TransportState
  { _bufferState :: BS.ByteString
  , recordState :: RecordState
  }
  deriving (Eq, Show)

transportNext :: RecordState -> TransportState
transportNext = TransportState BS.empty

transportStart :: TransportState
transportStart = transportNext RecordStart

recvTransportWith :: Net.Socket -> RecordState -> (BS.ByteString -> RecordState -> IO (Maybe a)) -> IO (Maybe a)
recvTransportWith sock rs f = do
  (b, rs') <- recvTransport sock rs
  if BS.null b
    then return Nothing
    else f b rs'

-- |Get the next part of the current record, after calling 'recvGetFirst' to start.
recvGetNext :: Net.Socket -> S.Get a -> TransportState -> IO (Maybe (Either String a, TransportState))
recvGetNext sock getter = start where
  start (TransportState b rs) -- continue record
    | BS.null b = get Nothing rs -- check for more
    | otherwise = got Nothing b rs -- buffered data
  get f RecordStart = got f BS.empty RecordStart -- end of record
  get f rs = recvTransportWith sock rs $ got f -- read next block
  got Nothing b rs = fed rs $ S.runGetChunk getter (recordRemaining rs) b -- start parsing
  got (Just f) b rs = fed rs $ f b -- parse block
  fed rs (S.Partial f) = get (Just f) rs
  fed rs (S.Done r b) = return $ Just (Right r, TransportState b rs)
  fed rs (S.Fail e b) = return $ Just (Left e, TransportState b rs)

-- |Get the first part of the next record, possibly skipping over the rest of the current record.
recvGetFirst :: Net.Socket -> S.Get a -> TransportState -> IO (Maybe (Either String a, TransportState))
recvGetFirst sock getter = get . recordState where
  get rs = recvTransportWith sock rs $ got rs -- read next block
  got RecordStart b rs = recvGetNext sock getter $ TransportState b rs -- start next record
  got _ _ rs = get rs -- ignore remaining record