module Snap.Internal.Http.Server.GnuTLS
( GnuTLSException(..)
, initTLS
, stopTLS
, bindHttps
, freePort
, createSession
, endSession
, recv
, send
) where
import Control.Exception
import Data.ByteString (ByteString)
import Data.Dynamic
import Foreign.C
import Snap.Internal.Debug
import Snap.Internal.Http.Server.Backend
#ifdef GNUTLS
import qualified Data.ByteString as B
import Data.ByteString.Internal (w2c)
import qualified Data.ByteString.Internal as BI
import qualified Data.ByteString.Unsafe as BI
import Foreign
import qualified Network.Socket as Socket
#endif
data GnuTLSException = GnuTLSException String
deriving (Show, Typeable)
instance Exception GnuTLSException
#ifndef GNUTLS
initTLS :: IO ()
initTLS = throwIO $ GnuTLSException "TLS is not supported"
stopTLS :: IO ()
stopTLS = return ()
bindHttps :: ByteString -> Int -> FilePath -> FilePath -> IO ListenSocket
bindHttps _ _ _ _ = throwIO $ GnuTLSException "TLS is not supported"
freePort :: ListenSocket -> IO ()
freePort _ = return ()
createSession :: ListenSocket -> Int -> CInt -> IO () -> IO NetworkSession
createSession _ _ _ _ = throwIO $ GnuTLSException "TLS is not supported"
endSession :: NetworkSession -> IO ()
endSession _ = return ()
send :: IO () -> IO () -> NetworkSession -> ByteString -> IO ()
send _ _ _ _ = return ()
recv :: IO b -> NetworkSession -> IO (Maybe ByteString)
recv _ _ = throwIO $ GnuTLSException "TLS is not supported"
#else
initTLS :: IO ()
initTLS = gnutls_set_threading_helper >>
throwErrorIf "TLS init" gnutls_global_init
stopTLS :: IO ()
stopTLS = gnutls_global_deinit
bindHttps :: ByteString
-> Int
-> FilePath
-> FilePath
-> IO ListenSocket
bindHttps bindAddress bindPort cert key = do
sock <- Socket.socket Socket.AF_INET Socket.Stream 0
addr <- getHostAddr bindPort bindAddress
Socket.setSocketOption sock Socket.ReuseAddr 1
Socket.bindSocket sock addr
Socket.listen sock 150
creds <- loadCredentials cert key
dh <- regenerateDHParam creds
return $ ListenHttps sock (castPtr creds) (castPtr dh)
loadCredentials :: FilePath
-> FilePath
-> IO (Ptr GnuTLSCredentials)
loadCredentials cert key = alloca $ \cPtr -> do
throwErrorIf "TLS allocate" $ gnutls_certificate_allocate_credentials cPtr
creds <- peek cPtr
withCString cert $ \certstr -> withCString key $ \keystr ->
throwErrorIf "TLS set Certificate" $
gnutls_certificate_set_x509_key_file
creds certstr keystr gnutls_x509_fmt_pem
return creds
regenerateDHParam :: Ptr GnuTLSCredentials -> IO (Ptr GnuTLSDHParam)
regenerateDHParam creds = alloca $ \dhptr -> do
throwErrorIf "TLS allocate" $ gnutls_dh_params_init dhptr
dh <- peek dhptr
throwErrorIf "TLS DHParm" $ gnutls_dh_params_generate2 dh 1024
gnutls_certificate_set_dh_params creds dh
return dh
freePort :: ListenSocket -> IO ()
freePort (ListenHttps _ creds dh) = do
gnutls_certificate_free_credentials $ castPtr creds
gnutls_dh_params_deinit $ castPtr dh
freePort _ = return ()
createSession :: ListenSocket -> Int -> CInt -> IO () -> IO NetworkSession
createSession (ListenHttps _ creds _) recvSize socket on_block =
alloca $ \sPtr -> do
throwErrorIf "TLS allocate" $ gnutls_init sPtr 1
session <- peek sPtr
finishInit session `onException` gnutls_deinit session
where
finishInit session = do
throwErrorIf "TLS session" $
gnutls_credentials_set session 1 $ castPtr creds
throwErrorIf "TLS session" $ gnutls_set_default_priority session
gnutls_certificate_send_x509_rdn_sequence session 1
gnutls_session_enable_compatibility_mode session
let s = NetworkSession socket (castPtr session) $
fromIntegral recvSize
gnutls_transport_set_ptr session $ intPtrToPtr $ fromIntegral $ socket
handshake s on_block
return s
createSession _ _ _ _ = error "Invalid socket"
endSession :: NetworkSession -> IO ()
endSession (NetworkSession _ session _) = do
throwErrorIf "TLS bye" $ gnutls_bye (castPtr session) 1 `finally` do
gnutls_deinit $ castPtr session
handshake :: NetworkSession -> IO () -> IO ()
handshake s@(NetworkSession { _session = session}) on_block = do
rc <- gnutls_handshake $ castPtr session
case rc of
x | x >= 0 -> return ()
| isIntrCode x -> handshake s on_block
| isAgainCode x -> on_block >> handshake s on_block
| otherwise -> throwError "TLS handshake" rc
send :: IO () -> IO () -> NetworkSession -> ByteString -> IO ()
send tickleTimeout onBlock (NetworkSession { _session = session}) bs =
BI.unsafeUseAsCStringLen bs $ uncurry loop
where
loop ptr len = do
sent <- gnutls_record_send (castPtr session) ptr $ fromIntegral len
let sent' = fromIntegral sent
case sent' of
x | x == 0 || x == len -> return ()
| x > 0 && x < len -> tickleTimeout >>
loop (plusPtr ptr sent') (len sent')
| isIntrCode x -> loop ptr len
| isAgainCode x -> onBlock >> loop ptr len
| otherwise -> throwError "TLS send" $
fromIntegral sent'
recv :: IO b -> NetworkSession -> IO (Maybe ByteString)
recv onBlock (NetworkSession _ session recvLen) = do
fp <- BI.mallocByteString recvLen
sz <- withForeignPtr fp loop
if (sz :: Int) <= 0
then return Nothing
else return $ Just $ BI.fromForeignPtr fp 0 $ fromEnum sz
where
loop recvBuf = do
debug $ "TLS: calling record_recv with recvLen=" ++ show recvLen
size <- gnutls_record_recv (castPtr session) recvBuf $ toEnum recvLen
debug $ "TLS: record_recv returned with size=" ++ show size
let size' = fromIntegral size
case size' of
x | x >= 0 -> return x
| isIntrCode x -> loop recvBuf
| isAgainCode x -> onBlock >> loop recvBuf
| otherwise -> (throwError "TLS recv" $ fromIntegral size')
throwError :: String -> ReturnCode -> IO a
throwError prefix rc = gnutls_strerror rc >>=
peekCString >>=
throwIO . GnuTLSException . (prefix'++)
where
prefix' = prefix ++ "<" ++ show rc ++ ">: "
throwErrorIf :: String -> IO ReturnCode -> IO ()
throwErrorIf prefix action = do
rc <- action
if (rc < 0)
then throwError prefix rc
else return ()
isAgainCode :: (Integral a) => a -> Bool
isAgainCode x = (fromIntegral x) == (28 :: Int)
isIntrCode :: (Integral a) => a -> Bool
isIntrCode x = (fromIntegral x) == (52 :: Int)
getHostAddr :: Int
-> ByteString
-> IO Socket.SockAddr
getHostAddr p s = do
h <- if s == "*"
then return Socket.iNADDR_ANY
else Socket.inet_addr (map w2c . B.unpack $ s)
return $ Socket.SockAddrInet (fromIntegral p) h
newtype ReturnCode = ReturnCode CInt
deriving (Show, Eq, Ord, Num, Real, Enum, Integral)
data GnuTLSCredentials
data GnuTLSSession
data GnuTLSDHParam
foreign import ccall safe
"gnutls_set_threading_helper"
gnutls_set_threading_helper :: IO ()
foreign import ccall safe
"gnutls/gnutls.h gnutls_global_init"
gnutls_global_init :: IO ReturnCode
foreign import ccall safe
"gnutls/gnutls.h gnutls_global_deinit"
gnutls_global_deinit :: IO ()
foreign import ccall safe
"gnutls/gnutls.h gnutls_strerror"
gnutls_strerror :: ReturnCode -> IO CString
foreign import ccall unsafe
"gnutls/gnutls.h gnutls_init"
gnutls_init :: Ptr (Ptr GnuTLSSession) -> CInt -> IO ReturnCode
foreign import ccall unsafe
"gnutls/gnutls.h gnutls_deinit"
gnutls_deinit :: Ptr GnuTLSSession -> IO ()
foreign import ccall safe
"gnutls/gnutls.h gnutls_handshake"
gnutls_handshake :: Ptr GnuTLSSession -> IO ReturnCode
foreign import ccall safe
"gnutls/gnutls.h gnutls_bye"
gnutls_bye :: Ptr GnuTLSSession -> CInt -> IO ReturnCode
foreign import ccall unsafe
"gnutls/gnutls.h gnutls_set_default_priority"
gnutls_set_default_priority :: Ptr GnuTLSSession -> IO ReturnCode
foreign import ccall unsafe
"gnutls/gnutls.h gnutls_session_enable_compatibility_mode"
gnutls_session_enable_compatibility_mode :: Ptr GnuTLSSession -> IO ()
foreign import ccall unsafe
"gnutls/gnutls.h gnutls_certificate_send_x509_rdn_sequence"
gnutls_certificate_send_x509_rdn_sequence
:: Ptr GnuTLSSession -> CInt -> IO ()
foreign import ccall safe
"gnutls/gnutls.h gnutls_certificate_allocate_credentials"
gnutls_certificate_allocate_credentials
:: Ptr (Ptr GnuTLSCredentials) -> IO ReturnCode
foreign import ccall safe
"gnutls/gnutls.h gnutls_certificate_free_credentials"
gnutls_certificate_free_credentials
:: Ptr GnuTLSCredentials -> IO ()
gnutls_x509_fmt_pem :: CInt
gnutls_x509_fmt_pem = 1
foreign import ccall safe
"gnutls/gnutls.h gnutls_certificate_set_x509_key_file"
gnutls_certificate_set_x509_key_file
:: Ptr GnuTLSCredentials -> CString -> CString -> CInt -> IO ReturnCode
foreign import ccall unsafe
"gnutls/gnutls.h gnutls_credentials_set"
gnutls_credentials_set
:: Ptr GnuTLSSession -> CInt -> Ptr a -> IO ReturnCode
foreign import ccall unsafe
"gnutls/gnutls.h gnutls_transport_set_ptr"
gnutls_transport_set_ptr :: Ptr GnuTLSSession -> Ptr a -> IO ()
foreign import ccall unsafe
"gnutls/gnutls.h gnutls_record_recv"
gnutls_record_recv :: Ptr GnuTLSSession -> Ptr a -> CSize -> IO CSize
foreign import ccall unsafe
"gnutls/gnutls.h gnutls_record_send"
gnutls_record_send :: Ptr GnuTLSSession -> Ptr a -> CSize -> IO CSize
foreign import ccall safe
"gnutls/gnutls.h gnutls_dh_params_init"
gnutls_dh_params_init :: Ptr (Ptr GnuTLSDHParam) -> IO ReturnCode
foreign import ccall safe
"gnutls/gnutls.h gnutls_dh_params_deinit"
gnutls_dh_params_deinit :: Ptr GnuTLSDHParam -> IO ()
foreign import ccall safe
"gnutls/gnutls.h gnutls_dh_params_generate2"
gnutls_dh_params_generate2 :: Ptr GnuTLSDHParam -> CUInt -> IO ReturnCode
foreign import ccall safe
"gnutls/gnutls.h gnutls_certificate_set_dh_params"
gnutls_certificate_set_dh_params
:: Ptr GnuTLSCredentials -> Ptr GnuTLSDHParam -> IO ()
#endif