-- | This is a nondeterminism monad which allows you to give computations
-- weights, such that the lowest-weight computations will be returned first.
-- This allows you to search infinite spaces productively, by guarding
-- recursive calls with weights.  Example:
--
-- > import qualified Control.Monad.WeightedSearch as W
-- > import Control.Applicative
-- > 
-- > -- All naturals, weighted by the size of the number
-- > naturals :: W.T Integer Integer
-- > naturals = go 0
-- >     where
-- >     go n = pure n <|> W.weight 1 (go $! n+1)
-- > 
-- > -- All finite lists, weighted by the length of the list
-- > finiteLists :: W.T Integer a -> W.T Integer a
-- > finiteLists = pure [] <|> W.weight 1 ((:) <$> w <*> finiteLists w)
-- > 
-- > -- A list of all finite lists of naturals
-- > finiteListsOfNaturals = W.toList (finiteLists naturals)
-- >    -- [ [], [0], [0,0], [1], [0,0,0], [0,1], [1,0], [2], [0,0,0,0], [0,0,1], ... ]
--
-- Weights must be strictly positive for this to be well-defined.

module Control.Monad.WeightedSearch 
    ( T, Weight(..), weight, toList )
where

import Control.Applicative
import Control.Monad (ap, MonadPlus(..))
import Control.Arrow (first)
import Data.Ratio (Ratio)
import Data.Foldable (Foldable, foldMap, toList)
import Data.Traversable (Traversable, sequenceA)
import Data.Monoid (Monoid(..))

-- | Weighted nondeterminstic computations over the weight @w@.  
data T w a
    = Fail
    | Yield a (T w a)
    | Weight w (T w a)

-- | The class of positive weights. We need to know how to subtract.  Weights
-- must be strictly positive.
class (Ord w) => Weight w where
    difference :: w -> w -> w

-- | Take a positive weight and weight a computation with it.
weight :: w -> T w a -> T w a
weight = Weight

instance Weight Int where difference = (-)
instance Weight Integer where difference = (-)
instance Weight Float where difference = (-)
instance Weight Double where difference = (-)
instance (Integral a) => Weight (Ratio a) where difference = (-)

instance Functor (T w) where
    fmap _ Fail = Fail
    fmap f (Yield x w) = Yield (f x) (fmap f w)
    fmap f (Weight a w) = Weight a (fmap f w)

instance (Weight w) => Monad (T w) where
    return x = Yield x Fail
    Fail >>= _ = Fail
    Yield x m >>= f = f x `mplus` (m >>= f)
    Weight w m >>= f = Weight w (m >>= f)

instance (Weight w) => MonadPlus (T w) where
    mzero = Fail
    Fail `mplus` m = m
    Yield x m `mplus` n = Yield x (m `mplus` n)
    Weight w m `mplus` Fail = Weight w m
    Weight w m `mplus` Yield x n = Yield x (Weight w m `mplus` n)
    Weight w m `mplus` Weight w' n
        = case compare w w' of
            LT -> Weight w (m `mplus` Weight (difference w' w) n)
            EQ -> Weight w (m `mplus` n)
            GT -> Weight w' (Weight (difference w w') m `mplus` n)

instance (Weight w) => Applicative (T w) where
    pure = return
    (<*>) = ap

instance (Weight w) => Alternative (T w) where
    empty = mzero
    (<|>) = mplus

instance Foldable (T w) where
    foldMap _ Fail = mempty
    foldMap f (Yield a ms) = f a `mappend` foldMap f ms
    foldMap f (Weight _ w) = foldMap f w

instance Traversable (T w) where
    sequenceA Fail = pure Fail
    sequenceA (Yield x w) = Yield <$> x <*> sequenceA w
    sequenceA (Weight a w) = Weight a <$> sequenceA w