{-# LANGUAGE DeriveDataTypeable #-}
module Database.MSSQLServer.Query (
sql
, ResultSet (..)
, Result (..)
, Row (..)
, Only (..)
, rpc
, RpcResponseSet (..)
, RpcResponse (..)
, RpcOutputSet (..)
, RpcResultSet (..)
, RpcResult (..)
, RpcQuerySet (..)
, RpcQuery (..)
, RpcQueryId (..)
, StoredProcedure (..)
, RpcParamSet (..)
, RpcParam (..)
, RpcParamName
, nvarcharVal
, ntextVal
, varcharVal
, textVal
, withTransaction
, QueryError (..)
) where
import Control.Applicative((<$>))
import Data.Monoid ((<>))
import Data.Typeable(Typeable)
import Network.Socket (Socket)
import Network.Socket.ByteString (recv)
import Network.Socket.ByteString.Lazy (sendAll)
import qualified Data.ByteString as B
import qualified Data.ByteString.Lazy as LB
import qualified Data.Text as T
import qualified Data.Binary.Get as Get
import qualified Data.Binary.Put as Put
import Control.Monad (when)
import Control.Exception (Exception(..),throwIO,onException)
import Database.Tds.Message
import Database.MSSQLServer.Connection
import Database.MSSQLServer.Query.Only
import Database.MSSQLServer.Query.Row
import Database.MSSQLServer.Query.ResultSet
import Database.MSSQLServer.Query.RpcResponseSet
import Database.MSSQLServer.Query.RpcQuerySet
import Database.MSSQLServer.Query.TokenStreamParser
data QueryError = QueryError !Info
deriving (Show,Typeable)
instance Exception QueryError
sql :: ResultSet a => Connection -> T.Text -> IO a
sql (Connection sock ps) query = do
sendAll sock $ Put.runPut $ putClientMessage ps $ CMSqlBatch $ SqlBatch query
TokenStreams tss <- readMessage sock $ Get.runGetIncremental getServerMessage
case filter isTSError tss of
[] -> case parse resultSetParser tss of
[] -> error "sql: failed to parse token streams"
(x,_):_ -> return x
TSError info :_ -> throwIO $ QueryError info
rpc :: (RpcQuerySet a, RpcResponseSet b) => Connection -> a -> IO b
rpc (Connection sock ps) queries = do
sendAll sock $ Put.runPut $ putClientMessage ps $ CMRpcRequest $ toRpcRequest queries
TokenStreams tss <- readMessage sock $ Get.runGetIncremental getServerMessage
case filter isTSError tss of
[] -> case parse rpcResponseSetParser tss of
[] -> error "rpc: failed to parse token streams"
(x,_):_ -> return x
TSError info :_ -> throwIO $ QueryError info
nvarcharVal :: RpcParamName -> T.Text -> RpcParam T.Text
nvarcharVal name ts = RpcParamVal name (TINVarChar (fromIntegral $ (T.length ts) * 2) (Collation 0x00000000 0x00)) ts
ntextVal :: RpcParamName -> T.Text -> RpcParam T.Text
ntextVal name ts = RpcParamVal name (TINText (fromIntegral $ (T.length ts) * 2) (Collation 0x00000000 0x00)) ts
varcharVal :: RpcParamName -> B.ByteString -> RpcParam B.ByteString
varcharVal name bs = RpcParamVal name (TIBigVarChar (fromIntegral $ B.length bs) (Collation 0x00000000 0x00)) bs
textVal :: RpcParamName -> B.ByteString -> RpcParam B.ByteString
textVal name bs = RpcParamVal name (TIText (fromIntegral $ B.length bs) (Collation 0x00000000 0x00)) bs
withTransaction :: Connection -> IO a -> IO a
withTransaction conn act = do
begin
r <- act `onException` rollback
commit
return r
where
begin = sql conn $ T.pack "BEGIN TRANSACTION" :: IO ()
commit = sql conn $ T.pack "COMMIT TRANSACTION" :: IO ()
rollback = sql conn $ T.pack "ROLLBACK TRANSACTION":: IO ()
isTSError :: TokenStream -> Bool
isTSError (TSError{}) = True
isTSError _ = False
readMessage :: Socket -> Get.Decoder a -> IO a
readMessage sock decoder = do
bs <- recv sock 512
case Get.pushChunk decoder bs of
Get.Done _ _ msg -> return msg
decoder' -> readMessage sock decoder'