module Network.TLS.Context.Internal
(
ClientParams(..)
, ServerParams(..)
, defaultParamsClient
, SessionID
, SessionData(..)
, MaxFragmentEnum(..)
, Measurement(..)
, Context(..)
, Hooks(..)
, ctxEOF
, ctxHasSSLv2ClientHello
, ctxDisableSSLv2ClientHello
, ctxEstablished
, withLog
, ctxWithHooks
, contextModifyHooks
, setEOF
, setEstablished
, contextFlush
, contextClose
, contextSend
, contextRecv
, updateMeasure
, withMeasure
, withReadLock
, withWriteLock
, withStateLock
, withRWLock
, Information(..)
, contextGetInformation
, throwCore
, usingState
, usingState_
, runTxState
, runRxState
, usingHState
, getHState
, getStateRNG
) where
import Network.TLS.Backend
import Network.TLS.Extension
import Network.TLS.Cipher
import Network.TLS.Struct
import Network.TLS.Compression (Compression)
import Network.TLS.State
import Network.TLS.Handshake.State
import Network.TLS.Hooks
import Network.TLS.Record.State
import Network.TLS.Parameters
import Network.TLS.Measurement
import qualified Data.ByteString as B
import Control.Concurrent.MVar
import Control.Monad.State
import Control.Exception (throwIO, Exception())
import Data.IORef
import Data.Tuple
data Information = Information
{ infoVersion :: Version
, infoCipher :: Cipher
, infoCompression :: Compression
, infoMasterSecret :: Maybe Bytes
, infoClientRandom :: Maybe ClientRandom
, infoServerRandom :: Maybe ServerRandom
} deriving (Show,Eq)
data Context = Context
{ ctxConnection :: Backend
, ctxSupported :: Supported
, ctxShared :: Shared
, ctxCiphers :: [Cipher]
, ctxState :: MVar TLSState
, ctxMeasurement :: IORef Measurement
, ctxEOF_ :: IORef Bool
, ctxEstablished_ :: IORef Bool
, ctxNeedEmptyPacket :: IORef Bool
, ctxSSLv2ClientHello :: IORef Bool
, ctxTxState :: MVar RecordState
, ctxRxState :: MVar RecordState
, ctxHandshake :: MVar (Maybe HandshakeState)
, ctxDoHandshake :: Context -> IO ()
, ctxDoHandshakeWith :: Context -> Handshake -> IO ()
, ctxHooks :: IORef Hooks
, ctxLockWrite :: MVar ()
, ctxLockRead :: MVar ()
, ctxLockState :: MVar ()
}
updateMeasure :: Context -> (Measurement -> Measurement) -> IO ()
updateMeasure ctx f = do
x <- readIORef (ctxMeasurement ctx)
writeIORef (ctxMeasurement ctx) $! f x
withMeasure :: Context -> (Measurement -> IO a) -> IO a
withMeasure ctx f = readIORef (ctxMeasurement ctx) >>= f
contextFlush :: Context -> IO ()
contextFlush = backendFlush . ctxConnection
contextClose :: Context -> IO ()
contextClose = backendClose . ctxConnection
contextGetInformation :: Context -> IO (Maybe Information)
contextGetInformation ctx = do
ver <- usingState_ ctx $ gets stVersion
hstate <- getHState ctx
let (ms, cr, sr) = case hstate of
Just st -> (hstMasterSecret st,
Just (hstClientRandom st),
hstServerRandom st)
Nothing -> (Nothing, Nothing, Nothing)
(cipher,comp) <- failOnEitherError $ runRxState ctx $ gets $ \st -> (stCipher st, stCompression st)
case (ver, cipher) of
(Just v, Just c) -> return $ Just $ Information v c comp ms cr sr
_ -> return Nothing
contextSend :: Context -> Bytes -> IO ()
contextSend c b = updateMeasure c (addBytesSent $ B.length b) >> (backendSend $ ctxConnection c) b
contextRecv :: Context -> Int -> IO Bytes
contextRecv c sz = updateMeasure c (addBytesReceived sz) >> (backendRecv $ ctxConnection c) sz
ctxEOF :: Context -> IO Bool
ctxEOF ctx = readIORef $ ctxEOF_ ctx
ctxHasSSLv2ClientHello :: Context -> IO Bool
ctxHasSSLv2ClientHello ctx = readIORef $ ctxSSLv2ClientHello ctx
ctxDisableSSLv2ClientHello :: Context -> IO ()
ctxDisableSSLv2ClientHello ctx = writeIORef (ctxSSLv2ClientHello ctx) False
setEOF :: Context -> IO ()
setEOF ctx = writeIORef (ctxEOF_ ctx) True
ctxEstablished :: Context -> IO Bool
ctxEstablished ctx = readIORef $ ctxEstablished_ ctx
ctxWithHooks :: Context -> (Hooks -> IO a) -> IO a
ctxWithHooks ctx f = readIORef (ctxHooks ctx) >>= f
contextModifyHooks :: Context -> (Hooks -> Hooks) -> IO ()
contextModifyHooks ctx f = modifyIORef (ctxHooks ctx) f
setEstablished :: Context -> Bool -> IO ()
setEstablished ctx v = writeIORef (ctxEstablished_ ctx) v
withLog :: Context -> (Logging -> IO ()) -> IO ()
withLog ctx f = ctxWithHooks ctx (f . hookLogging)
throwCore :: (MonadIO m, Exception e) => e -> m a
throwCore = liftIO . throwIO
failOnEitherError :: MonadIO m => m (Either TLSError a) -> m a
failOnEitherError f = do
ret <- f
case ret of
Left err -> throwCore err
Right r -> return r
usingState :: Context -> TLSSt a -> IO (Either TLSError a)
usingState ctx f =
modifyMVar (ctxState ctx) $ \st ->
let (a, newst) = runTLSState f st
in newst `seq` return (newst, a)
usingState_ :: Context -> TLSSt a -> IO a
usingState_ ctx f = failOnEitherError $ usingState ctx f
usingHState :: Context -> HandshakeM a -> IO a
usingHState ctx f = liftIO $ modifyMVar (ctxHandshake ctx) $ \mst ->
case mst of
Nothing -> throwCore $ Error_Misc "missing handshake"
Just st -> return $ swap (Just `fmap` runHandshake st f)
getHState :: Context -> IO (Maybe HandshakeState)
getHState ctx = liftIO $ readMVar (ctxHandshake ctx)
runTxState :: Context -> RecordM a -> IO (Either TLSError a)
runTxState ctx f = do
ver <- usingState_ ctx (getVersionWithDefault $ maximum $ supportedVersions $ ctxSupported ctx)
modifyMVar (ctxTxState ctx) $ \st ->
case runRecordM f ver st of
Left err -> return (st, Left err)
Right (a, newSt) -> return (newSt, Right a)
runRxState :: Context -> RecordM a -> IO (Either TLSError a)
runRxState ctx f = do
ver <- usingState_ ctx getVersion
modifyMVar (ctxRxState ctx) $ \st ->
case runRecordM f ver st of
Left err -> return (st, Left err)
Right (a, newSt) -> return (newSt, Right a)
getStateRNG :: Context -> Int -> IO Bytes
getStateRNG ctx n = usingState_ ctx $ genRandom n
withReadLock :: Context -> IO a -> IO a
withReadLock ctx f = withMVar (ctxLockRead ctx) (const f)
withWriteLock :: Context -> IO a -> IO a
withWriteLock ctx f = withMVar (ctxLockWrite ctx) (const f)
withRWLock :: Context -> IO a -> IO a
withRWLock ctx f = withReadLock ctx $ withWriteLock ctx f
withStateLock :: Context -> IO a -> IO a
withStateLock ctx f = withMVar (ctxLockState ctx) (const f)