{-| 
    Module      : Vocoder.Dunai
    Description : Phase vocoder in Dunai
    Copyright   : (c) Marek Materzok, 2021
    License     : BSD2

This module wraps phase vocoder algorithms for use in Dunai and Rhine.
-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE BangPatterns #-}
module Vocoder.Dunai (
    volumeFix,
    analysis,
    synthesis,
    processFrames,
    process,
    framesOfS,
    sumFramesS,
    sumFramesWithLengthS
    ) where

import Data.MonadicStreamFunction
import Data.Tuple(swap)
import Data.Maybe(fromMaybe)
import qualified Data.Vector.Storable as V
import Vocoder

-- | Perform the phase vocoder analysis phase.
analysis :: (Traversable t, Monad m) => VocoderParams -> Phase -> MSF m (t Frame) (t STFTFrame)
analysis :: VocoderParams -> Phase -> MSF m (t Phase) (t STFTFrame)
analysis VocoderParams
par = (t Phase -> Phase -> (t STFTFrame, Phase))
-> Phase -> MSF m (t Phase) (t STFTFrame)
forall (m :: * -> *) a s b.
Monad m =>
(a -> s -> (b, s)) -> s -> MSF m a b
mealy ((t Phase -> Phase -> (t STFTFrame, Phase))
 -> Phase -> MSF m (t Phase) (t STFTFrame))
-> (t Phase -> Phase -> (t STFTFrame, Phase))
-> Phase
-> MSF m (t Phase) (t STFTFrame)
forall a b. (a -> b) -> a -> b
$ \t Phase
a Phase
s -> (Phase, t STFTFrame) -> (t STFTFrame, Phase)
forall a b. (a, b) -> (b, a)
swap ((Phase, t STFTFrame) -> (t STFTFrame, Phase))
-> (Phase, t STFTFrame) -> (t STFTFrame, Phase)
forall a b. (a -> b) -> a -> b
$ VocoderParams -> Phase -> t Phase -> (Phase, t STFTFrame)
forall (t :: * -> *).
Traversable t =>
VocoderParams -> Phase -> t Phase -> (Phase, t STFTFrame)
analysisStage VocoderParams
par Phase
s t Phase
a

-- | Perform the phase vocoder synthesis phase.
synthesis :: (Traversable t, Monad m) => VocoderParams -> Phase -> MSF m (t STFTFrame) (t Frame)
synthesis :: VocoderParams -> Phase -> MSF m (t STFTFrame) (t Phase)
synthesis VocoderParams
par = (t STFTFrame -> Phase -> (t Phase, Phase))
-> Phase -> MSF m (t STFTFrame) (t Phase)
forall (m :: * -> *) a s b.
Monad m =>
(a -> s -> (b, s)) -> s -> MSF m a b
mealy ((t STFTFrame -> Phase -> (t Phase, Phase))
 -> Phase -> MSF m (t STFTFrame) (t Phase))
-> (t STFTFrame -> Phase -> (t Phase, Phase))
-> Phase
-> MSF m (t STFTFrame) (t Phase)
forall a b. (a -> b) -> a -> b
$ \t STFTFrame
a Phase
s -> (Phase, t Phase) -> (t Phase, Phase)
forall a b. (a, b) -> (b, a)
swap ((Phase, t Phase) -> (t Phase, Phase))
-> (Phase, t Phase) -> (t Phase, Phase)
forall a b. (a -> b) -> a -> b
$ VocoderParams -> Phase -> t STFTFrame -> (Phase, t Phase)
forall (t :: * -> *).
Traversable t =>
VocoderParams -> Phase -> t STFTFrame -> (Phase, t Phase)
synthesisStage VocoderParams
par Phase
s t STFTFrame
a

-- | Perform frequency domain processing on overlapping frames.
processFrames :: (Traversable t, Monad m) => VocoderParams -> MSF m (t STFTFrame) (t STFTFrame) -> MSF m (t Frame) (t Frame)
processFrames :: VocoderParams
-> MSF m (t STFTFrame) (t STFTFrame) -> MSF m (t Phase) (t Phase)
processFrames VocoderParams
par MSF m (t STFTFrame) (t STFTFrame)
msf = VocoderParams -> Phase -> MSF m (t Phase) (t STFTFrame)
forall (t :: * -> *) (m :: * -> *).
(Traversable t, Monad m) =>
VocoderParams -> Phase -> MSF m (t Phase) (t STFTFrame)
analysis VocoderParams
par (VocoderParams -> Phase
zeroPhase VocoderParams
par) MSF m (t Phase) (t STFTFrame)
-> MSF m (t STFTFrame) (t Phase) -> MSF m (t Phase) (t Phase)
forall k (cat :: k -> k -> *) (a :: k) (b :: k) (c :: k).
Category cat =>
cat a b -> cat b c -> cat a c
>>> MSF m (t STFTFrame) (t STFTFrame)
msf MSF m (t STFTFrame) (t STFTFrame)
-> MSF m (t STFTFrame) (t Phase) -> MSF m (t STFTFrame) (t Phase)
forall k (cat :: k -> k -> *) (a :: k) (b :: k) (c :: k).
Category cat =>
cat a b -> cat b c -> cat a c
>>> VocoderParams -> Phase -> MSF m (t STFTFrame) (t Phase)
forall (t :: * -> *) (m :: * -> *).
(Traversable t, Monad m) =>
VocoderParams -> Phase -> MSF m (t STFTFrame) (t Phase)
synthesis VocoderParams
par (VocoderParams -> Phase
zeroPhase VocoderParams
par)

-- | Corrects for volume change introduced by STFT processing.
volumeFix :: Monad m => VocoderParams -> MSF m Frame Frame
volumeFix :: VocoderParams -> MSF m Phase Phase
volumeFix VocoderParams
par = (Phase -> Phase) -> MSF m Phase Phase
forall (a :: * -> * -> *) b c. Arrow a => (b -> c) -> a b c
arr ((Phase -> Phase) -> MSF m Phase Phase)
-> (Phase -> Phase) -> MSF m Phase Phase
forall a b. (a -> b) -> a -> b
$ (Double -> Double) -> Phase -> Phase
forall a b.
(Storable a, Storable b) =>
(a -> b) -> Vector a -> Vector b
V.map (Double -> Double -> Double
forall a. Num a => a -> a -> a
* VocoderParams -> Double
volumeCoeff VocoderParams
par)

-- | Perform frequency domain processing on a chunked stream. 
--   The chunks' size must be a multiple of the vocoder's hop size.
process :: Monad m => VocoderParams -> MSF m [STFTFrame] [STFTFrame] -> MSF m Frame Frame
process :: VocoderParams -> MSF m [STFTFrame] [STFTFrame] -> MSF m Phase Phase
process VocoderParams
par MSF m [STFTFrame] [STFTFrame]
msf = (Length -> Length -> MSF m Phase [Phase]
forall a (m :: * -> *).
(Storable a, Num a, Monad m) =>
Length -> Length -> MSF m (Vector a) [Vector a]
framesOfS (VocoderParams -> Length
vocInputFrameLength VocoderParams
par) (VocoderParams -> Length
vocHopSize VocoderParams
par) MSF m Phase [Phase] -> MSF m [Phase] [Phase] -> MSF m Phase [Phase]
forall k (cat :: k -> k -> *) (a :: k) (b :: k) (c :: k).
Category cat =>
cat a b -> cat b c -> cat a c
>>> VocoderParams
-> MSF m [STFTFrame] [STFTFrame] -> MSF m [Phase] [Phase]
forall (t :: * -> *) (m :: * -> *).
(Traversable t, Monad m) =>
VocoderParams
-> MSF m (t STFTFrame) (t STFTFrame) -> MSF m (t Phase) (t Phase)
processFrames VocoderParams
par MSF m [STFTFrame] [STFTFrame]
msf) MSF m Phase [Phase]
-> MSF m Phase Length -> MSF m Phase ([Phase], Length)
forall (a :: * -> * -> *) b c c'.
Arrow a =>
a b c -> a b c' -> a b (c, c')
&&& (Phase -> Length) -> MSF m Phase Length
forall (a :: * -> * -> *) b c. Arrow a => (b -> c) -> a b c
arr Phase -> Length
forall a. Storable a => Vector a -> Length
V.length 
               MSF m Phase ([Phase], Length)
-> MSF m ([Phase], Length) Phase -> MSF m Phase Phase
forall k (cat :: k -> k -> *) (a :: k) (b :: k) (c :: k).
Category cat =>
cat a b -> cat b c -> cat a c
>>> Length -> MSF m ([Phase], Length) Phase
forall a (m :: * -> *).
(Storable a, Num a, Monad m) =>
Length -> MSF m ([Vector a], Length) (Vector a)
sumFramesWithLengthS (VocoderParams -> Length
vocHopSize VocoderParams
par) MSF m ([Phase], Length) Phase
-> MSF m Phase Phase -> MSF m ([Phase], Length) Phase
forall k (cat :: k -> k -> *) (a :: k) (b :: k) (c :: k).
Category cat =>
cat a b -> cat b c -> cat a c
>>> VocoderParams -> MSF m Phase Phase
forall (m :: * -> *). Monad m => VocoderParams -> MSF m Phase Phase
volumeFix VocoderParams
par

data P a = P {-# UNPACK #-} !Length {-# UNPACK #-} !(V.Vector a)

mapP :: (Length -> Length) -> (V.Vector a1 -> V.Vector a2) -> P a1 -> P a2
mapP :: (Length -> Length) -> (Vector a1 -> Vector a2) -> P a1 -> P a2
mapP Length -> Length
f Vector a1 -> Vector a2
g (P Length
n Vector a1
c) = Length -> Vector a2 -> P a2
forall a. Length -> Vector a -> P a
P (Length -> Length
f Length
n) (Vector a1 -> Vector a2
g Vector a1
c)

-- | Splits a chunked input stream into overlapping frames of constant size
--   suitable for STFT processing.
--   The input and output chunks' size must be a multiple of the vocoder's hop size.
framesOfS :: forall a m. (V.Storable a, Num a, Monad m) => Length -> HopSize -> MSF m (V.Vector a) [V.Vector a]
framesOfS :: Length -> Length -> MSF m (Vector a) [Vector a]
framesOfS Length
chunkSize Length
hopSize = (Vector a -> Vector a -> ([Vector a], Vector a))
-> Vector a -> MSF m (Vector a) [Vector a]
forall (m :: * -> *) a s b.
Monad m =>
(a -> s -> (b, s)) -> s -> MSF m a b
mealy Vector a -> Vector a -> ([Vector a], Vector a)
f (Vector a -> MSF m (Vector a) [Vector a])
-> Vector a -> MSF m (Vector a) [Vector a]
forall a b. (a -> b) -> a -> b
$ Length -> a -> Vector a
forall a. Storable a => Length -> a -> Vector a
V.replicate Length
bufLen a
0
    where
    bufHops :: Length
bufHops = (Length
chunkSizeLength -> Length -> Length
forall a. Num a => a -> a -> a
-Length
1) Length -> Length -> Length
forall a. Integral a => a -> a -> a
`div` Length
hopSize
    bufLen :: Length
bufLen = Length
bufHops Length -> Length -> Length
forall a. Num a => a -> a -> a
* Length
hopSize
    f :: V.Vector a -> V.Vector a -> ([V.Vector a], V.Vector a)
    f :: Vector a -> Vector a -> ([Vector a], Vector a)
f Vector a
nextv Vector a
q = ([Vector a]
outs, Vector a
q')
        where
        len :: Length
len = Vector a -> Length
forall a. Storable a => Vector a -> Length
V.length Vector a
nextv
        newBuf :: Vector a
newBuf = Vector a
q Vector a -> Vector a -> Vector a
forall a. Storable a => Vector a -> Vector a -> Vector a
V.++ Vector a
nextv
        q' :: Vector a
q' = Length -> Vector a -> Vector a
forall a. Storable a => Length -> Vector a -> Vector a
V.drop Length
len Vector a
newBuf
        outs :: [Vector a]
outs = [Length -> Vector a -> Vector a
forall a. Storable a => Length -> Vector a -> Vector a
V.take Length
chunkSize (Vector a -> Vector a) -> Vector a -> Vector a
forall a b. (a -> b) -> a -> b
$ Length -> Vector a -> Vector a
forall a. Storable a => Length -> Vector a -> Vector a
V.drop (Length
k Length -> Length -> Length
forall a. Num a => a -> a -> a
* Length
hopSize) Vector a
newBuf | Length
k <- [Length
0 .. Length
len Length -> Length -> Length
forall a. Integral a => a -> a -> a
`div` Length
hopSize Length -> Length -> Length
forall a. Num a => a -> a -> a
- Length
1]]

-- | Builds a chunked output stream from a stream of overlapping frames.
--   The input and output chunks's size must be a multiple of the vocoder's hop size.
sumFramesS :: forall a m. (V.Storable a, Num a, Monad m) => Length -> HopSize -> MSF m [V.Vector a] (V.Vector a)
sumFramesS :: Length -> Length -> MSF m [Vector a] (Vector a)
sumFramesS Length
chunkSize Length
hopSize = ([Vector a] -> ([Vector a], Length))
-> MSF m [Vector a] ([Vector a], Length)
forall (a :: * -> * -> *) b c. Arrow a => (b -> c) -> a b c
arr ([Vector a] -> [Vector a]
forall a. a -> a
id ([Vector a] -> [Vector a])
-> ([Vector a] -> Length) -> [Vector a] -> ([Vector a], Length)
forall (a :: * -> * -> *) b c c'.
Arrow a =>
a b c -> a b c' -> a b (c, c')
&&& Length -> [Vector a] -> Length
forall a b. a -> b -> a
const Length
chunkSize) MSF m [Vector a] ([Vector a], Length)
-> MSF m ([Vector a], Length) (Vector a)
-> MSF m [Vector a] (Vector a)
forall k (cat :: k -> k -> *) (a :: k) (b :: k) (c :: k).
Category cat =>
cat a b -> cat b c -> cat a c
>>> Length -> MSF m ([Vector a], Length) (Vector a)
forall a (m :: * -> *).
(Storable a, Num a, Monad m) =>
Length -> MSF m ([Vector a], Length) (Vector a)
sumFramesWithLengthS Length
hopSize

sumFramesWithLengthS :: forall a m. (V.Storable a, Num a, Monad m) => HopSize -> MSF m ([V.Vector a], Length) (V.Vector a)
sumFramesWithLengthS :: Length -> MSF m ([Vector a], Length) (Vector a)
sumFramesWithLengthS Length
hopSize = (([Vector a], Length) -> [P a] -> (Vector a, [P a]))
-> [P a] -> MSF m ([Vector a], Length) (Vector a)
forall (m :: * -> *) a s b.
Monad m =>
(a -> s -> (b, s)) -> s -> MSF m a b
mealy ([Vector a], Length) -> [P a] -> (Vector a, [P a])
f []
    where
    f :: ([V.Vector a], Length) -> [P a] -> (V.Vector a, [P a])
    f :: ([Vector a], Length) -> [P a] -> (Vector a, [P a])
f ([Vector a]
nexts, Length
chunkSize) [P a]
q = (Vector a
nextv, [P a]
q'')
        where
        ith :: Length -> P a -> a
ith Length
i (P Length
n Vector a
c0) = a -> Maybe a -> a
forall a. a -> Maybe a -> a
fromMaybe a
0 (Maybe a -> a) -> Maybe a -> a
forall a b. (a -> b) -> a -> b
$ Vector a
c0 Vector a -> Length -> Maybe a
forall a. Storable a => Vector a -> Length -> Maybe a
V.!? (Length
i Length -> Length -> Length
forall a. Num a => a -> a -> a
- Length
n)
        q' :: [P a]
q' = [P a]
q [P a] -> [P a] -> [P a]
forall a. [a] -> [a] -> [a]
++ (Length -> Vector a -> P a) -> [Length] -> [Vector a] -> [P a]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Length -> Vector a -> P a
forall a. Length -> Vector a -> P a
P [Length
0, Length
hopSize..] [Vector a]
nexts
        nextv :: Vector a
nextv = Length -> (Length -> a) -> Vector a
forall a. Storable a => Length -> (Length -> a) -> Vector a
V.generate Length
chunkSize ((Length -> a) -> Vector a) -> (Length -> a) -> Vector a
forall a b. (a -> b) -> a -> b
$ \Length
i -> [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([a] -> a) -> [a] -> a
forall a b. (a -> b) -> a -> b
$ (P a -> a) -> [P a] -> [a]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Length -> P a -> a
forall a. (Num a, Storable a) => Length -> P a -> a
ith Length
i) [P a]
q'
        q'' :: [P a]
q'' = (P a -> P a) -> [P a] -> [P a]
forall a b. (a -> b) -> [a] -> [b]
map ((Length -> Length) -> (Vector a -> Vector a) -> P a -> P a
forall a1 a2.
(Length -> Length) -> (Vector a1 -> Vector a2) -> P a1 -> P a2
mapP (Length -> Length -> Length
forall a. Num a => a -> a -> a
+ (-Length
chunkSize)) Vector a -> Vector a
forall a. a -> a
id) ([P a] -> [P a]) -> [P a] -> [P a]
forall a b. (a -> b) -> a -> b
$ (P a -> Bool) -> [P a] -> [P a]
forall a. (a -> Bool) -> [a] -> [a]
dropWhile (\(P Length
n Vector a
c) -> Vector a -> Length
forall a. Storable a => Vector a -> Length
V.length Vector a
c Length -> Length -> Length
forall a. Num a => a -> a -> a
+ Length
n Length -> Length -> Bool
forall a. Ord a => a -> a -> Bool
<= Length
chunkSize) [P a]
q'