module Criterion.Analysis
    (
      Outliers (..)
    , OutlierEffect(..)
    , OutlierVariance(..)
    , SampleAnalysis(..)
    , analyseSample
    , scale
    , analyseMean
    , countOutliers
    , classifyOutliers
    , noteOutliers
    , outlierVariance
    ) where
import Control.Monad (when)
import Criterion.Analysis.Types
import Criterion.IO (note)
import Criterion.Measurement (secs)
import Criterion.Monad (Criterion)
import Data.Int (Int64)
import Data.Monoid (Monoid(..))
import Statistics.Function (sort)
import Statistics.Quantile (weightedAvg)
import Statistics.Resampling (Resample, resample)
import Statistics.Sample (mean, stdDev)
import Statistics.Types (Sample)
import System.Random.MWC (withSystemRandom)
import qualified Data.Vector.Unboxed as U
import qualified Statistics.Resampling.Bootstrap as B
classifyOutliers :: Sample -> Outliers
classifyOutliers sa = U.foldl' ((. outlier) . mappend) mempty ssa
    where outlier e = Outliers {
                        samplesSeen = 1
                      , lowSevere = if e <= loS then 1 else 0
                      , lowMild = if e > loS && e <= loM then 1 else 0
                      , highMild = if e >= hiM && e < hiS then 1 else 0
                      , highSevere = if e >= hiS then 1 else 0
                      }
          loS = q1  (iqr * 3)
          loM = q1  (iqr * 1.5)
          hiM = q3 + (iqr * 1.5)
          hiS = q3 + (iqr * 3)
          q1  = weightedAvg 1 4 ssa
          q3  = weightedAvg 3 4 ssa
          ssa = sort sa
          iqr = q3  q1
outlierVariance :: B.Estimate  
                -> B.Estimate  
                               
                -> Double      
                -> OutlierVariance
outlierVariance µ σ a = OutlierVariance effect varOutMin
  where
    effect | varOutMin < 0.01 = Unaffected
           | varOutMin < 0.1  = Slight
           | varOutMin < 0.5  = Moderate
           | otherwise        = Severe
    varOutMin = (minBy varOut 1 (minBy cMax 0 µgMin)) / σb2
    varOut c  = (ac / a) * (σb2  ac * σg2) where ac = a  c
    σb        = B.estPoint σ
    µa        = B.estPoint µ / a
    µgMin     = µa / 2
    σg        = min (µgMin / 4) (σb / sqrt a)
    σg2       = σg * σg
    σb2       = σb * σb
    minBy f q r = min (f q) (f r)
    cMax x    = fromIntegral (floor (2 * k0 / (k1 + sqrt det)) :: Int)
      where
        k1    = σb2  a * σg2 + ad
        k0    = a * ad
        ad    = a * d
        d     = k * 2 where k = µa  x
        det   = k1 * k1  4 * σg2 * k0
countOutliers :: Outliers -> Int64
countOutliers (Outliers _ a b c d) = a + b + c + d
analyseMean :: Sample
            -> Int              
                                
            -> Criterion Double
analyseMean a iters = do
  let µ = mean a
  _ <- note "mean is %s (%d iterations)\n" (secs µ) iters
  noteOutliers . classifyOutliers $ a
  return µ
scale :: Double                 
      -> SampleAnalysis -> SampleAnalysis
scale f s@SampleAnalysis{..} = s {
                                 anMean = B.scale f anMean
                               , anStdDev = B.scale f anStdDev
                               }
analyseSample :: Double         
              -> Sample         
              -> Int            
                                
              -> IO SampleAnalysis
analyseSample ci samples numResamples = do
  let ests = [mean,stdDev]
  resamples <- withSystemRandom $ \gen ->
               resample gen ests numResamples samples :: IO [Resample]
  let [estMean,estStdDev] = B.bootstrapBCA ci samples ests resamples
      ov = outlierVariance estMean estStdDev (fromIntegral $ U.length samples)
  return SampleAnalysis {
               anMean = estMean
             , anStdDev = estStdDev
             , anOutliers = ov
             }
noteOutliers :: Outliers -> Criterion ()
noteOutliers o = do
  let frac n = (100::Double) * fromIntegral n / fromIntegral (samplesSeen o)
      check :: Int64 -> Double -> String -> Criterion ()
      check k t d = when (frac k > t) $
                    note "  %d (%.1g%%) %s\n" k (frac k) d
      outCount = countOutliers o
  when (outCount > 0) $ do
    _ <- note "found %d outliers among %d samples (%.1g%%)\n"
         outCount (samplesSeen o) (frac outCount)
    check (lowSevere o) 0 "low severe"
    check (lowMild o) 1 "low mild"
    check (highMild o) 1 "high mild"
    check (highSevere o) 0 "high severe"