{-| Implementation of a PostgreSQL-based backend pool. -}
module Data.CQRS.EventStore.Backend.PostgreSQL
       ( createBackendPool
       ) where

import           Control.Monad (when, forM_, void)
import           Data.ByteString (ByteString)
import           Data.Conduit (ResourceT, Source, ($=), ($$), runResourceT)
import qualified Data.Conduit.List as CL
import           Data.Conduit.Pool (Pool, createPool)
import           Data.CQRS.EventStore.Backend (EventStoreBackend(..), RawEvent, RawSnapshot(..))
import           Data.CQRS.GUID (GUID)
import qualified Data.CQRS.GUID as G
import           Data.CQRS.PersistedEvent (PersistedEvent(..))
import           Database.PostgreSQL.LibPQ (Connection)
import qualified Database.PostgreSQL.LibPQ as P

import           Data.CQRS.EventStore.Backend.PostgreSQLUtils

-- SQL
createEventsSql :: ByteString
createEventsSql = "CREATE TABLE IF NOT EXISTS events ( guid BYTEA , ev_data BYTEA , version INTEGER , PRIMARY KEY (guid, version) );"

selectEventsSql :: ByteString
selectEventsSql = "SELECT version, ev_data FROM events WHERE guid = $1 AND version >= $2 ORDER BY version ASC;"

enumerateAllEventsSql :: ByteString
enumerateAllEventsSql = "SELECT guid, version, ev_data FROM events ORDER BY version ASC;"

insertEventSql :: ByteString
insertEventSql = "INSERT INTO events ( guid, version, ev_data ) VALUES ($1, $2, $3);"

insertVersionSql :: ByteString
insertVersionSql = "INSERT INTO versions ( guid , version ) SELECT $1 , $2 WHERE $1 NOT IN ( SELECT guid FROM versions );"

insertSnapshotSql :: ByteString
insertSnapshotSql = "INSERT INTO snapshots ( guid, data, version) SELECT $1 , $2, $3 WHERE $1 NOT IN ( SELECT guid FROM snapshots );"

createAggregateVersionsSql :: ByteString
createAggregateVersionsSql = "CREATE TABLE IF NOT EXISTS versions ( guid BYTEA PRIMARY KEY , version INTEGER );"

getCurrentVersionSql :: ByteString
getCurrentVersionSql = "SELECT version FROM versions WHERE guid = $1;"

updateCurrentVersionSql :: ByteString
updateCurrentVersionSql = "UPDATE versions SET version = $1 WHERE guid = $2;"

createSnapshotSql :: ByteString
createSnapshotSql = "CREATE TABLE IF NOT EXISTS snapshots ( guid BYTEA PRIMARY KEY , data BYTEA , version INTEGER );"

writeSnapshotSql :: ByteString
writeSnapshotSql = "UPDATE snapshots SET data=$1, version=$2 WHERE guid=$3;"

selectSnapshotSql :: ByteString
selectSnapshotSql = "SELECT data, version FROM snapshots WHERE guid = $1;"

badQueryResultMsg :: [String] -> [SqlValue] -> String
badQueryResultMsg params columns = concat ["Invalid query result shape. Params: ", show params, ". Result columns: ", show columns]

versionConflict :: (Show a, Show b) => a -> b -> IO c
versionConflict ov cv =
  fail $ concat [ "Version conflict detected (expected ", show ov
                , ", saw ", show cv, ")"
                ]

sqlGuid :: GUID -> SqlValue
sqlGuid g = SqlByteArray (Just $ G.toByteString g)

storeEvents :: Connection -> GUID -> Int -> [RawEvent] -> IO ()
storeEvents c guid originatingVersion events = do
  -- Column unpacking.
  let unpackColumns [ SqlInt32 (Just v) ] = fromIntegral v
      unpackColumns columns = error $ badQueryResultMsg [show guid] columns
  -- Get the current version number of the aggregate.
  curVer <- runResourceT $ (sourceQuery c getCurrentVersionSql [sqlGuid guid]) $$ CL.fold (\x -> max x . unpackColumns) 0

  -- Sanity check current version number.
  when (curVer /= originatingVersion) $
    versionConflict originatingVersion curVer

  -- Update de-normalized version number.
  let version' = originatingVersion + length events
  void $ run c updateCurrentVersionSql
    [ SqlInt32 $ Just $ fromIntegral $ version'
    , sqlGuid guid
    ]
  void $ run c insertVersionSql
    [ sqlGuid guid
    , SqlInt32 $ Just $ fromIntegral $ version'
    ]

  -- Store the supplied events.
  forM_ events $ \e -> do
    void $ run c insertEventSql
      [ sqlGuid guid
      , SqlInt32 $ Just $ fromIntegral $ peSequenceNumber e
      , SqlByteArray $ Just $ peEvent e
      ]


retrieveEvents :: Connection -> GUID -> Int -> Source (ResourceT IO) RawEvent
retrieveEvents connection guid v0 = do
  -- Unpack the columns into tuples.
  let unpackColumns [ SqlInt32 (Just version)
                    , SqlByteArray (Just eventData)
                    ]       = PersistedEvent guid eventData (fromIntegral version)
      unpackColumns columns = error $ badQueryResultMsg [show guid, show v0] columns
  -- Find events with version numbers.
  sourceQuery connection selectEventsSql
    [sqlGuid guid, SqlInt32 $ Just $ fromIntegral v0] $= CL.map unpackColumns

enumerateAllEvents :: Connection -> Source (ResourceT IO) RawEvent
enumerateAllEvents connection = do
  sourceQuery connection enumerateAllEventsSql [] $= CL.map
    (\columns -> do
        case columns of
          [ SqlByteArray (Just g), SqlInt32 (Just v), SqlByteArray (Just ed)] ->
            PersistedEvent (G.fromByteString g) ed (fromIntegral v)
          _ ->
            error $ badQueryResultMsg [] columns)

writeSnapshot :: Connection -> GUID -> RawSnapshot -> IO ()
writeSnapshot c guid (RawSnapshot v d) = do
  void $ run c writeSnapshotSql
    [ SqlByteArray (Just d)
    , SqlInt32 $ Just $ fromIntegral v
    , sqlGuid guid
    ]
  void $ run c insertSnapshotSql
    [ sqlGuid guid
    , SqlByteArray (Just d)
    , SqlInt32 $ Just $ fromIntegral v
    ]

getLatestSnapshot :: Connection -> GUID -> IO (Maybe RawSnapshot)
getLatestSnapshot connection guid = do
  -- Unpack columns from result.
  let unpackColumns :: [SqlValue] -> (ByteString, Int)
      unpackColumns [ SqlByteArray (Just d)
                    , SqlInt32 (Just v) ] = (d, fromIntegral v)
      unpackColumns columns               = error $ badQueryResultMsg [show guid] columns
  -- Run the query.
  r <- runResourceT $ (sourceQuery connection selectSnapshotSql [sqlGuid guid] $= (CL.map unpackColumns)) $$ CL.take 1
  case r of
    ((d,v):_) -> return $ Just $ RawSnapshot v d
    []        -> return Nothing

-- | PostgreSQL backend
newtype PostgreSQLEventStoreBackend = ESB Connection

-- | Instance of EventStoreBackend for PostgreSQLBackend.
instance EventStoreBackend PostgreSQLEventStoreBackend where
    esbStoreEvents (ESB c) = storeEvents c
    esbRetrieveEvents (ESB c) = retrieveEvents c
    esbEnumerateAllEvents (ESB c) = enumerateAllEvents c
    esbWriteSnapshot (ESB c) = writeSnapshot c
    esbGetLatestSnapshot (ESB c) = getLatestSnapshot c
    esbWithTransaction (ESB c) = withTransaction c

-- | Create a pool of PostgreSQL-based event store backends.
createBackendPool :: Int -> ByteString -> IO (Pool PostgreSQLEventStoreBackend)
createBackendPool n connectionString = do
  createPool open close 1 1 n
  where
    open = do
      -- Connect
      c <- P.connectdb connectionString
      -- Set up tables if necessary.
      void $ run c createEventsSql []
      void $ run c createAggregateVersionsSql []
      void $ run c createSnapshotSql []
      -- Return backend
      return $ ESB c
    close (ESB c) = do
      P.finish c