module BayesStack.Dirichlet (
Alpha
, symAlpha, asymAlpha
, alphaDomain, alphaNormalizer, sumAlpha
, DirMean, DirPrecision
, alphaOf, setAlphaOf, setSymAlpha
, alphaToMeanPrecision, meanPrecisionToAlpha
, symmetrizeAlpha
, prettyAlpha
) where
import Data.Foldable (toList, Foldable, fold)
import Data.EnumMap (EnumMap)
import qualified Data.EnumMap as EM
import Data.Sequence (Seq)
import qualified Data.Sequence as SQ
import Data.Number.LogFloat hiding (realToFrac, isNaN, isInfinite)
import Math.Gamma
import Text.Printf
import Text.PrettyPrint
import Data.Serialize
import Data.Serialize.EnumMap ()
import Data.Serialize.LogFloat ()
import GHC.Generics (Generic)
checkNaN :: RealFloat a => String -> a -> a
checkNaN loc x | isNaN x = error $ "BayesStack.Dirichlet."++loc++": Not a number"
checkNaN loc x | isInfinite x = error $ "BayesStack.Dirichlet."++loc++": Infinity"
checkNaN _ x = x
data Alpha a = SymAlpha { aDomain :: Seq a
, aAlpha :: !Double
, aNorm :: LogFloat
}
| Alpha { aAlphas :: EnumMap a Double
, aSumAlphas :: !Double
, aNorm :: LogFloat
}
deriving (Show, Eq, Generic)
instance (Enum a, Serialize a) => Serialize (Alpha a)
type DirMean a = EnumMap a Double
type DirPrecision = Double
symAlpha :: Enum a => [a] -> Double -> Alpha a
symAlpha domain _ | null domain = error "Dirichlet over null domain is undefined"
symAlpha domain alpha = SymAlpha { aDomain = SQ.fromList domain
, aAlpha = alpha
, aNorm = alphaNorm $ symAlpha domain alpha
}
asymAlpha :: Enum a => EnumMap a Double -> Alpha a
asymAlpha alphas | EM.null alphas = error "Dirichlet over null domain is undefined"
asymAlpha alphas = Alpha { aAlphas = alphas
, aSumAlphas = sum $ EM.elems alphas
, aNorm = alphaNorm $ asymAlpha alphas
}
setSymAlpha :: Enum a => Double -> Alpha a -> Alpha a
setSymAlpha alpha a = let b = (symmetrizeAlpha a) { aAlpha = alpha
, aNorm = alphaNorm b
}
in b
alphaNorm :: Enum a => Alpha a -> LogFloat
alphaNorm alpha = normNum / normDenom
where dim = realToFrac $ SQ.length $ aDomain alpha
normNum = case alpha of
Alpha {} -> product $ map (\a->logToLogFloat $ checkNaN ("alphaNorm.normNum(asym) alpha="++show a) $ lnGamma a)
$ EM.elems $ aAlphas alpha
SymAlpha {} -> logToLogFloat $ checkNaN "alphaNorm.normNum(sym)" $ dim * lnGamma (aAlpha alpha)
normDenom = logToLogFloat $ checkNaN "alphaNorm.normDenom" $ lnGamma $ sumAlpha alpha
alphaDomain :: Enum a => Alpha a -> Seq a
alphaDomain (SymAlpha {aDomain=d}) = d
alphaDomain (Alpha {aAlphas=a}) = SQ.fromList $ EM.keys a
alphaNormalizer :: Enum a => Alpha a -> LogFloat
alphaNormalizer = aNorm
alphaOf :: Enum a => Alpha a -> a -> Double
alphaOf (SymAlpha {aAlpha=alpha}) = const alpha
alphaOf (Alpha {aAlphas=alphas}) = (alphas EM.!)
sumAlpha :: Enum a => Alpha a -> Double
sumAlpha (SymAlpha {aDomain=domain, aAlpha=alpha}) = realToFrac (SQ.length domain) * alpha
sumAlpha (Alpha {aSumAlphas=sum}) = sum
setAlphaOf :: Enum a => a -> Double -> Alpha a -> Alpha a
setAlphaOf k a alpha@(SymAlpha {}) = setAlphaOf k a $ asymmetrizeAlpha alpha
setAlphaOf k a (Alpha {aAlphas=alphas}) = asymAlpha $ EM.insert k a alphas
alphaToMeanPrecision :: Enum a => Alpha a -> (DirMean a, DirPrecision)
alphaToMeanPrecision (SymAlpha {aDomain=dom, aAlpha=alpha}) =
let prec = realToFrac (SQ.length dom) * alpha
in (EM.fromList $ map (\a->(a, alpha/prec)) $ toList dom, prec)
alphaToMeanPrecision (Alpha {aAlphas=alphas, aSumAlphas=prec}) =
(fmap (/prec) alphas, prec)
meanPrecisionToAlpha :: Enum a => DirMean a -> DirPrecision -> Alpha a
meanPrecisionToAlpha mean prec = asymAlpha $ fmap (*prec) mean
symmetrizeAlpha :: Enum a => Alpha a -> Alpha a
symmetrizeAlpha alpha@(SymAlpha {}) = alpha
symmetrizeAlpha alpha@(Alpha {}) =
SymAlpha { aDomain = alphaDomain alpha
, aAlpha = sumAlpha alpha / realToFrac (EM.size $ aAlphas alpha)
, aNorm = alphaNorm $ symmetrizeAlpha alpha
}
asymmetrizeAlpha :: Enum a => Alpha a -> Alpha a
asymmetrizeAlpha (SymAlpha {aDomain=domain, aAlpha=alpha}) =
asymAlpha $ fold $ fmap (\k->EM.singleton k alpha) domain
asymmetrizeAlpha alpha@(Alpha {}) = alpha
prettyAlpha :: Enum a => (a -> String) -> Alpha a -> Doc
prettyAlpha showA (SymAlpha {aAlpha=alpha}) = text "Symmetric" <+> double alpha
prettyAlpha showA (Alpha {aAlphas=alphas}) =
text "Assymmetric"
<+> fsep (punctuate comma
$ map (\(a,alpha)->text (showA a) <> parens (text $ printf "%1.2e" alpha))
$ take 100 $ EM.toList $ alphas)