{-# LANGUAGE CPP #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
module Network.TLS.IO
( checkValid
, sendPacket
, sendPacket13
, recvPacket
, recvPacket13
, PacketFlightM
, runPacketFlight
, loadPacket13
) where
import Network.TLS.Context.Internal
import Network.TLS.Struct
import Network.TLS.Struct13
import Network.TLS.Record
import Network.TLS.Record.Types13
import Network.TLS.Record.Disengage13
import Network.TLS.Packet
import Network.TLS.Hooks
import Network.TLS.Sending
import Network.TLS.Sending13
import Network.TLS.Receiving
import Network.TLS.Imports
import Network.TLS.Receiving13
import Network.TLS.State
import qualified Data.ByteString as B
import Data.IORef
import Control.Monad.Reader
import Control.Exception (finally, throwIO)
import System.IO.Error (mkIOError, eofErrorType)
checkValid :: Context -> IO ()
checkValid ctx = do
established <- ctxEstablished ctx
when (established == NotEstablished) $ throwIO ConnectionNotEstablished
eofed <- ctxEOF ctx
when eofed $ throwIO $ mkIOError eofErrorType "data" Nothing Nothing
readExact :: Context -> Int -> IO (Either TLSError ByteString)
readExact ctx sz = do
hdrbs <- contextRecv ctx sz
if B.length hdrbs == sz
then return $ Right hdrbs
else do
setEOF ctx
return . Left $
if B.null hdrbs
then Error_EOF
else Error_Packet ("partial packet: expecting " ++ show sz ++ " bytes, got: " ++ show (B.length hdrbs))
recvRecord :: Bool
-> Context
-> IO (Either TLSError (Record Plaintext))
recvRecord compatSSLv2 ctx
#ifdef SSLV2_COMPATIBLE
| compatSSLv2 = readExact ctx 2 >>= either (return . Left) sslv2Header
#endif
| otherwise = readExact ctx 5 >>= either (return . Left) (recvLengthE . decodeHeader)
where recvLengthE = either (return . Left) recvLength
recvLength header@(Header _ _ readlen)
| readlen > 16384 + 2048 = return $ Left maximumSizeExceeded
| otherwise =
readExact ctx (fromIntegral readlen) >>=
either (return . Left) (getRecord header)
#ifdef SSLV2_COMPATIBLE
sslv2Header header =
if B.head header >= 0x80
then either (return . Left) recvDeprecatedLength $ decodeDeprecatedHeaderLength header
else readExact ctx 3 >>=
either (return . Left) (recvLengthE . decodeHeader . B.append header)
recvDeprecatedLength readlen
| readlen > 1024 * 4 = return $ Left maximumSizeExceeded
| otherwise = do
res <- readExact ctx (fromIntegral readlen)
case res of
Left e -> return $ Left e
Right content ->
either (return . Left) (`getRecord` content) $ decodeDeprecatedHeader readlen content
#endif
maximumSizeExceeded = Error_Protocol ("record exceeding maximum size", True, RecordOverflow)
getRecord :: Header -> ByteString -> IO (Either TLSError (Record Plaintext))
getRecord header content = do
withLog ctx $ \logging -> loggingIORecv logging header content
runRxState ctx $ disengageRecord $ rawToRecord header (fragmentCiphertext content)
isCCS :: Record a -> Bool
isCCS (Record ProtocolType_ChangeCipherSpec _ _) = True
isCCS _ = False
recvPacket :: MonadIO m => Context -> m (Either TLSError Packet)
recvPacket ctx = liftIO $ do
compatSSLv2 <- ctxHasSSLv2ClientHello ctx
erecord <- recvRecord compatSSLv2 ctx
case erecord of
Left err -> return $ Left err
Right record -> do
hrr <- usingState_ ctx getTLS13HRR
if hrr && isCCS record then
recvPacket ctx
else do
pktRecv <- processPacket ctx record
pkt <- case pktRecv of
Right (Handshake hss) ->
ctxWithHooks ctx $ \hooks ->
Right . Handshake <$> mapM (hookRecvHandshake hooks) hss
_ -> return pktRecv
case pkt of
Right p -> withLog ctx $ \logging -> loggingPacketRecv logging $ show p
_ -> return ()
when compatSSLv2 $ ctxDisableSSLv2ClientHello ctx
return pkt
sendPacket :: MonadIO m => Context -> Packet -> m ()
sendPacket ctx pkt = do
withEmptyPacket <- liftIO $ readIORef $ ctxNeedEmptyPacket ctx
when (isNonNullAppData pkt && withEmptyPacket) $ sendPacket ctx $ AppData B.empty
edataToSend <- liftIO $ do
withLog ctx $ \logging -> loggingPacketSent logging (show pkt)
writePacket ctx pkt
case edataToSend of
Left err -> throwCore err
Right dataToSend -> sendBytes ctx dataToSend
where isNonNullAppData (AppData b) = not $ B.null b
isNonNullAppData _ = False
sendPacket13 :: MonadIO m => Context -> Packet13 -> m ()
sendPacket13 ctx pkt = writePacketBytes13 ctx pkt >>= sendBytes ctx
writePacketBytes13 :: MonadIO m => Context -> Packet13 -> m ByteString
writePacketBytes13 ctx pkt = do
edataToSend <- liftIO $ do
withLog ctx $ \logging -> loggingPacketSent logging (show pkt)
writePacket13 ctx pkt
either throwCore return edataToSend
sendBytes :: MonadIO m => Context -> ByteString -> m ()
sendBytes ctx dataToSend = liftIO $ do
withLog ctx $ \logging -> loggingIOSent logging dataToSend
contextSend ctx dataToSend
recvRecord13 :: Context
-> IO (Either TLSError Record13)
recvRecord13 ctx = readExact ctx 5 >>= either (return . Left) (recvLengthE . decodeHeader)
where recvLengthE = either (return . Left) recvLength
recvLength header@(Header _ _ readlen)
| readlen > 16384 + 2048 = return $ Left maximumSizeExceeded
| otherwise =
readExact ctx (fromIntegral readlen) >>=
either (return . Left) (getRecord header)
maximumSizeExceeded = Error_Protocol ("record exceeding maximum size", True, RecordOverflow)
getRecord :: Header -> ByteString -> IO (Either TLSError Record13)
getRecord header content = do
liftIO $ withLog ctx $ \logging -> loggingIORecv logging header content
runRxState ctx $ disengageRecord13 $ rawToRecord13 header content
recvPacket13 :: MonadIO m => Context -> m (Either TLSError Packet13)
recvPacket13 ctx = liftIO $ do
erecord <- recvRecord13 ctx
case erecord of
Left err@(Error_Protocol (_, True, BadRecordMac)) -> do
established <- ctxEstablished ctx
case established of
EarlyDataNotAllowed n
| n > 0 -> do setEstablished ctx $ EarlyDataNotAllowed (n - 1)
recvPacket13 ctx
_ -> return $ Left err
Left err -> return $ Left err
Right record -> do
pkt <- processPacket13 ctx record
case pkt of
Right p -> withLog ctx $ \logging -> loggingPacketRecv logging $ show p
_ -> return ()
return pkt
newtype PacketFlightM a = PacketFlightM (ReaderT (IORef [ByteString]) IO a)
deriving (Functor, Applicative, Monad, MonadFail, MonadIO)
runPacketFlight :: Context -> PacketFlightM a -> IO a
runPacketFlight ctx (PacketFlightM f) = do
ref <- newIORef []
finally (runReaderT f ref) $ do
st <- readIORef ref
unless (null st) $ sendBytes ctx $ B.concat $ reverse st
loadPacket13 :: Context -> Packet13 -> PacketFlightM ()
loadPacket13 ctx pkt = PacketFlightM $ do
bs <- writePacketBytes13 ctx pkt
ref <- ask
liftIO $ modifyIORef ref (bs :)