module Database.Redis.Internal where
import Prelude hiding (putStrLn, catch)
import Control.Concurrent (ThreadId, myThreadId)
import Control.Concurrent.MVar
import Data.IORef
import qualified Network.Socket as S
import qualified System.IO as IO
import System.IO.UTF8 (putStrLn)
import qualified Data.ByteString as B
import Data.ByteString (ByteString)
import Data.ByteString.Char8 ()
import qualified Data.ByteString.UTF8 as U
import Data.Maybe (fromJust, isNothing, isJust)
import Data.List (intersperse)
import Control.Monad (when)
import Control.Exception (block, bracket, bracketOnError, catch, SomeException)
import Database.Redis.ByteStringClass
tracebs bs = putStrLn (U.toString bs)
data RedisState = RedisState { server :: (String, String),
database :: Int,
handle :: IO.Handle,
isSubscribed :: Int
}
data Redis = Redis {r_lock_cnt :: MVar (Maybe (ThreadId, Int)),
r_lock :: MVar (),
r_st :: IORef RedisState}
deriving Eq
data Command = CInline ByteString
| CMInline [ByteString]
| CBulk [ByteString] ByteString
| CMBulk [ByteString]
data BS s => Reply s = RTimeout
| RParseError String
| ROk
| RPong
| RQueued
| RError String
| RInline s
| RInt Int
| RBulk (Maybe s)
| RMulti (Maybe [Reply s])
deriving Eq
showbs :: BS s => s -> String
showbs = U.toString . toBS
instance BS s => Show (Reply s) where
show RTimeout = "RTimeout"
show (RParseError msg) = "RParseError: " ++ msg
show ROk = "ROk"
show RPong = "RPong"
show RQueued = "RQueued"
show (RError msg) = "RError: " ++ msg
show (RInline s) = "RInline (" ++ (showbs s) ++ ")"
show (RInt a) = "RInt " ++ show a
show (RBulk (Just s)) = "RBulk " ++ showbs s
show (RBulk Nothing) = "RBulk Nothing"
show (RMulti (Just rs)) = "RMulti [" ++ join rs ++ "]"
where join = concat . intersperse ", " . map show
show (RMulti Nothing) = "RMulti Nil"
data (BS s) => Message s = MSubscribe s Int
| MUnsubscribe s Int
| MPSubscribe s Int
| MPUnsubscribe s Int
| MMessage s s
| MPMessage s s s
deriving Show
urn = U.fromString "\r\n"
uspace = U.fromString " "
uminus = U.fromString "-"
uplus = U.fromString "+"
ucolon = U.fromString ":"
ubucks = U.fromString "$"
uasterisk = U.fromString "*"
hPutRn h = B.hPut h urn
takeState :: Redis -> IO RedisState
takeState r = block $ do lcnt <- takeMVar $ r_lock_cnt r
mytid <- myThreadId
case lcnt of
Nothing -> do l <- tryTakeMVar $ r_lock r
when (isNothing l)
$ error "takeState: r_lock_cnt is Nothing BUT r_lock is locked"
take_n_put mytid 1
Just (tid, cnt) -> if tid == mytid
then let !cnt' = cnt + 1
in take_n_put tid cnt'
else do putMVar (r_lock_cnt r) lcnt
l <- takeMVar $ r_lock r
lcnt <- takeMVar $ r_lock_cnt r
when (isJust lcnt)
$ error "takeState: r_lock is locked by me BUT r_lock_cnt is not Nothing"
take_n_put mytid 1
where take_n_put tid cnt = do st <- readIORef $ r_st r
putMVar (r_lock_cnt r) $ Just (tid, cnt)
return st
putState :: Redis -> RedisState -> IO ()
putState r s = block $ do lcnt <- takeMVar $ r_lock_cnt r
mytid <- myThreadId
case lcnt of
Nothing -> error "putState: trying put state that was not took"
Just (tid, cnt) -> if tid /= mytid
then error "putState: trying put state that was not took by me"
else do writeIORef (r_st r) s
if cnt > 1
then let !cnt' = cnt 1
in putMVar (r_lock_cnt r) $ Just (tid, cnt')
else do putMVar (r_lock r) ()
putMVar (r_lock_cnt r) Nothing
putStateUnmodified :: Redis -> IO ()
putStateUnmodified r = block $ do lcnt <- takeMVar $ r_lock_cnt r
mytid <- myThreadId
case lcnt of
Nothing -> error "putState: trying put state that was not took"
Just (tid, cnt) -> if tid /= mytid
then error "putState: trying put state that was not took by me"
else if cnt > 1
then let !cnt' = cnt 1
in putMVar (r_lock_cnt r) $ Just (tid, cnt')
else do putMVar (r_lock r) ()
putMVar (r_lock_cnt r) Nothing
inState :: Redis -> (RedisState -> IO (RedisState, a)) -> IO a
inState r action = bracketOnError (takeState r) (\_ -> putStateUnmodified r)
$ \s -> do (s', a) <- action s
putState r s'
return a
inState_ :: Redis -> (RedisState -> IO RedisState) -> IO ()
inState_ r action = bracketOnError (takeState r) (\_ -> putStateUnmodified r) (\s -> action s >>= putState r)
withState :: Redis -> (RedisState -> IO a) -> IO a
withState r action = bracket (takeState r) (\_ -> putStateUnmodified r) action
withState' = flip withState
send :: IO.Handle -> [ByteString] -> IO ()
send h [] = return ()
send h (bs:ls) = B.hPut h bs >> B.hPut h uspace >> send h ls
sendCommand :: RedisState -> Command -> IO ()
sendCommand r (CInline bs) = let h = handle r
in B.hPut h bs >> hPutRn h >> IO.hFlush h
sendCommand r (CMInline ls) = let h = handle r
in send h ls >> hPutRn h >> IO.hFlush h
sendCommand r (CBulk lcmd bs) = let h = handle r
size = U.fromString $ show $ B.length bs
in do send h lcmd
B.hPut h uspace
B.hPut h size
hPutRn h
B.hPut h bs
hPutRn h
IO.hFlush h
sendCommand r (CMBulk strings) = let h = handle r
sendls [] = return ()
sendls (bs:ls) = let size = U.fromString . show . B.length
in do B.hPut h ubucks
B.hPut h $ size bs
hPutRn h
B.hPut h bs
hPutRn h
sendls ls
in do B.hPut h uasterisk
B.hPut h $ U.fromString $ show $ length strings
hPutRn h
sendls strings
IO.hFlush h
sendCommand' = flip sendCommand
recv :: BS s => RedisState -> IO (Reply s)
recv r = do first <- trim `fmap` B.hGetLine h
case U.uncons first of
Just ('-', rest) -> recv_err rest
Just ('+', rest) -> recv_inline rest
Just (':', rest) -> recv_int rest
Just ('$', rest) -> recv_bulk rest
Just ('*', rest) -> recv_multi rest
where
h = handle r
trim = B.takeWhile (\c -> c /= 13 && c /= 10)
safeFromBS constructor bs = (return $! constructor $! fromBS bs)
`catch`
(\e -> let msg = show (e :: SomeException)
in return $ RParseError msg)
recv_err rest = return $ RError $ U.toString rest
recv_inline rest = case rest of
"OK" -> return ROk
"PONG" -> return RPong
"QUEUED" -> return RQueued
_ -> safeFromBS RInline rest
recv_int rest = let reply = read (U.toString rest) :: Int
in return $ RInt reply
recv_bulk rest = let size = read (U.toString rest) :: Int
in do body <- recv_bulk_body size
maybe (return $ RBulk Nothing) (safeFromBS (RBulk . Just)) body
recv_bulk_body (1) = return Nothing
recv_bulk_body size = do body <- B.hGet h (size + 2)
let reply = B.take size body
return $ Just reply
recv_multi rest = let cnt = read (U.toString rest) :: Int
in do bulks <- recv_multi_n cnt
return $ RMulti bulks
recv_multi_n (1) = return Nothing
recv_multi_n 0 = return $ Just []
recv_multi_n n = do this <- recv r
tail <- fromJust `fmap` recv_multi_n (n1)
return $ Just (this : tail)
wait :: RedisState -> Int -> IO Bool
wait rs = IO.hWaitForInput (handle rs)