{-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE Strict #-} {-# LANGUAGE UndecidableInstances #-} {-# LANGUAGE ViewPatterns #-} -- | This module provides monad transformer similar to -- 'Control.Monad.Writer.Strict.WriterT', implemented using 'StateT', making it -- tail recursive. (The traditional writer always leaks space: see -- -- for more information). -- -- -- are used to provide the same interface as -- 'Control.Monad.Writer.Strict.WriterT'. Unfortunately, current GHC warns -- whenever these patterns are used that there are unmatched patterns: the -- pragma should solve -- this problem in future version of GHC. -- -- A pattern synonym is also provided for a non-transformer version of writer. -- Again, this is just 'StateT' underneath, but its interface looks as if it was -- defined like so: -- -- > newtype Writer w a = Writer { runWriter :: (a, w) } -- -- The other difference between this monad and -- 'Control.Monad.Writer.Strict.WriterT' is that it relies on '<.>' from -- 'Semiring', rather than 'mappend' from 'Monoid'. module Control.Monad.Weighted ( -- * Transformer WeightedT ,runWeightedT ,pattern WeightedT ,execWeightedT ,evalWeightedT , -- * Plain Weighted ,runWeighted ,pattern Weighted ,execWeighted ,evalWeighted) where import Control.Applicative import Control.Monad.Fail import Control.Monad.Identity import Control.Monad.Reader.Class import Control.Monad.Weighted.Class import Control.Monad.Writer.Class import Control.Monad.Cont.Class import Control.Monad.Error.Class import Control.Monad.State.Strict import Data.Coerce import Data.Functor.Classes import Data.Monoid import Data.Semiring -- | A monad transformer similar to 'Control.Monad.Writer.Strict.WriterT', except -- that it does not leak space. It is implemented using a state monad, so that -- `mappend` is tail recursive. See -- -- email to the Haskell libraries committee for more information. -- -- It also uses '<.>' from 'Semiring', rather than 'mappend' from 'Monoid' when -- combining computations. -- -- Wherever possible, coercions are used to eliminate any overhead from the -- newtype wrapper. newtype WeightedT s m a = WeightedT_ (StateT s m a) deriving (Functor,Applicative,Monad,MonadTrans,MonadCont ,MonadError e,MonadReader r,MonadFix,MonadFail ,MonadIO,Alternative,MonadPlus,MonadWriter w) -- | Run a weighted computation in the underlying monad. runWeightedT :: Semiring s => WeightedT s m a -> m (a, s) runWeightedT = (coerce :: (StateT s m a -> m (a, s)) -> WeightedT s m a -> m (a, s)) (`runStateT` one) {-# INLINE runWeightedT #-} {-# ANN module "HLint: ignore Use second" #-} -- | This pattern gives the newtype wrapper around 'StateT' the same interface -- as 'Control.Monad.Writer.Strict.WriterT'. Unfortunately, GHC currently warns -- that a function is incomplete wherever this pattern is used. This issue -- should be solved in a future version of GHC, when the -- pragma is -- implemented. pattern WeightedT :: (Functor m, Semiring s) => m (a, s) -> WeightedT s m a pattern WeightedT x <- (runWeightedT -> x) where WeightedT y = WeightedT_ (StateT (\ s -> fmap (\ (x, p) -> (x, s <.> p)) y)) -- | A type synonym for the plain (non-transformer) version of 'Weighted'. This -- can be used as if it were defined as: -- -- > newtype Weighted w a = Weighted { runWeighted :: (a, w) } type Weighted s = WeightedT s Identity -- | This pattern gives the newtype wrapper around 'StateT' the same interface -- as as if it was defined like so: -- -- > newtype Weighted w a = Weighted { runWeighted :: (a, w) } -- -- Unfortunately GHC warns that a function is incomplete wherever this pattern -- is used. This issue should be solved in a future version of GHC, when the -- pragma is -- implemented. -- -- >>> execWeighted $ traverse (\x -> Weighted ((), x)) [1..5] -- 120 pattern Weighted :: Semiring s => (a, s) -> Weighted s a pattern Weighted x <- (runWeighted -> x) where Weighted (y, p) = WeightedT_ (StateT (\ s -> Identity (y, (<.>) p s))) -- | Run a weighted computation. -- -- >>> runWeighted $ traverse (\x -> Weighted (show x, x)) [1..5] -- (["1","2","3","4","5"],120) runWeighted :: Semiring s => Weighted s a -> (a, s) runWeighted = (coerce :: (WeightedT s Identity a -> Identity (a, s)) -> (WeightedT s Identity a -> (a, s))) runWeightedT {-# INLINE runWeighted #-} instance MonadState s m => MonadState s (WeightedT w m) where get = lift get put = lift . put state = lift . state -- | Run a weighted computation in the underlying monad, and return its result. evalWeightedT :: (Monad m, Semiring s) => WeightedT s m a -> m a evalWeightedT = (coerce :: (StateT s m a -> m a) -> WeightedT s m a -> m a) (`evalStateT` one) {-# INLINE evalWeightedT #-} -- | Run a weighted computation in the underlying monad, and collect its weight. execWeightedT :: (Monad m, Semiring s) => WeightedT s m a -> m s execWeightedT = (coerce :: (StateT s m a -> m s) -> WeightedT s m a -> m s) (`execStateT` one) {-# INLINE execWeightedT #-} -- | Run a weighted computation, and return its result. evalWeighted :: Semiring s => Weighted s a -> a evalWeighted = (coerce :: (State s a -> a) -> Weighted s a -> a) (`evalState` one) {-# INLINE evalWeighted #-} -- | Run a weighted computation, and collect its weight. execWeighted :: Semiring s => Weighted s a -> s execWeighted = (coerce :: (State s a -> s) -> Weighted s a -> s) (`execState` one) {-# INLINE execWeighted #-} instance (Foldable m, Semiring w) => Foldable (WeightedT w m) where foldMap f = foldMap (\(x,_) -> f x) . runWeightedT first_ :: Applicative f => (a -> f b) -> (a, c) -> f (b, c) first_ f (x,y) = flip (,) y <$> f x instance (Traversable m, Semiring w) => Traversable (WeightedT w m) where traverse f x = WeightedT <$> (traverse . first_) f (runWeightedT x) instance (Eq1 m, Eq w, Semiring w) => Eq1 (WeightedT w m) where liftEq eq x y = liftEq (\(xx,xy) (yx,yy) -> eq xx yx && xy == yy) (runWeightedT x) (runWeightedT y) instance (Ord1 m, Ord w, Semiring w) => Ord1 (WeightedT w m) where liftCompare cmp x y = liftCompare (\(xx,xy) (yx,yy) -> cmp xx yx <> compare xy yy) (runWeightedT x) (runWeightedT y) instance (Read w, Read1 m, Semiring w, Functor m) => Read1 (WeightedT w m) where liftReadsPrec rp rl = readsData $ readsUnaryWith (liftReadsPrec rp' rl') "WeightedT" WeightedT where rp' = liftReadsPrec2 rp rl readsPrec readList rl' = liftReadList2 rp rl readsPrec readList instance (Show w, Show1 m, Semiring w) => Show1 (WeightedT w m) where liftShowsPrec sp sl d m = showsUnaryWith (liftShowsPrec sp' sl') "WeightedT" d (runWeightedT m) where sp' = liftShowsPrec2 sp sl showsPrec showList sl' = liftShowList2 sp sl showsPrec showList instance (Eq w, Eq1 m, Eq a, Semiring w) => Eq (WeightedT w m a) where (==) = eq1 instance (Ord w, Ord1 m, Ord a, Semiring w) => Ord (WeightedT w m a) where compare = compare1 instance (Read w, Read1 m, Read a, Semiring w, Functor m) => Read (WeightedT w m a) where readsPrec = readsPrec1 instance (Show w, Show1 m, Show a, Semiring w) => Show (WeightedT w m a) where showsPrec = showsPrec1 instance (Semiring w, Monad m) => MonadWeighted w (WeightedT w m) where weighted (x,s) = WeightedT (pure (x, s)) {-# INLINE weighted #-} weigh (WeightedT_ s) = WeightedT_ ((,) <$> s <*> get) {-# INLINE weigh #-} scale (WeightedT_ s) = WeightedT_ (scaleS s) where scaleS = (=<<) (uncurry (<$) . fmap modify) {-# INLINE scale #-}