-- | Facilities for creating, inspecting, and simplifying reshape and
-- coercion operations.
module Futhark.IR.Prop.Reshape
  ( -- * Construction
    shapeCoerce,

    -- * Execution
    reshapeOuter,
    reshapeInner,

    -- * Simplification

    -- * Shape calculations
    reshapeIndex,
    flattenIndex,
    unflattenIndex,
    sliceSizes,
  )
where

import Data.Foldable
import Futhark.IR.Syntax
import Futhark.Util.IntegralExp
import Prelude hiding (product, quot, sum)

-- | Construct a 'Reshape' that is a 'ReshapeCoerce'.
shapeCoerce :: [SubExp] -> VName -> Exp rep
shapeCoerce :: forall {k} (rep :: k). [SubExp] -> VName -> Exp rep
shapeCoerce [SubExp]
newdims VName
arr =
  forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ ReshapeKind -> Shape -> VName -> BasicOp
Reshape ReshapeKind
ReshapeCoerce (forall d. [d] -> ShapeBase d
Shape [SubExp]
newdims) VName
arr

-- | @reshapeOuter newshape n oldshape@ returns a 'Reshape' expression
-- that replaces the outer @n@ dimensions of @oldshape@ with @newshape@.
reshapeOuter :: Shape -> Int -> Shape -> Shape
reshapeOuter :: Shape -> Int -> Shape -> Shape
reshapeOuter Shape
newshape Int
n Shape
oldshape =
  Shape
newshape forall a. Semigroup a => a -> a -> a
<> forall d. [d] -> ShapeBase d
Shape (forall a. Int -> [a] -> [a]
drop Int
n (forall d. ShapeBase d -> [d]
shapeDims Shape
oldshape))

-- | @reshapeInner newshape n oldshape@ returns a 'Reshape' expression
-- that replaces the inner @m-n@ dimensions (where @m@ is the rank of
-- @oldshape@) of @src@ with @newshape@.
reshapeInner :: Shape -> Int -> Shape -> Shape
reshapeInner :: Shape -> Int -> Shape -> Shape
reshapeInner Shape
newshape Int
n Shape
oldshape =
  forall d. [d] -> ShapeBase d
Shape (forall a. Int -> [a] -> [a]
take Int
n (forall d. ShapeBase d -> [d]
shapeDims Shape
oldshape)) forall a. Semigroup a => a -> a -> a
<> Shape
newshape

-- | @reshapeIndex to_dims from_dims is@ transforms the index list
-- @is@ (which is into an array of shape @from_dims@) into an index
-- list @is'@, which is into an array of shape @to_dims@.  @is@ must
-- have the same length as @from_dims@, and @is'@ will have the same
-- length as @to_dims@.
reshapeIndex ::
  IntegralExp num =>
  [num] ->
  [num] ->
  [num] ->
  [num]
reshapeIndex :: forall num. IntegralExp num => [num] -> [num] -> [num] -> [num]
reshapeIndex [num]
to_dims [num]
from_dims [num]
is =
  forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex [num]
to_dims forall a b. (a -> b) -> a -> b
$ forall num. IntegralExp num => [num] -> [num] -> num
flattenIndex [num]
from_dims [num]
is

-- | @unflattenIndex dims i@ computes a list of indices into an array
-- with dimension @dims@ given the flat index @i@.  The resulting list
-- will have the same size as @dims@.
unflattenIndex ::
  IntegralExp num =>
  [num] ->
  num ->
  [num]
unflattenIndex :: forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex = forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndexFromSlices forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Int -> [a] -> [a]
drop Int
1 forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall num. IntegralExp num => [num] -> [num]
sliceSizes

unflattenIndexFromSlices ::
  IntegralExp num =>
  [num] ->
  num ->
  [num]
unflattenIndexFromSlices :: forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndexFromSlices [] num
_ = []
unflattenIndexFromSlices (num
size : [num]
slices) num
i =
  (num
i forall e. IntegralExp e => e -> e -> e
`quot` num
size) forall a. a -> [a] -> [a]
: forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndexFromSlices [num]
slices (num
i forall a. Num a => a -> a -> a
- (num
i forall e. IntegralExp e => e -> e -> e
`quot` num
size) forall a. Num a => a -> a -> a
* num
size)

-- | @flattenIndex dims is@ computes the flat index of @is@ into an
-- array with dimensions @dims@.  The length of @dims@ and @is@ must
-- be the same.
flattenIndex ::
  IntegralExp num =>
  [num] ->
  [num] ->
  num
flattenIndex :: forall num. IntegralExp num => [num] -> [num] -> num
flattenIndex [num]
dims [num]
is =
  forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum forall a b. (a -> b) -> a -> b
$ forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith forall a. Num a => a -> a -> a
(*) [num]
is [num]
slicesizes
  where
    slicesizes :: [num]
slicesizes = forall a. Int -> [a] -> [a]
drop Int
1 forall a b. (a -> b) -> a -> b
$ forall num. IntegralExp num => [num] -> [num]
sliceSizes [num]
dims

-- | Given a length @n@ list of dimensions @dims@, @sizeSizes dims@
-- will compute a length @n+1@ list of the size of each possible array
-- slice.  The first element of this list will be the product of
-- @dims@, and the last element will be 1.
sliceSizes ::
  IntegralExp num =>
  [num] ->
  [num]
sliceSizes :: forall num. IntegralExp num => [num] -> [num]
sliceSizes [] = [num
1]
sliceSizes (num
n : [num]
ns) =
  forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product (num
n forall a. a -> [a] -> [a]
: [num]
ns) forall a. a -> [a] -> [a]
: forall num. IntegralExp num => [num] -> [num]
sliceSizes [num]
ns

{- HLINT ignore sliceSizes -}