{-# LANGUAGE Rank2Types #-} {-# LANGUAGE TypeFamilies #-} module Synthesizer.LLVM.Interpolation ( C(margin), loadNodes, indexNodes, Margin(..), toMargin, T, Nodes02(..), linear, linearVector, Nodes13(..), cubic, cubicVector, ) where import qualified Synthesizer.LLVM.Simple.Value as Value import qualified Synthesizer.LLVM.Frame.SerialVector as Serial import qualified Synthesizer.Interpolation.Core as Interpolation import qualified LLVM.Extra.Scalar as Scalar import qualified LLVM.Extra.Arithmetic as A import qualified LLVM.Extra.Tuple as Tuple import qualified LLVM.Extra.Storable as Storable import qualified LLVM.Extra.Memory as Memory import qualified LLVM.Core as LLVM import LLVM.Core (CodeGenFunction, Value) import Foreign.Ptr (Ptr) import Data.Word (Word) import qualified Type.Data.Num.Decimal as TypeNum import qualified Control.Monad.Trans.State as MS import Control.Applicative (Applicative, liftA2, pure, (<*>)) import Data.Traversable (Traversable, traverse, sequenceA, foldMapDefault) import Data.Foldable (Foldable, foldMap) class (Applicative nodes, Traversable nodes) => C nodes where margin :: Margin (nodes a) data Margin nodes = Margin { marginNumber, marginOffset :: Int } deriving (Show, Eq) type T r nodes a v = a -> nodes v -> CodeGenFunction r v toMargin :: (C nodes) => (forall r. T r nodes a v) -> Margin (nodes v) toMargin _ = margin {- | Zero nodes before index 0 and two nodes starting from index 0. -} data Nodes02 a = Nodes02 {nodes02_0, nodes02_1 :: a} instance C Nodes02 where margin = Margin { marginNumber = 2, marginOffset = 0 } instance Functor Nodes02 where fmap f (Nodes02 x0 x1) = Nodes02 (f x0) (f x1) instance Applicative Nodes02 where pure x = Nodes02 x x (Nodes02 f0 f1) <*> (Nodes02 x0 x1) = Nodes02 (f0 x0) (f1 x1) instance Foldable Nodes02 where foldMap = foldMapDefault instance Traversable Nodes02 where traverse f (Nodes02 x0 x1) = liftA2 Nodes02 (f x0) (f x1) instance (Serial.Sized value) => Serial.Sized (Nodes02 value) where type Size (Nodes02 value) = Serial.Size value instance (Serial.Read v) => Serial.Read (Nodes02 v) where type Element (Nodes02 v) = Nodes02 (Serial.Element v) type ReadIt (Nodes02 v) = Nodes02 (Serial.ReadIt v) extract = Serial.extractTraversable readStart = Serial.readStartTraversable readNext = Serial.readNextTraversable instance (Serial.C v) => Serial.C (Nodes02 v) where type WriteIt (Nodes02 v) = Nodes02 (Serial.WriteIt v) insert = Serial.insertTraversable writeStart = Serial.writeStartTraversable writeNext = Serial.writeNextTraversable writeStop = Serial.writeStopTraversable instance (Tuple.Undefined a) => Tuple.Undefined (Nodes02 a) where undef = Tuple.undefPointed instance (Tuple.Phi a) => Tuple.Phi (Nodes02 a) where phi = Tuple.phiTraversable addPhi = Tuple.addPhiFoldable type Struct02 a = LLVM.Struct (a, (a, ())) memory02 :: (Memory.C l) => Memory.Record r (Struct02 (Memory.Struct l)) (Nodes02 l) memory02 = liftA2 Nodes02 (Memory.element nodes02_0 TypeNum.d0) (Memory.element nodes02_1 TypeNum.d1) instance (Memory.C l) => Memory.C (Nodes02 l) where type Struct (Nodes02 l) = Struct02 (Memory.Struct l) load = Memory.loadRecord memory02 store = Memory.storeRecord memory02 decompose = Memory.decomposeRecord memory02 compose = Memory.composeRecord memory02 linear :: (A.PseudoRing a, A.IntegerConstant a) => T r Nodes02 a a linear r (Nodes02 a b) = Scalar.unliftM3 (Value.unlift3 Interpolation.linear) a b r linearVector :: (A.PseudoModule v, A.Scalar v ~ a, A.IntegerConstant a) => T r Nodes02 a v linearVector r (Nodes02 a b) = Value.unlift3 Interpolation.linear a b r {- | One node before index 0 and three nodes starting from index 0. -} data Nodes13 a = Nodes13 {nodes13_0, nodes13_1, nodes13_2, nodes13_3 :: a} instance C Nodes13 where margin = Margin { marginNumber = 4, marginOffset = 1 } instance Functor Nodes13 where fmap f (Nodes13 x0 x1 x2 x3) = Nodes13 (f x0) (f x1) (f x2) (f x3) instance Applicative Nodes13 where pure x = Nodes13 x x x x (Nodes13 f0 f1 f2 f3) <*> (Nodes13 x0 x1 x2 x3) = Nodes13 (f0 x0) (f1 x1) (f2 x2) (f3 x3) instance Foldable Nodes13 where foldMap = foldMapDefault instance Traversable Nodes13 where traverse f (Nodes13 x0 x1 x2 x3) = pure Nodes13 <*> f x0 <*> f x1 <*> f x2 <*> f x3 instance (Serial.Sized value) => Serial.Sized (Nodes13 value) where type Size (Nodes13 value) = Serial.Size value instance (Serial.Read v) => Serial.Read (Nodes13 v) where type Element (Nodes13 v) = Nodes13 (Serial.Element v) type ReadIt (Nodes13 v) = Nodes13 (Serial.ReadIt v) extract = Serial.extractTraversable readStart = Serial.readStartTraversable readNext = Serial.readNextTraversable instance (Serial.C v) => Serial.C (Nodes13 v) where type WriteIt (Nodes13 v) = Nodes13 (Serial.WriteIt v) insert = Serial.insertTraversable writeStart = Serial.writeStartTraversable writeNext = Serial.writeNextTraversable writeStop = Serial.writeStopTraversable instance (Tuple.Undefined a) => Tuple.Undefined (Nodes13 a) where undef = Tuple.undefPointed instance (Tuple.Phi a) => Tuple.Phi (Nodes13 a) where phi = Tuple.phiTraversable addPhi = Tuple.addPhiFoldable type Struct13 a = LLVM.Struct (a, (a, (a, (a, ())))) memory13 :: (Memory.C l) => Memory.Record r (Struct13 (Memory.Struct l)) (Nodes13 l) memory13 = pure Nodes13 <*> Memory.element nodes13_0 TypeNum.d0 <*> Memory.element nodes13_1 TypeNum.d1 <*> Memory.element nodes13_2 TypeNum.d2 <*> Memory.element nodes13_3 TypeNum.d3 instance (Memory.C l) => Memory.C (Nodes13 l) where type Struct (Nodes13 l) = Struct13 (Memory.Struct l) load = Memory.loadRecord memory13 store = Memory.storeRecord memory13 decompose = Memory.decomposeRecord memory13 compose = Memory.composeRecord memory13 cubic :: (A.Field a, A.RationalConstant a) => T r Nodes13 a a cubic r (Nodes13 a b c d) = Scalar.unliftM5 (Value.unlift5 Interpolation.cubic) a b c d r cubicVector :: (A.PseudoModule v, A.Scalar v ~ a, A.Field a, A.RationalConstant a) => T r Nodes13 a v cubicVector r (Nodes13 a b c d) = Value.unlift5 Interpolation.cubic a b c d r loadNodes :: (C nodes, Storable.C am) => (Value (Ptr am) -> CodeGenFunction r a) -> Value Int -> Value (Ptr am) -> CodeGenFunction r (nodes a) loadNodes loadNode step = MS.evalStateT $ sequenceA $ pure $ loadNext loadNode step loadNext :: (Storable.C am) => (Value (Ptr am) -> CodeGenFunction r a) -> Value Int -> MS.StateT (Value (Ptr am)) (CodeGenFunction r) a loadNext loadNode step = MS.StateT $ \ptr -> liftA2 (,) (loadNode ptr) (Storable.advancePtr step ptr) indexNodes :: (C nodes) => (Value Word -> CodeGenFunction r v) -> Value Word -> Value Word -> CodeGenFunction r (nodes v) indexNodes indexNode step = MS.evalStateT $ sequenceA $ pure $ indexNext indexNode step indexNext :: (Value Word -> CodeGenFunction r v) -> Value Word -> MS.StateT (Value Word) (CodeGenFunction r) v indexNext indexNode step = MS.StateT $ \i -> liftA2 (,) (indexNode i) (A.add i step)