{-# LANGUAGE ConstraintKinds     #-}
{-# LANGUAGE FlexibleContexts    #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeOperators       #-}
-- |
-- Module      : Data.Array.Accelerate.Math.DFT
-- Copyright   : [2012..2020] The Accelerate Team
-- License     : BSD3
--
-- Maintainer  : Trevor L. McDonell <trevor.mcdonell@gmail.com>
-- Stability   : experimental
-- Portability : non-portable (GHC extensions)
--
-- Compute the Discrete Fourier Transform (DFT) along the lower order dimension
-- of an array.
--
-- This uses a naïve algorithm which takes O(n^2) time. However, you can
-- transform an array with an arbitrary extent, unlike with FFT which requires
-- each dimension to be a power of two.
--
-- The `dft` and `idft` functions compute the roots of unity as needed. If you
-- need to transform several arrays with the same extent than it is faster to
-- compute the roots once using `rootsOfUnity` or `inverseRootsOfUnity`
-- respectively, then call `dftG` directly.
--
-- You can also compute single values of the transform using `dftGS`
--
module Data.Array.Accelerate.Math.DFT (

  dft, idft, dftG, dftGS,

) where

import Prelude                                  as P hiding ((!!))
import Data.Array.Accelerate                    as A
import Data.Array.Accelerate.Math.DFT.Roots
import Data.Array.Accelerate.Data.Complex


-- | Compute the DFT along the low order dimension of an array
--
dft :: (Shape sh, Slice sh, A.RealFloat e, A.FromIntegral Int e)
    => Acc (Array (sh:.Int) (Complex e))
    -> Acc (Array (sh:.Int) (Complex e))
dft :: Acc (Array (sh :. Int) (Complex e))
-> Acc (Array (sh :. Int) (Complex e))
dft Acc (Array (sh :. Int) (Complex e))
v = Acc (Array (sh :. Int) (Complex e))
-> Acc (Array (sh :. Int) (Complex e))
-> Acc (Array (sh :. Int) (Complex e))
forall sh e.
(Shape sh, Slice sh, RealFloat e) =>
Acc (Array (sh :. Int) (Complex e))
-> Acc (Array (sh :. Int) (Complex e))
-> Acc (Array (sh :. Int) (Complex e))
dftG (Exp (sh :. Int) -> Acc (Array (sh :. Int) (Complex e))
forall sh e.
(Shape sh, Slice sh, Floating e, FromIntegral Int e) =>
Exp (sh :. Int) -> Acc (Array (sh :. Int) (Complex e))
rootsOfUnity (Acc (Array (sh :. Int) (Complex e)) -> Exp (sh :. Int)
forall sh e. (Shape sh, Elt e) => Acc (Array sh e) -> Exp sh
shape Acc (Array (sh :. Int) (Complex e))
v)) Acc (Array (sh :. Int) (Complex e))
v


-- | Compute the inverse DFT along the low order dimension of an array
--
idft :: (Shape sh, Slice sh, A.RealFloat e, A.FromIntegral Int e)
     => Acc (Array (sh:.Int) (Complex e))
     -> Acc (Array (sh:.Int) (Complex e))
idft :: Acc (Array (sh :. Int) (Complex e))
-> Acc (Array (sh :. Int) (Complex e))
idft Acc (Array (sh :. Int) (Complex e))
v
  = let sh :: Exp (sh :. Int)
sh      = Acc (Array (sh :. Int) (Complex e)) -> Exp (sh :. Int)
forall sh e. (Shape sh, Elt e) => Acc (Array sh e) -> Exp sh
shape Acc (Array (sh :. Int) (Complex e))
v
        n :: Exp Int
n       = Exp (sh :. Int) -> Exp Int
forall sh a. (Elt sh, Elt a) => Exp (sh :. a) -> Exp a
indexHead Exp (sh :. Int)
sh
        roots :: Acc (Array (sh :. Int) (Complex e))
roots   = Exp (sh :. Int) -> Acc (Array (sh :. Int) (Complex e))
forall sh e.
(Shape sh, Slice sh, Floating e, FromIntegral Int e) =>
Exp (sh :. Int) -> Acc (Array (sh :. Int) (Complex e))
inverseRootsOfUnity Exp (sh :. Int)
sh
        scale :: Exp (Plain (Complex (Exp e)))
scale   = Complex (Exp e) -> Exp (Plain (Complex (Exp e)))
forall (c :: * -> *) e. Lift c e => e -> c (Plain e)
lift (Exp Int -> Exp e
forall a b. (FromIntegral a b, Integral a) => Exp a -> Exp b
A.fromIntegral Exp Int
n Exp e -> Exp e -> Complex (Exp e)
forall a. a -> a -> Complex a
:+ Exp e
0)
    in
    (Exp (Complex e) -> Exp (Complex e))
-> Acc (Array (sh :. Int) (Complex e))
-> Acc (Array (sh :. Int) (Complex e))
forall sh a b.
(Shape sh, Elt a, Elt b) =>
(Exp a -> Exp b) -> Acc (Array sh a) -> Acc (Array sh b)
A.map (Exp (Complex e) -> Exp (Complex e) -> Exp (Complex e)
forall a. Fractional a => a -> a -> a
/Exp (Complex e)
scale) (Acc (Array (sh :. Int) (Complex e))
 -> Acc (Array (sh :. Int) (Complex e)))
-> Acc (Array (sh :. Int) (Complex e))
-> Acc (Array (sh :. Int) (Complex e))
forall a b. (a -> b) -> a -> b
$ Acc (Array (sh :. Int) (Complex e))
-> Acc (Array (sh :. Int) (Complex e))
-> Acc (Array (sh :. Int) (Complex e))
forall sh e.
(Shape sh, Slice sh, RealFloat e) =>
Acc (Array (sh :. Int) (Complex e))
-> Acc (Array (sh :. Int) (Complex e))
-> Acc (Array (sh :. Int) (Complex e))
dftG Acc (Array (sh :. Int) (Complex e))
roots Acc (Array (sh :. Int) (Complex e))
v


-- | Generic function for computation of forward and inverse DFT. This function
--   is also useful if you transform many arrays of the same extent, and don't
--   want to recompute the roots for each one.
--
--   The extent of the input and roots must match.
--
dftG :: forall sh e. (Shape sh, Slice sh, A.RealFloat e)
     => Acc (Array (sh:.Int) (Complex e))       -- ^ roots of unity
     -> Acc (Array (sh:.Int) (Complex e))       -- ^ input array
     -> Acc (Array (sh:.Int) (Complex e))
dftG :: Acc (Array (sh :. Int) (Complex e))
-> Acc (Array (sh :. Int) (Complex e))
-> Acc (Array (sh :. Int) (Complex e))
dftG Acc (Array (sh :. Int) (Complex e))
roots Acc (Array (sh :. Int) (Complex e))
arr
  = (Exp (Complex e) -> Exp (Complex e) -> Exp (Complex e))
-> Exp (Complex e)
-> Acc (Array ((sh :. Int) :. Int) (Complex e))
-> Acc (Array (sh :. Int) (Complex e))
forall sh a.
(Shape sh, Elt a) =>
(Exp a -> Exp a -> Exp a)
-> Exp a -> Acc (Array (sh :. Int) a) -> Acc (Array sh a)
A.fold Exp (Complex e) -> Exp (Complex e) -> Exp (Complex e)
forall a. Num a => a -> a -> a
(+) Exp (Complex e)
0
  (Acc (Array ((sh :. Int) :. Int) (Complex e))
 -> Acc (Array (sh :. Int) (Complex e)))
-> Acc (Array ((sh :. Int) :. Int) (Complex e))
-> Acc (Array (sh :. Int) (Complex e))
forall a b. (a -> b) -> a -> b
$ (Exp (Complex e) -> Exp (Complex e) -> Exp (Complex e))
-> Acc (Array ((sh :. Int) :. Int) (Complex e))
-> Acc (Array ((sh :. Int) :. Int) (Complex e))
-> Acc (Array ((sh :. Int) :. Int) (Complex e))
forall sh a b c.
(Shape sh, Elt a, Elt b, Elt c) =>
(Exp a -> Exp b -> Exp c)
-> Acc (Array sh a) -> Acc (Array sh b) -> Acc (Array sh c)
A.zipWith Exp (Complex e) -> Exp (Complex e) -> Exp (Complex e)
forall a. Num a => a -> a -> a
(*) Acc (Array ((sh :. Int) :. Int) (Complex e))
arr' Acc (Array ((sh :. Int) :. Int) (Complex e))
roots'
  where
    base :: Exp (sh :. Int)
base        = Acc (Array (sh :. Int) (Complex e)) -> Exp (sh :. Int)
forall sh e. (Shape sh, Elt e) => Acc (Array sh e) -> Exp sh
shape Acc (Array (sh :. Int) (Complex e))
arr
    l :: Exp Int
l           = Exp (sh :. Int) -> Exp Int
forall sh a. (Elt sh, Elt a) => Exp (sh :. a) -> Exp a
indexHead Exp (sh :. Int)
base
    extend :: Exp (Plain (Exp (sh :. Int) :. Exp Int))
extend      = (Exp (sh :. Int) :. Exp Int)
-> Exp (Plain (Exp (sh :. Int) :. Exp Int))
forall (c :: * -> *) e. Lift c e => e -> c (Plain e)
lift (Exp (sh :. Int)
base Exp (sh :. Int) -> Exp Int -> Exp (sh :. Int) :. Exp Int
forall tail head. tail -> head -> tail :. head
:. Exp (sh :. Int) -> Exp Int
forall sh. Shape sh => Exp sh -> Exp Int
shapeSize Exp (sh :. Int)
base)

    -- Extend the entirety of the input arrays into a higher dimension, reading
    -- roots from the appropriate places and then reduce along this axis.
    --
    -- In the calculation for 'roots'', 'i' is the index into the extended
    -- dimension, with corresponding base index 'ix' which we are attempting to
    -- calculate the single DFT value of. The rest proceeds as per 'dftGS'.
    --
    arr' :: Acc (Array ((sh :. Int) :. Int) (Complex e))
arr'        = Exp ((sh :. Int) :. Int)
-> (Exp ((sh :. Int) :. Int) -> Exp (Complex e))
-> Acc (Array ((sh :. Int) :. Int) (Complex e))
forall sh a.
(Shape sh, Elt a) =>
Exp sh -> (Exp sh -> Exp a) -> Acc (Array sh a)
A.generate Exp ((sh :. Int) :. Int)
extend (\Exp ((sh :. Int) :. Int)
ix' -> let i :: Exp Int
i = Exp ((sh :. Int) :. Int) -> Exp Int
forall sh a. (Elt sh, Elt a) => Exp (sh :. a) -> Exp a
indexHead Exp ((sh :. Int) :. Int)
ix' in Acc (Array (sh :. Int) (Complex e))
arr Acc (Array (sh :. Int) (Complex e)) -> Exp Int -> Exp (Complex e)
forall sh e.
(Shape sh, Elt e) =>
Acc (Array sh e) -> Exp Int -> Exp e
!! Exp Int
i)
    roots' :: Acc (Array ((sh :. Int) :. Int) (Complex e))
roots'      = Exp ((sh :. Int) :. Int)
-> (Exp ((sh :. Int) :. Int) -> Exp (Complex e))
-> Acc (Array ((sh :. Int) :. Int) (Complex e))
forall sh a.
(Shape sh, Elt a) =>
Exp sh -> (Exp sh -> Exp a) -> Acc (Array sh a)
A.generate Exp ((sh :. Int) :. Int)
extend (\Exp ((sh :. Int) :. Int)
ix' -> let Exp (sh :. Int)
ix :. Exp Int
i    = Exp (Plain (Exp (sh :. Int) :. Exp Int))
-> Exp (sh :. Int) :. Exp Int
forall (c :: * -> *) e. Unlift c e => c (Plain e) -> e
unlift Exp (Plain (Exp (sh :. Int) :. Exp Int))
Exp ((sh :. Int) :. Int)
ix'
                                                 Exp sh
sh :. Exp Int
n    = Exp (Plain (Exp sh :. Exp Int)) -> Exp sh :. Exp Int
forall (c :: * -> *) e. Unlift c e => c (Plain e) -> e
unlift (Exp (sh :. Int) -> Exp Int -> Exp (sh :. Int)
forall sh. Shape sh => Exp sh -> Exp Int -> Exp sh
fromIndex Exp (sh :. Int)
base Exp Int
i) :: Exp sh :. Exp Int
                                                 k :: Exp Int
k          = Exp (sh :. Int) -> Exp Int
forall sh a. (Elt sh, Elt a) => Exp (sh :. a) -> Exp a
indexHead Exp (sh :. Int)
ix
                                             in
                                             Acc (Array (sh :. Int) (Complex e))
roots Acc (Array (sh :. Int) (Complex e))
-> Exp (sh :. Int) -> Exp (Complex e)
forall sh e.
(Shape sh, Elt e) =>
Acc (Array sh e) -> Exp sh -> Exp e
! (Exp sh :. Exp Int) -> Exp (Plain (Exp sh :. Exp Int))
forall (c :: * -> *) e. Lift c e => e -> c (Plain e)
lift (Exp sh
sh Exp sh -> Exp Int -> Exp sh :. Exp Int
forall tail head. tail -> head -> tail :. head
:. (Exp Int
kExp Int -> Exp Int -> Exp Int
forall a. Num a => a -> a -> a
*Exp Int
n) Exp Int -> Exp Int -> Exp Int
forall a. Integral a => a -> a -> a
`mod` Exp Int
l))


-- | Compute a single value of the DFT.
--
dftGS :: forall sh e. (Shape sh, Slice sh, A.RealFloat e)
      => Exp (sh :. Int)                        -- ^ index of the value we want
      -> Acc (Array (sh:.Int) (Complex e))      -- ^ roots of unity
      -> Acc (Array (sh:.Int) (Complex e))      -- ^ input array
      -> Acc (Scalar (Complex e))
dftGS :: Exp (sh :. Int)
-> Acc (Array (sh :. Int) (Complex e))
-> Acc (Array (sh :. Int) (Complex e))
-> Acc (Scalar (Complex e))
dftGS Exp (sh :. Int)
ix Acc (Array (sh :. Int) (Complex e))
roots Acc (Array (sh :. Int) (Complex e))
arr
  = let k :: Exp Int
k = Exp (sh :. Int) -> Exp Int
forall sh a. (Elt sh, Elt a) => Exp (sh :. a) -> Exp a
indexHead Exp (sh :. Int)
ix
        l :: Exp Int
l = Exp (sh :. Int) -> Exp Int
forall sh a. (Elt sh, Elt a) => Exp (sh :. a) -> Exp a
indexHead (Acc (Array (sh :. Int) (Complex e)) -> Exp (sh :. Int)
forall sh e. (Shape sh, Elt e) => Acc (Array sh e) -> Exp sh
shape Acc (Array (sh :. Int) (Complex e))
arr)

        -- all the roots we need to multiply with
        roots' :: Acc (Array (sh :. Int) (Complex e))
roots'  = Exp (sh :. Int)
-> (Exp (sh :. Int) -> Exp (Complex e))
-> Acc (Array (sh :. Int) (Complex e))
forall sh a.
(Shape sh, Elt a) =>
Exp sh -> (Exp sh -> Exp a) -> Acc (Array sh a)
A.generate (Acc (Array (sh :. Int) (Complex e)) -> Exp (sh :. Int)
forall sh e. (Shape sh, Elt e) => Acc (Array sh e) -> Exp sh
shape Acc (Array (sh :. Int) (Complex e))
arr)
                             (\Exp (sh :. Int)
ix' -> let Exp sh
sh :. Exp Int
n = Exp (Plain (Exp sh :. Exp Int)) -> Exp sh :. Exp Int
forall (c :: * -> *) e. Unlift c e => c (Plain e) -> e
unlift Exp (Plain (Exp sh :. Exp Int))
Exp (sh :. Int)
ix'  :: Exp sh :. Exp Int
                                      in  Acc (Array (sh :. Int) (Complex e))
roots Acc (Array (sh :. Int) (Complex e))
-> Exp (sh :. Int) -> Exp (Complex e)
forall sh e.
(Shape sh, Elt e) =>
Acc (Array sh e) -> Exp sh -> Exp e
! (Exp sh :. Exp Int) -> Exp (Plain (Exp sh :. Exp Int))
forall (c :: * -> *) e. Lift c e => e -> c (Plain e)
lift (Exp sh
sh Exp sh -> Exp Int -> Exp sh :. Exp Int
forall tail head. tail -> head -> tail :. head
:. (Exp Int
kExp Int -> Exp Int -> Exp Int
forall a. Num a => a -> a -> a
*Exp Int
n) Exp Int -> Exp Int -> Exp Int
forall a. Integral a => a -> a -> a
`mod` Exp Int
l))
    in
    (Exp (Complex e) -> Exp (Complex e) -> Exp (Complex e))
-> Exp (Complex e)
-> Acc (Array (sh :. Int) (Complex e))
-> Acc (Scalar (Complex e))
forall sh a.
(Shape sh, Elt a) =>
(Exp a -> Exp a -> Exp a)
-> Exp a -> Acc (Array sh a) -> Acc (Scalar a)
A.foldAll Exp (Complex e) -> Exp (Complex e) -> Exp (Complex e)
forall a. Num a => a -> a -> a
(+) Exp (Complex e)
0 (Acc (Array (sh :. Int) (Complex e)) -> Acc (Scalar (Complex e)))
-> Acc (Array (sh :. Int) (Complex e)) -> Acc (Scalar (Complex e))
forall a b. (a -> b) -> a -> b
$ (Exp (Complex e) -> Exp (Complex e) -> Exp (Complex e))
-> Acc (Array (sh :. Int) (Complex e))
-> Acc (Array (sh :. Int) (Complex e))
-> Acc (Array (sh :. Int) (Complex e))
forall sh a b c.
(Shape sh, Elt a, Elt b, Elt c) =>
(Exp a -> Exp b -> Exp c)
-> Acc (Array sh a) -> Acc (Array sh b) -> Acc (Array sh c)
A.zipWith Exp (Complex e) -> Exp (Complex e) -> Exp (Complex e)
forall a. Num a => a -> a -> a
(*) Acc (Array (sh :. Int) (Complex e))
arr Acc (Array (sh :. Int) (Complex e))
roots'