module Data.Array.Repa.Operators.IndexSpace
( reshape
, append, (++)
, transpose
, extend
, slice
, backpermute, unsafeBackpermute
, backpermuteDft, unsafeBackpermuteDft)
where
import Data.Array.Repa.Index
import Data.Array.Repa.Slice
import Data.Array.Repa.Base
import Data.Array.Repa.Repr.Delayed
import Data.Array.Repa.Operators.Traversal
import Data.Array.Repa.Shape as S
import Prelude hiding ((++))
import qualified Prelude as P
stage = "Data.Array.Repa.Operators.IndexSpace"
reshape :: (Shape sh2, Shape sh1
, Repr r1 e)
=> sh2
-> Array r1 sh1 e
-> Array D sh2 e
reshape sh2 arr
| not $ S.size sh2 == S.size (extent arr)
= error
$ stage P.++ ".reshape: reshaped array will not match size of the original"
reshape sh2 arr
= fromFunction sh2
$ unsafeIndex arr . fromIndex (extent arr) . toIndex sh2
append, (++)
:: ( Shape sh
, Repr r1 e, Repr r2 e)
=> Array r1 (sh :. Int) e
-> Array r2 (sh :. Int) e
-> Array D (sh :. Int) e
append arr1 arr2
= unsafeTraverse2 arr1 arr2 fnExtent fnElem
where
(_ :. n) = extent arr1
fnExtent (sh :. i) (_ :. j)
= sh :. (i + j)
fnElem f1 f2 (sh :. i)
| i < n = f1 (sh :. i)
| otherwise = f2 (sh :. (i n))
(++) arr1 arr2 = append arr1 arr2
transpose
:: ( Shape sh
, Repr r e)
=> Array r (sh :. Int :. Int) e
-> Array D (sh :. Int :. Int) e
transpose arr
= unsafeTraverse arr
(\(sh :. m :. n) -> (sh :. n :.m))
(\f -> \(sh :. i :. j) -> f (sh :. j :. i))
extend
:: ( Slice sl
, Shape (FullShape sl)
, Shape (SliceShape sl)
, Repr r e)
=> sl
-> Array r (SliceShape sl) e
-> Array D (FullShape sl) e
extend sl arr
= unsafeBackpermute
(fullOfSlice sl (extent arr))
(sliceOfFull sl)
arr
slice :: ( Slice sl
, Shape (FullShape sl)
, Shape (SliceShape sl)
, Repr r e)
=> Array r (FullShape sl) e
-> sl
-> Array D (SliceShape sl) e
slice arr sl
= unsafeBackpermute
(sliceOfFull sl (extent arr))
(fullOfSlice sl)
arr
backpermute, unsafeBackpermute
:: forall r sh1 sh2 e
. ( Shape sh1, Shape sh2
, Repr r e)
=> sh2
-> (sh2 -> sh1)
-> Array r sh1 e
-> Array D sh2 e
backpermute newExtent perm arr
= traverse arr (const newExtent) (. perm)
unsafeBackpermute newExtent perm arr
= unsafeTraverse arr (const newExtent) (. perm)
backpermuteDft, unsafeBackpermuteDft
:: forall r0 r1 sh1 sh2 e
. ( Shape sh1, Shape sh2
, Repr r0 e, Repr r1 e)
=> Array r0 sh2 e
-> (sh2 -> Maybe sh1)
-> Array r1 sh1 e
-> Array D sh2 e
backpermuteDft arrDft fnIndex arrSrc
= fromFunction (extent arrDft) fnElem
where fnElem ix
= case fnIndex ix of
Just ix' -> arrSrc `index` ix'
Nothing -> arrDft `index` ix
unsafeBackpermuteDft arrDft fnIndex arrSrc
= fromFunction (extent arrDft) fnElem
where fnElem ix
= case fnIndex ix of
Just ix' -> arrSrc `unsafeIndex` ix'
Nothing -> arrDft `unsafeIndex` ix