-- |
-- Module:     Control.ContStuff.Classes
-- Copyright:  (c) 2010 Ertugrul Soeylemez
-- License:    BSD3
-- Maintainer: Ertugrul Soeylemez <es@ertes.de>
-- Stability:  experimental
--
-- This module implements the various effect classes supported by
-- contstuff.

{-# LANGUAGE MultiParamTypeClasses, TypeFamilies #-}

module Control.ContStuff.Classes
    ( -- * Effect classes
      -- ** Abortion
      Abortable(..),
      -- ** Call with current continuation
      CallCC(..), Label, labelCC, goto,
      -- ** Exceptions
      HasExceptions(..), bracket, bracket_, catch, finally, forbid,
                         handle, raiseUnless, raiseWhen, require,
      -- ** Lifting
      Transformer(..),
      LiftBase(..), io,
      -- ** Running
      Runnable(..),
      -- ** State
      Stateful(..), getField, modify, modifyField, modifyFieldLazy,
                    modifyLazy,
      -- ** Logging support (writers)
      Writable(..)
    )
    where

import Control.Applicative
import Control.Monad
import Prelude hiding (catch)


--------------
-- Abortion --
--------------

-- | Monads supporting abortion.

class Abortable m where
    -- | End result of the computation.
    type Result m

    -- | Ignore current continuation and abort.
    abort :: Result m -> m a


------------
-- CallCC --
------------

-- | Monads supporting *call-with-current-continuation* (aka callCC).

class CallCC m where
    -- | Call with current continuation.
    callCC :: ((a -> m b) -> m a) -> m a


newtype Label m a = Label (a -> Label m a -> m ())


-- | Capture the current continuation for later use.

labelCC :: (Applicative m, CallCC m) => a -> m (a, Label m a)
labelCC x = callCC $ \k -> pure (x, Label $ curry k)


-- | Jump to a label.

goto :: Label m a -> a -> m ()
goto lk@(Label k) x = k x lk


----------------
-- Exceptions --
----------------

-- | Monads with exception support.

class HasExceptions m where
    -- | The exception type.
    type Exception m

    -- | Raise an exception.
    raise :: Exception m -> m a

    -- | Run computation catching exceptions.
    try :: m a -> m (Either (Exception m) a)


-- | Get a resource, run a computation, then release the resource, even
-- if an exception is raised:
--
-- > bracket acquire release use
--
-- Please note that this function behaves slightly different from the
-- usual 'E.bracket'.  If both the user and the releaser throw an
-- exception, the user exception is significant.

bracket :: (HasExceptions m, Monad m) => m res -> (res -> m b) -> (res -> m a) -> m a
bracket acquire release use = do
    resource <- acquire
    result <- try (use resource)
    try (release resource)
    either raise return result


-- | Initialize, then run, then clean up safely, even if an exception is
-- raised:
--
-- > bracket_ init cleanup run
--
-- Please note that this function behaves slightly different from the
-- usual 'E.bracket_'.  If both the user and the releaser throw an
-- exception, the user exception is significant.

bracket_ :: (HasExceptions m, Monad m) => m a -> m b -> m c -> m c
bracket_ init cleanup run = do
    init
    result <- try run
    try cleanup
    either raise return result


-- | Catch exceptions using an exception handler.

catch :: (HasExceptions m, Monad m) => m a -> (Exception m -> m a) -> m a
catch c h = try c >>= either h return


-- | Run a final computation regardless of whether an exception was
-- raised.

finally :: (HasExceptions m, Monad m) => m a -> m b -> m a
finally c d = try c >>= either (\exp -> d >> raise exp) (\x -> d >> return x)


-- | Fail (in the sense of the given transformer), if the given
-- underlying computation returns 'True'.

forbid ::
    ( Exception (t m) ~ (), HasExceptions (t m),
      Monad m, Monad (t m), Transformer t ) =>
    m Bool -> t m ()
forbid = raiseWhen () . lift


-- | Catch exceptions using an exception handler (flip 'catch').

handle :: (HasExceptions m, Monad m) => (Exception m -> m a) -> m a -> m a
handle h c = try c >>= either h return


-- | Throw given exception, if the given computation returns 'False'.

raiseUnless :: (HasExceptions m, Monad m) => Exception m -> m Bool -> m ()
raiseUnless ex c = do b <- c; unless b (raise ex)


-- | Throw given exception, if the given computation returns 'True'.

raiseWhen :: (HasExceptions m, Monad m) => Exception m -> m Bool -> m ()
raiseWhen ex c = do b <- c; when b (raise ex)


-- | Fail (in the sense of the given transformer), if the given
-- underlying computation returns 'False'.

require ::
    ( Exception (t m) ~ (), HasExceptions (t m),
      Monad m, Monad (t m), Transformer t ) =>
    m Bool -> t m ()
require = raiseUnless () . lift


-------------
-- Lifting --
-------------

-- | Monads, which support lifting base monad computations.

class LiftBase m where
    -- | Base monad of @m@.
    type Base m :: * -> *

    -- | Promote a base monad computation.
    base :: Base m a -> m a


-- | Handy alias for lifting 'IO' computations.

io :: (LiftBase m, Base m ~ IO) => Base m a -> m a
io = base


-------------
-- Running --
-------------

-- | Every monad transformer @t@ that supports transforming @t m a@ to
-- @m a@ can be an instance of this class.

class Runnable t r m a where
    -- | Arguments needed to run.
    type Argument t r m a

    -- | Run the transformer.
    runT :: Argument t r m a -> t m a -> m r


-----------
-- State --
-----------

-- | Stateful monads.
--
-- Minimal complete definition: 'StateOf', 'get' and 'putLazy'.

class Stateful m where
    -- | State type of @m@.
    type StateOf m

    -- | Get the current state.
    get :: m (StateOf m)

    -- | Set the current state and force it.
    put :: StateOf m -> m ()
    put x = x `seq` putLazy x

    -- | Set the current state, but don't force it.
    putLazy :: StateOf m -> m ()


-- | Get a certain field.

getField :: (Functor m, Stateful m) => (StateOf m -> a) -> m a
getField = (<$> get)


-- | Apply a function to the current state.

modify :: (Monad m, Stateful m) => (StateOf m -> StateOf m) -> m ()
modify f = liftM f get >>= put


-- | Get a field and modify the state.

modifyField :: (Monad m, Stateful m) =>
               (StateOf m -> a) -> (a -> StateOf m) -> m ()
modifyField accessor f = liftM (f . accessor) get >>= put


-- | Get a field and modify the state.  Lazy version.

modifyFieldLazy :: (Monad m, Stateful m) =>
                   (StateOf m -> a) -> (a -> StateOf m) -> m ()
modifyFieldLazy accessor f = liftM (f . accessor) get >>= putLazy


-- | Apply a function to the current state.  Lazy version.

modifyLazy :: (Monad m, Stateful m) => (StateOf m -> StateOf m) -> m ()
modifyLazy f = liftM f get >>= putLazy


--------------------------
-- Monad transformation --
--------------------------

-- | The monad transformer class.  Lifting computations one level down
-- the monad stack, or stated differently promoting a computation of the
-- underlying monad to the transformer.

class Transformer t where
    -- | Promote a monadic computation to the transformer.
    lift :: Monad m => m a -> t m a


-------------
-- Logging --
-------------

-- | Monads with support for logging.  Traditionally these are called
-- *writer monads*.

class Writable m w where
    -- | Log a value.
    tell :: w -> m ()