{-# LANGUAGE BangPatterns, FlexibleContexts #-}
-- |
-- Module    : Statistics.Transform
-- Copyright : (c) 2011 Bryan O'Sullivan
-- License   : BSD3
--
-- Maintainer  : bos@serpentine.com
-- Stability   : experimental
-- Portability : portable
--
-- Fourier-related transformations of mathematical functions.
--
-- These functions are written for simplicity and correctness, not
-- speed.  If you need a fast FFT implementation for your application,
-- you should strongly consider using a library of FFTW bindings
-- instead.

module Statistics.Transform
    (
    -- * Type synonyms
      CD
    -- * Discrete cosine transform
    , dct
    , dct_
    , idct
    , idct_
    -- * Fast Fourier transform
    , fft
    , ifft
    ) where

import Control.Monad (when)
import Control.Monad.ST (ST)
import Data.Bits (shiftL, shiftR)
import Data.Complex (Complex(..), conjugate, realPart)
import Numeric.SpecFunctions (log2)
import qualified Data.Vector.Generic         as G
import qualified Data.Vector.Generic.Mutable as M
import qualified Data.Vector.Unboxed         as U
import qualified Data.Vector                 as V

type CD = Complex Double

-- | Discrete cosine transform (DCT-II).
dct :: (G.Vector v CD, G.Vector v Double, G.Vector v Int) => v Double -> v Double
dct :: forall (v :: * -> *).
(Vector v CD, Vector v Double, Vector v Int) =>
v Double -> v Double
dct = forall (v :: * -> *).
(Vector v CD, Vector v Double, Vector v Int) =>
v CD -> v Double
dctWorker forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (v :: * -> *) a b.
(Vector v a, Vector v b) =>
(a -> b) -> v a -> v b
G.map (forall a. a -> a -> Complex a
:+Double
0)
{-# INLINABLE  dct #-}
{-# SPECIAlIZE dct :: U.Vector Double -> U.Vector Double #-}
{-# SPECIAlIZE dct :: V.Vector Double -> V.Vector Double #-}

-- | Discrete cosine transform (DCT-II). Only real part of vector is
--   transformed, imaginary part is ignored.
dct_ :: (G.Vector v CD, G.Vector v Double, G.Vector v Int) => v CD -> v Double
dct_ :: forall (v :: * -> *).
(Vector v CD, Vector v Double, Vector v Int) =>
v CD -> v Double
dct_ = forall (v :: * -> *).
(Vector v CD, Vector v Double, Vector v Int) =>
v CD -> v Double
dctWorker forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (v :: * -> *) a b.
(Vector v a, Vector v b) =>
(a -> b) -> v a -> v b
G.map (\(Double
i :+ Double
_) -> Double
i forall a. a -> a -> Complex a
:+ Double
0)
{-# INLINABLE  dct_ #-}
{-# SPECIAlIZE dct_ :: U.Vector CD -> U.Vector Double #-}
{-# SPECIAlIZE dct_ :: V.Vector CD -> V.Vector Double#-}

dctWorker :: (G.Vector v CD, G.Vector v Double, G.Vector v Int) => v CD -> v Double
{-# INLINE dctWorker #-}
dctWorker :: forall (v :: * -> *).
(Vector v CD, Vector v Double, Vector v Int) =>
v CD -> v Double
dctWorker v CD
xs
  -- length 1 is special cased because shuffle algorithms fail for it.
  | forall (v :: * -> *) a. Vector v a => v a -> Int
G.length v CD
xs forall a. Eq a => a -> a -> Bool
== Int
1 = forall (v :: * -> *) a b.
(Vector v a, Vector v b) =>
(a -> b) -> v a -> v b
G.map ((Double
2forall a. Num a => a -> a -> a
*) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Complex a -> a
realPart) v CD
xs
  | forall (v :: * -> *) a. Vector v a => v a -> Bool
vectorOK v CD
xs      = forall (v :: * -> *) a b.
(Vector v a, Vector v b) =>
(a -> b) -> v a -> v b
G.map forall a. Complex a -> a
realPart forall a b. (a -> b) -> a -> b
$ forall (v :: * -> *) a b c.
(Vector v a, Vector v b, Vector v c) =>
(a -> b -> c) -> v a -> v b -> v c
G.zipWith forall a. Num a => a -> a -> a
(*) v CD
weights (forall (v :: * -> *). Vector v CD => v CD -> v CD
fft v CD
interleaved)
  | Bool
otherwise        = forall a. HasCallStack => [Char] -> a
error [Char]
"Statistics.Transform.dct: bad vector length"
  where
    interleaved :: v CD
interleaved = forall (v :: * -> *) a.
(HasCallStack, Vector v a, Vector v Int) =>
v a -> v Int -> v a
G.backpermute v CD
xs forall a b. (a -> b) -> a -> b
$ forall (v :: * -> *) a. (Vector v a, Enum a) => a -> a -> a -> v a
G.enumFromThenTo Int
0 Int
2 (Int
lenforall a. Num a => a -> a -> a
-Int
2) forall (v :: * -> *) a. Vector v a => v a -> v a -> v a
G.++
                                     forall (v :: * -> *) a. (Vector v a, Enum a) => a -> a -> a -> v a
G.enumFromThenTo (Int
lenforall a. Num a => a -> a -> a
-Int
1) (Int
lenforall a. Num a => a -> a -> a
-Int
3) Int
1
    weights :: v CD
weights = forall (v :: * -> *) a. Vector v a => a -> v a -> v a
G.cons CD
2 forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (v :: * -> *) a. Vector v a => Int -> (Int -> a) -> v a
G.generate (Int
lenforall a. Num a => a -> a -> a
-Int
1) forall a b. (a -> b) -> a -> b
$ \Int
x ->
              CD
2 forall a. Num a => a -> a -> a
* forall a. Floating a => a -> a
exp ((Double
0forall a. a -> a -> Complex a
:+(-Double
1))forall a. Num a => a -> a -> a
*Int -> CD
fi (Int
xforall a. Num a => a -> a -> a
+Int
1)forall a. Num a => a -> a -> a
*forall a. Floating a => a
piforall a. Fractional a => a -> a -> a
/(CD
2forall a. Num a => a -> a -> a
*CD
n))
      where n :: CD
n = Int -> CD
fi Int
len
    len :: Int
len = forall (v :: * -> *) a. Vector v a => v a -> Int
G.length v CD
xs



-- | Inverse discrete cosine transform (DCT-III). It's inverse of
-- 'dct' only up to scale parameter:
--
-- > (idct . dct) x = (* length x)
idct :: (G.Vector v CD, G.Vector v Double) => v Double -> v Double
idct :: forall (v :: * -> *).
(Vector v CD, Vector v Double) =>
v Double -> v Double
idct = forall (v :: * -> *).
(Vector v CD, Vector v Double) =>
v CD -> v Double
idctWorker forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (v :: * -> *) a b.
(Vector v a, Vector v b) =>
(a -> b) -> v a -> v b
G.map (forall a. a -> a -> Complex a
:+Double
0)
{-# INLINABLE  idct #-}
{-# SPECIAlIZE idct :: U.Vector Double -> U.Vector Double #-}
{-# SPECIAlIZE idct :: V.Vector Double -> V.Vector Double #-}

-- | Inverse discrete cosine transform (DCT-III). Only real part of vector is
--   transformed, imaginary part is ignored.
idct_ :: (G.Vector v CD, G.Vector v Double) => v CD -> v Double
idct_ :: forall (v :: * -> *).
(Vector v CD, Vector v Double) =>
v CD -> v Double
idct_ = forall (v :: * -> *).
(Vector v CD, Vector v Double) =>
v CD -> v Double
idctWorker forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (v :: * -> *) a b.
(Vector v a, Vector v b) =>
(a -> b) -> v a -> v b
G.map (\(Double
i :+ Double
_) -> Double
i forall a. a -> a -> Complex a
:+ Double
0)
{-# INLINABLE  idct_ #-}
{-# SPECIAlIZE idct_ :: U.Vector CD -> U.Vector Double #-}
{-# SPECIAlIZE idct_ :: V.Vector CD -> V.Vector Double #-}

idctWorker :: (G.Vector v CD, G.Vector v Double) => v CD -> v Double
{-# INLINE idctWorker #-}
idctWorker :: forall (v :: * -> *).
(Vector v CD, Vector v Double) =>
v CD -> v Double
idctWorker v CD
xs
  | forall (v :: * -> *) a. Vector v a => v a -> Bool
vectorOK v CD
xs = forall (v :: * -> *) a. Vector v a => Int -> (Int -> a) -> v a
G.generate Int
len Int -> Double
interleave
  | Bool
otherwise   = forall a. HasCallStack => [Char] -> a
error [Char]
"Statistics.Transform.dct: bad vector length"
  where
    interleave :: Int -> Double
interleave Int
z | forall a. Integral a => a -> Bool
even Int
z    = v Double
vals forall (v :: * -> *) a. Vector v a => v a -> Int -> a
`G.unsafeIndex` Int -> Int
halve Int
z
                 | Bool
otherwise = v Double
vals forall (v :: * -> *) a. Vector v a => v a -> Int -> a
`G.unsafeIndex` (Int
len forall a. Num a => a -> a -> a
- Int -> Int
halve Int
z forall a. Num a => a -> a -> a
- Int
1)
    vals :: v Double
vals = forall (v :: * -> *) a b.
(Vector v a, Vector v b) =>
(a -> b) -> v a -> v b
G.map forall a. Complex a -> a
realPart forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (v :: * -> *). Vector v CD => v CD -> v CD
ifft forall a b. (a -> b) -> a -> b
$ forall (v :: * -> *) a b c.
(Vector v a, Vector v b, Vector v c) =>
(a -> b -> c) -> v a -> v b -> v c
G.zipWith forall a. Num a => a -> a -> a
(*) v CD
weights v CD
xs
    weights :: v CD
weights
      = forall (v :: * -> *) a. Vector v a => a -> v a -> v a
G.cons CD
n
      forall a b. (a -> b) -> a -> b
$ forall (v :: * -> *) a. Vector v a => Int -> (Int -> a) -> v a
G.generate (Int
len forall a. Num a => a -> a -> a
- Int
1) forall a b. (a -> b) -> a -> b
$ \Int
x -> CD
2 forall a. Num a => a -> a -> a
* CD
n forall a. Num a => a -> a -> a
* forall a. Floating a => a -> a
exp ((Double
0forall a. a -> a -> Complex a
:+Double
1) forall a. Num a => a -> a -> a
* Int -> CD
fi (Int
xforall a. Num a => a -> a -> a
+Int
1) forall a. Num a => a -> a -> a
* forall a. Floating a => a
piforall a. Fractional a => a -> a -> a
/(CD
2forall a. Num a => a -> a -> a
*CD
n))
      where n :: CD
n = Int -> CD
fi Int
len
    len :: Int
len = forall (v :: * -> *) a. Vector v a => v a -> Int
G.length v CD
xs



-- | Inverse fast Fourier transform.
ifft :: G.Vector v CD => v CD -> v CD
ifft :: forall (v :: * -> *). Vector v CD => v CD -> v CD
ifft v CD
xs
  | forall (v :: * -> *) a. Vector v a => v a -> Bool
vectorOK v CD
xs = forall (v :: * -> *) a b.
(Vector v a, Vector v b) =>
(a -> b) -> v a -> v b
G.map ((forall a. Fractional a => a -> a -> a
/Int -> CD
fi (forall (v :: * -> *) a. Vector v a => v a -> Int
G.length v CD
xs)) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Num a => Complex a -> Complex a
conjugate) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (v :: * -> *). Vector v CD => v CD -> v CD
fft forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (v :: * -> *) a b.
(Vector v a, Vector v b) =>
(a -> b) -> v a -> v b
G.map forall a. Num a => Complex a -> Complex a
conjugate forall a b. (a -> b) -> a -> b
$ v CD
xs
  | Bool
otherwise   = forall a. HasCallStack => [Char] -> a
error [Char]
"Statistics.Transform.ifft: bad vector length"
{-# INLINABLE  ifft #-}
{-# SPECIAlIZE ifft :: U.Vector CD -> U.Vector CD #-}
{-# SPECIAlIZE ifft :: V.Vector CD -> V.Vector CD #-}

-- | Radix-2 decimation-in-time fast Fourier transform.
fft :: G.Vector v CD => v CD -> v CD
fft :: forall (v :: * -> *). Vector v CD => v CD -> v CD
fft v CD
v | forall (v :: * -> *) a. Vector v a => v a -> Bool
vectorOK v CD
v  = forall (v :: * -> *) a.
Vector v a =>
(forall s. ST s (Mutable v s a)) -> v a
G.create forall a b. (a -> b) -> a -> b
$ do Mutable v s CD
mv <- forall (m :: * -> *) (v :: * -> *) a.
(PrimMonad m, Vector v a) =>
v a -> m (Mutable v (PrimState m) a)
G.thaw v CD
v
                                    forall (v :: * -> * -> *) s. MVector v CD => v s CD -> ST s ()
mfft Mutable v s CD
mv
                                    forall (m :: * -> *) a. Monad m => a -> m a
return Mutable v s CD
mv
      | Bool
otherwise   = forall a. HasCallStack => [Char] -> a
error [Char]
"Statistics.Transform.fft: bad vector length"
{-# INLINABLE  fft #-}
{-# SPECIAlIZE fft :: U.Vector CD -> U.Vector CD #-}
{-# SPECIAlIZE fft :: V.Vector CD -> V.Vector CD #-}

-- Vector length must be power of two. It's not checked
mfft :: (M.MVector v CD) => v s CD -> ST s ()
{-# INLINE mfft #-}
mfft :: forall (v :: * -> * -> *) s. MVector v CD => v s CD -> ST s ()
mfft v s CD
vec = Int -> Int -> ST s ()
bitReverse Int
0 Int
0
 where
  bitReverse :: Int -> Int -> ST s ()
bitReverse Int
i Int
j | Int
i forall a. Eq a => a -> a -> Bool
== Int
lenforall a. Num a => a -> a -> a
-Int
1 = Int -> Int -> ST s ()
stage Int
0 Int
1
                 | Bool
otherwise  = do
    forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
i forall a. Ord a => a -> a -> Bool
< Int
j) forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> Int -> m ()
M.swap v s CD
vec Int
i Int
j
    let inner :: Int -> Int -> ST s ()
inner Int
k Int
l | Int
k forall a. Ord a => a -> a -> Bool
<= Int
l    = Int -> Int -> ST s ()
inner (Int
k forall a. Bits a => a -> Int -> a
`shiftR` Int
1) (Int
lforall a. Num a => a -> a -> a
-Int
k)
                  | Bool
otherwise = Int -> Int -> ST s ()
bitReverse (Int
iforall a. Num a => a -> a -> a
+Int
1) (Int
lforall a. Num a => a -> a -> a
+Int
k)
    Int -> Int -> ST s ()
inner (Int
len forall a. Bits a => a -> Int -> a
`shiftR` Int
1) Int
j
  stage :: Int -> Int -> ST s ()
stage Int
l !Int
l1 | Int
l forall a. Eq a => a -> a -> Bool
== Int
m    = forall (m :: * -> *) a. Monad m => a -> m a
return ()
              | Bool
otherwise = do
    let !l2 :: Int
l2 = Int
l1 forall a. Bits a => a -> Int -> a
`shiftL` Int
1
        !e :: Double
e  = -Double
6.283185307179586forall a. Fractional a => a -> a -> a
/forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
l2
        flight :: Int -> Double -> ST s ()
flight Int
j !Double
a | Int
j forall a. Eq a => a -> a -> Bool
== Int
l1   = Int -> Int -> ST s ()
stage (Int
lforall a. Num a => a -> a -> a
+Int
1) Int
l2
                    | Bool
otherwise = do
          let butterfly :: Int -> ST s ()
butterfly Int
i | Int
i forall a. Ord a => a -> a -> Bool
>= Int
len  = Int -> Double -> ST s ()
flight (Int
jforall a. Num a => a -> a -> a
+Int
1) (Double
aforall a. Num a => a -> a -> a
+Double
e)
                          | Bool
otherwise = do
                let i1 :: Int
i1 = Int
i forall a. Num a => a -> a -> a
+ Int
l1
                Double
xi1 :+ Double
yi1 <- forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
M.read v s CD
vec Int
i1
                let !c :: Double
c = forall a. Floating a => a -> a
cos Double
a
                    !s :: Double
s = forall a. Floating a => a -> a
sin Double
a
                    d :: CD
d  = (Double
cforall a. Num a => a -> a -> a
*Double
xi1 forall a. Num a => a -> a -> a
- Double
sforall a. Num a => a -> a -> a
*Double
yi1) forall a. a -> a -> Complex a
:+ (Double
sforall a. Num a => a -> a -> a
*Double
xi1 forall a. Num a => a -> a -> a
+ Double
cforall a. Num a => a -> a -> a
*Double
yi1)
                CD
ci <- forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
M.read v s CD
vec Int
i
                forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
M.write v s CD
vec Int
i1 (CD
ci forall a. Num a => a -> a -> a
- CD
d)
                forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
M.write v s CD
vec Int
i (CD
ci forall a. Num a => a -> a -> a
+ CD
d)
                Int -> ST s ()
butterfly (Int
iforall a. Num a => a -> a -> a
+Int
l2)
          Int -> ST s ()
butterfly Int
j
    Int -> Double -> ST s ()
flight Int
0 Double
0
  len :: Int
len = forall (v :: * -> * -> *) a s. MVector v a => v s a -> Int
M.length v s CD
vec
  m :: Int
m   = Int -> Int
log2 Int
len


----------------------------------------------------------------
-- Helpers
----------------------------------------------------------------

fi :: Int -> CD
fi :: Int -> CD
fi = forall a b. (Integral a, Num b) => a -> b
fromIntegral

halve :: Int -> Int
halve :: Int -> Int
halve = (forall a. Bits a => a -> Int -> a
`shiftR` Int
1)

vectorOK :: G.Vector v a => v a -> Bool
{-# INLINE vectorOK #-}
vectorOK :: forall (v :: * -> *) a. Vector v a => v a -> Bool
vectorOK v a
v = (Int
1 forall a. Bits a => a -> Int -> a
`shiftL` Int -> Int
log2 Int
n) forall a. Eq a => a -> a -> Bool
== Int
n where n :: Int
n = forall (v :: * -> *) a. Vector v a => v a -> Int
G.length v a
v