{-# LANGUAGE RankNTypes #-}

module Pinch.Transport
  ( Transport(..)
  , framedTransport
  , unframedTransport
  , Connection(..)
  , ReadResult(..)
  ) where

import Data.IORef (newIORef, readIORef, writeIORef)
import Network.Socket (Socket)
import Network.Socket.ByteString (sendAll, recv)
import System.IO (Handle)

import qualified Data.ByteString as BS
import qualified Data.Serialize.Get as G

import qualified Pinch.Internal.Builder as B

class Connection c where
  -- | Gets up to n bytes. Returns an empty bytestring if EOF is reached.
  cGetSome :: c -> Int -> IO BS.ByteString
  -- | Writes the given bytestring.
  cPut :: c -> BS.ByteString -> IO ()

instance Connection Handle where
  cPut :: Handle -> ByteString -> IO ()
cPut = Handle -> ByteString -> IO ()
BS.hPut
  cGetSome :: Handle -> Int -> IO ByteString
cGetSome = Handle -> Int -> IO ByteString
BS.hGetSome

instance Connection Socket where
  cPut :: Socket -> ByteString -> IO ()
cPut = Socket -> ByteString -> IO ()
sendAll
  cGetSome :: Socket -> Int -> IO ByteString
cGetSome Socket
s Int
n = Socket -> Int -> IO ByteString
recv Socket
s (forall a. Ord a => a -> a -> a
min Int
n Int
4096)

data ReadResult a
  = RRSuccess a
  | RRFailure String
  | RREOF
  deriving (ReadResult a -> ReadResult a -> Bool
forall a. Eq a => ReadResult a -> ReadResult a -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: ReadResult a -> ReadResult a -> Bool
$c/= :: forall a. Eq a => ReadResult a -> ReadResult a -> Bool
== :: ReadResult a -> ReadResult a -> Bool
$c== :: forall a. Eq a => ReadResult a -> ReadResult a -> Bool
Eq, Int -> ReadResult a -> ShowS
forall a. Show a => Int -> ReadResult a -> ShowS
forall a. Show a => [ReadResult a] -> ShowS
forall a. Show a => ReadResult a -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [ReadResult a] -> ShowS
$cshowList :: forall a. Show a => [ReadResult a] -> ShowS
show :: ReadResult a -> String
$cshow :: forall a. Show a => ReadResult a -> String
showsPrec :: Int -> ReadResult a -> ShowS
$cshowsPrec :: forall a. Show a => Int -> ReadResult a -> ShowS
Show)

-- | A bidirectional transport to read/write messages from/to.
data Transport
  = Transport
  { Transport -> Builder -> IO ()
writeMessage :: B.Builder -> IO ()
  , Transport -> forall a. Get a -> IO (ReadResult a)
readMessage  :: forall a . G.Get a -> IO (ReadResult a)
  }

-- | Creates a thrift framed transport. See also <https://github.com/apache/thrift/blob/master/doc/specs/thrift-rpc.md#framed-vs-unframed-transport>.
framedTransport :: Connection c => c -> IO Transport
framedTransport :: forall c. Connection c => c -> IO Transport
framedTransport c
c = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ (Builder -> IO ())
-> (forall a. Get a -> IO (ReadResult a)) -> Transport
Transport Builder -> IO ()
writeMsg forall a. Get a -> IO (ReadResult a)
readMsg where
  writeMsg :: Builder -> IO ()
writeMsg Builder
msg = do
    forall c. Connection c => c -> ByteString -> IO ()
cPut c
c forall a b. (a -> b) -> a -> b
$ Builder -> ByteString
B.runBuilder forall a b. (a -> b) -> a -> b
$ Int32 -> Builder
B.int32BE (forall a b. (Integral a, Num b) => a -> b
fromIntegral forall a b. (a -> b) -> a -> b
$ Builder -> Int
B.getSize Builder
msg)
    forall c. Connection c => c -> ByteString -> IO ()
cPut c
c forall a b. (a -> b) -> a -> b
$ Builder -> ByteString
B.runBuilder Builder
msg

  readMsg :: Get a -> IO (ReadResult a)
readMsg Get a
p = do
    ByteString
szBs <- forall c. Connection c => c -> Int -> IO ByteString
getExactly c
c Int
4
    if ByteString -> Int
BS.length ByteString
szBs forall a. Ord a => a -> a -> Bool
< Int
4
      then
        forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a. ReadResult a
RREOF
      else do
        let sz :: Either String Int
sz = forall a b. (Integral a, Num b) => a -> b
fromIntegral forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. Get a -> ByteString -> Either String a
G.runGet Get Int32
G.getInt32be ByteString
szBs
        case Either String Int
sz of
          Right Int
x -> do
            ByteString
msgBs <- forall c. Connection c => c -> Int -> IO ByteString
getExactly c
c Int
x
            forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ if ByteString -> Int
BS.length ByteString
msgBs forall a. Ord a => a -> a -> Bool
< Int
x
              then
                -- less data has been returned than expected. This means we have reached EOF.
                forall a. ReadResult a
RREOF
              else
                forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either forall a. String -> ReadResult a
RRFailure forall a. a -> ReadResult a
RRSuccess forall a b. (a -> b) -> a -> b
$ forall a. Get a -> ByteString -> Either String a
G.runGet Get a
p ByteString
msgBs
          Left String
s -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a. String -> ReadResult a
RRFailure forall a b. (a -> b) -> a -> b
$ String
"Invalid frame size: " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show String
s

-- | Creates a thrift unframed transport. See also <https://github.com/apache/thrift/blob/master/doc/specs/thrift-rpc.md#framed-vs-unframed-transport>.
unframedTransport :: Connection c => c -> IO Transport
unframedTransport :: forall c. Connection c => c -> IO Transport
unframedTransport c
c = do
  -- As we do not know how long messages are,
  -- we may read more data then the current message needs.
  -- We keep the leftovers in a buffer so that we may use them
  -- when reading the next message.
  IORef ByteString
readBuffer <- forall a. a -> IO (IORef a)
newIORef forall a. Monoid a => a
mempty
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ (Builder -> IO ())
-> (forall a. Get a -> IO (ReadResult a)) -> Transport
Transport Builder -> IO ()
writeMsg (forall {a}. IORef ByteString -> Get a -> IO (ReadResult a)
readMsg IORef ByteString
readBuffer)
  where
    writeMsg :: Builder -> IO ()
writeMsg Builder
msg = forall c. Connection c => c -> ByteString -> IO ()
cPut c
c forall a b. (a -> b) -> a -> b
$ Builder -> ByteString
B.runBuilder Builder
msg

    readMsg :: IORef ByteString -> Get a -> IO (ReadResult a)
readMsg IORef ByteString
buf Get a
p = do
      ByteString
bs <- forall a. IORef a -> IO a
readIORef IORef ByteString
buf
      ByteString
bs' <- if ByteString -> Bool
BS.null ByteString
bs then IO ByteString
getSome else forall (f :: * -> *) a. Applicative f => a -> f a
pure ByteString
bs
      (ByteString
leftOvers, ReadResult a
r) <- forall a.
IO ByteString
-> Get a -> ByteString -> IO (ByteString, ReadResult a)
runGetWith IO ByteString
getSome Get a
p ByteString
bs'
      forall a. IORef a -> a -> IO ()
writeIORef IORef ByteString
buf ByteString
leftOvers
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ ReadResult a
r
    getSome :: IO ByteString
getSome = forall c. Connection c => c -> Int -> IO ByteString
cGetSome c
c Int
1024

-- | Runs a Get parser incrementally, reading more input as necessary until a successful parse
-- has been achieved.
runGetWith :: IO BS.ByteString -> G.Get a -> BS.ByteString -> IO (BS.ByteString, ReadResult a)
runGetWith :: forall a.
IO ByteString
-> Get a -> ByteString -> IO (ByteString, ReadResult a)
runGetWith IO ByteString
getBs Get a
p ByteString
initial = forall {a}. Result a -> IO (ByteString, ReadResult a)
go (forall a. Get a -> ByteString -> Result a
G.runGetPartial Get a
p ByteString
initial)
  where
    go :: Result a -> IO (ByteString, ReadResult a)
go Result a
r = case Result a
r of
      G.Fail String
err ByteString
bs -> do
        forall (f :: * -> *) a. Applicative f => a -> f a
pure (ByteString
bs, forall a. String -> ReadResult a
RRFailure String
err)
      G.Done a
a ByteString
bs -> do
        forall (f :: * -> *) a. Applicative f => a -> f a
pure (ByteString
bs, forall a. a -> ReadResult a
RRSuccess a
a)
      G.Partial ByteString -> Result a
cont -> do
        ByteString
bs <- IO ByteString
getBs
        if ByteString -> Bool
BS.null ByteString
bs
          then
            -- EOF
            forall (f :: * -> *) a. Applicative f => a -> f a
pure (ByteString
bs, forall a. ReadResult a
RREOF)
          else
            Result a -> IO (ByteString, ReadResult a)
go forall a b. (a -> b) -> a -> b
$ ByteString -> Result a
cont ByteString
bs
  
-- | Gets exactly n bytes. If EOF is reached, an empty string is returned.
getExactly :: Connection c => c -> Int -> IO BS.ByteString
getExactly :: forall c. Connection c => c -> Int -> IO ByteString
getExactly c
c Int
sz = Builder -> ByteString
B.runBuilder forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> Builder -> IO Builder
go Int
sz forall a. Monoid a => a
mempty
  where
    go :: Int -> B.Builder -> IO B.Builder
    go :: Int -> Builder -> IO Builder
go Int
n Builder
b = do
      ByteString
bs <- forall c. Connection c => c -> Int -> IO ByteString
cGetSome c
c Int
n
      let b' :: Builder
b' = Builder
b forall a. Semigroup a => a -> a -> a
<> ByteString -> Builder
B.byteString ByteString
bs
      case ByteString -> Int
BS.length ByteString
bs of
        -- EOF, return what data we might have gotten so far
        Int
0 -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Monoid a => a
mempty
        Int
n' | Int
n' forall a. Ord a => a -> a -> Bool
< Int
n -> Int -> Builder -> IO Builder
go (Int
n forall a. Num a => a -> a -> a
- Int
n') Builder
b'
        Int
_  | Bool
otherwise -> forall (f :: * -> *) a. Applicative f => a -> f a
pure Builder
b'