{-# LANGUAGE ExistentialQuantification #-}
{-# LANGUAGE ScopedTypeVariables       #-}
{-# LANGUAGE TypeFamilies              #-}

-- | This module provides convenience functions for interfacing @tls@.
--
-- This module is intended to be imported @qualified@, e.g.:
--
module Metro.TP.TLS
  ( TLS
    -- * re-export
  , module Metro.TP.TLSSetting
  , tlsConfig
  ) where

import           Control.Exception     (SomeException, bracketOnError, catch)
import qualified Data.ByteString.Char8 as B (append, length, null)
import qualified Data.ByteString.Lazy  as BL (fromStrict)
import           Metro.Class           (Transport (..))
import           Metro.TP.TLSSetting
import           Network.TLS           (Context, TLSParams)
import qualified Network.TLS           as TLS


newtype TLS = TLS Context

instance Transport TLS where
  data TransportConfig TLS = forall params tp. (Transport tp, TLSParams params) => TLSConfig params (TransportConfig tp)

  -- | Convenience function for initiating an TLS transport
  --
  -- This operation may throw 'TLS.TLSException' on failure.
  --
  newTransport :: TransportConfig TLS -> IO TLS
newTransport (TLSConfig params config) = do
    tp
transport <- TransportConfig tp -> IO tp
forall transport.
Transport transport =>
TransportConfig transport -> IO transport
newTransport TransportConfig tp
config
    IO Context -> (Context -> IO ()) -> (Context -> IO TLS) -> IO TLS
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracketOnError (Backend -> params -> IO Context
forall (m :: * -> *) backend params.
(MonadIO m, HasBackend backend, TLSParams params) =>
backend -> params -> m Context
TLS.contextNew (tp -> Backend
forall tp. Transport tp => tp -> Backend
transportBackend tp
transport) params
params) Context -> IO ()
closeTLS ((Context -> IO TLS) -> IO TLS) -> (Context -> IO TLS) -> IO TLS
forall a b. (a -> b) -> a -> b
$ \ctx :: Context
ctx -> do
      Context -> IO ()
forall (m :: * -> *). MonadIO m => Context -> m ()
TLS.handshake Context
ctx
      TLS -> IO TLS
forall (m :: * -> *) a. Monad m => a -> m a
return (TLS -> IO TLS) -> TLS -> IO TLS
forall a b. (a -> b) -> a -> b
$ Context -> TLS
TLS Context
ctx

  recvData :: TLS -> Int -> IO ByteString
recvData (TLS ctx :: Context
ctx) = IO ByteString -> Int -> IO ByteString
forall a b. a -> b -> a
const (IO ByteString -> Int -> IO ByteString)
-> IO ByteString -> Int -> IO ByteString
forall a b. (a -> b) -> a -> b
$ Context -> IO ByteString
forall (m :: * -> *). MonadIO m => Context -> m ByteString
TLS.recvData Context
ctx
  sendData :: TLS -> ByteString -> IO ()
sendData (TLS ctx :: Context
ctx) = Context -> ByteString -> IO ()
forall (m :: * -> *). MonadIO m => Context -> ByteString -> m ()
TLS.sendData Context
ctx (ByteString -> IO ())
-> (ByteString -> ByteString) -> ByteString -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> ByteString
BL.fromStrict
  closeTransport :: TLS -> IO ()
closeTransport (TLS ctx :: Context
ctx) = Context -> IO ()
closeTLS Context
ctx

transportBackend :: Transport tp => tp -> TLS.Backend
transportBackend :: tp -> Backend
transportBackend transport :: tp
transport = Backend :: IO ()
-> IO ()
-> (ByteString -> IO ())
-> (Int -> IO ByteString)
-> Backend
TLS.Backend
  { backendFlush :: IO ()
TLS.backendFlush = () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
  , backendClose :: IO ()
TLS.backendClose = tp -> IO ()
forall transport. Transport transport => transport -> IO ()
closeTransport tp
transport
  , backendSend :: ByteString -> IO ()
TLS.backendSend = tp -> ByteString -> IO ()
forall transport.
Transport transport =>
transport -> ByteString -> IO ()
sendData tp
transport
  , backendRecv :: Int -> IO ByteString
TLS.backendRecv = Int -> IO ByteString
recvData'
  }

  where recvData' :: Int -> IO ByteString
recvData' nbytes :: Int
nbytes = do
         ByteString
s <- tp -> Int -> IO ByteString
forall transport.
Transport transport =>
transport -> Int -> IO ByteString
recvData tp
transport Int
nbytes
         if Int -> ByteString -> Bool
loadMore Int
nbytes ByteString
s then do
                              ByteString
s' <- Int -> IO ByteString
recvData' (Int
nbytes Int -> Int -> Int
forall a. Num a => a -> a -> a
- ByteString -> Int
B.length ByteString
s)
                              ByteString -> IO ByteString
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString -> IO ByteString) -> ByteString -> IO ByteString
forall a b. (a -> b) -> a -> b
$ ByteString
s ByteString -> ByteString -> ByteString
`B.append` ByteString
s'
                              else ByteString -> IO ByteString
forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
s

        loadMore :: Int -> ByteString -> Bool
loadMore nbytes :: Int
nbytes bs :: ByteString
bs | ByteString -> Bool
B.null ByteString
bs = Bool
False
                           | ByteString -> Int
B.length ByteString
bs Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
nbytes = Bool
True
                           | Bool
otherwise = Bool
False


-- | Close a TLS 'Context' and its underlying socket.
--
closeTLS :: Context -> IO ()
closeTLS :: Context -> IO ()
closeTLS ctx :: Context
ctx = (Context -> IO ()
forall (m :: * -> *). MonadIO m => Context -> m ()
TLS.bye Context
ctx IO () -> IO () -> IO ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Context -> IO ()
TLS.contextClose Context
ctx) -- sometimes socket was closed before 'TLS.bye'
    IO () -> (SomeException -> IO ()) -> IO ()
forall e a. Exception e => IO a -> (e -> IO a) -> IO a
`catch` (\(SomeException
_::SomeException) -> () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ())   -- so we catch the 'Broken pipe' error here


tlsConfig :: (Transport tp, TLSParams params) => params -> TransportConfig tp -> TransportConfig TLS
tlsConfig :: params -> TransportConfig tp -> TransportConfig TLS
tlsConfig = params -> TransportConfig tp -> TransportConfig TLS
forall params tp.
(Transport tp, TLSParams params) =>
params -> TransportConfig tp -> TransportConfig TLS
TLSConfig