{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE Rank2Types #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE NoMonomorphismRestriction #-}

module Database.Memcached.Binary.Internal where

import Network

import Foreign.Ptr
import Foreign.Storable
import Foreign.Marshal.Utils
import Foreign.Marshal.Alloc

import System.IO

import Control.Monad
import Control.Exception
import Control.Concurrent.MVar

import Data.Word
import Data.Pool
import Data.Storable.Endian
import qualified Data.HashMap.Strict as H
import qualified Data.ByteString as S
import qualified Data.ByteString.Lazy as L
import qualified Data.ByteString.Unsafe as S

import Database.Memcached.Binary.Types
import Database.Memcached.Binary.Types.Exception
import Database.Memcached.Binary.Internal.Definition

data Connection
    = Connection (MVar Handle)
    | ConnectionPool (Pool Handle)

withConnection :: ConnectInfo -> (Connection -> IO a) -> IO a
withConnection i m = withSocketsDo $ bracket (connect i) close m

connect :: ConnectInfo -> IO Connection
connect i =
    if numConnection i == 1
    then fmap Connection $ connect' i >>= newMVar
    else fmap ConnectionPool $ 
        createPool (connect' i) (\h -> quit h >> hClose h) 1 
        (connectionIdleTime i) (numConnection i)

connect' :: ConnectInfo -> IO Handle
connect' i = loop (connectAuth i)
  where
    loop [] = do
        connectTo (connectHost i) (connectPort i)

    loop [a] = do
        h <- connectTo (connectHost i) (connectPort i)
        auth a (\_ -> return h) (\w m -> throwIO $ MemcachedException w m) h

    loop (a:as) = do
        h <- connectTo (connectHost i) (connectPort i)
        handle (\(_::IOError) -> loop as) $
            auth a (\_ -> return h) (\_ _ -> loop as) h

close :: Connection -> IO ()
close (Connection mv) = do
    h <- swapMVar mv (error "connection already closed")
    quit h
    hClose h
close (ConnectionPool p) = destroyAllResources p

useConnection :: (Handle -> IO a) -> Connection -> IO a
useConnection f (Connection    mv) = withMVar mv f
useConnection f (ConnectionPool p) = withResource p f

pokeWord8 :: Ptr a -> Word8 -> IO ()
pokeWord8 = poke . castPtr

pokeWord16be :: Ptr a -> Word16 -> IO ()
pokeWord16be p w = poke (castPtr p) (BE w)

pokeWord32be :: Ptr a -> Word32 -> IO ()
pokeWord32be p w = poke (castPtr p) (BE w)

pokeWord64be :: Ptr a -> Word64 -> IO ()
pokeWord64be p w = poke (castPtr p) (BE w)

peekWord8 :: Ptr a -> IO Word8
peekWord8 = peek . castPtr

peekWord16be :: Ptr a -> IO Word16
peekWord16be p = peek (castPtr p) >>= \(BE w) -> return w

peekWord32be :: Ptr a -> IO Word32
peekWord32be p = peek (castPtr p) >>= \(BE w) -> return w

peekWord64be :: Ptr a -> IO Word64
peekWord64be p = peek (castPtr p) >>= \(BE w) -> return w

pokeByteString :: Ptr a -> S.ByteString -> IO ()
pokeByteString p v =
    S.unsafeUseAsCString v $ \cstr ->
        copyBytes (castPtr p) cstr (S.length v)

pokeLazyByteString :: Ptr a -> L.ByteString -> IO ()
pokeLazyByteString p v =
    void $ L.foldlChunks (\mi s -> mi >>= \i -> do
        pokeByteString (plusPtr p i) s
        return $ i + S.length s
    ) (return 0) v

data Header
data Request

mallocRequest :: OpCode -> Key -> Word8 -> (Ptr Request -> IO ())
              -> Int -> (Ptr Request -> IO ()) -> Word32 -> CAS -> IO (Ptr Request)
mallocRequest (OpCode o) key elen epoke vlen vpoke opaque (CAS cas) = do
    let tlen = S.length key + fromIntegral elen + vlen
    p <- mallocBytes (24 + fromIntegral tlen)
    pokeWord8             p     0x80
    pokeWord8    (plusPtr p  1) o
    pokeWord16be (plusPtr p  2) (fromIntegral $ S.length key)
    pokeWord8    (plusPtr p  4) elen
    pokeWord8    (plusPtr p  5) 0x00
    pokeWord16be (plusPtr p  6) 0x00
    pokeWord32be (plusPtr p  8) (fromIntegral tlen)
    pokeWord32be (plusPtr p 12) opaque
    pokeWord64be (plusPtr p 16) cas
    epoke (plusPtr p 24)
    pokeByteString (plusPtr p $ 24 + fromIntegral elen) key
    vpoke (plusPtr p $ 24 + fromIntegral elen + S.length key)
    return p
{-# INLINE mallocRequest #-}

sendRequest :: OpCode -> Key -> Word8 -> (Ptr Request -> IO ())
            -> Int -> (Ptr Request -> IO ()) -> Word32 -> CAS -> Handle -> IO ()
sendRequest op key elen epoke vlen vpoke opaque cas h =
    bracket (mallocRequest op key elen epoke vlen vpoke opaque cas) free $ \req -> do
        hPutBuf h req (24 + S.length key + fromIntegral elen + vlen)
        hFlush h
{-# INLINE sendRequest #-}

type Failure a = Word16 -> S.ByteString -> IO a

peekResponse :: (Ptr Header -> IO a) -> Failure a -> Handle -> IO a
peekResponse success failure h = bracket (mallocBytes 24) free $ \p ->
    hGetBuf h p 24 >> peekWord16be (plusPtr p 6) >>= \st ->
    if st == 0
        then success p
        else do
            bl <- peekWord32be (plusPtr p 8)
            failure st =<< S.hGet h (fromIntegral bl)
{-# INLINE peekResponse #-}

withRequest :: OpCode -> Key -> Word8 -> (Ptr Request -> IO ())
            -> Int -> (Ptr Request -> IO ()) -> CAS
            -> (Handle -> Ptr Header -> IO a) -> Failure a -> Handle -> IO a
withRequest op key elen epoke vlen vpoke cas success failure h = do
    sendRequest  op key elen epoke vlen vpoke 0 cas h
    peekResponse (success h) failure h

getExtraLength :: Ptr Header -> IO Word8
getExtraLength p = peekWord8 (plusPtr p 4)

getKeyLength :: Ptr Header -> IO Word16
getKeyLength p = peekWord16be (plusPtr p 2)

getTotalLength :: Ptr Header -> IO Word32
getTotalLength p = peekWord32be (plusPtr p 8)

getCAS :: Ptr Header -> IO CAS
getCAS p = fmap CAS $ peekWord64be (plusPtr p 16)

getOpaque :: Ptr Header -> IO Word32
getOpaque p = peekWord32be (plusPtr p 12)

nop :: Ptr Request -> IO ()
nop _ = return ()

inspectResponse :: Handle -> Ptr Header 
                -> IO (S.ByteString, S.ByteString, L.ByteString)
inspectResponse h p = do
    el <- getExtraLength p
    kl <- getKeyLength   p
    tl <- getTotalLength p
    e <- S.hGet h $ fromIntegral el
    k <- S.hGet h $ fromIntegral kl
    v <- L.hGet h $ fromIntegral tl - fromIntegral el - fromIntegral kl
    return (e,k,v)

getSuccessCallback :: (Flags -> Value -> IO a)
                   -> Handle -> Ptr Header -> IO a
getSuccessCallback success h p = do
    elen <- getExtraLength p
    tlen <- getTotalLength p
    void $ hGetBuf h p 4
    flags <- peekWord32be p
    value <- L.hGet h (fromIntegral tlen - fromIntegral elen)
    success flags value

get :: (Flags -> Value -> IO a) -> Failure a
    -> Key -> Handle -> IO a
get success failure key = 
    withRequest opGet key 0 nop 0 nop (CAS 0)
    (getSuccessCallback success) failure

getWithCAS :: (CAS -> Flags -> Value -> IO a) -> Failure a
           -> Key -> Handle -> IO a
getWithCAS success failure key =
    withRequest opGet key 0 nop 0 nop (CAS 0)
    (\h p -> getCAS p >>= \c -> getSuccessCallback (success c) h p) failure

setAddReplace :: IO a -> Failure a -> OpCode -> CAS
              -> Key -> Value -> Flags -> Expiry -> Handle -> IO a
setAddReplace success failure o cas key value flags expiry = withRequest o key
        8 (\p -> pokeWord32be p flags >> pokeWord32be (plusPtr p 4) expiry) 
        (fromIntegral $ L.length value) (flip pokeLazyByteString value) cas (\_ _ -> success) failure

setAddReplaceWithCAS :: (CAS -> IO a) -> Failure a -> OpCode -> CAS
                     -> Key -> Value -> Flags -> Expiry -> Handle -> IO a
setAddReplaceWithCAS success failure o cas key value flags expiry = withRequest o key
        8 (\p -> pokeWord32be p flags >> pokeWord32be (plusPtr p 4) expiry) 
        (fromIntegral $ L.length value) (flip pokeLazyByteString value) cas (\_ p -> getCAS p >>= success) failure

delete :: IO a -> Failure a -> CAS -> Key -> Handle -> IO a
delete success failure cas key =
    withRequest opDelete key 0 nop 0 nop cas (\_ _ -> success) failure

incrDecr :: (Word64 -> IO a) -> Failure a -> OpCode -> CAS
         -> Key -> Delta -> Initial -> Expiry -> Handle -> IO a
incrDecr success failure op cas key delta initial expiry =
    withRequest op key 20 extra 0 nop cas success' failure
  where
    extra p = do
        pokeWord64be          p     delta
        pokeWord64be (plusPtr p  8) initial
        pokeWord32be (plusPtr p 16) expiry

    success' h p = do
        void $ hGetBuf h p 8
        peekWord64be p >>= success

quit :: Handle -> IO ()
quit h = do
    sendRequest  opQuit "" 0 nop 0 nop 0 (CAS 0) h
    peekResponse (\_ -> return ()) (\_ _ -> return ()) h

flushAll :: IO a -> Failure a -> Handle -> IO a
flushAll success =
    withRequest opFlush "" 0 nop 0 nop (CAS 0) (\_ _ -> success)

flushWithin :: IO a -> Failure a -> Expiry -> Handle -> IO a
flushWithin success failure w =
    withRequest opFlush "" 4 (flip pokeWord32be w) 0 nop (CAS 0)
    (\_ _ -> success) failure

noOp :: IO a -> Failure a -> Handle -> IO a
noOp success =
    withRequest opNoOp "" 0 nop 0 nop (CAS 0) (\_ _ -> success)

version :: (S.ByteString -> IO a) -> Failure a -> Handle -> IO a
version success =
    withRequest opVersion "" 0 nop 0 nop (CAS 0)
    (\h p -> getTotalLength p >>= S.hGet h . fromIntegral >>= success)

appendPrepend :: IO a -> Failure a -> OpCode -> CAS
              -> Key -> Value -> Handle -> IO a
appendPrepend success failure op cas key value = withRequest op key 0 nop
    (fromIntegral $ L.length value) (flip pokeLazyByteString value)
    cas (\_ _ -> success) failure

stats :: Handle -> IO (H.HashMap S.ByteString S.ByteString)
stats h = loop H.empty
  where
    loop m = do
        sendRequest opStat "" 0 nop 0 nop 0 (CAS 0) h
        peekResponse (success m) (\w s -> throwIO $ MemcachedException w s) h

    success m p = getTotalLength p >>= \tl ->
        if tl == 0
            then return m
            else do
                kl <- getKeyLength p
                k  <- S.hGet h (fromIntegral kl)
                v  <- S.hGet h (fromIntegral tl - fromIntegral kl)
                loop (H.insert k v m)

verbosity :: IO a -> Failure a -> Word32 -> Handle -> IO a
verbosity success failure v = withRequest opVerbosity ""
    4 (flip pokeWord32be v) 0 nop (CAS 0) (\_ _ -> success) failure

touch :: (Flags -> Value -> IO a) -> Failure a -> OpCode
      -> Key -> Expiry -> Handle -> IO a
touch success failure op key e =
    withRequest op key 4 (flip pokeWord32be e) 0 nop (CAS 0)
    (getSuccessCallback success) failure

saslListMechs :: (S.ByteString -> IO a) -> Failure a
              -> Handle -> IO a
saslListMechs success failure =
    withRequest opSaslListMechs "" 0 nop 0 nop (CAS 0)
    (\h p -> getTotalLength p >>= S.hGet h . fromIntegral >>= success)
    failure

auth :: Auth -> (S.ByteString -> IO a) -> Failure a -> Handle -> IO a
auth (Plain u w) success next h = do
    sendRequest  opSaslAuth "PLAIN" 0 nop (S.length u + S.length w + 2) pokeCred 0 (CAS 0) h
    peekResponse consumeResponse next h
  where
    ul = S.length u
    pokeCred p = do
        pokeWord8 p 0
        pokeByteString (plusPtr p        1) u
        pokeWord8      (plusPtr p $ ul + 1) 0
        pokeByteString (plusPtr p $ ul + 2) w

    consumeResponse p = do
        l <- getTotalLength p
        success =<< S.hGet h (fromIntegral l)