{-| 
    Module      : Vocoder
    Description : Phase vocoder
    Copyright   : (c) Celina Pawlińska, 2020
                      Marek Materzok, 2021
    License     : BSD2

This module implements the phase vocoder algorithms. 
The implementation is designed to be used directly or to be integrated
into some convenient abstraction (streaming or FRP).
-}
module Vocoder (
      Moduli,
      Phase,
      PhaseInc,
      Frame,
      Window,
      HopSize,
      Length,
      STFTFrame,
      FFTOutput,
      VocoderParams,
      vocoderParams,
      vocFrameLength,
      vocInputFrameLength,
      vocFreqFrameLength,
      vocHopSize,
      vocWindow,
      doFFT,
      doIFFT,
      analysisBlock,
      analysisStep,
      analysisStage,
      synthesisBlock,
      synthesisStep,
      synthesisStage,
      zeroPhase,
      volumeCoeff,
      frameFromComplex,
      frameToComplex,
      addFrames
    ) where

import Data.List
import Data.Complex
import Data.Fixed
import Data.Tuple
import Control.Arrow
import Numeric.FFT.Vector.Invertible as FFT
import Numeric.FFT.Vector.Plan as FFTp
import qualified Data.Vector.Storable as V

-- | Complex moduli of FFT frames. Represent signal amplitudes.
type Moduli = V.Vector Double

-- | Complex arguments of FFT frames. Represent signal phases.
type Phase = V.Vector Double

-- | Phase increments. Represent the deviation of the phase difference
-- between successive frames from the expected difference for the center
-- frequencies of the FFT bins.
type PhaseInc = V.Vector Double

-- | Time domain frame.
type Frame = V.Vector Double

-- | Sampled STFT window function.
type Window = Frame

-- | Offset between successive STFT frames, in samples.
type HopSize = Int

-- | Size in samples.
type Length = Int

-- | STFT processing unit.
type STFTFrame = (Moduli, PhaseInc)

-- | Frequency domain frame.
type FFTOutput = V.Vector (Complex Double)

-- | Type of FFT plans for real signals.
type FFTPlan = FFTp.Plan Double (Complex Double)

-- | Type of IFFT plans for real signals.
type IFFTPlan = FFTp.Plan (Complex Double) Double

-- | Configuration parameters for the phase vocoder algorithm.
data VocoderParams = VocoderParams{
    -- | FFT plan used in analysis stage.
    VocoderParams -> FFTPlan
vocFFTPlan  :: FFTPlan,
    -- | FFT plan used in synthesis stage.
    VocoderParams -> IFFTPlan
vocIFFTPlan :: IFFTPlan,
    -- | STFT hop size.
    VocoderParams -> HopSize
vocHopSize :: HopSize,
    -- | Window function used during analysis and synthesis.
    VocoderParams -> Window
vocWindow :: Window
    -- TODO thread safety?
}

-- | FFT frequency frame length.
vocFreqFrameLength :: VocoderParams -> Length
vocFreqFrameLength :: VocoderParams -> HopSize
vocFreqFrameLength VocoderParams
par = FFTPlan -> HopSize
forall b a. Storable b => Plan a b -> HopSize
planOutputSize (FFTPlan -> HopSize) -> FFTPlan -> HopSize
forall a b. (a -> b) -> a -> b
$ VocoderParams -> FFTPlan
vocFFTPlan VocoderParams
par

-- | FFT frame length. Can be larger than `vocInputFrameLength` for zero-padding.
vocFrameLength :: VocoderParams -> Length
vocFrameLength :: VocoderParams -> HopSize
vocFrameLength VocoderParams
par = FFTPlan -> HopSize
forall a b. Storable a => Plan a b -> HopSize
planInputSize (FFTPlan -> HopSize) -> FFTPlan -> HopSize
forall a b. (a -> b) -> a -> b
$ VocoderParams -> FFTPlan
vocFFTPlan VocoderParams
par

-- | STFT frame length.
vocInputFrameLength :: VocoderParams -> Length
vocInputFrameLength :: VocoderParams -> HopSize
vocInputFrameLength VocoderParams
par = Window -> HopSize
forall a. Storable a => Vector a -> HopSize
V.length (Window -> HopSize) -> Window -> HopSize
forall a b. (a -> b) -> a -> b
$ VocoderParams -> Window
vocWindow VocoderParams
par

-- | Create a vocoder configuration.
vocoderParams :: Length -> HopSize -> Window -> VocoderParams
vocoderParams :: HopSize -> HopSize -> Window -> VocoderParams
vocoderParams HopSize
len HopSize
hs Window
wnd = FFTPlan -> IFFTPlan -> HopSize -> Window -> VocoderParams
VocoderParams (Transform Double (Complex Double) -> HopSize -> FFTPlan
forall a b.
(Storable a, Storable b) =>
Transform a b -> HopSize -> Plan a b
plan Transform Double (Complex Double)
dftR2C HopSize
len) (Transform (Complex Double) Double -> HopSize -> IFFTPlan
forall a b.
(Storable a, Storable b) =>
Transform a b -> HopSize -> Plan a b
plan Transform (Complex Double) Double
dftC2R HopSize
len) HopSize
hs Window
wnd

-- | Apply a window function on a time domain frame.
applyWindow :: Window -> Frame -> Frame
applyWindow :: Window -> Window -> Window
applyWindow = (Double -> Double -> Double) -> Window -> Window -> Window
forall a b c.
(Storable a, Storable b, Storable c) =>
(a -> b -> c) -> Vector a -> Vector b -> Vector c
V.zipWith Double -> Double -> Double
forall a. Num a => a -> a -> a
(*)

-- | Change the vector indexing so that the sample at the middle has the number 0.
-- This is done so that the FFT of the window has zero phase, and therefore does not
-- introduce phase shifts in the signal.
rewind :: (V.Storable a) => V.Vector a -> V.Vector a
rewind :: Vector a -> Vector a
rewind Vector a
vec = (Vector a -> Vector a -> Vector a)
-> (Vector a, Vector a) -> Vector a
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry Vector a -> Vector a -> Vector a
forall a. Storable a => Vector a -> Vector a -> Vector a
(V.++) ((Vector a, Vector a) -> Vector a)
-> (Vector a, Vector a) -> Vector a
forall a b. (a -> b) -> a -> b
$ (Vector a, Vector a) -> (Vector a, Vector a)
forall a b. (a, b) -> (b, a)
swap ((Vector a, Vector a) -> (Vector a, Vector a))
-> (Vector a, Vector a) -> (Vector a, Vector a)
forall a b. (a -> b) -> a -> b
$ HopSize -> Vector a -> (Vector a, Vector a)
forall a. Storable a => HopSize -> Vector a -> (Vector a, Vector a)
V.splitAt (Vector a -> HopSize
forall a. Storable a => Vector a -> HopSize
V.length Vector a
vec HopSize -> HopSize -> HopSize
forall a. Integral a => a -> a -> a
`div` HopSize
2) Vector a
vec

-- | Zero-pad the signal symmetrically from both sides.
addZeroPadding :: Length
    -> Frame
    -> Frame
addZeroPadding :: HopSize -> Window -> Window
addZeroPadding HopSize
len Window
v
    | HopSize
diff HopSize -> HopSize -> Bool
forall a. Ord a => a -> a -> Bool
< HopSize
0  = [Char] -> Window
forall a. HasCallStack => [Char] -> a
error ([Char] -> Window) -> [Char] -> Window
forall a b. (a -> b) -> a -> b
$ [Char]
"addZeroPadding: input is " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ (HopSize -> [Char]
forall a. Show a => a -> [Char]
show HopSize
diff) [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
" samples longer than target length"
    | HopSize
diff HopSize -> HopSize -> Bool
forall a. Eq a => a -> a -> Bool
== HopSize
0 = Window
v
    | Bool
otherwise = Window
res
    where
    l :: HopSize
l = Window -> HopSize
forall a. Storable a => Vector a -> HopSize
V.length Window
v
    diff :: HopSize
diff = HopSize
len HopSize -> HopSize -> HopSize
forall a. Num a => a -> a -> a
- HopSize
l
    halfdiff :: HopSize
halfdiff = HopSize
diff HopSize -> HopSize -> HopSize
forall a. Num a => a -> a -> a
- (HopSize
diff HopSize -> HopSize -> HopSize
forall a. Integral a => a -> a -> a
`div` HopSize
2)
    res :: Window
res = Window -> Window -> Window
forall a. Storable a => Vector a -> Vector a -> Vector a
(V.++) (Window -> Window -> Window
forall a. Storable a => Vector a -> Vector a -> Vector a
(V.++) (HopSize -> Double -> Window
forall a. Storable a => HopSize -> a -> Vector a
V.replicate HopSize
halfdiff Double
0) Window
v) (HopSize -> Double -> Window
forall a. Storable a => HopSize -> a -> Vector a
V.replicate (HopSize
diffHopSize -> HopSize -> HopSize
forall a. Num a => a -> a -> a
-HopSize
halfdiff) Double
0)

-- | Perform FFT processing, which includes the actual FFT, rewinding, zero-paddding
-- and windowing.
doFFT :: VocoderParams -> Frame -> FFTOutput
doFFT :: VocoderParams -> Window -> FFTOutput
doFFT VocoderParams
par =
    FFTPlan -> Window -> FFTOutput
forall (v :: * -> *) a b.
(Vector v a, Vector v b, Storable a, Storable b) =>
Plan a b -> v a -> v b
FFT.execute (VocoderParams -> FFTPlan
vocFFTPlan VocoderParams
par) (Window -> FFTOutput) -> (Window -> Window) -> Window -> FFTOutput
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Window -> Window
forall a. Storable a => Vector a -> Vector a
rewind (Window -> Window) -> (Window -> Window) -> Window -> Window
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HopSize -> Window -> Window
addZeroPadding (VocoderParams -> HopSize
vocFrameLength VocoderParams
par) (Window -> Window) -> (Window -> Window) -> Window -> Window
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Window -> Window -> Window
applyWindow (VocoderParams -> Window
vocWindow VocoderParams
par)

-- | Perform analysis on a sequence of frames. This consists of FFT processing
-- and performing analysis on frequency domain frames.
analysisStage :: Traversable t => VocoderParams -> Phase -> t Frame -> (Phase, t STFTFrame)
analysisStage :: VocoderParams -> Window -> t Window -> (Window, t STFTFrame)
analysisStage VocoderParams
par Window
ph = (Window -> Window -> (Window, STFTFrame))
-> Window -> t Window -> (Window, t STFTFrame)
forall (t :: * -> *) a b c.
Traversable t =>
(a -> b -> (a, c)) -> a -> t b -> (a, t c)
mapAccumL (VocoderParams -> Window -> Window -> (Window, STFTFrame)
analysisBlock VocoderParams
par) Window
ph

-- | Perform FFT transform and frequency-domain analysis.
analysisBlock :: VocoderParams -> Phase -> Frame -> (Phase, STFTFrame)
analysisBlock :: VocoderParams -> Window -> Window -> (Window, STFTFrame)
analysisBlock VocoderParams
par Window
prev_ph Window
vec = HopSize -> HopSize -> Window -> FFTOutput -> (Window, STFTFrame)
analysisStep (VocoderParams -> HopSize
vocHopSize VocoderParams
par) (VocoderParams -> HopSize
vocFrameLength VocoderParams
par) Window
prev_ph (VocoderParams -> Window -> FFTOutput
doFFT VocoderParams
par Window
vec)

-- | Analyze a frequency domain frame. Phase from a previous frame must be supplied.
-- It returns the phase of the analyzed frame and the result.
analysisStep :: HopSize -> Length -> Phase -> FFTOutput -> (Phase, STFTFrame)
analysisStep :: HopSize -> HopSize -> Window -> FFTOutput -> (Window, STFTFrame)
analysisStep HopSize
h HopSize
eN Window
prev_ph FFTOutput
vec =
    (Window
ph,(Window
mag,Window
ph_inc))
    where
    (Window
mag, Window
ph) = FFTOutput -> STFTFrame
frameFromComplex FFTOutput
vec
    ph_inc :: Window
ph_inc = (HopSize -> Double -> Double) -> Window -> Window
forall a b.
(Storable a, Storable b) =>
(HopSize -> a -> b) -> Vector a -> Vector b
V.imap (HopSize -> HopSize -> HopSize -> Double -> Double
calcPhaseInc HopSize
eN HopSize
h) (Window -> Window) -> Window -> Window
forall a b. (a -> b) -> a -> b
$ (Double -> Double -> Double) -> Window -> Window -> Window
forall a b c.
(Storable a, Storable b, Storable c) =>
(a -> b -> c) -> Vector a -> Vector b -> Vector c
V.zipWith (-) Window
ph Window
prev_ph

-- | Wraps an angle (in radians) to the range [-pi : pi].
wrap :: Double -> Double
wrap :: Double -> Double
wrap Double
e = (Double
eDouble -> Double -> Double
forall a. Num a => a -> a -> a
+Double
forall a. Floating a => a
pi) Double -> Double -> Double
forall a. Real a => a -> a -> a
`mod'` (Double
2Double -> Double -> Double
forall a. Num a => a -> a -> a
*Double
forall a. Floating a => a
pi) Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
forall a. Floating a => a
pi

calcPhaseInc :: Length -> HopSize -> Int -> Double -> Double
calcPhaseInc :: HopSize -> HopSize -> HopSize -> Double -> Double
calcPhaseInc HopSize
eN HopSize
hop HopSize
k Double
ph_diff =
    (Double
omega Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double -> Double
wrap (Double
ph_diff Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
omega)) Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ HopSize -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral HopSize
hop
    where
    omega :: Double
omega = (Double
2Double -> Double -> Double
forall a. Num a => a -> a -> a
*Double
forall a. Floating a => a
piDouble -> Double -> Double
forall a. Num a => a -> a -> a
*HopSize -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral HopSize
kDouble -> Double -> Double
forall a. Num a => a -> a -> a
*HopSize -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral HopSize
hop) Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ HopSize -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral HopSize
eN

-- | Perform synthesis on a sequence of frames. This consists of performing
-- synthesis and IFFT processing.
synthesisStage :: Traversable t => VocoderParams -> Phase -> t STFTFrame -> (Phase, t Frame)
synthesisStage :: VocoderParams -> Window -> t STFTFrame -> (Window, t Window)
synthesisStage VocoderParams
par Window
ph t STFTFrame
frs = (Window -> STFTFrame -> STFTFrame)
-> Window -> t STFTFrame -> (Window, t Window)
forall (t :: * -> *) a b c.
Traversable t =>
(a -> b -> (a, c)) -> a -> t b -> (a, t c)
mapAccumL (VocoderParams -> Window -> STFTFrame -> STFTFrame
synthesisBlock VocoderParams
par) Window
ph t STFTFrame
frs

-- | Perform frequency-domain synthesis and IFFT transform.
synthesisBlock :: VocoderParams -> Phase -> STFTFrame -> (Phase, Frame)
synthesisBlock :: VocoderParams -> Window -> STFTFrame -> STFTFrame
synthesisBlock VocoderParams
par Window
ph STFTFrame
fr = (Window -> Window
forall a. a -> a
id (Window -> Window)
-> (FFTOutput -> Window) -> (Window, FFTOutput) -> STFTFrame
forall (a :: * -> * -> *) b c b' c'.
Arrow a =>
a b c -> a b' c' -> a (b, b') (c, c')
*** VocoderParams -> FFTOutput -> Window
doIFFT VocoderParams
par) ((Window, FFTOutput) -> STFTFrame)
-> (Window, FFTOutput) -> STFTFrame
forall a b. (a -> b) -> a -> b
$ HopSize -> Window -> STFTFrame -> (Window, FFTOutput)
synthesisStep (VocoderParams -> HopSize
vocHopSize VocoderParams
par) Window
ph STFTFrame
fr

-- | Synthesize a frequency domain frame. Phase from the previously synthesized frame
-- must be supplied. It returns the phase of the synthesized frame and the result.
synthesisStep :: HopSize -> Phase -> STFTFrame -> (Phase, FFTOutput)
synthesisStep :: HopSize -> Window -> STFTFrame -> (Window, FFTOutput)
synthesisStep HopSize
hop Window
ph (Window
mag, Window
ph_inc) =
    (Window
new_ph, STFTFrame -> FFTOutput
frameToComplex (Window
mag, Window
new_ph))
    where
    new_ph :: Window
new_ph = (Double -> Double -> Double) -> Window -> Window -> Window
forall a b c.
(Storable a, Storable b, Storable c) =>
(a -> b -> c) -> Vector a -> Vector b -> Vector c
V.zipWith Double -> Double -> Double
forall a. Num a => a -> a -> a
(+) Window
ph (Window -> Window) -> Window -> Window
forall a b. (a -> b) -> a -> b
$ (Double -> Double) -> Window -> Window
forall a b.
(Storable a, Storable b) =>
(a -> b) -> Vector a -> Vector b
V.map (Double -> Double -> Double
forall a. Num a => a -> a -> a
* HopSize -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral HopSize
hop) Window
ph_inc

-- | Perform IFFT processing, which includes the actual IFFT, rewinding, removing padding
-- and windowing.
doIFFT :: VocoderParams -> FFTOutput -> Frame
doIFFT :: VocoderParams -> FFTOutput -> Window
doIFFT VocoderParams
par =
    Window -> Window -> Window
applyWindow (VocoderParams -> Window
vocWindow VocoderParams
par) (Window -> Window) -> (FFTOutput -> Window) -> FFTOutput -> Window
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HopSize -> Window -> Window
forall a. Storable a => HopSize -> Vector a -> Vector a
cutCenter (VocoderParams -> HopSize
vocInputFrameLength VocoderParams
par) (Window -> Window) -> (FFTOutput -> Window) -> FFTOutput -> Window
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Window -> Window
forall a. Storable a => Vector a -> Vector a
rewind (Window -> Window) -> (FFTOutput -> Window) -> FFTOutput -> Window
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IFFTPlan -> FFTOutput -> Window
forall (v :: * -> *) a b.
(Vector v a, Vector v b, Storable a, Storable b) =>
Plan a b -> v a -> v b
FFT.execute (VocoderParams -> IFFTPlan
vocIFFTPlan VocoderParams
par)

-- | Cut the center of a time domain frame, discarding zero padding.
cutCenter :: (V.Storable a) => Length -> V.Vector a -> V.Vector a
cutCenter :: HopSize -> Vector a -> Vector a
cutCenter HopSize
len Vector a
vec = HopSize -> Vector a -> Vector a
forall a. Storable a => HopSize -> Vector a -> Vector a
V.take HopSize
len (Vector a -> Vector a) -> Vector a -> Vector a
forall a b. (a -> b) -> a -> b
$ HopSize -> Vector a -> Vector a
forall a. Storable a => HopSize -> Vector a -> Vector a
V.drop ((Vector a -> HopSize
forall a. Storable a => Vector a -> HopSize
V.length Vector a
vec HopSize -> HopSize -> HopSize
forall a. Num a => a -> a -> a
- HopSize
len) HopSize -> HopSize -> HopSize
forall a. Integral a => a -> a -> a
`div` HopSize
2) Vector a
vec

-- | Zero phase for a given vocoder configuration.
-- Can be used to initialize the synthesis stage.
zeroPhase :: VocoderParams -> Phase
zeroPhase :: VocoderParams -> Window
zeroPhase VocoderParams
par = HopSize -> Double -> Window
forall a. Storable a => HopSize -> a -> Vector a
V.replicate (VocoderParams -> HopSize
vocFreqFrameLength VocoderParams
par) Double
0

-- | An amplitude change coefficient for the processing pipeline.
-- Can be used to ensure that the output has the same volume as the input.
volumeCoeff :: VocoderParams -> Double
volumeCoeff :: VocoderParams -> Double
volumeCoeff VocoderParams
par = HopSize -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral (VocoderParams -> HopSize
vocHopSize VocoderParams
par) Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Window -> Double
forall a. (Storable a, Num a) => Vector a -> a
V.sum ((Double -> Double) -> Window -> Window
forall a b.
(Storable a, Storable b) =>
(a -> b) -> Vector a -> Vector b
V.map (Double -> Double -> Double
forall a. Floating a => a -> a -> a
**Double
2) (Window -> Window) -> Window -> Window
forall a b. (a -> b) -> a -> b
$ VocoderParams -> Window
vocWindow VocoderParams
par)

-- | Converts frame representation to complex numbers.
frameToComplex :: STFTFrame -> FFTOutput
frameToComplex :: STFTFrame -> FFTOutput
frameToComplex = (Window -> Window -> FFTOutput) -> STFTFrame -> FFTOutput
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry ((Window -> Window -> FFTOutput) -> STFTFrame -> FFTOutput)
-> (Window -> Window -> FFTOutput) -> STFTFrame -> FFTOutput
forall a b. (a -> b) -> a -> b
$ (Double -> Double -> Complex Double)
-> Window -> Window -> FFTOutput
forall a b c.
(Storable a, Storable b, Storable c) =>
(a -> b -> c) -> Vector a -> Vector b -> Vector c
V.zipWith Double -> Double -> Complex Double
forall a. Floating a => a -> a -> Complex a
mkPolar

-- | Converts frame representation to magnitude and phase.
frameFromComplex :: FFTOutput -> STFTFrame
frameFromComplex :: FFTOutput -> STFTFrame
frameFromComplex = (Complex Double -> Double) -> FFTOutput -> Window
forall a b.
(Storable a, Storable b) =>
(a -> b) -> Vector a -> Vector b
V.map Complex Double -> Double
forall a. RealFloat a => Complex a -> a
magnitude (FFTOutput -> Window)
-> (FFTOutput -> Window) -> FFTOutput -> STFTFrame
forall (a :: * -> * -> *) b c c'.
Arrow a =>
a b c -> a b c' -> a b (c, c')
&&& (Complex Double -> Double) -> FFTOutput -> Window
forall a b.
(Storable a, Storable b) =>
(a -> b) -> Vector a -> Vector b
V.map Complex Double -> Double
forall a. RealFloat a => Complex a -> a
phase

-- | Adds STFT frames.
addFrames :: STFTFrame -> STFTFrame -> STFTFrame
addFrames :: STFTFrame -> STFTFrame -> STFTFrame
addFrames STFTFrame
f1 STFTFrame
f2 = FFTOutput -> STFTFrame
frameFromComplex (FFTOutput -> STFTFrame) -> FFTOutput -> STFTFrame
forall a b. (a -> b) -> a -> b
$ (Complex Double -> Complex Double -> Complex Double)
-> FFTOutput -> FFTOutput -> FFTOutput
forall a b c.
(Storable a, Storable b, Storable c) =>
(a -> b -> c) -> Vector a -> Vector b -> Vector c
V.zipWith Complex Double -> Complex Double -> Complex Double
forall a. Num a => a -> a -> a
(+) (STFTFrame -> FFTOutput
frameToComplex STFTFrame
f1) (STFTFrame -> FFTOutput
frameToComplex STFTFrame
f2)