module Database.MongoDB.Internal.Protocol (
	FullCollection,
	
	Pipe, newPipe, send, call,
	
	Notice(..), InsertOption(..), UpdateOption(..), DeleteOption(..), CursorId,
	
	Request(..), QueryOption(..),
	
	Reply(..), ResponseFlag(..),
	
	Username, Password, Nonce, pwHash, pwKey
) where
import Control.Applicative ((<$>))
import Control.Arrow ((***))
import Control.Exception (try)
import Control.Monad (forM_, replicateM, unless)
import Data.Binary.Get (Get, runGet)
import Data.Binary.Put (Put, runPut)
import Data.Bits (bit, testBit)
import Data.Int (Int32, Int64)
import Data.IORef (IORef, newIORef, atomicModifyIORef)
import System.IO (Handle, hClose, hFlush)
import System.IO.Unsafe (unsafePerformIO)
import qualified Data.ByteString.Lazy as L
import Control.Monad.Error (ErrorT(..))
import Control.Monad.Trans (MonadIO, liftIO)
import Data.Bson (Document)
import Data.Bson.Binary (getDocument, putDocument, getInt32, putInt32, getInt64,
                         putInt64, putCString)
import Data.Text (Text)
import qualified Crypto.Hash.MD5 as MD5
import qualified Data.Text as T
import qualified Data.Text.Encoding as TE
import Database.MongoDB.Internal.Util (whenJust, hGetN, bitOr, byteStringHex)
import System.IO.Pipeline (IOE, Pipeline, newPipeline, IOStream(..))
import qualified System.IO.Pipeline as P
type Pipe = Pipeline Response Message
newPipe :: Handle -> IO Pipe
newPipe handle = newPipeline $ IOStream (writeMessage handle) (readMessage handle) (hClose handle)
send :: Pipe -> [Notice] -> IOE ()
send pipe notices = P.send pipe (notices, Nothing)
call :: Pipe -> [Notice] -> Request -> IOE (IOE Reply)
call pipe notices request = do
	requestId <- genRequestId
	promise <- P.call pipe (notices, Just (request, requestId))
	return $ check requestId <$> promise
 where
	check requestId (responseTo, reply) = if requestId == responseTo then reply else
		error $ "expected response id (" ++ show responseTo ++ ") to match request id (" ++ show requestId ++ ")"
type Message = ([Notice], Maybe (Request, RequestId))
writeMessage :: Handle -> Message -> IOE ()
writeMessage handle (notices, mRequest) = ErrorT . try $ do
	forM_ notices $ \n -> writeReq . (Left n,) =<< genRequestId
	whenJust mRequest $ writeReq . (Right *** id)
	hFlush handle
 where
	writeReq (e, requestId) = do
		L.hPut handle lenBytes
		L.hPut handle bytes
	 where
		bytes = runPut $ (either putNotice putRequest e) requestId
		lenBytes = encodeSize . toEnum . fromEnum $ L.length bytes
	encodeSize = runPut . putInt32 . (+ 4)
type Response = (ResponseTo, Reply)
readMessage :: Handle -> IOE Response
readMessage handle = ErrorT $ try readResp  where
	readResp = do
		len <- fromEnum . decodeSize <$> hGetN handle 4
		runGet getReply <$> hGetN handle len
	decodeSize = subtract 4 . runGet getInt32
type FullCollection = Text
type Opcode = Int32
type RequestId = Int32
type ResponseTo = RequestId
genRequestId :: (MonadIO m) => m RequestId
genRequestId = liftIO $ atomicModifyIORef counter $ \n -> (n + 1, n) where
	counter :: IORef RequestId
	counter = unsafePerformIO (newIORef 0)
	
putHeader :: Opcode -> RequestId -> Put
putHeader opcode requestId = do
	putInt32 requestId
	putInt32 0
	putInt32 opcode
getHeader :: Get (Opcode, ResponseTo)
getHeader = do
	_requestId <- getInt32
	responseTo <- getInt32
	opcode <- getInt32
	return (opcode, responseTo)
data Notice =
	  Insert {
	  	iFullCollection :: FullCollection,
	  	iOptions :: [InsertOption],
	  	iDocuments :: [Document]}
	| Update {
		uFullCollection :: FullCollection,
		uOptions :: [UpdateOption],
		uSelector :: Document,
		uUpdater :: Document}
	| Delete {
		dFullCollection :: FullCollection,
		dOptions :: [DeleteOption],
		dSelector :: Document}
	| KillCursors {
		kCursorIds :: [CursorId]}
	deriving (Show, Eq)
data InsertOption = KeepGoing  
	deriving (Show, Eq)
data UpdateOption =
	  Upsert  
	| MultiUpdate  
	deriving (Show, Eq)
data DeleteOption = SingleRemove  
	deriving (Show, Eq)
type CursorId = Int64
nOpcode :: Notice -> Opcode
nOpcode Update{} = 2001
nOpcode Insert{} = 2002
nOpcode Delete{} = 2006
nOpcode KillCursors{} = 2007
putNotice :: Notice -> RequestId -> Put
putNotice notice requestId = do
	putHeader (nOpcode notice) requestId
	case notice of
		Insert{..} -> do
			putInt32 (iBits iOptions)
			putCString iFullCollection
			mapM_ putDocument iDocuments
		Update{..} -> do
			putInt32 0
			putCString uFullCollection
			putInt32 (uBits uOptions)
			putDocument uSelector
			putDocument uUpdater
		Delete{..} -> do
			putInt32 0
			putCString dFullCollection
			putInt32 (dBits dOptions)
			putDocument dSelector
		KillCursors{..} -> do
			putInt32 0
			putInt32 $ toEnum (length kCursorIds)
			mapM_ putInt64 kCursorIds
iBit :: InsertOption -> Int32
iBit KeepGoing = bit 0
iBits :: [InsertOption] -> Int32
iBits = bitOr . map iBit
uBit :: UpdateOption -> Int32
uBit Upsert = bit 0
uBit MultiUpdate = bit 1
uBits :: [UpdateOption] -> Int32
uBits = bitOr . map uBit
dBit :: DeleteOption -> Int32
dBit SingleRemove = bit 0
dBits :: [DeleteOption] -> Int32
dBits = bitOr . map dBit
data Request =
	  Query {
		qOptions :: [QueryOption],
		qFullCollection :: FullCollection,
		qSkip :: Int32,  
		qBatchSize :: Int32,  
		qSelector :: Document,  
		qProjector :: Document  
	} | GetMore {
		gFullCollection :: FullCollection,
		gBatchSize :: Int32,
		gCursorId :: CursorId}
	deriving (Show, Eq)
data QueryOption =
	  TailableCursor  
	| SlaveOK  
	| NoCursorTimeout  
	| AwaitData  
	| Partial  
	deriving (Show, Eq)
qOpcode :: Request -> Opcode
qOpcode Query{} = 2004
qOpcode GetMore{} = 2005
putRequest :: Request -> RequestId -> Put
putRequest request requestId = do
	putHeader (qOpcode request) requestId
	case request of
		Query{..} -> do
			putInt32 (qBits qOptions)
			putCString qFullCollection
			putInt32 qSkip
			putInt32 qBatchSize
			putDocument qSelector
			unless (null qProjector) (putDocument qProjector)
		GetMore{..} -> do
			putInt32 0
			putCString gFullCollection
			putInt32 gBatchSize
			putInt64 gCursorId
qBit :: QueryOption -> Int32
qBit TailableCursor = bit 1
qBit SlaveOK = bit 2
qBit NoCursorTimeout = bit 4
qBit AwaitData = bit 5
qBit Partial = bit 7
qBits :: [QueryOption] -> Int32
qBits = bitOr . map qBit
data Reply = Reply {
	rResponseFlags :: [ResponseFlag],
	rCursorId :: CursorId,  
	rStartingFrom :: Int32,
	rDocuments :: [Document]
	} deriving (Show, Eq)
data ResponseFlag =
	  CursorNotFound  
	| QueryError  
	| AwaitCapable  
	deriving (Show, Eq, Enum)
replyOpcode :: Opcode
replyOpcode = 1
getReply :: Get (ResponseTo, Reply)
getReply = do
	(opcode, responseTo) <- getHeader
	unless (opcode == replyOpcode) $ fail $ "expected reply opcode (1) but got " ++ show opcode
	rResponseFlags <-  rFlags <$> getInt32
	rCursorId <- getInt64
	rStartingFrom <- getInt32
	numDocs <- fromIntegral <$> getInt32
	rDocuments <- replicateM numDocs getDocument
	return (responseTo, Reply{..})
rFlags :: Int32 -> [ResponseFlag]
rFlags bits = filter (testBit bits . rBit) [CursorNotFound ..]
rBit :: ResponseFlag -> Int
rBit CursorNotFound = 0
rBit QueryError = 1
rBit AwaitCapable = 3
type Username = Text
type Password = Text
type Nonce = Text
pwHash :: Username -> Password -> Text
pwHash u p = T.pack . byteStringHex . MD5.hash . TE.encodeUtf8 $ u `T.append` ":mongo:" `T.append` p
pwKey :: Nonce -> Username -> Password -> Text
pwKey n u p = T.pack . byteStringHex . MD5.hash . TE.encodeUtf8 . T.append n . T.append u $ pwHash u p