{-# LANGUAGE CPP #-}
#if __GLASGOW_HASKELL__ >= 702
{-# LANGUAGE Safe #-}
{-# LANGUAGE DeriveGeneric #-}
#endif
#if __GLASGOW_HASKELL__ >= 710 && __GLASGOW_HASKELL__ < 802
{-# LANGUAGE AutoDeriveTypeable #-}
#endif
-----------------------------------------------------------------------------
-- |
-- Module      :  Control.Monad.Trans.Writer.CPS
-- Copyright   :  (c) Daniel Mendler 2016,
--                (c) Andy Gill 2001,
--                (c) Oregon Graduate Institute of Science and Technology, 2001
-- License     :  BSD-style (see the file LICENSE)
--
-- Maintainer  :  R.Paterson@city.ac.uk
-- Stability   :  experimental
-- Portability :  portable
--
-- The strict 'WriterT' monad transformer, which adds collection of
-- outputs (such as a count or string output) to a given monad.
--
-- This monad transformer provides only limited access to the output
-- during the computation. For more general access, use
-- "Control.Monad.Trans.State" instead.
--
-- This version builds its output strictly and uses continuation-passing-style
-- to achieve constant space usage. This transformer can be used as a
-- drop-in replacement for "Control.Monad.Trans.Writer.Strict".
-----------------------------------------------------------------------------

module Control.Monad.Trans.Writer.CPS (
    -- * The Writer monad
    Writer,
    writer,
    runWriter,
    execWriter,
    mapWriter,
    -- * The WriterT monad transformer
    WriterT,
    writerT,
    runWriterT,
    execWriterT,
    mapWriterT,
    -- * Writer operations
    tell,
    listen,
    listens,
    pass,
    censor,
    -- * Lifting other operations
    liftCallCC,
    liftCatch,
  ) where

import Control.Applicative
import Control.Monad
import Control.Monad.Fix
import Control.Monad.IO.Class
import Control.Monad.Trans.Class
import Control.Monad.Signatures
import Data.Functor.Identity

#if !(MIN_VERSION_base(4,8,0))
import Data.Monoid
#endif

#if MIN_VERSION_base(4,9,0)
import qualified Control.Monad.Fail as Fail
#endif
#if __GLASGOW_HASKELL__ >= 704
import GHC.Generics
#endif

-- ---------------------------------------------------------------------------
-- | A writer monad parameterized by the type @w@ of output to accumulate.
--
-- The 'return' function produces the output 'mempty', while @m '>>=' k@
-- combines the outputs of the subcomputations using 'mappend' (also
-- known as @<>@):
--
-- <<images/bind-WriterT.svg>>
--
type Writer w = WriterT w Identity

-- | Construct a writer computation from a (result, output) pair.
-- (The inverse of 'runWriter'.)
writer :: (Monoid w, Monad m) => (a, w) -> WriterT w m a
writer :: forall w (m :: * -> *) a.
(Monoid w, Monad m) =>
(a, w) -> WriterT w m a
writer (a
a, w
w') = forall w (m :: * -> *) a. (w -> m (a, w)) -> WriterT w m a
WriterT forall a b. (a -> b) -> a -> b
$ \ w
w ->
    let wt :: w
wt = w
w forall a. Monoid a => a -> a -> a
`mappend` w
w' in w
wt seq :: forall a b. a -> b -> b
`seq` forall (m :: * -> *) a. Monad m => a -> m a
return (a
a, w
wt)
{-# INLINE writer #-}

-- | Unwrap a writer computation as a (result, output) pair.
-- (The inverse of 'writer'.)
runWriter :: (Monoid w) => Writer w a -> (a, w)
runWriter :: forall w a. Monoid w => Writer w a -> (a, w)
runWriter = forall a. Identity a -> a
runIdentity forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall w (m :: * -> *) a. Monoid w => WriterT w m a -> m (a, w)
runWriterT
{-# INLINE runWriter #-}

-- | Extract the output from a writer computation.
--
-- * @'execWriter' m = 'snd' ('runWriter' m)@
execWriter :: (Monoid w) => Writer w a -> w
execWriter :: forall w a. Monoid w => Writer w a -> w
execWriter = forall a. Identity a -> a
runIdentity forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) w a.
(Monad m, Monoid w) =>
WriterT w m a -> m w
execWriterT
{-# INLINE execWriter #-}

-- | Map both the return value and output of a computation using
-- the given function.
--
-- * @'runWriter' ('mapWriter' f m) = f ('runWriter' m)@
mapWriter :: (Monoid w, Monoid w') =>
    ((a, w) -> (b, w')) -> Writer w a -> Writer w' b
mapWriter :: forall w w' a b.
(Monoid w, Monoid w') =>
((a, w) -> (b, w')) -> Writer w a -> Writer w' b
mapWriter (a, w) -> (b, w')
f = forall (n :: * -> *) w w' (m :: * -> *) a b.
(Monad n, Monoid w, Monoid w') =>
(m (a, w) -> n (b, w')) -> WriterT w m a -> WriterT w' n b
mapWriterT (forall a. a -> Identity a
Identity forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a, w) -> (b, w')
f forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Identity a -> a
runIdentity)
{-# INLINE mapWriter #-}

-- ---------------------------------------------------------------------------
-- | A writer monad parameterized by:
--
--   * @w@ - the output to accumulate.
--
--   * @m@ - The inner monad.
--
-- The 'return' function produces the output 'mempty', while @m '>>=' k@
-- combines the outputs of the subcomputations using 'mappend' (also
-- known as @<>@):
--
-- <<images/bind-WriterT.svg>>
--
newtype WriterT w m a = WriterT { forall w (m :: * -> *) a. WriterT w m a -> w -> m (a, w)
unWriterT :: w -> m (a, w) }
#if __GLASGOW_HASKELL__ >= 704
    deriving (forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall w (m :: * -> *) a x. Rep (WriterT w m a) x -> WriterT w m a
forall w (m :: * -> *) a x. WriterT w m a -> Rep (WriterT w m a) x
$cto :: forall w (m :: * -> *) a x. Rep (WriterT w m a) x -> WriterT w m a
$cfrom :: forall w (m :: * -> *) a x. WriterT w m a -> Rep (WriterT w m a) x
Generic)
#endif

-- | Construct a writer computation from a (result, output) computation.
-- (The inverse of 'runWriterT'.)
writerT :: (Functor m, Monoid w) => m (a, w) -> WriterT w m a
writerT :: forall (m :: * -> *) w a.
(Functor m, Monoid w) =>
m (a, w) -> WriterT w m a
writerT m (a, w)
f = forall w (m :: * -> *) a. (w -> m (a, w)) -> WriterT w m a
WriterT forall a b. (a -> b) -> a -> b
$ \ w
w ->
    (\ (a
a, w
w') -> let wt :: w
wt = w
w forall a. Monoid a => a -> a -> a
`mappend` w
w' in w
wt seq :: forall a b. a -> b -> b
`seq` (a
a, w
wt)) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> m (a, w)
f
{-# INLINE writerT #-}

-- | Unwrap a writer computation.
-- (The inverse of 'writerT'.)
runWriterT :: (Monoid w) => WriterT w m a -> m (a, w)
runWriterT :: forall w (m :: * -> *) a. Monoid w => WriterT w m a -> m (a, w)
runWriterT WriterT w m a
m = forall w (m :: * -> *) a. WriterT w m a -> w -> m (a, w)
unWriterT WriterT w m a
m forall a. Monoid a => a
mempty
{-# INLINE runWriterT #-}

-- | Extract the output from a writer computation.
--
-- * @'execWriterT' m = 'liftM' 'snd' ('runWriterT' m)@
execWriterT :: (Monad m, Monoid w) => WriterT w m a -> m w
execWriterT :: forall (m :: * -> *) w a.
(Monad m, Monoid w) =>
WriterT w m a -> m w
execWriterT WriterT w m a
m = do
    (a
_, w
w) <- forall w (m :: * -> *) a. Monoid w => WriterT w m a -> m (a, w)
runWriterT WriterT w m a
m
    forall (m :: * -> *) a. Monad m => a -> m a
return w
w
{-# INLINE execWriterT #-}

-- | Map both the return value and output of a computation using
-- the given function.
--
-- * @'runWriterT' ('mapWriterT' f m) = f ('runWriterT' m)@
mapWriterT :: (Monad n, Monoid w, Monoid w') =>
    (m (a, w) -> n (b, w')) -> WriterT w m a -> WriterT w' n b
mapWriterT :: forall (n :: * -> *) w w' (m :: * -> *) a b.
(Monad n, Monoid w, Monoid w') =>
(m (a, w) -> n (b, w')) -> WriterT w m a -> WriterT w' n b
mapWriterT m (a, w) -> n (b, w')
f WriterT w m a
m = forall w (m :: * -> *) a. (w -> m (a, w)) -> WriterT w m a
WriterT forall a b. (a -> b) -> a -> b
$ \ w'
w -> do
    (b
a, w'
w') <- m (a, w) -> n (b, w')
f (forall w (m :: * -> *) a. Monoid w => WriterT w m a -> m (a, w)
runWriterT WriterT w m a
m)
    let wt :: w'
wt = w'
w forall a. Monoid a => a -> a -> a
`mappend` w'
w'
    w'
wt seq :: forall a b. a -> b -> b
`seq` forall (m :: * -> *) a. Monad m => a -> m a
return (b
a, w'
wt)
{-# INLINE mapWriterT #-}

instance (Functor m) => Functor (WriterT w m) where
    fmap :: forall a b. (a -> b) -> WriterT w m a -> WriterT w m b
fmap a -> b
f WriterT w m a
m = forall w (m :: * -> *) a. (w -> m (a, w)) -> WriterT w m a
WriterT forall a b. (a -> b) -> a -> b
$ \ w
w -> (\ (a
a, w
w') -> (a -> b
f a
a, w
w')) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall w (m :: * -> *) a. WriterT w m a -> w -> m (a, w)
unWriterT WriterT w m a
m w
w
    {-# INLINE fmap #-}

instance (Functor m, Monad m) => Applicative (WriterT w m) where
    pure :: forall a. a -> WriterT w m a
pure a
a = forall w (m :: * -> *) a. (w -> m (a, w)) -> WriterT w m a
WriterT forall a b. (a -> b) -> a -> b
$ \ w
w -> forall (m :: * -> *) a. Monad m => a -> m a
return (a
a, w
w)
    {-# INLINE pure #-}

    WriterT w -> m (a -> b, w)
mf <*> :: forall a b. WriterT w m (a -> b) -> WriterT w m a -> WriterT w m b
<*> WriterT w -> m (a, w)
mx = forall w (m :: * -> *) a. (w -> m (a, w)) -> WriterT w m a
WriterT forall a b. (a -> b) -> a -> b
$ \ w
w -> do
        (a -> b
f, w
w') <- w -> m (a -> b, w)
mf w
w
        (a
x, w
w'') <- w -> m (a, w)
mx w
w'
        forall (m :: * -> *) a. Monad m => a -> m a
return (a -> b
f a
x, w
w'')
    {-# INLINE (<*>) #-}

instance (Functor m, MonadPlus m) => Alternative (WriterT w m) where
    empty :: forall a. WriterT w m a
empty = forall w (m :: * -> *) a. (w -> m (a, w)) -> WriterT w m a
WriterT forall a b. (a -> b) -> a -> b
$ forall a b. a -> b -> a
const forall (m :: * -> *) a. MonadPlus m => m a
mzero
    {-# INLINE empty #-}

    WriterT w -> m (a, w)
m <|> :: forall a. WriterT w m a -> WriterT w m a -> WriterT w m a
<|> WriterT w -> m (a, w)
n = forall w (m :: * -> *) a. (w -> m (a, w)) -> WriterT w m a
WriterT forall a b. (a -> b) -> a -> b
$ \ w
w -> w -> m (a, w)
m w
w forall (m :: * -> *) a. MonadPlus m => m a -> m a -> m a
`mplus` w -> m (a, w)
n w
w
    {-# INLINE (<|>) #-}

instance (Monad m) => Monad (WriterT w m) where
#if !(MIN_VERSION_base(4,8,0))
    return a = WriterT $ \ w -> return (a, w)
    {-# INLINE return #-}
#endif

    WriterT w m a
m >>= :: forall a b. WriterT w m a -> (a -> WriterT w m b) -> WriterT w m b
>>= a -> WriterT w m b
k = forall w (m :: * -> *) a. (w -> m (a, w)) -> WriterT w m a
WriterT forall a b. (a -> b) -> a -> b
$ \ w
w -> do
        (a
a, w
w') <- forall w (m :: * -> *) a. WriterT w m a -> w -> m (a, w)
unWriterT WriterT w m a
m w
w
        forall w (m :: * -> *) a. WriterT w m a -> w -> m (a, w)
unWriterT (a -> WriterT w m b
k a
a) w
w'
    {-# INLINE (>>=) #-}

#if !(MIN_VERSION_base(4,13,0))
    fail msg = WriterT $ \ _ -> fail msg
    {-# INLINE fail #-}
#endif

#if MIN_VERSION_base(4,9,0)
instance (Fail.MonadFail m) => Fail.MonadFail (WriterT w m) where
    fail :: forall a. String -> WriterT w m a
fail String
msg = forall w (m :: * -> *) a. (w -> m (a, w)) -> WriterT w m a
WriterT forall a b. (a -> b) -> a -> b
$ \ w
_ -> forall (m :: * -> *) a. MonadFail m => String -> m a
Fail.fail String
msg
    {-# INLINE fail #-}
#endif

instance (Functor m, MonadPlus m) => MonadPlus (WriterT w m) where
    mzero :: forall a. WriterT w m a
mzero = forall (f :: * -> *) a. Alternative f => f a
empty
    {-# INLINE mzero #-}
    mplus :: forall a. WriterT w m a -> WriterT w m a -> WriterT w m a
mplus = forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
(<|>)
    {-# INLINE mplus #-}

instance (MonadFix m) => MonadFix (WriterT w m) where
    mfix :: forall a. (a -> WriterT w m a) -> WriterT w m a
mfix a -> WriterT w m a
f = forall w (m :: * -> *) a. (w -> m (a, w)) -> WriterT w m a
WriterT forall a b. (a -> b) -> a -> b
$ \ w
w -> forall (m :: * -> *) a. MonadFix m => (a -> m a) -> m a
mfix forall a b. (a -> b) -> a -> b
$ \ ~(a
a, w
_) -> forall w (m :: * -> *) a. WriterT w m a -> w -> m (a, w)
unWriterT (a -> WriterT w m a
f a
a) w
w
    {-# INLINE mfix #-}

instance MonadTrans (WriterT w) where
    lift :: forall (m :: * -> *) a. Monad m => m a -> WriterT w m a
lift m a
m = forall w (m :: * -> *) a. (w -> m (a, w)) -> WriterT w m a
WriterT forall a b. (a -> b) -> a -> b
$ \ w
w -> do
        a
a <- m a
m
        forall (m :: * -> *) a. Monad m => a -> m a
return (a
a, w
w)
    {-# INLINE lift #-}

instance (MonadIO m) => MonadIO (WriterT w m) where
    liftIO :: forall a. IO a -> WriterT w m a
liftIO = forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO
    {-# INLINE liftIO #-}

-- | @'tell' w@ is an action that produces the output @w@.
tell :: (Monoid w, Monad m) => w -> WriterT w m ()
tell :: forall w (m :: * -> *). (Monoid w, Monad m) => w -> WriterT w m ()
tell w
w = forall w (m :: * -> *) a.
(Monoid w, Monad m) =>
(a, w) -> WriterT w m a
writer ((), w
w)
{-# INLINE tell #-}

-- | @'listen' m@ is an action that executes the action @m@ and adds its
-- output to the value of the computation.
--
-- * @'runWriterT' ('listen' m) = 'liftM' (\\ (a, w) -> ((a, w), w)) ('runWriterT' m)@
listen :: (Monoid w, Monad m) => WriterT w m a -> WriterT w m (a, w)
listen :: forall w (m :: * -> *) a.
(Monoid w, Monad m) =>
WriterT w m a -> WriterT w m (a, w)
listen = forall w (m :: * -> *) b a.
(Monoid w, Monad m) =>
(w -> b) -> WriterT w m a -> WriterT w m (a, b)
listens forall a. a -> a
id
{-# INLINE listen #-}

-- | @'listens' f m@ is an action that executes the action @m@ and adds
-- the result of applying @f@ to the output to the value of the computation.
--
-- * @'listens' f m = 'liftM' (id *** f) ('listen' m)@
--
-- * @'runWriterT' ('listens' f m) = 'liftM' (\\ (a, w) -> ((a, f w), w)) ('runWriterT' m)@
listens :: (Monoid w, Monad m) =>
    (w -> b) -> WriterT w m a -> WriterT w m (a, b)
listens :: forall w (m :: * -> *) b a.
(Monoid w, Monad m) =>
(w -> b) -> WriterT w m a -> WriterT w m (a, b)
listens w -> b
f WriterT w m a
m = forall w (m :: * -> *) a. (w -> m (a, w)) -> WriterT w m a
WriterT forall a b. (a -> b) -> a -> b
$ \ w
w -> do
    (a
a, w
w') <- forall w (m :: * -> *) a. Monoid w => WriterT w m a -> m (a, w)
runWriterT WriterT w m a
m
    let wt :: w
wt = w
w forall a. Monoid a => a -> a -> a
`mappend` w
w'
    w
wt seq :: forall a b. a -> b -> b
`seq` forall (m :: * -> *) a. Monad m => a -> m a
return ((a
a, w -> b
f w
w'), w
wt)
{-# INLINE listens #-}

-- | @'pass' m@ is an action that executes the action @m@, which returns
-- a value and a function, and returns the value, applying the function
-- to the output.
--
-- * @'runWriterT' ('pass' m) = 'liftM' (\\ ((a, f), w) -> (a, f w)) ('runWriterT' m)@
pass :: (Monoid w, Monoid w', Monad m) =>
    WriterT w m (a, w -> w') -> WriterT w' m a
pass :: forall w w' (m :: * -> *) a.
(Monoid w, Monoid w', Monad m) =>
WriterT w m (a, w -> w') -> WriterT w' m a
pass WriterT w m (a, w -> w')
m = forall w (m :: * -> *) a. (w -> m (a, w)) -> WriterT w m a
WriterT forall a b. (a -> b) -> a -> b
$ \ w'
w -> do
    ((a
a, w -> w'
f), w
w') <- forall w (m :: * -> *) a. Monoid w => WriterT w m a -> m (a, w)
runWriterT WriterT w m (a, w -> w')
m
    let wt :: w'
wt = w'
w forall a. Monoid a => a -> a -> a
`mappend` w -> w'
f w
w'
    w'
wt seq :: forall a b. a -> b -> b
`seq` forall (m :: * -> *) a. Monad m => a -> m a
return (a
a, w'
wt)
{-# INLINE pass #-}

-- | @'censor' f m@ is an action that executes the action @m@ and
-- applies the function @f@ to its output, leaving the return value
-- unchanged.
--
-- * @'censor' f m = 'pass' ('liftM' (\\ x -> (x,f)) m)@
--
-- * @'runWriterT' ('censor' f m) = 'liftM' (\\ (a, w) -> (a, f w)) ('runWriterT' m)@
censor :: (Monoid w, Monad m) => (w -> w) -> WriterT w m a -> WriterT w m a
censor :: forall w (m :: * -> *) a.
(Monoid w, Monad m) =>
(w -> w) -> WriterT w m a -> WriterT w m a
censor w -> w
f WriterT w m a
m = forall w (m :: * -> *) a. (w -> m (a, w)) -> WriterT w m a
WriterT forall a b. (a -> b) -> a -> b
$ \ w
w -> do
    (a
a, w
w') <- forall w (m :: * -> *) a. Monoid w => WriterT w m a -> m (a, w)
runWriterT WriterT w m a
m
    let wt :: w
wt = w
w forall a. Monoid a => a -> a -> a
`mappend` w -> w
f w
w'
    w
wt seq :: forall a b. a -> b -> b
`seq` forall (m :: * -> *) a. Monad m => a -> m a
return (a
a, w
wt)
{-# INLINE censor #-}

-- | Uniform lifting of a @callCC@ operation to the new monad.
-- This version rolls back to the original state on entering the
-- continuation.
liftCallCC :: CallCC m (a, w) (b, w) -> CallCC (WriterT w m) a b
liftCallCC :: forall (m :: * -> *) a w b.
CallCC m (a, w) (b, w) -> CallCC (WriterT w m) a b
liftCallCC CallCC m (a, w) (b, w)
callCC (a -> WriterT w m b) -> WriterT w m a
f = forall w (m :: * -> *) a. (w -> m (a, w)) -> WriterT w m a
WriterT forall a b. (a -> b) -> a -> b
$ \ w
w ->
    CallCC m (a, w) (b, w)
callCC forall a b. (a -> b) -> a -> b
$ \ (a, w) -> m (b, w)
c -> forall w (m :: * -> *) a. WriterT w m a -> w -> m (a, w)
unWriterT ((a -> WriterT w m b) -> WriterT w m a
f (\ a
a -> forall w (m :: * -> *) a. (w -> m (a, w)) -> WriterT w m a
WriterT forall a b. (a -> b) -> a -> b
$ \ w
_ -> (a, w) -> m (b, w)
c (a
a, w
w))) w
w
{-# INLINE liftCallCC #-}

-- | Lift a @catchE@ operation to the new monad.
liftCatch :: Catch e m (a, w) -> Catch e (WriterT w m) a
liftCatch :: forall e (m :: * -> *) a w.
Catch e m (a, w) -> Catch e (WriterT w m) a
liftCatch Catch e m (a, w)
catchE WriterT w m a
m e -> WriterT w m a
h = forall w (m :: * -> *) a. (w -> m (a, w)) -> WriterT w m a
WriterT forall a b. (a -> b) -> a -> b
$ \ w
w ->
    forall w (m :: * -> *) a. WriterT w m a -> w -> m (a, w)
unWriterT WriterT w m a
m w
w Catch e m (a, w)
`catchE` \ e
e -> forall w (m :: * -> *) a. WriterT w m a -> w -> m (a, w)
unWriterT (e -> WriterT w m a
h e
e) w
w
{-# INLINE liftCatch #-}