{-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE UndecidableInstances #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE ViewPatterns #-} module Control.Monad.Weighted.Filter (catchZero ,FilterT ,pattern FilterT ,runFilterT ,evalFilterT ,execFilterT) where import Data.Semiring import Data.Functor.Classes import Control.Applicative import Control.Monad.State.Strict import Control.Monad.Cont.Class import Control.Monad.Error.Class import Control.Monad.Fail import Control.Monad.Reader.Class import Control.Monad.Weighted.Class import Control.Monad.Writer.Class import Control.Arrow (first) import Data.Coerce import Data.Monoid catchZero :: (DetectableZero s, Alternative m) => (s -> m a) -> s -> m a catchZero f s | isZero s = empty | otherwise = f s remZeroes :: (DetectableZero s, Alternative m, Monad m) => m (a, s) -> m (a, s) remZeroes xs = xs >>= (\(x,p) -> if isZero p then empty else pure (x,p)) -- | Discards results which are zero newtype FilterT s m a = FilterT_ { unFilterT :: StateT s m a } deriving (MonadTrans,MonadCont,MonadError e ,MonadReader r,MonadFix,MonadFail,MonadIO,Alternative,MonadPlus ,MonadWriter w) instance (Alternative m, DetectableZero s) => Functor (FilterT s m) where fmap f (FilterT_ (StateT st)) = (FilterT_ . StateT . catchZero) ((fmap . first) f . st) {-# INLINE fmap #-} instance (Alternative m, Monad m, DetectableZero s) => Applicative (FilterT s m) where pure x = (FilterT_ . StateT) (catchZero (pure . (,) x)) {-# INLINE pure #-} FilterT_ (StateT fs) <*> FilterT_ (StateT xs) = FilterT_ . StateT . catchZero $ \s -> do (f,s') <- fs s (x,s'') <- catchZero xs s' catchZero (\s''' -> pure (f x, s''')) s'' {-# INLINE (<*>) #-} instance (Alternative m, Monad m, DetectableZero s) => Monad (FilterT s m) where FilterT_ (StateT st) >>= f = FilterT_ . StateT . catchZero $ \s -> do (x,s') <- st s (y,s'') <- catchZero (runStateT (unFilterT (f x))) s' catchZero (\s''' -> pure (y, s''')) s'' {-# INLINE (>>=) #-} runFilterT :: (DetectableZero s, Alternative m, Monad m) => FilterT s m a -> m (a, s) runFilterT = remZeroes . (coerce :: (StateT s m a -> m (a, s)) -> FilterT s m a -> m (a, s)) (`runStateT` one) {-# INLINE runFilterT #-} evalFilterT :: (Monad m, Semiring s) => FilterT s m a -> m a evalFilterT = (coerce :: (StateT s m a -> m a) -> FilterT s m a -> m a) (`evalStateT` one) {-# INLINE evalFilterT #-} execFilterT :: (Monad m, Semiring s) => FilterT s m a -> m s execFilterT = (coerce :: (StateT s m a -> m s) -> FilterT s m a -> m s) (`execStateT` one) {-# INLINE execFilterT #-} pattern FilterT :: (Alternative m, DetectableZero s, Monad m) => m (a, s) -> FilterT s m a pattern FilterT x <- (runFilterT -> x) where FilterT y = FilterT_ . StateT . catchZero $ \s -> (fmap.fmap) (s<.>) y instance (DetectableZero w, Monad m, Alternative m) => MonadWeighted w (FilterT w m) where weighted (x,s) = FilterT (pure (x, s)) {-# INLINE weighted #-} weigh (FilterT_ s) = FilterT_ ((,) <$> s <*> get) {-# INLINE weigh #-} scale (FilterT_ s) = FilterT_ (scaleS s) where scaleS = (=<<) (uncurry (<$) . fmap modify) {-# INLINE scale #-} instance (Foldable m, DetectableZero w, Alternative m, Monad m) => Foldable (FilterT w m) where foldMap f = foldMap (\(x,p) -> if isZero p then mempty else f x) . runFilterT first_ :: Applicative f => (a -> f b) -> (a, c) -> f (b, c) first_ f (x,y) = flip (,) y <$> f x instance (Traversable m, DetectableZero w, Alternative m, Monad m) => Traversable (FilterT w m) where traverse f x = FilterT <$> (traverse . first_) f (runFilterT x) instance (Eq1 m, Eq w, DetectableZero w, Monad m, Alternative m) => Eq1 (FilterT w m) where liftEq eq x y = liftEq (\(xx,xy) (yx,yy) -> eq xx yx && xy == yy) (runFilterT x) (runFilterT y) instance (Ord1 m, Ord w, DetectableZero w, Monad m, Alternative m) => Ord1 (FilterT w m) where liftCompare cmp x y = liftCompare (\(xx,xy) (yx,yy) -> cmp xx yx <> compare xy yy) (runFilterT x) (runFilterT y) instance (Read w, Read1 m, DetectableZero w, Alternative m, Monad m) => Read1 (FilterT w m) where liftReadsPrec rp rl = readsData $ readsUnaryWith (liftReadsPrec rp' rl') "FilterT" FilterT where rp' = liftReadsPrec2 rp rl readsPrec readList rl' = liftReadList2 rp rl readsPrec readList instance (Show w, Show1 m, DetectableZero w, Monad m, Alternative m) => Show1 (FilterT w m) where liftShowsPrec sp sl d m = showsUnaryWith (liftShowsPrec sp' sl') "FilterT" d (runFilterT m) where sp' = liftShowsPrec2 sp sl showsPrec showList sl' = liftShowList2 sp sl showsPrec showList instance (Eq w, Eq1 m, Eq a, DetectableZero w, Monad m, Alternative m) => Eq (FilterT w m a) where (==) = eq1 instance (Ord w, Ord1 m, Ord a, DetectableZero w, Monad m, Alternative m) => Ord (FilterT w m a) where compare = compare1 instance (Read w, Read1 m, Read a, DetectableZero w, Alternative m, Monad m) => Read (FilterT w m a) where readsPrec = readsPrec1 instance (Show w, Show1 m, Show a, DetectableZero w, Alternative m, Monad m) => Show (FilterT w m a) where showsPrec = showsPrec1 instance (Alternative m, MonadState s m, DetectableZero w) => MonadState s (FilterT w m) where get = lift get put = lift . put state = lift . state