{-# LANGUAGE CPP #-}
{-# LANGUAGE DeriveFoldable #-}
{-# LANGUAGE DeriveTraversable #-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE BangPatterns, DeriveDataTypeable, DeriveGeneric, FlexibleContexts #-}
module Statistics.Resampling
    ( 
      Resample(..)
    , Bootstrap(..)
    , Estimator(..)
    , estimate
      
    , resampleST
    , resample
    , resampleVector
      
    , jackknife
    , jackknifeMean
    , jackknifeVariance
    , jackknifeVarianceUnb
    , jackknifeStdDev
      
    , splitGen
    ) where
import Data.Aeson (FromJSON, ToJSON)
import Control.Applicative
import Control.Concurrent (forkIO, newChan, readChan, writeChan)
import Control.Monad (forM_, forM, replicateM, replicateM_, liftM2)
import Control.Monad.Primitive (PrimMonad(..))
import Data.Binary (Binary(..))
import Data.Data (Data, Typeable)
import Data.Vector.Algorithms.Intro (sort)
import Data.Vector.Binary ()
import Data.Vector.Generic (unsafeFreeze,unsafeThaw)
import Data.Word (Word32)
import qualified Data.Foldable as T
import qualified Data.Traversable as T
import qualified Data.Vector.Generic as G
import qualified Data.Vector.Unboxed as U
import qualified Data.Vector.Unboxed.Mutable as MU
import GHC.Conc (numCapabilities)
import GHC.Generics (Generic)
import Numeric.Sum (Summation(..), kbn)
import Statistics.Function (indices)
import Statistics.Sample (mean, stdDev, variance, varianceUnbiased)
import Statistics.Types (Sample)
import System.Random.MWC (Gen, GenIO, initialize, uniformR, uniformVector)
newtype Resample = Resample {
      fromResample :: U.Vector Double
    } deriving (Eq, Read, Show, Typeable, Data, Generic)
instance FromJSON Resample
instance ToJSON Resample
instance Binary Resample where
    put = put . fromResample
    get = fmap Resample get
data Bootstrap v a = Bootstrap
  { fullSample :: !a
  , resamples  :: v a
  }
  deriving (Eq, Read, Show , Generic, Functor, T.Foldable, T.Traversable
#if __GLASGOW_HASKELL__ >= 708
           , Typeable, Data
#endif
           )
instance (Binary a,   Binary   (v a)) => Binary   (Bootstrap v a) where
  get = liftM2 Bootstrap get get
  put (Bootstrap fs rs) = put fs >> put rs
instance (FromJSON a, FromJSON (v a)) => FromJSON (Bootstrap v a)
instance (ToJSON a,   ToJSON   (v a)) => ToJSON   (Bootstrap v a)
data Estimator = Mean
               | Variance
               | VarianceUnbiased
               | StdDev
               | Function (Sample -> Double)
estimate :: Estimator -> Sample -> Double
estimate Mean             = mean
estimate Variance         = variance
estimate VarianceUnbiased = varianceUnbiased
estimate StdDev           = stdDev
estimate (Function est) = est
resampleST :: PrimMonad m
           => Gen (PrimState m)
           -> [Estimator]         
           -> Int                 
           -> U.Vector Double     
           -> m [Bootstrap U.Vector Double]
resampleST gen ests numResamples sample = do
  
  res <- forM ests $ \e -> U.replicateM numResamples $ do
    v <- resampleVector gen sample
    return $! estimate e v
  
  resM <- mapM unsafeThaw res
  mapM_ sort resM
  resSorted <- mapM unsafeFreeze resM
  return $ zipWith Bootstrap [estimate e sample | e <- ests]
                             resSorted
resample :: GenIO
         -> [Estimator]         
         -> Int                 
         -> U.Vector Double     
         -> IO [(Estimator, Bootstrap U.Vector Double)]
resample gen ests numResamples samples = do
  let ixs = scanl (+) 0 $
            zipWith (+) (replicate numCapabilities q)
                        (replicate r 1 ++ repeat 0)
          where (q,r) = numResamples `quotRem` numCapabilities
  results <- mapM (const (MU.new numResamples)) ests
  done <- newChan
  gens <- splitGen numCapabilities gen
  forM_ (zip3 ixs (tail ixs) gens) $ \ (start,!end,gen') ->
    forkIO $ do
      let loop k ers | k >= end = writeChan done ()
                     | otherwise = do
            re <- resampleVector gen' samples
            forM_ ers $ \(est,arr) ->
                MU.write arr k . est $ re
            loop (k+1) ers
      loop start (zip ests' results)
  replicateM_ numCapabilities $ readChan done
  mapM_ sort results
  
  res <- mapM unsafeFreeze results
  return $ zip ests
         $ zipWith Bootstrap [estimate e samples | e <- ests]
                             res
 where
  ests' = map estimate ests
resampleVector :: (PrimMonad m, G.Vector v a)
               => Gen (PrimState m) -> v a -> m (v a)
resampleVector gen v
  = G.replicateM n $ do i <- uniformR (0,n-1) gen
                        return $! G.unsafeIndex v i
  where
    n = G.length v
jackknife :: Estimator -> Sample -> U.Vector Double
jackknife Mean sample             = jackknifeMean sample
jackknife Variance sample         = jackknifeVariance sample
jackknife VarianceUnbiased sample = jackknifeVarianceUnb sample
jackknife StdDev sample = jackknifeStdDev sample
jackknife (Function est) sample
  | G.length sample == 1 = singletonErr "jackknife"
  | otherwise            = U.map f . indices $ sample
  where f i = est (dropAt i sample)
jackknifeMean :: Sample -> U.Vector Double
jackknifeMean samp
  | len == 1  = singletonErr "jackknifeMean"
  | otherwise = G.map (/l) $ G.zipWith (+) (pfxSumL samp) (pfxSumR samp)
  where
    l   = fromIntegral (len - 1)
    len = G.length samp
jackknifeVariance_ :: Double -> Sample -> U.Vector Double
jackknifeVariance_ c samp
  | len == 1  = singletonErr "jackknifeVariance"
  | otherwise = G.zipWith4 go als ars bls brs
  where
    als = pfxSumL . G.map goa $ samp
    ars = pfxSumR . G.map goa $ samp
    goa x = v * v where v = x - m
    bls = pfxSumL . G.map (subtract m) $ samp
    brs = pfxSumR . G.map (subtract m) $ samp
    m = mean samp
    n = fromIntegral len
    go al ar bl br = (al + ar - (b * b) / q) / (q - c)
      where b = bl + br
            q = n - 1
    len = G.length samp
jackknifeVarianceUnb :: Sample -> U.Vector Double
jackknifeVarianceUnb samp
  | G.length samp == 2  = singletonErr "jackknifeVariance"
  | otherwise           = jackknifeVariance_ 1 samp
jackknifeVariance :: Sample -> U.Vector Double
jackknifeVariance = jackknifeVariance_ 0
jackknifeStdDev :: Sample -> U.Vector Double
jackknifeStdDev = G.map sqrt . jackknifeVarianceUnb
pfxSumL :: U.Vector Double -> U.Vector Double
pfxSumL = G.map kbn . G.scanl add zero
pfxSumR :: U.Vector Double -> U.Vector Double
pfxSumR = G.tail . G.map kbn . G.scanr (flip add) zero
dropAt :: U.Unbox e => Int -> U.Vector e -> U.Vector e
dropAt n v = U.slice 0 n v U.++ U.slice (n+1) (U.length v - n - 1) v
singletonErr :: String -> a
singletonErr func = error $
                    "Statistics.Resampling." ++ func ++ ": not enough elements in sample"
splitGen :: Int -> GenIO -> IO [GenIO]
splitGen n gen
  | n <= 0    = return []
  | otherwise =
  fmap (gen:) . replicateM (n-1) $
  initialize =<< (uniformVector gen 256 :: IO (U.Vector Word32))