{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE DeriveDataTypeable #-}
module Database.MSSQLServer.Connection
(
ConnectInfo(..)
, defaultConnectInfo
, Connection(..)
, connect
, connectWithoutEncryption
, close
, ProtocolError(..)
, AuthError(..)
) where
import qualified Network.Socket as Socket
import Network.Socket (AddrInfo(..),SocketType(..),Socket(..))
import Network.Socket.ByteString (recv)
import Network.Socket.ByteString.Lazy (sendAll)
import Data.Monoid ((<>),mempty)
import qualified Data.ByteString as B
import qualified Data.Text as T
import qualified Data.Text.IO as T
import qualified Data.Text.Encoding as T
import qualified Data.Binary.Put as Put
import qualified Data.Binary.Get as Get
import Control.Monad (when)
import Control.Exception (Exception(..),throwIO)
import qualified Network.TLS as TLS
import Network.HostName (getHostName)
import Database.Tds.Message
import Database.Tds.Transport (contextNew)
import Data.Word (Word8,Word32)
import Data.Int (Int32)
import Data.Typeable(Typeable)
data ProtocolError = ProtocolError String
deriving (Show,Typeable)
instance Exception ProtocolError
data AuthError = AuthError !Info
deriving (Show,Typeable)
instance Exception AuthError
data ConnectInfo = ConnectInfo { connectHost :: !String
, connectPort :: !String
, connectDatabase :: !String
, connectUser :: !String
, connectPassword :: !String
, connectEncryption :: !Word8
, connectPacketSize :: !Word32
, connectOptionFlags1 :: !Word8
, connectOptionFlags2 :: !Word8
, connectOptionFlags3 :: !Word8
, connectTypeFlags :: !Word8
, connectTimeZone :: !Int32
, connectCollation :: !Collation32
, connectLanguage :: !String
, connectAppName :: !String
, connectServerName :: !String
}
defaultConnectInfo :: ConnectInfo
defaultConnectInfo =
let
l7 = defaultLogin7
in ConnectInfo { connectHost = mempty
, connectPort = mempty
, connectDatabase = T.unpack $ l7Database l7
, connectUser = T.unpack $ l7UserName l7
, connectPassword = T.unpack $ l7Password l7
, connectEncryption = 0x00
, connectPacketSize = l7PacketSize l7
, connectOptionFlags1 = l7OptionFlags1 l7
, connectOptionFlags2 = l7OptionFlags2 l7
, connectOptionFlags3 = l7OptionFlags3 l7
, connectTypeFlags = l7TypeFlags l7
, connectTimeZone = l7TimeZone l7
, connectCollation = l7Collation l7
, connectLanguage = T.unpack $ l7Language l7
, connectAppName = T.unpack $ l7AppName l7
, connectServerName = T.unpack $ l7ServerName l7
}
data Connection = Connection Socket Word32
connect :: ConnectInfo -> IO Connection
connect ci@(ConnectInfo host port _ _ _ encrypt ps _ _ _ _ _ _ _ _ _) = do
addr <- resolve host port
sock <- connect' addr
Prelogin plResOpts <- performPrelogin sock ps encrypt
PLOEncryption modeEnc:_ <- case filter isPLOEncryption plResOpts of
[] -> throwIO $ ProtocolError "connect: PLOEncryption is necessary"
xs -> return xs
PLOMars modeMars:_ <- case filter isPLOMars plResOpts of
[] -> throwIO $ ProtocolError "connect: PLOMars is necessary"
xs -> return xs
when (modeEnc/=encrypt) $ throwIO $ ProtocolError "connect: Server reported unsupported encryption mode"
when (modeMars/=0) $ throwIO $ ProtocolError "connect: Server reported unsupported mars mode"
login7 <- newLogin7 ci
tss <- case encrypt of
0x00 -> do
tlsContext <- contextNew sock host
TLS.handshake tlsContext
TLS.sendData tlsContext $ Put.runPut $ putClientMessage ps $ CMLogin7 login7
readMessage sock $ Get.runGetIncremental getServerMessage
0x02 -> do
sendAll sock $ Put.runPut $ putClientMessage ps $ CMLogin7 login7
readMessage sock $ Get.runGetIncremental getServerMessage
validLoginAck login7 tss
return $ Connection sock ps
connectWithoutEncryption :: ConnectInfo -> IO Connection
connectWithoutEncryption ci = connect $ ci {connectEncryption = 0x02}
close :: Connection -> IO ()
close (Connection sock _ ) = Socket.close sock
performPrelogin :: Socket -> Word32 -> Word8 -> IO Prelogin
performPrelogin sock ps enc = do
let clientPrelogin = Prelogin [ PLOVersion 8 0 341 0
, PLOEncryption enc
, PLOInstopt "MSSQLServer"
, PLOThreadid (Just 1000)
, PLOMars 0
]
sendAll sock $ Put.runPut $ putClientMessage ps $ CMPrelogin clientPrelogin
serverPrelogin <- readMessage sock $ Get.runGetIncremental getServerMessage
return serverPrelogin
newLogin7 :: ConnectInfo -> IO Login7
newLogin7 (ConnectInfo _ _ database user pass _ _ optf1 optf2 optf3 typef tz coll lang app serv) = do
hostname <- getHostName
let login7 = defaultLogin7 { l7ClientProgVer = 1
, l7OptionFlags1 = optf1
, l7OptionFlags2 = optf2
, l7OptionFlags3 = optf3
, l7TypeFlags = typef
, l7TimeZone = tz
, l7Collation = coll
, l7CltIntName = T.pack "mssql-simple"
, l7Language = T.pack lang
, l7ClientPID = 1
, l7ClientMacAddr = B.pack [0x00,0x00,0x00,0x00,0x00,0x00]
, l7ClientHostName = T.pack hostname
, l7AppName = T.pack app
, l7ServerName = T.pack serv
, l7UserName = T.pack user
, l7Password = T.pack pass
, l7Database = T.pack database
}
return login7
validLoginAck :: Login7 -> TokenStreams -> IO ()
validLoginAck login7 (TokenStreams loginResTokenStreams) = do
let loginAcks = filter isTSLoginAck loginResTokenStreams
when (null loginAcks) $ do
TSError info:_ <- case filter isTSError loginResTokenStreams of
[] -> throwIO $ ProtocolError "validLoginAck: TSError is necessary"
xs -> return xs
throwIO $ AuthError info
let (TSLoginAck _ tdsVersion' _ _):_ = loginAcks
when (tdsVersion /= tdsVersion') $ throwIO $ ProtocolError "validLoginAck: Server reported unsupported tds version"
return ()
where
isTSLoginAck :: TokenStream -> Bool
isTSLoginAck (TSLoginAck{}) = True
isTSLoginAck _ = False
isTSError :: TokenStream -> Bool
isTSError (TSError{}) = True
isTSError _ = False
printEnvChange :: TokenStream -> IO ()
printEnvChange (TSEnvChange t o n) = do
putStr "TSEnvChange: "
case t of
1 -> T.putStr $ "Database: " <> T.decodeUtf16LE o <> " -> " <> T.decodeUtf16LE n
2 -> T.putStr $ "Language: " <> T.decodeUtf16LE o <> " -> " <> T.decodeUtf16LE n
3 -> T.putStr $ "Charset: " <> T.decodeUtf16LE o <> " -> " <> T.decodeUtf16LE n
4 -> T.putStr $ "PacketSize: " <> T.decodeUtf16LE o <> " -> " <> T.decodeUtf16LE n
5 -> T.putStr $ "DSLID: " <> T.decodeUtf16LE n
6 -> T.putStr $ "DSCFlags: " <> T.decodeUtf16LE n
7 -> putStr $ "Collaction: " <> show o <> " -> " <> show n
8 -> putStr $ "BeginTran: " <> show n
9 -> putStr $ "CommitTran: " <> show o
10 -> putStr $ "RollbackTran: " <> show o
11 -> putStr $ "EnlistDTCTran: " <> show o
12 -> putStr $ "DefactTran: " <> show n
13 -> T.putStr $ "MirrorPartner: " <> T.decodeUtf16LE n
15 -> putStr $ "PromoteTran: " <> show n
16 -> putStr $ "TranManAddr: " <> show n
17 -> putStr $ "TranEndedr: " <> show o
18 -> putStr $ "ResetAck: "
19 -> T.putStr $ "SendsBackInfo: " <> T.decodeUtf16LE n
20 -> putStr $ "Routing: " <> show n
putStrLn mempty
isPLOEncryption :: PreloginOption -> Bool
isPLOEncryption (PLOEncryption{}) = True
isPLOEncryption _ = False
isPLOMars :: PreloginOption -> Bool
isPLOMars (PLOMars{}) = True
isPLOMars _ = False
resolve host port = do
let hints = Socket.defaultHints { addrSocketType = Stream }
addr:_ <- Socket.getAddrInfo (Just hints) (Just host) (Just port)
return addr
connect' addr = do
sock <- Socket.socket (addrFamily addr) (addrSocketType addr) (addrProtocol addr)
Socket.connect sock $ addrAddress addr
return sock
readMessage :: Socket -> Get.Decoder a -> IO a
readMessage sock decoder = do
bs <- recv sock 512
case Get.pushChunk decoder bs of
Get.Done _ _ msg -> return msg
decoder' -> readMessage sock decoder'