{-# LANGUAGE ConstraintKinds  #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TypeOperators    #-}
-- |
-- Module      : Data.Array.Accelerate.Math.DFT.Centre
-- Copyright   : [2012..2017] Manuel M T Chakravarty, Gabriele Keller, Trevor L. McDonell
--               [2013..2017] Robert Clifton-Everest
-- License     : BSD3
--
-- Maintainer  : Trevor L. McDonell <tmcdonell@cse.unsw.edu.au>
-- Stability   : experimental
-- Portability : non-portable (GHC extensions)
--
-- These transforms allow the centering of the frequency domain of a DFT such
-- that the the zero frequency is in the middle. The centering transform, when
-- performed on the input of a DFT, will cause zero frequency to be centred in
-- the middle. The shifting transform however takes the output of a DFT to
-- give the same result. Therefore the relationship between the two is:
--
-- > fft(center(X)) = shift(fft(X))
--
module Data.Array.Accelerate.Math.DFT.Centre (

  centre1D, centre2D, centre3D,
  shift1D,  shift2D,  shift3D,

) where

import Prelude                                  as P
import Data.Array.Accelerate                    as A
import Data.Array.Accelerate.Data.Complex


-- | Apply the centring transform to a vector
--
centre1D :: (A.RealFloat e, A.FromIntegral Int e)
         => Acc (Array DIM1 (Complex e))
         -> Acc (Array DIM1 (Complex e))
centre1D arr
  = A.generate (shape arr)
               (\ix -> let Z :. x = unlift ix           :: Z :. Exp Int
                       in  lift (((-1) ** A.fromIntegral x) :+ 0) * arr!ix)

-- | Apply the centring transform to a matrix
--
centre2D :: (A.RealFloat e, A.FromIntegral Int e)
         => Acc (Array DIM2 (Complex e))
         -> Acc (Array DIM2 (Complex e))
centre2D arr
  = A.generate (shape arr)
               (\ix -> let Z :. y :. x = unlift ix      :: Z :. Exp Int :. Exp Int
                       in  lift (((-1) ** A.fromIntegral (y + x)) :+ 0) * arr!ix)

-- | Apply the centring transform to a 3D array
--
centre3D :: (A.RealFloat e, A.FromIntegral Int e)
         => Acc (Array DIM3 (Complex e))
         -> Acc (Array DIM3 (Complex e))
centre3D arr
  = A.generate (shape arr)
               (\ix -> let Z :. z :. y :. x = unlift ix :: Z :. Exp Int :. Exp Int :. Exp Int
                       in  lift (((-1) ** A.fromIntegral (z + y + x)) :+ 0) * arr!ix)


-- | Apply the shifting transform to a vector
--
shift1D :: Elt e => Acc (Vector e) -> Acc (Vector e)
shift1D arr
  = A.backpermute (A.shape arr) p arr
  where
    p ix
      = let Z:.x = unlift ix :: Z :. Exp Int
        in index1 (x A.< mw ? (x + mw, x - mw))
    Z:.w    = unlift (A.shape arr)
    mw      = w `div` 2


-- | Apply the shifting transform to a 2D array
--
shift2D :: Elt e => Acc (Array DIM2 e) -> Acc (Array DIM2 e)
shift2D arr
  = A.backpermute (A.shape arr) p arr
  where
    p ix
      = let Z:.y:.x = unlift ix :: Z :. Exp Int :. Exp Int
        in index2 (y A.< mh ? (y + mh, y - mh))
                  (x A.< mw ? (x + mw, x - mw))
    Z:.h:.w = unlift (A.shape arr)
    (mh,mw) = (h `div` 2, w `div` 2)


-- | Apply the shifting transform to a 3D array
--
shift3D :: Elt e => Acc (Array DIM3 e) -> Acc (Array DIM3 e)
shift3D arr
  = A.backpermute (A.shape arr) p arr
  where
    p ix
      = let Z:.z:.y:.x = unlift ix :: Z :. Exp Int :. Exp Int :. Exp Int
        in index3 (z A.< md ? (z + md, z - md))
                  (y A.< mh ? (y + mh, y - mh))
                  (x A.< mw ? (x + mw, x - mw))
    Z:.h:.w:.d   = unlift (A.shape arr)
    (mh,mw,md)   = (h `div` 2, w `div` 2, d `div` 2)