module Data.Array.Repa.Algorithms.FFT
( fft, ifft
, fft2d, ifft2d
, fft3d, ifft3d
, fftWithRoots )
where
import Data.Array.Repa.Algorithms.DFT.Roots
import Data.Array.Repa.Algorithms.Complex
import Data.Array.Repa as A
import Data.Ratio
fft :: Shape sh
=> Array (sh :. Int) Complex
-> Array (sh :. Int) Complex
fft v
= let rofu = calcRootsOfUnity (extent v)
in force $ fftWithRoots rofu v
ifft :: Shape sh
=> Array (sh :. Int) Complex
-> Array (sh :. Int) Complex
ifft v
= let _ :. len = extent v
scale = fromIntegral len :*: 0
rofu = calcInverseRootsOfUnity (extent v)
in force $ A.map (/ scale) $ fftWithRoots rofu v
fft2d :: Array DIM2 Complex
-> Array DIM2 Complex
fft2d arr
| Z :. height :. width <- extent arr
, height /= width
= error $ "fft2d: height of matrix (" ++ show height ++ ")"
++ " does not match width (" ++ show width ++ ")"
| otherwise
= let rofu = calcRootsOfUnity (extent arr)
fftTrans = transpose . fftWithRoots rofu
in force $ fftTrans $ fftTrans arr
ifft2d :: Array DIM2 Complex
-> Array DIM2 Complex
ifft2d arr
| Z :. height :. width <- extent arr
, height /= width
= error $ "fft2d: height of matrix (" ++ show height ++ ")"
++ " does not match width (" ++ show width ++ ")"
| otherwise
= let _ :. height :. width = extent arr
scale = fromIntegral (height * width) :*: 0
rofu = calcInverseRootsOfUnity (extent arr)
fftTrans = transpose . fftWithRoots rofu
in force $ A.map (/ scale) $ fftTrans $ fftTrans arr
fft3d :: Array DIM3 Complex
-> Array DIM3 Complex
fft3d arrIn
| Z :. depth :. height :. width <- extent arrIn
, (height /= width) || (height /= depth)
= error $ "fft3d: array is not a cube"
| otherwise
= let rofu = calcRootsOfUnity (extent arrIn)
transpose3 arr
= traverse arr
(\(Z :. k :. l :. m) -> (Z :. l :. m :. k))
(\f (Z :. l :. m :. k) -> f (Z :. k :. l :. m))
fftTrans = transpose3 . fftWithRoots rofu
in force $ fftTrans $ fftTrans $ fftTrans arrIn
ifft3d :: Array DIM3 Complex
-> Array DIM3 Complex
ifft3d arrIn
| Z :. depth :. height :. width <- extent arrIn
, (height /= width) || (height /= depth)
= error $ "ifft3d: array is not a cube"
| otherwise
= let rofu = calcInverseRootsOfUnity (extent arrIn)
transpose3 arr
= traverse arr
(\(Z :. k :. l :. m) -> (Z :. l :. m :. k))
(\f (Z :. l :. m :. k) -> f (Z :. k :. l :. m))
_ :. depth :. height :. width
= extent arrIn
scale = fromIntegral (height * width * depth) :*: 0
fftTrans = transpose3 . fftWithRoots rofu
in force $ A.map (/ scale) $ fftTrans $ fftTrans $ fftTrans arrIn
fftWithRoots
:: forall sh
. Shape sh
=> Array (sh :. Int) Complex
-> Array (sh :. Int) Complex
-> Array (sh :. Int) Complex
fftWithRoots rofu v
| not $ (denominator $ toRational (logBase (2 :: Double) $ fromIntegral vLen)) == 1
= error $ "fft: vector length of " ++ show vLen ++ " is not a power of 2"
| rLen /= vLen
= error $ "fft: length of vector (" ++ show vLen ++ ")"
++ " does not match the length of the roots (" ++ show rLen ++ ")"
| otherwise
= fftWithRoots' rofu v
where _ :. rLen = extent rofu
_ :. vLen = extent v
fftWithRoots'
:: Shape sh
=> Array (sh :. Int) Complex
-> Array (sh :. Int) Complex
-> Array (sh :. Int) Complex
fftWithRoots' rofu v
= case extent v of
_ :. 2 -> fft_two v
_ -> fft_split rofu v
fft_two v
= let vFn' vFn (sh :. 0) = vFn (sh :. 0) + vFn (sh :. 1)
vFn' vFn (sh :. 1) = vFn (sh :. 0) vFn (sh :. 1)
vFn' _ _ = error "Data.Array.Repa.Algorithms.FFT fft_two fail"
in traverse v id vFn'
fft_split rofu v
= let fft_lr = force $ fftWithRoots' (splitRofu rofu) (splitVector v)
fft_l = traverse2 fft_lr rofu
(\(sh :. 2 :. n) _ -> sh :. n)
(\f r (sh :. i) -> f (sh :. 0 :. i) + r (sh :. i) * f (sh :. 1 :. i))
fft_r = traverse2 fft_lr rofu
(\(sh :. 2 :. n) _ -> sh :. n)
(\f r (sh :. i) -> f (sh :. 0 :. i) r (sh :. i) * f (sh :. 1 :. i))
in fft_l +:+ fft_r
splitRofu rofu
= traverse rofu
(\(rSh :. rLen) -> rSh :. (2::Int) :. (rLen `div` 2))
(\rFn (sh :. _ :. i) -> rFn (sh :. 2*i))
splitVector v
= let vFn' vFn (sh :. 0 :. i) = vFn (sh :. 2*i)
vFn' vFn (sh :. 1 :. i) = vFn (sh :. 2*i+1)
vFn' _ _ = error "Data.Array.Repa.Algorithms.FFT splitVector fail"
in traverse v
(\(vSh :. vLen) -> vSh :. 2 :. (vLen `div` 2))
vFn'