{-
Copyright (c) 2010-2011, Alexander Bogdanov <andorn@gmail.com>
License: MIT
-}

{-# LANGUAGE CPP, OverloadedStrings, BangPatterns #-}
module Database.Redis.Internal where

import Prelude hiding (putStrLn, putStr, catch)
import Control.Concurrent (ThreadId, myThreadId)
import qualified Control.Concurrent.RLock as RLock
import Data.IORef
import qualified System.IO as IO
import System.IO.UTF8 (putStrLn, putStr)
import qualified Data.ByteString as B
import Data.ByteString (ByteString)
import Data.ByteString.Char8 (readInt)
import qualified Data.ByteString.UTF8 as U
import Data.Maybe (fromJust, isNothing, isJust)
import Data.List (intersperse)
import qualified Data.Map as Map
import Data.Map (Map(..))
import Control.Monad (when)
import Control.Exception (bracket, bracketOnError, catch, SomeException)

import Database.Redis.ByteStringClass

#if __GLASGOW_HASKELL__ < 700
import Control.Exception (block)
#else
import Control.Exception.Base (mask)
block f = mask $ \ _ -> f
#endif

tracebs = putStrLn . U.toString
tracebs' = putStr . U.toString

data RedisState = RedisState { server :: (String, String), -- ^ hostname and port pair
                               database :: Int,            -- ^ currently selected database
                               handle :: IO.Handle,        -- ^ real network connection
                               isSubscribed :: Int,        -- ^ currently in PUB/SUB mode
                               renamedCommands :: Map ByteString ByteString -- ^ map of the renamed commands
                             }

-- | Redis connection descriptor
data Redis = Redis {r_lock     :: RLock.RLock,
                    r_st       :: IORef RedisState}
             deriving Eq

newRedis :: (String, String) -> IO.Handle -> IO Redis
newRedis server h = do l <- RLock.new
                       st <- newIORef $ RedisState server 0 h 0 Map.empty
                       return $ Redis l st

-- | Redis command variants
data Command = CInline ByteString
             | CMInline [ByteString]
             | CBulk [ByteString] ByteString
             | CMBulk [ByteString]

-- | Redis reply variants
data Reply s = RTimeout               -- ^ Timeout. Currently unused
             | RParseError String     -- ^ Error converting value from ByteString. It's a client-side error.
             | ROk                    -- ^ \"Ok\" reply
             | RPong                  -- ^ Reply for the ping command
             | RQueued                -- ^ Used inside multi-exec block
             | RError String          -- ^ Some kind of server-side error
             | RInline s              -- ^ Simple oneline reply
             | RInt Int               -- ^ Integer reply
             | RBulk (Maybe s)        -- ^ Multiline reply
             | RMulti (Maybe [Reply s]) -- ^ Complex reply. It may consists of various type of replys
               deriving Eq

showbs :: BS s => s -> String
showbs = U.toString . toBS

instance BS s => Show (Reply s) where
    show RTimeout = "RTimeout"
    show (RParseError msg) = "RParseError: " ++ msg
    show ROk = "ROk"
    show RPong = "RPong"
    show RQueued = "RQueued"
    show (RError msg) = "RError: " ++ msg
    show (RInline s) = "RInline (" ++ (showbs s) ++ ")"
    show (RInt a) = "RInt " ++ show a
    show (RBulk (Just s)) = "RBulk " ++ showbs s
    show (RBulk Nothing) = "RBulk Nil"
    show (RMulti (Just rs)) = "RMulti [" ++ join rs ++ "]"
                              where join = concat . intersperse ", " . map show
    show (RMulti Nothing) = "RMulti Nil"

data Message s = MSubscribe s Int   -- ^ subscribed
               | MUnsubscribe s Int -- ^ unsubscribed
               | MPSubscribe s Int  -- ^ pattern subscribed
               | MPUnsubscribe s Int -- ^ pattern unsubscribed
               | MMessage s s        -- ^ message recieved
               | MPMessage s s s     -- ^ message recieved by pattern
                 deriving (Eq, Show)

urn       = U.fromString "\r\n"
uspace    = U.fromString " "
uminus    = U.fromString "-"
uplus     = U.fromString "+"
ucolon    = U.fromString ":"
ubucks    = U.fromString "$"
uasterisk = U.fromString "*"

hPutRn h = B.hPut h urn
{-# INLINE hPutRn #-}

takeState :: Redis -> IO RedisState
takeState r = block $ do RLock.acquire $ r_lock r
                         readIORef $ r_st r

putState :: Redis -> RedisState -> IO ()
putState r s = block $ do lstate <- RLock.state $ r_lock r
                          mytid <- myThreadId
                          case lstate of
                            Just (mytid, _) -> do writeIORef (r_st r) s
                                                  RLock.release $ r_lock r
                            otherwise -> error "putState: trying put state that was not took"

putStateUnmodified :: Redis -> IO ()
putStateUnmodified r = RLock.release $ r_lock r

inState :: Redis -> (RedisState -> IO (RedisState, a)) -> IO a
inState r action = bracketOnError (takeState r) (\_ -> putStateUnmodified r)
                     $ \s -> do (s', a) <- action s
                                putState r s'
                                return a

inState_ :: Redis -> (RedisState -> IO RedisState) -> IO ()
inState_ r action = bracketOnError (takeState r) (\_ -> putStateUnmodified r) (\s -> action s >>= putState r)

withState :: Redis -> (RedisState -> IO a) -> IO a
withState r action = bracket (takeState r) (\_ -> putStateUnmodified r) action

withState' = flip withState

send :: IO.Handle -> [ByteString] -> IO ()
send h [] = return ()
send h (bs:ls) = B.hPut h bs >> B.hPut h uspace >> send h ls

lookupRenamed :: RedisState -> ByteString -> ByteString
lookupRenamed r c = let c' = Map.findWithDefault c c (renamedCommands r)
                    in if B.null c'
                         then error $ "Command " ++ (fromBS c :: String) ++ " is disabled"
                         else c'

sendCommand :: RedisState -> Command -> IO ()
sendCommand r (CInline bs) = let h = handle r
                                 cmd = lookupRenamed r bs
                             in B.hPut h cmd >> hPutRn h >> IO.hFlush h
sendCommand r (CMInline (l:ls)) = let h = handle r
                                      cmd = lookupRenamed r l
                                  in send h (cmd:ls) >> hPutRn h >> IO.hFlush h
sendCommand r (CBulk (l:ls) bs) = let h = handle r
                                      size = U.fromString $ show $ B.length bs
                                      cmd = lookupRenamed r l
                                  in do send h (cmd:ls)
                                        B.hPut h uspace
                                        B.hPut h size
                                        hPutRn h
                                        B.hPut h bs
                                        hPutRn h
                                        IO.hFlush h
sendCommand r (CMBulk s@(c:cs)) = let h = handle r
                                      sendls [] = return ()
                                      sendls (bs:ls) = let size = U.fromString . show . B.length
                                                       in do B.hPut h ubucks
                                                             B.hPut h $ size bs
                                                             hPutRn h
                                                             B.hPut h bs
                                                             hPutRn h
                                                             sendls ls
                                      c' = lookupRenamed r c
                                  in do B.hPut h uasterisk
                                        B.hPut h $ U.fromString $ show $ length s
                                        hPutRn h
                                        sendls (c':cs)
                                        IO.hFlush h

sendCommand' = flip sendCommand

recv :: BS s => RedisState -> IO (Reply s)
recv r = do first <- trim `fmap` B.hGetLine h
            case U.uncons first of
              Just ('-', rest)  -> recv_err rest
              Just ('+', rest)  -> recv_inline rest
              Just (':', rest)  -> recv_int rest
              Just ('$', rest)  -> recv_bulk rest
              Just ('*', rest)  -> recv_multi rest
    where
      h = handle r
      trim = B.takeWhile (\c -> c /= 13 && c /= 10)

      safeFromBS constructor bs = (return $! constructor $! fromBS bs)
                                  `catch`
                                  (\e -> let msg = show (e :: SomeException)
                                         in return $ RParseError msg)

      -- recv_err :: ByteString -> IO Reply
      recv_err rest = return $ RError $ U.toString rest

      -- recv_inline :: ByteString -> IO Reply
      recv_inline rest = case rest of
                           "OK"       -> return ROk
                           "PONG"     -> return RPong
                           "QUEUED"   -> return RQueued
                           _          -> safeFromBS RInline rest

      -- recv_int :: ByteString -> IO Reply
      recv_int rest = let reply = fst $ fromJust $ readInt rest
                      in return $ RInt reply

      -- recv_bulk :: ByteString -> IO Reply
      recv_bulk rest = let size = fst $ fromJust $ readInt rest
                       in do body <- recv_bulk_body size
                             maybe (return $ RBulk Nothing) (safeFromBS (RBulk . Just)) body

      -- recv_bulk_body :: Int -> IO (Maybe ByteString)
      recv_bulk_body (-1) = return Nothing
      recv_bulk_body size = do body <- B.hGet h (size + 2)
                               let reply = B.take size body
                               return $ Just reply

      -- recv_multi :: ByteString -> IO Reply
      recv_multi rest = let cnt = fst $ fromJust $ readInt rest
                        in do bulks <- recv_multi_n cnt
                              return $ RMulti bulks

      -- recv_multi_n :: Int -> IO (Maybe [Reply])
      recv_multi_n (-1) = return Nothing
      recv_multi_n 0    = return $ Just []
      recv_multi_n n = do this <- recv r
                          tail <- fromJust `fmap` recv_multi_n (n-1)
                          return $ Just (this : tail)

wait :: RedisState -> Int -> IO Bool
wait rs = IO.hWaitForInput (handle rs)