{-# LANGUAGE ScopedTypeVariables #-}
module System.IO.Streams.TLS
  ( TLSConnection
    
  , connect
  , connectTLS
  , tLsToConnection
    
  , accept
    
  , module Data.TLSSetting
  ) where
import qualified Control.Exception     as E
import           Data.Connection
import qualified Data.ByteString       as B
import qualified Data.ByteString.Char8 as BC
import           Data.TLSSetting
import qualified Network.Socket        as N
import           Network.TLS           (ClientParams, Context, ServerParams)
import qualified Network.TLS           as TLS
import qualified System.IO.Streams     as Stream
import qualified System.IO.Streams.TCP as TCP
type TLSConnection = Connection (TLS.Context, N.SockAddr)
tLsToConnection :: (Context, N.SockAddr)    
                -> IO TLSConnection
tLsToConnection :: (Context, SockAddr) -> IO TLSConnection
tLsToConnection (Context
ctx, SockAddr
addr) = do
    InputStream ByteString
is <- IO (Maybe ByteString) -> IO (InputStream ByteString)
forall a. IO (Maybe a) -> IO (InputStream a)
Stream.makeInputStream IO (Maybe ByteString)
input
    TLSConnection -> IO TLSConnection
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (InputStream ByteString
-> (ByteString -> IO ())
-> IO ()
-> (Context, SockAddr)
-> TLSConnection
forall a.
InputStream ByteString
-> (ByteString -> IO ()) -> IO () -> a -> Connection a
Connection InputStream ByteString
is ByteString -> IO ()
forall {m :: * -> *}. MonadIO m => ByteString -> m ()
write (Context -> IO ()
closeTLS Context
ctx) (Context
ctx, SockAddr
addr))
  where
    input :: IO (Maybe ByteString)
input = (do
        ByteString
s <- Context -> IO ByteString
forall (m :: * -> *). MonadIO m => Context -> m ByteString
TLS.recvData Context
ctx
        Maybe ByteString -> IO (Maybe ByteString)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe ByteString -> IO (Maybe ByteString))
-> Maybe ByteString -> IO (Maybe ByteString)
forall a b. (a -> b) -> a -> b
$! if ByteString -> Bool
B.null ByteString
s then Maybe ByteString
forall a. Maybe a
Nothing else ByteString -> Maybe ByteString
forall a. a -> Maybe a
Just ByteString
s
        ) IO (Maybe ByteString)
-> (SomeException -> IO (Maybe ByteString))
-> IO (Maybe ByteString)
forall e a. Exception e => IO a -> (e -> IO a) -> IO a
`E.catch` (\(SomeException
_::E.SomeException) -> Maybe ByteString -> IO (Maybe ByteString)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe ByteString
forall a. Maybe a
Nothing)
    write :: ByteString -> m ()
write ByteString
s = Context -> ByteString -> m ()
forall (m :: * -> *). MonadIO m => Context -> ByteString -> m ()
TLS.sendData Context
ctx ByteString
s
closeTLS :: Context -> IO ()
closeTLS :: Context -> IO ()
closeTLS Context
ctx = (Context -> IO ()
forall (m :: * -> *). MonadIO m => Context -> m ()
TLS.bye Context
ctx IO () -> IO () -> IO ()
forall a b. IO a -> IO b -> IO b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Context -> IO ()
TLS.contextClose Context
ctx) 
    IO () -> (SomeException -> IO ()) -> IO ()
forall e a. Exception e => IO a -> (e -> IO a) -> IO a
`E.catch` (\(SomeException
_::E.SomeException) -> () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ())   
connectTLS :: ClientParams         
           -> Maybe String         
                                   
           -> N.HostName           
           -> N.PortNumber         
           -> IO (Context, N.SockAddr)
connectTLS :: ClientParams
-> Maybe String -> String -> PortNumber -> IO (Context, SockAddr)
connectTLS ClientParams
prms Maybe String
subname String
host PortNumber
port = do
    let subname' :: String
subname' = String -> (String -> String) -> Maybe String -> String
forall b a. b -> (a -> b) -> Maybe a -> b
maybe String
host String -> String
forall a. a -> a
id Maybe String
subname
        prms' :: ClientParams
prms' = ClientParams
prms { TLS.clientServerIdentification = (subname', BC.pack (show port)) }
    (Socket
sock, SockAddr
addr) <- String -> PortNumber -> IO (Socket, SockAddr)
TCP.connectSocket String
host PortNumber
port
    IO Context
-> (Context -> IO ())
-> (Context -> IO (Context, SockAddr))
-> IO (Context, SockAddr)
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
E.bracketOnError (Socket -> ClientParams -> IO Context
forall (m :: * -> *) backend params.
(MonadIO m, HasBackend backend, TLSParams params) =>
backend -> params -> m Context
TLS.contextNew Socket
sock ClientParams
prms') Context -> IO ()
closeTLS ((Context -> IO (Context, SockAddr)) -> IO (Context, SockAddr))
-> (Context -> IO (Context, SockAddr)) -> IO (Context, SockAddr)
forall a b. (a -> b) -> a -> b
$ \ Context
ctx -> do
        Context -> IO ()
forall (m :: * -> *). MonadIO m => Context -> m ()
TLS.handshake Context
ctx
        (Context, SockAddr) -> IO (Context, SockAddr)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Context
ctx, SockAddr
addr)
connect :: ClientParams         
        -> Maybe String         
                                
        -> N.HostName           
        -> N.PortNumber         
        -> IO TLSConnection
connect :: ClientParams
-> Maybe String -> String -> PortNumber -> IO TLSConnection
connect ClientParams
prms Maybe String
subname String
host PortNumber
port = ClientParams
-> Maybe String -> String -> PortNumber -> IO (Context, SockAddr)
connectTLS ClientParams
prms Maybe String
subname String
host PortNumber
port IO (Context, SockAddr)
-> ((Context, SockAddr) -> IO TLSConnection) -> IO TLSConnection
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= (Context, SockAddr) -> IO TLSConnection
tLsToConnection
accept :: ServerParams              
       -> N.Socket                  
       -> IO TLSConnection
accept :: ServerParams -> Socket -> IO TLSConnection
accept ServerParams
prms Socket
sock = do
    (Socket
sock', SockAddr
addr) <- Socket -> IO (Socket, SockAddr)
N.accept Socket
sock
    IO Context
-> (Context -> IO ())
-> (Context -> IO TLSConnection)
-> IO TLSConnection
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
E.bracketOnError (Socket -> ServerParams -> IO Context
forall (m :: * -> *) backend params.
(MonadIO m, HasBackend backend, TLSParams params) =>
backend -> params -> m Context
TLS.contextNew Socket
sock' ServerParams
prms) Context -> IO ()
closeTLS ((Context -> IO TLSConnection) -> IO TLSConnection)
-> (Context -> IO TLSConnection) -> IO TLSConnection
forall a b. (a -> b) -> a -> b
$ \ Context
ctx -> do
        Context -> IO ()
forall (m :: * -> *). MonadIO m => Context -> m ()
TLS.handshake Context
ctx
        TLSConnection
conn <- (Context, SockAddr) -> IO TLSConnection
tLsToConnection (Context
ctx, SockAddr
addr)
        TLSConnection -> IO TLSConnection
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return TLSConnection
conn