module Happstack.Server.Internal.Cryptonite.TLS where
import Control.Concurrent                         (forkIO, killThread, myThreadId)
import Control.Exception.Extensible               as E
import Control.Monad                              (forever, when)
import Crypto.Random.EntropyPool
import Data.Default.Class
import Data.Time                                  (UTCTime)
import GHC.IO.Exception                           (IOErrorType(..))
import Happstack.Server.Internal.Listen           (listenOn)
import Happstack.Server.Internal.Handler          (request)
import Happstack.Server.Internal.Socket           (acceptLite)
import Happstack.Server.Internal.TimeoutManager   (cancel, initialize, register)
import Happstack.Server.Internal.Cryptonite.TimeoutSocketTLS as TSS
import Happstack.Server.Internal.Types            (Request, Response)
import Network.Socket                             (HostName, PortNumber, Socket, sClose, socketPort)
import Network.TLS
import Network.TLS.Extra.Cipher
import Happstack.Server.Types                     (LogAccess, logMAccess)
import System.IO.Error                            (ioeGetErrorType, isFullError, isDoesNotExistError)
import System.Log.Logger                          (Priority(..), logM)
#ifndef mingw32_HOST_OS
import System.Posix.Signals                       (Handler(Ignore), installHandler, openEndedPipe)
#endif
log':: Priority -> String -> IO ()
log' = logM "Happstack.Server.Internal.TLS"
data TLSConf = TLSConf {
      tlsPort      :: Int        
    , tlsCert      :: FilePath   
    , tlsKey       :: FilePath   
    , tlsCA        :: Maybe FilePath 
    , tlsTimeout   :: Int        
    , tlsLogAccess :: Maybe (LogAccess UTCTime) 
    , tlsValidator :: Maybe (Response -> IO Response) 
    }
nullTLSConf :: TLSConf
nullTLSConf =
    TLSConf { tlsPort      = 443
            , tlsCert      = ""
            , tlsKey       = ""
            , tlsCA        = Nothing
            , tlsTimeout   = 30
            , tlsLogAccess = Just logMAccess
            , tlsValidator = Nothing
            }
data HTTPS = HTTPS
    { httpsSocket :: Socket
    , sslContext  :: ServerParams
    }
httpsOnSocket :: FilePath  
              -> FilePath  
              -> Maybe FilePath 
              -> Socket    
              -> IO HTTPS
httpsOnSocket cert key _ socket =
    do creds <- credentialLoadX509 cert key
       let credentials = either (\msg -> error $ "Can't load certificate " ++ cert ++ " and key " ++ key ++ ": " ++ msg) id creds
       let params = def {
            serverSupported = def { supportedCiphers = ciphersuite_strong },
            serverShared = def {
                sharedCredentials = Credentials [credentials]
             }
         }
       return (HTTPS socket params)
acceptTLS :: Socket      
          -> ServerParams
          -> IO Context
acceptTLS sck params =
      handle (\ (e :: SomeException) -> sClose sck >> throwIO e) $ do
          ssl <- contextNew sck params
          handshake ssl
          return ssl
listenTLS :: TLSConf                  
          -> (Request -> IO Response) 
          -> IO ()
listenTLS tlsConf hand =
    do
       tlsSocket <- listenOn (tlsPort tlsConf)
       https     <- httpsOnSocket (tlsCert tlsConf) (tlsKey tlsConf) (tlsCA tlsConf) tlsSocket
       listenTLS' (tlsTimeout tlsConf) (tlsLogAccess tlsConf) https hand
listenTLS' :: Int -> Maybe (LogAccess UTCTime) -> HTTPS -> (Request -> IO Response) -> IO ()
listenTLS' timeout mlog https@(HTTPS lsocket _) handler = do
#ifndef mingw32_HOST_OS
  installHandler openEndedPipe Ignore Nothing
#endif
  tm <- initialize (timeout * (10^(6 :: Int)))
  do let work :: (Socket, Context, HostName, PortNumber) -> IO ()
         work (socket, ssl, hn, p) =
             do 
                tid     <- myThreadId
                thandle <- register tm $ do shutdownClose socket ssl
                                            killThread tid
                
                let timeoutIO = TSS.timeoutSocketIO thandle socket ssl
                request timeoutIO mlog (hn, fromIntegral p) handler
                            `E.catches` [ Handler ignoreConnectionAbruptlyTerminated
                                        , Handler ehs
                                        ]
                
                cancel thandle
                
                shutdownClose socket ssl
         loop :: IO ()
         loop = forever $ do 
                             (sck, peer, port) <- acceptLite (httpsSocket https)
                             forkIO $ do 
                                         ssl <- acceptTLS sck (sslContext https)
                                         work (sck, ssl, peer, port)
                                           `catch` (\(e :: SomeException) -> do
                                                          shutdownClose sck ssl
                                                          throwIO e)
                             return ()
         pe e = log' ERROR ("ERROR in https accept thread: " ++ show e)
         infi = loop `catchSome` pe >> infi
     
     sockPort <- socketPort lsocket
     log' NOTICE ("Listening for https:// on port " ++ show sockPort)
     (infi `catch` (\e -> do log' ERROR ("https:// terminated by " ++ show (e :: SomeException))
                             throwIO e))
       `finally` (sClose lsocket)
         where
           shutdownClose :: Socket -> Context -> IO ()
           shutdownClose _ ssl =
               do bye ssl          `E.catch` ignoreException
                  contextClose ssl `E.catch` ignoreException
           
           ignoreConnectionAbruptlyTerminated :: TLSException -> IO ()  
           ignoreConnectionAbruptlyTerminated _ = return ()
           ignoreSSLException :: TLSException -> IO ()
           ignoreSSLException _ = return ()
           ignoreException :: SomeException -> IO ()
           ignoreException _ = return ()
           ehs :: SomeException -> IO ()
           ehs x = when ((fromException x) /= Just ThreadKilled) $ log' ERROR ("HTTPS request failed with: " ++ show x)
           catchSome op h =
               op `E.catches` [ Handler $ ignoreSSLException
                              , Handler $ \(e :: ArithException) -> h (toException e)
                              , Handler $ \(e :: ArrayException) -> h (toException e)
                              , Handler $ \(e :: IOException)    ->
                                  if isFullError e || isDoesNotExistError e || isResourceVanishedError e
                                  then return () 
                                  else log' ERROR ("HTTPS accept loop ignoring " ++ show e)
                              ]
           isResourceVanishedError :: IOException -> Bool
           isResourceVanishedError = isResourceVanishedType . ioeGetErrorType
           isResourceVanishedType :: IOErrorType -> Bool
           isResourceVanishedType ResourceVanished = True
           isResourceVanishedType _                = False