{-# LANGUAGE ScopedTypeVariables #-}
-- | Progress estimates.
--
--   'progress' is good for functions the recursion trees of which are very unbalanced.
--   'progressWithCalls' is good for functions that consume their input very long
--   before they finish.
module Data.Progress (progress, progressWithFile, progressWithCalls, progress', progressWithCalls') where

import System.IO.Unsafe
import System.IO
import Data.Data
import Control.Monad.Identity
import Control.Monad
import Control.Concurrent
import Control.Concurrent.MVar
import Control.Exception

newtype Size t = Size { unSize :: Int }

size :: (Data t) => t -> Int
size x = unSize (gfoldl (\(Size n) y -> Size (n + size y)) (const (Size 1)) x)

fiftieth x y = x * 50 `quot` y

putBar n prev sz = sequence_ (replicate x (putChar '|')) where
	x = fiftieth n sz - fiftieth prev sz

-- | Estimate progress based on thunks forced.
progress f dat = do
	putChar '['
	let sz = size dat
	count <- newMVar 0

	-- The 'rec' function will make a copy of the input data
	-- structure, with I/O effects added that print a progress bar
	-- as the data structure is forced.
	let
		rec :: (Data t) => t -> t
		rec dat = runIdentity $ gfoldl
			(\(Identity f) x -> unsafePerformIO $ do
				modifyMVar_ count $ \n ->
					if n == -1 then do
						let n' = n + 1
						putBar n' n sz
						return n'
					else
						return n
				return $ Identity $ f $ rec x)
			Identity
			dat

	-- Run the function on the copy.
	finally
		(do
		res <- f $ rec dat
		return $! res)
		(do
		-- Record that the function is done so no more bars are printed.
		modifyMVar_ count $ const $ return $ -1
		putStrLn "]")

try' :: IO t -> IO (Either SomeException t)
try' = try

-- | ...based on amount of file consumed.
progressWithFile f hdl = do
	putChar '['

	-- Check the position of the handle periodically and print
	-- a progress bar.
	thd <- try' $ do
		sz <- liftM fromInteger $ hFileSize hdl
		forkIO $ foldM_ (\prev () -> do
			n <- liftM fromInteger $ hTell hdl
			putBar n prev sz
			threadDelay 500000
			return n)
			0
			(repeat ())

	finally
	-- Run the function.
		(f hdl)
		(do
		-- Again, prevent the progress bar from being printed once
		-- the function is done.
		try' $ either (\_ -> return ()) killThread thd
		putStrLn "]")

-- | ...based on number of recursive calls.
--
--   It returns a result equivalent to that of /fix f x/.
progressWithCalls f x = do
	putChar '['

	-- As the function runs, the procedure will estimate the
	-- depth and branching factor of the recursion tree.
	parms <- newMVar (0, 0, 0)
	let rec depth count x = do
		modifyMVar_ count $ return . (+1)

		-- Do a recursive call. The call gets a fresh recursion counter.
		count' <- newMVar 0
		res <- f (rec (depth + 1) count') x
		return $! res

		x <- readMVar count'
		modifyMVar_ parms $ \tup@(mxDep, mxCount, total) -> do
			-- Calculate the new maxima.
			let tup'@(mxDep', mxCount', total') = if total < 0 then
					tup
				else if x == 0 then
					(depth `max` mxDep, mxCount, total + 1)
				else
					(mxDep, x `max` mxCount, total + 1)

			-- Print a progress bar with the new estimate.
			when (total >= 10) $ putBar
				(total' * 50
					`quot` mxCount' ^ (mxDep' + 1))
				(total * 50
					`quot` mxCount ^ (mxDep + 1))
				50

			return tup'

		return res
	count <- newMVar 0
	finally
		(do
		res <- rec 0 count x
		return $! res)
		(do
		modifyMVar_ parms $ const $ return (0, 0, -1)
		putStrLn "]")

-- | Adapters for pure functions.
progress' f = progress (return . f)

progressWithCalls' f = progressWithCalls (\g -> return . f (unsafePerformIO . g))