module Network.Riak.Connection.Pool
(
Pool
, client
, create
, idleTime
, maxConnections
, numStripes
, withConnection
) where
import Control.Applicative ((<$>))
import Control.Concurrent (forkIO, killThread, myThreadId, threadDelay)
import Control.Concurrent.STM
import Control.Exception (SomeException, catch, onException)
import Control.Monad (forM_, forever, join, liftM2, unless, when)
import Data.Hashable (hash)
import Data.List (partition)
import Data.Time.Clock (NominalDiffTime, UTCTime, diffUTCTime, getCurrentTime)
import Data.Typeable (Typeable)
import Network.Riak.Connection.Internal (connect, disconnect, makeClientID)
import Network.Riak.Debug (debug)
import Network.Riak.Types (Client(clientID), Connection)
import Prelude hiding (catch)
import System.Mem.Weak (addFinalizer)
import qualified Data.Vector as V
data Entry = Entry {
connection :: Connection
, lastUse :: UTCTime
}
data LocalPool = LocalPool {
connected :: TVar Int
, entries :: TVar [Entry]
}
data Pool = Pool {
client :: Client
, numStripes :: Int
, idleTime :: NominalDiffTime
, maxConnections :: Int
, localPools :: V.Vector LocalPool
} deriving (Typeable)
instance Show Pool where
show Pool{..} = "Pool { client = " ++ show client ++ ", " ++
"numStripes = " ++ show numStripes ++ ", " ++
"idleTime = " ++ show idleTime ++ ", " ++
"maxConnections = " ++ show maxConnections ++ "}"
instance Eq Pool where
a == b = client a == client b && numStripes a == numStripes b &&
idleTime a == idleTime b && maxConnections a == maxConnections b
create :: Client
-> Int
-> NominalDiffTime
-> Int
-> IO Pool
create client numStripes idleTime maxConnections = do
when (numStripes < 1) $
modError "pool " $ "invalid stripe count " ++ show numStripes
when (idleTime < 0.5) $
modError "pool " $ "invalid idle time " ++ show idleTime
when (maxConnections < 1) $
modError "pool " $ "invalid maximum connection count " ++
show maxConnections
localPools <- atomically . V.replicateM numStripes $
liftM2 LocalPool (newTVar 0) (newTVar [])
reaperId <- forkIO $ reaper idleTime localPools
let p = Pool {
client
, numStripes
, idleTime
, maxConnections
, localPools
}
addFinalizer p $ killThread reaperId
return p
reaper :: NominalDiffTime -> V.Vector LocalPool -> IO ()
reaper idleTime pools = forever $ do
threadDelay (2 * 1000000)
now <- getCurrentTime
let isStale Entry{..} = now `diffUTCTime` lastUse > idleTime
V.forM_ pools $ \LocalPool{..} -> do
conns <- atomically $ do
(stale,fresh) <- partition isStale <$> readTVar entries
unless (null stale) $ do
writeTVar entries fresh
modifyTVar_ connected (subtract (length stale))
return (map connection stale)
forM_ conns $ \conn -> do
debug "reaper" "closing idle connection"
disconnect conn `catch` \(_::SomeException) -> return ()
withConnection :: Pool -> (Connection -> IO a) -> IO a
withConnection Pool{..} act = do
i <- ((`mod` numStripes) . hash) <$> myThreadId
let LocalPool{..} = localPools V.! i
conn <- join . atomically $ do
ents <- readTVar entries
case ents of
(Entry{..}:es) -> writeTVar entries es >> return (return connection)
[] -> do
inUse <- readTVar connected
when (inUse == maxConnections) retry
writeTVar connected $! inUse + 1
return $ do
cid <- makeClientID
connect client { clientID = cid }
`onException` atomically (modifyTVar_ connected (subtract 1))
ret <- act conn `onException` do
disconnect conn `catch` \(_::SomeException) -> return ()
atomically (modifyTVar_ connected (subtract 1))
now <- getCurrentTime
atomically $ modifyTVar_ entries (Entry conn now:)
return ret
modifyTVar_ :: TVar a -> (a -> a) -> STM ()
modifyTVar_ v f = readTVar v >>= \a -> writeTVar v $! f a
modError :: String -> String -> a
modError func msg =
error $ "Network.Riak.Connection.Pool." ++ func ++ ": " ++ msg