{-# LANGUAGE MultiParamTypeClasses, FlexibleInstances #-}
module Acme.TimeMachine.Undoable (
        Suspension,
        Undoable(..),
        evalUndoable,
        execUndoable,
        suspend,
        resume,
        undo
    )
    where

import Control.Applicative (Applicative(..))
import Control.Monad.State.Class (MonadState(..))
import Acme.TimeMachine.Suspension

-- | The undo-able stateful computation monad.
data Undoable s a = Undoable { getUndoable :: Suspension s -> (Suspension s, a) }

-- | Run an undo-able computation and return both the resulting state and the result.
--
-- > >>> runUndoable (do { modify (+3); return 5 }) 0
-- > (3, 5)
runUndoable :: Undoable s a -> s -> (s, a)
runUndoable (Undoable f) s = case f (mkSuspension s) of ~(Suspension s _, r) -> (s, r)

-- | Run an undo-able computation and return the result.
--
-- > >>> evalUndoable (do { return 5 }) 0
-- > 5
evalUndoable :: Undoable s a -> s -> a
evalUndoable (Undoable f) s = case f (mkSuspension s) of ~(_, r) -> r

-- | Run an undo-able computation and return the resulting state.
--
-- > >>> execUndoable (do { modify (+3) }) 0
-- > 3
execUndoable :: Undoable s a -> s -> s
execUndoable (Undoable f) s = case f (mkSuspension s) of ~(Suspension s _, _) -> s

-- | Save the history of a computation, to be later loaded with 'resume'.
suspend :: Undoable s (Suspension s)
suspend = Undoable $ \l -> (l, l)

-- | Load the history of a computation, saved by 'suspend'.
resume :: Suspension s -> Undoable s ()
resume l = Undoable $ \_ -> (l, ())

-- | Rollback the latest addition to the computation's history.
undo :: Undoable s ()
undo = Undoable $ \(Suspension _ l) -> (l, ())

instance Functor (Undoable s) where
    fmap f (Undoable x) = Undoable $ \l -> case x l of ~(l, r) -> (l, f r)

instance Applicative (Undoable s) where
    pure x = Undoable $ \l -> (l, x)
    (Undoable f) <*> (Undoable k) = Undoable $ \l -> case f l of ~(l, f) -> case k l of ~(l, k) -> (l, f k)

instance Monad (Undoable s) where
    return x = Undoable $ \l -> (l, x)
    (Undoable k) >>= f = Undoable $ \l -> case k l of ~(l, r) -> getUndoable (f r) l
    (Undoable k) >> (Undoable f) = Undoable $ \l -> case k l of ~(l, _) -> f l

instance MonadState s (Undoable s) where
    get = Undoable $ \l@(Suspension s _) -> (l, s)
    put s = Undoable $ \l -> (Suspension s l, ())