module Database.Redis.Monadic
       ( -- * Typeclass
         HasRedis(..)
         -- * Default transformer
       , RedisReaderT(..)
       , runRedisReaderT
         -- * Redis interconnection
       , queryRedis
       , runRedisTrans
       , queryRedisTrans
       ) where

import Control.Applicative
import Control.Concurrent
import Control.Monad.Base
import Control.Monad.Cont.Class
import Control.Monad.Reader
import Control.Monad.State
import Control.Monad.Trans.Cont
import Control.Monad.Trans.Control
import Control.Monad.Trans.Identity
import Control.Monad.Trans.Maybe
import Control.Monad.Writer
import Database.Redis

#if MIN_VERSION_mtl(2,2,1)
import Control.Monad.Except
#else
import Control.Monad.Error
#endif

import qualified Control.Monad.Trans.State.Lazy as SL
import qualified Control.Monad.Trans.State.Strict as ST
import qualified Control.Monad.Trans.Writer.Lazy as WL
import qualified Control.Monad.Trans.Writer.Strict as WS


-- | Monad which has access to Redis connection
class (MonadBase IO m) => HasRedis m where
    getRedis :: m (Connection)

#define HASREDIS(T)                            \
instance (HasRedis m) => HasRedis (T m) where { \
    getRedis = lift getRedis;                  \
    {-# INLINEABLE getRedis #-}                \
}

HASREDIS(ExceptT e)
HASREDIS(IdentityT)
HASREDIS(MaybeT)
HASREDIS(ReaderT r)
HASREDIS(SL.StateT s)
HASREDIS(ST.StateT s)
HASREDIS(ContT r)
instance (HasRedis m, Monoid w) => HasRedis (WL.WriterT w m) where
    getRedis = lift getRedis
    {-# INLINEABLE getRedis #-}
instance (HasRedis m, Monoid w) => HasRedis (WS.WriterT w m) where
    getRedis = lift getRedis
    {-# INLINEABLE getRedis #-}


newtype RedisReaderT m a =
    RedisReaderT
    { getRedisReader :: ReaderT Connection m a
    } deriving ( Monad, MonadWriter w, MonadState s
               , MonadError e, MonadTrans, Functor, MonadFix
               , MonadPlus, Applicative, Alternative, MonadIO
               , MonadCont, MonadBase b
               )

runRedisReaderT :: Connection -> RedisReaderT m a -> m a
runRedisReaderT con (RedisReaderT a) = runReaderT a con

instance (MonadReader r m) => MonadReader r (RedisReaderT m) where
    ask = lift ask
    local f action = do
        con <- RedisReaderT ask
        lift $ local f
            $ runRedisReaderT con action
    reader = lift . reader
    {-# INLINEABLE ask #-}
    {-# INLINEABLE local #-}
    {-# INLINEABLE reader #-}

#if MIN_VERSION_monad_control(1,0,0)
instance (MonadBaseControl b m) => MonadBaseControl b (RedisReaderT m) where
    type StM (RedisReaderT m) a = StM (ReaderT Connection m) a
    liftBaseWith action = RedisReaderT $ do
        liftBaseWith $ \runInBase -> do
            action (runInBase . getRedisReader)
    restoreM = RedisReaderT . restoreM
    {-# INLINEABLE liftBaseWith #-}
    {-# INLINEABLE restoreM #-}

instance MonadTransControl RedisReaderT where
    type StT RedisReaderT a = StT (ReaderT Connection) a
    liftWith action = RedisReaderT $ do
        liftWith $ \runTrans -> action (runTrans . getRedisReader)
    restoreT st = RedisReaderT $ restoreT st
    {-# INLINEABLE liftWith #-}
    {-# INLINEABLE restoreT #-}
#else
instance (MonadBaseControl b m) => MonadBaseControl b (RedisReaderT m) where
    newtype StM (RedisReaderT m) a
        = RRStM (StM (ReaderT Connection m) a)
    liftBaseWith action = RedisReaderT $ do
        liftBaseWith $ \runInBase -> do
            action ((RRStM `liftM`) . runInBase . getRedisReader)
    restoreM (RRStM st) = RedisReaderT $ restoreM st
    {-# INLINEABLE liftBaseWith #-}
    {-# INLINEABLE restoreM #-}

instance MonadTransControl RedisReaderT where
    newtype StT RedisReaderT a
        = RRStT { unRRStT :: StT (ReaderT Connection) a }
    liftWith action = RedisReaderT $ do
        liftWith $ \runTrans -> do
            action ((RRStT `liftM`) . runTrans . getRedisReader)
    restoreT st = RedisReaderT $ restoreT $ unRRStT `liftM` st
    {-# INLINEABLE liftWith #-}
    {-# INLINEABLE restoreT #-}
#endif



queryRedis :: (HasRedis m) => Redis a -> m a
queryRedis a = do
    con <- getRedis
    liftBase $ runRedis con a

{- | Run redis transaction and try rerun it if it was aborted. Perform
random delay between retries.

@
runRedisTrans con (randomRIO (100, 1000)) 10 $ do
    watch [key1, key2, key3]
    lIndex key3 0 >>= \case
        Nothing -> unwatch *> pure TxAborted
        Just val -> multiExec $ do
            lRem  key3 1
            lPush key1 [val]
            lPush key2 [val]
@

In next example we copy first value to two different lists and remove
it from original list transactionally. If any value in either key1,
key2, or key3 was changed between 'watch' command and 'exec' (inside
'multiExec') then transaction will be aboretd, then thread will wait
for between 100 and 1000 microseconds, then whole action will be
relaunched.

__ User responsible do not perform mutating actions outside of 'multiExec' because this actions can be launched multiple times. __

Look at next example:

@
runRedisTrans con (randomRIO (100, 1000)) 10 $ do
    watch [key1, key2, key3]
    lPop key3 >>= \case
        Nothing -> unwatch *> pure TxAborted
        Just val -> multiExec $ do
            lPush key1 [val]
            lPush key2 [val]
@

It is highly unrecommended to do that, because if transaction aborted
'lPop' will be performed several times (up to 10).

-}

runRedisTrans :: Connection
              -> IO Int          -- microseconds to wait between transaction retry
              -> Int             -- max transaction retries count
              -> Redis (TxResult a)
              -> IO (TxResult a)
runRedisTrans con delay maxRepeats raction = go 0
  where
    go acc
        | acc >= maxRepeats  = return TxAborted
        | otherwise = do
              res <- runRedis con raction
              case res of
                  TxAborted -> do
                      t <- delay
                      when (t > 0)
                          $ threadDelay t
                      go $ acc + 1
                  a         -> return a

-- | Same as 'runRedisTrans' but for 'HasRedis' monad instances.
queryRedisTrans :: (HasRedis m)
                => IO Int        -- microseconds to wait between transaction rerun
                -> Int           -- transactions rerun max count
                -> Redis (TxResult a)
                -> m (TxResult a)
queryRedisTrans rr maxRepeats raction = do
    con <- getRedis
    liftBase $ runRedisTrans con rr maxRepeats raction