module Control.Monad.Weighted.Filter
(FilterT
,pattern FilterT
,runFilterT
,evalFilterT
,execFilterT)
where
import Data.Semiring
import Data.Functor.Classes
import Control.Applicative
import Control.Monad.State.Strict
import Control.Monad.Cont.Class
import Control.Monad.Error.Class
import Control.Monad.Fail
import Control.Monad.Reader.Class
import Control.Monad.Weighted.Class
import Control.Monad.Writer.Class
import Data.Coerce
import Data.Monoid
first :: (a -> b) -> (a, c) -> (b, c)
first f (x, y) = (f x, y)
catchZero
:: (DetectableZero s, Alternative m)
=> (s -> m a) -> s -> m a
catchZero f s
| isZero s = empty
| otherwise = f s
remZeroes
:: (DetectableZero s, Alternative m, Monad m)
=> m (a, s) -> m (a, s)
remZeroes xs =
xs >>=
(\(x,p) ->
if isZero p
then empty
else pure (x, p))
newtype FilterT s m a = FilterT_
{ unFilterT :: StateT s m a
} deriving (MonadTrans,MonadCont,MonadError e,MonadReader r,MonadFix
,MonadFail,MonadIO,Alternative,MonadPlus,MonadWriter w)
instance (Alternative m, DetectableZero s) =>
Functor (FilterT s m) where
fmap f (FilterT_ (StateT st)) =
(FilterT_ . StateT . catchZero) ((fmap . first) f . st)
instance (Alternative m, Monad m, DetectableZero s) =>
Applicative (FilterT s m) where
pure x = (FilterT_ . StateT) (catchZero (pure . (,) x))
FilterT_ (StateT fs) <*> FilterT_ (StateT xs) =
FilterT_ . StateT . catchZero $
\s -> do
(f,s') <- fs s
(x,s'') <- catchZero xs s'
catchZero
(\s''' ->
pure (f x, s'''))
s''
instance (Alternative m, Monad m, DetectableZero s) =>
Monad (FilterT s m) where
FilterT_ (StateT st) >>= f =
FilterT_ . StateT . catchZero $
\s -> do
(x,s') <- st s
(y,s'') <- catchZero (runStateT (unFilterT (f x))) s'
catchZero
(\s''' ->
pure (y, s'''))
s''
runFilterT
:: (DetectableZero s, Alternative m, Monad m)
=> FilterT s m a -> m (a, s)
runFilterT =
remZeroes .
(coerce :: (StateT s m a -> m (a, s)) -> FilterT s m a -> m (a, s))
(`runStateT` one)
evalFilterT
:: (Monad m, Semiring s)
=> FilterT s m a -> m a
evalFilterT =
(coerce :: (StateT s m a -> m a) -> FilterT s m a -> m a)
(`evalStateT` one)
execFilterT
:: (Monad m, Semiring s)
=> FilterT s m a -> m s
execFilterT =
(coerce :: (StateT s m a -> m s) -> FilterT s m a -> m s)
(`execStateT` one)
pattern FilterT :: (Alternative m, DetectableZero s, Monad m) =>
m (a, s) -> FilterT s m a
pattern FilterT x <- (runFilterT -> x)
where FilterT y
= FilterT_ . StateT . catchZero $ \ s -> (fmap . fmap) (s <.>) y
instance (DetectableZero w, Monad m, Alternative m) =>
MonadWeighted w (FilterT w m) where
weighted (x,s) = FilterT (pure (x, s))
weigh (FilterT_ s) = FilterT_ ((,) <$> s <*> get)
scale (FilterT_ s) = FilterT_ (scaleS s)
where
scaleS = (=<<) (uncurry (<$) . fmap modify)
instance (Foldable m, DetectableZero w, Alternative m, Monad m) =>
Foldable (FilterT w m) where
foldMap f =
foldMap
(\(x,p) ->
if isZero p
then mempty
else f x) .
runFilterT
first_
:: Applicative f
=> (a -> f b) -> (a, c) -> f (b, c)
first_ f (x,y) = flip (,) y <$> f x
instance (Traversable m, DetectableZero w, Alternative m, Monad m) =>
Traversable (FilterT w m) where
traverse f x = FilterT <$> (traverse . first_) f (runFilterT x)
instance (Eq1 m, Eq w, DetectableZero w, Monad m, Alternative m) =>
Eq1 (FilterT w m) where
liftEq eq x y =
liftEq
(\(xx,xy) (yx,yy) ->
eq xx yx && xy == yy)
(runFilterT x)
(runFilterT y)
instance (Ord1 m, Ord w, DetectableZero w, Monad m, Alternative m) =>
Ord1 (FilterT w m) where
liftCompare cmp x y =
liftCompare
(\(xx,xy) (yx,yy) ->
cmp xx yx <> compare xy yy)
(runFilterT x)
(runFilterT y)
instance (Read w, Read1 m, DetectableZero w, Alternative m, Monad m) =>
Read1 (FilterT w m) where
liftReadsPrec rp rl =
readsData $ readsUnaryWith (liftReadsPrec rp' rl') "FilterT" FilterT
where
rp' = liftReadsPrec2 rp rl readsPrec readList
rl' = liftReadList2 rp rl readsPrec readList
instance (Show w, Show1 m, DetectableZero w, Monad m, Alternative m) =>
Show1 (FilterT w m) where
liftShowsPrec sp sl d m =
showsUnaryWith (liftShowsPrec sp' sl') "FilterT" d (runFilterT m)
where
sp' = liftShowsPrec2 sp sl showsPrec showList
sl' = liftShowList2 sp sl showsPrec showList
instance (Eq w, Eq1 m, Eq a, DetectableZero w, Monad m, Alternative m) =>
Eq (FilterT w m a) where
(==) = eq1
instance (Ord w, Ord1 m, Ord a, DetectableZero w, Monad m, Alternative m) =>
Ord (FilterT w m a) where
compare = compare1
instance (Read w, Read1 m, Read a, DetectableZero w, Alternative m, Monad m) =>
Read (FilterT w m a) where
readsPrec = readsPrec1
instance (Show w, Show1 m, Show a, DetectableZero w, Alternative m, Monad m) =>
Show (FilterT w m a) where
showsPrec = showsPrec1
instance (Alternative m, MonadState s m, DetectableZero w) =>
MonadState s (FilterT w m) where
get = lift get
put = lift . put
state = lift . state