----------------------------------------------------------------------------- -- | -- Module: Network.Socket.Enumerator -- Copyright: 2010 John Millikin -- License: MIT -- -- Maintainer: jmillikin@gmail.com -- Portability: portable -- ----------------------------------------------------------------------------- module Network.Socket.Enumerator ( enumSocket , enumSocketFrom , enumSocketTimed , iterSocket , iterSocketTo , iterSocketTimed ) where import qualified Control.Exception as Exc import Control.Monad.IO.Class (MonadIO, liftIO) import qualified Data.ByteString as B import Data.Enumerator ((>>==)) import qualified Data.Enumerator as E import qualified Network.Socket as S import Network.Socket.ByteString (sendMany, sendManyTo, recv, recvFrom) import System.Timeout (timeout) -- | Enumerate binary data from a 'S.Socket', using 'recv'. The socket -- must be connected. -- -- The buffer size should be a small power of 2, such as 4096. enumSocket :: MonadIO m => Integer -- ^ Buffer size -> S.Socket -> E.Enumerator B.ByteString m b enumSocket bufferSize sock = enumLoop $ \loop k -> do let intSize = fromInteger bufferSize bytes <- try (recv sock intSize) if B.null bytes then E.continue k else k (E.Chunks [bytes]) >>== loop -- | Enumerate binary data from a 'S.Socket', using 'recvFrom'. The socket -- does not have to be connected. Each chunk of data received will be paired -- with its address. enumSocketFrom :: MonadIO m => Integer -- ^ Buffer size -> S.Socket -> E.Enumerator (B.ByteString, S.SockAddr) m b enumSocketFrom bufferSize sock = enumLoop $ \loop k -> do let intSize = fromInteger bufferSize (bytes, addr) <- try (recvFrom sock intSize) if B.null bytes then E.continue k else k (E.Chunks [(bytes, addr)]) >>== loop -- | Enumerate binary data from a 'Socket', using 'recv'. The socket must -- be connected. -- -- The buffer size should be a small power of 2, such as 4096. -- -- If any call to 'recv' takes longer than the timeout, 'enumSocketTimed' -- will throw an error. To add a timeout for the entire session, wrap the -- call to 'E.run' in 'timeout'. -- -- Since: 0.1.2 enumSocketTimed :: MonadIO m => Integer -- ^ Buffer size -> Integer -- ^ Timeout, in microseconds -> S.Socket -> E.Enumerator B.ByteString m b enumSocketTimed bufferSize maxWait sock = enumLoop $ \loop k -> do let intSize = fromInteger bufferSize intWait = fromInteger maxWait timedOut = Exc.ErrorCall "enumSocketTimed: timeout exceeded" tried <- try (timeout intWait (recv sock intSize)) case tried of Nothing -> E.throwError timedOut Just bytes -> if B.null bytes then E.continue k else k (E.Chunks [bytes]) >>== loop -- | Write data to a 'S.Socket', using 'sendMany'. The socket must be connected. iterSocket :: MonadIO m => S.Socket -> E.Iteratee B.ByteString m () iterSocket sock = foldMany (\xs -> try (sendMany sock xs)) -- | Write data to a 'S.Socket', using 'sendManyTo'. The socket does not -- have to be connected. iterSocketTo :: MonadIO m => S.Socket -> S.SockAddr -> E.Iteratee B.ByteString m () iterSocketTo sock addr = foldMany (\xs -> try (sendManyTo sock xs addr)) -- | Write data to a 'S.Socket', using 'sendMany'. The socket must be connected. -- -- If any call to 'sendMany' takes longer than the timeout, 'iterSocketTimed' -- will throw an error. To add a timeout for the entire session, wrap the -- call to 'E.run' in 'timeout'. -- -- Since: 0.1.2 iterSocketTimed :: MonadIO m => Integer -- ^ Timeout, in microseconds -> S.Socket -> E.Iteratee B.ByteString m () iterSocketTimed maxWait sock = foldMany $ \xs -> do let intWait = fromInteger maxWait timedOut = Exc.ErrorCall "iterSocketTimed: timeout exceeded" tried <- try (timeout intWait (sendMany sock xs)) case tried of Nothing -> E.throwError timedOut Just _ -> return () try :: MonadIO m => IO b -> E.Iteratee a m b try io = do tried <- liftIO (Exc.try io) case tried of Left err -> E.throwError (err :: Exc.SomeException) Right b -> return b enumLoop :: Monad m => ((E.Step a m b -> E.Iteratee a m b) -> (E.Stream a -> E.Iteratee a m b) -> E.Iteratee a m b) -> E.Enumerator a m b enumLoop iter = loop where loop (E.Continue k) = iter loop k loop step = E.returnI step foldMany :: Monad m => ([a] -> E.Iteratee a m b) -> E.Iteratee a m () foldMany f = E.continue step where step E.EOF = E.yield () E.EOF step (E.Chunks []) = E.continue step step (E.Chunks xs) = f xs >> E.continue step