---------------------------------------------------------------------------
-- | Module    : Math.Statistics.Dirichlet.Density
-- Copyright   : (c) 2009-2012 Felipe Lessa
--
-- Maintainer  : felipe.lessa@gmail.com
-- Stability   : experimental
-- Portability : portable
--
--------------------------------------------------------------------------

module Math.Statistics.Dirichlet.Density
( DirichletDensity(..)
, empty
, fromList
, toList
, derive
, cost
) where

import qualified Data.Vector as V
import qualified Data.Vector.Generic as G
import qualified Data.Vector.Unboxed as U

import Control.DeepSeq (NFData(..))
import Numeric.GSL.Special.Gamma (lngamma)
import Numeric.GSL.Special.Psi (psi)

import Math.Statistics.Dirichlet.Options
import Math.Statistics.Dirichlet.Util

-- | A Dirichlet density.
newtype DirichletDensity = DD {unDD :: U.Vector Double} deriving (Eq)

instance Show DirichletDensity where
showsPrec prec (DD v) =
showParen (prec > 10) \$
showString "fromList " .
showsPrec 11 (U.toList v)

readsPrec p ('(':xs) = let (ys,')':zs) = break (== ')') xs
in map (\(x,s) -> (x,s++zs)) \$
readsPrec p xs = let [("fromList",list)] = lex xs
in map (\(x,s) -> (fromList x,s)) \$

instance NFData DirichletDensity where
rnf DD {} = ()

-- | @empty n x@ is an \"empty\" Dirichlet density with size
-- @n@ and all alphas set to @x@.
empty :: Int -> Double -> DirichletDensity
empty = (DD .) . U.replicate
{-# INLINE empty #-}

-- | @fromList xs@ constructs a Dirichlet density from a list of
-- alpha values.
fromList :: [Double] -> DirichletDensity
fromList = DD . U.fromList
{-# INLINE fromList #-}

-- | @toList d@ deconstructs a Dirichlet density to a list of
-- alpha values.
toList :: DirichletDensity -> [Double]
toList (DD xs) = U.toList xs
{-# INLINE toList #-}

-- | Derive a Dirichlet density using a maximum likelihood method
-- as described by Karplus et al (equation 26).  All training
-- vectors should have the same length, however this is not
-- verified.
derive :: DirichletDensity -> Predicate -> StepSize
-> TrainingVectors -> Result DirichletDensity
derive (DD initial) (Pred maxIter' minDelta_ deltaSteps' _ _)
(Step step) trainingData
| V.length trainingData == 0 = err "empty training data"
| U.length initial < 1       = err "empty initial vector"
| maxIter' < 1               = err "non-positive maxIter"
| minDelta_ < 0              = err "negative minDelta"
| deltaSteps' < 1            = err "non-positive deltaSteps"
| step <= 0                  = err "non-positive step"
| step >= 1                  = err "step greater than one"
| otherwise                  = train
where
err = error . ("Dirichlet.derive: " ++)

-- Compensate the different deltaSteps.
!minDelta'    = minDelta_ * fromIntegral deltaSteps'

-- Number of training sequences.
!trainingSize = fromIntegral \$ V.length trainingData

-- Sums of each training sequence.
trainingSums :: U.Vector Double
!trainingSums = G.unstream \$ G.stream \$ V.map U.sum trainingData

-- Functions that work on the alphas only (and not their logs).
calcSumAs = U.sum . snd . U.unzip
finish    = DD    . snd . U.unzip

-- Start training in the zero-th iteration and with
-- infinite inital cost.
train = train' 1 infinity (U.sum initial) \$
U.map (\x -> (log x, x)) initial

train' !iter !oldCost !sumAs !alphas =
-- Reestimate alpha's.
let !alphas'  = U.imap calculateAlphas alphas
!psiSumAs = psi sumAs
!psiSums  = U.sum \$ U.map (\sumT -> psi \$ sumT + sumAs) trainingSums
calculateAlphas !i (!w, !a) =
let !s1 = trainingSize * (psiSumAs - psi a)
!s2 = V.sum \$ V.map (\t -> psi \$ t U.! i + a) trainingData
!w' = w + step * a * (s1 + s2 - psiSums)
!a' = exp w'
in (w', a')

-- Recalculate constants.
!sumAs'   = calcSumAs alphas'
!calcCost = iter `mod` deltaSteps' == 0
!cost'    = if calcCost then newCost else oldCost
where newCost = costWorker (snd \$ U.unzip alphas') sumAs'
trainingData trainingSums
!delta    = abs (cost' - oldCost)

-- Verify convergence.  Even with MaxIter we only stop
-- iterating if the delta was calculated.  Otherwise we
-- wouldn't be able to tell the caller why the delta was
-- still big when we reached the limit.
in case (calcCost, delta <= minDelta', iter >= maxIter') of
(True, True, _) -> Result Delta   iter delta cost' \$ finish alphas'
(True, _, True) -> Result MaxIter iter delta cost' \$ finish alphas'
_               -> train' (iter+1) cost' sumAs' alphas'

-- | Cost function for deriving a Dirichlet density (equation
-- 18).  This function is minimized by 'derive'.
cost :: TrainingVectors -> DirichletDensity -> Double
cost tv (DD arr) = costWorker arr (U.sum arr) tv \$
G.unstream \$ G.stream \$ V.map U.sum tv

-- | 'cost' needs to calculate the sum of all training vectors.
-- This functios avoids recalculting this quantity in 'derive'
-- multiple times.  This is the used by both 'cost' and 'derive'.
costWorker :: U.Vector Double -> Double -> TrainingVectors -> U.Vector Double -> Double
costWorker !alphas !sumAs !trainingData !trainingSums =
let !lngammaSumAs = lngamma sumAs
f t = U.sum \$ U.zipWith w t alphas
where w t_i a_i = lngamma (t_i + a_i) - lngamma (t_i + 1) - lngamma a_i
g sumT = lngamma (sumT+1) - lngamma (sumT + sumAs)
in negate \$ (V.sum \$ V.map f trainingData)
+ (U.sum \$ U.map g trainingSums)
+ lngammaSumAs * fromIntegral (U.length trainingSums)
{-# INLINE costWorker #-}