{-# LANGUAGE TypeOperators, PatternGuards, RankNTypes, ScopedTypeVariables, BangPatterns, FlexibleContexts #-}
{-# OPTIONS -fno-warn-incomplete-patterns #-}

-- | Fast computation of Discrete Fourier Transforms using the Cooley-Tuckey algorithm. 
--   Time complexity is O(n log n) in the size of the input. 
--
--   This uses a naive divide-and-conquer algorithm, the absolute performance is about
--   50x slower than FFTW in estimate mode.
--
module Data.Array.Repa.Algorithms.FFT
	( Mode(..)
	, isPowerOfTwo
	, fft3d
	, fft2d
	, fft1d)
where
import Data.Array.Repa.Algorithms.Complex
import Data.Array.Repa				as A

data Mode
	= Forward
	| Reverse
	| Inverse
	deriving (Show, Eq)

{-# INLINE signOfMode #-}
signOfMode :: Mode -> Double
signOfMode mode
 = case mode of
	Forward		-> (-1)
	Reverse		->   1
	Inverse		->   1


{-# INLINE isPowerOfTwo #-}
-- | Check if an `Int` is a power of two.
isPowerOfTwo :: Int -> Bool
isPowerOfTwo x
 = let	r	= (log (fromIntegral x) / log 2) :: Double
   in	ceiling r == (floor r :: Int)




-- 3D Transform -----------------------------------------------------------------------------------
-- | Compute the DFT of a 3d array. Array dimensions must be powers of two else `error`.
fft3d 	:: Mode
	-> Array DIM3 Complex
	-> Array DIM3 Complex

fft3d mode arr
 = let	_ :. depth :. height :. width	= extent arr
	!sign	= signOfMode mode
	!scale 	= fromIntegral (depth * width * height) 
		
   in	if not (isPowerOfTwo depth && isPowerOfTwo height && isPowerOfTwo width)
	 then error "Data.Array.Repa.Algorithms.FFT: fft3d -- array dimensions must be powers of two."
	 else arr `deepSeqArray` 
		case mode of
			Forward	-> fftTrans3d sign $ fftTrans3d sign $ fftTrans3d sign arr
			Reverse	-> fftTrans3d sign $ fftTrans3d sign $ fftTrans3d sign arr
			Inverse	-> force $ A.map (/ scale) 
					$ fftTrans3d sign $ fftTrans3d sign $ fftTrans3d sign arr

fftTrans3d 
	:: Double
	-> Array DIM3 Complex 
	-> Array DIM3 Complex

{-# NOINLINE fftTrans3d #-}
fftTrans3d sign arr'
 = let 	arr		= force arr'
	(sh :. len)	= extent arr
   in	force $ rotate3d $ fft sign sh len arr


rotate3d :: Array DIM3 Complex -> Array DIM3 Complex
{-# INLINE rotate3d #-}
rotate3d arr
 = backpermute (sh :. m :. k :. l) f arr
 where	(sh :. k :. l :. m)		= extent arr
	f (sh' :. m' :. k' :. l')	= sh' :. k' :. l' :. m'



-- Matrix Transform -------------------------------------------------------------------------------
-- | Compute the DFT of a matrix. Array dimensions must be powers of two else `error`.
fft2d 	:: Mode
	-> Array DIM2 Complex
	-> Array DIM2 Complex

fft2d mode arr
 = let	_ :. height :. width	= extent arr
	sign	= signOfMode mode
	scale 	= fromIntegral (width * height) 
		
   in	if not (isPowerOfTwo height && isPowerOfTwo width)
	 then error "Data.Array.Repa.Algorithms.FFT: fft2d -- array dimensions must be powers of two."
	 else arr `deepSeqArray` 
		case mode of
			Forward	-> fftTrans2d sign $ fftTrans2d sign arr
			Reverse	-> fftTrans2d sign $ fftTrans2d sign arr
			Inverse	-> force $ A.map (/ scale) $ fftTrans2d sign $ fftTrans2d sign arr

fftTrans2d 
	:: Double
	-> Array DIM2 Complex 
	-> Array DIM2 Complex

{-# NOINLINE fftTrans2d #-}
fftTrans2d sign arr'
 = let 	arr		= force arr'
	(sh :. len)	= extent arr
   in	force $ transpose $ fft sign sh len arr


-- Vector Transform -------------------------------------------------------------------------------
-- | Compute the DFT of a vector. Array dimensions must be powers of two else `error`.
fft1d	:: Mode 
	-> Array DIM1 Complex 
	-> Array DIM1 Complex
	
fft1d mode arr
 = let	_ :. len	= extent arr
	sign	= signOfMode mode
	scale	= fromIntegral len
	
   in	if not $ isPowerOfTwo len
	 then error "Data.Array.Repa.Algorithms.FFT: fft1d -- array dimensions must be powers of two."
	 else arr `deepSeqArray`
		case mode of
			Forward	-> fftTrans1d sign arr
			Reverse	-> fftTrans1d sign arr
			Inverse -> force $ A.map (/ scale) $ fftTrans1d sign arr

fftTrans1d
	:: Double 
	-> Array DIM1 Complex
	-> Array DIM1 Complex

{-# NOINLINE fftTrans1d #-}
fftTrans1d sign arr'
 = let	arr		= force arr'
	(sh :. len)	= extent arr
   in	fft sign sh len arr


-- Rank Generalised Worker ------------------------------------------------------------------------
{-# INLINE fft #-}
fft !sign !sh !lenVec !vec
 = go lenVec 0 1
 where	go !len !offset !stride
	 | len == 2
	 = force $ fromFunction (sh :. 2) swivel
	
	 | otherwise
	 = combine len 
		(go (len `div` 2) offset            (stride * 2))
		(go (len `div` 2) (offset + stride) (stride * 2))

	 where	swivel (sh' :. ix)
		 = case ix of
			0	-> (vec `unsafeIndex` (sh' :. offset)) + (vec `unsafeIndex` (sh' :. (offset + stride)))
			1	-> (vec `unsafeIndex` (sh' :. offset)) - (vec `unsafeIndex` (sh' :. (offset + stride)))

		{-# INLINE combine #-}
		combine !len' 	evens@(Array _ [Region RangeAll GenManifest{}]) 
				 odds@(Array _ [Region RangeAll GenManifest{}])
 	 	 = evens `deepSeqArray` odds `deepSeqArray`
   	   	   let	odds'	= unsafeTraverse odds id (\get ix@(_ :. k) -> twiddle sign k len' * get ix) 
   	   	   in	force 	$ (evens +^ odds') A.++ (evens -^ odds')


-- Compute a twiddle factor.
twiddle :: Double
	-> Int 			-- index
	-> Int 			-- length
	-> Complex

{-# INLINE twiddle #-}
twiddle sign k' n'
 	=  (cos (2 * pi * k / n), sign * sin  (2 * pi * k / n))
	where 	k	= fromIntegral k'
		n	= fromIntegral n'