{-# LANGUAGE BangPatterns               #-}
module Data.Massiv.Core.Iterator
  ( loop
  , loopM
  , loopM_
  , loopDeepM
  , splitLinearly
  , splitLinearlyWith_
  , splitLinearlyWithM_
  ) where
loop :: Int -> (Int -> Bool) -> (Int -> Int) -> a -> (Int -> a -> a) -> a
loop !init' condition increment !initAcc f = go init' initAcc
  where
    go !step !acc =
      case condition step of
        False -> acc
        True -> go (increment step) (f step acc)
{-# INLINE loop #-}
loopM :: Monad m => Int -> (Int -> Bool) -> (Int -> Int) -> a -> (Int -> a -> m a) -> m a
loopM !init' condition increment !initAcc f = go init' initAcc
  where
    go !step !acc =
      case condition step of
        False -> return acc
        True -> f step acc >>= go (increment step)
{-# INLINE loopM #-}
loopM_ :: Monad m => Int -> (Int -> Bool) -> (Int -> Int) -> (Int -> m a) -> m ()
loopM_ !init' condition increment f = go init'
  where
    go !step =
      case condition step of
        False -> return ()
        True -> f step >> go (increment step)
{-# INLINE loopM_ #-}
loopDeepM :: Monad m => Int -> (Int -> Bool) -> (Int -> Int) -> a -> (Int -> a -> m a) -> m a
loopDeepM !init' condition increment !initAcc f = go init' initAcc
  where
    go !step !acc =
      case condition step of
        False -> return acc
        True -> go (increment step) acc >>= f step
{-# INLINE loopDeepM #-}
splitLinearly :: Int -> Int -> (Int -> Int -> a) -> a
splitLinearly numChunks totalLength action = action chunkLength slackStart
  where
    !chunkLength = totalLength `quot` numChunks
    !slackStart = chunkLength * numChunks
{-# INLINE splitLinearly #-}
splitLinearlyWith_ :: Monad m => Int -> (m () -> m a) -> Int -> (Int -> b) -> (Int -> b -> m ()) -> m a
splitLinearlyWith_ numChunks with totalLength index =
  splitLinearlyWithM_ numChunks with totalLength (pure . index)
{-# INLINE splitLinearlyWith_ #-}
splitLinearlyWithM_ ::
     Monad m => Int -> (m () -> m a) -> Int -> (Int -> m b) -> (Int -> b -> m c) -> m a
splitLinearlyWithM_ numChunks with totalLength make write =
  splitLinearly numChunks totalLength  $ \chunkLength slackStart -> do
    loopM_ 0 (< slackStart) (+ chunkLength) $ \ !start ->
      with $ loopM_ start (< (start + chunkLength)) (+ 1) $ \ !k -> make k >>= write k
    with $ loopM_ slackStart (< totalLength) (+ 1) $ \ !k -> make k >>= write k
{-# INLINE splitLinearlyWithM_ #-}