{-# LANGUAGE NoImplicitPrelude #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE FlexibleContexts #-} module Synthesizer.LLVM.Filter.NonRecursive ( convolve, convolvePacked, ) where import qualified Synthesizer.LLVM.CausalParameterized.ProcessPrivate as CausalP import qualified Synthesizer.LLVM.CausalParameterized.RingBuffer as RingBuffer import qualified Synthesizer.LLVM.Parameter as Param import qualified Synthesizer.LLVM.Frame.SerialVector as Serial import qualified Synthesizer.LLVM.Storable.Vector as SVU import qualified Data.StorableVector as SV import qualified LLVM.Extra.Memory as Memory import qualified LLVM.Extra.Control as C import qualified LLVM.Extra.Arithmetic as A import qualified LLVM.Extra.Class as Class import LLVM.Extra.Class (undefTuple, ) import qualified LLVM.Core as LLVM import LLVM.Core (Value, valueOf, CodeGenFunction, IsSized, SizeOf, ) import qualified Type.Data.Num.Decimal as TypeNum import Type.Data.Num.Decimal.Number ((:*:), ) import Foreign.ForeignPtr (touchForeignPtr, ) import Foreign.Storable (Storable, ) import Foreign.Ptr (Ptr, ) import Data.Word (Word32, ) import Control.Arrow ((<<<), (&&&), ) import Control.Monad (liftM2, ) import qualified Algebra.IntegralDomain as Integral import NumericPrelude.Numeric import NumericPrelude.Base {- This is a brute-force implementation. No Karatsuba, No Toom-Cook, No Fourier. -} convolve :: (Storable a, Class.MakeValueTuple a, Class.ValueTuple a ~ al, Memory.C al, A.PseudoRing al) => Param.T p (SV.Vector a) -> CausalP.T p al al convolve mask = let len = fmap SV.length mask in CausalP.zipWith scalarProduct (fmap (fromIntegral :: Int -> Word32) len) <<< RingBuffer.trackConst A.zero len &&& provideMask mask convolvePacked :: (LLVM.IsPrimitive a, Memory.FirstClass a, Memory.Stored a ~ am, LLVM.IsPrimitive am, IsSized am, SizeOf am ~ amsize, TypeNum.Positive n, TypeNum.Positive (n :*: amsize), Class.MakeValueTuple a, Class.ValueTuple a ~ al, Memory.Struct al ~ am, Storable a, Memory.C al, LLVM.IsArithmetic a) => Param.T p (SV.Vector a) -> CausalP.T p (Serial.Value n a) (Serial.Value n a) convolvePacked mask = Serial.withSize $ \vectorSize -> let len = fmap SV.length mask in CausalP.zipWith scalarProductPacked (fmap (fromIntegral :: Int -> Word32) len) <<< RingBuffer.trackConst A.zero (fmap (flip Integral.divUp vectorSize) len) &&& provideMask mask provideMask :: (Storable a, Class.MakeValueTuple a, Class.ValueTuple a ~ al, Memory.C al, Memory.Struct al ~ am) => Param.T p (SV.Vector a) -> CausalP.T p x (Value (Ptr am)) provideMask mask = CausalP.Cons (\p () _x () -> return (p,())) (return ()) return (const $ const $ return ()) (\p -> let (fp,ptr,_l) = SVU.unsafeToPointers $ Param.get mask p in return (fp, (ptr, ()))) -- keep the foreign ptr alive touchForeignPtr scalarProduct :: (Memory.C a, Memory.Struct a ~ am, A.PseudoRing a) => Value Word32 -> RingBuffer.T a -> Value (Ptr am) -> CodeGenFunction r a scalarProduct n rb mask = fmap snd $ C.arrayLoop n mask (A.zero, A.zero) $ \ptr (k, s) -> do a <- RingBuffer.index k rb b <- Memory.load ptr liftM2 (,) (A.inc k) (A.add s =<< A.mul a b) _scalarProduct :: (Memory.FirstClass a, Memory.Stored a ~ am, IsSized am, LLVM.IsArithmetic a) => Value Word32 -> RingBuffer.T (Value a) -> Value (Ptr am) -> CodeGenFunction r (Value a) _scalarProduct = scalarProduct scalarProductPacked :: (LLVM.IsPrimitive a, Memory.FirstClass a, Memory.Stored a ~ am, LLVM.IsPrimitive am, IsSized am, SizeOf am ~ amsize, TypeNum.Positive n, TypeNum.Positive (n :*: amsize), LLVM.IsArithmetic a) => Value Word32 -> RingBuffer.T (Serial.Value n a) -> Value (Ptr am) -> CodeGenFunction r (Serial.Value n a) scalarProductPacked n0 rb mask0 = do (ax, rx) <- readSerialStart rb bx <- Memory.load mask0 sx <- A.scale bx ax n1 <- A.dec n0 mask1 <- A.advanceArrayElementPtr mask0 fmap snd $ C.arrayLoop n1 mask1 (rx, sx) $ \ptr (r1, s1) -> do (a,r2) <- readSerialNext rb r1 b <- Memory.load ptr fmap ((,) r2) (A.add s1 =<< A.scale b a) type Iterator n a = ((Serial.Value n a, {- I would like to use Serial.Iterator, but we need to read in reversed order, that is, from high to low indices. -} Serial.Value n a, Value Word32), Value Word32) readSerialStart :: (LLVM.IsPrimitive a, Memory.FirstClass a, Memory.Stored a ~ am, LLVM.IsPrimitive am, IsSized am, SizeOf am ~ amsize, TypeNum.Positive n, TypeNum.Positive (n :*: amsize)) => RingBuffer.T (Serial.Value n a) -> CodeGenFunction r (Serial.Value n a, Iterator n a) readSerialStart rb = do a <- RingBuffer.index A.zero rb return (a, ((a, undefTuple, A.zero), A.zero)) readSerialNext :: (LLVM.IsPrimitive a, Memory.FirstClass a, Memory.Stored a ~ am, LLVM.IsPrimitive am, IsSized am, SizeOf am ~ amsize, TypeNum.Positive n, TypeNum.Positive (n :*: amsize)) => RingBuffer.T (Serial.Value n a) -> Iterator n a -> CodeGenFunction r (Serial.Value n a, Iterator n a) readSerialNext rb ((a0,r0,j0), k0) = do vectorEnd <- A.cmp LLVM.CmpEQ j0 A.zero ((r1,j1), k1) <- C.ifThen vectorEnd ((r0,j0), k0) $ do k <- A.inc k0 r <- RingBuffer.index k rb return ((r, valueOf (fromIntegral $ Serial.size r :: Word32)), k) j2 <- A.dec j1 (ai,r2) <- Serial.shiftUp undefTuple r1 (_, a1) <- Serial.shiftUp ai a0 return (a1, ((a1,r2,j2), k1))