module Database.MongoDB.Internal.Protocol (
Connection, mkConnection,
send, call,
FullCollection,
Notice(..), UpdateOption(..), DeleteOption(..), CursorId,
Request(..), QueryOption(..),
Reply(..), ResponseFlag(..),
Username, Password, Nonce, pwHash, pwKey
) where
import Prelude as X
import Control.Applicative ((<$>))
import Control.Monad (unless, replicateM)
import System.IO (Handle)
import Data.ByteString.Lazy (ByteString)
import qualified Control.Pipeline as P
import Data.Bson (Document, UString)
import Data.Bson.Binary
import Data.Binary.Put
import Data.Binary.Get
import Data.Int
import Data.Bits
import Database.MongoDB.Internal.Util (bitOr)
import Data.IORef
import System.IO.Unsafe (unsafePerformIO)
import Data.Digest.OpenSSL.MD5 (md5sum)
import Data.UString as U (pack, append, toByteString)
type Connection = P.Pipe Handle ByteString
mkConnection :: Handle -> IO Connection
mkConnection = P.newPipe encodeSize decodeSize where
encodeSize = runPut . putInt32 . toEnum . (+ 4)
decodeSize = subtract 4 . fromEnum . runGet getInt32
send :: Connection -> [Notice] -> IO ()
send conn notices = P.send conn =<< mapM noticeBytes notices
call :: Connection -> [Notice] -> Request -> IO (IO Reply)
call conn notices request = do
nMessages <- mapM noticeBytes notices
requestId <- genRequestId
let rMessage = runPut (putRequest request requestId)
promise <- P.call conn (nMessages ++ [rMessage])
return (bytesReply requestId <$> promise)
noticeBytes :: Notice -> IO ByteString
noticeBytes notice = runPut . putNotice notice <$> genRequestId
bytesReply :: RequestId -> ByteString -> Reply
bytesReply requestId bytes = if requestId == responseTo then reply else err where
(responseTo, reply) = runGet getReply bytes
err = error $ "expected response id (" ++ show responseTo ++ ") to match request id (" ++ show requestId ++ ")"
type FullCollection = UString
type Opcode = Int32
type RequestId = Int32
type ResponseTo = RequestId
genRequestId :: IO RequestId
genRequestId = atomicModifyIORef counter $ \n -> (n + 1, n) where
counter :: IORef RequestId
counter = unsafePerformIO (newIORef 0)
putHeader :: Opcode -> RequestId -> Put
putHeader opcode requestId = do
putInt32 requestId
putInt32 0
putInt32 opcode
getHeader :: Get (Opcode, ResponseTo)
getHeader = do
_requestId <- getInt32
responseTo <- getInt32
opcode <- getInt32
return (opcode, responseTo)
data Notice =
Insert {
iFullCollection :: FullCollection,
iDocuments :: [Document]}
| Update {
uFullCollection :: FullCollection,
uOptions :: [UpdateOption],
uSelector :: Document,
uUpdater :: Document}
| Delete {
dFullCollection :: FullCollection,
dOptions :: [DeleteOption],
dSelector :: Document}
| KillCursors {
kCursorIds :: [CursorId]}
deriving (Show, Eq)
data UpdateOption =
Upsert
| MultiUpdate
deriving (Show, Eq)
data DeleteOption = SingleRemove
deriving (Show, Eq)
type CursorId = Int64
nOpcode :: Notice -> Opcode
nOpcode Update{} = 2001
nOpcode Insert{} = 2002
nOpcode Delete{} = 2006
nOpcode KillCursors{} = 2007
putNotice :: Notice -> RequestId -> Put
putNotice notice requestId = do
putHeader (nOpcode notice) requestId
putInt32 0
case notice of
Insert{..} -> do
putCString iFullCollection
mapM_ putDocument iDocuments
Update{..} -> do
putCString uFullCollection
putInt32 (uBits uOptions)
putDocument uSelector
putDocument uUpdater
Delete{..} -> do
putCString dFullCollection
putInt32 (dBits dOptions)
putDocument dSelector
KillCursors{..} -> do
putInt32 $ toEnum (X.length kCursorIds)
mapM_ putInt64 kCursorIds
uBit :: UpdateOption -> Int32
uBit Upsert = bit 0
uBit MultiUpdate = bit 1
uBits :: [UpdateOption] -> Int32
uBits = bitOr . map uBit
dBit :: DeleteOption -> Int32
dBit SingleRemove = bit 0
dBits :: [DeleteOption] -> Int32
dBits = bitOr . map dBit
data Request =
Query {
qOptions :: [QueryOption],
qFullCollection :: FullCollection,
qSkip :: Int32,
qBatchSize :: Int32,
qSelector :: Document,
qProjector :: Document
} | GetMore {
gFullCollection :: FullCollection,
gBatchSize :: Int32,
gCursorId :: CursorId}
deriving (Show, Eq)
data QueryOption =
TailableCursor
| SlaveOK
| NoCursorTimeout
| AwaitData
deriving (Show, Eq)
qOpcode :: Request -> Opcode
qOpcode Query{} = 2004
qOpcode GetMore{} = 2005
putRequest :: Request -> RequestId -> Put
putRequest request requestId = do
putHeader (qOpcode request) requestId
case request of
Query{..} -> do
putInt32 (qBits qOptions)
putCString qFullCollection
putInt32 qSkip
putInt32 qBatchSize
putDocument qSelector
unless (null qProjector) (putDocument qProjector)
GetMore{..} -> do
putInt32 0
putCString gFullCollection
putInt32 gBatchSize
putInt64 gCursorId
qBit :: QueryOption -> Int32
qBit TailableCursor = bit 1
qBit SlaveOK = bit 2
qBit NoCursorTimeout = bit 4
qBit AwaitData = bit 5
qBits :: [QueryOption] -> Int32
qBits = bitOr . map qBit
data Reply = Reply {
rResponseFlags :: [ResponseFlag],
rCursorId :: CursorId,
rStartingFrom :: Int32,
rDocuments :: [Document]
} deriving (Show, Eq)
data ResponseFlag =
CursorNotFound
| QueryError
| AwaitCapable
deriving (Show, Eq, Enum)
replyOpcode :: Opcode
replyOpcode = 1
getReply :: Get (ResponseTo, Reply)
getReply = do
(opcode, responseTo) <- getHeader
unless (opcode == replyOpcode) $ fail $ "expected reply opcode (1) but got " ++ show opcode
rResponseFlags <- rFlags <$> getInt32
rCursorId <- getInt64
rStartingFrom <- getInt32
numDocs <- fromIntegral <$> getInt32
rDocuments <- replicateM numDocs getDocument
return (responseTo, Reply{..})
rFlags :: Int32 -> [ResponseFlag]
rFlags bits = filter (testBit bits . rBit) [CursorNotFound ..]
rBit :: ResponseFlag -> Int
rBit CursorNotFound = 0
rBit QueryError = 1
rBit AwaitCapable = 3
type Username = UString
type Password = UString
type Nonce = UString
pwHash :: Username -> Password -> UString
pwHash u p = pack . md5sum . toByteString $ u `U.append` ":mongo:" `U.append` p
pwKey :: Nonce -> Username -> Password -> UString
pwKey n u p = pack . md5sum . toByteString . U.append n . U.append u $ pwHash u p