module Database.Tds.Transport (contextNew) where import Data.Monoid((<>)) import Network.Socket (Socket,close) import Network.Socket.ByteString (recv,sendAll) import qualified Data.ByteString as B import qualified Data.ByteString.Lazy as LB import Data.Binary (decode,encode) import Data.Default.Class (def) import qualified Network.TLS as TLS import Network.TLS (ClientParams(..),Supported(..),Shared(..),ValidationCache(..),ValidationCacheResult(..)) import Network.TLS.Extra.Cipher (ciphersuite_default) import Data.X509.CertificateStore (CertificateStore(..)) import System.X509 (getSystemCertificateStore) import Control.Concurrent(MVar(..),newMVar,readMVar,modifyMVar_) import Database.Tds.Message.Header contextNew :: Socket -> TLS.HostName -> IO TLS.Context contextNew sock host = do certStore <- getSystemCertificateStore sock' <- newSecureSocket sock TLS.contextNew sock' $ getTlsParams host certStore data SecureSocket = SecureSocket{ getSocket::Socket , getSendBuff::MVar B.ByteString , getSendStep::MVar Int , getRecvBuff::MVar B.ByteString } newSecureSocket sock = SecureSocket sock <$> newMVar mempty <*> newMVar 0 <*> newMVar mempty instance TLS.HasBackend SecureSocket where initializeBackend sock' = return () getBackend sock' = TLS.Backend flush (close sock) sendAll' recvAll where sock = getSocket sock' flush = return() -- [MEMO] Put them into TDS packets at regular intervals -- [TODO] Consider a better implementation sendAll' bs = do step <- readMVar (getSendStep sock') case step of 0 -> sendAll sock $ (header bs) <> bs --0x16 1 -> appendBuff -- 0x16 2 -> appendBuff -- 0x14 3 -> do buff <- readMVar (getSendBuff sock') let bs' = buff <> bs sendAll sock $ (header bs') <> bs' -- 0x16 modifyMVar_ (getSendBuff sock') (\_ -> return mempty) _ -> sendAll sock bs -- 0x17 modifyMVar_ (getSendStep sock') (return . (+1)) where appendBuff = modifyMVar_ (getSendBuff sock') (return . (<>bs)) header bs = LB.toStrict $ encode $ Header 0x12 1 (B.length bs +8) 0 0 0 -- [MEMO] This doesn't work -- [MEMO] Want to do this sendAll'' bs = do case B.head bs of 0x17 -> sendAll sock bs _ -> sendAll sock $ (header bs) <> bs where header bs = LB.toStrict $ encode $ Header 0x12 1 (B.length bs +8) 0 0 0 -- [MEMO] Remove TDS header -- [MEMO] Receive as much as possible from the source. and return only sink's requested size for each turn. -- [TODO] Consider a better implementation recvAll len = do buff <- readMVar (getRecvBuff sock') if B.null buff then recvDropBuff else dropBuff where recvDropBuff = do header <- recv sock 8 let (Header _ _ totalLen _ _ _) = decode $ LB.fromStrict header body <- recv sock $ totalLen -8 let bs = B.take len body modifyMVar_ (getRecvBuff sock') (\_ -> return $ B.drop len body) return bs dropBuff = do buff <- readMVar (getRecvBuff sock') let bs = B.take len buff modifyMVar_ (getRecvBuff sock') (\_ -> return $ B.drop len buff) return bs getTlsParams :: TLS.HostName -> CertificateStore -> ClientParams getTlsParams host store = (TLS.defaultParamsClient host mempty) { clientSupported = def { supportedVersions = [TLS.TLS10] , supportedCiphers = ciphersuite_default , supportedEmptyPacket = False } , clientShared = def { sharedCAStore = store , sharedValidationCache = validateCache } } where validateCache = ValidationCache (\_ _ _ -> return ValidationCachePass) (\_ _ _ -> return ())