{-
Copyright (c) 2010-2011, Alexander Bogdanov <andorn@gmail.com>
License: MIT
-}

{-# LANGUAGE CPP, OverloadedStrings, BangPatterns #-}
module Database.Redis.Internal where

#if MIN_VERSION_utf8_string(1,0,0)
import Prelude hiding (catch)
#else
import Prelude hiding (catch, putStrLn, putStr)
import System.IO.UTF8 (putStrLn, putStr)
#endif
import Control.Concurrent (ThreadId, myThreadId)
import qualified Control.Concurrent.RLock as RLock
import Data.IORef
import qualified System.IO as IO
import qualified Data.ByteString as B
import Data.ByteString (ByteString)
import Data.ByteString.Char8 (readInt)
import qualified Data.ByteString.UTF8 as U
import Data.Maybe (fromJust, isNothing, isJust)
import Data.List (intersperse)
import qualified Data.Map as Map
import Data.Map (Map(..))
import Control.Monad (when)
import Control.Exception (bracket, bracketOnError, catch, SomeException)

import Database.Redis.ByteStringClass

#if __GLASGOW_HASKELL__ < 700
import Control.Exception (block)
#else
import Control.Exception.Base (mask)
block f = mask $ \ _ -> f
#endif

tracebs = putStrLn . U.toString
tracebs' = putStr . U.toString

data RedisState = RedisState { server :: (String, String), -- ^ hostname and port pair
                               database :: Int,            -- ^ currently selected database
                               handle :: IO.Handle,        -- ^ real network connection
                               isSubscribed :: Int,        -- ^ currently in PUB/SUB mode
                               renamedCommands :: Map ByteString ByteString -- ^ map of the renamed commands
                             }

-- | Redis connection descriptor
data Redis = Redis {r_lock     :: RLock.RLock,
                    r_st       :: IORef RedisState}
             deriving Eq

newRedis :: (String, String) -> IO.Handle -> IO Redis
newRedis server h = do l <- RLock.new
                       st <- newIORef $ RedisState server 0 h 0 Map.empty
                       return $ Redis l st

-- | Redis command variants
data Command = CInline ByteString
             | CMInline [ByteString]
             | CBulk [ByteString] ByteString
             | CMBulk [ByteString]

-- | Redis reply variants
data Reply s = RTimeout               -- ^ Timeout. Currently unused
             | RParseError String     -- ^ Error converting value from ByteString. It's a client-side error.
             | ROk                    -- ^ \"Ok\" reply
             | RPong                  -- ^ Reply for the ping command
             | RQueued                -- ^ Used inside multi-exec block
             | RError String          -- ^ Some kind of server-side error
             | RInline s              -- ^ Simple oneline reply
             | RInt Int               -- ^ Integer reply
             | RBulk (Maybe s)        -- ^ Multiline reply
             | RMulti (Maybe [Reply s]) -- ^ Complex reply. It may consists of various type of replys
               deriving Eq

showbs :: BS s => s -> String
showbs = U.toString . toBS

instance BS s => Show (Reply s) where
    show RTimeout = "RTimeout"
    show (RParseError msg) = "RParseError: " ++ msg
    show ROk = "ROk"
    show RPong = "RPong"
    show RQueued = "RQueued"
    show (RError msg) = "RError: " ++ msg
    show (RInline s) = "RInline (" ++ (showbs s) ++ ")"
    show (RInt a) = "RInt " ++ show a
    show (RBulk (Just s)) = "RBulk " ++ showbs s
    show (RBulk Nothing) = "RBulk Nil"
    show (RMulti (Just rs)) = "RMulti [" ++ join rs ++ "]"
                              where join = concat . intersperse ", " . map show
    show (RMulti Nothing) = "RMulti Nil"

data Message s = MSubscribe s Int   -- ^ subscribed
               | MUnsubscribe s Int -- ^ unsubscribed
               | MPSubscribe s Int  -- ^ pattern subscribed
               | MPUnsubscribe s Int -- ^ pattern unsubscribed
               | MMessage s s        -- ^ message recieved
               | MPMessage s s s     -- ^ message recieved by pattern
                 deriving (Eq, Show)

urn       = U.fromString "\r\n"
uspace    = U.fromString " "
uminus    = U.fromString "-"
uplus     = U.fromString "+"
ucolon    = U.fromString ":"
ubucks    = U.fromString "$"
uasterisk = U.fromString "*"

hPutRn h = B.hPut h urn
{-# INLINE hPutRn #-}

takeState :: Redis -> IO RedisState
takeState r = block $ do RLock.acquire $ r_lock r
                         readIORef $ r_st r

putState :: Redis -> RedisState -> IO ()
putState r s = block $ do lstate <- RLock.state $ r_lock r
                          mytid <- myThreadId
                          case lstate of
                            Just (mytid, _) -> do writeIORef (r_st r) s
                                                  RLock.release $ r_lock r
                            otherwise -> error "putState: trying put state that was not took"

putStateUnmodified :: Redis -> IO ()
putStateUnmodified r = RLock.release $ r_lock r

inState :: Redis -> (RedisState -> IO (RedisState, a)) -> IO a
inState r action = bracketOnError (takeState r) (\_ -> putStateUnmodified r)
                     $ \s -> do (s', a) <- action s
                                putState r s'
                                return a

inState_ :: Redis -> (RedisState -> IO RedisState) -> IO ()
inState_ r action = bracketOnError (takeState r) (\_ -> putStateUnmodified r) (\s -> action s >>= putState r)

withState :: Redis -> (RedisState -> IO a) -> IO a
withState r action = bracket (takeState r) (\_ -> putStateUnmodified r) action

withState' = flip withState

send :: IO.Handle -> [ByteString] -> IO ()
send h [] = return ()
send h (bs:ls) = B.hPut h bs >> B.hPut h uspace >> send h ls

lookupRenamed :: RedisState -> ByteString -> ByteString
lookupRenamed r c = let c' = Map.findWithDefault c c (renamedCommands r)
                    in if B.null c'
                         then error $ "Command " ++ (fromBS c :: String) ++ " is disabled"
                         else c'

sendCommand :: RedisState -> Command -> IO ()
sendCommand r (CInline bs) = let h = handle r
                                 cmd = lookupRenamed r bs
                             in B.hPut h cmd >> hPutRn h >> IO.hFlush h
sendCommand r (CMInline (l:ls)) = let h = handle r
                                      cmd = lookupRenamed r l
                                  in send h (cmd:ls) >> hPutRn h >> IO.hFlush h
sendCommand r (CBulk (l:ls) bs) = let h = handle r
                                      size = U.fromString $ show $ B.length bs
                                      cmd = lookupRenamed r l
                                  in do send h (cmd:ls)
                                        B.hPut h uspace
                                        B.hPut h size
                                        hPutRn h
                                        B.hPut h bs
                                        hPutRn h
                                        IO.hFlush h
sendCommand r (CMBulk s@(c:cs)) = let h = handle r
                                      sendls [] = return ()
                                      sendls (bs:ls) = let size = U.fromString . show . B.length
                                                       in do B.hPut h ubucks
                                                             B.hPut h $ size bs
                                                             hPutRn h
                                                             B.hPut h bs
                                                             hPutRn h
                                                             sendls ls
                                      c' = lookupRenamed r c
                                  in do B.hPut h uasterisk
                                        B.hPut h $ U.fromString $ show $ length s
                                        hPutRn h
                                        sendls (c':cs)
                                        IO.hFlush h

sendCommand' = flip sendCommand

recv :: BS s => RedisState -> IO (Reply s)
recv r = do first <- trim `fmap` B.hGetLine h
            case U.uncons first of
              Just ('-', rest)  -> recv_err rest
              Just ('+', rest)  -> recv_inline rest
              Just (':', rest)  -> recv_int rest
              Just ('$', rest)  -> recv_bulk rest
              Just ('*', rest)  -> recv_multi rest
    where
      h = handle r
      trim = B.takeWhile (\c -> c /= 13 && c /= 10)

      safeFromBS constructor bs = (return $! constructor $! fromBS bs)
                                  `catch`
                                  (\e -> let msg = show (e :: SomeException)
                                         in return $ RParseError msg)

      -- recv_err :: ByteString -> IO Reply
      recv_err rest = return $ RError $ U.toString rest

      -- recv_inline :: ByteString -> IO Reply
      recv_inline rest = case rest of
                           "OK"       -> return ROk
                           "PONG"     -> return RPong
                           "QUEUED"   -> return RQueued
                           _          -> safeFromBS RInline rest

      -- recv_int :: ByteString -> IO Reply
      recv_int rest = let reply = fst $ fromJust $ readInt rest
                      in return $ RInt reply

      -- recv_bulk :: ByteString -> IO Reply
      recv_bulk rest = let size = fst $ fromJust $ readInt rest
                       in do body <- recv_bulk_body size
                             maybe (return $ RBulk Nothing) (safeFromBS (RBulk . Just)) body

      -- recv_bulk_body :: Int -> IO (Maybe ByteString)
      recv_bulk_body (-1) = return Nothing
      recv_bulk_body size = do body <- B.hGet h (size + 2)
                               let reply = B.take size body
                               return $ Just reply

      -- recv_multi :: ByteString -> IO Reply
      recv_multi rest = let cnt = fst $ fromJust $ readInt rest
                        in do bulks <- recv_multi_n cnt
                              return $ RMulti bulks

      -- recv_multi_n :: Int -> IO (Maybe [Reply])
      recv_multi_n (-1) = return Nothing
      recv_multi_n 0    = return $ Just []
      recv_multi_n n = do this <- recv r
                          tail <- fromJust `fmap` recv_multi_n (n-1)
                          return $ Just (this : tail)

wait :: RedisState -> Int -> IO Bool
wait rs = IO.hWaitForInput (handle rs)