{-# LANGUAGE Rank2Types #-}
{-# LANGUAGE TypeFamilies #-}
module Synthesizer.LLVM.Interpolation (
   C(margin),
   loadNodes,
   indexNodes,

   Margin(..),
   toMargin,

   T,

   Nodes02(..),
   linear,
   linearVector,

   Nodes13(..),
   cubic,
   cubicVector,
   ) where

import qualified Synthesizer.LLVM.Simple.Value as Value

import qualified Synthesizer.LLVM.Frame.SerialVector as Serial
import qualified Synthesizer.Interpolation.Core as Interpolation

import qualified LLVM.Extra.Scalar as Scalar
import qualified LLVM.Extra.Arithmetic as A
import qualified LLVM.Extra.Class as Class
import qualified LLVM.Extra.Memory as Memory
import qualified LLVM.Util.Loop as Loop
import qualified LLVM.Core as LLVM

import LLVM.Core (CodeGenFunction, Value, )

import Foreign.Ptr (Ptr, )
import Data.Word (Word32, )

import qualified Type.Data.Num.Decimal as TypeNum

import qualified Control.Monad.Trans.State as MS
import Control.Applicative (Applicative, liftA2, pure, (<*>), )
import Data.Traversable (Traversable, traverse, sequenceA, foldMapDefault, )
import Data.Foldable (Foldable, foldMap, )


class (Applicative nodes, Traversable nodes) => C nodes where
   margin :: Margin (nodes a)

data Margin nodes =
      Margin { marginNumber, marginOffset :: Int }
   deriving (Show, Eq)


type T r nodes a v = a -> nodes v -> CodeGenFunction r v


toMargin ::
   (C nodes) =>
   (forall r. T r nodes a v) ->
   Margin (nodes v)
toMargin _ = margin


{- |
Zero nodes before index 0 and two nodes starting from index 0.
-}
data Nodes02 a = Nodes02 {nodes02_0, nodes02_1 :: a}

instance C Nodes02 where
   margin = Margin { marginNumber = 2, marginOffset = 0 }


instance Functor Nodes02 where
   fmap f (Nodes02 x0 x1) = Nodes02 (f x0) (f x1)

instance Applicative Nodes02 where
   pure x = Nodes02 x x
   (Nodes02 f0 f1) <*> (Nodes02 x0 x1) = Nodes02 (f0 x0) (f1 x1)

instance Foldable Nodes02 where
   foldMap = foldMapDefault

instance Traversable Nodes02 where
   traverse f (Nodes02 x0 x1) = liftA2 Nodes02 (f x0) (f x1)


instance (Serial.Sized value) => Serial.Sized (Nodes02 value) where
   type Size (Nodes02 value) = Serial.Size value

instance (Serial.Read v) => Serial.Read (Nodes02 v) where
   type Element (Nodes02 v) = Nodes02 (Serial.Element v)
   type ReadIt (Nodes02 v) = Nodes02 (Serial.ReadIt v)

   extract = Serial.extractTraversable

   readStart = Serial.readStartTraversable
   readNext = Serial.readNextTraversable

instance (Serial.C v) => Serial.C (Nodes02 v) where
   type WriteIt (Nodes02 v) = Nodes02 (Serial.WriteIt v)

   insert = Serial.insertTraversable

   writeStart = Serial.writeStartTraversable
   writeNext = Serial.writeNextTraversable
   writeStop = Serial.writeStopTraversable


instance (Class.Undefined a) => Class.Undefined (Nodes02 a) where
   undefTuple = Class.undefTuplePointed

instance (Loop.Phi a) => Loop.Phi (Nodes02 a) where
   phis = Class.phisTraversable
   addPhis = Class.addPhisFoldable


type Struct02 a = LLVM.Struct (a, (a, ()))

memory02 ::
   (Memory.C l) =>
   Memory.Record r (Struct02 (Memory.Struct l)) (Nodes02 l)
memory02 =
   liftA2 Nodes02
      (Memory.element nodes02_0 TypeNum.d0)
      (Memory.element nodes02_1 TypeNum.d1)

instance (Memory.C l) => Memory.C (Nodes02 l) where
   type Struct (Nodes02 l) = Struct02 (Memory.Struct l)
   load = Memory.loadRecord memory02
   store = Memory.storeRecord memory02
   decompose = Memory.decomposeRecord memory02
   compose = Memory.composeRecord memory02


linear ::
   (A.PseudoRing a, A.IntegerConstant a) =>
   T r Nodes02 a a
linear r (Nodes02 a b) =
   Scalar.unliftM3 (Value.unlift3 Interpolation.linear) a b r

linearVector ::
   (A.PseudoModule v, A.Scalar v ~ a, A.IntegerConstant a) =>
   T r Nodes02 a v
linearVector r (Nodes02 a b) =
   Value.unlift3 Interpolation.linear a b r




{- |
One node before index 0 and three nodes starting from index 0.
-}
data Nodes13 a = Nodes13 {nodes13_0, nodes13_1, nodes13_2, nodes13_3 :: a}

instance C Nodes13 where
   margin = Margin { marginNumber = 4, marginOffset = 1 }

instance Functor Nodes13 where
   fmap f (Nodes13 x0 x1 x2 x3) = Nodes13 (f x0) (f x1) (f x2) (f x3)

instance Applicative Nodes13 where
   pure x = Nodes13 x x x x
   (Nodes13 f0 f1 f2 f3) <*> (Nodes13 x0 x1 x2 x3) =
      Nodes13 (f0 x0) (f1 x1) (f2 x2) (f3 x3)

instance Foldable Nodes13 where
   foldMap = foldMapDefault

instance Traversable Nodes13 where
   traverse f (Nodes13 x0 x1 x2 x3) =
      pure Nodes13 <*> f x0 <*> f x1 <*> f x2 <*> f x3


instance (Serial.Sized value) => Serial.Sized (Nodes13 value) where
   type Size (Nodes13 value) = Serial.Size value

instance (Serial.Read v) => Serial.Read (Nodes13 v) where
   type Element (Nodes13 v) = Nodes13 (Serial.Element v)
   type ReadIt (Nodes13 v) = Nodes13 (Serial.ReadIt v)

   extract = Serial.extractTraversable

   readStart = Serial.readStartTraversable
   readNext = Serial.readNextTraversable

instance (Serial.C v) => Serial.C (Nodes13 v) where
   type WriteIt (Nodes13 v) = Nodes13 (Serial.WriteIt v)

   insert = Serial.insertTraversable

   writeStart = Serial.writeStartTraversable
   writeNext = Serial.writeNextTraversable
   writeStop = Serial.writeStopTraversable


instance (Class.Undefined a) => Class.Undefined (Nodes13 a) where
   undefTuple = Class.undefTuplePointed

instance (Loop.Phi a) => Loop.Phi (Nodes13 a) where
   phis = Class.phisTraversable
   addPhis = Class.addPhisFoldable


type Struct13 a = LLVM.Struct (a, (a, (a, (a, ()))))

memory13 ::
   (Memory.C l) =>
   Memory.Record r (Struct13 (Memory.Struct l)) (Nodes13 l)
memory13 =
   pure Nodes13
      <*> Memory.element nodes13_0 TypeNum.d0
      <*> Memory.element nodes13_1 TypeNum.d1
      <*> Memory.element nodes13_2 TypeNum.d2
      <*> Memory.element nodes13_3 TypeNum.d3

instance (Memory.C l) => Memory.C (Nodes13 l) where
   type Struct (Nodes13 l) = Struct13 (Memory.Struct l)
   load = Memory.loadRecord memory13
   store = Memory.storeRecord memory13
   decompose = Memory.decomposeRecord memory13
   compose = Memory.composeRecord memory13


cubic ::
   (A.Field a, A.RationalConstant a) =>
   T r Nodes13 a a
cubic r (Nodes13 a b c d) =
   Scalar.unliftM5 (Value.unlift5 Interpolation.cubic) a b c d r

cubicVector ::
   (A.PseudoModule v, A.Scalar v ~ a, A.Field a, A.RationalConstant a) =>
   T r Nodes13 a v
cubicVector r (Nodes13 a b c d) =
   Value.unlift5 Interpolation.cubic a b c d r


loadNodes ::
   (C nodes) =>
   (Value (Ptr am) -> CodeGenFunction r a) ->
   Value Word32 ->
   Value (Ptr am) -> CodeGenFunction r (nodes a)
loadNodes loadNode step =
   MS.evalStateT $ sequenceA $ pure $ loadNext loadNode step

loadNext ::
   (Value (Ptr am) -> CodeGenFunction r a) ->
   Value Word32 ->
   MS.StateT (Value (Ptr am)) (CodeGenFunction r) a
loadNext loadNode step =
   MS.StateT $ \ptr ->
      liftA2 (,) (loadNode ptr) (LLVM.getElementPtr ptr (step, ()))



indexNodes ::
   (C nodes) =>
   (Value Word32 -> CodeGenFunction r v) ->
   Value Word32 ->
   Value Word32 -> CodeGenFunction r (nodes v)
indexNodes indexNode step =
   MS.evalStateT $ sequenceA $ pure $ indexNext indexNode step

indexNext ::
   (Value Word32 -> CodeGenFunction r v) ->
   Value Word32 ->
   MS.StateT (Value Word32) (CodeGenFunction r) v
indexNext indexNode step =
   MS.StateT $ \i -> liftA2 (,) (indexNode i) (A.add i step)