{-# LANGUAGE CPP, ScopedTypeVariables #-}
{- | core functions and types for HTTPS support
-}
module Happstack.Server.Internal.TLS where

import Control.Concurrent                         (forkIO, killThread, myThreadId)
import Control.Exception.Extensible               as E
import Control.Monad                              (forever, when)
import Data.Time                                  (UTCTime)
import GHC.IO.Exception                           (IOErrorType(..))
import Happstack.Server.Internal.Listen           (listenOn)
import Happstack.Server.Internal.Handler          (request)
import Happstack.Server.Internal.Socket           (acceptLite)
import Happstack.Server.Internal.TimeoutManager   (cancel, initialize, register)
import Happstack.Server.Internal.TimeoutSocketTLS as TSS
import Happstack.Server.Internal.Types            (Request, Response)
import Network.Socket                             (HostName, PortNumber, Socket, close, socketPort)
import Prelude                                    hiding (catch)
import           OpenSSL                          (withOpenSSL)
import           OpenSSL.Session                  (SSL, SSLContext)
import qualified OpenSSL.Session                  as SSL
import Happstack.Server.Types                     (LogAccess, logMAccess)
import System.IO.Error                            (ioeGetErrorType, isFullError, isDoesNotExistError)
import System.Log.Logger                          (Priority(..), logM)
#ifndef mingw32_HOST_OS
import System.Posix.Signals                       (Handler(Ignore), installHandler, openEndedPipe)
#endif

-- | wrapper around 'logM' for this module
log':: Priority -> String -> IO ()
log' :: Priority -> String -> IO ()
log' = String -> Priority -> String -> IO ()
logM String
"Happstack.Server.Internal.TLS"


-- | configuration for using https:\/\/
data TLSConf = TLSConf {
      TLSConf -> Int
tlsPort      :: Int        -- port (usually 443)
    , TLSConf -> String
tlsCert      :: FilePath   -- path to SSL certificate
    , TLSConf -> String
tlsKey       :: FilePath   -- path to SSL private key
    , TLSConf -> Maybe String
tlsCA        :: Maybe FilePath -- PEM encoded list of CA certificates
    , TLSConf -> Int
tlsTimeout   :: Int        -- kill connect of timeout (in seconds)
    , TLSConf -> Maybe (LogAccess UTCTime)
tlsLogAccess :: Maybe (LogAccess UTCTime) -- see 'logMAccess'
    , TLSConf -> Maybe (Response -> IO Response)
tlsValidator :: Maybe (Response -> IO Response) -- ^ a function to validate the output on-the-fly
    }

-- | a partially complete 'TLSConf' . You must sete 'tlsCert' and 'tlsKey' at a mininum.
nullTLSConf :: TLSConf
nullTLSConf :: TLSConf
nullTLSConf =
    TLSConf :: Int
-> String
-> String
-> Maybe String
-> Int
-> Maybe (LogAccess UTCTime)
-> Maybe (Response -> IO Response)
-> TLSConf
TLSConf { tlsPort :: Int
tlsPort      = Int
443
            , tlsCert :: String
tlsCert      = String
""
            , tlsKey :: String
tlsKey       = String
""
            , tlsCA :: Maybe String
tlsCA        = Maybe String
forall a. Maybe a
Nothing
            , tlsTimeout :: Int
tlsTimeout   = Int
30
            , tlsLogAccess :: Maybe (LogAccess UTCTime)
tlsLogAccess = LogAccess UTCTime -> Maybe (LogAccess UTCTime)
forall a. a -> Maybe a
Just LogAccess UTCTime
forall t. FormatTime t => LogAccess t
logMAccess
            , tlsValidator :: Maybe (Response -> IO Response)
tlsValidator = Maybe (Response -> IO Response)
forall a. Maybe a
Nothing
            }


-- | record that holds the 'Socket' and 'SSLContext' needed to start
-- the https:\/\/ event loop. Used with 'simpleHTTPWithSocket''
--
-- see also: 'httpOnSocket'
data HTTPS = HTTPS
    { HTTPS -> Socket
httpsSocket :: Socket
    , HTTPS -> SSLContext
sslContext  :: SSLContext
    }

-- | generate the 'HTTPS' record needed to start the https:\/\/ event loop
--
httpsOnSocket :: FilePath  -- ^ path to ssl certificate
              -> FilePath  -- ^ path to ssl private key
              -> Maybe FilePath -- ^ path to PEM encoded list of CA certificates
              -> Socket    -- ^ listening socket (on which listen() has been called, but not accept())
              -> IO HTTPS
httpsOnSocket :: String -> String -> Maybe String -> Socket -> IO HTTPS
httpsOnSocket String
cert String
key Maybe String
mca Socket
socket =
    do SSLContext
ctx <- IO SSLContext
SSL.context
       SSLContext -> String -> IO ()
SSL.contextSetPrivateKeyFile  SSLContext
ctx String
key
       SSLContext -> String -> IO ()
SSL.contextSetCertificateFile SSLContext
ctx String
cert
       case Maybe String
mca of
         Maybe String
Nothing   -> () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
         (Just String
ca) -> SSLContext -> String -> IO ()
SSL.contextSetCAFile SSLContext
ctx String
ca
       SSLContext -> IO ()
SSL.contextSetDefaultCiphers  SSLContext
ctx

       Bool
certOk <- SSLContext -> IO Bool
SSL.contextCheckPrivateKey SSLContext
ctx
       Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Bool -> Bool
not Bool
certOk) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ String -> IO ()
forall a. HasCallStack => String -> a
error (String -> IO ()) -> String -> IO ()
forall a b. (a -> b) -> a -> b
$ String
"OpenTLS certificate and key do not match."

       HTTPS -> IO HTTPS
forall (m :: * -> *) a. Monad m => a -> m a
return (Socket -> SSLContext -> HTTPS
HTTPS Socket
socket SSLContext
ctx)

-- | accept a TLS connection
acceptTLS :: Socket      -- ^ the socket returned from 'acceptLite'
          -> SSLContext
          -> IO SSL
acceptTLS :: Socket -> SSLContext -> IO SSL
acceptTLS Socket
sck SSLContext
ctx =
      (SomeException -> IO SSL) -> IO SSL -> IO SSL
forall e a. Exception e => (e -> IO a) -> IO a -> IO a
handle (\ (SomeException
e :: SomeException) -> Socket -> IO ()
close Socket
sck IO () -> IO SSL -> IO SSL
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> SomeException -> IO SSL
forall e a. Exception e => e -> IO a
throwIO SomeException
e) (IO SSL -> IO SSL) -> IO SSL -> IO SSL
forall a b. (a -> b) -> a -> b
$ do
          SSL
ssl <- SSLContext -> Socket -> IO SSL
SSL.connection SSLContext
ctx Socket
sck
          SSL -> IO ()
SSL.accept SSL
ssl
          SSL -> IO SSL
forall (m :: * -> *) a. Monad m => a -> m a
return SSL
ssl

-- | https:// 'Request'/'Response' loop
--
-- This function initializes SSL, and starts accepting and handling
-- 'Request's and sending 'Respone's.
--
-- Each 'Request' is processed in a separate thread.
listenTLS :: TLSConf                  -- ^ tls configuration
          -> (Request -> IO Response) -- ^ request handler
          -> IO ()
listenTLS :: TLSConf -> (Request -> IO Response) -> IO ()
listenTLS TLSConf
tlsConf Request -> IO Response
hand =
    do IO () -> IO ()
forall a. IO a -> IO a
withOpenSSL (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
       Socket
tlsSocket <- Int -> IO Socket
listenOn (TLSConf -> Int
tlsPort TLSConf
tlsConf)
       HTTPS
https     <- String -> String -> Maybe String -> Socket -> IO HTTPS
httpsOnSocket (TLSConf -> String
tlsCert TLSConf
tlsConf) (TLSConf -> String
tlsKey TLSConf
tlsConf) (TLSConf -> Maybe String
tlsCA TLSConf
tlsConf) Socket
tlsSocket
       Int
-> Maybe (LogAccess UTCTime)
-> HTTPS
-> (Request -> IO Response)
-> IO ()
listenTLS' (TLSConf -> Int
tlsTimeout TLSConf
tlsConf) (TLSConf -> Maybe (LogAccess UTCTime)
tlsLogAccess TLSConf
tlsConf) HTTPS
https Request -> IO Response
hand

-- | low-level https:// 'Request'/'Response' loop
--
-- This is the low-level loop that reads 'Request's and sends
-- 'Respone's. It assumes that SSL has already been initialized and
-- that socket is listening.
--
-- Each 'Request' is processed in a separate thread.
--
-- see also: 'listenTLS'
listenTLS' :: Int -> Maybe (LogAccess UTCTime) -> HTTPS -> (Request -> IO Response) -> IO ()
listenTLS' :: Int
-> Maybe (LogAccess UTCTime)
-> HTTPS
-> (Request -> IO Response)
-> IO ()
listenTLS' Int
timeout Maybe (LogAccess UTCTime)
mlog https :: HTTPS
https@(HTTPS Socket
lsocket SSLContext
_) Request -> IO Response
handler = do
#ifndef mingw32_HOST_OS
  Signal -> Handler -> Maybe SignalSet -> IO Handler
installHandler Signal
openEndedPipe Handler
Ignore Maybe SignalSet
forall a. Maybe a
Nothing
#endif
  Manager
tm <- Int -> IO Manager
initialize (Int
timeout Int -> Int -> Int
forall a. Num a => a -> a -> a
* (Int
10Int -> Int -> Int
forall a b. (Num a, Integral b) => a -> b -> a
^(Int
6 :: Int)))
  do let work :: (Socket, SSL, HostName, PortNumber) -> IO ()
         work :: (Socket, SSL, String, PortNumber) -> IO ()
work (Socket
socket, SSL
ssl, String
hn, PortNumber
p) =
             do -- add this thread to the timeout table
                ThreadId
tid     <- IO ThreadId
myThreadId
                Handle
thandle <- Manager -> IO () -> IO Handle
register Manager
tm (IO () -> IO Handle) -> IO () -> IO Handle
forall a b. (a -> b) -> a -> b
$ do Socket -> SSL -> IO ()
shutdownClose Socket
socket SSL
ssl
                                            ThreadId -> IO ()
killThread ThreadId
tid
                -- handle the request
                let timeoutIO :: TimeoutIO
timeoutIO = Handle -> Socket -> SSL -> TimeoutIO
TSS.timeoutSocketIO Handle
thandle Socket
socket SSL
ssl

                TimeoutIO
-> Maybe (LogAccess UTCTime)
-> Host
-> (Request -> IO Response)
-> IO ()
request TimeoutIO
timeoutIO Maybe (LogAccess UTCTime)
mlog (String
hn, PortNumber -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral PortNumber
p) Request -> IO Response
handler
                            IO () -> [Handler ()] -> IO ()
forall a. IO a -> [Handler a] -> IO a
`E.catches` [ (ConnectionAbruptlyTerminated -> IO ()) -> Handler ()
forall a e. Exception e => (e -> IO a) -> Handler a
Handler ConnectionAbruptlyTerminated -> IO ()
ignoreConnectionAbruptlyTerminated
                                        , (SomeException -> IO ()) -> Handler ()
forall a e. Exception e => (e -> IO a) -> Handler a
Handler SomeException -> IO ()
ehs
                                        ]

                -- remove thread from timeout table
                Handle -> IO ()
cancel Handle
thandle

                -- close connection
                Socket -> SSL -> IO ()
shutdownClose Socket
socket SSL
ssl

         loop :: IO ()
         loop :: IO ()
loop = IO () -> IO ()
forall (f :: * -> *) a b. Applicative f => f a -> f b
forever (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do -- do a normal accept
                             (Socket
sck, String
peer, PortNumber
port) <- Socket -> IO (Socket, String, PortNumber)
acceptLite (HTTPS -> Socket
httpsSocket HTTPS
https)
                             IO () -> IO ThreadId
forkIO (IO () -> IO ThreadId) -> IO () -> IO ThreadId
forall a b. (a -> b) -> a -> b
$ do -- do the TLS accept/handshake
                                         SSL
ssl <- Socket -> SSLContext -> IO SSL
acceptTLS Socket
sck (HTTPS -> SSLContext
sslContext HTTPS
https)
                                         (Socket, SSL, String, PortNumber) -> IO ()
work (Socket
sck, SSL
ssl, String
peer, PortNumber
port)
                                           IO () -> (SomeException -> IO ()) -> IO ()
forall e a. Exception e => IO a -> (e -> IO a) -> IO a
`catch` (\(SomeException
e :: SomeException) -> do
                                                          Socket -> SSL -> IO ()
shutdownClose Socket
sck SSL
ssl
                                                          SomeException -> IO ()
forall e a. Exception e => e -> IO a
throwIO SomeException
e)
                             () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
         pe :: a -> IO ()
pe a
e = Priority -> String -> IO ()
log' Priority
ERROR (String
"ERROR in https accept thread: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ a -> String
forall a. Show a => a -> String
show a
e)
         infi :: IO b
infi = IO ()
loop IO () -> (SomeException -> IO ()) -> IO ()
`catchSome` SomeException -> IO ()
forall a. Show a => a -> IO ()
pe IO () -> IO b -> IO b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> IO b
infi
     -- sockName <- getSocketName lsocket
     PortNumber
sockPort <- Socket -> IO PortNumber
socketPort Socket
lsocket
     Priority -> String -> IO ()
log' Priority
NOTICE (String
"Listening for https:// on port " String -> String -> String
forall a. [a] -> [a] -> [a]
++ PortNumber -> String
forall a. Show a => a -> String
show PortNumber
sockPort)
     (IO ()
forall b. IO b
infi IO () -> (SomeException -> IO ()) -> IO ()
forall e a. Exception e => IO a -> (e -> IO a) -> IO a
`catch` (\SomeException
e -> do Priority -> String -> IO ()
log' Priority
ERROR (String
"https:// terminated by " String -> String -> String
forall a. [a] -> [a] -> [a]
++ SomeException -> String
forall a. Show a => a -> String
show (SomeException
e :: SomeException))
                             SomeException -> IO ()
forall e a. Exception e => e -> IO a
throwIO SomeException
e))
       IO () -> IO () -> IO ()
forall a b. IO a -> IO b -> IO a
`finally` (Socket -> IO ()
close Socket
lsocket)

         where
           shutdownClose :: Socket -> SSL -> IO ()
           shutdownClose :: Socket -> SSL -> IO ()
shutdownClose Socket
socket SSL
ssl =
               do SSL -> ShutdownType -> IO ()
SSL.shutdown SSL
ssl ShutdownType
SSL.Unidirectional IO () -> (SomeException -> IO ()) -> IO ()
forall e a. Exception e => IO a -> (e -> IO a) -> IO a
`E.catch` SomeException -> IO ()
ignoreException
                  Socket -> IO ()
close Socket
socket                       IO () -> (SomeException -> IO ()) -> IO ()
forall e a. Exception e => IO a -> (e -> IO a) -> IO a
`E.catch` SomeException -> IO ()
ignoreException

           -- exception handlers
           ignoreConnectionAbruptlyTerminated :: SSL.ConnectionAbruptlyTerminated -> IO ()
           ignoreConnectionAbruptlyTerminated :: ConnectionAbruptlyTerminated -> IO ()
ignoreConnectionAbruptlyTerminated ConnectionAbruptlyTerminated
_ = () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

           ignoreSSLException :: SSL.SomeSSLException -> IO ()
           ignoreSSLException :: SomeSSLException -> IO ()
ignoreSSLException SomeSSLException
_ = () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

           ignoreException :: SomeException -> IO ()
           ignoreException :: SomeException -> IO ()
ignoreException SomeException
_ = () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

           ehs :: SomeException -> IO ()
           ehs :: SomeException -> IO ()
ehs SomeException
x = Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when ((SomeException -> Maybe AsyncException
forall e. Exception e => SomeException -> Maybe e
fromException SomeException
x) Maybe AsyncException -> Maybe AsyncException -> Bool
forall a. Eq a => a -> a -> Bool
/= AsyncException -> Maybe AsyncException
forall a. a -> Maybe a
Just AsyncException
ThreadKilled) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ Priority -> String -> IO ()
log' Priority
ERROR (String
"HTTPS request failed with: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ SomeException -> String
forall a. Show a => a -> String
show SomeException
x)

           catchSome :: IO () -> (SomeException -> IO ()) -> IO ()
catchSome IO ()
op SomeException -> IO ()
h =
               IO ()
op IO () -> [Handler ()] -> IO ()
forall a. IO a -> [Handler a] -> IO a
`E.catches` [ (SomeSSLException -> IO ()) -> Handler ()
forall a e. Exception e => (e -> IO a) -> Handler a
Handler ((SomeSSLException -> IO ()) -> Handler ())
-> (SomeSSLException -> IO ()) -> Handler ()
forall a b. (a -> b) -> a -> b
$ SomeSSLException -> IO ()
ignoreSSLException
                              , (ArithException -> IO ()) -> Handler ()
forall a e. Exception e => (e -> IO a) -> Handler a
Handler ((ArithException -> IO ()) -> Handler ())
-> (ArithException -> IO ()) -> Handler ()
forall a b. (a -> b) -> a -> b
$ \(ArithException
e :: ArithException) -> SomeException -> IO ()
h (ArithException -> SomeException
forall e. Exception e => e -> SomeException
toException ArithException
e)
                              , (ArrayException -> IO ()) -> Handler ()
forall a e. Exception e => (e -> IO a) -> Handler a
Handler ((ArrayException -> IO ()) -> Handler ())
-> (ArrayException -> IO ()) -> Handler ()
forall a b. (a -> b) -> a -> b
$ \(ArrayException
e :: ArrayException) -> SomeException -> IO ()
h (ArrayException -> SomeException
forall e. Exception e => e -> SomeException
toException ArrayException
e)
                              , (IOException -> IO ()) -> Handler ()
forall a e. Exception e => (e -> IO a) -> Handler a
Handler ((IOException -> IO ()) -> Handler ())
-> (IOException -> IO ()) -> Handler ()
forall a b. (a -> b) -> a -> b
$ \(IOException
e :: IOException)    ->
                                  if IOException -> Bool
isFullError IOException
e Bool -> Bool -> Bool
|| IOException -> Bool
isDoesNotExistError IOException
e Bool -> Bool -> Bool
|| IOException -> Bool
isResourceVanishedError IOException
e
                                  then () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return () -- h (toException e) -- we could log the exception, but there could be thousands of them
                                  else Priority -> String -> IO ()
log' Priority
ERROR (String
"HTTPS accept loop ignoring " String -> String -> String
forall a. [a] -> [a] -> [a]
++ IOException -> String
forall a. Show a => a -> String
show IOException
e)
                              ]
           isResourceVanishedError :: IOException -> Bool
           isResourceVanishedError :: IOException -> Bool
isResourceVanishedError = IOErrorType -> Bool
isResourceVanishedType (IOErrorType -> Bool)
-> (IOException -> IOErrorType) -> IOException -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IOException -> IOErrorType
ioeGetErrorType
           isResourceVanishedType :: IOErrorType -> Bool
           isResourceVanishedType :: IOErrorType -> Bool
isResourceVanishedType IOErrorType
ResourceVanished = Bool
True
           isResourceVanishedType IOErrorType
_                = Bool
False