```{-# LANGUAGE TypeOperators, PatternGuards, RankNTypes, ScopedTypeVariables, BangPatterns, FlexibleContexts #-}
{-# OPTIONS -fno-warn-incomplete-patterns #-}

-- | Fast computation of Discrete Fourier Transforms using the Cooley-Tuckey algorithm.
--   Time complexity is O(n log n) in the size of the input.
--
--   This uses a naive divide-and-conquer algorithm, the absolute performance is about
--   50x slower than FFTW in estimate mode.
--
module Data.Array.Repa.Algorithms.FFT
( Mode(..)
, isPowerOfTwo
, fft3dP
, fft2dP
, fft1dP)
where
import Data.Array.Repa.Algorithms.Complex
import Data.Array.Repa				as R
import Data.Array.Repa.Eval                     as R
import Data.Array.Repa.Unsafe                   as R
import Prelude                                  as P

data Mode
= Forward
| Reverse
| Inverse
deriving (Show, Eq)

signOfMode :: Mode -> Double
signOfMode mode
= case mode of
Forward		-> (-1)
Reverse		->   1
Inverse		->   1
{-# INLINE signOfMode #-}

-- | Check if an `Int` is a power of two.
isPowerOfTwo :: Int -> Bool
isPowerOfTwo n
| 0	<- n		= True
| 2	<- n		= True
| n `mod` 2 == 0	= isPowerOfTwo (n `div` 2)
| otherwise		= False
{-# INLINE isPowerOfTwo #-}

-- 3D Transform -----------------------------------------------------------------------------------
-- | Compute the DFT of a 3d array. Array dimensions must be powers of two else `error`.
fft3dP 	:: (Source r Complex, Monad m)
=> Mode
-> Array r DIM3 Complex
-> m (Array U DIM3 Complex)
fft3dP mode arr
= let	_ :. depth :. height :. width	= extent arr
!sign	= signOfMode mode
!scale 	= fromIntegral (depth * width * height)

in	if not (isPowerOfTwo depth && isPowerOfTwo height && isPowerOfTwo width)
then error \$ unlines
[ "Data.Array.Repa.Algorithms.FFT: fft3d"
, "  Array dimensions must be powers of two,"
, "  but the provided array is "
P.++ show height P.++ "x" P.++ show width P.++ "x" P.++ show depth ]

else arr `deepSeqArray`
case mode of
Forward	-> now \$ fftTrans3d sign \$ fftTrans3d sign \$ fftTrans3d sign arr
Reverse	-> now \$ fftTrans3d sign \$ fftTrans3d sign \$ fftTrans3d sign arr
Inverse	-> computeP
\$  R.map (/ scale)
\$  fftTrans3d sign \$ fftTrans3d sign \$ fftTrans3d sign arr
{-# INLINE fft3dP #-}

fftTrans3d
:: Source r Complex
=> Double
-> Array r DIM3 Complex
-> Array U DIM3 Complex

fftTrans3d sign arr
= let	(sh :. len)	= extent arr
in	suspendedComputeP \$ rotate3d \$ fft sign sh len arr
{-# INLINE fftTrans3d #-}

rotate3d
:: Source r Complex
=> Array r DIM3 Complex -> Array D DIM3 Complex
rotate3d arr
= backpermute (sh :. m :. k :. l) f arr
where	(sh :. k :. l :. m)		= extent arr
f (sh' :. m' :. k' :. l')	= sh' :. k' :. l' :. m'
{-# INLINE rotate3d #-}

-- Matrix Transform -------------------------------------------------------------------------------
-- | Compute the DFT of a matrix. Array dimensions must be powers of two else `error`.
fft2dP 	:: (Source r Complex, Monad m)
=> Mode
-> Array r DIM2 Complex
-> m (Array U DIM2 Complex)
fft2dP mode arr
= let	_ :. height :. width	= extent arr
sign	= signOfMode mode
scale 	= fromIntegral (width * height)

in	if not (isPowerOfTwo height && isPowerOfTwo width)
then error \$ unlines
[ "Data.Array.Repa.Algorithms.FFT: fft2d"
, "  Array dimensions must be powers of two,"
, "  but the provided array is " P.++ show height P.++ "x" P.++ show width ]

else arr `deepSeqArray`
case mode of
Forward	-> now \$ fftTrans2d sign \$ fftTrans2d sign arr
Reverse	-> now \$ fftTrans2d sign \$ fftTrans2d sign arr
Inverse	-> computeP \$ R.map (/ scale) \$ fftTrans2d sign \$ fftTrans2d sign arr
{-# INLINE fft2dP #-}

fftTrans2d
:: Source r Complex
=> Double
-> Array r DIM2 Complex
-> Array U DIM2 Complex

fftTrans2d sign arr
= let  (sh :. len)	= extent arr
in	suspendedComputeP \$ transpose \$ fft sign sh len arr
{-# INLINE fftTrans2d #-}

-- Vector Transform -------------------------------------------------------------------------------
-- | Compute the DFT of a vector. Array dimensions must be powers of two else `error`.
fft1dP	:: (Source r Complex, Monad m)
=> Mode
-> Array r DIM1 Complex
-> m (Array U DIM1 Complex)
fft1dP mode arr
= let	_ :. len	= extent arr
sign	= signOfMode mode
scale	= fromIntegral len

in	if not \$ isPowerOfTwo len
then error \$ unlines
[ "Data.Array.Repa.Algorithms.FFT: fft1d"
, "  Array dimensions must be powers of two, "
, "  but the provided array is " P.++ show len ]

else arr `deepSeqArray`
case mode of
Forward	-> now \$ fftTrans1d sign arr
Reverse	-> now \$ fftTrans1d sign arr
Inverse -> computeP \$ R.map (/ scale) \$ fftTrans1d sign arr
{-# INLINE fft1dP #-}

fftTrans1d
:: Source r Complex
=> Double
-> Array r DIM1 Complex
-> Array U DIM1 Complex

fftTrans1d sign arr
= let	(sh :. len)	= extent arr
in	fft sign sh len arr
{-# INLINE fftTrans1d #-}

-- Rank Generalised Worker ------------------------------------------------------------------------
fft     :: (Shape sh, Source r Complex)
=> Double -> sh -> Int
-> Array r (sh :. Int) Complex
-> Array U (sh :. Int) Complex

fft !sign !sh !lenVec !vec
= go lenVec 0 1
where	go !len !offset !stride
| len == 2
= suspendedComputeP \$ fromFunction (sh :. 2) swivel

| otherwise
= combine len
(go (len `div` 2) offset            (stride * 2))
(go (len `div` 2) (offset + stride) (stride * 2))

where	swivel (sh' :. ix)
= case ix of
0	-> (vec `unsafeIndex` (sh' :. offset)) + (vec `unsafeIndex` (sh' :. (offset + stride)))
1	-> (vec `unsafeIndex` (sh' :. offset)) - (vec `unsafeIndex` (sh' :. (offset + stride)))

{-# INLINE combine #-}
combine !len' 	evens odds
= evens `deepSeqArray` odds `deepSeqArray`
let	odds'	= unsafeTraverse odds id (\get ix@(_ :. k) -> twiddle sign k len' * get ix)
in	suspendedComputeP \$ (evens +^ odds') R.++ (evens -^ odds')
{-# INLINE fft #-}

-- Compute a twiddle factor.
twiddle :: Double
-> Int 			-- index
-> Int 			-- length
-> Complex

twiddle sign k' n'
=  (cos (2 * pi * k / n), sign * sin  (2 * pi * k / n))
where 	k	= fromIntegral k'
n	= fromIntegral n'
{-# INLINE twiddle #-}

```