{-# Options_GHC -Wno-unused-do-bind #-}
{-# Language OverloadedStrings #-}
module Client.Network.Async
( NetworkConnection
, NetworkEvent(..)
, createConnection
, Client.Network.Async.send
, Client.Network.Async.recv
, upgrade
, abortConnection
, TerminationReason(..)
) where
import Client.Configuration.ServerSettings
import Client.Network.Connect
import Control.Concurrent
import Control.Concurrent.Async
import Control.Concurrent.STM
import Control.Exception
import Control.Lens
import Control.Monad
import Data.ByteString (ByteString)
import qualified Data.ByteString as B
import Data.Foldable
import Data.List
import Data.List.Split (chunksOf)
import Data.Text (Text)
import qualified Data.Text as Text
import Data.Time
import Data.Traversable
import Data.Word (Word8)
import Hookup
import Hookup.OpenSSL (getPubKeyDer)
import Irc.RateLimit
import Numeric (showHex)
import qualified OpenSSL.EVP.Digest as Digest
import OpenSSL.X509 (X509, printX509, writeDerX509)
data NetworkConnection = NetworkConnection
{ connOutQueue :: TQueue ByteString
, connInQueue :: TQueue NetworkEvent
, connAsync :: Async ()
, connUpgrade :: MVar (IO ())
}
upgrade :: NetworkConnection -> IO ()
upgrade c = join (swapMVar (connUpgrade c) (pure ()))
data NetworkEvent
= NetworkOpen !ZonedTime
| NetworkTLS [Text]
| NetworkLine !ZonedTime !ByteString
| NetworkError !ZonedTime !SomeException
| NetworkClose !ZonedTime
instance Show NetworkConnection where
showsPrec p _ = showParen (p > 10)
$ showString "NetworkConnection _"
data TerminationReason
= PingTimeout
| ForcedDisconnect
| StsUpgrade
| StartTLSFailed
| BadCertFingerprint ByteString (Maybe ByteString)
| BadPubkeyFingerprint ByteString (Maybe ByteString)
deriving Show
instance Exception TerminationReason where
displayException PingTimeout = "connection killed due to ping timeout"
displayException ForcedDisconnect = "connection killed by client command"
displayException StsUpgrade = "connection killed by sts policy"
displayException StartTLSFailed = "connection killed due to failed STARTTLS"
displayException (BadCertFingerprint expect got) =
"Expected certificate fingerprint: " ++ formatDigest expect ++
"; got: " ++ maybe "none" formatDigest got
displayException (BadPubkeyFingerprint expect got) =
"Expected public key fingerprint: " ++ formatDigest expect ++
"; got: " ++ maybe "none" formatDigest got
send :: NetworkConnection -> ByteString -> IO ()
send c msg = atomically (writeTQueue (connOutQueue c) msg)
recv :: NetworkConnection -> STM [NetworkEvent]
recv = flushTQueue . connInQueue
abortConnection :: TerminationReason -> NetworkConnection -> IO ()
abortConnection reason c = cancelWith (connAsync c) reason
createConnection ::
Int ->
ServerSettings ->
IO NetworkConnection
createConnection delay settings =
do outQueue <- newTQueueIO
inQueue <- newTQueueIO
upgradeMVar <- newEmptyMVar
supervisor <- async $
threadDelay (delay * 1000000) >>
withConnection settings
(startConnection settings inQueue outQueue upgradeMVar)
let recordFailure :: SomeException -> IO ()
recordFailure ex =
do now <- getZonedTime
atomically (writeTQueue inQueue (NetworkError now ex))
recordNormalExit :: IO ()
recordNormalExit =
do now <- getZonedTime
atomically (writeTQueue inQueue (NetworkClose now))
forkIO $ do outcome <- waitCatch supervisor
case outcome of
Right{} -> recordNormalExit
Left e -> recordFailure e
return NetworkConnection
{ connOutQueue = outQueue
, connInQueue = inQueue
, connAsync = supervisor
, connUpgrade = upgradeMVar
}
startConnection ::
ServerSettings ->
TQueue NetworkEvent ->
TQueue ByteString ->
MVar (IO ()) ->
Connection ->
IO ()
startConnection settings inQueue outQueue upgradeMVar h =
do reportNetworkOpen inQueue
ready <- presend
when ready $
do checkFingerprints
race_ receiveMain sendMain
where
receiveMain = receiveLoop h inQueue
sendMain =
do rate <- newRateLimit (view ssFloodPenalty settings)
(view ssFloodThreshold settings)
sendLoop h outQueue rate
presend =
case view ssTls settings of
TlsNo -> True <$ putMVar upgradeMVar (pure ())
TlsYes ->
do txts <- describeCertificates h
putMVar upgradeMVar (pure ())
atomically (writeTQueue inQueue (NetworkTLS txts))
pure True
TlsStart ->
do Hookup.send h "STARTTLS\n"
r <- withAsync receiveMain $ \t ->
do putMVar upgradeMVar (cancel t)
waitCatch t
case r of
Right () -> pure False
Left e | Just AsyncCancelled <- fromException e ->
do Hookup.upgradeTls (tlsParams settings) (view ssHostName settings) h
txts <- describeCertificates h
atomically (writeTQueue inQueue (NetworkTLS txts))
pure True
Left e -> throwIO e
checkFingerprints =
case view ssTls settings of
TlsNo -> pure ()
_ ->
do for_ (view ssTlsCertFingerprint settings) (checkCertFingerprint h)
for_ (view ssTlsPubkeyFingerprint settings) (checkPubkeyFingerprint h)
checkCertFingerprint :: Connection -> Fingerprint -> IO ()
checkCertFingerprint h fp =
do (expect, got) <-
case fp of
FingerprintSha1 expect -> (,) expect <$> getPeerCertFingerprintSha1 h
FingerprintSha256 expect -> (,) expect <$> getPeerCertFingerprintSha256 h
FingerprintSha512 expect -> (,) expect <$> getPeerCertFingerprintSha512 h
unless (Just expect == got)
(throwIO (BadCertFingerprint expect got))
checkPubkeyFingerprint :: Connection -> Fingerprint -> IO ()
checkPubkeyFingerprint h fp =
do (expect, got) <-
case fp of
FingerprintSha1 expect -> (,) expect <$> getPeerPubkeyFingerprintSha1 h
FingerprintSha256 expect -> (,) expect <$> getPeerPubkeyFingerprintSha256 h
FingerprintSha512 expect -> (,) expect <$> getPeerPubkeyFingerprintSha512 h
unless (Just expect == got)
(throwIO (BadPubkeyFingerprint expect got))
reportNetworkOpen :: TQueue NetworkEvent -> IO ()
reportNetworkOpen inQueue =
do now <- getZonedTime
atomically (writeTQueue inQueue (NetworkOpen now))
describeCertificates :: Connection -> IO [Text]
describeCertificates h =
do mbServer <- getPeerCertificate h
mbClient <- getClientCertificate h
cTxts <- certText "Server" mbServer
sTxts <- certText "Client" mbClient
pure (reverse (cTxts ++ sTxts))
certText :: String -> Maybe X509 -> IO [Text]
certText label mbX509 =
case mbX509 of
Nothing -> pure []
Just x509 ->
do str <- printX509 x509
fps <- getFingerprints x509
pure $ map Text.pack
$ ('\^B' : label)
: map colorize (lines str ++ fps)
where
colorize x@(' ':_) = x
colorize xs = "\^C07" ++ xs
getFingerprints :: X509 -> IO [String]
getFingerprints x509 =
do certDer <- writeDerX509 x509
spkiDer <- getPubKeyDer x509
xss <- for ["sha1", "sha256", "sha512"] $ \alg ->
do mb <- Digest.getDigestByName alg
pure $ case mb of
Nothing -> []
Just d ->
("Certificate " ++ alg ++ " fingerprint:")
: fingerprintLines (Digest.digestLBS d certDer)
++ ("SPKI " ++ alg ++ " fingerprint:")
: fingerprintLines (Digest.digestBS d spkiDer)
pure (concat xss)
fingerprintLines :: ByteString -> [String]
fingerprintLines
= map (" "++)
. chunksOf (16*3)
. formatDigest
formatDigest :: ByteString -> String
formatDigest
= intercalate ":"
. map showByte
. B.unpack
showByte :: Word8 -> String
showByte x
| x < 0x10 = '0' : showHex x ""
| otherwise = showHex x ""
sendLoop :: Connection -> TQueue ByteString -> RateLimit -> IO a
sendLoop h outQueue rate =
forever $
do msg <- atomically (readTQueue outQueue)
tickRateLimit rate
Hookup.send h msg
ircMaxMessageLength :: Int
ircMaxMessageLength = 512
receiveLoop :: Connection -> TQueue NetworkEvent -> IO ()
receiveLoop h inQueue =
do mb <- recvLine h (4*ircMaxMessageLength)
for_ mb $ \msg ->
do unless (B.null msg) $
do now <- getZonedTime
atomically $ writeTQueue inQueue
$ NetworkLine now msg
receiveLoop h inQueue