{-# LANGUAGE CPP #-}
{-# LANGUAGE OverloadedStrings #-}

#if (__GLASGOW_HASKELL__ >= 706)
{-# LANGUAGE RecursiveDo #-}
#else
{-# LANGUAGE DoRec #-}
#endif

{-|
Module      : MongoDB TLS
Copyright   : (c)	Yuras Shumovich, 2016
License     : Apache 2.0
Maintainer  : Victor Denisov denisovenator@gmail.com
Stability   : experimental
Portability : POSIX

This module is for connecting to TLS enabled mongodb servers.
ATTENTION!!! Be aware that this module is highly experimental and is
barely tested. The current implementation doesn't verify server's identity.
It only allows you to connect to a mongodb server using TLS protocol.
-}

module Database.MongoDB.Transport.Tls
( connect
, connectWithTlsParams
)
where

import Data.IORef
import qualified Data.ByteString as ByteString
import qualified Data.ByteString.Lazy as Lazy.ByteString
import Data.Default.Class (def)
import Control.Exception (bracketOnError)
import Control.Monad (when, unless)
import System.IO
import Database.MongoDB.Internal.Protocol (Pipe, newPipeWith)
import Database.MongoDB.Transport (Transport(Transport))
import qualified Database.MongoDB.Transport as T
import System.IO.Error (mkIOError, eofErrorType)
import Database.MongoDB.Internal.Network (connectTo, HostName, PortID)
import qualified Network.TLS as TLS
import qualified Network.TLS.Extra.Cipher as TLS
import Database.MongoDB.Query (access, slaveOk, retrieveServerData)

-- | Connect to mongodb using TLS
connect :: HostName -> PortID -> IO Pipe
connect :: HostName -> PortID -> IO Pipe
connect HostName
host PortID
port = ClientParams -> HostName -> PortID -> IO Pipe
connectWithTlsParams ClientParams
params HostName
host PortID
port
  where
    params :: ClientParams
params = (HostName -> ByteString -> ClientParams
TLS.defaultParamsClient HostName
host ByteString
"")
        { TLS.clientSupported = def
            { TLS.supportedCiphers = TLS.ciphersuite_default }
        , TLS.clientHooks = def
            { TLS.onServerCertificate = \CertificateStore
_ ValidationCache
_ ServiceID
_ CertificateChain
_ -> [FailedReason] -> IO [FailedReason]
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return [] }
        }

-- | Connect to mongodb using TLS using provided TLS client parameters
connectWithTlsParams :: TLS.ClientParams -> HostName -> PortID -> IO Pipe
connectWithTlsParams :: ClientParams -> HostName -> PortID -> IO Pipe
connectWithTlsParams ClientParams
clientParams HostName
host PortID
port = IO Handle -> (Handle -> IO ()) -> (Handle -> IO Pipe) -> IO Pipe
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracketOnError (HostName -> PortID -> IO Handle
connectTo HostName
host PortID
port) Handle -> IO ()
hClose ((Handle -> IO Pipe) -> IO Pipe) -> (Handle -> IO Pipe) -> IO Pipe
forall a b. (a -> b) -> a -> b
$ \Handle
handle -> do
  Context
context <- Handle -> ClientParams -> IO Context
forall (m :: * -> *) backend params.
(MonadIO m, HasBackend backend, TLSParams params) =>
backend -> params -> m Context
TLS.contextNew Handle
handle ClientParams
clientParams
  Context -> IO ()
forall (m :: * -> *). MonadIO m => Context -> m ()
TLS.handshake Context
context

  Transport
conn <- Context -> IO Transport
tlsConnection Context
context
  rec
    Pipe
p <- ServerData -> Transport -> IO Pipe
newPipeWith ServerData
sd Transport
conn
    ServerData
sd <- Pipe
-> AccessMode -> Database -> Action IO ServerData -> IO ServerData
forall (m :: * -> *) a.
MonadIO m =>
Pipe -> AccessMode -> Database -> Action m a -> m a
access Pipe
p AccessMode
slaveOk Database
"admin" Action IO ServerData
forall (m :: * -> *). MonadIO m => Action m ServerData
retrieveServerData
  Pipe -> IO Pipe
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Pipe
p

tlsConnection :: TLS.Context -> IO Transport
tlsConnection :: Context -> IO Transport
tlsConnection Context
ctx = do
  IORef ByteString
restRef <- ByteString -> IO (IORef ByteString)
forall a. a -> IO (IORef a)
newIORef ByteString
forall a. Monoid a => a
mempty
  Transport -> IO Transport
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Transport
    { read :: Int -> IO ByteString
T.read = \Int
count -> let
          readSome :: IO ByteString
readSome = do
            ByteString
rest <- IORef ByteString -> IO ByteString
forall a. IORef a -> IO a
readIORef IORef ByteString
restRef
            IORef ByteString -> ByteString -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef IORef ByteString
restRef ByteString
forall a. Monoid a => a
mempty
            if ByteString -> Bool
ByteString.null ByteString
rest
              then Context -> IO ByteString
forall (m :: * -> *). MonadIO m => Context -> m ByteString
TLS.recvData Context
ctx
              else ByteString -> IO ByteString
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
rest
          unread :: ByteString -> IO ()
unread = \ByteString
rest ->
            IORef ByteString -> (ByteString -> ByteString) -> IO ()
forall a. IORef a -> (a -> a) -> IO ()
modifyIORef IORef ByteString
restRef (ByteString
rest ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<>)
          go :: ByteString -> Int -> IO ByteString
go ByteString
acc Int
n = do
            -- read until get enough bytes
            ByteString
chunk <- IO ByteString
readSome
            Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (ByteString -> Bool
ByteString.null ByteString
chunk) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$
              IOError -> IO ()
forall a. IOError -> IO a
ioError IOError
eof
            let len :: Int
len = ByteString -> Int
ByteString.length ByteString
chunk
            if Int
len Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
n
              then do
                let (ByteString
res, ByteString
rest) = Int -> ByteString -> (ByteString, ByteString)
ByteString.splitAt Int
n ByteString
chunk
                Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (ByteString -> Bool
ByteString.null ByteString
rest) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$
                  ByteString -> IO ()
unread ByteString
rest
                ByteString -> IO ByteString
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString
acc ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString -> ByteString
Lazy.ByteString.fromStrict ByteString
res)
              else ByteString -> Int -> IO ByteString
go (ByteString
acc ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString -> ByteString
Lazy.ByteString.fromStrict ByteString
chunk) (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
len)
          eof :: IOError
eof = IOErrorType
-> HostName -> Maybe Handle -> Maybe HostName -> IOError
mkIOError IOErrorType
eofErrorType HostName
"Database.MongoDB.Transport"
                Maybe Handle
forall a. Maybe a
Nothing Maybe HostName
forall a. Maybe a
Nothing
       in ByteString -> ByteString
Lazy.ByteString.toStrict (ByteString -> ByteString) -> IO ByteString -> IO ByteString
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ByteString -> Int -> IO ByteString
go ByteString
forall a. Monoid a => a
mempty Int
count
    , write :: ByteString -> IO ()
T.write = Context -> ByteString -> IO ()
forall (m :: * -> *). MonadIO m => Context -> ByteString -> m ()
TLS.sendData Context
ctx (ByteString -> IO ())
-> (ByteString -> ByteString) -> ByteString -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> ByteString
Lazy.ByteString.fromStrict
    , flush :: IO ()
T.flush = Context -> IO ()
TLS.contextFlush Context
ctx
    , close :: IO ()
T.close = Context -> IO ()
TLS.contextClose Context
ctx
    }