{-# LANGUAGE ForeignFunctionInterface, OverloadedStrings #-} {-# LANGUAGE CPP #-} module Network.Wai.Handler.Warp.Recv ( receive , receiveBuf , makeReceiveN , makePlainReceiveN , spell ) where import qualified Control.Exception as E import qualified Data.ByteString as BS import Data.IORef import Foreign.C.Error (eAGAIN, getErrno, throwErrno) import Foreign.C.Types import Foreign.ForeignPtr (withForeignPtr) import Foreign.Ptr (Ptr, castPtr, plusPtr) import GHC.Conc (threadWaitRead) import qualified GHC.IO.Exception as E import Network.Socket (Socket) import qualified System.IO.Error as E #if MIN_VERSION_network(3,1,0) import Network.Socket (withFdSocket) #else import Network.Socket (fdSocket) #endif import System.Posix.Types (Fd(..)) import Network.Wai.Handler.Warp.Buffer import Network.Wai.Handler.Warp.Imports import Network.Wai.Handler.Warp.Types #ifdef mingw32_HOST_OS import GHC.IO.FD (FD(..), readRawBufferPtr) import Network.Wai.Handler.Warp.Windows #endif ---------------------------------------------------------------- makeReceiveN :: ByteString -> Recv -> RecvBuf -> IO (BufSize -> IO ByteString) makeReceiveN bs0 recv recvBuf = do ref <- newIORef bs0 return $ receiveN ref recv recvBuf -- | This function returns a receiving function -- based on two receiving functions. -- The returned function efficiently manages received data -- which is initialized by the first argument. -- The returned function may allocate a byte string with malloc(). makePlainReceiveN :: Socket -> ByteString -> IO (BufSize -> IO ByteString) makePlainReceiveN s bs0 = do ref <- newIORef bs0 pool <- newBufferPool return $ receiveN ref (receive s pool) (receiveBuf s) receiveN :: IORef ByteString -> Recv -> RecvBuf -> BufSize -> IO ByteString receiveN ref recv recvBuf size = E.handle handler $ do cached <- readIORef ref (bs, leftover) <- spell cached size recv recvBuf writeIORef ref leftover return bs where handler :: E.SomeException -> IO ByteString handler _ = return "" ---------------------------------------------------------------- spell :: ByteString -> BufSize -> IO ByteString -> RecvBuf -> IO (ByteString, ByteString) spell init0 siz0 recv recvBuf | siz0 <= len0 = return $ BS.splitAt siz0 init0 -- fixme: hard coding 4096 | siz0 <= 4096 = loop [init0] (siz0 - len0) | otherwise = do bs@(PS fptr _ _) <- mallocBS siz0 withForeignPtr fptr $ \ptr -> do ptr' <- copy ptr init0 full <- recvBuf ptr' (siz0 - len0) if full then return (bs, "") else return ("", "") -- fixme where len0 = BS.length init0 loop bss siz = do bs <- recv let len = BS.length bs if len == 0 then return ("", "") else if len >= siz then do let (consume, leftover) = BS.splitAt siz bs ret = BS.concat $ reverse (consume : bss) return (ret, leftover) else do let bss' = bs : bss siz' = siz - len loop bss' siz' -- The timeout manager may close the socket. -- In that case, an error of "Bad file descriptor" occurs. -- We ignores it because we expect TimeoutThread. receive :: Socket -> BufferPool -> Recv receive sock pool = E.handle handler $ withBufferPool pool $ \ (ptr, size) -> do #if MIN_VERSION_network(3,1,0) withFdSocket sock $ \fd -> do #elif MIN_VERSION_network(3,0,0) fd <- fdSocket sock #else let fd = fdSocket sock #endif let size' = fromIntegral size fromIntegral <$> receiveloop fd ptr size' where handler :: E.IOException -> IO ByteString handler e | E.ioeGetErrorType e == E.InvalidArgument = return "" | otherwise = E.throwIO e receiveBuf :: Socket -> RecvBuf receiveBuf sock buf0 siz0 = do #if MIN_VERSION_network(3,1,0) withFdSocket sock $ \fd -> do #elif MIN_VERSION_network(3,0,0) fd <- fdSocket sock #else let fd = fdSocket sock #endif loop fd buf0 siz0 where loop _ _ 0 = return True loop fd buf siz = do n <- fromIntegral <$> receiveloop fd buf (fromIntegral siz) -- fixme: what should we do in the case of n == 0 if n == 0 then return False else loop fd (buf `plusPtr` n) (siz - n) receiveloop :: CInt -> Ptr Word8 -> CSize -> IO CInt receiveloop sock ptr size = do #ifdef mingw32_HOST_OS bytes <- windowsThreadBlockHack $ fromIntegral <$> readRawBufferPtr "recv" (FD sock 1) (castPtr ptr) 0 size #else bytes <- c_recv sock (castPtr ptr) size 0 #endif if bytes == -1 then do errno <- getErrno if errno == eAGAIN then do threadWaitRead (Fd sock) receiveloop sock ptr size else throwErrno "receiveloop" else return bytes #ifndef mingw32_HOST_OS -- fixme: the type of the return value foreign import ccall unsafe "recv" c_recv :: CInt -> Ptr CChar -> CSize -> CInt -> IO CInt #endif