{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE ConstraintKinds #-}
module Data.Array.Accelerate.FFTW.Manifest (
   Transform,
   Element,
   dft,
   idft,
   dft2d,
   idft2d,
   ) where

import qualified Math.FFT as FFT
import Math.FFT.Base (FFTWReal)

import qualified Data.Array.Accelerate.CArray.Conversion as Conv

import qualified Data.Array.Accelerate.IO as AIO
import qualified Data.Array.Accelerate as A
import Data.Array.Accelerate.Array.Sugar (EltRepr)
import Data.Array.Accelerate (Array, (:.)((:.)))

import Foreign.Storable.Complex ()

import Data.Array.Accelerate.Data.Complex (Complex)

import Foreign.Ptr (Ptr)


type Bounds i = (i,i)

type Transform sh a = Array sh a -> Array sh a

type Element a = (FFTWReal a, A.Elt a, AIO.BlockPtrs (EltRepr a) ~ Ptr a)

dft ::
   (A.Shape sh, Element a) =>
   Transform (sh:.Int) (Complex a)
dft = Conv.withCArrayComplex bnds1 $ FFT.dftN [1]

idft ::
   (A.Shape sh, Element a) =>
   Transform (sh:.Int) (Complex a)
idft = Conv.withCArrayComplex bnds1 $ FFT.dftN [1]

bnds1 :: (A.Shape sh) => sh:.Int -> Bounds (Int,Int)
bnds1 (sh:.n) = ((0,0), (A.arraySize sh-1, n-1))


dft2d ::
   (A.Shape sh, Element a) =>
   Transform (sh:.Int:.Int) (Complex a)
dft2d =
   Conv.withCArrayComplex bnds2 $ FFT.dftN [1,2]

idft2d ::
   (A.Shape sh, Element a) =>
   Transform (sh:.Int:.Int) (Complex a)
idft2d =
   Conv.withCArrayComplex bnds2 $ FFT.idftN [1,2]

bnds2 :: (A.Shape sh) => sh:.Int:.Int -> Bounds (Int,Int,Int)
bnds2 (sh:.n:.m) = ((0,0,0), (A.arraySize sh-1, n-1, m-1))