{-# LANGUAGE AllowAmbiguousTypes  #-}
{-# LANGUAGE DeriveAnyClass       #-}
{-# LANGUAGE DeriveGeneric        #-}
{-# LANGUAGE FlexibleContexts     #-}
{-# LANGUAGE FlexibleInstances    #-}
{-# LANGUAGE ScopedTypeVariables  #-}
{-# LANGUAGE TypeApplications     #-}
{-# LANGUAGE TypeFamilies         #-}
{-# LANGUAGE TypeOperators        #-}
{-# LANGUAGE UndecidableInstances #-}
{-# OPTIONS_HADDOCK hide #-}
-- |
-- Module      : Data.Array.Accelerate.Sugar.Shape
-- Copyright   : [2008..2020] The Accelerate Team
-- License     : BSD3
--
-- Maintainer  : Trevor L. McDonell <trevor.mcdonell@gmail.com>
-- Stability   : experimental
-- Portability : non-portable (GHC extensions)
--
-- Array indices are snoc lists at both the type and value level. That is,
-- they're backwards, where the end-of-list token, 'Z', occurs first. For
-- example, the type of a rank-2 array index is @Z :. Int :. Int@, and
-- shape of a rank-2 array with 5 rows and 10 columns is @Z :. 5 :. 10@.
--
-- In Accelerate the rightmost dimension is the /fastest varying/ or
-- innermost; these values are adjacent in memory.
--

module Data.Array.Accelerate.Sugar.Shape
  where

import Data.Array.Accelerate.Sugar.Elt
import Data.Array.Accelerate.Representation.Tag
import Data.Array.Accelerate.Representation.Type
import qualified Data.Array.Accelerate.Representation.Shape         as R
import qualified Data.Array.Accelerate.Representation.Slice         as R

import Data.Kind
import GHC.Generics


-- Shorthand for common shape types
--
type DIM0 = Z
type DIM1 = DIM0 :. Int
type DIM2 = DIM1 :. Int
type DIM3 = DIM2 :. Int
type DIM4 = DIM3 :. Int
type DIM5 = DIM4 :. Int
type DIM6 = DIM5 :. Int
type DIM7 = DIM6 :. Int
type DIM8 = DIM7 :. Int
type DIM9 = DIM8 :. Int

-- | Rank-0 index
--
data Z = Z
  deriving (Int -> Z -> ShowS
[Z] -> ShowS
Z -> String
(Int -> Z -> ShowS) -> (Z -> String) -> ([Z] -> ShowS) -> Show Z
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Z] -> ShowS
$cshowList :: [Z] -> ShowS
show :: Z -> String
$cshow :: Z -> String
showsPrec :: Int -> Z -> ShowS
$cshowsPrec :: Int -> Z -> ShowS
Show, Z -> Z -> Bool
(Z -> Z -> Bool) -> (Z -> Z -> Bool) -> Eq Z
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Z -> Z -> Bool
$c/= :: Z -> Z -> Bool
== :: Z -> Z -> Bool
$c== :: Z -> Z -> Bool
Eq, (forall x. Z -> Rep Z x) -> (forall x. Rep Z x -> Z) -> Generic Z
forall x. Rep Z x -> Z
forall x. Z -> Rep Z x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cto :: forall x. Rep Z x -> Z
$cfrom :: forall x. Z -> Rep Z x
Generic, [TagR (EltR Z)]
TypeR (EltR Z)
TypeR (EltR Z)
-> [TagR (EltR Z)] -> (Z -> EltR Z) -> (EltR Z -> Z) -> Elt Z
EltR Z -> Z
Z -> EltR Z
forall a.
TypeR (EltR a)
-> [TagR (EltR a)] -> (a -> EltR a) -> (EltR a -> a) -> Elt a
toElt :: EltR Z -> Z
$ctoElt :: EltR Z -> Z
fromElt :: Z -> EltR Z
$cfromElt :: Z -> EltR Z
tagsR :: [TagR (EltR Z)]
$ctagsR :: [TagR (EltR Z)]
eltR :: TypeR (EltR Z)
$celtR :: TypeR (EltR Z)
Elt)

-- | Increase an index rank by one dimension. The ':.' operator is used to
-- construct both values and types.
--
infixl 3 :.
data tail :. head = !tail :. !head
  deriving ((tail :. head) -> (tail :. head) -> Bool
((tail :. head) -> (tail :. head) -> Bool)
-> ((tail :. head) -> (tail :. head) -> Bool) -> Eq (tail :. head)
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
forall tail head.
(Eq tail, Eq head) =>
(tail :. head) -> (tail :. head) -> Bool
/= :: (tail :. head) -> (tail :. head) -> Bool
$c/= :: forall tail head.
(Eq tail, Eq head) =>
(tail :. head) -> (tail :. head) -> Bool
== :: (tail :. head) -> (tail :. head) -> Bool
$c== :: forall tail head.
(Eq tail, Eq head) =>
(tail :. head) -> (tail :. head) -> Bool
Eq, (forall x. (tail :. head) -> Rep (tail :. head) x)
-> (forall x. Rep (tail :. head) x -> tail :. head)
-> Generic (tail :. head)
forall x. Rep (tail :. head) x -> tail :. head
forall x. (tail :. head) -> Rep (tail :. head) x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall tail head x. Rep (tail :. head) x -> tail :. head
forall tail head x. (tail :. head) -> Rep (tail :. head) x
$cto :: forall tail head x. Rep (tail :. head) x -> tail :. head
$cfrom :: forall tail head x. (tail :. head) -> Rep (tail :. head) x
Generic)  -- Not deriving Elt or Show

-- We don't we use a derived Show instance for (:.) because this will insert
-- parenthesis to demonstrate which order the operator is applied, i.e.:
--
--   (((Z :. z) :. y) :. x)
--
-- This is fine, but I find it a little unsightly. Instead, we drop all
-- parenthesis and just display the shape thus:
--
--   Z :. z :. y :. x
--
-- and then require the down-stream user to wrap the whole thing in parentheses.
-- This works fine for the most important case, which is to show Acc and Exp
-- expressions via the pretty printer, although Show-ing a Shape directly
-- results in no parenthesis being displayed.
--
-- One way around this might be to have specialised instances for DIM1, DIM2,
-- etc.
--
instance (Show sh, Show sz) => Show (sh :. sz) where
  showsPrec :: Int -> (sh :. sz) -> ShowS
showsPrec Int
p (sh
sh :. sz
sz) =
    Int -> sh -> ShowS
forall a. Show a => Int -> a -> ShowS
showsPrec Int
p sh
sh ShowS -> ShowS -> ShowS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> ShowS
showString String
" :. " ShowS -> ShowS -> ShowS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> sz -> ShowS
forall a. Show a => Int -> a -> ShowS
showsPrec Int
p sz
sz

-- | Marker for entire dimensions in 'Data.Array.Accelerate.Language.slice' and
-- 'Data.Array.Accelerate.Language.replicate' descriptors.
--
-- Occurrences of 'All' indicate the dimensions into which the array's existing
-- extent will be placed unchanged.
--
-- See 'Data.Array.Accelerate.Language.slice' and
-- 'Data.Array.Accelerate.Language.replicate' for examples.
--
data All = All
  deriving (Int -> All -> ShowS
[All] -> ShowS
All -> String
(Int -> All -> ShowS)
-> (All -> String) -> ([All] -> ShowS) -> Show All
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [All] -> ShowS
$cshowList :: [All] -> ShowS
show :: All -> String
$cshow :: All -> String
showsPrec :: Int -> All -> ShowS
$cshowsPrec :: Int -> All -> ShowS
Show, All -> All -> Bool
(All -> All -> Bool) -> (All -> All -> Bool) -> Eq All
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: All -> All -> Bool
$c/= :: All -> All -> Bool
== :: All -> All -> Bool
$c== :: All -> All -> Bool
Eq, (forall x. All -> Rep All x)
-> (forall x. Rep All x -> All) -> Generic All
forall x. Rep All x -> All
forall x. All -> Rep All x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cto :: forall x. Rep All x -> All
$cfrom :: forall x. All -> Rep All x
Generic, [TagR (EltR All)]
TypeR (EltR All)
TypeR (EltR All)
-> [TagR (EltR All)]
-> (All -> EltR All)
-> (EltR All -> All)
-> Elt All
EltR All -> All
All -> EltR All
forall a.
TypeR (EltR a)
-> [TagR (EltR a)] -> (a -> EltR a) -> (EltR a -> a) -> Elt a
toElt :: EltR All -> All
$ctoElt :: EltR All -> All
fromElt :: All -> EltR All
$cfromElt :: All -> EltR All
tagsR :: [TagR (EltR All)]
$ctagsR :: [TagR (EltR All)]
eltR :: TypeR (EltR All)
$celtR :: TypeR (EltR All)
Elt)

-- | Marker for arbitrary dimensions in 'Data.Array.Accelerate.Language.slice'
-- and 'Data.Array.Accelerate.Language.replicate' descriptors.
--
-- 'Any' can be used in the leftmost position of a slice instead of 'Z',
-- indicating that any dimensionality is admissible in that position.
--
-- See 'Data.Array.Accelerate.Language.slice' and
-- 'Data.Array.Accelerate.Language.replicate' for examples.
--
data Any sh = Any
  deriving (Int -> Any sh -> ShowS
[Any sh] -> ShowS
Any sh -> String
(Int -> Any sh -> ShowS)
-> (Any sh -> String) -> ([Any sh] -> ShowS) -> Show (Any sh)
forall sh. Int -> Any sh -> ShowS
forall sh. [Any sh] -> ShowS
forall sh. Any sh -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Any sh] -> ShowS
$cshowList :: forall sh. [Any sh] -> ShowS
show :: Any sh -> String
$cshow :: forall sh. Any sh -> String
showsPrec :: Int -> Any sh -> ShowS
$cshowsPrec :: forall sh. Int -> Any sh -> ShowS
Show, Any sh -> Any sh -> Bool
(Any sh -> Any sh -> Bool)
-> (Any sh -> Any sh -> Bool) -> Eq (Any sh)
forall sh. Any sh -> Any sh -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Any sh -> Any sh -> Bool
$c/= :: forall sh. Any sh -> Any sh -> Bool
== :: Any sh -> Any sh -> Bool
$c== :: forall sh. Any sh -> Any sh -> Bool
Eq, (forall x. Any sh -> Rep (Any sh) x)
-> (forall x. Rep (Any sh) x -> Any sh) -> Generic (Any sh)
forall x. Rep (Any sh) x -> Any sh
forall x. Any sh -> Rep (Any sh) x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall sh x. Rep (Any sh) x -> Any sh
forall sh x. Any sh -> Rep (Any sh) x
$cto :: forall sh x. Rep (Any sh) x -> Any sh
$cfrom :: forall sh x. Any sh -> Rep (Any sh) x
Generic)

-- | Marker for splitting along an entire dimension in division descriptors.
--
-- For example, when used in a division descriptor passed to
-- 'Data.Array.Accelerate.toSeq', a `Split` indicates that the array should be
-- divided along this dimension forming the elements of the output sequence.
--
data Split = Split
  deriving (Int -> Split -> ShowS
[Split] -> ShowS
Split -> String
(Int -> Split -> ShowS)
-> (Split -> String) -> ([Split] -> ShowS) -> Show Split
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Split] -> ShowS
$cshowList :: [Split] -> ShowS
show :: Split -> String
$cshow :: Split -> String
showsPrec :: Int -> Split -> ShowS
$cshowsPrec :: Int -> Split -> ShowS
Show, Split -> Split -> Bool
(Split -> Split -> Bool) -> (Split -> Split -> Bool) -> Eq Split
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Split -> Split -> Bool
$c/= :: Split -> Split -> Bool
== :: Split -> Split -> Bool
$c== :: Split -> Split -> Bool
Eq)

-- | Marker for arbitrary shapes in slices descriptors, where it is desired to
-- split along an unknown number of dimensions.
--
-- For example, in the following definition, 'Divide' matches against any shape
-- and flattens everything but the innermost dimension.
--
-- > vectors :: (Shape sh, Elt e) => Acc (Array (sh:.Int) e) -> Seq [Vector e]
-- > vectors = toSeq (Divide :. All)
--
data Divide sh = Divide
  deriving (Int -> Divide sh -> ShowS
[Divide sh] -> ShowS
Divide sh -> String
(Int -> Divide sh -> ShowS)
-> (Divide sh -> String)
-> ([Divide sh] -> ShowS)
-> Show (Divide sh)
forall sh. Int -> Divide sh -> ShowS
forall sh. [Divide sh] -> ShowS
forall sh. Divide sh -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Divide sh] -> ShowS
$cshowList :: forall sh. [Divide sh] -> ShowS
show :: Divide sh -> String
$cshow :: forall sh. Divide sh -> String
showsPrec :: Int -> Divide sh -> ShowS
$cshowsPrec :: forall sh. Int -> Divide sh -> ShowS
Show, Divide sh -> Divide sh -> Bool
(Divide sh -> Divide sh -> Bool)
-> (Divide sh -> Divide sh -> Bool) -> Eq (Divide sh)
forall sh. Divide sh -> Divide sh -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Divide sh -> Divide sh -> Bool
$c/= :: forall sh. Divide sh -> Divide sh -> Bool
== :: Divide sh -> Divide sh -> Bool
$c== :: forall sh. Divide sh -> Divide sh -> Bool
Eq)


-- | Number of dimensions of a /shape/ or /index/ (>= 0)
--
rank :: forall sh. Shape sh => Int
rank :: Int
rank = ShapeR (EltR sh) -> Int
forall sh. ShapeR sh -> Int
R.rank (Shape sh => ShapeR (EltR sh)
forall sh. Shape sh => ShapeR (EltR sh)
shapeR @sh)

-- | Total number of elements in an array of the given /shape/
--
size :: forall sh. Shape sh => sh -> Int
size :: sh -> Int
size = ShapeR (EltR sh) -> EltR sh -> Int
forall sh. ShapeR sh -> sh -> Int
R.size (Shape sh => ShapeR (EltR sh)
forall sh. Shape sh => ShapeR (EltR sh)
shapeR @sh) (EltR sh -> Int) -> (sh -> EltR sh) -> sh -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. sh -> EltR sh
forall a. Elt a => a -> EltR a
fromElt

-- | The empty /shape/
--
empty :: forall sh. Shape sh => sh
empty :: sh
empty = EltR sh -> sh
forall a. Elt a => EltR a -> a
toElt (EltR sh -> sh) -> EltR sh -> sh
forall a b. (a -> b) -> a -> b
$ ShapeR (EltR sh) -> EltR sh
forall sh. ShapeR sh -> sh
R.empty (Shape sh => ShapeR (EltR sh)
forall sh. Shape sh => ShapeR (EltR sh)
shapeR @sh)

-- | Yield the intersection of two shapes
intersect :: forall sh. Shape sh => sh -> sh -> sh
intersect :: sh -> sh -> sh
intersect sh
x sh
y = EltR sh -> sh
forall a. Elt a => EltR a -> a
toElt (EltR sh -> sh) -> EltR sh -> sh
forall a b. (a -> b) -> a -> b
$ ShapeR (EltR sh) -> EltR sh -> EltR sh -> EltR sh
forall sh. ShapeR sh -> sh -> sh -> sh
R.intersect (Shape sh => ShapeR (EltR sh)
forall sh. Shape sh => ShapeR (EltR sh)
shapeR @sh) (sh -> EltR sh
forall a. Elt a => a -> EltR a
fromElt sh
x) (sh -> EltR sh
forall a. Elt a => a -> EltR a
fromElt sh
y)

-- | Yield the union of two shapes
--
union :: forall sh. Shape sh => sh -> sh -> sh
union :: sh -> sh -> sh
union sh
x sh
y = EltR sh -> sh
forall a. Elt a => EltR a -> a
toElt (EltR sh -> sh) -> EltR sh -> sh
forall a b. (a -> b) -> a -> b
$ ShapeR (EltR sh) -> EltR sh -> EltR sh -> EltR sh
forall sh. ShapeR sh -> sh -> sh -> sh
R.union (Shape sh => ShapeR (EltR sh)
forall sh. Shape sh => ShapeR (EltR sh)
shapeR @sh) (sh -> EltR sh
forall a. Elt a => a -> EltR a
fromElt sh
x) (sh -> EltR sh
forall a. Elt a => a -> EltR a
fromElt sh
y)

-- | Map a multi-dimensional index into one in a linear, row-major
-- representation of the array (first argument is the /shape/, second
-- argument is the index).
--
toIndex :: forall sh. Shape sh
        => sh       -- ^ Total shape (extent) of the array
        -> sh       -- ^ The argument index
        -> Int      -- ^ Corresponding linear index
toIndex :: sh -> sh -> Int
toIndex sh
sh sh
ix = ShapeR (EltR sh) -> EltR sh -> EltR sh -> Int
forall sh. HasCallStack => ShapeR sh -> sh -> sh -> Int
R.toIndex (Shape sh => ShapeR (EltR sh)
forall sh. Shape sh => ShapeR (EltR sh)
shapeR @sh) (sh -> EltR sh
forall a. Elt a => a -> EltR a
fromElt sh
sh) (sh -> EltR sh
forall a. Elt a => a -> EltR a
fromElt sh
ix)

-- | Inverse of 'toIndex'.
--
fromIndex :: forall sh. Shape sh
          => sh     -- ^ Total shape (extent) of the array
          -> Int    -- ^ The argument index
          -> sh     -- ^ Corresponding multi-dimensional index
fromIndex :: sh -> Int -> sh
fromIndex sh
sh = EltR sh -> sh
forall a. Elt a => EltR a -> a
toElt (EltR sh -> sh) -> (Int -> EltR sh) -> Int -> sh
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ShapeR (EltR sh) -> EltR sh -> Int -> EltR sh
forall sh. HasCallStack => ShapeR sh -> sh -> Int -> sh
R.fromIndex (Shape sh => ShapeR (EltR sh)
forall sh. Shape sh => ShapeR (EltR sh)
shapeR @sh) (sh -> EltR sh
forall a. Elt a => a -> EltR a
fromElt sh
sh)

-- | Iterate through all of the indices of a shape, applying the given
-- function at each index. The index space is traversed in row-major order.
--
iter :: forall sh e. Shape sh
     => sh              -- ^ The total shape (extent) of the index space
     -> (sh -> e)       -- ^ Function to apply at each index
     -> (e -> e -> e)   -- ^ Function to combine results
     -> e               -- ^ Value to return in case of an empty iteration space
     -> e
iter :: sh -> (sh -> e) -> (e -> e -> e) -> e -> e
iter sh
sh sh -> e
f = ShapeR (EltR sh)
-> EltR sh -> (EltR sh -> e) -> (e -> e -> e) -> e -> e
forall sh a.
ShapeR sh -> sh -> (sh -> a) -> (a -> a -> a) -> a -> a
R.iter (Shape sh => ShapeR (EltR sh)
forall sh. Shape sh => ShapeR (EltR sh)
shapeR @sh) (sh -> EltR sh
forall a. Elt a => a -> EltR a
fromElt sh
sh) (sh -> e
f (sh -> e) -> (EltR sh -> sh) -> EltR sh -> e
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EltR sh -> sh
forall a. Elt a => EltR a -> a
toElt)

-- | Variant of 'iter' without an initial value
--
iter1 :: forall sh e. Shape sh
      => sh
      -> (sh -> e)
      -> (e -> e -> e)
      -> e
iter1 :: sh -> (sh -> e) -> (e -> e -> e) -> e
iter1 sh
sh sh -> e
f = ShapeR (EltR sh) -> EltR sh -> (EltR sh -> e) -> (e -> e -> e) -> e
forall sh a.
HasCallStack =>
ShapeR sh -> sh -> (sh -> a) -> (a -> a -> a) -> a
R.iter1 (Shape sh => ShapeR (EltR sh)
forall sh. Shape sh => ShapeR (EltR sh)
shapeR @sh) (sh -> EltR sh
forall a. Elt a => a -> EltR a
fromElt sh
sh) (sh -> e
f (sh -> e) -> (EltR sh -> sh) -> EltR sh -> e
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EltR sh -> sh
forall a. Elt a => EltR a -> a
toElt)

-- | Convert a minpoint-maxpoint index into a zero-indexed shape
--
rangeToShape :: forall sh. Shape sh => (sh, sh) -> sh
rangeToShape :: (sh, sh) -> sh
rangeToShape (sh
u, sh
v) = EltR sh -> sh
forall a. Elt a => EltR a -> a
toElt (EltR sh -> sh) -> EltR sh -> sh
forall a b. (a -> b) -> a -> b
$ ShapeR (EltR sh) -> (EltR sh, EltR sh) -> EltR sh
forall sh. ShapeR sh -> (sh, sh) -> sh
R.rangeToShape (Shape sh => ShapeR (EltR sh)
forall sh. Shape sh => ShapeR (EltR sh)
shapeR @sh) (sh -> EltR sh
forall a. Elt a => a -> EltR a
fromElt sh
u, sh -> EltR sh
forall a. Elt a => a -> EltR a
fromElt sh
v)

-- | Convert a shape into a minpoint-maxpoint index
--
shapeToRange :: forall sh. Shape sh => sh -> (sh, sh)
shapeToRange :: sh -> (sh, sh)
shapeToRange sh
ix =
  let (EltR sh
u, EltR sh
v) = ShapeR (EltR sh) -> EltR sh -> (EltR sh, EltR sh)
forall sh. ShapeR sh -> sh -> (sh, sh)
R.shapeToRange (Shape sh => ShapeR (EltR sh)
forall sh. Shape sh => ShapeR (EltR sh)
shapeR @sh) (sh -> EltR sh
forall a. Elt a => a -> EltR a
fromElt sh
ix)
   in (EltR sh -> sh
forall a. Elt a => EltR a -> a
toElt EltR sh
u, EltR sh -> sh
forall a. Elt a => EltR a -> a
toElt EltR sh
v)

-- | Convert a shape to a list of dimensions
--
shapeToList :: forall sh. Shape sh => sh -> [Int]
shapeToList :: sh -> [Int]
shapeToList = ShapeR (EltR sh) -> EltR sh -> [Int]
forall sh. ShapeR sh -> sh -> [Int]
R.shapeToList (Shape sh => ShapeR (EltR sh)
forall sh. Shape sh => ShapeR (EltR sh)
shapeR @sh) (EltR sh -> [Int]) -> (sh -> EltR sh) -> sh -> [Int]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. sh -> EltR sh
forall a. Elt a => a -> EltR a
fromElt

-- | Convert a list of dimensions into a shape. If the list does not
-- contain exactly the number of elements as specified by the type of the
-- shape: error.
--
listToShape :: forall sh. Shape sh => [Int] -> sh
listToShape :: [Int] -> sh
listToShape = EltR sh -> sh
forall a. Elt a => EltR a -> a
toElt (EltR sh -> sh) -> ([Int] -> EltR sh) -> [Int] -> sh
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ShapeR (EltR sh) -> [Int] -> EltR sh
forall sh. HasCallStack => ShapeR sh -> [Int] -> sh
R.listToShape (Shape sh => ShapeR (EltR sh)
forall sh. Shape sh => ShapeR (EltR sh)
shapeR @sh)

-- | Attempt to convert a list of dimensions into a shape
--
listToShape' :: forall sh. Shape sh => [Int] -> Maybe sh
listToShape' :: [Int] -> Maybe sh
listToShape' = (EltR sh -> sh) -> Maybe (EltR sh) -> Maybe sh
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap EltR sh -> sh
forall a. Elt a => EltR a -> a
toElt (Maybe (EltR sh) -> Maybe sh)
-> ([Int] -> Maybe (EltR sh)) -> [Int] -> Maybe sh
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ShapeR (EltR sh) -> [Int] -> Maybe (EltR sh)
forall sh. ShapeR sh -> [Int] -> Maybe sh
R.listToShape' (Shape sh => ShapeR (EltR sh)
forall sh. Shape sh => ShapeR (EltR sh)
shapeR @sh)

-- | Nicely format a shape as a string
--
showShape :: Shape sh => sh -> String
showShape :: sh -> String
showShape = (Int -> ShowS) -> String -> [Int] -> String
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (\Int
sh String
str -> String
str String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" :. " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
sh) String
"Z" ([Int] -> String) -> (sh -> [Int]) -> sh -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. sh -> [Int]
forall sh. Shape sh => sh -> [Int]
shapeToList

-- | Project the shape of a slice from the full shape.
--
sliceShape
    :: forall slix co sl dim. (Shape sl, Shape dim)
    => R.SliceIndex slix (EltR sl) co (EltR dim)
    -> dim
    -> sl
sliceShape :: SliceIndex slix (EltR sl) co (EltR dim) -> dim -> sl
sliceShape SliceIndex slix (EltR sl) co (EltR dim)
slix = EltR sl -> sl
forall a. Elt a => EltR a -> a
toElt (EltR sl -> sl) -> (dim -> EltR sl) -> dim -> sl
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SliceIndex slix (EltR sl) co (EltR dim) -> EltR dim -> EltR sl
forall slix co sl dim. SliceIndex slix sl co dim -> dim -> sl
R.sliceShape SliceIndex slix (EltR sl) co (EltR dim)
slix (EltR dim -> EltR sl) -> (dim -> EltR dim) -> dim -> EltR sl
forall b c a. (b -> c) -> (a -> b) -> a -> c
. dim -> EltR dim
forall a. Elt a => a -> EltR a
fromElt

-- | Enumerate all slices within a given bound. The innermost dimension
-- changes most rapidly.
--
-- Example:
--
-- > let slix = sliceIndex @(Z :. Int :. Int :. All)
-- >     sh   = Z :. 2 :. 3 :. 1 :: DIM3
-- > in
-- > enumSlices slix sh :: [ Z :. Int :. Int :. All ]
--
enumSlices :: forall slix co sl dim. (Elt slix, Elt dim)
           => R.SliceIndex (EltR slix) sl co (EltR dim)
           -> dim    -- Bounds
           -> [slix] -- All slices within bounds.
enumSlices :: SliceIndex (EltR slix) sl co (EltR dim) -> dim -> [slix]
enumSlices SliceIndex (EltR slix) sl co (EltR dim)
slix = (EltR slix -> slix) -> [EltR slix] -> [slix]
forall a b. (a -> b) -> [a] -> [b]
map EltR slix -> slix
forall a. Elt a => EltR a -> a
toElt ([EltR slix] -> [slix]) -> (dim -> [EltR slix]) -> dim -> [slix]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SliceIndex (EltR slix) sl co (EltR dim) -> EltR dim -> [EltR slix]
forall slix co sl dim. SliceIndex slix sl co dim -> dim -> [slix]
R.enumSlices SliceIndex (EltR slix) sl co (EltR dim)
slix (EltR dim -> [EltR slix])
-> (dim -> EltR dim) -> dim -> [EltR slix]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. dim -> EltR dim
forall a. Elt a => a -> EltR a
fromElt

-- | Shapes and indices of multi-dimensional arrays
--
class (Elt sh, Elt (Any sh), FullShape sh ~ sh, CoSliceShape sh ~ sh, SliceShape sh ~ Z)
       => Shape sh where

  -- | Reified type witness for shapes
  shapeR :: R.ShapeR (EltR sh)

  -- | The slice index for slice specifier 'Any sh'
  sliceAnyIndex  :: R.SliceIndex (EltR (Any sh)) (EltR sh) () (EltR sh)

  -- | The slice index for specifying a slice with only the Z component projected
  sliceNoneIndex :: R.SliceIndex (EltR sh) () (EltR sh) (EltR sh)


-- | Slices, aka generalised indices, as /n/-tuples and mappings of slice
-- indices to slices, co-slices, and slice dimensions
--
class (Elt sl, Shape (SliceShape sl), Shape (CoSliceShape sl), Shape (FullShape sl))
       => Slice sl where
  type SliceShape   sl :: Type    -- the projected slice
  type CoSliceShape sl :: Type    -- the complement of the slice
  type FullShape    sl :: Type    -- the combined dimension
  sliceIndex :: R.SliceIndex (EltR sl)
                             (EltR (SliceShape   sl))
                             (EltR (CoSliceShape sl))
                             (EltR (FullShape    sl))

-- | Generalised array division, like above but use for splitting an array
-- into many subarrays, as opposed to extracting a single subarray.
--
class (Slice (DivisionSlice sl)) => Division sl where
  type DivisionSlice sl :: Type   -- the slice
  slicesIndex :: slix ~ DivisionSlice sl
              => R.SliceIndex (EltR slix)
                              (EltR (SliceShape   slix))
                              (EltR (CoSliceShape slix))
                              (EltR (FullShape    slix))

instance (Elt t, Elt h) => Elt (t :. h) where
  type EltR (t :. h) = (EltR t, EltR h)
  eltR :: TypeR (EltR (t :. h))
eltR           = TupR ScalarType (EltR t)
-> TupR ScalarType (EltR h) -> TupR ScalarType (EltR t, EltR h)
forall (s :: * -> *) a b. TupR s a -> TupR s b -> TupR s (a, b)
TupRpair (Elt t => TupR ScalarType (EltR t)
forall a. Elt a => TypeR (EltR a)
eltR @t) (Elt h => TupR ScalarType (EltR h)
forall a. Elt a => TypeR (EltR a)
eltR @h)
  tagsR :: [TagR (EltR (t :. h))]
tagsR          = [TagR (EltR t) -> TagR (EltR h) -> TagR (EltR t, EltR h)
forall a b. TagR a -> TagR b -> TagR (a, b)
TagRpair TagR (EltR t)
t TagR (EltR h)
h | TagR (EltR t)
t <- Elt t => [TagR (EltR t)]
forall a. Elt a => [TagR (EltR a)]
tagsR @t, TagR (EltR h)
h <- Elt h => [TagR (EltR h)]
forall a. Elt a => [TagR (EltR a)]
tagsR @h]
  fromElt :: (t :. h) -> EltR (t :. h)
fromElt (t
t:.h
h) = (t -> EltR t
forall a. Elt a => a -> EltR a
fromElt t
t, h -> EltR h
forall a. Elt a => a -> EltR a
fromElt h
h)
  toElt :: EltR (t :. h) -> t :. h
toElt (t, h)   = EltR t -> t
forall a. Elt a => EltR a -> a
toElt EltR t
t t -> h -> t :. h
forall tail head. tail -> head -> tail :. head
:. EltR h -> h
forall a. Elt a => EltR a -> a
toElt EltR h
h

instance Elt (Any Z)
instance Shape sh => Elt (Any (sh :. Int)) where
  type EltR (Any (sh :. Int)) = (EltR (Any sh), ())
  eltR :: TypeR (EltR (Any (sh :. Int)))
eltR      = TupR ScalarType (EltR (Any sh))
-> TupR ScalarType () -> TupR ScalarType (EltR (Any sh), ())
forall (s :: * -> *) a b. TupR s a -> TupR s b -> TupR s (a, b)
TupRpair (Elt (Any sh) => TupR ScalarType (EltR (Any sh))
forall a. Elt a => TypeR (EltR a)
eltR @(Any sh)) TupR ScalarType ()
forall (s :: * -> *). TupR s ()
TupRunit
  tagsR :: [TagR (EltR (Any (sh :. Int)))]
tagsR     = [TagR (EltR (Any sh)) -> TagR () -> TagR (EltR (Any sh), ())
forall a b. TagR a -> TagR b -> TagR (a, b)
TagRpair TagR (EltR (Any sh))
t TagR ()
TagRunit | TagR (EltR (Any sh))
t <- Elt (Any sh) => [TagR (EltR (Any sh))]
forall a. Elt a => [TagR (EltR a)]
tagsR @(Any sh)]
  fromElt :: Any (sh :. Int) -> EltR (Any (sh :. Int))
fromElt Any (sh :. Int)
_ = (Any sh -> EltR (Any sh)
forall a. Elt a => a -> EltR a
fromElt (Any sh
forall sh. Any sh
Any :: Any sh), ())
  toElt :: EltR (Any (sh :. Int)) -> Any (sh :. Int)
toElt EltR (Any (sh :. Int))
_   = Any (sh :. Int)
forall sh. Any sh
Any

instance Shape Z where
  shapeR :: ShapeR (EltR Z)
shapeR         = ShapeR ()
ShapeR (EltR Z)
R.ShapeRz
  sliceAnyIndex :: SliceIndex (EltR (Any Z)) (EltR Z) () (EltR Z)
sliceAnyIndex  = SliceIndex () () () ()
SliceIndex (EltR (Any Z)) (EltR Z) () (EltR Z)
R.SliceNil
  sliceNoneIndex :: SliceIndex (EltR Z) () (EltR Z) (EltR Z)
sliceNoneIndex = SliceIndex () () () ()
SliceIndex (EltR Z) () (EltR Z) (EltR Z)
R.SliceNil

instance Shape sh => Shape (sh:.Int) where
  shapeR :: ShapeR (EltR (sh :. Int))
shapeR         = ShapeR (EltR sh) -> ShapeR (EltR sh, Int)
forall sh. ShapeR sh -> ShapeR (sh, Int)
R.ShapeRsnoc (Shape sh => ShapeR (EltR sh)
forall sh. Shape sh => ShapeR (EltR sh)
shapeR @sh)
  sliceAnyIndex :: SliceIndex
  (EltR (Any (sh :. Int))) (EltR (sh :. Int)) () (EltR (sh :. Int))
sliceAnyIndex  = SliceIndex (EltR (Any sh)) (EltR sh) () (EltR sh)
-> SliceIndex (EltR (Any sh), ()) (EltR sh, Int) () (EltR sh, Int)
forall ix slice co dim.
SliceIndex ix slice co dim
-> SliceIndex (ix, ()) (slice, Int) co (dim, Int)
R.SliceAll   (Shape sh => SliceIndex (EltR (Any sh)) (EltR sh) () (EltR sh)
forall sh.
Shape sh =>
SliceIndex (EltR (Any sh)) (EltR sh) () (EltR sh)
sliceAnyIndex  @sh)
  sliceNoneIndex :: SliceIndex
  (EltR (sh :. Int)) () (EltR (sh :. Int)) (EltR (sh :. Int))
sliceNoneIndex = SliceIndex (EltR sh) () (EltR sh) (EltR sh)
-> SliceIndex (EltR sh, Int) () (EltR sh, Int) (EltR sh, Int)
forall ix slice co dim.
SliceIndex ix slice co dim
-> SliceIndex (ix, Int) slice (co, Int) (dim, Int)
R.SliceFixed (Shape sh => SliceIndex (EltR sh) () (EltR sh) (EltR sh)
forall sh. Shape sh => SliceIndex (EltR sh) () (EltR sh) (EltR sh)
sliceNoneIndex @sh)

instance Slice Z where
  type SliceShape   Z = Z
  type CoSliceShape Z = Z
  type FullShape    Z = Z
  sliceIndex :: SliceIndex
  (EltR Z)
  (EltR (SliceShape Z))
  (EltR (CoSliceShape Z))
  (EltR (FullShape Z))
sliceIndex = SliceIndex () () () ()
SliceIndex
  (EltR Z)
  (EltR (SliceShape Z))
  (EltR (CoSliceShape Z))
  (EltR (FullShape Z))
R.SliceNil

instance Slice sl => Slice (sl:.All) where
  type SliceShape   (sl:.All) = SliceShape   sl :. Int
  type CoSliceShape (sl:.All) = CoSliceShape sl
  type FullShape    (sl:.All) = FullShape    sl :. Int
  sliceIndex :: SliceIndex
  (EltR (sl :. All))
  (EltR (SliceShape (sl :. All)))
  (EltR (CoSliceShape (sl :. All)))
  (EltR (FullShape (sl :. All)))
sliceIndex = SliceIndex
  (EltR sl)
  (EltR (SliceShape sl))
  (EltR (CoSliceShape sl))
  (EltR (FullShape sl))
-> SliceIndex
     (EltR sl, ())
     (EltR (SliceShape sl), Int)
     (EltR (CoSliceShape sl))
     (EltR (FullShape sl), Int)
forall ix slice co dim.
SliceIndex ix slice co dim
-> SliceIndex (ix, ()) (slice, Int) co (dim, Int)
R.SliceAll (Slice sl =>
SliceIndex
  (EltR sl)
  (EltR (SliceShape sl))
  (EltR (CoSliceShape sl))
  (EltR (FullShape sl))
forall sl.
Slice sl =>
SliceIndex
  (EltR sl)
  (EltR (SliceShape sl))
  (EltR (CoSliceShape sl))
  (EltR (FullShape sl))
sliceIndex @sl)

instance Slice sl => Slice (sl:.Int) where
  type SliceShape   (sl:.Int) = SliceShape   sl
  type CoSliceShape (sl:.Int) = CoSliceShape sl :. Int
  type FullShape    (sl:.Int) = FullShape    sl :. Int
  sliceIndex :: SliceIndex
  (EltR (sl :. Int))
  (EltR (SliceShape (sl :. Int)))
  (EltR (CoSliceShape (sl :. Int)))
  (EltR (FullShape (sl :. Int)))
sliceIndex = SliceIndex
  (EltR sl)
  (EltR (SliceShape sl))
  (EltR (CoSliceShape sl))
  (EltR (FullShape sl))
-> SliceIndex
     (EltR sl, Int)
     (EltR (SliceShape sl))
     (EltR (CoSliceShape sl), Int)
     (EltR (FullShape sl), Int)
forall ix slice co dim.
SliceIndex ix slice co dim
-> SliceIndex (ix, Int) slice (co, Int) (dim, Int)
R.SliceFixed (Slice sl =>
SliceIndex
  (EltR sl)
  (EltR (SliceShape sl))
  (EltR (CoSliceShape sl))
  (EltR (FullShape sl))
forall sl.
Slice sl =>
SliceIndex
  (EltR sl)
  (EltR (SliceShape sl))
  (EltR (CoSliceShape sl))
  (EltR (FullShape sl))
sliceIndex @sl)

instance Shape sh => Slice (Any sh) where
  type SliceShape   (Any sh) = sh
  type CoSliceShape (Any sh) = Z
  type FullShape    (Any sh) = sh
  sliceIndex :: SliceIndex
  (EltR (Any sh))
  (EltR (SliceShape (Any sh)))
  (EltR (CoSliceShape (Any sh)))
  (EltR (FullShape (Any sh)))
sliceIndex = Shape sh => SliceIndex (EltR (Any sh)) (EltR sh) () (EltR sh)
forall sh.
Shape sh =>
SliceIndex (EltR (Any sh)) (EltR sh) () (EltR sh)
sliceAnyIndex @sh

instance Division Z where
  type DivisionSlice Z = Z
  slicesIndex :: SliceIndex
  (EltR slix)
  (EltR (SliceShape slix))
  (EltR (CoSliceShape slix))
  (EltR (FullShape slix))
slicesIndex = SliceIndex () () () ()
SliceIndex
  (EltR slix)
  (EltR (SliceShape slix))
  (EltR (CoSliceShape slix))
  (EltR (FullShape slix))
R.SliceNil

instance Division sl => Division (sl:.All) where
  type DivisionSlice (sl:.All) = DivisionSlice sl :. All
  slicesIndex :: SliceIndex
  (EltR slix)
  (EltR (SliceShape slix))
  (EltR (CoSliceShape slix))
  (EltR (FullShape slix))
slicesIndex = SliceIndex
  (EltR (DivisionSlice sl))
  (EltR (SliceShape (DivisionSlice sl)))
  (EltR (CoSliceShape (DivisionSlice sl)))
  (EltR (FullShape (DivisionSlice sl)))
-> SliceIndex
     (EltR (DivisionSlice sl), ())
     (EltR (SliceShape (DivisionSlice sl)), Int)
     (EltR (CoSliceShape (DivisionSlice sl)))
     (EltR (FullShape (DivisionSlice sl)), Int)
forall ix slice co dim.
SliceIndex ix slice co dim
-> SliceIndex (ix, ()) (slice, Int) co (dim, Int)
R.SliceAll (forall slix.
(Division sl, slix ~ DivisionSlice sl) =>
SliceIndex
  (EltR slix)
  (EltR (SliceShape slix))
  (EltR (CoSliceShape slix))
  (EltR (FullShape slix))
forall sl slix.
(Division sl, slix ~ DivisionSlice sl) =>
SliceIndex
  (EltR slix)
  (EltR (SliceShape slix))
  (EltR (CoSliceShape slix))
  (EltR (FullShape slix))
slicesIndex @sl)

instance Division sl => Division (sl:.Split) where
  type DivisionSlice (sl:.Split) = DivisionSlice sl :. Int
  slicesIndex :: SliceIndex
  (EltR slix)
  (EltR (SliceShape slix))
  (EltR (CoSliceShape slix))
  (EltR (FullShape slix))
slicesIndex = SliceIndex
  (EltR (DivisionSlice sl))
  (EltR (SliceShape (DivisionSlice sl)))
  (EltR (CoSliceShape (DivisionSlice sl)))
  (EltR (FullShape (DivisionSlice sl)))
-> SliceIndex
     (EltR (DivisionSlice sl), Int)
     (EltR (SliceShape (DivisionSlice sl)))
     (EltR (CoSliceShape (DivisionSlice sl)), Int)
     (EltR (FullShape (DivisionSlice sl)), Int)
forall ix slice co dim.
SliceIndex ix slice co dim
-> SliceIndex (ix, Int) slice (co, Int) (dim, Int)
R.SliceFixed (forall slix.
(Division sl, slix ~ DivisionSlice sl) =>
SliceIndex
  (EltR slix)
  (EltR (SliceShape slix))
  (EltR (CoSliceShape slix))
  (EltR (FullShape slix))
forall sl slix.
(Division sl, slix ~ DivisionSlice sl) =>
SliceIndex
  (EltR slix)
  (EltR (SliceShape slix))
  (EltR (CoSliceShape slix))
  (EltR (FullShape slix))
slicesIndex @sl)

instance Shape sh => Division (Any sh) where
  type DivisionSlice (Any sh) = Any sh
  slicesIndex :: SliceIndex
  (EltR slix)
  (EltR (SliceShape slix))
  (EltR (CoSliceShape slix))
  (EltR (FullShape slix))
slicesIndex = Shape sh => SliceIndex (EltR (Any sh)) (EltR sh) () (EltR sh)
forall sh.
Shape sh =>
SliceIndex (EltR (Any sh)) (EltR sh) () (EltR sh)
sliceAnyIndex @sh

instance (Shape sh, Slice sh) => Division (Divide sh) where
  type DivisionSlice (Divide sh) = sh
  slicesIndex :: SliceIndex
  (EltR slix)
  (EltR (SliceShape slix))
  (EltR (CoSliceShape slix))
  (EltR (FullShape slix))
slicesIndex = Shape sh => SliceIndex (EltR sh) () (EltR sh) (EltR sh)
forall sh. Shape sh => SliceIndex (EltR sh) () (EltR sh) (EltR sh)
sliceNoneIndex @sh