module Data.Repa.Flow.Chunked.Folds
        ( folds_i
        , FoldsDict)
import Data.Repa.Flow.Chunked.Base
import Data.Repa.Flow.States
import Data.Repa.Scalar.Option
import Data.Repa.Array.Generic.Index            as A
import Data.Repa.Array.Generic                  as A hiding (FoldsDict)
import Data.Repa.Array.Meta.Window              as A
import Data.Repa.Array.Meta.Tuple               as A
import qualified Data.Repa.Flow.Generic         as G
#include "repa-flow.h"

-- | Dictionaries needed to perform a segmented fold.
type FoldsDict i m lSeg lElt lGrp lRes n a b
        = ( States i m
          , Windowable lSeg (n, Int), Windowable lElt a
          , BulkI   lSeg (n, Int)
          , BulkI   lElt a
          , BulkI   lGrp n
          , BulkI   lRes b
          , TargetI lElt a
          , TargetI lGrp n
          , TargetI lRes b)

-- Folds ----------------------------------------------------------------------
-- | Segmented fold over vectors of segment lengths and input values.
folds_i :: FoldsDict i m lSeg lElt lGrp lRes n a b
        => Name lGrp                 -- ^ Layout for group names.
        -> Name lRes                 -- ^ Layout for fold results.
        -> (a -> b -> b)             -- ^ Worker function.
        -> b                         -- ^ Initial state when folding each segment.
        -> Sources i m lSeg (n, Int) -- ^ Segment lengths.
        -> Sources i m lElt a        -- ^ Input elements to fold.
        -> m (Sources i m (T2 lGrp lRes) (n, b)) -- ^ Result elements.

folds_i _ _ f z sLens@(G.Sources nLens _)
            sVals@(G.Sources nVals _)
 = do
        -- Arity of the result bundle is the minimum of the two inputs.
        let nFolds = min nLens nVals

        -- Refs to hold partial fold states between chunks.
        refsState    <- newRefs nFolds None3

        -- Refs to hold the current chunk of lengths data for each stream.
        refsNameLens <- newRefs nFolds Nothing

        -- Refs to hold the current chunk of vals data for each stream.
        refsVals     <- newRefs nFolds Nothing
        refsValsDone <- newRefs nFolds False

        let pull_folds i eat eject
             = do mNameLens <- folds_loadChunkNameLens sLens refsNameLens i
                  mVals     <- folds_loadChunkVals     sVals refsVals refsValsDone i 

                  case (mNameLens, mVals) of
                   -- If we couldn't get a chunk for both sides then we can't
                   -- produce anymore results, and the merge is done.
                   (Nothing, _)       -> eject
                   (_,       Nothing) -> eject

                   -- We've got a chunk for both sides, time to do some work.
                   (Just cNameLens, Just cVals) 
                    -> cNameLens `seq` cVals `seq`
                          mState    <- readRefs refsState i

                          let (cResults, sFolds) 
                                = A.foldsWith name name f z 
                                        (fromOption3 mState) cNameLens cVals

                                refsState refsNameLens refsVals i 
                                cNameLens cVals sFolds

                          valsDone <- readRefs refsValsDone i

                          -- If we're not producing output while we still
                          -- have segment lengths then we're done.
                          if  A.length cResults   == 0
                           && A.length cNameLens  >= 0
                           && valsDone
                                then eject
                                else eat cResults
            {-# INLINE pull_folds #-} 

        return $ G.Sources nFolds pull_folds
{-# INLINE_FLOW folds_i #-}

-- Load the current chunk of lengths data.
-- If we already have one in the state then use that, 
-- otherwise try to pull a new chunk from the source.
    :: States  i m
    => Sources i m l1 (n, Int)
    -> Refs i m (Maybe (Array l1 (n, Int)))
    -> i 
    -> m (Maybe (Array l1 (n, Int)))

folds_loadChunkNameLens (G.Sources _ pullLens) refsLens i
 = do mChunkLens <- readRefs refsLens i
      case mChunkLens of 
        -> let eatLens_folds chunk
                = writeRefs refsLens i (Just chunk)
               {-# INLINE eatLens_folds #-}

               ejectLens_folds = return ()
               {-# INLINE ejectLens_folds #-}

           in do
               pullLens i eatLens_folds ejectLens_folds
               readRefs refsLens i

       jc@(Just _)
        ->    return jc
{-# NOINLINE folds_loadChunkNameLens #-}
--  NOINLINE as this doesn't need to be specialized,
--- and we want to hide the case from the simplifier.

-- Grab the current chunk of values data.
-- If we already have one in the state then use that,
-- otherwise try to pull a new chunk from the source.
        :: (States i m, TargetI l2 a)
        => Sources i m l2 a
        -> Refs i m (Maybe (Array l2 a))
        -> Refs i m Bool
        -> i 
        -> m (Maybe (Array l2 a))

folds_loadChunkVals (G.Sources _ pullVals) refsVals refsValsDone i
 = do mChunkVals <- readRefs refsVals i
      case mChunkVals of
        -> let  eatVals_folds chunk
                 = writeRefs refsVals i (Just chunk)
                {-# INLINE eatVals_folds #-}

                -- When there are no more values then shim in an 
                -- empty chunk so that we can keep calling A.folds
                -- this is needed when there are zero lengthed
                -- segments on the end of the stream
                 = do writeRefs refsVals     i (Just $ A.fromList name [])
                      writeRefs refsValsDone i True
                {-# INLINE ejectVals_folds #-}

           in do
                pullVals i eatVals_folds ejectVals_folds
                readRefs refsVals i

       jc@(Just _)    
        ->     return jc
{-# NOINLINE folds_loadChunkVals #-}
--  NOINLINE as this doesn't need to be specialized, 
--  and we want to hide the case from the simplifier.

        :: ( States i m
           , Windowable l1 (n, Int), Windowable l2 a
           , A.Index l1 ~ Int,       A.Index l2 ~ Int)
        => Refs i m (Option3 n Int b)
        -> Refs i m (Maybe (Array l1 (n, Int)))
        -> Refs i m (Maybe (Array l2 a))
        -> i
        -> Array l1 (n, Int)
        -> Array l2 a
        -> Folds Int Int n a b
        -> m ()

folds_update refsState refsLens refsVals i cLens cVals sFolds 
 = do 
        -- Remember state for the final segment.
        writeRefs refsState i 
         $ case _nameSeg sFolds of
            Some n      -> Some3 n (_lenSeg sFolds) (_valSeg sFolds)
            None        -> None3

        -- Slice down the lengths chunk to just the elements
        -- that we haven't already consumed. If we've consumed
        -- them all then clear the chunk reference so a new one
        -- will be loaded the next time around.
        let !posLens     = _stateLens sFolds
        let !nLensRemain = A.length cLens - posLens
        writeRefs refsLens i 
         $ if nLensRemain <= 0 
              then Nothing
              else Just $ A.window posLens nLensRemain cLens

        -- Likewise for the values chunk.
        let !posVals     = _stateVals sFolds
        let !nValsRemain = A.length cVals - posVals
        writeRefs refsVals i
         $ if nValsRemain <= 0
              then Nothing
              else Just $ A.window posVals nValsRemain cVals
{-# NOINLINE folds_update #-}
--  NOINLINE because it is only called once per chunk
--  and does not need to be specialised.