{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE PackageImports #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE EmptyDataDecls #-}
module Database.Persist.GenericSql.Raw
    ( withStmt
    , execute
    , SqlPersist (..)
    , getStmt'
    , getStmt
    , SqlBackend
    , MonadSqlPersist (..)
    ) where

import qualified Database.Persist.GenericSql.Internal as I
import Database.Persist.GenericSql.Internal hiding (execute, withStmt)
import Database.Persist.Store (PersistValue)
import Data.IORef
import Control.Monad.IO.Class
import Control.Monad.Trans.Reader
import qualified Data.Map as Map
import Control.Applicative (Applicative)
import Control.Monad.Trans.Class (MonadTrans (..))
import Control.Monad.Base (MonadBase (liftBase))
import Control.Monad.Trans.Control (MonadBaseControl (..), ComposeSt, defaultLiftBaseWith, defaultRestoreM, MonadTransControl (..))
import Control.Monad (liftM)
#define MBCIO MonadBaseControl IO
import Data.Text (Text)
import Control.Monad (MonadPlus)
import Control.Monad.Trans.Resource (MonadResource (..))
import Data.Conduit
import Control.Monad.Logger (MonadLogger (..))
import Data.Monoid (Monoid)

import Control.Monad.Logger (LoggingT)
import Control.Monad.Trans.Identity ( IdentityT)
import Control.Monad.Trans.List     ( ListT    )
import Control.Monad.Trans.Maybe    ( MaybeT   )
import Control.Monad.Trans.Error    ( ErrorT, Error)
import Control.Monad.Trans.Cont     ( ContT  )
import Control.Monad.Trans.State    ( StateT   )
import Control.Monad.Trans.Writer   ( WriterT  )
import Control.Monad.Trans.RWS      ( RWST     )

import qualified Control.Monad.Trans.RWS.Strict    as Strict ( RWST   )
import qualified Control.Monad.Trans.State.Strict  as Strict ( StateT )
import qualified Control.Monad.Trans.Writer.Strict as Strict ( WriterT )

data SqlBackend

newtype SqlPersist m a = SqlPersist { unSqlPersist :: ReaderT Connection m a }
    deriving (Monad, MonadIO, MonadTrans, Functor, Applicative, MonadPlus)

instance MonadThrow m => MonadThrow (SqlPersist m) where
    monadThrow = lift . monadThrow

instance MonadBase backend m => MonadBase backend (SqlPersist m) where
    liftBase = lift . liftBase

instance MonadBaseControl backend m => MonadBaseControl backend (SqlPersist m) where
     newtype StM (SqlPersist m) a = StMSP {unStMSP :: ComposeSt SqlPersist m a}
     liftBaseWith = defaultLiftBaseWith StMSP
     restoreM     = defaultRestoreM   unStMSP
instance MonadTransControl SqlPersist where
    newtype StT SqlPersist a = StReader {unStReader :: a}
    liftWith f = SqlPersist $ ReaderT $ \r -> f $ \t -> liftM StReader $ runReaderT (unSqlPersist t) r
    restoreT = SqlPersist . ReaderT . const . liftM unStReader

instance MonadResource m => MonadResource (SqlPersist m) where
#if MIN_VERSION_resourcet(0,4,0)
    liftResourceT = lift . liftResourceT
#else
    register = lift . register
    release = lift . release
    allocate a = lift . allocate a
    resourceMask = lift . resourceMask
#endif

class (MonadIO m, MonadLogger m) => MonadSqlPersist m where
    askSqlConn :: m Connection

instance (MonadIO m, MonadLogger m) => MonadSqlPersist (SqlPersist m) where
    askSqlConn = SqlPersist ask

#define GO(T) instance (MonadSqlPersist m) => MonadSqlPersist (T m) where askSqlConn = lift askSqlConn
#define GOX(X, T) instance (X, MonadSqlPersist m) => MonadSqlPersist (T m) where askSqlConn = lift askSqlConn
GO(LoggingT)
GO(IdentityT)
GO(ListT)
GO(MaybeT)
GOX(Error e, ErrorT e)
GO(ReaderT r)
GO(ContT r)
GO(StateT s)
GO(ResourceT)
GO(Pipe l i o u)
GOX(Monoid w, WriterT w)
GOX(Monoid w, RWST r w s)
GOX(Monoid w, Strict.RWST r w s)
GO(Strict.StateT s)
GOX(Monoid w, Strict.WriterT w)
#undef GO
#undef GOX

instance MonadLogger m => MonadLogger (SqlPersist m) where
    monadLoggerLog a b c = lift $ monadLoggerLog a b c
    monadLoggerLogSource a b c = lift . monadLoggerLogSource a b c

withStmt :: (MonadSqlPersist m, MonadResource m)
         => Text
         -> [PersistValue]
         -> Source m [PersistValue]
withStmt sql vals = do
    lift $ $logSQL sql vals
    conn <- lift askSqlConn
    bracketP
        (getStmt' conn sql)
        I.reset
        (flip I.withStmt vals)

execute :: MonadSqlPersist m => Text -> [PersistValue] -> m ()
execute sql vals = do
    $logSQL sql vals
    stmt <- getStmt sql
    liftIO $ I.execute stmt vals
    liftIO $ reset stmt

getStmt :: MonadSqlPersist m => Text -> m Statement
getStmt sql = do
    conn <- askSqlConn
    liftIO $ getStmt' conn sql

getStmt' :: Connection -> Text -> IO Statement
getStmt' conn sql = do
    smap <- liftIO $ readIORef $ stmtMap conn
    case Map.lookup sql smap of
        Just stmt -> return stmt
        Nothing -> do
            stmt <- liftIO $ prepare conn sql
            liftIO $ writeIORef (stmtMap conn) $ Map.insert sql stmt smap
            return stmt