{-# LANGUAGE DeriveDataTypeable #-}
-- | Network streams for use with strict
-- `B.ByteString`s. For lazy ByteString's, see
-- `Network.ByteString.Lazy.Stream`. Use this module with
-- `Data.Serialize.Serialize` to send data over a stream without worrying
-- about sending and receiving the lengths.
-- One can also send data in chunks, sending data whenever it is ready, and
-- the data will be collected transparently to the client interface.
module Network.ByteString.Stream



import Control.Applicative
import Control.Concurrent
import Control.Concurrent.STM
import Control.Concurrent.STM.TChan
import Control.Exception
import Control.Monad
import Control.Monad.IO.Class
import qualified Data.ByteString as B
import qualified Data.Enumerator as E
import Data.Serialize hiding (Result)
import Data.Typeable
import Data.Word
import Network
import System.IO

data StreamItem = SBytes B.ByteString
                  | SInvalidate
                  | SSuccess
                  | SFailure
                  deriving (Read, Show)

instance Serialize StreamItem where
    put (SBytes bytes) = do put (0 :: Word8)
                            put $ B.length bytes
                            putByteString bytes
    put SInvalidate = put (1 :: Word8)
    put SSuccess = put (2 :: Word8)
    put SFailure = put (3 :: Word8)

    -- | Not very useful on the client end since we have to read the first 5
    -- bytes first to determine how much more to read.
    get = do cons <- get :: Get Word8
             case cons of 
                0 -> do l <- get :: Get Int
                        bytes <- getByteString l
                        return $ SBytes bytes
                1 -> return SInvalidate
                2 -> return SSuccess
                3 -> return SFailure

-- | The core data type for a Stream. It can only be created using `withStream`.
data Stream = Stream Handle (TChan StreamItem)

-- | Represents whether the stream transaction was a success or a failure;
-- nothing is done by the library with the attached value. If you do not need
-- to send back a value to the caller of `withStream`, you can use
-- a `Result` ().
data Result a = Success a
                | Failure a
                deriving (Read, Show)

intSize :: Int
intSize = B.length (encode (1 :: Int))

openStream :: Handle -> IO Stream
openStream h = do chan <- atomically newTChan 
                  return $ Stream h chan

closeStream :: Stream -> IO ()
closeStream (Stream _ chan) = atomically $ writeTChan chan SSuccess

failStream :: Stream -> IO ()
failStream (Stream _ chan) = atomically $ unGetTChan chan SFailure

clearChan :: TChan a -> STM ()
clearChan chan = do empty <- isEmptyTChan chan
                    unless empty $ readTChan chan >> clearChan chan

-- | Should be called in a forked IO instance (as in, use `forkIO`)
streamProcess :: Stream -> IO ()
streamProcess s@(Stream h chan) = do item <- atomically $ readTChan chan
                                     B.hPut h $ encode item
                                     case item of
                                        SBytes bytes -> streamProcess s
                                        SInvalidate -> streamProcess s
                                        SSuccess -> return ()
                                        SFailure -> return ()

-- | Opens a stream using the given handle and passes it to the function, and
-- then unwraps the result given and gives any user data that the specific
-- function wants to give back.
withStream :: Handle -> (Stream -> IO (Result a)) -> IO a
withStream h f = do s <- openStream h
                    forkIO $ streamProcess s
                    res <- f s
                    case res of
                        Success r -> closeStream s >> return r
                        Failure r -> failStream s >> return r

-- | Doesn't fail, but tells the client that all the data sent by the stream
-- so far has been invalidated, and hence the queue of messages to be sent
-- is cleared.
invalidate :: Stream -> IO ()
invalidate (Stream _ chan) = atomically $ do clearChan chan
                                             writeTChan chan SInvalidate

-- | Writes partial or full data over a `Stream`, placing it in the queue
-- of all of the partial data.
write :: Stream -> B.ByteString -> IO ()
write (Stream _ chan) bytes = atomically $ writeTChan chan (SBytes bytes)

-- | Serializes data and sends it over a newly created `Stream`.
send :: Serialize a => Handle -> a -> IO ()
send h x = withStream h $ \s -> do write s $ encode x
                                   return $ Success ()

readBytes :: Handle -> IO (Maybe B.ByteString)
readBytes h = do sizeB <- B.hGet h intSize
                 let sizeE = decode sizeB :: Either String Int
                 case sizeE of
                  Left _ -> return Nothing
                  Right size -> Just <$> B.hGet h size

readStreamItem :: Handle -> IO (Maybe StreamItem)
readStreamItem h = do
    codeB <- B.hGet h 1
    let codeE = decode codeB :: Either String Word8
    case codeE of
        Left _ -> return Nothing
        Right code -> case code of
            0 -> do bytesM <- readBytes h
                    case bytesM of
                        Nothing -> return Nothing
                        Just bytes -> return . Just $ SBytes bytes
            1 -> return . Just $ SInvalidate
            2 -> return . Just $ SSuccess
            3 -> return . Just $ SFailure
-- | Receives a ByteString sent via a `Stream`.
receive :: Handle -> IO (Maybe B.ByteString)
receive = flip receiveLoop B.empty

receiveLoop :: Handle -> B.ByteString -> IO (Maybe B.ByteString)
receiveLoop h bytes = do
    itemM <- readStreamItem h
    case itemM of
        Nothing -> return Nothing
        Just item ->
            case item of
                SBytes bytes' -> receiveLoop h $ bytes `B.append` bytes'
                SInvalidate -> receiveLoop h B.empty
                SSuccess -> return $ Just bytes
                SFailure -> return Nothing

----- Enumerator-based API

data StreamEnumException = InvalidateException | FailureException | DecodeException
    deriving (Read, Show, Typeable)

instance Exception StreamEnumException

-- | Enumerator-based version of receive that allows the client to fold over
-- the data as it is being received. Each `B.ByteString` is a single chunk sent
-- from `write`. Keep in mind that any IO performed is dangerous if you are
-- possibly expected an Invalidation, since then that IO could end up being
-- incorrect. Hence, it is more useful to simply use this in a pure manner to
-- build up some result data as the bytes are being streamed in.
receiveE :: MonadIO m =>
            -> E.Iteratee B.ByteString m b
            -> m (Maybe b)
receiveE handle iter = do res <- E.run $ streamEnum handle E.$$ iter
                          case res of
                            Left exception -> case fromException exception of
                                Nothing -> return Nothing
                                Just InvalidateException -> receiveE handle iter
                                Just FailureException -> return Nothing
                                Just DecodeException -> error "DecodeException"
                            Right val -> return $ Just val

returnError :: (Exception e, Monad m) => e -> E.Iteratee B.ByteString m b
returnError e = E.returnI $ E.Error (toException e)

streamEnum :: MonadIO m => Handle -> E.Enumerator B.ByteString m b
streamEnum h (E.Continue k) = do
    itemM <- liftIO $ readStreamItem h
    case itemM of
        Nothing -> returnError DecodeException
        Just item ->
            case item of
                SBytes bytes -> k (E.Chunks [bytes]) E.>>== streamEnum h
                SInvalidate -> returnError InvalidateException
                SSuccess -> E.continue k
                SFailure -> returnError FailureException
streamEnum h step = E.returnI step