module Futhark.Representation.AST.Attributes.Reshape ( -- * Basic tools newDim , newDims , newShape -- * Construction , shapeCoerce , repeatShapes -- * Execution , reshapeOuter , reshapeInner , repeatDims -- * Inspection , shapeCoercion -- * Simplification , fuseReshape , informReshape -- * Shape calculations , reshapeIndex , flattenIndex , unflattenIndex , sliceSizes ) where import Data.Foldable import Prelude hiding (sum, product, quot) import Futhark.Representation.AST.Attributes.Types import Futhark.Representation.AST.Syntax import Futhark.Util.IntegralExp -- | The new dimension. newDim :: DimChange d -> d newDim (DimCoercion se) = se newDim (DimNew se) = se -- | The new dimensions resulting from a reshape operation. newDims :: ShapeChange d -> [d] newDims = map newDim -- | The new shape resulting from a reshape operation. newShape :: ShapeChange SubExp -> Shape newShape = Shape . newDims -- ^ Construct a 'Reshape' where all dimension changes are -- 'DimCoercion's. shapeCoerce :: [SubExp] -> VName -> Exp lore shapeCoerce newdims arr = BasicOp $ Reshape (map DimCoercion newdims) arr -- | Construct a pair suitable for a 'Repeat'. repeatShapes :: [Shape] -> Type -> ([Shape], Shape) repeatShapes shapes t = case splitAt t_rank shapes of (outer_shapes, [inner_shape]) -> (outer_shapes, inner_shape) _ -> (shapes ++ replicate (length shapes - t_rank) (Shape []), Shape []) where t_rank = arrayRank t -- | Modify the shape of an array type as 'Repeat' would do repeatDims :: [Shape] -> Shape -> Type -> Type repeatDims shape innershape = modifyArrayShape repeatDims' where repeatDims' (Shape ds) = Shape $ concat (zipWith (++) (map shapeDims shape) (map pure ds)) ++ shapeDims innershape -- | @reshapeOuter newshape n oldshape@ returns a 'Reshape' expression -- that replaces the outer @n@ dimensions of @oldshape@ with @newshape@. reshapeOuter :: ShapeChange SubExp -> Int -> Shape -> ShapeChange SubExp reshapeOuter newshape n oldshape = newshape ++ map coercion_or_new (drop n (shapeDims oldshape)) where coercion_or_new | length newshape == n = DimCoercion | otherwise = DimNew -- | @reshapeInner newshape n oldshape@ returns a 'Reshape' expression -- that replaces the inner @m-n@ dimensions (where @m@ is the rank of -- @oldshape@) of @src@ with @newshape@. reshapeInner :: ShapeChange SubExp -> Int -> Shape -> ShapeChange SubExp reshapeInner newshape n oldshape = map coercion_or_new (take n (shapeDims oldshape)) ++ newshape where coercion_or_new | length newshape == m-n = DimCoercion | otherwise = DimNew m = shapeRank oldshape -- | If the shape change is nothing but shape coercions, return the new dimensions. Otherwise, return -- 'Nothing'. shapeCoercion :: ShapeChange d -> Maybe [d] shapeCoercion = mapM dimCoercion where dimCoercion (DimCoercion d) = Just d dimCoercion (DimNew _) = Nothing -- | @fuseReshape s1 s2@ creates a new 'ShapeChange' that is -- semantically the same as first applying @s1@ and then @s2@. This -- may take advantage of properties of 'DimCoercion' versus 'DimNew' -- to preserve information. fuseReshape :: Eq d => ShapeChange d -> ShapeChange d -> ShapeChange d fuseReshape s1 s2 | length s1 == length s2 = zipWith comb s1 s2 where comb (DimNew _) (DimCoercion d2) = DimNew d2 comb (DimCoercion d1) (DimNew d2) | d1 == d2 = DimCoercion d2 | otherwise = DimNew d2 comb _ d2 = d2 -- TODO: intelligently handle case where s1 is a prefix of s2. fuseReshape _ s2 = s2 -- | Given concrete information about the shape of the source array, -- convert some 'DimNew's into 'DimCoercion's. informReshape :: Eq d => [d] -> ShapeChange d -> ShapeChange d informReshape shape sc | length shape == length sc = zipWith inform shape sc where inform d1 (DimNew d2) | d1 == d2 = DimCoercion d2 inform _ dc = dc informReshape _ sc = sc -- | @reshapeIndex to_dims from_dims is@ transforms the index list -- @is@ (which is into an array of shape @from_dims@) into an index -- list @is'@, which is into an array of shape @to_dims@. @is@ must -- have the same length as @from_dims@, and @is'@ will have the same -- length as @to_dims@. reshapeIndex :: IntegralExp num => [num] -> [num] -> [num] -> [num] reshapeIndex to_dims from_dims is = unflattenIndex to_dims $ flattenIndex from_dims is -- | @unflattenIndex dims i@ computes a list of indices into an array -- with dimension @dims@ given the flat index @i@. The resulting list -- will have the same size as @dims@. unflattenIndex :: IntegralExp num => [num] -> num -> [num] unflattenIndex = unflattenIndexFromSlices . drop 1 . sliceSizes unflattenIndexFromSlices :: IntegralExp num => [num] -> num -> [num] unflattenIndexFromSlices [] _ = [] unflattenIndexFromSlices (size : slices) i = (i `quot` size) : unflattenIndexFromSlices slices (i - (i `quot` size) * size) -- | @flattenIndex dims is@ computes the flat index of @is@ into an -- array with dimensions @dims@. The length of @dims@ and @is@ must -- be the same. flattenIndex :: IntegralExp num => [num] -> [num] -> num flattenIndex dims is = sum $ zipWith (*) is slicesizes where slicesizes = drop 1 $ sliceSizes dims -- | Given a length @n@ list of dimensions @dims@, @sizeSizes dims@ -- will compute a length @n+1@ list of the size of each possible array -- slice. The first element of this list will be the product of -- @dims@, and the last element will be 1. sliceSizes :: IntegralExp num => [num] -> [num] sliceSizes [] = [1] sliceSizes (n:ns) = product (n : ns) : sliceSizes ns {- HLINT ignore sliceSizes -}