{-# LANGUAGE CPP #-}
module Database.Tds.Transport (contextNew) where
import Data.Monoid((<>),mempty)
import Control.Applicative((<$>),(<*>))
import Network.Socket (Socket,close)
import Network.Socket.ByteString (recv,sendAll)
import qualified Data.ByteString as B
import qualified Data.ByteString.Lazy as LB
import Data.Binary (decode,encode)
import Data.Default.Class (def)
import qualified Network.TLS as TLS
import Network.TLS (ClientParams(..),Supported(..),Shared(..),ValidationCache(..),ValidationCacheResult(..))
import Network.TLS.Extra.Cipher (ciphersuite_strong)
import Data.X509.CertificateStore (CertificateStore(..))
import System.X509 (getSystemCertificateStore)
import Control.Concurrent(MVar(..),newMVar,readMVar,modifyMVar_)
import Database.Tds.Message.Header
#if !MIN_VERSION_tls(1,3,0)
import Crypto.Random(createEntropyPool,cprgCreate,SystemRNG(..))
#endif
contextNew :: Socket -> String -> IO TLS.Context
contextNew :: Socket -> String -> IO Context
contextNew Socket
sock String
host = do
CertificateStore
certStore <- IO CertificateStore
getSystemCertificateStore
SecureSocket
sock' <- Socket -> IO SecureSocket
newSecureSocket Socket
sock
#if MIN_VERSION_tls(1,3,0)
Backend -> ClientParams -> IO Context
forall (m :: * -> *) backend params.
(MonadIO m, HasBackend backend, TLSParams params) =>
backend -> params -> m Context
TLS.contextNew (SecureSocket -> Backend
getBackend SecureSocket
sock') (String -> CertificateStore -> ClientParams
getTlsParams String
host CertificateStore
certStore)
#else
pool <- createEntropyPool
TLS.contextNew (getBackend sock') (getTlsParams host certStore) (cprgCreate pool :: SystemRNG)
#endif
data SecureSocket = SecureSocket{ SecureSocket -> Socket
getSocket::Socket
, SecureSocket -> MVar ByteString
getSendBuff::MVar B.ByteString
, SecureSocket -> MVar Int
getSendStep::MVar Int
, SecureSocket -> MVar ByteString
getRecvBuff::MVar B.ByteString
}
newSecureSocket :: Socket -> IO SecureSocket
newSecureSocket Socket
sock = Socket
-> MVar ByteString -> MVar Int -> MVar ByteString -> SecureSocket
SecureSocket Socket
sock (MVar ByteString -> MVar Int -> MVar ByteString -> SecureSocket)
-> IO (MVar ByteString)
-> IO (MVar Int -> MVar ByteString -> SecureSocket)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ByteString -> IO (MVar ByteString)
forall a. a -> IO (MVar a)
newMVar ByteString
forall a. Monoid a => a
mempty IO (MVar Int -> MVar ByteString -> SecureSocket)
-> IO (MVar Int) -> IO (MVar ByteString -> SecureSocket)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Int -> IO (MVar Int)
forall a. a -> IO (MVar a)
newMVar Int
0 IO (MVar ByteString -> SecureSocket)
-> IO (MVar ByteString) -> IO SecureSocket
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ByteString -> IO (MVar ByteString)
forall a. a -> IO (MVar a)
newMVar ByteString
forall a. Monoid a => a
mempty
getBackend :: SecureSocket -> Backend
getBackend SecureSocket
sock' = IO ()
-> IO ()
-> (ByteString -> IO ())
-> (Int -> IO ByteString)
-> Backend
TLS.Backend IO ()
flush (Socket -> IO ()
close Socket
sock) ByteString -> IO ()
sendAll' Int -> IO ByteString
recvAll
where
sock :: Socket
sock = SecureSocket -> Socket
getSocket SecureSocket
sock'
flush :: IO ()
flush = () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return()
sendAll' :: ByteString -> IO ()
sendAll' ByteString
bs = do
Int
step <- MVar Int -> IO Int
forall a. MVar a -> IO a
readMVar (SecureSocket -> MVar Int
getSendStep SecureSocket
sock')
case Int
step of
Int
0 -> Socket -> ByteString -> IO ()
sendAll Socket
sock (ByteString -> IO ()) -> ByteString -> IO ()
forall a b. (a -> b) -> a -> b
$ (ByteString -> ByteString
header ByteString
bs) ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
bs
Int
1 -> IO ()
appendBuff
Int
2 -> IO ()
appendBuff
Int
3 -> do
ByteString
buff <- MVar ByteString -> IO ByteString
forall a. MVar a -> IO a
readMVar (SecureSocket -> MVar ByteString
getSendBuff SecureSocket
sock')
let bs' :: ByteString
bs' = ByteString
buff ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
bs
Socket -> ByteString -> IO ()
sendAll Socket
sock (ByteString -> IO ()) -> ByteString -> IO ()
forall a b. (a -> b) -> a -> b
$ (ByteString -> ByteString
header ByteString
bs') ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
bs'
MVar ByteString -> (ByteString -> IO ByteString) -> IO ()
forall a. MVar a -> (a -> IO a) -> IO ()
modifyMVar_ (SecureSocket -> MVar ByteString
getSendBuff SecureSocket
sock') (\ByteString
_ -> ByteString -> IO ByteString
forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
forall a. Monoid a => a
mempty)
Int
_ -> Socket -> ByteString -> IO ()
sendAll Socket
sock ByteString
bs
MVar Int -> (Int -> IO Int) -> IO ()
forall a. MVar a -> (a -> IO a) -> IO ()
modifyMVar_ (SecureSocket -> MVar Int
getSendStep SecureSocket
sock') (Int -> IO Int
forall (m :: * -> *) a. Monad m => a -> m a
return (Int -> IO Int) -> (Int -> Int) -> Int -> IO Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Int -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1))
where
appendBuff :: IO ()
appendBuff = MVar ByteString -> (ByteString -> IO ByteString) -> IO ()
forall a. MVar a -> (a -> IO a) -> IO ()
modifyMVar_ (SecureSocket -> MVar ByteString
getSendBuff SecureSocket
sock') (ByteString -> IO ByteString
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString -> IO ByteString)
-> (ByteString -> ByteString) -> ByteString -> IO ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<>ByteString
bs))
header :: ByteString -> ByteString
header ByteString
bs = ByteString -> ByteString
LB.toStrict (ByteString -> ByteString) -> ByteString -> ByteString
forall a b. (a -> b) -> a -> b
$ Header -> ByteString
forall a. Binary a => a -> ByteString
encode (Header -> ByteString) -> Header -> ByteString
forall a b. (a -> b) -> a -> b
$ Type -> Type -> Length -> Length -> Type -> Type -> Header
Header Type
0x12 Type
1 (Int -> Length
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> Length) -> Int -> Length
forall a b. (a -> b) -> a -> b
$ ByteString -> Int
B.length ByteString
bs Int -> Int -> Int
forall a. Num a => a -> a -> a
+Int
8) Length
0 Type
0 Type
0
sendAll'' :: ByteString -> IO ()
sendAll'' ByteString
bs = do
case ByteString -> Type
B.head ByteString
bs of
Type
0x17 -> Socket -> ByteString -> IO ()
sendAll Socket
sock ByteString
bs
Type
_ -> Socket -> ByteString -> IO ()
sendAll Socket
sock (ByteString -> IO ()) -> ByteString -> IO ()
forall a b. (a -> b) -> a -> b
$ (ByteString -> ByteString
header ByteString
bs) ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
bs
where
header :: ByteString -> ByteString
header ByteString
bs = ByteString -> ByteString
LB.toStrict (ByteString -> ByteString) -> ByteString -> ByteString
forall a b. (a -> b) -> a -> b
$ Header -> ByteString
forall a. Binary a => a -> ByteString
encode (Header -> ByteString) -> Header -> ByteString
forall a b. (a -> b) -> a -> b
$ Type -> Type -> Length -> Length -> Type -> Type -> Header
Header Type
0x12 Type
1 (Int -> Length
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> Length) -> Int -> Length
forall a b. (a -> b) -> a -> b
$ ByteString -> Int
B.length ByteString
bs Int -> Int -> Int
forall a. Num a => a -> a -> a
+Int
8) Length
0 Type
0 Type
0
recvAll :: Int -> IO ByteString
recvAll Int
len = do
ByteString
buff <- MVar ByteString -> IO ByteString
forall a. MVar a -> IO a
readMVar (SecureSocket -> MVar ByteString
getRecvBuff SecureSocket
sock')
if ByteString -> Bool
B.null ByteString
buff
then IO ByteString
recvDropBuff
else IO ByteString
dropBuff
where
recvDropBuff :: IO ByteString
recvDropBuff = do
ByteString
header <- Socket -> Int -> IO ByteString
recv Socket
sock Int
8
let (Header Type
_ Type
_ Length
totalLen Length
_ Type
_ Type
_) = ByteString -> Header
forall a. Binary a => ByteString -> a
decode (ByteString -> Header) -> ByteString -> Header
forall a b. (a -> b) -> a -> b
$ ByteString -> ByteString
LB.fromStrict ByteString
header
ByteString
body <- Socket -> Int -> IO ByteString
recv Socket
sock (Int -> IO ByteString) -> Int -> IO ByteString
forall a b. (a -> b) -> a -> b
$ Length -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Length -> Int) -> Length -> Int
forall a b. (a -> b) -> a -> b
$ Length
totalLen Length -> Length -> Length
forall a. Num a => a -> a -> a
-Length
8
let bs :: ByteString
bs = Int -> ByteString -> ByteString
B.take Int
len ByteString
body
MVar ByteString -> (ByteString -> IO ByteString) -> IO ()
forall a. MVar a -> (a -> IO a) -> IO ()
modifyMVar_ (SecureSocket -> MVar ByteString
getRecvBuff SecureSocket
sock') (\ByteString
_ -> ByteString -> IO ByteString
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString -> IO ByteString) -> ByteString -> IO ByteString
forall a b. (a -> b) -> a -> b
$ Int -> ByteString -> ByteString
B.drop Int
len ByteString
body)
ByteString -> IO ByteString
forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
bs
dropBuff :: IO ByteString
dropBuff = do
ByteString
buff <- MVar ByteString -> IO ByteString
forall a. MVar a -> IO a
readMVar (SecureSocket -> MVar ByteString
getRecvBuff SecureSocket
sock')
let bs :: ByteString
bs = Int -> ByteString -> ByteString
B.take Int
len ByteString
buff
MVar ByteString -> (ByteString -> IO ByteString) -> IO ()
forall a. MVar a -> (a -> IO a) -> IO ()
modifyMVar_ (SecureSocket -> MVar ByteString
getRecvBuff SecureSocket
sock') (\ByteString
_ -> ByteString -> IO ByteString
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString -> IO ByteString) -> ByteString -> IO ByteString
forall a b. (a -> b) -> a -> b
$ Int -> ByteString -> ByteString
B.drop Int
len ByteString
buff)
ByteString -> IO ByteString
forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
bs
getTlsParams :: String -> CertificateStore -> ClientParams
getTlsParams :: String -> CertificateStore -> ClientParams
getTlsParams String
host CertificateStore
store =
(String -> ByteString -> ClientParams
TLS.defaultParamsClient String
host ByteString
forall a. Monoid a => a
mempty) { clientSupported :: Supported
clientSupported = Supported
forall a. Default a => a
def { supportedVersions :: [Version]
supportedVersions = [Version
TLS.TLS10]
, supportedCiphers :: [Cipher]
supportedCiphers = [Cipher]
ciphersuite_strong
}
, clientShared :: Shared
clientShared = Shared
forall a. Default a => a
def { sharedCAStore :: CertificateStore
sharedCAStore = CertificateStore
store
, sharedValidationCache :: ValidationCache
sharedValidationCache = ValidationCache
validateCache
}
}
where
validateCache :: ValidationCache
validateCache = ValidationCacheQueryCallback
-> ValidationCacheAddCallback -> ValidationCache
ValidationCache (\ServiceID
_ Fingerprint
_ Certificate
_ -> ValidationCacheResult -> IO ValidationCacheResult
forall (m :: * -> *) a. Monad m => a -> m a
return ValidationCacheResult
ValidationCachePass) (\ServiceID
_ Fingerprint
_ Certificate
_ -> () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ())