module Synthesizer.LLVM.Filter.NonRecursive (
convolve,
convolvePacked,
) where
import qualified Synthesizer.LLVM.CausalParameterized.ProcessPrivate as CausalP
import qualified Synthesizer.LLVM.CausalParameterized.RingBuffer as RingBuffer
import qualified Synthesizer.LLVM.Frame.SerialVector as Serial
import qualified Synthesizer.LLVM.Storable.Vector as SVU
import qualified Data.StorableVector as SV
import qualified LLVM.DSL.Parameter as Param
import qualified LLVM.Extra.Storable as Storable
import qualified LLVM.Extra.Memory as Memory
import qualified LLVM.Extra.Control as C
import qualified LLVM.Extra.Arithmetic as A
import qualified LLVM.Extra.Tuple as Tuple
import qualified LLVM.Core as LLVM
import LLVM.Core (Value, valueOf, CodeGenFunction, IsSized, SizeOf)
import qualified Type.Data.Num.Decimal as TypeNum
import Type.Data.Num.Decimal.Number ((:*:))
import Foreign.ForeignPtr (touchForeignPtr)
import Foreign.Ptr (Ptr)
import Data.Word (Word)
import Control.Arrow ((<<<), (&&&))
import Control.Monad (liftM2)
import qualified Algebra.IntegralDomain as Integral
import NumericPrelude.Numeric
import NumericPrelude.Base
convolve ::
(Storable.C a, Tuple.ValueOf a ~ al, Memory.C al, A.PseudoRing al) =>
Param.T p (SV.Vector a) -> CausalP.T p al al
convolve mask =
let len = fmap SV.length mask
in CausalP.zipWith scalarProduct
(fmap (fromIntegral :: Int -> Word) len)
<<<
RingBuffer.trackConst A.zero len &&& provideMask mask
convolvePacked ::
(TypeNum.Positive n, TypeNum.Positive (n :*: asize),
Storable.C a, Tuple.ValueOf a ~ Value al,
LLVM.IsArithmetic al, LLVM.IsPrimitive al, IsSized al, SizeOf al ~ asize) =>
Param.T p (SV.Vector a) ->
CausalP.T p (Serial.Value n al) (Serial.Value n al)
convolvePacked mask =
Serial.withSize $ \vectorSize ->
let len = fmap SV.length mask
in CausalP.zipWith scalarProductPacked
(fmap (fromIntegral :: Int -> Word) len)
<<<
RingBuffer.trackConst A.zero
(fmap (flip Integral.divUp vectorSize) len)
&&&
provideMask mask
provideMask ::
(Storable.C a) => Param.T p (SV.Vector a) -> CausalP.T p x (Value (Ptr a))
provideMask mask =
CausalP.Cons
(\p () _x () -> return (p,()))
(return ())
return
(const $ const $ return ())
(\p ->
let (fp,ptr,_l) = SVU.unsafeToPointers $ Param.get mask p
in return (fp, (ptr, ())))
touchForeignPtr
scalarProduct ::
(Storable.C a, Tuple.ValueOf a ~ al, Memory.C al, A.PseudoRing al) =>
Value Word ->
RingBuffer.T al -> Value (Ptr a) ->
CodeGenFunction r al
scalarProduct n rb mask =
fmap snd $
Storable.arrayLoop n mask (A.zero, A.zero) $ \ptr (k, s) -> do
a <- RingBuffer.index k rb
b <- Storable.load ptr
liftM2 (,) (A.inc k) (A.add s =<< A.mul a b)
_scalarProduct ::
(Storable.C a, IsSized a,
Tuple.ValueOf a ~ Value a, LLVM.IsArithmetic a) =>
Value Word ->
RingBuffer.T (Value a) -> Value (Ptr a) ->
CodeGenFunction r (Value a)
_scalarProduct = scalarProduct
scalarProductPacked ::
(Storable.C a,
Tuple.ValueOf a ~ Value al, LLVM.IsArithmetic al,
LLVM.IsPrimitive al, IsSized al, SizeOf al ~ asize,
TypeNum.Positive n, TypeNum.Positive (n :*: asize)) =>
Value Word ->
RingBuffer.T (Serial.Value n al) -> Value (Ptr a) ->
CodeGenFunction r (Serial.Value n al)
scalarProductPacked n0 rb mask0 = do
(ax, rx) <- readSerialStart rb
bx <- Storable.load mask0
sx <- A.scale bx ax
n1 <- A.dec n0
mask1 <- Storable.incrementPtr mask0
fmap snd $ Storable.arrayLoop n1 mask1 (rx, sx) $ \ptr (r1, s1) -> do
(a,r2) <- readSerialNext rb r1
b <- Storable.load ptr
fmap ((,) r2) (A.add s1 =<< A.scale b a)
type
Iterator n a =
((Serial.Value n a,
Serial.Value n a,
Value Word),
Value Word)
readSerialStart ::
(LLVM.IsPrimitive a, IsSized a, SizeOf a ~ asize,
TypeNum.Positive n, TypeNum.Positive (n :*: asize)) =>
RingBuffer.T (Serial.Value n a) ->
CodeGenFunction r (Serial.Value n a, Iterator n a)
readSerialStart rb = do
a <- RingBuffer.index A.zero rb
return (a, ((a, Tuple.undef, A.zero), A.zero))
readSerialNext ::
(LLVM.IsPrimitive a, IsSized a, SizeOf a ~ asize,
TypeNum.Positive n, TypeNum.Positive (n :*: asize)) =>
RingBuffer.T (Serial.Value n a) ->
Iterator n a ->
CodeGenFunction r (Serial.Value n a, Iterator n a)
readSerialNext rb ((a0,r0,j0), k0) = do
vectorEnd <- A.cmp LLVM.CmpEQ j0 A.zero
((r1,j1), k1) <-
C.ifThen vectorEnd ((r0,j0), k0) $ do
k <- A.inc k0
r <- RingBuffer.index k rb
return ((r, valueOf (fromIntegral $ Serial.size r :: Word)), k)
j2 <- A.dec j1
(ai,r2) <- Serial.shiftUp Tuple.undef r1
(_, a1) <- Serial.shiftUp ai a0
return (a1, ((a1,r2,j2), k1))