{-# LANGUAGE BangPatterns          #-}
{-# LANGUAGE FlexibleContexts      #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE ScopedTypeVariables   #-}
{-# LANGUAGE UndecidableInstances  #-}
module Data.Massiv.Array.Ops.Fold.Internal
  (
    foldlS
  , foldrS
  , ifoldlS
  , ifoldrS
  
  , foldlM
  , foldrM
  , foldlM_
  , foldrM_
  , ifoldlM
  , ifoldrM
  , ifoldlM_
  , ifoldrM_
  
  , fold
  , foldlInternal
  , foldrFB
  , lazyFoldlS
  , lazyFoldrS
  
  , foldlP
  , foldrP
  , ifoldlP
  , ifoldrP
  , foldlOnP
  , ifoldlIO
  , foldrOnP
  , ifoldlOnP
  , ifoldrOnP
  , ifoldrIO
  ) where
import           Control.Monad              (void, when)
import qualified Data.Foldable              as F
import           Data.Functor.Identity      (runIdentity)
import           Data.Massiv.Core
import           Data.Massiv.Core.Common
import           Data.Massiv.Core.Scheduler
import           Prelude                    hiding (all, and, any, foldl, foldr,
                                             maximum, minimum, or, product, sum)
import           System.IO.Unsafe           (unsafePerformIO)
fold :: Source r ix e =>
        (e -> e -> e) 
                      
     -> e 
          
     -> Array r ix e 
     -> e
fold f initAcc = foldlInternal f initAcc f initAcc
{-# INLINE fold #-}
foldlM :: (Source r ix e, Monad m) => (a -> e -> m a) -> a -> Array r ix e -> m a
foldlM f = ifoldlM (\ a _ b -> f a b)
{-# INLINE foldlM #-}
foldlM_ :: (Source r ix e, Monad m) => (a -> e -> m a) -> a -> Array r ix e -> m ()
foldlM_ f = ifoldlM_ (\ a _ b -> f a b)
{-# INLINE foldlM_ #-}
ifoldlM :: (Source r ix e, Monad m) => (a -> ix -> e -> m a) -> a -> Array r ix e -> m a
ifoldlM f !acc !arr =
  iterM zeroIndex (size arr) (pureIndex 1) (<) acc $ \ !ix !a -> f a ix (unsafeIndex arr ix)
{-# INLINE ifoldlM #-}
ifoldlM_ :: (Source r ix e, Monad m) => (a -> ix -> e -> m a) -> a -> Array r ix e -> m ()
ifoldlM_ f acc = void . ifoldlM f acc
{-# INLINE ifoldlM_ #-}
foldrM :: (Source r ix e, Monad m) => (e -> a -> m a) -> a -> Array r ix e -> m a
foldrM f = ifoldrM (\_ e a -> f e a)
{-# INLINE foldrM #-}
foldrM_ :: (Source r ix e, Monad m) => (e -> a -> m a) -> a -> Array r ix e -> m ()
foldrM_ f = ifoldrM_ (\_ e a -> f e a)
{-# INLINE foldrM_ #-}
ifoldrM :: (Source r ix e, Monad m) => (ix -> e -> a -> m a) -> a -> Array r ix e -> m a
ifoldrM f !acc !arr =
  iterM (liftIndex (subtract 1) (size arr)) zeroIndex (pureIndex (-1)) (>=) acc $ \ !ix !acc0 ->
    f ix (unsafeIndex arr ix) acc0
{-# INLINE ifoldrM #-}
ifoldrM_ :: (Source r ix e, Monad m) => (ix -> e -> a -> m a) -> a -> Array r ix e -> m ()
ifoldrM_ f !acc !arr = void $ ifoldrM f acc arr
{-# INLINE ifoldrM_ #-}
lazyFoldlS :: Source r ix e => (a -> e -> a) -> a -> Array r ix e -> a
lazyFoldlS f initAcc arr = go initAcc 0 where
    len = totalElem (size arr)
    go acc k | k < len = go (f acc (unsafeLinearIndex arr k)) (k + 1)
             | otherwise = acc
{-# INLINE lazyFoldlS #-}
lazyFoldrS :: Source r ix e => (e -> a -> a) -> a -> Array r ix e -> a
lazyFoldrS = foldrFB
{-# INLINE lazyFoldrS #-}
foldlS :: Source r ix e => (a -> e -> a) -> a -> Array r ix e -> a
foldlS f = ifoldlS (\ a _ e -> f a e)
{-# INLINE foldlS #-}
ifoldlS :: Source r ix e
        => (a -> ix -> e -> a) -> a -> Array r ix e -> a
ifoldlS f acc = runIdentity . ifoldlM (\ a ix e -> return $ f a ix e) acc
{-# INLINE ifoldlS #-}
foldrS :: Source r ix e => (e -> a -> a) -> a -> Array r ix e -> a
foldrS f = ifoldrS (\_ e a -> f e a)
{-# INLINE foldrS #-}
foldrFB :: Source r ix e => (e -> b -> b) -> b -> Array r ix e -> b
foldrFB c n arr = go 0
  where
    !k = totalElem (size arr)
    go !i
      | i == k = n
      | otherwise = let !v = unsafeLinearIndex arr i in v `c` go (i + 1)
{-# INLINE [0] foldrFB #-}
ifoldrS :: Source r ix e => (ix -> e -> a -> a) -> a -> Array r ix e -> a
ifoldrS f acc = runIdentity . ifoldrM (\ ix e a -> return $ f ix e a) acc
{-# INLINE ifoldrS #-}
foldlP :: Source r ix e =>
          (a -> e -> a) 
       -> a 
       -> (b -> a -> b) 
       -> b 
       -> Array r ix e -> IO b
foldlP f = ifoldlP (\ x _ -> f x)
{-# INLINE foldlP #-}
foldlOnP
  :: Source r ix e
  => [Int] -> (a -> e -> a) -> a -> (b -> a -> b) -> b -> Array r ix e -> IO b
foldlOnP wIds f = ifoldlOnP wIds (\ x _ -> f x)
{-# INLINE foldlOnP #-}
ifoldlIO :: Source r ix e =>
            [Int] 
         -> (a -> ix -> e -> IO a) 
         -> a 
         -> (b -> a -> IO b) 
         -> b 
         -> Array r ix e -> IO b
ifoldlIO wIds f !initAcc g !tAcc !arr = do
  let !sz = size arr
  results <-
    divideWork wIds sz $ \ !scheduler !chunkLength !totalLength !slackStart -> do
      loopM_ 0 (< slackStart) (+ chunkLength) $ \ !start ->
        scheduleWork scheduler $
          iterLinearM sz start (start + chunkLength) 1 (<) initAcc $ \ !i ix !acc ->
            f acc ix (unsafeLinearIndex arr i)
      when (slackStart < totalLength) $
        scheduleWork scheduler $
        iterLinearM sz slackStart totalLength 1 (<) initAcc $ \ !i ix !acc ->
          f acc ix (unsafeLinearIndex arr i)
  F.foldlM g tAcc results
{-# INLINE ifoldlIO #-}
ifoldlOnP :: Source r ix e =>
           [Int] -> (a -> ix -> e -> a) -> a -> (b -> a -> b) -> b -> Array r ix e -> IO b
ifoldlOnP wIds f initAcc g =
  ifoldlIO wIds (\acc ix -> return . f acc ix) initAcc (\acc -> return . g acc)
{-# INLINE ifoldlOnP #-}
ifoldlP :: Source r ix e =>
           (a -> ix -> e -> a) -> a -> (b -> a -> b) -> b -> Array r ix e -> IO b
ifoldlP = ifoldlOnP []
{-# INLINE ifoldlP #-}
foldrP :: Source r ix e =>
          (e -> a -> a) -> a -> (a -> b -> b) -> b -> Array r ix e -> IO b
foldrP f = ifoldrP (const f)
{-# INLINE foldrP #-}
foldrOnP :: Source r ix e =>
            [Int] -> (e -> a -> a) -> a -> (a -> b -> b) -> b -> Array r ix e -> IO b
foldrOnP wIds f = ifoldrOnP wIds (const f)
{-# INLINE foldrOnP #-}
ifoldrIO :: Source r ix e =>
           [Int] -> (ix -> e -> a -> IO a) -> a -> (a -> b -> IO b) -> b -> Array r ix e -> IO b
ifoldrIO wIds f !initAcc g !tAcc !arr = do
  let !sz = size arr
  results <-
    divideWork wIds sz $ \ !scheduler !chunkLength !totalLength !slackStart -> do
      when (slackStart < totalLength) $
        scheduleWork scheduler $
        iterLinearM sz (totalLength - 1) slackStart (-1) (>=) initAcc $ \ !i ix !acc ->
          f ix (unsafeLinearIndex arr i) acc
      loopM_ slackStart (> 0) (subtract chunkLength) $ \ !start ->
        scheduleWork scheduler $
          iterLinearM sz (start - 1) (start - chunkLength) (-1) (>=) initAcc $ \ !i ix !acc ->
            f ix (unsafeLinearIndex arr i) acc
  F.foldlM (flip g) tAcc results
{-# INLINE ifoldrIO #-}
ifoldrOnP :: Source r ix e =>
           [Int] -> (ix -> e -> a -> a) -> a -> (a -> b -> b) -> b -> Array r ix e -> IO b
ifoldrOnP wIds f !initAcc g =
  ifoldrIO wIds (\ix e -> return . f ix e) initAcc (\e -> return . g e)
{-# INLINE ifoldrOnP #-}
ifoldrP :: Source r ix e =>
           (ix -> e -> a -> a) -> a -> (a -> b -> b) -> b -> Array r ix e -> IO b
ifoldrP = ifoldrOnP []
{-# INLINE ifoldrP #-}
foldlInternal :: Source r ix e =>
         (a -> e -> a) -> a -> (b -> a -> b) -> b -> Array r ix e -> b
foldlInternal g initAcc f resAcc = \ arr ->
  case getComp arr of
    Seq        -> f resAcc (foldlS g initAcc arr)
    ParOn wIds -> unsafePerformIO $ foldlOnP wIds g initAcc f resAcc arr
{-# INLINE foldlInternal #-}