module Data.Array.Repa.Algorithms.DFT
( dft
, idft
, dftWithRoots
, dftWithRootsSingle)
where
import Data.Array.Repa.Algorithms.DFT.Roots
import Data.Array.Repa.Algorithms.Complex
import Data.Array.Repa as A
dft :: forall sh
. Shape sh
=> Array (sh :. Int) Complex
-> Array (sh :. Int) Complex
dft v
= let rofu = calcRootsOfUnity (extent v)
in force $ dftWithRoots rofu v
idft :: forall sh
. Shape sh
=> Array (sh :. Int) Complex
-> Array (sh :. Int) Complex
idft v
= let _ :. len = extent v
scale = fromIntegral len :*: 0
rofu = calcInverseRootsOfUnity (extent v)
in force $ A.map (/ scale) $ dftWithRoots rofu v
dftWithRoots
:: forall sh
. Shape sh
=> Array (sh :. Int) Complex
-> Array (sh :. Int) Complex
-> Array (sh :. Int) Complex
dftWithRoots rofu arr
| _ :. rLen <- extent rofu
, _ :. vLen <- extent arr
, rLen /= vLen
= error $ "dftWithRoots: length of vector (" ++ show vLen ++ ")"
++ " does not match the length of the roots (" ++ show rLen ++ ")"
| otherwise
= traverse arr id (\_ k -> dftWithRootsSingle rofu arr k)
dftWithRootsSingle
:: forall sh
. Shape sh
=> Array (sh :. Int) Complex
-> Array (sh :. Int) Complex
-> (sh :. Int)
-> Complex
dftWithRootsSingle rofu arrX (_ :. k)
| _ :. rLen <- extent rofu
, _ :. vLen <- extent arrX
, rLen /= vLen
= error $ "dftWithRootsSingle: length of vector (" ++ show vLen ++ ")"
++ " does not match the length of the roots (" ++ show rLen ++ ")"
| otherwise
= let sh@(_ :. len) = extent arrX
wroots = fromFunction sh elemFn
elemFn (sh' :. n)
= rofu !: (sh' :. (k * n) `mod` len)
in A.sumAll $ A.zipWith (*) arrX wroots