module Network.TLS.Core
(
TLSParams(..)
, TLSLogging(..)
, TLSCertificateUsage(..)
, TLSCertificateRejectReason(..)
, defaultLogging
, defaultParams
, TLSCtx
, ctxHandle
, ctxEOF
, sendPacket
, recvPacket
, client
, server
, bye
, handshake
, sendData
, recvData
) where
import Network.TLS.Struct
import Network.TLS.Cipher
import Network.TLS.Compression
import Network.TLS.Crypto
import Network.TLS.Packet
import Network.TLS.State
import Network.TLS.Sending
import Network.TLS.Receiving
import Data.Maybe
import Data.Certificate.X509
import Data.List (intersect, intercalate, find)
import qualified Data.ByteString as B
import qualified Data.ByteString.Lazy as L
import Crypto.Random
import Control.Applicative ((<$>))
import Control.Concurrent.MVar
import Control.Monad.State
import Control.Exception (throwIO, Exception(), onException, fromException, catch)
import Data.IORef
import System.IO (Handle, hSetBuffering, BufferMode(..), hFlush)
import System.IO.Error (mkIOError, eofErrorType)
import Prelude hiding (catch)
data TLSLogging = TLSLogging
{ loggingPacketSent :: String -> IO ()
, loggingPacketRecv :: String -> IO ()
, loggingIOSent :: Bytes -> IO ()
, loggingIORecv :: Header -> Bytes -> IO ()
}
data TLSCertificateRejectReason =
CertificateRejectExpired
| CertificateRejectRevoked
| CertificateRejectUnknownCA
| CertificateRejectOther String
deriving (Show,Eq)
data TLSCertificateUsage =
CertificateUsageAccept
| CertificateUsageReject TLSCertificateRejectReason
deriving (Show,Eq)
data TLSParams = TLSParams
{ pConnectVersion :: Version
, pAllowedVersions :: [Version]
, pCiphers :: [Cipher]
, pCompressions :: [Compression]
, pWantClientCert :: Bool
, pUseSecureRenegotiation :: Bool
, pCertificates :: [(X509, Maybe PrivateKey)]
, pLogging :: TLSLogging
, onCertificatesRecv :: ([X509] -> IO TLSCertificateUsage)
}
defaultLogging :: TLSLogging
defaultLogging = TLSLogging
{ loggingPacketSent = (\_ -> return ())
, loggingPacketRecv = (\_ -> return ())
, loggingIOSent = (\_ -> return ())
, loggingIORecv = (\_ _ -> return ())
}
defaultParams :: TLSParams
defaultParams = TLSParams
{ pConnectVersion = TLS10
, pAllowedVersions = [TLS10,TLS11]
, pCiphers = []
, pCompressions = [nullCompression]
, pWantClientCert = False
, pUseSecureRenegotiation = True
, pCertificates = []
, pLogging = defaultLogging
, onCertificatesRecv = (\_ -> return CertificateUsageAccept)
}
instance Show TLSParams where
show p = "TLSParams { " ++ (intercalate "," $ map (\(k,v) -> k ++ "=" ++ v)
[ ("connectVersion", show $ pConnectVersion p)
, ("allowedVersions", show $ pAllowedVersions p)
, ("ciphers", show $ pCiphers p)
, ("compressions", show $ pCompressions p)
, ("want-client-cert", show $ pWantClientCert p)
, ("certificates", show $ length $ pCertificates p)
]) ++ " }"
data TLSCtx = TLSCtx
{ ctxHandle :: Handle
, ctxParams :: TLSParams
, ctxState :: MVar TLSState
, ctxEOF_ :: IORef Bool
}
ctxEOF :: MonadIO m => TLSCtx -> m Bool
ctxEOF ctx = liftIO (readIORef $ ctxEOF_ ctx)
throwCore :: (MonadIO m, Exception e) => e -> m a
throwCore = liftIO . throwIO
newCtx :: Handle -> TLSParams -> TLSState -> IO TLSCtx
newCtx handle params st = do
hSetBuffering handle NoBuffering
stvar <- newMVar st
eof <- newIORef False
return $ TLSCtx
{ ctxHandle = handle
, ctxParams = params
, ctxState = stvar
, ctxEOF_ = eof
}
ctxLogging :: TLSCtx -> TLSLogging
ctxLogging = pLogging . ctxParams
usingState :: MonadIO m => TLSCtx -> TLSSt a -> m (Either TLSError a)
usingState ctx f = liftIO (takeMVar mvar) >>= \st -> liftIO $ onException (execAndStore st) (putMVar mvar st)
where
mvar = ctxState ctx
execAndStore st = do
let (a, newst) = runTLSState f st
putMVar mvar newst
return a
usingState_ :: MonadIO m => TLSCtx -> TLSSt a -> m a
usingState_ ctx f = do
ret <- usingState ctx f
case ret of
Left err -> throwCore err
Right r -> return r
getStateRNG :: MonadIO m => TLSCtx -> Int -> m Bytes
getStateRNG ctx n = usingState_ ctx (genTLSRandom n)
whileStatus :: MonadIO m => TLSCtx -> (TLSStatus -> Bool) -> m a -> m ()
whileStatus ctx p a = do
b <- usingState_ ctx (p . stStatus <$> get)
when b (a >> whileStatus ctx p a)
errorToAlert :: TLSError -> Packet
errorToAlert (Error_Protocol (_, _, ad)) = Alert [(AlertLevel_Fatal, ad)]
errorToAlert _ = Alert [(AlertLevel_Fatal, InternalError)]
setEOF :: MonadIO m => TLSCtx -> m ()
setEOF ctx = liftIO $ writeIORef (ctxEOF_ ctx) True
readExact :: MonadIO m => TLSCtx -> Int -> m Bytes
readExact ctx sz = do
hdrbs <- liftIO $ B.hGet (ctxHandle ctx) sz
when (B.length hdrbs < sz) $ do
setEOF ctx
if B.null hdrbs
then throwCore Error_EOF
else throwCore (Error_Packet ("partial packet: expecting " ++ show sz ++ " bytes, got: " ++ (show $B.length hdrbs)))
return hdrbs
recvPacket :: MonadIO m => TLSCtx -> m (Either TLSError Packet)
recvPacket ctx = do
hdrbs <- readExact ctx 5
case decodeHeader hdrbs of
Left err -> return $ Left err
Right header@(Header _ _ readlen) ->
if readlen > (16384 + 2048)
then return $ Left $ Error_Protocol ("record exceeding maximum size",True, RecordOverflow)
else recvLength header readlen
where recvLength header readlen = do
content <- readExact ctx (fromIntegral readlen)
liftIO $ (loggingIORecv $ ctxLogging ctx) header content
pkt <- usingState ctx $ readPacket header (EncryptedData content)
case pkt of
Right p -> liftIO $ (loggingPacketRecv $ ctxLogging ctx) $ show p
_ -> return ()
return pkt
recvPacketSuccess :: MonadIO m => TLSCtx -> m ()
recvPacketSuccess ctx = do
pkt <- recvPacket ctx
case pkt of
Left err -> throwCore err
Right _ -> return ()
sendPacket :: MonadIO m => TLSCtx -> Packet -> m ()
sendPacket ctx pkt = do
liftIO $ (loggingPacketSent $ ctxLogging ctx) (show pkt)
dataToSend <- usingState_ ctx $ writePacket pkt
liftIO $ (loggingIOSent $ ctxLogging ctx) dataToSend
liftIO $ B.hPut (ctxHandle ctx) dataToSend
client :: (MonadIO m, CryptoRandomGen g) => TLSParams -> g -> Handle -> m TLSCtx
client params rng handle = liftIO $ newCtx handle params st
where st = (newTLSState rng) { stClientContext = True }
server :: (MonadIO m, CryptoRandomGen g) => TLSParams -> g -> Handle -> m TLSCtx
server params rng handle = liftIO $ newCtx handle params st
where st = (newTLSState rng) { stClientContext = False }
bye :: MonadIO m => TLSCtx -> m ()
bye ctx = sendPacket ctx $ Alert [(AlertLevel_Warning, CloseNotify)]
handshakeClient :: MonadIO m => TLSCtx -> m ()
handshakeClient ctx = do
crand <- getStateRNG ctx 32 >>= return . ClientRandom
extensions <- getExtensions
usingState_ ctx (startHandshakeClient ver crand)
sendPacket ctx $ Handshake
[ ClientHello ver crand (Session Nothing) (map cipherID ciphers)
(map compressionID compressions) extensions
]
whileStatus ctx (/= (StatusHandshake HsStatusServerHelloDone)) $ do
pkts <- recvPacket ctx
case pkts of
Left err -> throwCore err
Right l -> processServerInfo l
certRequested <- return False
when certRequested (sendPacket ctx $ Handshake [Certificates clientCerts])
sendClientKeyXchg
sendPacket ctx ChangeCipherSpec
liftIO $ hFlush $ ctxHandle ctx
cf <- usingState_ ctx $ getHandshakeDigest True
sendPacket ctx (Handshake [Finished cf])
recvPacketSuccess ctx >> recvPacketSuccess ctx >> return ()
where
params = ctxParams ctx
ver = pConnectVersion params
allowedvers = pAllowedVersions params
ciphers = pCiphers params
compressions = pCompressions params
clientCerts = map fst $ pCertificates params
getExtensions =
if pUseSecureRenegotiation params
then usingState_ ctx (getVerifiedData True) >>= \vd -> return [ (0xff01, vd) ]
else return []
processServerInfo (Handshake hss) = mapM_ processHandshake hss
processServerInfo _ = return ()
processHandshake (ServerHello rver _ _ cipher _ _) = do
when (rver == SSL2) $ throwCore $ Error_Protocol ("ssl2 is not supported", True, ProtocolVersion)
case find ((==) rver) allowedvers of
Nothing -> throwCore $ Error_Protocol ("version " ++ show ver ++ "is not supported", True, ProtocolVersion)
Just _ -> usingState_ ctx $ setVersion ver
case find ((==) cipher . cipherID) ciphers of
Nothing -> throwCore $ Error_Protocol ("no cipher in common with the server", True, HandshakeFailure)
Just c -> usingState_ ctx $ setCipher c
processHandshake (Certificates certs) = do
let cb = onCertificatesRecv $ params
usage <- liftIO $ cb certs
case usage of
CertificateUsageAccept -> return ()
CertificateUsageReject reason -> certificateRejected reason
processHandshake (CertRequest _ _ _) = do
return ()
processHandshake _ = return ()
sendClientKeyXchg = do
prerand <- getStateRNG ctx 46 >>= return . ClientKeyData
sendPacket ctx $ Handshake [ClientKeyXchg ver prerand]
certificateRejected CertificateRejectRevoked =
throwCore $ Error_Protocol ("certificate is revoked", True, CertificateRevoked)
certificateRejected CertificateRejectExpired =
throwCore $ Error_Protocol ("certificate has expired", True, CertificateExpired)
certificateRejected CertificateRejectUnknownCA =
throwCore $ Error_Protocol ("certificate has unknown CA", True, UnknownCa)
certificateRejected (CertificateRejectOther s) =
throwCore $ Error_Protocol ("certificate rejected: " ++ s, True, CertificateUnknown)
handshakeServerWith :: MonadIO m => TLSCtx -> Handshake -> m ()
handshakeServerWith ctx (ClientHello ver _ _ ciphers compressions _) = do
when (ver == SSL2) $ throwCore $ Error_Protocol ("ssl2 is not supported", True, ProtocolVersion)
when (not $ elem ver (pAllowedVersions params)) $
throwCore $ Error_Protocol ("version " ++ show ver ++ "is not supported", True, ProtocolVersion)
when (commonCiphers == []) $
throwCore $ Error_Protocol ("no cipher in common with the client", True, HandshakeFailure)
when (commonCompressions == []) $
throwCore $ Error_Protocol ("no compression in common with the client", True, HandshakeFailure)
usingState_ ctx $ modify (\st -> st
{ stVersion = ver
, stCipher = Just usedCipher
})
handshakeSendServerData
liftIO $ hFlush $ ctxHandle ctx
whileStatus ctx (/= (StatusHandshake HsStatusClientFinished)) (recvPacketSuccess ctx)
sendPacket ctx ChangeCipherSpec
cf <- usingState_ ctx $ getHandshakeDigest False
sendPacket ctx (Handshake [Finished cf])
liftIO $ hFlush $ ctxHandle ctx
return ()
where
params = ctxParams ctx
commonCiphers = intersect ciphers (map cipherID $ pCiphers params)
usedCipher = fromJust $ find (\c -> cipherID c == head commonCiphers) (pCiphers params)
commonCompressions = intersect compressions (map compressionID $ pCompressions params)
usedCompression = fromJust $ find (\c -> compressionID c == head commonCompressions) (pCompressions params)
srvCerts = map fst $ pCertificates params
privKeys = map snd $ pCertificates params
needKeyXchg = cipherExchangeNeedMoreData $ cipherKeyExchange usedCipher
handshakeSendServerData = do
srand <- getStateRNG ctx 32 >>= return . ServerRandom
case privKeys of
(Just privkey : _) -> usingState_ ctx $ setPrivateKey privkey
_ -> return ()
secReneg <- usingState_ ctx getSecureRenegotiation
extensions <- if secReneg
then do
vf <- usingState_ ctx $ do
cvf <- getVerifiedData True
svf <- getVerifiedData False
return $ encodeExtSecureRenegotiation cvf (Just svf)
return [ (0xff01, vf) ]
else return []
usingState_ ctx (setVersion ver >> setServerRandom srand)
sendPacket ctx $ Handshake
[ ServerHello ver srand (Session Nothing) (cipherID usedCipher)
(compressionID usedCompression) extensions
, Certificates srvCerts
]
when needKeyXchg $ do
let skg = SKX_RSA Nothing
sendPacket ctx (Handshake [ServerKeyXchg skg])
when (pWantClientCert params) $ do
let certTypes = [ CertificateType_RSA_Sign ]
let creq = CertRequest certTypes Nothing [0,0,0]
sendPacket ctx (Handshake [creq])
sendPacket ctx (Handshake [ServerHelloDone])
handshakeServerWith _ _ = fail "unexpected handshake type received. expecting client hello"
handshakeServer :: MonadIO m => TLSCtx -> m ()
handshakeServer ctx = do
pkts <- recvPacket ctx
case pkts of
Right (Handshake [hs]) -> handshakeServerWith ctx hs
x -> fail ("unexpected type received. expecting handshake ++ " ++ show x)
handshake :: MonadIO m => TLSCtx -> m Bool
handshake ctx = do
cc <- usingState_ ctx (stClientContext <$> get)
liftIO $ handleException $ if cc then handshakeClient ctx else handshakeServer ctx
where
handleException f = catch (f >> return True) (\e -> handler e >> return False)
handler e = case fromException e of
Just err -> sendPacket ctx (errorToAlert err)
Nothing -> sendPacket ctx (errorToAlert $ Error_Misc "")
sendData :: MonadIO m => TLSCtx -> L.ByteString -> m ()
sendData ctx dataToSend = do
eofed <- ctxEOF ctx
when eofed $ liftIO $ throwIO $ mkIOError eofErrorType "sendData" (Just (ctxHandle ctx)) Nothing
mapM_ sendDataChunk (L.toChunks dataToSend)
where sendDataChunk d = if B.length d > 16384
then do
let (sending, remain) = B.splitAt 16384 d
sendPacket ctx $ AppData sending
sendDataChunk remain
else
sendPacket ctx $ AppData d
recvData :: MonadIO m => TLSCtx -> m L.ByteString
recvData ctx = do
eofed <- ctxEOF ctx
when eofed $ liftIO $ throwIO $ mkIOError eofErrorType "recvData" (Just (ctxHandle ctx)) Nothing
pkt <- recvPacket ctx
case pkt of
Right (Handshake [ch@(ClientHello _ _ _ _ _ _)]) ->
handshakeServerWith ctx ch >> recvData ctx
Right (Handshake [HelloRequest]) ->
handshakeClient ctx >> recvData ctx
Right (Alert [(AlertLevel_Fatal, _)]) -> do
setEOF ctx
return L.empty
Right (Alert [(AlertLevel_Warning, CloseNotify)]) -> do
setEOF ctx
return L.empty
Right (AppData x) -> return $ L.fromChunks [x]
Right p -> error ("error unexpected packet: p" ++ show p)
Left err -> error ("error received: " ++ show err)