module Data.Conduit.Network.TLS
(
TLSConfig
, tlsConfigBS
, tlsConfig
, tlsHost
, tlsPort
, tlsNeedLocalAddr
, tlsAppData
, runTCPServerTLS
, runTCPServerStartTLS
, TLSClientConfig
, tlsClientConfig
, runTLSClient
, runTLSClientStartTLS
, tlsClientPort
, tlsClientHost
, tlsClientUseTLS
, tlsClientTLSSettings
, tlsClientSockSettings
, tlsClientConnectionContext
) where
import Prelude hiding (FilePath, readFile)
import Control.Applicative ((<$>), (<*>))
import Control.Monad (forever)
import Filesystem.Path.CurrentOS (FilePath)
import Filesystem (readFile)
import qualified Data.ByteString.Lazy as L
import qualified Network.TLS as TLS
import Data.Conduit.Network (sinkSocket, runTCPServerWithHandle, serverSettings, sourceSocket)
import Data.Streaming.Network.Internal (AppData (..), HostPreference)
import Data.Streaming.Network (ConnectionHandle, safeRecv)
import Data.Conduit.Network.TLS.Internal
import Data.Conduit (yield, awaitForever, Producer, Consumer)
import qualified Data.Conduit.List as CL
import Network.Socket (SockAddr (SockAddrInet), sClose)
import Network.Socket.ByteString (recv, sendAll)
import Control.Exception (bracket)
import Control.Monad.Trans.Class (lift)
import Control.Monad.IO.Class (liftIO, MonadIO)
import qualified Network.TLS.Extra as TLSExtra
import Network.Socket (Socket)
import qualified Data.ByteString as S
import qualified Data.ByteString.Char8 as S8
import qualified Crypto.Random.AESCtr
import qualified Network.Connection as NC
import Control.Monad.Trans.Control
import Data.Default
makeCertDataPath :: FilePath -> FilePath -> TlsCertData
makeCertDataPath certPath keyPath = TlsCertData (readFile certPath) (readFile keyPath)
makeCertDataBS :: S.ByteString -> S.ByteString -> TlsCertData
makeCertDataBS certBS keyBS = TlsCertData (return certBS) (return keyBS)
tlsConfig :: HostPreference
-> Int
-> FilePath
-> FilePath
-> TLSConfig
tlsConfig a b c d = TLSConfig a b (makeCertDataPath c d) False
tlsConfigBS :: HostPreference
-> Int
-> S.ByteString
-> S.ByteString
-> TLSConfig
tlsConfigBS a b c d = TLSConfig a b (makeCertDataBS c d ) False
serverHandshake :: Socket -> TLS.Credentials -> IO (TLS.Context)
serverHandshake socket creds = do
gen <- Crypto.Random.AESCtr.makeSystem
ctx <- TLS.contextNew
TLS.Backend
{ TLS.backendFlush = return ()
, TLS.backendClose = return ()
, TLS.backendSend = sendAll socket
, TLS.backendRecv = recvExact socket
}
params
gen
TLS.handshake ctx
return ctx
where
params = def
{ TLS.serverWantClientCert = False
, TLS.serverSupported = def
{ TLS.supportedCiphers = ciphers
, TLS.supportedVersions = [TLS.SSL3,TLS.TLS10,TLS.TLS11,TLS.TLS12]
}
, TLS.serverShared = def
{ TLS.sharedCredentials = creds
}
}
runTCPServerTLS :: TLSConfig -> (AppData -> IO ()) -> IO ()
runTCPServerTLS TLSConfig{..} app = do
creds <- readCreds tlsCertData
runTCPServerWithHandle settings (wrapApp creds)
where
settings = serverSettings tlsPort tlsHost
wrapApp creds = app'
where
app' socket addr mlocal = do
ctx <- serverHandshake socket creds
app (tlsAppData ctx addr mlocal)
TLS.bye ctx
type ApplicationStartTLS = (AppData, (AppData -> IO ()) -> IO ()) -> IO ()
runTCPServerStartTLS :: TLSConfig -> ApplicationStartTLS -> IO ()
runTCPServerStartTLS TLSConfig{..} app = do
creds <- readCreds tlsCertData
runTCPServerWithHandle settings (wrapApp creds)
where
settings = serverSettings tlsPort tlsHost
wrapApp creds = clearapp
where clearapp socket addr mlocal = let
clearData = AppData
{ appRead' = safeRecv socket 4096
, appWrite' = sendAll socket
, appSockAddr' = addr
, appLocalAddr' = mlocal
#if MIN_VERSION_streaming_commons(0,1,6)
, appCloseConnection' = sClose socket
#endif
}
startTls = \app' -> do
ctx <- serverHandshake socket creds
app' (tlsAppData ctx addr mlocal)
TLS.bye ctx
in
app (clearData, startTls)
tlsAppData :: TLS.Context
-> SockAddr
-> Maybe SockAddr
-> AppData
tlsAppData ctx addr mlocal = AppData
{ appRead' = TLS.recvData ctx
, appWrite' = TLS.sendData ctx . L.fromChunks . return
, appSockAddr' = addr
, appLocalAddr' = mlocal
#if MIN_VERSION_streaming_commons(0,1,6)
, appCloseConnection' = TLS.contextClose ctx
#endif
}
ciphers :: [TLS.Cipher]
ciphers =
[ TLSExtra.cipher_AES128_SHA1
, TLSExtra.cipher_AES256_SHA1
, TLSExtra.cipher_RC4_128_MD5
, TLSExtra.cipher_RC4_128_SHA1
]
readCreds :: TlsCertData -> IO TLS.Credentials
readCreds (TlsCertData iocert iokey) =
(TLS.credentialLoadX509FromMemory <$> iocert <*> iokey)
>>= either
(error . ("Error reading TLS credentials: " ++))
(return . TLS.Credentials . return)
recvExact :: Socket -> Int -> IO S.ByteString
recvExact socket =
loop id
where
loop front rest
| rest < 0 = error "Data.Conduit.Network.TLS.recvExact: rest < 0"
| rest == 0 = return $ S.concat $ front []
| otherwise = do
next <- recv socket rest
if S.length next == 0
then return $ S.concat $ front []
else loop (front . (next:)) $ rest S.length next
data TLSClientConfig = TLSClientConfig
{ tlsClientPort :: Int
, tlsClientHost :: S.ByteString
, tlsClientUseTLS :: Bool
, tlsClientTLSSettings :: NC.TLSSettings
, tlsClientSockSettings :: Maybe NC.SockSettings
, tlsClientConnectionContext :: Maybe NC.ConnectionContext
}
tlsClientConfig :: Int
-> S.ByteString
-> TLSClientConfig
tlsClientConfig port host = TLSClientConfig
{ tlsClientPort = port
, tlsClientHost = host
, tlsClientUseTLS = True
, tlsClientTLSSettings = def
, tlsClientSockSettings = Nothing
, tlsClientConnectionContext = Nothing
}
runTLSClient :: (MonadIO m, MonadBaseControl IO m)
=> TLSClientConfig
-> (AppData -> m a)
-> m a
runTLSClient TLSClientConfig {..} app = do
context <- maybe (liftIO NC.initConnectionContext) return tlsClientConnectionContext
let params = NC.ConnectionParams
{ NC.connectionHostname = S8.unpack tlsClientHost
, NC.connectionPort = fromIntegral tlsClientPort
, NC.connectionUseSecure =
if tlsClientUseTLS
then Just tlsClientTLSSettings
else Nothing
, NC.connectionUseSocks = tlsClientSockSettings
}
control $ \run -> bracket
(NC.connectTo context params)
NC.connectionClose
(\conn -> run $ app AppData
{ appRead' = NC.connectionGetChunk conn
, appWrite' = NC.connectionPut conn
, appSockAddr' = SockAddrInet (fromIntegral tlsClientPort) 0
, appLocalAddr' = Nothing
#if MIN_VERSION_streaming_commons(0,1,6)
, appCloseConnection' = NC.connectionClose conn
#endif
})
runTLSClientStartTLS :: TLSClientConfig
-> ApplicationStartTLS
-> IO ()
runTLSClientStartTLS TLSClientConfig {..} app = do
context <- maybe (liftIO NC.initConnectionContext) return tlsClientConnectionContext
let params = NC.ConnectionParams
{ NC.connectionHostname = S8.unpack tlsClientHost
, NC.connectionPort = fromIntegral tlsClientPort
, NC.connectionUseSecure = Nothing
, NC.connectionUseSocks = tlsClientSockSettings
}
control $ \run -> bracket
(NC.connectTo context params)
NC.connectionClose
(\conn -> run $ app (
AppData
{ appRead' = NC.connectionGetChunk conn
, appWrite' = NC.connectionPut conn
, appSockAddr' = SockAddrInet (fromIntegral tlsClientPort) 0
, appLocalAddr' = Nothing
#if MIN_VERSION_streaming_commons(0,1,6)
, appCloseConnection' = NC.connectionClose conn
#endif
}
, \app' -> do
NC.connectionSetSecure context conn tlsClientTLSSettings
app' AppData
{ appRead' = NC.connectionGetChunk conn
, appWrite' = NC.connectionPut conn
, appSockAddr' = SockAddrInet (fromIntegral tlsClientPort) 0
, appLocalAddr' = Nothing
#if MIN_VERSION_streaming_commons(0,1,6)
, appCloseConnection' = NC.connectionClose conn
#endif
}
)
)
sourceConnection :: MonadIO m => NC.Connection -> Producer m S.ByteString
sourceConnection conn =
loop
where
loop = do
bs <- liftIO $ NC.connectionGetChunk conn
if S.null bs
then return ()
else yield bs >> loop
sinkConnection :: MonadIO m => NC.Connection -> Consumer S.ByteString m ()
sinkConnection conn = awaitForever (liftIO . NC.connectionPut conn)