module Data.CQRS.Internal.AggregateRef
       ( AggregateRef
       , arGUID
       , arSnapshotVersion
       , arStartVersion
       , getCurrentVersion
       , mkAggregateRef
       , publishEvent
       , readEvents
       , readValue
       ) where

import           Control.Monad (liftM)
import           Control.Monad.IO.Class (MonadIO, liftIO)
import           Data.CQRS.Eventable (Eventable(..))
import           Data.CQRS.GUID (GUID)
import           Data.CQRS.PersistedEvent (PersistedEvent(..))
import           Data.Foldable (toList)
import           Data.IORef (IORef, modifyIORef, newIORef, readIORef)
import           Data.Sequence (Seq, (|>))
import qualified Data.Sequence as S
import           Data.Typeable (Typeable)

-- | Aggregate root reference.
data AggregateRef a e =
  AggregateRef { arValue :: IORef a
               , arEvents :: IORef (Seq (PersistedEvent e))
               , arGUID :: GUID
               , arStartVersion :: Int
               , arSnapshotVersion :: Int
               }
  deriving (Typeable)

-- | Make aggregate
mkAggregateRef :: (MonadIO m) => a -> GUID -> Int -> Int -> m (AggregateRef a e)
mkAggregateRef a guid originatingVersion snapshotVersion = do
  a' <- liftIO $ newIORef a
  e' <- liftIO $ newIORef S.empty
  return $ AggregateRef a' e' guid originatingVersion snapshotVersion

-- | Publish event to aggregate.
publishEvent :: (MonadIO m, Eventable a e) => AggregateRef a e -> e -> Int -> m ()
publishEvent aggregateRef event gv = liftIO $ do
  -- Apply event to aggregate state.
  modifyIORef (arValue aggregateRef) $ applyEvent event
  -- Add event to aggregate.
  modifyIORef (arEvents aggregateRef) $ \events ->
    events |> (PersistedEvent (arGUID aggregateRef) event (arStartVersion aggregateRef + 1 + S.length events) gv)

-- | Read aggregate events.
readEvents :: (MonadIO m) => AggregateRef a e -> m [PersistedEvent e]
readEvents = liftM toList . liftIO . readIORef . arEvents

-- | Read aggregate state.
readValue :: (MonadIO m) => AggregateRef a e -> m a
readValue = liftIO . readIORef . arValue

-- | Get the current version of the aggregate in aggregate ref.
getCurrentVersion :: (MonadIO m) => AggregateRef a e -> m Int
getCurrentVersion a = do
  nevs <- liftM S.length $ liftIO $ readIORef $ arEvents a
  return $ nevs + (arStartVersion a)