{-# LANGUAGE TypeOperators, FlexibleInstances, ScopedTypeVariables #-}

-- | Index types.
module Data.Array.Repa.Index
	(
	-- * Index types
	  Z	(..)
	, (:.)	(..)

	-- * Common dimensions.
	, DIM0, DIM1, DIM2, DIM3, DIM4, DIM5
        ,       ix1,  ix2,  ix3,  ix4,  ix5)
where
import Data.Array.Repa.Shape
import GHC.Base 		(quotInt, remInt)

stage	= "Data.Array.Repa.Index"

-- | An index of dimension zero
data Z	= Z
	deriving (Show, Read, Eq, Ord)

-- | Our index type, used for both shapes and indices.
infixl 3 :.
data tail :. head
	= !tail :. !head
	deriving (Show, Read, Eq, Ord)

-- Common dimensions
type DIM0	= Z
type DIM1	= DIM0 :. Int
type DIM2	= DIM1 :. Int
type DIM3	= DIM2 :. Int
type DIM4	= DIM3 :. Int
type DIM5	= DIM4 :. Int


-- | Helper for index construction.
--
--   Use this instead of explicit constructors like @(Z :. (x :: Int))@.
--   The this is sometimes needed to ensure that 'x' is constrained to 
--   be in @Int@.
ix1 :: Int -> DIM1
ix1 x = Z :. x
{-# INLINE ix1 #-}

ix2 :: Int -> Int -> DIM2
ix2 y x = Z :. y :. x
{-# INLINE ix2 #-}

ix3 :: Int -> Int -> Int -> DIM3
ix3 z y x = Z :. z :. y :. x
{-# INLINE ix3 #-}

ix4 :: Int -> Int -> Int -> Int -> DIM4
ix4 a z y x = Z :. a :. z :. y :. x
{-# INLINE ix4 #-}

ix5 :: Int -> Int -> Int -> Int -> Int -> DIM5
ix5 b a z y x = Z :. b :. a :. z :. y :. x
{-# INLINE ix5 #-}


-- Shape ----------------------------------------------------------------------
instance Shape Z where
	{-# INLINE [1] rank #-}
	rank _			= 0

	{-# INLINE [1] zeroDim #-}
	zeroDim		 	= Z

	{-# INLINE [1] unitDim #-}
	unitDim			= Z

	{-# INLINE [1] intersectDim #-}
	intersectDim _ _	= Z

	{-# INLINE [1] addDim #-}
	addDim _ _		= Z

	{-# INLINE [1] size #-}
	size _			= 1

	{-# INLINE [1] sizeIsValid #-}
	sizeIsValid _		= True


	{-# INLINE [1] toIndex #-}
	toIndex _ _		= 0

	{-# INLINE [1] fromIndex #-}
	fromIndex _ _		= Z


	{-# INLINE [1] inShapeRange #-}
	inShapeRange Z Z Z	= True

        {-# NOINLINE listOfShape #-}
	listOfShape _		= []

        {-# NOINLINE shapeOfList #-}
	shapeOfList []		= Z
	shapeOfList _		= error $ stage ++ ".fromList: non-empty list when converting to Z."

	{-# INLINE deepSeq #-}
	deepSeq Z x		= x


instance Shape sh => Shape (sh :. Int) where
	{-# INLINE [1] rank #-}
	rank   (sh  :. _)
		= rank sh + 1

	{-# INLINE [1] zeroDim #-}
	zeroDim = zeroDim :. 0

	{-# INLINE [1] unitDim #-}
	unitDim = unitDim :. 1

	{-# INLINE [1] intersectDim #-}
	intersectDim (sh1 :. n1) (sh2 :. n2)
		= (intersectDim sh1 sh2 :. (min n1 n2))

	{-# INLINE [1] addDim #-}
	addDim (sh1 :. n1) (sh2 :. n2)
		= addDim sh1 sh2 :. (n1 + n2)

	{-# INLINE [1] size #-}
	size  (sh1 :. n)
		= size sh1 * n

	{-# INLINE [1] sizeIsValid #-}
	sizeIsValid (sh1 :. n)
		| size sh1 > 0
		= n <= maxBound `div` size sh1

		| otherwise
		= False

	{-# INLINE [1] toIndex #-}
	toIndex (sh1 :. sh2) (sh1' :. sh2')
		= toIndex sh1 sh1' * sh2 + sh2'

	{-# INLINE [1] fromIndex #-}
        fromIndex (ds :. d) n
                = fromIndex ds (n `quotInt` d) :. r
                where
                -- If we assume that the index is in range, there is no point
                -- in computing the remainder for the highest dimension since
                -- n < d must hold. This saves one remInt per element access which
                -- is quite a big deal.
                r       | rank ds == 0  = n
                        | otherwise     = n `remInt` d

	{-# INLINE [1] inShapeRange #-}
	inShapeRange (zs :. z) (sh1 :. n1) (sh2 :. n2)
		= (n2 >= z) && (n2 < n1) && (inShapeRange zs sh1 sh2)

        {-# NOINLINE listOfShape #-}
       	listOfShape (sh :. n)
	 = n : listOfShape sh

        {-# NOINLINE shapeOfList #-}
	shapeOfList xx
	 = case xx of
		[]	-> error $ stage ++ ".toList: empty list when converting to  (_ :. Int)"
		x:xs	-> shapeOfList xs :. x

	{-# INLINE deepSeq #-}
	deepSeq (sh :. n) x = deepSeq sh (n `seq` x)