{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE ExistentialQuantification #-}

-- | <https://www.fpcomplete.com/user/agocorona/the-hardworking-programmer-ii-practical-backtracking-to-undo-actions>

module Transient.Backtrack (onUndo, undo, retry, undoCut,registerUndo,

-- * generalized versions of backtracking with an extra parameter that gives the reason for going back.
-- Different kinds of backtracking with different reasons can be managed in the same program
onBack, back, forward, backCut,registerBack,

-- * finalization primitives
finish, onFinish, onFinish' ,initFinish , noFinish, killOnFinish ,checkFinalize , FinishReason
) where

import Transient.Base
import Transient.Internals(EventF(..),killChildren,onNothing,runClosure,runContinuation)
import Data.Typeable
import Control.Applicative
import Control.Monad.State
import Unsafe.Coerce
import System.Mem.StableName
import Control.Exception
import Control.Concurrent.STM hiding (retry)
import Data.Maybe

data Backtrack b= Show b =>Backtrack{backtracking :: Maybe b
                                    ,backStack :: [EventF] }
                                    deriving Typeable

-- | assures that backtracking will not go further back
backCut :: (Typeable reason, Show reason) => reason -> TransientIO ()
backCut reason= Transient $ do
     delData $ Backtrack (Just reason)  []
     return $ Just ()

undoCut ::  TransientIO ()
undoCut = backCut ()

-- | the second parameter will be executed when backtracking
{-# NOINLINE onBack #-}
onBack :: (Typeable b, Show b) => TransientIO a -> ( b -> TransientIO a) -> TransientIO a
onBack ac  bac= registerBack (typeof bac) $ Transient $ do
     Backtrack mreason _  <- getData `onNothing` backStateOf (typeof bac)
     runTrans $ case mreason of
                  Nothing     -> ac
                  Just reason -> bac reason
     where
     typeof :: (b -> TransIO a) -> b
     typeof = undefined

onUndo ::  TransientIO a -> TransientIO a -> TransientIO a
onUndo x y= onBack x (\() -> y)


-- | register an action that will be executed when backtracking
{-# NOINLINE registerUndo #-}
registerBack :: (Typeable b, Show b) => b -> TransientIO a -> TransientIO a
registerBack witness f  = Transient $ do
   cont@(EventF _ _ x _ _ _ _ _ _ _ _)  <- get   -- !!> "backregister"

   md <- getData `asTypeOf` (Just <$> backStateOf witness)

   case md of
            Just (bss@(Backtrack b (bs@((EventF _ _ x'  _ _ _ _ _ _ _ _):_)))) ->
               when (isNothing b) $ do
                   addrx  <- addr x
                   addrx' <- addr x'         -- to avoid duplicate backtracking points
                   setData $ if addrx == addrx' then bss else  Backtrack mwit (cont:bs)
            Nothing ->  setData $ Backtrack mwit [cont]

   runTrans f
   where
   mwit= Nothing `asTypeOf` (Just witness)
   addr x = liftIO $ return . hashStableName =<< (makeStableName $! x)


registerUndo :: TransientIO a -> TransientIO a
registerUndo f= registerBack ()  f

-- | restart the flow forward from this point on
forward :: (Typeable b, Show b) => b -> TransIO ()
forward reason= Transient $ do
    Backtrack _ stack <- getData `onNothing`  (backStateOf reason)
    setData $ Backtrack(Nothing `asTypeOf` Just reason)  stack
    return $ Just ()

retry= forward ()

noFinish= forward (FinishReason Nothing)

-- | execute backtracking. It execute the registered actions in reverse order.
--
-- If the backtracking flag is changed the flow proceed  forward from that point on.
--
-- If the backtrack stack is finished or undoCut executed, `undo` will stop.
back :: (Typeable b, Show b) => b -> TransientIO a
back reason = Transient $ do
  bs <- getData  `onNothing`  backStateOf  reason           -- !!>"GOBACK"
  goBackt  bs

  where

  goBackt (Backtrack _ [] )= return Nothing                      -- !!> "END"
  goBackt (Backtrack b (stack@(first : bs)) )= do
        (setData $ Backtrack (Just reason) stack)

        mr <-  runClosure first                                  -- !> "RUNCLOSURE"

        Backtrack back _ <- getData `onNothing`  backStateOf  reason
                                                                 -- !> "END RUNCLOSURE"
        case back of
           Nothing -> case mr of
                   Nothing ->  return empty                      -- !> "FORWARD END"
                   Just x  ->  runContinuation first x           -- !> "FORWARD EXEC"
           justreason -> goBackt $ Backtrack justreason bs       -- !> ("BACK AGAIN",back)

backStateOf :: (Monad m, Show a, Typeable a) => a -> m (Backtrack a)
backStateOf reason= return $ Backtrack (Nothing `asTypeOf` (Just reason)) []

undo ::  TransIO a
undo= back ()

------ finalization

newtype FinishReason= FinishReason (Maybe SomeException) deriving (Typeable, Show)

-- | initialize the event variable for finalization.
-- all the following computations in different threads will share it
-- it also isolate this event from other branches that may have his own finish variable
initFinish= backCut (FinishReason Nothing)

-- | set a computation to be called when the finish event happens
onFinish :: ((Maybe SomeException) ->TransIO ()) -> TransIO ()
onFinish f= onFinish' (return ()) f


-- | set a computation to be called when the finish event happens this only apply for
onFinish' ::TransIO a ->((Maybe SomeException) ->TransIO a) -> TransIO a
onFinish' proc f= proc `onBack`   \(FinishReason reason) ->
    f reason


-- | trigger the event, so this closes all the resources
finish :: Maybe SomeException -> TransIO a
finish reason= back (FinishReason reason)


-- | kill all the processes generated by the parameter when finish event occurs
killOnFinish comp= do
   chs <- liftIO $ newTVarIO []
   onFinish $ const $ liftIO $ killChildren chs   -- !> "killOnFinish event"
   r <- comp
   modify $ \ s -> s{children= chs}
   return r

-- | trigger finish when the stream of data ends
checkFinalize v=
           case v of
              SDone ->  finish Nothing >> stop
              SLast x ->  return x
              SError e -> liftIO ( print e) >> finish  Nothing >> stop
              SMore x -> return x