{-# LANGUAGE CPP                        #-}
{-# LANGUAGE PatternSynonyms            #-}
#if __GLASGOW_HASKELL__ >= 800
{-# OPTIONS_GHC -Wno-unrecognised-pragmas #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
#else
{-# LANGUAGE GADTs                      #-}
{-# LANGUAGE StandaloneDeriving         #-}
#endif
module Data.Massiv.Core.Index.Stride
  ( Stride(SafeStride)
  , pattern Stride
  , unStride
  , oneStride
  , toLinearIndexStride
  , strideStart
  , strideSize
  ) where
import           Control.DeepSeq
import           Data.Massiv.Core.Index.Class
#if __GLASGOW_HASKELL__ >= 800
newtype Stride ix = SafeStride ix deriving (Eq, Ord, NFData)
{-# COMPLETE Stride #-}
#else
data Stride ix where
  SafeStride :: Index ix => ix -> Stride ix
deriving instance Eq ix => Eq (Stride ix)
deriving instance Ord ix => Ord (Stride ix)
instance NFData ix => NFData (Stride ix) where
  rnf (SafeStride ix) = rnf ix
#endif
pattern Stride :: Index ix => ix -> Stride ix
pattern Stride ix <- SafeStride ix where
        Stride ix = SafeStride (liftIndex (max 1) ix)
instance Index ix => Show (Stride ix) where
  show (SafeStride ix) = "Stride (" ++ show ix ++ ")"
unStride :: Stride ix -> ix
unStride (SafeStride ix) = ix
{-# INLINE unStride #-}
strideStart :: Index ix => Stride ix -> ix -> ix
strideStart (SafeStride stride) ix =
  liftIndex2
    (+)
    ix
    (liftIndex2 mod (liftIndex2 subtract (liftIndex2 mod ix stride) stride) stride)
{-# INLINE strideStart #-}
strideSize :: Index ix => Stride ix -> ix -> ix
strideSize (SafeStride stride) sz = liftIndex (+ 1) $ liftIndex2 div (liftIndex (subtract 1) sz) stride
{-# INLINE strideSize #-}
toLinearIndexStride :: Index ix =>
  Stride ix 
  -> ix 
  -> ix 
  -> Int
toLinearIndexStride (SafeStride stride) sz ix = toLinearIndex sz (liftIndex2 div ix stride)
{-# INLINE toLinearIndexStride #-}
oneStride :: Index ix => Stride ix
oneStride = SafeStride (pureIndex 1)
{-# INLINE oneStride #-}