module Control.Monad.WeightedSearch
( T, Weight(..), weight, toList )
where
import Control.Applicative
import Control.Monad (ap, MonadPlus(..))
import Control.Arrow (first)
import Data.Ratio (Ratio)
import Data.Foldable (Foldable, foldMap, toList)
import Data.Traversable (Traversable, sequenceA)
import Data.Monoid (Monoid(..))
data T w a
= Fail
| Yield a (T w a)
| Weight w (T w a)
class (Ord w) => Weight w where
difference :: w -> w -> w
weight :: w -> T w a -> T w a
weight = Weight
instance Weight Int where difference = ()
instance Weight Integer where difference = ()
instance Weight Float where difference = ()
instance Weight Double where difference = ()
instance (Integral a) => Weight (Ratio a) where difference = ()
instance Functor (T w) where
fmap _ Fail = Fail
fmap f (Yield x w) = Yield (f x) (fmap f w)
fmap f (Weight a w) = Weight a (fmap f w)
instance (Weight w) => Monad (T w) where
return x = Yield x Fail
Fail >>= _ = Fail
Yield x m >>= f = f x `mplus` (m >>= f)
Weight w m >>= f = Weight w (m >>= f)
instance (Weight w) => MonadPlus (T w) where
mzero = Fail
Fail `mplus` m = m
Yield x m `mplus` n = Yield x (m `mplus` n)
Weight w m `mplus` Fail = Weight w m
Weight w m `mplus` Yield x n = Yield x (Weight w m `mplus` n)
Weight w m `mplus` Weight w' n
= case compare w w' of
LT -> Weight w (m `mplus` Weight (difference w' w) n)
EQ -> Weight w (m `mplus` n)
GT -> Weight w' (Weight (difference w w') m `mplus` n)
instance (Weight w) => Applicative (T w) where
pure = return
(<*>) = ap
instance (Weight w) => Alternative (T w) where
empty = mzero
(<|>) = mplus
instance Foldable (T w) where
foldMap _ Fail = mempty
foldMap f (Yield a ms) = f a `mappend` foldMap f ms
foldMap f (Weight _ w) = foldMap f w
instance Traversable (T w) where
sequenceA Fail = pure Fail
sequenceA (Yield x w) = Yield <$> x <*> sequenceA w
sequenceA (Weight a w) = Weight a <$> sequenceA w