module Network.Secure.Connection
(
HostName
, ServiceName
, Connection
, peer
, Network.Secure.Connection.connect
, Network.Secure.Connection.read
, Network.Secure.Connection.write
, Network.Secure.Connection.readPtr
, Network.Secure.Connection.writePtr
, Network.Secure.Connection.close
, Network.Secure.Connection.Socket
, newServer
, Network.Secure.Connection.accept
) where
import Prelude hiding (read)
import Control.Applicative ((<$>))
import Control.Exception (bracketOnError, onException)
import Control.Monad (liftM, unless)
import Data.ByteString (ByteString)
import Data.Maybe (fromJust)
import OpenSSL.Session (ShutdownType(Unidirectional), SSLContext, SSL,
VerificationMode(VerifyPeer), accept, connect,
connection, context, contextSetPrivateKey,
contextSetCertificate, contextSetCiphers,
contextSetVerificationMode, contextGetCAStore,
getPeerCertificate, getVerifyResult, read, shutdown,
write, readPtr, writePtr)
import OpenSSL.X509 (compareX509)
import OpenSSL.X509.Store (addCertToStore)
import Network.Socket hiding (shutdown)
import Network.Secure.Identity
import Foreign.Ptr(Ptr)
data Connection = C
{
ssl :: SSL
, peer :: PeerIdentity
, _addr :: SockAddr
}
instance Eq Connection where
(C _ p1 a1) == (C _ p2 a2) = (p1, a1) == (p2, a2)
instance Show Connection where
show (C _ p a) = concat [ "Connection { peer = "
, show p
, ", addr = "
, show a
, " }" ]
newtype Socket = S {
unSocket :: Network.Socket.Socket
} deriving (Eq, Show)
connect :: LocalIdentity -> [PeerIdentity] -> (HostName, ServiceName)
-> IO Connection
connect myId peerIds (host, port) =
do info <- getSockAddr (Just host) port
bracketOnError (newSock info) sClose $ \sock -> do
setSocketOption sock ReuseAddr 1
Network.Socket.connect sock (addrAddress info)
r <- connectSSL myId peerIds False sock
return r
read :: Connection -> Int -> IO ByteString
read = OpenSSL.Session.read . ssl
write :: Connection -> ByteString -> IO ()
write = OpenSSL.Session.write . ssl
readPtr :: Connection -> Ptr a -> Int -> IO Int
readPtr c p n = OpenSSL.Session.readPtr (ssl c) p n
writePtr :: Connection -> Ptr a -> Int -> IO ()
writePtr c p n = OpenSSL.Session.writePtr (ssl c) p n
close :: Connection -> IO ()
close conn = shutdown (ssl conn) Unidirectional
newServer :: (Maybe HostName, ServiceName)
-> IO Network.Secure.Connection.Socket
newServer (host, port) = do
info <- getSockAddr host port
sock <- newSock info
setSocketOption sock ReuseAddr 1
bindSocket sock (addrAddress info)
listen sock 10
return $ S sock
accept :: LocalIdentity -> [PeerIdentity] -> Network.Secure.Connection.Socket
-> IO Connection
accept myId peerIds listenSock = do
sock <- fst <$> Network.Socket.accept (unSocket listenSock)
connectSSL myId peerIds True sock
getSockAddr :: Maybe HostName -> ServiceName -> IO Network.Socket.AddrInfo
getSockAddr hn sn = do
let hints = defaultHints { addrFlags = [AI_PASSIVE, AI_ADDRCONFIG]
, addrSocketType = Stream
}
info <- getAddrInfo (Just hints) hn (Just sn)
return (head info)
connectSSL :: LocalIdentity -> [PeerIdentity] -> Bool -> Network.Socket.Socket
-> IO Connection
connectSSL myId peerIds isServer sock = do
sslCtx <- newSSLContext myId peerIds
conn <- connection sslCtx sock
flip onException (shutdown conn Unidirectional) $ do
initiate conn
verifyConnection conn >>= flip unless (fail "Peer verification error")
peerId <- fromX509 . fromJust =<< getPeerCertificate conn
C conn peerId <$> getPeerName sock
where
verifyConnection conn = do
verified <- getVerifyResult conn
if not verified then return False else
getPeerCertificate conn >>= \c -> case c of
Nothing -> return False
Just cert -> do
let match = liftM (EQ ==) . compareX509 cert . piX509
anyM match peerIds
initiate = if isServer
then OpenSSL.Session.accept
else OpenSSL.Session.connect
newSock :: Network.Socket.AddrInfo -> IO Network.Socket.Socket
newSock i = socket (addrFamily i) (addrSocketType i) (addrProtocol i)
newSSLContext :: LocalIdentity -> [PeerIdentity] -> IO SSLContext
newSSLContext localId validCerts = do
ctx <- context
contextSetPrivateKey ctx (liKey localId)
contextSetCertificate ctx (liX509 localId)
contextSetCiphers ctx "AES256-SHA"
contextSetVerificationMode ctx $ VerifyPeer True False Nothing
store <- contextGetCAStore ctx
mapM_ (addCertToStore store . piX509) validCerts
return ctx
anyM :: (Monad m) => (a -> m Bool) -> [a] -> m Bool
anyM _ [] = return False
anyM test (x:xs) = test x >>= \r -> if r then return True else anyM test xs