{-# LANGUAGE TypeOperators, RankNTypes, PatternGuards #-}

-- | Compute the Discrete Fourier Transform (DFT) along the low order dimension
--   of an array. 
--
--   This uses the naive algorithm and takes O(n^2) time. 
--   However, you can transform an array with an arbitray extent, unlike with FFT which requires
--   each dimension to be a power of two.
--
--   The `dft` and `idft` functions also compute the roots of unity needed.
--   If you need to transform several arrays with the same extent then it is faster to
--   compute the roots once using `calcRootsOfUnity` or `calcInverseRootsOfUnity`, 
--   then call `dftWithRoots` directly.
--
--   You can also compute single values of the transform using `dftWithRootsSingle`.
module Data.Array.Repa.Algorithms.DFT 
        ( dftP
        , idftP
        , dftWithRootsP
        , dftWithRootsSingleS)
where
import Data.Array.Repa.Algorithms.DFT.Roots     as R
import Data.Array.Repa.Algorithms.Complex       as R
import Data.Array.Repa                          as R
import Prelude                                  as P


-- | Compute the DFT along the low order dimension of an array.
dftP    :: (Shape sh, Monad m)
        => Array U (sh :. Int) Complex
        -> m (Array U (sh :. Int) Complex)

dftP :: Array U (sh :. Int) Complex -> m (Array U (sh :. Int) Complex)
dftP Array U (sh :. Int) Complex
v
 = do   Array U (sh :. Int) Complex
rofu    <- (sh :. Int) -> m (Array U (sh :. Int) Complex)
forall sh (m :: * -> *).
(Shape sh, Monad m) =>
(sh :. Int) -> m (Array U (sh :. Int) Complex)
calcRootsOfUnityP (Array U (sh :. Int) Complex -> sh :. Int
forall r e sh. (Source r e, Shape sh) => Array r sh e -> sh
extent Array U (sh :. Int) Complex
v)
        Array U (sh :. Int) Complex
-> Array U (sh :. Int) Complex -> m (Array U (sh :. Int) Complex)
forall sh (m :: * -> *).
(Shape sh, Monad m) =>
Array U (sh :. Int) Complex
-> Array U (sh :. Int) Complex -> m (Array U (sh :. Int) Complex)
dftWithRootsP Array U (sh :. Int) Complex
rofu Array U (sh :. Int) Complex
v
{-# INLINE dftP #-}


-- | Compute the inverse DFT along the low order dimension of an array.
idftP   :: (Shape sh, Monad m)
        => Array U (sh :. Int) Complex
        -> m (Array U (sh :. Int) Complex)

idftP :: Array U (sh :. Int) Complex -> m (Array U (sh :. Int) Complex)
idftP Array U (sh :. Int) Complex
v
 = do   let sh
_ :. Int
len    = Array U (sh :. Int) Complex -> sh :. Int
forall r e sh. (Source r e, Shape sh) => Array r sh e -> sh
extent Array U (sh :. Int) Complex
v
        let scale :: Complex
scale       = (Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
len, Double
0)
        Array U (sh :. Int) Complex
rofu            <- (sh :. Int) -> m (Array U (sh :. Int) Complex)
forall sh (m :: * -> *).
(Shape sh, Monad m) =>
(sh :. Int) -> m (Array U (sh :. Int) Complex)
calcInverseRootsOfUnityP (Array U (sh :. Int) Complex -> sh :. Int
forall r e sh. (Source r e, Shape sh) => Array r sh e -> sh
extent Array U (sh :. Int) Complex
v)
        Array U (sh :. Int) Complex
roots           <- Array U (sh :. Int) Complex
-> Array U (sh :. Int) Complex -> m (Array U (sh :. Int) Complex)
forall sh (m :: * -> *).
(Shape sh, Monad m) =>
Array U (sh :. Int) Complex
-> Array U (sh :. Int) Complex -> m (Array U (sh :. Int) Complex)
dftWithRootsP Array U (sh :. Int) Complex
rofu Array U (sh :. Int) Complex
v
        Array D (sh :. Int) Complex -> m (Array U (sh :. Int) Complex)
forall r1 sh e r2 (m :: * -> *).
(Load r1 sh e, Target r2 e, Source r2 e, Monad m) =>
Array r1 sh e -> m (Array r2 sh e)
computeP (Array D (sh :. Int) Complex -> m (Array U (sh :. Int) Complex))
-> Array D (sh :. Int) Complex -> m (Array U (sh :. Int) Complex)
forall a b. (a -> b) -> a -> b
$ (Complex -> Complex)
-> Array U (sh :. Int) Complex -> Array D (sh :. Int) Complex
forall sh r a b.
(Shape sh, Source r a) =>
(a -> b) -> Array r sh a -> Array D sh b
R.map (Complex -> Complex -> Complex
forall a. Fractional a => a -> a -> a
/ Complex
scale) Array U (sh :. Int) Complex
roots
{-# INLINE idftP #-}


-- | Generic function for computation of forward or inverse DFT.
--      This function is also useful if you transform many arrays with the same extent, 
--      and don't want to recompute the roots for each one.
--      The extent of the given roots must match that of the input array, else `error`.
dftWithRootsP
        :: (Shape sh, Monad m)
        => Array U (sh :. Int) Complex          -- ^ Roots of unity.
        -> Array U (sh :. Int) Complex          -- ^ Input array.
        -> m (Array U (sh :. Int) Complex)

dftWithRootsP :: Array U (sh :. Int) Complex
-> Array U (sh :. Int) Complex -> m (Array U (sh :. Int) Complex)
dftWithRootsP Array U (sh :. Int) Complex
rofu Array U (sh :. Int) Complex
arr
        | sh
_ :. Int
rLen     <- Array U (sh :. Int) Complex -> sh :. Int
forall r e sh. (Source r e, Shape sh) => Array r sh e -> sh
extent Array U (sh :. Int) Complex
rofu
        , sh
_ :. Int
vLen     <- Array U (sh :. Int) Complex -> sh :. Int
forall r e sh. (Source r e, Shape sh) => Array r sh e -> sh
extent Array U (sh :. Int) Complex
arr
        , Int
rLen Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
vLen
        = [Char] -> m (Array U (sh :. Int) Complex)
forall a. HasCallStack => [Char] -> a
error ([Char] -> m (Array U (sh :. Int) Complex))
-> [Char] -> m (Array U (sh :. Int) Complex)
forall a b. (a -> b) -> a -> b
$    [Char]
"dftWithRoots: length of vector (" [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
P.++ Int -> [Char]
forall a. Show a => a -> [Char]
show Int
vLen [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
P.++ [Char]
")"
                [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
P.++ [Char]
" does not match the length of the roots (" [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
P.++ Int -> [Char]
forall a. Show a => a -> [Char]
show Int
rLen [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
P.++ [Char]
")"

        | Bool
otherwise
        = Array D (sh :. Int) Complex -> m (Array U (sh :. Int) Complex)
forall r1 sh e r2 (m :: * -> *).
(Load r1 sh e, Target r2 e, Source r2 e, Monad m) =>
Array r1 sh e -> m (Array r2 sh e)
computeP (Array D (sh :. Int) Complex -> m (Array U (sh :. Int) Complex))
-> Array D (sh :. Int) Complex -> m (Array U (sh :. Int) Complex)
forall a b. (a -> b) -> a -> b
$ Array U (sh :. Int) Complex
-> ((sh :. Int) -> sh :. Int)
-> (((sh :. Int) -> Complex) -> (sh :. Int) -> Complex)
-> Array D (sh :. Int) Complex
forall r sh sh' a b.
(Source r a, Shape sh) =>
Array r sh a
-> (sh -> sh') -> ((sh -> a) -> sh' -> b) -> Array D sh' b
R.traverse Array U (sh :. Int) Complex
arr (sh :. Int) -> sh :. Int
forall a. a -> a
id (\(sh :. Int) -> Complex
_ sh :. Int
k -> Array U (sh :. Int) Complex
-> Array U (sh :. Int) Complex -> (sh :. Int) -> Complex
forall sh.
Shape sh =>
Array U (sh :. Int) Complex
-> Array U (sh :. Int) Complex -> (sh :. Int) -> Complex
dftWithRootsSingleS Array U (sh :. Int) Complex
rofu Array U (sh :. Int) Complex
arr sh :. Int
k)
{-# INLINE dftWithRootsP #-}            


-- | Compute a single value of the DFT.
--      The extent of the given roots must match that of the input array, else `error`.
dftWithRootsSingleS
        :: Shape sh
        => Array U (sh :. Int) Complex          -- ^ Roots of unity.
        -> Array U (sh :. Int) Complex          -- ^ Input array.
        -> (sh :. Int)                          -- ^ Index of the value we want.
        -> Complex

dftWithRootsSingleS :: Array U (sh :. Int) Complex
-> Array U (sh :. Int) Complex -> (sh :. Int) -> Complex
dftWithRootsSingleS Array U (sh :. Int) Complex
rofu Array U (sh :. Int) Complex
arrX (sh
_ :. Int
k)
        | sh
_ :. Int
rLen     <- Array U (sh :. Int) Complex -> sh :. Int
forall r e sh. (Source r e, Shape sh) => Array r sh e -> sh
extent Array U (sh :. Int) Complex
rofu
        , sh
_ :. Int
vLen     <- Array U (sh :. Int) Complex -> sh :. Int
forall r e sh. (Source r e, Shape sh) => Array r sh e -> sh
extent Array U (sh :. Int) Complex
arrX
        , Int
rLen Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
vLen
        = [Char] -> Complex
forall a. HasCallStack => [Char] -> a
error ([Char] -> Complex) -> [Char] -> Complex
forall a b. (a -> b) -> a -> b
$    [Char]
"dftWithRootsSingle: length of vector (" [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
P.++ Int -> [Char]
forall a. Show a => a -> [Char]
show Int
vLen [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
P.++ [Char]
")"
                [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
P.++ [Char]
" does not match the length of the roots (" [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
P.++ Int -> [Char]
forall a. Show a => a -> [Char]
show Int
rLen [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
P.++ [Char]
")"

        | Bool
otherwise
        = let   sh :: sh :. Int
sh@(sh
_ :. Int
len)   = Array U (sh :. Int) Complex -> sh :. Int
forall r e sh. (Source r e, Shape sh) => Array r sh e -> sh
extent Array U (sh :. Int) Complex
arrX

                -- All the roots we need to multiply with.
                wroots :: Array D (sh :. Int) Complex
wroots          = (sh :. Int)
-> ((sh :. Int) -> Complex) -> Array D (sh :. Int) Complex
forall sh a. sh -> (sh -> a) -> Array D sh a
fromFunction sh :. Int
sh (sh :. Int) -> Complex
elemFn
                elemFn :: (sh :. Int) -> Complex
elemFn (sh
sh' :. Int
n) 
                        = Array U (sh :. Int) Complex
rofu Array U (sh :. Int) Complex -> (sh :. Int) -> Complex
forall sh r e. (Shape sh, Source r e) => Array r sh e -> sh -> e
! (sh
sh' sh -> Int -> sh :. Int
forall tail head. tail -> head -> tail :. head
:. (Int
k Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
n) Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` Int
len)

          in  Array D (sh :. Int) Complex -> Complex
forall sh r a. (Shape sh, Source r a, Num a) => Array r sh a -> a
R.sumAllS (Array D (sh :. Int) Complex -> Complex)
-> Array D (sh :. Int) Complex -> Complex
forall a b. (a -> b) -> a -> b
$ (Complex -> Complex -> Complex)
-> Array U (sh :. Int) Complex
-> Array D (sh :. Int) Complex
-> Array D (sh :. Int) Complex
forall sh r1 a r2 b c.
(Shape sh, Source r1 a, Source r2 b) =>
(a -> b -> c) -> Array r1 sh a -> Array r2 sh b -> Array D sh c
R.zipWith Complex -> Complex -> Complex
forall a. Num a => a -> a -> a
(*) Array U (sh :. Int) Complex
arrX Array D (sh :. Int) Complex
wroots
{-# INLINE dftWithRootsSingleS #-}