{-# OPTIONS_GHC -Wall #-}
{-# OPTIONS_GHC -fno-warn-type-defaults #-}
{-# OPTIONS_GHC -fno-warn-name-shadowing #-}
{-# LANGUAGE OverloadedStrings #-}

module NumHask.Histogram
  ( Histogram(..)
  , freq
  , fill
  , DealOvers(..)
  , fromHist
  , hist
  , labels
  , insert
  , insertW
  , insertWs
  ) where

import NumHask.Rect

import Protolude
import qualified Control.Foldl as L
import qualified Data.Map.Strict as Map
import Linear hiding (identity)
import Data.List
import Formatting
import Control.Lens

-- a histogram
data Histogram = Histogram
   { _cuts   :: [Double] -- bucket boundaries
   , _values :: Map.Map Int Double -- bucket counts
   } deriving (Show, Eq)

freq' :: Map.Map Int Double -> Map.Map Int Double
freq' m = Map.map (* recip (Protolude.sum m)) m

freq :: Histogram -> Histogram
freq (Histogram c v) = Histogram c (freq' v)

count :: L.Fold Int (Map Int Double)
count = L.premap (\x -> (x,1.0)) countW

countW :: L.Fold (Int,Double) (Map Int Double)
countW = L.Fold (\x (a,w) -> Map.insertWith (+) a w x) Map.empty identity

countBool :: L.Fold Bool Int
countBool = L.Fold (\x a -> x + if a then 1 else 0) 0 identity

histMap :: (Functor f, Functor g, Ord a, Foldable f, Foldable g) =>
    f a -> g a -> Map Int Double
histMap cuts xs = L.fold count $ (\x -> L.fold countBool (fmap (x >) cuts)) <$> xs

histMapW :: (Functor f, Functor g, Ord a, Foldable f, Foldable g) =>
    f a -> g (a,Double) -> Map Int Double
histMapW cuts xs = L.fold countW $
    (\x -> (L.fold countBool (fmap (fst x >) cuts),snd x)) <$> xs

fill :: [Double] -> [Double] -> Histogram
fill cuts xs = Histogram cuts (histMap cuts xs)

insertW :: Histogram -> Double -> Double -> Histogram
insertW (Histogram cuts vs) value weight = Histogram cuts (Map.unionWith (+) vs s)
    where
      s = histMapW cuts [(value,weight)]

insertWs :: Histogram -> [(Double, Double)] -> Histogram
insertWs (Histogram cuts vs) vws = Histogram cuts (Map.unionWith (+) vs s)
    where
      s = histMapW cuts vws

data DealOvers = IgnoreOvers | IncludeOvers Double

fromHist :: DealOvers -> Histogram -> [Rect Double]
fromHist o (Histogram cuts counts) = view rect <$> zipWith4 V4 x y z w'
  where
      y = repeat 0
      w = zipWith (/)
          ((\x -> Map.findWithDefault 0 x counts) <$> [f..l])
          (zipWith (-) z x)
      f = case o of
        IgnoreOvers -> 1
        IncludeOvers _ -> 0
      l = case o of
        IgnoreOvers -> length cuts - 1
        IncludeOvers _ -> length cuts
      w' = (/Protolude.sum w) <$> w
      x = case o of
        IgnoreOvers -> cuts
        IncludeOvers outw -> [Data.List.head cuts - outw] <> cuts <> [Data.List.last cuts + outw]
      z = drop 1 x

labels :: DealOvers -> [Double] -> [Text]
labels o cuts =
    case o of
      IgnoreOvers -> inside
      IncludeOvers _ -> [ "< " <> sformat (prec 2) (Data.List.head cuts)] <> inside <> [ "> " <> sformat (prec 2) (Data.List.last cuts)]
  where
    inside = sformat (prec 2) <$> zipWith (\l u -> (l+u)/2) cuts (drop 1 cuts)

hist :: [Double] -> Double -> L.Fold Double Histogram
hist cuts r =
    L.Fold
    (\(Histogram cuts counts) a ->
       Histogram cuts
       (Map.unionWith (+)
        (Map.map (*r) counts)
        (Map.singleton (L.fold countBool (fmap (a>) cuts)) 1)))
    (Histogram cuts mempty)
    identity