module Network.TLS.Server
( TLSServerParams(..)
, TLSServerCallbacks(..)
, TLSStateServer
, runTLSServer
, recvPacket
, sendPacket
, listen
, sendData
, recvData
, close
) where
import Data.Word
import Data.Maybe
import Data.List (intersect, find)
import Control.Monad.Trans
import Control.Monad.State
import Control.Applicative ((<$>))
import Data.Certificate.X509
import qualified Data.Certificate.KeyRSA as KeyRSA
import qualified Data.Certificate.KeyDSA as KeyDSA
import Network.TLS.Cipher
import Network.TLS.Crypto
import Network.TLS.Struct
import Network.TLS.Packet
import Network.TLS.State
import Network.TLS.Sending
import Network.TLS.Receiving
import Network.TLS.SRandom
import qualified Data.ByteString as B
import qualified Data.ByteString.Lazy as L
import System.IO (Handle, hFlush)
import qualified Crypto.Cipher.RSA as RSA
type TLSServerCert = (B.ByteString, X509, KeyRSA.Private)
data TLSServerCallbacks = TLSServerCallbacks
{ cbCertificates :: Maybe ([Certificate] -> IO Bool)
}
instance Show TLSServerCallbacks where
show _ = "[callbacks]"
instance Show KeyRSA.Private where
show _ = "[privatekey]"
data TLSServerParams = TLSServerParams
{ spAllowedVersions :: [Version]
, spSessions :: [[Word8]]
, spCiphers :: [Cipher]
, spCertificate :: Maybe TLSServerCert
, spWantClientCert :: Bool
, spCallbacks :: TLSServerCallbacks
} deriving (Show)
data TLSStateServer = TLSStateServer
{ scParams :: TLSServerParams
, scTLSState :: TLSState
}
newtype TLSServer m a = TLSServer { runTLSC :: StateT TLSStateServer m a }
deriving (Monad, MonadState TLSStateServer)
instance Monad m => MonadTLSState (TLSServer m) where
getTLSState = TLSServer (get >>= return . scTLSState)
putTLSState s = TLSServer (get >>= put . (\st -> st { scTLSState = s }))
instance MonadTrans TLSServer where
lift = TLSServer . lift
instance (Monad m, Functor m) => Functor (TLSServer m) where
fmap f = TLSServer . fmap f . runTLSC
runTLSServerST :: TLSServer m a -> TLSStateServer -> m (a, TLSStateServer)
runTLSServerST f s = runStateT (runTLSC f) s
runTLSServer :: TLSServer m a -> TLSServerParams -> SRandomGen -> m (a, TLSStateServer)
runTLSServer f params rng = runTLSServerST f (TLSStateServer { scParams = params, scTLSState = state })
where state = (newTLSState rng) { stClientContext = False }
recvPacket :: Handle -> TLSServer IO (Either TLSError [Packet])
recvPacket handle = do
hdr <- lift $ B.hGet handle 5 >>= return . decodeHeader
case hdr of
Left err -> return $ Left err
Right header@(Header _ _ readlen) -> do
content <- lift $ B.hGet handle (fromIntegral readlen)
readPacket header (EncryptedData content)
sendPacket :: Handle -> Packet -> TLSServer IO ()
sendPacket handle pkt = do
dataToSend <- writePacket pkt
lift $ B.hPut handle dataToSend
handleClientHello :: Handshake -> TLSServer IO ()
handleClientHello (ClientHello ver _ _ ciphers compressionID _) = do
cfg <- get >>= return . scParams
when (not $ elem ver (spAllowedVersions cfg)) $ do
fail "unsupported version"
let commonCiphers = intersect ciphers (map cipherID $ spCiphers cfg)
when (commonCiphers == []) $ do
fail ("unsupported cipher: " ++ show ciphers ++ " : server : " ++ (show $ map cipherID $ spCiphers cfg))
when (not $ elem 0 compressionID) $ do
fail "unsupported compression"
modifyTLSState (\st -> st
{ stVersion = ver
, stCipher = find (\c -> cipherID c == (head commonCiphers)) (spCiphers cfg)
})
handleClientHello _ = do
fail "unexpected handshake type received. expecting client hello"
handshakeSendServerData :: Handle -> TLSServer IO ()
handshakeSendServerData handle = do
srand <- fromJust . serverRandom <$> withTLSRNG (\rng -> getRandomBytes rng 32)
sp <- get >>= return . scParams
st <- getTLSState
let cipher = fromJust $ stCipher st
let srvhello = ServerHello (stVersion st) srand (Session Nothing) (cipherID cipher) 0 Nothing
let (_,cert,privkeycert) = fromJust $ spCertificate sp
let srvcert = Certificates [ cert ]
let needkeyxchg = cipherExchangeNeedMoreData $ cipherKeyExchange cipher
let privkey = PrivRSA $ RSA.PrivateKey
{ RSA.private_sz = fromIntegral $ KeyRSA.lenmodulus privkeycert
, RSA.private_n = KeyRSA.modulus privkeycert
, RSA.private_d = KeyRSA.private_exponant privkeycert
, RSA.private_p = KeyRSA.p1 privkeycert
, RSA.private_q = KeyRSA.p2 privkeycert
, RSA.private_dP = KeyRSA.exp1 privkeycert
, RSA.private_dQ = KeyRSA.exp2 privkeycert
, RSA.private_qinv = KeyRSA.coef privkeycert
}
setPrivateKey privkey
sendPacket handle (Handshake srvhello)
sendPacket handle (Handshake srvcert)
when needkeyxchg $ do
let skg = SKX_RSA Nothing
sendPacket handle (Handshake $ ServerKeyXchg skg)
when (spWantClientCert sp) $ do
let certTypes = [ CertificateType_RSA_Sign ]
let creq = CertRequest certTypes Nothing [0,0,0]
sendPacket handle (Handshake creq)
sendPacket handle (Handshake ServerHelloDone)
handshakeSendFinish :: Handle -> TLSServer IO ()
handshakeSendFinish handle = do
cf <- getHandshakeDigest False
sendPacket handle (Handshake $ Finished $ B.unpack cf)
handshake :: Handle -> TLSServer IO ()
handshake handle = do
handshakeSendServerData handle
lift $ hFlush handle
whileStatus (/= (StatusHandshake HsStatusClientFinished)) (recvPacket handle)
sendPacket handle ChangeCipherSpec
handshakeSendFinish handle
lift $ hFlush handle
return ()
listen :: Handle -> TLSServer IO ()
listen handle = do
pkts <- recvPacket handle
case pkts of
Right [Handshake hs] -> handleClientHello hs
x -> fail ("unexpected type received. expecting handshake ++ " ++ show x)
handshake handle
sendDataChunk :: Handle -> B.ByteString -> TLSServer IO ()
sendDataChunk handle d =
if B.length d > 16384
then do
let (sending, remain) = B.splitAt 16384 d
sendPacket handle $ AppData sending
sendDataChunk handle remain
else
sendPacket handle $ AppData d
sendData :: Handle -> L.ByteString -> TLSServer IO ()
sendData handle d = mapM_ (sendDataChunk handle) (L.toChunks d)
recvData :: Handle -> TLSServer IO L.ByteString
recvData handle = do
pkt <- recvPacket handle
case pkt of
Right [Handshake (ClientHello _ _ _ _ _ _)] -> handshake handle >> recvData handle
Right [AppData x] -> return $ L.fromChunks [x]
Left err -> error ("error received: " ++ show err)
_ -> error "unexpected item"
close :: Handle -> TLSServer IO ()
close handle = do
sendPacket handle $ Alert (AlertLevel_Warning, CloseNotify)