{-# OPTIONS_HADDOCK hide #-}
{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PackageImports #-}
{-# LANGUAGE FlexibleContexts #-}

module Network.Xmpp.Tls where

import                 Control.Applicative ((<$>))
import qualified       Control.Exception.Lifted as Ex
import                 Control.Monad
import                 Control.Monad.Except
import                 Control.Monad.State.Strict
import "crypto-random" Crypto.Random
import qualified       Data.ByteString as BS
import qualified       Data.ByteString.Char8 as BSC8
import qualified       Data.ByteString.Lazy as BL
import                 Data.Conduit
import                 Data.IORef
import                 Data.Monoid
import                 Data.XML.Types
import                 Network.DNS.Resolver (ResolvConf)
import                 Network.TLS
import                 Network.Xmpp.Stream
import                 Network.Xmpp.Types
import                 System.Log.Logger (debugM, errorM, infoM)
import                 System.X509

mkBackend :: StreamHandle -> Backend
mkBackend con = Backend { backendSend = \bs -> void (streamSend con bs)
                        , backendRecv = bufferReceive (streamReceive con)
                        , backendFlush = streamFlush con
                        , backendClose = streamClose con
                        }
  where
    bufferReceive _ 0 = return BS.empty
    bufferReceive recv n = BS.concat `liftM` (go n)
      where
        go m = do
            mbBs <- recv m
            bs <- case mbBs of
                Left e -> Ex.throwIO e
                Right r -> return r
            case BS.length bs of
                0 -> return []
                l -> if l < m
                     then (bs :) `liftM` go (m - l)
                     else return [bs]

starttlsE :: Element
starttlsE = Element "{urn:ietf:params:xml:ns:xmpp-tls}starttls" [] []

-- | Checks for TLS support and run starttls procedure if applicable
tls :: Stream -> IO (Either XmppFailure ())
tls con = fmap join -- We can have Left values both from exceptions and the
                    -- error monad. Join unifies them into one error layer
          . wrapExceptions
          . flip withStream con
          . runExceptT $ do
    conf <- gets streamConfiguration
    sState <- gets streamConnectionState
    case sState of
        Plain -> return ()
        Closed -> do
            liftIO $ errorM "Pontarius.Xmpp.Tls" "The stream is closed."
            throwError XmppNoStream
        Finished -> do
            liftIO $ errorM "Pontarius.Xmpp.Tls" "The stream is finished."
            throwError XmppNoStream
        Secured -> do
            liftIO $ errorM "Pontarius.Xmpp.Tls" "The stream is already secured."
            throwError TlsStreamSecured
    features <- lift $ gets streamFeatures
    case (tlsBehaviour conf, streamFeaturesTls features) of
        (RequireTls  , Just _   ) -> startTls
        (RequireTls  , Nothing  ) -> throwError TlsNoServerSupport
        (PreferTls   , Just _   ) -> startTls
        (PreferTls   , Nothing  ) -> skipTls
        (PreferPlain , Just True) -> startTls
        (PreferPlain , _        ) -> skipTls
        (RefuseTls   , Just True) -> throwError XmppOtherFailure
        (RefuseTls   , _        ) -> skipTls
  where
    skipTls = liftIO $ infoM "Pontarius.Xmpp.Tls" "Skipping TLS negotiation"
    startTls = do
        liftIO $ infoM "Pontarius.Xmpp.Tls" "Running StartTLS"
        params <- gets $ tlsParams . streamConfiguration
        ExceptT $ pushElement starttlsE
        answer <- lift $ pullElement
        case answer of
            Left e -> throwError e
            Right (Element "{urn:ietf:params:xml:ns:xmpp-tls}proceed" [] []) ->
                return ()
            Right (Element "{urn:ietf:params:xml:ns:xmpp-tls}failure" _ _) -> do
                liftIO $ errorM "Pontarius.Xmpp" "startTls: TLS initiation failed."
                throwError XmppOtherFailure
            Right r ->
                liftIO $ errorM "Pontarius.Xmpp.Tls" $
                            "Unexpected element: " ++ show r
        hand <- gets streamHandle
        (_raw, _snk, psh, recv, ctx) <- lift $ tlsinit params (mkBackend hand)
        let newHand = StreamHandle { streamSend = catchPush . psh
                                   , streamReceive = wrapExceptions . recv
                                   , streamFlush = contextFlush ctx
                                   , streamClose = bye ctx >> streamClose hand
                                   }
        lift $ modify ( \x -> x {streamHandle = newHand})
        liftIO $ infoM "Pontarius.Xmpp.Tls" "Stream Secured."
        either (lift . Ex.throwIO) return =<< lift restartStream
        modify (\s -> s{streamConnectionState = Secured})
        return ()

client :: MonadIO m => ClientParams -> Backend -> m Context
client params backend = contextNew backend params

tlsinit :: (MonadIO m, MonadIO m1) =>
        ClientParams
     -> Backend
     -> m ( ConduitT () BS.ByteString m1 ()
          , ConduitT BS.ByteString Void m1 ()
          , BS.ByteString -> IO ()
          , Int -> m1 BS.ByteString
          , Context
          )
tlsinit params backend = do
    liftIO $ debugM "Pontarius.Xmpp.Tls" "TLS with debug mode enabled."
    -- gen <- liftIO (cprgCreate <$> createEntropyPool :: IO SystemRNG)
    sysCStore <- liftIO getSystemCertificateStore
    let params' = params{clientShared =
                      (clientShared params){ sharedCAStore =
                          sysCStore <> sharedCAStore (clientShared params)}}
    con <- client params' backend
    handshake con
    let src = forever $ do
            dt <- liftIO $ recvData con
            liftIO $ debugM "Pontarius.Xmpp.Tls" ("In :" ++ BSC8.unpack dt)
            yield dt
    let snk = do
            d <- await
            case d of
                Nothing -> return ()
                Just x -> do
                       sendData con (BL.fromChunks [x])
                       snk
    readWithBuffer <- liftIO $ mkReadBuffer (recvData con)
    return ( src
           , snk
             -- Note: sendData already sends the data to the debug output
           , \s -> sendData con $ BL.fromChunks [s]
           , liftIO . readWithBuffer
           , con
           )

mkReadBuffer :: IO BS.ByteString -> IO (Int -> IO BS.ByteString)
mkReadBuffer recv = do
    buffer <- newIORef BS.empty
    let read' n = do
            nc <- readIORef buffer
            bs <- if BS.null nc then recv
                                else return nc
            let (result, rest) = BS.splitAt n bs
            writeIORef buffer rest
            return result
    return read'

-- | Connect to an XMPP server and secure the connection with TLS before
-- starting the XMPP streams
--
-- /NB/ RFC 6120 does not specify this method, but some servers, notably GCS,
-- seem to use it.
connectTls :: ResolvConf -- ^ Resolv conf to use (try 'defaultResolvConf' as a
                         -- default)
           -> ClientParams  -- ^ TLS parameters to use when securing the connection
           -> String     -- ^ Host to use when connecting (will be resolved
                         -- using SRV records)
           -> ExceptT XmppFailure IO StreamHandle
connectTls config params host = do
    h <- connectSrv config host >>= \h' -> case h' of
        Nothing -> throwError TcpConnectionFailure
        Just h'' -> return h''
    let hand = handleToStreamHandle h
    let params' = params{clientServerIdentification
                   = case clientServerIdentification params of
                       ("", _) -> (host, "")
                       csi -> csi
                       }
    (_raw, _snk, psh, recv, ctx) <- tlsinit params' $ mkBackend hand
    return StreamHandle{ streamSend = catchPush . psh
                       , streamReceive = wrapExceptions . recv
                       , streamFlush = contextFlush ctx
                       , streamClose = bye ctx >> streamClose hand
                       }

wrapExceptions :: IO a -> IO (Either XmppFailure a)
wrapExceptions f = Ex.catches (liftM Right $ f)
                 [ Ex.Handler $ return . Left . XmppIOException
                 , Ex.Handler $ wrap . XmppTlsError
                 , Ex.Handler $ wrap . XmppTlsException
                 , Ex.Handler $ return . Left
                 ]
  where
    wrap = return . Left . TlsError