module Network.Xmpp.Tls where
import Control.Concurrent.STM.TMVar
import qualified Control.Exception.Lifted as Ex
import Control.Monad
import Control.Monad.Error
import Control.Monad.State.Strict
import Crypto.Random.API
import qualified Data.ByteString as BS
import qualified Data.ByteString.Char8 as BSC8
import qualified Data.ByteString.Lazy as BL
import Data.Conduit
import qualified Data.Conduit.Binary as CB
import Data.IORef
import Data.Typeable
import Data.XML.Types
import Network.TLS
import Network.TLS.Extra
import Network.Xmpp.Stream
import Network.Xmpp.Types
import System.Log.Logger (debugM, errorM)
mkBackend con = Backend { backendSend = \bs -> void (streamSend con bs)
, backendRecv = streamReceive con
, backendFlush = streamFlush con
, backendClose = streamClose con
}
starttlsE :: Element
starttlsE = Element "{urn:ietf:params:xml:ns:xmpp-tls}starttls" [] []
tls :: Stream -> IO (Either XmppFailure ())
tls con = Ex.handle (return . Left . TlsError)
. flip withStream con
. runErrorT $ do
conf <- gets $ streamConfiguration
sState <- gets streamConnectionState
case sState of
Plain -> return ()
Closed -> do
liftIO $ errorM "Pontarius.XMPP" "startTls: The stream is closed."
throwError XmppNoStream
Secured -> do
liftIO $ errorM "Pontarius.XMPP" "startTls: The stream is already secured."
throwError TlsStreamSecured
features <- lift $ gets streamFeatures
case (tlsBehaviour conf, streamTls features) of
(RequireTls , Just _ ) -> startTls
(RequireTls , Nothing ) -> throwError TlsNoServerSupport
(PreferTls , Just _ ) -> startTls
(PreferTls , Nothing ) -> return ()
(PreferPlain , Just True) -> startTls
(PreferPlain , _ ) -> return ()
(RefuseTls , Just True) -> throwError XmppOtherFailure
(RefuseTls , _ ) -> return ()
where
startTls = do
params <- gets $ tlsParams . streamConfiguration
lift $ pushElement starttlsE
answer <- lift $ pullElement
case answer of
Left e -> return $ Left e
Right (Element "{urn:ietf:params:xml:ns:xmpp-tls}proceed" [] []) ->
return $ Right ()
Right (Element "{urn:ietf:params:xml:ns:xmpp-tls}failure" _ _) -> do
liftIO $ errorM "Pontarius.XMPP" "startTls: TLS initiation failed."
return . Left $ XmppOtherFailure
hand <- gets streamHandle
(raw, _snk, psh, read, ctx) <- lift $ tlsinit params (mkBackend hand)
let newHand = StreamHandle { streamSend = catchPush . psh
, streamReceive = read
, streamFlush = contextFlush ctx
, streamClose = bye ctx >> streamClose hand
}
lift $ modify ( \x -> x {streamHandle = newHand})
either (lift . Ex.throwIO) return =<< lift restartStream
modify (\s -> s{streamConnectionState = Secured})
return ()
client params gen backend = do
contextNew backend params gen
defaultParams = defaultParamsClient
tlsinit :: (MonadIO m, MonadIO m1) =>
TLSParams
-> Backend
-> m ( Source m1 BS.ByteString
, Sink BS.ByteString m1 ()
, BS.ByteString -> IO ()
, Int -> m1 BS.ByteString
, Context
)
tlsinit tlsParams backend = do
liftIO $ debugM "Pontarius.Xmpp.TLS" "TLS with debug mode enabled."
gen <- liftIO $ getSystemRandomGen
con <- client tlsParams gen backend
handshake con
let src = forever $ do
dt <- liftIO $ recvData con
liftIO $ debugM "Pontarius.Xmpp.TLS" ("in :" ++ BSC8.unpack dt)
yield dt
let snk = do
d <- await
case d of
Nothing -> return ()
Just x -> do
sendData con (BL.fromChunks [x])
liftIO $ debugM "Pontarius.Xmpp.TLS"
("out :" ++ BSC8.unpack x)
snk
read <- liftIO $ mkReadBuffer (recvData con)
return ( src
, snk
, \s -> do
liftIO $ debugM "Pontarius.Xmpp.TLS" ("out :" ++ BSC8.unpack s)
sendData con $ BL.fromChunks [s]
, liftIO . read
, con
)
mkReadBuffer :: IO BS.ByteString -> IO (Int -> IO BS.ByteString)
mkReadBuffer read = do
buffer <- newIORef BS.empty
let read' n = do
nc <- readIORef buffer
bs <- if BS.null nc then read
else return nc
let (result, rest) = BS.splitAt n bs
writeIORef buffer rest
return result
return read'