{-# LANGUAGE BangPatterns #-}
{-# OPTIONS_GHC -fno-warn-unused-do-bind #-}
module Main (main) where

import System.Environment (getArgs)
import Data.Array.Unboxed
import Data.Array.Base (unsafeAt)
import System.Random
import System.CPUTime
import Control.Monad (when)
import Control.Exception (evaluate)
import Text.Printf
import GHC.Real

import Implementation

test :: (Rational -> Double) -> UArray Int Int -> Int -> Double
test conv arr len = loop len 0.0
  where
    rat :: Int -> Rational
    rat k = fromIntegral n :% (fromIntegral n * fromIntegral d + 1)
      where
        !n = arr `unsafeAt` (2*k)
        !d = arr `unsafeAt` (2*k+1)
    loop 0 !acc = acc + conv (rat 0)
    loop k  acc = loop (k-1) (acc + conv (rat k))

main :: IO ()
main = do
    args <- getArgs
    let size =
          case args of
            ("-bd":sz:_)    -> read sz
            _               -> 200000
    sg <- getStdGen
    let iarr = mkIArr size sg
        !ds = test (recip . fromRationalDouble . recip) iarr size
    when (ds == 0) (putStrLn "Jackpot Double")
    sequence_
        [ bench "New fromRational :: Rational -> Double"
            (test fromRationalDouble iarr size)
        , bench "Old fromRational :: Rational -> Double"
            (test fromRational iarr size)
-- Now for Float -> Double
        , bench "New fromRational :: Rational -> Float"
            (fest fromRationalFloat iarr size)
        , bench "Old fromRational :: Rational -> Float"
            (fest fromRational iarr size)
        ]

fest :: (Rational -> Float) -> UArray Int Int -> Int -> Float
fest conv arr len = loop len 0.0
  where
    rat :: Int -> Rational
    rat k = fromIntegral n :% (fromIntegral n * fromIntegral d + 1)
      where
        !n = arr `unsafeAt` (2*k)
        !d = arr `unsafeAt` (2*k+1)
    loop 0 !acc = acc + conv (rat 0)
    loop k  acc = loop (k-1) (acc + conv (rat k))

mkIArr :: Int -> StdGen -> UArray Int Int
mkIArr num = array (0,2*num+1) . zip [0 .. 2*num+1] . randomRs (2, maxBound-1)

prec :: Double
prec = 1e-12

bench :: String -> a -> IO ()
bench name val = do
    t0 <- getCPUTime
    evaluate val
    t1 <- getCPUTime
    printf "%s:\n    took %14.8fs\n" name (fromInteger (t1-t0)*prec)
