{-# OPTIONS_GHC -Wall #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE NoMonomorphismRestriction #-}
{-# LANGUAGE TypeFamilies #-}

-- |
-- Module: Numeric.MCMC.Slice
-- Copyright: (c) 2015 Jared Tobin
-- License: MIT
--
-- Maintainer: Jared Tobin <jared@jtobin.ca>
-- Stability: unstable
-- Portability: ghc
--
-- This implementation performs slice sampling by first finding a bracket about
-- a mode (using a simple doubling heuristic), and then doing rejection
-- sampling along it.  The result is a reliable and computationally inexpensive
-- sampling routine.
--
-- The 'mcmc' function streams a trace to stdout to be processed elsewhere,
-- while the `slice` transition can be used for more flexible purposes, such as
-- working with samples in memory.
--
-- See <http://people.ee.duke.edu/~lcarin/slice.pdf Neal, 2003> for the
-- definitive reference of the algorithm.

module Numeric.MCMC.Slice (
    mcmc
  , chain
  , slice

  -- * Re-exported
  , MWC.create
  , MWC.createSystemRandom
  , MWC.withSystemRandom
  , MWC.asGenIO
  ) where

import Control.Monad (replicateM)
import Control.Monad.Codensity (lowerCodensity)
import Control.Monad.Trans.State.Strict (put, get, execStateT)
import Control.Monad.Primitive (PrimMonad, PrimState)
import Control.Lens hiding (index)
import Data.Maybe (fromMaybe)
import Data.Sampling.Types
import Pipes hiding (next)
import qualified Pipes.Prelude as Pipes
import System.Random.MWC.Probability (Prob, Gen, Variate)
import qualified System.Random.MWC.Probability as MWC

-- | Trace 'n' iterations of a Markov chain and stream them to stdout.
--
-- >>> let rosenbrock [x0, x1] = negate (5  *(x1 - x0 ^ 2) ^ 2 + 0.05 * (1 - x0) ^ 2)
-- >>> withSystemRandom . asGenIO $ mcmc 3 1 [0, 0] rosenbrock
-- -3.854097694213343e-2,0.16688601288358407
-- -9.310661272172682e-2,0.2562387977415508
-- -0.48500122500661846,0.46245400501919076
mcmc
  :: (MonadIO m, PrimMonad m,
     Show (t a), FoldableWithIndex (Index (t a)) t, Ixed (t a),
     Num (IxValue (t a)), Variate (IxValue (t a)))
  => Int
  -> IxValue (t a)
  -> t a
  -> (t a -> Double)
  -> Gen (PrimState m)
  -> m ()
mcmc :: Int
-> IxValue (t a)
-> t a
-> (t a -> Double)
-> Gen (PrimState m)
-> m ()
mcmc Int
n IxValue (t a)
radial t a
chainPosition t a -> Double
target Gen (PrimState m)
gen = Effect m () -> m ()
forall (m :: * -> *) r. Monad m => Effect m r -> m r
runEffect (Effect m () -> m ()) -> Effect m () -> m ()
forall a b. (a -> b) -> a -> b
$
        IxValue (t a)
-> Chain (t a) Any
-> Gen (PrimState m)
-> Producer (Chain (t a) Any) m ()
forall (m :: * -> *) (t :: * -> *) a b c.
(PrimMonad m, FoldableWithIndex (Index (t a)) t, Ixed (t a),
 Num (IxValue (t a)), Variate (IxValue (t a))) =>
IxValue (t a)
-> Chain (t a) b
-> Gen (PrimState m)
-> Producer (Chain (t a) b) m c
drive IxValue (t a)
radial Chain :: forall a b. Target a -> Double -> a -> Maybe b -> Chain a b
Chain {t a
Double
Maybe Any
Target (t a)
forall a. Maybe a
chainTarget :: Target (t a)
chainScore :: Double
chainPosition :: t a
chainTunables :: Maybe Any
chainTarget :: Target (t a)
chainTunables :: forall a. Maybe a
chainScore :: Double
chainPosition :: t a
..} Gen (PrimState m)
gen
    Producer (Chain (t a) Any) m ()
-> Proxy () (Chain (t a) Any) () (Chain (t a) Any) m ()
-> Producer (Chain (t a) Any) m ()
forall (m :: * -> *) a' a b r c' c.
Functor m =>
Proxy a' a () b m r -> Proxy () b c' c m r -> Proxy a' a c' c m r
>-> Int -> Proxy () (Chain (t a) Any) () (Chain (t a) Any) m ()
forall (m :: * -> *) a. Functor m => Int -> Pipe a a m ()
Pipes.take Int
n
    Producer (Chain (t a) Any) m ()
-> Proxy () (Chain (t a) Any) () X m () -> Effect m ()
forall (m :: * -> *) a' a b r c' c.
Functor m =>
Proxy a' a () b m r -> Proxy () b c' c m r -> Proxy a' a c' c m r
>-> (Chain (t a) Any -> m ()) -> Consumer' (Chain (t a) Any) m ()
forall (m :: * -> *) a r. Monad m => (a -> m ()) -> Consumer' a m r
Pipes.mapM_ (IO () -> m ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ())
-> (Chain (t a) Any -> IO ()) -> Chain (t a) Any -> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Chain (t a) Any -> IO ()
forall a. Show a => a -> IO ()
print)
  where
    chainScore :: Double
chainScore    = Target (t a) -> t a -> Double
forall a. Target a -> a -> Double
lTarget Target (t a)
chainTarget t a
chainPosition
    chainTunables :: Maybe a
chainTunables = Maybe a
forall a. Maybe a
Nothing
    chainTarget :: Target (t a)
chainTarget   = (t a -> Double) -> Maybe (t a -> t a) -> Target (t a)
forall a. (a -> Double) -> Maybe (a -> a) -> Target a
Target t a -> Double
target Maybe (t a -> t a)
forall a. Maybe a
Nothing

-- | Trace 'n' iterations of a Markov chain and collect them in a list.
--
-- >>> results <- withSystemRandom . asGenIO $ mcmc 3 1 [0, 0] rosenbrock
chain
  :: (PrimMonad m, FoldableWithIndex (Index (f a)) f, Ixed (f a)
     , Variate (IxValue (f a)), Num (IxValue (f a)))
  => Int
  -> IxValue (f a)
  -> f a
  -> (f a -> Double)
  -> Gen (PrimState m)
  -> m [Chain (f a) b]
chain :: Int
-> IxValue (f a)
-> f a
-> (f a -> Double)
-> Gen (PrimState m)
-> m [Chain (f a) b]
chain Int
n IxValue (f a)
radial f a
position f a -> Double
target Gen (PrimState m)
gen = Effect m [Chain (f a) b] -> m [Chain (f a) b]
forall (m :: * -> *) r. Monad m => Effect m r -> m r
runEffect (Effect m [Chain (f a) b] -> m [Chain (f a) b])
-> Effect m [Chain (f a) b] -> m [Chain (f a) b]
forall a b. (a -> b) -> a -> b
$
        IxValue (f a)
-> Chain (f a) b
-> Gen (PrimState m)
-> Producer (Chain (f a) b) m [Chain (f a) b]
forall (m :: * -> *) (t :: * -> *) a b c.
(PrimMonad m, FoldableWithIndex (Index (t a)) t, Ixed (t a),
 Num (IxValue (t a)), Variate (IxValue (t a))) =>
IxValue (t a)
-> Chain (t a) b
-> Gen (PrimState m)
-> Producer (Chain (t a) b) m c
drive IxValue (f a)
radial Chain (f a) b
origin Gen (PrimState m)
gen
    Producer (Chain (f a) b) m [Chain (f a) b]
-> Proxy () (Chain (f a) b) () X m [Chain (f a) b]
-> Effect m [Chain (f a) b]
forall (m :: * -> *) a' a b r c' c.
Functor m =>
Proxy a' a () b m r -> Proxy () b c' c m r -> Proxy a' a c' c m r
>-> Int -> Proxy () (Chain (f a) b) () X m [Chain (f a) b]
forall (m :: * -> *) a. Monad m => Int -> Consumer a m [a]
collect Int
n
  where
    ctarget :: Target (f a)
ctarget = (f a -> Double) -> Maybe (f a -> f a) -> Target (f a)
forall a. (a -> Double) -> Maybe (a -> a) -> Target a
Target f a -> Double
target Maybe (f a -> f a)
forall a. Maybe a
Nothing

    origin :: Chain (f a) b
origin = Chain :: forall a b. Target a -> Double -> a -> Maybe b -> Chain a b
Chain {
        chainScore :: Double
chainScore    = Target (f a) -> f a -> Double
forall a. Target a -> a -> Double
lTarget Target (f a)
ctarget f a
position
      , chainTunables :: Maybe b
chainTunables = Maybe b
forall a. Maybe a
Nothing
      , chainTarget :: Target (f a)
chainTarget   = Target (f a)
ctarget
      , chainPosition :: f a
chainPosition = f a
position
      }

    collect :: Monad m => Int -> Consumer a m [a]
    collect :: Int -> Consumer a m [a]
collect Int
size = Codensity (Proxy () a () X m) [a] -> Consumer a m [a]
forall (f :: * -> *) a. Applicative f => Codensity f a -> f a
lowerCodensity (Codensity (Proxy () a () X m) [a] -> Consumer a m [a])
-> Codensity (Proxy () a () X m) [a] -> Consumer a m [a]
forall a b. (a -> b) -> a -> b
$
      Int
-> Codensity (Proxy () a () X m) a
-> Codensity (Proxy () a () X m) [a]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
size (Proxy () a () X m a -> Codensity (Proxy () a () X m) a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift Proxy () a () X m a
forall (m :: * -> *) a. Functor m => Consumer' a m a
Pipes.await)

-- A Markov chain driven by the slice transition operator.
drive
  :: (PrimMonad m, FoldableWithIndex (Index (t a)) t, Ixed (t a),
     Num (IxValue (t a)), Variate (IxValue (t a)))
  => IxValue (t a)
  -> Chain (t a) b
  -> Gen (PrimState m)
  -> Producer (Chain (t a) b) m c
drive :: IxValue (t a)
-> Chain (t a) b
-> Gen (PrimState m)
-> Producer (Chain (t a) b) m c
drive IxValue (t a)
radial = Chain (t a) b -> Gen (PrimState m) -> Producer (Chain (t a) b) m c
loop where
  loop :: Chain (t a) b -> Gen (PrimState m) -> Producer (Chain (t a) b) m c
loop Chain (t a) b
state Gen (PrimState m)
prng = do
    Chain (t a) b
next <- m (Chain (t a) b)
-> Proxy X () () (Chain (t a) b) m (Chain (t a) b)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (Prob m (Chain (t a) b) -> Gen (PrimState m) -> m (Chain (t a) b)
forall (m :: * -> *) a. Prob m a -> Gen (PrimState m) -> m a
MWC.sample (StateT (Chain (t a) b) (Prob m) ()
-> Chain (t a) b -> Prob m (Chain (t a) b)
forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m s
execStateT (IxValue (t a) -> StateT (Chain (t a) b) (Prob m) ()
forall (m :: * -> *) (t :: * -> *) a b.
(PrimMonad m, FoldableWithIndex (Index (t a)) t, Ixed (t a),
 Num (IxValue (t a)), Variate (IxValue (t a))) =>
IxValue (t a) -> Transition m (Chain (t a) b)
slice IxValue (t a)
radial) Chain (t a) b
state) Gen (PrimState m)
prng)
    Chain (t a) b -> Proxy X () () (Chain (t a) b) m ()
forall (m :: * -> *) a x' x. Functor m => a -> Proxy x' x () a m ()
yield Chain (t a) b
next
    Chain (t a) b -> Gen (PrimState m) -> Producer (Chain (t a) b) m c
loop Chain (t a) b
next Gen (PrimState m)
prng

-- | A slice sampling transition operator.
slice
  :: (PrimMonad m, FoldableWithIndex (Index (t a)) t, Ixed (t a),
      Num (IxValue (t a)), Variate (IxValue (t a)))
  => IxValue (t a)
  -> Transition m (Chain (t a) b)
slice :: IxValue (t a) -> Transition m (Chain (t a) b)
slice IxValue (t a)
step = do
  Chain Target (t a)
_ Double
_ t a
position Maybe b
_ <- StateT (Chain (t a) b) (Prob m) (Chain (t a) b)
forall (m :: * -> *) s. Monad m => StateT s m s
get
  t a
-> (Index (t a) -> a -> Transition m (Chain (t a) b))
-> Transition m (Chain (t a) b)
forall i (t :: * -> *) (f :: * -> *) a b.
(FoldableWithIndex i t, Applicative f) =>
t a -> (i -> a -> f b) -> f ()
ifor_ t a
position ((Index (t a) -> a -> Transition m (Chain (t a) b))
 -> Transition m (Chain (t a) b))
-> (Index (t a) -> a -> Transition m (Chain (t a) b))
-> Transition m (Chain (t a) b)
forall a b. (a -> b) -> a -> b
$ \Index (t a)
index a
_ -> do
    Chain {t a
Double
Maybe b
Target (t a)
chainTunables :: Maybe b
chainPosition :: t a
chainScore :: Double
chainTarget :: Target (t a)
chainTarget :: forall a b. Chain a b -> Target a
chainScore :: forall a b. Chain a b -> Double
chainPosition :: forall a b. Chain a b -> a
chainTunables :: forall a b. Chain a b -> Maybe b
..} <- StateT (Chain (t a) b) (Prob m) (Chain (t a) b)
forall (m :: * -> *) s. Monad m => StateT s m s
get
    let bounds :: (Double, Double)
bounds = (Double
0, Double -> Double
forall a. Floating a => a -> a
exp (Target (t a) -> t a -> Double
forall a. Target a -> a -> Double
lTarget Target (t a)
chainTarget t a
chainPosition))
    Double
height    <- Prob m Double -> StateT (Chain (t a) b) (Prob m) Double
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift ((Double -> Double) -> Prob m Double -> Prob m Double
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Double -> Double
forall a. Floating a => a -> a
log ((Double, Double) -> Prob m Double
forall (m :: * -> *) a.
(PrimMonad m, Variate a) =>
(a, a) -> Prob m a
MWC.uniformR (Double, Double)
bounds))

    let bracket :: (IxValue (t a), IxValue (t a))
bracket =
          (t a -> Double)
-> Index (t a)
-> IxValue (t a)
-> Double
-> t a
-> (IxValue (t a), IxValue (t a))
forall a s.
(Ord a, Ixed s, Num (IxValue s)) =>
(s -> a)
-> Index s -> IxValue s -> a -> s -> (IxValue s, IxValue s)
findBracket (Target (t a) -> t a -> Double
forall a. Target a -> a -> Double
lTarget Target (t a)
chainTarget) Index (t a)
index IxValue (t a)
step Double
height t a
chainPosition

    t a
perturbed <- Prob m (t a) -> StateT (Chain (t a) b) (Prob m) (t a)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (Prob m (t a) -> StateT (Chain (t a) b) (Prob m) (t a))
-> Prob m (t a) -> StateT (Chain (t a) b) (Prob m) (t a)
forall a b. (a -> b) -> a -> b
$
      (t a -> Double)
-> Index (t a)
-> (IxValue (t a), IxValue (t a))
-> Double
-> t a
-> Prob m (t a)
forall a (m :: * -> *) b.
(Ord a, PrimMonad m, Ixed b, Variate (IxValue b)) =>
(b -> a) -> Index b -> (IxValue b, IxValue b) -> a -> b -> Prob m b
rejection (Target (t a) -> t a -> Double
forall a. Target a -> a -> Double
lTarget Target (t a)
chainTarget) Index (t a)
index (IxValue (t a), IxValue (t a))
bracket Double
height t a
chainPosition

    let perturbedScore :: Double
perturbedScore = Target (t a) -> t a -> Double
forall a. Target a -> a -> Double
lTarget Target (t a)
chainTarget t a
perturbed
    Chain (t a) b -> Transition m (Chain (t a) b)
forall (m :: * -> *) s. Monad m => s -> StateT s m ()
put (Target (t a) -> Double -> t a -> Maybe b -> Chain (t a) b
forall a b. Target a -> Double -> a -> Maybe b -> Chain a b
Chain Target (t a)
chainTarget Double
perturbedScore t a
perturbed Maybe b
chainTunables)

-- Find a bracket by expanding its bounds through powers of 2.
findBracket
  :: (Ord a, Ixed s, Num (IxValue s))
  => (s -> a)
  -> Index s
  -> IxValue s
  -> a
  -> s
  -> (IxValue s, IxValue s)
findBracket :: (s -> a)
-> Index s -> IxValue s -> a -> s -> (IxValue s, IxValue s)
findBracket s -> a
target Index s
index IxValue s
step a
height s
xs = IxValue s -> s -> s -> (IxValue s, IxValue s)
go IxValue s
step s
xs s
xs where
  err :: a
err = [Char] -> a
forall a. HasCallStack => [Char] -> a
error [Char]
"findBracket: invalid index -- please report this as a bug!"
  go :: IxValue s -> s -> s -> (IxValue s, IxValue s)
go !IxValue s
e !s
bl !s
br
    | s -> a
target s
bl a -> a -> Bool
forall a. Ord a => a -> a -> Bool
< a
height Bool -> Bool -> Bool
&& s -> a
target s
br a -> a -> Bool
forall a. Ord a => a -> a -> Bool
< a
height =
        let l :: IxValue s
l = IxValue s -> Maybe (IxValue s) -> IxValue s
forall a. a -> Maybe a -> a
fromMaybe IxValue s
forall a. a
err (s
bl s -> Getting (First (IxValue s)) s (IxValue s) -> Maybe (IxValue s)
forall s a. s -> Getting (First a) s a -> Maybe a
^? Index s -> Traversal' s (IxValue s)
forall m. Ixed m => Index m -> Traversal' m (IxValue m)
ix Index s
index)
            r :: IxValue s
r = IxValue s -> Maybe (IxValue s) -> IxValue s
forall a. a -> Maybe a -> a
fromMaybe IxValue s
forall a. a
err (s
br s -> Getting (First (IxValue s)) s (IxValue s) -> Maybe (IxValue s)
forall s a. s -> Getting (First a) s a -> Maybe a
^? Index s -> Traversal' s (IxValue s)
forall m. Ixed m => Index m -> Traversal' m (IxValue m)
ix Index s
index)
        in  (IxValue s
l, IxValue s
r)
    | s -> a
target s
bl a -> a -> Bool
forall a. Ord a => a -> a -> Bool
< a
height Bool -> Bool -> Bool
&& s -> a
target s
br a -> a -> Bool
forall a. Ord a => a -> a -> Bool
>= a
height =
        let br0 :: s
br0 = Index s -> IxValue s -> s -> s
forall s.
(Ixed s, Num (IxValue s)) =>
Index s -> IxValue s -> s -> s
expandBracketRight Index s
index IxValue s
step s
br
        in  IxValue s -> s -> s -> (IxValue s, IxValue s)
go (IxValue s
2 IxValue s -> IxValue s -> IxValue s
forall a. Num a => a -> a -> a
* IxValue s
e) s
bl s
br0
    | s -> a
target s
bl a -> a -> Bool
forall a. Ord a => a -> a -> Bool
>= a
height Bool -> Bool -> Bool
&& s -> a
target s
br a -> a -> Bool
forall a. Ord a => a -> a -> Bool
< a
height =
        let bl0 :: s
bl0 = Index s -> IxValue s -> s -> s
forall s.
(Ixed s, Num (IxValue s)) =>
Index s -> IxValue s -> s -> s
expandBracketLeft Index s
index IxValue s
step s
bl
        in  IxValue s -> s -> s -> (IxValue s, IxValue s)
go (IxValue s
2 IxValue s -> IxValue s -> IxValue s
forall a. Num a => a -> a -> a
* IxValue s
e) s
bl0 s
br
    | Bool
otherwise =
        let bl0 :: s
bl0 = Index s -> IxValue s -> s -> s
forall s.
(Ixed s, Num (IxValue s)) =>
Index s -> IxValue s -> s -> s
expandBracketLeft Index s
index IxValue s
step s
bl
            br0 :: s
br0 = Index s -> IxValue s -> s -> s
forall s.
(Ixed s, Num (IxValue s)) =>
Index s -> IxValue s -> s -> s
expandBracketRight Index s
index IxValue s
step s
br
        in  IxValue s -> s -> s -> (IxValue s, IxValue s)
go (IxValue s
2 IxValue s -> IxValue s -> IxValue s
forall a. Num a => a -> a -> a
* IxValue s
e) s
bl0 s
br0

expandBracketLeft
  :: (Ixed s, Num (IxValue s))
  => Index s
  -> IxValue s
  -> s
  -> s
expandBracketLeft :: Index s -> IxValue s -> s -> s
expandBracketLeft = (IxValue s -> IxValue s -> IxValue s)
-> Index s -> IxValue s -> s -> s
forall s t.
Ixed s =>
(IxValue s -> t -> IxValue s) -> Index s -> t -> s -> s
expandBracketBy (-)

expandBracketRight
  :: (Ixed s, Num (IxValue s))
  => Index s
  -> IxValue s
  -> s
  -> s
expandBracketRight :: Index s -> IxValue s -> s -> s
expandBracketRight = (IxValue s -> IxValue s -> IxValue s)
-> Index s -> IxValue s -> s -> s
forall s t.
Ixed s =>
(IxValue s -> t -> IxValue s) -> Index s -> t -> s -> s
expandBracketBy IxValue s -> IxValue s -> IxValue s
forall a. Num a => a -> a -> a
(+)

expandBracketBy
  :: Ixed s
  => (IxValue s -> t -> IxValue s)
  -> Index s
  -> t
  -> s
  -> s
expandBracketBy :: (IxValue s -> t -> IxValue s) -> Index s -> t -> s -> s
expandBracketBy IxValue s -> t -> IxValue s
f Index s
index t
step s
xs = s
xs s -> (s -> s) -> s
forall a b. a -> (a -> b) -> b
& Index s -> Traversal' s (IxValue s)
forall m. Ixed m => Index m -> Traversal' m (IxValue m)
ix Index s
index ((IxValue s -> Identity (IxValue s)) -> s -> Identity s)
-> (IxValue s -> IxValue s) -> s -> s
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
%~ (IxValue s -> t -> IxValue s
`f` t
step )

-- Perform rejection sampling within the supplied bracket.
rejection
  :: (Ord a, PrimMonad m, Ixed b, Variate (IxValue b))
  => (b -> a)
  -> Index b
  -> (IxValue b, IxValue b)
  -> a
  -> b
  -> Prob m b
rejection :: (b -> a) -> Index b -> (IxValue b, IxValue b) -> a -> b -> Prob m b
rejection b -> a
target Index b
dimension (IxValue b, IxValue b)
bracket a
height = b -> Prob m b
go where
  go :: b -> Prob m b
go b
zs = do
    IxValue b
u <- (IxValue b, IxValue b) -> Prob m (IxValue b)
forall (m :: * -> *) a.
(PrimMonad m, Variate a) =>
(a, a) -> Prob m a
MWC.uniformR (IxValue b, IxValue b)
bracket
    let  updated :: b
updated = b
zs b -> (b -> b) -> b
forall a b. a -> (a -> b) -> b
& Index b -> Traversal' b (IxValue b)
forall m. Ixed m => Index m -> Traversal' m (IxValue m)
ix Index b
dimension ((IxValue b -> Identity (IxValue b)) -> b -> Identity b)
-> IxValue b -> b -> b
forall s t a b. ASetter s t a b -> b -> s -> t
.~ IxValue b
u
    if   b -> a
target b
updated a -> a -> Bool
forall a. Ord a => a -> a -> Bool
< a
height
    then b -> Prob m b
go b
updated
    else b -> Prob m b
forall (m :: * -> *) a. Monad m => a -> m a
return b
updated