module Data.Array.Repa.Operators.IndexSpace
( reshape
, append, (++)
, transpose
, extend
, slice
, backpermute
, backpermuteDft)
where
import Data.Array.Repa.Index
import Data.Array.Repa.Slice
import Data.Array.Repa.Internals.Elt
import Data.Array.Repa.Internals.Base
import Data.Array.Repa.Operators.Traverse
import Data.Array.Repa.Shape as S
import Prelude hiding ((++))
import qualified Prelude as P
stage = "Data.Array.Repa.Operators.IndexSpace"
reshape :: (Shape sh, Shape sh', Elt a)
=> sh'
-> Array sh a
-> Array sh' a
reshape sh' arr
| not $ S.size sh' == S.size (extent arr)
= error $ stage P.++ ".reshape: reshaped array will not match size of the original"
reshape sh' (Array sh [Region RangeAll gen])
= Array sh' [Region RangeAll gen']
where gen' = case gen of
GenManifest vec
-> GenManifest vec
GenCursor makeCursor _ loadElem
-> GenCursor
id
addDim
(loadElem . makeCursor . fromIndex sh . toIndex sh')
reshape _ _
= error $ stage P.++ ".reshape: can't reshape a partitioned array"
append, (++)
:: (Shape sh, Elt a)
=> Array (sh :. Int) a
-> Array (sh :. Int) a
-> Array (sh :. Int) a
append arr1 arr2
= unsafeTraverse2 arr1 arr2 fnExtent fnElem
where
(_ :. n) = extent arr1
fnExtent (sh :. i) (_ :. j)
= sh :. (i + j)
fnElem f1 f2 (sh :. i)
| i < n = f1 (sh :. i)
| otherwise = f2 (sh :. (i n))
(++) arr1 arr2 = append arr1 arr2
transpose
:: (Shape sh, Elt a)
=> Array (sh :. Int :. Int) a
-> Array (sh :. Int :. Int) a
transpose arr
= unsafeTraverse arr
(\(sh :. m :. n) -> (sh :. n :.m))
(\f -> \(sh :. i :. j) -> f (sh :. j :. i))
extend
:: ( Slice sl
, Shape (FullShape sl)
, Shape (SliceShape sl)
, Elt e)
=> sl
-> Array (SliceShape sl) e
-> Array (FullShape sl) e
extend sl arr
= backpermute
(fullOfSlice sl (extent arr))
(sliceOfFull sl)
arr
slice :: ( Slice sl
, Shape (FullShape sl)
, Shape (SliceShape sl)
, Elt e)
=> Array (FullShape sl) e
-> sl
-> Array (SliceShape sl) e
slice arr sl
= backpermute
(sliceOfFull sl (extent arr))
(fullOfSlice sl)
arr
backpermute
:: forall sh sh' a
. (Shape sh, Shape sh', Elt a)
=> sh'
-> (sh' -> sh)
-> Array sh a
-> Array sh' a
backpermute newExtent perm arr
= traverse arr (const newExtent) (. perm)
backpermuteDft
:: forall sh sh' a
. (Shape sh, Shape sh', Elt a)
=> Array sh' a
-> (sh' -> Maybe sh)
-> Array sh a
-> Array sh' a
backpermuteDft arrDft fnIndex arrSrc
= fromFunction (extent arrDft) fnElem
where fnElem ix
= case fnIndex ix of
Just ix' -> arrSrc ! ix'
Nothing -> arrDft ! ix