module Data.Conduit.Network.TLS
(
TLSConfig
, tlsConfigBS
, tlsConfig
, tlsHost
, tlsPort
, tlsNeedLocalAddr
, tlsAppData
, runTCPServerTLS
, runTCPServerStartTLS
, TLSClientConfig
, tlsClientConfig
, runTLSClient
, runTLSClientStartTLS
, tlsClientPort
, tlsClientHost
, tlsClientUseTLS
, tlsClientTLSSettings
, tlsClientSockSettings
, tlsClientConnectionContext
) where
import Prelude hiding (FilePath, readFile)
import Data.Aeson (FromJSON (parseJSON), (.:), (.:?), (.!=), Value (Object))
import Control.Applicative ((<$>), (<*>))
import Control.Monad (mzero, forever)
import Data.String (fromString)
import Filesystem.Path.CurrentOS ((</>), FilePath)
import Filesystem (readFile)
import qualified Data.ByteString.Lazy as L
import qualified Data.Certificate.KeyRSA as KeyRSA
import qualified Data.PEM as PEM
import qualified Network.TLS as TLS
import qualified Data.Certificate.X509 as X509
import Data.Conduit.Network (HostPreference, Application, bindPort, sinkSocket, acceptSafe, runTCPServerWithHandle, ConnectionHandle(..), serverSettings, sourceSocket)
import Data.Conduit.Network.Internal (AppData (..))
import Data.Conduit.Network.TLS.Internal
import Data.Conduit (($$), yield, awaitForever, Producer, Consumer)
import qualified Data.Conduit.List as CL
import Data.Either (rights)
import Network.Socket (sClose, getSocketName, SockAddr (SockAddrInet))
import Network.Socket.ByteString (recv, sendAll)
import Control.Exception (bracket, finally)
import Control.Concurrent (forkIO)
import Control.Monad.Trans.Class (lift)
import Control.Monad.IO.Class (liftIO, MonadIO)
import qualified Network.TLS.Extra as TLSExtra
import Crypto.Random.API (getSystemRandomGen, SystemRandom)
import Network.Socket (Socket)
import qualified Data.ByteString as S
import qualified Data.ByteString.Char8 as S8
import qualified Crypto.Random.AESCtr
import qualified Network.Connection as NC
import Control.Monad.Trans.Control
import Data.Default
makeCertDataPath :: FilePath -> FilePath -> TlsCertData
makeCertDataPath certPath keyPath = TlsCertData (readFile certPath) (readFile keyPath)
makeCertDataBS :: S.ByteString -> S.ByteString -> TlsCertData
makeCertDataBS certBS keyBS = TlsCertData (return certBS) (return keyBS)
tlsConfig :: HostPreference
-> Int
-> FilePath
-> FilePath
-> TLSConfig
tlsConfig a b c d = TLSConfig a b (makeCertDataPath c d) False
tlsConfigBS :: HostPreference
-> Int
-> S.ByteString
-> S.ByteString
-> TLSConfig
tlsConfigBS a b c d = TLSConfig a b (makeCertDataBS c d ) False
serverHandshake :: Socket -> TLS.Credentials -> IO (TLS.Context)
serverHandshake socket creds = do
gen <- Crypto.Random.AESCtr.makeSystem
ctx <- TLS.contextNew
TLS.Backend
{ TLS.backendFlush = return ()
, TLS.backendClose = return ()
, TLS.backendSend = sendAll socket
, TLS.backendRecv = recvExact socket
}
params
gen
TLS.handshake ctx
return ctx
where
params = def
{ TLS.serverWantClientCert = False
, TLS.serverSupported = def
{ TLS.supportedCiphers = ciphers
, TLS.supportedVersions = [TLS.SSL3,TLS.TLS10,TLS.TLS11,TLS.TLS12]
}
, TLS.serverShared = def
{ TLS.sharedCredentials = creds
}
}
runTCPServerTLS :: TLSConfig -> Application IO -> IO ()
runTCPServerTLS TLSConfig{..} app = do
creds <- readCreds tlsCertData
runTCPServerWithHandle settings (wrapApp creds)
where
settings = serverSettings tlsPort tlsHost
wrapApp creds = ConnectionHandle app'
where
app' socket addr mlocal = do
ctx <- serverHandshake socket creds
app (tlsAppData ctx addr mlocal)
type ApplicationStartTLS = (AppData IO, Application IO -> IO ()) -> IO ()
runTCPServerStartTLS :: TLSConfig -> ApplicationStartTLS -> IO ()
runTCPServerStartTLS TLSConfig{..} app = do
creds <- readCreds tlsCertData
runTCPServerWithHandle settings (wrapApp creds)
where
settings = serverSettings tlsPort tlsHost
wrapApp creds = ConnectionHandle clearapp
where clearapp socket addr mlocal = let
clearData = AppData
{ appSource = sourceSocket socket
, appSink = sinkSocket socket
, appSockAddr = addr
, appLocalAddr = mlocal
}
startTls = \app' -> do
ctx <- serverHandshake socket creds
app' (tlsAppData ctx addr mlocal)
in
app (clearData, startTls)
tlsAppData :: TLS.Context
-> SockAddr
-> Maybe SockAddr
-> AppData IO
tlsAppData ctx addr mlocal = AppData
{ appSource = forever $ lift (TLS.recvData ctx) >>= yield
, appSink = CL.mapM_ $ TLS.sendData ctx . L.fromChunks . return
, appSockAddr = addr
, appLocalAddr = mlocal
}
ciphers :: [TLS.Cipher]
ciphers =
[ TLSExtra.cipher_AES128_SHA1
, TLSExtra.cipher_AES256_SHA1
, TLSExtra.cipher_RC4_128_MD5
, TLSExtra.cipher_RC4_128_SHA1
]
readCreds :: TlsCertData -> IO TLS.Credentials
readCreds (TlsCertData iocert iokey) =
(TLS.credentialLoadX509FromMemory <$> iocert <*> iokey)
>>= either
(error . ("Error reading TLS credentials: " ++))
(return . TLS.Credentials . return)
readCertificates :: TlsCertData -> IO [X509.X509]
readCertificates certData = do
certs <- rights . parseCerts . PEM.pemParseBS <$> getTLSCert certData
case certs of
[] -> error "no valid certificate found"
(_:_) -> return certs
where parseCerts (Right pems) = map (X509.decodeCertificate . L.fromChunks . (:[]) . PEM.pemContent)
$ filter (flip elem ["CERTIFICATE", "TRUSTED CERTIFICATE"] . PEM.pemName) pems
parseCerts (Left err) = error $ "cannot parse PEM file: " ++ err
readPrivateKey :: TlsCertData -> IO TLS.PrivKey
readPrivateKey certData = do
pk <- rights . parseKey . PEM.pemParseBS <$> getTLSKey certData
case pk of
[] -> error "no valid RSA key found"
(x:_) -> return x
where parseKey (Right pems) = map (fmap (TLS.PrivKeyRSA . snd) . KeyRSA.decodePrivate . L.fromChunks . (:[]) . PEM.pemContent)
$ filter ((== "RSA PRIVATE KEY") . PEM.pemName) pems
parseKey (Left err) = error $ "Cannot parse PEM file: " ++ err
recvExact :: Socket -> Int -> IO S.ByteString
recvExact socket =
loop id
where
loop front rest
| rest < 0 = error "Data.Conduit.Network.TLS.recvExact: rest < 0"
| rest == 0 = return $ S.concat $ front []
| otherwise = do
next <- recv socket rest
if S.length next == 0
then return $ S.concat $ front []
else loop (front . (next:)) $ rest S.length next
data TLSClientConfig (m :: * -> *) = TLSClientConfig
{ tlsClientPort :: Int
, tlsClientHost :: S.ByteString
, tlsClientUseTLS :: Bool
, tlsClientTLSSettings :: NC.TLSSettings
, tlsClientSockSettings :: Maybe NC.SockSettings
, tlsClientConnectionContext :: Maybe NC.ConnectionContext
}
tlsClientConfig :: Int
-> S.ByteString
-> TLSClientConfig m
tlsClientConfig port host = TLSClientConfig
{ tlsClientPort = port
, tlsClientHost = host
, tlsClientUseTLS = True
, tlsClientTLSSettings = def
, tlsClientSockSettings = Nothing
, tlsClientConnectionContext = Nothing
}
runTLSClient :: (MonadIO m, MonadBaseControl IO m)
=> TLSClientConfig m
-> Application m
-> m ()
runTLSClient TLSClientConfig {..} app = do
context <- maybe (liftIO NC.initConnectionContext) return tlsClientConnectionContext
let params = NC.ConnectionParams
{ NC.connectionHostname = S8.unpack tlsClientHost
, NC.connectionPort = fromIntegral tlsClientPort
, NC.connectionUseSecure =
if tlsClientUseTLS
then Just tlsClientTLSSettings
else Nothing
, NC.connectionUseSocks = tlsClientSockSettings
}
control $ \run -> bracket
(NC.connectTo context params)
NC.connectionClose
(\conn -> run $ app AppData
{ appSource = sourceConnection conn
, appSink = sinkConnection conn
, appSockAddr = SockAddrInet (fromIntegral tlsClientPort) 0
, appLocalAddr = Nothing
})
runTLSClientStartTLS :: TLSClientConfig IO
-> ApplicationStartTLS
-> IO ()
runTLSClientStartTLS TLSClientConfig {..} app = do
context <- maybe (liftIO NC.initConnectionContext) return tlsClientConnectionContext
let params = NC.ConnectionParams
{ NC.connectionHostname = S8.unpack tlsClientHost
, NC.connectionPort = fromIntegral tlsClientPort
, NC.connectionUseSecure = Nothing
, NC.connectionUseSocks = tlsClientSockSettings
}
tlsSettings = tlsClientTLSSettings
control $ \run -> bracket
(NC.connectTo context params)
NC.connectionClose
(\conn -> run $ app (
AppData
{ appSource = sourceConnection conn
, appSink = sinkConnection conn
, appSockAddr = SockAddrInet (fromIntegral tlsClientPort) 0
, appLocalAddr = Nothing
}
, \app' -> do
NC.connectionSetSecure context conn tlsClientTLSSettings
app' AppData
{ appSource = sourceConnection conn
, appSink = sinkConnection conn
, appSockAddr = SockAddrInet (fromIntegral tlsClientPort) 0
, appLocalAddr = Nothing
}
)
)
sourceConnection :: MonadIO m => NC.Connection -> Producer m S.ByteString
sourceConnection conn =
loop
where
loop = do
bs <- liftIO $ NC.connectionGetChunk conn
if S.null bs
then return ()
else yield bs >> loop
sinkConnection :: MonadIO m => NC.Connection -> Consumer S.ByteString m ()
sinkConnection conn = awaitForever (liftIO . NC.connectionPut conn)