{-# LANGUAGE CPP, OverloadedStrings, BangPatterns #-}
module Database.Redis.Internal where
#if MIN_VERSION_utf8_string(1,0,0)
import Prelude hiding (catch)
#else
import Prelude hiding (catch, putStrLn, putStr)
import System.IO.UTF8 (putStrLn, putStr)
#endif
import Control.Concurrent (ThreadId, myThreadId)
import qualified Control.Concurrent.RLock as RLock
import Data.IORef
import qualified System.IO as IO
import qualified Data.ByteString as B
import Data.ByteString (ByteString)
import Data.ByteString.Char8 (readInt)
import qualified Data.ByteString.UTF8 as U
import Data.Maybe (fromJust, isNothing, isJust)
import Data.List (intersperse)
import qualified Data.Map as Map
import Data.Map (Map(..))
import Control.Monad (when)
import Control.Exception (bracket, bracketOnError, catch, SomeException)
import Database.Redis.ByteStringClass
#if __GLASGOW_HASKELL__ < 700
import Control.Exception (block)
#else
import Control.Exception.Base (mask)
block f = mask $ \ _ -> f
#endif
tracebs = putStrLn . U.toString
tracebs' = putStr . U.toString
data RedisState = RedisState { server :: (String, String),
database :: Int,
handle :: IO.Handle,
isSubscribed :: Int,
renamedCommands :: Map ByteString ByteString
}
data Redis = Redis {r_lock :: RLock.RLock,
r_st :: IORef RedisState}
deriving Eq
newRedis :: (String, String) -> IO.Handle -> IO Redis
newRedis server h = do l <- RLock.new
st <- newIORef $ RedisState server 0 h 0 Map.empty
return $ Redis l st
data Command = CInline ByteString
| CMInline [ByteString]
| CBulk [ByteString] ByteString
| CMBulk [ByteString]
data Reply s = RTimeout
| RParseError String
| ROk
| RPong
| RQueued
| RError String
| RInline s
| RInt Int
| RBulk (Maybe s)
| RMulti (Maybe [Reply s])
deriving Eq
showbs :: BS s => s -> String
showbs = U.toString . toBS
instance BS s => Show (Reply s) where
show RTimeout = "RTimeout"
show (RParseError msg) = "RParseError: " ++ msg
show ROk = "ROk"
show RPong = "RPong"
show RQueued = "RQueued"
show (RError msg) = "RError: " ++ msg
show (RInline s) = "RInline (" ++ (showbs s) ++ ")"
show (RInt a) = "RInt " ++ show a
show (RBulk (Just s)) = "RBulk " ++ showbs s
show (RBulk Nothing) = "RBulk Nil"
show (RMulti (Just rs)) = "RMulti [" ++ join rs ++ "]"
where join = concat . intersperse ", " . map show
show (RMulti Nothing) = "RMulti Nil"
data Message s = MSubscribe s Int
| MUnsubscribe s Int
| MPSubscribe s Int
| MPUnsubscribe s Int
| MMessage s s
| MPMessage s s s
deriving (Eq, Show)
urn = U.fromString "\r\n"
uspace = U.fromString " "
uminus = U.fromString "-"
uplus = U.fromString "+"
ucolon = U.fromString ":"
ubucks = U.fromString "$"
uasterisk = U.fromString "*"
hPutRn h = B.hPut h urn
{-# INLINE hPutRn #-}
takeState :: Redis -> IO RedisState
takeState r = block $ do RLock.acquire $ r_lock r
readIORef $ r_st r
putState :: Redis -> RedisState -> IO ()
putState r s = block $ do lstate <- RLock.state $ r_lock r
mytid <- myThreadId
case lstate of
Just (mytid, _) -> do writeIORef (r_st r) s
RLock.release $ r_lock r
otherwise -> error "putState: trying put state that was not took"
putStateUnmodified :: Redis -> IO ()
putStateUnmodified r = RLock.release $ r_lock r
inState :: Redis -> (RedisState -> IO (RedisState, a)) -> IO a
inState r action = bracketOnError (takeState r) (\_ -> putStateUnmodified r)
$ \s -> do (s', a) <- action s
putState r s'
return a
inState_ :: Redis -> (RedisState -> IO RedisState) -> IO ()
inState_ r action = bracketOnError (takeState r) (\_ -> putStateUnmodified r) (\s -> action s >>= putState r)
withState :: Redis -> (RedisState -> IO a) -> IO a
withState r action = bracket (takeState r) (\_ -> putStateUnmodified r) action
withState' = flip withState
send :: IO.Handle -> [ByteString] -> IO ()
send h [] = return ()
send h (bs:ls) = B.hPut h bs >> B.hPut h uspace >> send h ls
lookupRenamed :: RedisState -> ByteString -> ByteString
lookupRenamed r c = let c' = Map.findWithDefault c c (renamedCommands r)
in if B.null c'
then error $ "Command " ++ (fromBS c :: String) ++ " is disabled"
else c'
sendCommand :: RedisState -> Command -> IO ()
sendCommand r (CInline bs) = let h = handle r
cmd = lookupRenamed r bs
in B.hPut h cmd >> hPutRn h >> IO.hFlush h
sendCommand r (CMInline (l:ls)) = let h = handle r
cmd = lookupRenamed r l
in send h (cmd:ls) >> hPutRn h >> IO.hFlush h
sendCommand r (CBulk (l:ls) bs) = let h = handle r
size = U.fromString $ show $ B.length bs
cmd = lookupRenamed r l
in do send h (cmd:ls)
B.hPut h uspace
B.hPut h size
hPutRn h
B.hPut h bs
hPutRn h
IO.hFlush h
sendCommand r (CMBulk s@(c:cs)) = let h = handle r
sendls [] = return ()
sendls (bs:ls) = let size = U.fromString . show . B.length
in do B.hPut h ubucks
B.hPut h $ size bs
hPutRn h
B.hPut h bs
hPutRn h
sendls ls
c' = lookupRenamed r c
in do B.hPut h uasterisk
B.hPut h $ U.fromString $ show $ length s
hPutRn h
sendls (c':cs)
IO.hFlush h
sendCommand' = flip sendCommand
recv :: BS s => RedisState -> IO (Reply s)
recv r = do first <- trim `fmap` B.hGetLine h
case U.uncons first of
Just ('-', rest) -> recv_err rest
Just ('+', rest) -> recv_inline rest
Just (':', rest) -> recv_int rest
Just ('$', rest) -> recv_bulk rest
Just ('*', rest) -> recv_multi rest
where
h = handle r
trim = B.takeWhile (\c -> c /= 13 && c /= 10)
safeFromBS constructor bs = (return $! constructor $! fromBS bs)
`catch`
(\e -> let msg = show (e :: SomeException)
in return $ RParseError msg)
recv_err rest = return $ RError $ U.toString rest
recv_inline rest = case rest of
"OK" -> return ROk
"PONG" -> return RPong
"QUEUED" -> return RQueued
_ -> safeFromBS RInline rest
recv_int rest = let reply = fst $ fromJust $ readInt rest
in return $ RInt reply
recv_bulk rest = let size = fst $ fromJust $ readInt rest
in do body <- recv_bulk_body size
maybe (return $ RBulk Nothing) (safeFromBS (RBulk . Just)) body
recv_bulk_body (-1) = return Nothing
recv_bulk_body size = do body <- B.hGet h (size + 2)
let reply = B.take size body
return $ Just reply
recv_multi rest = let cnt = fst $ fromJust $ readInt rest
in do bulks <- recv_multi_n cnt
return $ RMulti bulks
recv_multi_n (-1) = return Nothing
recv_multi_n 0 = return $ Just []
recv_multi_n n = do this <- recv r
tail <- fromJust `fmap` recv_multi_n (n-1)
return $ Just (this : tail)
wait :: RedisState -> Int -> IO Bool
wait rs = IO.hWaitForInput (handle rs)