{-# LANGUAGE DataKinds                 #-}
{-# LANGUAGE GADTs                     #-}
{-# LANGUAGE KindSignatures            #-}
{-# LANGUAGE MagicHash                 #-}
{-# LANGUAGE ScopedTypeVariables       #-}
{-# LANGUAGE TypeApplications          #-}
{-# LANGUAGE UnboxedTuples             #-}
{-# LANGUAGE BangPatterns              #-}
{-# LANGUAGE Strict                    #-}
-----------------------------------------------------------------------------
-- |
-- Module      :  Numeric.Dimensions.Traverse
-- Copyright   :  (c) Artem Chirkin
-- License     :  BSD3
--
-- Maintainer  :  chirkin@arch.ethz.ch
--
-- Map a function over all dimensions provided dimension indices or offsets.
--
-----------------------------------------------------------------------------

module Numeric.Dimensions.Traverse
  ( overDim#, overDim_#, overDimIdx#, overDimIdx_#, overDimOff#, overDimOff_#
  , overDimPart#
  , foldDim, foldDimIdx, foldDimOff
  , foldDimReverse, foldDimReverseIdx
  ) where


import           GHC.Exts

import           Numeric.Dimensions.Dim
import           Numeric.Dimensions.Idx



-- | Traverse over all dimensions keeping track of index and offset
overDim# :: Dim (ds :: [Nat])
         -> (Idx ds -> Int# -> a -> State# s -> (# State# s, a #)) -- ^ function to map over each dimension
         -> Int# -- ^ Initial offset
         -> Int# -- ^ offset step
         -> a
         -> State# s
         -> (# State# s, a #)
overDim# ds f off0# step# a0 s0 = case overDim'# ds g off0# a0 s0 of
                              (# s1, _, a1 #) -> (# s1, a1 #)
  where
    g i off# a s = case f i off# a s of
                    (# t, b #) -> (# t, off# +# step#, b #)
{-# INLINE overDim# #-}

-- | Fold over all dimensions keeping track of index and offset
foldDim :: Dim (ds :: [Nat])
        -> (Idx ds -> Int# -> a -> a) -- ^ function to map over each dimension
        -> Int# -- ^ Initial offset
        -> Int# -- ^ offset step
        -> a -> a
foldDim ds f off0# step# a0 = case foldDim' ds g off0# a0 of
                              (# _, a1 #) -> a1
  where
    g i off# a = (# off# +# step#, f i off# a #)
{-# INLINE foldDim #-}

-- | Fold over all dimensions in reverse order keeping track of index and offset
foldDimReverse :: Dim (ds :: [Nat])
               -> (Idx ds -> Int# -> a -> a) -- ^ function to map over each dimension
               -> Int# -- ^ Initial offset
               -> Int# -- ^ offset step (substracted from initial offset)
               -> a -> a
foldDimReverse ds f off0# step# a0 = case foldDimReverse' ds g (off0# +# n# *# step# -# step#) a0 of
                              (# _, a1 #) -> a1
  where
    !(I# n#) = dimVal ds
    g i off# a = (# off# -# step#, f i off# a #)
{-# INLINE foldDimReverse #-}



-- | Same as overDim#, but with no return value
overDim_# :: Dim (ds :: [Nat])
          -> (Idx ds -> Int# -> State# s -> State# s) -- ^ function to map over each dimension
          -> Int# -- ^ Initial offset
          -> Int# -- ^ offset step
          -> State# s
          -> State# s
overDim_# ds f off0# step# s0 = case overDim_'# ds g off0# s0 of
                              (# s1, _ #) -> s1
  where
    g i off# s = (# f i off# s, off# +# step# #)
{-# INLINE overDim_# #-}

-- | Traverse over all dimensions keeping track of indices
overDimIdx# :: Dim (ds :: [Nat])
            -> (Idx ds -> a -> State# s -> (# State# s, a #))
            -> a
            -> State# s
            -> (# State# s, a #)
overDimIdx# D f = f Z
overDimIdx# ((Dn :: Dim n) :* (!ds)) f = overDimIdx# ds (loop 1)
  where
    n = dimVal' @n
    loop i js a s | i > n = (# s,  a #)
                  | otherwise = case f (i:!js) a s of
                            (# s', b #) -> loop (i+1) js b s'

-- | Fold all dimensions keeping track of indices
foldDimIdx :: Dim (ds :: [Nat])
            -> (Idx ds -> a -> a)
            -> a -> a
foldDimIdx D f = f Z
foldDimIdx ((Dn :: Dim n) :* (!ds)) f = foldDimIdx ds (loop 1)
  where
    n = dimVal' @n
    loop i js a | i > n = a
                | otherwise = loop (i+1) js $! f (i:!js) a

-- | Fold all dimensions in reverse order keeping track of indices
foldDimReverseIdx :: Dim (ds :: [Nat])
                  -> (Idx ds -> a -> a)
                  -> a -> a
foldDimReverseIdx D f = f Z
foldDimReverseIdx ((Dn :: Dim n) :* (!ds)) f = foldDimReverseIdx ds (loop n)
  where
    n = dimVal' @n
    loop i js a | i > n = a
                | otherwise = loop (i-1) js $! f (i:!js) a



-- | Traverse over all dimensions keeping track of indices, with no return value
overDimIdx_# :: Dim (ds :: [Nat])
             -> (Idx ds -> State# s -> State# s)
             -> State# s
             -> State# s
overDimIdx_# D f = f Z
overDimIdx_# ((Dn :: Dim n) :* (!ds)) f = overDimIdx_# ds (loop 1)
  where
    n = dimVal' @n
    loop i js s | i > n = s
                | otherwise =  loop (i+1) js (f (i:!js) s)

-- | Traverse over all dimensions keeping track of total offset
overDimOff# :: Dim (ds :: [Nat])
            -> (Int# -> a -> State# s -> (# State# s, a #)) -- ^ function to map over each dimension
            -> Int# -- ^ Initial offset
            -> Int# -- ^ offset step
            -> a -> State# s -> (# State# s, a #)
overDimOff# ds f off0# step# = loop off0#
  where
    off1# = case dimVal ds of I# n# -> n# *# step# +# off0#
    cond# = if isTrue# (off1# >=# off0#)
             then \off -> isTrue# (off >=# off1#)
             else \off -> isTrue# (off <=# off1#)
    loop off# a s | cond# off# = (# s,  a #)
                  | otherwise = case f off# a s of
                                  (# s', b #) -> loop (off# +# step#) b s'

-- | Fold over all dimensions keeping track of total offset
foldDimOff :: Dim (ds :: [Nat])
           -> (Int# -> a -> a) -- ^ function to map over each dimension
           -> Int# -- ^ Initial offset
           -> Int# -- ^ offset step
           -> a -> a
foldDimOff ds f off0# step# = loop off0#
  where
    off1# = case dimVal ds of I# n# -> n# *# step# +# off0#
    cond# = if isTrue# (off1# >=# off0#)
             then \off -> isTrue# (off >=# off1#)
             else \off -> isTrue# (off <=# off1#)
    loop off# a | cond# off# = a
                | otherwise  = loop (off# +# step#) $! f off# a


-- | Traverse over all dimensions keeping track of total offset, with not return value
overDimOff_# :: Dim (ds :: [Nat])
             -> (Int# -> State# s -> State# s) -- ^ function to map over each dimension
             -> Int# -- ^ Initial offset
             -> Int# -- ^ offset step
             -> State# s -> State# s
overDimOff_# ds f off0# step# = loop off0#
  where
    off1# = case dimVal ds of I# n# -> n# *# step# +# off0#
    cond# = if isTrue# (off1# >=# off0#)
            then \off -> isTrue# (off >=# off1#)
            else \off -> isTrue# (off <=# off1#)
    loop off# s | cond# off# = s
                | otherwise = loop (off# +# step#) (f off# s)

-- | Traverse from the first index to the second index in each dimension.
--   Indices must be within Dim range, which is not checked.
--   You can combine positive and negative traversal directions along different dimensions.
overDimPart# :: forall (ds :: [Nat]) a s
              . Dimensions ds
             => Idx ds
             -> Idx ds
             -> (Idx ds -> Int# -> a -> State# s -> (# State# s, a #)) -- ^ function to map over each dimension
             -> Int# -- ^ Initial offset
             -> Int# -- ^ offset step
             -> a
             -> State# s
             -> (# State# s, a #)
overDimPart# imin imax f off0 step = overDimPart'# offs imin imax f off0
    where
      offs = createOffsets (dim @ds) (I# step)
      createOffsets :: forall (ns :: [Nat]) . Dim ns -> Int -> Idx ns
      createOffsets D _ = Z
      createOffsets ((Dn :: Dim n) :* (!ds)) k = k :! createOffsets ds (k * dimVal' @n)






overDim'# :: Dim (ds :: [Nat])
          -> (Idx ds -> Int# -> a -> State# s -> (# State# s, Int#, a #)) -- ^ function to map over each dimension
          -> Int# -- ^ Initial offset
          -> a
          -> State# s
          -> (# State# s, Int#,  a #)
overDim'# D f = f Z
overDim'# ((Dn :: Dim n) :* (!ds)) f = overDim'# ds (loop 1)
  where
    n = dimVal' @n
    loop i js off# a s | i > n = (# s , off# , a #)
                       | otherwise = case f (i:!js) off# a s of
                                 (# s', off1#, b #) -> loop (i+1) js off1# b s'



foldDim' :: Dim (ds :: [Nat])
         -> (Idx ds -> Int# -> a -> (# Int#, a #)) -- ^ function to map over each dimension
         -> Int# -- ^ Initial offset
         -> a -> (# Int#,  a #)
foldDim' D f = f Z
foldDim' ((Dn :: Dim n) :* (!ds)) f = foldDim' ds (loop 1)
  where
    n = dimVal' @n
    loop i js off# a | i > n = (#  off#, a #)
                     | otherwise = case f (i:!js) off# a of
                               (# off1#, b #) -> loop (i+1) js off1# b

foldDimReverse' :: Dim (ds :: [Nat])
                -> (Idx ds -> Int# -> a -> (# Int#, a #)) -- ^ function to map over each dimension
                -> Int# -- ^ Initial offset
                -> a -> (# Int#,  a #)
foldDimReverse' D f = f Z
foldDimReverse' ((Dn :: Dim n) :* (!ds)) f = foldDimReverse' ds (loop n)
  where
    n = dimVal' @n
    loop i js off# a | i <= 0 = (#  off#, a #)
                     | otherwise = case f (i:!js) off# a of
                                (# off1#, b #) -> loop (i-1) js off1# b



overDim_'# :: Dim (ds :: [Nat])
           -> (Idx ds -> Int# -> State# s -> (# State# s, Int# #)) -- ^ function to map over each dimension
           -> Int# -- ^ Initial offset
           -> State# s
           -> (# State# s, Int# #)
overDim_'# D f = f Z
overDim_'# ((Dn :: Dim n) :* (!ds)) f = overDim_'# ds (loop 1)
  where
    n = dimVal' @n
    loop i js off# s | i > n = (# s , off#  #)
                     | otherwise = case f (i:!js) off# s of
                               (# s', off1# #) -> loop (i+1) js off1# s'


overDimPart'# :: Idx (ds :: [Nat])
              -> Idx (ds :: [Nat])
              -> Idx (ds :: [Nat])
              -> (Idx ds -> Int# -> a -> State# s -> (# State# s, a #)) -- ^ function to map over each dimension
              -> Int# -- ^ Initial offset
              -> a
              -> State# s
              -> (# State# s, a #)
overDimPart'# _ Z Z f off0# = f Z off0#
overDimPart'# (I# iW:!iws) (iMin:!mins) (iMax:!maxs) f off0#
    | iMax >= iMin = overDimPart'# iws mins maxs (loop iMin) (off0# +# minOff#)
    | otherwise    = overDimPart'# iws mins maxs (looi iMin) (off0# +# minOff#)
  where
    minOff# = case iMin of I# i -> iW *# (i -# 1#)
    loop i js off# a s | i > iMax = (# s, a #)
                       | otherwise = case f (i:!js) off# a s of
                               (# s', b #) -> loop (i+1) js (off# +# iW) b s'
    looi i js off# a s | i < iMax = (# s, a #)
                       | otherwise = case f (i:!js) off# a s of
                               (# s', b #) -> looi (i-1) js (off# -# iW) b s'