{-# LANGUAGE OverloadedStrings #-} module Database.MSSQLServer.Connection ( ConnectInfo(..) , Connection(..) , connect , connectWithoutEncription , close , ProtocolError(..) , AuthError(..) ) where import qualified Network.Socket as Socket import Network.Socket (AddrInfo(..),SocketType(..),Socket(..)) import Network.Socket.ByteString (recv) import Network.Socket.ByteString.Lazy (sendAll) import Data.Monoid ((<>)) import qualified Data.ByteString as B import qualified Data.ByteString.Lazy as LB import qualified Data.Text as T import qualified Data.Text.IO as T import qualified Data.Text.Encoding as T import Data.Binary (Binary(..),encode) import qualified Data.Binary.Put as Put import qualified Data.Binary.Get as Get import Control.Monad (when) import Control.Exception (Exception(..),throwIO) import qualified Network.TLS as TLS import Network.BSD (getHostName) import Database.Tds.Message import Database.Tds.Transport (contextNew) import Data.Word (Word8) data ProtocolError = ProtocolError String deriving (Show) instance Exception ProtocolError data AuthError = AuthError !Info deriving (Show) instance Exception AuthError data ConnectInfo = ConnectInfo { connectHost :: String , connectPort :: String , connectDatabase :: String , connectUser :: String , connectPassword :: String } newtype Connection = Connection Socket connect :: ConnectInfo -> IO Connection connect ci@(ConnectInfo host port database user pass) = do addr <- resolve host port sock <- connect' addr Prelogin plResOpts <- performPrelogin sock 0x00 -- ENCRYPT_OFF (Encrypt login packet only) [PLOEncription modeEnc] <- case filter isPLOEncription plResOpts of [] -> throwIO $ ProtocolError "connect: PLOEncription is necessary" xs -> return xs [PLOMars modeMars] <- case filter isPLOMars plResOpts of [] -> throwIO $ ProtocolError "connect: PLOMars is necessary" xs -> return xs when (modeEnc/=0x00) $ throwIO $ ProtocolError "connect: Server reported unsupported encription mode" when (modeMars/=0) $ throwIO $ ProtocolError "connect: Server reported unsupported mars mode" login7 <- newLogin7 ci --- --- TLS handshake --- tlsContext <- contextNew sock host TLS.handshake tlsContext --- --- Login with encripted packet --- TLS.sendData tlsContext $ encode $ CMLogin7 login7 ServerMessage tss <- readMessage sock $ Get.runGetIncremental get --- --- Verify Ack --- validLoginAck login7 tss return $ Connection sock connectWithoutEncription :: ConnectInfo -> IO Connection connectWithoutEncription ci@(ConnectInfo host port database user pass) = do addr <- resolve host port sock <- connect' addr Prelogin plResOpts <- performPrelogin sock 0x02 -- ENCRYPT_NOT_SUP (No encription) [PLOEncription modeEnc] <- case filter isPLOEncription plResOpts of [] -> throwIO $ ProtocolError "connectWithoutEncription: PLOEncription is necessary" xs -> return xs [PLOMars modeMars] <- case filter isPLOMars plResOpts of [] -> throwIO $ ProtocolError "connectWithoutEncription: PLOMars is necessary" xs -> return xs when (modeEnc/=0x02) $ throwIO $ ProtocolError "connectWithoutEncription: Server reported unsupported encription mode" when (modeMars/=0) $ throwIO $ ProtocolError "connectWithoutEncription: Server reported unsupported mars mode" login7 <- newLogin7 ci --- --- Login without encripted packet --- sendAll sock $ encode $ CMLogin7 login7 ServerMessage tss <- readMessage sock $ Get.runGetIncremental get --- --- Verify Ack --- validLoginAck login7 tss return $ Connection sock close :: Connection -> IO () close (Connection sock) = Socket.close sock performPrelogin :: Socket -> Word8 -> IO Prelogin performPrelogin sock enc = do -- https://docs.microsoft.com/en-us/openspecs/windows_protocols/ms-tds/60f56408-0188-4cd5-8b90-25c6f2423868 -- -- Prelogin -- -- [TODO] Threadid support -- [TODO] Mars support let clientPrelogin = Prelogin [ PLOVersion 9 0 0 0 , PLOEncription enc , PLOInstopt "MSSQLServer" , PLOThreadid (Just 1000) -- [TODO] , PLOMars 0 -- [TODO] ] sendAll sock $ encode $ CMPrelogin clientPrelogin ServerMessage serverPrelogin <- readMessage sock $ Get.runGetIncremental get return serverPrelogin newLogin7 :: ConnectInfo -> IO Login7 newLogin7 (ConnectInfo host port database user pass) = do --- --- Login7 --- -- [TODO] Improve default params -- [TODO] process ID support -- [TODO] MAC address support hostname <- getHostName let login7 = defaultLogin7 { l7ClientPID = 1 -- [TODO] , l7ClientMacAddr = B.pack [0x00,0x00,0x00,0x00,0x00,0x00] -- [TODO] , l7ClientHostName = (T.pack hostname) , l7AppName = "mssql-simple" -- [TODO] more nice name , l7ServerName = (T.pack host) , l7UserName = (T.pack user) , l7Password = (T.pack pass) , l7Database = (T.pack database) } return login7 validLoginAck :: Login7 -> TokenStreams -> IO () validLoginAck login7 (TokenStreams loginResTokenStreams) = do let loginAcks = filter isTSLoginAck loginResTokenStreams when (null loginAcks) $ do [TSError info] <- case filter isTSError loginResTokenStreams of [] -> throwIO $ ProtocolError "validLoginAck: TSError is necessary" xs -> return xs throwIO $ AuthError info let [TSLoginAck _ tdsVersion _ _] = loginAcks when (l7TdsVersion login7 /= tdsVersion) $ throwIO $ ProtocolError "validLoginAck: Server reported unsupported tds version" return () where isTSLoginAck :: TokenStream -> Bool isTSLoginAck (TSLoginAck{}) = True isTSLoginAck _ = False isTSError :: TokenStream -> Bool isTSError (TSError{}) = True isTSError _ = False printEnvChange :: TokenStream -> IO () printEnvChange (TSEnvChange t o n) = do putStr "TSEnvChange: " case t of 1 -> T.putStr $ "Database: " <> T.decodeUtf16LE o <> " -> " <> T.decodeUtf16LE n 2 -> T.putStr $ "Language: " <> T.decodeUtf16LE o <> " -> " <> T.decodeUtf16LE n 3 -> T.putStr $ "Charset: " <> T.decodeUtf16LE o <> " -> " <> T.decodeUtf16LE n 4 -> T.putStr $ "PacketSize: " <> T.decodeUtf16LE o <> " -> " <> T.decodeUtf16LE n 5 -> T.putStr $ "DSLID: " <> T.decodeUtf16LE n 6 -> T.putStr $ "DSCFlags: " <> T.decodeUtf16LE n 7 -> putStr $ "Collaction: " <> show o <> " -> " <> show n 8 -> putStr $ "BeginTran: " <> show n 9 -> putStr $ "CommitTran: " <> show o 10 -> putStr $ "RollbackTran: " <> show o 11 -> putStr $ "EnlistDTCTran: " <> show o 12 -> putStr $ "DefactTran: " <> show n 13 -> T.putStr $ "MirrorPartner: " <> T.decodeUtf16LE n 15 -> putStr $ "PromoteTran: " <> show n 16 -> putStr $ "TranManAddr: " <> show n 17 -> putStr $ "TranEndedr: " <> show o 18 -> putStr $ "ResetAck: " 19 -> T.putStr $ "SendsBackInfo: " <> T.decodeUtf16LE n 20 -> putStr $ "Routing: " <> show n putStrLn mempty isPLOEncription :: PreloginOption -> Bool isPLOEncription (PLOEncription{}) = True isPLOEncription _ = False isPLOMars :: PreloginOption -> Bool isPLOMars (PLOMars{}) = True isPLOMars _ = False resolve host port = do let hints = Socket.defaultHints { addrSocketType = Stream } addr:_ <- Socket.getAddrInfo (Just hints) (Just host) (Just port) return addr connect' addr = do sock <- Socket.socket (addrFamily addr) (addrSocketType addr) (addrProtocol addr) Socket.connect sock $ addrAddress addr return sock readMessage :: Socket -> Get.Decoder a -> IO a readMessage sock decoder = do bs <- recv sock 512 -- [TODO] optimize case Get.pushChunk decoder bs of Get.Done _ _ msg -> return msg decoder' -> readMessage sock decoder'