{-# LANGUAGE CPP #-}
module Database.Tds.Transport (contextNew) where
import Data.Monoid((<>),mempty)
import Control.Applicative((<$>),(<*>))
import Network.Socket (Socket,close)
import Network.Socket.ByteString (recv,sendAll)
import qualified Data.ByteString as B
import qualified Data.ByteString.Lazy as LB
import Data.Binary (decode,encode)
import Data.Default.Class (def)
import qualified Network.TLS as TLS
import Network.TLS (ClientParams(..),Supported(..),Shared(..),ValidationCache(..),ValidationCacheResult(..))
import Network.TLS.Extra.Cipher (ciphersuite_strong)
import Data.X509.CertificateStore (CertificateStore(..))
import System.X509 (getSystemCertificateStore)
import Control.Concurrent(MVar(..),newMVar,readMVar,modifyMVar_)
import Database.Tds.Message.Header
#if !MIN_VERSION_tls(1,3,0)
import Crypto.Random(createEntropyPool,cprgCreate,SystemRNG(..))
#endif
contextNew :: Socket -> String -> IO TLS.Context
contextNew sock host = do
certStore <- getSystemCertificateStore
sock' <- newSecureSocket sock
#if MIN_VERSION_tls(1,3,0)
TLS.contextNew (getBackend sock') (getTlsParams host certStore)
#else
pool <- createEntropyPool
TLS.contextNew (getBackend sock') (getTlsParams host certStore) (cprgCreate pool :: SystemRNG)
#endif
data SecureSocket = SecureSocket{ getSocket::Socket
, getSendBuff::MVar B.ByteString
, getSendStep::MVar Int
, getRecvBuff::MVar B.ByteString
}
newSecureSocket sock = SecureSocket sock <$> newMVar mempty <*> newMVar 0 <*> newMVar mempty
getBackend sock' = TLS.Backend flush (close sock) sendAll' recvAll
where
sock = getSocket sock'
flush = return()
sendAll' bs = do
step <- readMVar (getSendStep sock')
case step of
0 -> sendAll sock $ (header bs) <> bs
1 -> appendBuff
2 -> appendBuff
3 -> do
buff <- readMVar (getSendBuff sock')
let bs' = buff <> bs
sendAll sock $ (header bs') <> bs'
modifyMVar_ (getSendBuff sock') (\_ -> return mempty)
_ -> sendAll sock bs
modifyMVar_ (getSendStep sock') (return . (+1))
where
appendBuff = modifyMVar_ (getSendBuff sock') (return . (<>bs))
header bs = LB.toStrict $ encode $ Header 0x12 1 (fromIntegral $ B.length bs +8) 0 0 0
sendAll'' bs = do
case B.head bs of
0x17 -> sendAll sock bs
_ -> sendAll sock $ (header bs) <> bs
where
header bs = LB.toStrict $ encode $ Header 0x12 1 (fromIntegral $ B.length bs +8) 0 0 0
recvAll len = do
buff <- readMVar (getRecvBuff sock')
if B.null buff
then recvDropBuff
else dropBuff
where
recvDropBuff = do
header <- recv sock 8
let (Header _ _ totalLen _ _ _) = decode $ LB.fromStrict header
body <- recv sock $ fromIntegral $ totalLen -8
let bs = B.take len body
modifyMVar_ (getRecvBuff sock') (\_ -> return $ B.drop len body)
return bs
dropBuff = do
buff <- readMVar (getRecvBuff sock')
let bs = B.take len buff
modifyMVar_ (getRecvBuff sock') (\_ -> return $ B.drop len buff)
return bs
getTlsParams :: String -> CertificateStore -> ClientParams
getTlsParams host store =
(TLS.defaultParamsClient host mempty) { clientSupported = def { supportedVersions = [TLS.TLS10]
, supportedCiphers = ciphersuite_strong
}
, clientShared = def { sharedCAStore = store
, sharedValidationCache = validateCache
}
}
where
validateCache = ValidationCache (\_ _ _ -> return ValidationCachePass) (\_ _ _ -> return ())