{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE CPP #-}

{-|

The 'Thunk' API provides a way to defer potentially recursive computations:

* 'thunk' is lazy in its argument, and does not run it directly
* the first 'force' triggers execution of the action passed to thunk
* that action is run at most once, and returuns a list of other thunks
* 'force' forces these thunks as well, and does not return before all of them have executed
* Cycles are allowed: The action passed to 'thunk' may return a thunk whose action returns the first thunk.

The implementation is hopefully thread safe: Even if multiple threads force or
kick related thunks, all actions are still run at most once, and all calls to
force terminate (no deadlock).

>>> :set -XRecursiveDo
>>> :{
  mdo t1 <- thunk $ putStrLn "Hello" >> pure [t1, t2]
      t2 <- thunk $ putStrLn "World" >> pure [t1, t2]
      putStrLn "Nothing happened so far, but now:"
      force t1
      putStrLn "No more will happen now:"
      force t1
      putStrLn "That's it"
:}
Nothing happened so far, but now:
Hello
World
No more will happen now:
That's it

-}
module System.IO.RecThunk
    ( Thunk
    , thunk
    , doneThunk
    , force
    )
where


-- I want to test this code with dejafu, without carrying it as a dependency
-- of the main library. So here is a bit of CPP to care for that.

#ifdef DEJAFU

#define Ctxt   MonadConc m =>
#define Thunk_  (Thunk m)
#define ResolvingState_  (ResolvingState m)
#define KickedThunk_  (KickedThunk m)
#define ThreadId_  (ThreadId m)
#define IORef_ IORef m
#define MVar_  MVar m
#define M      m

import Control.Concurrent.Classy hiding (wait)

#else

#define Ctxt
#define Thunk_  Thunk
#define ResolvingState_  ResolvingState
#define KickedThunk_  KickedThunk
#define ThreadId_  ThreadId
#define IORef_ IORef
#define MVar_  MVar
#define M      IO

import Control.Concurrent.MVar
import Control.Concurrent
import Data.IORef

#endif



-- | An @IO@ action that is to be run at most once
newtype Thunk_ = Thunk (MVar_ (Either (M [Thunk_]) KickedThunk_))
data ResolvingState_ = NotStarted | ProcessedBy ThreadId_ (MVar_ ()) | Done
-- | A 'Thunk' that is being evaluated
data KickedThunk_ = KickedThunk (MVar_ [KickedThunk_]) (MVar_ ResolvingState_)

-- | Create a new 'Thunk' from an 'IO' action.
--
-- The 'IO' action may return other thunks that should be forced together
-- whenver this thunk is forced (in arbitrary order)
thunk :: Ctxt M [Thunk_] -> M Thunk_
thunk :: IO [Thunk] -> IO Thunk
thunk IO [Thunk]
act = MVar (Either (IO [Thunk]) KickedThunk) -> Thunk
Thunk forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. a -> IO (MVar a)
newMVar (forall a b. a -> Either a b
Left IO [Thunk]
act)

-- | A Thunk that that already is done.
--
-- Equivalent to @do {t <- thunk (pure []); force t; pure t }@
doneThunk :: Ctxt M Thunk_
doneThunk :: IO Thunk
doneThunk = do
    MVar [KickedThunk]
mv_ts <- forall a. a -> IO (MVar a)
newMVar []
    MVar ResolvingState
mv_s <- forall a. a -> IO (MVar a)
newMVar ResolvingState
Done
    MVar (Either (IO [Thunk]) KickedThunk) -> Thunk
Thunk forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. a -> IO (MVar a)
newMVar (forall a b. b -> Either a b
Right (MVar [KickedThunk] -> MVar ResolvingState -> KickedThunk
KickedThunk MVar [KickedThunk]
mv_ts MVar ResolvingState
mv_s))

-- Recursively explores the thunk, and kicks the execution
-- May return before before execution is done (if started by another thread)
kick :: Ctxt Thunk_ -> M KickedThunk_
kick :: Thunk -> IO KickedThunk
kick (Thunk MVar (Either (IO [Thunk]) KickedThunk)
t) = forall a. MVar a -> IO a
takeMVar MVar (Either (IO [Thunk]) KickedThunk)
t forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    Left IO [Thunk]
act -> do
        MVar [KickedThunk]
mv_thunks <- forall a. IO (MVar a)
newEmptyMVar
        MVar ResolvingState
mv_state <- forall a. a -> IO (MVar a)
newMVar ResolvingState
NotStarted
        let kt :: KickedThunk
kt = MVar [KickedThunk] -> MVar ResolvingState -> KickedThunk
KickedThunk MVar [KickedThunk]
mv_thunks MVar ResolvingState
mv_state
        forall a. MVar a -> a -> IO ()
putMVar MVar (Either (IO [Thunk]) KickedThunk)
t (forall a b. b -> Either a b
Right KickedThunk
kt)

        [Thunk]
ts <- IO [Thunk]
act
        [KickedThunk]
kts <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Thunk -> IO KickedThunk
kick [Thunk]
ts
        forall a. MVar a -> a -> IO ()
putMVar MVar [KickedThunk]
mv_thunks [KickedThunk]
kts
        forall (f :: * -> *) a. Applicative f => a -> f a
pure KickedThunk
kt

    -- Thread was already kicked, nothing to do
    Right KickedThunk
kt -> do
        forall a. MVar a -> a -> IO ()
putMVar MVar (Either (IO [Thunk]) KickedThunk)
t (forall a b. b -> Either a b
Right KickedThunk
kt)
        forall (f :: * -> *) a. Applicative f => a -> f a
pure KickedThunk
kt

wait :: Ctxt KickedThunk_ -> M ()
wait :: KickedThunk -> IO ()
wait (KickedThunk MVar [KickedThunk]
mv_deps MVar ResolvingState
mv_s) = do
    ThreadId
my_id <- IO ThreadId
myThreadId
    ResolvingState
s <- forall a. MVar a -> IO a
takeMVar MVar ResolvingState
mv_s
    case ResolvingState
s of
        -- Thunk and all dependences are done
        ResolvingState
Done -> forall a. MVar a -> a -> IO ()
putMVar MVar ResolvingState
mv_s ResolvingState
s
        -- Thunk is being processed by a higher priority thread, so simply wait
        ProcessedBy ThreadId
other_id MVar ()
done_mv | ThreadId
other_id forall a. Ord a => a -> a -> Bool
< ThreadId
my_id -> do
            forall a. MVar a -> a -> IO ()
putMVar MVar ResolvingState
mv_s ResolvingState
s
            forall a. MVar a -> IO a
readMVar MVar ()
done_mv
        -- Thunk is already being processed by this thread, ignore
        ProcessedBy ThreadId
other_id MVar ()
_done_mv | ThreadId
other_id forall a. Eq a => a -> a -> Bool
== ThreadId
my_id -> do
            forall a. MVar a -> a -> IO ()
putMVar MVar ResolvingState
mv_s ResolvingState
s
            forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
        -- Thunk is not yet processed, or processed by a lower priority thread, so process now
        ResolvingState
_ -> do
            MVar ()
done_mv <- forall a. IO (MVar a)
newEmptyMVar
            forall a. MVar a -> a -> IO ()
putMVar MVar ResolvingState
mv_s (ThreadId -> MVar () -> ResolvingState
ProcessedBy ThreadId
my_id MVar ()
done_mv)
            [KickedThunk]
ts <- forall a. MVar a -> IO a
readMVar MVar [KickedThunk]
mv_deps
            forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ KickedThunk -> IO ()
wait [KickedThunk]
ts
            -- Mark kicked thunk as done
            ResolvingState
_ <- forall a. MVar a -> a -> IO a
swapMVar MVar ResolvingState
mv_s ResolvingState
Done
            -- Wake up waiting threads
            forall a. MVar a -> a -> IO ()
putMVar MVar ()
done_mv ()

-- | Force the execution of the thunk. If it has been forced already, it will
-- do nothing. Else it will run the action passed to 'thunk', force thunks
-- returned by that action, and not return until all of them are forced.
force :: Ctxt Thunk_ -> M ()
force :: Thunk -> IO ()
force Thunk
t = do
    KickedThunk
rt <- Thunk -> IO KickedThunk
kick Thunk
t
    KickedThunk -> IO ()
wait KickedThunk
rt