{-# LANGUAGE CPP, DeriveDataTypeable, RecordWildCards, ScopedTypeVariables #-}
-----------------------------------------------------------------------------
{- |
 Module      :  Data.Acid.Remote.TLS
 Copyright   :  PublicDomain

 Maintainer  :  lemmih@gmail.com, jeremy@n-heptane.com
 Portability :  non-portable (uses GHC extensions)

This module provides the same functionality as "Data.Acid.Remote" but
over a secured TLS socket.

-}
module Data.Acid.Remote.TLS
    (
    -- * Server/Client
      acidServerTLS
    , openRemoteStateTLS
    -- * Authentication
    , skipAuthenticationCheck
    , skipAuthenticationPerform
    , sharedSecretCheck
    , sharedSecretPerform

    ) where

import Control.Concurrent        ( forkIO, threadDelay )
import Control.Exception         ( Handler(..), IOException, SomeException, catch, catches, handle
                                 , finally, throwIO )
import Control.Monad             ( forever, when )
import Data.Acid                 ( AcidState, IsAcidic )
import Data.Acid.Remote          ( CommChannel(..), process, processRemoteState, skipAuthenticationCheck
                                 , skipAuthenticationPerform, sharedSecretCheck, sharedSecretPerform )
import Data.SafeCopy             ( SafeCopy )
import GHC.IO.Exception          ( IOErrorType(..) )
import           OpenSSL         ( withOpenSSL )
import           OpenSSL.Session ( SomeSSLException, SSL, SSLContext )
import qualified OpenSSL.Session as SSL
import Network                   ( HostName, PortID(..), Socket, listenOn, sClose, withSocketsDo )
import Network.Socket            as Socket ( Family(..), SockAddr(..), SocketType(..), accept, socket, connect )
import Network.BSD               ( getHostByName, getProtocolNumber, getServicePortNumber, hostAddress )
import System.Directory          ( removeFile )
import System.IO.Error           ( ioeGetErrorType, isFullError, isDoesNotExistError )

debugStrLn :: String -> IO ()
debugStrLn s =
    do putStrLn s -- uncomment to enable debugging
       return ()

initSSLContext :: FilePath  -- ^ path to ssl certificate
               -> FilePath  -- ^ path to ssl private key
               -> IO SSLContext
initSSLContext cert key =
    do ctx <- SSL.context
       SSL.contextSetPrivateKeyFile  ctx key
       SSL.contextSetCertificateFile ctx cert
       SSL.contextSetDefaultCiphers  ctx

       certOk <- SSL.contextCheckPrivateKey ctx
       when (not certOk) $ error $ "OpenTLS certificate and key do not match."

       return ctx

-- | accept a TLS connection
acceptTLS :: SSLContext -> Socket -> IO (Socket, SSL, SockAddr)
acceptTLS ctx sck' =
    do -- do normal accept
      (sck, sockAddr) <- accept sck'
      --  then TLS accept
      handle (\ (e :: SomeException) -> sClose sck >> throwIO e) $ do
          ssl <- SSL.connection ctx sck
          SSL.accept ssl
          return (sck, ssl, sockAddr)

{- | Accept connections on @port@ and handle requests using the given 'AcidState'.
     This call doesn't return.

     The connection is secured using TLS/SSL.

     On Unix®-like systems you can use 'UnixSocket' to communicate
     using a socket file. To control access, you can set the permissions of
     the parent directory which contains the socket file.

     see also: 'openRemoteStateTLS' and 'sharedSecretCheck'.
 -}
acidServerTLS :: SafeCopy st =>
                 FilePath                 -- ^ path to ssl certificate
              -> FilePath                 -- ^ path to ssl private key
              -> (CommChannel -> IO Bool) -- ^ authorization function
              -> PortID                   -- ^ port to list on
              -> AcidState st             -- ^ 'AcidState' to serve
              -> IO ()
acidServerTLS sslCert sslKey checkAuth port acidState
  = withSocketsDo $
    do withOpenSSL $ return ()
       debugStrLn $ "acidServerTLS: listenOn " ++ show port
       tlsSocket <- listenOn port
       debugStrLn $ "acidServerTLS: initSSLContext"
       ctx       <- initSSLContext sslCert sslKey
       let worker :: (Socket, SSL, SockAddr) -> IO ()
           worker (socket, ssl, _sockAddr) =
               do -- TODO: log this connection, sockAddr
                  let socketCommChannel :: CommChannel
                      socketCommChannel = CommChannel
                        { ccPut     = SSL.write ssl
                        , ccGetSome = SSL.read ssl
                        , ccClose   = shutdownClose socket ssl
                        }
                  forkIO $ (do authorized <- checkAuth socketCommChannel
                               when authorized $
                                    ignoreSome $ (process socketCommChannel acidState)
                               ccClose socketCommChannel) `catch` (\(e::SomeException) -> do
                                                                     shutdownClose socket ssl
                                                                     throwIO e)
                  return ()
           loop :: IO ()
           loop = do ignoreSome $ (forever $ worker =<< acceptTLS ctx tlsSocket)
                     loop
       loop `finally` (cleanup tlsSocket `catch` ignoreException)

    where
      cleanup tlsSocket
        = do debugStrLn "acidServerTLS: cleanup."
             sClose tlsSocket
#if !defined(mingw32_HOST_OS) && !defined(cygwin32_HOST_OS) && !defined(_WIN32)
             case port of
               (UnixSocket path) ->
                   removeFile path `catch` (\e -> if isDoesNotExistError e then return () else throwIO e)
               _ -> return ()
#endif

      -- exception handlers
      ignoreConnectionAbruptlyTerminated :: SSL.ConnectionAbruptlyTerminated -> IO ()
      ignoreConnectionAbruptlyTerminated _ = return ()

      ignoreSSLException :: SSL.SomeSSLException -> IO ()
      ignoreSSLException _ = return ()

      ignoreException :: SomeException -> IO ()
      ignoreException _ = return ()

      shutdownClose :: Socket -> SSL -> IO ()
      shutdownClose socket ssl =
          do debugStrLn "acidServerTLS: shutdownClose."
             SSL.shutdown ssl SSL.Unidirectional `catch` ignoreException
             sClose socket `catch` ignoreException

      ignoreSome op =
               op `catches` [ Handler $ ignoreSSLException
                            , Handler $ ignoreConnectionAbruptlyTerminated
                            , Handler $ \(e :: IOException)    ->
                                  if isFullError e || isDoesNotExistError e || isResourceVanishedError e
                                  then return ()
                                  else throwIO e
                              ]

      isResourceVanishedError :: IOException -> Bool
      isResourceVanishedError = isResourceVanishedType . ioeGetErrorType

      isResourceVanishedType :: IOErrorType -> Bool
      isResourceVanishedType ResourceVanished = True
      isResourceVanishedType _                = False

{- | Connect to an acid-state server which is sharing an 'AcidState'.

The connection is secured using SSL/TLS.
 -}
openRemoteStateTLS  :: IsAcidic st =>
                       (CommChannel -> IO ()) -- ^ authentication function, see 'sharedSecretPerform'
                    -> HostName               -- ^ remote host to connect to (ignored when 'PortID' is 'UnixSocket')
                    -> PortID                 -- ^ remote port to connect to
                    -> IO (AcidState st)
openRemoteStateTLS performAuthorization host port
  = do withOpenSSL $ return ()
       processRemoteState reconnect
    where
      sslCommChannel ssl =
         CommChannel { ccGetSome = SSL.read ssl
                     , ccPut     = SSL.write ssl
                     , ccClose   = do SSL.shutdown ssl SSL.Unidirectional
                                      --   close ssl
                     }

      -- | reconnect
      reconnect :: IO CommChannel
      reconnect
          = (do ssl <- connectToTLS host port
                let cc = sslCommChannel ssl
                performAuthorization cc
                return cc
            )
            `catch`
            ((\e -> threadDelay 1000000 >> reconnect) :: IOError -> IO CommChannel)

-- IPV4 support only, sorry
connectToTLS :: HostName
             -> PortID
             -> IO SSL
connectToTLS hostName (Service serv)
  = do port <- getServicePortNumber serv
       connectToTLS hostName (PortNumber port)

connectToTLS hostName (PortNumber port)
  = do proto <- getProtocolNumber "tcp"
       sock <- socket AF_INET Stream proto
       (do he <- getHostByName hostName
           Socket.connect sock (SockAddrInet port (hostAddress he))
           ctx <- SSL.context
           ssl <- SSL.connection ctx sock
           SSL.connect ssl
           return ssl) `catch` (\e -> do print (e :: SomeSSLException)
                                         sClose sock
                                         throwIO e
                               )
#if !defined(mingw32_HOST_OS) && !defined(cygwin32_HOST_OS) && !defined(_WIN32)
connectToTLS _hostName p@(UnixSocket path)
  = do debugStrLn $ "connectToTLS: " ++ show p
       sock <- socket AF_UNIX Stream 0
       (do debugStrLn $ "connectToTLS: connect."
           Socket.connect sock (SockAddrUnix path)
           ctx <- SSL.context
           ssl <- SSL.connection ctx sock
           debugStrLn $ "connectToTLS: connect ssl."
           SSL.connect ssl
           debugStrLn $ "connectToTLS: done."
           return ssl) `catch` (\e -> do print (e :: SomeSSLException)
                                         sClose sock
                                         throwIO e
                               )
#endif