-- | Class of types that can be used as array shapes and indices. module Data.Repa.Array.Internals.Shape ( -- * Shapes Shape(..) -- * Shape operators , inShape , showShape -- * Polymorphic shapes , Z (..) , (:.) (..) , SH0, SH1, SH2, SH3, SH4, SH5 , ish0, ish1, ish2, ish3, ish4, ish5) where #include "repa-array.h" -- | Class of types that can be used as array shapes and indices. class Eq sh => Shape sh where -- | Get the number of dimensions in a shape. rank :: sh -> Int -- | The shape of an array of size zero, with a particular -- dimensionality. zeroDim :: sh -- | The shape of an array with size one, -- with a particular dimensionality. unitDim :: sh -- | Compute the intersection of two shapes. intersectDim :: sh -> sh -> sh -- | Add the coordinates of two shapes componentwise addDim :: sh -> sh -> sh -- | Get the total number of elements in an array with this shape. size :: sh -> Int -- | Given a starting and ending index, check if some index is with -- that range. inShapeRange :: sh -> sh -> sh -> Bool -- | Convert a shape into its list of dimensions. listOfShape :: sh -> [Int] -- | Convert a list of dimensions to a shape shapeOfList :: [Int] -> Maybe sh ------------------------------------------------------------------------------- -- | Given an array shape and index, check whether the index is in the shape. inShape :: Shape sh => sh -> sh -> Bool inShape sh ix = inShapeRange zeroDim sh ix {-# INLINE_ARRAY inShape #-} -- | Nicely format a shape as a string showShape :: Shape sh => sh -> String showShape = foldr (\sh str -> str ++ " :. " ++ show sh) "Z" . listOfShape {-# NOINLINE showShape #-} ------------------------------------------------------------------------------- instance Shape Int where rank _ = 1 zeroDim = 0 unitDim = 1 intersectDim s1 s2 = min s1 s2 addDim s1 s2 = s1 + s2 size s = s inShapeRange i1 i2 i = i >= i1 && i <= i2 listOfShape i = [i] shapeOfList [i] = Just i shapeOfList _ = Nothing {-# INLINE rank #-} {-# INLINE zeroDim #-} {-# INLINE unitDim #-} {-# INLINE intersectDim #-} {-# INLINE addDim #-} {-# INLINE size #-} {-# INLINE inShapeRange #-} {-# INLINE listOfShape #-} {-# INLINE shapeOfList #-} ------------------------------------------------------------------------------- -- | An index of dimension zero data Z = Z deriving (Show, Read, Eq, Ord) -- | Our index type, used for both shapes and indices. infixl 3 :. data tail :. head = !tail :. !head deriving (Show, Read, Eq, Ord) instance Shape Z where rank _ = 0 {-# INLINE rank #-} zeroDim = Z {-# INLINE zeroDim #-} unitDim = Z {-# INLINE unitDim #-} intersectDim _ _ = Z {-# INLINE intersectDim #-} addDim _ _ = Z {-# INLINE addDim #-} size _ = 1 {-# INLINE size #-} inShapeRange Z Z Z = True {-# INLINE inShapeRange #-} listOfShape _ = [] {-# NOINLINE listOfShape #-} shapeOfList [] = Just Z shapeOfList _ = Nothing {-# NOINLINE shapeOfList #-} instance Shape sh => Shape (sh :. Int) where rank (sh :. _) = rank sh + 1 {-# INLINE rank #-} zeroDim = zeroDim :. 0 {-# INLINE zeroDim #-} unitDim = unitDim :. 1 {-# INLINE unitDim #-} intersectDim (sh1 :. n1) (sh2 :. n2) = (intersectDim sh1 sh2 :. (min n1 n2)) {-# INLINE intersectDim #-} addDim (sh1 :. n1) (sh2 :. n2) = addDim sh1 sh2 :. (n1 + n2) {-# INLINE addDim #-} size (sh1 :. n) = size sh1 * n {-# INLINE size #-} inShapeRange (zs :. z) (sh1 :. n1) (sh2 :. n2) = (n2 >= z) && (n2 < n1) && (inShapeRange zs sh1 sh2) {-# INLINE inShapeRange #-} listOfShape (sh :. n) = n : listOfShape sh {-# NOINLINE listOfShape #-} shapeOfList xx = case xx of [] -> Nothing x : xs -> do ss <- shapeOfList xs return $ ss :. x {-# NOINLINE shapeOfList #-} ------------------------------------------------------------------------------- -- Common shapes type SH0 = Z type SH1 = SH0 :. Int type SH2 = SH1 :. Int type SH3 = SH2 :. Int type SH4 = SH3 :. Int type SH5 = SH4 :. Int ish0 :: SH0 ish0 = Z ish1 :: Int -> SH1 ish1 x1 = Z :. x1 ish2 :: Int -> Int -> SH2 ish2 x2 x1 = Z :. x2 :. x1 ish3 :: Int -> Int -> Int -> SH3 ish3 x3 x2 x1 = Z :. x3 :. x2 :. x1 ish4 :: Int -> Int -> Int -> Int -> SH4 ish4 x4 x3 x2 x1 = Z :. x4 :. x3 :. x2 :. x1 ish5 :: Int -> Int -> Int -> Int -> Int -> SH5 ish5 x5 x4 x3 x2 x1 = Z :. x5 :. x4 :. x3 :. x2 :. x1