{-# LANGUAGE TypeOperators, PatternGuards, RankNTypes #-}

-- | Fast computation of Discrete Fourier Transforms using the Cooley-Tuckey algorithm.
--
--   Time complexity is O(n log n) in the size of the input.
--
--   Input dimensions must be powers of two, else `error`.
--
--   The `fft` and `ifft` functions (and friends) also compute the roots of unity needed.
--   If you need to transform several arrays with the same extent then it is faster to
--   compute the roots once using `calcRootsOfUnity` or `calcInverseRootsOfUnity`, 
--   then call `fftWithRoots` directly.
--
--   The inverse transforms provided also perform post-scaling so that `ifft` is the true inverse of `fft`. 
--   If you don't want that then call `fftWithRoots` directly.
--
--   The functions `fft2d` and `fft3d` require their inputs to be squares (and cubes) respectively. 
--   This allows them to reuse the same roots-of-unity when transforming along each axis. If you 
--   need to transform rectanglular arrays then call `fftWithRoots` directly.
module Data.Array.Repa.Algorithms.FFT
	( fft,   ifft
	, fft2d, ifft2d
	, fft3d, ifft3d
	, fftWithRoots )
where
import Data.Array.Repa.Algorithms.DFT.Roots
import Data.Array.Repa.Algorithms.Complex
import Data.Array.Repa				as A
import Data.Ratio

-- Vector Transform -------------------------------------------------------------------------------
-- | Compute the DFT along the low order dimension of an array.
fft	:: Shape sh
	=> Array (sh :. Int) Complex
	-> Array (sh :. Int) Complex

fft v
 = let	rofu	= calcRootsOfUnity (extent v)
   in	force $ fftWithRoots rofu v


-- | Compute the inverse DFT along the low order dimension of an array.
ifft	:: Shape sh
	=> Array (sh :. Int) Complex
	-> Array (sh :. Int) Complex

ifft v
 = let	_ :. len	= extent v
	scale		= fromIntegral len :*: 0
	rofu		= calcInverseRootsOfUnity (extent v)
   in	force $ A.map (/ scale) $ fftWithRoots rofu v


-- Matrix Transform -------------------------------------------------------------------------------
-- | Compute the DFT of a square matrix.
--   If the matrix is not square then `error`.
fft2d 	:: Array DIM2 Complex
	-> Array DIM2 Complex

fft2d arr
 	| Z :. height :. width	<- extent arr
 	, height /= width	
	= error $ "fft2d: height of matrix (" ++ show height ++ ")"
		++  " does not match width (" ++ show width  ++ ")"

	| otherwise
	= let	rofu		= calcRootsOfUnity (extent arr)
  		fftTrans 	= transpose . fftWithRoots rofu
   	  in	force $ fftTrans $ fftTrans arr


-- | Compute the inverse DFT of a square matrix. 
ifft2d	:: Array DIM2 Complex
	-> Array DIM2 Complex
	
ifft2d arr
 	| Z :. height :. width	<- extent arr
 	, height /= width	
	= error $ "fft2d: height of matrix (" ++ show height ++ ")"
		++  " does not match width (" ++ show width  ++ ")"

	| otherwise
	= let	_ :. height :. width = extent arr
		scale		= fromIntegral (height * width) :*: 0
		rofu		= calcInverseRootsOfUnity (extent arr)
		fftTrans	= transpose . fftWithRoots rofu
	  in	force $ A.map (/ scale) $ fftTrans $ fftTrans arr
	

-- Cube Transform ---------------------------------------------------------------------------------
-- | Compute the DFT of a 3d cube.
--   If the array is not a cube then `error`.
fft3d 	:: Array DIM3 Complex
	-> Array DIM3 Complex

fft3d arrIn
 	| Z :. depth :. height :. width	<- extent arrIn
 	, (height /= width) || (height /= depth)
	= error $ "fft3d: array is not a cube"

	| otherwise
	= let	rofu		= calcRootsOfUnity (extent arrIn)

		transpose3 arr
	 	 = traverse arr 
        		(\(Z :. k :. l :. m)   -> (Z :. l :. m :. k)) 
            		(\f (Z :. l :. m :. k) -> f (Z :. k :. l :. m)) 

		fftTrans	= transpose3 . fftWithRoots rofu
	
  	  in	force $ fftTrans $ fftTrans $ fftTrans arrIn


-- | Compute the inverse DFT of a 3d cube.
--   If the array is not a cube then `error`.
ifft3d 	:: Array DIM3 Complex
	-> Array DIM3 Complex

ifft3d arrIn
 	| Z :. depth :. height :. width	<- extent arrIn
 	, (height /= width) || (height /= depth)
	= error $ "ifft3d: array is not a cube"

	| otherwise
	= let	rofu		= calcInverseRootsOfUnity (extent arrIn)

		transpose3 arr
	 	 = traverse arr 
        		(\(Z :. k :. l :. m)   -> (Z :. l :. m :. k)) 
            		(\f (Z :. l :. m :. k) -> f (Z :. k :. l :. m)) 

		_ :. depth :. height :. width 
				= extent arrIn
		scale		= fromIntegral (height * width * depth) :*: 0

		fftTrans	= transpose3 . fftWithRoots rofu
	  in	force $ A.map (/ scale) $ fftTrans $ fftTrans $ fftTrans arrIn

	
-- Worker -----------------------------------------------------------------------------------------
-- | Generic function for computation of forward or inverse Discrete Fourier Transforms.
--	Computation is along the low order dimension of the array.
fftWithRoots	
	:: forall sh
	.  Shape sh
	=> Array (sh :. Int) Complex		-- ^ Roots of unity.
	-> Array (sh :. Int) Complex		-- ^ Input values.
        -> Array (sh :. Int) Complex

fftWithRoots rofu v
	| not $ (denominator $ toRational (logBase (2 :: Double) $ fromIntegral vLen)) == 1
	= error $ "fft: vector length of " ++ show vLen ++ " is not a power of 2"
	
	| rLen /= vLen
	= error $  "fft: length of vector (" ++ show vLen ++ ")"
		++ " does not match the length of the roots (" ++ show rLen ++ ")"
	
	| otherwise
	= fftWithRoots' rofu v

	where	_ :. rLen	= extent rofu
		_ :. vLen	= extent v

fftWithRoots'
	:: Shape sh
	=> Array (sh :. Int) Complex
	-> Array (sh :. Int) Complex
        -> Array (sh :. Int) Complex

{-# INLINE fftWithRoots' #-}
fftWithRoots' rofu v
 = case extent v of
	_ :. 2	-> fft_two   v
	_	-> fft_split rofu v

{-# INLINE fft_two #-}
fft_two v
 = let	vFn' vFn (sh :. 0)  = vFn (sh :. 0) + vFn (sh :. 1)
	vFn' vFn (sh :. 1)  = vFn (sh :. 0) - vFn (sh :. 1)
	vFn' _   _          = error "Data.Array.Repa.Algorithms.FFT fft_two fail"
   in	traverse v id vFn'
	
{-# INLINE fft_split #-}
fft_split rofu v
 = let 	fft_lr = force $ fftWithRoots' (splitRofu rofu) (splitVector v)

	fft_l  = traverse2 fft_lr rofu 
 		   (\(sh :. 2 :. n) _ -> sh :. n)
		   (\f r (sh :. i)    -> f (sh :. 0 :. i) + r (sh :. i) * f (sh :. 1 :. i))

	fft_r  = traverse2 fft_lr rofu 
		   (\(sh :. 2 :. n) _ -> sh :. n)
		   (\f r (sh :. i)    -> f (sh :. 0 :. i) - r (sh :. i) * f (sh :. 1 :. i))

   in	fft_l +:+ fft_r

{-# INLINE splitRofu #-}
splitRofu rofu
 = traverse rofu
	(\(rSh :. rLen) 	-> rSh :. (2::Int) :. (rLen `div` 2))
	(\rFn (sh :. _ :. i) 	-> rFn (sh :. 2*i))

{-# INLINE splitVector #-}
splitVector v 
 = let	vFn' vFn (sh :. 0 :. i) = vFn (sh :. 2*i)
	vFn' vFn (sh :. 1 :. i) = vFn (sh :. 2*i+1)
	vFn' _   _              = error "Data.Array.Repa.Algorithms.FFT splitVector fail"

   in	traverse v
		(\(vSh :. vLen)    -> vSh :. 2 :. (vLen `div` 2)) 
		vFn'