{-# LANGUAGE PolyKinds #-}
{-# OPTIONS_GHC -fno-warn-incomplete-patterns #-}
module Numeric.Dimensions.Fold
  ( overDim, overDim_, overDimIdx, overDimIdx_
  , overDimOff, overDimOff_
  , overDimReverse, overDimReverseIdx
  , foldDim, foldDimIdx, foldDimOff
  , foldDimReverse, foldDimReverseIdx
  , overDimPart, overDimPartIdx
  ) where
import           Control.Monad           ((>=>))
import           Numeric.Dimensions.Idxs
overDim :: Monad m
        => Dims ds 
        -> (Idxs ds -> Int -> a -> m a) 
        -> Int 
        -> Int 
        -> a 
        -> m a
overDim U k offset _step = k U offset
overDim (d :* ds) k offset step = overDim ds k' offset (di * step)
  where
    dw = dimVal d
    di = fromIntegral dw
    k' is = go 1
      where
        go i off
          | i > dw = return
          | otherwise = k (Idx i :* is) off >=> go (i+1) (off+step)
{-# INLINE overDim #-}
overDimReverse :: Monad m
               => Dims ds 
               -> (Idxs ds -> Int -> a -> m a) 
               -> Int 
               -> Int 
               -> a 
               -> m a
overDimReverse U k offset _step = k U offset
overDimReverse (d :* ds) k offset step = overDimReverse ds k' offset (di * step)
  where
    dw = dimVal d
    di = fromIntegral dw
    k' is = go dw
      where
        go i off
          | i <= 0 = return
          | otherwise = k (Idx i :* is) off >=> go (i-1) (off-step)
{-# INLINE overDimReverse #-}
overDim_ :: Monad m
         => Dims ds 
         -> (Idxs ds -> Int -> m ()) 
         -> Int 
         -> Int 
         -> m ()
overDim_ U k offset _step = k U offset
overDim_ (d :* ds) k offset step = overDim_ ds k' offset (di * step)
  where
    dw = dimVal d
    di = fromIntegral dw
    k' is = go 1
      where
        go i off
          | i > dw = return ()
          | otherwise = k (Idx i :* is) off >> go (i+1) (off+step)
{-# INLINE overDim_ #-}
overDimIdx :: Monad m
           => Dims ds 
           -> (Idxs ds -> a -> m a) 
           -> a 
           -> m a
overDimIdx U k = k U
overDimIdx (d :* ds) k = overDimIdx ds k'
  where
    dw = dimVal d
    k' is = go 1
      where
        go i
          | i > dw = return
          | otherwise = k (Idx i :* is) >=> go (i+1)
{-# INLINE overDimIdx #-}
overDimIdx_ :: Monad m
            => Dims ds 
            -> (Idxs ds -> m ()) 
            -> m ()
overDimIdx_ U k = k U
overDimIdx_ (d :* ds) k = overDimIdx_ ds k'
  where
    dw = dimVal d
    k' is = go 1
      where
        go i
          | i > dw = return ()
          | otherwise = k (Idx i :* is) >> go (i+1)
{-# INLINE overDimIdx_ #-}
overDimOff :: Monad m
           => Dims ds 
           -> (Int -> a -> m a) 
           -> Int 
           -> Int 
           -> a 
           -> m a
overDimOff ds k offset step = go (totalDim ds) offset
  where
    go i off
          | i == 0 = return
          | otherwise = k off >=> go (i-1) (off+step)
{-# INLINE overDimOff #-}
overDimOff_ :: Monad m
            => Dims ds 
            -> (Int -> m ()) 
            -> Int 
            -> Int 
            -> m ()
overDimOff_ ds k offset step = go (totalDim ds) offset
  where
    go i off
          | i == 0 = return ()
          | otherwise = k off >> go (i-1) (off+step)
{-# INLINE overDimOff_ #-}
overDimReverseIdx :: Monad m
                  => Dims ds 
                  -> (Idxs ds -> a -> m a) 
                  -> a 
                  -> m a
overDimReverseIdx U k = k U
overDimReverseIdx (d :* ds) k = overDimReverseIdx ds k'
  where
    dw = dimVal d
    k' is = go dw
      where
        go i
          | i <= 0 = return
          | otherwise = k (Idx i :* is) >=> go (i-1)
{-# INLINE overDimReverseIdx #-}
foldDim :: Dims ds 
        -> (Idxs ds -> Int -> a -> a) 
        -> Int 
        -> Int 
        -> a 
        -> a
foldDim U k offset _step = k U offset
foldDim (d :* ds) k offset step = foldDim ds k' offset (di * step)
  where
    dw = dimVal d
    di = fromIntegral dw
    k' is = go 1
      where
        go i off
          | i > dw = id
          | otherwise = go (i+1) (off+step) . k (Idx i :* is) off
{-# INLINE foldDim #-}
foldDimReverse :: Dims ds 
               -> (Idxs ds -> Int -> a -> a) 
               -> Int 
               -> Int 
               -> a 
               -> a
foldDimReverse U k offset _step = k U offset
foldDimReverse (d :* ds) k offset step = foldDimReverse ds k' offset (di * step)
  where
    dw = dimVal d
    di = fromIntegral dw
    k' is = go dw
      where
        go i off
          | i <= 0 = id
          | otherwise = go (i-1) (off-step) . k (Idx i :* is) off
{-# INLINE foldDimReverse #-}
foldDimIdx :: Dims ds 
           -> (Idxs ds -> a -> a) 
           -> a 
           -> a
foldDimIdx U k = k U
foldDimIdx (d :* ds) k = foldDimIdx ds k'
  where
    dw = dimVal d
    k' is = go 1
      where
        go i
          | i > dw = id
          | otherwise = go (i+1) . k (Idx i :* is)
{-# INLINE foldDimIdx #-}
foldDimOff :: Dims ds 
           -> (Int -> a -> a) 
           -> Int 
           -> Int 
           -> a 
           -> a
foldDimOff ds k offset step = go (totalDim ds) offset
  where
    go i off
          | i == 0 = id
          | otherwise = go (i-1) (off+step) . k off
{-# INLINE foldDimOff #-}
foldDimReverseIdx :: Dims ds 
                  -> (Idxs ds -> a -> a) 
                  -> a 
                  -> a
foldDimReverseIdx U k = k U
foldDimReverseIdx (d :* ds) k = foldDimReverseIdx ds k'
  where
    dw = dimVal d
    k' is = go dw
      where
        go i
          | i <= 0 = id
          | otherwise = go (i-1) . k (Idx i :* is)
{-# INLINE foldDimReverseIdx #-}
overDimPart :: (Dimensions ds, Monad m)
            => Idxs ds 
            -> Idxs ds 
            -> (Idxs ds -> Int -> a -> m a)
                       
            -> Int     
                       
            -> Int     
            -> a       
            -> m a
overDimPart imin imax f offset step = overDimPart' stepSizes imin imax f offset
    where
      stepSizes = createStepSizes (dims `inSpaceOf` imin) step
      createStepSizes :: Dims ns -> Int -> TypedList StepSize ns
      createStepSizes U _ = U
      createStepSizes (d :* ds) k
        = StepSize k :* createStepSizes ds (k * fromIntegral (dimVal d))
overDimPart' :: Monad m
             => TypedList StepSize ns
             -> Idxs ds -> Idxs ds
             -> (Idxs ds -> Int -> a -> m a)
             -> Int
             -> a -> m a
overDimPart' U U U k off0 = k U off0
overDimPart' (siW :* iws) (Idx iStart :* starts) (Idx iEnd :* ends) k off0
  | iEnd >= iStart = overDimPart' iws starts ends (loop iStart) (off0 + headOff)
  | otherwise      = overDimPart' iws starts ends (looi iStart) (off0 + headOff)
  where
    StepSize iW = siW
    headOff = iW * (fromIntegral iStart - 1)
    loop i js off
      | i > iEnd = return
      | otherwise = k (Idx i :* js) off >=> loop (i+1) js (off + iW)
    looi i js off
      | i < iEnd = return
      | otherwise = k (Idx i :* js) off >=> looi (i-1) js (off - iW)
newtype StepSize n = StepSize Int
overDimPartIdx :: Monad m
               => Idxs ds 
               -> Idxs ds 
               -> (Idxs ds -> a -> m a)
                          
               -> a       
               -> m a
overDimPartIdx U U k = k U
overDimPartIdx (start :* starts) (end :* ends) k
  | iEnd >= iStart = overDimPartIdx starts ends (loop iStart)
  | otherwise      = overDimPartIdx starts ends (looi iStart)
  where
    Idx iStart = start
    Idx iEnd   = end
    loop i is
      | i > iEnd = return
      | otherwise = k (Idx i :* is) >=> loop (i+1) is
    looi i is
      | i < iEnd = return
      | otherwise = k (Idx i :* is) >=> looi (i-1) is