{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE Rank2Types #-}
{-# LANGUAGE RebindableSyntax #-}
module Synthesizer.LLVM.Causal.Helix (
static,
staticPacked,
dynamic,
dynamicLimited,
zigZag,
zigZagPacked,
zigZagLong,
zigZagLongPacked,
) where
import qualified Synthesizer.LLVM.Causal.ProcessPacked as CausalPS
import qualified Synthesizer.LLVM.Causal.Private as CausalPriv
import qualified Synthesizer.LLVM.Causal.Process as Causal
import qualified Synthesizer.LLVM.Causal.Functional as Func
import qualified Synthesizer.LLVM.Generator.Source as Source
import qualified Synthesizer.LLVM.Generator.SignalPacked as SigPS
import qualified Synthesizer.LLVM.Generator.Private as SigPriv
import qualified Synthesizer.LLVM.Generator.Signal as Sig
import qualified Synthesizer.LLVM.Causal.RingBufferForward as RingBuffer
import qualified Synthesizer.LLVM.Frame.SerialVector as SerialExp
import qualified Synthesizer.LLVM.Frame.SerialVector.Code as Serial
import qualified Synthesizer.LLVM.Frame.SerialVector.Class as SerialClass
import qualified Synthesizer.LLVM.Interpolation as Ip
import Synthesizer.LLVM.Causal.Functional (($&), (&|&))
import Synthesizer.LLVM.Private (noLocalPtr)
import Synthesizer.Causal.Class (($*), ($<))
import qualified LLVM.DSL.Expression.Vector as ExprVec
import qualified LLVM.DSL.Expression as Expr
import LLVM.DSL.Expression (Exp, (<*), (>=*))
import qualified LLVM.Extra.Multi.Value.Storable as Storable
import qualified LLVM.Extra.Multi.Value.Marshal as Marshal
import qualified LLVM.Extra.Multi.Value.Vector as MultiValueVec
import qualified LLVM.Extra.Multi.Value as MultiValue
import qualified LLVM.Extra.Multi.Vector as MultiVector
import qualified LLVM.Extra.Arithmetic as A
import qualified LLVM.Extra.Memory as Memory
import qualified LLVM.Core as LLVM
import qualified Type.Data.Num.Decimal as TypeNum
import Data.Word (Word)
import Control.Arrow (first, (<<<))
import Control.Category (id)
import Control.Functor.HT (unzip)
import Data.Traversable (mapM)
import Data.Tuple.HT (mapPair, mapFst)
import qualified Algebra.Ring as Ring
import NumericPrelude.Numeric hiding (splitFraction)
import NumericPrelude.Base hiding (unzip, zip, mapM, id)
import Prelude ()
static ::
(Ip.C nodesStep, Ip.C nodesLeap) =>
(Storable.C vh, MultiValue.T vh ~ v) =>
(Marshal.C a, MultiValue.Field a, MultiValue.RationalConstant a) =>
(MultiValue.Fraction a, MultiValue.NativeFloating a ar) =>
(MultiValueVec.NativeFloating a ar, MultiValue.T a ~ am) =>
(forall r. Ip.T r nodesLeap am v) ->
(forall r. Ip.T r nodesStep am v) ->
Exp Int ->
Exp a ->
Exp (Source.StorableVector vh) ->
Causal.T (am, am) v
static ipLeap ipStep periodInt period vec =
let periodWord = wordFromInt periodInt
cellMargin = combineMarginParams ipLeap ipStep periodInt
in interpolateCell ipLeap ipStep
<<<
first (peekCell cellMargin periodWord vec)
<<<
flattenShapePhaseProc periodWord period
<<<
first
(limitShape cellMargin periodInt
(intFromWord $ Source.storableVectorLength vec))
intFromWord :: Exp Word -> Exp Int
intFromWord = Expr.liftReprM LLVM.bitcast
wordFromInt :: Exp Int -> Exp Word
wordFromInt = Expr.liftReprM LLVM.bitcast
staticPacked ::
(Ip.C nodesStep, Ip.C nodesLeap) =>
(Storable.C vh, MultiValue.T vh ~ ve, SerialClass.Element v ~ ve) =>
(SerialClass.Size (nodesLeap (nodesStep v)) ~ n,
SerialClass.Write (nodesLeap (nodesStep v)),
SerialClass.Element (nodesLeap (nodesStep v)) ~
nodesLeap (nodesStep (SerialClass.Element v))) =>
(TypeNum.Positive n) =>
(Marshal.C a, MultiVector.Field a, MultiVector.Real a,
MultiVector.Fraction a, MultiVector.RationalConstant a,
MultiVector.NativeFloating n a ar) =>
(forall r. Ip.T r nodesLeap (Serial.Value n a) v) ->
(forall r. Ip.T r nodesStep (Serial.Value n a) v) ->
Exp Int ->
Exp a ->
Exp (Source.StorableVector vh) ->
Causal.T (Serial.Value n a, Serial.Value n a) v
staticPacked ipLeap ipStep periodInt period vec =
let periodWord = wordFromInt periodInt
cellMargin = combineMarginParams ipLeap ipStep periodInt
in interpolateCell ipLeap ipStep
<<<
first (CausalPS.pack
(peekCell (elementMargin cellMargin) periodWord vec))
<<<
flattenShapePhaseProcPacked periodWord period
<<<
first
(limitShapePacked cellMargin periodInt
(intFromWord $ Source.storableVectorLength vec))
dynamicLimited ::
(Ip.C nodesStep, Ip.C nodesLeap) =>
(Marshal.C a, MultiValue.Field a, MultiValue.Fraction a,
MultiValue.Select a, MultiValue.Comparison a,
MultiValue.NativeFloating a ar,
MultiValue.RationalConstant a,
MultiValueVec.NativeFloating a ar) =>
(MultiValue.T a ~ am) =>
(Memory.C v) =>
(forall r. Ip.T r nodesLeap am v) ->
(forall r. Ip.T r nodesStep am v) ->
Exp Int ->
Exp a ->
Sig.T v ->
Causal.T (am, am) v
dynamicLimited ipLeap ipStep periodInt period sig =
dynamicGen
(\cellMargin (skips, fracs) ->
let windows =
(RingBuffer.trackSkip
(wordFromInt $ Ip.marginNumberExp cellMargin) sig)
$& skips
in (windows,
Causal.delay1 zero $& skips,
Causal.delay1 zero $& fracs))
ipLeap ipStep periodInt period
dynamic ::
(Ip.C nodesStep, Ip.C nodesLeap) =>
(Marshal.C a, MultiValue.Field a, MultiValue.Fraction a,
MultiValue.Select a, MultiValue.Comparison a,
MultiValue.NativeFloating a ar,
MultiValue.RationalConstant a,
MultiValueVec.NativeFloating a ar) =>
(MultiValue.T a ~ am) =>
(Memory.C v) =>
(forall r. Ip.T r nodesLeap am v) ->
(forall r. Ip.T r nodesStep am v) ->
Exp Int ->
Exp a ->
Sig.T v ->
Causal.T (am, am) v
dynamic ipLeap ipStep periodInt period sig =
dynamicGen
(\cellMargin (skips, fracs) ->
let
((running, actualSkips), windows) =
mapFst unzip $ unzip $
(RingBuffer.trackSkipHold
(wordFromInt (Ip.marginNumberExp cellMargin) + 1) sig)
$& skips
holdFracs =
Causal.zipWith (\r fr -> Expr.select r fr 1)
$&
running &|& (Causal.delay1 zero $& fracs)
in (windows, actualSkips, holdFracs))
ipLeap ipStep periodInt period
dynamicGen ::
(Ip.C nodesStep, Ip.C nodesLeap) =>
(Marshal.C a, MultiValue.Field a, MultiValue.Fraction a,
MultiValue.Select a, MultiValue.Comparison a,
MultiValue.NativeFloating a ar,
MultiValue.RationalConstant a,
MultiValueVec.NativeFloating a ar) =>
(MultiValue.T a ~ am) =>
(Memory.C v) =>
(Exp (Ip.Margin (nodesLeap (nodesStep v))) ->
(Func.T (am, am) (MultiValue.T Word),
Func.T (am, am) am) ->
(Func.T (am, am) (RingBuffer.T v),
Func.T (am, am) (MultiValue.T Word),
Func.T (am, am) am)) ->
(forall r. Ip.T r nodesLeap am v) ->
(forall r. Ip.T r nodesStep am v) ->
Exp Int ->
Exp a ->
Causal.T (am, am) v
dynamicGen limitMaxShape ipLeap ipStep periodInt period =
let periodWord = wordFromInt periodInt
cellMargin = combineMarginParams ipLeap ipStep periodInt
minShape = wordFromInt $ fst $ shapeMargin cellMargin periodInt
in Func.withArgs $ \(shape, phase) ->
let (windows, skips, fracs) =
limitMaxShape cellMargin $
unzip (integrateFrac $& (limitMinShape minShape $& shape))
(offsets, shapePhases) =
unzip
(flattenShapePhaseProc periodWord period $&
(constantFromWord minShape + fracs)
&|&
(Causal.osciCoreSync $&
phase
&|&
negate
(Causal.map ((/period)) $&
(Causal.map Expr.fromIntegral $& skips))))
in interpolateCell ipLeap ipStep $&
(CausalPriv.map
(\(buffer, offset) -> do
p <- Expr.unExp periodWord
cellFromBuffer p buffer offset)
$&
windows
&|&
offsets)
&|&
shapePhases
constantFromWord ::
(MultiValue.NativeFloating a ar) =>
Exp Word -> Func.T inp (MultiValue.T a)
constantFromWord x =
Func.fromSignal (Causal.map Expr.fromIntegral $* Sig.constant x)
limitMinShape ::
(Marshal.C a, MultiValue.Select a, MultiValue.Comparison a,
MultiValue.NativeFloating a ar) =>
Exp Word ->
Causal.T (MultiValue.T a) (MultiValue.T a)
limitMinShape xLim =
Causal.mapAccum
(\x lim ->
Expr.unzip $
Expr.select (x>=*lim) (Expr.zip (x-lim) zero) (Expr.zip zero (lim-x)))
(Expr.fromIntegral xLim)
integrateFrac ::
(Marshal.C a, MultiValue.Additive a,
MultiValueVec.NativeFloating a ar, LLVM.IsPrimitive ar) =>
Causal.T (MultiValue.T a) (MultiValue.T Word, MultiValue.T a)
integrateFrac =
Causal.mapAccum
(\a frac ->
let s = ExprVec.splitFractionToInt (a+frac)
in (s, snd s))
zero
interpolateCell ::
(Ip.C nodesStep, Ip.C nodesLeap) =>
(forall r. Ip.T r nodesLeap a v) ->
(forall r. Ip.T r nodesStep a v) ->
Causal.T (nodesLeap (nodesStep v), (a, a)) v
interpolateCell ipLeap ipStep =
CausalPriv.map
(\(nodes, (leap,step)) ->
ipLeap leap =<< mapM (ipStep step) nodes)
cellFromBuffer ::
(Memory.C a, Ip.C nodesLeap, Ip.C nodesStep) =>
MultiValue.T Word ->
RingBuffer.T a ->
MultiValue.T Word ->
LLVM.CodeGenFunction r (nodesLeap (nodesStep a))
cellFromBuffer periodInt buffer offset =
Ip.indexNodesExp
(Ip.indexNodesExp (flip RingBuffer.index buffer) A.one)
periodInt offset
elementMargin ::
Exp (Ip.Margin (nodesLeap (nodesStep v))) ->
Exp (Ip.Margin (nodesLeap (nodesStep (SerialClass.Element v))))
elementMargin = Expr.liftReprM return
peekCell ::
(Storable.C a, MultiValue.T a ~ value, Ip.C nodesLeap, Ip.C nodesStep) =>
Exp (Ip.Margin (nodesLeap (nodesStep value))) ->
Exp Word ->
Exp (Source.StorableVector a) ->
Causal.T (MultiValue.T Word) (nodesLeap (nodesStep value))
peekCell margin periodWord vec =
CausalPriv.map
(\n -> do
~(MultiValue.Cons (ptr,_l)) <- Expr.unExp vec
~(MultiValue.Cons offset) <-
Expr.unExp $ intFromWord (Expr.lift0 n) - Ip.marginOffsetExp margin
perInt <- Expr.unExp $ intFromWord periodWord
Ip.loadNodesExp (Ip.loadNodesExp Storable.load A.one) perInt
=<< Storable.advancePtr offset ptr)
flattenShapePhaseProc ::
(MultiValue.Field a, MultiValue.RationalConstant a, MultiValue.Fraction a) =>
(MultiValue.NativeFloating a ar, MultiValueVec.NativeFloating a ar) =>
Exp Word ->
Exp a ->
Causal.T
(MultiValue.T a, MultiValue.T a)
(MultiValue.T Word, (MultiValue.T a, MultiValue.T a))
flattenShapePhaseProc periodInt period =
Causal.map
(\(shape, phase) -> flattenShapePhase periodInt period shape phase)
_flattenShapePhaseProc ::
(MultiValue.Field a, MultiValue.RationalConstant a, MultiValue.Fraction a) =>
(MultiValue.NativeFloating a ar) =>
Exp Word ->
Exp a ->
Causal.T
(MultiValue.T a, MultiValue.T a)
(MultiValue.T Word, (MultiValue.T a, MultiValue.T a))
_flattenShapePhaseProc period32 period =
CausalPriv.map
(\(shape, phase) -> do
perInt <- Expr.unExp period32
per <- Expr.unExp period
_flattenShapePhase perInt per shape phase)
flattenShapePhaseProcPacked ::
(TypeNum.Positive n, MultiVector.Field a, MultiVector.RationalConstant a) =>
(MultiVector.Fraction a, MultiVector.NativeFloating n a ar) =>
Exp Word ->
Exp a ->
Causal.T
(Serial.Value n a, Serial.Value n a)
(Serial.Value n Word, (Serial.Value n a, Serial.Value n a))
flattenShapePhaseProcPacked periodInt period =
Causal.zipWith
(flattenShapePhase
(SerialExp.upsample periodInt) (SerialExp.upsample period))
flattenShapePhase ::
(MultiValue.Field a, MultiValue.RationalConstant a, MultiValue.Fraction a) =>
(MultiValueVec.NativeFloating a ar, MultiValueVec.NativeInteger i ir) =>
(LLVM.ShapeOf ir ~ LLVM.ShapeOf ar) =>
Exp i -> Exp a ->
Exp a -> Exp a ->
(Exp i, (Exp a, Exp a))
flattenShapePhase periodInt period shape phase =
let qLeap = Expr.fraction $ shape/period - phase
(n,qStep) =
ExprVec.splitFractionToInt $
Expr.max zero $
shape - qLeap * ExprVec.fromIntegral periodInt
in (n,(qLeap,qStep))
_flattenShapePhase ::
(MultiValue.Field a, MultiValue.RationalConstant a, MultiValue.Fraction a) =>
(MultiValue.NativeFloating a ar, MultiValue.NativeInteger i ir) =>
MultiValue.T i ->
MultiValue.T a ->
MultiValue.T a -> MultiValue.T a ->
LLVM.CodeGenFunction r (MultiValue.T i, (MultiValue.T a, MultiValue.T a))
_flattenShapePhase = Expr.unliftM4 $ \periodInt period shape phase ->
let qLeap = Expr.fraction $ shape/period - phase
(n,qStep) =
Expr.splitFractionToInt $
Expr.max zero $
shape - qLeap * Expr.fromIntegral periodInt
in (n,(qLeap,qStep))
limitShape ::
(Ip.C nodesStep, Ip.C nodesLeap) =>
(Marshal.C t, MultiValue.Real t, MultiValue.NativeFloating t tr) =>
(i ~ Int) =>
Exp (Ip.Margin (nodesLeap (nodesStep value))) ->
Exp i -> Exp i -> Causal.MV t t
limitShape margin periodInt len =
Causal.zipWith Expr.limit
$<
limitShapeSignal margin periodInt len
limitShapePacked ::
(Ip.C nodesStep, Ip.C nodesLeap) =>
(Marshal.C t, MultiValue.NativeFloating t tr) =>
(TypeNum.Positive n, MultiVector.Real t) =>
(i ~ Int) =>
Exp (Ip.Margin (nodesLeap (nodesStep value))) ->
Exp i ->
Exp i ->
Causal.T (Serial.Value n t) (Serial.Value n t)
limitShapePacked margin periodInt len =
Causal.zipWith
(\(minShape,maxShape) shape ->
SerialExp.limit
(SerialExp.upsample minShape,
SerialExp.upsample maxShape)
shape)
$<
limitShapeSignal margin periodInt len
limitShapeSignal ::
(Ip.C nodesStep, Ip.C nodesLeap) =>
(Marshal.C t, MultiValue.NativeFloating t tr) =>
(i ~ Int) =>
Exp (Ip.Margin (nodesLeap (nodesStep value))) ->
Exp i ->
Exp i ->
Sig.T (MultiValue.T t, MultiValue.T t)
limitShapeSignal margin periodInt len =
SigPriv.Cons
(\minMax -> noLocalPtr $ \() -> return (minMax, ()))
(do
limits <-
Expr.bundle
(mapPair (Expr.fromIntegral, Expr.fromIntegral) $
shapeLimits margin periodInt len)
return (limits, ()))
(const $ return ())
shapeLimits ::
(Ip.C nodesLeap, Ip.C nodesStep, Exp Int ~ t) =>
Exp (Ip.Margin (nodesLeap (nodesStep value))) ->
t -> t -> (t, t)
shapeLimits margin periodInt len =
case shapeMargin margin periodInt of
(leftMargin, rightMargin) -> (leftMargin, len - rightMargin)
shapeMargin ::
(Ip.C nodesLeap, Ip.C nodesStep, Exp Int ~ i) =>
Exp (Ip.Margin (nodesLeap (nodesStep value))) ->
i -> (i, i)
shapeMargin margin periodInt =
let (marginNumber, marginOffset) =
Expr.unzip $
Expr.lift1 (uncurry MultiValue.zip . Ip.unzipMargin) margin
leftMargin = marginOffset + periodInt
rightMargin = marginNumber - leftMargin
in (leftMargin, rightMargin)
_shapeLimits ::
(Ip.C nodesLeap, Ip.C nodesStep) =>
(MultiValue.NativeFloating t tr) =>
(MultiValue.Additive t) =>
Ip.Margin (nodesLeap (nodesStep value)) ->
Exp Word -> Exp t -> (Exp t, Exp t)
_shapeLimits margin periodInt len =
let (leftMargin, rightMargin) = _shapeMargin margin periodInt
in (Expr.fromIntegral leftMargin, len - Expr.fromIntegral rightMargin)
_shapeMargin ::
(Ip.C nodesLeap, Ip.C nodesStep, Ring.C i) =>
Ip.Margin (nodesLeap (nodesStep value)) ->
i -> (i, i)
_shapeMargin margin periodInt =
let leftMargin = fromIntegral (Ip.marginOffset margin) + periodInt
rightMargin = fromIntegral (Ip.marginNumber margin) - leftMargin
in (leftMargin, rightMargin)
combineMarginParams ::
(Ip.C nodesStep, Ip.C nodesLeap) =>
(forall r. Ip.T r nodesLeap a v) ->
(forall r. Ip.T r nodesStep a v) ->
Exp Int ->
Exp (Ip.Margin (nodesLeap (nodesStep v)))
combineMarginParams ipLeap ipStep periodInt =
let marginLeap = Ip.toMargin ipLeap in
let marginStep = Ip.toMargin ipStep in
Expr.lift2 Ip.zipMargin
(fromIntegral (Ip.marginNumber marginStep) +
fromIntegral (Ip.marginNumber marginLeap) * periodInt)
(fromIntegral (Ip.marginOffset marginStep) +
fromIntegral (Ip.marginOffset marginLeap) * periodInt)
_combineMargins ::
Ip.Margin (nodesLeap value) ->
Ip.Margin (nodesStep value) ->
Int ->
Ip.Margin (nodesLeap (nodesStep value))
_combineMargins marginLeap marginStep periodInt =
Ip.Margin {
Ip.marginNumber =
Ip.marginNumber marginStep +
Ip.marginNumber marginLeap * periodInt,
Ip.marginOffset =
Ip.marginOffset marginStep +
Ip.marginOffset marginLeap * periodInt
}
zigZagLong ::
(Marshal.C a) =>
(MultiValue.Select a, MultiValue.Comparison a, MultiValue.Fraction a) =>
(MultiValue.Field a, MultiValue.RationalConstant a) =>
Exp a -> Exp a -> Causal.MV a a
zigZagLong =
zigZagLongGen (Causal.fromSignal . Sig.constant) zigZag
zigZagLongPacked ::
(Marshal.Vector n a) =>
(MultiVector.Field a, MultiVector.Fraction a) =>
(MultiVector.RationalConstant a) =>
(MultiVector.Select a, MultiVector.Comparison a) =>
Exp a -> Exp a -> Causal.T (Serial.Value n a) (Serial.Value n a)
zigZagLongPacked =
zigZagLongGen (Causal.fromSignal . SigPS.constant) zigZagPacked
zigZagLongGen ::
(MultiValue.RationalConstant a, MultiValue.Field a) =>
(A.RationalConstant al, A.Field al) =>
(Exp a -> Causal.T al al) ->
(Exp a -> Causal.T al al) ->
Exp a -> Exp a -> Causal.T al al
zigZagLongGen constant zz prefix loop =
zz (negate $ prefix/loop) * constant loop + constant prefix
<<<
id / constant loop
zigZag ::
(Marshal.C a) =>
(MultiValue.Select a, MultiValue.Comparison a, MultiValue.Fraction a) =>
(MultiValue.Field a, MultiValue.RationalConstant a) =>
Exp a -> Causal.MV a a
zigZag start =
Causal.map (\x -> 1 - abs (1-x))
<<<
Causal.mapAccum
(\d t0 -> let t1 = t0+d in (t0, wrap Expr.select (0<*) t1))
start
zigZagPacked ::
(TypeNum.Positive n) =>
(Marshal.C a) =>
(MultiVector.Field a, MultiVector.Fraction a) =>
(MultiVector.RationalConstant a) =>
(MultiVector.Select a, MultiVector.Comparison a) =>
Exp a -> Causal.T (Serial.Value n a) (Serial.Value n a)
zigZagPacked start =
Causal.map (\x -> 1 - abs (1-x))
<<<
Causal.mapAccum
(\d t0 ->
let (t1,cum) = SerialExp.cumulate t0 d
in (wrap SerialExp.select (SerialExp.cmp LLVM.CmpLT zero) cum, t1))
start
wrap ::
(MultiValue.Field a, MultiValue.Fraction a, MultiValue.RationalConstant a) =>
(Exp b -> Exp a -> Exp a -> Exp a) ->
(Exp a -> Exp b) ->
Exp a -> Exp a
wrap select positive a = select (positive a) (2 * Expr.fraction (a/2)) a