module System.Directory.Watchman.BSER.Protocol
    ( sendBSERMessage
    , readBSERMessage
    ) where

import Control.Monad (unless)
import Network.Socket (Socket)
import Network.Socket.ByteString.Lazy as BL
import Data.Binary
import Data.Binary.Put
import qualified Data.ByteString.Lazy as BL
import System.Directory.Watchman.BSER

sendBSERMessage :: Socket -> BSERValue -> IO ()
sendBSERMessage :: Socket -> BSERValue -> IO ()
sendBSERMessage Socket
sock BSERValue
val = do
    let encoded :: ByteString
encoded = BSERValue -> ByteString
forall a. Binary a => a -> ByteString
encode BSERValue
val
        lengthPrefix :: ByteString
lengthPrefix = Put -> ByteString
runPut (Put -> ByteString) -> Put -> ByteString
forall a b. (a -> b) -> a -> b
$ BSERValue -> Put
forall t. Binary t => t -> Put
put (Int64 -> BSERValue
forall n. Integral n => n -> BSERValue
compactBSERInt (ByteString -> Int64
BL.length ByteString
encoded))
        packet :: ByteString
packet = ByteString
headerBSER ByteString -> ByteString -> ByteString
`BL.append` ByteString
lengthPrefix ByteString -> ByteString -> ByteString
`BL.append` ByteString
encoded
    Socket -> ByteString -> IO ()
BL.sendAll Socket
sock ByteString
packet
    where
    headerBSER :: ByteString
headerBSER = [Word8] -> ByteString
BL.pack [Word8
0x00, Word8
0x01]

recvN :: Socket -> Int -> IO BL.ByteString
recvN :: Socket -> Int -> IO ByteString
recvN Socket
sock Int
n = Socket -> [ByteString] -> Int -> IO ByteString
recvN' Socket
sock [] Int
n

recvN' :: Socket -> [BL.ByteString] -> Int -> IO BL.ByteString
recvN' :: Socket -> [ByteString] -> Int -> IO ByteString
recvN' Socket
sock [ByteString]
buf Int
n = do
    ByteString
x <- Socket -> Int64 -> IO ByteString
BL.recv Socket
sock (Int -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
n)
    let remaining :: Int
remaining = Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int64 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (ByteString -> Int64
BL.length ByteString
x)
    if  Int
remaining Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 Bool -> Bool -> Bool
||
        Int
remaining Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
n -- Remote side closed the socket
            then ByteString -> IO ByteString
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ByteString -> IO ByteString) -> ByteString -> IO ByteString
forall a b. (a -> b) -> a -> b
$ [ByteString] -> ByteString
BL.concat ([ByteString] -> [ByteString]
forall a. [a] -> [a]
reverse (ByteString
xByteString -> [ByteString] -> [ByteString]
forall a. a -> [a] -> [a]
:[ByteString]
buf))
            else Socket -> [ByteString] -> Int -> IO ByteString
recvN' Socket
sock (ByteString
xByteString -> [ByteString] -> [ByteString]
forall a. a -> [a] -> [a]
:[ByteString]
buf) Int
remaining

readBSERMessage :: Socket -> IO BSERValue
readBSERMessage :: Socket -> IO BSERValue
readBSERMessage Socket
sock = do
    -- TODO Clean this up
    ByteString
header <- Socket -> Int -> IO ByteString
recvN Socket
sock Int
2 -- Swallow the protocol header
    Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (ByteString -> Int64
BL.length ByteString
header Int64 -> Int64 -> Bool
forall a. Eq a => a -> a -> Bool
== Int64
2) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$
        String -> IO ()
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String -> IO ()) -> String -> IO ()
forall a b. (a -> b) -> a -> b
$ String
"Error reading header. Received: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ ByteString -> String
forall a. Show a => a -> String
show ByteString
header

    ByteString
tagBuf <- Socket -> Int -> IO ByteString
recvN Socket
sock Int
1
    Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (ByteString -> Int64
BL.length ByteString
tagBuf Int64 -> Int64 -> Bool
forall a. Eq a => a -> a -> Bool
== Int64
1) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$
        String -> IO ()
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Error reading tag"

    ByteString
lenBuf <- case ByteString -> Word8
BL.head ByteString
tagBuf of
        Word8
0x03 -> Socket -> Int -> IO ByteString
recvN Socket
sock Int
1
        Word8
0x04 -> Socket -> Int -> IO ByteString
recvN Socket
sock Int
2
        Word8
0x05 -> Socket -> Int -> IO ByteString
recvN Socket
sock Int
4
        Word8
0x06 -> Socket -> Int -> IO ByteString
recvN Socket
sock Int
8
        Word8
_ ->
            -- TODO Better error handling
            String -> IO ByteString
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Invalid BSER Message"
    let lengthPrefix :: ByteString
lengthPrefix = ByteString
tagBuf ByteString -> ByteString -> ByteString
`BL.append` ByteString
lenBuf
    Int
len <- case ByteString
-> Either
     (ByteString, Int64, String) (ByteString, Int64, BSERValue)
forall a.
Binary a =>
ByteString
-> Either (ByteString, Int64, String) (ByteString, Int64, a)
decodeOrFail ByteString
lengthPrefix of
        Left (ByteString
_, Int64
_, String
err) -> String -> IO Int
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
err
        Right (ByteString
_, Int64
_, BSERValue
val) -> case BSERValue -> Either String Int
readBSERInt BSERValue
val of
            Left String
err -> String -> IO Int
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
err
            Right Int
l -> Int -> IO Int
forall (f :: * -> *) a. Applicative f => a -> f a
pure Int
l
    ByteString
encoded <- Socket -> Int -> IO ByteString
recvN Socket
sock Int
len
    case ByteString
-> Either
     (ByteString, Int64, String) (ByteString, Int64, BSERValue)
forall a.
Binary a =>
ByteString
-> Either (ByteString, Int64, String) (ByteString, Int64, a)
decodeOrFail ByteString
encoded of
        Left (ByteString
_, Int64
_, String
err) -> String -> IO BSERValue
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
err
        Right (ByteString
_, Int64
_, BSERValue
val) -> BSERValue -> IO BSERValue
forall (f :: * -> *) a. Applicative f => a -> f a
pure BSERValue
val