module Synthesizer.LLVM.Interpolation (
C(margin),
loadNodes,
indexNodes,
Margin(..),
toMargin,
T,
Nodes02(..),
linear,
linearVector,
Nodes13(..),
cubic,
cubicVector,
) where
import qualified Synthesizer.LLVM.Simple.Value as Value
import qualified Synthesizer.LLVM.Frame.SerialVector as Serial
import qualified Synthesizer.Interpolation.Core as Interpolation
import qualified LLVM.Extra.Scalar as Scalar
import qualified LLVM.Extra.Arithmetic as A
import qualified LLVM.Extra.Class as Class
import qualified LLVM.Extra.Memory as Memory
import qualified LLVM.Util.Loop as Loop
import qualified LLVM.Core as LLVM
import LLVM.Core (CodeGenFunction, Value, )
import Foreign.Ptr (Ptr, )
import Data.Word (Word32, )
import qualified Type.Data.Num.Decimal as TypeNum
import qualified Control.Monad.Trans.State as MS
import Control.Applicative (Applicative, liftA2, pure, (<*>), )
import Data.Traversable (Traversable, traverse, sequenceA, foldMapDefault, )
import Data.Foldable (Foldable, foldMap, )
class (Applicative nodes, Traversable nodes) => C nodes where
margin :: Margin (nodes a)
data Margin nodes =
Margin { marginNumber, marginOffset :: Int }
deriving (Show, Eq)
type T r nodes a v = a -> nodes v -> CodeGenFunction r v
toMargin ::
(C nodes) =>
(forall r. T r nodes a v) ->
Margin (nodes v)
toMargin _ = margin
data Nodes02 a = Nodes02 {nodes02_0, nodes02_1 :: a}
instance C Nodes02 where
margin = Margin { marginNumber = 2, marginOffset = 0 }
instance Functor Nodes02 where
fmap f (Nodes02 x0 x1) = Nodes02 (f x0) (f x1)
instance Applicative Nodes02 where
pure x = Nodes02 x x
(Nodes02 f0 f1) <*> (Nodes02 x0 x1) = Nodes02 (f0 x0) (f1 x1)
instance Foldable Nodes02 where
foldMap = foldMapDefault
instance Traversable Nodes02 where
traverse f (Nodes02 x0 x1) = liftA2 Nodes02 (f x0) (f x1)
instance (Serial.Sized value) => Serial.Sized (Nodes02 value) where
type Size (Nodes02 value) = Serial.Size value
instance (Serial.Read v) => Serial.Read (Nodes02 v) where
type Element (Nodes02 v) = Nodes02 (Serial.Element v)
type ReadIt (Nodes02 v) = Nodes02 (Serial.ReadIt v)
extract = Serial.extractTraversable
readStart = Serial.readStartTraversable
readNext = Serial.readNextTraversable
instance (Serial.C v) => Serial.C (Nodes02 v) where
type WriteIt (Nodes02 v) = Nodes02 (Serial.WriteIt v)
insert = Serial.insertTraversable
writeStart = Serial.writeStartTraversable
writeNext = Serial.writeNextTraversable
writeStop = Serial.writeStopTraversable
instance (Class.Undefined a) => Class.Undefined (Nodes02 a) where
undefTuple = Class.undefTuplePointed
instance (Loop.Phi a) => Loop.Phi (Nodes02 a) where
phis = Class.phisTraversable
addPhis = Class.addPhisFoldable
type Struct02 a = LLVM.Struct (a, (a, ()))
memory02 ::
(Memory.C l) =>
Memory.Record r (Struct02 (Memory.Struct l)) (Nodes02 l)
memory02 =
liftA2 Nodes02
(Memory.element nodes02_0 TypeNum.d0)
(Memory.element nodes02_1 TypeNum.d1)
instance (Memory.C l) => Memory.C (Nodes02 l) where
type Struct (Nodes02 l) = Struct02 (Memory.Struct l)
load = Memory.loadRecord memory02
store = Memory.storeRecord memory02
decompose = Memory.decomposeRecord memory02
compose = Memory.composeRecord memory02
linear ::
(A.PseudoRing a, A.IntegerConstant a) =>
T r Nodes02 a a
linear r (Nodes02 a b) =
Scalar.unliftM3 (Value.unlift3 Interpolation.linear) a b r
linearVector ::
(A.PseudoModule v, A.Scalar v ~ a, A.IntegerConstant a) =>
T r Nodes02 a v
linearVector r (Nodes02 a b) =
Value.unlift3 Interpolation.linear a b r
data Nodes13 a = Nodes13 {nodes13_0, nodes13_1, nodes13_2, nodes13_3 :: a}
instance C Nodes13 where
margin = Margin { marginNumber = 4, marginOffset = 1 }
instance Functor Nodes13 where
fmap f (Nodes13 x0 x1 x2 x3) = Nodes13 (f x0) (f x1) (f x2) (f x3)
instance Applicative Nodes13 where
pure x = Nodes13 x x x x
(Nodes13 f0 f1 f2 f3) <*> (Nodes13 x0 x1 x2 x3) =
Nodes13 (f0 x0) (f1 x1) (f2 x2) (f3 x3)
instance Foldable Nodes13 where
foldMap = foldMapDefault
instance Traversable Nodes13 where
traverse f (Nodes13 x0 x1 x2 x3) =
pure Nodes13 <*> f x0 <*> f x1 <*> f x2 <*> f x3
instance (Serial.Sized value) => Serial.Sized (Nodes13 value) where
type Size (Nodes13 value) = Serial.Size value
instance (Serial.Read v) => Serial.Read (Nodes13 v) where
type Element (Nodes13 v) = Nodes13 (Serial.Element v)
type ReadIt (Nodes13 v) = Nodes13 (Serial.ReadIt v)
extract = Serial.extractTraversable
readStart = Serial.readStartTraversable
readNext = Serial.readNextTraversable
instance (Serial.C v) => Serial.C (Nodes13 v) where
type WriteIt (Nodes13 v) = Nodes13 (Serial.WriteIt v)
insert = Serial.insertTraversable
writeStart = Serial.writeStartTraversable
writeNext = Serial.writeNextTraversable
writeStop = Serial.writeStopTraversable
instance (Class.Undefined a) => Class.Undefined (Nodes13 a) where
undefTuple = Class.undefTuplePointed
instance (Loop.Phi a) => Loop.Phi (Nodes13 a) where
phis = Class.phisTraversable
addPhis = Class.addPhisFoldable
type Struct13 a = LLVM.Struct (a, (a, (a, (a, ()))))
memory13 ::
(Memory.C l) =>
Memory.Record r (Struct13 (Memory.Struct l)) (Nodes13 l)
memory13 =
pure Nodes13
<*> Memory.element nodes13_0 TypeNum.d0
<*> Memory.element nodes13_1 TypeNum.d1
<*> Memory.element nodes13_2 TypeNum.d2
<*> Memory.element nodes13_3 TypeNum.d3
instance (Memory.C l) => Memory.C (Nodes13 l) where
type Struct (Nodes13 l) = Struct13 (Memory.Struct l)
load = Memory.loadRecord memory13
store = Memory.storeRecord memory13
decompose = Memory.decomposeRecord memory13
compose = Memory.composeRecord memory13
cubic ::
(A.Field a, A.RationalConstant a) =>
T r Nodes13 a a
cubic r (Nodes13 a b c d) =
Scalar.unliftM5 (Value.unlift5 Interpolation.cubic) a b c d r
cubicVector ::
(A.PseudoModule v, A.Scalar v ~ a, A.Field a, A.RationalConstant a) =>
T r Nodes13 a v
cubicVector r (Nodes13 a b c d) =
Value.unlift5 Interpolation.cubic a b c d r
loadNodes ::
(C nodes) =>
(Value (Ptr am) -> CodeGenFunction r a) ->
Value Word32 ->
Value (Ptr am) -> CodeGenFunction r (nodes a)
loadNodes loadNode step =
MS.evalStateT $ sequenceA $ pure $ loadNext loadNode step
loadNext ::
(Value (Ptr am) -> CodeGenFunction r a) ->
Value Word32 ->
MS.StateT (Value (Ptr am)) (CodeGenFunction r) a
loadNext loadNode step =
MS.StateT $ \ptr ->
liftA2 (,) (loadNode ptr) (LLVM.getElementPtr ptr (step, ()))
indexNodes ::
(C nodes) =>
(Value Word32 -> CodeGenFunction r v) ->
Value Word32 ->
Value Word32 -> CodeGenFunction r (nodes v)
indexNodes indexNode step =
MS.evalStateT $ sequenceA $ pure $ indexNext indexNode step
indexNext ::
(Value Word32 -> CodeGenFunction r v) ->
Value Word32 ->
MS.StateT (Value Word32) (CodeGenFunction r) v
indexNext indexNode step =
MS.StateT $ \i -> liftA2 (,) (indexNode i) (A.add i step)