module Data.Conduit.Network.TLS
(
ApplicationStartTLS
, TLSConfig
, tlsConfigBS
, tlsConfig
, tlsConfigChainBS
, tlsConfigChain
, tlsHost
, tlsPort
, tlsNeedLocalAddr
, tlsAppData
, runTCPServerTLS
, runGeneralTCPServerTLS
, runTCPServerStartTLS
, TLSClientConfig
, tlsClientConfig
, runTLSClient
, runTLSClientStartTLS
, tlsClientPort
, tlsClientHost
, tlsClientUseTLS
, tlsClientTLSSettings
, tlsClientSockSettings
, tlsClientConnectionContext
) where
import Control.Applicative ((<$>), (<*>))
import Control.Monad (forever, void)
import qualified Data.ByteString.Lazy as L
import qualified Network.TLS as TLS
import Data.Conduit.Network (sinkSocket, runTCPServerWithHandle, serverSettings, sourceSocket)
import Data.Streaming.Network.Internal (AppData (..), HostPreference)
import Data.Streaming.Network (ConnectionHandle, safeRecv)
import Data.Conduit.Network.TLS.Internal
import Data.Conduit (yield, awaitForever, Producer, Consumer)
import qualified Data.Conduit.List as CL
import Network.Socket (SockAddr (SockAddrInet), sClose)
import Network.Socket.ByteString (sendAll)
import Control.Exception (bracket)
import Control.Monad.Trans.Class (lift)
import Control.Monad.IO.Class (liftIO, MonadIO)
import qualified Network.TLS.Extra as TLSExtra
import Network.Socket (Socket)
import qualified Data.ByteString as S
import qualified Data.ByteString.Char8 as S8
import qualified Crypto.Random.AESCtr
import qualified Network.Connection as NC
import Control.Monad.Trans.Control
import Data.Default
makeCertDataPath :: FilePath -> [FilePath] -> FilePath -> TlsCertData
makeCertDataPath certPath chainCertPaths keyPath =
TlsCertData
(S.readFile certPath)
(mapM S.readFile chainCertPaths)
(S.readFile keyPath)
makeCertDataBS :: S.ByteString -> [S.ByteString] -> S.ByteString ->
TlsCertData
makeCertDataBS certBS chainCertsBS keyBS =
TlsCertData (return certBS) (return chainCertsBS) (return keyBS)
tlsConfig :: HostPreference
-> Int
-> FilePath
-> FilePath
-> TLSConfig
tlsConfig a b c d = tlsConfigChain a b c [] d
tlsConfigBS :: HostPreference
-> Int
-> S.ByteString
-> S.ByteString
-> TLSConfig
tlsConfigBS a b c d = tlsConfigChainBS a b c [] d
tlsConfigChain :: HostPreference
-> Int
-> FilePath
-> [FilePath]
-> FilePath
-> TLSConfig
tlsConfigChain a b c d e = TLSConfig a b (makeCertDataPath c d e) False
tlsConfigChainBS :: HostPreference
-> Int
-> S.ByteString
-> [S.ByteString]
-> S.ByteString
-> TLSConfig
tlsConfigChainBS a b c d e = TLSConfig a b (makeCertDataBS c d e) False
serverHandshake :: Socket -> TLS.Credentials -> IO (TLS.Context)
serverHandshake socket creds = do
#if !MIN_VERSION_tls(1,3,0)
gen <- Crypto.Random.AESCtr.makeSystem
#endif
ctx <- TLS.contextNew
TLS.Backend
{ TLS.backendFlush = return ()
, TLS.backendClose = return ()
, TLS.backendSend = sendAll socket
, TLS.backendRecv = recvExact socket
}
params
#if !MIN_VERSION_tls(1,3,0)
gen
#endif
TLS.handshake ctx
return ctx
where
params = def
{ TLS.serverWantClientCert = False
, TLS.serverSupported = def
{ TLS.supportedCiphers = ciphers
, TLS.supportedVersions = [TLS.SSL3,TLS.TLS10,TLS.TLS11,TLS.TLS12]
}
, TLS.serverShared = def
{ TLS.sharedCredentials = creds
}
}
runTCPServerTLS :: TLSConfig -> (AppData -> IO ()) -> IO ()
runTCPServerTLS TLSConfig{..} app = do
creds <- readCreds tlsCertData
runTCPServerWithHandle settings (wrapApp creds)
where
settings = serverSettings tlsPort tlsHost
wrapApp creds = app'
where
app' socket addr mlocal = do
ctx <- serverHandshake socket creds
app (tlsAppData ctx addr mlocal)
TLS.bye ctx
type ApplicationStartTLS = (AppData, (AppData -> IO ()) -> IO ()) -> IO ()
runGeneralTCPServerTLS :: MonadBaseControl IO m => TLSConfig -> (AppData -> m ()) -> m ()
runGeneralTCPServerTLS config app = liftBaseWith $ \run ->
runTCPServerTLS config $ void . run . app
runTCPServerStartTLS :: TLSConfig -> ApplicationStartTLS -> IO ()
runTCPServerStartTLS TLSConfig{..} app = do
creds <- readCreds tlsCertData
runTCPServerWithHandle settings (wrapApp creds)
where
settings = serverSettings tlsPort tlsHost
wrapApp creds = clearapp
where clearapp socket addr mlocal = let
clearData = AppData
{ appRead' = safeRecv socket 4096
, appWrite' = sendAll socket
, appSockAddr' = addr
, appLocalAddr' = mlocal
#if MIN_VERSION_streaming_commons(0,1,6)
, appCloseConnection' = sClose socket
#endif
#if MIN_VERSION_streaming_commons(0,1,12)
, appRawSocket' = Just socket
#endif
}
startTls = \app' -> do
ctx <- serverHandshake socket creds
app' (tlsAppData ctx addr mlocal)
TLS.bye ctx
in
app (clearData, startTls)
tlsAppData :: TLS.Context
-> SockAddr
-> Maybe SockAddr
-> AppData
tlsAppData ctx addr mlocal = AppData
{ appRead' = TLS.recvData ctx
, appWrite' = TLS.sendData ctx . L.fromChunks . return
, appSockAddr' = addr
, appLocalAddr' = mlocal
#if MIN_VERSION_streaming_commons(0,1,6)
, appCloseConnection' = TLS.contextClose ctx
#endif
#if MIN_VERSION_streaming_commons(0,1,12)
, appRawSocket' = Nothing
#endif
}
ciphers :: [TLS.Cipher]
ciphers =
[ TLSExtra.cipher_AES128_SHA1
, TLSExtra.cipher_AES256_SHA1
, TLSExtra.cipher_RC4_128_MD5
, TLSExtra.cipher_RC4_128_SHA1
]
readCreds :: TlsCertData -> IO TLS.Credentials
readCreds (TlsCertData iocert iochains iokey) =
(TLS.credentialLoadX509ChainFromMemory <$> iocert <*> iochains <*> iokey)
>>= either
(error . ("Error reading TLS credentials: " ++))
(return . TLS.Credentials . return)
recvExact :: Socket -> Int -> IO S.ByteString
recvExact socket =
loop id
where
loop front rest
| rest < 0 = error "Data.Conduit.Network.TLS.recvExact: rest < 0"
| rest == 0 = return $ S.concat $ front []
| otherwise = do
next <- safeRecv socket rest
if S.length next == 0
then return $ S.concat $ front []
else loop (front . (next:)) $ rest S.length next
data TLSClientConfig = TLSClientConfig
{ tlsClientPort :: Int
, tlsClientHost :: S.ByteString
, tlsClientUseTLS :: Bool
, tlsClientTLSSettings :: NC.TLSSettings
, tlsClientSockSettings :: Maybe NC.SockSettings
, tlsClientConnectionContext :: Maybe NC.ConnectionContext
}
tlsClientConfig :: Int
-> S.ByteString
-> TLSClientConfig
tlsClientConfig port host = TLSClientConfig
{ tlsClientPort = port
, tlsClientHost = host
, tlsClientUseTLS = True
, tlsClientTLSSettings = def
, tlsClientSockSettings = Nothing
, tlsClientConnectionContext = Nothing
}
runTLSClient :: (MonadIO m, MonadBaseControl IO m)
=> TLSClientConfig
-> (AppData -> m a)
-> m a
runTLSClient TLSClientConfig {..} app = do
context <- maybe (liftIO NC.initConnectionContext) return tlsClientConnectionContext
let params = NC.ConnectionParams
{ NC.connectionHostname = S8.unpack tlsClientHost
, NC.connectionPort = fromIntegral tlsClientPort
, NC.connectionUseSecure =
if tlsClientUseTLS
then Just tlsClientTLSSettings
else Nothing
, NC.connectionUseSocks = tlsClientSockSettings
}
control $ \run -> bracket
(NC.connectTo context params)
NC.connectionClose
(\conn -> run $ app AppData
{ appRead' = NC.connectionGetChunk conn
, appWrite' = NC.connectionPut conn
, appSockAddr' = SockAddrInet (fromIntegral tlsClientPort) 0
, appLocalAddr' = Nothing
#if MIN_VERSION_streaming_commons(0,1,6)
, appCloseConnection' = NC.connectionClose conn
#endif
#if MIN_VERSION_streaming_commons(0,1,12)
, appRawSocket' = Nothing
#endif
})
runTLSClientStartTLS :: TLSClientConfig
-> ApplicationStartTLS
-> IO ()
runTLSClientStartTLS TLSClientConfig {..} app = do
context <- maybe (liftIO NC.initConnectionContext) return tlsClientConnectionContext
let params = NC.ConnectionParams
{ NC.connectionHostname = S8.unpack tlsClientHost
, NC.connectionPort = fromIntegral tlsClientPort
, NC.connectionUseSecure = Nothing
, NC.connectionUseSocks = tlsClientSockSettings
}
control $ \run -> bracket
(NC.connectTo context params)
NC.connectionClose
(\conn -> run $ app (
AppData
{ appRead' = NC.connectionGetChunk conn
, appWrite' = NC.connectionPut conn
, appSockAddr' = SockAddrInet (fromIntegral tlsClientPort) 0
, appLocalAddr' = Nothing
#if MIN_VERSION_streaming_commons(0,1,6)
, appCloseConnection' = NC.connectionClose conn
#endif
#if MIN_VERSION_streaming_commons(0,1,12)
, appRawSocket' = Nothing
#endif
}
, \app' -> do
NC.connectionSetSecure context conn tlsClientTLSSettings
app' AppData
{ appRead' = NC.connectionGetChunk conn
, appWrite' = NC.connectionPut conn
, appSockAddr' = SockAddrInet (fromIntegral tlsClientPort) 0
, appLocalAddr' = Nothing
#if MIN_VERSION_streaming_commons(0,1,6)
, appCloseConnection' = NC.connectionClose conn
#endif
#if MIN_VERSION_streaming_commons(0,1,12)
, appRawSocket' = Nothing
#endif
}
)
)
sourceConnection :: MonadIO m => NC.Connection -> Producer m S.ByteString
sourceConnection conn =
loop
where
loop = do
bs <- liftIO $ NC.connectionGetChunk conn
if S.null bs
then return ()
else yield bs >> loop
sinkConnection :: MonadIO m => NC.Connection -> Consumer S.ByteString m ()
sinkConnection conn = awaitForever (liftIO . NC.connectionPut conn)