module Futhark.IR.Prop.Reshape
  ( 
    newDim,
    newDims,
    newShape,
    
    shapeCoerce,
    
    reshapeOuter,
    reshapeInner,
    
    shapeCoercion,
    
    fuseReshape,
    informReshape,
    
    reshapeIndex,
    flattenIndex,
    unflattenIndex,
    sliceSizes,
  )
where
import Data.Foldable
import Futhark.IR.Syntax
import Futhark.Util.IntegralExp
import Prelude hiding (product, quot, sum)
newDim :: DimChange d -> d
newDim :: DimChange d -> d
newDim (DimCoercion d
se) = d
se
newDim (DimNew d
se) = d
se
newDims :: ShapeChange d -> [d]
newDims :: ShapeChange d -> [d]
newDims = (DimChange d -> d) -> ShapeChange d -> [d]
forall a b. (a -> b) -> [a] -> [b]
map DimChange d -> d
forall d. DimChange d -> d
newDim
newShape :: ShapeChange SubExp -> Shape
newShape :: ShapeChange SubExp -> Shape
newShape = [SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape ([SubExp] -> Shape)
-> (ShapeChange SubExp -> [SubExp]) -> ShapeChange SubExp -> Shape
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ShapeChange SubExp -> [SubExp]
forall d. ShapeChange d -> [d]
newDims
shapeCoerce :: [SubExp] -> VName -> Exp rep
shapeCoerce :: [SubExp] -> VName -> Exp rep
shapeCoerce [SubExp]
newdims VName
arr =
  BasicOp -> Exp rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> Exp rep) -> BasicOp -> Exp rep
forall a b. (a -> b) -> a -> b
$ ShapeChange SubExp -> VName -> BasicOp
Reshape ((SubExp -> DimChange SubExp) -> [SubExp] -> ShapeChange SubExp
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> DimChange SubExp
forall d. d -> DimChange d
DimCoercion [SubExp]
newdims) VName
arr
reshapeOuter :: ShapeChange SubExp -> Int -> Shape -> ShapeChange SubExp
reshapeOuter :: ShapeChange SubExp -> Int -> Shape -> ShapeChange SubExp
reshapeOuter ShapeChange SubExp
newshape Int
n Shape
oldshape =
  ShapeChange SubExp
newshape ShapeChange SubExp -> ShapeChange SubExp -> ShapeChange SubExp
forall a. [a] -> [a] -> [a]
++ (SubExp -> DimChange SubExp) -> [SubExp] -> ShapeChange SubExp
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> DimChange SubExp
forall d. d -> DimChange d
coercion_or_new (Int -> [SubExp] -> [SubExp]
forall a. Int -> [a] -> [a]
drop Int
n (Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
oldshape))
  where
    coercion_or_new :: d -> DimChange d
coercion_or_new
      | ShapeChange SubExp -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ShapeChange SubExp
newshape Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
n = d -> DimChange d
forall d. d -> DimChange d
DimCoercion
      | Bool
otherwise = d -> DimChange d
forall d. d -> DimChange d
DimNew
reshapeInner :: ShapeChange SubExp -> Int -> Shape -> ShapeChange SubExp
reshapeInner :: ShapeChange SubExp -> Int -> Shape -> ShapeChange SubExp
reshapeInner ShapeChange SubExp
newshape Int
n Shape
oldshape =
  (SubExp -> DimChange SubExp) -> [SubExp] -> ShapeChange SubExp
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> DimChange SubExp
forall d. d -> DimChange d
coercion_or_new (Int -> [SubExp] -> [SubExp]
forall a. Int -> [a] -> [a]
take Int
n (Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
oldshape)) ShapeChange SubExp -> ShapeChange SubExp -> ShapeChange SubExp
forall a. [a] -> [a] -> [a]
++ ShapeChange SubExp
newshape
  where
    coercion_or_new :: d -> DimChange d
coercion_or_new
      | ShapeChange SubExp -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ShapeChange SubExp
newshape Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
m Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
n = d -> DimChange d
forall d. d -> DimChange d
DimCoercion
      | Bool
otherwise = d -> DimChange d
forall d. d -> DimChange d
DimNew
    m :: Int
m = Shape -> Int
forall a. ArrayShape a => a -> Int
shapeRank Shape
oldshape
shapeCoercion :: ShapeChange d -> Maybe [d]
shapeCoercion :: ShapeChange d -> Maybe [d]
shapeCoercion = (DimChange d -> Maybe d) -> ShapeChange d -> Maybe [d]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM DimChange d -> Maybe d
forall a. DimChange a -> Maybe a
dimCoercion
  where
    dimCoercion :: DimChange a -> Maybe a
dimCoercion (DimCoercion a
d) = a -> Maybe a
forall a. a -> Maybe a
Just a
d
    dimCoercion (DimNew a
_) = Maybe a
forall a. Maybe a
Nothing
fuseReshape :: Eq d => ShapeChange d -> ShapeChange d -> ShapeChange d
fuseReshape :: ShapeChange d -> ShapeChange d -> ShapeChange d
fuseReshape ShapeChange d
s1 ShapeChange d
s2
  | ShapeChange d -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ShapeChange d
s1 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== ShapeChange d -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ShapeChange d
s2 =
    (DimChange d -> DimChange d -> DimChange d)
-> ShapeChange d -> ShapeChange d -> ShapeChange d
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith DimChange d -> DimChange d -> DimChange d
forall d. Eq d => DimChange d -> DimChange d -> DimChange d
comb ShapeChange d
s1 ShapeChange d
s2
  where
    comb :: DimChange d -> DimChange d -> DimChange d
comb (DimNew d
_) (DimCoercion d
d2) =
      d -> DimChange d
forall d. d -> DimChange d
DimNew d
d2
    comb (DimCoercion d
d1) (DimNew d
d2)
      | d
d1 d -> d -> Bool
forall a. Eq a => a -> a -> Bool
== d
d2 = d -> DimChange d
forall d. d -> DimChange d
DimCoercion d
d2
      | Bool
otherwise = d -> DimChange d
forall d. d -> DimChange d
DimNew d
d2
    comb DimChange d
_ DimChange d
d2 =
      DimChange d
d2
fuseReshape ShapeChange d
_ ShapeChange d
s2 = ShapeChange d
s2
informReshape :: Eq d => [d] -> ShapeChange d -> ShapeChange d
informReshape :: [d] -> ShapeChange d -> ShapeChange d
informReshape [d]
shape ShapeChange d
sc
  | [d] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [d]
shape Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== ShapeChange d -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ShapeChange d
sc =
    (d -> DimChange d -> DimChange d)
-> [d] -> ShapeChange d -> ShapeChange d
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith d -> DimChange d -> DimChange d
forall d. Eq d => d -> DimChange d -> DimChange d
inform [d]
shape ShapeChange d
sc
  where
    inform :: d -> DimChange d -> DimChange d
inform d
d1 (DimNew d
d2)
      | d
d1 d -> d -> Bool
forall a. Eq a => a -> a -> Bool
== d
d2 = d -> DimChange d
forall d. d -> DimChange d
DimCoercion d
d2
    inform d
_ DimChange d
dc =
      DimChange d
dc
informReshape [d]
_ ShapeChange d
sc = ShapeChange d
sc
reshapeIndex ::
  IntegralExp num =>
  [num] ->
  [num] ->
  [num] ->
  [num]
reshapeIndex :: [num] -> [num] -> [num] -> [num]
reshapeIndex [num]
to_dims [num]
from_dims [num]
is =
  [num] -> num -> [num]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex [num]
to_dims (num -> [num]) -> num -> [num]
forall a b. (a -> b) -> a -> b
$ [num] -> [num] -> num
forall num. IntegralExp num => [num] -> [num] -> num
flattenIndex [num]
from_dims [num]
is
unflattenIndex ::
  IntegralExp num =>
  [num] ->
  num ->
  [num]
unflattenIndex :: [num] -> num -> [num]
unflattenIndex = [num] -> num -> [num]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndexFromSlices ([num] -> num -> [num])
-> ([num] -> [num]) -> [num] -> num -> [num]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> [num] -> [num]
forall a. Int -> [a] -> [a]
drop Int
1 ([num] -> [num]) -> ([num] -> [num]) -> [num] -> [num]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [num] -> [num]
forall num. IntegralExp num => [num] -> [num]
sliceSizes
unflattenIndexFromSlices ::
  IntegralExp num =>
  [num] ->
  num ->
  [num]
unflattenIndexFromSlices :: [num] -> num -> [num]
unflattenIndexFromSlices [] num
_ = []
unflattenIndexFromSlices (num
size : [num]
slices) num
i =
  (num
i num -> num -> num
forall e. IntegralExp e => e -> e -> e
`quot` num
size) num -> [num] -> [num]
forall a. a -> [a] -> [a]
: [num] -> num -> [num]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndexFromSlices [num]
slices (num
i num -> num -> num
forall a. Num a => a -> a -> a
- (num
i num -> num -> num
forall e. IntegralExp e => e -> e -> e
`quot` num
size) num -> num -> num
forall a. Num a => a -> a -> a
* num
size)
flattenIndex ::
  IntegralExp num =>
  [num] ->
  [num] ->
  num
flattenIndex :: [num] -> [num] -> num
flattenIndex [num]
dims [num]
is =
  [num] -> num
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([num] -> num) -> [num] -> num
forall a b. (a -> b) -> a -> b
$ (num -> num -> num) -> [num] -> [num] -> [num]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith num -> num -> num
forall a. Num a => a -> a -> a
(*) [num]
is [num]
slicesizes
  where
    slicesizes :: [num]
slicesizes = Int -> [num] -> [num]
forall a. Int -> [a] -> [a]
drop Int
1 ([num] -> [num]) -> [num] -> [num]
forall a b. (a -> b) -> a -> b
$ [num] -> [num]
forall num. IntegralExp num => [num] -> [num]
sliceSizes [num]
dims
sliceSizes ::
  IntegralExp num =>
  [num] ->
  [num]
sliceSizes :: [num] -> [num]
sliceSizes [] = [num
1]
sliceSizes (num
n : [num]
ns) =
  [num] -> num
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product (num
n num -> [num] -> [num]
forall a. a -> [a] -> [a]
: [num]
ns) num -> [num] -> [num]
forall a. a -> [a] -> [a]
: [num] -> [num]
forall num. IntegralExp num => [num] -> [num]
sliceSizes [num]
ns