module Network.BufferedSocket
    ( BufferedSocket()
    , makeBuffered
    , socketPort
    ) where

import           Control.Monad                  ( unless )
import           Control.Monad.IO.Class         ( MonadIO(..), liftIO )
import qualified Data.ByteString                as BS ( ByteString, append, empty, length, null, splitAt )
import qualified Data.ByteString.Lazy           as LBS ( ByteString )
import           Data.IORef                     ( IORef, atomicModifyIORef', newIORef, writeIORef )
import qualified Network.Socket                 as S ( PortNumber, Socket, close, socketPort )
import qualified Network.Socket.ByteString      as NBS ( recv )
import qualified Network.Socket.ByteString.Lazy as NBL ( sendAll )
import           Util.BufferedIOx

--------------------------------------------------------------------------------
newtype BufferedSocket = BufferedSocket (S.Socket, IORef BS.ByteString)

instance BufferedIOx BufferedSocket where
    readBuffered a = liftIO . socketRecv a
    unreadBuffered a = liftIO . pushback a
    writeBuffered a = liftIO . socketSend a
    closeBuffered = liftIO . socketClose

makeBuffered :: S.Socket -> IO BufferedSocket
makeBuffered sock = do
    bufIO <- newIORef BS.empty
    return $ BufferedSocket (sock, bufIO)

socketPort :: BufferedSocket -> IO S.PortNumber
socketPort (BufferedSocket (sock, _)) =
    S.socketPort sock

socketRecv :: BufferedSocket -> Int -> IO BS.ByteString
socketRecv (BufferedSocket (sock, bufIO)) len
    | len < 0 = error $ "Bad length: " ++ show len
    | len == 0 = return BS.empty
    | otherwise = do
          atomicModifyIORef' bufIO
                             (\buf -> if BS.null buf
                                      then (BS.empty, Nothing)
                                      else let bufLen = BS.length buf
                                           in
                                               if len > bufLen
                                               then (BS.empty, (Just buf))
                                               else let (buf0, buf1) = BS.splitAt len buf
                                                    in
                                                        (buf1, Just buf0)) >>=
              maybe (NBS.recv sock len) (return)

pushback :: BufferedSocket -> BS.ByteString -> IO ()
pushback (BufferedSocket (_, bufIO)) bytes = do
    unless (BS.null bytes) $ do
        atomicModifyIORef' bufIO (\buf -> (buf `BS.append` bytes, ()))

socketSend :: BufferedSocket -> LBS.ByteString -> IO ()
socketSend (BufferedSocket (sock, _)) bl = do
    NBL.sendAll sock bl

socketClose :: BufferedSocket -> IO ()
socketClose (BufferedSocket (sock, bufIO)) = do
    writeIORef bufIO BS.empty
    S.close sock