module SDR.FilterInternal where
import           Control.Monad.Primitive 
import           Control.Monad
import           Foreign.C.Types
import           Foreign.Ptr
import           Data.Coerce
import           Data.Complex
import           Foreign.Marshal.Array
import           Foreign.Marshal.Alloc
import           Foreign.Storable
import qualified Data.Vector.Generic               as VG
import qualified Data.Vector.Generic.Mutable       as VGM
import qualified Data.Vector.Storable              as VS
import qualified Data.Vector.Storable.Mutable      as VSM
import qualified Data.Vector.Fusion.Bundle         as VFB
import           SDR.VectorUtils
import           SDR.Util
filterHighLevel :: (PrimMonad m, Functor m, Num a, Mult a b, VG.Vector v a, VG.Vector v b, VGM.MVector vm a) => v b -> Int -> v a -> vm (PrimState m) a -> m ()
filterHighLevel coeffs num inBuf outBuf = fill (VFB.generate num dotProd) outBuf
    where
    dotProd offset = VG.sum $ VG.zipWith mult (VG.unsafeDrop offset inBuf) coeffs
filterImperative1 :: (PrimMonad m, Functor m, Num a, Mult a b, VG.Vector v a, VG.Vector v b, VGM.MVector vm a) => v b -> Int -> v a -> vm (PrimState m) a -> m ()
filterImperative1 coeffs num inBuf outBuf = go 0
    where
    go offset 
        | offset < num = do
            let res = dotProd offset
            VGM.unsafeWrite outBuf offset res
            go $ offset + 1
        | otherwise    = return ()
    dotProd offset = VG.sum $ VG.zipWith mult (VG.unsafeDrop offset inBuf) coeffs
filterImperative2 :: (PrimMonad m, Functor m, Num a, Mult a b, VG.Vector v a, VG.Vector v b, VGM.MVector vm a) => v b -> Int -> v a -> vm (PrimState m) a -> m ()
filterImperative2 coeffs num inBuf outBuf = go 0
    where
    go offset 
        | offset < num = do
            let res = dotProd (VG.unsafeDrop offset inBuf)
            VGM.unsafeWrite outBuf offset res
            go $ offset + 1
        | otherwise    = return ()
    dotProd buf = go 0 0
        where
        go !accum j 
            | j < VG.length coeffs = go (VG.unsafeIndex buf j `mult` VG.unsafeIndex coeffs j  + accum) (j + 1)
            | otherwise            = accum
type FilterCRR = CInt -> CInt -> Ptr CFloat -> Ptr CFloat -> Ptr CFloat -> IO ()
type FilterRR  = VS.Vector Float -> Int -> VS.Vector Float -> VS.MVector RealWorld Float -> IO ()
type FilterRC  = VS.Vector Float -> Int -> VS.Vector (Complex Float) -> VS.MVector RealWorld (Complex Float) -> IO ()
filterFFIR :: FilterCRR -> FilterRR 
filterFFIR func coeffs num inBuf outBuf = 
    VS.unsafeWith (coerce coeffs) $ \cPtr -> 
        VS.unsafeWith (coerce inBuf) $ \iPtr -> 
            VSM.unsafeWith (coerce outBuf) $ \oPtr -> 
                func (fromIntegral num) (fromIntegral $ VG.length coeffs) cPtr iPtr oPtr
filterFFIC :: FilterCRR -> FilterRC 
filterFFIC func coeffs num inBuf outBuf = 
    VS.unsafeWith (coerce coeffs) $ \cPtr -> 
        VS.unsafeWith (coerce inBuf) $ \iPtr -> 
            VSM.unsafeWith (coerce outBuf) $ \oPtr -> 
                func (fromIntegral num) (fromIntegral $ VG.length coeffs) cPtr iPtr oPtr
foreign import ccall unsafe "filterRR"
    filterRR_c :: CInt -> CInt -> Ptr CFloat -> Ptr CFloat -> Ptr CFloat -> IO ()
filterCRR :: FilterRR
filterCRR = filterFFIR filterRR_c 
foreign import ccall unsafe "filterRC"
    filterRC_c :: CInt -> CInt -> Ptr CFloat -> Ptr CFloat -> Ptr CFloat -> IO ()
filterCRC :: FilterRC
filterCRC = filterFFIC filterRC_c
foreign import ccall unsafe "filterSSERR"
    filterSSERR_c :: CInt -> CInt -> Ptr CFloat -> Ptr CFloat -> Ptr CFloat -> IO ()
filterCSSERR :: FilterRR
filterCSSERR = filterFFIR filterSSERR_c
foreign import ccall unsafe "filterSSESymmetricRR"
    filterSSESymmetricRR_c :: CInt -> CInt -> Ptr CFloat -> Ptr CFloat -> Ptr CFloat -> IO ()
filterCSSESymmetricRR :: FilterRR
filterCSSESymmetricRR = filterFFIR filterSSESymmetricRR_c
foreign import ccall unsafe "filterSSERC"
    filterSSERC_c :: CInt -> CInt -> Ptr CFloat -> Ptr CFloat -> Ptr CFloat -> IO ()
filterCSSERC :: FilterRC
filterCSSERC = filterFFIC filterSSERC_c
foreign import ccall unsafe "filterSSERC2"
    filterSSERC2_c :: CInt -> CInt -> Ptr CFloat -> Ptr CFloat -> Ptr CFloat -> IO ()
filterCSSERC2 :: FilterRC
filterCSSERC2 = filterFFIC filterSSERC2_c
foreign import ccall unsafe "filterAVXRR"
    filterAVXRR_c :: CInt -> CInt -> Ptr CFloat -> Ptr CFloat -> Ptr CFloat -> IO ()
filterCAVXRR :: FilterRR
filterCAVXRR = filterFFIR filterAVXRR_c
foreign import ccall unsafe "filterAVXSymmetricRR"
    filterAVXSymmetricRR_c :: CInt -> CInt -> Ptr CFloat -> Ptr CFloat -> Ptr CFloat -> IO ()
filterCAVXSymmetricRR :: FilterRR
filterCAVXSymmetricRR = filterFFIR filterAVXSymmetricRR_c
foreign import ccall unsafe "filterAVXRC"
    filterAVXRC_c :: CInt -> CInt -> Ptr CFloat -> Ptr CFloat -> Ptr CFloat -> IO ()
filterCAVXRC :: FilterRC
filterCAVXRC = filterFFIC filterAVXRC_c
foreign import ccall unsafe "filterAVXRC2"
    filterAVXRC2_c :: CInt -> CInt -> Ptr CFloat -> Ptr CFloat -> Ptr CFloat -> IO ()
filterCAVXRC2 :: FilterRC
filterCAVXRC2 = filterFFIC filterAVXRC2_c
foreign import ccall unsafe "filterSSESymmetricRC"
    filterSSESymmetricRC_c :: CInt -> CInt -> Ptr CFloat -> Ptr CFloat -> Ptr CFloat -> IO ()
filterCSSESymmetricRC :: FilterRC
filterCSSESymmetricRC = filterFFIC filterSSESymmetricRC_c
foreign import ccall unsafe "filterAVXSymmetricRC"
    filterAVXSymmetricRC_c :: CInt -> CInt -> Ptr CFloat -> Ptr CFloat -> Ptr CFloat -> IO ()
filterCAVXSymmetricRC :: FilterRC
filterCAVXSymmetricRC = filterFFIC filterAVXSymmetricRC_c
decimateHighLevel :: (PrimMonad m, Functor m, Num a, Mult a b, VG.Vector v a, VG.Vector v b, VGM.MVector vm a) => Int -> v b -> Int -> v a -> vm (PrimState m) a -> m ()
decimateHighLevel factor coeffs num inBuf outBuf = fill x outBuf
    where 
    x = VFB.map dotProd (VFB.iterateN num (+ factor) 0)
    dotProd offset = VG.sum $ VG.zipWith mult (VG.unsafeDrop offset inBuf) coeffs
type DecimateCRR = CInt -> CInt -> CInt -> Ptr CFloat -> Ptr CFloat -> Ptr CFloat -> IO ()
type DecimateRR  = Int -> VS.Vector Float -> Int -> VS.Vector Float -> VS.MVector RealWorld Float -> IO ()
type DecimateRC  = Int -> VS.Vector Float -> Int -> VS.Vector (Complex Float) -> VS.MVector RealWorld (Complex Float) -> IO ()
decimateFFIR :: DecimateCRR -> DecimateRR 
decimateFFIR func factor coeffs num inBuf outBuf = 
    VS.unsafeWith (coerce coeffs) $ \cPtr -> 
        VS.unsafeWith (coerce inBuf) $ \iPtr -> 
            VSM.unsafeWith (coerce outBuf) $ \oPtr -> 
                func (fromIntegral num) (fromIntegral factor) (fromIntegral $ VG.length coeffs) cPtr iPtr oPtr
decimateFFIC :: DecimateCRR -> DecimateRC 
decimateFFIC func factor coeffs num inBuf outBuf = 
    VS.unsafeWith (coerce coeffs) $ \cPtr -> 
        VS.unsafeWith (coerce inBuf) $ \iPtr -> 
            VSM.unsafeWith (coerce outBuf) $ \oPtr -> 
                func (fromIntegral num) (fromIntegral factor) (fromIntegral $ VG.length coeffs) cPtr iPtr oPtr
foreign import ccall unsafe "decimateRR"
    decimateCRR_c :: CInt -> CInt -> CInt -> Ptr CFloat -> Ptr CFloat -> Ptr CFloat -> IO ()
decimateCRR :: DecimateRR
decimateCRR = decimateFFIR decimateCRR_c
foreign import ccall unsafe "decimateRC"
    decimateCRC_c :: CInt -> CInt -> CInt -> Ptr CFloat -> Ptr CFloat -> Ptr CFloat -> IO ()
decimateCRC :: DecimateRC
decimateCRC = decimateFFIC decimateCRC_c
foreign import ccall unsafe "decimateSSERR"
    decimateSSERR_c :: CInt -> CInt -> CInt -> Ptr CFloat -> Ptr CFloat -> Ptr CFloat -> IO ()
decimateCSSERR :: DecimateRR
decimateCSSERR = decimateFFIR decimateSSERR_c
foreign import ccall unsafe "decimateSSERC"
    decimateSSERC_c :: CInt -> CInt -> CInt -> Ptr CFloat -> Ptr CFloat -> Ptr CFloat -> IO ()
decimateCSSERC :: DecimateRC
decimateCSSERC = decimateFFIC decimateSSERC_c
foreign import ccall unsafe "decimateSSERC2"
    decimateSSERC2_c :: CInt -> CInt -> CInt -> Ptr CFloat -> Ptr CFloat -> Ptr CFloat -> IO ()
decimateCSSERC2 :: DecimateRC
decimateCSSERC2 = decimateFFIC decimateSSERC2_c
foreign import ccall unsafe "decimateAVXRR"
    decimateAVXRR_c :: CInt -> CInt -> CInt -> Ptr CFloat -> Ptr CFloat -> Ptr CFloat -> IO ()
decimateCAVXRR :: DecimateRR
decimateCAVXRR = decimateFFIR decimateAVXRR_c
foreign import ccall unsafe "decimateAVXRC"
    decimateAVXRC_c :: CInt -> CInt -> CInt -> Ptr CFloat -> Ptr CFloat -> Ptr CFloat -> IO ()
decimateCAVXRC :: DecimateRC
decimateCAVXRC = decimateFFIC decimateAVXRC_c
foreign import ccall unsafe "decimateAVXRC2"
    decimateAVXRC2_c :: CInt -> CInt -> CInt -> Ptr CFloat -> Ptr CFloat -> Ptr CFloat -> IO ()
decimateCAVXRC2 :: DecimateRC
decimateCAVXRC2 = decimateFFIC decimateAVXRC2_c
foreign import ccall unsafe "decimateSSESymmetricRR"
    decimateSSESymmetricRR_c :: CInt -> CInt -> CInt -> Ptr CFloat -> Ptr CFloat -> Ptr CFloat -> IO ()
decimateCSSESymmetricRR :: DecimateRR
decimateCSSESymmetricRR = decimateFFIR decimateSSESymmetricRR_c
foreign import ccall unsafe "decimateAVXSymmetricRR"
    decimateAVXSymmetricRR_c :: CInt -> CInt -> CInt -> Ptr CFloat -> Ptr CFloat -> Ptr CFloat -> IO ()
decimateCAVXSymmetricRR :: DecimateRR
decimateCAVXSymmetricRR = decimateFFIR decimateAVXSymmetricRR_c
foreign import ccall unsafe "decimateSSESymmetricRC"
    decimateSSESymmetricRC_c :: CInt -> CInt -> CInt -> Ptr CFloat -> Ptr CFloat -> Ptr CFloat -> IO ()
decimateCSSESymmetricRC :: DecimateRC
decimateCSSESymmetricRC = decimateFFIC decimateSSESymmetricRC_c
foreign import ccall unsafe "decimateAVXSymmetricRC"
    decimateAVXSymmetricRC_c :: CInt -> CInt -> CInt -> Ptr CFloat -> Ptr CFloat -> Ptr CFloat -> IO ()
decimateCAVXSymmetricRC :: DecimateRC
decimateCAVXSymmetricRC = decimateFFIC decimateAVXSymmetricRC_c
resampleHighLevel :: (PrimMonad m, Num a, Mult a b, VG.Vector v a, VG.Vector v b, VGM.MVector vm a) => Int -> Int -> v b -> Int -> Int -> v a -> vm (PrimState m) a -> m Int
resampleHighLevel interpolation decimation coeffs filterOffset count inBuf outBuf = fill 0 filterOffset 0
    where
    fill i filterOffset inputOffset
        | i < count = do
            let dp = dotProd filterOffset inputOffset
            VGM.unsafeWrite outBuf i dp
            let (q, r)        = divMod (decimation  filterOffset  1) interpolation
                inputOffset'  = inputOffset + q + 1
                filterOffset' = interpolation  1  r
            filterOffset' `seq` inputOffset' `seq` fill (i + 1) filterOffset' inputOffset'
        | otherwise = return filterOffset
    dotProd filterOffset offset = VG.sum $ VG.zipWith mult (VG.unsafeDrop offset inBuf) (stride interpolation (VG.unsafeDrop filterOffset coeffs))
foreign import ccall unsafe "resampleRR"
    resample_c :: CInt -> CInt -> CInt -> CInt -> CInt -> Ptr CFloat -> Ptr CFloat -> Ptr CFloat -> IO ()
resampleCRR :: Int -> Int -> Int -> Int -> VS.Vector Float -> VS.Vector Float -> VS.MVector RealWorld Float -> IO ()
resampleCRR num interpolation decimation offset coeffs inBuf outBuf = 
    VS.unsafeWith (coerce coeffs) $ \cPtr -> 
        VS.unsafeWith (coerce inBuf) $ \iPtr -> 
            VSM.unsafeWith (coerce outBuf) $ \oPtr -> 
                resample_c (fromIntegral num) (fromIntegral $ VG.length coeffs) (fromIntegral interpolation) (fromIntegral decimation) (fromIntegral offset) cPtr iPtr oPtr
pad :: a -> Int -> [a] -> [a]
pad with num list = list ++ replicate (num  length list) with 
strideList :: Int -> [a] -> [a]
strideList s xs = go 0 xs
    where
    go _ []     = []
    go 0 (x:xs) = x : go (s1) xs
    go n (x:xs) = go (n  1) xs
roundUp :: Int -> Int -> Int
roundUp num div = ((num + div  1) `quot` div) * div
data Coeffs = Coeffs {
    numCoeffs  :: Int,
    numGroups  :: Int,
    increments :: [Int],
    groups     :: [[Float]]
}
prepareCoeffs :: Int -> Int -> Int -> [Float] -> Coeffs
prepareCoeffs n interpolation decimation coeffs = Coeffs {..}
    where
    numCoeffs   = maximum $ map (length . snd) dats
    numGroups   = length groups
    increments  = map fst dats
    groups      :: [[Float]]
    groups      = map (pad 0 (roundUp numCoeffs n)) $ map snd dats
    dats :: [(Int, [Float])]
    dats = func 0
        where
        func' 0      = []
        func' x      = func x
        func :: Int -> [(Int, [Float])]
        func offset = (increment, strideList interpolation $ drop offset coeffs) : func' offset'
            where
            (q, r)    = divMod (decimation  offset  1) interpolation
            increment = q + 1
            offset'   = interpolation  1  r
resampleFFIR :: (Ptr CFloat -> Ptr CFloat -> IO CInt) -> VS.Vector Float -> VSM.MVector RealWorld Float -> IO Int
resampleFFIR func inBuf outBuf = liftM fromIntegral $
    VS.unsafeWith (coerce inBuf) $ \iPtr -> 
        VSM.unsafeWith (coerce outBuf) $ \oPtr -> 
            func iPtr oPtr
resampleFFIC :: (Ptr CFloat -> Ptr CFloat -> IO CInt) -> VS.Vector (Complex Float) -> VSM.MVector RealWorld (Complex Float) -> IO Int
resampleFFIC func inBuf outBuf = liftM fromIntegral $
    VS.unsafeWith (coerce inBuf) $ \iPtr -> 
        VSM.unsafeWith (coerce outBuf) $ \oPtr -> 
            func iPtr oPtr
type ResampleR = CInt -> CInt -> CInt -> CInt -> Ptr CInt -> Ptr (Ptr CFloat) -> Ptr CFloat -> Ptr CFloat -> IO CInt
mkResampler :: ResampleR -> Int -> Int -> Int -> [Float] -> IO (Int -> Int -> VS.Vector Float -> VS.MVector RealWorld Float -> IO Int)
mkResampler func n interpolation decimation coeffs = do
    groupsP     <- mapM newArray $ map (map realToFrac) groups
    groupsPP    <- newArray groupsP
    incrementsP <- newArray $ map fromIntegral increments
    return $ \offset num -> resampleFFIR $ func (fromIntegral num) (fromIntegral numCoeffs) (fromIntegral offset) (fromIntegral numGroups) incrementsP groupsPP
    where
    Coeffs {..} = prepareCoeffs n interpolation decimation coeffs
type ResampleRR = Int -> Int -> [Float] -> IO (Int -> Int -> VS.Vector Float -> VS.MVector RealWorld Float -> IO Int)
foreign import ccall unsafe "resample2RR"
    resample2_c :: CInt -> CInt -> CInt -> CInt -> Ptr CInt -> Ptr (Ptr CFloat) -> Ptr CFloat -> Ptr CFloat -> IO CInt
resampleCRR2 :: ResampleRR
resampleCRR2 = mkResampler resample2_c 1
foreign import ccall unsafe "resampleSSERR"
    resampleCSSERR_c :: CInt -> CInt -> CInt -> CInt -> Ptr CInt -> Ptr (Ptr CFloat) -> Ptr CFloat -> Ptr CFloat -> IO CInt
resampleCSSERR :: ResampleRR
resampleCSSERR = mkResampler resampleCSSERR_c 4
foreign import ccall unsafe "resampleAVXRR"
    resampleAVXRR_c :: CInt -> CInt -> CInt -> CInt -> Ptr CInt -> Ptr (Ptr CFloat) -> Ptr CFloat -> Ptr CFloat -> IO CInt
resampleCAVXRR :: ResampleRR
resampleCAVXRR = mkResampler resampleAVXRR_c 8
type ResampleRC = Int -> Int -> [Float] -> IO (Int -> Int -> VS.Vector (Complex Float) -> VS.MVector RealWorld (Complex Float) -> IO Int)
mkResamplerC :: ResampleR -> Int -> Int -> Int -> [Float] -> IO (Int -> Int -> VS.Vector (Complex Float) -> VS.MVector RealWorld (Complex Float) -> IO Int)
mkResamplerC func n interpolation decimation coeffs = do
    groupsP     <- mapM newArray $ map (map realToFrac) groups
    groupsPP    <- newArray groupsP
    incrementsP <- newArray $ map fromIntegral increments
    return $ \offset num -> resampleFFIC $ func (fromIntegral num) (fromIntegral numCoeffs) (fromIntegral offset) (fromIntegral numGroups) incrementsP groupsPP
    where
    Coeffs {..} = prepareCoeffs n interpolation decimation coeffs
foreign import ccall unsafe "resample2RC"
    resample2RC_c :: CInt -> CInt -> CInt -> CInt -> Ptr CInt -> Ptr (Ptr CFloat) -> Ptr CFloat -> Ptr CFloat -> IO CInt
resampleCRC :: ResampleRC
resampleCRC = mkResamplerC resample2RC_c 1
foreign import ccall unsafe "resampleSSERC"
    resampleCSSERC_c :: CInt -> CInt -> CInt -> CInt -> Ptr CInt -> Ptr (Ptr CFloat) -> Ptr CFloat -> Ptr CFloat -> IO CInt
resampleCSSERC :: ResampleRC
resampleCSSERC = mkResamplerC resampleCSSERC_c 4
foreign import ccall unsafe "resampleAVXRC"
    resampleAVXRC_c :: CInt -> CInt -> CInt -> CInt -> Ptr CInt -> Ptr (Ptr CFloat) -> Ptr CFloat -> Ptr CFloat -> IO CInt
resampleCAVXRC :: ResampleRC
resampleCAVXRC = mkResamplerC resampleAVXRC_c 8
decimateCrossHighLevel :: (PrimMonad m, Functor m, Num a, Mult a b, VG.Vector v a, VG.Vector v b, VGM.MVector vm a) => Int -> v b -> Int -> v a -> v a -> vm (PrimState m) a -> m ()
decimateCrossHighLevel factor coeffs num lastBuf nextBuf outBuf = fill x outBuf
    where
    x = VFB.map dotProd (VFB.iterateN num (+ factor) 0)
    dotProd i = VG.sum $ VG.zipWith mult (VG.unsafeDrop i lastBuf VG.++ nextBuf) coeffs
filterCrossHighLevel :: (PrimMonad m, Functor m, Num a, Mult a b, VG.Vector v a, VG.Vector v b, VGM.MVector vm a) => v b -> Int -> v a -> v a -> vm (PrimState m) a -> m ()
filterCrossHighLevel coeffs num lastBuf nextBuf outBuf = fill (VFB.generate num dotProd) outBuf
    where
    dotProd i = VG.sum $ VG.zipWith mult (VG.unsafeDrop i lastBuf VG.++ nextBuf) coeffs
resampleCrossHighLevel :: (PrimMonad m, Num a, Mult a b, VG.Vector v a, VG.Vector v b, VGM.MVector vm a) => Int -> Int -> v b -> Int -> Int -> v a -> v a -> vm (PrimState m) a -> m Int
resampleCrossHighLevel interpolation decimation coeffs filterOffset count lastBuf nextBuf outBuf = fill 0 filterOffset 0
    where
    fill i filterOffset inputOffset
        | i < count = do
            let dp = dotProd filterOffset inputOffset
            VGM.unsafeWrite outBuf i dp
            let (q, r)        = divMod (decimation  filterOffset  1) interpolation
                inputOffset'  = inputOffset + q + 1
                filterOffset' = interpolation  1  r
            filterOffset' `seq` inputOffset' `seq` fill (i + 1) filterOffset' inputOffset'
        | otherwise = return filterOffset
    dotProd filterOffset i = VG.sum $ VG.zipWith mult (VG.unsafeDrop i lastBuf VG.++ nextBuf) (stride interpolation (VG.unsafeDrop filterOffset coeffs))
foreign import ccall unsafe "dcBlocker"
    c_dcBlocker :: CInt -> CFloat -> CFloat -> Ptr CFloat -> Ptr CFloat -> Ptr CFloat -> Ptr CFloat -> IO ()
dcBlocker :: Int -> Float -> Float -> VS.Vector Float -> VS.MVector RealWorld Float -> IO (Float, Float)
dcBlocker num lastSample lastOutput inBuf outBuf = 
    alloca $ \fsp -> 
        alloca $ \fop -> 
            VS.unsafeWith (coerce inBuf) $ \iPtr -> 
                VSM.unsafeWith (coerce outBuf) $ \oPtr -> do
                    c_dcBlocker (fromIntegral num) (realToFrac lastSample) (realToFrac lastOutput) fsp fop iPtr oPtr
                    r1 <- peek fsp
                    r2 <- peek fop
                    return (realToFrac r1, realToFrac r2)