{-# LANGUAGE ScopedTypeVariables #-}

-- | This module provides convenience functions for interfacing @tls@.
--
-- This module is intended to be imported @qualified@, e.g.:
--
-- @
-- import           "Data.Connection"
-- import qualified "System.IO.Streams.TLS" as TLS
-- @
--
module System.IO.Streams.TLS
  ( TLSConnection
    -- * client
  , connect
  , connectTLS
  , tLsToConnection
    -- * server
  , accept
    -- * re-export
  , module Data.TLSSetting
  ) where

import qualified Control.Exception     as E
import           Data.Connection
import qualified Data.ByteString       as B
import qualified Data.ByteString.Char8 as BC
import           Data.TLSSetting
import qualified Network.Socket        as N
import           Network.TLS           (ClientParams, Context, ServerParams)
import qualified Network.TLS           as TLS
import qualified System.IO.Streams     as Stream
import qualified System.IO.Streams.TCP as TCP


-- | Type alias for tls connection.
--
-- Normally you shouldn't use 'TLS.Context' in 'connExtraInfo' directly.
--
type TLSConnection = Connection (TLS.Context, N.SockAddr)

-- | Make a 'Connection' from a 'Context'.
--
tLsToConnection :: (Context, N.SockAddr)    -- ^ TLS connection / socket address pair
                -> IO TLSConnection
tLsToConnection :: (Context, SockAddr) -> IO TLSConnection
tLsToConnection (Context
ctx, SockAddr
addr) = do
    InputStream ByteString
is <- forall a. IO (Maybe a) -> IO (InputStream a)
Stream.makeInputStream IO (Maybe ByteString)
input
    forall (m :: * -> *) a. Monad m => a -> m a
return (forall a.
InputStream ByteString
-> (ByteString -> IO ()) -> IO () -> a -> Connection a
Connection InputStream ByteString
is forall {m :: * -> *}. MonadIO m => ByteString -> m ()
write (Context -> IO ()
closeTLS Context
ctx) (Context
ctx, SockAddr
addr))
  where
    input :: IO (Maybe ByteString)
input = (do
        ByteString
s <- forall (m :: * -> *). MonadIO m => Context -> m ByteString
TLS.recvData Context
ctx
        forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$! if ByteString -> Bool
B.null ByteString
s then forall a. Maybe a
Nothing else forall a. a -> Maybe a
Just ByteString
s
        ) forall e a. Exception e => IO a -> (e -> IO a) -> IO a
`E.catch` (\(SomeException
_::E.SomeException) -> forall (m :: * -> *) a. Monad m => a -> m a
return forall a. Maybe a
Nothing)
    write :: ByteString -> m ()
write ByteString
s = forall (m :: * -> *). MonadIO m => Context -> ByteString -> m ()
TLS.sendData Context
ctx ByteString
s

-- | Close a TLS 'Context' and its underlying socket.
--
closeTLS :: Context -> IO ()
closeTLS :: Context -> IO ()
closeTLS Context
ctx = (forall (m :: * -> *). MonadIO m => Context -> m ()
TLS.bye Context
ctx forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Context -> IO ()
TLS.contextClose Context
ctx) -- sometimes socket was closed before 'TLS.bye'
    forall e a. Exception e => IO a -> (e -> IO a) -> IO a
`E.catch` (\(SomeException
_::E.SomeException) -> forall (m :: * -> *) a. Monad m => a -> m a
return ())   -- so we catch the 'Broken pipe' error here

-- | Convenience function for initiating an TLS connection to the given
-- @('HostName', 'PortNumber')@ combination.
--
-- This operation may throw 'TLS.TLSException' on failure.
--
connectTLS :: ClientParams         -- ^ check "Data.TLSSetting"
           -> Maybe String         -- ^ Optional certificate subject name, if set to 'Nothing'
                                   -- then we will try to verify 'HostName' as subject name
           -> N.HostName           -- ^ hostname to connect to
           -> N.PortNumber         -- ^ port number to connect to
           -> IO (Context, N.SockAddr)
connectTLS :: ClientParams
-> Maybe String -> String -> PortNumber -> IO (Context, SockAddr)
connectTLS ClientParams
prms Maybe String
subname String
host PortNumber
port = do
    let subname' :: String
subname' = forall b a. b -> (a -> b) -> Maybe a -> b
maybe String
host forall a. a -> a
id Maybe String
subname
        prms' :: ClientParams
prms' = ClientParams
prms { clientServerIdentification :: (String, ByteString)
TLS.clientServerIdentification = (String
subname', String -> ByteString
BC.pack (forall a. Show a => a -> String
show PortNumber
port)) }
    (Socket
sock, SockAddr
addr) <- String -> PortNumber -> IO (Socket, SockAddr)
TCP.connectSocket String
host PortNumber
port
    forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
E.bracketOnError (forall (m :: * -> *) backend params.
(MonadIO m, HasBackend backend, TLSParams params) =>
backend -> params -> m Context
TLS.contextNew Socket
sock ClientParams
prms') Context -> IO ()
closeTLS forall a b. (a -> b) -> a -> b
$ \ Context
ctx -> do
        forall (m :: * -> *). MonadIO m => Context -> m ()
TLS.handshake Context
ctx
        forall (m :: * -> *) a. Monad m => a -> m a
return (Context
ctx, SockAddr
addr)

-- | Connect to server using TLS and return a 'Connection'.
--
connect :: ClientParams         -- ^ check "Data.TLSSetting"
        -> Maybe String         -- ^ Optional certificate subject name, if set to 'Nothing'
                                -- then we will try to verify 'HostName' as subject name
        -> N.HostName           -- ^ hostname to connect to
        -> N.PortNumber         -- ^ port number to connect to
        -> IO TLSConnection
connect :: ClientParams
-> Maybe String -> String -> PortNumber -> IO TLSConnection
connect ClientParams
prms Maybe String
subname String
host PortNumber
port = ClientParams
-> Maybe String -> String -> PortNumber -> IO (Context, SockAddr)
connectTLS ClientParams
prms Maybe String
subname String
host PortNumber
port forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= (Context, SockAddr) -> IO TLSConnection
tLsToConnection

-- | Accept a new TLS connection from remote client with listening socket.
--
-- This operation may throw 'TLS.TLSException' on failure.
--
accept :: ServerParams              -- ^ check "Data.TLSSetting"
       -> N.Socket                  -- ^ the listening 'Socket'
       -> IO TLSConnection
accept :: ServerParams -> Socket -> IO TLSConnection
accept ServerParams
prms Socket
sock = do
    (Socket
sock', SockAddr
addr) <- Socket -> IO (Socket, SockAddr)
N.accept Socket
sock
    forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
E.bracketOnError (forall (m :: * -> *) backend params.
(MonadIO m, HasBackend backend, TLSParams params) =>
backend -> params -> m Context
TLS.contextNew Socket
sock' ServerParams
prms) Context -> IO ()
closeTLS forall a b. (a -> b) -> a -> b
$ \ Context
ctx -> do
        forall (m :: * -> *). MonadIO m => Context -> m ()
TLS.handshake Context
ctx
        TLSConnection
conn <- (Context, SockAddr) -> IO TLSConnection
tLsToConnection (Context
ctx, SockAddr
addr)
        forall (m :: * -> *) a. Monad m => a -> m a
return TLSConnection
conn