module Data.Array.Repa.Algorithms.FFT
( Mode(..)
, isPowerOfTwo
, fft3d
, fft2d
, fft1d)
where
import Data.Array.Repa.Algorithms.Complex
import Data.Array.Repa as A
data Mode
= Forward
| Reverse
| Inverse
deriving (Show, Eq)
signOfMode :: Mode -> Double
signOfMode mode
= case mode of
Forward -> (1)
Reverse -> 1
Inverse -> 1
isPowerOfTwo :: Int -> Bool
isPowerOfTwo x
= let r = (log (fromIntegral x) / log 2) :: Double
in ceiling r == (floor r :: Int)
fft3d :: Mode
-> Array DIM3 Complex
-> Array DIM3 Complex
fft3d mode arr
= let _ :. depth :. height :. width = extent arr
!sign = signOfMode mode
!scale = fromIntegral (depth * width * height)
in if not (isPowerOfTwo depth && isPowerOfTwo height && isPowerOfTwo width)
then error "Data.Array.Repa.Algorithms.FFT: fft3d -- array dimensions must be powers of two."
else arr `deepSeqArray`
case mode of
Forward -> fftTrans3d sign $ fftTrans3d sign $ fftTrans3d sign arr
Reverse -> fftTrans3d sign $ fftTrans3d sign $ fftTrans3d sign arr
Inverse -> force $ A.map (/ scale)
$ fftTrans3d sign $ fftTrans3d sign $ fftTrans3d sign arr
fftTrans3d
:: Double
-> Array DIM3 Complex
-> Array DIM3 Complex
fftTrans3d sign arr'
= let arr = force arr'
(sh :. len) = extent arr
in force $ rotate3d $ fft sign sh len arr
rotate3d :: Array DIM3 Complex -> Array DIM3 Complex
rotate3d arr
= backpermute (sh :. m :. k :. l) f arr
where (sh :. k :. l :. m) = extent arr
f (sh' :. m' :. k' :. l') = sh' :. k' :. l' :. m'
fft2d :: Mode
-> Array DIM2 Complex
-> Array DIM2 Complex
fft2d mode arr
= let _ :. height :. width = extent arr
sign = signOfMode mode
scale = fromIntegral (width * height)
in if not (isPowerOfTwo height && isPowerOfTwo width)
then error "Data.Array.Repa.Algorithms.FFT: fft2d -- array dimensions must be powers of two."
else arr `deepSeqArray`
case mode of
Forward -> fftTrans2d sign $ fftTrans2d sign arr
Reverse -> fftTrans2d sign $ fftTrans2d sign arr
Inverse -> force $ A.map (/ scale) $ fftTrans2d sign $ fftTrans2d sign arr
fftTrans2d
:: Double
-> Array DIM2 Complex
-> Array DIM2 Complex
fftTrans2d sign arr'
= let arr = force arr'
(sh :. len) = extent arr
in force $ transpose $ fft sign sh len arr
fft1d :: Mode
-> Array DIM1 Complex
-> Array DIM1 Complex
fft1d mode arr
= let _ :. len = extent arr
sign = signOfMode mode
scale = fromIntegral len
in if not $ isPowerOfTwo len
then error "Data.Array.Repa.Algorithms.FFT: fft1d -- array dimensions must be powers of two."
else arr `deepSeqArray`
case mode of
Forward -> fftTrans1d sign arr
Reverse -> fftTrans1d sign arr
Inverse -> force $ A.map (/ scale) $ fftTrans1d sign arr
fftTrans1d
:: Double
-> Array DIM1 Complex
-> Array DIM1 Complex
fftTrans1d sign arr'
= let arr = force arr'
(sh :. len) = extent arr
in fft sign sh len arr
fft !sign !sh !lenVec !vec
= go lenVec 0 1
where go !len !offset !stride
| len == 2
= force $ fromFunction (sh :. 2) swivel
| otherwise
= combine len
(go (len `div` 2) offset (stride * 2))
(go (len `div` 2) (offset + stride) (stride * 2))
where swivel (sh' :. ix)
= case ix of
0 -> (vec `unsafeIndex` (sh' :. offset)) + (vec `unsafeIndex` (sh' :. (offset + stride)))
1 -> (vec `unsafeIndex` (sh' :. offset)) (vec `unsafeIndex` (sh' :. (offset + stride)))
combine !len' evens@(Array _ [Region RangeAll GenManifest{}])
odds@(Array _ [Region RangeAll GenManifest{}])
= evens `deepSeqArray` odds `deepSeqArray`
let odds' = unsafeTraverse odds id (\get ix@(_ :. k) -> twiddle sign k len' * get ix)
in force $ (evens +^ odds') A.++ (evens -^ odds')
twiddle :: Double
-> Int
-> Int
-> Complex
twiddle sign k' n'
= (cos (2 * pi * k / n), sign * sin (2 * pi * k / n))
where k = fromIntegral k'
n = fromIntegral n'