{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE OverloadedStrings #-}
module Network.TLS.Handshake.Common
    ( handshakeFailed
    , handleException
    , unexpected
    , newSession
    , handshakeTerminate
    -- * sending packets
    , sendChangeCipherAndFinish
    -- * receiving packets
    , recvChangeCipherAndFinish
    , RecvState(..)
    , runRecvState
    , recvPacketHandshake
    , onRecvStateHandshake
    , ensureRecvComplete
    , processExtendedMasterSec
    , extensionLookup
    , getSessionData
    , storePrivInfo
    , isSupportedGroup
    , checkSupportedGroup
    , errorToAlert
    , errorToAlertMessage
    ) where

import qualified Data.ByteString as B
import Control.Concurrent.MVar

import Network.TLS.Parameters
import Network.TLS.Compression
import Network.TLS.Context.Internal
import Network.TLS.Extension
import Network.TLS.Session
import Network.TLS.Struct
import Network.TLS.Struct13
import Network.TLS.IO
import Network.TLS.State
import Network.TLS.Handshake.Key
import Network.TLS.Handshake.Process
import Network.TLS.Handshake.State
import Network.TLS.Record.State
import Network.TLS.Measurement
import Network.TLS.Types
import Network.TLS.Cipher
import Network.TLS.Crypto
import Network.TLS.Util
import Network.TLS.X509
import Network.TLS.Imports

import Control.Monad.State.Strict
import Control.Exception (IOException, handle, fromException, throwIO)
import Data.IORef (writeIORef)

handshakeFailed :: TLSError -> IO ()
handshakeFailed :: TLSError -> IO ()
handshakeFailed TLSError
err = TLSException -> IO ()
forall e a. Exception e => e -> IO a
throwIO (TLSException -> IO ()) -> TLSException -> IO ()
forall a b. (a -> b) -> a -> b
$ TLSError -> TLSException
HandshakeFailed TLSError
err

handleException :: Context -> IO () -> IO ()
handleException :: Context -> IO () -> IO ()
handleException Context
ctx IO ()
f = IO () -> (SomeException -> IO ()) -> IO ()
forall a. IO a -> (SomeException -> IO a) -> IO a
catchException IO ()
f ((SomeException -> IO ()) -> IO ())
-> (SomeException -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \SomeException
exception -> do
    let tlserror :: TLSError
tlserror = TLSError -> Maybe TLSError -> TLSError
forall a. a -> Maybe a -> a
fromMaybe (String -> TLSError
Error_Misc (String -> TLSError) -> String -> TLSError
forall a b. (a -> b) -> a -> b
$ SomeException -> String
forall a. Show a => a -> String
show SomeException
exception) (Maybe TLSError -> TLSError) -> Maybe TLSError -> TLSError
forall a b. (a -> b) -> a -> b
$ SomeException -> Maybe TLSError
forall e. Exception e => SomeException -> Maybe e
fromException SomeException
exception
    Context -> Established -> IO ()
setEstablished Context
ctx Established
NotEstablished
    (IOException -> IO ()) -> IO () -> IO ()
forall e a. Exception e => (e -> IO a) -> IO a -> IO a
handle IOException -> IO ()
ignoreIOErr (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
        Bool
tls13 <- Context -> IO Bool
forall (m :: * -> *). MonadIO m => Context -> m Bool
tls13orLater Context
ctx
        if Bool
tls13 then
            Context -> Packet13 -> IO ()
sendPacket13 Context
ctx (Packet13 -> IO ()) -> Packet13 -> IO ()
forall a b. (a -> b) -> a -> b
$ [(AlertLevel, AlertDescription)] -> Packet13
Alert13 [TLSError -> (AlertLevel, AlertDescription)
errorToAlert TLSError
tlserror]
          else
            Context -> Packet -> IO ()
sendPacket Context
ctx (Packet -> IO ()) -> Packet -> IO ()
forall a b. (a -> b) -> a -> b
$ [(AlertLevel, AlertDescription)] -> Packet
Alert [TLSError -> (AlertLevel, AlertDescription)
errorToAlert TLSError
tlserror]
    TLSError -> IO ()
handshakeFailed TLSError
tlserror
  where
    ignoreIOErr :: IOException -> IO ()
    ignoreIOErr :: IOException -> IO ()
ignoreIOErr IOException
_ = () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

errorToAlert :: TLSError -> (AlertLevel, AlertDescription)
errorToAlert :: TLSError -> (AlertLevel, AlertDescription)
errorToAlert (Error_Protocol (String
_, Bool
b, AlertDescription
ad))   = let lvl :: AlertLevel
lvl = if Bool
b then AlertLevel
AlertLevel_Fatal else AlertLevel
AlertLevel_Warning
                                             in (AlertLevel
lvl, AlertDescription
ad)
errorToAlert (Error_Packet_unexpected String
_ String
_) = (AlertLevel
AlertLevel_Fatal, AlertDescription
UnexpectedMessage)
errorToAlert (Error_Packet_Parsing String
msg)
  | String
"invalid version" String -> String -> Bool
forall a. Eq a => [a] -> [a] -> Bool
`isInfixOf` String
msg      = (AlertLevel
AlertLevel_Fatal, AlertDescription
ProtocolVersion)
  | String
"request_update"  String -> String -> Bool
forall a. Eq a => [a] -> [a] -> Bool
`isInfixOf` String
msg      = (AlertLevel
AlertLevel_Fatal, AlertDescription
IllegalParameter)
  | Bool
otherwise                              = (AlertLevel
AlertLevel_Fatal, AlertDescription
DecodeError)
errorToAlert TLSError
_                             = (AlertLevel
AlertLevel_Fatal, AlertDescription
InternalError)

-- | Return the message that a TLS endpoint can add to its local log for the
-- specified library error.
errorToAlertMessage :: TLSError -> String
errorToAlertMessage :: TLSError -> String
errorToAlertMessage (Error_Protocol (String
msg, Bool
_, AlertDescription
_))    = String
msg
errorToAlertMessage (Error_Packet_unexpected String
msg String
_) = String
msg
errorToAlertMessage (Error_Packet_Parsing String
msg)      = String
msg
errorToAlertMessage TLSError
e                               = TLSError -> String
forall a. Show a => a -> String
show TLSError
e

unexpected :: MonadIO m => String -> Maybe String -> m a
unexpected :: String -> Maybe String -> m a
unexpected String
msg Maybe String
expected = TLSError -> m a
forall (m :: * -> *) a. MonadIO m => TLSError -> m a
throwCore (TLSError -> m a) -> TLSError -> m a
forall a b. (a -> b) -> a -> b
$ String -> String -> TLSError
Error_Packet_unexpected String
msg (String -> (String -> String) -> Maybe String -> String
forall b a. b -> (a -> b) -> Maybe a -> b
maybe String
"" (String
" expected: " String -> String -> String
forall a. [a] -> [a] -> [a]
++) Maybe String
expected)

newSession :: Context -> IO Session
newSession :: Context -> IO Session
newSession Context
ctx
    | Supported -> Bool
supportedSession (Supported -> Bool) -> Supported -> Bool
forall a b. (a -> b) -> a -> b
$ Context -> Supported
ctxSupported Context
ctx = Maybe SessionID -> Session
Session (Maybe SessionID -> Session)
-> (SessionID -> Maybe SessionID) -> SessionID -> Session
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SessionID -> Maybe SessionID
forall a. a -> Maybe a
Just (SessionID -> Session) -> IO SessionID -> IO Session
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Context -> Int -> IO SessionID
getStateRNG Context
ctx Int
32
    | Bool
otherwise                           = Session -> IO Session
forall (m :: * -> *) a. Monad m => a -> m a
return (Session -> IO Session) -> Session -> IO Session
forall a b. (a -> b) -> a -> b
$ Maybe SessionID -> Session
Session Maybe SessionID
forall a. Maybe a
Nothing

-- | when a new handshake is done, wrap up & clean up.
handshakeTerminate :: Context -> IO ()
handshakeTerminate :: Context -> IO ()
handshakeTerminate Context
ctx = do
    Session
session <- Context -> TLSSt Session -> IO Session
forall a. Context -> TLSSt a -> IO a
usingState_ Context
ctx TLSSt Session
getSession
    -- only callback the session established if we have a session
    case Session
session of
        Session (Just SessionID
sessionId) -> do
            Maybe SessionData
sessionData <- Context -> IO (Maybe SessionData)
getSessionData Context
ctx
            let !sessionId' :: SessionID
sessionId' = SessionID -> SessionID
B.copy SessionID
sessionId
            IO () -> IO ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ SessionManager -> SessionID -> SessionData -> IO ()
sessionEstablish (Shared -> SessionManager
sharedSessionManager (Shared -> SessionManager) -> Shared -> SessionManager
forall a b. (a -> b) -> a -> b
$ Context -> Shared
ctxShared Context
ctx) SessionID
sessionId' (String -> Maybe SessionData -> SessionData
forall a. String -> Maybe a -> a
fromJust String
"session-data" Maybe SessionData
sessionData)
        Session
_ -> () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    -- forget most handshake data and reset bytes counters.
    IO () -> IO ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ MVar (Maybe HandshakeState)
-> (Maybe HandshakeState -> IO (Maybe HandshakeState)) -> IO ()
forall a. MVar a -> (a -> IO a) -> IO ()
modifyMVar_ (Context -> MVar (Maybe HandshakeState)
ctxHandshake Context
ctx) ((Maybe HandshakeState -> IO (Maybe HandshakeState)) -> IO ())
-> (Maybe HandshakeState -> IO (Maybe HandshakeState)) -> IO ()
forall a b. (a -> b) -> a -> b
$ \ Maybe HandshakeState
mhshake ->
        case Maybe HandshakeState
mhshake of
            Maybe HandshakeState
Nothing -> Maybe HandshakeState -> IO (Maybe HandshakeState)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe HandshakeState
forall a. Maybe a
Nothing
            Just HandshakeState
hshake ->
                Maybe HandshakeState -> IO (Maybe HandshakeState)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe HandshakeState -> IO (Maybe HandshakeState))
-> Maybe HandshakeState -> IO (Maybe HandshakeState)
forall a b. (a -> b) -> a -> b
$ HandshakeState -> Maybe HandshakeState
forall a. a -> Maybe a
Just (Version -> ClientRandom -> HandshakeState
newEmptyHandshake (HandshakeState -> Version
hstClientVersion HandshakeState
hshake) (HandshakeState -> ClientRandom
hstClientRandom HandshakeState
hshake))
                    { hstServerRandom :: Maybe ServerRandom
hstServerRandom = HandshakeState -> Maybe ServerRandom
hstServerRandom HandshakeState
hshake
                    , hstMasterSecret :: Maybe SessionID
hstMasterSecret = HandshakeState -> Maybe SessionID
hstMasterSecret HandshakeState
hshake
                    , hstExtendedMasterSec :: Bool
hstExtendedMasterSec = HandshakeState -> Bool
hstExtendedMasterSec HandshakeState
hshake
                    , hstNegotiatedGroup :: Maybe Group
hstNegotiatedGroup = HandshakeState -> Maybe Group
hstNegotiatedGroup HandshakeState
hshake
                    }
    Context -> (Measurement -> Measurement) -> IO ()
updateMeasure Context
ctx Measurement -> Measurement
resetBytesCounters
    -- mark the secure connection up and running.
    Context -> Established -> IO ()
setEstablished Context
ctx Established
Established
    () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

sendChangeCipherAndFinish :: Context
                          -> Role
                          -> IO ()
sendChangeCipherAndFinish :: Context -> Role -> IO ()
sendChangeCipherAndFinish Context
ctx Role
role = do
    Context -> Packet -> IO ()
sendPacket Context
ctx Packet
ChangeCipherSpec
    IO () -> IO ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ Context -> IO ()
contextFlush Context
ctx
    SessionID
cf <- Context -> TLSSt Version -> IO Version
forall a. Context -> TLSSt a -> IO a
usingState_ Context
ctx TLSSt Version
getVersion IO Version -> (Version -> IO SessionID) -> IO SessionID
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \Version
ver -> Context -> HandshakeM SessionID -> IO SessionID
forall (m :: * -> *) a. MonadIO m => Context -> HandshakeM a -> m a
usingHState Context
ctx (HandshakeM SessionID -> IO SessionID)
-> HandshakeM SessionID -> IO SessionID
forall a b. (a -> b) -> a -> b
$ Version -> Role -> HandshakeM SessionID
getHandshakeDigest Version
ver Role
role
    Context -> Packet -> IO ()
sendPacket Context
ctx ([Handshake] -> Packet
Handshake [SessionID -> Handshake
Finished SessionID
cf])
    IORef (Maybe SessionID) -> Maybe SessionID -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef (Context -> IORef (Maybe SessionID)
ctxFinished Context
ctx) (Maybe SessionID -> IO ()) -> Maybe SessionID -> IO ()
forall a b. (a -> b) -> a -> b
$ SessionID -> Maybe SessionID
forall a. a -> Maybe a
Just SessionID
cf
    IO () -> IO ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ Context -> IO ()
contextFlush Context
ctx

recvChangeCipherAndFinish :: Context -> IO ()
recvChangeCipherAndFinish :: Context -> IO ()
recvChangeCipherAndFinish Context
ctx = Context -> RecvState IO -> IO ()
runRecvState Context
ctx ((Packet -> IO (RecvState IO)) -> RecvState IO
forall (m :: * -> *). (Packet -> m (RecvState m)) -> RecvState m
RecvStateNext Packet -> IO (RecvState IO)
forall (m :: * -> *) (m :: * -> *).
(MonadIO m, MonadIO m) =>
Packet -> m (RecvState m)
expectChangeCipher)
  where expectChangeCipher :: Packet -> m (RecvState m)
expectChangeCipher Packet
ChangeCipherSpec = RecvState m -> m (RecvState m)
forall (m :: * -> *) a. Monad m => a -> m a
return (RecvState m -> m (RecvState m)) -> RecvState m -> m (RecvState m)
forall a b. (a -> b) -> a -> b
$ (Handshake -> m (RecvState m)) -> RecvState m
forall (m :: * -> *). (Handshake -> m (RecvState m)) -> RecvState m
RecvStateHandshake Handshake -> m (RecvState m)
forall (m :: * -> *) (m :: * -> *).
MonadIO m =>
Handshake -> m (RecvState m)
expectFinish
        expectChangeCipher Packet
p                = String -> Maybe String -> m (RecvState m)
forall (m :: * -> *) a. MonadIO m => String -> Maybe String -> m a
unexpected (Packet -> String
forall a. Show a => a -> String
show Packet
p) (String -> Maybe String
forall a. a -> Maybe a
Just String
"change cipher")
        expectFinish :: Handshake -> m (RecvState m)
expectFinish (Finished SessionID
_) = RecvState m -> m (RecvState m)
forall (m :: * -> *) a. Monad m => a -> m a
return RecvState m
forall (m :: * -> *). RecvState m
RecvStateDone
        expectFinish Handshake
p            = String -> Maybe String -> m (RecvState m)
forall (m :: * -> *) a. MonadIO m => String -> Maybe String -> m a
unexpected (Handshake -> String
forall a. Show a => a -> String
show Handshake
p) (String -> Maybe String
forall a. a -> Maybe a
Just String
"Handshake Finished")

data RecvState m =
      RecvStateNext (Packet -> m (RecvState m))
    | RecvStateHandshake (Handshake -> m (RecvState m))
    | RecvStateDone

recvPacketHandshake :: Context -> IO [Handshake]
recvPacketHandshake :: Context -> IO [Handshake]
recvPacketHandshake Context
ctx = do
    Either TLSError Packet
pkts <- Context -> IO (Either TLSError Packet)
recvPacket Context
ctx
    case Either TLSError Packet
pkts of
        Right (Handshake [Handshake]
l) -> [Handshake] -> IO [Handshake]
forall (m :: * -> *) a. Monad m => a -> m a
return [Handshake]
l
        Right x :: Packet
x@(AppData SessionID
_) -> do
            -- If a TLS13 server decides to reject RTT0 data, the server should
            -- skip records for RTT0 data up to the maximum limit.
            Established
established <- Context -> IO Established
ctxEstablished Context
ctx
            case Established
established of
                EarlyDataNotAllowed Int
n
                    | Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0 -> do Context -> Established -> IO ()
setEstablished Context
ctx (Established -> IO ()) -> Established -> IO ()
forall a b. (a -> b) -> a -> b
$ Int -> Established
EarlyDataNotAllowed (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
                                  Context -> IO [Handshake]
recvPacketHandshake Context
ctx
                Established
_           -> String -> Maybe String -> IO [Handshake]
forall (m :: * -> *) a. MonadIO m => String -> Maybe String -> m a
unexpected (Packet -> String
forall a. Show a => a -> String
show Packet
x) (String -> Maybe String
forall a. a -> Maybe a
Just String
"handshake")
        Right Packet
x             -> String -> Maybe String -> IO [Handshake]
forall (m :: * -> *) a. MonadIO m => String -> Maybe String -> m a
unexpected (Packet -> String
forall a. Show a => a -> String
show Packet
x) (String -> Maybe String
forall a. a -> Maybe a
Just String
"handshake")
        Left TLSError
err            -> TLSError -> IO [Handshake]
forall (m :: * -> *) a. MonadIO m => TLSError -> m a
throwCore TLSError
err

-- | process a list of handshakes message in the recv state machine.
onRecvStateHandshake :: Context -> RecvState IO -> [Handshake] -> IO (RecvState IO)
onRecvStateHandshake :: Context -> RecvState IO -> [Handshake] -> IO (RecvState IO)
onRecvStateHandshake Context
_   RecvState IO
recvState [] = RecvState IO -> IO (RecvState IO)
forall (m :: * -> *) a. Monad m => a -> m a
return RecvState IO
recvState
onRecvStateHandshake Context
_   (RecvStateNext Packet -> IO (RecvState IO)
f) [Handshake]
hms = Packet -> IO (RecvState IO)
f ([Handshake] -> Packet
Handshake [Handshake]
hms)
onRecvStateHandshake Context
ctx (RecvStateHandshake Handshake -> IO (RecvState IO)
f) (Handshake
x:[Handshake]
xs) = do
    RecvState IO
nstate <- Handshake -> IO (RecvState IO)
f Handshake
x
    Context -> Handshake -> IO ()
processHandshake Context
ctx Handshake
x
    Context -> RecvState IO -> [Handshake] -> IO (RecvState IO)
onRecvStateHandshake Context
ctx RecvState IO
nstate [Handshake]
xs
onRecvStateHandshake Context
_ RecvState IO
_ [Handshake]
_   = String -> Maybe String -> IO (RecvState IO)
forall (m :: * -> *) a. MonadIO m => String -> Maybe String -> m a
unexpected String
"spurious handshake" Maybe String
forall a. Maybe a
Nothing

runRecvState :: Context -> RecvState IO -> IO ()
runRecvState :: Context -> RecvState IO -> IO ()
runRecvState Context
_    RecvState IO
RecvStateDone    = () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
runRecvState Context
ctx (RecvStateNext Packet -> IO (RecvState IO)
f) = Context -> IO (Either TLSError Packet)
recvPacket Context
ctx IO (Either TLSError Packet)
-> (Either TLSError Packet -> IO (RecvState IO))
-> IO (RecvState IO)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= (TLSError -> IO (RecvState IO))
-> (Packet -> IO (RecvState IO))
-> Either TLSError Packet
-> IO (RecvState IO)
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either TLSError -> IO (RecvState IO)
forall (m :: * -> *) a. MonadIO m => TLSError -> m a
throwCore Packet -> IO (RecvState IO)
f IO (RecvState IO) -> (RecvState IO -> IO ()) -> IO ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Context -> RecvState IO -> IO ()
runRecvState Context
ctx
runRecvState Context
ctx RecvState IO
iniState          = Context -> IO [Handshake]
recvPacketHandshake Context
ctx IO [Handshake]
-> ([Handshake] -> IO (RecvState IO)) -> IO (RecvState IO)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Context -> RecvState IO -> [Handshake] -> IO (RecvState IO)
onRecvStateHandshake Context
ctx RecvState IO
iniState IO (RecvState IO) -> (RecvState IO -> IO ()) -> IO ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Context -> RecvState IO -> IO ()
runRecvState Context
ctx

ensureRecvComplete :: MonadIO m => Context -> m ()
ensureRecvComplete :: Context -> m ()
ensureRecvComplete Context
ctx = do
    Bool
complete <- IO Bool -> m Bool
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO Bool -> m Bool) -> IO Bool -> m Bool
forall a b. (a -> b) -> a -> b
$ Context -> IO Bool
isRecvComplete Context
ctx
    Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless Bool
complete (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$
        TLSError -> m ()
forall (m :: * -> *) a. MonadIO m => TLSError -> m a
throwCore (TLSError -> m ()) -> TLSError -> m ()
forall a b. (a -> b) -> a -> b
$ (String, Bool, AlertDescription) -> TLSError
Error_Protocol (String
"received incomplete message at key change", Bool
True, AlertDescription
UnexpectedMessage)

processExtendedMasterSec :: MonadIO m => Context -> Version -> MessageType -> [ExtensionRaw] -> m Bool
processExtendedMasterSec :: Context -> Version -> MessageType -> [ExtensionRaw] -> m Bool
processExtendedMasterSec Context
ctx Version
ver MessageType
msgt [ExtensionRaw]
exts
    | Version
ver Version -> Version -> Bool
forall a. Ord a => a -> a -> Bool
< Version
TLS10  = Bool -> m Bool
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False
    | Version
ver Version -> Version -> Bool
forall a. Ord a => a -> a -> Bool
> Version
TLS12  = String -> m Bool
forall a. HasCallStack => String -> a
error String
"EMS processing is not compatible with TLS 1.3"
    | EMSMode
ems EMSMode -> EMSMode -> Bool
forall a. Eq a => a -> a -> Bool
== EMSMode
NoEMS = Bool -> m Bool
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False
    | Bool
otherwise    =
        case ExtensionID -> [ExtensionRaw] -> Maybe SessionID
extensionLookup ExtensionID
extensionID_ExtendedMasterSecret [ExtensionRaw]
exts Maybe SessionID
-> (SessionID -> Maybe ExtendedMasterSecret)
-> Maybe ExtendedMasterSecret
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= MessageType -> SessionID -> Maybe ExtendedMasterSecret
forall a. Extension a => MessageType -> SessionID -> Maybe a
extensionDecode MessageType
msgt of
            Just ExtendedMasterSecret
ExtendedMasterSecret -> Context -> HandshakeM () -> m ()
forall (m :: * -> *) a. MonadIO m => Context -> HandshakeM a -> m a
usingHState Context
ctx (Bool -> HandshakeM ()
setExtendedMasterSec Bool
True) m () -> m Bool -> m Bool
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Bool -> m Bool
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True
            Maybe ExtendedMasterSecret
Nothing | EMSMode
ems EMSMode -> EMSMode -> Bool
forall a. Eq a => a -> a -> Bool
== EMSMode
RequireEMS -> TLSError -> m Bool
forall (m :: * -> *) a. MonadIO m => TLSError -> m a
throwCore (TLSError -> m Bool) -> TLSError -> m Bool
forall a b. (a -> b) -> a -> b
$ (String, Bool, AlertDescription) -> TLSError
Error_Protocol (String
err, Bool
True, AlertDescription
HandshakeFailure)
                    | Bool
otherwise -> Bool -> m Bool
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False
  where ems :: EMSMode
ems = Supported -> EMSMode
supportedExtendedMasterSec (Context -> Supported
ctxSupported Context
ctx)
        err :: String
err = String
"peer does not support Extended Master Secret"

getSessionData :: Context -> IO (Maybe SessionData)
getSessionData :: Context -> IO (Maybe SessionData)
getSessionData Context
ctx = do
    Version
ver <- Context -> TLSSt Version -> IO Version
forall a. Context -> TLSSt a -> IO a
usingState_ Context
ctx TLSSt Version
getVersion
    Maybe String
sni <- Context -> TLSSt (Maybe String) -> IO (Maybe String)
forall a. Context -> TLSSt a -> IO a
usingState_ Context
ctx TLSSt (Maybe String)
getClientSNI
    Maybe SessionID
mms <- Context -> HandshakeM (Maybe SessionID) -> IO (Maybe SessionID)
forall (m :: * -> *) a. MonadIO m => Context -> HandshakeM a -> m a
usingHState Context
ctx ((HandshakeState -> Maybe SessionID) -> HandshakeM (Maybe SessionID)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets HandshakeState -> Maybe SessionID
hstMasterSecret)
    !Bool
ems <- Context -> HandshakeM Bool -> IO Bool
forall (m :: * -> *) a. MonadIO m => Context -> HandshakeM a -> m a
usingHState Context
ctx HandshakeM Bool
getExtendedMasterSec
    RecordState
tx  <- IO RecordState -> IO RecordState
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO RecordState -> IO RecordState)
-> IO RecordState -> IO RecordState
forall a b. (a -> b) -> a -> b
$ MVar RecordState -> IO RecordState
forall a. MVar a -> IO a
readMVar (Context -> MVar RecordState
ctxTxState Context
ctx)
    Maybe SessionID
alpn <- Context -> TLSSt (Maybe SessionID) -> IO (Maybe SessionID)
forall a. Context -> TLSSt a -> IO a
usingState_ Context
ctx TLSSt (Maybe SessionID)
getNegotiatedProtocol
    let !cipher :: ExtensionID
cipher      = Cipher -> ExtensionID
cipherID (Cipher -> ExtensionID) -> Cipher -> ExtensionID
forall a b. (a -> b) -> a -> b
$ String -> Maybe Cipher -> Cipher
forall a. String -> Maybe a -> a
fromJust String
"cipher" (Maybe Cipher -> Cipher) -> Maybe Cipher -> Cipher
forall a b. (a -> b) -> a -> b
$ RecordState -> Maybe Cipher
stCipher RecordState
tx
        !compression :: CompressionID
compression = Compression -> CompressionID
compressionID (Compression -> CompressionID) -> Compression -> CompressionID
forall a b. (a -> b) -> a -> b
$ RecordState -> Compression
stCompression RecordState
tx
        flags :: [SessionFlag]
flags = [SessionFlag
SessionEMS | Bool
ems]
    case Maybe SessionID
mms of
        Maybe SessionID
Nothing -> Maybe SessionData -> IO (Maybe SessionData)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe SessionData
forall a. Maybe a
Nothing
        Just SessionID
ms -> Maybe SessionData -> IO (Maybe SessionData)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe SessionData -> IO (Maybe SessionData))
-> Maybe SessionData -> IO (Maybe SessionData)
forall a b. (a -> b) -> a -> b
$ SessionData -> Maybe SessionData
forall a. a -> Maybe a
Just SessionData :: Version
-> ExtensionID
-> CompressionID
-> Maybe String
-> SessionID
-> Maybe Group
-> Maybe TLS13TicketInfo
-> Maybe SessionID
-> Int
-> [SessionFlag]
-> SessionData
SessionData
                        { sessionVersion :: Version
sessionVersion     = Version
ver
                        , sessionCipher :: ExtensionID
sessionCipher      = ExtensionID
cipher
                        , sessionCompression :: CompressionID
sessionCompression = CompressionID
compression
                        , sessionClientSNI :: Maybe String
sessionClientSNI   = Maybe String
sni
                        , sessionSecret :: SessionID
sessionSecret      = SessionID
ms
                        , sessionGroup :: Maybe Group
sessionGroup       = Maybe Group
forall a. Maybe a
Nothing
                        , sessionTicketInfo :: Maybe TLS13TicketInfo
sessionTicketInfo  = Maybe TLS13TicketInfo
forall a. Maybe a
Nothing
                        , sessionALPN :: Maybe SessionID
sessionALPN        = Maybe SessionID
alpn
                        , sessionMaxEarlyDataSize :: Int
sessionMaxEarlyDataSize = Int
0
                        , sessionFlags :: [SessionFlag]
sessionFlags       = [SessionFlag]
flags
                        }

extensionLookup :: ExtensionID -> [ExtensionRaw] -> Maybe ByteString
extensionLookup :: ExtensionID -> [ExtensionRaw] -> Maybe SessionID
extensionLookup ExtensionID
toFind = (ExtensionRaw -> SessionID)
-> Maybe ExtensionRaw -> Maybe SessionID
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\(ExtensionRaw ExtensionID
_ SessionID
content) -> SessionID
content)
                       (Maybe ExtensionRaw -> Maybe SessionID)
-> ([ExtensionRaw] -> Maybe ExtensionRaw)
-> [ExtensionRaw]
-> Maybe SessionID
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (ExtensionRaw -> Bool) -> [ExtensionRaw] -> Maybe ExtensionRaw
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find (\(ExtensionRaw ExtensionID
eid SessionID
_) -> ExtensionID
eid ExtensionID -> ExtensionID -> Bool
forall a. Eq a => a -> a -> Bool
== ExtensionID
toFind)

-- | Store the specified keypair.  Whether the public key and private key
-- actually match is left for the peer to discover.  We're not presently
-- burning  CPU to detect that misconfiguration.  We verify only that the
-- types of keys match and that it does not include an algorithm that would
-- not be safe.
storePrivInfo :: MonadIO m
              => Context
              -> CertificateChain
              -> PrivKey
              -> m PubKey
storePrivInfo :: Context -> CertificateChain -> PrivKey -> m PubKey
storePrivInfo Context
ctx CertificateChain
cc PrivKey
privkey = do
    let CertificateChain (SignedExact Certificate
c:[SignedExact Certificate]
_) = CertificateChain
cc
        pubkey :: PubKey
pubkey = Certificate -> PubKey
certPubKey (Certificate -> PubKey) -> Certificate -> PubKey
forall a b. (a -> b) -> a -> b
$ SignedExact Certificate -> Certificate
getCertificate SignedExact Certificate
c
    Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ((PubKey, PrivKey) -> Bool
isDigitalSignaturePair (PubKey
pubkey, PrivKey
privkey)) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$
        TLSError -> m ()
forall (m :: * -> *) a. MonadIO m => TLSError -> m a
throwCore (TLSError -> m ()) -> TLSError -> m ()
forall a b. (a -> b) -> a -> b
$ (String, Bool, AlertDescription) -> TLSError
Error_Protocol
            ( String
"mismatched or unsupported private key pair"
            , Bool
True
            , AlertDescription
InternalError )
    Context -> HandshakeM () -> m ()
forall (m :: * -> *) a. MonadIO m => Context -> HandshakeM a -> m a
usingHState Context
ctx (HandshakeM () -> m ()) -> HandshakeM () -> m ()
forall a b. (a -> b) -> a -> b
$ (PubKey, PrivKey) -> HandshakeM ()
setPublicPrivateKeys (PubKey
pubkey, PrivKey
privkey)
    PubKey -> m PubKey
forall (m :: * -> *) a. Monad m => a -> m a
return PubKey
pubkey

-- verify that the group selected by the peer is supported in the local
-- configuration
checkSupportedGroup :: Context -> Group -> IO ()
checkSupportedGroup :: Context -> Group -> IO ()
checkSupportedGroup Context
ctx Group
grp =
    Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Context -> Group -> Bool
isSupportedGroup Context
ctx Group
grp) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$
        let msg :: String
msg = String
"unsupported (EC)DHE group: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Group -> String
forall a. Show a => a -> String
show Group
grp
         in TLSError -> IO ()
forall (m :: * -> *) a. MonadIO m => TLSError -> m a
throwCore (TLSError -> IO ()) -> TLSError -> IO ()
forall a b. (a -> b) -> a -> b
$ (String, Bool, AlertDescription) -> TLSError
Error_Protocol (String
msg, Bool
True, AlertDescription
IllegalParameter)

isSupportedGroup :: Context -> Group -> Bool
isSupportedGroup :: Context -> Group -> Bool
isSupportedGroup Context
ctx Group
grp = Group
grp Group -> [Group] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` Supported -> [Group]
supportedGroups (Context -> Supported
ctxSupported Context
ctx)