module Data.Array.Repa.FFTW
  ( 
    
    
    
    
    
    
    fft
  , ifft
    
  , fft2d
  , ifft2d
    
  , fft3d
  , ifft3d
  ) where
import Data.Complex (Complex(..))
import Foreign.ForeignPtr (withForeignPtr)
import Foreign.Storable (Storable(..))
import System.IO.Unsafe (unsafePerformIO)
import Data.Array.CArray (CArray)
import Data.Array.Repa ((:.)(..), Array, DIM1, DIM2, DIM3, Z(..))
import Data.Array.Repa.Repr.ForeignPtr (F)
import Foreign.Storable.Complex ()
import qualified Data.Array.CArray as C
import qualified Data.Array.Repa as R
import qualified Data.Array.Repa.Repr.ForeignPtr as RF
import qualified Math.FFT as FFT
fft :: Array F DIM1 (Complex Double) -> Array F DIM1 (Complex Double)
fft = c2r . FFT.dft . r2c
ifft :: Array F DIM1 (Complex Double) -> Array F DIM1 (Complex Double)
ifft = c2r . FFT.idft . r2c
fft2d :: Array F DIM2 (Complex Double) -> Array F DIM2 (Complex Double)
fft2d = c2r2d . FFT.dftN [0,1] . r2c2d
ifft2d :: Array F DIM2 (Complex Double) -> Array F DIM2 (Complex Double)
ifft2d = c2r2d . FFT.idftN [0,1] . r2c2d
fft3d :: Array F DIM3 (Complex Double) -> Array F DIM3 (Complex Double)
fft3d = c2r3d . FFT.dftN [0,1,2] . r2c3d
ifft3d :: Array F DIM3 (Complex Double) -> Array F DIM3 (Complex Double)
ifft3d = c2r3d . FFT.idftN [0,1,2] . r2c3d
r2c :: Array F DIM1 (Complex Double) -> CArray Int (Complex Double)
r2c rarr = unsafePerformIO $ do
  let _:.nelem = R.extent rarr
      fptr = RF.toForeignPtr rarr
  C.unsafeForeignPtrToCArray fptr (0,nelem1)
c2r :: CArray Int (Complex Double) -> Array F DIM1 (Complex Double)
c2r carr = case C.toForeignPtr carr of
  (n, fptr) -> let sh = Z:.n in
    R.computeS $ R.fromFunction sh $ \ix ->
    unsafePerformIO $ withForeignPtr fptr $ \ptr ->
    peekElemOff ptr $ R.toIndex sh ix
r2c2d :: Array F DIM2 (Complex Double) -> CArray (Int, Int) (Complex Double)
r2c2d rarr = unsafePerformIO $ do
    let _:.n1:.n2 = R.extent rarr
        fptr = RF.toForeignPtr rarr
    C.unsafeForeignPtrToCArray fptr ((0,0), (n11, n21))
c2r2d :: CArray (Int, Int) (Complex Double) -> Array F DIM2 (Complex Double)
c2r2d carr = case C.toForeignPtr carr of
    (n, fptr) ->
        let sh = Z:.n':.n'
            n' = ceiling $ (sqrt $ fromIntegral n :: Double)
        in  R.computeS $ R.fromFunction sh $ \ix ->
            unsafePerformIO $ withForeignPtr fptr $ \ptr ->
            peekElemOff ptr $ R.toIndex sh ix
r2c3d :: Array F DIM3 (Complex Double)
      -> CArray (Int, Int, Int) (Complex Double)
r2c3d rarr = unsafePerformIO $ do
    let _:.n1:.n2:.n3 = R.extent rarr
        fptr = RF.toForeignPtr rarr
    C.unsafeForeignPtrToCArray fptr ((0,0,0), (n11, n21, n31))
c2r3d :: CArray (Int, Int, Int) (Complex Double)
      -> Array F DIM3 (Complex Double)
c2r3d carr = case C.toForeignPtr carr of
    (n, fptr) ->
        let sh = Z:.n':.n':.n'
            n' = ceiling $ fromIntegral n ** (1/3 :: Double)
        in  R.computeS $ R.fromFunction sh $ \ix ->
            unsafePerformIO $ withForeignPtr fptr $ \ptr ->
            peekElemOff ptr $ R.toIndex sh ix