module Data.Array.Repa.Algorithms.FFT
( Mode(..)
, isPowerOfTwo
, fft3dP
, fft2dP
, fft1dP)
where
import Data.Array.Repa.Algorithms.Complex
import Data.Array.Repa as R
import Data.Array.Repa.Eval as R
import Data.Array.Repa.Unsafe as R
import Prelude as P
data Mode
= Forward
| Reverse
| Inverse
deriving (Show, Eq)
signOfMode :: Mode -> Double
signOfMode mode
= case mode of
Forward -> (1)
Reverse -> 1
Inverse -> 1
isPowerOfTwo :: Int -> Bool
isPowerOfTwo n
| 0 <- n = True
| 2 <- n = True
| n `mod` 2 == 0 = isPowerOfTwo (n `div` 2)
| otherwise = False
fft3dP :: (Source r Complex, Monad m)
=> Mode
-> Array r DIM3 Complex
-> m (Array U DIM3 Complex)
fft3dP mode arr
= let _ :. depth :. height :. width = extent arr
!sign = signOfMode mode
!scale = fromIntegral (depth * width * height)
in if not (isPowerOfTwo depth && isPowerOfTwo height && isPowerOfTwo width)
then error $ unlines
[ "Data.Array.Repa.Algorithms.FFT: fft3d"
, " Array dimensions must be powers of two,"
, " but the provided array is "
P.++ show height P.++ "x" P.++ show width P.++ "x" P.++ show depth ]
else arr `deepSeqArray`
case mode of
Forward -> now $ fftTrans3d sign $ fftTrans3d sign $ fftTrans3d sign arr
Reverse -> now $ fftTrans3d sign $ fftTrans3d sign $ fftTrans3d sign arr
Inverse -> computeP
$ R.map (/ scale)
$ fftTrans3d sign $ fftTrans3d sign $ fftTrans3d sign arr
fftTrans3d
:: Source r Complex
=> Double
-> Array r DIM3 Complex
-> Array U DIM3 Complex
fftTrans3d sign arr
= let (sh :. len) = extent arr
in suspendedComputeP $ rotate3d $ fft sign sh len arr
rotate3d
:: Source r Complex
=> Array r DIM3 Complex -> Array D DIM3 Complex
rotate3d arr
= backpermute (sh :. m :. k :. l) f arr
where (sh :. k :. l :. m) = extent arr
f (sh' :. m' :. k' :. l') = sh' :. k' :. l' :. m'
fft2dP :: (Source r Complex, Monad m)
=> Mode
-> Array r DIM2 Complex
-> m (Array U DIM2 Complex)
fft2dP mode arr
= let _ :. height :. width = extent arr
sign = signOfMode mode
scale = fromIntegral (width * height)
in if not (isPowerOfTwo height && isPowerOfTwo width)
then error $ unlines
[ "Data.Array.Repa.Algorithms.FFT: fft2d"
, " Array dimensions must be powers of two,"
, " but the provided array is " P.++ show height P.++ "x" P.++ show width ]
else arr `deepSeqArray`
case mode of
Forward -> now $ fftTrans2d sign $ fftTrans2d sign arr
Reverse -> now $ fftTrans2d sign $ fftTrans2d sign arr
Inverse -> computeP $ R.map (/ scale) $ fftTrans2d sign $ fftTrans2d sign arr
fftTrans2d
:: Source r Complex
=> Double
-> Array r DIM2 Complex
-> Array U DIM2 Complex
fftTrans2d sign arr
= let (sh :. len) = extent arr
in suspendedComputeP $ transpose $ fft sign sh len arr
fft1dP :: (Source r Complex, Monad m)
=> Mode
-> Array r DIM1 Complex
-> m (Array U DIM1 Complex)
fft1dP mode arr
= let _ :. len = extent arr
sign = signOfMode mode
scale = fromIntegral len
in if not $ isPowerOfTwo len
then error $ unlines
[ "Data.Array.Repa.Algorithms.FFT: fft1d"
, " Array dimensions must be powers of two, "
, " but the provided array is " P.++ show len ]
else arr `deepSeqArray`
case mode of
Forward -> now $ fftTrans1d sign arr
Reverse -> now $ fftTrans1d sign arr
Inverse -> computeP $ R.map (/ scale) $ fftTrans1d sign arr
fftTrans1d
:: Source r Complex
=> Double
-> Array r DIM1 Complex
-> Array U DIM1 Complex
fftTrans1d sign arr
= let (sh :. len) = extent arr
in fft sign sh len arr
fft :: (Shape sh, Source r Complex)
=> Double -> sh -> Int
-> Array r (sh :. Int) Complex
-> Array U (sh :. Int) Complex
fft !sign !sh !lenVec !vec
= go lenVec 0 1
where go !len !offset !stride
| len == 2
= suspendedComputeP $ fromFunction (sh :. 2) swivel
| otherwise
= combine len
(go (len `div` 2) offset (stride * 2))
(go (len `div` 2) (offset + stride) (stride * 2))
where swivel (sh' :. ix)
= case ix of
0 -> (vec `unsafeIndex` (sh' :. offset)) + (vec `unsafeIndex` (sh' :. (offset + stride)))
1 -> (vec `unsafeIndex` (sh' :. offset)) (vec `unsafeIndex` (sh' :. (offset + stride)))
combine !len' evens odds
= evens `deepSeqArray` odds `deepSeqArray`
let odds' = unsafeTraverse odds id (\get ix@(_ :. k) -> twiddle sign k len' * get ix)
in suspendedComputeP $ (evens +^ odds') R.++ (evens -^ odds')
twiddle :: Double
-> Int
-> Int
-> Complex
twiddle sign k' n'
= (cos (2 * pi * k / n), sign * sin (2 * pi * k / n))
where k = fromIntegral k'
n = fromIntegral n'