{-# LANGUAGE CPP #-}
{-# LANGUAGE PatternSynonyms #-}
#if __GLASGOW_HASKELL__ >= 800
{-# OPTIONS_GHC -Wno-unrecognised-pragmas #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
#else
{-# LANGUAGE GADTs #-}
{-# LANGUAGE StandaloneDeriving #-}
#endif
module Data.Massiv.Core.Index.Stride
( Stride(SafeStride)
, pattern Stride
, unStride
, oneStride
, toLinearIndexStride
, strideStart
, strideSize
) where
import Control.DeepSeq
import Data.Massiv.Core.Index.Class
#if __GLASGOW_HASKELL__ >= 800
newtype Stride ix = SafeStride ix deriving (Eq, Ord, NFData)
{-# COMPLETE Stride #-}
#else
data Stride ix where
SafeStride :: Index ix => ix -> Stride ix
deriving instance Eq ix => Eq (Stride ix)
deriving instance Ord ix => Ord (Stride ix)
instance NFData ix => NFData (Stride ix) where
rnf (SafeStride ix) = rnf ix
#endif
pattern Stride :: Index ix => ix -> Stride ix
pattern Stride ix <- SafeStride ix where
Stride ix = SafeStride (liftIndex (max 1) ix)
instance Index ix => Show (Stride ix) where
show (SafeStride ix) = "Stride (" ++ show ix ++ ")"
unStride :: Stride ix -> ix
unStride (SafeStride ix) = ix
{-# INLINE unStride #-}
strideStart :: Index ix => Stride ix -> ix -> ix
strideStart (SafeStride stride) ix =
liftIndex2
(+)
ix
(liftIndex2 mod (liftIndex2 subtract (liftIndex2 mod ix stride) stride) stride)
{-# INLINE strideStart #-}
strideSize :: Index ix => Stride ix -> Sz ix -> ix
strideSize (SafeStride stride) sz = liftIndex (+ 1) $ liftIndex2 div (liftIndex (subtract 1) sz) stride
{-# INLINE strideSize #-}
toLinearIndexStride :: Index ix =>
Stride ix
-> ix
-> ix
-> Int
toLinearIndexStride (SafeStride stride) sz ix = toLinearIndex sz (liftIndex2 div ix stride)
{-# INLINE toLinearIndexStride #-}
oneStride :: Index ix => Stride ix
oneStride = SafeStride (pureIndex 1)
{-# INLINE oneStride #-}