module Network.Connection
(
Connection
, connectionID
, ConnectionParams(..)
, TLSSettings(..)
, SockSettings(..)
, initConnectionContext
, ConnectionContext
, connectFromHandle
, connectTo
, connectionClose
, connectionGet
, connectionGetChunk
, connectionPut
, connectionSetSecure
, connectionIsSecure
) where
import Control.Applicative
import Control.Concurrent.MVar
import qualified Network.TLS as TLS
import qualified Network.TLS.Extra as TLS
import System.Certificate.X509 (getSystemCertificateStore)
import Network.Socks5
import qualified Network as N
import Data.ByteString (ByteString)
import qualified Data.ByteString as B
import qualified Data.ByteString.Lazy as L
import qualified Crypto.Random.AESCtr as RNG
import System.IO
import qualified Data.Map as M
import Network.Connection.Types
type Manager = MVar (M.Map TLS.SessionID TLS.SessionData)
data ConnectionSessionManager = ConnectionSessionManager Manager
instance TLS.SessionManager ConnectionSessionManager where
sessionResume (ConnectionSessionManager mvar) sessionID =
withMVar mvar (return . M.lookup sessionID)
sessionEstablish (ConnectionSessionManager mvar) sessionID sessionData =
modifyMVar_ mvar (return . M.insert sessionID sessionData)
sessionInvalidate (ConnectionSessionManager mvar) sessionID =
modifyMVar_ mvar (return . M.delete sessionID)
initConnectionContext :: IO ConnectionContext
initConnectionContext = ConnectionContext <$> getSystemCertificateStore
makeTLSParams :: ConnectionContext -> TLSSettings -> TLS.Params
makeTLSParams cg ts@(TLSSettingsSimple {}) =
TLS.defaultParamsClient
{ TLS.pConnectVersion = TLS.TLS11
, TLS.pAllowedVersions = [TLS.TLS10,TLS.TLS11,TLS.TLS12]
, TLS.pCiphers = TLS.ciphersuite_all
, TLS.pCertificates = []
, TLS.onCertificatesRecv = if settingDisableCertificateValidation ts
then const $ return TLS.CertificateUsageAccept
else TLS.certificateVerifyChain (globalCertificateStore cg)
}
makeTLSParams _ (TLSSettings p) = p
withBackend :: (ConnectionBackend -> IO a) -> Connection -> IO a
withBackend f conn = modifyMVar (connectionBackend conn) (\b -> f b >>= \a -> return (b,a))
withBuffer :: (ByteString -> IO (ByteString, b)) -> Connection -> IO b
withBuffer f conn = modifyMVar (connectionBuffer conn) f
connectionNew :: ConnectionParams -> ConnectionBackend -> IO Connection
connectionNew p backend = Connection <$> newMVar backend <*> newMVar B.empty <*> pure (connectionHostname p, connectionPort p)
connectFromHandle :: ConnectionContext
-> Handle
-> ConnectionParams
-> IO Connection
connectFromHandle cg h p = withSecurity (connectionUseSecure p)
where withSecurity Nothing = connectionNew p $ ConnectionStream h
withSecurity (Just tlsSettings) = tlsEstablish h (makeTLSParams cg tlsSettings) >>= connectionNew p . ConnectionTLS
connectTo :: ConnectionContext
-> ConnectionParams
-> IO Connection
connectTo cg cParams = do
h <- conFct (connectionHostname cParams) (N.PortNumber $ connectionPort cParams)
connectFromHandle cg h cParams
where
conFct = case connectionUseSocks cParams of
Nothing -> N.connectTo
Just (SockSettingsSimple h p) -> socksConnectTo h (N.PortNumber p)
connectionPut :: Connection -> ByteString -> IO ()
connectionPut connection content = withBackend doWrite connection
where doWrite (ConnectionStream h) = B.hPut h content >> hFlush h
doWrite (ConnectionTLS ctx) = TLS.sendData ctx $ L.fromChunks [content]
connectionGet :: Connection -> Int -> IO ByteString
connectionGet con size = withBuffer getData con
where getData buf
| B.null buf = do chunk <- withBackend getMoreData con
let (ret, remain) = B.splitAt size chunk
return (remain, ret)
| B.length buf >= size = let (ret, remain) = B.splitAt size buf
in return (remain, ret)
| otherwise = return (B.empty, buf)
getMoreData (ConnectionTLS tlsctx) = TLS.recvData tlsctx
getMoreData (ConnectionStream h) = hWaitForInput h (1) >> B.hGetNonBlocking h (16 * 1024)
connectionGetChunk :: Connection -> IO ByteString
connectionGetChunk con = withBuffer getData con
where getData buf
| B.null buf = withBackend getMoreData con >>= \chunk -> return (B.empty, chunk)
| otherwise = return (B.empty, buf)
getMoreData (ConnectionTLS tlsctx) = TLS.recvData tlsctx
getMoreData (ConnectionStream h) = hWaitForInput h (1) >> B.hGetNonBlocking h (16 * 1024)
connectionClose :: Connection -> IO ()
connectionClose = withBackend backendClose
where backendClose (ConnectionTLS ctx) = TLS.bye ctx >> TLS.contextClose ctx
backendClose (ConnectionStream h) = hClose h
connectionSetSecure :: ConnectionContext
-> Connection
-> TLSSettings
-> IO ()
connectionSetSecure cg connection params =
modifyMVar_ (connectionBuffer connection) $ \b ->
modifyMVar (connectionBackend connection) $ \backend ->
case backend of
(ConnectionStream h) -> do ctx <- tlsEstablish h (makeTLSParams cg params)
return (ConnectionTLS ctx, B.empty)
(ConnectionTLS _) -> return (backend, b)
connectionIsSecure :: Connection -> IO Bool
connectionIsSecure conn = withBackend isSecure conn
where isSecure (ConnectionStream _) = return False
isSecure (ConnectionTLS _) = return True
tlsEstablish :: Handle -> TLS.TLSParams -> IO TLS.Context
tlsEstablish handle tlsParams = do
rng <- RNG.makeSystem
ctx <- TLS.contextNewOnHandle handle tlsParams rng
TLS.handshake ctx
return ctx