module Control.Monad.Weighted
(
WeightedT
,runWeightedT
,pattern WeightedT
,execWeightedT
,evalWeightedT
,
Weighted
,runWeighted
,pattern Weighted
,execWeighted
,evalWeighted)
where
import Control.Applicative
import Control.Monad.Fail
import Control.Monad.Identity
import Control.Monad.Reader.Class
import Control.Monad.Weighted.Class
import Control.Monad.Writer.Class
import Control.Monad.Cont.Class
import Control.Monad.Error.Class
import Control.Monad.State.Strict
import Data.Coerce
import Data.Functor.Classes
import Data.Monoid
import Data.Semiring
newtype WeightedT s m a =
WeightedT_ (StateT s m a)
deriving (Functor,Applicative,Monad,MonadTrans,MonadCont
,MonadError e,MonadReader r,MonadFix,MonadFail
,MonadIO,Alternative,MonadPlus,MonadWriter w)
runWeightedT
:: Semiring s
=> WeightedT s m a -> m (a, s)
runWeightedT =
(coerce :: (StateT s m a -> m (a, s)) -> WeightedT s m a -> m (a, s))
(`runStateT` one)
pattern WeightedT :: (Functor m, Semiring s) =>
m (a, s) -> WeightedT s m a
pattern WeightedT x <- (runWeightedT -> x)
where WeightedT y
= WeightedT_ (StateT (\ s -> fmap (\ (x, p) -> (x, s <.> p)) y))
type Weighted s = WeightedT s Identity
pattern Weighted :: Semiring s => (a, s) -> Weighted s a
pattern Weighted x <- (runWeighted -> x)
where Weighted (y, p)
= WeightedT_ (StateT (\ s -> Identity (y, (<.>) p s)))
runWeighted
:: Semiring s
=> Weighted s a -> (a, s)
runWeighted =
(coerce :: (WeightedT s Identity a -> Identity (a, s)) -> (WeightedT s Identity a -> (a, s)))
runWeightedT
instance MonadState s m =>
MonadState s (WeightedT w m) where
get = lift get
put = lift . put
state = lift . state
evalWeightedT
:: (Monad m, Semiring s)
=> WeightedT s m a -> m a
evalWeightedT =
(coerce :: (StateT s m a -> m a) -> WeightedT s m a -> m a)
(`evalStateT` one)
execWeightedT
:: (Monad m, Semiring s)
=> WeightedT s m a -> m s
execWeightedT =
(coerce :: (StateT s m a -> m s) -> WeightedT s m a -> m s)
(`execStateT` one)
evalWeighted
:: Semiring s
=> Weighted s a -> a
evalWeighted =
(coerce :: (State s a -> a) -> Weighted s a -> a) (`evalState` one)
execWeighted
:: Semiring s
=> Weighted s a -> s
execWeighted =
(coerce :: (State s a -> s) -> Weighted s a -> s) (`execState` one)
instance (Foldable m, Semiring w) =>
Foldable (WeightedT w m) where
foldMap f =
foldMap
(\(x,_) ->
f x) .
runWeightedT
first_
:: Applicative f
=> (a -> f b) -> (a, c) -> f (b, c)
first_ f (x,y) = flip (,) y <$> f x
instance (Traversable m, Semiring w) =>
Traversable (WeightedT w m) where
traverse f x = WeightedT <$> (traverse . first_) f (runWeightedT x)
instance (Eq1 m, Eq w, Semiring w) =>
Eq1 (WeightedT w m) where
liftEq eq x y =
liftEq
(\(xx,xy) (yx,yy) ->
eq xx yx && xy == yy)
(runWeightedT x)
(runWeightedT y)
instance (Ord1 m, Ord w, Semiring w) =>
Ord1 (WeightedT w m) where
liftCompare cmp x y =
liftCompare
(\(xx,xy) (yx,yy) ->
cmp xx yx <> compare xy yy)
(runWeightedT x)
(runWeightedT y)
instance (Read w, Read1 m, Semiring w, Functor m) =>
Read1 (WeightedT w m) where
liftReadsPrec rp rl =
readsData $
readsUnaryWith (liftReadsPrec rp' rl') "WeightedT" WeightedT
where
rp' = liftReadsPrec2 rp rl readsPrec readList
rl' = liftReadList2 rp rl readsPrec readList
instance (Show w, Show1 m, Semiring w) =>
Show1 (WeightedT w m) where
liftShowsPrec sp sl d m =
showsUnaryWith (liftShowsPrec sp' sl') "WeightedT" d (runWeightedT m)
where
sp' = liftShowsPrec2 sp sl showsPrec showList
sl' = liftShowList2 sp sl showsPrec showList
instance (Eq w, Eq1 m, Eq a, Semiring w) =>
Eq (WeightedT w m a) where
(==) = eq1
instance (Ord w, Ord1 m, Ord a, Semiring w) =>
Ord (WeightedT w m a) where
compare = compare1
instance (Read w, Read1 m, Read a, Semiring w, Functor m) =>
Read (WeightedT w m a) where
readsPrec = readsPrec1
instance (Show w, Show1 m, Show a, Semiring w) =>
Show (WeightedT w m a) where
showsPrec = showsPrec1
instance (Semiring w, Monad m) =>
MonadWeighted w (WeightedT w m) where
weighted (x,s) = WeightedT (pure (x, s))
weigh (WeightedT_ s) = WeightedT_ ((,) <$> s <*> get)
scale (WeightedT_ s) = WeightedT_ (scaleS s)
where
scaleS = (=<<) (uncurry (<$) . fmap modify)