module Simulation.Aivika.Trans.Agent
       (Agent,
        AgentState,
        newAgent,
        newState,
        newSubstate,
        selectedState,
        selectedStateChanged,
        selectedStateChanged_,
        selectState,
        stateAgent,
        stateParent,
        addTimeout,
        addTimer,
        setStateActivation,
        setStateDeactivation,
        setStateTransition) where
import Control.Monad
import Simulation.Aivika.Trans.Ref.Base
import Simulation.Aivika.Trans.DES
import Simulation.Aivika.Trans.Internal.Specs
import Simulation.Aivika.Trans.Internal.Simulation
import Simulation.Aivika.Trans.Internal.Event
import Simulation.Aivika.Trans.Signal
data Agent m = Agent { agentModeRef            :: Ref m AgentMode,
                       agentStateRef           :: Ref m (Maybe (AgentState m)), 
                       agentStateChangedSource :: SignalSource m (Maybe (AgentState m)) }
data AgentState m = AgentState { stateAgent         :: Agent m,
                                 
                                 stateParent        :: Maybe (AgentState m),
                                 
                                 stateActivateRef   :: Ref m (Event m ()),
                                 stateDeactivateRef :: Ref m (Event m ()),
                                 stateTransitRef    :: Ref m (Event m (Maybe (AgentState m))),
                                 stateVersionRef    :: Ref m Int }
                  
data AgentMode = CreationMode
               | TransientMode
               | ProcessingMode
                      
instance MonadDES m => Eq (Agent m) where
  
  x == y = agentStateRef x == agentStateRef y
  
instance MonadDES m => Eq (AgentState m) where
  
  x == y = stateVersionRef x == stateVersionRef y
fullPath :: AgentState m -> [AgentState m] -> [AgentState m]
fullPath st acc =
  case stateParent st of
    Nothing  -> st : acc
    Just st' -> fullPath st' (st : acc)
partitionPath :: MonadDES m => [AgentState m] -> [AgentState m] -> ([AgentState m], [AgentState m])
partitionPath path1 path2 =
  case (path1, path2) of
    (h1 : t1, [h2]) | h1 == h2 -> 
      (reverse path1, path2)
    (h1 : t1, h2 : t2) | h1 == h2 -> 
      partitionPath t1 t2
    _ ->
      (reverse path1, path2)
findPath :: MonadDES m => Maybe (AgentState m) -> AgentState m -> ([AgentState m], [AgentState m])
findPath Nothing target = ([], fullPath target [])
findPath (Just source) target
  | stateAgent source /= stateAgent target =
    error "Different agents: findPath."
  | otherwise =
    partitionPath path1 path2
  where
    path1 = fullPath source []
    path2 = fullPath target []
traversePath :: MonadDES m => Maybe (AgentState m) -> AgentState m -> Event m ()
traversePath source target =
  let (path1, path2) = findPath source target
      agent = stateAgent target
      activate st p   = invokeEvent p =<< (invokeEvent p $ readRef (stateActivateRef st))
      deactivate st p = invokeEvent p =<< (invokeEvent p $ readRef (stateDeactivateRef st))
      transit st p    = invokeEvent p =<< (invokeEvent p $ readRef (stateTransitRef st))
      continue st p   = invokeEvent p $ traversePath (Just target) st
  in Event $ \p ->
       unless (null path1 && null path2) $
       do invokeEvent p $ writeRef (agentModeRef agent) TransientMode
          forM_ path1 $ \st ->
            do invokeEvent p $ writeRef (agentStateRef agent) (Just st)
               deactivate st p
               
               invokeEvent p $ modifyRef (stateVersionRef st) (1 +)
          forM_ path2 $ \st ->
            do invokeEvent p $ writeRef (agentStateRef agent) (Just st)
               activate st p
          st' <- transit target p
          case st' of
            Nothing ->
              do invokeEvent p $ writeRef (agentModeRef agent) ProcessingMode
                 triggerAgentStateChanged p agent
            Just st' ->
              continue st' p
addTimeout :: MonadDES m => AgentState m -> Double -> Event m () -> Event m ()
addTimeout st dt action =
  Event $ \p ->
  do v <- invokeEvent p $ readRef (stateVersionRef st)
     let m1 = Event $ \p ->
           do v' <- invokeEvent p $ readRef (stateVersionRef st)
              when (v == v') $
                invokeEvent p action
         m2 = enqueueEvent (pointTime p + dt) m1
     invokeEvent p m2
addTimer :: MonadDES m => AgentState m -> Event m Double -> Event m () -> Event m ()
addTimer st dt action =
  Event $ \p ->
  do v <- invokeEvent p $ readRef (stateVersionRef st)
     let m1 = Event $ \p ->
           do v' <- invokeEvent p $ readRef (stateVersionRef st)
              when (v == v') $
                do invokeEvent p m2
                   invokeEvent p action
         m2 = Event $ \p ->
           do dt' <- invokeEvent p dt
              invokeEvent p $ enqueueEvent (pointTime p + dt') m1
     invokeEvent p m2
newState :: MonadDES m => Agent m -> Simulation m (AgentState m)
newState agent =
  do aref <- newRef $ return ()
     dref <- newRef $ return ()
     tref <- newRef $ return Nothing
     vref <- newRef 0
     return AgentState { stateAgent = agent,
                         stateParent = Nothing,
                         stateActivateRef = aref,
                         stateDeactivateRef = dref,
                         stateTransitRef = tref,
                         stateVersionRef = vref }
newSubstate :: MonadDES m => AgentState m -> Simulation m (AgentState m)
newSubstate parent =
  do let agent = stateAgent parent
     aref <- newRef $ return ()
     dref <- newRef $ return ()
     tref <- newRef $ return Nothing
     vref <- newRef 0
     return AgentState { stateAgent = agent,
                         stateParent = Just parent,
                         stateActivateRef= aref,
                         stateDeactivateRef = dref,
                         stateTransitRef = tref,
                         stateVersionRef = vref }
newAgent :: MonadDES m => Simulation m (Agent m)
newAgent =
  do modeRef  <- newRef CreationMode
     stateRef <- newRef Nothing
     stateChangedSource <- newSignalSource
     return Agent { agentModeRef = modeRef,
                    agentStateRef = stateRef, 
                    agentStateChangedSource = stateChangedSource }
selectedState :: MonadDES m => Agent m -> Event m (Maybe (AgentState m))
selectedState agent = readRef (agentStateRef agent)
                   
selectState :: MonadDES m => AgentState m -> Event m ()
selectState st =
  Event $ \p ->
  do let agent = stateAgent st
     mode <- invokeEvent p $ readRef (agentModeRef agent)
     case mode of
       CreationMode ->
         do x0 <- invokeEvent p $ readRef (agentStateRef agent)
            invokeEvent p $ traversePath x0 st
       TransientMode ->
         error $
         "Use the setStateTransition function to define " ++
         "the transition state: activateState."
       ProcessingMode ->
         do x0 @ (Just st0) <- invokeEvent p $ readRef (agentStateRef agent)
            invokeEvent p $ traversePath x0 st
setStateActivation :: MonadDES m => AgentState m -> Event m () -> Event m ()
setStateActivation st action =
  writeRef (stateActivateRef st) action
  
setStateDeactivation :: MonadDES m => AgentState m -> Event m () -> Event m ()
setStateDeactivation st action =
  writeRef (stateDeactivateRef st) action
  
setStateTransition :: MonadDES m => AgentState m -> Event m (Maybe (AgentState m)) -> Event m ()
setStateTransition st action =
  writeRef (stateTransitRef st) action
  
triggerAgentStateChanged :: MonadDES m => Point m -> Agent m -> m ()
triggerAgentStateChanged p agent =
  do st <- invokeEvent p $ readRef (agentStateRef agent)
     invokeEvent p $ triggerSignal (agentStateChangedSource agent) st
selectedStateChanged :: Agent m -> Signal m (Maybe (AgentState m))
selectedStateChanged agent =
  publishSignal (agentStateChangedSource agent)
selectedStateChanged_ :: MonadDES m => Agent m -> Signal m ()
selectedStateChanged_ agent =
  mapSignal (const ()) $ selectedStateChanged agent