{-# LANGUAGE CPP #-}

module Database.Tds.Transport (contextNew) where

import Data.Monoid((<>),mempty)
import Control.Applicative((<$>),(<*>))

import Network.Socket (Socket,close)
import Network.Socket.ByteString (recv,sendAll)

import qualified Data.ByteString as B
import qualified Data.ByteString.Lazy as LB

import Data.Binary (decode,encode)

import Data.Default.Class (def)
import qualified Network.TLS as TLS
import Network.TLS (ClientParams(..),Supported(..),Shared(..),ValidationCache(..),ValidationCacheResult(..))
import Network.TLS.Extra.Cipher (ciphersuite_strong)
import Data.X509.CertificateStore (CertificateStore(..))
import System.X509 (getSystemCertificateStore)

import Control.Concurrent(MVar(..),newMVar,readMVar,modifyMVar_)

import Database.Tds.Message.Header
#if !MIN_VERSION_tls(1,3,0)
import Crypto.Random(createEntropyPool,cprgCreate,SystemRNG(..))
#endif


-- | [\[MS-TDS\] 3.2.5.2 Sent TLS/SSL Negotiation Packet State](https://docs.microsoft.com/en-us/openspecs/windows_protocols/ms-tds/d62e225b-d865-4ccc-8f73-de1ef49e30d4)
contextNew :: Socket -> String -> IO TLS.Context
contextNew sock host = do
  certStore <- getSystemCertificateStore
  sock' <- newSecureSocket sock
#if MIN_VERSION_tls(1,3,0)
  TLS.contextNew (getBackend sock') (getTlsParams host certStore)
#else
  pool <- createEntropyPool
  TLS.contextNew (getBackend sock') (getTlsParams host certStore) (cprgCreate pool :: SystemRNG)
#endif



data SecureSocket = SecureSocket{ getSocket::Socket
                                , getSendBuff::MVar B.ByteString
                                , getSendStep::MVar Int
                                , getRecvBuff::MVar B.ByteString
                                }

newSecureSocket sock = SecureSocket sock <$> newMVar mempty <*> newMVar 0 <*> newMVar mempty

getBackend sock' = TLS.Backend flush (close sock) sendAll' recvAll
      where
        sock = getSocket sock'

        flush = return()

        -- [MEMO] Put them into TDS packets at regular intervals
        -- [TODO] Consider a better implementation
        sendAll' bs = do
          step <- readMVar (getSendStep sock')
          case step of
            0 -> sendAll sock $ (header bs) <> bs --0x16
            1 -> appendBuff -- 0x16
            2 -> appendBuff -- 0x14
            3 -> do
              buff <- readMVar (getSendBuff sock')
              let bs' = buff <> bs
              sendAll sock $ (header bs') <> bs' -- 0x16
              modifyMVar_ (getSendBuff sock') (\_ -> return mempty)
            _ -> sendAll sock bs -- 0x17
          modifyMVar_ (getSendStep sock') (return . (+1))
            where
              appendBuff = modifyMVar_ (getSendBuff sock') (return . (<>bs))
              header bs = LB.toStrict $ encode $ Header 0x12 1 (fromIntegral $ B.length bs +8) 0 0 0

        -- [MEMO] This doesn't work
        -- [MEMO] Want to do this
        sendAll'' bs = do
          case B.head bs of
            0x17 -> sendAll sock bs
            _    -> sendAll sock $ (header bs) <> bs
            where
              header bs = LB.toStrict $ encode $ Header 0x12 1 (fromIntegral $ B.length bs +8) 0 0 0


        -- [MEMO] Remove TDS header
        -- [MEMO] Receive as much as possible from the source. and return only sink's requested size for each turn.
        -- [TODO] Consider a better implementation
        recvAll len = do
          buff <- readMVar (getRecvBuff sock')
          if B.null buff
            then recvDropBuff
            else dropBuff

            where
              recvDropBuff = do
                header <- recv sock 8
                let (Header _ _ totalLen _ _ _) = decode $ LB.fromStrict header
                body <- recv sock $ fromIntegral $ totalLen -8
                let bs = B.take len body
                modifyMVar_ (getRecvBuff sock') (\_ -> return $ B.drop len body)
                return bs

              dropBuff = do
                buff <- readMVar (getRecvBuff sock')
                let bs = B.take len buff
                modifyMVar_ (getRecvBuff sock') (\_ -> return $ B.drop len buff)
                return bs


getTlsParams :: String -> CertificateStore -> ClientParams
getTlsParams host store =
  (TLS.defaultParamsClient host mempty) { clientSupported = def { supportedVersions = [TLS.TLS10]
                                                                , supportedCiphers = ciphersuite_strong
                                                                }
                                        , clientShared = def { sharedCAStore = store
                                                             , sharedValidationCache = validateCache
                                                             }
                                        }
  where
    validateCache = ValidationCache (\_ _ _ -> return ValidationCachePass) (\_ _ _ -> return ())