{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE CPP #-}
module Database.MSSQLServer.Query (
sql
, ResultSet (..)
, Result (..)
, Row (..)
, Only (..)
, RowCount (..)
, ReturnStatus (..)
, rpc
, RpcResponseSet (..)
, RpcResponse (..)
, RpcOutputSet (..)
, RpcResultSet (..)
, RpcResult (..)
, RpcQuerySet (..)
, RpcQuery (..)
, RpcQueryId (..)
, StoredProcedure (..)
, RpcParamSet (..)
, RpcParam (..)
, RpcParamName
, bitVal
, tinyintVal
, smallintVal
, intVal
, bigintVal
, smallmoneyVal
, moneyVal
, smalldatetimeVal
, datetimeVal
, float24Val
, realVal
, float53Val
, doubleVal
, uniqueidentifierVal
, decimalVal
, numericVal
, charVal
, varcharVal
, textVal
, ncharVal
, nvarcharVal
, ntextVal
, binaryVal
, varbinaryVal
, imageVal
, bitRef
, tinyintRef
, smallintRef
, intRef
, bigintRef
, smallmoneyRef
, moneyRef
, smalldatetimeRef
, datetimeRef
, float24Ref
, realRef
, float53Ref
, doubleRef
, uniqueidentifierRef
, decimalRef
, numericRef
, charRef
, varcharRef
, textRef
, ncharRef
, nvarcharRef
, ntextRef
, binaryRef
, varbinaryRef
, imageRef
, bitDefRef
, tinyintDefRef
, smallintDefRef
, intDefRef
, bigintDefRef
, smallmoneyDefRef
, moneyDefRef
, smalldatetimeDefRef
, datetimeDefRef
, float24DefRef
, realDefRef
, float53DefRef
, doubleDefRef
, uniqueidentifierDefRef
, decimalDefRef
, numericDefRef
, charDefRef
, varcharDefRef
, textDefRef
, ncharDefRef
, nvarcharDefRef
, ntextDefRef
, binaryDefRef
, varbinaryDefRef
, imageDefRef
, withTransaction
, QueryError (..)
) where
import Data.Typeable(Typeable)
import Network.Socket (Socket)
import Network.Socket.ByteString (recv)
import Network.Socket.ByteString.Lazy (sendAll)
import qualified Data.Text as T
import qualified Data.Binary.Get as Get
import qualified Data.Binary.Put as Put
import Control.Exception (Exception(..),throwIO,onException)
import Database.Tds.Message
import Database.MSSQLServer.Connection (Connection(..))
import Database.MSSQLServer.Query.Only
import Database.MSSQLServer.Query.Row
import Database.MSSQLServer.Query.ResultSet
import Database.MSSQLServer.Query.RpcResponseSet
import Database.MSSQLServer.Query.RpcQuerySet
import Database.MSSQLServer.Query.TokenStreamParser
#if MIN_VERSION_mtl(2,2,1)
import Control.Monad.Except
#else
import Control.Monad.Error
runExceptT = runErrorT
#endif
data QueryError = QueryError !Info
deriving (Show,Typeable)
instance Exception QueryError
sql :: ResultSet a => Connection -> T.Text -> IO a
sql (Connection sock ps) query = do
sendAll sock $ Put.runPut $ putClientMessage ps $ CMSqlBatch $ SqlBatch query
TokenStreams tss <- readMessage sock $ Get.runGetIncremental getServerMessage
case parse responseParser tss of
[] -> fail "sql: failed to parse token streams"
(Left info,_):_ -> throwIO $ QueryError info
(Right x,_):_ -> return x
where
responseParser :: (ResultSet a) => Parser (Either Info a)
responseParser = runExceptT $ resultSetParser
rpc :: (RpcQuerySet a, RpcResponseSet b) => Connection -> a -> IO b
rpc (Connection sock ps) queries = do
sendAll sock $ Put.runPut $ putClientMessage ps $ CMRpcRequest $ toRpcRequest queries
TokenStreams tss <- readMessage sock $ Get.runGetIncremental getServerMessage
case parse rpcResponseSetParser tss of
[] -> fail "rpc: failed to parse token streams"
(x,_):_ -> return x
withTransaction :: Connection -> IO a -> IO a
withTransaction conn act = do
begin
r <- act `onException` rollback
commit
return r
where
begin = sql conn $ T.pack "BEGIN TRANSACTION" :: IO ()
commit = sql conn $ T.pack "COMMIT TRANSACTION" :: IO ()
rollback = sql conn $ T.pack "ROLLBACK TRANSACTION":: IO ()
readMessage :: Socket -> Get.Decoder a -> IO a
readMessage sock decoder = do
bs <- recv sock 512
case Get.pushChunk decoder bs of
Get.Done _ _ msg -> return msg
decoder' -> readMessage sock decoder'