{-# LANGUAGE TupleSections, Rank2Types, NoMonomorphismRestriction #-} module Numeric.MaxEnt.Moment ( ExpectationConstraint, (.=.), ExpectationFunction, average, variance, maxent, UU(..) ) where import Numeric.Optimization.Algorithms.HagerZhang05 import qualified Data.Vector.Unboxed as U import qualified Data.Vector.Storable as S import Numeric.AD import GHC.IO (unsafePerformIO) import Data.Traversable import Numeric.AD.Types import Numeric.AD.Internal.Classes import Data.List (transpose) import Control.Applicative import Numeric.MaxEnt.ConjugateGradient import Data.List (foldl') --import Data.Vector -- | Constraint type. A function and the constant it equals. -- -- Think of it as the pair @(f, c)@ in the constraint -- -- @ -- Σ pₐ f(xₐ) = c -- @ -- -- such that we are summing over all values . -- -- For example, for a variance constraint the @f@ would be @(\\x -> x*x)@ and @c@ would be the variance. type ExpectationConstraint a = (UU a, a) -- infixr 1 .=. (.=.) :: (forall s. Mode s => AD s a -> AD s a) -> a -> ExpectationConstraint a f .=. c = (UU f, c) -- | A function that takes an index and value and returns a value. -- See 'average' and 'variance' for examples. type ExpectationFunction a = (a -> a) newtype UU a = UU {unUU :: forall s. Mode s => ExpectationFunction (AD s a) } -- The average constraint average :: Num a => a -> ExpectationConstraint a average m = id .=. m -- The variance constraint variance :: Num a => a -> ExpectationConstraint a variance sigma = (^(2 :: Int)) .=. sigma --partialPart' ls fs x = exp . negate . S.sum . S.zipWith (\l f -> l * f x) ls $ fs --partitionFunc' values fs ls = S.sum . S.map (partialPart' ls fs) $ values probs values fs ls = result where lsList = S.toList ls norm = partitionFunc values fs lsList result = S.map (\x -> partialPart lsList fs x / norm) $ S.fromList values partialPart ls fs x = exp . negate . sum . zipWith (\l f -> l * f x) ls $ fs partitionFunc values fs ls = sum . map (partialPart ls fs) $ values objectiveFunc fs moments values ls = log (partitionFunc values fs ls) + (sum $ zipWith (*) ls moments) -- | Discrete maximum entropy solver where the constraints are all moment constraints. maxent :: Double -- ^ Tolerance for the numerical solver -> [Double] -- ^ values that the distributions is over -> [ExpectationConstraint Double] -- ^ The constraints -> Either (Result, Statistics) (S.Vector Double) -- ^ Either the a discription of what wrong or the probability distribution maxent tolerance values constraints = result where obj = objectiveFunc (map unUU fs') (map auto moments) (map auto values) count = length fs (fs', moments) = unzip constraints fs = map (\x -> lowerUU $ unUU x) fs' guess = U.fromList $ replicate count (1.0 / fromIntegral count :: Double) result = probs values fs <$> minimize tolerance count obj