{-# LANGUAGE FlexibleContexts      #-}
{-# LANGUAGE FlexibleInstances     #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE OverloadedStrings     #-}
{-# LANGUAGE TypeFamilies          #-}

module Snap.Snaplet.Persistent
  ( initPersist
  , initPersistGeneric
  , PersistState(..)
  , HasPersistPool(..)
  , mkPgPool
  , mkSnapletPgPool
  , runPersist
  , withPool

  -- * Utility Functions
  , mkKey
  , mkKeyBS
  , mkKeyT
  , showKey
  , showKeyBS
  , mkInt
  , mkWord64
  , followForeignKey
  , fromPersistValue'
  ) where

-------------------------------------------------------------------------------
import           Control.Monad.Logger
import           Control.Monad.State
import           Control.Monad.Trans.Reader
import           Control.Monad.Trans.Resource
import           Data.ByteString              (ByteString)
import           Data.Configurator
import           Data.Configurator.Types
import           Data.Maybe
import           Data.Pool
import           Data.Readable
import           Data.Text                    (Text)
import qualified Data.Text                    as T
import qualified Data.Text.Encoding           as T
import           Data.Word
import           Database.Persist
import           Database.Persist.Postgresql  hiding (get)
import qualified Database.Persist.Postgresql  as DB
import           Paths_snaplet_persistent
import           Snap.Snaplet
-------------------------------------------------------------------------------


-------------------------------------------------------------------------------
newtype PersistState = PersistState { PersistState -> ConnectionPool
persistPool :: ConnectionPool }


-------------------------------------------------------------------------------
-- | Implement this type class to have any monad work with snaplet-persistent.
-- A default instance is provided for (Handler b PersistState).
class MonadIO m => HasPersistPool m where
    getPersistPool :: m ConnectionPool


instance HasPersistPool m => HasPersistPool (NoLoggingT m) where
    getPersistPool :: NoLoggingT m ConnectionPool
getPersistPool = NoLoggingT (NoLoggingT m) ConnectionPool
-> NoLoggingT m ConnectionPool
forall (m :: * -> *) a. NoLoggingT m a -> m a
runNoLoggingT NoLoggingT (NoLoggingT m) ConnectionPool
forall (m :: * -> *). HasPersistPool m => m ConnectionPool
getPersistPool

instance HasPersistPool (Handler b PersistState) where
    getPersistPool :: Handler b PersistState ConnectionPool
getPersistPool = (PersistState -> ConnectionPool)
-> Handler b PersistState ConnectionPool
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets PersistState -> ConnectionPool
persistPool

instance MonadIO m => HasPersistPool (ReaderT ConnectionPool m) where
    getPersistPool :: ReaderT ConnectionPool m ConnectionPool
getPersistPool = ReaderT ConnectionPool m ConnectionPool
forall (m :: * -> *) r. Monad m => ReaderT r m r
ask


-------------------------------------------------------------------------------
-- | Initialize Persistent with an initial SQL function called right
-- after the connection pool has been created. This is most useful for
-- calling migrations upfront right after initialization.
--
-- Example:
--
-- > initPersist (runMigrationUnsafe migrateAll)
--
-- where migrateAll is the migration function that was auto-generated
-- by the QQ statement in your persistent schema definition in the
-- call to 'mkMigrate'.
initPersist :: SqlPersistT (NoLoggingT IO) a -> SnapletInit b PersistState
initPersist :: SqlPersistT (NoLoggingT IO) a -> SnapletInit b PersistState
initPersist = Initializer b PersistState ConnectionPool
-> SqlPersistT (NoLoggingT IO) a -> SnapletInit b PersistState
forall b a.
Initializer b PersistState ConnectionPool
-> SqlPersistT (NoLoggingT IO) a -> SnapletInit b PersistState
initPersistGeneric Initializer b PersistState ConnectionPool
forall (m :: * -> * -> * -> *) b v.
(MonadIO (m b v), MonadSnaplet m) =>
m b v ConnectionPool
mkSnapletPgPool


-------------------------------------------------------------------------------
-- | Backend-agnostic initalization with an initial SQL function called right
-- after the connection pool has been created. This is most useful for
-- calling migrations upfront right after initialization.
--
-- Example:
--
-- > initPersist mkPool (runMigrationUnsafe migrateAll)
--
-- where migrateAll is the migration function that was auto-generated
-- by the QQ statement in your persistent schema definition in the
-- call to 'mkMigrate'.
--
-- mkPool is a function to construct a pool of connections to your database
initPersistGeneric
    :: Initializer b PersistState (Pool SqlBackend)
    -> SqlPersistT (NoLoggingT IO) a
    -> SnapletInit b PersistState
initPersistGeneric :: Initializer b PersistState ConnectionPool
-> SqlPersistT (NoLoggingT IO) a -> SnapletInit b PersistState
initPersistGeneric Initializer b PersistState ConnectionPool
mkPool SqlPersistT (NoLoggingT IO) a
migration = Text
-> Text
-> Maybe (IO FilePath)
-> Initializer b PersistState PersistState
-> SnapletInit b PersistState
forall b v.
Text
-> Text
-> Maybe (IO FilePath)
-> Initializer b v v
-> SnapletInit b v
makeSnaplet Text
"persist" Text
description Maybe (IO FilePath)
datadir (Initializer b PersistState PersistState
 -> SnapletInit b PersistState)
-> Initializer b PersistState PersistState
-> SnapletInit b PersistState
forall a b. (a -> b) -> a -> b
$ do
    ConnectionPool
p <- Initializer b PersistState ConnectionPool
mkPool
    a
_ <- IO a -> Initializer b PersistState a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO a -> Initializer b PersistState a)
-> IO a -> Initializer b PersistState a
forall a b. (a -> b) -> a -> b
$ NoLoggingT IO a -> IO a
forall (m :: * -> *) a. NoLoggingT m a -> m a
runNoLoggingT (NoLoggingT IO a -> IO a) -> NoLoggingT IO a -> IO a
forall a b. (a -> b) -> a -> b
$ SqlPersistT (NoLoggingT IO) a -> ConnectionPool -> NoLoggingT IO a
forall backend (m :: * -> *) a.
(MonadUnliftIO m, BackendCompatible SqlBackend backend) =>
ReaderT backend m a -> Pool backend -> m a
runSqlPool SqlPersistT (NoLoggingT IO) a
migration ConnectionPool
p
    PersistState -> Initializer b PersistState PersistState
forall (m :: * -> *) a. Monad m => a -> m a
return (PersistState -> Initializer b PersistState PersistState)
-> PersistState -> Initializer b PersistState PersistState
forall a b. (a -> b) -> a -> b
$ ConnectionPool -> PersistState
PersistState ConnectionPool
p
  where
    description :: Text
description = Text
"Snaplet for persistent DB library"
    datadir :: Maybe (IO FilePath)
datadir = IO FilePath -> Maybe (IO FilePath)
forall a. a -> Maybe a
Just (IO FilePath -> Maybe (IO FilePath))
-> IO FilePath -> Maybe (IO FilePath)
forall a b. (a -> b) -> a -> b
$ (FilePath -> FilePath) -> IO FilePath -> IO FilePath
forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
liftM (FilePath -> FilePath -> FilePath
forall a. [a] -> [a] -> [a]
++FilePath
"/resources/db") IO FilePath
getDataDir


-------------------------------------------------------------------------------
-- | Constructs a connection pool from Config.
mkPgPool :: MonadIO m => Config -> m ConnectionPool
mkPgPool :: Config -> m ConnectionPool
mkPgPool Config
conf = do
  ConnectionString
pgConStr <- IO ConnectionString -> m ConnectionString
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO ConnectionString -> m ConnectionString)
-> IO ConnectionString -> m ConnectionString
forall a b. (a -> b) -> a -> b
$ Config -> Text -> IO ConnectionString
forall a. Configured a => Config -> Text -> IO a
require Config
conf Text
"postgre-con-str"
  Int
cons <- IO Int -> m Int
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO Int -> m Int) -> IO Int -> m Int
forall a b. (a -> b) -> a -> b
$ Config -> Text -> IO Int
forall a. Configured a => Config -> Text -> IO a
require Config
conf Text
"postgre-pool-size"
  IO ConnectionPool -> m ConnectionPool
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO ConnectionPool -> m ConnectionPool)
-> (NoLoggingT IO ConnectionPool -> IO ConnectionPool)
-> NoLoggingT IO ConnectionPool
-> m ConnectionPool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. NoLoggingT IO ConnectionPool -> IO ConnectionPool
forall (m :: * -> *) a. NoLoggingT m a -> m a
runNoLoggingT (NoLoggingT IO ConnectionPool -> m ConnectionPool)
-> NoLoggingT IO ConnectionPool -> m ConnectionPool
forall a b. (a -> b) -> a -> b
$ ConnectionString -> Int -> NoLoggingT IO ConnectionPool
forall (m :: * -> *).
(MonadUnliftIO m, MonadLoggerIO m) =>
ConnectionString -> Int -> m ConnectionPool
createPostgresqlPool ConnectionString
pgConStr Int
cons


-------------------------------------------------------------------------------
-- | Conscruts a connection pool in a snaplet context.
mkSnapletPgPool :: (MonadIO (m b v), MonadSnaplet m) => m b v ConnectionPool
mkSnapletPgPool :: m b v ConnectionPool
mkSnapletPgPool = do
  Config
conf <- m b v Config
forall (m :: * -> * -> * -> *) b v.
(Monad (m b v), MonadSnaplet m) =>
m b v Config
getSnapletUserConfig
  Config -> m b v ConnectionPool
forall (m :: * -> *). MonadIO m => Config -> m ConnectionPool
mkPgPool Config
conf


-------------------------------------------------------------------------------
-- | Runs a SqlPersist action in any monad with a HasPersistPool instance.
runPersist :: (HasPersistPool m)
           => SqlPersistM b
           -- ^ Run given Persistent action in the defined monad.
           -> m b
runPersist :: SqlPersistM b -> m b
runPersist SqlPersistM b
action = do
  ConnectionPool
pool <- m ConnectionPool
forall (m :: * -> *). HasPersistPool m => m ConnectionPool
getPersistPool
  ConnectionPool -> SqlPersistM b -> m b
forall (m :: * -> *) a.
MonadIO m =>
ConnectionPool -> SqlPersistM a -> m a
withPool ConnectionPool
pool SqlPersistM b
action


------------------------------------------------------------------------------
-- | Run a database action
withPool :: MonadIO m
         => ConnectionPool
         -> SqlPersistM a -> m a
withPool :: ConnectionPool -> SqlPersistM a -> m a
withPool ConnectionPool
cp SqlPersistM a
f = IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO a -> m a)
-> (NoLoggingT (ResourceT IO) a -> IO a)
-> NoLoggingT (ResourceT IO) a
-> m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ResourceT IO a -> IO a
forall (m :: * -> *) a. MonadUnliftIO m => ResourceT m a -> m a
runResourceT (ResourceT IO a -> IO a)
-> (NoLoggingT (ResourceT IO) a -> ResourceT IO a)
-> NoLoggingT (ResourceT IO) a
-> IO a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. NoLoggingT (ResourceT IO) a -> ResourceT IO a
forall (m :: * -> *) a. NoLoggingT m a -> m a
runNoLoggingT (NoLoggingT (ResourceT IO) a -> m a)
-> NoLoggingT (ResourceT IO) a -> m a
forall a b. (a -> b) -> a -> b
$ SqlPersistM a -> ConnectionPool -> NoLoggingT (ResourceT IO) a
forall backend (m :: * -> *) a.
(MonadUnliftIO m, BackendCompatible SqlBackend backend) =>
ReaderT backend m a -> Pool backend -> m a
runSqlPool SqlPersistM a
f ConnectionPool
cp


-------------------------------------------------------------------------------
-- | Make a Key from an Int.
mkKey :: ToBackendKey SqlBackend entity => Int -> Key entity
mkKey :: Int -> Key entity
mkKey = BackendKey SqlBackend -> Key entity
forall backend record.
ToBackendKey backend record =>
BackendKey backend -> Key record
fromBackendKey (BackendKey SqlBackend -> Key entity)
-> (Int -> BackendKey SqlBackend) -> Int -> Key entity
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int64 -> BackendKey SqlBackend
SqlBackendKey (Int64 -> BackendKey SqlBackend)
-> (Int -> Int64) -> Int -> BackendKey SqlBackend
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral


-------------------------------------------------------------------------------
-- | Makes a Key from a ByteString.  Calls error on failure.
mkKeyBS :: ToBackendKey SqlBackend entity => ByteString -> Key entity
mkKeyBS :: ConnectionString -> Key entity
mkKeyBS = Int -> Key entity
forall entity. ToBackendKey SqlBackend entity => Int -> Key entity
mkKey (Int -> Key entity)
-> (ConnectionString -> Int) -> ConnectionString -> Key entity
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> Maybe Int -> Int
forall a. a -> Maybe a -> a
fromMaybe (FilePath -> Int
forall a. HasCallStack => FilePath -> a
error FilePath
"Can't ByteString value") (Maybe Int -> Int)
-> (ConnectionString -> Maybe Int) -> ConnectionString -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ConnectionString -> Maybe Int
forall a (m :: * -> *).
(Readable a, MonadPlus m) =>
ConnectionString -> m a
fromBS


-------------------------------------------------------------------------------
-- | Makes a Key from Text.  Calls error on failure.
mkKeyT :: ToBackendKey SqlBackend entity => Text -> Key entity
mkKeyT :: Text -> Key entity
mkKeyT = Int -> Key entity
forall entity. ToBackendKey SqlBackend entity => Int -> Key entity
mkKey (Int -> Key entity) -> (Text -> Int) -> Text -> Key entity
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> Maybe Int -> Int
forall a. a -> Maybe a -> a
fromMaybe (FilePath -> Int
forall a. HasCallStack => FilePath -> a
error FilePath
"Can't Text value") (Maybe Int -> Int) -> (Text -> Maybe Int) -> Text -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Text -> Maybe Int
forall a (m :: * -> *). (Readable a, MonadPlus m) => Text -> m a
fromText


-------------------------------------------------------------------------------
-- | Makes a Text representation of a Key.
showKey :: ToBackendKey SqlBackend e => Key e -> Text
showKey :: Key e -> Text
showKey = FilePath -> Text
T.pack (FilePath -> Text) -> (Key e -> FilePath) -> Key e -> Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> FilePath
forall a. Show a => a -> FilePath
show (Int -> FilePath) -> (Key e -> Int) -> Key e -> FilePath
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Key e -> Int
forall a. ToBackendKey SqlBackend a => Key a -> Int
mkInt


-------------------------------------------------------------------------------
-- | Makes a ByteString representation of a Key.
showKeyBS :: ToBackendKey SqlBackend e => Key e -> ByteString
showKeyBS :: Key e -> ConnectionString
showKeyBS = Text -> ConnectionString
T.encodeUtf8 (Text -> ConnectionString)
-> (Key e -> Text) -> Key e -> ConnectionString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Key e -> Text
forall e. ToBackendKey SqlBackend e => Key e -> Text
showKey


-------------------------------------------------------------------------------
-- | Converts a Key to Int.  Fails with error if the conversion fails.
mkInt :: ToBackendKey SqlBackend a => Key a -> Int
mkInt :: Key a -> Int
mkInt = Int64 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int64 -> Int) -> (Key a -> Int64) -> Key a -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BackendKey SqlBackend -> Int64
unSqlBackendKey (BackendKey SqlBackend -> Int64)
-> (Key a -> BackendKey SqlBackend) -> Key a -> Int64
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Key a -> BackendKey SqlBackend
forall backend record.
ToBackendKey backend record =>
Key record -> BackendKey backend
toBackendKey


-------------------------------------------------------------------------------
-- | Converts a Key to Word64.  Fails with error if the conversion fails.
mkWord64 :: ToBackendKey SqlBackend a => Key a -> Word64
mkWord64 :: Key a -> Word64
mkWord64 = Int64 -> Word64
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int64 -> Word64) -> (Key a -> Int64) -> Key a -> Word64
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BackendKey SqlBackend -> Int64
unSqlBackendKey (BackendKey SqlBackend -> Int64)
-> (Key a -> BackendKey SqlBackend) -> Key a -> Int64
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Key a -> BackendKey SqlBackend
forall backend record.
ToBackendKey backend record =>
Key record -> BackendKey backend
toBackendKey


-------------------------------------------------------------------------------
-- Converts a PersistValue to a more concrete type.  Calls error if the
-- conversion fails.
fromPersistValue' :: PersistField c => PersistValue -> c
fromPersistValue' :: PersistValue -> c
fromPersistValue' = (Text -> c) -> (c -> c) -> Either Text c -> c
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either (c -> Text -> c
forall a b. a -> b -> a
const (c -> Text -> c) -> c -> Text -> c
forall a b. (a -> b) -> a -> b
$ FilePath -> c
forall a. HasCallStack => FilePath -> a
error FilePath
"Persist conversion failed") c -> c
forall a. a -> a
id
                    (Either Text c -> c)
-> (PersistValue -> Either Text c) -> PersistValue -> c
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PersistValue -> Either Text c
forall a. PersistField a => PersistValue -> Either Text a
fromPersistValue


------------------------------------------------------------------------------
-- | Follows a foreign key field in one entity and retrieves the corresponding
-- entity from the database.
followForeignKey :: (PersistEntity a, HasPersistPool m,
                     PersistEntityBackend a ~ SqlBackend)
                 => (t -> Key a) -> Entity t -> m (Maybe (Entity a))
followForeignKey :: (t -> Key a) -> Entity t -> m (Maybe (Entity a))
followForeignKey t -> Key a
toKey (Entity Key t
_ t
val) = do
    let key' :: Key a
key' = t -> Key a
toKey t
val
    Maybe a
mval <- SqlPersistM (Maybe a) -> m (Maybe a)
forall (m :: * -> *) b. HasPersistPool m => SqlPersistM b -> m b
runPersist (SqlPersistM (Maybe a) -> m (Maybe a))
-> SqlPersistM (Maybe a) -> m (Maybe a)
forall a b. (a -> b) -> a -> b
$ Key a -> SqlPersistM (Maybe a)
forall backend record (m :: * -> *).
(PersistStoreRead backend, MonadIO m,
 PersistRecordBackend record backend) =>
Key record -> ReaderT backend m (Maybe record)
DB.get Key a
key'
    Maybe (Entity a) -> m (Maybe (Entity a))
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe (Entity a) -> m (Maybe (Entity a)))
-> Maybe (Entity a) -> m (Maybe (Entity a))
forall a b. (a -> b) -> a -> b
$ (a -> Entity a) -> Maybe a -> Maybe (Entity a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Key a -> a -> Entity a
forall rec. Key rec -> rec -> Entity rec
Entity Key a
key') Maybe a
mval