-- | Run transactions against event stores.
module Data.CQRS.Transaction
       ( TransactionT
       , runTransactionT
       , publishEvent
       , getAggregateRoot
       ) where

import           Control.Monad (forM_)
import           Control.Monad.IO.Class (MonadIO)
import           Control.Monad.Trans.Class (MonadTrans(..), lift)
import           Control.Monad.Trans.State (StateT, get, modify, runStateT)
import           Data.Binary (Binary)
import           Data.CQRS.Internal.AggregateRef (AggregateRef, mkAggregateRef)
import qualified Data.CQRS.Internal.AggregateRef as AR
import           Data.CQRS.Event (Event(..))
import           Data.CQRS.EventStore (EventStore, withTransaction)
import qualified Data.CQRS.EventStore as ES
import           Data.CQRS.GUID (GUID)
import           Data.Default (Default(..))
import           Data.List (find)
import           Data.Typeable (Typeable, cast)

-- | Transaction monad transformer.
newtype TransactionT m a = TransactionT (Transaction m a)
                           deriving (Functor, Monad)

instance MonadTrans TransactionT where
  lift m = TransactionT $ lift m

-- Existential wrapper for AggregateRef.
data BoxedAggregateRef =
  forall a e . (Typeable a, Typeable e, Event e a, Default a, Binary e) => BoxedAggregateRef (AggregateRef a e)

-- | Transaction monad itself.
type Transaction = StateT TransactionState

data TransactionState =
  TransactionState { eventStore :: EventStore
                   , aggregateRefsToCommit :: [BoxedAggregateRef]
                   }

-- | Run transaction against an event store.
runTransactionT :: EventStore -> TransactionT IO a -> IO a
runTransactionT eventStore_ (TransactionT transaction) = do
  withTransaction eventStore_ $ do
    -- Run the computation.
    (r,s) <- runStateT transaction s0
    -- Write out all the accumulated aggregates
    forM_ (aggregateRefsToCommit s) $ \(BoxedAggregateRef a) -> do
      es <- AR.readEvents a
      ES.storeEvents eventStore_ (AR.arGUID a) (AR.arStartVersion a) es
    -- Return the value.
    return r
  where s0 = TransactionState eventStore_ [ ]

-- | Get an aggregate ref by GUID.
getById :: forall a e . (Typeable a, Typeable e, Event e a, Default a, Binary e) => GUID a -> TransactionT IO (AggregateRef a e)
getById guid = TransactionT $ do
  -- Check through list to see if we've given out a reference to the aggregate before.
  aggregateRefs <- fmap aggregateRefsToCommit get
  case find (\(BoxedAggregateRef a) ->
              case cast $ AR.arGUID a of
                Just (guid' :: GUID a) -> guid == guid'
                Nothing -> False) aggregateRefs of
    Just (BoxedAggregateRef a) ->
      case cast a of
        Just (a' :: AggregateRef a e) -> return a'
        Nothing ->
          -- This cast could only really fail if there are duplicate GUIDs for
          -- different types of aggregates/events.
          fail $ concat ["Duplicate GUID ", show guid, "!" ]
    Nothing -> do
      getByIdFromEventStore guid

-- Retrieve aggregate from event store.
getByIdFromEventStore :: forall a e . (Typeable a, Typeable e, Event e a, Default a, Binary e) => GUID a -> Transaction IO (AggregateRef a e)
getByIdFromEventStore guid = do
  es <- fmap eventStore get
  (latestVersion :: Int, events :: [e]) <-
    lift $ (ES.retrieveEvents es $ guid :: IO (Int,[e]))
  -- Build the aggregate state from all the events.
  let a = foldr applyEvent def events
  -- Make the aggregate itself
  (a' :: AggregateRef a e) <- lift $ mkAggregateRef a guid latestVersion
  -- Add to set of aggregates to commit later.
  modify $ \s -> s { aggregateRefsToCommit = (BoxedAggregateRef a' : aggregateRefsToCommit s) }
  -- Return the aggregate.
  return a'

-- | Publish event for an aggregate root.
publishEvent :: (MonadIO m, Event e a) => AggregateRef a e -> e -> TransactionT m ()
publishEvent aggregateRef event = lift $ AR.publishEvent aggregateRef event

-- | Get aggregate root.
getAggregateRoot :: (Default a, Binary e, Event e a, Typeable a, Typeable e) => GUID a -> TransactionT IO (AggregateRef a e, a)
getAggregateRoot guid = do
  aggregateRef <- getById guid
  aggregate <- lift $ AR.readValue aggregateRef
  return (aggregateRef, aggregate)