module Database.Persist.Sql.Run where
import Database.Persist.Sql.Types
import Database.Persist.Sql.Raw
import Control.Monad.Trans.Control
import Data.Pool as P
import Control.Monad.Trans.Reader hiding (local)
import Control.Monad.Trans.Resource
import Control.Monad.Logger
import Control.Monad.Base
import Control.Exception.Lifted (onException)
import Control.Monad.IO.Class
import Control.Exception.Lifted (bracket)
import Control.Exception (mask)
import System.Timeout (timeout)
import Control.Monad.Trans.Control (control)
import Data.IORef (readIORef)
import qualified Data.Map as Map
import Control.Exception.Lifted (throwIO)
import Control.Exception (mask)
import Control.Monad (liftM)
import System.Timeout (timeout)
runSqlPool :: MonadBaseControl IO m => SqlPersistT m a -> Pool Connection -> m a
runSqlPool r pconn = do
    mres <- withResourceTimeout 2000000 pconn $ runSqlConn r
    maybe (throwIO Couldn'tGetSQLConnection) return mres
withResourceTimeout ::
    (MonadBaseControl IO m)
  => Int 
  -> Pool a
  -> (a -> m b)
  -> m (Maybe b)
withResourceTimeout ms pool act = control $ \runInIO -> mask $ \restore -> do
    mres <- timeout ms $ takeResource pool
    case mres of
        Nothing -> runInIO $ return Nothing
        Just (resource, local) -> do
            ret <- restore (runInIO (liftM Just $ act resource)) `onException`
                    destroyResource pool local resource
            putResource local resource
            return ret
runSqlConn :: MonadBaseControl IO m => SqlPersistT m a -> Connection -> m a
runSqlConn r conn = do
    let getter = getStmtConn conn
    liftBase $ connBegin conn getter
    x <- onException
            (runReaderT r conn)
            (liftBase $ connRollback conn getter)
    liftBase $ connCommit conn getter
    return x
runSqlPersistM :: SqlPersistM a -> Connection -> IO a
runSqlPersistM x conn = runResourceT $ runNoLoggingT $ runSqlConn x conn
runSqlPersistMPool :: SqlPersistM a -> Pool Connection -> IO a
runSqlPersistMPool x pool = runResourceT $ runNoLoggingT $ runSqlPool x pool
withSqlPool :: (MonadIO m, MonadLogger m, MonadBaseControl IO m)
            => (LogFunc -> IO Connection) 
            -> Int 
            -> (Pool Connection -> m a)
            -> m a
withSqlPool mkConn connCount f = do
    pool <- createSqlPool mkConn connCount
    f pool
createSqlPool :: (MonadIO m, MonadLogger m, MonadBaseControl IO m)
              => (LogFunc -> IO Connection)
              -> Int
              -> m (Pool Connection)
createSqlPool mkConn size = do
    logFunc <- askLogFunc
    liftIO $ createPool (mkConn logFunc) close' 1 20 size
askLogFunc :: (MonadBaseControl IO m, MonadLogger m) => m LogFunc
askLogFunc = do
    runInBase <- control $ \run -> run $ return run
    return $ \a b c d -> do
        _ <- runInBase (monadLoggerLog a b c d)
        return ()
withSqlConn :: (MonadIO m, MonadBaseControl IO m, MonadLogger m)
            => (LogFunc -> IO Connection) -> (Connection -> m a) -> m a
withSqlConn open f = do
    logFunc <- askLogFunc
    bracket (liftIO $ open logFunc) (liftIO . close') f
close' :: Connection -> IO ()
close' conn = do
    readIORef (connStmtMap conn) >>= mapM_ stmtFinalize . Map.elems
    connClose conn