{-# LANGUAGE NoImplicitPrelude #-}
{-# LANGUAGE FlexibleContexts #-}
{- |
Copyright   :  (c) Henning Thielemann 2008-2009
License     :  GPL

Maintainer  :  synthesizer@henning-thielemann.de
Stability   :  provisional
Portability :  requires multi-parameter type classes
-}
module Synthesizer.Storable.Filter.NonRecursive where

import qualified Synthesizer.Storable.Signal as SigSt
import qualified Data.StorableVector as V
import qualified Data.StorableVector.Lazy as VL
import qualified Data.StorableVector.Lazy.Pattern as VP

import qualified Synthesizer.Generic.Signal as SigG
import qualified Synthesizer.State.Signal as SigS
import qualified Synthesizer.Plain.Filter.NonRecursive as Filt
import qualified Synthesizer.Generic.Filter.NonRecursive as FiltG

import qualified Algebra.Module         as Module
import qualified Algebra.Field          as Field
import qualified Algebra.Ring           as Ring
import qualified Algebra.Additive       as Additive

import Foreign.Storable (Storable, )
import Foreign.Storable.Tuple ()

import Algebra.Module( {- linearComb, -} (*>), )

import Control.Monad (mplus, )

import qualified Data.List as List
import Data.Tuple.HT (mapFst, mapSnd, mapPair, swap, )

import qualified Numeric.NonNegative.Chunky as NonNegChunky

import PreludeBase
import NumericPrelude as NP
import qualified Prelude as P


{- |
The Maybe type carries an unpaired value from one block to the next one.
-}
sumsDownsample2Strict ::
   (Additive.C v, Storable v) =>
   Maybe v -> V.Vector v -> (Maybe v, V.Vector v)
sumsDownsample2Strict carry ys =
   mapFst (\v -> fmap fst $ V.viewL . snd =<< v) $ swap $
   V.unfoldrN (div (V.length ys + maybe 0 (const 1) carry) 2) (\(carry0,xs0) ->
      do (x0,xs1) <- mplus (fmap (\c -> (c, xs0)) carry0) (V.viewL xs0)
         (x1,xs2) <- V.viewL xs1
         return (x0+x1, (Nothing, xs2)))
      (carry, ys)

sumsDownsample2 ::
   (Additive.C v, Storable v) =>
   SigSt.T v -> SigSt.T v
sumsDownsample2 =
   SigSt.fromChunks .
   filter (not . V.null) .
   (\(carry, chunks) ->
      chunks ++ maybe [] (\cr -> [V.singleton cr]) carry) .
   List.mapAccumL sumsDownsample2Strict Nothing .
   SigSt.chunks

sumsDownsample2Alt ::
   (Additive.C v, Storable v) =>
   SigSt.T v -> SigSt.T v
sumsDownsample2Alt ys =
   fst .
   VP.unfoldrN (halfLazySize $ VP.length ys) (\xs ->
      flip fmap (SigS.viewL xs) $ \xxs0@(x0,xs0) ->
         SigS.switchL xxs0 {- xs0 is empty -}
            (\ x1 xs1 -> (x0+x1, xs1))
            xs0)
    . SigS.fromStorableSignal $ ys

halfLazySize :: NonNegChunky.T VP.ChunkSize -> NonNegChunky.T VP.ChunkSize
halfLazySize =
   NonNegChunky.fromChunks .
   filter (VL.ChunkSize zero /=) .
   (\(c,ls) -> ls ++ [VL.ChunkSize c]) .
   List.mapAccumL (\c (VL.ChunkSize l) ->
      mapSnd VL.ChunkSize $ swap $ divMod (c+l) 2) zero .
   NonNegChunky.toChunks

{- |
offset must be zero or one.
-}
downsample2Strict ::
   (Storable v) =>
   Int -> V.Vector v -> V.Vector v
downsample2Strict offset ys =
   fst $
   V.unfoldrN (- div (offset - V.length ys) 2)
      (fmap (mapSnd laxTailStrict) . V.viewL) $
   if offset == 0
     then ys
     else laxTailStrict ys

laxTailStrict ::
   (Storable v) =>
   V.Vector v -> V.Vector v
laxTailStrict ys =
   V.switchL ys (flip const) ys

downsample2 ::
   (Storable v) =>
   SigSt.T v -> SigSt.T v
downsample2 =
   SigSt.fromChunks .
   filter (not . V.null) .
   snd .
   List.mapAccumL
      (\k c ->
         (mod (k + V.length c) 2, downsample2Strict k c)) zero .
   SigSt.chunks


pyramid ::
   (Additive.C v, Storable v) =>
   Int -> SigSt.T v -> [SigSt.T v]
pyramid height =
   take (1+height) . iterate sumsDownsample2

{-
This function uses the efficient Storable.index function.
If @Generic.index@ becomes as fast as @Storable.index@
then we can replace this function by its generic counterpart.
-}
sumRangeFromPyramid ::
   (Additive.C v, Storable v) =>
   [SigSt.T v] -> (Int,Int) -> v
sumRangeFromPyramid =
   Filt.sumRangePrepare $ \(l0,r0) pyr0 ->
   case pyr0 of
      [] -> error "empty pyramid"
      (ps0:pss) ->
         foldr
            (\psNext k (l,r) ps s ->
               case r-l of
                  0 -> s
                  1 -> s + VL.index ps l
                  _ ->
                     let (lh,ll) = NP.negate $ divMod (NP.negate l) 2
                         (rh,rl) = divMod r 2
                         {-# INLINE inc #-}
                         inc b x = if b==0 then id else (x+)
                     in  k (lh,rh) psNext $
                         inc ll (VL.index ps l) $
                         inc rl (VL.index ps (r-1)) $
                         s)
            (\(l,r) ps s ->
               s + (SigG.sum $ SigSt.take (r-l) $ SigSt.drop l ps))
            pss (l0,r0) ps0 zero

{- |
Moving average, where window bounds must be always non-negative.

The laziness granularity of the input signal is maintained.
-}
sumsPosModulatedPyramid ::
   (Additive.C v, Storable (Int,Int), Storable v) =>
   Int -> SigSt.T (Int,Int) -> SigSt.T v -> SigSt.T v
sumsPosModulatedPyramid height ctrl xs =
   let pyr0 = pyramid height xs
       sizes =
          reverse $ take (1+height) $ iterate (2*) 1
       blockSize = head sizes
       pyrStarts =
          iterate (zipWith SigSt.drop sizes) pyr0
       ctrlBlocks =
          SigS.toList $
          SigG.sliceVertical blockSize ctrl
   in  SigSt.fromChunks $
       zipWith
          (\pyr ->
              SigS.toStrictStorableSignal blockSize .
              SigS.map (sumRangeFromPyramid pyr) .
              SigS.zipWith (\d -> mapPair ((d+), (d+))) (SigS.iterate (1+) 0) .
              SigS.fromStorableSignal)
          pyrStarts ctrlBlocks

{- |
The first argument is the amplification.
The main reason to introduce it,
was to have only a Module constraint instead of Field.
This way we can also filter stereo signals.
-}
movingAverageModulatedPyramid ::
   (Field.C a, Module.C a v,
    Storable Int, Storable v) =>
   a -> Int -> Int -> SigSt.T Int -> SigSt.T v -> SigSt.T v
movingAverageModulatedPyramid amp height maxC ctrl xs =
   SigSt.zipWith (\c x -> (amp / fromIntegral (2*c+1)) *> x) ctrl $
   sumsPosModulatedPyramid height
      (SigSt.map (\c -> (maxC - c, maxC + c)) ctrl)
      (FiltG.delay maxC xs)