module Data.Metrics.Reservoir.Uniform (
  UniformReservoir,
  reservoir,
  unsafeReservoir,
  clear,
  unsafeClear,
  size,
  snapshot,
  update,
  unsafeUpdate
) where
import Control.Monad.ST
import Data.Metrics.Internal
import Data.Time.Clock
import qualified Data.Metrics.Reservoir as R
import qualified Data.Metrics.Snapshot as S
import Data.Primitive.MutVar
import System.Random.MWC
import qualified Data.Vector.Unboxed as I
import qualified Data.Vector.Unboxed.Mutable as V
reservoir :: Seed
  -> Int 
  -> R.Reservoir
reservoir g r = R.Reservoir
  { R._reservoirClear = clear
  , R._reservoirSize = size
  , R._reservoirSnapshot = snapshot
  , R._reservoirUpdate = update
  , R._reservoirState = UniformReservoir 0 (I.replicate r 0) g
  }
unsafeReservoir :: Seed -> Int -> R.Reservoir
unsafeReservoir g r = R.Reservoir
  { R._reservoirClear = unsafeClear
  , R._reservoirSize = size
  , R._reservoirSnapshot = snapshot
  , R._reservoirUpdate = unsafeUpdate
  , R._reservoirState = UniformReservoir 0 (I.replicate r 0) g
  }
data UniformReservoir = UniformReservoir
  { _urCount :: !Int
  , _urReservoir :: !(I.Vector Double)
  , _urSeed :: !Seed
  }
clear :: NominalDiffTime -> UniformReservoir -> UniformReservoir
clear = go
  where
    go _ c = c { _urCount = 0, _urReservoir = newRes $ _urReservoir c }
    newRes v = runST $ do
      v' <- I.thaw v
      V.set v' 0
      I.unsafeFreeze v'
unsafeClear :: NominalDiffTime -> UniformReservoir -> UniformReservoir
unsafeClear = go
  where
    go _ c = c { _urCount = 0, _urReservoir = newRes $ _urReservoir c }
    newRes v = runST $ do
      v' <- I.unsafeThaw v
      V.set v' 0
      I.unsafeFreeze v'
size :: UniformReservoir -> Int
size = go
  where
    go c = min (_urCount c) (I.length $ _urReservoir c)
snapshot :: UniformReservoir -> S.Snapshot
snapshot = go
  where
    go c = runST $ do
      v' <- I.unsafeThaw $ _urReservoir c
      S.takeSnapshot $ V.slice 0 (size c) v'
update :: Double -> NominalDiffTime -> UniformReservoir -> UniformReservoir
update = go
  where
    go x _ c = c { _urCount = newCount, _urReservoir = newRes, _urSeed = newSeed }
      where
        newCount = succ $ _urCount c
        (newSeed, newRes) = runST $ do
          v' <- I.thaw $ _urReservoir c
          g <- restore (_urSeed c)
          if newCount <= V.length v'
            then V.unsafeWrite v' (_urCount c) x
            else do
              i <- uniformR (0, newCount) g
              if i < V.length v'
                then V.unsafeWrite v' i x
                else return ()
          v'' <- I.unsafeFreeze v'
          s <- save g
          return (s, v'')
unsafeUpdate :: Double -> NominalDiffTime -> UniformReservoir -> UniformReservoir
unsafeUpdate = go
  where
    go x _ c = c { _urCount = newCount, _urReservoir = newRes, _urSeed = newSeed }
      where
        newCount = succ $ _urCount c
        (newSeed, newRes) = runST $ do
          v' <- I.unsafeThaw $ _urReservoir c
          g <- restore (_urSeed c)
          if newCount <= V.length v'
            then V.unsafeWrite v' (_urCount c) x
            else do
              i <- uniformR (0, newCount) g
              if i < V.length v'
                then V.unsafeWrite v' i x
                else return ()
          v'' <- I.unsafeFreeze v'
          s <- save g
          return (s, v'')