```{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE TypeFamilies #-}

-- | The categorical distribution is used for discrete data.  It is also sometimes called the discrete distribution or the multinomial distribution.  For more, see the wikipedia entry: <https://en.wikipedia.org/wiki/Categorical_distribution>
module HLearn.Models.Distributions.Univariate.Categorical
(
-- * Data types
Categorical (Categorical)

-- * Helper functions
, dist2list
, mostLikely
)
where

import Control.DeepSeq
import Data.List
import Data.List.Extras
import Debug.Trace

import qualified Data.Map.Strict as Map
import qualified Data.Foldable as F

import HLearn.Algebra
import HLearn.Models.Distributions.Common
import HLearn.Models.Distributions.Visualization.Gnuplot

-------------------------------------------------------------------------------
-- Categorical

data Categorical sampletype prob = Categorical
{ pdfmap :: !(Map.Map sampletype prob)
}

instance (NFData sampletype, NFData prob) => NFData (Categorical sampletype prob) where
rnf d = rnf \$ pdfmap d

-------------------------------------------------------------------------------
-- Algebra

instance (Ord label, Num prob) => Abelian (Categorical label prob)
instance (Ord label, Num prob) => Monoid (Categorical label prob) where
mempty = Categorical Map.empty
mappend !d1 !d2 = Categorical \$ res
where
res = Map.unionWith (+) (pdfmap d1) (pdfmap d2)

instance (Ord label, Num prob) => Group (Categorical label prob) where
inverse d1 = d1 {pdfmap=Map.map (0-) (pdfmap d1)}

instance (Num prob) => HasRing (Categorical label prob) where
type Ring (Categorical label prob) = prob
instance (Ord label, Num prob) => Module (Categorical label prob) where
p .* (Categorical pdf) = Categorical \$ Map.map (*p) pdf

-------------------------------------------------------------------------------
-- Training

instance (Ord label, Num prob) => HomTrainer (Categorical label prob) where
type Datapoint (Categorical label prob) = label
train1dp dp = Categorical \$ Map.singleton dp 1

instance (Num prob) => NumDP (Categorical label prob) where
numdp dist = F.foldl' (+) 0 \$ pdfmap dist

-------------------------------------------------------------------------------
-- Distribution

instance Probabilistic (Categorical label prob) where
type Probability (Categorical label prob) = prob

instance (Ord label, Ord prob, Fractional prob) => PDF (Categorical label prob) where

{-# INLINE pdf #-}
pdf dist label = {-0.0001+-}(val/tot)
where
val = case Map.lookup label (pdfmap dist) of
Nothing -> 0
Just x  -> x
tot = F.foldl' (+) 0 \$ pdfmap dist

instance (Ord label, Ord prob, Fractional prob) => CDF (Categorical label prob) where

{-# INLINE cdf #-}
cdf dist label = (Map.foldl' (+) 0 \$ Map.filterWithKey (\k a -> k<=label) \$ pdfmap dist)
/ (Map.foldl' (+) 0 \$ pdfmap dist)

{-# INLINE cdfInverse #-}
cdfInverse dist prob = go cdfL
where
cdfL = sortBy (\(k1,p1) (k2,p2) -> compare p2 p1) \$ map (\k -> (k,pdf dist k)) \$ Map.keys \$ pdfmap dist
go (x:[]) = fst \$ last cdfL
go (x:xs) = if prob < snd x -- && prob > (snd \$ head xs)
then fst x
else go xs
--     cdfInverse dist prob = argmax (cdf dist) \$ Map.keys \$ pdfmap dist

--     {-# INLINE mean #-}
--     mean dist = fst \$ argmax snd \$ Map.toList \$ pdfmap dist
--
--     {-# INLINE drawSample #-}
--     drawSample dist = do
--         x <- getRandomR (0,1)
--         return \$ cdfInverse dist (x::prob)

instance (Num prob, Ord prob, Ord label) => Mean (Categorical label prob) where
mean dist = fst \$ argmax snd \$ Map.toList \$ pdfmap dist

-- | Extracts the element in the distribution with the highest probability
mostLikely :: Ord prob => Categorical label prob -> label
mostLikely dist = fst \$ argmax snd \$ Map.toList \$ pdfmap dist

-- | Converts a distribution into a list of (sample,probability) pai
dist2list :: Categorical sampletype prob -> [(sampletype,prob)]
dist2list (Categorical pdfmap) = Map.toList pdfmap

instance
( Ord label, Show label
, Ord prob, Show prob, Fractional prob
) => PlottableDistribution (Categorical label prob)
where
samplePoints (Categorical dist) = Map.keys dist
plotType dist = Bar

-------------------------------------------------------------------------------
-- Morphisms

-- instance
--     ( Ord label
--     , Num prob
--     ) => Morphism (Categorical label prob) FreeModParams (FreeMod prob label)
--         where
--     Categorical pdf \$> FreeModParams = FreeMod pdf
--
```