module Futhark.IR.Prop.Reshape
(
shapeCoerce,
reshapeOuter,
reshapeInner,
reshapeIndex,
flattenIndex,
unflattenIndex,
sliceSizes,
)
where
import Data.Foldable
import Futhark.IR.Syntax
import Futhark.Util.IntegralExp
import Prelude hiding (product, quot, sum)
shapeCoerce :: [SubExp] -> VName -> Exp rep
shapeCoerce :: forall rep. [SubExp] -> VName -> Exp rep
shapeCoerce [SubExp]
newdims VName
arr =
BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp rep) -> BasicOp -> Exp rep
forall a b. (a -> b) -> a -> b
$ ReshapeKind -> Shape -> VName -> BasicOp
Reshape ReshapeKind
ReshapeCoerce ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp]
newdims) VName
arr
reshapeOuter :: Shape -> Int -> Shape -> Shape
reshapeOuter :: Shape -> Int -> Shape -> Shape
reshapeOuter Shape
newshape Int
n Shape
oldshape =
Shape
newshape Shape -> Shape -> Shape
forall a. Semigroup a => a -> a -> a
<> [SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape (Int -> [SubExp] -> [SubExp]
forall a. Int -> [a] -> [a]
drop Int
n (Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
oldshape))
reshapeInner :: Shape -> Int -> Shape -> Shape
reshapeInner :: Shape -> Int -> Shape -> Shape
reshapeInner Shape
newshape Int
n Shape
oldshape =
[SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape (Int -> [SubExp] -> [SubExp]
forall a. Int -> [a] -> [a]
take Int
n (Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
oldshape)) Shape -> Shape -> Shape
forall a. Semigroup a => a -> a -> a
<> Shape
newshape
reshapeIndex ::
(IntegralExp num) =>
[num] ->
[num] ->
[num] ->
[num]
reshapeIndex :: forall num. IntegralExp num => [num] -> [num] -> [num] -> [num]
reshapeIndex [num]
to_dims [num]
from_dims [num]
is =
[num] -> num -> [num]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex [num]
to_dims (num -> [num]) -> num -> [num]
forall a b. (a -> b) -> a -> b
$ [num] -> [num] -> num
forall num. IntegralExp num => [num] -> [num] -> num
flattenIndex [num]
from_dims [num]
is
unflattenIndex ::
(IntegralExp num) =>
[num] ->
num ->
[num]
unflattenIndex :: forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex = [num] -> num -> [num]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndexFromSlices ([num] -> num -> [num])
-> ([num] -> [num]) -> [num] -> num -> [num]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> [num] -> [num]
forall a. Int -> [a] -> [a]
drop Int
1 ([num] -> [num]) -> ([num] -> [num]) -> [num] -> [num]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [num] -> [num]
forall num. IntegralExp num => [num] -> [num]
sliceSizes
unflattenIndexFromSlices ::
(IntegralExp num) =>
[num] ->
num ->
[num]
unflattenIndexFromSlices :: forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndexFromSlices [] num
_ = []
unflattenIndexFromSlices (num
size : [num]
slices) num
i =
(num
i num -> num -> num
forall e. IntegralExp e => e -> e -> e
`quot` num
size) num -> [num] -> [num]
forall a. a -> [a] -> [a]
: [num] -> num -> [num]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndexFromSlices [num]
slices (num
i num -> num -> num
forall a. Num a => a -> a -> a
- (num
i num -> num -> num
forall e. IntegralExp e => e -> e -> e
`quot` num
size) num -> num -> num
forall a. Num a => a -> a -> a
* num
size)
flattenIndex ::
(IntegralExp num) =>
[num] ->
[num] ->
num
flattenIndex :: forall num. IntegralExp num => [num] -> [num] -> num
flattenIndex [num]
dims [num]
is
| [num] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [num]
is Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= [num] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [num]
slicesizes = [Char] -> num
forall a. HasCallStack => [Char] -> a
error [Char]
"flattenIndex: length mismatch"
| Bool
otherwise = [num] -> num
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([num] -> num) -> [num] -> num
forall a b. (a -> b) -> a -> b
$ (num -> num -> num) -> [num] -> [num] -> [num]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith num -> num -> num
forall a. Num a => a -> a -> a
(*) [num]
is [num]
slicesizes
where
slicesizes :: [num]
slicesizes = Int -> [num] -> [num]
forall a. Int -> [a] -> [a]
drop Int
1 ([num] -> [num]) -> [num] -> [num]
forall a b. (a -> b) -> a -> b
$ [num] -> [num]
forall num. IntegralExp num => [num] -> [num]
sliceSizes [num]
dims
sliceSizes ::
(IntegralExp num) =>
[num] ->
[num]
sliceSizes :: forall num. IntegralExp num => [num] -> [num]
sliceSizes [] = [num
1]
sliceSizes (num
n : [num]
ns) =
[num] -> num
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product (num
n num -> [num] -> [num]
forall a. a -> [a] -> [a]
: [num]
ns) num -> [num] -> [num]
forall a. a -> [a] -> [a]
: [num] -> [num]
forall num. IntegralExp num => [num] -> [num]
sliceSizes [num]
ns