{-# LANGUAGE TypeOperators, PatternGuards, RankNTypes, ScopedTypeVariables, BangPatterns, FlexibleContexts #-}
{-# OPTIONS -fno-warn-incomplete-patterns #-}

-- | Fast computation of Discrete Fourier Transforms using the Cooley-Tuckey algorithm. 
--   Time complexity is O(n log n) in the size of the input. 
--
--   This uses a naive divide-and-conquer algorithm, the absolute performance is about
--   50x slower than FFTW in estimate mode.
--
module Data.Array.Repa.Algorithms.FFT
        ( Mode(..)
        , isPowerOfTwo
        , fft3dP
        , fft2dP
        , fft1dP)
where
import Data.Array.Repa.Algorithms.Complex
import Data.Array.Repa                          as R
import Data.Array.Repa.Eval                     as R
import Data.Array.Repa.Unsafe                   as R
import Prelude                                  as P


data Mode
        = Forward
        | Reverse
        | Inverse
        deriving (Show, Eq)


signOfMode :: Mode -> Double
signOfMode mode
 = case mode of
        Forward         -> (-1)
        Reverse         ->   1
        Inverse         ->   1
{-# INLINE signOfMode #-}


-- | Check if an `Int` is a power of two.
isPowerOfTwo :: Int -> Bool
isPowerOfTwo n
        | 0     <- n            = True
        | 2     <- n            = True
        | n `mod` 2 == 0        = isPowerOfTwo (n `div` 2)
        | otherwise             = False
{-# INLINE isPowerOfTwo #-}


-- 3D Transform -----------------------------------------------------------------------------------
-- | Compute the DFT of a 3d array. Array dimensions must be powers of two else `error`.
fft3dP  :: (Source r Complex, Monad m)
        => Mode
        -> Array r DIM3 Complex
        -> m (Array U DIM3 Complex)
fft3dP mode arr
 = let  _ :. depth :. height :. width   = extent arr
        !sign   = signOfMode mode
        !scale  = fromIntegral (depth * width * height) 
                
   in   if not (isPowerOfTwo depth && isPowerOfTwo height && isPowerOfTwo width)
         then error $ unlines
                [ "Data.Array.Repa.Algorithms.FFT: fft3d"
                , "  Array dimensions must be powers of two,"
                , "  but the provided array is " 
                        P.++ show height P.++ "x" P.++ show width P.++ "x" P.++ show depth ]
                   
         else arr `deepSeqArray` 
                case mode of
                        Forward -> now $ fftTrans3d sign $ fftTrans3d sign $ fftTrans3d sign arr
                        Reverse -> now $ fftTrans3d sign $ fftTrans3d sign $ fftTrans3d sign arr
                        Inverse -> computeP
                                $  R.map (/ scale) 
                                $  fftTrans3d sign $ fftTrans3d sign $ fftTrans3d sign arr
{-# INLINE fft3dP #-}


fftTrans3d 
        :: Source r Complex
        => Double
        -> Array r DIM3 Complex 
        -> Array U DIM3 Complex

fftTrans3d sign arr
 = let  (sh :. len)     = extent arr
   in   suspendedComputeP $ rotate3d $ fft sign sh len arr
{-# INLINE fftTrans3d #-}


rotate3d 
        :: Source r Complex
        => Array r DIM3 Complex -> Array D DIM3 Complex
rotate3d arr
 = backpermute (sh :. m :. k :. l) f arr
 where  (sh :. k :. l :. m)             = extent arr
        f (sh' :. m' :. k' :. l')       = sh' :. k' :. l' :. m'
{-# INLINE rotate3d #-}



-- Matrix Transform -------------------------------------------------------------------------------
-- | Compute the DFT of a matrix. Array dimensions must be powers of two else `error`.
fft2dP  :: (Source r Complex, Monad m)
        => Mode
        -> Array r DIM2 Complex
        -> m (Array U DIM2 Complex)
fft2dP mode arr
 = let  _ :. height :. width    = extent arr
        sign    = signOfMode mode
        scale   = fromIntegral (width * height) 
                
   in   if not (isPowerOfTwo height && isPowerOfTwo width)
         then error $ unlines
                [ "Data.Array.Repa.Algorithms.FFT: fft2d"
                , "  Array dimensions must be powers of two,"
                , "  but the provided array is " P.++ show height P.++ "x" P.++ show width ]
         
         else arr `deepSeqArray` 
                case mode of
                        Forward -> now $ fftTrans2d sign $ fftTrans2d sign arr
                        Reverse -> now $ fftTrans2d sign $ fftTrans2d sign arr
                        Inverse -> computeP $ R.map (/ scale) $ fftTrans2d sign $ fftTrans2d sign arr
{-# INLINE fft2dP #-}


fftTrans2d
        :: Source r Complex
        => Double
        -> Array r DIM2 Complex 
        -> Array U DIM2 Complex

fftTrans2d sign arr
 = let  (sh :. len)     = extent arr
   in   suspendedComputeP $ transpose $ fft sign sh len arr
{-# INLINE fftTrans2d #-}


-- Vector Transform -------------------------------------------------------------------------------
-- | Compute the DFT of a vector. Array dimensions must be powers of two else `error`.
fft1dP  :: (Source r Complex, Monad m)
        => Mode 
        -> Array r DIM1 Complex 
        -> m (Array U DIM1 Complex)
fft1dP mode arr
 = let  _ :. len        = extent arr
        sign    = signOfMode mode
        scale   = fromIntegral len
        
   in   if not $ isPowerOfTwo len
         then error $ unlines 
                [ "Data.Array.Repa.Algorithms.FFT: fft1d"
                , "  Array dimensions must be powers of two, "
                , "  but the provided array is " P.++ show len ]
              
         else arr `deepSeqArray`
                case mode of
                        Forward -> now $ fftTrans1d sign arr
                        Reverse -> now $ fftTrans1d sign arr
                        Inverse -> computeP $ R.map (/ scale) $ fftTrans1d sign arr
{-# INLINE fft1dP #-}


fftTrans1d
        :: Source r Complex
        => Double 
        -> Array r DIM1 Complex
        -> Array U DIM1 Complex

fftTrans1d sign arr
 = let  (sh :. len)     = extent arr
   in   fft sign sh len arr
{-# INLINE fftTrans1d #-}


-- Rank Generalised Worker ------------------------------------------------------------------------
fft     :: (Shape sh, Source r Complex)
        => Double -> sh -> Int 
        -> Array r (sh :. Int) Complex
        -> Array U (sh :. Int) Complex

fft !sign !sh !lenVec !vec
 = go lenVec 0 1
 where  go !len !offset !stride
         | len == 2
         = suspendedComputeP $ fromFunction (sh :. 2) swivel
        
         | otherwise
         = combine len 
                (go (len `div` 2) offset            (stride * 2))
                (go (len `div` 2) (offset + stride) (stride * 2))

         where  swivel (sh' :. ix)
                 = case ix of
                        0       -> (vec `unsafeIndex` (sh' :. offset)) + (vec `unsafeIndex` (sh' :. (offset + stride)))
                        1       -> (vec `unsafeIndex` (sh' :. offset)) - (vec `unsafeIndex` (sh' :. (offset + stride)))

                {-# INLINE combine #-}
                combine !len'   evens odds
                 = evens `deepSeqArray` odds `deepSeqArray`
                   let  odds'   = unsafeTraverse odds id (\get ix@(_ :. k) -> twiddle sign k len' * get ix) 
                   in   suspendedComputeP $ (evens +^ odds') R.++ (evens -^ odds')
{-# INLINE fft #-}


-- Compute a twiddle factor.
twiddle :: Double
        -> Int                  -- index
        -> Int                  -- length
        -> Complex

twiddle sign k' n'
        =  (cos (2 * pi * k / n), sign * sin  (2 * pi * k / n))
        where   k       = fromIntegral k'
                n       = fromIntegral n'
{-# INLINE twiddle #-}