{-# LANGUAGE DeriveDataTypeable #-}

{-|
This module provides facilities for building 'IO' actions in such a way that, if one 'IO' action in a sequence
throws an exception, the effects of previous actions will be undone.

Here's an example of how to use this module.  Suppose you have two files that, every so often, must be updated
from some external data source.  The new contents for a particular file are retrieved from the external data
source via a function @getNewContents :: 'FilePath' -> 'IO' 'String'@.  @getNewContents@ could throw an exception, as 
could any of the other 'IO' actions that we invoke, and if an exception is thrown while the files are being updated,
we want all changes made so far to either of the files to be rolled back.  Using this module, we could do this thus:

> import Control.MonadTransaction
> import System.IO
> import System.FilePath.Posix
> 
> getNewContents :: FilePath -> IO String
> getNewContents path = ...
> 
> updateFile :: FilePath -> UndoableIO ()
> updateFile path = do -- get current contents
>                      oldContents <- doAction $ readFile path
> 
>                      -- get new contents from external data source
>                      newContents <- doAction $ getNewContents path
> 
>                      -- write new contents
>                      doAction $ writeFile path newContents
> 
>                      -- add an undo action that rewrites the old contents
>                      addUndoer $ writeFile path oldContents
> 
> main :: IO ()
> main = exec $ do updateFile "file1"
>                  updateFile "file2"

In this code, we use the following from this module: the 'UndoableIO' monad, and the functions 'doAction', 'addUndoer',
and 'exec'.  The 'UndoableIO' monad is like a context for combining 'IO' actions together into a transaction.  Inside
'UndoableIO', we invoke 'IO' actions using the function 'doAction'.  When executed, these actions will be sequenced as they would
be if they had been combined as usual inside the 'IO' monad.  When we invoke an 'IO' action whose effect should be undone if an
exception occurs later, we add an \"undoer\" --- that is, an 'IO' action that undoes the effect --- using the function
'addUndoer'.  'UndoableIO' maintains a stack of undoers, and if an exception occurs during execution, the undoers will
be executed in the reverse of the order in which they were added, and then the exception will be rethrown.

So, in @updateFile@ we use 'doAction' to call 'IO' actions that read from and write to the files and retrieve strings from the 
external data source, and at the end we add an undoer that restores the original contents.  In @main@, we combine the 'UndoableIO'
actions returned by two calls to @updateFile@ into one.  We pass the resulting 'UndoableIO' action to 'exec', which
converts it into an 'IO' action.  If an exception occurs when this 'IO' action is executed, then any changes so far made to the
files will be undone using the undoers added by @updateFile@.
-}
module Control.IoTransaction(
-- * IO Transactions
UndoableIO, doAction, addUndoer, exec, rollback, makeUndoable, ManualUndo(ManualUndo),
-- * Internal Stuff
UndoableM(Do), ExceptionalMonad(throwM, catchM), doActionM, addUndoerM, execM, rollbackM, makeUndoableM
) where

import qualified Control.Exception as C
import Data.Typeable.Internal

-- * Impl

{-| A monad for combining other, side-effectual monads in a transaction
that can be rolled back if an exception is thrown.  @m@ must implement
@ExceptionalMonad@.

This type is for implementing transactions and should not be used directly
by code that uses transactions.
-}
data UndoableM m a = Do (m (a, m ()))

runUndoableM :: UndoableM m a -> m (a, m ())
runUndoableM (Do op) = op

execM :: Monad m => UndoableM m a -> m a
execM u = do (val, undo) <- runUndoableM u
             return val

{-| An @ExceptionalMonad@ is a monad that in which 'Control.Exception's can be thrown and caught.

A monad @m@ must implement @ExceptionalMonad@ in order to work with 'UndoableM'.
-}
class Monad m => ExceptionalMonad m where
    -- | Throw an exception.
    throwM :: C.Exception e => e -> m a
    -- | Catch an exception.
    catchM :: C.Exception e => m a -> (e -> m a) -> m a

instance ExceptionalMonad m => Monad (UndoableM m) where
    {-
    We combine two UndoableMs by making a new one that executes the
    first one and then tries to execute the second; if the first fails,
    the UndoableM fails; if the second fails, the UndoableM undoes the first and
    then fails.  The new UndoableM returns the value retuned by the second
    along with an undo action that is a combination of the undo actions of the
    first and second UndoableMs.
    -}
    Do op >>= f  = Do $ do (val, undo)   <- op
                           (val', undo') <- runUndoableM (f val) `catchM` (\e -> undo >> throwM (e :: C.SomeException))
                           return (val', undo >> undo')

    return val = Do $ return (val, return ())

makeUndoableM :: ExceptionalMonad m => m a -> m () -> UndoableM m a
makeUndoableM op undo = Do $ do result <- op
                                return (result, undo)

doActionM :: ExceptionalMonad m => m a -> UndoableM m a
doActionM action = makeUndoableM action (return ())

addUndoerM :: ExceptionalMonad m => m () -> UndoableM m ()
addUndoerM = makeUndoableM (return ())

data ManualUndo = ManualUndo
    deriving (Show, Typeable)

instance C.Exception ManualUndo

rollbackM :: ExceptionalMonad m => UndoableM m ()
rollbackM = makeUndoableM (throwM ManualUndo) (return ()) -- yes, the handler in the ExceptionalMonad instance decl above will catch this exception

{-|
An \"undoable action\" is a wrapper for an 'IO' action (the \"doer\") that combines it with another
'IO' action (the \"undoer\") that undoes the effects of the first one.

Undoable actions are monads, and when sequenced together they act like transactions involving 'IO' operations.
As undoable actions are sequenced together, their doers are also sequenced together and their undoers
are placed into a stack.  When the doers are executed, if one of them throws an exception, the undoers
so far added to the stack are executed in reverse the reverse of the order in which they were added to the stack,
and then the exception is rethrown; no other doers (or undoers) are executed.  If no exception is thrown, none of
the undoers are executed.
-}
type UndoableIO a = UndoableM IO a

instance ExceptionalMonad IO where
    throwM = C.throwIO -- C.throwIO should be used when throwing inside the IO monad
    catchM = C.catch

runUndoable :: UndoableIO a -> IO (a, IO ())
runUndoable = runUndoableM

{-|
Convert an 'UndoableIO' action into an 'IO' action that invokes
the actions (and the undoers if necessary) that were added to the
'UndoableIO' action.
-}
exec :: UndoableIO a -> IO a
exec = execM

-- | Make an undoable action.
makeUndoable :: IO a -- ^ The \"doer\": the action to perform.
             -> IO () -- ^ An \"undoer\": an action that undoes the effect of the other one.
             -> UndoableIO a
makeUndoable = makeUndoableM

{-|
Make an undoable action without any undoer.

This undoable action will not add any undoer to the undoer stack.
-}
doAction :: IO a -- ^ The \"doer\": the action to perform.
         -> UndoableIO a
doAction = doActionM

-- | Add an undoer to the undoer stack.
addUndoer :: IO () -- ^ An \"undoer\": an action that will be added to the undoer stack.
          -> UndoableIO ()
addUndoer = addUndoerM

-- | Stop execution, run the actions on the undoer stack, and throw 'ManualUndo'.
rollback :: UndoableIO ()
rollback = rollbackM

{-
UndoableIO satisfies first two monad laws.  Don't know about third one.

Monad Laws:
1. return x >>= f    ==  f x
2. mv >>= return     ==  mv
3. (mv >>= f) >>= g  ==  mv >>= (\x -> (f x >>= g))

let gK = \(v, u) -> runUio (g v) `C.catch` (\e -> u >> ioError e)
let fK = \(v, u) -> runUio (f v) `C.catch` (\e -> u >> ioError e)
let J = \u -> \(val', undoIo') -> return (val', (u >> undoIo'))

// (Do io >>= f) >>= g  ==  Do io >>= (\x -> (f x >>= g))
\x -> (f x >>= g) == \x -> Do $ runUio (f x) >>= \(val, undoIo) -> gK (val, undoIo) >>= J undoIo

(Do io >>= f) >>= g == Do io' >>= g
    where io' = io >>= \(val, undoIo) -> fK (val, undoIo) >>= J undoIo
          
Do io' >>= g == Do $ (io >>= \(val, undoIo) -> fK (val, undoIo) >>= J undoIo) >>= \(val, undoIo) -> gK (val, undoIo) >>= J undoIo

Do io >>= (\x -> (f x >>= g)) == Do io >>= (\x -> Do $ runUio (f x) >>= \(val, undoIo) -> gK (val, undoIo) >>= J undoIo)
    == Do $ do (v, u) <- io
               (v', u') <- runUio (Do $ runUio (f v) >>= \(val, undoIo) -> gK (val, undoIo) >>= J undoIo) `C.catch` (\e -> undoIo >> ioError e)
               return (v', (u >> u'))
    == Do $ io >>= \(v, u) -> 
-}