module Control.Monad.Progress (
  
  WithProgress,
  
  runWithProgress,
  runWithPercentage,
  withProgressFromList,
  withProgressM,
  setWeight,
  printComponentTime,
  
  (C.>>>)
  ) where
import Control.DeepSeq
import Control.Monad       ( forM_, when )
import Control.Monad.Trans ( MonadIO (..) )
import qualified Control.Category as C
import Data.List           ( genericLength )
import Data.IORef          ( newIORef, atomicModifyIORef', readIORef )
import Data.Time           ( getCurrentTime )
data WithProgress m a b where
  Id            :: WithProgress m a a
  WithProgressM :: ((Double -> m ()) -> a -> m b)           -> WithProgress m a b
  Combine       :: WithProgress m b c -> WithProgress m a b -> WithProgress m a c
  SetWeight     :: Double             -> WithProgress m a b -> WithProgress m a b
instance C.Category (WithProgress m) where
  id  = Id
  (.) = Combine
withProgressFromList :: forall a b m. (Monad m, NFData b) => (a -> [b]) -> WithProgress m a [b]
withProgressFromList f = WithProgressM ret where
  ret :: (Double -> m ()) -> a -> m [b]
  ret report input = do
    
    report 0
    
    
    let output = f input
    let len :: Double
        len = genericLength output
    
    forM_ (zip [1..] output) $ \(n,el) -> do
      () <- deepseq el $ return ()
      report $ n / len
    
    when (null output) $ report 1
    
    return output
setWeight :: Double -> WithProgress m a b -> WithProgress m a b
setWeight = SetWeight
withProgressM :: ((Double -> m ()) -> a -> m b) -> WithProgress m a b
withProgressM f = WithProgressM f
runWithProgress :: Monad m => WithProgress m a b -> (Double -> m ()) -> a -> m b
runWithProgress Id r a = r 1 >> return a
runWithProgress p  r a = runWithProgress' p (r . (/w)) a where
  w = getWeight p
runWithPercentage :: MonadIO m => WithProgress m a b -> (Int -> m ()) -> a -> m b
runWithPercentage Id r a = r 100 >> return a
runWithPercentage p  r a = do
  let w = getWeight p
  r 0
  prevR <- liftIO $ newIORef 0
  let report d = do
        let new = floor $ (/w) $ (*100) d
        isNew <- liftIO $ atomicModifyIORef' prevR $ \prev -> (new, prev /= new)
        when isNew $ r new
  ret <- runWithProgress' p report a
  final <- liftIO $ readIORef prevR
  when (final /= 100) $ report 100
  return ret
runWithProgress' :: Monad m => WithProgress m a b -> (Double -> m ()) -> a -> m b
runWithProgress' Id                _ a = return a
runWithProgress' (WithProgressM p) r a = p r a
runWithProgress' (SetWeight w p)   r a = runWithProgress' p (r . (*w) . (/wp)) a where
  wp = getWeight p
runWithProgress' (Combine q p)     r a = runWithProgress' p r a >>= runWithProgress' q (r . (+wp)) where
  wp = getWeight p
printComponentTime :: MonadIO m => WithProgress m a b -> a -> m b
printComponentTime c a = printTime >> f c a >>= \r -> printTime >> return r where
  f :: MonadIO m => WithProgress m a b -> a -> m b
  f Id                a' = return a'
  f (SetWeight _ p)   a' = f p a'
  f (WithProgressM p) a' = p (const $ return ()) a'
  f (Combine q p)     a' = f p a' >>= \b -> printTime >> f q b
printTime :: MonadIO m => m ()
printTime = liftIO (getCurrentTime >>= print)
getWeight :: WithProgress m a b -> Double
getWeight Id                = 0
getWeight (WithProgressM _) = 1
getWeight (Combine p q)     = getWeight p + getWeight q
getWeight (SetWeight w _)   = w