{-# LANGUAGE CPP #-}
{-# LANGUAGE OverloadedStrings #-}
#if (__GLASGOW_HASKELL__ >= 706)
{-# LANGUAGE RecursiveDo #-}
#else
{-# LANGUAGE DoRec #-}
#endif
module Database.MongoDB.Transport.Tls
( connect
, connectWithTlsParams
)
where
import Data.IORef
import qualified Data.ByteString as ByteString
import qualified Data.ByteString.Lazy as Lazy.ByteString
import Data.Default.Class (def)
import Control.Exception (bracketOnError)
import Control.Monad (when, unless)
import System.IO
import Database.MongoDB.Internal.Protocol (Pipe, newPipeWith)
import Database.MongoDB.Transport (Transport(Transport))
import qualified Database.MongoDB.Transport as T
import System.IO.Error (mkIOError, eofErrorType)
import Database.MongoDB.Internal.Network (connectTo, HostName, PortID)
import qualified Network.TLS as TLS
import qualified Network.TLS.Extra.Cipher as TLS
import Database.MongoDB.Query (access, slaveOk, retrieveServerData)
connect :: HostName -> PortID -> IO Pipe
connect :: HostName -> PortID -> IO Pipe
connect HostName
host PortID
port = ClientParams -> HostName -> PortID -> IO Pipe
connectWithTlsParams ClientParams
params HostName
host PortID
port
where
params :: ClientParams
params = (HostName -> ByteString -> ClientParams
TLS.defaultParamsClient HostName
host ByteString
"")
{ TLS.clientSupported = def
{ TLS.supportedCiphers = TLS.ciphersuite_default }
, TLS.clientHooks = def
{ TLS.onServerCertificate = \CertificateStore
_ ValidationCache
_ ServiceID
_ CertificateChain
_ -> [FailedReason] -> IO [FailedReason]
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return [] }
}
connectWithTlsParams :: TLS.ClientParams -> HostName -> PortID -> IO Pipe
connectWithTlsParams :: ClientParams -> HostName -> PortID -> IO Pipe
connectWithTlsParams ClientParams
clientParams HostName
host PortID
port = IO Handle -> (Handle -> IO ()) -> (Handle -> IO Pipe) -> IO Pipe
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracketOnError (HostName -> PortID -> IO Handle
connectTo HostName
host PortID
port) Handle -> IO ()
hClose ((Handle -> IO Pipe) -> IO Pipe) -> (Handle -> IO Pipe) -> IO Pipe
forall a b. (a -> b) -> a -> b
$ \Handle
handle -> do
Context
context <- Handle -> ClientParams -> IO Context
forall (m :: * -> *) backend params.
(MonadIO m, HasBackend backend, TLSParams params) =>
backend -> params -> m Context
TLS.contextNew Handle
handle ClientParams
clientParams
Context -> IO ()
forall (m :: * -> *). MonadIO m => Context -> m ()
TLS.handshake Context
context
Transport
conn <- Context -> IO Transport
tlsConnection Context
context
rec
Pipe
p <- ServerData -> Transport -> IO Pipe
newPipeWith ServerData
sd Transport
conn
ServerData
sd <- Pipe
-> AccessMode -> Database -> Action IO ServerData -> IO ServerData
forall (m :: * -> *) a.
MonadIO m =>
Pipe -> AccessMode -> Database -> Action m a -> m a
access Pipe
p AccessMode
slaveOk Database
"admin" Action IO ServerData
forall (m :: * -> *). MonadIO m => Action m ServerData
retrieveServerData
Pipe -> IO Pipe
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Pipe
p
tlsConnection :: TLS.Context -> IO Transport
tlsConnection :: Context -> IO Transport
tlsConnection Context
ctx = do
IORef ByteString
restRef <- ByteString -> IO (IORef ByteString)
forall a. a -> IO (IORef a)
newIORef ByteString
forall a. Monoid a => a
mempty
Transport -> IO Transport
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Transport
{ read :: Int -> IO ByteString
T.read = \Int
count -> let
readSome :: IO ByteString
readSome = do
ByteString
rest <- IORef ByteString -> IO ByteString
forall a. IORef a -> IO a
readIORef IORef ByteString
restRef
IORef ByteString -> ByteString -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef IORef ByteString
restRef ByteString
forall a. Monoid a => a
mempty
if ByteString -> Bool
ByteString.null ByteString
rest
then Context -> IO ByteString
forall (m :: * -> *). MonadIO m => Context -> m ByteString
TLS.recvData Context
ctx
else ByteString -> IO ByteString
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
rest
unread :: ByteString -> IO ()
unread = \ByteString
rest ->
IORef ByteString -> (ByteString -> ByteString) -> IO ()
forall a. IORef a -> (a -> a) -> IO ()
modifyIORef IORef ByteString
restRef (ByteString
rest ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<>)
go :: ByteString -> Int -> IO ByteString
go ByteString
acc Int
n = do
ByteString
chunk <- IO ByteString
readSome
Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (ByteString -> Bool
ByteString.null ByteString
chunk) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$
IOError -> IO ()
forall a. IOError -> IO a
ioError IOError
eof
let len :: Int
len = ByteString -> Int
ByteString.length ByteString
chunk
if Int
len Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
n
then do
let (ByteString
res, ByteString
rest) = Int -> ByteString -> (ByteString, ByteString)
ByteString.splitAt Int
n ByteString
chunk
Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (ByteString -> Bool
ByteString.null ByteString
rest) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$
ByteString -> IO ()
unread ByteString
rest
ByteString -> IO ByteString
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString
acc ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString -> ByteString
Lazy.ByteString.fromStrict ByteString
res)
else ByteString -> Int -> IO ByteString
go (ByteString
acc ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString -> ByteString
Lazy.ByteString.fromStrict ByteString
chunk) (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
len)
eof :: IOError
eof = IOErrorType
-> HostName -> Maybe Handle -> Maybe HostName -> IOError
mkIOError IOErrorType
eofErrorType HostName
"Database.MongoDB.Transport"
Maybe Handle
forall a. Maybe a
Nothing Maybe HostName
forall a. Maybe a
Nothing
in ByteString -> ByteString
Lazy.ByteString.toStrict (ByteString -> ByteString) -> IO ByteString -> IO ByteString
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ByteString -> Int -> IO ByteString
go ByteString
forall a. Monoid a => a
mempty Int
count
, write :: ByteString -> IO ()
T.write = Context -> ByteString -> IO ()
forall (m :: * -> *). MonadIO m => Context -> ByteString -> m ()
TLS.sendData Context
ctx (ByteString -> IO ())
-> (ByteString -> ByteString) -> ByteString -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> ByteString
Lazy.ByteString.fromStrict
, flush :: IO ()
T.flush = Context -> IO ()
TLS.contextFlush Context
ctx
, close :: IO ()
T.close = Context -> IO ()
TLS.contextClose Context
ctx
}