{-# LANGUAGE BangPatterns    #-}
{-# LANGUAGE DataKinds       #-}
{-# LANGUAGE GADTs           #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE TypeOperators   #-}
module Data.Massiv.Core.Index
  ( module Data.Massiv.Core.Index.Ix
  , Stride
  , pattern Stride
  , unStride
  , toLinearIndexStride
  , strideStart
  , strideSize
  , oneStride
  , Border(..)
  , handleBorderIndex
  , module Data.Massiv.Core.Index.Class
  , zeroIndex
  , isSafeSize
  , isNonEmpty
  , headDim
  , tailDim
  , lastDim
  , initDim
  , getIndex'
  , setIndex'
  , getDim'
  , setDim'
  , dropDim'
  , pullOutDim'
  , insertDim'
  , fromDimension
  , getDimension
  , setDimension
  , dropDimension
  , pullOutDimension
  , insertDimension
  , iterLinearM
  , iterLinearM_
  , module Data.Massiv.Core.Iterator
  ) where
import           Control.DeepSeq
import           Data.Massiv.Core.Index.Class
import           Data.Massiv.Core.Index.Ix
import           Data.Massiv.Core.Index.Stride
import           Data.Massiv.Core.Iterator
import           GHC.TypeLits
data Border e =
  Fill e    
              
              
              
              
              
              
  | Wrap      
              
              
              
              
              
              
  | Edge      
              
              
              
              
              
              
  | Reflect   
              
              
              
              
              
              
  | Continue  
              
              
              
              
              
              
  deriving (Eq, Show)
instance NFData e => NFData (Border e) where
  rnf b = case b of
            Fill e   -> rnf e
            Wrap     -> ()
            Edge     -> ()
            Reflect  -> ()
            Continue -> ()
handleBorderIndex ::
     Index ix
  => Border e 
  -> ix 
  -> (ix -> e) 
  -> ix 
  -> e
handleBorderIndex border !sz getVal !ix =
  case border of
    Fill val -> if isSafeIndex sz ix then getVal ix else val
    Wrap     -> getVal (repairIndex sz ix (flip mod) (flip mod))
    Edge     -> getVal (repairIndex sz ix (const (const 0)) (\ !k _ -> k - 1))
    Reflect  -> getVal (repairIndex sz ix (\ !k !i -> (abs i - 1) `mod` k)
                        (\ !k !i -> (-i - 1) `mod` k))
    Continue -> getVal (repairIndex sz ix (\ !k !i -> abs i `mod` k)
                        (\ !k !i -> (-i - 2) `mod` k))
{-# INLINE [1] handleBorderIndex #-}
zeroIndex :: Index ix => ix
zeroIndex = pureIndex 0
{-# INLINE [1] zeroIndex #-}
isSafeSize :: Index ix => ix -> Bool
isSafeSize = (zeroIndex >=)
{-# INLINE [1] isSafeSize #-}
isNonEmpty :: Index ix => ix -> Bool
isNonEmpty !sz = isSafeIndex sz zeroIndex
{-# INLINE [1] isNonEmpty #-}
headDim :: Index ix => ix -> Int
headDim = fst . unconsDim
{-# INLINE [1] headDim #-}
tailDim :: Index ix => ix -> Lower ix
tailDim = snd . unconsDim
{-# INLINE [1] tailDim #-}
lastDim :: Index ix => ix -> Int
lastDim = snd . unsnocDim
{-# INLINE [1] lastDim #-}
initDim :: Index ix => ix -> Lower ix
initDim = fst . unsnocDim
{-# INLINE [1] initDim #-}
setDim' :: Index ix => ix -> Dim -> Int -> ix
setDim' ix dim i =
  case setDim ix dim i of
    Just ix' -> ix'
    Nothing  -> errorDim "setDim'" dim
{-# INLINE [1] setDim' #-}
getDim' :: Index ix => ix -> Dim -> Int
getDim' ix dim =
  case getDim ix dim of
    Just ix' -> ix'
    Nothing  -> errorDim "getDim'" dim
{-# INLINE [1] getDim' #-}
setIndex' :: Index ix => ix -> Dim -> Int -> ix
setIndex' ix dim i =
  case setDim ix dim i of
    Just ix' -> ix'
    Nothing  -> errorDim "setIndex'" dim
{-# INLINE [1] setIndex' #-}
{-# DEPRECATED setIndex' "In favor of `setDim'`" #-}
getIndex' :: Index ix => ix -> Dim -> Int
getIndex' ix dim =
  case getDim ix dim of
    Just ix' -> ix'
    Nothing  -> errorDim "getIndex'" dim
{-# INLINE [1] getIndex' #-}
{-# DEPRECATED getIndex' "In favor of `getDim'`" #-}
dropDim' :: Index ix => ix -> Dim -> Lower ix
dropDim' ix dim =
  case dropDim ix dim of
    Just ixl -> ixl
    Nothing  -> errorDim "dropDim'" dim
{-# INLINE [1] dropDim' #-}
pullOutDim' :: Index ix => ix -> Dim -> (Int, Lower ix)
pullOutDim' ix dim =
  case pullOutDim ix dim of
    Just i_ixl -> i_ixl
    Nothing  -> errorDim "pullOutDim'" dim
{-# INLINE [1] pullOutDim' #-}
insertDim' :: Index ix => Lower ix -> Dim -> Int -> ix
insertDim' ix dim i =
  case insertDim ix dim i of
    Just ix' -> ix'
    Nothing  -> errorDim "insertDim'" dim
{-# INLINE [1] insertDim' #-}
errorDim :: String -> Dim -> a
errorDim funName dim = error $ funName ++ ": Dimension is out of reach: " ++ show dim
{-# NOINLINE errorDim #-}
fromDimension :: KnownNat n => Dimension n -> Dim
fromDimension = fromIntegral . natVal
{-# INLINE [1] fromDimension #-}
setDimension :: IsIndexDimension ix n => ix -> Dimension n -> Int -> ix
setDimension ix d = setDim' ix (fromDimension d)
{-# INLINE [1] setDimension #-}
getDimension :: IsIndexDimension ix n => ix -> Dimension n -> Int
getDimension ix d = getDim' ix (fromDimension d)
{-# INLINE [1] getDimension #-}
dropDimension :: IsIndexDimension ix n => ix -> Dimension n -> Lower ix
dropDimension ix d = dropDim' ix (fromDimension d)
{-# INLINE [1] dropDimension #-}
pullOutDimension :: IsIndexDimension ix n => ix -> Dimension n -> (Int, Lower ix)
pullOutDimension ix d = pullOutDim' ix (fromDimension d)
{-# INLINE [1] pullOutDimension #-}
insertDimension :: IsIndexDimension ix n => Lower ix -> Dimension n -> Int -> ix
insertDimension ix d = insertDim' ix (fromDimension d)
{-# INLINE [1] insertDimension #-}
iterLinearM :: (Index ix, Monad m)
            => ix 
            -> Int 
            -> Int 
            -> Int 
            -> (Int -> Int -> Bool) 
            -> a 
            -> (Int -> ix -> a -> m a)
            -> m a
iterLinearM !sz !k0 !k1 !inc cond !acc f =
  loopM k0 (`cond` k1) (+ inc) acc $ \ !i !acc0 -> f i (fromLinearIndex sz i) acc0
{-# INLINE iterLinearM #-}
iterLinearM_ :: (Index ix, Monad m) =>
                ix 
             -> Int 
             -> Int 
             -> Int 
             -> (Int -> Int -> Bool) 
             -> (Int -> ix -> m ()) 
             -> m ()
iterLinearM_ !sz !k0 !k1 !inc cond f =
  loopM_ k0 (`cond` k1) (+ inc) $ \ !i -> f i (fromLinearIndex sz i)
{-# INLINE iterLinearM_ #-}