module Network.SSH.Stream where

import           Control.Exception    ( throwIO )
import           Control.Monad        ( when )
import           Foreign.Ptr
import qualified Data.ByteString   as BS
import qualified Data.ByteArray    as BA

-- | A `DuplexStream` is an abstraction over all things that
--   behave like file handles or sockets.
class (InputStream stream, OutputStream stream) => DuplexStream stream where

-- | An `OutputStream` is something that chunks of bytes can be written to.
class OutputStream stream where
    -- | Send a chunk of bytes into the stream.
    --
    -- (1) This method shall block until at least one byte could be sent or
    --     the connection got closed.
    -- (2) Returns the number of bytes sent or 0 if the other side
    --     closed the connection. The return value must be checked when
    --     using a loop for sending or the program will get stuck in
    --     endless recursion!
    send          :: stream -> BS.ByteString -> IO Int
    -- | Like `send`, but allows for more efficiency with less memory
    --   allocations when working with builders and re-usable buffers.
    sendUnsafe    :: stream -> BA.MemView -> IO Int
    sendUnsafe stream view = do
        bs <- BA.copy view (const $ pure ())
        send stream bs
    {-# MINIMAL send #-}

-- | An `InputStream` is something that bytes can be read from.
class InputStream stream where
    -- | Like `receive`, but does not actually remove anything
    --   from the input buffer.
    --
    -- (1) Use with care! There are very few legitimate use cases
    --     for this.
    peek          :: stream -> Int -> IO BS.ByteString
    -- | Receive a chunk of bytes from the stream.
    --
    -- (1) This method shall block until at least one byte becomes
    --     available or the connection got closed.
    -- (2) As with sockets, the chunk boundaries are not guaranteed to
    --     be preserved during transmission although this will be most often
    --     the case. Never rely on this behaviour!
    -- (3) The second parameter determines how many bytes to receive at most,
    --     but the `BS.ByteString` returned might be shorter.
    -- (4) Returns a chunk which is guaranteed to be shorter or equal
    --     than the given limit. It is empty when the connection got
    --     closed and all subsequent attempts to read shall return the
    --     empty string. This must be checked when collecting chunks in
    --     a loop or the program will get stuck in endless recursion!
    receive       :: stream -> Int -> IO BS.ByteString
    -- | Like `receive`, but allows for more efficiency with less memory
    --   allocations when working with builders and re-usable buffers.
    receiveUnsafe :: stream -> BA.MemView -> IO Int
    receiveUnsafe stream (BA.MemView ptr n) = do
        bs <- receive stream n
        BA.copyByteArrayToPtr bs ptr
        pure (BS.length bs)
    {-# MINIMAL peek, receive #-}

-- | Try to send the complete `BS.ByteString`.
--
--   * Blocks until either the `BS.ByteString` has been sent
--     or throws an exception when the connection got terminated
--     while sending it.
sendAll :: OutputStream stream => stream -> BS.ByteString -> IO ()
sendAll stream bs
    | BS.null bs = pure ()
    | otherwise  = BA.withByteArray bs $ sendAll' (BS.length bs)
    where
        sendAll' remaining ptr
            | remaining <= 0 = pure ()
            | otherwise = do
                sent <- sendUnsafe stream (BA.MemView ptr remaining)
                when (sent <= 0) (throwIO $ userError "sendAll: connection lost")
                sendAll' (remaining - sent) (plusPtr ptr sent)

-- | Try to receive a `BS.ByteString` of the designated length in bytes.
--
--   * Blocks until either the complete `BS.ByteString` has been received
--     or throws an exception when the connection got terminated
--     before enough bytes arrived.
receiveAll :: InputStream stream => stream -> Int -> IO BS.ByteString
receiveAll stream n
    | n <= 0    = pure mempty
    | otherwise = BA.alloc n $ receiveAll' n
    where
        receiveAll' remaining ptr
            | remaining <= 0 = pure ()
            | otherwise = do
                received <- receiveUnsafe stream (BA.MemView ptr remaining)
                when (received <= 0) (throwIO $ userError "receiveAll: connection lost")
                receiveAll' (remaining - received) (plusPtr ptr received)