{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE FlexibleContexts #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}
module Synthesizer.LLVM.Frame.StereoInterleavedCode (
T,
Value,
interleave,
deinterleave,
fromMono,
assemble, dissect,
zero,
scale,
amplify,
envelope,
) where
import qualified Synthesizer.LLVM.Frame.Stereo as Stereo
import qualified Synthesizer.LLVM.Frame.SerialVector.Code as Serial
import qualified LLVM.Extra.Multi.Vector as MultiVector
import qualified LLVM.Extra.Multi.Value.Storable as Storable
import qualified LLVM.Extra.Multi.Value.Marshal as Marshal
import qualified LLVM.Extra.Multi.Value as MultiValue
import qualified LLVM.Extra.Arithmetic as A
import qualified LLVM.Core as LLVM
import LLVM.Core (Vector)
import qualified Type.Data.Num.Decimal as TypeNum
import qualified Foreign.Storable as St
import Foreign.Ptr (Ptr, castPtr)
import qualified Control.Applicative.HT as AppHT
import Control.Applicative (liftA2, pure)
import qualified Data.Foldable as Fold
import Data.Tuple.HT (mapPair)
import qualified Algebra.Additive as Additive
data T n a = Cons (Vector n a) (Vector n a)
type Value n a = MultiValue.T (T n a)
withSize :: (TypeNum.Natural n) => (Int -> m (Value n a)) -> m (Value n a)
withSize =
let sz ::
(TypeNum.Natural n) =>
TypeNum.Singleton n -> (Int -> m (Value n a)) -> m (Value n a)
sz n f = f (TypeNum.integralFromSingleton n)
in sz TypeNum.singleton
interleave ::
(TypeNum.Positive n, MultiVector.C a) =>
Stereo.T (Serial.Value n a) ->
LLVM.CodeGenFunction r (Value n a)
interleave x =
assemble . map Stereo.unMultiValue
=<< Serial.dissect (Stereo.multiValueSerial x)
deinterleave ::
(TypeNum.Positive n, MultiVector.C a) =>
Value n a ->
LLVM.CodeGenFunction r (Stereo.T (Serial.Value n a))
deinterleave v =
Stereo.unMultiValueSerial <$>
(Serial.assemble . map Stereo.multiValue =<< dissect v)
fromMono ::
(TypeNum.Positive n, MultiVector.C a) =>
Serial.Value n a ->
LLVM.CodeGenFunction r (Value n a)
fromMono x =
assemble . map pure =<< Serial.dissect x
assemble ::
(TypeNum.Positive n, MultiVector.C a) =>
[Stereo.T (MultiValue.T a)] -> LLVM.CodeGenFunction r (Value n a)
assemble x =
withSize $ \n ->
uncurry (liftA2 merge) .
mapPair (MultiVector.assemble, MultiVector.assemble) .
splitAt n .
concatMap Fold.toList $ x
dissect ::
(TypeNum.Positive n, MultiVector.C a) =>
Value n a -> LLVM.CodeGenFunction r [Stereo.T (MultiValue.T a)]
dissect v =
let (v0,v1) = split v in
fmap
(let aux (l:r:xs) = Stereo.cons l r : aux xs
aux [] = []
aux _ = error "odd number of stereo elements"
in aux) $
liftA2 (++)
(MultiVector.dissect v0)
(MultiVector.dissect v1)
merge :: MultiVector.T n a -> MultiVector.T n a -> MultiValue.T (T n a)
merge (MultiVector.Cons a) (MultiVector.Cons b) = MultiValue.Cons (a,b)
split :: MultiValue.T (T n a) -> (MultiVector.T n a, MultiVector.T n a)
split (MultiValue.Cons (a,b)) = (MultiVector.Cons a, MultiVector.Cons b)
merge_ ::
MultiValue.T (Vector n a) -> MultiValue.T (Vector n a) ->
MultiValue.T (T n a)
merge_ (MultiValue.Cons a) (MultiValue.Cons b) = MultiValue.Cons (a,b)
split_ ::
MultiValue.T (T n a) ->
(MultiValue.T (Vector n a), MultiValue.T (Vector n a))
split_ (MultiValue.Cons (a,b)) = (MultiValue.Cons a, MultiValue.Cons b)
instance (TypeNum.Positive n, MultiVector.C a) => MultiValue.C (T n a) where
type Repr (T n a) = (MultiVector.Repr n a, MultiVector.Repr n a)
cons (Cons v0 v1) = merge (MultiVector.cons v0) (MultiVector.cons v1)
undef = merge MultiVector.undef MultiVector.undef
zero = merge MultiVector.zero MultiVector.zero
phi bb =
fmap (uncurry merge) .
AppHT.mapPair (MultiVector.phi bb, MultiVector.phi bb) . split
addPhi bb a b =
case (split a, split b) of
((a0,a1), (b0,b1)) -> do
MultiVector.addPhi bb a0 b0
MultiVector.addPhi bb a1 b1
instance (Marshal.Vector n a) => Marshal.C (T n a) where
pack (Cons v0 v1) = Marshal.pack (v0,v1)
unpack = uncurry Cons . Marshal.unpack
instance
(TypeNum.Positive n, MultiVector.C a, St.Storable a) =>
St.Storable (T n a) where
sizeOf ~(Cons v0 v1) = St.sizeOf v0 + St.sizeOf v1
alignment ~(Cons v _) = St.alignment v
peek ptr =
let p = castPtr ptr
in liftA2 Cons
(St.peekElemOff p 0)
(St.peekElemOff p 1)
poke ptr (Cons v0 v1) =
let p = castPtr ptr
in St.pokeElemOff p 0 v0 >>
St.pokeElemOff p 1 v1
instance (TypeNum.Positive n, Storable.Vector a) => Storable.C (T n a) where
load ptrV = do
ptr <- castHalfPtr ptrV
liftA2 merge_
(Storable.load ptr)
(Storable.load =<< Storable.incrementPtr ptr)
store v ptrV = do
let (v0,v1) = split_ v
ptr <- castHalfPtr ptrV
Storable.storeNext v0 ptr >>= Storable.store v1
castHalfPtr ::
LLVM.Value (Ptr (T n a)) ->
LLVM.CodeGenFunction r (LLVM.Value (Ptr (Vector n a)))
castHalfPtr = LLVM.bitcast
instance
(TypeNum.Positive n, MultiVector.Additive a) =>
MultiValue.Additive (T n a) where
add = zipV merge A.add
sub = zipV merge A.sub
neg = mapV A.neg
zero :: (TypeNum.Positive n, Additive.C a) => T n a
zero = Cons (pure Additive.zero) (pure Additive.zero)
scale ::
(TypeNum.Positive n, MultiVector.PseudoRing a) =>
MultiValue.T a -> Value n a -> LLVM.CodeGenFunction r (Value n a)
scale a v = do
av <- MultiVector.replicate a
mapV (A.mul av) v
amplify ::
(TypeNum.Positive n, MultiVector.PseudoRing a) =>
a -> Value n a -> LLVM.CodeGenFunction r (Value n a)
amplify a = scale (MultiValue.cons a)
envelope ::
(TypeNum.Positive n, MultiVector.PseudoRing a) =>
Serial.Value n a -> Value n a -> LLVM.CodeGenFunction r (Value n a)
envelope e a =
zipV merge (flip A.mul) a =<< fromMono e
mapV :: (Applicative m) =>
(MultiVector.T n a -> m (MultiVector.T n a)) ->
Value n a -> m (Value n a)
mapV f x =
case split x of
(x0,x1) -> uncurry merge <$> liftA2 (,) (f x0) (f x1)
zipV :: (Applicative m) =>
(c -> c -> d) ->
(MultiVector.T n a ->
MultiVector.T n b ->
m c) ->
Value n a ->
Value n b ->
m d
zipV g f x y =
case (split x, split y) of
((x0,x1), (y0,y1)) -> liftA2 g (f x0 y0) (f x1 y1)