{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
module Synthesizer.LLVM.Filter.NonRecursive (
   convolve,
   convolvePacked,
   ) where

import qualified Synthesizer.LLVM.Causal.Process as Causal
import qualified Synthesizer.LLVM.Causal.Private as CausalPriv
import qualified Synthesizer.LLVM.Generator.Source as Source
import qualified Synthesizer.LLVM.Generator.Signal as Sig
import qualified Synthesizer.LLVM.RingBuffer as RingBuffer
import qualified Synthesizer.LLVM.Frame.SerialVector.Code as Serial

import qualified Synthesizer.Causal.Class as CausalClass
import Synthesizer.Causal.Class (($<))

import qualified LLVM.DSL.Expression as Expr
import LLVM.DSL.Expression (Exp)

import qualified LLVM.Extra.Multi.Value.Storable as Storable
import qualified LLVM.Extra.Multi.Value.Marshal as Marshal
import qualified LLVM.Extra.Multi.Value as MultiValue
import qualified LLVM.Extra.Multi.Vector as MultiVector
import qualified LLVM.Extra.Control as C
import qualified LLVM.Extra.Arithmetic as A
import qualified LLVM.Extra.Tuple as Tuple

import qualified LLVM.Core as LLVM

import qualified Type.Data.Num.Decimal as TypeNum

import Foreign.Ptr (Ptr)
import Data.Word (Word)

import Control.Arrow ((<<<), (&&&))
import Control.Monad (liftM2)

import NumericPrelude.Numeric
import NumericPrelude.Base
import Prelude ()


{-
This is a brute-force implementation.
No Karatsuba, No Toom-Cook, No Fourier.
-}
convolve ::
   (Storable.C a, Marshal.C a, MultiValue.PseudoRing a, MultiValue.T a ~ am) =>
   Exp (Source.StorableVector a) -> Causal.T am am
convolve mask =
   let len = Source.storableVectorLength mask
   in (CausalPriv.zipWith (\(MultiValue.Cons l) -> scalarProduct l)
         $< Sig.constant len)
      <<<
      Causal.track Expr.zero len &&& provideMask mask

convolvePacked ::
   (Marshal.Vector n a, MultiVector.PseudoRing a) =>
   (Storable.C a, MultiValue.PseudoRing a, Serial.Value n a ~ v) =>
   Exp (Source.StorableVector a) -> Causal.T v v
convolvePacked = convolvePackedAux TypeNum.singleton

convolvePackedAux ::
   (Marshal.Vector n a, MultiVector.PseudoRing a) =>
   (Storable.C a, MultiValue.PseudoRing a, Serial.Value n a ~ v) =>
   TypeNum.Singleton n -> Exp (Source.StorableVector a) -> Causal.T v v
convolvePackedAux vectorSize mask =
   let len = Source.storableVectorLength mask
   in (CausalPriv.zipWith (\(MultiValue.Cons l) -> scalarProductPacked l)
         $< Sig.constant len)
      <<<
      Causal.track Expr.zero
         (divUp (TypeNum.integralFromSingleton vectorSize) len)
      &&&
      provideMask mask

divUp :: Exp Word -> Exp Word -> Exp Word
divUp k n = Expr.idiv (n+(k-1)) k

provideMask ::
   (Storable.C a) =>
   Exp (Source.StorableVector a) -> Causal.T x (LLVM.Value (Ptr a))
provideMask mask =
   CausalClass.fromSignal $
   fmap (\(MultiValue.Cons (ptr,_l)) -> ptr) $
   Sig.constant mask


scalarProduct ::
   (Storable.C a, Marshal.C a, MultiValue.T a ~ am, MultiValue.PseudoRing a) =>
   LLVM.Value Word ->
   (RingBuffer.T am, LLVM.Value (Ptr a)) ->
   LLVM.CodeGenFunction r am
scalarProduct n (rb,mask) =
   fmap snd $
   Storable.arrayLoop n mask (A.zero, A.zero) $ \ptr (k, s) -> do
      a <- RingBuffer.index k rb
      b <- Storable.load ptr
      liftM2 (,) (A.inc k) (A.add s =<< A.mul a b)


scalarProductPacked ::
   (Storable.C a, Marshal.Vector n a, MultiVector.PseudoRing a) =>
   LLVM.Value Word ->
   (RingBuffer.T (Serial.Value n a), LLVM.Value (Ptr a)) ->
   LLVM.CodeGenFunction r (Serial.Value n a)
scalarProductPacked n0 (rb,mask0) = do
   (ax, rx) <- readSerialStart rb
   bx <- Storable.load mask0
   sx <- Serial.scale bx ax
   n1 <- A.dec n0
   mask1 <- Storable.incrementPtr mask0
   fmap snd $ Storable.arrayLoop n1 mask1 (rx, sx) $ \ptr (r1, s1) -> do
      (a,r2) <- readSerialNext rb r1
      b <- Storable.load ptr
      fmap ((,) r2) (A.add s1 =<< Serial.scale b a)


type
   Iterator n a =
      ((Serial.Value n a,
        {-
        I would like to use Serial.Iterator,
        but we need to read in reversed order,
        that is, from high to low indices.
        -}
        Serial.Value n a,
        LLVM.Value Word),
       LLVM.Value Word)

readSerialStart ::
   (TypeNum.Positive n, Marshal.Vector n a) =>
   RingBuffer.T (Serial.Value n a) ->
   LLVM.CodeGenFunction r (Serial.Value n a, Iterator n a)
readSerialStart rb = do
   a <- RingBuffer.index A.zero rb
   return (a, ((a, Tuple.undef, A.zero), A.zero))

readSerialNext ::
   (MultiValue.C a, Marshal.Vector n a) =>
   RingBuffer.T (Serial.Value n a) ->
   Iterator n a ->
   LLVM.CodeGenFunction r (Serial.Value n a, Iterator n a)
readSerialNext rb ((a0,r0,j0), k0) = do
   vectorEnd <- A.cmp LLVM.CmpEQ j0 A.zero
   ((r1,j1), k1) <-
      C.ifThen vectorEnd ((r0,j0), k0) $ do
         k <- A.inc k0
         r <- RingBuffer.index k rb
         return ((r, LLVM.valueOf (Serial.size r :: Word)), k)
   j2 <- A.dec j1
   (ai,r2) <- Serial.shiftUp Tuple.undef r1
   (_, a1) <- Serial.shiftUp ai a0
   return (a1, ((a1,r2,j2), k1))