{-# LANGUAGE ScopedTypeVariables #-}

-- | This module provides convenience functions for interfacing @tls@.
--
-- This module is intended to be imported @qualified@, e.g.:
--
-- @
-- import           "Data.Connection"
-- import qualified "System.IO.Streams.TLS" as TLS
-- @
--
module System.IO.Streams.TLS
  ( TLSConnection
    -- * client
  , connect
  , connectTLS
  , tLsToConnection
    -- * server
  , accept
    -- * re-export
  , module Data.TLSSetting
  ) where

import qualified Control.Exception     as E
import           Data.Connection
import qualified Data.ByteString       as B
import qualified Data.ByteString.Char8 as BC
import           Data.TLSSetting
import qualified Network.Socket        as N
import           Network.TLS           (ClientParams, Context, ServerParams)
import qualified Network.TLS           as TLS
import qualified System.IO.Streams     as Stream
import qualified System.IO.Streams.TCP as TCP


-- | Type alias for tls connection.
--
-- Normally you shouldn't use 'TLS.Context' in 'connExtraInfo' directly.
--
type TLSConnection = Connection (TLS.Context, N.SockAddr)

-- | Make a 'Connection' from a 'Context'.
--
tLsToConnection :: (Context, N.SockAddr)    -- ^ TLS connection / socket address pair
                -> IO TLSConnection
tLsToConnection :: (Context, SockAddr) -> IO TLSConnection
tLsToConnection (Context
ctx, SockAddr
addr) = do
    InputStream ByteString
is <- IO (Maybe ByteString) -> IO (InputStream ByteString)
forall a. IO (Maybe a) -> IO (InputStream a)
Stream.makeInputStream IO (Maybe ByteString)
input
    TLSConnection -> IO TLSConnection
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (InputStream ByteString
-> (ByteString -> IO ())
-> IO ()
-> (Context, SockAddr)
-> TLSConnection
forall a.
InputStream ByteString
-> (ByteString -> IO ()) -> IO () -> a -> Connection a
Connection InputStream ByteString
is ByteString -> IO ()
forall {m :: * -> *}. MonadIO m => ByteString -> m ()
write (Context -> IO ()
closeTLS Context
ctx) (Context
ctx, SockAddr
addr))
  where
    input :: IO (Maybe ByteString)
input = (do
        ByteString
s <- Context -> IO ByteString
forall (m :: * -> *). MonadIO m => Context -> m ByteString
TLS.recvData Context
ctx
        Maybe ByteString -> IO (Maybe ByteString)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe ByteString -> IO (Maybe ByteString))
-> Maybe ByteString -> IO (Maybe ByteString)
forall a b. (a -> b) -> a -> b
$! if ByteString -> Bool
B.null ByteString
s then Maybe ByteString
forall a. Maybe a
Nothing else ByteString -> Maybe ByteString
forall a. a -> Maybe a
Just ByteString
s
        ) IO (Maybe ByteString)
-> (SomeException -> IO (Maybe ByteString))
-> IO (Maybe ByteString)
forall e a. Exception e => IO a -> (e -> IO a) -> IO a
`E.catch` (\(SomeException
_::E.SomeException) -> Maybe ByteString -> IO (Maybe ByteString)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe ByteString
forall a. Maybe a
Nothing)
    write :: ByteString -> m ()
write ByteString
s = Context -> ByteString -> m ()
forall (m :: * -> *). MonadIO m => Context -> ByteString -> m ()
TLS.sendData Context
ctx ByteString
s

-- | Close a TLS 'Context' and its underlying socket.
--
closeTLS :: Context -> IO ()
closeTLS :: Context -> IO ()
closeTLS Context
ctx = (Context -> IO ()
forall (m :: * -> *). MonadIO m => Context -> m ()
TLS.bye Context
ctx IO () -> IO () -> IO ()
forall a b. IO a -> IO b -> IO b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Context -> IO ()
TLS.contextClose Context
ctx) -- sometimes socket was closed before 'TLS.bye'
    IO () -> (SomeException -> IO ()) -> IO ()
forall e a. Exception e => IO a -> (e -> IO a) -> IO a
`E.catch` (\(SomeException
_::E.SomeException) -> () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ())   -- so we catch the 'Broken pipe' error here

-- | Convenience function for initiating an TLS connection to the given
-- @('HostName', 'PortNumber')@ combination.
--
-- This operation may throw 'TLS.TLSException' on failure.
--
connectTLS :: ClientParams         -- ^ check "Data.TLSSetting"
           -> Maybe String         -- ^ Optional certificate subject name, if set to 'Nothing'
                                   -- then we will try to verify 'HostName' as subject name
           -> N.HostName           -- ^ hostname to connect to
           -> N.PortNumber         -- ^ port number to connect to
           -> IO (Context, N.SockAddr)
connectTLS :: ClientParams
-> Maybe String -> String -> PortNumber -> IO (Context, SockAddr)
connectTLS ClientParams
prms Maybe String
subname String
host PortNumber
port = do
    let subname' :: String
subname' = String -> (String -> String) -> Maybe String -> String
forall b a. b -> (a -> b) -> Maybe a -> b
maybe String
host String -> String
forall a. a -> a
id Maybe String
subname
        prms' :: ClientParams
prms' = ClientParams
prms { TLS.clientServerIdentification = (subname', BC.pack (show port)) }
    (Socket
sock, SockAddr
addr) <- String -> PortNumber -> IO (Socket, SockAddr)
TCP.connectSocket String
host PortNumber
port
    IO Context
-> (Context -> IO ())
-> (Context -> IO (Context, SockAddr))
-> IO (Context, SockAddr)
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
E.bracketOnError (Socket -> ClientParams -> IO Context
forall (m :: * -> *) backend params.
(MonadIO m, HasBackend backend, TLSParams params) =>
backend -> params -> m Context
TLS.contextNew Socket
sock ClientParams
prms') Context -> IO ()
closeTLS ((Context -> IO (Context, SockAddr)) -> IO (Context, SockAddr))
-> (Context -> IO (Context, SockAddr)) -> IO (Context, SockAddr)
forall a b. (a -> b) -> a -> b
$ \ Context
ctx -> do
        Context -> IO ()
forall (m :: * -> *). MonadIO m => Context -> m ()
TLS.handshake Context
ctx
        (Context, SockAddr) -> IO (Context, SockAddr)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Context
ctx, SockAddr
addr)

-- | Connect to server using TLS and return a 'Connection'.
--
connect :: ClientParams         -- ^ check "Data.TLSSetting"
        -> Maybe String         -- ^ Optional certificate subject name, if set to 'Nothing'
                                -- then we will try to verify 'HostName' as subject name
        -> N.HostName           -- ^ hostname to connect to
        -> N.PortNumber         -- ^ port number to connect to
        -> IO TLSConnection
connect :: ClientParams
-> Maybe String -> String -> PortNumber -> IO TLSConnection
connect ClientParams
prms Maybe String
subname String
host PortNumber
port = ClientParams
-> Maybe String -> String -> PortNumber -> IO (Context, SockAddr)
connectTLS ClientParams
prms Maybe String
subname String
host PortNumber
port IO (Context, SockAddr)
-> ((Context, SockAddr) -> IO TLSConnection) -> IO TLSConnection
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= (Context, SockAddr) -> IO TLSConnection
tLsToConnection

-- | Accept a new TLS connection from remote client with listening socket.
--
-- This operation may throw 'TLS.TLSException' on failure.
--
accept :: ServerParams              -- ^ check "Data.TLSSetting"
       -> N.Socket                  -- ^ the listening 'Socket'
       -> IO TLSConnection
accept :: ServerParams -> Socket -> IO TLSConnection
accept ServerParams
prms Socket
sock = do
    (Socket
sock', SockAddr
addr) <- Socket -> IO (Socket, SockAddr)
N.accept Socket
sock
    IO Context
-> (Context -> IO ())
-> (Context -> IO TLSConnection)
-> IO TLSConnection
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
E.bracketOnError (Socket -> ServerParams -> IO Context
forall (m :: * -> *) backend params.
(MonadIO m, HasBackend backend, TLSParams params) =>
backend -> params -> m Context
TLS.contextNew Socket
sock' ServerParams
prms) Context -> IO ()
closeTLS ((Context -> IO TLSConnection) -> IO TLSConnection)
-> (Context -> IO TLSConnection) -> IO TLSConnection
forall a b. (a -> b) -> a -> b
$ \ Context
ctx -> do
        Context -> IO ()
forall (m :: * -> *). MonadIO m => Context -> m ()
TLS.handshake Context
ctx
        TLSConnection
conn <- (Context, SockAddr) -> IO TLSConnection
tLsToConnection (Context
ctx, SockAddr
addr)
        TLSConnection -> IO TLSConnection
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return TLSConnection
conn