{- Copyright (C) 2009 John Millikin This program is free software: you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation, either version 3 of the License, or any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details. You should have received a copy of the GNU General Public License along with this program. If not, see . -} module Network.Protocol.XMPP.Stream ( Stream ( streamLanguage ,streamVersion ,streamFeatures ) ,StreamFeature ( FeatureStartTLS ,FeatureSASL ,FeatureRegister ,FeatureBind ,FeatureSession ) ,beginStream ,restartStream ,getTree ,putTree ) where import qualified System.IO as IO import Data.AssocList (lookupDef) import Data.Char (toUpper) -- XML Parsing import Text.XML.HXT.Arrow ((>>>)) import qualified Text.XML.HXT.Arrow as A import qualified Text.XML.HXT.DOM.Interface as DOM import qualified Text.XML.HXT.DOM.XmlNode as XN import qualified Text.XML.LibXML.SAX as SAX -- TLS support import qualified Network.GnuTLS as GnuTLS import Foreign (allocaBytes) import Foreign.C (peekCAStringLen) import Network.Protocol.XMPP.JID (JID, jidFormat) import qualified Network.Protocol.XMPP.Util as Util maxXMPPVersion :: XMPPVersion maxXMPPVersion = XMPPVersion 1 0 data Stream = Stream { streamHandle :: Handle ,streamJID :: JID ,streamParser :: SAX.Parser ,streamLanguage :: XMLLanguage ,streamVersion :: XMPPVersion ,streamFeatures :: [StreamFeature] } data StreamFeature = FeatureStartTLS Bool | FeatureSASL [String] | FeatureRegister | FeatureBind | FeatureSession | FeatureUnknown DOM.XmlTree deriving (Show, Eq) newtype XMLLanguage = XMLLanguage String deriving (Show, Eq) data XMPPVersion = XMPPVersion Int Int deriving (Show, Eq) data Handle = PlainHandle IO.Handle | SecureHandle IO.Handle (GnuTLS.Session GnuTLS.Client) ------------------------------------------------------------------------------ restartStream :: Stream -> IO Stream restartStream s = beginStream' (streamJID s) (streamHandle s) beginStream :: JID -> IO.Handle -> IO Stream beginStream jid rawHandle = do IO.hSetBuffering rawHandle IO.NoBuffering plainStream <- beginStream' jid (PlainHandle rawHandle) putTree plainStream $ Util.mkElement ("", "starttls") [("", "xmlns", "urn:ietf:params:xml:ns:xmpp-tls")] [] getTree plainStream session <- GnuTLS.tlsClient [ GnuTLS.handle GnuTLS.:= rawHandle ,GnuTLS.priorities GnuTLS.:= [GnuTLS.CrtX509] ,GnuTLS.credentials GnuTLS.:= GnuTLS.certificateCredentials ] GnuTLS.handshake session beginStream' jid (SecureHandle rawHandle session) beginStream' :: JID -> Handle -> IO Stream beginStream' jid h = do -- Since only the opening tag should be written, normal XML -- serialization cannot be used. Be careful to escape any embedded -- attributes. let xmlHeader = "\n" ++ "" parser <- SAX.mkParser hPutStr h xmlHeader initialEvents <- readEventsUntil startOfStream h parser featureTree <- getTree' h parser let startStreamEvent = last initialEvents let (language, version) = parseStartStream startStreamEvent let features = parseFeatures featureTree return $ Stream h jid parser language version features where streamName = Util.mkQName "http://etherx.jabber.org/streams" "stream" startOfStream depth event = case (depth, event) of (1, (SAX.BeginElement elemName _)) -> streamName == Util.convertQName elemName _ -> False parseStartStream :: SAX.Event -> (XMLLanguage, XMPPVersion) parseStartStream e = (XMLLanguage "en", XMPPVersion 1 0) -- TODO parseFeatures :: DOM.XmlTree -> [StreamFeature] parseFeatures t = A.runLA (A.getChildren >>> A.hasQName featuresName >>> A.getChildren >>> A.arrL (\t' -> [parseFeature t'])) t where featuresName = Util.mkQName "http://etherx.jabber.org/streams" "features" parseFeature :: DOM.XmlTree -> StreamFeature parseFeature t = lookupDef FeatureUnknown qname [ (("urn:ietf:params:xml:ns:xmpp-tls", "starttls"), parseFeatureTLS) ,(("urn:ietf:params:xml:ns:xmpp-sasl", "mechanisms"), parseFeatureSASL) ,(("http://jabber.org/features/iq-register", "register"), (\_ -> FeatureRegister)) ,(("urn:ietf:params:xml:ns:xmpp-bind", "bind"), (\_ -> FeatureBind)) ,(("urn:ietf:params:xml:ns:xmpp-session", "session"), (\_ -> FeatureSession)) ] t where qname = maybe ("", "") (\n -> (DOM.namespaceUri n, DOM.localPart n)) (XN.getName t) parseFeatureTLS :: DOM.XmlTree -> StreamFeature parseFeatureTLS t = FeatureStartTLS True -- TODO: detect whether or not required parseFeatureSASL :: DOM.XmlTree -> StreamFeature parseFeatureSASL t = let mechName = Util.mkQName "urn:ietf:params:xml:ns:xmpp-sasl" "mechanism" mechanisms = A.runLA ( A.getChildren >>> A.hasQName mechName >>> A.getChildren >>> A.getText) t in FeatureSASL $ map (map toUpper) mechanisms ------------------------------------------------------------------------------- getTree :: Stream -> IO DOM.XmlTree getTree s = getTree' (streamHandle s) (streamParser s) getTree' :: Handle -> SAX.Parser -> IO DOM.XmlTree getTree' h p = do events <- readEventsUntil finished h p return $ Util.eventsToTree events where finished 0 (SAX.EndElement _) = True finished _ _ = False putTree :: Stream -> DOM.XmlTree -> IO () putTree s t = do let root = XN.mkRoot [] [t] let h = streamHandle s [text] <- A.runX (A.constA root >>> A.writeDocumentToString [ (A.a_no_xml_pi, "1") ]) hPutStr h text ------------------------------------------------------------------------------- readEventsUntil :: (Int -> SAX.Event -> Bool) -> Handle -> SAX.Parser -> IO [SAX.Event] readEventsUntil done h parser = readEventsUntil' done 0 [] $ do char <- hGetChar h SAX.parse parser [char] False readEventsUntil' :: (Int -> SAX.Event -> Bool) -> Int -> [SAX.Event] -> IO [SAX.Event] -> IO [SAX.Event] readEventsUntil' done depth accum getEvents = do events <- getEvents let (done', depth', accum') = readEventsStep done events depth accum if done' then return accum' else readEventsUntil' done depth' accum' getEvents readEventsStep :: (Int -> SAX.Event -> Bool) -> [SAX.Event] -> Int -> [SAX.Event] -> (Bool, Int, [SAX.Event]) readEventsStep _ [] depth accum = (False, depth, accum) readEventsStep done (e:es) depth accum = let depth' = depth + case e of (SAX.BeginElement _ _) -> 1 (SAX.EndElement _) -> (- 1) _ -> 0 accum' = accum ++ [e] in if done depth' e then (True, depth', accum') else readEventsStep done es depth' accum' ------------------------------------------------------------------------------- hPutStr :: Handle -> String -> IO () hPutStr (PlainHandle h) = IO.hPutStr h hPutStr (SecureHandle _ session) = GnuTLS.tlsSendString session hGetChar :: Handle -> IO Char hGetChar (PlainHandle h) = IO.hGetChar h hGetChar (SecureHandle h session) = allocaBytes 1 $ \ptr -> do pending <- GnuTLS.tlsCheckPending session if pending == 0 then do IO.hWaitForInput h (-1) return () else return () len <- GnuTLS.tlsRecv session ptr 1 [char] <- peekCAStringLen (ptr, len) return char