```{-# LANGUAGE TypeOperators, RankNTypes, PatternGuards #-}

-- | Compute the Discrete Fourier Transform (DFT) along the low order dimension
--   of an array.
--
--   This uses the naive algorithm and takes O(n^2) time.
--   However, you can transform an array with an arbitray extent, unlike with FFT which requires
--   each dimension to be a power of two.
--
--   The `dft` and `idft` functions also compute the roots of unity needed.
--   If you need to transform several arrays with the same extent then it is faster to
--   compute the roots once using `calcRootsOfUnity` or `calcInverseRootsOfUnity`,
--   then call `dftWithRoots` directly.
--
--   You can also compute single values of the transform using `dftWithRootsSingle`.
module Data.Array.Repa.Algorithms.DFT
( dftP
, idftP
, dftWithRootsP
, dftWithRootsSingleS)
where
import Data.Array.Repa.Algorithms.DFT.Roots
import Data.Array.Repa.Algorithms.Complex
import Data.Array.Repa				as R
import Prelude					as P

-- | Compute the DFT along the low order dimension of an array.
dftP 	:: (Shape sh, Monad m)
=> Array U (sh :. Int) Complex
-> m (Array U (sh :. Int) Complex)

dftP v
= do   rofu	<- calcRootsOfUnityP (extent v)
dftWithRootsP rofu v
{-# INLINE dftP #-}

-- | Compute the inverse DFT along the low order dimension of an array.
idftP 	:: (Shape sh, Monad m)
=> Array U (sh :. Int) Complex
-> m (Array U (sh :. Int) Complex)

idftP v
= do   let _ :. len	= extent v
let scale	= (fromIntegral len, 0)
rofu		<- calcInverseRootsOfUnityP (extent v)
roots           <- dftWithRootsP rofu v
computeP \$ R.map (/ scale) roots
{-# INLINE idftP #-}

-- | Generic function for computation of forward or inverse DFT.
--	This function is also useful if you transform many arrays with the same extent,
--	and don't want to recompute the roots for each one.
--	The extent of the given roots must match that of the input array, else `error`.
dftWithRootsP
=> Array U (sh :. Int) Complex		-- ^ Roots of unity.
-> Array U (sh :. Int) Complex		-- ^ Input array.
-> m (Array U (sh :. Int) Complex)

dftWithRootsP rofu arr
| _ :. rLen 	<- extent rofu
, _ :. vLen 	<- extent arr
, rLen /= vLen
= error \$    "dftWithRoots: length of vector (" P.++ show vLen P.++ ")"
P.++ " does not match the length of the roots (" P.++ show rLen P.++ ")"

| otherwise
= computeP \$ traverse arr id (\_ k -> dftWithRootsSingleS rofu arr k)
{-# INLINE dftWithRootsP #-}

-- | Compute a single value of the DFT.
--	The extent of the given roots must match that of the input array, else `error`.
dftWithRootsSingleS
:: Shape sh
=> Array U (sh :. Int) Complex 		-- ^ Roots of unity.
-> Array U (sh :. Int) Complex		-- ^ Input array.
-> (sh :. Int)				-- ^ Index of the value we want.
-> Complex

dftWithRootsSingleS rofu arrX (_ :. k)
| _ :. rLen 	<- extent rofu
, _ :. vLen 	<- extent arrX
, rLen /= vLen
= error \$    "dftWithRootsSingle: length of vector (" P.++ show vLen P.++ ")"
P.++ " does not match the length of the roots (" P.++ show rLen P.++ ")"

| otherwise
= let	sh@(_ :. len)	= extent arrX

-- All the roots we need to multiply with.
wroots		= fromFunction sh elemFn
elemFn (sh' :. n)
= rofu ! (sh' :. (k * n) `mod` len)

in  R.sumAllS \$ R.zipWith (*) arrX wroots
{-# INLINE dftWithRootsSingleS #-}

```