module Network.TLS.Client
( TLSClientParams(..)
, TLSClientCallbacks(..)
, TLSStateClient
, newIState
, recvPacket
, sendPacket
, connect
, sendData
, recvData
, close
, clientEnum
, clientEnumSimple
) where
import Data.Maybe
import Data.Word
import Control.Applicative ((<$>))
import Data.Certificate.X509
import Network.TLS.Cipher
import Network.TLS.Struct
import Network.TLS.Packet
import Network.TLS.State
import Network.TLS.Sending
import Network.TLS.Receiving
import Network.TLS.SRandom
import qualified Data.ByteString as B
import qualified Data.ByteString.Lazy as L
import System.IO (Handle, hFlush)
import Data.List (find)
import Data.IORef
import qualified Data.Enumerator as E
import qualified Control.Monad.IO.Class as Trans
import qualified Control.Monad.Trans.Class as Trans
import qualified Codec.Crypto.AES.Random as AESRand
import Data.Bits
import Control.Monad (when, unless)
data TLSClientCallbacks = TLSClientCallbacks
{ cbCertificates :: Maybe ([Certificate] -> IO Bool)
}
instance Show TLSClientCallbacks where
show _ = "[callbacks]"
data TLSClientParams = TLSClientParams
{ cpConnectVersion :: Version
, cpAllowedVersions :: [Version]
, cpSession :: Maybe [Word8]
, cpCiphers :: [Cipher]
, cpCertificate :: Maybe Certificate
, cpCallbacks :: TLSClientCallbacks
} deriving (Show)
data TLSStateClient = TLSStateClient
{ scParams :: TLSClientParams
, scTLSState :: TLSState
, scCertRequested :: Bool
} deriving (Show)
type IState = IORef TLSStateClient
newIState :: TLSClientParams -> SRandomGen -> IO IState
newIState params rng = newIORef $ TLSStateClient
{ scParams = params
, scTLSState = state
, scCertRequested = False
}
where
state = (newTLSState rng) { stVersion = cpConnectVersion params, stClientContext = True }
recvPacket :: IState -> Handle -> IO (Either TLSError Packet)
recvPacket istate handle = do
hdr <- B.hGet handle 5 >>= return . decodeHeader
case hdr of
Left err -> return $ Left err
Right header@(Header _ _ readlen) -> do
content <- B.hGet handle (fromIntegral readlen)
runIStateWrapper (readPacket header (EncryptedData content)) istate
newtype IStateWrapper a = IStateWrapper { runIStateWrapper :: IState -> IO a }
instance Monad IStateWrapper where
return = IStateWrapper . const . return
(IStateWrapper f) >>= g = IStateWrapper $ \ i -> do
x <- f i
runIStateWrapper (g x) i
instance MonadTLSState IStateWrapper where
getTLSState = IStateWrapper $ fmap scTLSState . readIORef
putTLSState s = IStateWrapper $ \i -> do
cs <- readIORef i
writeIORef i cs { scTLSState = s }
sendPacket :: IState -> Handle -> Packet -> IO ()
sendPacket istate handle pkt = do
dataToSend <- runIStateWrapper (writePacket pkt) istate
B.hPut handle dataToSend
recvServerHello :: IState -> Handle -> IO ()
recvServerHello istate handle = do
state' <- readIORef istate
let ciphers = cpCiphers $ scParams state'
let allowedvers = cpAllowedVersions $ scParams state'
let callbacks = cpCallbacks $ scParams state'
pkt <- recvPacket istate handle
let hs = case pkt of
Right (Handshake h) -> h
Left err -> error ("error received: " ++ show err)
Right x -> error ("unexpected packet received, expecting handshake " ++ show x)
case hs of
ServerHello ver _ _ cipher _ _ -> do
case find ((==) ver) allowedvers of
Nothing -> error ("received version which is not allowed: " ++ show ver)
Just _ -> do
state <- readIORef istate
let st = state { scTLSState = (scTLSState state) { stVersion = ver } }
writeIORef istate st
case find ((==) cipher . cipherID) ciphers of
Nothing -> error "no cipher in common with the server"
Just c -> do
state <- readIORef istate
let st = state { scTLSState = (scTLSState state) { stCipher = Just c } }
writeIORef istate st
recvServerHello istate handle
CertRequest _ _ _ -> do
sc <- readIORef istate
writeIORef istate sc { scCertRequested = True }
recvServerHello istate handle
Certificates certs -> do
valid <- maybe (return True) (\cb -> cb certs) (cbCertificates callbacks)
unless valid $ error "certificates received deemed invalid by user"
recvServerHello istate handle
ServerHelloDone -> return ()
_ -> error "unexpected handshake message received in server hello messages"
connectSendClientHello :: IState -> Handle -> ClientRandom -> IO ()
connectSendClientHello istate handle crand = do
state <- readIORef istate
let ver = cpConnectVersion $ scParams state
let ciphers = cpCiphers $ scParams state
sendPacket istate handle $ Handshake (ClientHello ver crand (Session Nothing) (map cipherID ciphers) [ 0 ] Nothing)
connectSendClientCertificate :: IState -> Handle -> IO ()
connectSendClientCertificate istate handle = do
certRequested <- scCertRequested <$> readIORef istate
when certRequested $ do
clientCert <- cpCertificate . scParams <$> readIORef istate
sendPacket istate handle $ Handshake (Certificates $ maybe [] (:[]) clientCert)
connectSendClientKeyXchg :: IState -> Handle -> ClientKeyData -> IO ()
connectSendClientKeyXchg istate handle prerand = do
ver <- cpConnectVersion . scParams <$> readIORef istate
sendPacket istate handle $ Handshake (ClientKeyXchg ver prerand)
connectSendFinish :: IState -> Handle -> IO ()
connectSendFinish istate handle = do
cf <- runIStateWrapper (getHandshakeDigest True) istate
sendPacket istate handle (Handshake $ Finished $ B.unpack cf)
connect :: IState -> Handle -> ClientRandom -> ClientKeyData -> IO ()
connect istate handle crand premasterRandom = do
connectSendClientHello istate handle crand
recvServerHello istate handle
connectSendClientCertificate istate handle
connectSendClientKeyXchg istate handle premasterRandom
sendPacket istate handle (ChangeCipherSpec)
hFlush handle
connectSendFinish istate handle
pktCCS <- recvPacket istate handle
case pktCCS of
Right ChangeCipherSpec -> return ()
x -> error ("unexpected reply. expecting change cipher spec " ++ show x)
pktFin <- recvPacket istate handle
case pktFin of
Right (Handshake (Finished _)) -> return ()
x -> error ("unexpected reply. expecting finished " ++ show x)
return ()
sendDataChunk :: IState -> Handle -> B.ByteString -> IO ()
sendDataChunk istate handle d =
if B.length d > 16384
then do
let (sending, remain) = B.splitAt 16384 d
sendPacket istate handle $ AppData sending
sendDataChunk istate handle remain
else
sendPacket istate handle $ AppData d
sendData :: IState -> Handle -> L.ByteString -> IO ()
sendData istate handle d = mapM_ (sendDataChunk istate handle) (L.toChunks d)
recvData :: IState -> Handle -> IO L.ByteString
recvData istate handle = do
pkt <- recvPacket istate handle
case pkt of
Right (AppData x) -> return $ L.fromChunks [x]
Right (Handshake HelloRequest) -> do
state <- readIORef istate
let st = scTLSState state
let (bytes, rng') = getRandomBytes (stRandomGen st) 32
let (premaster, rng'') = getRandomBytes rng' 46
writeIORef istate $ state { scTLSState = st { stRandomGen = rng'' } }
let crand = fromJust $ clientRandom bytes
connect istate handle crand (ClientKeyData $ B.pack premaster)
recvData istate handle
Left err -> error ("error received: " ++ show err)
_ -> error "unexpected item"
close :: IState -> Handle -> IO ()
close istate handle = do
sendPacket istate handle $ Alert (AlertLevel_Warning, CloseNotify)
clientEnumSimple
:: Trans.MonadIO m
=> Handle
-> (E.Iteratee B.ByteString m () -> E.Enumerator B.ByteString m a -> m b)
-> m b
clientEnumSimple h f = do
ranByte <- Trans.liftIO $ B.head <$> AESRand.randBytes 1
_ <- Trans.liftIO $ AESRand.randBytes (fromIntegral ranByte)
clientRandom' <- Trans.liftIO $ fromJust . clientRandom . B.unpack <$> AESRand.randBytes 32
premasterRandom <- Trans.liftIO $ ClientKeyData <$> AESRand.randBytes 46
seqInit <- Trans.liftIO $ conv . B.unpack <$> AESRand.randBytes 4
let clientstate = TLSClientParams
{ cpConnectVersion = TLS10
, cpAllowedVersions = [ TLS10, TLS11 ]
, cpSession = Nothing
, cpCiphers = ciphers
, cpCertificate = Nothing
, cpCallbacks = TLSClientCallbacks
{ cbCertificates = Nothing
}
}
clientEnum clientstate (makeSRandomGen seqInit) h clientRandom' premasterRandom f
where
ciphers =
[ cipher_AES128_SHA1
, cipher_AES256_SHA1
, cipher_RC4_128_MD5
, cipher_RC4_128_SHA1
]
conv l = (a `shiftL` 24) .|. (b `shiftL` 16) .|. (c `shiftL` 8) .|. d
where
[a,b,c,d] = map fromIntegral l
clientEnum :: Trans.MonadIO m
=> TLSClientParams -> SRandomGen -> Handle -> ClientRandom -> ClientKeyData
-> (E.Iteratee B.ByteString m () -> E.Enumerator B.ByteString m a -> m b)
-> m b
clientEnum tcp srg h cr ckd f = do
istate <- Trans.liftIO $ newIState tcp srg
Trans.liftIO $ connect istate h cr ckd
b <- f (iter istate) (enum istate)
Trans.liftIO $ close istate h
return b
where
iter :: Trans.MonadIO m => IState -> E.Iteratee B.ByteString m ()
iter istate =
E.continue go
where
go E.EOF = return ()
go (E.Chunks xs) = do
Trans.liftIO $ sendData istate h $ L.fromChunks xs
E.continue go
enum :: Trans.MonadIO m => IState -> E.Enumerator B.ByteString m a
enum istate (E.Continue k) = do
lbs <- Trans.liftIO $ recvData istate h
let chunks = E.Chunks $ L.toChunks lbs
step <- Trans.lift $ E.runIteratee $ k chunks
enum istate step
enum _ step = E.returnI step