{-# LANGUAGE BangPatterns #-} {-# LANGUAGE MagicHash #-} {-# LANGUAGE ScopedTypeVariables #-} {-# OPTIONS_GHC -Wall #-} module Streaming.FFT ( -- * streaming fft streamFFT -- * types , Transform(..) , Bin(..) , Signal(..) ) where import Prelude ( RealFloat ) import Control.Monad (Monad(return)) import Control.Monad.Primitive import Data.Complex (Complex(..)) import Data.Either (Either(..)) import Data.Eq (Eq((==))) import Data.Function (($)) import Data.Ord (Ord(..)) import Data.Primitive.PrimArray import Data.Primitive.Types import GHC.Classes (modInt#) import GHC.Num (Num(..)) import GHC.Real (fromIntegral, RealFrac(..)) import GHC.Types (Int(..)) import Streaming.FFT.Internal (initialDFT, subDFT, updateWindow', rToComplex) import Streaming.FFT.Types (Window(..), Transform(..), Signal(..), Bin(..)) import Streaming import Streaming.Prelude (next, yield) data Depleted = NotDepleted -- ^ bin is not depleted | Past !Int -- ^ how many bins we have past binDepleted :: forall e. (Num e, Ord e, RealFrac e) => Bin e -> e -> e -> Depleted binDepleted (Bin binSize) old new = let !k = new - (old + fromIntegral binSize) in if k > 0 then Past (floor k) else NotDepleted -- [NOTE]: A drawback of the dense-stream optimisation -- is that we must keep track of the number of bins that -- we ingest that are 0. if too many are 0 w.r.t. the signal -- size, then we must fall back to the /O(n log n) computation -- until we reach another dense area of the stream. This amounts -- to keeping an Int around that counts the number of bins that -- were equal to zero, it gets incremented after each bin is finished -- loading. So, there should realy be two 'thereafter' functions, -- and 'loadInitial' should do some additional checks. -- This is currently not the case. loadInitial :: forall m e b. (Prim e, PrimMonad m, RealFloat e) => MutablePrimArray (PrimState m) (Complex e) -- ^ array to which we should allocate -> Bin e -- ^ bin size -> Signal e -- ^ signal size -> Int -- ^ index -> Int -- ^ bin accumulator -> e -- ^ bin pivot -> Int -- ^ have we finished consuming the signal -> Stream (Of e) m b -- first part of stream -> m (Stream (Of e) m b) -- stream minus original signal loadInitial !mpa !b s@(Signal !sigSize) !ix !binAccum !binFirst !untilSig st = if (untilSig >= sigSize) then return st else do e <- next st case e of Left _ -> return st Right (x, rest) -> if ix == 0 then loadInitial mpa b s (ix + 1) binAccum x untilSig st else do let isDepleted = binDepleted b binFirst x case isDepleted of NotDepleted -> loadInitial mpa b s ix (binAccum + 1) binFirst untilSig rest Past i -> do let !k = rToComplex (fromIntegral binAccum) :: Complex e !_ <- writePrimArray mpa (unsafeMod (ix - 1 + untilSig) sigSize) k :: m () loadInitial mpa b s (ix + i) 0 x (untilSig + 1) rest thereafter :: forall m e b c. (Prim e, PrimMonad m, RealFloat e) => (Transform m e -> m c) -- ^ extract -> Bin e -- ^ bin size -> Signal e -- ^ signal size -> Int -- ^ index -> Int -- ^ have we filled a bin -> e -- ^ first thing in the bin -> Window m e -- ^ window -> Transform m e -- ^ transform -> Stream (Of e) m b -> Stream (Of c) m b thereafter extract !b !s !ix !binAccum !binFirst win trans st = do e <- lift $ next st case e of Left r -> return r Right (x, rest) -> if ix == 0 then thereafter extract b s (ix + 1) binAccum x win trans st else do let isDepleted = binDepleted b binFirst x case isDepleted of NotDepleted -> thereafter extract b s ix (binAccum + 1) binFirst win trans rest Past i -> do let k :: Complex e !k = rToComplex (fromIntegral binAccum) !trans' <- lift $ subDFT s win k trans !info <- lift $ extract trans' yield info -- a problem is that if too many empty bins pass, -- the optimised streaming-fft algorithm fails, and we -- need to revert (temporarily) to the original O(n log n) -- algorithm. !_ <- lift $ updateWindow' win k i thereafter extract b s (ix + i) 0 x win trans' rest -- | 'streamFFT' is based off ideas from signal processing, with an optimisation -- outlined in . -- Here, I will give you an outline of how this works. The idea is that we -- have a stream of data, which we will divide into 'Signal's, and each 'Signal' -- is something for which we want to compute the DFT. Each signal is divided into -- 'Bin's (more on this later, but you can just think of 'Bin's as a chunk of a -- 'Signal', where all the chunks are of equal length). We treat our stream not as -- contiguous blocks of 'Signal's, but as overlapping 'Signal's, where each overlap -- is one 'Bin'-length. The motivation for the blog post is to reduce the work of -- this overlap; they show a way to compute the DFT of each 'Signal' subsequent -- to the initial in /O(n)/ time, instead of the typical /O(n log n)/ time, -- by abusing the overlap. -- -- Consider you would like to compute the Fourier Transform of the signal -- -- \[ -- x_{i-n+1}, x_{i-n+2}, ..., x_{i-1}, x_{i}. -- \] -- -- However this means that when you receive \( x_{i+1} \), you'll be the computing -- the Fourier Transform of -- -- \[ -- x_{i-n+2}, x_{i-n+3}, ..., x_{i}, x_{i+1}, -- \] -- -- which is almost identical to the first sequence. How do we avoid extra work? -- -- Assume data windows to be of length \( N \) (this corresponds to the number of -- 'Bin's in the 'Signal'). Let -- the original data window be \( x_{1} \), whose first sample is \( x_{old} = x_{1}[0] \). -- (here, \( a[k] \) is used to denote accessing the \( (k-1)th \) element from -- a sequence \( a \) ). Let your new data window be denoted as \( x_{2} \), whose -- bins are one left-shifted version of \( x_{1} \), i.e. -- \( x_{2}[k] = x_{1}[k+1]\) for \(k = 0, 1, ... N - 2 \), plus a new arrived datum to -- position \( N - 1 \), which is denoted as \( x_{new} = x_{2}[N - 1]\). -- -- The following will compute the N-point DFT, \( X_{2} \) of the new data set -- \( x_{2} \) from that of the already computed and stored N-point DFT -- \( X_{1} \) of the old data set \( x_{1} \): -- -- \[ -- X{2}[k] = e^{2 \pi i k / N} * (X{1}[k] + (x_{new} - x_{old})) -- \] -- -- for each \( k = 0, 1, ..., N - 1 \). This updated computation of \( X{2} \) -- pre-computed \( X{1} \) requires \( N \) complex multiplications and \( N \) -- real additions. Compared to a direct N-point DFT which requires \( N log_{2}(N) \) -- complex multiply-accumulate operations, this is an improvement by a factor of -- \( log_{2}(N) \), which for example at N=1024 would translate to a speedup of -- about 10. -- -- Another advantage of this algorithm as this it is amenable to being done in-place. -- `streamFFT` in fact does do this, and for that reason allocations are kept to an -- absolute minimum. -- -- streamFFT :: forall m a b c. (Prim a, PrimMonad m, RealFloat a) => (Transform m a -> m c) -- ^ extraction method. This is a function that takes a 'Transform' -- and produces (or 'extracts') some value from it. It is used -- to produce the values in the output stream. -> Bin a -- ^ bin size -> Signal a -- ^ signal size -> Stream (Of a) m b -- ^ input stream -> Stream (Of c) m b -- ^ output stream {-# INLINABLE streamFFT #-} streamFFT extract b s@(Signal sigSize) strm = do -- Allocate the one array mpaW <- lift $ newPrimArray sigSize let win = Window mpaW -- Grab the first signal from the stream subStrm <- lift $ loadInitial mpaW b s 0 0 0 0 strm -- Compute the transform on the signal we just grabbed -- so we can perform our dense-stream optimisation !initialT <- lift $ initialDFT win -- Extract information from that transform !initialInfo <- lift $ extract initialT -- Yield that information to the new stream !_ <- yield initialInfo -- Now go thereafter extract b s 0 0 0 win initialT subStrm -- | Only safe when the second argument is not 0 unsafeMod :: Int -> Int -> Int unsafeMod (I# x#) (I# y#) = I# (modInt# x# y#) {-# INLINE unsafeMod #-} -- this should happen anyway. trust but verify.