{-# LANGUAGE ForeignFunctionInterface, OverloadedStrings #-}
{-# LANGUAGE CPP #-}

module Network.Wai.Handler.Warp.Recv (
    receive
  , receiveBuf
  , makeReceiveN
  , makePlainReceiveN
  , spell
  ) where

import qualified UnliftIO
import qualified Data.ByteString as BS
import Data.IORef
import Foreign.C.Error (eAGAIN, getErrno, throwErrno)
import Foreign.C.Types
import Foreign.ForeignPtr (withForeignPtr)
import Foreign.Ptr (Ptr, castPtr, plusPtr)
import GHC.Conc (threadWaitRead)
import qualified GHC.IO.Exception as E
import Network.Socket (Socket)
import qualified System.IO.Error as E
#if MIN_VERSION_network(3,1,0)
import Network.Socket (withFdSocket)
#else
import Network.Socket (fdSocket)
#endif
import System.Posix.Types (Fd(..))

import Network.Wai.Handler.Warp.Buffer
import Network.Wai.Handler.Warp.Imports
import Network.Wai.Handler.Warp.Types

#ifdef mingw32_HOST_OS
import GHC.IO.FD (FD(..), readRawBufferPtr)
import Network.Wai.Handler.Warp.Windows
#endif

----------------------------------------------------------------

makeReceiveN :: ByteString -> Recv -> RecvBuf -> IO (BufSize -> IO ByteString)
makeReceiveN :: ByteString -> Recv -> RecvBuf -> IO (BufSize -> Recv)
makeReceiveN ByteString
bs0 Recv
recv RecvBuf
recvBuf = do
    IORef ByteString
ref <- ByteString -> IO (IORef ByteString)
forall a. a -> IO (IORef a)
newIORef ByteString
bs0
    (BufSize -> Recv) -> IO (BufSize -> Recv)
forall (m :: * -> *) a. Monad m => a -> m a
return ((BufSize -> Recv) -> IO (BufSize -> Recv))
-> (BufSize -> Recv) -> IO (BufSize -> Recv)
forall a b. (a -> b) -> a -> b
$ IORef ByteString -> Recv -> RecvBuf -> BufSize -> Recv
receiveN IORef ByteString
ref Recv
recv RecvBuf
recvBuf

-- | This function returns a receiving function
--   based on two receiving functions.
--   The returned function efficiently manages received data
--   which is initialized by the first argument.
--   The returned function may allocate a byte string with malloc().
makePlainReceiveN :: Socket -> ByteString -> IO (BufSize -> IO ByteString)
makePlainReceiveN :: Socket -> ByteString -> IO (BufSize -> Recv)
makePlainReceiveN Socket
s ByteString
bs0 = do
    IORef ByteString
ref <- ByteString -> IO (IORef ByteString)
forall a. a -> IO (IORef a)
newIORef ByteString
bs0
    IORef ByteString
pool <- IO (IORef ByteString)
newBufferPool
    (BufSize -> Recv) -> IO (BufSize -> Recv)
forall (m :: * -> *) a. Monad m => a -> m a
return ((BufSize -> Recv) -> IO (BufSize -> Recv))
-> (BufSize -> Recv) -> IO (BufSize -> Recv)
forall a b. (a -> b) -> a -> b
$ IORef ByteString -> Recv -> RecvBuf -> BufSize -> Recv
receiveN IORef ByteString
ref (Socket -> IORef ByteString -> Recv
receive Socket
s IORef ByteString
pool) (Socket -> RecvBuf
receiveBuf Socket
s)

receiveN :: IORef ByteString -> Recv -> RecvBuf -> BufSize -> IO ByteString
receiveN :: IORef ByteString -> Recv -> RecvBuf -> BufSize -> Recv
receiveN IORef ByteString
ref Recv
recv RecvBuf
recvBuf BufSize
size = (SomeException -> Recv) -> Recv -> Recv
forall (m :: * -> *) a.
MonadUnliftIO m =>
(SomeException -> m a) -> m a -> m a
UnliftIO.handleAny SomeException -> Recv
handler (Recv -> Recv) -> Recv -> Recv
forall a b. (a -> b) -> a -> b
$ do
    ByteString
cached <- IORef ByteString -> Recv
forall a. IORef a -> IO a
readIORef IORef ByteString
ref
    (ByteString
bs, ByteString
leftover) <- ByteString
-> BufSize -> Recv -> RecvBuf -> IO (ByteString, ByteString)
spell ByteString
cached BufSize
size Recv
recv RecvBuf
recvBuf
    IORef ByteString -> ByteString -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef IORef ByteString
ref ByteString
leftover
    ByteString -> Recv
forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
bs
 where
   handler :: UnliftIO.SomeException -> IO ByteString
   handler :: SomeException -> Recv
handler SomeException
_ = ByteString -> Recv
forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
""

----------------------------------------------------------------

spell :: ByteString -> BufSize -> IO ByteString -> RecvBuf -> IO (ByteString, ByteString)
spell :: ByteString
-> BufSize -> Recv -> RecvBuf -> IO (ByteString, ByteString)
spell ByteString
init0 BufSize
siz0 Recv
recv RecvBuf
recvBuf
  | BufSize
siz0 BufSize -> BufSize -> Bool
forall a. Ord a => a -> a -> Bool
<= BufSize
len0 = (ByteString, ByteString) -> IO (ByteString, ByteString)
forall (m :: * -> *) a. Monad m => a -> m a
return ((ByteString, ByteString) -> IO (ByteString, ByteString))
-> (ByteString, ByteString) -> IO (ByteString, ByteString)
forall a b. (a -> b) -> a -> b
$ BufSize -> ByteString -> (ByteString, ByteString)
BS.splitAt BufSize
siz0 ByteString
init0
  -- fixme: hard coding 4096
  | BufSize
siz0 BufSize -> BufSize -> Bool
forall a. Ord a => a -> a -> Bool
<= BufSize
4096 = [ByteString] -> BufSize -> IO (ByteString, ByteString)
loop [ByteString
init0] (BufSize
siz0 BufSize -> BufSize -> BufSize
forall a. Num a => a -> a -> a
- BufSize
len0)
  | Bool
otherwise    = do
      bs :: ByteString
bs@(PS ForeignPtr Word8
fptr BufSize
_ BufSize
_) <- BufSize -> Recv
mallocBS BufSize
siz0
      ForeignPtr Word8
-> (Ptr Word8 -> IO (ByteString, ByteString))
-> IO (ByteString, ByteString)
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr Word8
fptr ((Ptr Word8 -> IO (ByteString, ByteString))
 -> IO (ByteString, ByteString))
-> (Ptr Word8 -> IO (ByteString, ByteString))
-> IO (ByteString, ByteString)
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
ptr -> do
          Ptr Word8
ptr' <- Ptr Word8 -> ByteString -> IO (Ptr Word8)
copy Ptr Word8
ptr ByteString
init0
          Bool
full <- RecvBuf
recvBuf Ptr Word8
ptr' (BufSize
siz0 BufSize -> BufSize -> BufSize
forall a. Num a => a -> a -> a
- BufSize
len0)
          if Bool
full then
              (ByteString, ByteString) -> IO (ByteString, ByteString)
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString
bs, ByteString
"")
            else
              (ByteString, ByteString) -> IO (ByteString, ByteString)
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString
"", ByteString
"") -- fixme
  where
    len0 :: BufSize
len0 = ByteString -> BufSize
BS.length ByteString
init0
    loop :: [ByteString] -> BufSize -> IO (ByteString, ByteString)
loop [ByteString]
bss BufSize
siz = do
        ByteString
bs <- Recv
recv
        let len :: BufSize
len = ByteString -> BufSize
BS.length ByteString
bs
        if BufSize
len BufSize -> BufSize -> Bool
forall a. Eq a => a -> a -> Bool
== BufSize
0 then
            (ByteString, ByteString) -> IO (ByteString, ByteString)
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString
"", ByteString
"")
          else if BufSize
len BufSize -> BufSize -> Bool
forall a. Ord a => a -> a -> Bool
>= BufSize
siz then do
            let (ByteString
consume, ByteString
leftover) = BufSize -> ByteString -> (ByteString, ByteString)
BS.splitAt BufSize
siz ByteString
bs
                ret :: ByteString
ret = [ByteString] -> ByteString
BS.concat ([ByteString] -> ByteString) -> [ByteString] -> ByteString
forall a b. (a -> b) -> a -> b
$ [ByteString] -> [ByteString]
forall a. [a] -> [a]
reverse (ByteString
consume ByteString -> [ByteString] -> [ByteString]
forall a. a -> [a] -> [a]
: [ByteString]
bss)
            (ByteString, ByteString) -> IO (ByteString, ByteString)
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString
ret, ByteString
leftover)
          else do
            let bss' :: [ByteString]
bss' = ByteString
bs ByteString -> [ByteString] -> [ByteString]
forall a. a -> [a] -> [a]
: [ByteString]
bss
                siz' :: BufSize
siz' = BufSize
siz BufSize -> BufSize -> BufSize
forall a. Num a => a -> a -> a
- BufSize
len
            [ByteString] -> BufSize -> IO (ByteString, ByteString)
loop [ByteString]
bss' BufSize
siz'

-- The timeout manager may close the socket.
-- In that case, an error of "Bad file descriptor" occurs.
-- We ignores it because we expect TimeoutThread.
receive :: Socket -> BufferPool -> Recv
receive :: Socket -> IORef ByteString -> Recv
receive Socket
sock IORef ByteString
pool = (IOException -> Recv) -> Recv -> Recv
forall (m :: * -> *) a.
MonadUnliftIO m =>
(IOException -> m a) -> m a -> m a
UnliftIO.handleIO IOException -> Recv
handler (Recv -> Recv) -> Recv -> Recv
forall a b. (a -> b) -> a -> b
$ IORef ByteString -> ((Ptr Word8, BufSize) -> IO BufSize) -> Recv
withBufferPool IORef ByteString
pool (((Ptr Word8, BufSize) -> IO BufSize) -> Recv)
-> ((Ptr Word8, BufSize) -> IO BufSize) -> Recv
forall a b. (a -> b) -> a -> b
$ \ (Ptr Word8
ptr, BufSize
size) -> do
#if MIN_VERSION_network(3,1,0)
  Socket -> (CInt -> IO BufSize) -> IO BufSize
forall r. Socket -> (CInt -> IO r) -> IO r
withFdSocket Socket
sock ((CInt -> IO BufSize) -> IO BufSize)
-> (CInt -> IO BufSize) -> IO BufSize
forall a b. (a -> b) -> a -> b
$ \CInt
fd -> do
#elif MIN_VERSION_network(3,0,0)
    fd <- fdSocket sock
#else
    let fd = fdSocket sock
#endif
    let size' :: CSize
size' = BufSize -> CSize
forall a b. (Integral a, Num b) => a -> b
fromIntegral BufSize
size
    CInt -> BufSize
forall a b. (Integral a, Num b) => a -> b
fromIntegral (CInt -> BufSize) -> IO CInt -> IO BufSize
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> CInt -> Ptr Word8 -> CSize -> IO CInt
receiveloop CInt
fd Ptr Word8
ptr CSize
size'
  where
    handler :: UnliftIO.IOException -> IO ByteString
    handler :: IOException -> Recv
handler IOException
e
      | IOException -> IOErrorType
E.ioeGetErrorType IOException
e IOErrorType -> IOErrorType -> Bool
forall a. Eq a => a -> a -> Bool
== IOErrorType
E.InvalidArgument = ByteString -> Recv
forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
""
      | Bool
otherwise                                = IOException -> Recv
forall (m :: * -> *) e a. (MonadIO m, Exception e) => e -> m a
UnliftIO.throwIO IOException
e

receiveBuf :: Socket -> RecvBuf
receiveBuf :: Socket -> RecvBuf
receiveBuf Socket
sock Ptr Word8
buf0 BufSize
siz0 = do
#if MIN_VERSION_network(3,1,0)
  Socket -> (CInt -> IO Bool) -> IO Bool
forall r. Socket -> (CInt -> IO r) -> IO r
withFdSocket Socket
sock ((CInt -> IO Bool) -> IO Bool) -> (CInt -> IO Bool) -> IO Bool
forall a b. (a -> b) -> a -> b
$ \CInt
fd -> do
#elif MIN_VERSION_network(3,0,0)
    fd <- fdSocket sock
#else
    let fd = fdSocket sock
#endif
    CInt -> RecvBuf
loop CInt
fd Ptr Word8
buf0 BufSize
siz0
  where
    loop :: CInt -> RecvBuf
loop CInt
_  Ptr Word8
_   BufSize
0   = Bool -> IO Bool
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True
    loop CInt
fd Ptr Word8
buf BufSize
siz = do
        BufSize
n <- CInt -> BufSize
forall a b. (Integral a, Num b) => a -> b
fromIntegral (CInt -> BufSize) -> IO CInt -> IO BufSize
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> CInt -> Ptr Word8 -> CSize -> IO CInt
receiveloop CInt
fd Ptr Word8
buf (BufSize -> CSize
forall a b. (Integral a, Num b) => a -> b
fromIntegral BufSize
siz)
        -- fixme: what should we do in the case of n == 0
        if BufSize
n BufSize -> BufSize -> Bool
forall a. Eq a => a -> a -> Bool
== BufSize
0 then
            Bool -> IO Bool
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False
          else
            CInt -> RecvBuf
loop CInt
fd (Ptr Word8
buf Ptr Word8 -> BufSize -> Ptr Word8
forall a b. Ptr a -> BufSize -> Ptr b
`plusPtr` BufSize
n) (BufSize
siz BufSize -> BufSize -> BufSize
forall a. Num a => a -> a -> a
- BufSize
n)

receiveloop :: CInt -> Ptr Word8 -> CSize -> IO CInt
receiveloop :: CInt -> Ptr Word8 -> CSize -> IO CInt
receiveloop CInt
sock Ptr Word8
ptr CSize
size = do
#ifdef mingw32_HOST_OS
    bytes <- windowsThreadBlockHack $ fromIntegral <$> readRawBufferPtr "recv" (FD sock 1) (castPtr ptr) 0 size
#else
    CInt
bytes <- CInt -> Ptr CChar -> CSize -> CInt -> IO CInt
c_recv CInt
sock (Ptr Word8 -> Ptr CChar
forall a b. Ptr a -> Ptr b
castPtr Ptr Word8
ptr) CSize
size CInt
0
#endif
    if CInt
bytes CInt -> CInt -> Bool
forall a. Eq a => a -> a -> Bool
== -CInt
1 then do
        Errno
errno <- IO Errno
getErrno
        if Errno
errno Errno -> Errno -> Bool
forall a. Eq a => a -> a -> Bool
== Errno
eAGAIN then do
            Fd -> IO ()
threadWaitRead (CInt -> Fd
Fd CInt
sock)
            CInt -> Ptr Word8 -> CSize -> IO CInt
receiveloop CInt
sock Ptr Word8
ptr CSize
size
          else
            String -> IO CInt
forall a. String -> IO a
throwErrno String
"receiveloop"
       else
        CInt -> IO CInt
forall (m :: * -> *) a. Monad m => a -> m a
return CInt
bytes

#ifndef mingw32_HOST_OS
-- fixme: the type of the return value
foreign import ccall unsafe "recv"
    c_recv :: CInt -> Ptr CChar -> CSize -> CInt -> IO CInt
#endif