{-# LANGUAGE TypeOperators, RankNTypes, PatternGuards #-} -- | Compute the Discrete Fourier Transform (DFT) along the low order dimension -- of an array. -- -- This uses the naive algorithm and takes O(n^2) time. -- However, you can transform an array with an arbitray extent, unlike with FFT which requires -- each dimension to be a power of two. -- -- The `dft` and `idft` functions also compute the roots of unity needed. -- If you need to transform several arrays with the same extent then it is faster to -- compute the roots once using `calcRootsOfUnity` or `calcInverseRootsOfUnity`, -- then call `dftWithRoots` directly. -- -- You can also compute single values of the transform using `dftWithRootsSingle`. module Data.Array.Repa.Algorithms.DFT ( dft , idft , dftWithRoots , dftWithRootsSingle) where import Data.Array.Repa.Algorithms.DFT.Roots import Data.Array.Repa.Algorithms.Complex import Data.Array.Repa as A import Prelude as P -- | Compute the DFT along the low order dimension of an array. dft :: forall sh . Shape sh => Array (sh :. Int) Complex -> Array (sh :. Int) Complex dft v = let rofu = calcRootsOfUnity (extent v) in force $ dftWithRoots rofu v -- | Compute the inverse DFT along the low order dimension of an array. idft :: forall sh . Shape sh => Array (sh :. Int) Complex -> Array (sh :. Int) Complex idft v = let _ :. len = extent v scale = (fromIntegral len, 0) rofu = calcInverseRootsOfUnity (extent v) in force $ A.map (/ scale) $ dftWithRoots rofu v -- | Generic function for computation of forward or inverse DFT. -- This function is also useful if you transform many arrays with the same extent, -- and don't want to recompute the roots for each one. -- The extent of the given roots must match that of the input array, else `error`. dftWithRoots :: forall sh . Shape sh => Array (sh :. Int) Complex -- ^ Roots of unity. -> Array (sh :. Int) Complex -- ^ Input array. -> Array (sh :. Int) Complex dftWithRoots rofu arr | _ :. rLen <- extent rofu , _ :. vLen <- extent arr , rLen /= vLen = error $ "dftWithRoots: length of vector (" P.++ show vLen P.++ ")" P.++ " does not match the length of the roots (" P.++ show rLen P.++ ")" | otherwise = traverse arr id (\_ k -> dftWithRootsSingle rofu arr k) -- | Compute a single value of the DFT. -- The extent of the given roots must match that of the input array, else `error`. dftWithRootsSingle :: forall sh . Shape sh => Array (sh :. Int) Complex -- ^ Roots of unity. -> Array (sh :. Int) Complex -- ^ Input array. -> (sh :. Int) -- ^ Index of the value we want. -> Complex {-# INLINE dftWithRootsSingle #-} dftWithRootsSingle rofu arrX (_ :. k) | _ :. rLen <- extent rofu , _ :. vLen <- extent arrX , rLen /= vLen = error $ "dftWithRootsSingle: length of vector (" P.++ show vLen P.++ ")" P.++ " does not match the length of the roots (" P.++ show rLen P.++ ")" | otherwise = let sh@(_ :. len) = extent arrX -- All the roots we need to multiply with. wroots = fromFunction sh elemFn elemFn (sh' :. n) = rofu ! (sh' :. (k * n) `mod` len) in A.sumAll $ A.zipWith (*) arrX wroots