{-# LANGUAGE PatternGuards, NamedFieldPuns, RecordWildCards #-}


module Database.Cassandra.Pool where



import Control.Applicative ((<$>))
import Control.Concurrent.STM
import Control.Exception (SomeException, catch, onException)
import Control.Monad (forM_, forever, join, liftM2, unless, when)
import Control.Monad.IO.Class (liftIO)
import Data.ByteString (ByteString)
import Data.List (partition)
import Data.Time.Clock (NominalDiffTime, UTCTime, diffUTCTime, getCurrentTime)
import Prelude hiding (catch)
import System.Mem.Weak (addFinalizer)
import System.IO (hClose, Handle(..))


import qualified Database.Cassandra.Thrift.Cassandra_Client as C
import Thrift.Transport
import Thrift.Transport.Handle
import Thrift.Transport.Framed
import Thrift.Protocol.Binary
import Network

------------------------------------------------------------------------------
-- | A round-robin pool of cassandra connections
type CPool = Pool Cassandra Server


type Server = (HostName, PortID)


-- | A localhost server with default configuration
defServer :: Server
defServer = ("127.0.0.1", PortNumber 9160)


-- | A single localhost server with default configuration
defServers :: [Server]
defServers = [defServer]


type KeySpace = String


data Cassandra = Cassandra {
    cHandle :: Handle
  , cFramed :: FramedTransport Handle
  , cProto :: BinaryProtocol (FramedTransport Handle)
}



-- | Create a pool of connections to a cluster of Cassandra boxes
--
-- Each box in the cluster will get up to n connections. The pool will send
-- queries in round-robin fashion to balance load on each box in the cluster.
createCassandraPool 
  :: [Server]
  -- ^ List of servers to connect to
  -> Int
  -- ^ Max connections per server (n)
  -> NominalDiffTime
  -- ^ Kill each connection after this many seconds
  -> KeySpace
  -- ^ Each pool operates on a single KeySpace
  -> IO CPool
createCassandraPool servers n maxIdle ks = createPool cr dest n maxIdle servers
  where
    cr :: Server -> IO Cassandra
    cr (host, p) = do
      h <- hOpen (host, p)
      ft <- openFramedTransport h
      let p = BinaryProtocol ft
      C.set_keyspace (p,p) ks
      return $ Cassandra h ft p
    dest h = hClose $ cHandle h


------------------------------------------------------------------------------
-- Generic pool functionality - might want to factor out one day
--
------------------------------------------------------------------------------

newtype Pool a s = Pool { stripes :: TVar (Ring (Stripe a s)) }


createPool cr dest n maxIdle servers = do
  when (maxIdle < 0.5) $
    modError "pool " $ "invalid idle time " ++ show maxIdle
  when (n < 1) $
    modError "pool " $ "invalid maximum resource count " ++ show n
  stripes' <- mapM (createStripe cr dest n maxIdle) servers
  -- reaperId <- forkIO $ reaper destroy idleTime localPools
  -- addFinalizer p $ killThread reaperId
  tv <- atomically $ newTVar (mkRing stripes')
  return $ Pool tv



withPool :: Pool a s -> (a -> IO b) -> IO b
withPool Pool{..} f = do
  Ring{..} <- atomically $ do
    r <- readTVar stripes
    writeTVar stripes $ next r
    return r
  withStripe current f


data Ring a = Ring {
    current :: !a
  , used :: [a]
  , upcoming :: [a]
  }


mkRing [] = error "Can't make a ring from empty list"
mkRing (a:as) = Ring a [] as


next :: Ring a -> Ring a
next Ring{..} 
  | (n:rest) <- upcoming
  = Ring n (current : used) rest
next Ring{..} 
  | (n:rest) <- reverse (current : used)
  = Ring n [] rest


data Stripe a s = Stripe {
    idle :: TVar [Connection a]
  -- ^ FIFO buffer of idle connections
  , inUse :: TVar Int
  -- ^ Set of in-use connections
  , server :: s
  -- ^ Server this strip is connected to
  , create :: s -> IO a
  -- ^ Create action
  , destroy :: (a -> IO ())
  -- ^ Destroy action
  , cxns :: Int
  -- ^ Max connections
  , ttl :: NominalDiffTime
  -- ^ TTL for each connection
  }


createStripe 
  :: (s -> IO a)
  -> (a -> IO ())
  -> Int
  -> NominalDiffTime
  -> s
  -> IO (Stripe a s)
createStripe cr dest n maxIdle s = atomically $ do
  idles <- newTVar []
  used <- newTVar 0
  return $ Stripe {
    idle = idles
  , inUse = used
  , server = s
  , create = cr
  , destroy = dest
  , cxns = n
  , ttl = maxIdle
  }


withStripe :: Stripe a s -> (a -> IO b) -> IO b
withStripe Stripe{..} f = do
  res <- join . atomically $ do
    cs <- readTVar idle
    case cs of
      (Connection{..}:rest) -> writeTVar idle rest >> return (return cxn)
      [] -> do
        used <- readTVar inUse
        when (used == cxns) retry
        writeTVar inUse $! used + 1
        return $ create server 
          `onException` atomically (modifyTVar_ inUse (subtract 1))
  ret <- f res `onException` (destroy res `onException` return ())
  now <- getCurrentTime
  atomically $ modifyTVar_ idle (Connection res now : ) 
  return ret



data Connection a = Connection {
    cxn :: a
  , lastUse :: UTCTime
  }



modifyTVar_ :: TVar a -> (a -> a) -> STM ()
modifyTVar_ v f = readTVar v >>= \a -> writeTVar v $! f a


modError :: String -> String -> a
modError func msg =
    error $ "Data.Pool." ++ func ++ ": " ++ msg