{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE OverloadedStrings #-}

module Network.Socket.ByteString.Lazy.Posix (
    -- * Send data to a socket
    send
  , sendAll
  ) where

import qualified Data.ByteString.Lazy               as L
import           Data.ByteString.Unsafe             (unsafeUseAsCStringLen)
import           Foreign.Marshal.Array              (allocaArray)

import           Network.Socket.ByteString.IO       (waitWhen0)
import           Network.Socket.ByteString.Internal (c_writev)
import           Network.Socket.Imports
import           Network.Socket.Internal
import           Network.Socket.Posix.IOVec    (IOVec (IOVec))
import           Network.Socket.Types

-- -----------------------------------------------------------------------------
-- Sending
send
    :: Socket -- ^ Connected socket
    -> L.ByteString -- ^ Data to send
    -> IO Int64 -- ^ Number of bytes sent
send :: Socket -> ByteString -> IO Int64
send Socket
s ByteString
lbs = do
    let cs :: [ByteString]
cs  = forall a. Int -> [a] -> [a]
take Int
maxNumChunks (ByteString -> [ByteString]
L.toChunks ByteString
lbs)
        len :: Int
len = forall (t :: * -> *) a. Foldable t => t a -> Int
length [ByteString]
cs
    CSsize
siz <- forall r. Socket -> (CInt -> IO r) -> IO r
withFdSocket Socket
s forall a b. (a -> b) -> a -> b
$ \CInt
fd -> forall a b. Storable a => Int -> (Ptr a -> IO b) -> IO b
allocaArray Int
len forall a b. (a -> b) -> a -> b
$ \Ptr IOVec
ptr ->
             forall {t} {a}.
Num t =>
[ByteString] -> Ptr IOVec -> (t -> IO a) -> IO a
withPokes [ByteString]
cs Ptr IOVec
ptr forall a b. (a -> b) -> a -> b
$ \CInt
niovs ->
               forall a. (Eq a, Num a) => Socket -> String -> IO a -> IO a
throwSocketErrorWaitWrite Socket
s String
"writev" forall a b. (a -> b) -> a -> b
$ CInt -> Ptr IOVec -> CInt -> IO CSsize
c_writev CInt
fd Ptr IOVec
ptr CInt
niovs
    forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall a b. (Integral a, Num b) => a -> b
fromIntegral CSsize
siz
  where
    withPokes :: [ByteString] -> Ptr IOVec -> (t -> IO a) -> IO a
withPokes [ByteString]
ss Ptr IOVec
p t -> IO a
f = [ByteString] -> Ptr IOVec -> Int -> t -> IO a
loop [ByteString]
ss Ptr IOVec
p Int
0 t
0
      where
        loop :: [ByteString] -> Ptr IOVec -> Int -> t -> IO a
loop (ByteString
c:[ByteString]
cs) Ptr IOVec
q Int
k !t
niovs
            | Int
k forall a. Ord a => a -> a -> Bool
< Int
maxNumBytes = forall a. ByteString -> (CStringLen -> IO a) -> IO a
unsafeUseAsCStringLen ByteString
c forall a b. (a -> b) -> a -> b
$ \(Ptr CChar
ptr, Int
len) -> do
                forall a. Storable a => Ptr a -> a -> IO ()
poke Ptr IOVec
q forall a b. (a -> b) -> a -> b
$ Ptr Word8 -> CSize -> IOVec
IOVec (forall a b. Ptr a -> Ptr b
castPtr Ptr CChar
ptr) (forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
len)
                [ByteString] -> Ptr IOVec -> Int -> t -> IO a
loop [ByteString]
cs
                     (Ptr IOVec
q forall a b. Ptr a -> Int -> Ptr b
`plusPtr` forall a. Storable a => a -> Int
sizeOf (Ptr Word8 -> CSize -> IOVec
IOVec forall a. Ptr a
nullPtr CSize
0))
                     (Int
k forall a. Num a => a -> a -> a
+ forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
len)
                     (t
niovs forall a. Num a => a -> a -> a
+ t
1)
            | Bool
otherwise = t -> IO a
f t
niovs
        loop [ByteString]
_ Ptr IOVec
_ Int
_ t
niovs = t -> IO a
f t
niovs
    maxNumBytes :: Int
maxNumBytes  = Int
4194304 :: Int -- maximum number of bytes to transmit in one system call
    maxNumChunks :: Int
maxNumChunks = Int
1024 :: Int -- maximum number of chunks to transmit in one system call

sendAll
    :: Socket -- ^ Connected socket
    -> L.ByteString -- ^ Data to send
    -> IO ()
sendAll :: Socket -> ByteString -> IO ()
sendAll Socket
_ ByteString
"" = forall (m :: * -> *) a. Monad m => a -> m a
return ()
sendAll Socket
s ByteString
bs0 = ByteString -> IO ()
loop ByteString
bs0
  where
    loop :: ByteString -> IO ()
loop ByteString
bs = do
        -- "send" throws an exception.
        Int64
sent <- Socket -> ByteString -> IO Int64
send Socket
s ByteString
bs
        Int -> Socket -> IO ()
waitWhen0 (forall a b. (Integral a, Num b) => a -> b
fromIntegral Int64
sent) Socket
s
        forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int64
sent forall a. Eq a => a -> a -> Bool
/= ByteString -> Int64
L.length ByteString
bs) forall a b. (a -> b) -> a -> b
$ ByteString -> IO ()
loop forall a b. (a -> b) -> a -> b
$ Int64 -> ByteString -> ByteString
L.drop Int64
sent ByteString
bs