{-# LANGUAGE CPP #-}

module Database.Tds.Transport (contextNew) where

import Data.Monoid((<>),mempty)
import Control.Applicative((<$>),(<*>))

import Network.Socket (Socket,close)
import Network.Socket.ByteString (recv,sendAll)

import qualified Data.ByteString as B
import qualified Data.ByteString.Lazy as LB

import Data.Binary (decode,encode)

import Data.Default.Class (def)
import qualified Network.TLS as TLS
import Network.TLS (ClientParams(..),Supported(..),Shared(..),ValidationCache(..),ValidationCacheResult(..))
import Network.TLS.Extra.Cipher (ciphersuite_strong)
import Data.X509.CertificateStore (CertificateStore(..))
import System.X509 (getSystemCertificateStore)

import Control.Concurrent(MVar(..),newMVar,readMVar,modifyMVar_)

import Database.Tds.Message.Header
#if !MIN_VERSION_tls(1,3,0)
import Crypto.Random(createEntropyPool,cprgCreate,SystemRNG(..))
#endif


-- | [\[MS-TDS\] 3.2.5.2 Sent TLS/SSL Negotiation Packet State](https://docs.microsoft.com/en-us/openspecs/windows_protocols/ms-tds/d62e225b-d865-4ccc-8f73-de1ef49e30d4)
contextNew :: Socket -> String -> IO TLS.Context
contextNew :: Socket -> String -> IO Context
contextNew Socket
sock String
host = do
  CertificateStore
certStore <- IO CertificateStore
getSystemCertificateStore
  SecureSocket
sock' <- Socket -> IO SecureSocket
newSecureSocket Socket
sock
#if MIN_VERSION_tls(1,3,0)
  Backend -> ClientParams -> IO Context
forall (m :: * -> *) backend params.
(MonadIO m, HasBackend backend, TLSParams params) =>
backend -> params -> m Context
TLS.contextNew (SecureSocket -> Backend
getBackend SecureSocket
sock') (String -> CertificateStore -> ClientParams
getTlsParams String
host CertificateStore
certStore)
#else
  pool <- createEntropyPool
  TLS.contextNew (getBackend sock') (getTlsParams host certStore) (cprgCreate pool :: SystemRNG)
#endif



data SecureSocket = SecureSocket{ SecureSocket -> Socket
getSocket::Socket
                                , SecureSocket -> MVar ByteString
getSendBuff::MVar B.ByteString
                                , SecureSocket -> MVar Int
getSendStep::MVar Int
                                , SecureSocket -> MVar ByteString
getRecvBuff::MVar B.ByteString
                                }
                    
newSecureSocket :: Socket -> IO SecureSocket
newSecureSocket Socket
sock = Socket
-> MVar ByteString -> MVar Int -> MVar ByteString -> SecureSocket
SecureSocket Socket
sock (MVar ByteString -> MVar Int -> MVar ByteString -> SecureSocket)
-> IO (MVar ByteString)
-> IO (MVar Int -> MVar ByteString -> SecureSocket)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ByteString -> IO (MVar ByteString)
forall a. a -> IO (MVar a)
newMVar ByteString
forall a. Monoid a => a
mempty IO (MVar Int -> MVar ByteString -> SecureSocket)
-> IO (MVar Int) -> IO (MVar ByteString -> SecureSocket)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Int -> IO (MVar Int)
forall a. a -> IO (MVar a)
newMVar Int
0 IO (MVar ByteString -> SecureSocket)
-> IO (MVar ByteString) -> IO SecureSocket
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ByteString -> IO (MVar ByteString)
forall a. a -> IO (MVar a)
newMVar ByteString
forall a. Monoid a => a
mempty

getBackend :: SecureSocket -> Backend
getBackend SecureSocket
sock' = IO ()
-> IO ()
-> (ByteString -> IO ())
-> (Int -> IO ByteString)
-> Backend
TLS.Backend IO ()
flush (Socket -> IO ()
close Socket
sock) ByteString -> IO ()
sendAll' Int -> IO ByteString
recvAll
      where
        sock :: Socket
sock = SecureSocket -> Socket
getSocket SecureSocket
sock'
              
        flush :: IO ()
flush = () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return()
          
        -- [MEMO] Put them into TDS packets at regular intervals
        -- [TODO] Consider a better implementation
        sendAll' :: ByteString -> IO ()
sendAll' ByteString
bs = do
          Int
step <- MVar Int -> IO Int
forall a. MVar a -> IO a
readMVar (SecureSocket -> MVar Int
getSendStep SecureSocket
sock')
          case Int
step of
            Int
0 -> Socket -> ByteString -> IO ()
sendAll Socket
sock (ByteString -> IO ()) -> ByteString -> IO ()
forall a b. (a -> b) -> a -> b
$ (ByteString -> ByteString
header ByteString
bs) ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
bs --0x16
            Int
1 -> IO ()
appendBuff -- 0x16
            Int
2 -> IO ()
appendBuff -- 0x14
            Int
3 -> do
              ByteString
buff <- MVar ByteString -> IO ByteString
forall a. MVar a -> IO a
readMVar (SecureSocket -> MVar ByteString
getSendBuff SecureSocket
sock')
              let bs' :: ByteString
bs' = ByteString
buff ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
bs
              Socket -> ByteString -> IO ()
sendAll Socket
sock (ByteString -> IO ()) -> ByteString -> IO ()
forall a b. (a -> b) -> a -> b
$ (ByteString -> ByteString
header ByteString
bs') ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
bs' -- 0x16
              MVar ByteString -> (ByteString -> IO ByteString) -> IO ()
forall a. MVar a -> (a -> IO a) -> IO ()
modifyMVar_ (SecureSocket -> MVar ByteString
getSendBuff SecureSocket
sock') (\ByteString
_ -> ByteString -> IO ByteString
forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
forall a. Monoid a => a
mempty)
            Int
_ -> Socket -> ByteString -> IO ()
sendAll Socket
sock ByteString
bs -- 0x17
          MVar Int -> (Int -> IO Int) -> IO ()
forall a. MVar a -> (a -> IO a) -> IO ()
modifyMVar_ (SecureSocket -> MVar Int
getSendStep SecureSocket
sock') (Int -> IO Int
forall (m :: * -> *) a. Monad m => a -> m a
return (Int -> IO Int) -> (Int -> Int) -> Int -> IO Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Int -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1))
            where
              appendBuff :: IO ()
appendBuff = MVar ByteString -> (ByteString -> IO ByteString) -> IO ()
forall a. MVar a -> (a -> IO a) -> IO ()
modifyMVar_ (SecureSocket -> MVar ByteString
getSendBuff SecureSocket
sock') (ByteString -> IO ByteString
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString -> IO ByteString)
-> (ByteString -> ByteString) -> ByteString -> IO ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<>ByteString
bs))
              header :: ByteString -> ByteString
header ByteString
bs = ByteString -> ByteString
LB.toStrict (ByteString -> ByteString) -> ByteString -> ByteString
forall a b. (a -> b) -> a -> b
$ Header -> ByteString
forall a. Binary a => a -> ByteString
encode (Header -> ByteString) -> Header -> ByteString
forall a b. (a -> b) -> a -> b
$ Type -> Type -> Length -> Length -> Type -> Type -> Header
Header Type
0x12 Type
1 (Int -> Length
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> Length) -> Int -> Length
forall a b. (a -> b) -> a -> b
$ ByteString -> Int
B.length ByteString
bs Int -> Int -> Int
forall a. Num a => a -> a -> a
+Int
8) Length
0 Type
0 Type
0
          
        -- [MEMO] This doesn't work
        -- [MEMO] Want to do this
        sendAll'' :: ByteString -> IO ()
sendAll'' ByteString
bs = do
          case ByteString -> Type
B.head ByteString
bs of
            Type
0x17 -> Socket -> ByteString -> IO ()
sendAll Socket
sock ByteString
bs
            Type
_    -> Socket -> ByteString -> IO ()
sendAll Socket
sock (ByteString -> IO ()) -> ByteString -> IO ()
forall a b. (a -> b) -> a -> b
$ (ByteString -> ByteString
header ByteString
bs) ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
bs
            where
              header :: ByteString -> ByteString
header ByteString
bs = ByteString -> ByteString
LB.toStrict (ByteString -> ByteString) -> ByteString -> ByteString
forall a b. (a -> b) -> a -> b
$ Header -> ByteString
forall a. Binary a => a -> ByteString
encode (Header -> ByteString) -> Header -> ByteString
forall a b. (a -> b) -> a -> b
$ Type -> Type -> Length -> Length -> Type -> Type -> Header
Header Type
0x12 Type
1 (Int -> Length
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> Length) -> Int -> Length
forall a b. (a -> b) -> a -> b
$ ByteString -> Int
B.length ByteString
bs Int -> Int -> Int
forall a. Num a => a -> a -> a
+Int
8) Length
0 Type
0 Type
0
          

        -- [MEMO] Remove TDS header
        -- [MEMO] Receive as much as possible from the source. and return only sink's requested size for each turn.
        -- [TODO] Consider a better implementation
        recvAll :: Int -> IO ByteString
recvAll Int
len = do
          ByteString
buff <- MVar ByteString -> IO ByteString
forall a. MVar a -> IO a
readMVar (SecureSocket -> MVar ByteString
getRecvBuff SecureSocket
sock')
          if ByteString -> Bool
B.null ByteString
buff
            then IO ByteString
recvDropBuff
            else IO ByteString
dropBuff
                 
            where
              recvDropBuff :: IO ByteString
recvDropBuff = do
                ByteString
header <- Socket -> Int -> IO ByteString
recv Socket
sock Int
8
                let (Header Type
_ Type
_ Length
totalLen Length
_ Type
_ Type
_) = ByteString -> Header
forall a. Binary a => ByteString -> a
decode (ByteString -> Header) -> ByteString -> Header
forall a b. (a -> b) -> a -> b
$ ByteString -> ByteString
LB.fromStrict ByteString
header
                ByteString
body <- Socket -> Int -> IO ByteString
recv Socket
sock (Int -> IO ByteString) -> Int -> IO ByteString
forall a b. (a -> b) -> a -> b
$ Length -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Length -> Int) -> Length -> Int
forall a b. (a -> b) -> a -> b
$ Length
totalLen Length -> Length -> Length
forall a. Num a => a -> a -> a
-Length
8
                let bs :: ByteString
bs = Int -> ByteString -> ByteString
B.take Int
len ByteString
body
                MVar ByteString -> (ByteString -> IO ByteString) -> IO ()
forall a. MVar a -> (a -> IO a) -> IO ()
modifyMVar_ (SecureSocket -> MVar ByteString
getRecvBuff SecureSocket
sock') (\ByteString
_ -> ByteString -> IO ByteString
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString -> IO ByteString) -> ByteString -> IO ByteString
forall a b. (a -> b) -> a -> b
$ Int -> ByteString -> ByteString
B.drop Int
len ByteString
body)
                ByteString -> IO ByteString
forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
bs
                
              dropBuff :: IO ByteString
dropBuff = do
                ByteString
buff <- MVar ByteString -> IO ByteString
forall a. MVar a -> IO a
readMVar (SecureSocket -> MVar ByteString
getRecvBuff SecureSocket
sock')
                let bs :: ByteString
bs = Int -> ByteString -> ByteString
B.take Int
len ByteString
buff
                MVar ByteString -> (ByteString -> IO ByteString) -> IO ()
forall a. MVar a -> (a -> IO a) -> IO ()
modifyMVar_ (SecureSocket -> MVar ByteString
getRecvBuff SecureSocket
sock') (\ByteString
_ -> ByteString -> IO ByteString
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString -> IO ByteString) -> ByteString -> IO ByteString
forall a b. (a -> b) -> a -> b
$ Int -> ByteString -> ByteString
B.drop Int
len ByteString
buff)
                ByteString -> IO ByteString
forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
bs


getTlsParams :: String -> CertificateStore -> ClientParams
getTlsParams :: String -> CertificateStore -> ClientParams
getTlsParams String
host CertificateStore
store =
  (String -> ByteString -> ClientParams
TLS.defaultParamsClient String
host ByteString
forall a. Monoid a => a
mempty) { clientSupported :: Supported
clientSupported = Supported
forall a. Default a => a
def { supportedVersions :: [Version]
supportedVersions = [Version
TLS.TLS10]
                                                                , supportedCiphers :: [Cipher]
supportedCiphers = [Cipher]
ciphersuite_strong
                                                                }
                                        , clientShared :: Shared
clientShared = Shared
forall a. Default a => a
def { sharedCAStore :: CertificateStore
sharedCAStore = CertificateStore
store
                                                             , sharedValidationCache :: ValidationCache
sharedValidationCache = ValidationCache
validateCache
                                                             }
                                        }
  where
    validateCache :: ValidationCache
validateCache = ValidationCacheQueryCallback
-> ValidationCacheAddCallback -> ValidationCache
ValidationCache (\ServiceID
_ Fingerprint
_ Certificate
_ -> ValidationCacheResult -> IO ValidationCacheResult
forall (m :: * -> *) a. Monad m => a -> m a
return ValidationCacheResult
ValidationCachePass) (\ServiceID
_ Fingerprint
_ Certificate
_ -> () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ())