{-# LANGUAGE OverloadedStrings #-}
module Database.Redis.Internal where

import Prelude hiding (putStrLn)
import Control.Concurrent (ThreadId, myThreadId)
import Control.Concurrent.MVar
import Data.IORef
import qualified Network.Socket as S
import qualified System.IO as IO
import System.IO.UTF8 (putStrLn)
import qualified Data.ByteString as B
import Data.ByteString (ByteString)
import Data.ByteString.Char8 ()
import qualified Data.ByteString.UTF8 as U
import Data.Maybe (fromJust, isNothing, isJust)
import Data.List (intersperse)
import Control.Monad (when)
import Control.Exception (block, bracket, bracketOnError)

import Database.Redis.ByteStringClass

tracebs bs = putStrLn (U.toString bs)

data RedisState = RedisState { server :: (String, String), -- ^ hostname and port pair
                               database :: Int,            -- ^ currently selected database
                               handle :: IO.Handle,        -- ^ real network connection
                               isSubscribed :: Int         -- ^ currently in PUB/SUB mode
                             }

-- | Redis connection descriptor
data Redis = Redis {r_lock_cnt :: MVar (Maybe (ThreadId, Int)),
                    r_lock     :: MVar (),
                    r_st       :: IORef RedisState}
             deriving Eq

-- | Redis command variants
data Command = CInline ByteString
             | CMInline [ByteString]
             | CBulk [ByteString] ByteString
             | CMBulk [ByteString]

-- | Redis reply variants
data BS s => Reply s = RTimeout               -- ^ Timeout. Currently unused
                     | ROk                    -- ^ \"Ok\" reply
                     | RPong                  -- ^ Reply for the ping command
                     | RQueued                -- ^ Used inside multi-exec block
                     | RError String          -- ^ Some kind of server-side error
                     | RInline s              -- ^ Simple oneline reply
                     | RInt Int               -- ^ Integer reply
                     | RBulk (Maybe s)        -- ^ Multiline reply
                     | RMulti (Maybe [Reply s]) -- ^ Complex reply. It may consists of various type of replys
                       deriving Eq

showbs :: BS s => s -> String
showbs = U.toString . toBS

instance BS s => Show (Reply s) where
    show RTimeout = "RTimeout"
    show ROk = "ROk"
    show RPong = "RPong"
    show RQueued = "RQueued"
    show (RError msg) = "RError: " ++ msg
    show (RInline s) = "RInline (" ++ (showbs s) ++ ")"
    show (RInt a) = "RInt " ++ show a
    show (RBulk (Just s)) = "RBulk " ++ showbs s
    show (RBulk Nothing) = "RBulk Nothing"
    show (RMulti (Just rs)) = "RMulti [" ++ join rs ++ "]"
                              where join = concat . intersperse ", " . map show
    show (RMulti Nothing) = "[]"

data (BS s) => Message s = MSubscribe s Int
                         | MUnsubscribe s Int
                         | MMessage s s
                           deriving Show

urn       = U.fromString "\r\n"
uspace    = U.fromString " "
uminus    = U.fromString "-"
uplus     = U.fromString "+"
ucolon    = U.fromString ":"
ubucks    = U.fromString "$"
uasterisk = U.fromString "*"

hPutRn h = B.hPut h urn

takeState :: Redis -> IO RedisState
takeState r = block $ do lcnt <- takeMVar $ r_lock_cnt r
                         mytid  <- myThreadId
                         case lcnt of
                           Nothing -> do l <- tryTakeMVar $ r_lock r
                                         when (isNothing l)
                                                  $ error "takeState: r_lock_cnt is Nothing BUT r_lock is locked"
                                         take_n_put mytid 1
                           Just (tid, cnt) -> if tid == mytid
                                              then let !cnt' = cnt + 1
                                                   in take_n_put tid cnt'
                                              else do putMVar (r_lock_cnt r) lcnt
                                                      l <- takeMVar $ r_lock r
                                                      lcnt <- takeMVar $ r_lock_cnt r
                                                      when (isJust lcnt)
                                                               $ error "takeState: r_lock is locked by me BUT r_lock_cnt is not Nothing"
                                                      take_n_put mytid 1
    where take_n_put tid cnt = do st <- readIORef $ r_st r
                                  putMVar (r_lock_cnt r) $ Just (tid, cnt)
                                  return st

putState :: Redis -> RedisState -> IO ()
putState r s = block $ do lcnt <- takeMVar $ r_lock_cnt r
                          mytid <- myThreadId
                          case lcnt of
                            Nothing -> error "putState: trying put state that was not took"
                            Just (tid, cnt) -> if tid /= mytid
                                               then error "putState: trying put state that was not took by me"
                                               else do writeIORef (r_st r) s
                                                       if cnt > 1
                                                         then let !cnt' = cnt - 1
                                                              in putMVar (r_lock_cnt r) $ Just (tid, cnt')
                                                         else do putMVar (r_lock r) ()
                                                                 putMVar (r_lock_cnt r) Nothing

putStateUnmodified :: Redis -> IO ()
putStateUnmodified r = block $ do lcnt <- takeMVar $ r_lock_cnt r
                                  mytid <- myThreadId
                                  case lcnt of
                                    Nothing -> error "putState: trying put state that was not took"
                                    Just (tid, cnt) -> if tid /= mytid
                                                       then error "putState: trying put state that was not took by me"
                                                       else if cnt > 1
                                                            then let !cnt' = cnt - 1
                                                                 in putMVar (r_lock_cnt r) $ Just (tid, cnt')
                                                            else do putMVar (r_lock r) ()
                                                                    putMVar (r_lock_cnt r) Nothing

inState :: Redis -> (RedisState -> IO (RedisState, a)) -> IO a
inState r action = bracketOnError (takeState r) (\_ -> putStateUnmodified r)
                     $ \s -> do (s', a) <- action s
                                putState r s'
                                return a

inState_ :: Redis -> (RedisState -> IO RedisState) -> IO ()
inState_ r action = bracketOnError (takeState r) (\_ -> putStateUnmodified r) (\s -> action s >>= putState r)

withState :: Redis -> (RedisState -> IO a) -> IO a
withState r action = bracket (takeState r) (\_ -> putStateUnmodified r) action

withState' = flip withState

send :: IO.Handle -> [ByteString] -> IO ()
send h [] = return ()
send h (bs:ls) = B.hPut h bs >> B.hPut h uspace >> send h ls

sendCommand :: RedisState -> Command -> IO ()
sendCommand r (CInline bs) = let h = handle r
                             in B.hPut h bs >> hPutRn h >> IO.hFlush h
sendCommand r (CMInline ls) = let h = handle r
                              in send h ls >> hPutRn h >> IO.hFlush h
sendCommand r (CBulk lcmd bs) = let h = handle r
                                    size = U.fromString $ show $ B.length bs
                                in do send h lcmd
                                      B.hPut h uspace
                                      B.hPut h size
                                      hPutRn h
                                      B.hPut h bs
                                      hPutRn h
                                      IO.hFlush h
sendCommand r (CMBulk strings) = let h = handle r
                                     sendls [] = return ()
                                     sendls (bs:ls) = let size = U.fromString . show . B.length
                                                      in do B.hPut h ubucks
                                                            B.hPut h $ size bs
                                                            hPutRn h
                                                            B.hPut h bs
                                                            hPutRn h
                                                            sendls ls
                                 in do B.hPut h uasterisk
                                       B.hPut h $ U.fromString $ show $ length strings
                                       hPutRn h
                                       sendls strings
                                       IO.hFlush h

sendCommand' = flip sendCommand

recv :: BS s => RedisState -> IO (Reply s)
recv r = do first <- trim `fmap` B.hGetLine h
            case U.uncons first of
              Just ('-', rest)  -> recv_err rest
              Just ('+', rest)  -> recv_inline rest
              Just (':', rest)  -> recv_int rest
              Just ('$', rest)  -> recv_bulk rest
              Just ('*', rest)  -> recv_multi rest
    where
      h = handle r
      trim = B.takeWhile (\c -> c /= 13 && c /= 10)

      -- recv_err :: ByteString -> IO Reply
      recv_err rest = return $ RError $ U.toString rest

      -- recv_inline :: ByteString -> IO Reply
      recv_inline rest = return $ case rest of
                                    "OK"       -> ROk
                                    "PONG"     -> RPong
                                    "QUEUED"   -> RQueued
                                    _          -> RInline $ fromBS rest

      -- recv_int :: ByteString -> IO Reply
      recv_int rest = let reply = read (U.toString rest) :: Int
                      in return $ RInt reply

      -- recv_bulk :: ByteString -> IO Reply
      recv_bulk rest = let size = read (U.toString rest) :: Int
                       in do body <- recv_bulk_body size
                             return $ RBulk (fromBS `fmap` body)

      -- recv_bulk_body :: Int -> IO (Maybe ByteString)
      recv_bulk_body (-1) = return Nothing
      recv_bulk_body size = do body <- B.hGet h (size + 2)
                               let reply = B.take size body
                               return $ Just reply

      -- recv_multi :: ByteString -> IO Reply
      recv_multi rest = let cnt = read (U.toString rest) :: Int
                        in do bulks <- recv_multi_n cnt
                              return $ RMulti bulks

      -- recv_multi_n :: Int -> IO (Maybe [Reply])
      recv_multi_n (-1) = return Nothing
      recv_multi_n 0    = return $ Just []
      recv_multi_n n = do this <- recv r
                          tail <- fromJust `fmap` recv_multi_n (n-1)
                          return $ Just (this : tail)