module Web.Postie.Settings(
Settings(..)
, defaultSettings
, TLSSettings(..)
, StartTLSPolicy(..)
, tlsSettings
, defaultTLSSettings
, defaultExceptionHandler
, settingsStartTLSPolicy
, settingsConnectWithTLS
, settingsAllowStartTLS
, settingsDemandStartTLS
) where
import Web.Postie.Types
import Web.Postie.Address
import Web.Postie.SessionID
import qualified Web.Postie.Connection as Connection
import Network (HostName, PortID(..))
import System.IO (hPrint, stderr)
import System.IO.Error (ioeGetErrorType)
import Data.ByteString (ByteString)
import qualified Network.TLS as TLS
import qualified Network.TLS.Extra.Cipher as TLS
import Data.Default.Class
import Data.Maybe (fromMaybe)
import Control.Exception
import GHC.IO.Exception (IOErrorType(..))
import Control.Monad.Trans.Class
import Control.Monad.Trans.Maybe
import Control.Applicative ((<$>))
data Settings = Settings {
settingsPort :: PortID
, settingsTimeout :: Int
, settingsMaxDataSize :: Int
, settingsHost :: Maybe HostName
, settingsTLS :: Maybe TLSSettings
, settingsOnException :: Maybe SessionID -> SomeException -> IO ()
, settingsBeforeMainLoop :: IO ()
, settingsOnOpen :: SessionID -> IO ()
, settingsOnClose :: SessionID -> IO ()
, settingsOnStartTLS :: SessionID -> IO ()
, settingsOnHello :: SessionID -> ByteString -> IO HandlerResponse
, settingsOnMailFrom :: SessionID -> Address -> IO HandlerResponse
, settingsOnRecipient :: SessionID -> Address -> IO HandlerResponse
}
defaultSettings :: Settings
defaultSettings = Settings {
settingsPort = PortNumber 3001
, settingsTimeout = 1800
, settingsMaxDataSize = 32000
, settingsHost = Nothing
, settingsTLS = Nothing
, settingsOnException = defaultExceptionHandler
, settingsBeforeMainLoop = return ()
, settingsOnOpen = const $ return ()
, settingsOnClose = const $ return ()
, settingsOnStartTLS = const $ return ()
, settingsOnHello = void
, settingsOnMailFrom = void
, settingsOnRecipient = void
}
where
void = \_ _ -> return Accepted
data TLSSettings = TLSSettings {
certFile :: FilePath
, keyFile :: FilePath
, security :: StartTLSPolicy
, tlsLogging :: TLS.Logging
, tlsAllowedVersions :: [TLS.Version]
, tlsCiphers :: [TLS.Cipher]
}
data StartTLSPolicy = AllowStartTLS
| DemandStartTLS
| ConnectWithTLS
deriving (Eq, Show)
defaultTLSSettings :: TLSSettings
defaultTLSSettings = TLSSettings {
certFile = "certificate.pem"
, keyFile = "key.pem"
, security = DemandStartTLS
, tlsLogging = def
, tlsAllowedVersions = [TLS.SSL3,TLS.TLS10,TLS.TLS11,TLS.TLS12]
, tlsCiphers = TLS.ciphersuite_all
}
tlsSettings :: FilePath -> FilePath -> TLSSettings
tlsSettings cert key = defaultTLSSettings {
certFile = cert
, keyFile = key
}
settingsConnectWithTLS :: Settings -> Bool
settingsConnectWithTLS = checkSecurity ConnectWithTLS
settingsAllowStartTLS :: Settings -> Bool
settingsAllowStartTLS = checkSecurity AllowStartTLS
settingsDemandStartTLS :: Settings -> Bool
settingsDemandStartTLS = checkSecurity DemandStartTLS
checkSecurity :: StartTLSPolicy -> Settings -> Bool
checkSecurity p s = fromMaybe False $ do
tlss <- settingsTLS s
return (security tlss == p)
settingsStartTLSPolicy :: Settings -> IO Connection.StartTLSPolicy
settingsStartTLSPolicy settings = do
mserverParams <- settingsServerParams settings
return $ case mserverParams of
(Just params) | settingsDemandStartTLS settings -> Connection.Demand params
| settingsAllowStartTLS settings -> Connection.Allow params
| settingsConnectWithTLS settings -> Connection.Always params
_ -> Connection.NotAvailable
settingsServerParams :: Settings -> IO (Maybe TLS.ServerParams)
settingsServerParams settings = runMaybeT $ do
tlss <- MaybeT . return $ settingsTLS settings
credential <- lift $ loadCredentials tlss
return def {
TLS.serverShared = def {
TLS.sharedCredentials = TLS.Credentials [credential]
},
TLS.serverSupported = def {
TLS.supportedCiphers = (tlsCiphers tlss)
, TLS.supportedVersions = (tlsAllowedVersions tlss)
}
}
where
loadCredentials tlss = either (throw . TLS.Error_Certificate) id <$>
TLS.credentialLoadX509 (certFile tlss) (keyFile tlss)
defaultExceptionHandler :: Maybe SessionID -> SomeException -> IO ()
defaultExceptionHandler _ e = throwIO e `catches` handlers
where
handlers = [Handler ah, Handler oh, Handler tlsh, Handler th, Handler sh]
ah :: AsyncException -> IO ()
ah ThreadKilled = return ()
ah x = hPrint stderr x
oh :: IOException -> IO ()
oh x
| et == ResourceVanished || et == InvalidArgument = return ()
| otherwise = hPrint stderr x
where
et = ioeGetErrorType x
tlsh :: TLS.TLSException -> IO ()
tlsh (TLS.Terminated _ _ _) = return ()
tlsh (TLS.HandshakeFailed _) = return ()
tlsh x = hPrint stderr x
th :: TLS.TLSError -> IO ()
th TLS.Error_EOF = return ()
th (TLS.Error_Packet_Parsing _) = return ()
th (TLS.Error_Packet _) = return ()
th (TLS.Error_Protocol _) = return ()
th x = hPrint stderr x
sh :: SomeException -> IO ()
sh x = hPrint stderr x