-- This Source Code Form is subject to the terms of the Mozilla Public
-- License, v. 2.0. If a copy of the MPL was not distributed with this
-- file, You can obtain one at http://mozilla.org/MPL/2.0/.

{-# LANGUAGE MultiWayIf          #-}
{-# LANGUAGE OverloadedStrings   #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskell     #-}

module Database.CQL.IO.Pool
    ( Pool
    , create
    , destroy
    , purge
    , with

    , PoolSettings
    , defSettings
    , idleTimeout
    , maxConnections
    , maxTimeouts
    , poolStripes
    ) where

import Control.AutoUpdate
import Control.Concurrent
import Control.Concurrent.Async
import Control.Concurrent.STM
import Control.Exception
import Control.Lens ((^.), makeLenses, view)
import Control.Monad.IO.Class
import Control.Monad
import Data.Foldable (forM_, mapM_, find)
import Data.Function (on)
import Data.Hashable
import Data.IORef
import Data.Sequence (Seq, ViewL (..), (|>), (><))
import Data.Semigroup ((<>))
import Data.Time.Clock (UTCTime, NominalDiffTime, getCurrentTime, diffUTCTime)
import Data.Vector (Vector, (!))
import Database.CQL.IO.Connection (Connection)
import Database.CQL.IO.Exception (ConnectionError (..), ignore)
import Database.CQL.IO.Log

import qualified Data.Sequence as Seq
import qualified Data.Vector   as Vec

-----------------------------------------------------------------------------
-- API

data PoolSettings = PoolSettings
    { _idleTimeout    :: !NominalDiffTime
    , _maxConnections :: !Int
    , _maxTimeouts    :: !Int
    , _poolStripes    :: !Int
    }

data Pool = Pool
    { _createFn    :: !(IO Connection)
    , _destroyFn   :: !(Connection -> IO ())
    , _logger      :: !Logger
    , _settings    :: !PoolSettings
    , _maxRefs     :: !Int
    , _currentTime :: !(IO UTCTime)
    , _stripes     :: !(Vector Stripe)
    , _finaliser   :: !(IORef ())
    }

data Resource = Resource
    { tstamp   :: !UTCTime
    , refcnt   :: !Int
    , timeouts :: !Int
    , value    :: !Connection
    } deriving Show

data Box
    = New  !(IO Resource)
    | Used !Resource
    | Empty

data Stripe = Stripe
    { conns :: !(TVar (Seq Resource))
    , inUse :: !(TVar Int)
    }

makeLenses ''PoolSettings
makeLenses ''Pool

defSettings :: PoolSettings
defSettings = PoolSettings
    60 -- idle timeout
    2  -- max connections per stripe
    16 -- max timeouts per connection
    4  -- max stripes

create :: IO Connection -> (Connection -> IO ()) -> Logger -> PoolSettings -> Int -> IO Pool
create mk del g s k = do
    p <- Pool mk del g s k
            <$> mkAutoUpdate defaultUpdateSettings { updateAction = getCurrentTime }
            <*> Vec.replicateM (s^.poolStripes) (Stripe <$> newTVarIO Seq.empty <*> newTVarIO 0)
            <*> newIORef ()
    r <- async $ reaper p
    void $ mkWeakIORef (p^.finaliser) (cancel r >> destroy p)
    return p

destroy :: Pool -> IO ()
destroy = purge

with :: MonadIO m => Pool -> (Connection -> IO a) -> m (Maybe a)
with p f = liftIO $ do
    s <- stripe p
    mask $ \restore -> do
        r <- take1 p s
        case r of
            Just  v -> do
                x <- restore (f (value v)) `catch` cleanup p s v
                put p s v id
                return (Just x)
            Nothing -> return Nothing

purge :: Pool -> IO ()
purge p = Vec.forM_ (p^.stripes) $ \s -> do
    cs <- atomically (swapTVar (conns s) Seq.empty)
    mapM_ (ignore . view destroyFn p . value) cs

-----------------------------------------------------------------------------
-- Internal

cleanup :: Pool -> Stripe -> Resource -> SomeException -> IO a
cleanup p s r x = do
    case fromException x of
        Just (ResponseTimeout {}) -> onTimeout
        _                         -> destroyR p s r
    throwIO x
  where
    onTimeout =
        if timeouts r > p^.settings.maxTimeouts
            then do
                logInfo (p^.logger) $ string8 (show (value r)) <> ": Too many timeouts."
                destroyR p s r
            else put p s r incrTimeouts

take1 :: Pool -> Stripe -> IO (Maybe Resource)
take1 p s = do
    r <- atomically $ do
        c <- readTVar (conns s)
        u <- readTVar (inUse s)
        let n = Seq.length c
        check (u == n)
        let r :< rr = Seq.viewl $ Seq.unstableSortBy (compare `on` refcnt) c
        if | u < p^.settings.maxConnections -> do
                writeTVar (inUse s) $! u + 1
                mkNew p
           | n > 0 && refcnt r < p^.maxRefs -> use s r rr
           | otherwise                      -> return Empty
    case r of
        New io -> do
            x <- io `onException` atomically (modifyTVar' (inUse s) (subtract 1))
            atomically (modifyTVar' (conns s) (|> x))
            return (Just x)
        Used x -> return (Just x)
        Empty  -> return Nothing

use :: Stripe -> Resource -> Seq Resource -> STM Box
use s r rr = do
    writeTVar (conns s) $! rr |> r { refcnt = refcnt r + 1 }
    return (Used r)
{-# INLINE use #-}

mkNew :: Pool -> STM Box
mkNew p = return (New $ Resource <$> p^.currentTime <*> pure 1 <*> pure 0 <*> p^.createFn)
{-# INLINE mkNew #-}

put :: Pool -> Stripe -> Resource -> (Resource -> Resource) -> IO ()
put p s r f = do
    now <- p^.currentTime
    let updated x = f x { tstamp = now, refcnt = refcnt x - 1 }
    atomically $ do
        rs <- readTVar (conns s)
        let (xs, rr) = Seq.breakl ((value r ==) . value) rs
        case Seq.viewl rr of
            EmptyL  -> writeTVar (conns s) $! xs         |> updated r
            y :< ys -> writeTVar (conns s) $! (xs >< ys) |> updated y

destroyR :: Pool -> Stripe -> Resource -> IO ()
destroyR p s r = do
    atomically $ do
        rs <- readTVar (conns s)
        case find ((value r ==) . value) rs of
            Nothing -> return ()
            Just  _ -> do
                modifyTVar' (inUse s) (subtract 1)
                writeTVar (conns s) $! Seq.filter ((value r /=) . value) rs
    ignore $ p^.destroyFn $ value r

reaper :: Pool -> IO ()
reaper p = forever $ do
    threadDelay 1000000
    now <- p^.currentTime
    let isStale r = refcnt r == 0 && now `diffUTCTime` tstamp r > p^.settings.idleTimeout
    Vec.forM_ (p^.stripes) $ \s -> do
        x <- atomically $ do
                (stale, okay) <- Seq.partition isStale <$> readTVar (conns s)
                unless (Seq.null stale) $ do
                    writeTVar   (conns s) okay
                    modifyTVar' (inUse s) (subtract (Seq.length stale))
                return stale
        forM_ x $ \v -> ignore $ do
            logDebug (p^.logger) $ "Reaping idle connection: " <> string8 (show (value v))
            p^.destroyFn $ (value v)

stripe :: Pool -> IO Stripe
stripe p = ((p^.stripes) !) <$> ((`mod` (p^.settings.poolStripes)) . hash) <$> myThreadId
{-# INLINE stripe #-}

incrTimeouts :: Resource -> Resource
incrTimeouts r = r { timeouts = timeouts r + 1 }
{-# INLINE incrTimeouts #-}