{-# LANGUAGE FlexibleContexts #-}
{-# language BangPatterns        #-}
{-# language LambdaCase          #-}
{-# options_ghc -Wno-unused-imports #-}
module Data.Vector.FFT (
  fft, ifft
  -- * Useful results
  , crossCorrelation
  ) where

import Control.Monad (when)
import Control.Monad.Primitive (PrimMonad(..))

import Control.Monad.ST (runST)
import Data.Bits (shiftR,shiftL,(.&.),(.|.))
import Data.Bool (Bool,otherwise)
import Data.Complex (Complex(..),conjugate)
import Data.Foldable (forM_)

import Data.Vector.Unboxed as V (Vector, Unbox, map, zipWith, length, unsafeFreeze, (!))
import qualified Data.Vector.Unboxed.Mutable as VM (MVector, read, write, new, length)
import qualified Data.Vector.Generic as VG (Vector(..), copy)

import Prelude hiding (read)

{-# RULES
"fft/ifft" forall x. fft (ifft x) = x
"ifft/fft" forall x. ifft (fft x) = x
  #-}


-- | (Circular) cross-correlation of two vectors
--
-- Defined via the FFT and IFFT for computational efficiency
--
-- NB the source vectors should have matching length for meaningful results
crossCorrelation :: Vector (Complex Double)
                 -> Vector (Complex Double)
                 -> Vector (Complex Double)
crossCorrelation :: Vector (Complex Double)
-> Vector (Complex Double) -> Vector (Complex Double)
crossCorrelation Vector (Complex Double)
v1 Vector (Complex Double)
v2 = Vector (Complex Double) -> Vector (Complex Double)
ifft (Vector (Complex Double) -> Vector (Complex Double))
-> Vector (Complex Double) -> Vector (Complex Double)
forall a b. (a -> b) -> a -> b
$ ((Complex Double -> Complex Double)
-> Vector (Complex Double) -> Vector (Complex Double)
forall a.
(Floating a, Unbox a) =>
(Complex a -> Complex a)
-> Vector (Complex a) -> Vector (Complex a)
cmap Complex Double -> Complex Double
forall a. Num a => Complex a -> Complex a
conjugate Vector (Complex Double)
v1hat) Vector (Complex Double)
-> Vector (Complex Double) -> Vector (Complex Double)
`prod` Vector (Complex Double)
v2hat
  where
    prod :: Vector (Complex Double)
-> Vector (Complex Double) -> Vector (Complex Double)
prod = (Complex Double -> Complex Double -> Complex Double)
-> Vector (Complex Double)
-> Vector (Complex Double)
-> Vector (Complex Double)
forall a b c.
(Unbox a, Unbox b, Unbox c) =>
(a -> b -> c) -> Vector a -> Vector b -> Vector c
V.zipWith Complex Double -> Complex Double -> Complex Double
forall a. Num a => a -> a -> a
(*)
    v1hat :: Vector (Complex Double)
v1hat = Vector (Complex Double) -> Vector (Complex Double)
fft Vector (Complex Double)
v1
    v2hat :: Vector (Complex Double)
v2hat = Vector (Complex Double) -> Vector (Complex Double)
fft Vector (Complex Double)
v2



-- | Radix-2 decimation-in-time fast Fourier Transform.
--
--   The given array (and therefore the output as well) is zero-padded to the next power of two if necessary.
fft :: Vector (Complex Double) -> Vector (Complex Double)
fft :: Vector (Complex Double) -> Vector (Complex Double)
fft Vector (Complex Double)
arr = (forall s. ST s (Vector (Complex Double)))
-> Vector (Complex Double)
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s (Vector (Complex Double)))
 -> Vector (Complex Double))
-> (forall s. ST s (Vector (Complex Double)))
-> Vector (Complex Double)
forall a b. (a -> b) -> a -> b
$ do
  MVector s (Complex Double)
marr <- Vector (Complex Double)
-> ST s (MVector (PrimState (ST s)) (Complex Double))
forall (m :: * -> *) a.
(PrimMonad m, Num a, Unbox a) =>
Vector a -> m (MVector (PrimState m) a)
copyPadded Vector (Complex Double)
arr
  MVector (PrimState (ST s)) (Complex Double) -> ST s ()
forall (m :: * -> *).
PrimMonad m =>
MVector (PrimState m) (Complex Double) -> m ()
mfft MVector s (Complex Double)
MVector (PrimState (ST s)) (Complex Double)
marr
  MVector (PrimState (ST s)) (Complex Double)
-> ST s (Vector (Complex Double))
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
MVector (PrimState m) a -> m (Vector a)
V.unsafeFreeze MVector s (Complex Double)
MVector (PrimState (ST s)) (Complex Double)
marr
{-# inlinable [1] fft #-}

-- | Inverse fast Fourier transform.
--
--   The given array (and therefore the output as well) is zero-padded to the next power of two if necessary.
ifft :: Vector (Complex Double) -> Vector (Complex Double)
ifft :: Vector (Complex Double) -> Vector (Complex Double)
ifft Vector (Complex Double)
arr = do
  let lenComplex :: Complex Double
lenComplex = Int -> Complex Double
intToComplexDouble (Vector (Complex Double) -> Int
forall a. Unbox a => Vector a -> Int
V.length Vector (Complex Double)
arr)
  (Complex Double -> Complex Double)
-> Vector (Complex Double) -> Vector (Complex Double)
forall a.
(Floating a, Unbox a) =>
(Complex a -> Complex a)
-> Vector (Complex a) -> Vector (Complex a)
cmap ((Complex Double -> Complex Double -> Complex Double
forall a. Fractional a => a -> a -> a
/ Complex Double
lenComplex) (Complex Double -> Complex Double)
-> (Complex Double -> Complex Double)
-> Complex Double
-> Complex Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Complex Double -> Complex Double
forall a. Num a => Complex a -> Complex a
conjugate) (Vector (Complex Double) -> Vector (Complex Double))
-> (Vector (Complex Double) -> Vector (Complex Double))
-> Vector (Complex Double)
-> Vector (Complex Double)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Vector (Complex Double) -> Vector (Complex Double)
fft (Vector (Complex Double) -> Vector (Complex Double))
-> (Vector (Complex Double) -> Vector (Complex Double))
-> Vector (Complex Double)
-> Vector (Complex Double)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Complex Double -> Complex Double)
-> Vector (Complex Double) -> Vector (Complex Double)
forall a.
(Floating a, Unbox a) =>
(Complex a -> Complex a)
-> Vector (Complex a) -> Vector (Complex a)
cmap Complex Double -> Complex Double
forall a. Num a => Complex a -> Complex a
conjugate (Vector (Complex Double) -> Vector (Complex Double))
-> Vector (Complex Double) -> Vector (Complex Double)
forall a b. (a -> b) -> a -> b
$ Vector (Complex Double)
arr
{-# inlinable [1] ifft #-}




-- | Copy the source vector into a zero-padded mutable one
copyPadded :: (PrimMonad m, Num a, Unbox a) =>
              Vector a -> m (VM.MVector (PrimState m) a)
copyPadded :: Vector a -> m (MVector (PrimState m) a)
copyPadded Vector a
arr = do
  let
    len :: Int
len = Vector a -> Int
forall a. Unbox a => Vector a -> Int
V.length Vector a
arr
    l2 :: Int
l2 = Int -> Int
nextPow2 Int
len
  MVector (PrimState m) a
marr <- Int -> m (MVector (PrimState m) a)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> m (MVector (PrimState m) a)
VM.new Int
l2
  [Int] -> (Int -> m ()) -> m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Int
0 .. Int
l2 Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1] ((Int -> m ()) -> m ()) -> (Int -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \Int
i -> do
    let x :: a
x | Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
len = Vector a
arr Vector a -> Int -> a
forall a. Unbox a => Vector a -> Int -> a
V.! Int
i
          | Bool
otherwise = a
0
    MVector (PrimState m) a -> Int -> a -> m ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
VM.write MVector (PrimState m) a
marr Int
i a
x
  MVector (PrimState m) a -> m (MVector (PrimState m) a)
forall (f :: * -> *) a. Applicative f => a -> f a
pure MVector (PrimState m) a
marr
{-# inline copyPadded #-}





-- | Radix-2 decimation-in-time fast Fourier Transform.
--   The given array must have a length that is a power of two,
--   though this property is not checked.
mfft :: (PrimMonad m) => VM.MVector (PrimState m) (Complex Double) -> m ()
mfft :: MVector (PrimState m) (Complex Double) -> m ()
mfft MVector (PrimState m) (Complex Double)
mut = do {
    let len :: Int
len = MVector (PrimState m) (Complex Double) -> Int
forall a s. Unbox a => MVector s a -> Int
VM.length MVector (PrimState m) (Complex Double)
mut
  ; let bitReverse :: Int -> Int -> m ()
bitReverse !Int
i !Int
j = do {
          ; if Int
i Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
len Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1
              then Int -> Int -> m ()
stage Int
0 Int
1
              else do {
                  Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
j) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ MVector (PrimState m) (Complex Double) -> Int -> Int -> m ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> Int -> m ()
swap MVector (PrimState m) (Complex Double)
mut Int
i Int
j
                ; let inner :: Int -> Int -> m ()
inner Int
k Int
l = if Int
k Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
l
                        then Int -> Int -> m ()
inner (Int
k Int -> Int -> Int
forall a. Bits a => a -> Int -> a
`shiftR` Int
1) (Int
l Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
k)
                        else Int -> Int -> m ()
bitReverse (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) (Int
l Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
k)
                ; Int -> Int -> m ()
inner (Int
len Int -> Int -> Int
forall a. Bits a => a -> Int -> a
`shiftR` Int
1) Int
j
              }
        }
        stage :: Int -> Int -> m ()
stage Int
l Int
l1 = if Int
l Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== (Int -> Int
log2 Int
len)
          then () -> m ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
          else do {
              let !l2 :: Int
l2 = Int
l1 Int -> Int -> Int
forall a. Bits a => a -> Int -> a
`shiftL` Int
1
                  !e :: Double
e = (Double -> Double
forall a. Num a => a -> a
negate Double
twoPi) Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ (Int -> Double
intToDouble Int
l2)
                  flight :: Int -> Double -> m ()
flight Int
j !Double
a = if Int
j Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
l1
                    then Int -> Int -> m ()
stage (Int
l Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Int
l2
                    else do {
                        let butterfly :: Int -> m ()
butterfly Int
i = if Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
len
                              then Int -> Double -> m ()
flight (Int
j Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) (Double
a Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
e)
                              else do {
                                  let i1 :: Int
i1 = Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
l1
                                ; Double
xi1 :+ Double
yi1 <- MVector (PrimState m) (Complex Double) -> Int -> m (Complex Double)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
VM.read MVector (PrimState m) (Complex Double)
mut Int
i1
                                ; let !co :: Double
co = Double -> Double
forall a. Floating a => a -> a
cos Double
a
                                      !si :: Double
si = Double -> Double
forall a. Floating a => a -> a
sin Double
a
                                      d :: Complex Double
d = (Double
co Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
xi1 Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
si Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
yi1) Double -> Double -> Complex Double
forall a. a -> a -> Complex a
:+ (Double
si Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
xi1 Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
co Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
yi1)
                                ; Complex Double
ci <- MVector (PrimState m) (Complex Double) -> Int -> m (Complex Double)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
VM.read MVector (PrimState m) (Complex Double)
mut Int
i
                                ; MVector (PrimState m) (Complex Double)
-> Int -> Complex Double -> m ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
VM.write MVector (PrimState m) (Complex Double)
mut Int
i1 (Complex Double
ci Complex Double -> Complex Double -> Complex Double
forall a. Num a => a -> a -> a
- Complex Double
d)
                                ; MVector (PrimState m) (Complex Double)
-> Int -> Complex Double -> m ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
VM.write MVector (PrimState m) (Complex Double)
mut Int
i (Complex Double
ci Complex Double -> Complex Double -> Complex Double
forall a. Num a => a -> a -> a
+ Complex Double
d)
                                ; Int -> m ()
butterfly (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
l2)
                              }
                      ; Int -> m ()
butterfly Int
j
                    }
            ; Int -> Double -> m ()
flight Int
0 Double
0
         }
  ; Int -> Int -> m ()
bitReverse Int
0 Int
0
}

-- wildcard cases should never happen. if they do, really bad things will happen.
b,s :: Int -> Int
b :: Int -> Int
b = \case { Int
0 -> Int
0x02; Int
1 -> Int
0x0c; Int
2 -> Int
0xf0; Int
3 -> Int
0xff00; Int
4 -> Word -> Int
wordToInt Word
0xffff0000; Int
5 -> Word -> Int
wordToInt Word
0xffffffff00000000; Int
_ -> Int
0; }
s :: Int -> Int
s = \case { Int
0 -> Int
1; Int
1 -> Int
2; Int
2 -> Int
4; Int
3 -> Int
8; Int
4 -> Int
16; Int
5 -> Int
32; Int
_ -> Int
0; }
{-# inline b #-}
{-# inline s #-}

-- | Next power of 2
nextPow2 :: Int -> Int
nextPow2 :: Int -> Int
nextPow2 Int
n
  | Int -> Int -> Int
forall a. Integral a => a -> a -> a
mod Int
n Int
2 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 = Int
n
  | Bool
otherwise = (Int
2 :: Int) Int -> Int -> Int
forall a b. (Num a, Integral b) => a -> b -> a
^ (Int -> Int
log2 Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)


log2 :: Int -> Int
log2 :: Int -> Int
log2 Int
v0 = if Int
v0 Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
0
  then [Char] -> Int
forall a. HasCallStack => [Char] -> a
error ([Char] -> Int) -> [Char] -> Int
forall a b. (a -> b) -> a -> b
$ [Char]
"Data.Vector.FFT: nonpositive input, got " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ Int -> [Char]
forall a. Show a => a -> [Char]
show Int
v0
  else Int -> Int -> Int -> Int
go Int
5 Int
0 Int
v0
  where
    go :: Int -> Int -> Int -> Int
go !Int
i !Int
r !Int
v
      | Int
i Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== -Int
1 = Int
r
      | Int
v Int -> Int -> Int
forall a. Bits a => a -> a -> a
.&. Int -> Int
b Int
i Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
0 =
          let si :: Int
si = Int -> Int
s Int
i
          in Int -> Int -> Int -> Int
go (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) (Int
r Int -> Int -> Int
forall a. Bits a => a -> a -> a
.|. Int
si) (Int
v Int -> Int -> Int
forall a. Bits a => a -> Int -> a
`shiftR` Int
si)
      | Bool
otherwise = Int -> Int -> Int -> Int
go (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) Int
r Int
v


{-# inline swap #-}
swap :: (PrimMonad m, Unbox a) =>
        VM.MVector (PrimState m) a -> Int -> Int -> m ()
swap :: MVector (PrimState m) a -> Int -> Int -> m ()
swap MVector (PrimState m) a
mut Int
i Int
j = do
  a
atI <- MVector (PrimState m) a -> Int -> m a
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
VM.read MVector (PrimState m) a
mut Int
i
  a
atJ <- MVector (PrimState m) a -> Int -> m a
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
VM.read MVector (PrimState m) a
mut Int
j
  MVector (PrimState m) a -> Int -> a -> m ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
VM.write MVector (PrimState m) a
mut Int
i a
atJ
  MVector (PrimState m) a -> Int -> a -> m ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
VM.write MVector (PrimState m) a
mut Int
j a
atI

twoPi :: Double
{-# inline twoPi #-}
twoPi :: Double
twoPi = Double
6.283185307179586

intToDouble :: Int -> Double
{-# inline intToDouble #-}
intToDouble :: Int -> Double
intToDouble = Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral

wordToInt :: Word -> Int
{-# inline wordToInt #-}
wordToInt :: Word -> Int
wordToInt = Word -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral

intToComplexDouble :: Int -> Complex Double
{-# inline intToComplexDouble #-}
intToComplexDouble :: Int -> Complex Double
intToComplexDouble = Int -> Complex Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral


{-# inline cmap #-}
cmap :: (Floating a, Unbox a) => (Complex a -> Complex a) -> V.Vector (Complex a) -> V.Vector (Complex a)
cmap :: (Complex a -> Complex a)
-> Vector (Complex a) -> Vector (Complex a)
cmap = (Complex a -> Complex a)
-> Vector (Complex a) -> Vector (Complex a)
forall a b. (Unbox a, Unbox b) => (a -> b) -> Vector a -> Vector b
V.map


--

-- {-# inline copyWhole #-}
-- copyWhole :: (PrimMonad m, VG.Vector Vector a, Unbox a) => V.Vector a -> m (VM.MVector (PrimState m) a)
-- copyWhole arr = do
--   let len = V.length arr
--   marr <- VM.new len
--   VG.copy marr arr
--   pure marr

-- {-# inline arrOK #-}
-- arrOK :: Unbox a => Vector a -> Bool
-- arrOK arr =
--   let n = V.length arr
--   in (1 `shiftL` log2 n) == n