module Web.Postie.Connection(
    Connection,
    StartTLSPolicy(..),

    connRecv,
    connSend,
    connClose,
    connIsSecure,

    connStartTlsPolicy,
    connStartTls,
    connAllowStartTLS,
    connDemandStartTLS,

    socketConnection,

    connectionP
  ) where

import Network.Socket hiding (send, sendTo, recv, recvFrom)
import Network.Socket.ByteString.Lazy (sendAll)
import Network.Socket.ByteString hiding (sendAll)

import Network.TLS
import Crypto.Random.AESCtr

import qualified Data.ByteString as BS
import qualified Data.ByteString.Lazy as LBS
import Data.ByteString.Lazy.Internal (defaultChunkSize)

import Control.Exception (finally)
import Control.Monad.IO.Class

import qualified Pipes as P

-- |Low-level connection abstraction
data Connection = Connection {
    connRecv           :: IO BS.ByteString -- ^ Reads data from connection. Returns empty bytestring if eof is reached.
  , connSend           :: LBS.ByteString -> IO () -- ^ Sends data over the connection.
  , connClose          :: IO ()    -- ^Closes the connection.
  , connIsSecure       :: Bool     -- ^Returns true if this is a TLS-secured connection.
  , connStartTlsPolicy :: StartTLSPolicy
  , connStartTls       :: IO Connection -- ^Creates new connection which is secured by TLS.
  }

data StartTLSPolicy = Always ServerParams | Allow ServerParams | Demand ServerParams | NotAvailable

-- | Upgradeable connection from Socket
socketConnection :: Socket -> StartTLSPolicy -> IO Connection
socketConnection socket policy     = return connection
  where
    connection = Connection {
      connRecv     = recv socket defaultChunkSize
    , connSend     = sendAll socket
    , connClose    = sClose socket
    , connIsSecure = False
    , connStartTlsPolicy = policy
    , connStartTls = secureConnection
    }

    secureConnection = do
      context <- contextNew socket params =<< makeSystem
      handshake context

      return Connection {
          connRecv     = recvData context
        , connSend     = sendData context
        , connClose    = bye context `finally` contextClose context
        , connIsSecure = True
        , connStartTlsPolicy = policy
        , connStartTls = error "already on secure connection"
      }

    params = case policy of
      (Allow p)  -> p
      (Demand p) -> p
      (Always p) -> p
      _          -> error "no upgrade allowed"

connAllowStartTLS :: Connection -> Bool
connAllowStartTLS conn | connIsSecure conn = False
                       | allowedByPolicy (connStartTlsPolicy conn) = True
                       | otherwise         = False
  where
    allowedByPolicy (Allow _)   = True
    allowedByPolicy (Demand _)  = True
    allowedByPolicy _           = False

connDemandStartTLS :: Connection -> Bool
connDemandStartTLS conn | connIsSecure conn = False
                        | demandByPolicy (connStartTlsPolicy conn) = True
                        | otherwise         = False
  where
    demandByPolicy (Demand _) = True
    demandByPolicy _          = False

connectionP :: (MonadIO m) => Connection -> P.Producer' BS.ByteString m ()
connectionP conn = go
  where go = do
          bs <- liftIO $ connRecv conn
          if BS.null bs then
            return ()
            else
              P.yield bs >> go