module Util.BufferedSocket ( BufferedSocket() , makeBuffered , socketPort , socketRecv , pushback , socketSend , socketClose ) where import Control.Monad (unless) import qualified Data.ByteString as BS (ByteString, empty, append, null, length, splitAt) import qualified Data.ByteString.Lazy as BL (ByteString) import qualified Network.Socket as S (Socket, PortNumber, socketPort, close) import qualified Network.Socket.ByteString as NBS (recv) import qualified Network.Socket.ByteString.Lazy as NBL (sendAll) import Data.IORef (IORef, newIORef, readIORef, writeIORef, modifyIORef') import Util.IOx -------------------------------------------------------------------------------- newtype BufferedSocket = BufferedSocket (S.Socket, IORef BS.ByteString) makeBuffered :: S.Socket -> IOx BufferedSocket makeBuffered sock = do bufIO <- liftIOx $ newIORef BS.empty return $ BufferedSocket (sock, bufIO) socketPort:: BufferedSocket -> IOx S.PortNumber socketPort (BufferedSocket (sock, _)) = toIOx $ do S.socketPort sock socketRecv :: BufferedSocket -> Int -> IOx BS.ByteString socketRecv (BufferedSocket (sock, bufIO)) len | len < 0 = error $ "Bad length: " ++ show len | len == 0 = return BS.empty | otherwise = toIOx $ do buf <- readIORef bufIO if BS.null buf then do NBS.recv sock len else do let bufLen = BS.length buf if len > bufLen then do writeIORef bufIO BS.empty return buf else do let (buf0, buf1) = BS.splitAt len buf writeIORef bufIO buf1 return buf0 pushback :: BufferedSocket -> BS.ByteString -> IOx () pushback (BufferedSocket (_, bufIO)) buf0 = do unless (BS.null buf0) $ toIOx $ do modifyIORef' bufIO $ BS.append buf0 socketSend :: BufferedSocket -> BL.ByteString -> IOx () socketSend (BufferedSocket (sock, _)) bl = toIOx $ do NBL.sendAll sock bl socketClose :: BufferedSocket -> IOx () socketClose (BufferedSocket (sock, bufIO)) = toIOx $ do writeIORef bufIO BS.empty S.close sock