{-# LANGUAGE TypeFamilies, FlexibleInstances, ConstraintKinds, DeriveGeneric, DefaultSignatures #-} module BayesStack.DirMulti ( -- * Dirichlet/multinomial pair Multinom, dirMulti, symDirMulti, multinom -- | Do not do record updates with these , dmTotal, dmAlpha, dmDomain , setMultinom, SetUnset (..) , decMultinom, incMultinom , prettyMultinom , updatePrior -- * Parameter estimation , estimatePrior, reestimatePriors, reestimateSymPriors -- * Convenience functions , probabilities, decProbabilities ) where import Data.EnumMap (EnumMap) import qualified Data.EnumMap as EM import Data.Sequence (Seq) import qualified Data.Sequence as SQ import qualified Data.Foldable import Data.Foldable (toList, Foldable, foldMap) import Data.Function (on) import Text.PrettyPrint import Text.Printf import GHC.Generics (Generic) import Data.Serialize import Data.Serialize.EnumMap () import Data.Serialize.LogFloat () import BayesStack.Core import BayesStack.Dirichlet import Data.Number.LogFloat hiding (realToFrac, isNaN, isInfinite) import Numeric.Digamma import Math.Gamma hiding (p) -- | Make error handling a bit easier checkNaN :: RealFloat a => String -> a -> a checkNaN loc x | isNaN x = error $ "BayesStack.DirMulti."++loc++": Not a number" checkNaN loc x | isInfinite x = error $ "BayesStack.DirMulti."++loc++": Infinity" checkNaN _ x = x maybeInc, maybeDec :: Maybe Int -> Maybe Int maybeInc Nothing = Just 1 maybeInc (Just n) = Just (n+1) maybeDec Nothing = error "Can't decrement zero count" maybeDec (Just 1) = Nothing maybeDec (Just n) = Just (n-1) {-# INLINEABLE decMultinom #-} {-# INLINEABLE incMultinom #-} decMultinom, incMultinom :: (Ord a, Enum a) => a -> Multinom a -> Multinom a decMultinom k dm = dm { dmCounts = EM.alter maybeDec k $ dmCounts dm , dmTotal = dmTotal dm - 1 } incMultinom k dm = dm { dmCounts = EM.alter maybeInc k $ dmCounts dm , dmTotal = dmTotal dm + 1 } data SetUnset = Set | Unset setMultinom :: (Enum a, Ord a) => SetUnset -> a -> Multinom a -> Multinom a setMultinom Set s = incMultinom s setMultinom Unset s = decMultinom s -- | 'Multinom a' represents multinomial distribution over domain 'a'. -- Optionally, this can include a collapsed Dirichlet prior. -- 'Multinom alpha count total' is a multinomial with Dirichlet prior -- with symmetric parameter 'alpha', ... data Multinom a = DirMulti { dmAlpha :: Alpha a , dmCounts :: EnumMap a Int , dmTotal :: !Int , dmDomain :: Seq a } | Multinom { dmProbs :: !(EnumMap a Double) , dmCounts :: !(EnumMap a Int) , dmTotal :: !Int , dmDomain :: !(Seq a) } deriving (Show, Eq, Generic) instance (Enum a, Serialize a) => Serialize (Multinom a) -- | 'symMultinomFromPrecision d p' is a symmetric Dirichlet/multinomial over a -- domain 'd' with precision 'p' symDirMultiFromPrecision :: Enum a => [a] -> DirPrecision -> Multinom a symDirMultiFromPrecision domain prec = symDirMulti (0.5*prec) domain -- | 'dirMultiFromMeanPrecision m p' is an asymmetric Dirichlet/multinomial -- over a domain 'd' with mean 'm' and precision 'p' dirMultiFromPrecision :: Enum a => DirMean a -> DirPrecision -> Multinom a dirMultiFromPrecision m p = dirMultiFromAlpha $ meanPrecisionToAlpha m p -- | Create a symmetric Dirichlet/multinomial symDirMulti :: Enum a => Double -> [a] -> Multinom a symDirMulti alpha domain = dirMultiFromAlpha $ symAlpha domain alpha -- | A multinomial without a prior multinom :: Enum a => [(a,Double)] -> Multinom a multinom probs = Multinom { dmProbs = EM.fromList probs , dmCounts = EM.empty , dmTotal = 0 , dmDomain = SQ.fromList $ map fst probs } -- | Create an asymmetric Dirichlet/multinomial from items and alphas dirMulti :: Enum a => [(a,Double)] -> Multinom a dirMulti domain = dirMultiFromAlpha $ asymAlpha $ EM.fromList domain -- | Create a Dirichlet/multinomial with a given prior dirMultiFromAlpha :: Enum a => Alpha a -> Multinom a dirMultiFromAlpha alpha = DirMulti { dmAlpha = alpha , dmCounts = EM.empty , dmTotal = 0 , dmDomain = alphaDomain alpha } dmGetCounts :: Enum a => Multinom a -> a -> Int dmGetCounts dm k = EM.findWithDefault 0 k (dmCounts dm) instance HasLikelihood Multinom where type LContext Multinom a = (Ord a, Enum a) likelihood dm@(Multinom {}) = product $ map (\(k,n)->(realToFrac $ dmProbs dm EM.! k)^n) $ EM.assocs $ dmCounts dm likelihood dm = let alpha = dmAlpha dm f k = logToLogFloat $ checkNaN "likelihood(factor)" $ lnGamma (realToFrac (dmGetCounts dm k) + alpha `alphaOf` k) in 1 / alphaNormalizer alpha * product (map f $ toList $ dmDomain dm) / logToLogFloat (checkNaN "likelihood" $ lnGamma $ realToFrac (dmTotal dm) + sumAlpha alpha) {-# INLINEABLE likelihood #-} prob dm@(Multinom {}) k = realToFrac $ dmProbs dm EM.! k prob dm k = let alpha = dmAlpha dm f k = logToLogFloat $ checkNaN "prob(factor)" $ lnGamma (realToFrac (dmGetCounts dm k) + alpha `alphaOf` k) in 1 / alphaNormalizer alpha * f k / logToLogFloat (checkNaN "prob" $ lnGamma $ realToFrac (dmTotal dm) + sumAlpha alpha) {-# INLINEABLE prob #-} instance FullConditionable Multinom where type FCContext Multinom a = (Ord a, Enum a) sampleProb (Multinom {dmProbs=prob}) k = prob EM.! k sampleProb dm@(DirMulti {dmAlpha=a}) k = let alpha = a `alphaOf` k n = realToFrac $ dmGetCounts dm k total = realToFrac $ dmTotal dm in (n + alpha) / (total + sumAlpha a) {-# INLINEABLE sampleProb #-} {-# INLINEABLE probabilities #-} probabilities :: (Ord a, Enum a) => Multinom a -> Seq (Double, a) probabilities dm = fmap (\a->(sampleProb dm a, a)) $ dmDomain dm -- FIXME -- | Probabilities sorted decreasingly decProbabilities :: (Ord a, Enum a) => Multinom a -> Seq (Double, a) decProbabilities = SQ.sortBy (flip (compare `on` fst)) . probabilities prettyMultinom :: (Ord a, Enum a) => Int -> (a -> String) -> Multinom a -> Doc prettyMultinom _ _ (Multinom {}) = error "TODO: prettyMultinom" prettyMultinom n showA dm@(DirMulti {}) = text "DirMulti" <+> parens (text "alpha=" <> prettyAlpha showA (dmAlpha dm)) $$ nest 5 (fsep $ punctuate comma $ map (\(p,a)->text (showA a) <> parens (text $ printf "%1.2e" p)) $ take n $ Data.Foldable.toList $ decProbabilities dm) -- | Update the prior of a Dirichlet/multinomial updatePrior :: (Alpha a -> Alpha a) -> Multinom a -> Multinom a updatePrior _ (Multinom {}) = error "TODO: updatePrior" updatePrior f dm = dm {dmAlpha=f $ dmAlpha dm} -- | Relative tolerance in precision for prior estimation estimationTol = 1e-8 reestimatePriors :: (Foldable f, Functor f, Enum a) => f (Multinom a) -> f (Multinom a) reestimatePriors dms = let usableDms = filter (\dm->dmTotal dm > 5) $ toList dms alpha = case () of _ | length usableDms <= 3 -> id otherwise -> const $ estimatePrior estimationTol usableDms in fmap (updatePrior alpha) dms reestimateSymPriors :: (Foldable f, Functor f, Enum a) => f (Multinom a) -> f (Multinom a) reestimateSymPriors dms = let usableDms = filter (\dm->dmTotal dm > 5) $ toList dms alpha = case () of _ | length usableDms <= 3 -> id otherwise -> const $ symmetrizeAlpha $ estimatePrior estimationTol usableDms in fmap (updatePrior alpha) dms -- | Estimate the prior alpha from a set of Dirichlet/multinomials estimatePrior' :: (Enum a) => [Multinom a] -> Alpha a -> Alpha a estimatePrior' dms alpha = let domain = toList $ dmDomain $ head dms f k = let num = sum $ map (\i->digamma (realToFrac (dmGetCounts i k) + alphaOf alpha k) - digamma (alphaOf alpha k) ) $ filter (\i->dmGetCounts i k > 0) dms total i = realToFrac $ sum $ map (\k->dmGetCounts i k) domain sumAlpha = sum $ map (alphaOf alpha) domain denom = sum $ map (\i->digamma (total i + sumAlpha) - digamma sumAlpha) dms in case () of _ | isNaN num -> error $ "BayesStack.DirMulti.estimatePrior': num = NaN: "++show (map (\i->(digamma (realToFrac (dmGetCounts i k) + alphaOf alpha k), digamma (alphaOf alpha k))) dms) _ | denom == 0 -> error "BayesStack.DirMulti.estimatePrior': denom=0" _ | isInfinite num -> error "BayesStack.DirMulti.estimatePrior': num is infinity " _ | isNaN (alphaOf alpha k * num / denom) -> error $ "NaN"++show (num, denom) otherwise -> alphaOf alpha k * num / denom in asymAlpha $ foldMap (\k->EM.singleton k (f k)) domain estimatePrior :: (Enum a) => Double -> [Multinom a] -> Alpha a estimatePrior tol dms = iter $ dmAlpha $ head dms where iter alpha = let alpha' = estimatePrior' dms alpha (_, prec) = alphaToMeanPrecision alpha (_, prec') = alphaToMeanPrecision alpha' in if abs ((prec' - prec) / prec) > tol then iter alpha' else alpha'