{-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE MultiWayIf #-} module Network.SSH.TStreamingQueue where import Control.Concurrent.STM.TChan import Control.Concurrent.STM.TVar import Control.Concurrent.STM.TMVar import Control.Monad.STM import Control.Applicative import Data.Word import qualified Data.ByteString as BS import qualified Data.ByteString.Short as SBS import Prelude hiding ( head , tail ) import qualified Network.SSH.Stream as S import Network.SSH.Constants data TStreamingQueue = TStreamingQueue { qCapacity :: Word32 , qWindow :: TVar Word32 , qSize :: TVar Word32 , qEof :: TVar Bool , qHead :: TMVar SBS.ShortByteString , qTail :: TChan SBS.ShortByteString } newTStreamingQueue :: Word32 -> TVar Word32 -> STM TStreamingQueue newTStreamingQueue c window = TStreamingQueue c window <$> newTVar 0 <*> newTVar False <*> newEmptyTMVar <*> newTChan capacity :: TStreamingQueue -> Word32 capacity = qCapacity getSize :: TStreamingQueue -> STM Word32 getSize = readTVar . qSize getFree :: TStreamingQueue -> STM Word32 getFree q = (capacity q -) <$> getSize q getWindowSpace :: TStreamingQueue -> STM Word32 getWindowSpace = readTVar . qWindow addWindowSpace :: TStreamingQueue -> Word32 -> STM () addWindowSpace q increment = do wndw <- getWindowSpace q :: STM Word32 check $ (fromIntegral wndw + fromIntegral increment :: Word64) <= fromIntegral (maxBound :: Word32) writeTVar (qWindow q) $! wndw + increment askWindowSpaceAdjustRecommended :: TStreamingQueue -> STM Bool askWindowSpaceAdjustRecommended q = do size <- getSize q wndw <- getWindowSpace q let threshold = capacity q `div` 2 -- 1st condition: window size must be below half of its maximum -- 2nd condition: queue size must be below half of its capacity -- in order to avoid byte-wise adjustment and flapping pure $ size <= threshold && wndw <= threshold fillWindowSpace :: TStreamingQueue -> STM Word32 fillWindowSpace q = do free <- getFree q wndw <- getWindowSpace q writeTVar (qWindow q) $! wndw + free pure free terminate :: TStreamingQueue -> STM () terminate q = writeTVar (qEof q) True enqueue :: TStreamingQueue -> BS.ByteString -> STM Word32 enqueue q bs | BS.null bs = pure 0 | otherwise = do eof <- readTVar (qEof q) if eof then pure 0 else do size <- getSize q wndw <- getWindowSpace q let free = capacity q - size requested = fromIntegral (BS.length bs) :: Word32 available = min (min free wndw) maxBoundIntWord32 :: Word32 check $ available > 0 -- Block until there's free capacity and window space. if | available >= requested -> do writeTVar (qSize q) $! size + requested writeTVar (qWindow q) $! wndw - requested writeTChan (qTail q) $! SBS.toShort bs pure requested | otherwise -> do writeTVar (qSize q) $! size + available writeTVar (qWindow q) $! wndw - available writeTChan (qTail q) $! SBS.toShort $ BS.take (fromIntegral available) bs pure available dequeue :: TStreamingQueue -> Word32 -> STM BS.ByteString dequeue q maxBufSize = do size <- getSize q eof <- readTVar (qEof q) check $ size > 0 || eof -- Block until there's at least 1 byte available. if size == 0 && eof then pure mempty else SBS.fromShort . mconcat <$> f size requested where f s 0 = do writeTVar (qSize q) $! s - requested pure [] f s j = do bs <- takeTMVar (qHead q) <|> readTChan (qTail q) <|> pure mempty if | SBS.null bs -> do writeTVar (qSize q) 0 pure [] | fromIntegral (SBS.length bs) <= j -> (bs:) <$> f s (j - fromIntegral (SBS.length bs)) | otherwise -> do writeTVar (qSize q) $! s - requested putTMVar (qHead q) $! SBS.toShort $ BS.drop (fromIntegral j) $ SBS.fromShort bs pure [ SBS.toShort $ BS.take (fromIntegral j) $ SBS.fromShort bs ] requested = min maxBufSize maxBoundIntWord32 dequeueShort :: TStreamingQueue -> Word32 -> STM SBS.ShortByteString dequeueShort q maxBufSize = do size <- getSize q eof <- readTVar (qEof q) check $ size > 0 || eof -- Block until there's at least 1 byte available. if size == 0 && eof then pure mempty else mconcat <$> f size requested where f s 0 = do writeTVar (qSize q) $! s - requested pure [] f s j = do bs <- takeTMVar (qHead q) <|> readTChan (qTail q) <|> pure mempty if | SBS.null bs -> do writeTVar (qSize q) 0 pure [] | fromIntegral (SBS.length bs) <= j -> (bs:) <$> f s (j - fromIntegral (SBS.length bs)) | otherwise -> do writeTVar (qSize q) $! s - requested putTMVar (qHead q) $! SBS.toShort $ BS.drop (fromIntegral j) $ SBS.fromShort bs pure [ SBS.toShort $ BS.take (fromIntegral j) $ SBS.fromShort bs ] requested = min maxBufSize maxBoundIntWord32 lookAhead :: TStreamingQueue -> Word32 -> STM BS.ByteString lookAhead q maxBufSize = do size <- getSize q eof <- readTVar (qEof q) check $ size > 0 || eof if size == 0 && eof then pure mempty else do bs <- readTMVar (qHead q) <|> peekTChan (qTail q) pure $ BS.take (fromIntegral maxBufSize) (SBS.fromShort bs) instance S.DuplexStream TStreamingQueue instance S.OutputStream TStreamingQueue where send q bs = fromIntegral <$> atomically (enqueue q bs) instance S.InputStream TStreamingQueue where peek q i = atomically $ lookAhead q $ fromIntegral $ min i maxBoundIntWord32 receive q i = atomically $ dequeue q $ fromIntegral $ min i maxBoundIntWord32