{-# LANGUAGE CPP, ScopedTypeVariables #-}
module Happstack.Server.Internal.TLS where
import Control.Concurrent (forkIO, killThread, myThreadId)
import Control.Exception.Extensible as E
import Control.Monad (forever, when)
import Data.Time (UTCTime)
import GHC.IO.Exception (IOErrorType(..))
import Happstack.Server.Internal.Listen (listenOn)
import Happstack.Server.Internal.Handler (request)
import Happstack.Server.Internal.Socket (acceptLite)
import Happstack.Server.Internal.TimeoutManager (cancel, initialize, register)
import Happstack.Server.Internal.TimeoutSocketTLS as TSS
import Happstack.Server.Internal.Types (Request, Response)
import Network.Socket (HostName, PortNumber, Socket, close, socketPort)
import Prelude hiding (catch)
import OpenSSL (withOpenSSL)
import OpenSSL.Session (SSL, SSLContext)
import qualified OpenSSL.Session as SSL
import Happstack.Server.Types (LogAccess, logMAccess)
import System.IO.Error (ioeGetErrorType, isFullError, isDoesNotExistError)
import System.Log.Logger (Priority(..), logM)
#ifndef mingw32_HOST_OS
import System.Posix.Signals (Handler(Ignore), installHandler, openEndedPipe)
#endif
log':: Priority -> String -> IO ()
log' :: Priority -> String -> IO ()
log' = String -> Priority -> String -> IO ()
logM String
"Happstack.Server.Internal.TLS"
data TLSConf = TLSConf {
TLSConf -> Int
tlsPort :: Int
, TLSConf -> String
tlsCert :: FilePath
, TLSConf -> String
tlsKey :: FilePath
, TLSConf -> Maybe String
tlsCA :: Maybe FilePath
, TLSConf -> Int
tlsTimeout :: Int
, TLSConf -> Maybe (LogAccess UTCTime)
tlsLogAccess :: Maybe (LogAccess UTCTime)
, TLSConf -> Maybe (Response -> IO Response)
tlsValidator :: Maybe (Response -> IO Response)
}
nullTLSConf :: TLSConf
nullTLSConf :: TLSConf
nullTLSConf =
TLSConf :: Int
-> String
-> String
-> Maybe String
-> Int
-> Maybe (LogAccess UTCTime)
-> Maybe (Response -> IO Response)
-> TLSConf
TLSConf { tlsPort :: Int
tlsPort = Int
443
, tlsCert :: String
tlsCert = String
""
, tlsKey :: String
tlsKey = String
""
, tlsCA :: Maybe String
tlsCA = Maybe String
forall a. Maybe a
Nothing
, tlsTimeout :: Int
tlsTimeout = Int
30
, tlsLogAccess :: Maybe (LogAccess UTCTime)
tlsLogAccess = LogAccess UTCTime -> Maybe (LogAccess UTCTime)
forall a. a -> Maybe a
Just LogAccess UTCTime
forall t. FormatTime t => LogAccess t
logMAccess
, tlsValidator :: Maybe (Response -> IO Response)
tlsValidator = Maybe (Response -> IO Response)
forall a. Maybe a
Nothing
}
data HTTPS = HTTPS
{ HTTPS -> Socket
httpsSocket :: Socket
, HTTPS -> SSLContext
sslContext :: SSLContext
}
httpsOnSocket :: FilePath
-> FilePath
-> Maybe FilePath
-> Socket
-> IO HTTPS
httpsOnSocket :: String -> String -> Maybe String -> Socket -> IO HTTPS
httpsOnSocket String
cert String
key Maybe String
mca Socket
socket =
do SSLContext
ctx <- IO SSLContext
SSL.context
SSLContext -> String -> IO ()
SSL.contextSetPrivateKeyFile SSLContext
ctx String
key
SSLContext -> String -> IO ()
SSL.contextSetCertificateFile SSLContext
ctx String
cert
case Maybe String
mca of
Maybe String
Nothing -> () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
(Just String
ca) -> SSLContext -> String -> IO ()
SSL.contextSetCAFile SSLContext
ctx String
ca
SSLContext -> IO ()
SSL.contextSetDefaultCiphers SSLContext
ctx
Bool
certOk <- SSLContext -> IO Bool
SSL.contextCheckPrivateKey SSLContext
ctx
Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Bool -> Bool
not Bool
certOk) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ String -> IO ()
forall a. HasCallStack => String -> a
error (String -> IO ()) -> String -> IO ()
forall a b. (a -> b) -> a -> b
$ String
"OpenTLS certificate and key do not match."
HTTPS -> IO HTTPS
forall (m :: * -> *) a. Monad m => a -> m a
return (Socket -> SSLContext -> HTTPS
HTTPS Socket
socket SSLContext
ctx)
acceptTLS :: Socket
-> SSLContext
-> IO SSL
acceptTLS :: Socket -> SSLContext -> IO SSL
acceptTLS Socket
sck SSLContext
ctx =
(SomeException -> IO SSL) -> IO SSL -> IO SSL
forall e a. Exception e => (e -> IO a) -> IO a -> IO a
handle (\ (SomeException
e :: SomeException) -> Socket -> IO ()
close Socket
sck IO () -> IO SSL -> IO SSL
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> SomeException -> IO SSL
forall e a. Exception e => e -> IO a
throwIO SomeException
e) (IO SSL -> IO SSL) -> IO SSL -> IO SSL
forall a b. (a -> b) -> a -> b
$ do
SSL
ssl <- SSLContext -> Socket -> IO SSL
SSL.connection SSLContext
ctx Socket
sck
SSL -> IO ()
SSL.accept SSL
ssl
SSL -> IO SSL
forall (m :: * -> *) a. Monad m => a -> m a
return SSL
ssl
listenTLS :: TLSConf
-> (Request -> IO Response)
-> IO ()
listenTLS :: TLSConf -> (Request -> IO Response) -> IO ()
listenTLS TLSConf
tlsConf Request -> IO Response
hand =
do IO () -> IO ()
forall a. IO a -> IO a
withOpenSSL (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
Socket
tlsSocket <- Int -> IO Socket
listenOn (TLSConf -> Int
tlsPort TLSConf
tlsConf)
HTTPS
https <- String -> String -> Maybe String -> Socket -> IO HTTPS
httpsOnSocket (TLSConf -> String
tlsCert TLSConf
tlsConf) (TLSConf -> String
tlsKey TLSConf
tlsConf) (TLSConf -> Maybe String
tlsCA TLSConf
tlsConf) Socket
tlsSocket
Int
-> Maybe (LogAccess UTCTime)
-> HTTPS
-> (Request -> IO Response)
-> IO ()
listenTLS' (TLSConf -> Int
tlsTimeout TLSConf
tlsConf) (TLSConf -> Maybe (LogAccess UTCTime)
tlsLogAccess TLSConf
tlsConf) HTTPS
https Request -> IO Response
hand
listenTLS' :: Int -> Maybe (LogAccess UTCTime) -> HTTPS -> (Request -> IO Response) -> IO ()
listenTLS' :: Int
-> Maybe (LogAccess UTCTime)
-> HTTPS
-> (Request -> IO Response)
-> IO ()
listenTLS' Int
timeout Maybe (LogAccess UTCTime)
mlog https :: HTTPS
https@(HTTPS Socket
lsocket SSLContext
_) Request -> IO Response
handler = do
#ifndef mingw32_HOST_OS
Signal -> Handler -> Maybe SignalSet -> IO Handler
installHandler Signal
openEndedPipe Handler
Ignore Maybe SignalSet
forall a. Maybe a
Nothing
#endif
Manager
tm <- Int -> IO Manager
initialize (Int
timeout Int -> Int -> Int
forall a. Num a => a -> a -> a
* (Int
10Int -> Int -> Int
forall a b. (Num a, Integral b) => a -> b -> a
^(Int
6 :: Int)))
do let work :: (Socket, SSL, HostName, PortNumber) -> IO ()
work :: (Socket, SSL, String, PortNumber) -> IO ()
work (Socket
socket, SSL
ssl, String
hn, PortNumber
p) =
do
ThreadId
tid <- IO ThreadId
myThreadId
Handle
thandle <- Manager -> IO () -> IO Handle
register Manager
tm (IO () -> IO Handle) -> IO () -> IO Handle
forall a b. (a -> b) -> a -> b
$ do Socket -> SSL -> IO ()
shutdownClose Socket
socket SSL
ssl
ThreadId -> IO ()
killThread ThreadId
tid
let timeoutIO :: TimeoutIO
timeoutIO = Handle -> Socket -> SSL -> TimeoutIO
TSS.timeoutSocketIO Handle
thandle Socket
socket SSL
ssl
TimeoutIO
-> Maybe (LogAccess UTCTime)
-> Host
-> (Request -> IO Response)
-> IO ()
request TimeoutIO
timeoutIO Maybe (LogAccess UTCTime)
mlog (String
hn, PortNumber -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral PortNumber
p) Request -> IO Response
handler
IO () -> [Handler ()] -> IO ()
forall a. IO a -> [Handler a] -> IO a
`E.catches` [ (ConnectionAbruptlyTerminated -> IO ()) -> Handler ()
forall a e. Exception e => (e -> IO a) -> Handler a
Handler ConnectionAbruptlyTerminated -> IO ()
ignoreConnectionAbruptlyTerminated
, (SomeException -> IO ()) -> Handler ()
forall a e. Exception e => (e -> IO a) -> Handler a
Handler SomeException -> IO ()
ehs
]
Handle -> IO ()
cancel Handle
thandle
Socket -> SSL -> IO ()
shutdownClose Socket
socket SSL
ssl
loop :: IO ()
loop :: IO ()
loop = IO () -> IO ()
forall (f :: * -> *) a b. Applicative f => f a -> f b
forever (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
(Socket
sck, String
peer, PortNumber
port) <- Socket -> IO (Socket, String, PortNumber)
acceptLite (HTTPS -> Socket
httpsSocket HTTPS
https)
IO () -> IO ThreadId
forkIO (IO () -> IO ThreadId) -> IO () -> IO ThreadId
forall a b. (a -> b) -> a -> b
$ do
SSL
ssl <- Socket -> SSLContext -> IO SSL
acceptTLS Socket
sck (HTTPS -> SSLContext
sslContext HTTPS
https)
(Socket, SSL, String, PortNumber) -> IO ()
work (Socket
sck, SSL
ssl, String
peer, PortNumber
port)
IO () -> (SomeException -> IO ()) -> IO ()
forall e a. Exception e => IO a -> (e -> IO a) -> IO a
`catch` (\(SomeException
e :: SomeException) -> do
Socket -> SSL -> IO ()
shutdownClose Socket
sck SSL
ssl
SomeException -> IO ()
forall e a. Exception e => e -> IO a
throwIO SomeException
e)
() -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
pe :: a -> IO ()
pe a
e = Priority -> String -> IO ()
log' Priority
ERROR (String
"ERROR in https accept thread: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ a -> String
forall a. Show a => a -> String
show a
e)
infi :: IO b
infi = IO ()
loop IO () -> (SomeException -> IO ()) -> IO ()
`catchSome` SomeException -> IO ()
forall a. Show a => a -> IO ()
pe IO () -> IO b -> IO b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> IO b
infi
PortNumber
sockPort <- Socket -> IO PortNumber
socketPort Socket
lsocket
Priority -> String -> IO ()
log' Priority
NOTICE (String
"Listening for https:// on port " String -> String -> String
forall a. [a] -> [a] -> [a]
++ PortNumber -> String
forall a. Show a => a -> String
show PortNumber
sockPort)
(IO ()
forall b. IO b
infi IO () -> (SomeException -> IO ()) -> IO ()
forall e a. Exception e => IO a -> (e -> IO a) -> IO a
`catch` (\SomeException
e -> do Priority -> String -> IO ()
log' Priority
ERROR (String
"https:// terminated by " String -> String -> String
forall a. [a] -> [a] -> [a]
++ SomeException -> String
forall a. Show a => a -> String
show (SomeException
e :: SomeException))
SomeException -> IO ()
forall e a. Exception e => e -> IO a
throwIO SomeException
e))
IO () -> IO () -> IO ()
forall a b. IO a -> IO b -> IO a
`finally` (Socket -> IO ()
close Socket
lsocket)
where
shutdownClose :: Socket -> SSL -> IO ()
shutdownClose :: Socket -> SSL -> IO ()
shutdownClose Socket
socket SSL
ssl =
do SSL -> ShutdownType -> IO ()
SSL.shutdown SSL
ssl ShutdownType
SSL.Unidirectional IO () -> (SomeException -> IO ()) -> IO ()
forall e a. Exception e => IO a -> (e -> IO a) -> IO a
`E.catch` SomeException -> IO ()
ignoreException
Socket -> IO ()
close Socket
socket IO () -> (SomeException -> IO ()) -> IO ()
forall e a. Exception e => IO a -> (e -> IO a) -> IO a
`E.catch` SomeException -> IO ()
ignoreException
ignoreConnectionAbruptlyTerminated :: SSL.ConnectionAbruptlyTerminated -> IO ()
ignoreConnectionAbruptlyTerminated :: ConnectionAbruptlyTerminated -> IO ()
ignoreConnectionAbruptlyTerminated ConnectionAbruptlyTerminated
_ = () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
ignoreSSLException :: SSL.SomeSSLException -> IO ()
ignoreSSLException :: SomeSSLException -> IO ()
ignoreSSLException SomeSSLException
_ = () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
ignoreException :: SomeException -> IO ()
ignoreException :: SomeException -> IO ()
ignoreException SomeException
_ = () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
ehs :: SomeException -> IO ()
ehs :: SomeException -> IO ()
ehs SomeException
x = Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when ((SomeException -> Maybe AsyncException
forall e. Exception e => SomeException -> Maybe e
fromException SomeException
x) Maybe AsyncException -> Maybe AsyncException -> Bool
forall a. Eq a => a -> a -> Bool
/= AsyncException -> Maybe AsyncException
forall a. a -> Maybe a
Just AsyncException
ThreadKilled) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ Priority -> String -> IO ()
log' Priority
ERROR (String
"HTTPS request failed with: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ SomeException -> String
forall a. Show a => a -> String
show SomeException
x)
catchSome :: IO () -> (SomeException -> IO ()) -> IO ()
catchSome IO ()
op SomeException -> IO ()
h =
IO ()
op IO () -> [Handler ()] -> IO ()
forall a. IO a -> [Handler a] -> IO a
`E.catches` [ (SomeSSLException -> IO ()) -> Handler ()
forall a e. Exception e => (e -> IO a) -> Handler a
Handler ((SomeSSLException -> IO ()) -> Handler ())
-> (SomeSSLException -> IO ()) -> Handler ()
forall a b. (a -> b) -> a -> b
$ SomeSSLException -> IO ()
ignoreSSLException
, (ArithException -> IO ()) -> Handler ()
forall a e. Exception e => (e -> IO a) -> Handler a
Handler ((ArithException -> IO ()) -> Handler ())
-> (ArithException -> IO ()) -> Handler ()
forall a b. (a -> b) -> a -> b
$ \(ArithException
e :: ArithException) -> SomeException -> IO ()
h (ArithException -> SomeException
forall e. Exception e => e -> SomeException
toException ArithException
e)
, (ArrayException -> IO ()) -> Handler ()
forall a e. Exception e => (e -> IO a) -> Handler a
Handler ((ArrayException -> IO ()) -> Handler ())
-> (ArrayException -> IO ()) -> Handler ()
forall a b. (a -> b) -> a -> b
$ \(ArrayException
e :: ArrayException) -> SomeException -> IO ()
h (ArrayException -> SomeException
forall e. Exception e => e -> SomeException
toException ArrayException
e)
, (IOException -> IO ()) -> Handler ()
forall a e. Exception e => (e -> IO a) -> Handler a
Handler ((IOException -> IO ()) -> Handler ())
-> (IOException -> IO ()) -> Handler ()
forall a b. (a -> b) -> a -> b
$ \(IOException
e :: IOException) ->
if IOException -> Bool
isFullError IOException
e Bool -> Bool -> Bool
|| IOException -> Bool
isDoesNotExistError IOException
e Bool -> Bool -> Bool
|| IOException -> Bool
isResourceVanishedError IOException
e
then () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
else Priority -> String -> IO ()
log' Priority
ERROR (String
"HTTPS accept loop ignoring " String -> String -> String
forall a. [a] -> [a] -> [a]
++ IOException -> String
forall a. Show a => a -> String
show IOException
e)
]
isResourceVanishedError :: IOException -> Bool
isResourceVanishedError :: IOException -> Bool
isResourceVanishedError = IOErrorType -> Bool
isResourceVanishedType (IOErrorType -> Bool)
-> (IOException -> IOErrorType) -> IOException -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IOException -> IOErrorType
ioeGetErrorType
isResourceVanishedType :: IOErrorType -> Bool
isResourceVanishedType :: IOErrorType -> Bool
isResourceVanishedType IOErrorType
ResourceVanished = Bool
True
isResourceVanishedType IOErrorType
_ = Bool
False