{-# LANGUAGE OverloadedStrings #-}
module Database.MSSQLServer.Connection ( ConnectInfo(..)
, Connection(..)
, connect
, connectWithoutEncription
, close
, ProtocolError(..)
, AuthError(..)
) where
import qualified Network.Socket as Socket
import Network.Socket (AddrInfo(..),SocketType(..),Socket(..))
import Network.Socket.ByteString (recv)
import Network.Socket.ByteString.Lazy (sendAll)
import Data.Monoid ((<>))
import qualified Data.ByteString as B
import qualified Data.ByteString.Lazy as LB
import qualified Data.Text as T
import qualified Data.Text.IO as T
import qualified Data.Text.Encoding as T
import Data.Binary (Binary(..),encode)
import qualified Data.Binary.Put as Put
import qualified Data.Binary.Get as Get
import Control.Monad (when)
import Control.Exception (Exception(..),throwIO)
import qualified Network.TLS as TLS
import Network.HostName (getHostName)
import Database.Tds.Message
import Database.Tds.Transport (contextNew)
import Data.Word (Word8)
data ProtocolError = ProtocolError String
deriving (Show)
instance Exception ProtocolError
data AuthError = AuthError !Info
deriving (Show)
instance Exception AuthError
data ConnectInfo = ConnectInfo { connectHost :: String
, connectPort :: String
, connectDatabase :: String
, connectUser :: String
, connectPassword :: String
}
newtype Connection = Connection Socket
connect :: ConnectInfo -> IO Connection
connect ci@(ConnectInfo host port database user pass) = do
addr <- resolve host port
sock <- connect' addr
Prelogin plResOpts <- performPrelogin sock 0x00
[PLOEncription modeEnc] <- case filter isPLOEncription plResOpts of
[] -> throwIO $ ProtocolError "connect: PLOEncription is necessary"
xs -> return xs
[PLOMars modeMars] <- case filter isPLOMars plResOpts of
[] -> throwIO $ ProtocolError "connect: PLOMars is necessary"
xs -> return xs
when (modeEnc/=0x00) $ throwIO $ ProtocolError "connect: Server reported unsupported encription mode"
when (modeMars/=0) $ throwIO $ ProtocolError "connect: Server reported unsupported mars mode"
login7 <- newLogin7 ci
tlsContext <- contextNew sock host
TLS.handshake tlsContext
TLS.sendData tlsContext $ encode $ CMLogin7 login7
ServerMessage tss <- readMessage sock $ Get.runGetIncremental get
validLoginAck login7 tss
return $ Connection sock
connectWithoutEncription :: ConnectInfo -> IO Connection
connectWithoutEncription ci@(ConnectInfo host port database user pass) = do
addr <- resolve host port
sock <- connect' addr
Prelogin plResOpts <- performPrelogin sock 0x02
[PLOEncription modeEnc] <- case filter isPLOEncription plResOpts of
[] -> throwIO $ ProtocolError "connectWithoutEncription: PLOEncription is necessary"
xs -> return xs
[PLOMars modeMars] <- case filter isPLOMars plResOpts of
[] -> throwIO $ ProtocolError "connectWithoutEncription: PLOMars is necessary"
xs -> return xs
when (modeEnc/=0x02) $ throwIO $ ProtocolError "connectWithoutEncription: Server reported unsupported encription mode"
when (modeMars/=0) $ throwIO $ ProtocolError "connectWithoutEncription: Server reported unsupported mars mode"
login7 <- newLogin7 ci
sendAll sock $ encode $ CMLogin7 login7
ServerMessage tss <- readMessage sock $ Get.runGetIncremental get
validLoginAck login7 tss
return $ Connection sock
close :: Connection -> IO ()
close (Connection sock) = Socket.close sock
performPrelogin :: Socket -> Word8 -> IO Prelogin
performPrelogin sock enc = do
let clientPrelogin = Prelogin [ PLOVersion 9 0 0 0
, PLOEncription enc
, PLOInstopt "MSSQLServer"
, PLOThreadid (Just 1000)
, PLOMars 0
]
sendAll sock $ encode $ CMPrelogin clientPrelogin
ServerMessage serverPrelogin <- readMessage sock $ Get.runGetIncremental get
return serverPrelogin
newLogin7 :: ConnectInfo -> IO Login7
newLogin7 (ConnectInfo host port database user pass) = do
hostname <- getHostName
let login7 = defaultLogin7 { l7ClientPID = 1
, l7ClientMacAddr = B.pack [0x00,0x00,0x00,0x00,0x00,0x00]
, l7ClientHostName = (T.pack hostname)
, l7AppName = "mssql-simple"
, l7ServerName = (T.pack host)
, l7UserName = (T.pack user)
, l7Password = (T.pack pass)
, l7Database = (T.pack database)
}
return login7
validLoginAck :: Login7 -> TokenStreams -> IO ()
validLoginAck login7 (TokenStreams loginResTokenStreams) = do
let loginAcks = filter isTSLoginAck loginResTokenStreams
when (null loginAcks) $ do
[TSError info] <- case filter isTSError loginResTokenStreams of
[] -> throwIO $ ProtocolError "validLoginAck: TSError is necessary"
xs -> return xs
throwIO $ AuthError info
let [TSLoginAck _ tdsVersion _ _] = loginAcks
when (l7TdsVersion login7 /= tdsVersion) $ throwIO $ ProtocolError "validLoginAck: Server reported unsupported tds version"
return ()
where
isTSLoginAck :: TokenStream -> Bool
isTSLoginAck (TSLoginAck{}) = True
isTSLoginAck _ = False
isTSError :: TokenStream -> Bool
isTSError (TSError{}) = True
isTSError _ = False
printEnvChange :: TokenStream -> IO ()
printEnvChange (TSEnvChange t o n) = do
putStr "TSEnvChange: "
case t of
1 -> T.putStr $ "Database: " <> T.decodeUtf16LE o <> " -> " <> T.decodeUtf16LE n
2 -> T.putStr $ "Language: " <> T.decodeUtf16LE o <> " -> " <> T.decodeUtf16LE n
3 -> T.putStr $ "Charset: " <> T.decodeUtf16LE o <> " -> " <> T.decodeUtf16LE n
4 -> T.putStr $ "PacketSize: " <> T.decodeUtf16LE o <> " -> " <> T.decodeUtf16LE n
5 -> T.putStr $ "DSLID: " <> T.decodeUtf16LE n
6 -> T.putStr $ "DSCFlags: " <> T.decodeUtf16LE n
7 -> putStr $ "Collaction: " <> show o <> " -> " <> show n
8 -> putStr $ "BeginTran: " <> show n
9 -> putStr $ "CommitTran: " <> show o
10 -> putStr $ "RollbackTran: " <> show o
11 -> putStr $ "EnlistDTCTran: " <> show o
12 -> putStr $ "DefactTran: " <> show n
13 -> T.putStr $ "MirrorPartner: " <> T.decodeUtf16LE n
15 -> putStr $ "PromoteTran: " <> show n
16 -> putStr $ "TranManAddr: " <> show n
17 -> putStr $ "TranEndedr: " <> show o
18 -> putStr $ "ResetAck: "
19 -> T.putStr $ "SendsBackInfo: " <> T.decodeUtf16LE n
20 -> putStr $ "Routing: " <> show n
putStrLn mempty
isPLOEncription :: PreloginOption -> Bool
isPLOEncription (PLOEncription{}) = True
isPLOEncription _ = False
isPLOMars :: PreloginOption -> Bool
isPLOMars (PLOMars{}) = True
isPLOMars _ = False
resolve host port = do
let hints = Socket.defaultHints { addrSocketType = Stream }
addr:_ <- Socket.getAddrInfo (Just hints) (Just host) (Just port)
return addr
connect' addr = do
sock <- Socket.socket (addrFamily addr) (addrSocketType addr) (addrProtocol addr)
Socket.connect sock $ addrAddress addr
return sock
readMessage :: Socket -> Get.Decoder a -> IO a
readMessage sock decoder = do
bs <- recv sock 512
case Get.pushChunk decoder bs of
Get.Done _ _ msg -> return msg
decoder' -> readMessage sock decoder'