{-# LANGUAGE RankNTypes #-}

{- |
  Module      :  Control.Monad.Tx
  Copyright   :  (c) Matt Morrow, 2009
  License     :  BSD3
  Maintainer  :  <morrow@moonpatio.com>
  Stability   :  experimental
  Portability :  portable

  A transactional state monad.
-}


module Control.Monad.Tx (
   TxM,Tx
  ,TxStat(..)
  ,runTxM,runTxM_
  ,begin,abort,dirty,rollback,commit
  ,get,gets,set,modify
  ,test0,runTest0
) where

import Control.Monad (when)

-----------------------------------------------------------------------------

{- |
> runTest0 :: [(TxStat String String, Int)]
> runTest0 = fmap (runTxM_ (begin test0)) [0..4]

> ghci> mapM_ print runTest0
> (Abort Nothing,0)
> (Abort (Just "1"),1)
> (Dirty (Just "2"),99)
> (Rollback "rollback!",3)
> (Commit "wooo",99)
-}
runTest0 :: [(TxStat String String, Int)]
runTest0 = fmap (runTxM_ (begin test0)) [0..4]

{- |
> test0 :: Tx o Int String String -> TxM o Int ()
> test0 tx = do
>   s <- get
>   set 99
>   case s of
>     0 -> return   ()
>     1 -> abort    tx (Just (show s))
>     2 -> dirty    tx (Just (show s))
>     3 -> rollback tx  "rollback!"
>     _ -> commit   tx  "wooo"
-}
test0 :: Tx o Int String String -> TxM o Int ()
test0 tx = do
  s <- get
  set 99
  case s of
    0 -> return   ()
    1 -> abort    tx (Just (show s))
    2 -> dirty    tx (Just (show s))
    3 -> rollback tx  "rollback!"
    _ -> commit   tx  "wooo"

-----------------------------------------------------------------------------

-- | The transaction monad. A State monad, with transactional state.
newtype TxM o s a   = TxM {unTxM :: s -> (s -> a -> o) -> o}

-- | A transaction handle.
newtype Tx o s e a  = Tx  {unTx :: Lbl o s (TxStat e a)}

-- | Transaction Status.
data TxStat e a
  = Begin
  | Abort (Maybe e)       -- ^ Reverted  state, returned an error.
  | Dirty (Maybe e)       -- ^ Committed state, returned an error.
  | Rollback a            -- ^ Reverted  state, returned a result.
  | Commit a              -- ^ Committed state, returned a result.
  deriving(Eq,Ord,Read,Show)

runTxM :: TxM o s a -> s -> (s -> a -> o) -> o
runTxM (TxM g) s k = g s k

runTxM_ :: TxM (a,s) s a -> s -> (a, s)
runTxM_ (TxM g) s = g s (\s a -> (a, s))

-- | Begin a transaction. @begin@ takes a function
--  which represents this transaction.
begin :: (Tx o s e a -> TxM o s ()) -> TxM o s (TxStat e a)
begin f = withRollback (\abort -> do
            (stat, lbl) <- checkpoint
            when (isBegin stat)
              (f (Tx lbl) >> abort (Abort Nothing))
            return stat)

-- | Revert state, return an error.
abort    :: Tx o s e a -> Maybe e -> TxM o s ()

-- | Commit state, return an error.
dirty    :: Tx o s e a -> Maybe e -> TxM o s ()

-- | Revert state, return a result.
rollback :: Tx o s e a ->       a -> TxM o s ()

-- | Commit state, return a result.
commit   :: Tx o s e a ->       a -> TxM o s ()

abort     (Tx lbl) e = jump lbl (Abort e)
dirty     (Tx lbl) e = jump lbl (Dirty e)
rollback  (Tx lbl) a = jump lbl (Rollback a)
commit    (Tx lbl) a = jump lbl (Commit a)

get :: TxM o s s
get = TxM (\s k -> k s s)

gets :: (s -> a) -> TxM o s a
gets f = TxM (\s k -> k s (f s))

set :: s -> TxM o s ()
set s = TxM (\_ k -> k s ())

modify :: (s -> s) -> TxM o s ()
modify f = TxM (\s k -> k (f s) ())

-----------------------------------------------------------------------------

newtype Lbl o s a   = Lbl {unLbl :: (a, Lbl o s a) -> TxM o s ()}

isBegin :: TxStat e a -> Bool
isBegin Begin = True
isBegin _     = False

checkpoint :: TxM o s (TxStat e a, Lbl o s (TxStat e a))
checkpoint = withCommit (\commit ->
              withRollback (\rollback ->
  let go (Begin,      lbl) = error "TODO: nested transactions?"
      go (Abort e,    lbl) = rollback  (Abort e,     lbl)
      go (Dirty e,    lbl) = commit    (Dirty e,     lbl)
      go (Rollback a, lbl) = rollback  (Rollback a,  lbl)
      go (Commit a,   lbl) = commit    (Commit a,    lbl)
  in return (Begin, Lbl go)))

jump :: Lbl o s a -> a -> TxM o s b
jump (Lbl k) a = k (a, Lbl k) >> undefined

withCommit :: ((forall b. a -> TxM o s b) -> TxM o s a) -> TxM o s a
withCommit f = TxM (\s k -> unTxM (f (\a -> TxM (\s _ -> k s a))) s k)

withRollback :: ((forall b. a -> TxM o s b) -> TxM o s a) -> TxM o s a
withRollback f = TxM (\s k -> unTxM (f (\a -> TxM (\_ _ -> k s a))) s k)

instance Functor (TxM o s) where
  fmap f (TxM g) = TxM (\s k -> g s (\s a -> k s (f a)))
instance Monad (TxM o s) where
  return a = TxM (\s k -> k s a)
  TxM g >>= f = TxM (\s k -> g s (\s a -> unTxM (f a) s k))

-----------------------------------------------------------------------------