module Network.EngineIO
  ( 
    
    
    
    initialize
  , handler
  , EngineIO
  , ServerAPI (..)
  , SocketApp(..)
    
  , send
  , receive
  , Socket
  , SocketId
  , socketId
  , getOpenSockets
  , dupRawReader
    
    
    
    
  , Packet(..)
  , parsePacket
  , encodePacket
  , PacketType
    
  , PacketContent(..)
    
  , Payload(..)
  , parsePayload
  , encodePayload
    
  , TransportType(..)
  , parseTransportType
  ) where
import Prelude hiding (any)
import Control.Applicative
import Control.Concurrent.MVar (MVar, newMVar, withMVar)
import Control.Exception (SomeException(SomeException), try)
import Control.Monad (MonadPlus, forever, guard, mzero, replicateM)
import Control.Monad.Trans.Iter (cutoff, delay, retract)
import Control.Monad.IO.Class (MonadIO, liftIO)
import Control.Monad.Loops (unfoldM)
import Control.Monad.Trans.Class (lift)
import Control.Monad.Trans.Either (eitherT, left)
import Control.Monad.Trans.Maybe (runMaybeT)
import Data.Aeson ((.=))
import Data.Char (digitToInt, intToDigit)
import Data.Foldable (asum, for_)
import Data.Function (fix, on)
import Data.Ix (inRange)
import Data.List (foldl')
import Data.Monoid ((<>), mconcat, mempty)
import Data.Ord (comparing)
import Data.Traversable (for)
import qualified Control.Concurrent.Async as Async
import qualified Control.Concurrent.STM as STM
import qualified Control.Concurrent.STM.Delay as STMDelay
import qualified Data.Aeson as Aeson
import qualified Data.Attoparsec.ByteString as Attoparsec
import qualified Data.Attoparsec.ByteString.Char8 as AttoparsecC8
import qualified Data.ByteString as BS
import qualified Data.ByteString.Base64 as Base64
import qualified Data.ByteString.Builder as Builder
import qualified Data.ByteString.Char8 as BSChar8
import qualified Data.ByteString.Lazy as LBS
import qualified Data.HashMap.Strict as HashMap
import qualified Data.Text as Text
import qualified Data.Text.Encoding as Text
import qualified Data.Vector as V
import qualified Network.WebSockets as WebSockets
import qualified Network.WebSockets.Connection as WebSockets
import qualified System.Random.MWC as Random
data PacketType = Open | Close | Ping | Pong | Message | Upgrade | Noop
  deriving (Bounded, Enum, Eq, Read, Show)
packetTypeToIndex :: Num i => PacketType -> i
packetTypeToIndex t =
  case t of
    Open -> 0
    Close -> 1
    Ping -> 2
    Pong -> 3
    Message -> 4
    Upgrade -> 5
    Noop -> 6
packetTypeFromIndex :: (Eq i, MonadPlus m, Num i) => i -> m PacketType
packetTypeFromIndex i =
  case i of
    0 -> return Open
    1 -> return Close
    2 -> return Ping
    3 -> return Pong
    4 -> return Message
    5 -> return Upgrade
    6 -> return Noop
    _ -> mzero
data Packet = Packet !PacketType !PacketContent
  deriving (Eq, Show)
data PacketContent
  = BinaryPacket !BS.ByteString
  | TextPacket !Text.Text
  deriving (Eq, Show)
parsePacket :: Attoparsec.Parser Packet
parsePacket = parsePacket' Attoparsec.takeByteString
parsePacket' :: Attoparsec.Parser BS.ByteString -> Attoparsec.Parser Packet
parsePacket' body = parseBase64 <|> parseBinary <|> parseText
  where
  parseBase64 = do
    _ <- AttoparsecC8.char 'b'
    Packet <$> c8PacketType
           <*> (either fail (return . BinaryPacket) . Base64.decode =<< body)
  parseBinary = do
    Packet <$> (packetTypeFromIndex =<< Attoparsec.satisfy (inRange (0, 6)))
           <*> (BinaryPacket <$> body)
  parseText = do
    Packet <$> c8PacketType
           <*> (TextPacket . Text.decodeUtf8 <$> body)
  c8PacketType =
    packetTypeFromIndex . digitToInt  =<< AttoparsecC8.satisfy (inRange ('0', '6'))
encodePacket
  :: Bool
  
  
  -> Packet
  -> Builder.Builder
encodePacket True (Packet t (BinaryPacket bytes)) =
  Builder.word8 (packetTypeToIndex t) <>
    Builder.byteString bytes
encodePacket False (Packet t (BinaryPacket bytes)) =
  Builder.char8 'b' <>
    Builder.char8 (intToDigit (packetTypeToIndex t)) <>
      Builder.byteString (Base64.encode bytes)
encodePacket _ (Packet t (TextPacket bytes)) =
  Builder.char8 (intToDigit (packetTypeToIndex t)) <>
    Builder.byteString (Text.encodeUtf8 bytes)
newtype Payload = Payload (V.Vector Packet)
  deriving (Eq, Show)
parsePayload :: Attoparsec.Parser Payload
parsePayload = Payload <$> go
  where
  go = do
    _ <- Attoparsec.satisfy (`elem` [0, 1])
    len <- parseLength =<< Attoparsec.many1 (Attoparsec.satisfy (inRange (0, 9)))
    _ <- Attoparsec.word8 maxBound
    packet <- parsePacket' (Attoparsec.take (len  1)) 
    (V.singleton packet <$ Attoparsec.endOfInput) <|> (V.cons packet <$> go)
  parseLength bytes = do
    guard (length bytes <= 319)
    return $ foldl' (\n x -> n * 10 + x) 0 $ map fromIntegral bytes
encodePayload
  :: Bool
  
  
  -> Payload
  -> Builder.Builder
encodePayload supportsBinary (Payload packets) =
  let encodeOne packet =
        let bytes = encodePacket supportsBinary packet
        in mconcat [ Builder.word8 $ if isBinaryPacket packet then 1 else 0
                   , mconcat $ map (Builder.word8 . read . pure) $
                       show (LBS.length (Builder.toLazyByteString bytes))
                   , Builder.word8 maxBound
                   , bytes
                   ]
  in V.foldl' (\bytes p -> bytes <> encodeOne p) mempty packets
  where
  isBinaryPacket (Packet _ (BinaryPacket _)) = True
  isBinaryPacket _ = False
data TransportType
  = Polling
    
  | Websocket
    
  deriving (Eq, Show)
instance Aeson.ToJSON TransportType where
  toJSON t = Aeson.toJSON $ (`asTypeOf` show t) $
    case t of
      Polling -> "polling"
      Websocket -> "websocket"
parseTransportType :: Text.Text -> Maybe TransportType
parseTransportType t =
  case t of
    "polling" -> Just Polling
    "websocket" -> Just Websocket
    _ -> Nothing
type SocketId = BS.ByteString
data Transport = Transport
  { transIn :: STM.TChan Packet
  , transOut :: STM.TChan Packet
  , transType :: !TransportType
  }
data Socket = Socket
  { socketId :: !SocketId
  , socketTransport :: STM.TVar Transport
  , socketIncomingMessages :: STM.TChan PacketContent
  , socketOutgoingMessages :: STM.TChan PacketContent
  , socketRawIncomingBroadcast :: STM.TChan Packet
  }
instance Eq Socket where
  (==) = (==) `on` socketId
instance Ord Socket where
  compare = comparing socketId
receive :: Socket -> STM.STM PacketContent
receive Socket{..} = STM.readTChan socketIncomingMessages
send :: Socket -> PacketContent -> STM.STM ()
send Socket{..} = STM.writeTChan socketOutgoingMessages
data ServerAPI m = ServerAPI
  { srvGetQueryParams :: m (HashMap.HashMap BS.ByteString [BS.ByteString])
    
    
  , srvTerminateWithResponse :: Int -> BS.ByteString -> Builder.Builder -> forall a . m a
    
    
    
  , srvParseRequestBody :: forall a. Attoparsec.Parser a -> m (Either String a)
    
  , srvGetRequestMethod :: m BS.ByteString
    
    
  , srvRunWebSocket :: WebSockets.ServerApp -> m ()
    
  }
data EngineIO = EngineIO
  { eioOpenSessions :: STM.TVar (HashMap.HashMap SocketId Socket)
  , eioRng :: MVar Random.GenIO
  }
initialize :: IO EngineIO
initialize =
  EngineIO
    <$> STM.newTVarIO mempty
    <*> (Random.createSystemRandom >>= newMVar)
getOpenSockets :: EngineIO -> STM.STM (HashMap.HashMap SocketId Socket)
getOpenSockets = STM.readTVar . eioOpenSessions
data EngineIOError = BadRequest | TransportUnknown | SessionIdUnknown
  deriving (Bounded, Enum, Eq, Show)
data SocketApp = SocketApp
  { saApp :: IO ()
    
    
    
  , saOnDisconnect :: IO ()
    
    
  }
handler :: MonadIO m => EngineIO -> (Socket -> m SocketApp) -> ServerAPI m -> m ()
handler eio socketHandler api@ServerAPI{..} = do
  queryParams <- srvGetQueryParams
  eitherT (serveError api) return $ do
    reqTransport <- maybe (left TransportUnknown) return $ do
      [t] <- HashMap.lookup "transport" queryParams
      parseTransportType (Text.decodeUtf8 t)
    socket <-
      for (HashMap.lookup "sid" queryParams) $ \sids -> do
        sid <- case sids of
                 [sid] -> return sid
                 _ -> left SessionIdUnknown
        mSocket <- liftIO (STM.atomically (HashMap.lookup sid <$> getOpenSockets eio))
        case mSocket of
          Nothing -> left SessionIdUnknown
          Just s -> return s
    supportsBinary <-
      case HashMap.lookup "b64" queryParams of
        Just ["1"] -> return False
        Just ["0"] -> return True
        Nothing    -> return True
        _          -> left BadRequest
    case socket of
      Just s -> do
        transport <- liftIO $ STM.atomically $ STM.readTVar (socketTransport s)
        case transType transport of
          Polling
            | reqTransport == Polling -> lift (handlePoll api transport supportsBinary)
            | reqTransport == Websocket -> lift (upgrade api s)
          _ -> left BadRequest
      Nothing ->
        lift (freshSession eio socketHandler api supportsBinary)
freshSession
  :: MonadIO m
  => EngineIO
  -> (Socket -> m SocketApp)
  -> ServerAPI m
  -> Bool
  -> m ()
freshSession eio socketHandler api supportsBinary = do
  socket <- do
    mkSocket <- liftIO $ do
      transport <- STM.newTVarIO =<< (Transport <$> STM.newTChanIO <*> STM.newTChanIO <*> pure Polling)
      incoming <- STM.newTChanIO
      outgoing <- STM.newTChanIO
      rawInBroadcast <- STM.newBroadcastTChanIO
      return (\sId -> Socket sId transport incoming outgoing rawInBroadcast)
    let
      tryAllocation = liftIO $ do
        sId <- newSocketId eio
        STM.atomically $ runMaybeT $ do
          openSessions <- lift (STM.readTVar (eioOpenSessions eio))
          guard (not (HashMap.member sId openSessions))
          let socket = mkSocket sId
          lift (STM.modifyTVar' (eioOpenSessions eio) (HashMap.insert sId socket))
          return socket
      untilSuccess f = maybe (delay (untilSuccess f)) return =<< f
    maybeSocket <- retract (cutoff 10 (untilSuccess tryAllocation))
    maybe (srvTerminateWithResponse api 500 "text/plain" "Session allocation failed")
          return maybeSocket
  app <- socketHandler socket
  userSpace <- liftIO $ Async.async (saApp app)
  pingTimeoutDelay <- liftIO $ STMDelay.newDelay (pingTimeout * 1000000)
  heartbeat <- liftIO $ Async.async $
    STM.atomically (STMDelay.waitDelay pingTimeoutDelay)
  brain <- liftIO $ Async.async $ fix $ \loop -> do
    mMessage <- STM.atomically $ do
      transport <- STM.readTVar (socketTransport socket)
      asum
        [ do req <- STM.readTChan (transIn transport)
             case req of
               Packet Message m ->
                 STM.writeTChan (socketIncomingMessages socket) m
               Packet Ping m ->
                 STM.writeTChan (transOut transport) (Packet Pong m)
               _ ->
                 return ()
             STM.writeTChan (socketRawIncomingBroadcast socket) req
             return (Just req)
        , do STM.readTChan (socketOutgoingMessages socket)
               >>= STM.writeTChan (transOut transport) . Packet Message
             return Nothing
        ]
    case mMessage of
      Just (Packet Close _) ->
        return ()
      _ -> do
        STMDelay.updateDelay pingTimeoutDelay (pingTimeout * 1000000)
        loop
  _ <- liftIO $ Async.async $ do
    _ <- Async.waitAnyCatchCancel [ userSpace, brain, heartbeat ]
    STM.atomically (STM.modifyTVar' (eioOpenSessions eio) (HashMap.delete (socketId socket)))
    saOnDisconnect app
  let openMessage = OpenMessage { omSocketId = socketId socket
                                , omUpgrades = [ Websocket ]
                                , omPingTimeout = pingTimeout * 1000
                                , omPingInterval = 25000
                                }
      payload = Payload $ V.singleton $
                  Packet Open (TextPacket $ Text.decodeUtf8 $ LBS.toStrict $ Aeson.encode openMessage)
  writeBytes api (encodePayload supportsBinary payload)
  where
  pingTimeout = 60
upgrade :: MonadIO m => ServerAPI m -> Socket -> m ()
upgrade ServerAPI{..} socket = srvRunWebSocket go
  where
  go pending = do
    conn <- WebSockets.acceptRequest $
      
      pending { WebSockets.pendingOnAccept = (const $ return ()) }
    mWsTransport <- runMaybeT $ do
      Packet Ping (TextPacket "probe") <- lift (receivePacket conn)
      lift (sendPacket conn (Packet Pong (TextPacket "probe")))
      (wsIn, wsOut) <- liftIO $ STM.atomically $ do
        currentTransport <- STM.readTVar (socketTransport socket)
        STM.writeTChan (transOut currentTransport) (Packet Noop (TextPacket Text.empty))
        return (transIn currentTransport, transOut currentTransport)
      Packet Upgrade body <- lift (receivePacket conn)
      guard (body == TextPacket Text.empty || body == BinaryPacket BS.empty)
      return (Transport wsIn wsOut Websocket)
    for_ mWsTransport $ \wsTransport@Transport { transIn = wsIn, transOut = wsOut } -> do
      
      
      STM.atomically (STM.writeTVar (socketTransport socket) wsTransport)
      reader <- Async.async $ forever $ do
        p <- STM.atomically (STM.readTChan wsOut)
        sendPacket conn p
      fix $ \loop -> do
        e <- try (receivePacket conn >>= STM.atomically . STM.writeTChan wsIn)
        case e of
          Left (SomeException _) ->
            return ()
          Right _ -> loop
      Async.cancel reader
      STM.atomically (STM.writeTChan wsIn (Packet Close (TextPacket Text.empty)))
  receivePacket conn = do
    msg <- WebSockets.receiveDataMessage conn
    case msg of
      WebSockets.Text bytes ->
        case Attoparsec.parseOnly parsePacket (LBS.toStrict bytes)  of
          Left ex -> do
            putStrLn $ "Malformed packet received: " ++ show bytes ++ " (" ++ show ex ++ ")"
            receivePacket conn
          Right p -> return p
      other -> do
        putStrLn $ "Unknown WebSocket message: " ++ show other
        receivePacket conn
  sendPacket conn (Packet t (TextPacket text)) =
    WebSockets.sendTextData conn $
      Text.encodeUtf8 $
        Text.pack (pure $ intToDigit (packetTypeToIndex t)) <> text
  sendPacket conn p@(Packet _ (BinaryPacket _)) = do
    WebSockets.sendBinaryData conn (Builder.toLazyByteString (encodePacket True p))
handlePoll :: MonadIO m => ServerAPI m -> Transport -> Bool -> m ()
handlePoll api@ServerAPI{..} transport supportsBinary = do
  requestMethod <- srvGetRequestMethod
  case requestMethod of
    m | m == "GET" -> poll
    m | m == "POST" -> post
    _ -> serveError api BadRequest
  where
  poll = do
    readTimeout <- liftIO $ STM.registerDelay (45 * 1000000)
    let out = transOut transport
    
    
    
    packets <- liftIO $ do
      p <- STM.atomically $ do
        let dequeueHead = Just <$> STM.readTChan out
            timeout = Nothing <$ (STM.readTVar readTimeout >>= STM.check)
        dequeueHead <|> timeout
      case p of
        Just p' ->
          (p' :) <$> unfoldM (STM.atomically (STM.tryReadTChan (transOut transport)))
        Nothing ->
          return [ Packet Ping (BinaryPacket mempty) ]
    writeBytes api (encodePayload supportsBinary (Payload (V.fromList packets)))
  post = do
    p <- srvParseRequestBody parsePayload
    case p of
      Left ex -> do
        liftIO $ putStrLn $ "WARNING: Parse failure in Network.EngineIO.handlePoll: " ++ show ex
        srvTerminateWithResponse 400 "text/plain" "Empty request body"
      Right (Payload packets) ->
        liftIO $ STM.atomically (V.mapM_ (STM.writeTChan (transIn transport)) packets)
writeBytes :: Monad m => ServerAPI m -> Builder.Builder -> m a
writeBytes ServerAPI {..} builder = do
  srvTerminateWithResponse 200 "application/octet-stream" builder
newSocketId :: EngineIO -> IO SocketId
newSocketId eio =
  Base64.encode . BS.pack
    <$> withMVar (eioRng eio) (replicateM 15 . Random.uniformR (0, 63))
data OpenMessage = OpenMessage
  { omSocketId :: !SocketId
  , omUpgrades :: [TransportType]
  , omPingTimeout :: !Int
  , omPingInterval :: !Int
  }
instance Aeson.ToJSON OpenMessage where
  toJSON OpenMessage {..} = Aeson.object
    [ "sid" .= Text.decodeUtf8 omSocketId
    , "upgrades" .= omUpgrades
    , "pingTimeout" .= omPingTimeout
    , "pingInterval" .= omPingInterval
    ]
serveError :: Monad m => ServerAPI m -> EngineIOError -> m a
serveError ServerAPI{..} e = srvTerminateWithResponse 400 "application/json" $
  Builder.lazyByteString $ Aeson.encode $ Aeson.object
    [ "code" .= errorCode, "message" .= errorMessage ]
  where
  errorCode :: Int
  errorCode = case e of
                TransportUnknown -> 0
                SessionIdUnknown -> 1
                BadRequest -> 3
  errorMessage :: Text.Text
  errorMessage = case e of
                   TransportUnknown -> "Transport unknown"
                   SessionIdUnknown -> "Session ID unknown"
                   BadRequest -> "Bad request"
dupRawReader :: Socket -> IO (STM.STM Packet)
dupRawReader s = do
  c <- STM.atomically (STM.dupTChan (socketRawIncomingBroadcast s))
  return (STM.readTChan c)