{-# LANGUAGE DeriveGeneric #-} module BayesStack.Dirichlet ( -- * Dirichlet parameter Alpha , symAlpha, asymAlpha , alphaDomain, alphaNormalizer, sumAlpha , DirMean, DirPrecision , alphaOf, setAlphaOf, setSymAlpha , alphaToMeanPrecision, meanPrecisionToAlpha , symmetrizeAlpha , prettyAlpha ) where import Data.Foldable (toList, Foldable, fold) import Data.EnumMap (EnumMap) import qualified Data.EnumMap as EM import Data.Sequence (Seq) import qualified Data.Sequence as SQ import Data.Number.LogFloat hiding (realToFrac, isNaN, isInfinite) import Math.Gamma import Text.Printf import Text.PrettyPrint import Data.Serialize import Data.Serialize.EnumMap () import Data.Serialize.LogFloat () import GHC.Generics (Generic) -- | Make error handling a bit easier checkNaN :: RealFloat a => String -> a -> a checkNaN loc x | isNaN x = error $ "BayesStack.Dirichlet."++loc++": Not a number" checkNaN loc x | isInfinite x = error $ "BayesStack.Dirichlet."++loc++": Infinity" checkNaN _ x = x -- | A Dirichlet prior data Alpha a = SymAlpha { aDomain :: Seq a , aAlpha :: !Double , aNorm :: LogFloat } | Alpha { aAlphas :: EnumMap a Double , aSumAlphas :: !Double , aNorm :: LogFloat } deriving (Show, Eq, Generic) instance (Enum a, Serialize a) => Serialize (Alpha a) type DirMean a = EnumMap a Double type DirPrecision = Double symAlpha :: Enum a => [a] -> Double -> Alpha a symAlpha domain _ | null domain = error "Dirichlet over null domain is undefined" symAlpha domain alpha = SymAlpha { aDomain = SQ.fromList domain , aAlpha = alpha , aNorm = alphaNorm $ symAlpha domain alpha } -- | Construct an asymmetric Alpha asymAlpha :: Enum a => EnumMap a Double -> Alpha a asymAlpha alphas | EM.null alphas = error "Dirichlet over null domain is undefined" asymAlpha alphas = Alpha { aAlphas = alphas , aSumAlphas = sum $ EM.elems alphas , aNorm = alphaNorm $ asymAlpha alphas } setSymAlpha :: Enum a => Double -> Alpha a -> Alpha a setSymAlpha alpha a = let b = (symmetrizeAlpha a) { aAlpha = alpha , aNorm = alphaNorm b } in b -- | Compute the normalizer of the likelihood involving alphas, -- (product_k gamma(alpha_k)) / gamma(sum_k alpha_k) alphaNorm :: Enum a => Alpha a -> LogFloat alphaNorm alpha = normNum / normDenom where dim = realToFrac $ SQ.length $ aDomain alpha normNum = case alpha of Alpha {} -> product $ map (\a->logToLogFloat $ checkNaN ("alphaNorm.normNum(asym) alpha="++show a) $ lnGamma a) $ EM.elems $ aAlphas alpha SymAlpha {} -> logToLogFloat $ checkNaN "alphaNorm.normNum(sym)" $ dim * lnGamma (aAlpha alpha) normDenom = logToLogFloat $ checkNaN "alphaNorm.normDenom" $ lnGamma $ sumAlpha alpha -- | 'alphaDomain a' is the domain of prior 'a' alphaDomain :: Enum a => Alpha a -> Seq a alphaDomain (SymAlpha {aDomain=d}) = d alphaDomain (Alpha {aAlphas=a}) = SQ.fromList $ EM.keys a alphaNormalizer :: Enum a => Alpha a -> LogFloat alphaNormalizer = aNorm -- | 'alphaOf alpha k' is the value of element 'k' in prior 'alpha' alphaOf :: Enum a => Alpha a -> a -> Double alphaOf (SymAlpha {aAlpha=alpha}) = const alpha alphaOf (Alpha {aAlphas=alphas}) = (alphas EM.!) -- | 'sumAlpha alpha' is the sum of all alphas sumAlpha :: Enum a => Alpha a -> Double sumAlpha (SymAlpha {aDomain=domain, aAlpha=alpha}) = realToFrac (SQ.length domain) * alpha sumAlpha (Alpha {aSumAlphas=sum}) = sum -- | Set a particular alpha element setAlphaOf :: Enum a => a -> Double -> Alpha a -> Alpha a setAlphaOf k a alpha@(SymAlpha {}) = setAlphaOf k a $ asymmetrizeAlpha alpha setAlphaOf k a (Alpha {aAlphas=alphas}) = asymAlpha $ EM.insert k a alphas -- | 'alphaToMeanPrecision a' is the mean/precision representation of the prior 'a' alphaToMeanPrecision :: Enum a => Alpha a -> (DirMean a, DirPrecision) alphaToMeanPrecision (SymAlpha {aDomain=dom, aAlpha=alpha}) = let prec = realToFrac (SQ.length dom) * alpha in (EM.fromList $ map (\a->(a, alpha/prec)) $ toList dom, prec) alphaToMeanPrecision (Alpha {aAlphas=alphas, aSumAlphas=prec}) = (fmap (/prec) alphas, prec) -- | 'meanPrecisionToAlpha m p' is a prior with mean 'm' and precision 'p' meanPrecisionToAlpha :: Enum a => DirMean a -> DirPrecision -> Alpha a meanPrecisionToAlpha mean prec = asymAlpha $ fmap (*prec) mean -- | Symmetrize a Dirichlet prior (such that mean=0) symmetrizeAlpha :: Enum a => Alpha a -> Alpha a symmetrizeAlpha alpha@(SymAlpha {}) = alpha symmetrizeAlpha alpha@(Alpha {}) = SymAlpha { aDomain = alphaDomain alpha , aAlpha = sumAlpha alpha / realToFrac (EM.size $ aAlphas alpha) , aNorm = alphaNorm $ symmetrizeAlpha alpha } -- | Turn a symmetric alpha into an asymmetric alpha. For internal use. asymmetrizeAlpha :: Enum a => Alpha a -> Alpha a asymmetrizeAlpha (SymAlpha {aDomain=domain, aAlpha=alpha}) = asymAlpha $ fold $ fmap (\k->EM.singleton k alpha) domain asymmetrizeAlpha alpha@(Alpha {}) = alpha -- | Pretty-print a Dirichlet prior prettyAlpha :: Enum a => (a -> String) -> Alpha a -> Doc prettyAlpha showA (SymAlpha {aAlpha=alpha}) = text "Symmetric" <+> double alpha prettyAlpha showA (Alpha {aAlphas=alphas}) = text "Assymmetric" <+> fsep (punctuate comma $ map (\(a,alpha)->text (showA a) <> parens (text $ printf "%1.2e" alpha)) $ take 100 $ EM.toList $ alphas)