module Simulation.Aivika.Internal.Cont
       (ContCancellation(..),
        ContId,
        ContEvent(..),
        Cont(..),
        ContParams,
        FrozenCont,
        newContId,
        contSignal,
        contCancellationInitiated,
        contCancellationInitiate,
        contCancellationInitiating,
        contCancellationActivated,
        contCancellationBind,
        contCancellationConnect,
        contPreemptionBegun,
        contPreemptionBegin,
        contPreemptionBeginning,
        contPreemptionEnd,
        contPreemptionEnding,
        invokeCont,
        runCont,
        rerunCont,
        spawnCont,
        contParallel,
        contParallel_,
        catchCont,
        finallyCont,
        throwCont,
        resumeCont,
        resumeECont,
        reenterCont,
        freezeCont,
        freezeContReentering,
        unfreezeCont,
        substituteCont,
        contCanceled,
        contAwait,
        transferCont,
        traceCont) where
import Data.IORef
import Data.Array
import Data.Array.IO.Safe
import Data.Monoid
import Control.Exception
import Control.Monad
import Control.Monad.Trans
import Control.Applicative
import Debug.Trace
import Simulation.Aivika.Internal.Specs
import Simulation.Aivika.Internal.Parameter
import Simulation.Aivika.Internal.Simulation
import Simulation.Aivika.Internal.Dynamics
import Simulation.Aivika.Internal.Event
import Simulation.Aivika.Signal
data ContCancellation = CancelTogether
                        
                      | CancelChildAfterParent
                        
                      | CancelParentAfterChild
                        
                      | CancelInIsolation
                        
data ContId =
  ContId { contCancellationInitiatedRef :: IORef Bool,
           contCancellationActivatedRef :: IORef Bool,
           contPreemptionCountRef :: IORef Int,
           contSignalSource :: SignalSource ContEvent
         }
instance Eq ContId where
  x == y = contCancellationInitiatedRef x == contCancellationInitiatedRef y  
data ContEvent = ContCancellationInitiating
                 
               | ContPreemptionBeginning
                 
               | ContPreemptionEnding
                 
               deriving (Eq, Ord, Show)
newContId :: Simulation ContId
newContId =
  Simulation $ \r ->
  do r1 <- newIORef False
     r2 <- newIORef False
     r3 <- newIORef 0
     s  <- invokeSimulation r newSignalSource
     return ContId { contCancellationInitiatedRef = r1,
                     contCancellationActivatedRef = r2,
                     contPreemptionCountRef = r3,
                     contSignalSource = s
                   }
contSignal :: ContId -> Signal ContEvent
contSignal = publishSignal . contSignalSource
contCancellationInitiating :: ContId -> Signal ()
contCancellationInitiating =
  filterSignal_ (ContCancellationInitiating ==) . contSignal
contCancellationInitiated :: ContId -> Event Bool
contCancellationInitiated x =
  Event $ \p -> readIORef (contCancellationInitiatedRef x)
contCancellationActivated :: ContId -> IO Bool
contCancellationActivated =
  readIORef . contCancellationActivatedRef
contCancellationDeactivate :: ContId -> IO ()
contCancellationDeactivate x =
  writeIORef (contCancellationActivatedRef x) False
contCancellationBind :: ContId -> [ContId] -> Event DisposableEvent
contCancellationBind x ys =
  Event $ \p ->
  do hs1 <- forM ys $ \y ->
       invokeEvent p $
       handleSignal (contCancellationInitiating x) $ \_ ->
       contCancellationInitiate y
     hs2 <- forM ys $ \y ->
       invokeEvent p $
       handleSignal (contCancellationInitiating y) $ \_ ->
       contCancellationInitiate x
     return $ mconcat hs1 <> mconcat hs2
contCancellationConnect :: ContId
                           
                           -> ContCancellation
                           
                           -> ContId
                           
                           -> Event DisposableEvent
                           
contCancellationConnect parent cancellation child =
  Event $ \p ->
  do let m1 =
           handleSignal (contCancellationInitiating parent) $ \_ ->
           contCancellationInitiate child
         m2 =
           handleSignal (contCancellationInitiating child) $ \_ ->
           contCancellationInitiate parent
     h1 <- 
       case cancellation of
         CancelTogether -> invokeEvent p m1
         CancelChildAfterParent -> invokeEvent p m1
         CancelParentAfterChild -> return mempty
         CancelInIsolation -> return mempty
     h2 <-
       case cancellation of
         CancelTogether -> invokeEvent p m2
         CancelChildAfterParent -> return mempty
         CancelParentAfterChild -> invokeEvent p m2
         CancelInIsolation -> return mempty
     return $ h1 <> h2
contCancellationInitiate :: ContId -> Event ()
contCancellationInitiate x =
  Event $ \p ->
  do f <- readIORef (contCancellationInitiatedRef x)
     unless f $
       do writeIORef (contCancellationInitiatedRef x) True
          writeIORef (contCancellationActivatedRef x) True
          invokeEvent p $
            triggerSignal (contSignalSource x) ContCancellationInitiating
contPreemptionBegin :: ContId -> Event ()
contPreemptionBegin x =
  Event $ \p ->
  do f <- readIORef (contCancellationInitiatedRef x)
     unless f $
       do n <- readIORef (contPreemptionCountRef x)
          let n' = n + 1
          n' `seq` writeIORef (contPreemptionCountRef x) n'
          when (n == 0) $
            invokeEvent p $
            triggerSignal (contSignalSource x) ContPreemptionBeginning
contPreemptionEnd :: ContId -> Event ()
contPreemptionEnd x =
  Event $ \p ->
  do f <- readIORef (contCancellationInitiatedRef x)
     unless f $
       do n <- readIORef (contPreemptionCountRef x)
          let n' = n  1
          n' `seq` writeIORef (contPreemptionCountRef x) n'
          when (n' == 0) $
            invokeEvent p $
            triggerSignal (contSignalSource x) ContPreemptionEnding
contPreemptionBeginning :: ContId -> Signal ()
contPreemptionBeginning =
  filterSignal_ (ContPreemptionBeginning ==) . contSignal
contPreemptionEnding :: ContId -> Signal ()
contPreemptionEnding =
  filterSignal_ (ContPreemptionEnding ==) . contSignal
contPreemptionBegun :: ContId -> Event Bool
contPreemptionBegun x =
  Event $ \p ->
  do n <- readIORef (contPreemptionCountRef x)
     return (n > 0)
newtype Cont a = Cont (ContParams a -> Event ())
data ContParams a = 
  ContParams { contCont :: a -> Event (), 
               contAux  :: ContParamsAux }
data ContParamsAux =
  ContParamsAux { contECont :: SomeException -> Event (),
                  contCCont :: () -> Event (),
                  contId :: ContId,
                  contCancelRef :: IORef Bool,
                  contCatchFlag  :: Bool }
instance Monad Cont where
  return  = returnC
  m >>= k = bindC m k
instance ParameterLift Cont where
  liftParameter = liftPC
instance SimulationLift Cont where
  liftSimulation = liftSC
instance DynamicsLift Cont where
  liftDynamics = liftDC
instance EventLift Cont where
  liftEvent = liftEC
instance Functor Cont where
  fmap = liftM
instance Applicative Cont where
  pure = return
  (<*>) = ap
instance MonadIO Cont where
  liftIO = liftIOC 
invokeCont :: ContParams a -> Cont a -> Event ()
invokeCont p (Cont m) = m p
cancelCont :: Point -> ContParams a -> IO ()
cancelCont p c =
  do contCancellationDeactivate (contId $ contAux c)
     invokeEvent p $ (contCCont $ contAux c) ()
returnC :: a -> Cont a
returnC a = 
  Cont $ \c ->
  Event $ \p ->
  do z <- contCanceled c
     if z 
       then cancelCont p c
       else invokeEvent p $ contCont c a
                          
bindC :: Cont a -> (a -> Cont b) -> Cont b
bindC (Cont m) k =
  Cont $ \c ->
  Event $ \p ->
  do z <- contCanceled c
     if z 
       then cancelCont p c
       else invokeEvent p $ m $ 
            let cont a = invokeCont c (k a)
            in c { contCont = cont }
callCont :: (a -> Cont b) -> a -> ContParams b -> Event ()
callCont k a c =
  Event $ \p ->
  do z <- contCanceled c
     if z 
       then cancelCont p c
       else invokeEvent p $ invokeCont c (k a)
catchCont :: Exception e => Cont a -> (e -> Cont a) -> Cont a
catchCont (Cont m) h = 
  Cont $ \c0 ->
  Event $ \p ->
  do let c = c0 { contAux = (contAux c0) { contCatchFlag = True } }
     z <- contCanceled c
     if z 
       then cancelCont p c
       else invokeEvent p $ m $
            let econt e0 =
                  case fromException e0 of
                    Just e  -> callCont h e c
                    Nothing -> (contECont . contAux $ c) e0
            in c { contAux = (contAux c) { contECont = econt } }
               
finallyCont :: Cont a -> Cont b -> Cont a
finallyCont (Cont m) (Cont m') = 
  Cont $ \c0 ->
  Event $ \p ->
  do let c = c0 { contAux = (contAux c0) { contCatchFlag = True } }
     z <- contCanceled c
     if z 
       then cancelCont p c
       else invokeEvent p $ m $
            let cont a   = 
                  Event $ \p ->
                  invokeEvent p $ m' $
                  let cont b = contCont c a
                  in c { contCont = cont }
                econt e  =
                  Event $ \p ->
                  invokeEvent p $ m' $
                  let cont b = (contECont . contAux $ c) e
                  in c { contCont = cont }
                ccont () = 
                  Event $ \p ->
                  invokeEvent p $ m' $
                  let cont b  = (contCCont . contAux $ c) ()
                      econt e = (contCCont . contAux $ c) ()
                  in c { contCont = cont,
                         contAux  = (contAux c) { contECont = econt } }
            in c { contCont = cont,
                   contAux  = (contAux c) { contECont = econt,
                                            contCCont = ccont } }
throwCont :: IOException -> Cont a
throwCont = liftIO . throw
runCont :: Cont a
           
           -> (a -> Event ())
           
           -> (SomeException -> Event ())
           
           -> (() -> Event ())
           
           -> ContId
           
           -> Bool
           
           -> Event ()
runCont (Cont m) cont econt ccont cid catchFlag = 
  m ContParams { contCont = cont,
                 contAux  = 
                   ContParamsAux { contECont = econt,
                                   contCCont = ccont,
                                   contId = cid,
                                   contCancelRef = contCancellationActivatedRef cid, 
                                   contCatchFlag  = catchFlag } }
liftPC :: Parameter a -> Cont a
liftPC (Parameter m) = 
  Cont $ \c ->
  Event $ \p ->
  if contCatchFlag . contAux $ c
  then liftIOWithCatch (m $ pointRun p) p c
  else liftIOWithoutCatch (m $ pointRun p) p c
liftSC :: Simulation a -> Cont a
liftSC (Simulation m) = 
  Cont $ \c ->
  Event $ \p ->
  if contCatchFlag . contAux $ c
  then liftIOWithCatch (m $ pointRun p) p c
  else liftIOWithoutCatch (m $ pointRun p) p c
     
liftDC :: Dynamics a -> Cont a
liftDC (Dynamics m) =
  Cont $ \c ->
  Event $ \p ->
  if contCatchFlag . contAux $ c
  then liftIOWithCatch (m p) p c
  else liftIOWithoutCatch (m p) p c
     
liftEC :: Event a -> Cont a
liftEC (Event m) =
  Cont $ \c ->
  Event $ \p ->
  if contCatchFlag . contAux $ c
  then liftIOWithCatch (m p) p c
  else liftIOWithoutCatch (m p) p c
     
liftIOC :: IO a -> Cont a
liftIOC m =
  Cont $ \c ->
  Event $ \p ->
  if contCatchFlag . contAux $ c
  then liftIOWithCatch m p c
  else liftIOWithoutCatch m p c
  
liftIOWithoutCatch :: IO a -> Point -> ContParams a -> IO ()
liftIOWithoutCatch m p c =
  do z <- contCanceled c
     if z
       then cancelCont p c
       else do a <- m
               invokeEvent p $ contCont c a
liftIOWithCatch :: IO a -> Point -> ContParams a -> IO ()
liftIOWithCatch m p c =
  do z <- contCanceled c
     if z
       then cancelCont p c
       else do aref <- newIORef undefined
               eref <- newIORef Nothing
               catch (m >>= writeIORef aref) 
                 (writeIORef eref . Just)
               e <- readIORef eref
               case e of
                 Nothing -> 
                   do a <- readIORef aref
                      
                      invokeEvent p $ contCont c a
                 Just e ->
                   
                   invokeEvent p $ (contECont . contAux) c e
resumeCont :: ContParams a -> a -> Event ()
resumeCont c a = 
  Event $ \p ->
  do z <- contCanceled c
     if z
       then cancelCont p c
       else invokeEvent p $ contCont c a
resumeECont :: ContParams a -> SomeException -> Event ()
resumeECont c e = 
  Event $ \p ->
  do z <- contCanceled c
     if z
       then cancelCont p c
       else invokeEvent p $ (contECont $ contAux c) e
contCanceled :: ContParams a -> IO Bool
contCanceled c = readIORef $ contCancelRef $ contAux c
contParallel :: [(Cont a, ContId)]
                
                
                
                -> Cont [a]
contParallel xs =
  Cont $ \c ->
  Event $ \p ->
  do let n = length xs
         worker =
           do results   <- newArray_ (1, n) :: IO (IOArray Int a)
              counter   <- newIORef 0
              catchRef  <- newIORef Nothing
              hs <- invokeEvent p $
                    contCancellationBind (contId $ contAux c) $
                    map snd xs
              let propagate =
                    Event $ \p ->
                    do n' <- readIORef counter
                       when (n' == n) $
                         do invokeEvent p $ disposeEvent hs  
                            f1 <- contCanceled c
                            f2 <- readIORef catchRef
                            case (f1, f2) of
                              (False, Nothing) ->
                                do rs <- getElems results
                                   invokeEvent p $ resumeCont c rs
                              (False, Just e) ->
                                invokeEvent p $ resumeECont c e
                              (True, _) ->
                                cancelCont p c
                  cont i a =
                    Event $ \p ->
                    do modifyIORef counter (+ 1)
                       writeArray results i a
                       invokeEvent p propagate
                  econt e =
                    Event $ \p ->
                    do modifyIORef counter (+ 1)
                       r <- readIORef catchRef
                       case r of
                         Nothing -> writeIORef catchRef $ Just e
                         Just e' -> return ()  
                       invokeEvent p propagate
                  ccont e =
                    Event $ \p ->
                    do modifyIORef counter (+ 1)
                       
                       invokeEvent p propagate
              forM_ (zip [1..n] xs) $ \(i, (x, cid)) ->
                invokeEvent p $
                runCont x (cont i) econt ccont cid (contCatchFlag $ contAux c)
     z <- contCanceled c
     if z
       then cancelCont p c
       else if n == 0
            then invokeEvent p $ contCont c []
            else worker
contParallel_ :: [(Cont a, ContId)]
                 
                 
                 
                 -> Cont ()
contParallel_ xs =
  Cont $ \c ->
  Event $ \p ->
  do let n = length xs
         worker =
           do counter   <- newIORef 0
              catchRef  <- newIORef Nothing
              hs <- invokeEvent p $
                    contCancellationBind (contId $ contAux c) $
                    map snd xs
              let propagate =
                    Event $ \p ->
                    do n' <- readIORef counter
                       when (n' == n) $
                         do invokeEvent p $ disposeEvent hs  
                            f1 <- contCanceled c
                            f2 <- readIORef catchRef
                            case (f1, f2) of
                              (False, Nothing) ->
                                invokeEvent p $ resumeCont c ()
                              (False, Just e) ->
                                invokeEvent p $ resumeECont c e
                              (True, _) ->
                                cancelCont p c
                  cont i a =
                    Event $ \p ->
                    do modifyIORef counter (+ 1)
                       
                       invokeEvent p propagate
                  econt e =
                    Event $ \p ->
                    do modifyIORef counter (+ 1)
                       r <- readIORef catchRef
                       case r of
                         Nothing -> writeIORef catchRef $ Just e
                         Just e' -> return ()  
                       invokeEvent p propagate
                  ccont e =
                    Event $ \p ->
                    do modifyIORef counter (+ 1)
                       
                       invokeEvent p propagate
              forM_ (zip [1..n] xs) $ \(i, (x, cid)) ->
                invokeEvent p $
                runCont x (cont i) econt ccont cid (contCatchFlag $ contAux c)
     z <- contCanceled c
     if z
       then cancelCont p c
       else if n == 0
            then invokeEvent p $ contCont c ()
            else worker
rerunCont :: Cont a -> ContId -> Cont a
rerunCont x cid =
  Cont $ \c ->
  Event $ \p ->
  do let worker =
           do hs <- invokeEvent p $
                    contCancellationBind (contId $ contAux c) [cid]
              let cont a  =
                    Event $ \p ->
                    do invokeEvent p $ disposeEvent hs  
                       invokeEvent p $ resumeCont c a
                  econt e =
                    Event $ \p ->
                    do invokeEvent p $ disposeEvent hs  
                       invokeEvent p $ resumeECont c e
                  ccont e =
                    Event $ \p ->
                    do invokeEvent p $ disposeEvent hs  
                       cancelCont p c
              invokeEvent p $
                runCont x cont econt ccont cid (contCatchFlag $ contAux c)
     z <- contCanceled c
     if z
       then cancelCont p c
       else worker
spawnCont :: ContCancellation -> Cont () -> ContId -> Cont ()
spawnCont cancellation x cid =
  Cont $ \c ->
  Event $ \p ->
  do let worker =
           do hs <- invokeEvent p $
                    contCancellationConnect
                    (contId $ contAux c) cancellation cid
              let cont a  =
                    Event $ \p ->
                    do invokeEvent p $ disposeEvent hs  
                       
                  econt e =
                    Event $ \p ->
                    do invokeEvent p $ disposeEvent hs  
                       invokeEvent p $ throwEvent e  
                  ccont e =
                    Event $ \p ->
                    do invokeEvent p $ disposeEvent hs  
                       
              invokeEvent p $
                enqueueEvent (pointTime p) $
                runCont x cont econt ccont cid False
              invokeEvent p $
                resumeCont c ()
     z <- contCanceled c
     if z
       then cancelCont p c
       else worker
newtype FrozenCont a =
  FrozenCont { unfreezeCont :: Event (Maybe (ContParams a))
               
             }
freezeCont :: ContParams a -> Event (FrozenCont a)
freezeCont c =
  Event $ \p ->
  do rh <- newIORef Nothing
     rc <- newIORef $ Just c
     h <- invokeEvent p $
          handleSignal (contCancellationInitiating $
                        contId $ contAux c) $ \e ->
          Event $ \p ->
          do h <- readIORef rh
             case h of
               Nothing ->
                 error "The handler was lost: freezeCont."
               Just h ->
                 do invokeEvent p $ disposeEvent h
                    c <- readIORef rc
                    case c of
                      Nothing -> return ()
                      Just c  ->
                        do writeIORef rc Nothing
                           invokeEvent p $
                             enqueueEvent (pointTime p) $
                             Event $ \p ->
                             do z <- contCanceled c
                                when z $ cancelCont p c
     writeIORef rh (Just h)
     return $
       FrozenCont $
       Event $ \p ->
       do invokeEvent p $ disposeEvent h
          c <- readIORef rc
          writeIORef rc Nothing
          return c
freezeContReentering :: ContParams a -> a -> Event () -> Event (FrozenCont a)
freezeContReentering c a m =
  Event $ \p ->
  do rh <- newIORef Nothing
     rc <- newIORef $ Just c
     h <- invokeEvent p $
          handleSignal (contCancellationInitiating $
                        contId $ contAux c) $ \e ->
          Event $ \p ->
          do h <- readIORef rh
             case h of
               Nothing ->
                 error "The handler was lost: freezeContReentering."
               Just h ->
                 do invokeEvent p $ disposeEvent h
                    c <- readIORef rc
                    case c of
                      Nothing -> return ()
                      Just c  ->
                        do writeIORef rc Nothing
                           invokeEvent p $
                             enqueueEvent (pointTime p) $
                             Event $ \p ->
                             do z <- contCanceled c
                                when z $ cancelCont p c
     writeIORef rh (Just h)
     return $
       FrozenCont $
       Event $ \p ->
       do invokeEvent p $ disposeEvent h
          c <- readIORef rc
          writeIORef rc Nothing
          case c of
            Nothing -> return Nothing
            z @ (Just c) ->
              do f <- invokeEvent p $
                      contPreemptionBegun $
                      contId $ contAux c
                 if not f
                   then return z
                   else do let c = c { contCont = \a -> m }
                           invokeEvent p $ sleepCont c a
                           return Nothing
reenterCont :: ContParams a -> a -> Event ()
reenterCont c a =
  Event $ \p ->
  do f <- invokeEvent p $
          contPreemptionBegun $
          contId $ contAux c
     if not f
       then invokeEvent p $
            enqueueEvent (pointTime p) $
            Event $ \p ->
            do f <- invokeEvent p $
                    contPreemptionBegun $
                    contId $ contAux c
               if not f
                 then invokeEvent p $
                      resumeCont c a
                 else invokeEvent p $
                      sleepCont c a
       else invokeEvent p $
            sleepCont c a
sleepCont :: ContParams a -> a -> Event ()
sleepCont c a =
  Event $ \p ->
  do rh <- newIORef Nothing
     h  <- invokeEvent p $
           handleSignal (contSignal $
                         contId $ contAux c) $ \e ->
           Event $ \p ->
           do h <- readIORef rh
              case h of
                Nothing ->
                  error "The handler was lost: sleepCont."
                Just h ->
                  do invokeEvent p $ disposeEvent h
                     case e of
                       ContCancellationInitiating ->
                         invokeEvent p $
                         enqueueEvent (pointTime p) $
                         Event $ \p ->
                         do z <- contCanceled c
                            when z $ cancelCont p c
                       ContPreemptionEnding ->
                         invokeEvent p $
                         enqueueEvent (pointTime p) $
                         reenterCont c a
                       ContPreemptionBeginning ->
                         error "The computation was already preempted: sleepCont."
     writeIORef rh (Just h)
substituteCont :: ContParams a -> (a -> Event ()) -> ContParams a
substituteCont c m = c { contCont = m }
contAwait :: Signal a -> Cont a
contAwait signal =
  Cont $ \c ->
  Event $ \p ->
  do c <- invokeEvent p $ freezeCont c
     r <- newIORef Nothing
     h <- invokeEvent p $
          handleSignal signal $ 
          \a -> Event $ 
                \p -> do x <- readIORef r
                         case x of
                           Nothing ->
                             error "The signal was lost: contAwait."
                           Just x ->
                             do invokeEvent p $ disposeEvent x
                                c <- invokeEvent p $ unfreezeCont c
                                case c of
                                  Nothing -> return ()
                                  Just c  ->
                                    invokeEvent p $ reenterCont c a
     writeIORef r $ Just h          
transferCont :: Cont () -> Cont a
transferCont x =
  Cont $ \c ->
  Event $ \p ->
  do let worker =
           do let cid   = contId $ contAux c
                  cont  = return
                  econt = throwEvent
                  ccont = return
              when (contCatchFlag $ contAux c) $
                error "Cannot be combined with the exception handling: unsafeTransferCont"
              invokeEvent p $
                runCont x cont econt ccont cid False
     z <- contCanceled c
     if z
       then cancelCont p c
       else worker
traceCont :: String -> Cont a -> Cont a
traceCont message (Cont m) =
  Cont $ \c ->
  Event $ \p ->
  do z <- contCanceled c
     if z
       then cancelCont p c
       else trace ("t = " ++ show (pointTime p) ++ ": " ++ message) $
            invokeEvent p $ m c