{-# LANGUAGE NoImplicitPrelude #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE FlexibleContexts #-}
module Synthesizer.LLVM.Filter.NonRecursive (
   convolve,
   convolvePacked,
   ) where

import qualified Synthesizer.LLVM.CausalParameterized.ProcessPrivate as CausalP
import qualified Synthesizer.LLVM.CausalParameterized.RingBuffer as RingBuffer
import qualified Synthesizer.LLVM.Frame.SerialVector as Serial

import qualified Synthesizer.LLVM.Storable.Vector as SVU
import qualified Data.StorableVector as SV

import qualified LLVM.DSL.Parameter as Param

import qualified LLVM.Extra.Storable as Storable
import qualified LLVM.Extra.Memory as Memory
import qualified LLVM.Extra.Control as C
import qualified LLVM.Extra.Arithmetic as A
import qualified LLVM.Extra.Tuple as Tuple

import qualified LLVM.Core as LLVM
import LLVM.Core (Value, valueOf, CodeGenFunction, IsSized, SizeOf)

import qualified Type.Data.Num.Decimal as TypeNum
import Type.Data.Num.Decimal.Number ((:*:))

import Foreign.ForeignPtr (touchForeignPtr)
import Foreign.Ptr (Ptr)
import Data.Word (Word)

import Control.Arrow ((<<<), (&&&))
import Control.Monad (liftM2)

import qualified Algebra.IntegralDomain as Integral

import NumericPrelude.Numeric
import NumericPrelude.Base



{-
This is a brute-force implementation.
No Karatsuba, No Toom-Cook, No Fourier.
-}
convolve ::
   (Storable.C a, Tuple.ValueOf a ~ al, Memory.C al, A.PseudoRing al) =>
   Param.T p (SV.Vector a) -> CausalP.T p al al
convolve mask =
   let len = fmap SV.length mask
   in  CausalP.zipWith scalarProduct
         (fmap (fromIntegral :: Int -> Word) len)
       <<<
       RingBuffer.trackConst A.zero len &&& provideMask mask

convolvePacked ::
   (TypeNum.Positive n, TypeNum.Positive (n :*: asize),
    Storable.C a, Tuple.ValueOf a ~ Value al,
    LLVM.IsArithmetic al, LLVM.IsPrimitive al, IsSized al, SizeOf al ~ asize) =>
   Param.T p (SV.Vector a) ->
   CausalP.T p (Serial.Value n al) (Serial.Value n al)
convolvePacked mask =
   Serial.withSize $ \vectorSize ->
      let len = fmap SV.length mask
      in  CausalP.zipWith scalarProductPacked
             (fmap (fromIntegral :: Int -> Word) len)
          <<<
          RingBuffer.trackConst A.zero
             (fmap (flip Integral.divUp vectorSize) len)
          &&&
          provideMask mask

provideMask ::
   (Storable.C a) => Param.T p (SV.Vector a) -> CausalP.T p x (Value (Ptr a))
provideMask mask =
   CausalP.Cons
      (\p () _x () -> return (p,()))
      (return ())
      return
      (const $ const $ return ())
      (\p ->
         let (fp,ptr,_l) = SVU.unsafeToPointers $ Param.get mask p
         in  return (fp, (ptr, ())))
      -- keep the foreign ptr alive
      touchForeignPtr


scalarProduct ::
   (Storable.C a, Tuple.ValueOf a ~ al, Memory.C al, A.PseudoRing al) =>
   Value Word ->
   RingBuffer.T al -> Value (Ptr a) ->
   CodeGenFunction r al
scalarProduct n rb mask =
   fmap snd $
   Storable.arrayLoop n mask (A.zero, A.zero) $ \ptr (k, s) -> do
      a <- RingBuffer.index k rb
      b <- Storable.load ptr
      liftM2 (,) (A.inc k) (A.add s =<< A.mul a b)

_scalarProduct ::
   (Storable.C a, IsSized a,
    Tuple.ValueOf a ~ Value a, LLVM.IsArithmetic a) =>
   Value Word ->
   RingBuffer.T (Value a) -> Value (Ptr a) ->
   CodeGenFunction r (Value a)
_scalarProduct = scalarProduct


scalarProductPacked ::
   (Storable.C a,
    Tuple.ValueOf a ~ Value al, LLVM.IsArithmetic al,
    LLVM.IsPrimitive al, IsSized al, SizeOf al ~ asize,
    TypeNum.Positive n, TypeNum.Positive (n :*: asize)) =>
   Value Word ->
   RingBuffer.T (Serial.Value n al) -> Value (Ptr a) ->
   CodeGenFunction r (Serial.Value n al)
scalarProductPacked n0 rb mask0 = do
   (ax, rx) <- readSerialStart rb
   bx <- Storable.load mask0
   sx <- A.scale bx ax
   n1 <- A.dec n0
   mask1 <- Storable.incrementPtr mask0
   fmap snd $ Storable.arrayLoop n1 mask1 (rx, sx) $ \ptr (r1, s1) -> do
      (a,r2) <- readSerialNext rb r1
      b <- Storable.load ptr
      fmap ((,) r2) (A.add s1 =<< A.scale b a)


type
   Iterator n a =
      ((Serial.Value n a,
        {-
        I would like to use Serial.Iterator,
        but we need to read in reversed order,
        that is, from high to low indices.
        -}
        Serial.Value n a,
        Value Word),
       Value Word)

readSerialStart ::
   (LLVM.IsPrimitive a, IsSized a, SizeOf a ~ asize,
    TypeNum.Positive n, TypeNum.Positive (n :*: asize)) =>
   RingBuffer.T (Serial.Value n a) ->
   CodeGenFunction r (Serial.Value n a, Iterator n a)
readSerialStart rb = do
   a <- RingBuffer.index A.zero rb
   return (a, ((a, Tuple.undef, A.zero), A.zero))

readSerialNext ::
   (LLVM.IsPrimitive a, IsSized a, SizeOf a ~ asize,
    TypeNum.Positive n, TypeNum.Positive (n :*: asize)) =>
   RingBuffer.T (Serial.Value n a) ->
   Iterator n a ->
   CodeGenFunction r (Serial.Value n a, Iterator n a)
readSerialNext rb ((a0,r0,j0), k0) = do
   vectorEnd <- A.cmp LLVM.CmpEQ j0 A.zero
   ((r1,j1), k1) <-
      C.ifThen vectorEnd ((r0,j0), k0) $ do
         k <- A.inc k0
         r <- RingBuffer.index k rb
         return ((r, valueOf (fromIntegral $ Serial.size r :: Word)), k)
   j2 <- A.dec j1
   (ai,r2) <- Serial.shiftUp Tuple.undef r1
   (_, a1) <- Serial.shiftUp ai a0
   return (a1, ((a1,r2,j2), k1))