module Database.Tds.Transport (contextNew) where
import Data.Monoid((<>))
import Network.Socket (Socket,close)
import Network.Socket.ByteString (recv,sendAll)
import qualified Data.ByteString as B
import qualified Data.ByteString.Lazy as LB
import Data.Binary (decode,encode)
import Data.Default.Class (def)
import qualified Network.TLS as TLS
import Network.TLS (ClientParams(..),Supported(..),Shared(..),ValidationCache(..),ValidationCacheResult(..))
import Network.TLS.Extra.Cipher (ciphersuite_default)
import Data.X509.CertificateStore (CertificateStore(..))
import System.X509 (getSystemCertificateStore)
import Control.Concurrent(MVar(..),newMVar,readMVar,modifyMVar_)
import Database.Tds.Message.Header
contextNew :: Socket -> TLS.HostName -> IO TLS.Context
contextNew sock host = do
certStore <- getSystemCertificateStore
sock' <- newSecureSocket sock
TLS.contextNew sock' $ getTlsParams host certStore
data SecureSocket = SecureSocket{ getSocket::Socket
, getSendBuff::MVar B.ByteString
, getSendStep::MVar Int
, getRecvBuff::MVar B.ByteString
}
newSecureSocket sock = SecureSocket sock <$> newMVar mempty <*> newMVar 0 <*> newMVar mempty
instance TLS.HasBackend SecureSocket where
initializeBackend sock' = return ()
getBackend sock' = TLS.Backend flush (close sock) sendAll' recvAll
where
sock = getSocket sock'
flush = return()
sendAll' bs = do
step <- readMVar (getSendStep sock')
case step of
0 -> sendAll sock $ (header bs) <> bs
1 -> appendBuff
2 -> appendBuff
3 -> do
buff <- readMVar (getSendBuff sock')
let bs' = buff <> bs
sendAll sock $ (header bs') <> bs'
modifyMVar_ (getSendBuff sock') (\_ -> return mempty)
_ -> sendAll sock bs
modifyMVar_ (getSendStep sock') (return . (+1))
where
appendBuff = modifyMVar_ (getSendBuff sock') (return . (<>bs))
header bs = LB.toStrict $ encode $ Header 0x12 1 (B.length bs +8) 0 0 0
sendAll'' bs = do
case B.head bs of
0x17 -> sendAll sock bs
_ -> sendAll sock $ (header bs) <> bs
where
header bs = LB.toStrict $ encode $ Header 0x12 1 (B.length bs +8) 0 0 0
recvAll len = do
buff <- readMVar (getRecvBuff sock')
if B.null buff
then recvDropBuff
else dropBuff
where
recvDropBuff = do
header <- recv sock 8
let (Header _ _ totalLen _ _ _) = decode $ LB.fromStrict header
body <- recv sock $ totalLen -8
let bs = B.take len body
modifyMVar_ (getRecvBuff sock') (\_ -> return $ B.drop len body)
return bs
dropBuff = do
buff <- readMVar (getRecvBuff sock')
let bs = B.take len buff
modifyMVar_ (getRecvBuff sock') (\_ -> return $ B.drop len buff)
return bs
getTlsParams :: TLS.HostName -> CertificateStore -> ClientParams
getTlsParams host store =
(TLS.defaultParamsClient host mempty) { clientSupported = def { supportedVersions = [TLS.TLS10]
, supportedCiphers = ciphersuite_default
, supportedEmptyPacket = False
}
, clientShared = def { sharedCAStore = store
, sharedValidationCache = validateCache
}
}
where
validateCache = ValidationCache (\_ _ _ -> return ValidationCachePass) (\_ _ _ -> return ())