-- | Run commands against repositories.
module Data.CQRS.Internal.UnitOfWork
       ( UnitOfWorkT
       , createAggregate
       , createOrLoadAggregate
       , findAggregate
       , loadAggregate
       , publishEvent
       , runUnitOfWorkT
       ) where

-- External imports
import           Control.DeepSeq (NFData)
import           Control.Monad (forM, when)
import           Control.Monad.IO.Class (MonadIO, liftIO)
import           Control.Monad.Trans.Class (MonadTrans(..), lift)
import           Control.Monad.Trans.State (StateT, get, modify, runStateT)
import           Data.Conduit (($$), runResourceT)
import qualified Data.Conduit.List as CL
import           Data.Foldable (forM_)
import           Data.List (find)
import           Data.Typeable (Typeable, cast)

-- Internal imports
import           Data.CQRS.Aggregate
import           Data.CQRS.Eventable (Eventable(..))
import           Data.CQRS.Internal.EventBus
import           Data.CQRS.EventStore.Backend (EventStoreBackend(..), RawSnapshot(..))
import           Data.CQRS.GUID (GUID)
import           Data.CQRS.Internal.AggregateRef (AggregateRef, mkAggregateRef)
import qualified Data.CQRS.Internal.AggregateRef as AR
import           Data.CQRS.Internal.EventStore (EventStore(..))
import qualified Data.CQRS.Internal.EventStore as ES
import           Data.CQRS.Internal.Repository
import           Data.CQRS.PersistedEvent (PersistedEvent(..))
import           Data.CQRS.Serializable

-- | UnitOfWork monad transformer.
newtype UnitOfWorkT e m a = UnitOfWorkT (UnitOfWorkM e m a)
    deriving (Functor, Monad)

instance MonadTrans (UnitOfWorkT e) where
  lift m = UnitOfWorkT $ lift m

-- Existential wrapper for AggregateRef.
data BoxedAggregateRef e =
  forall a . (Typeable a, Typeable e, Serializable e, Aggregate a, Eventable a e) => BoxedAggregateRef (AggregateRef a e)

-- Existential wrapper for event store.
data BoxedEventStore e =
    forall b . (EventStoreBackend b) => BoxedEventStore (EventStore e b)

-- Existential wrapper for event store backend
data BoxedEventStoreBackend =
    forall b . (EventStoreBackend b) => BoxedEventStoreBackend b

-- | UnitOfWork monad.
type UnitOfWorkM e = StateT (UnitOfWork e)

data UnitOfWork e =
    UnitOfWork { txnEventStore :: BoxedEventStore e
               , txnEventStoreBackend :: BoxedEventStoreBackend
               , aggregateRefsToCommit :: [BoxedAggregateRef e]
               }

-- | Run transaction against an event store.
runUnitOfWorkT :: forall b c e . (Typeable e, Serializable e, EventStoreBackend b) => Repository e b -> UnitOfWorkT e IO c -> IO c
runUnitOfWorkT repository (UnitOfWorkT transaction) = do
  (r, writtenEvents) <- withEventStoreBackend repository $ \eventStoreBackend -> do
    esbWithTransaction eventStoreBackend $ do
      -- Run the computation.
      let eventStore = EventStore eventStoreBackend
      (r,s) <- runStateT transaction $ UnitOfWork (BoxedEventStore eventStore) (BoxedEventStoreBackend eventStoreBackend) []
      -- Write out all the aggregates.
      writtenEvents <- forM (aggregateRefsToCommit s) $ \(BoxedAggregateRef a) -> do
        -- Write out accumulated events.
        evs <- AR.readEvents a
        ES.storeEvents eventStore (AR.arGUID a) (AR.arStartVersion a) evs
        -- If we've advanced N events past the last snapshot, we
        -- create a new snapshot.
        forM_ (settingsSnapshotFrequency $ repositorySettings repository) $ \f -> do
          v <- AR.getCurrentVersion a
          when (v - AR.arSnapshotVersion a > f) $ do
            mav <- AR.readValue a
            case mav of
              Just av ->
                  esbWriteSnapshot eventStoreBackend (AR.arGUID a) $ RawSnapshot v $ serialize av
              Nothing ->
                  return ()
        -- Return the written events for accumulator
        return $ evs
      -- Return the value.
      return (r, concat writtenEvents)
  -- Publish
  publishEventsToBus (repositoryEventBus repository) writtenEvents
  -- Return command return value
  return r

-- Get an aggregate ref by GUID.
getById :: forall a e . (Typeable a, Typeable e, Serializable e, Aggregate a, Eventable a e) => GUID -> UnitOfWorkT e IO (AggregateRef a e)
getById guid = UnitOfWorkT $ do
  -- Check through list to see if we've given out a reference to the aggregate before.
  aggregateRefs <- fmap aggregateRefsToCommit get
  case find (\(BoxedAggregateRef a) -> AR.arGUID a == guid) aggregateRefs of
    Just (BoxedAggregateRef a) ->
      case cast a of
        Just (a' :: AggregateRef a e) -> return a'
        Nothing ->
          -- This cast could only really fail if there are duplicate GUIDs for
          -- different types of aggregates/events.
          fail $ concat ["Duplicate GUID ", show guid, "!" ]
    Nothing -> do
      getByIdFromEventStore guid

-- Get the latest snapshot from database, filling in a default
-- if a) no snapshot exists, or b) snapshot state was not decodable.
getLatestSnapshot :: forall a e . (Typeable a, Typeable e, Serializable e, Aggregate a) => GUID -> UnitOfWorkM e IO (Int, Maybe a)
getLatestSnapshot guid = do
  (BoxedEventStoreBackend eventStoreBackend) <- fmap txnEventStoreBackend get
  r <- liftIO $ esbGetLatestSnapshot eventStoreBackend guid
  case r of
    Just (RawSnapshot v a) -> do
      case deserialize a :: Maybe a of
        Just a' -> return (v, Just a')
        Nothing -> return (0, Nothing)
    Nothing -> do
      return (0, Nothing)

-- Retrieve aggregate from event store.
getByIdFromEventStore :: forall a e . (Typeable a, Typeable e, Serializable e, Aggregate a, Eventable a e) => GUID -> UnitOfWorkM e IO (AggregateRef a e)
getByIdFromEventStore guid = do
  (BoxedEventStore es) <- fmap txnEventStore get
  -- Get latest snapshot (if any).
  (v0,ma0) <- getLatestSnapshot guid
  -- Get events.
  events <- lift $ runResourceT $ (ES.retrieveEvents es guid v0 $$ CL.consume)
  let latestVersion = maximum $ (:) v0 (map peSequenceNumber events)
  -- Build the aggregate state from all the events.
  let a = foldl (\a0 e -> applyEvent a0 $ peEvent e) ma0 events
  -- Make the aggregate itself
  (a' :: AggregateRef a e) <- lift $ mkAggregateRef a guid latestVersion v0
  -- Add to set of aggregates to commit later.
  modify $ \s -> s { aggregateRefsToCommit = (BoxedAggregateRef a' : aggregateRefsToCommit s) }
  -- Return the aggregate.
  return $ a'

-- | Publish event for an aggregate root.
publishEvent :: (MonadIO m, Serializable e, Typeable a, Typeable e, Aggregate a, Eventable a e, NFData a, NFData e) => AggregateRef a e -> e -> UnitOfWorkT e m ()
publishEvent aggregateRef event = UnitOfWorkT $ do
  lift $ AR.publishEvent aggregateRef event

-- | Find aggregate root.
findAggregate :: (Serializable e, Typeable a, Typeable e, Aggregate a, Eventable a e) => GUID -> UnitOfWorkT e IO (Maybe (AggregateRef a e, a))
findAggregate guid = do
  aggregateRef <- getById guid
  aggregate <- lift $ AR.readValue aggregateRef
  case aggregate of
    Nothing -> return Nothing
    Just a -> return $ Just (aggregateRef, a)

-- | Load aggregate root. The aggregate root must exist.
loadAggregate :: (Serializable e, Typeable a, Typeable e, Aggregate a, Eventable a e) => GUID -> UnitOfWorkT e IO (AggregateRef a e, a)
loadAggregate guid = do
  mAggregate <- findAggregate guid
  case mAggregate of
    Nothing -> fail $ show $ "Aggregate with GUID " ++ show guid ++ " does not exist"
    Just a -> return a

-- | Add aggregate root. The aggregate root will be created upon
-- transaction commit.
createAggregate :: (Serializable e, Typeable a, Typeable e, Aggregate a, Eventable a e) => GUID -> UnitOfWorkT e IO (AggregateRef a e)
createAggregate guid = do
  aggregateRef <- getById guid
  aggregate <- lift $ AR.readValue aggregateRef
  case aggregate of
    Nothing -> do
      return aggregateRef
    Just _ -> fail $ show $ "Aggregate with GUID " ++ show guid ++ " already exists"

-- | Create or load aggregate. The aggregate root will be created (if necessary)
-- upon transaction commit.
createOrLoadAggregate :: (Serializable e, Typeable a, Typeable e, Aggregate a, Eventable a e) => GUID -> UnitOfWorkT e IO (AggregateRef a e)
createOrLoadAggregate guid = getById guid