-- | Run transactions against event stores.
module Data.CQRS.Transaction
       ( TransactionT
       , getAggregateRoot
       , publishEvent
       , runTransactionT
       ) where

import           Control.Monad (forM_, when, liftM)
import           Control.Monad.IO.Class (MonadIO, liftIO)
import           Control.Monad.Trans.Class (MonadTrans(..), lift)
import           Control.Monad.Trans.State (StateT, get, gets, modify, runStateT)
import           Data.CQRS.Aggregate (Aggregate(..))
import           Data.CQRS.Internal.AggregateRef (AggregateRef, mkAggregateRef)
import qualified Data.CQRS.Internal.AggregateRef as AR
import           Data.CQRS.Eventable (Eventable(..))
import           Data.CQRS.EventStore (EventStore)
import qualified Data.CQRS.Internal.EventStore as ES
import           Data.CQRS.GUID (GUID)
import           Data.Default (Default(..))
import           Data.List (find)
import           Data.Serialize (Serialize)
import           Data.Typeable (Typeable, cast)

-- | Transaction monad transformer.
newtype TransactionT e m a = TransactionT (Transaction e m a)
                           deriving (Functor, Monad)

instance MonadTrans (TransactionT e) where
  lift m = TransactionT $ lift m

-- Existential wrapper for AggregateRef.
data BoxedAggregateRef e =
  forall a . (Typeable a, Typeable e, Serialize e, Default a, Aggregate a, Eventable a e) => BoxedAggregateRef (AggregateRef a e)

-- | Transaction monad itself.
type Transaction e = StateT (TransactionState e)

data TransactionState e =
  TransactionState { eventStore :: EventStore e
                   , aggregateRefsToCommit :: [BoxedAggregateRef e]
                   , tsCurrentVersion :: Int
                   }

-- | Run transaction against an event store.
runTransactionT :: (Typeable e, Serialize e) => EventStore e -> TransactionT e IO c -> IO c
runTransactionT eventStore_ (TransactionT transaction) = do
  ES.withTransaction eventStore_ $ do
    -- Get the latest global version number to use.
    gv0 <- liftM ((+) 1) $ ES.getLatestVersion eventStore_
    -- Run the computation.
    (r,s) <- runStateT transaction $ s0 gv0
    -- Write out all the aggregates.
    forM_ (aggregateRefsToCommit s) $ \(BoxedAggregateRef a) -> do
      -- Write out accumulated events.
      evs <- AR.readEvents a
      ES.storeEvents eventStore_ (AR.arGUID a) (AR.arStartVersion a) evs
      -- If we've advanced N events past the last snapshot, we
      -- create a new snapshot.
      v <- AR.getCurrentVersion a
      when (v - AR.arSnapshotVersion a > 10) $ do
        av <- AR.readValue a
        ES.writeSnapshot eventStore_ (AR.arGUID a) (v, encodeAggregate av)
    -- Return the value.
    return r
  where s0 = TransactionState eventStore_ [ ]

-- | Get an aggregate ref by GUID.
getById :: forall a e . (Typeable a, Typeable e, Serialize e, Default a, Aggregate a, Eventable a e) => GUID -> TransactionT e IO (AggregateRef a e)
getById guid = TransactionT $ do
  -- Check through list to see if we've given out a reference to the aggregate before.
  aggregateRefs <- fmap aggregateRefsToCommit get
  case find (\(BoxedAggregateRef a) -> AR.arGUID a == guid) aggregateRefs of
    Just (BoxedAggregateRef a) ->
      case cast a of
        Just (a' :: AggregateRef a e) -> return a'
        Nothing ->
          -- This cast could only really fail if there are duplicate GUIDs for
          -- different types of aggregates/events.
          fail $ concat ["Duplicate GUID ", show guid, "!" ]
    Nothing -> do
      getByIdFromEventStore guid

-- Get the latest snapshot from database, filling in a default
-- if a) no snapshot exists, or b) snapshot state was not decodable.
getLatestSnapshot :: forall a e . (Typeable a, Typeable e, Serialize e, Default a, Aggregate a) => GUID -> Transaction e IO (Int,a)
getLatestSnapshot guid = do
  es <- fmap eventStore get
  r <- liftIO $ ES.getLatestSnapshot es guid
  case r of
    Just (v',a') -> do
      case decodeAggregate a' :: Maybe a of
        Just a'' -> return (v', a'')
        Nothing -> return (0, def)
    Nothing -> do
      return (0, def)

-- Retrieve aggregate from event store.
getByIdFromEventStore :: forall a e . (Typeable a, Typeable e, Serialize e, Default a, Aggregate a, Eventable a e) => GUID -> Transaction e IO (AggregateRef a e)
getByIdFromEventStore guid = do
  es <- fmap eventStore get
  -- Get latest snapshot (if any).
  (v0,a0) <- getLatestSnapshot guid
  -- Get events.
  (latestVersion, events) <- lift $ ES.retrieveEvents es guid v0
  -- Build the aggregate state from all the events.
  let a = foldr applyEvent a0 events
  -- Make the aggregate itself
  (a' :: AggregateRef a e) <- lift $ mkAggregateRef a guid latestVersion v0
  -- Add to set of aggregates to commit later.
  modify $ \s -> s { aggregateRefsToCommit = (BoxedAggregateRef a' : aggregateRefsToCommit s) }
  -- Return the aggregate.
  return a'

-- | Publish event for an aggregate root.
publishEvent :: (MonadIO m, Serialize e, Typeable a, Typeable e, Aggregate a, Eventable a e, Default a) => AggregateRef a e -> e -> TransactionT e m ()
publishEvent aggregateRef event = TransactionT $ do
  -- Publish event to aggregate itself.
  currentGlobalVersion <- gets tsCurrentVersion
  lift $ AR.publishEvent aggregateRef event currentGlobalVersion
  -- Increment global version number.
  modify $ \s -> s { tsCurrentVersion = currentGlobalVersion + 1 }

-- | Get aggregate root.
getAggregateRoot :: (Default a, Serialize e, Typeable a, Typeable e, Aggregate a, Eventable a e) => GUID -> TransactionT e IO (AggregateRef a e, a)
getAggregateRoot guid = do
  aggregateRef <- getById guid
  aggregate <- lift $ AR.readValue aggregateRef
  return (aggregateRef, aggregate)