module Data.Array.Accelerate.Math.DFT.Centre (
centre1D, centre2D, centre3D,
shift1D, shift2D, shift3D,
) where
import Prelude as P
import Data.Array.Accelerate as A
import Data.Array.Accelerate.Math.Complex
centre1D :: (Elt e, IsFloating e)
=> Acc (Array DIM1 (Complex e))
-> Acc (Array DIM1 (Complex e))
centre1D arr
= A.generate (shape arr)
(\ix -> let Z :. x = unlift ix :: Z :. Exp Int
in lift (((1) ** A.fromIntegral x) :+ A.constant 0) * arr!ix)
centre2D :: (Elt e, IsFloating e)
=> Acc (Array DIM2 (Complex e))
-> Acc (Array DIM2 (Complex e))
centre2D arr
= A.generate (shape arr)
(\ix -> let Z :. y :. x = unlift ix :: Z :. Exp Int :. Exp Int
in lift (((1) ** A.fromIntegral (y + x)) :+ A.constant 0) * arr!ix)
centre3D :: (Elt e, IsFloating e)
=> Acc (Array DIM3 (Complex e))
-> Acc (Array DIM3 (Complex e))
centre3D arr
= A.generate (shape arr)
(\ix -> let Z :. z :. y :. x = unlift ix :: Z :. Exp Int :. Exp Int :. Exp Int
in lift (((1) ** A.fromIntegral (z + y + x)) :+ A.constant 0) * arr!ix)
shift1D :: Elt e => Acc (Vector e) -> Acc (Vector e)
shift1D arr
= A.backpermute (A.shape arr) p arr
where
p ix
= let Z:.x = unlift ix :: Z :. Exp Int
in index1 (x <* mw ? (x + mw, x mw))
Z:.w = unlift (A.shape arr)
mw = w `div` 2
shift2D :: Elt e => Acc (Array DIM2 e) -> Acc (Array DIM2 e)
shift2D arr
= A.backpermute (A.shape arr) p arr
where
p ix
= let Z:.y:.x = unlift ix :: Z :. Exp Int :. Exp Int
in index2 (y <* mh ? (y + mh, y mh))
(x <* mw ? (x + mw, x mw))
Z:.h:.w = unlift (A.shape arr)
(mh,mw) = (h `div` 2, w `div` 2)
shift3D :: Elt e => Acc (Array DIM3 e) -> Acc (Array DIM3 e)
shift3D arr
= A.backpermute (A.shape arr) p arr
where
p ix
= let Z:.z:.y:.x = unlift ix :: Z :. Exp Int :. Exp Int :. Exp Int
in index3 (z <* md ? (z + md, z md))
(y <* mh ? (y + mh, y mh))
(x <* mw ? (x + mw, x mw))
Z:.h:.w:.d = unlift (A.shape arr)
(mh,mw,md) = (h `div` 2, w `div` 2, d `div` 2)
index3 i j k = lift (Z:.i:.j:.k)