{-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE Strict #-} {-# LANGUAGE UndecidableInstances #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE ViewPatterns #-} -- | This module provides a weighted monad which filters out zero-weighted -- results from a computation at every opportunity. module Control.Monad.Weighted.Filter (FilterT ,pattern FilterT ,runFilterT ,evalFilterT ,execFilterT) where import Data.Semiring import Data.Functor.Classes import Control.Applicative import Control.Monad.State.Strict import Control.Monad.Cont.Class import Control.Monad.Error.Class import Control.Monad.Fail import Control.Monad.Reader.Class import Control.Monad.Weighted.Class import Control.Monad.Writer.Class import Data.Coerce import Data.Monoid first :: (a -> b) -> (a, c) -> (b, c) first f (x, y) = (f x, y) catchZero :: (DetectableZero s, Alternative m) => (s -> m a) -> s -> m a catchZero f s | isZero s = empty | otherwise = f s {-# INLINE catchZero #-} remZeroes :: (DetectableZero s, Alternative m, Monad m) => m (a, s) -> m (a, s) remZeroes xs = xs >>= (\(x,p) -> if isZero p then empty else pure (x, p)) -- | A weighted monad which discards results which are zero as it goes. newtype FilterT s m a = FilterT_ { unFilterT :: StateT s m a } deriving (MonadTrans,MonadCont,MonadError e,MonadReader r,MonadFix ,MonadFail,MonadIO,Alternative,MonadPlus,MonadWriter w) instance (Alternative m, DetectableZero s) => Functor (FilterT s m) where fmap f (FilterT_ (StateT st)) = (FilterT_ . StateT . catchZero) ((fmap . first) f . st) {-# INLINE fmap #-} instance (Alternative m, Monad m, DetectableZero s) => Applicative (FilterT s m) where pure x = (FilterT_ . StateT) (catchZero (pure . (,) x)) {-# INLINE pure #-} FilterT_ (StateT fs) <*> FilterT_ (StateT xs) = FilterT_ . StateT . catchZero $ \s -> do (f,s') <- fs s (x,s'') <- catchZero xs s' catchZero (\s''' -> pure (f x, s''')) s'' {-# INLINE (<*>) #-} instance (Alternative m, Monad m, DetectableZero s) => Monad (FilterT s m) where FilterT_ (StateT st) >>= f = FilterT_ . StateT . catchZero $ \s -> do (x,s') <- st s (y,s'') <- catchZero (runStateT (unFilterT (f x))) s' catchZero (\s''' -> pure (y, s''')) s'' {-# INLINE (>>=) #-} -- | Run a filtered computation in the underlying monad. runFilterT :: (DetectableZero s, Alternative m, Monad m) => FilterT s m a -> m (a, s) runFilterT = remZeroes . (coerce :: (StateT s m a -> m (a, s)) -> FilterT s m a -> m (a, s)) (`runStateT` one) {-# INLINE runFilterT #-} -- | Evaluate a filtered computation in the underlying monad and return its result. evalFilterT :: (Monad m, Semiring s) => FilterT s m a -> m a evalFilterT = (coerce :: (StateT s m a -> m a) -> FilterT s m a -> m a) (`evalStateT` one) {-# INLINE evalFilterT #-} -- | Evaluate a filtered computation in the underlying monad and collect its weight. execFilterT :: (Monad m, Semiring s) => FilterT s m a -> m s execFilterT = (coerce :: (StateT s m a -> m s) -> FilterT s m a -> m s) (`execStateT` one) {-# INLINE execFilterT #-} -- | This pattern gives an interface to the 'FilterT' monad which makes it look as if -- it were defined without the state monad. pattern FilterT :: (Alternative m, DetectableZero s, Monad m) => m (a, s) -> FilterT s m a pattern FilterT x <- (runFilterT -> x) where FilterT y = FilterT_ . StateT . catchZero $ \ s -> (fmap . fmap) (s <.>) y instance (DetectableZero w, Monad m, Alternative m) => MonadWeighted w (FilterT w m) where weighted (x,s) = FilterT (pure (x, s)) {-# INLINE weighted #-} weigh (FilterT_ s) = FilterT_ ((,) <$> s <*> get) {-# INLINE weigh #-} scale (FilterT_ s) = FilterT_ (scaleS s) where scaleS = (=<<) (uncurry (<$) . fmap modify) {-# INLINE scale #-} instance (Foldable m, DetectableZero w, Alternative m, Monad m) => Foldable (FilterT w m) where foldMap f = foldMap (\(x,p) -> if isZero p then mempty else f x) . runFilterT first_ :: Applicative f => (a -> f b) -> (a, c) -> f (b, c) first_ f (x,y) = flip (,) y <$> f x instance (Traversable m, DetectableZero w, Alternative m, Monad m) => Traversable (FilterT w m) where traverse f x = FilterT <$> (traverse . first_) f (runFilterT x) instance (Eq1 m, Eq w, DetectableZero w, Monad m, Alternative m) => Eq1 (FilterT w m) where liftEq eq x y = liftEq (\(xx,xy) (yx,yy) -> eq xx yx && xy == yy) (runFilterT x) (runFilterT y) instance (Ord1 m, Ord w, DetectableZero w, Monad m, Alternative m) => Ord1 (FilterT w m) where liftCompare cmp x y = liftCompare (\(xx,xy) (yx,yy) -> cmp xx yx <> compare xy yy) (runFilterT x) (runFilterT y) instance (Read w, Read1 m, DetectableZero w, Alternative m, Monad m) => Read1 (FilterT w m) where liftReadsPrec rp rl = readsData $ readsUnaryWith (liftReadsPrec rp' rl') "FilterT" FilterT where rp' = liftReadsPrec2 rp rl readsPrec readList rl' = liftReadList2 rp rl readsPrec readList instance (Show w, Show1 m, DetectableZero w, Monad m, Alternative m) => Show1 (FilterT w m) where liftShowsPrec sp sl d m = showsUnaryWith (liftShowsPrec sp' sl') "FilterT" d (runFilterT m) where sp' = liftShowsPrec2 sp sl showsPrec showList sl' = liftShowList2 sp sl showsPrec showList instance (Eq w, Eq1 m, Eq a, DetectableZero w, Monad m, Alternative m) => Eq (FilterT w m a) where (==) = eq1 instance (Ord w, Ord1 m, Ord a, DetectableZero w, Monad m, Alternative m) => Ord (FilterT w m a) where compare = compare1 instance (Read w, Read1 m, Read a, DetectableZero w, Alternative m, Monad m) => Read (FilterT w m a) where readsPrec = readsPrec1 instance (Show w, Show1 m, Show a, DetectableZero w, Alternative m, Monad m) => Show (FilterT w m a) where showsPrec = showsPrec1 instance (Alternative m, MonadState s m, DetectableZero w) => MonadState s (FilterT w m) where get = lift get put = lift . put state = lift . state