-----------------------------------------------------------------------------
-- |
-- Module: Network.Socket.Enumerator
-- Copyright: 2010 John Millikin
-- License: MIT
--
-- Maintainer: jmillikin@gmail.com
-- Portability: portable
--
-----------------------------------------------------------------------------
module Network.Socket.Enumerator
	( enumSocket
	, enumSocketFrom
	, enumSocketTimed
	, iterSocket
	, iterSocketTo
	, iterSocketTimed
	) where
import qualified Control.Exception as Exc
import           Control.Monad.IO.Class (MonadIO, liftIO)
import qualified Data.ByteString as B
import           Data.Enumerator ((>>==))
import qualified Data.Enumerator as E
import qualified Network.Socket as S
import           Network.Socket.ByteString (sendMany, sendManyTo, recv, recvFrom)
import           System.Timeout (timeout)

-- | Enumerate binary data from a 'S.Socket', using 'recv'. The socket
-- must be connected.
--
-- The buffer size should be a small power of 2, such as 4096.
enumSocket :: MonadIO m
           => Integer -- ^ Buffer size
           -> S.Socket
           -> E.Enumerator B.ByteString m b
enumSocket bufferSize sock = enumLoop $ \loop k -> do
	let intSize = fromInteger bufferSize
	
	bytes <- try (recv sock intSize)
	if B.null bytes
		then E.continue k
		else k (E.Chunks [bytes]) >>== loop

-- | Enumerate binary data from a 'S.Socket', using 'recvFrom'. The socket
-- does not have to be connected. Each chunk of data received will be paired
-- with its address.
enumSocketFrom :: MonadIO m
               => Integer -- ^ Buffer size
               -> S.Socket
               -> E.Enumerator (B.ByteString, S.SockAddr) m b
enumSocketFrom bufferSize sock = enumLoop $ \loop k -> do
	let intSize = fromInteger bufferSize
	
	(bytes, addr) <- try (recvFrom sock intSize)
	if B.null bytes
		then E.continue k
		else k (E.Chunks [(bytes, addr)]) >>== loop

-- | Enumerate binary data from a 'Socket', using 'recv'. The socket must
-- be connected.
--
-- The buffer size should be a small power of 2, such as 4096.
--
-- If any call to 'recv' takes longer than the timeout, 'enumSocketTimed'
-- will throw an error. To add a timeout for the entire session, wrap the
-- call to 'E.run' in 'timeout'.
--
-- Since: 0.1.2
enumSocketTimed :: MonadIO m
                => Integer -- ^ Buffer size
                -> Integer -- ^ Timeout, in microseconds
                -> S.Socket
                -> E.Enumerator B.ByteString m b
enumSocketTimed bufferSize maxWait sock = enumLoop $ \loop k -> do
	let intSize = fromInteger bufferSize
	    intWait = fromInteger maxWait
	    timedOut = Exc.ErrorCall "enumSocketTimed: timeout exceeded"
	
	tried <- try (timeout intWait (recv sock intSize))
	case tried of
		Nothing -> E.throwError timedOut
		Just bytes -> if B.null bytes
			then E.continue k
			else k (E.Chunks [bytes]) >>== loop

-- | Write data to a 'S.Socket', using 'sendMany'. The socket must be connected.
iterSocket :: MonadIO m
           => S.Socket
           -> E.Iteratee B.ByteString m ()
iterSocket sock = foldMany (\xs -> try (sendMany sock xs))

-- | Write data to a 'S.Socket', using 'sendManyTo'. The socket does not
-- have to be connected.
iterSocketTo :: MonadIO m
             => S.Socket
             -> S.SockAddr
             -> E.Iteratee B.ByteString m ()
iterSocketTo sock addr = foldMany (\xs -> try (sendManyTo sock xs addr))

-- | Write data to a 'S.Socket', using 'sendMany'. The socket must be connected.
--
-- If any call to 'sendMany' takes longer than the timeout, 'iterSocketTimed'
-- will throw an error. To add a timeout for the entire session, wrap the
-- call to 'E.run' in 'timeout'.
--
-- Since: 0.1.2
iterSocketTimed :: MonadIO m
                => Integer -- ^ Timeout, in microseconds
                -> S.Socket
                -> E.Iteratee B.ByteString m ()
iterSocketTimed maxWait sock = foldMany $ \xs -> do
	let intWait = fromInteger maxWait
	    timedOut = Exc.ErrorCall "iterSocketTimed: timeout exceeded"
	
	tried <- try (timeout intWait (sendMany sock xs))
	case tried of
		Nothing -> E.throwError timedOut
		Just _ -> return ()

try :: MonadIO m => IO b -> E.Iteratee a m b
try io = do
	tried <- liftIO (Exc.try io)
	case tried of
		Left err -> E.throwError (err :: Exc.SomeException)
		Right b -> return b

enumLoop :: Monad m
         => ((E.Step a m b -> E.Iteratee a m b)
          -> (E.Stream a -> E.Iteratee a m b)
          -> E.Iteratee a m b)
         -> E.Enumerator a m b
enumLoop iter = loop where
	loop (E.Continue k) = iter loop k
	loop step = E.returnI step

foldMany :: Monad m => ([a] -> E.Iteratee a m b) -> E.Iteratee a m ()
foldMany f = E.continue step where
	step E.EOF = E.yield () E.EOF
	step (E.Chunks []) = E.continue step
	step (E.Chunks xs) = f xs >> E.continue step