module Network.WebSockets.Stream
    ( Stream
    , makeStream
    , makeSocketStream
    , makeEchoStream
    , parse
    , write
    , close
    ) where
import           Control.Concurrent.MVar        (MVar, newEmptyMVar, newMVar,
                                                 putMVar, takeMVar, withMVar)
import           Control.Exception              (onException, throwIO)
import           Control.Monad                  (forM_, when)
import qualified Data.Attoparsec.ByteString     as Atto
import qualified Data.ByteString                as B
import qualified Data.ByteString.Lazy           as BL
import           Data.IORef                     (IORef, atomicModifyIORef,
                                                 newIORef, readIORef,
                                                 writeIORef)
import qualified Network.Socket                 as S
import qualified Network.Socket.ByteString      as SB (recv)
#if !defined(mingw32_HOST_OS)
import qualified Network.Socket.ByteString.Lazy as SBL (sendAll)
#else
import qualified Network.Socket.ByteString      as SB (sendAll)
#endif
import           Network.WebSockets.Types
data StreamState
    = Closed !B.ByteString  
    | Open   !B.ByteString  
data Stream = Stream
    { streamIn    :: IO (Maybe B.ByteString)
    , streamOut   :: (Maybe BL.ByteString -> IO ())
    , streamState :: !(IORef StreamState)
    }
makeStream
    :: IO (Maybe B.ByteString)         
    -> (Maybe BL.ByteString -> IO ())  
    -> IO Stream                       
makeStream receive send = do
    ref         <- newIORef (Open B.empty)
    receiveLock <- newMVar ()
    sendLock    <- newMVar ()
    return $ Stream (receive' ref receiveLock) (send' ref sendLock) ref
  where
    closeRef :: IORef StreamState -> IO ()
    closeRef ref = atomicModifyIORef ref $ \state -> case state of
        Open   buf -> (Closed buf, ())
        Closed buf -> (Closed buf, ())
    assertNotClosed :: IORef StreamState -> IO a -> IO a
    assertNotClosed ref io = do
        state <- readIORef ref
        case state of
            Closed _ -> throwIO ConnectionClosed
            Open   _ -> io
    receive' :: IORef StreamState -> MVar () -> IO (Maybe B.ByteString)
    receive' ref lock = withMVar lock $ \() -> assertNotClosed ref $ do
        mbBs <- onException receive (closeRef ref)
        case mbBs of
            Nothing -> closeRef ref >> return Nothing
            Just bs -> return (Just bs)
    send' :: IORef StreamState -> MVar () -> (Maybe BL.ByteString -> IO ())
    send' ref lock mbBs = withMVar lock $ \() -> assertNotClosed ref $ do
        when (mbBs == Nothing) (closeRef ref)
        onException (send mbBs) (closeRef ref)
makeSocketStream :: S.Socket -> IO Stream
makeSocketStream socket = makeStream receive send
  where
    receive = do
        bs <- SB.recv socket 1024
        return $ if B.null bs then Nothing else Just bs
    send Nothing   = return ()
    send (Just bs) = do
#if !defined(mingw32_HOST_OS)
        SBL.sendAll socket bs
#else
        forM_ (BL.toChunks bs) (SB.sendAll socket)
#endif
makeEchoStream :: IO Stream
makeEchoStream = do
    mvar <- newEmptyMVar
    makeStream (takeMVar mvar) $ \mbBs -> case mbBs of
        Nothing -> putMVar mvar Nothing
        Just bs -> forM_ (BL.toChunks bs) $ \c -> putMVar mvar (Just c)
parse :: Stream -> Atto.Parser a -> IO (Maybe a)
parse stream parser = do
    state <- readIORef (streamState stream)
    case state of
        Closed remainder
            | B.null remainder -> return Nothing
            | otherwise        -> go (Atto.parse parser remainder) True
        Open buffer
            | B.null buffer -> do
                mbBs <- streamIn stream
                case mbBs of
                    Nothing -> do
                        writeIORef (streamState stream) (Closed B.empty)
                        return Nothing
                    Just bs -> go (Atto.parse parser bs) False
            | otherwise     -> go (Atto.parse parser buffer) False
  where
    
    go (Atto.Done remainder x) closed = do
        writeIORef (streamState stream) $
            if closed then Closed remainder else Open remainder
        return (Just x)
    go (Atto.Partial f) closed
        | closed    = go (f B.empty) True
        | otherwise = do
            mbBs <- streamIn stream
            case mbBs of
                Nothing -> go (f B.empty) True
                Just bs -> go (f bs) False
    go (Atto.Fail _ _ err) _ = throwIO (ParseException err)
write :: Stream -> BL.ByteString -> IO ()
write stream = streamOut stream . Just
close :: Stream -> IO ()
close stream = streamOut stream Nothing