{-# LANGUAGE NoImplicitPrelude #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE UndecidableInstances #-}
module Synthesizer.LLVM.Filter.SecondOrderCascade (
   causal,  causalPacked,
   causalP, causalPackedP,
   ParameterValue(..),
   ParameterStruct,
   fixSize, constArray,
   ) where

import qualified Synthesizer.LLVM.Filter.SecondOrder as Filt2
import qualified Synthesizer.Plain.Filter.Recursive.SecondOrder as Filt2Core

import qualified Synthesizer.LLVM.CausalParameterized.Functional as Func
import qualified Synthesizer.LLVM.CausalParameterized.Process as CausalP
import qualified Synthesizer.LLVM.Causal.Process as Causal
import qualified Synthesizer.LLVM.Simple.SignalPrivate as Sig
import Synthesizer.LLVM.CausalParameterized.Functional (($&), (&|&))

import qualified Synthesizer.LLVM.Frame.SerialVector as Serial
import Synthesizer.Causal.Class (($<))

import qualified LLVM.Extra.Arithmetic as A
import qualified LLVM.Extra.Tuple as Tuple
import qualified LLVM.Extra.ScalarOrVector as SoV
import qualified LLVM.Extra.Memory as Memory

import qualified LLVM.Core as LLVM
import LLVM.Core (Value, IsArithmetic, IsSized, CodeGenFunction)

import qualified Type.Data.Num.Decimal as TypeNum
import Type.Data.Num.Decimal.Number ((:*:))
import Type.Base.Proxy (Proxy)

import Data.Word (Word)

import qualified Control.Arrow as Arrow
import Control.Arrow ((>>>), (<<<), (^<<), (&&&), arr)
import Control.Applicative (liftA2)


import NumericPrelude.Numeric
import NumericPrelude.Base


type ParameterStruct n a = LLVM.Array n (Filt2.ParameterStruct a)

newtype ParameterValue n a =
   ParameterValue {parameterValue :: Value (ParameterStruct n a)}
{-
Automatic deriving is not allowed even with GeneralizedNewtypeDeriving
because of IsSized constraint
and it would also be wrong for Functor and friends.
      deriving
         (Tuple.Phi, Tuple.Undefined, Tuple.Zero,
          Functor, App.Applicative, Fold.Foldable, Trav.Traversable)
-}

instance (TypeNum.Natural n, IsSized a) =>
      Tuple.Phi (ParameterValue n a) where
   phi bb (ParameterValue r) =
      fmap ParameterValue $ Tuple.phi bb r
   addPhi bb
        (ParameterValue r)
        (ParameterValue r') =
      Tuple.addPhi bb r r'

instance (TypeNum.Natural n, IsSized a) =>
      Tuple.Undefined (ParameterValue n a) where
   undef = ParameterValue Tuple.undef

instance (TypeNum.Natural n, IsSized a) =>
      Tuple.Zero (ParameterValue n a) where
   zero = ParameterValue Tuple.zero

instance (TypeNum.Natural n, IsSized a,
          TypeNum.Positive (n :*: LLVM.UnknownSize)) =>
      Memory.C (ParameterValue n a) where
   type Struct (ParameterValue n a) = ParameterStruct n a
   load = Memory.loadNewtype ParameterValue
   store = Memory.storeNewtype (\(ParameterValue k) -> k)
   decompose = Memory.decomposeNewtype ParameterValue
   compose = Memory.composeNewtype (\(ParameterValue k) -> k)

type instance Func.Arguments f (ParameterValue n a) = f (ParameterValue n a)
instance Func.MakeArguments (ParameterValue n a) where
   makeArgs = id


withSize ::
   (TypeNum.Natural n) =>
   (TypeNum.Singleton n -> process (ParameterValue n a, x) y) ->
   process (ParameterValue n a, x) y
withSize f = f TypeNum.singleton

fixSize ::
   Proxy n ->
   process (ParameterValue n a, x) y ->
   process (ParameterValue n a, x) y
fixSize _n = id

constArray ::
   (TypeNum.Natural n, IsSized a) =>
   Proxy n -> [LLVM.ConstValue a] ->
   LLVM.Value (LLVM.Array n a)
constArray _n = LLVM.value . LLVM.constArray


causalP ::
   (Memory.C v, A.PseudoModule v, A.Scalar v ~ LLVM.Value a,
    IsSized a, IsArithmetic a, SoV.IntegerConstant a, TypeNum.Natural n,
    TypeNum.Positive (n :*: LLVM.UnknownSize)) =>
   CausalP.T p (ParameterValue n a, v) v
causalP = causal

causalPackedP ::
   (Memory.C v, A.PseudoRing v, A.IntegerConstant v, A.PseudoModule v,
    Serial.C v, Serial.Element v ~ LLVM.Value a,
    A.Scalar v ~ LLVM.Value a,
    SoV.IntegerConstant a, LLVM.IsPrimitive a, IsSized a,
    TypeNum.Positive (n :*: LLVM.UnknownSize),
    TypeNum.Natural n) =>
   CausalP.T p (ParameterValue n a, v) v
causalPackedP = causalPacked


causal ::
   (Causal.C process,
    Memory.C v, A.PseudoModule v, A.Scalar v ~ LLVM.Value a,
    IsSized a, IsArithmetic a, SoV.IntegerConstant a, TypeNum.Natural n,
    TypeNum.Positive (n :*: LLVM.UnknownSize)) =>
   process (ParameterValue n a, v) v
causal = causalGen Filt2.causal

causalPacked ::
   (Causal.C process,
    A.PseudoRing v, A.IntegerConstant v,
    Memory.C v, A.PseudoModule v, A.Scalar v ~ LLVM.Value a,
    Serial.C v, Serial.Element v ~ LLVM.Value a,
    SoV.IntegerConstant a, LLVM.IsPrimitive a, IsSized a,
    TypeNum.Positive (n :*: LLVM.UnknownSize),
    TypeNum.Natural n) =>
   process (ParameterValue n a, v) v
causalPacked = causalGen Filt2.causalPacked

causalGen ::
   (Causal.C process, IsSized a, Tuple.Phi v, Tuple.Undefined v,
    TypeNum.Natural n, TypeNum.Positive (n :*: LLVM.UnknownSize)) =>
   process (Filt2Core.Parameter (Value a), v) v ->
   process (ParameterValue n a, v) v
causalGen stage =
   withSize $ \n ->
      snd
      ^<<
      Causal.replicateControlled
         (TypeNum.integralFromSingleton n)
         (paramStage stage)
      <<<
      Causal.map
         (\(ptr, (p,v)) -> do
            LLVM.store (parameterValue p) ptr
            return (ptr, (A.zero, v)))
      $<
      Sig.alloca

paramStage ::
   (Causal.C process, IsSized a,
    TypeNum.Natural n, TypeNum.Positive (n :*: LLVM.UnknownSize)) =>
   process (Filt2Core.Parameter (Value a), v) v ->
   process
      (Value (LLVM.Ptr (ParameterStruct n a)), (Value Word, v)) (Value Word, v)
paramStage stage =
   let p = arr fst
       i = arr (fst.snd)
       v = arr (snd.snd)
   in  (Causal.map A.inc <<< i)
       &&&
       (stage <<<
           (Causal.zipWith getStageParameterGEP <<< p &&& i)
           &&&
           v)

_paramStage ::
   (IsSized a,
    TypeNum.Natural n, TypeNum.Positive (n :*: LLVM.UnknownSize)) =>
   CausalP.T p (Filt2Core.Parameter (Value a), v) v ->
   CausalP.T p
      (Value (LLVM.Ptr (ParameterStruct n a)), (Value Word, v)) (Value Word, v)
_paramStage stage =
   Func.withGuidedArgs (Func.atom, (Func.atom, Func.atom)) $ \(p,(i,v)) ->
      liftA2 (,) (i+1)
         (stage $&
             (Causal.zipWith getStageParameterGEP $& p &|& i)
             &|&
             v)

_causalGenP ::
   (Causal.C process, IsSized a,
    TypeNum.Natural n, TypeNum.Positive (n :*: LLVM.UnknownSize)) =>
   process (Filt2Core.Parameter (Value a), v) v ->
   process (ParameterValue n a, v) v
_causalGenP stage =
   withSize $ \n ->
   foldl (\x y -> (arr fst &&& x) >>> y) (arr snd) $
   map
      (\k ->
         stage <<<
         Arrow.first (Causal.map (flip getStageParameter k)))
      (take (TypeNum.integralFromSingleton n) [0..])


getStageParameter ::
   (IsSized a,
    TypeNum.Natural n, TypeNum.Positive (n :*: LLVM.UnknownSize)) =>
   ParameterValue n a ->
   Word ->
   CodeGenFunction r (Filt2Core.Parameter (Value a))
getStageParameter ps k =
   Filt2.decomposeParameter =<< LLVM.extractvalue (parameterValue ps) k
{-
   Memory.decompose =<<
   flip LLVM.extractvalue k =<<
   Memory.compose ps
-}

getStageParameterGEP ::
   (IsSized a,
    TypeNum.Natural n, TypeNum.Positive (n :*: LLVM.UnknownSize)) =>
   Value (LLVM.Ptr (ParameterStruct n a)) ->
   Value Word -> CodeGenFunction r (Filt2Core.Parameter (Value a))
getStageParameterGEP ptr k =
   Filt2.decomposeParameter
    =<< LLVM.load
    =<< LLVM.getElementPtr0 ptr (k, ())


{-# DEPRECATED causalP          "use 'causal' instead" #-}
{-# DEPRECATED causalPackedP    "use 'causalPacked' instead" #-}