{-# LANGUAGE TupleSections, Rank2Types, NoMonomorphismRestriction #-}
module Numeric.MaxEnt.Moment (
        ExpectationConstraint,
        (.=.),
        ExpectationFunction,
        average,
        variance,
        maxent,
        UU(..)
    ) where
import Numeric.Optimization.Algorithms.HagerZhang05
import qualified Data.Vector.Unboxed as U
import qualified Data.Vector.Storable as S
import Numeric.AD
import GHC.IO                   (unsafePerformIO)
import Data.Traversable
import Numeric.AD.Types
import Numeric.AD.Internal.Classes
import Data.List (transpose)
import Control.Applicative
import Numeric.MaxEnt.ConjugateGradient
import Data.List (foldl')
--import Data.Vector

-- | Constraint type. A function and the constant it equals.
-- 
--   Think of it as the pair @(f, c)@ in the constraint 
--
-- @
--     Σ pₐ f(xₐ) = c
-- @
--
--  such that we are summing over all values .
--
--  For example, for a variance constraint the @f@ would be @(\\x -> x*x)@ and @c@ would be the variance.
type ExpectationConstraint a = (UU a, a)

--
infixr 1 .=.
(.=.) :: (forall s. Mode s => AD s a -> AD s a) -> a -> ExpectationConstraint a
f .=. c = (UU f, c)

-- | A function that takes an index and value and returns a value.
--   See 'average' and 'variance' for examples.
type ExpectationFunction a = (a -> a)

newtype UU a = UU {unUU :: forall s. Mode s => ExpectationFunction (AD s a) }

-- The average constraint
average :: Num a => a -> ExpectationConstraint a
average m = id .=. m

-- The variance constraint
variance :: Num a => a -> ExpectationConstraint a
variance sigma = (^(2 :: Int)) .=. sigma

pOfK :: [Double] -> [ExpectationFunction Double] -> S.Vector Double -> Int -> Double
pOfK values fs ls k = 
    exp (negate . sum . zipWith (\l f -> l * f (values !! k)) lsList $ fs) / 
        (partitionFunc values fs lsList) where
            lsList = S.toList ls

probs values fs ls = S.map (pOfK values fs ls) . S.enumFromN 0 $ length values 

partitionFunc values fs ls = sum $ [ exp ((-l) * f x) | 
                                x <- values, 
                                (f, l) <- zip fs ls]

objectiveFunc fs moments values ls = 
    log (partitionFunc values fs ls) + (sum $ zipWith (\x y -> x * y) ls moments)

-- My thoughts are that I should maybe split this up
-- | Discrete maximum entropy solver where the constraints are all moment constraints. 
maxent :: Double 
       -- ^ Tolerance for the numerical solver
       -> [Double]
       -- ^ values that the distributions is over
       -> [ExpectationConstraint Double]
       -- ^ The constraints
       -> Either (Result, Statistics) (S.Vector Double) 
       -- ^ Either the a discription of what wrong or the probability distribution 
maxent tolerance values constraints = result where
    obj = objectiveFunc (map unUU fs') (map auto moments) (map auto values)
    
    count = length fs
        
    (fs', moments) = unzip constraints 
    
    fs = map (\x -> lowerUU $ unUU x) fs'
    
    guess = U.fromList $ replicate count (1.0 / fromIntegral count :: Double) 
    
    result =  probs values fs <$> minimize tolerance count obj