module Hookup
(
ConnectionParams(..),
SocksParams(..),
TlsParams(..),
Connection,
connect,
recvLine,
send,
close,
ConnectionFailure(..),
) where
import Control.Concurrent
import Control.Exception
import Control.Monad
import Data.ByteString (ByteString)
import qualified Data.ByteString as B
import Data.Foldable
import Network (PortID(..))
import Network.Socket (Socket, AddrInfo, PortNumber, HostName)
import qualified Network.Socket as Socket
import qualified Network.Socket.ByteString as SocketB
import Network.Socks5
import OpenSSL.Session (SSL, SSLContext)
import qualified OpenSSL as SSL
import qualified OpenSSL.Session as SSL
import qualified OpenSSL.X509 as SSL
import OpenSSL.X509.SystemStore
import qualified OpenSSL.PEM as PEM
import Hookup.OpenSSL (installVerification)
data ConnectionParams = ConnectionParams
{ cpHost :: HostName
, cpPort :: PortNumber
, cpSocks :: Maybe SocksParams
, cpTls :: Maybe TlsParams
}
data SocksParams = SocksParams
{ spHost :: HostName
, spPort :: PortNumber
}
data TlsParams = TlsParams
{ tpClientCertificate :: Maybe FilePath
, tpClientPrivateKey :: Maybe FilePath
, tpServerCertificate :: Maybe FilePath
, tpCipherSuite :: String
, tpInsecure :: Bool
}
data ConnectionFailure
= HostnameResolutionFailure IOError
| ConnectionFailure [IOError]
| LineTooLong
| LineTruncated
deriving Show
instance Exception ConnectionFailure
openSocket :: ConnectionParams -> IO Socket
openSocket params =
case cpSocks params of
Nothing -> openSocket' (cpHost params) (cpPort params)
Just sp -> openSocks sp (cpHost params) (cpPort params)
openSocks :: SocksParams -> HostName -> PortNumber -> IO Socket
openSocks sp h p =
do socksConnectTo'
(spHost sp) (PortNumber (spPort sp))
h (PortNumber p)
openSocket' :: HostName -> PortNumber -> IO Socket
openSocket' h p =
do let hints = Socket.defaultHints
{ Socket.addrSocketType = Socket.Stream
, Socket.addrFlags = [Socket.AI_ADDRCONFIG
,Socket.AI_NUMERICSERV]
}
res <- try (Socket.getAddrInfo (Just hints) (Just h) (Just (show p)))
case res of
Right ais -> attemptConnections [] ais
Left ioe -> throwIO (HostnameResolutionFailure ioe)
attemptConnections :: [IOError] -> [Socket.AddrInfo] -> IO Socket
attemptConnections exs [] = throwIO (ConnectionFailure exs)
attemptConnections exs (ai:ais) =
do s <- socket' ai
res <- try (Socket.connect s (Socket.addrAddress ai))
case res of
Left ex -> do Socket.close s
attemptConnections (ex:exs) ais
Right{} -> return s
socket' :: AddrInfo -> IO Socket
socket' ai =
Socket.socket
(Socket.addrFamily ai)
(Socket.addrSocketType ai)
(Socket.addrProtocol ai)
data NetworkHandle = SSL SSL | Socket Socket
openNetworkHandle :: ConnectionParams -> IO NetworkHandle
openNetworkHandle params =
do s <- openSocket params
case cpTls params of
Nothing -> return (Socket s)
Just tp -> SSL <$> startTls (cpHost params) tp s
closeNetworkHandle :: NetworkHandle -> IO ()
closeNetworkHandle (SSL s) = SSL.shutdown s SSL.Unidirectional
closeNetworkHandle (Socket s) = Socket.close s
networkSend :: NetworkHandle -> ByteString -> IO ()
networkSend (Socket s) = SocketB.sendAll s
networkSend (SSL s) = SSL.write s
networkRecv :: NetworkHandle -> Int -> IO ByteString
networkRecv (Socket s) = SocketB.recv s
networkRecv (SSL s) = SSL.read s
data Connection = Connection (MVar ByteString) NetworkHandle
connect :: ConnectionParams -> IO Connection
connect params =
do h <- openNetworkHandle params
b <- newMVar B.empty
return (Connection b h)
close :: Connection -> IO ()
close (Connection _ h) = closeNetworkHandle h
recvLine :: Connection -> Int -> IO (Maybe ByteString)
recvLine (Connection buf h) n =
modifyMVar buf $ \bs ->
go (B.length bs) bs []
where
go bsn bs bss =
case B.elemIndex 10 bs of
Just i -> return (B.tail b,
Just (cleanEnd (B.concat (reverse (a:bss)))))
where
(a,b) = B.splitAt i bs
Nothing ->
do when (bsn >= n) (throwIO LineTooLong)
more <- networkRecv h n
if B.null more
then if B.null bs then return (B.empty, Nothing)
else throwIO LineTruncated
else go (bsn + B.length more) more (bs:bss)
cleanEnd :: ByteString -> ByteString
cleanEnd bs
| B.null bs || B.last bs /= 13 = bs
| otherwise = B.init bs
send :: Connection -> ByteString -> IO ()
send (Connection _ h) = networkSend h
startTls ::
HostName ->
TlsParams ->
Socket ->
IO SSL
startTls host tp s = SSL.withOpenSSL $
do ctx <- SSL.context
SSL.contextSetCiphers ctx (tpCipherSuite tp)
installVerification ctx host
SSL.contextSetVerificationMode ctx (verificationMode (tpInsecure tp))
SSL.contextAddOption ctx SSL.SSL_OP_ALL
SSL.contextRemoveOption ctx SSL.SSL_OP_DONT_INSERT_EMPTY_FRAGMENTS
setupCaCertificates ctx (tpServerCertificate tp)
traverse_ (setupCertificate ctx) (tpClientCertificate tp)
traverse_ (setupPrivateKey ctx) (tpClientPrivateKey tp)
ssl <- SSL.connection ctx s
SSL.setTlsextHostName ssl host
SSL.connect ssl
return ssl
setupCaCertificates :: SSLContext -> Maybe FilePath -> IO ()
setupCaCertificates ctx mbPath =
case mbPath of
Nothing -> contextLoadSystemCerts ctx
Just path -> SSL.contextSetCAFile ctx path
setupCertificate :: SSLContext -> FilePath -> IO ()
setupCertificate ctx path
= SSL.contextSetCertificate ctx
=<< PEM.readX509
=<< readFile path
setupPrivateKey :: SSLContext -> FilePath -> IO ()
setupPrivateKey ctx path =
do str <- readFile path
key <- PEM.readPrivateKey str PEM.PwNone
SSL.contextSetPrivateKey ctx key
verificationMode :: Bool -> SSL.VerificationMode
verificationMode insecure
| insecure = SSL.VerifyNone
| otherwise = SSL.VerifyPeer
{ SSL.vpFailIfNoPeerCert = True
, SSL.vpClientOnce = True
, SSL.vpCallback = Nothing
}