{-# LANGUAGE CPP #-}
module Database.PostgreSQL.PQTypes.Internal.Monad (
    DBT_(..)
  , DBT
  , runDBT
  , mapDBT
  ) where

import Control.Applicative
import Control.Concurrent.MVar
import Control.Monad.Base
import Control.Monad.Catch
import Control.Monad.Error.Class
import Control.Monad.Reader.Class
import Control.Monad.State.Strict
import Control.Monad.Trans.Control
import Control.Monad.Writer.Class
import Data.Monoid
import Prelude
import qualified Control.Monad.Trans.State.Strict as S

import Database.PostgreSQL.PQTypes.Class
import Database.PostgreSQL.PQTypes.Internal.Connection
import Database.PostgreSQL.PQTypes.Internal.Error
import Database.PostgreSQL.PQTypes.Internal.Notification
import Database.PostgreSQL.PQTypes.Internal.Query
import Database.PostgreSQL.PQTypes.Internal.State
import Database.PostgreSQL.PQTypes.SQL
import Database.PostgreSQL.PQTypes.SQL.Class
import Database.PostgreSQL.PQTypes.Transaction
import Database.PostgreSQL.PQTypes.Transaction.Settings
import Database.PostgreSQL.PQTypes.Utils

type InnerDBT m = StateT (DBState m)

-- | Monad transformer for adding database
-- interaction capabilities to the underlying monad.
newtype DBT_ m n a = DBT { unDBT :: InnerDBT m n a }
  deriving (Alternative, Applicative, Functor, Monad, MonadBase b, MonadCatch, MonadIO, MonadMask, MonadPlus, MonadThrow, MonadTrans)

type DBT m = DBT_ m m

-- | Evaluate monadic action with supplied
-- connection source and transaction settings.
runDBT
  :: (MonadBase IO m, MonadMask m)
  => ConnectionSourceM m
  -> TransactionSettings
  -> DBT m a
  -> m a
runDBT cs ts m = withConnection cs $ \conn -> do
  evalStateT action $ DBState {
    dbConnection = conn
  , dbConnectionSource = cs
  , dbTransactionSettings = ts
  , dbLastQuery = SomeSQL (mempty::SQL)
  , dbQueryResult = Nothing
  }
  where
    action = unDBT $ if tsAutoTransaction ts
      then withTransaction' (ts { tsAutoTransaction = False }) m
      else m

-- | Transform the underlying monad.
mapDBT
  :: (DBState n -> DBState m)
  -> (m (a, DBState m) -> n (b, DBState n))
  -> DBT m a
  -> DBT n b
mapDBT f g m = DBT . StateT $ g . runStateT (unDBT m) . f

----------------------------------------

instance (m ~ n, MonadBase IO m, MonadMask m) => MonadDB (DBT_ m n) where
  runQuery sql = DBT . StateT $ liftBase . runQueryIO sql
  getLastQuery = DBT . gets $ dbLastQuery

  getConnectionStats = do
    mconn <- DBT $ liftBase . readMVar =<< gets (unConnection . dbConnection)
    case mconn of
      Nothing -> throwDB $ HPQTypesError "getConnectionStats: no connection"
      Just cd -> return $ cdStats cd

  getTransactionSettings = DBT . gets $ dbTransactionSettings
  setTransactionSettings ts = DBT . modify $ \st -> st { dbTransactionSettings = ts }

  getQueryResult = DBT . gets $ dbQueryResult
  clearQueryResult = DBT . modify $ \st -> st { dbQueryResult = Nothing }

  getNotification time = DBT . StateT $ \st -> (, st)
    <$> liftBase (getNotificationIO st time)

  withNewConnection m = DBT . StateT $ \st -> do
    let cs = dbConnectionSource st
        ts = dbTransactionSettings st
    res <- runDBT cs ts m
    return (res, st)

----------------------------------------

instance MonadTransControl (DBT_ m) where
#if MIN_VERSION_monad_control(1,0,0)
  type StT (DBT_ m) a = StT (InnerDBT m) a
  liftWith = defaultLiftWith DBT unDBT
  restoreT = defaultRestoreT DBT
  {-# INLINE liftWith #-}
  {-# INLINE restoreT #-}
#else
  newtype StT (DBT_ m) a = StDBT { unStDBT :: StT (InnerDBT m) a }
  liftWith = defaultLiftWith DBT unDBT StDBT
  restoreT = defaultRestoreT DBT unStDBT
  {-# INLINE liftWith #-}
  {-# INLINE restoreT #-}
#endif

instance (m ~ n, MonadBaseControl b m) => MonadBaseControl b (DBT_ m n) where
#if MIN_VERSION_monad_control(1,0,0)
  type StM (DBT_ m n) a = ComposeSt (DBT_ m) m a
  liftBaseWith = defaultLiftBaseWith
  restoreM     = defaultRestoreM
  {-# INLINE liftBaseWith #-}
  {-# INLINE restoreM #-}
#else
  newtype StM (DBT m) a = StMDBT { unStMDBT :: ComposeSt (DBT_ m) m a }
  liftBaseWith = defaultLiftBaseWith StMDBT
  restoreM     = defaultRestoreM unStMDBT
  {-# INLINE liftBaseWith #-}
  {-# INLINE restoreM #-}
#endif

instance (m ~ n, MonadError e m) => MonadError e (DBT_ m n) where
  throwError = lift . throwError
  catchError m h = DBT $ S.liftCatch catchError (unDBT m) (unDBT . h)

instance (m ~ n, MonadReader r m) => MonadReader r (DBT_ m n) where
  ask = lift ask
  local f = mapDBT id (local f)
  reader = lift . reader

instance (m ~ n, MonadState s m) => MonadState s (DBT_ m n) where
  get = lift get
  put = lift . put
  state = lift . state

instance (m ~ n, MonadWriter w m) => MonadWriter w (DBT_ m n) where
  writer = lift . writer
  tell = lift . tell
  listen = DBT . S.liftListen listen . unDBT
  pass = DBT . S.liftPass pass . unDBT