-- |Wrappers for "Network.Socket" calls that run over entire buffers or builds.
{-# LANGUAGE ScopedTypeVariables #-}
module Network.Socket.All
  ( recvAllBuf
  , recvStorable
  , sendAllBuf
  , sendStorable
  , sendBuilderWith
  , sendBuilder
  ) where

import           Control.Monad (unless)
import qualified Data.ByteString.Builder as B
import qualified Data.ByteString.Builder.Extra as B
import           Data.Word (Word8)
import           Foreign.Ptr (Ptr, castPtr, plusPtr)
import           Foreign.Marshal (alloca, allocaBytes, with)
import           Foreign.Storable (Storable, sizeOf, peek)
import           Network.Socket
import qualified Network.Socket.ByteString as BS

allBufWith :: (Ptr a -> Int -> IO Int) -> Ptr a -> Int -> IO Int
allBufWith :: forall a. (Ptr a -> Int -> IO Int) -> Ptr a -> Int -> IO Int
allBufWith Ptr a -> Int -> IO Int
f Ptr a
p Int
n = Int -> IO Int
run Int
0 where
  run :: Int -> IO Int
run Int
l
    | Int
l forall a. Ord a => a -> a -> Bool
>= Int
n = forall (m :: * -> *) a. Monad m => a -> m a
return Int
l
    | Bool
otherwise = do
      Int
r <- Ptr a -> Int -> IO Int
f (Ptr a
p forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
l) (Int
n forall a. Num a => a -> a -> a
- Int
l)
      if Int
r forall a. Eq a => a -> a -> Bool
== Int
0
        then forall (m :: * -> *) a. Monad m => a -> m a
return Int
l
        else Int -> IO Int
run (Int
l forall a. Num a => a -> a -> a
+ Int
r)

-- |Receive data from a socket, attempting to fill the entire buffer, blocking as necessary.
-- Any short read indicates a 0 (closed) result from 'recvBuf'.
recvAllBuf :: Socket -> Ptr Word8 -> Int -> IO Int
recvAllBuf :: Socket -> Ptr Word8 -> Int -> IO Int
recvAllBuf = forall a. (Ptr a -> Int -> IO Int) -> Ptr a -> Int -> IO Int
allBufWith forall b c a. (b -> c) -> (a -> b) -> a -> c
. Socket -> Ptr Word8 -> Int -> IO Int
recvBuf

-- |Receive a raw memory object from a socket.
-- Returns 'Nothing' on short read.
recvStorable :: forall a . Storable a => Socket -> IO (Maybe a)
recvStorable :: forall a. Storable a => Socket -> IO (Maybe a)
recvStorable Socket
s = forall a b. Storable a => (Ptr a -> IO b) -> IO b
alloca forall a b. (a -> b) -> a -> b
$ \Ptr a
p -> do
  Int
r <- Socket -> Ptr Word8 -> Int -> IO Int
recvAllBuf Socket
s (forall a b. Ptr a -> Ptr b
castPtr Ptr a
p) Int
n
  if Int
r forall a. Ord a => a -> a -> Bool
< Int
n
    then forall (m :: * -> *) a. Monad m => a -> m a
return forall a. Maybe a
Nothing
    else forall a. a -> Maybe a
Just forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. Storable a => Ptr a -> IO a
peek Ptr a
p
  where
  n :: Int
n = forall a. Storable a => a -> Int
sizeOf (forall a. HasCallStack => a
undefined :: a)

-- |Send an entire buffer to a socket, blocking as necessary.
-- Any short write indicates a 0 result from 'sendBuf'.
sendAllBuf :: Socket -> Ptr Word8 -> Int -> IO ()
sendAllBuf :: Socket -> Ptr Word8 -> Int -> IO ()
sendAllBuf Socket
s Ptr Word8
p Int
n = do
  Int
r <- forall a. (Ptr a -> Int -> IO Int) -> Ptr a -> Int -> IO Int
allBufWith (Socket -> Ptr Word8 -> Int -> IO Int
sendBuf Socket
s) Ptr Word8
p Int
n
  forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Int
r forall a. Eq a => a -> a -> Bool
== Int
n) forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a. MonadFail m => String -> m a
fail forall a b. (a -> b) -> a -> b
$ String
"sendAllBuf: sent " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show Int
r forall a. [a] -> [a] -> [a]
++ String
"/" forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show Int
n

-- |Send a raw memory object to a socket.
-- Returns 'False' on short read.
sendStorable :: Storable a => Socket -> a -> IO ()
sendStorable :: forall a. Storable a => Socket -> a -> IO ()
sendStorable Socket
s a
x = forall a b. Storable a => a -> (Ptr a -> IO b) -> IO b
with a
x forall a b. (a -> b) -> a -> b
$ \Ptr a
p ->
  Socket -> Ptr Word8 -> Int -> IO ()
sendAllBuf Socket
s (forall a b. Ptr a -> Ptr b
castPtr Ptr a
p) (forall a. Storable a => a -> Int
sizeOf a
x)

-- |Effeciently send an entire builder to a network socket, using a specific buffer size.
-- Of course, this could be made even more efficient by using something like 'B.AllocationStrategy', but this is good enough for most purposes.
sendBuilderWith :: Socket -> Int -> B.Builder -> IO ()
sendBuilderWith :: Socket -> Int -> Builder -> IO ()
sendBuilderWith Socket
s Int
z0 = Int -> BufferWriter -> IO ()
buf Int
z0 forall b c a. (b -> c) -> (a -> b) -> a -> c
. Builder -> BufferWriter
B.runBuilder where
  buf :: Int -> BufferWriter -> IO ()
buf Int
z BufferWriter
w = do
    Next
n <- forall a b. Int -> (Ptr a -> IO b) -> IO b
allocaBytes Int
z forall a b. (a -> b) -> a -> b
$ Int -> BufferWriter -> Ptr Word8 -> IO Next
run Int
z BufferWriter
w
    case Next
n of
      B.More Int
z' BufferWriter
w' -> Int -> BufferWriter -> IO ()
buf Int
z' BufferWriter
w'
      ~Next
B.Done -> forall (m :: * -> *) a. Monad m => a -> m a
return ()
  run :: Int -> BufferWriter -> Ptr Word8 -> IO Next
run Int
z BufferWriter
w Ptr Word8
p = do
    (Int
l, Next
n) <- BufferWriter
w Ptr Word8
p Int
z
    Socket -> Ptr Word8 -> Int -> IO ()
sendAllBuf Socket
s Ptr Word8
p Int
l
    case Next
n of
      B.More Int
z' BufferWriter
w' | Int
z' forall a. Ord a => a -> a -> Bool
<= Int
z ->
        Int -> BufferWriter -> Ptr Word8 -> IO Next
run Int
z BufferWriter
w' Ptr Word8
p
      B.Chunk ByteString
b BufferWriter
w' -> do
        Socket -> ByteString -> IO ()
BS.sendAll Socket
s ByteString
b
        Int -> BufferWriter -> Ptr Word8 -> IO Next
run Int
z BufferWriter
w' Ptr Word8
p
      Next
_ -> forall (m :: * -> *) a. Monad m => a -> m a
return Next
n

-- |Effeciently send an entire builder to a network socket, using a buffer size of 'B.defaultChunkSize'.
sendBuilder :: Socket -> B.Builder -> IO ()
sendBuilder :: Socket -> Builder -> IO ()
sendBuilder Socket
s = Socket -> Int -> Builder -> IO ()
sendBuilderWith Socket
s Int
B.defaultChunkSize