```{-# LANGUAGE TypeOperators, RankNTypes, PatternGuards #-}

-- | Compute the Discrete Fourier Transform (DFT) along the low order dimension
--   of an array.
--
--   This uses the naive algorithm and takes O(n^2) time.
--   However, you can transform an array with an arbitray extent, unlike with FFT which requires
--   each dimension to be a power of two.
--
--   The `dft` and `idft` functions also compute the roots of unity needed.
--   If you need to transform several arrays with the same extent then it is faster to
--   compute the roots once using `calcRootsOfUnity` or `calcInverseRootsOfUnity`,
--   then call `dftWithRoots` directly.
--
--   You can also compute single values of the transform using `dftWithRootsSingle`.
module Data.Array.Repa.Algorithms.DFT
( dft
, idft
, dftWithRoots
, dftWithRootsSingle)
where
import Data.Array.Repa.Algorithms.DFT.Roots
import Data.Array.Repa.Algorithms.Complex
import Data.Array.Repa				as A
import Prelude					as P

-- | Compute the DFT along the low order dimension of an array.
dft 	:: forall sh
.  Shape sh
=> Array U (sh :. Int) Complex
-> Array U (sh :. Int) Complex

dft v
= let	rofu	= calcRootsOfUnity (extent v)
in	dftWithRoots rofu v

-- | Compute the inverse DFT along the low order dimension of an array.
idft 	:: forall sh
.  Shape sh
=> Array U (sh :. Int) Complex
-> Array U (sh :. Int) Complex

idft v
= let	_ :. len	= extent v
scale		= (fromIntegral len, 0)
rofu		= calcInverseRootsOfUnity (extent v)
in	computeP \$ A.map (/ scale) \$ dftWithRoots rofu v

-- | Generic function for computation of forward or inverse DFT.
--	This function is also useful if you transform many arrays with the same extent,
--	and don't want to recompute the roots for each one.
--	The extent of the given roots must match that of the input array, else `error`.
dftWithRoots
:: forall sh
.  Shape sh
=> Array U (sh :. Int) Complex		-- ^ Roots of unity.
-> Array U (sh :. Int) Complex		-- ^ Input array.
-> Array U (sh :. Int) Complex

dftWithRoots rofu arr
| _ :. rLen 	<- extent rofu
, _ :. vLen 	<- extent arr
, rLen /= vLen
= error \$    "dftWithRoots: length of vector (" P.++ show vLen P.++ ")"
P.++ " does not match the length of the roots (" P.++ show rLen P.++ ")"

| otherwise
= computeP \$ traverse arr id (\_ k -> dftWithRootsSingle rofu arr k)

-- | Compute a single value of the DFT.
--	The extent of the given roots must match that of the input array, else `error`.
dftWithRootsSingle
:: forall sh
.  Shape sh
=> Array U (sh :. Int) Complex 		-- ^ Roots of unity.
-> Array U (sh :. Int) Complex		-- ^ Input array.
-> (sh :. Int)				-- ^ Index of the value we want.
-> Complex

{-# INLINE dftWithRootsSingle #-}
dftWithRootsSingle rofu arrX (_ :. k)
| _ :. rLen 	<- extent rofu
, _ :. vLen 	<- extent arrX
, rLen /= vLen
= error \$    "dftWithRootsSingle: length of vector (" P.++ show vLen P.++ ")"
P.++ " does not match the length of the roots (" P.++ show rLen P.++ ")"

| otherwise
= let	sh@(_ :. len)	= extent arrX

-- All the roots we need to multiply with.
wroots		= fromFunction sh elemFn
elemFn (sh' :. n)
= rofu ! (sh' :. (k * n) `mod` len)

in  A.sumAllP \$ A.zipWith (*) arrX wroots

```