{-# LANGUAGE TypeOperators, PatternGuards, RankNTypes, ScopedTypeVariables, BangPatterns, FlexibleContexts #-} {-# OPTIONS -fno-warn-incomplete-patterns #-} -- | Fast computation of Discrete Fourier Transforms using the Cooley-Tuckey algorithm. -- Time complexity is O(n log n) in the size of the input. -- -- This uses a naive divide-and-conquer algorithm, the absolute performance is about -- 50x slower than FFTW in estimate mode. -- module Data.Array.Repa.Algorithms.FFT ( Mode(..) , isPowerOfTwo , fft3d , fft2d , fft1d) where import Data.Array.Repa.Algorithms.Complex import Data.Array.Repa as A data Mode = Forward | Reverse | Inverse deriving (Show, Eq) {-# INLINE signOfMode #-} signOfMode :: Mode -> Double signOfMode mode = case mode of Forward -> (-1) Reverse -> 1 Inverse -> 1 {-# INLINE isPowerOfTwo #-} -- | Check if an `Int` is a power of two. isPowerOfTwo :: Int -> Bool isPowerOfTwo x = let r = (log (fromIntegral x) / log 2) :: Double in ceiling r == (floor r :: Int) -- 3D Transform ----------------------------------------------------------------------------------- -- | Compute the DFT of a 3d array. Array dimensions must be powers of two else `error`. fft3d :: Mode -> Array DIM3 Complex -> Array DIM3 Complex fft3d mode arr = let _ :. depth :. height :. width = extent arr !sign = signOfMode mode !scale = fromIntegral (depth * width * height) in if not (isPowerOfTwo depth && isPowerOfTwo height && isPowerOfTwo width) then error "Data.Array.Repa.Algorithms.FFT: fft3d -- array dimensions must be powers of two." else arr `deepSeqArray` case mode of Forward -> fftTrans3d sign $ fftTrans3d sign $ fftTrans3d sign arr Reverse -> fftTrans3d sign $ fftTrans3d sign $ fftTrans3d sign arr Inverse -> force $ A.map (/ scale) $ fftTrans3d sign $ fftTrans3d sign $ fftTrans3d sign arr fftTrans3d :: Double -> Array DIM3 Complex -> Array DIM3 Complex {-# NOINLINE fftTrans3d #-} fftTrans3d sign arr' = let arr = force arr' (sh :. len) = extent arr in force $ rotate3d $ fft sign sh len arr rotate3d :: Array DIM3 Complex -> Array DIM3 Complex {-# INLINE rotate3d #-} rotate3d arr = backpermute (sh :. m :. k :. l) f arr where (sh :. k :. l :. m) = extent arr f (sh' :. m' :. k' :. l') = sh' :. k' :. l' :. m' -- Matrix Transform ------------------------------------------------------------------------------- -- | Compute the DFT of a matrix. Array dimensions must be powers of two else `error`. fft2d :: Mode -> Array DIM2 Complex -> Array DIM2 Complex fft2d mode arr = let _ :. height :. width = extent arr sign = signOfMode mode scale = fromIntegral (width * height) in if not (isPowerOfTwo height && isPowerOfTwo width) then error "Data.Array.Repa.Algorithms.FFT: fft2d -- array dimensions must be powers of two." else arr `deepSeqArray` case mode of Forward -> fftTrans2d sign $ fftTrans2d sign arr Reverse -> fftTrans2d sign $ fftTrans2d sign arr Inverse -> force $ A.map (/ scale) $ fftTrans2d sign $ fftTrans2d sign arr fftTrans2d :: Double -> Array DIM2 Complex -> Array DIM2 Complex {-# NOINLINE fftTrans2d #-} fftTrans2d sign arr' = let arr = force arr' (sh :. len) = extent arr in force $ transpose $ fft sign sh len arr -- Vector Transform ------------------------------------------------------------------------------- -- | Compute the DFT of a vector. Array dimensions must be powers of two else `error`. fft1d :: Mode -> Array DIM1 Complex -> Array DIM1 Complex fft1d mode arr = let _ :. len = extent arr sign = signOfMode mode scale = fromIntegral len in if not $ isPowerOfTwo len then error "Data.Array.Repa.Algorithms.FFT: fft1d -- array dimensions must be powers of two." else arr `deepSeqArray` case mode of Forward -> fftTrans1d sign arr Reverse -> fftTrans1d sign arr Inverse -> force $ A.map (/ scale) $ fftTrans1d sign arr fftTrans1d :: Double -> Array DIM1 Complex -> Array DIM1 Complex {-# NOINLINE fftTrans1d #-} fftTrans1d sign arr' = let arr = force arr' (sh :. len) = extent arr in fft sign sh len arr -- Rank Generalised Worker ------------------------------------------------------------------------ {-# INLINE fft #-} fft !sign !sh !lenVec !vec = go lenVec 0 1 where go !len !offset !stride | len == 2 = force $ fromFunction (sh :. 2) swivel | otherwise = combine len (go (len `div` 2) offset (stride * 2)) (go (len `div` 2) (offset + stride) (stride * 2)) where swivel (sh' :. ix) = case ix of 0 -> (vec `unsafeIndex` (sh' :. offset)) + (vec `unsafeIndex` (sh' :. (offset + stride))) 1 -> (vec `unsafeIndex` (sh' :. offset)) - (vec `unsafeIndex` (sh' :. (offset + stride))) {-# INLINE combine #-} combine !len' evens@(Array _ [Region RangeAll GenManifest{}]) odds@(Array _ [Region RangeAll GenManifest{}]) = evens `deepSeqArray` odds `deepSeqArray` let odds' = unsafeTraverse odds id (\get ix@(_ :. k) -> twiddle sign k len' * get ix) in force $ (evens +^ odds') A.++ (evens -^ odds') -- Compute a twiddle factor. twiddle :: Double -> Int -- index -> Int -- length -> Complex {-# INLINE twiddle #-} twiddle sign k' n' = (cos (2 * pi * k / n), sign * sin (2 * pi * k / n)) where k = fromIntegral k' n = fromIntegral n'