{-# LANGUAGE TypeOperators, FlexibleInstances, ScopedTypeVariables #-}

-- | Index types.
module Data.Array.Repa.Index
        (
        -- * Index types
          Z     (..)
        , (:.)  (..)

        -- * Common dimensions.
        , DIM0, DIM1, DIM2, DIM3, DIM4, DIM5
        ,       ix1,  ix2,  ix3,  ix4,  ix5)
where
import Data.Array.Repa.Shape
import GHC.Base                 (quotInt, remInt)

stage :: [Char]
stage   = [Char]
"Data.Array.Repa.Index"

-- | An index of dimension zero
data Z  = Z
        deriving (Int -> Z -> ShowS
[Z] -> ShowS
Z -> [Char]
(Int -> Z -> ShowS) -> (Z -> [Char]) -> ([Z] -> ShowS) -> Show Z
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
showList :: [Z] -> ShowS
$cshowList :: [Z] -> ShowS
show :: Z -> [Char]
$cshow :: Z -> [Char]
showsPrec :: Int -> Z -> ShowS
$cshowsPrec :: Int -> Z -> ShowS
Show, ReadPrec [Z]
ReadPrec Z
Int -> ReadS Z
ReadS [Z]
(Int -> ReadS Z)
-> ReadS [Z] -> ReadPrec Z -> ReadPrec [Z] -> Read Z
forall a.
(Int -> ReadS a)
-> ReadS [a] -> ReadPrec a -> ReadPrec [a] -> Read a
readListPrec :: ReadPrec [Z]
$creadListPrec :: ReadPrec [Z]
readPrec :: ReadPrec Z
$creadPrec :: ReadPrec Z
readList :: ReadS [Z]
$creadList :: ReadS [Z]
readsPrec :: Int -> ReadS Z
$creadsPrec :: Int -> ReadS Z
Read, Z -> Z -> Bool
(Z -> Z -> Bool) -> (Z -> Z -> Bool) -> Eq Z
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Z -> Z -> Bool
$c/= :: Z -> Z -> Bool
== :: Z -> Z -> Bool
$c== :: Z -> Z -> Bool
Eq, Eq Z
Eq Z
-> (Z -> Z -> Ordering)
-> (Z -> Z -> Bool)
-> (Z -> Z -> Bool)
-> (Z -> Z -> Bool)
-> (Z -> Z -> Bool)
-> (Z -> Z -> Z)
-> (Z -> Z -> Z)
-> Ord Z
Z -> Z -> Bool
Z -> Z -> Ordering
Z -> Z -> Z
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: Z -> Z -> Z
$cmin :: Z -> Z -> Z
max :: Z -> Z -> Z
$cmax :: Z -> Z -> Z
>= :: Z -> Z -> Bool
$c>= :: Z -> Z -> Bool
> :: Z -> Z -> Bool
$c> :: Z -> Z -> Bool
<= :: Z -> Z -> Bool
$c<= :: Z -> Z -> Bool
< :: Z -> Z -> Bool
$c< :: Z -> Z -> Bool
compare :: Z -> Z -> Ordering
$ccompare :: Z -> Z -> Ordering
$cp1Ord :: Eq Z
Ord)

-- | Our index type, used for both shapes and indices.
infixl 3 :.
data tail :. head
        = !tail :. !head
        deriving (Int -> (tail :. head) -> ShowS
[tail :. head] -> ShowS
(tail :. head) -> [Char]
(Int -> (tail :. head) -> ShowS)
-> ((tail :. head) -> [Char])
-> ([tail :. head] -> ShowS)
-> Show (tail :. head)
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
forall tail head.
(Show tail, Show head) =>
Int -> (tail :. head) -> ShowS
forall tail head. (Show tail, Show head) => [tail :. head] -> ShowS
forall tail head.
(Show tail, Show head) =>
(tail :. head) -> [Char]
showList :: [tail :. head] -> ShowS
$cshowList :: forall tail head. (Show tail, Show head) => [tail :. head] -> ShowS
show :: (tail :. head) -> [Char]
$cshow :: forall tail head.
(Show tail, Show head) =>
(tail :. head) -> [Char]
showsPrec :: Int -> (tail :. head) -> ShowS
$cshowsPrec :: forall tail head.
(Show tail, Show head) =>
Int -> (tail :. head) -> ShowS
Show, ReadPrec [tail :. head]
ReadPrec (tail :. head)
Int -> ReadS (tail :. head)
ReadS [tail :. head]
(Int -> ReadS (tail :. head))
-> ReadS [tail :. head]
-> ReadPrec (tail :. head)
-> ReadPrec [tail :. head]
-> Read (tail :. head)
forall a.
(Int -> ReadS a)
-> ReadS [a] -> ReadPrec a -> ReadPrec [a] -> Read a
forall tail head. (Read tail, Read head) => ReadPrec [tail :. head]
forall tail head. (Read tail, Read head) => ReadPrec (tail :. head)
forall tail head.
(Read tail, Read head) =>
Int -> ReadS (tail :. head)
forall tail head. (Read tail, Read head) => ReadS [tail :. head]
readListPrec :: ReadPrec [tail :. head]
$creadListPrec :: forall tail head. (Read tail, Read head) => ReadPrec [tail :. head]
readPrec :: ReadPrec (tail :. head)
$creadPrec :: forall tail head. (Read tail, Read head) => ReadPrec (tail :. head)
readList :: ReadS [tail :. head]
$creadList :: forall tail head. (Read tail, Read head) => ReadS [tail :. head]
readsPrec :: Int -> ReadS (tail :. head)
$creadsPrec :: forall tail head.
(Read tail, Read head) =>
Int -> ReadS (tail :. head)
Read, (tail :. head) -> (tail :. head) -> Bool
((tail :. head) -> (tail :. head) -> Bool)
-> ((tail :. head) -> (tail :. head) -> Bool) -> Eq (tail :. head)
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
forall tail head.
(Eq tail, Eq head) =>
(tail :. head) -> (tail :. head) -> Bool
/= :: (tail :. head) -> (tail :. head) -> Bool
$c/= :: forall tail head.
(Eq tail, Eq head) =>
(tail :. head) -> (tail :. head) -> Bool
== :: (tail :. head) -> (tail :. head) -> Bool
$c== :: forall tail head.
(Eq tail, Eq head) =>
(tail :. head) -> (tail :. head) -> Bool
Eq, Eq (tail :. head)
Eq (tail :. head)
-> ((tail :. head) -> (tail :. head) -> Ordering)
-> ((tail :. head) -> (tail :. head) -> Bool)
-> ((tail :. head) -> (tail :. head) -> Bool)
-> ((tail :. head) -> (tail :. head) -> Bool)
-> ((tail :. head) -> (tail :. head) -> Bool)
-> ((tail :. head) -> (tail :. head) -> tail :. head)
-> ((tail :. head) -> (tail :. head) -> tail :. head)
-> Ord (tail :. head)
(tail :. head) -> (tail :. head) -> Bool
(tail :. head) -> (tail :. head) -> Ordering
(tail :. head) -> (tail :. head) -> tail :. head
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall tail head. (Ord tail, Ord head) => Eq (tail :. head)
forall tail head.
(Ord tail, Ord head) =>
(tail :. head) -> (tail :. head) -> Bool
forall tail head.
(Ord tail, Ord head) =>
(tail :. head) -> (tail :. head) -> Ordering
forall tail head.
(Ord tail, Ord head) =>
(tail :. head) -> (tail :. head) -> tail :. head
min :: (tail :. head) -> (tail :. head) -> tail :. head
$cmin :: forall tail head.
(Ord tail, Ord head) =>
(tail :. head) -> (tail :. head) -> tail :. head
max :: (tail :. head) -> (tail :. head) -> tail :. head
$cmax :: forall tail head.
(Ord tail, Ord head) =>
(tail :. head) -> (tail :. head) -> tail :. head
>= :: (tail :. head) -> (tail :. head) -> Bool
$c>= :: forall tail head.
(Ord tail, Ord head) =>
(tail :. head) -> (tail :. head) -> Bool
> :: (tail :. head) -> (tail :. head) -> Bool
$c> :: forall tail head.
(Ord tail, Ord head) =>
(tail :. head) -> (tail :. head) -> Bool
<= :: (tail :. head) -> (tail :. head) -> Bool
$c<= :: forall tail head.
(Ord tail, Ord head) =>
(tail :. head) -> (tail :. head) -> Bool
< :: (tail :. head) -> (tail :. head) -> Bool
$c< :: forall tail head.
(Ord tail, Ord head) =>
(tail :. head) -> (tail :. head) -> Bool
compare :: (tail :. head) -> (tail :. head) -> Ordering
$ccompare :: forall tail head.
(Ord tail, Ord head) =>
(tail :. head) -> (tail :. head) -> Ordering
$cp1Ord :: forall tail head. (Ord tail, Ord head) => Eq (tail :. head)
Ord)

-- Common dimensions
type DIM0       = Z
type DIM1       = DIM0 :. Int
type DIM2       = DIM1 :. Int
type DIM3       = DIM2 :. Int
type DIM4       = DIM3 :. Int
type DIM5       = DIM4 :. Int


-- | Helper for index construction.
--
--   Use this instead of explicit constructors like @(Z :. (x :: Int))@.
--   The this is sometimes needed to ensure that 'x' is constrained to 
--   be in @Int@.
ix1 :: Int -> DIM1
ix1 :: Int -> DIM1
ix1 Int
x = Z
Z Z -> Int -> DIM1
forall tail head. tail -> head -> tail :. head
:. Int
x
{-# INLINE ix1 #-}

ix2 :: Int -> Int -> DIM2
ix2 :: Int -> Int -> DIM2
ix2 Int
y Int
x = Z
Z Z -> Int -> DIM1
forall tail head. tail -> head -> tail :. head
:. Int
y DIM1 -> Int -> DIM2
forall tail head. tail -> head -> tail :. head
:. Int
x
{-# INLINE ix2 #-}

ix3 :: Int -> Int -> Int -> DIM3
ix3 :: Int -> Int -> Int -> DIM3
ix3 Int
z Int
y Int
x = Z
Z Z -> Int -> DIM1
forall tail head. tail -> head -> tail :. head
:. Int
z DIM1 -> Int -> DIM2
forall tail head. tail -> head -> tail :. head
:. Int
y DIM2 -> Int -> DIM3
forall tail head. tail -> head -> tail :. head
:. Int
x
{-# INLINE ix3 #-}

ix4 :: Int -> Int -> Int -> Int -> DIM4
ix4 :: Int -> Int -> Int -> Int -> DIM4
ix4 Int
a Int
z Int
y Int
x = Z
Z Z -> Int -> DIM1
forall tail head. tail -> head -> tail :. head
:. Int
a DIM1 -> Int -> DIM2
forall tail head. tail -> head -> tail :. head
:. Int
z DIM2 -> Int -> DIM3
forall tail head. tail -> head -> tail :. head
:. Int
y DIM3 -> Int -> DIM4
forall tail head. tail -> head -> tail :. head
:. Int
x
{-# INLINE ix4 #-}

ix5 :: Int -> Int -> Int -> Int -> Int -> DIM5
ix5 :: Int -> Int -> Int -> Int -> Int -> DIM5
ix5 Int
b Int
a Int
z Int
y Int
x = Z
Z Z -> Int -> DIM1
forall tail head. tail -> head -> tail :. head
:. Int
b DIM1 -> Int -> DIM2
forall tail head. tail -> head -> tail :. head
:. Int
a DIM2 -> Int -> DIM3
forall tail head. tail -> head -> tail :. head
:. Int
z DIM3 -> Int -> DIM4
forall tail head. tail -> head -> tail :. head
:. Int
y DIM4 -> Int -> DIM5
forall tail head. tail -> head -> tail :. head
:. Int
x
{-# INLINE ix5 #-}


-- Shape ----------------------------------------------------------------------
instance Shape Z where
        {-# INLINE [1] rank #-}
        rank :: Z -> Int
rank Z
_                  = Int
0

        {-# INLINE [1] zeroDim #-}
        zeroDim :: Z
zeroDim                 = Z
Z

        {-# INLINE [1] unitDim #-}
        unitDim :: Z
unitDim                 = Z
Z

        {-# INLINE [1] intersectDim #-}
        intersectDim :: Z -> Z -> Z
intersectDim Z
_ Z
_        = Z
Z

        {-# INLINE [1] addDim #-}
        addDim :: Z -> Z -> Z
addDim Z
_ Z
_              = Z
Z

        {-# INLINE [1] size #-}
        size :: Z -> Int
size Z
_                  = Int
1

        {-# INLINE [1] sizeIsValid #-}
        sizeIsValid :: Z -> Bool
sizeIsValid Z
_           = Bool
True


        {-# INLINE [1] toIndex #-}
        toIndex :: Z -> Z -> Int
toIndex Z
_ Z
_             = Int
0

        {-# INLINE [1] fromIndex #-}
        fromIndex :: Z -> Int -> Z
fromIndex Z
_ Int
_           = Z
Z


        {-# INLINE [1] inShapeRange #-}
        inShapeRange :: Z -> Z -> Z -> Bool
inShapeRange Z
Z Z
Z Z
Z      = Bool
True

        {-# NOINLINE listOfShape #-}
        listOfShape :: Z -> [Int]
listOfShape Z
_           = []

        {-# NOINLINE shapeOfList #-}
        shapeOfList :: [Int] -> Z
shapeOfList []          = Z
Z
        shapeOfList [Int]
_           = [Char] -> Z
forall a. HasCallStack => [Char] -> a
error ([Char] -> Z) -> [Char] -> Z
forall a b. (a -> b) -> a -> b
$ [Char]
stage [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ [Char]
".fromList: non-empty list when converting to Z."

        {-# INLINE deepSeq #-}
        deepSeq :: Z -> a -> a
deepSeq Z
Z a
x             = a
x


instance Shape sh => Shape (sh :. Int) where
        {-# INLINE [1] rank #-}
        rank :: (sh :. Int) -> Int
rank   (sh
sh  :. Int
_)
                = sh -> Int
forall sh. Shape sh => sh -> Int
rank sh
sh Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1

        {-# INLINE [1] zeroDim #-}
        zeroDim :: sh :. Int
zeroDim = sh
forall sh. Shape sh => sh
zeroDim sh -> Int -> sh :. Int
forall tail head. tail -> head -> tail :. head
:. Int
0

        {-# INLINE [1] unitDim #-}
        unitDim :: sh :. Int
unitDim = sh
forall sh. Shape sh => sh
unitDim sh -> Int -> sh :. Int
forall tail head. tail -> head -> tail :. head
:. Int
1

        {-# INLINE [1] intersectDim #-}
        intersectDim :: (sh :. Int) -> (sh :. Int) -> sh :. Int
intersectDim (sh
sh1 :. Int
n1) (sh
sh2 :. Int
n2)
                = (sh -> sh -> sh
forall sh. Shape sh => sh -> sh -> sh
intersectDim sh
sh1 sh
sh2 sh -> Int -> sh :. Int
forall tail head. tail -> head -> tail :. head
:. (Int -> Int -> Int
forall a. Ord a => a -> a -> a
min Int
n1 Int
n2))

        {-# INLINE [1] addDim #-}
        addDim :: (sh :. Int) -> (sh :. Int) -> sh :. Int
addDim (sh
sh1 :. Int
n1) (sh
sh2 :. Int
n2)
                = sh -> sh -> sh
forall sh. Shape sh => sh -> sh -> sh
addDim sh
sh1 sh
sh2 sh -> Int -> sh :. Int
forall tail head. tail -> head -> tail :. head
:. (Int
n1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
n2)

        {-# INLINE [1] size #-}
        size :: (sh :. Int) -> Int
size  (sh
sh1 :. Int
n)
                = sh -> Int
forall sh. Shape sh => sh -> Int
size sh
sh1 Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
n

        {-# INLINE [1] sizeIsValid #-}
        sizeIsValid :: (sh :. Int) -> Bool
sizeIsValid (sh
sh1 :. Int
n)
                | sh -> Int
forall sh. Shape sh => sh -> Int
size sh
sh1 Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0
                = Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
forall a. Bounded a => a
maxBound Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` sh -> Int
forall sh. Shape sh => sh -> Int
size sh
sh1

                | Bool
otherwise
                = Bool
False

        {-# INLINE [1] toIndex #-}
        toIndex :: (sh :. Int) -> (sh :. Int) -> Int
toIndex (sh
sh1 :. Int
sh2) (sh
sh1' :. Int
sh2')
                = sh -> sh -> Int
forall sh. Shape sh => sh -> sh -> Int
toIndex sh
sh1 sh
sh1' Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
sh2 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
sh2'

        {-# INLINE [1] fromIndex #-}
        fromIndex :: (sh :. Int) -> Int -> sh :. Int
fromIndex (sh
ds :. Int
d) Int
n
                = sh -> Int -> sh
forall sh. Shape sh => sh -> Int -> sh
fromIndex sh
ds (Int
n Int -> Int -> Int
`quotInt` Int
d) sh -> Int -> sh :. Int
forall tail head. tail -> head -> tail :. head
:. Int
r
                where
                -- If we assume that the index is in range, there is no point
                -- in computing the remainder for the highest dimension since
                -- n < d must hold. This saves one remInt per element access which
                -- is quite a big deal.
                r :: Int
r       | sh -> Int
forall sh. Shape sh => sh -> Int
rank sh
ds Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0  = Int
n
                        | Bool
otherwise     = Int
n Int -> Int -> Int
`remInt` Int
d

        {-# INLINE [1] inShapeRange #-}
        inShapeRange :: (sh :. Int) -> (sh :. Int) -> (sh :. Int) -> Bool
inShapeRange (sh
zs :. Int
z) (sh
sh1 :. Int
n1) (sh
sh2 :. Int
n2)
                = (Int
n2 Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
z) Bool -> Bool -> Bool
&& (Int
n2 Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
n1) Bool -> Bool -> Bool
&& (sh -> sh -> sh -> Bool
forall sh. Shape sh => sh -> sh -> sh -> Bool
inShapeRange sh
zs sh
sh1 sh
sh2)

        {-# NOINLINE listOfShape #-}
        listOfShape :: (sh :. Int) -> [Int]
listOfShape (sh
sh :. Int
n)
         = Int
n Int -> [Int] -> [Int]
forall a. a -> [a] -> [a]
: sh -> [Int]
forall sh. Shape sh => sh -> [Int]
listOfShape sh
sh

        {-# NOINLINE shapeOfList #-}
        shapeOfList :: [Int] -> sh :. Int
shapeOfList [Int]
xx
         = case [Int]
xx of
                []      -> [Char] -> sh :. Int
forall a. HasCallStack => [Char] -> a
error ([Char] -> sh :. Int) -> [Char] -> sh :. Int
forall a b. (a -> b) -> a -> b
$ [Char]
stage [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ [Char]
".toList: empty list when converting to  (_ :. Int)"
                Int
x:[Int]
xs    -> [Int] -> sh
forall sh. Shape sh => [Int] -> sh
shapeOfList [Int]
xs sh -> Int -> sh :. Int
forall tail head. tail -> head -> tail :. head
:. Int
x

        {-# INLINE deepSeq #-}
        deepSeq :: (sh :. Int) -> a -> a
deepSeq (sh
sh :. Int
n) a
x = sh -> a -> a
forall sh a. Shape sh => sh -> a -> a
deepSeq sh
sh (Int
n Int -> a -> a
`seq` a
x)