{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
module Synthesizer.LLVM.Filter.NonRecursive (
convolve,
convolvePacked,
) where
import qualified Synthesizer.LLVM.Causal.Process as Causal
import qualified Synthesizer.LLVM.Causal.Private as CausalPriv
import qualified Synthesizer.LLVM.Generator.Source as Source
import qualified Synthesizer.LLVM.Generator.Signal as Sig
import qualified Synthesizer.LLVM.RingBuffer as RingBuffer
import qualified Synthesizer.LLVM.Frame.SerialVector.Code as Serial
import qualified Synthesizer.Causal.Class as CausalClass
import Synthesizer.Causal.Class (($<))
import qualified LLVM.DSL.Expression as Expr
import LLVM.DSL.Expression (Exp)
import qualified LLVM.Extra.Multi.Value.Storable as Storable
import qualified LLVM.Extra.Multi.Value.Marshal as Marshal
import qualified LLVM.Extra.Multi.Value as MultiValue
import qualified LLVM.Extra.Multi.Vector as MultiVector
import qualified LLVM.Extra.Control as C
import qualified LLVM.Extra.Arithmetic as A
import qualified LLVM.Extra.Tuple as Tuple
import qualified LLVM.Core as LLVM
import qualified Type.Data.Num.Decimal as TypeNum
import Foreign.Ptr (Ptr)
import Data.Word (Word)
import Control.Arrow ((<<<), (&&&))
import Control.Monad (liftM2)
import NumericPrelude.Numeric
import NumericPrelude.Base
import Prelude ()
convolve ::
(Storable.C a, Marshal.C a, MultiValue.PseudoRing a, MultiValue.T a ~ am) =>
Exp (Source.StorableVector a) -> Causal.T am am
convolve mask =
let len = Source.storableVectorLength mask
in (CausalPriv.zipWith (\(MultiValue.Cons l) -> scalarProduct l)
$< Sig.constant len)
<<<
Causal.track Expr.zero len &&& provideMask mask
convolvePacked ::
(Marshal.Vector n a, MultiVector.PseudoRing a) =>
(Storable.C a, MultiValue.PseudoRing a, Serial.Value n a ~ v) =>
Exp (Source.StorableVector a) -> Causal.T v v
convolvePacked = convolvePackedAux TypeNum.singleton
convolvePackedAux ::
(Marshal.Vector n a, MultiVector.PseudoRing a) =>
(Storable.C a, MultiValue.PseudoRing a, Serial.Value n a ~ v) =>
TypeNum.Singleton n -> Exp (Source.StorableVector a) -> Causal.T v v
convolvePackedAux vectorSize mask =
let len = Source.storableVectorLength mask
in (CausalPriv.zipWith (\(MultiValue.Cons l) -> scalarProductPacked l)
$< Sig.constant len)
<<<
Causal.track Expr.zero
(divUp (TypeNum.integralFromSingleton vectorSize) len)
&&&
provideMask mask
divUp :: Exp Word -> Exp Word -> Exp Word
divUp k n = Expr.idiv (n+(k-1)) k
provideMask ::
(Storable.C a) =>
Exp (Source.StorableVector a) -> Causal.T x (LLVM.Value (Ptr a))
provideMask mask =
CausalClass.fromSignal $
fmap (\(MultiValue.Cons (ptr,_l)) -> ptr) $
Sig.constant mask
scalarProduct ::
(Storable.C a, Marshal.C a, MultiValue.T a ~ am, MultiValue.PseudoRing a) =>
LLVM.Value Word ->
(RingBuffer.T am, LLVM.Value (Ptr a)) ->
LLVM.CodeGenFunction r am
scalarProduct n (rb,mask) =
fmap snd $
Storable.arrayLoop n mask (A.zero, A.zero) $ \ptr (k, s) -> do
a <- RingBuffer.index k rb
b <- Storable.load ptr
liftM2 (,) (A.inc k) (A.add s =<< A.mul a b)
scalarProductPacked ::
(Storable.C a, Marshal.Vector n a, MultiVector.PseudoRing a) =>
LLVM.Value Word ->
(RingBuffer.T (Serial.Value n a), LLVM.Value (Ptr a)) ->
LLVM.CodeGenFunction r (Serial.Value n a)
scalarProductPacked n0 (rb,mask0) = do
(ax, rx) <- readSerialStart rb
bx <- Storable.load mask0
sx <- Serial.scale bx ax
n1 <- A.dec n0
mask1 <- Storable.incrementPtr mask0
fmap snd $ Storable.arrayLoop n1 mask1 (rx, sx) $ \ptr (r1, s1) -> do
(a,r2) <- readSerialNext rb r1
b <- Storable.load ptr
fmap ((,) r2) (A.add s1 =<< Serial.scale b a)
type
Iterator n a =
((Serial.Value n a,
Serial.Value n a,
LLVM.Value Word),
LLVM.Value Word)
readSerialStart ::
(TypeNum.Positive n, Marshal.Vector n a) =>
RingBuffer.T (Serial.Value n a) ->
LLVM.CodeGenFunction r (Serial.Value n a, Iterator n a)
readSerialStart rb = do
a <- RingBuffer.index A.zero rb
return (a, ((a, Tuple.undef, A.zero), A.zero))
readSerialNext ::
(MultiValue.C a, Marshal.Vector n a) =>
RingBuffer.T (Serial.Value n a) ->
Iterator n a ->
LLVM.CodeGenFunction r (Serial.Value n a, Iterator n a)
readSerialNext rb ((a0,r0,j0), k0) = do
vectorEnd <- A.cmp LLVM.CmpEQ j0 A.zero
((r1,j1), k1) <-
C.ifThen vectorEnd ((r0,j0), k0) $ do
k <- A.inc k0
r <- RingBuffer.index k rb
return ((r, LLVM.valueOf (Serial.size r :: Word)), k)
j2 <- A.dec j1
(ai,r2) <- Serial.shiftUp Tuple.undef r1
(_, a1) <- Serial.shiftUp ai a0
return (a1, ((a1,r2,j2), k1))