{-# LANGUAGE TypeOperators, RankNTypes, PatternGuards #-}

-- | Compute the Discrete Fourier Transform (DFT) along the low order dimension
--   of an array. 
--
--   This uses the naive algorithm and takes O(n^2) time. 
--   However, you can transform an array with an arbitray extent, unlike with FFT which requires
--   each dimension to be a power of two.
--
--   The `dft` and `idft` functions also compute the roots of unity needed.
--   If you need to transform several arrays with the same extent then it is faster to
--   compute the roots once using `calcRootsOfUnity` or `calcInverseRootsOfUnity`, 
--   then call `dftWithRoots` directly.
--
--   You can also compute single values of the transform using `dftWithRootsSingle`.
module Data.Array.Repa.Algorithms.DFT 
	( dftP
	, idftP
	, dftWithRootsP
	, dftWithRootsSingleS)
where
import Data.Array.Repa.Algorithms.DFT.Roots
import Data.Array.Repa.Algorithms.Complex
import Data.Array.Repa				as R
import Prelude					as P


-- | Compute the DFT along the low order dimension of an array.
dftP 	:: (Shape sh, Monad m)
	=> Array U (sh :. Int) Complex
	-> m (Array U (sh :. Int) Complex)

dftP v
 = do   rofu	<- calcRootsOfUnityP (extent v)
        dftWithRootsP rofu v
{-# INLINE dftP #-}


-- | Compute the inverse DFT along the low order dimension of an array.
idftP 	:: (Shape sh, Monad m)
	=> Array U (sh :. Int) Complex
	-> m (Array U (sh :. Int) Complex)

idftP v
 = do   let _ :. len	= extent v
	let scale	= (fromIntegral len, 0)
	rofu		<- calcInverseRootsOfUnityP (extent v)
        roots           <- dftWithRootsP rofu v
        computeP $ R.map (/ scale) roots
{-# INLINE idftP #-}


-- | Generic function for computation of forward or inverse DFT.
--	This function is also useful if you transform many arrays with the same extent, 
--	and don't want to recompute the roots for each one.
--	The extent of the given roots must match that of the input array, else `error`.
dftWithRootsP
	:: (Shape sh, Monad m)
	=> Array U (sh :. Int) Complex		-- ^ Roots of unity.
	-> Array U (sh :. Int) Complex		-- ^ Input array.
	-> m (Array U (sh :. Int) Complex)

dftWithRootsP rofu arr
	| _ :. rLen 	<- extent rofu
	, _ :. vLen 	<- extent arr
	, rLen /= vLen
	= error $    "dftWithRoots: length of vector (" P.++ show vLen P.++ ")"
		P.++ " does not match the length of the roots (" P.++ show rLen P.++ ")"

	| otherwise
	= computeP $ traverse arr id (\_ k -> dftWithRootsSingleS rofu arr k)
{-# INLINE dftWithRootsP #-}		


-- | Compute a single value of the DFT.
--	The extent of the given roots must match that of the input array, else `error`.
dftWithRootsSingleS
	:: Shape sh
	=> Array U (sh :. Int) Complex 		-- ^ Roots of unity.
	-> Array U (sh :. Int) Complex		-- ^ Input array.
	-> (sh :. Int)				-- ^ Index of the value we want.
	-> Complex

dftWithRootsSingleS rofu arrX (_ :. k)
	| _ :. rLen 	<- extent rofu
	, _ :. vLen 	<- extent arrX
	, rLen /= vLen
	= error $    "dftWithRootsSingle: length of vector (" P.++ show vLen P.++ ")"
		P.++ " does not match the length of the roots (" P.++ show rLen P.++ ")"

	| otherwise
	= let	sh@(_ :. len)	= extent arrX

		-- All the roots we need to multiply with.
		wroots		= fromFunction sh elemFn
		elemFn (sh' :. n) 
			= rofu ! (sh' :. (k * n) `mod` len)

	  in  R.sumAllS $ R.zipWith (*) arrX wroots
{-# INLINE dftWithRootsSingleS #-}