module Network.HaskellNet.SMTP.SSL
    ( -- * Establishing connection
      connectSMTPSSL
    , connectSMTPSSLWithSettings
    , connectSMTPSTARTTLS
    , connectSMTPSTARTTLSWithSettings
      -- * Other Useful Operations
    , doSMTPSSL
    , doSMTPSSLWithSettings
    , doSMTPSTARTTLS
    , doSMTPSTARTTLSWithSettings
      -- * Settings
    , Settings(..)
    , defaultSettingsSMTPSSL
    , defaultSettingsSMTPSTARTTLS
      -- * Network.HaskellNet.SMTP re-exports
    , module Network.HaskellNet.SMTP
    ) where

import Network.HaskellNet.SMTP
import Network.HaskellNet.SSL

import Network.HaskellNet.SSL.Internal

import Network.HaskellNet.BSStream
import Network.BSD (getHostName)

import qualified Data.ByteString.Char8 as B

import Control.Exception
import Control.Monad
import Data.IORef

connectSMTPSSL :: String -> IO SMTPConnection
connectSMTPSSL :: String -> IO SMTPConnection
connectSMTPSSL String
hostname = String -> Settings -> IO SMTPConnection
connectSMTPSSLWithSettings String
hostname Settings
defaultSettingsSMTPSSL

connectSMTPSSLWithSettings :: String -> Settings -> IO SMTPConnection
connectSMTPSSLWithSettings :: String -> Settings -> IO SMTPConnection
connectSMTPSSLWithSettings String
hostname Settings
cfg = String -> Settings -> IO BSStream
connectSSL String
hostname Settings
cfg IO BSStream -> (BSStream -> IO SMTPConnection) -> IO SMTPConnection
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= HasCallStack => BSStream -> IO SMTPConnection
BSStream -> IO SMTPConnection
connectStream

connectSMTPSTARTTLS :: String -> IO SMTPConnection
connectSMTPSTARTTLS :: String -> IO SMTPConnection
connectSMTPSTARTTLS String
hostname = String -> Settings -> IO SMTPConnection
connectSMTPSTARTTLSWithSettings String
hostname Settings
defaultSettingsSMTPSTARTTLS

connectSMTPSTARTTLSWithSettings :: String -> Settings -> IO SMTPConnection
connectSMTPSTARTTLSWithSettings :: String -> Settings -> IO SMTPConnection
connectSMTPSTARTTLSWithSettings String
hostname Settings
cfg = String -> Settings -> IO BSStream
connectSTARTTLS String
hostname Settings
cfg IO BSStream -> (BSStream -> IO SMTPConnection) -> IO SMTPConnection
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= HasCallStack => BSStream -> IO SMTPConnection
BSStream -> IO SMTPConnection
connectStream

connectSTARTTLS :: String -> Settings -> IO BSStream
connectSTARTTLS :: String -> Settings -> IO BSStream
connectSTARTTLS String
hostname Settings
cfg = do
    (BSStream
bs, STARTTLS
startTLS) <- String -> Settings -> IO (BSStream, STARTTLS)
connectPlain String
hostname Settings
cfg

    ByteString
greeting <- BSStream -> IO ByteString
bsGetLine BSStream
bs
    BSStream -> Integer -> (Integer, String) -> STARTTLS
failIfNot BSStream
bs Integer
220 ((Integer, String) -> STARTTLS) -> (Integer, String) -> STARTTLS
forall a b. (a -> b) -> a -> b
$ String -> (Integer, String)
parse (String -> (Integer, String)) -> String -> (Integer, String)
forall a b. (a -> b) -> a -> b
$ ByteString -> String
B.unpack ByteString
greeting

    String
hn <- IO String
getHostName
    BSStream -> ByteString -> STARTTLS
bsPut BSStream
bs (ByteString -> STARTTLS) -> ByteString -> STARTTLS
forall a b. (a -> b) -> a -> b
$ String -> ByteString
B.pack (String
"HELO " String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
hn String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"\r\n")
    BSStream -> IO (Integer, String)
getResponse BSStream
bs IO (Integer, String) -> ((Integer, String) -> STARTTLS) -> STARTTLS
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= BSStream -> Integer -> (Integer, String) -> STARTTLS
failIfNot BSStream
bs Integer
250
    BSStream -> ByteString -> STARTTLS
bsPut BSStream
bs (ByteString -> STARTTLS) -> ByteString -> STARTTLS
forall a b. (a -> b) -> a -> b
$ String -> ByteString
B.pack (String
"EHLO " String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
hn String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"\r\n")
    BSStream -> IO (Integer, String)
getResponse BSStream
bs IO (Integer, String) -> ((Integer, String) -> STARTTLS) -> STARTTLS
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= BSStream -> Integer -> (Integer, String) -> STARTTLS
failIfNot BSStream
bs Integer
250
    BSStream -> ByteString -> STARTTLS
bsPut BSStream
bs (ByteString -> STARTTLS) -> ByteString -> STARTTLS
forall a b. (a -> b) -> a -> b
$ String -> ByteString
B.pack String
"STARTTLS\r\n"
    BSStream -> IO (Integer, String)
getResponse BSStream
bs IO (Integer, String) -> ((Integer, String) -> STARTTLS) -> STARTTLS
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= BSStream -> Integer -> (Integer, String) -> STARTTLS
failIfNot BSStream
bs Integer
220

    STARTTLS
startTLS

    IORef [ByteString]
prefixRef <- [ByteString] -> IO (IORef [ByteString])
forall a. a -> IO (IORef a)
newIORef [ByteString
greeting]
    BSStream -> IO BSStream
forall (m :: * -> *) a. Monad m => a -> m a
return (BSStream -> IO BSStream) -> BSStream -> IO BSStream
forall a b. (a -> b) -> a -> b
$ BSStream
bs {bsGetLine :: IO ByteString
bsGetLine = IORef [ByteString] -> IO ByteString -> IO ByteString
prefixedGetLine IORef [ByteString]
prefixRef (BSStream -> IO ByteString
bsGetLine BSStream
bs)}
  where getFinalResponse :: BSStream -> IO String
getFinalResponse BSStream
bs = do
            String
line <- (ByteString -> String) -> IO ByteString -> IO String
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ByteString -> String
B.unpack (IO ByteString -> IO String) -> IO ByteString -> IO String
forall a b. (a -> b) -> a -> b
$ BSStream -> IO ByteString
bsGetLine BSStream
bs
            if (String
line String -> Int -> Char
forall a. [a] -> Int -> a
!! Int
3) Char -> Char -> Bool
forall a. Eq a => a -> a -> Bool
== Char
'-' then BSStream -> IO String
getFinalResponse BSStream
bs else String -> IO String
forall (m :: * -> *) a. Monad m => a -> m a
return String
line
        parse :: String -> (Integer, String)
parse String
s = (String -> Integer
getCode String
s, String
s)
        getCode :: String -> Integer
getCode = String -> Integer
forall a. Read a => String -> a
read (String -> Integer) -> (String -> String) -> String -> Integer
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [String] -> String
forall a. [a] -> a
head ([String] -> String) -> (String -> [String]) -> String -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> [String]
words
        getResponse :: BSStream -> IO (Integer, String)
getResponse BSStream
bs = (String -> (Integer, String)) -> IO String -> IO (Integer, String)
forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
liftM String -> (Integer, String)
parse (IO String -> IO (Integer, String))
-> IO String -> IO (Integer, String)
forall a b. (a -> b) -> a -> b
$ BSStream -> IO String
getFinalResponse BSStream
bs

failIfNot :: BSStream -> Integer -> (Integer, String) -> IO ()
failIfNot :: BSStream -> Integer -> (Integer, String) -> STARTTLS
failIfNot BSStream
bs Integer
code (Integer
rc, String
rs) = Bool -> STARTTLS -> STARTTLS
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Integer
code Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
/= Integer
rc) STARTTLS
forall b. IO b
closeAndFail
  where closeAndFail :: IO b
closeAndFail = BSStream -> STARTTLS
bsClose BSStream
bs STARTTLS -> IO b -> IO b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> String -> IO b
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String
"cannot connect to server: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
rs)

-- This is a bit of a nasty hack.  Network.HaskellNet.SMTP.connectStream
-- expects to receive a status 220 from the server as soon as it connects,
-- but we've intercepted it in order to establish a STARTTLS connection.
-- This allows us to keep hold of the original greeting and pass it back to
-- HaskellNet.
prefixedGetLine :: IORef [B.ByteString] -> IO B.ByteString -> IO B.ByteString
prefixedGetLine :: IORef [ByteString] -> IO ByteString -> IO ByteString
prefixedGetLine IORef [ByteString]
prefix IO ByteString
rawGetLine = IORef [ByteString] -> IO [ByteString]
forall a. IORef a -> IO a
readIORef IORef [ByteString]
prefix IO [ByteString] -> ([ByteString] -> IO ByteString) -> IO ByteString
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= [ByteString] -> IO ByteString
deliverLine
  where deliverLine :: [ByteString] -> IO ByteString
deliverLine [] = IO ByteString
rawGetLine
        deliverLine (ByteString
l:[ByteString]
ls) = IORef [ByteString] -> [ByteString] -> STARTTLS
forall a. IORef a -> a -> STARTTLS
writeIORef IORef [ByteString]
prefix [ByteString]
ls STARTTLS -> IO ByteString -> IO ByteString
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> ByteString -> IO ByteString
forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
l

bracketSMTP :: IO SMTPConnection -> (SMTPConnection -> IO a) -> IO a
bracketSMTP :: IO SMTPConnection -> (SMTPConnection -> IO a) -> IO a
bracketSMTP = (IO SMTPConnection
 -> (SMTPConnection -> STARTTLS)
 -> (SMTPConnection -> IO a)
 -> IO a)
-> (SMTPConnection -> STARTTLS)
-> IO SMTPConnection
-> (SMTPConnection -> IO a)
-> IO a
forall a b c. (a -> b -> c) -> b -> a -> c
flip IO SMTPConnection
-> (SMTPConnection -> STARTTLS) -> (SMTPConnection -> IO a) -> IO a
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracket SMTPConnection -> STARTTLS
closeSMTP

doSMTPSSL :: String -> (SMTPConnection -> IO a) -> IO a
doSMTPSSL :: String -> (SMTPConnection -> IO a) -> IO a
doSMTPSSL String
host = IO SMTPConnection -> (SMTPConnection -> IO a) -> IO a
forall a. IO SMTPConnection -> (SMTPConnection -> IO a) -> IO a
bracketSMTP (IO SMTPConnection -> (SMTPConnection -> IO a) -> IO a)
-> IO SMTPConnection -> (SMTPConnection -> IO a) -> IO a
forall a b. (a -> b) -> a -> b
$ String -> IO SMTPConnection
connectSMTPSSL String
host

doSMTPSSLWithSettings :: String -> Settings -> (SMTPConnection -> IO a) -> IO a
doSMTPSSLWithSettings :: String -> Settings -> (SMTPConnection -> IO a) -> IO a
doSMTPSSLWithSettings String
host Settings
port = IO SMTPConnection -> (SMTPConnection -> IO a) -> IO a
forall a. IO SMTPConnection -> (SMTPConnection -> IO a) -> IO a
bracketSMTP (IO SMTPConnection -> (SMTPConnection -> IO a) -> IO a)
-> IO SMTPConnection -> (SMTPConnection -> IO a) -> IO a
forall a b. (a -> b) -> a -> b
$ String -> Settings -> IO SMTPConnection
connectSMTPSSLWithSettings String
host Settings
port

doSMTPSTARTTLS :: String -> (SMTPConnection -> IO a) -> IO a
doSMTPSTARTTLS :: String -> (SMTPConnection -> IO a) -> IO a
doSMTPSTARTTLS String
host = IO SMTPConnection -> (SMTPConnection -> IO a) -> IO a
forall a. IO SMTPConnection -> (SMTPConnection -> IO a) -> IO a
bracketSMTP (IO SMTPConnection -> (SMTPConnection -> IO a) -> IO a)
-> IO SMTPConnection -> (SMTPConnection -> IO a) -> IO a
forall a b. (a -> b) -> a -> b
$ String -> IO SMTPConnection
connectSMTPSTARTTLS String
host

doSMTPSTARTTLSWithSettings :: String -> Settings -> (SMTPConnection -> IO a) -> IO a
doSMTPSTARTTLSWithSettings :: String -> Settings -> (SMTPConnection -> IO a) -> IO a
doSMTPSTARTTLSWithSettings String
host Settings
port = IO SMTPConnection -> (SMTPConnection -> IO a) -> IO a
forall a. IO SMTPConnection -> (SMTPConnection -> IO a) -> IO a
bracketSMTP (IO SMTPConnection -> (SMTPConnection -> IO a) -> IO a)
-> IO SMTPConnection -> (SMTPConnection -> IO a) -> IO a
forall a b. (a -> b) -> a -> b
$ String -> Settings -> IO SMTPConnection
connectSMTPSTARTTLSWithSettings String
host Settings
port

defaultSettingsSMTPSSL :: Settings
defaultSettingsSMTPSSL :: Settings
defaultSettingsSMTPSSL = PortNumber -> Settings
defaultSettingsWithPort PortNumber
465

defaultSettingsSMTPSTARTTLS :: Settings
defaultSettingsSMTPSTARTTLS :: Settings
defaultSettingsSMTPSTARTTLS = PortNumber -> Settings
defaultSettingsWithPort PortNumber
587