module SecondTransfer.Http2.Session(
http2Session
,getFrameFromSession
,sendFrameToSession
,sendCommandToSession
,defaultSessionsConfig
,sessionId
,reportErrorCallback
,sessionsCallbacks
,nextSessionId
,makeSessionsContext
,sessionsConfig
,sessionExceptionHandler
,CoherentSession
,SessionInput(..)
,SessionInputCommand(..)
,SessionOutput(..)
,SessionOutputCommand(..)
,SessionsContext(..)
,SessionCoordinates(..)
,SessionComponent(..)
,SessionsCallbacks
,SessionsConfig
,ErrorCallback
,OutputFrame
,InputFrame
) where
import Control.Concurrent (ThreadId, forkIO)
import Control.Concurrent.Chan
import Control.Exception (SomeException, throwTo)
import qualified Control.Exception as E
import Control.Monad (forever)
import Control.Monad.IO.Class (liftIO)
import Control.Monad.Trans.Reader
import Control.Concurrent.MVar
import qualified Data.ByteString as B
import qualified Data.ByteString.Builder as Bu
import qualified Data.ByteString.Lazy as Bl
import Data.Conduit
import Data.Conduit.List (foldMapM)
import qualified Data.HashTable.IO as H
import qualified Data.IntSet as NS
import Data.Monoid as Mo
import Control.Lens
import qualified Network.HPACK as HP
import qualified Network.HTTP2 as NH2
import System.Log.Logger
import SecondTransfer.MainLoop.CoherentWorker
import SecondTransfer.MainLoop.Tokens
import SecondTransfer.Utils (unfoldChannelAndSource)
import SecondTransfer.Exception
type OutputFrame = (NH2.EncodeInfo, NH2.FramePayload)
type InputFrame = NH2.Frame
useChunkLength :: Int
useChunkLength = 16384
data HeadersSent = HeadersSent
data WorkerThreadEnvironment = WorkerThreadEnvironment {
_streamId :: GlobalStreamId
, _headersOutput :: Chan (GlobalStreamId, MVar HeadersSent, Headers)
,_dataOutput :: Chan (GlobalStreamId, B.ByteString)
,_streamsCancelled_WTE :: MVar NS.IntSet
}
makeLenses ''WorkerThreadEnvironment
type Session = (SessionInput, SessionOutput)
newtype SessionInput = SessionInput ( Chan (Either SessionInputCommand InputFrame) )
sendFrameToSession :: SessionInput -> InputFrame -> IO ()
sendFrameToSession (SessionInput chan) frame = writeChan chan $ Right frame
sendCommandToSession :: SessionInput -> SessionInputCommand -> IO ()
sendCommandToSession (SessionInput chan) command = writeChan chan $ Left command
newtype SessionOutput = SessionOutput ( Chan (Either SessionOutputCommand OutputFrame) )
getFrameFromSession :: SessionOutput -> IO (Either SessionOutputCommand OutputFrame)
getFrameFromSession (SessionOutput chan) = readChan chan
type HashTable k v = H.CuckooHashTable k v
type Stream2HeaderBlockFragment = HashTable GlobalStreamId Bu.Builder
type WorkerMonad = ReaderT WorkerThreadEnvironment IO
data SessionInputCommand =
CancelSession_SIC
deriving Show
data SessionOutputCommand =
CancelSession_SOC
deriving Show
newtype SessionCoordinates = SessionCoordinates Int
deriving Show
instance Eq SessionCoordinates where
(SessionCoordinates a) == (SessionCoordinates b) = a == b
sessionId :: Functor f => (Int -> f Int) -> SessionCoordinates -> f SessionCoordinates
sessionId f (SessionCoordinates session_id) =
fmap (\ s' -> (SessionCoordinates s')) (f session_id)
data SessionComponent =
SessionInputThread_SessionComponent
|SessionHeadersOutputThread_SessionComponent
|SessionDataOutputThread_SessionComponent
|Framer_SessionComponent
deriving Show
type ErrorCallback = (SessionComponent, SessionCoordinates, SomeException) -> IO ()
data SessionsCallbacks = SessionsCallbacks {
_reportErrorCallback :: Maybe ErrorCallback
}
makeLenses ''SessionsCallbacks
data SessionsConfig = SessionsConfig {
_sessionsCallbacks :: SessionsCallbacks
}
sessionsCallbacks :: Lens' SessionsConfig SessionsCallbacks
sessionsCallbacks f (
SessionsConfig {
_sessionsCallbacks= s
}) = fmap (\ s' -> SessionsConfig {_sessionsCallbacks = s'}) (f s)
data SessionsContext = SessionsContext {
_sessionsConfig :: SessionsConfig
,_nextSessionId :: MVar Int
}
makeLenses ''SessionsContext
type SessionMaker = SessionsContext -> IO Session
type CoherentSession = CoherentWorker -> SessionMaker
defaultSessionsConfig :: SessionsConfig
defaultSessionsConfig = SessionsConfig {
_sessionsCallbacks = SessionsCallbacks {
_reportErrorCallback = Nothing
}
}
makeSessionsContext :: SessionsConfig -> IO SessionsContext
makeSessionsContext sessions_config = do
next_session_id_mvar <- newMVar 1
return $ SessionsContext {
_sessionsConfig = sessions_config,
_nextSessionId = next_session_id_mvar
}
data PostInputMechanism = PostInputMechanism (Chan (Maybe B.ByteString), InputDataStream)
data SessionData = SessionData {
_sessionsContext :: SessionsContext
,_sessionInput :: Chan (Either SessionInputCommand InputFrame)
,_sessionOutput :: MVar (Chan (Either SessionOutputCommand OutputFrame))
,_toEncodeHeaders :: MVar HP.DynamicTable
,_toDecodeHeaders :: MVar HP.DynamicTable
,_stream2HeaderBlockFragment :: Stream2HeaderBlockFragment
,_forWorkerThread :: WorkerThreadEnvironment
,_coherentWorker :: CoherentWorker
,_streamsCancelled :: MVar NS.IntSet
,_stream2PostInputMechanism :: HashTable Int PostInputMechanism
,_stream2WorkerThread :: HashTable Int ThreadId
,_sessionIdAtSession :: Int
}
makeLenses ''SessionData
http2Session :: CoherentWorker -> Int -> SessionsContext -> IO Session
http2Session coherent_worker session_id sessions_context = do
session_input <- newChan
session_output <- newChan
session_output_mvar <- newMVar session_output
stream_request_headers <- H.new :: IO Stream2HeaderBlockFragment
decode_headers_table <- HP.newDynamicTableForDecoding 4096
decode_headers_table_mvar <- newMVar decode_headers_table
encode_headers_table <- HP.newDynamicTableForEncoding 4096
encode_headers_table_mvar <- newMVar encode_headers_table
headers_output <- newChan :: IO (Chan (GlobalStreamId, MVar HeadersSent, Headers))
data_output <- newChan :: IO (Chan (GlobalStreamId,B.ByteString))
stream2postinputmechanism <- H.new
stream2workerthread <- H.new
cancelled_streams_mvar <- newMVar $ NS.empty :: IO (MVar NS.IntSet)
let for_worker_thread = WorkerThreadEnvironment {
_streamId = error "NotInitialized"
,_headersOutput = headers_output
,_dataOutput = data_output
,_streamsCancelled_WTE = cancelled_streams_mvar
}
let session_data = SessionData {
_sessionsContext = sessions_context
,_sessionInput = session_input
,_sessionOutput = session_output_mvar
,_toDecodeHeaders = decode_headers_table_mvar
,_toEncodeHeaders = encode_headers_table_mvar
,_stream2HeaderBlockFragment = stream_request_headers
,_forWorkerThread = for_worker_thread
,_coherentWorker = coherent_worker
,_streamsCancelled = cancelled_streams_mvar
,_stream2PostInputMechanism = stream2postinputmechanism
,_stream2WorkerThread = stream2workerthread
,_sessionIdAtSession = session_id
}
let
exc_handler :: SessionComponent -> HTTP2SessionException -> IO ()
exc_handler component e = sessionExceptionHandler component session_id sessions_context e
exc_guard :: SessionComponent -> IO () -> IO ()
exc_guard component action = E.catch action $ exc_handler component
forkIO $ exc_guard SessionInputThread_SessionComponent
$ runReaderT sessionInputThread session_data
forkIO $ exc_guard SessionHeadersOutputThread_SessionComponent
$ runReaderT (headersOutputThread headers_output session_output_mvar) session_data
forkIO $ exc_guard SessionDataOutputThread_SessionComponent
$ dataOutputThread data_output session_output_mvar
return ( (SessionInput session_input),
(SessionOutput session_output) )
sessionInputThread :: ReaderT SessionData IO ()
sessionInputThread = do
liftIO $ debugM "HTTP2.Session" "Entering sessionInputThread"
session_input <- view sessionInput
decode_headers_table_mvar <- view toDecodeHeaders
stream_request_headers <- view stream2HeaderBlockFragment
cancelled_streams_mvar <- view streamsCancelled
coherent_worker <- view coherentWorker
for_worker_thread_uns <- view forWorkerThread
stream2workerthread <- view stream2WorkerThread
input <- liftIO $ readChan session_input
liftIO $ debugM "HTTP2.Session" $ "Got a frame or a command: " ++ (show input)
case input of
Left CancelSession_SIC -> do
liftIO $ do
H.mapM_
(\ (_, thread_id) -> do
throwTo thread_id StreamCancelledException
infoM "HTTP2.Session" $ "Stream successfully interrupted"
)
stream2workerthread
return ()
Right frame | Just (stream_id, bytes) <- frameIsHeaderOfStream frame -> do
appendHeaderFragmentBlock stream_id bytes
if frameEndsHeaders frame then
do
let for_worker_thread = set streamId stream_id for_worker_thread_uns
headers_bytes <- getHeaderBytes stream_id
dyn_table <- liftIO $ takeMVar decode_headers_table_mvar
(new_table, header_list ) <- liftIO $ HP.decodeHeader dyn_table headers_bytes
liftIO $ H.delete stream_request_headers stream_id
liftIO $ putMVar decode_headers_table_mvar new_table
post_data_source <- if not (frameEndsStream frame)
then do
mechanism <- createMechanismForStream stream_id
let source = postDataSourceFromMechanism mechanism
return $ Just source
else do
return Nothing
liftIO $ do
thread_id <- forkIO $ runReaderT
(workerThread (header_list, post_data_source) coherent_worker)
for_worker_thread
H.insert stream2workerthread stream_id thread_id
return ()
else
return ()
continue
Right frame@(NH2.Frame _ (NH2.RSTStreamFrame error_code_id)) -> do
let stream_id = streamIdFromFrame frame
liftIO $ do
infoM "HTTP2.Session" $ "Stream reset: " ++ (show error_code_id)
cancelled_streams <- takeMVar cancelled_streams_mvar
infoM "HTTP2.Session" $ "Cancelled stream was: " ++ (show stream_id)
putMVar cancelled_streams_mvar $ NS.insert stream_id cancelled_streams
maybe_thread_id <- H.lookup stream2workerthread stream_id
case maybe_thread_id of
Nothing ->
error "InterruptingUnexistentStream"
Just thread_id -> do
throwTo thread_id StreamCancelledException
infoM "HTTP2.Session" $ "Stream successfully interrupted"
continue
Right frame@(NH2.Frame (NH2.FrameHeader _ _ nh2_stream_id) (NH2.DataFrame somebytes)) -> do
let stream_id = NH2.fromStreamIdentifier nh2_stream_id
streamWorkerSendData stream_id somebytes
sendOutFrame
(NH2.EncodeInfo
NH2.defaultFlags
nh2_stream_id
Nothing
)
(NH2.WindowUpdateFrame
(fromIntegral (B.length somebytes))
)
sendOutFrame
(NH2.EncodeInfo
NH2.defaultFlags
(NH2.toStreamIdentifier 0)
Nothing
)
(NH2.WindowUpdateFrame
(fromIntegral (B.length somebytes))
)
if frameEndsStream frame
then do
closePostDataSource stream_id
else
return ()
continue
Right (NH2.Frame (NH2.FrameHeader _ flags _) (NH2.PingFrame _)) | NH2.testAck flags-> do
continue
Right (NH2.Frame (NH2.FrameHeader _ _ _) (NH2.PingFrame somebytes)) -> do
liftIO $ debugM "HTTP2.Session" "Ping processed"
sendOutFrame
(NH2.EncodeInfo
(NH2.setAck NH2.defaultFlags)
(NH2.toStreamIdentifier 0)
Nothing
)
(NH2.PingFrame somebytes)
continue
Right (NH2.Frame frame_header (NH2.SettingsFrame _)) | isSettingsAck frame_header -> do
continue
Right (NH2.Frame _ (NH2.SettingsFrame settings_list)) -> do
liftIO $ debugM "HTTP2.Session" $ "Received settings: " ++ (show settings_list)
sendOutFrame
(NH2.EncodeInfo
(NH2.setAck NH2.defaultFlags)
(NH2.toStreamIdentifier 0)
Nothing )
(NH2.SettingsFrame [])
continue
Right somethingelse -> do
liftIO $ errorM "HTTP2.Session" $ "Received problematic frame: "
liftIO $ errorM "HTTP2.Session" $ ".. " ++ (show somethingelse)
continue
where
continue = sessionInputThread
sendOutFrame :: NH2.EncodeInfo -> NH2.FramePayload -> ReaderT SessionData IO ()
sendOutFrame encode_info payload = do
session_output_mvar <- view sessionOutput
session_output <- liftIO $ takeMVar session_output_mvar
liftIO $ writeChan session_output $ Right (encode_info, payload)
liftIO $ putMVar session_output_mvar session_output
frameEndsStream :: InputFrame -> Bool
frameEndsStream (NH2.Frame (NH2.FrameHeader _ flags _) _) = NH2.testEndStream flags
createMechanismForStream :: GlobalStreamId -> ReaderT SessionData IO PostInputMechanism
createMechanismForStream stream_id = do
(chan, source) <- liftIO $ unfoldChannelAndSource
stream2postinputmechanism <- view stream2PostInputMechanism
let pim = PostInputMechanism (chan, source)
liftIO $ H.insert stream2postinputmechanism stream_id pim
return pim
closePostDataSource :: GlobalStreamId -> ReaderT SessionData IO ()
closePostDataSource stream_id = do
stream2postinputmechanism <- view stream2PostInputMechanism
pim_maybe <- liftIO $ H.lookup stream2postinputmechanism stream_id
case pim_maybe of
Just (PostInputMechanism (chan, _)) ->
liftIO $ writeChan chan Nothing
Nothing ->
error "Internal error/closePostDataSource"
streamWorkerSendData :: Int -> B.ByteString -> ReaderT SessionData IO ()
streamWorkerSendData stream_id bytes = do
s2pim <- view stream2PostInputMechanism
pim_maybe <- liftIO $ H.lookup s2pim stream_id
case pim_maybe of
Just pim ->
sendBytesToPim pim bytes
Nothing ->
error "Internal error"
sendBytesToPim :: PostInputMechanism -> B.ByteString -> ReaderT SessionData IO ()
sendBytesToPim (PostInputMechanism (chan, _)) bytes =
liftIO $ writeChan chan (Just bytes)
postDataSourceFromMechanism :: PostInputMechanism -> InputDataStream
postDataSourceFromMechanism (PostInputMechanism (_, source)) = source
isSettingsAck :: NH2.FrameHeader -> Bool
isSettingsAck (NH2.FrameHeader _ flags _) =
NH2.testAck flags
isStreamCancelled :: GlobalStreamId -> WorkerMonad Bool
isStreamCancelled stream_id = do
cancelled_streams_mvar <- view streamsCancelled_WTE
cancelled_streams <- liftIO $ readMVar cancelled_streams_mvar
return $ NS.member stream_id cancelled_streams
workerThread :: Request -> CoherentWorker -> WorkerMonad ()
workerThread req coherent_worker =
do
headers_output <- view headersOutput
stream_id <- view streamId
(headers, _, data_and_conclussion) <- liftIO $ coherent_worker req
headers_sent <- liftIO $ newEmptyMVar
liftIO $ writeChan headers_output (stream_id, headers_sent, headers)
is_stream_cancelled <- isStreamCancelled stream_id
if not is_stream_cancelled
then do
(maybe_footers, _) <- runConduit $
(transPipe liftIO data_and_conclussion)
`fuseBothMaybe`
(sendDataOfStream stream_id headers_sent)
return ()
else
return ()
sendDataOfStream :: GlobalStreamId -> MVar HeadersSent -> Sink B.ByteString (ReaderT WorkerThreadEnvironment IO) ()
sendDataOfStream stream_id headers_sent = do
data_output <- view dataOutput
transPipe liftIO $ do
liftIO $ takeMVar headers_sent
foldMapM $ \ bytes ->
writeChan data_output (stream_id, bytes)
appendHeaderFragmentBlock :: GlobalStreamId -> B.ByteString -> ReaderT SessionData IO ()
appendHeaderFragmentBlock global_stream_id bytes = do
ht <- view stream2HeaderBlockFragment
maybe_old_block <- liftIO $ H.lookup ht global_stream_id
new_block <- return $ case maybe_old_block of
Nothing -> Bu.byteString bytes
Just something -> something `mappend` (Bu.byteString bytes)
liftIO $ H.insert ht global_stream_id new_block
getHeaderBytes :: GlobalStreamId -> ReaderT SessionData IO B.ByteString
getHeaderBytes global_stream_id = do
ht <- view stream2HeaderBlockFragment
Just bytes <- liftIO $ H.lookup ht global_stream_id
return $ Bl.toStrict $ Bu.toLazyByteString bytes
frameIsHeaderOfStream :: InputFrame -> Maybe (GlobalStreamId, B.ByteString)
frameIsHeaderOfStream (NH2.Frame (NH2.FrameHeader _ _ stream_id) ( NH2.HeadersFrame _ block_fragment ) )
= Just (NH2.fromStreamIdentifier stream_id, block_fragment)
frameIsHeaderOfStream (NH2.Frame (NH2.FrameHeader _ _ stream_id) ( NH2.ContinuationFrame block_fragment) )
= Just (NH2.fromStreamIdentifier stream_id, block_fragment)
frameIsHeaderOfStream _
= Nothing
frameEndsHeaders :: InputFrame -> Bool
frameEndsHeaders (NH2.Frame (NH2.FrameHeader _ flags _) _) = NH2.testEndHeader flags
streamIdFromFrame :: InputFrame -> GlobalStreamId
streamIdFromFrame (NH2.Frame (NH2.FrameHeader _ _ stream_id) _) = NH2.fromStreamIdentifier stream_id
headersOutputThread :: Chan (GlobalStreamId, MVar HeadersSent, Headers)
-> MVar (Chan (Either SessionOutputCommand OutputFrame))
-> ReaderT SessionData IO ()
headersOutputThread input_chan session_output_mvar = forever $ do
(stream_id, headers_ready_mvar, headers) <- liftIO $ readChan input_chan
encode_dyn_table_mvar <- view toEncodeHeaders
encode_dyn_table <- liftIO $ takeMVar encode_dyn_table_mvar
(new_dyn_table, data_to_send ) <- liftIO $ HP.encodeHeader HP.defaultEncodeStrategy encode_dyn_table headers
liftIO $ putMVar encode_dyn_table_mvar new_dyn_table
let bs_chunks = bytestringChunk useChunkLength data_to_send
session_output <- liftIO $ takeMVar session_output_mvar
if (length bs_chunks) == 1
then
do
let flags = NH2.setEndHeader NH2.defaultFlags
liftIO $ writeChan session_output $ Right ( NH2.EncodeInfo {
NH2.encodeFlags = flags
,NH2.encodeStreamId = NH2.toStreamIdentifier stream_id
,NH2.encodePadding = Nothing },
NH2.HeadersFrame Nothing (head bs_chunks)
)
else
do
let flags = NH2.defaultFlags
liftIO $ writeChan session_output $ Right ( NH2.EncodeInfo {
NH2.encodeFlags = flags
,NH2.encodeStreamId = NH2.toStreamIdentifier stream_id
,NH2.encodePadding = Nothing },
NH2.HeadersFrame Nothing (head bs_chunks)
)
let
writeContinuations :: [B.ByteString] -> ReaderT SessionData IO ()
writeContinuations (last_fragment:[]) = liftIO $
writeChan session_output $ Right ( NH2.EncodeInfo {
NH2.encodeFlags = NH2.setEndHeader NH2.defaultFlags
,NH2.encodeStreamId = NH2.toStreamIdentifier stream_id
,NH2.encodePadding = Nothing },
NH2.ContinuationFrame last_fragment
)
writeContinuations (fragment:xs) = do
liftIO $ writeChan session_output $ Right ( NH2.EncodeInfo {
NH2.encodeFlags = NH2.defaultFlags
,NH2.encodeStreamId = NH2.toStreamIdentifier stream_id
,NH2.encodePadding = Nothing },
NH2.ContinuationFrame fragment
)
writeContinuations xs
writeContinuations (tail bs_chunks)
liftIO $ putMVar session_output_mvar session_output
liftIO $ putMVar headers_ready_mvar HeadersSent
bytestringChunk :: Int -> B.ByteString -> [B.ByteString]
bytestringChunk len s | (B.length s) < len = [ s ]
bytestringChunk len s = h:(bytestringChunk len xs)
where
(h, xs) = B.splitAt len s
dataOutputThread :: Chan (GlobalStreamId, B.ByteString)
-> MVar (Chan (Either SessionOutputCommand OutputFrame))
-> IO ()
dataOutputThread input_chan session_output_mvar = forever $ do
(stream_id, contents) <- readChan input_chan
let bs_chunks = bytestringChunk useChunkLength contents
session_output <- liftIO $ takeMVar session_output_mvar
if (length bs_chunks) == 1
then
do
let flags = NH2.setEndStream NH2.defaultFlags
liftIO $ writeChan session_output $ Right ( NH2.EncodeInfo {
NH2.encodeFlags = flags
,NH2.encodeStreamId = NH2.toStreamIdentifier stream_id
,NH2.encodePadding = Nothing },
NH2.DataFrame (head bs_chunks)
)
else
do
let flags = NH2.defaultFlags
liftIO $ writeChan session_output $ Right ( NH2.EncodeInfo {
NH2.encodeFlags = flags
,NH2.encodeStreamId = NH2.toStreamIdentifier stream_id
,NH2.encodePadding = Nothing },
NH2.DataFrame (head bs_chunks)
)
let
writeContinuations :: [B.ByteString] -> IO ()
writeContinuations (last_fragment:[]) = liftIO $
writeChan session_output $ Right ( NH2.EncodeInfo {
NH2.encodeFlags = NH2.setEndStream NH2.defaultFlags
,NH2.encodeStreamId = NH2.toStreamIdentifier stream_id
,NH2.encodePadding = Nothing },
NH2.DataFrame last_fragment
)
writeContinuations (fragment:xs) = do
liftIO $ writeChan session_output $ Right ( NH2.EncodeInfo {
NH2.encodeFlags = NH2.defaultFlags
,NH2.encodeStreamId = NH2.toStreamIdentifier stream_id
,NH2.encodePadding = Nothing },
NH2.DataFrame fragment
)
writeContinuations xs
writeContinuations (tail bs_chunks)
liftIO $ debugM "HTTP2.Session" $ "Output capability restored"
liftIO $ putMVar session_output_mvar session_output
sessionExceptionHandler :: E.Exception e => SessionComponent -> Int -> SessionsContext -> e -> IO ()
sessionExceptionHandler session_component session_id sessions_context e = do
let
getit = ( sessionsConfig . sessionsCallbacks . reportErrorCallback )
maybe_error_callback = sessions_context ^. getit
error_tuple = (
session_component,
SessionCoordinates session_id,
E.toException e
)
case maybe_error_callback of
Nothing ->
errorM "HTTP2.Session" (show (e))
Just callback ->
callback error_tuple