{-# LANGUAGE BangPatterns, ScopedTypeVariables, TypeOperators #-}

module Criterion.Measurement
    (
      getTime
    , runForAtLeast
    , secs
    , time
    , time_
    ) where
    
import Control.Monad (when)
import Data.Array.Vector ((:*:)(..))
import Data.Time.Clock.POSIX (getPOSIXTime)
import Text.Printf (printf)
        
time :: IO a -> IO (Double :*: a)
time act = do
  start <- getTime
  result <- act
  end <- getTime
  return (end - start :*: result)

time_ :: IO a -> IO Double
time_ act = do
  start <- getTime
  act
  end <- getTime
  return $! end - start

getTime :: IO Double
getTime = (fromRational . toRational) `fmap` getPOSIXTime

runForAtLeast :: Double -> Int -> (Int -> IO a) -> IO (Double :*: Int :*: a)
runForAtLeast howLong initSeed act = loop initSeed (0::Int) =<< getTime
  where
    loop !seed !iters initTime = do
      now <- getTime
      when (now - initTime > howLong * 10) $
        fail (printf "took too long to run: seed %d, iters %d" seed iters)
      elapsed :*: result <- time (act seed)
      if elapsed < howLong
        then loop (seed * 2) (iters+1) initTime
        else return (elapsed :*: seed :*: result)

secs :: Double -> String
secs k
    | k < 0      = '-' : secs (-k)
    | k >= 1     = k        `with` "s"
    | k >= 1e-3  = (k*1e3)  `with` "ms"
    | k >= 1e-6  = (k*1e6)  `with` "us"
    | k >= 1e-9  = (k*1e9)  `with` "ns"
    | k >= 1e-12 = (k*1e12) `with` "ps"
    | otherwise  = printf "%g s" k
     where with (t :: Double) (u :: String)
               | t >= 1e9  = printf "%.4g %s" t u
               | t >= 1e6  = printf "%.0f %s" t u
               | t >= 1e5  = printf "%.1f %s" t u
               | t >= 1e4  = printf "%.2f %s" t u
               | t >= 1e3  = printf "%.3f %s" t u
               | t >= 1e2  = printf "%.4f %s" t u
               | t >= 1e1  = printf "%.5f %s" t u
               | otherwise = printf "%.6f %s" t u