{-# LANGUAGE TypeOperators, PatternGuards, RankNTypes, ScopedTypeVariables, BangPatterns, FlexibleContexts #-} {-# OPTIONS -fno-warn-incomplete-patterns #-} -- | Fast computation of Discrete Fourier Transforms using the Cooley-Tuckey algorithm. -- Time complexity is O(n log n) in the size of the input. -- -- This uses a naive divide-and-conquer algorithm, the absolute performance is about -- 50x slower than FFTW in estimate mode. -- module Data.Array.Repa.Algorithms.FFT ( Mode(..) , isPowerOfTwo , fft3dP , fft2dP , fft1dP) where import Data.Array.Repa.Algorithms.Complex import Data.Array.Repa as R import Data.Array.Repa.Eval as R import Data.Array.Repa.Unsafe as R import Prelude as P data Mode = Forward | Reverse | Inverse deriving (Show, Eq) signOfMode :: Mode -> Double signOfMode mode = case mode of Forward -> (-1) Reverse -> 1 Inverse -> 1 {-# INLINE signOfMode #-} -- | Check if an `Int` is a power of two. isPowerOfTwo :: Int -> Bool isPowerOfTwo n | 0 <- n = True | 2 <- n = True | n `mod` 2 == 0 = isPowerOfTwo (n `div` 2) | otherwise = False {-# INLINE isPowerOfTwo #-} -- 3D Transform ----------------------------------------------------------------------------------- -- | Compute the DFT of a 3d array. Array dimensions must be powers of two else `error`. fft3dP :: (Source r Complex, Monad m) => Mode -> Array r DIM3 Complex -> m (Array U DIM3 Complex) fft3dP mode arr = let _ :. depth :. height :. width = extent arr !sign = signOfMode mode !scale = fromIntegral (depth * width * height) in if not (isPowerOfTwo depth && isPowerOfTwo height && isPowerOfTwo width) then error $ unlines [ "Data.Array.Repa.Algorithms.FFT: fft3d" , " Array dimensions must be powers of two," , " but the provided array is " P.++ show height P.++ "x" P.++ show width P.++ "x" P.++ show depth ] else arr `deepSeqArray` case mode of Forward -> now $ fftTrans3d sign $ fftTrans3d sign $ fftTrans3d sign arr Reverse -> now $ fftTrans3d sign $ fftTrans3d sign $ fftTrans3d sign arr Inverse -> computeP $ R.map (/ scale) $ fftTrans3d sign $ fftTrans3d sign $ fftTrans3d sign arr {-# INLINE fft3dP #-} fftTrans3d :: Source r Complex => Double -> Array r DIM3 Complex -> Array U DIM3 Complex fftTrans3d sign arr = let (sh :. len) = extent arr in suspendedComputeP $ rotate3d $ fft sign sh len arr {-# INLINE fftTrans3d #-} rotate3d :: Source r Complex => Array r DIM3 Complex -> Array D DIM3 Complex rotate3d arr = backpermute (sh :. m :. k :. l) f arr where (sh :. k :. l :. m) = extent arr f (sh' :. m' :. k' :. l') = sh' :. k' :. l' :. m' {-# INLINE rotate3d #-} -- Matrix Transform ------------------------------------------------------------------------------- -- | Compute the DFT of a matrix. Array dimensions must be powers of two else `error`. fft2dP :: (Source r Complex, Monad m) => Mode -> Array r DIM2 Complex -> m (Array U DIM2 Complex) fft2dP mode arr = let _ :. height :. width = extent arr sign = signOfMode mode scale = fromIntegral (width * height) in if not (isPowerOfTwo height && isPowerOfTwo width) then error $ unlines [ "Data.Array.Repa.Algorithms.FFT: fft2d" , " Array dimensions must be powers of two," , " but the provided array is " P.++ show height P.++ "x" P.++ show width ] else arr `deepSeqArray` case mode of Forward -> now $ fftTrans2d sign $ fftTrans2d sign arr Reverse -> now $ fftTrans2d sign $ fftTrans2d sign arr Inverse -> computeP $ R.map (/ scale) $ fftTrans2d sign $ fftTrans2d sign arr {-# INLINE fft2dP #-} fftTrans2d :: Source r Complex => Double -> Array r DIM2 Complex -> Array U DIM2 Complex fftTrans2d sign arr = let (sh :. len) = extent arr in suspendedComputeP $ transpose $ fft sign sh len arr {-# INLINE fftTrans2d #-} -- Vector Transform ------------------------------------------------------------------------------- -- | Compute the DFT of a vector. Array dimensions must be powers of two else `error`. fft1dP :: (Source r Complex, Monad m) => Mode -> Array r DIM1 Complex -> m (Array U DIM1 Complex) fft1dP mode arr = let _ :. len = extent arr sign = signOfMode mode scale = fromIntegral len in if not $ isPowerOfTwo len then error $ unlines [ "Data.Array.Repa.Algorithms.FFT: fft1d" , " Array dimensions must be powers of two, " , " but the provided array is " P.++ show len ] else arr `deepSeqArray` case mode of Forward -> now $ fftTrans1d sign arr Reverse -> now $ fftTrans1d sign arr Inverse -> computeP $ R.map (/ scale) $ fftTrans1d sign arr {-# INLINE fft1dP #-} fftTrans1d :: Source r Complex => Double -> Array r DIM1 Complex -> Array U DIM1 Complex fftTrans1d sign arr = let (sh :. len) = extent arr in fft sign sh len arr {-# INLINE fftTrans1d #-} -- Rank Generalised Worker ------------------------------------------------------------------------ fft :: (Shape sh, Source r Complex) => Double -> sh -> Int -> Array r (sh :. Int) Complex -> Array U (sh :. Int) Complex fft !sign !sh !lenVec !vec = go lenVec 0 1 where go !len !offset !stride | len == 2 = suspendedComputeP $ fromFunction (sh :. 2) swivel | otherwise = combine len (go (len `div` 2) offset (stride * 2)) (go (len `div` 2) (offset + stride) (stride * 2)) where swivel (sh' :. ix) = case ix of 0 -> (vec `unsafeIndex` (sh' :. offset)) + (vec `unsafeIndex` (sh' :. (offset + stride))) 1 -> (vec `unsafeIndex` (sh' :. offset)) - (vec `unsafeIndex` (sh' :. (offset + stride))) {-# INLINE combine #-} combine !len' evens odds = evens `deepSeqArray` odds `deepSeqArray` let odds' = unsafeTraverse odds id (\get ix@(_ :. k) -> twiddle sign k len' * get ix) in suspendedComputeP $ (evens +^ odds') R.++ (evens -^ odds') {-# INLINE fft #-} -- Compute a twiddle factor. twiddle :: Double -> Int -- index -> Int -- length -> Complex twiddle sign k' n' = (cos (2 * pi * k / n), sign * sin (2 * pi * k / n)) where k = fromIntegral k' n = fromIntegral n' {-# INLINE twiddle #-}