{-# LANGUAGE NoImplicitPrelude #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE UndecidableInstances #-}
module Synthesizer.LLVM.Filter.SecondOrderCascade (
   causal,  causalPacked,
   causalP, causalPackedP,
   ParameterValue(..),
   ParameterStruct,
   fixSize, constArray,
   ) where

import qualified Synthesizer.LLVM.Filter.SecondOrder as Filt2
import qualified Synthesizer.Plain.Filter.Recursive.SecondOrder as Filt2Core

import qualified Synthesizer.LLVM.CausalParameterized.Functional as Func
import qualified Synthesizer.LLVM.CausalParameterized.Process as CausalP
import qualified Synthesizer.LLVM.Causal.Process as Causal
import qualified Synthesizer.LLVM.Simple.SignalPrivate as Sig
import Synthesizer.LLVM.CausalParameterized.Functional (($&), (&|&), )

import qualified Synthesizer.LLVM.Frame.SerialVector as Serial
import Synthesizer.Causal.Class (($<), )

import qualified LLVM.Extra.Arithmetic as A
import qualified LLVM.Extra.Class as Class
import qualified LLVM.Extra.ScalarOrVector as SoV
import qualified LLVM.Extra.Memory as Memory
import LLVM.Extra.Class (Undefined, undefTuple, )

import qualified LLVM.Core as LLVM
import LLVM.Util.Loop (Phi, phis, addPhis, )
import LLVM.Core (Value, IsArithmetic, IsSized, CodeGenFunction, )

import qualified Type.Data.Num.Decimal as TypeNum
import Type.Data.Num.Decimal.Number ((:*:), )
import Type.Base.Proxy (Proxy, )

import Data.Word (Word32, )

import qualified Control.Arrow as Arrow
import Control.Arrow ((>>>), (<<<), (^<<), (&&&), arr, )
import Control.Applicative (liftA2, )

import Foreign.Ptr (Ptr, )

import NumericPrelude.Numeric
import NumericPrelude.Base


type ParameterStruct n a = LLVM.Array n (Filt2.ParameterStruct a)

newtype ParameterValue n a =
   ParameterValue {parameterValue :: Value (ParameterStruct n a)}
{-
Automatic deriving is not allowed even with GeneralizedNewtypeDeriving
because of IsSized constraint
and it would also be wrong for Functor and friends.
      deriving
         (Phi, Class.Undefined, Class.Zero,
          Functor, App.Applicative, Fold.Foldable, Trav.Traversable)
-}

instance (TypeNum.Natural n, IsSized a) =>
      Phi (ParameterValue n a) where
   phis bb (ParameterValue r) =
      fmap ParameterValue $ phis bb r
   addPhis bb
        (ParameterValue r)
        (ParameterValue r') =
      addPhis bb r r'

instance (TypeNum.Natural n, IsSized a) =>
      Class.Undefined (ParameterValue n a) where
   undefTuple = ParameterValue Class.undefTuple

instance (TypeNum.Natural n, IsSized a) =>
      Class.Zero (ParameterValue n a) where
   zeroTuple = ParameterValue Class.zeroTuple

instance (TypeNum.Natural n,
          Memory.FirstClass a, Memory.Stored a ~ am, IsSized a, IsSized am,
          TypeNum.Positive (n :*: LLVM.UnknownSize)) =>
      Memory.C (ParameterValue n a) where
   type Struct (ParameterValue n a) = ParameterStruct n (Memory.Stored a)
   load = Memory.loadNewtype ParameterValue
   store = Memory.storeNewtype (\(ParameterValue k) -> k)
   decompose = Memory.decomposeNewtype ParameterValue
   compose = Memory.composeNewtype (\(ParameterValue k) -> k)

type instance Func.Arguments f (ParameterValue n a) = f (ParameterValue n a)
instance Func.MakeArguments (ParameterValue n a) where
   makeArgs = id


withSize ::
   (TypeNum.Natural n) =>
   (TypeNum.Singleton n -> process (ParameterValue n a, x) y) ->
   process (ParameterValue n a, x) y
withSize f = f TypeNum.singleton

fixSize ::
   Proxy n ->
   process (ParameterValue n a, x) y ->
   process (ParameterValue n a, x) y
fixSize _n = id

constArray ::
   (TypeNum.Natural n, IsSized a) =>
   Proxy n -> [LLVM.ConstValue a] ->
   LLVM.Value (LLVM.Array n a)
constArray _n = LLVM.value . LLVM.constArray


causalP ::
   (LLVM.Value a ~ A.Scalar v, A.PseudoModule v,
    Memory.FirstClass a, Memory.Stored a ~ am, IsSized a, IsSized am,
    Memory.C v,
    IsArithmetic a, SoV.IntegerConstant a, TypeNum.Natural n,
    TypeNum.Positive (n :*: LLVM.UnknownSize)) =>
   CausalP.T p (ParameterValue n a, v) v
causalP = causal

causalPackedP ::
   (LLVM.Value a ~ A.Scalar v, A.PseudoModule v,
    Serial.C v, Serial.Element v ~ LLVM.Value a,
    SoV.IntegerConstant a,
    A.PseudoRing v, A.IntegerConstant v, Memory.C v,
    Memory.FirstClass a, Memory.Stored a ~ am, IsSized a, IsSized am,
    LLVM.IsPrimitive a,
    LLVM.IsPrimitive am,
    TypeNum.Positive (n :*: LLVM.UnknownSize),
    TypeNum.Natural n) =>
   CausalP.T p (ParameterValue n a, v) v
causalPackedP = causalPacked


causal ::
   (Causal.C process,
    LLVM.Value a ~ A.Scalar v, A.PseudoModule v,
    Memory.FirstClass a, Memory.Stored a ~ am, IsSized a, IsSized am,
    Memory.C v,
    IsArithmetic a, SoV.IntegerConstant a, TypeNum.Natural n,
    TypeNum.Positive (n :*: LLVM.UnknownSize)) =>
   process (ParameterValue n a, v) v
causal = causalGen Filt2.causal

causalPacked ::
   (Causal.C process,
    LLVM.Value a ~ A.Scalar v, A.PseudoModule v,
    Serial.C v, Serial.Element v ~ LLVM.Value a,
    SoV.IntegerConstant a,
    A.PseudoRing v, A.IntegerConstant v, Memory.C v,
    Memory.FirstClass a, Memory.Stored a ~ am, IsSized a, IsSized am,
    LLVM.IsPrimitive a,
    LLVM.IsPrimitive am,
    TypeNum.Positive (n :*: LLVM.UnknownSize),
    TypeNum.Natural n) =>
   process (ParameterValue n a, v) v
causalPacked = causalGen Filt2.causalPacked

causalGen ::
   (Causal.C process, IsSized a, Phi v, Undefined v,
    TypeNum.Natural n, TypeNum.Positive (n :*: LLVM.UnknownSize)) =>
   process (Filt2Core.Parameter (Value a), v) v ->
   process (ParameterValue n a, v) v
causalGen stage =
   withSize $ \n ->
      snd
      ^<<
      Causal.replicateControlled
         (TypeNum.integralFromSingleton n)
         (paramStage stage)
      <<<
      Causal.map
         (\(ptr, (p,v)) -> do
            LLVM.store (parameterValue p) ptr
            return (ptr, (A.zero, v)))
      $<
      Sig.alloca

paramStage ::
   (Causal.C process, IsSized a,
    TypeNum.Natural n, TypeNum.Positive (n :*: LLVM.UnknownSize)) =>
   process (Filt2Core.Parameter (Value a), v) v ->
   process
      (Value (Ptr (ParameterStruct n a)), (Value Word32, v)) (Value Word32, v)
paramStage stage =
   let p = arr fst
       i = arr (fst.snd)
       v = arr (snd.snd)
   in  (Causal.map A.inc <<< i)
       &&&
       (stage <<<
           (Causal.zipWith getStageParameterGEP <<< p &&& i)
           &&&
           v)

_paramStage ::
   (IsSized a,
    TypeNum.Natural n, TypeNum.Positive (n :*: LLVM.UnknownSize)) =>
   CausalP.T p (Filt2Core.Parameter (Value a), v) v ->
   CausalP.T p
      (Value (Ptr (ParameterStruct n a)), (Value Word32, v)) (Value Word32, v)
_paramStage stage =
   Func.withGuidedArgs (Func.atom, (Func.atom, Func.atom)) $ \(p,(i,v)) ->
      liftA2 (,) (i+1)
         (stage $&
             (Causal.zipWith getStageParameterGEP $& p &|& i)
             &|&
             v)

_causalGenP ::
   (Causal.C process, IsSized a,
    TypeNum.Natural n, TypeNum.Positive (n :*: LLVM.UnknownSize)) =>
   process (Filt2Core.Parameter (Value a), v) v ->
   process (ParameterValue n a, v) v
_causalGenP stage =
   withSize $ \n ->
   foldl (\x y -> (arr fst &&& x) >>> y) (arr snd) $
   map
      (\k ->
         stage <<<
         Arrow.first (Causal.map (flip getStageParameter k)))
      (take (TypeNum.integralFromSingleton n) [0..])


getStageParameter ::
   (IsSized a,
    TypeNum.Natural n, TypeNum.Positive (n :*: LLVM.UnknownSize)) =>
   ParameterValue n a ->
   Word32 ->
   CodeGenFunction r (Filt2Core.Parameter (Value a))
getStageParameter ps k =
   Filt2.decomposeParameter =<< LLVM.extractvalue (parameterValue ps) k
{-
   Memory.decompose =<<
   flip LLVM.extractvalue k =<<
   Memory.compose ps
-}

getStageParameterGEP ::
   (IsSized a,
    TypeNum.Natural n, TypeNum.Positive (n :*: LLVM.UnknownSize)) =>
   Value (Ptr (ParameterStruct n a)) ->
   Value Word32 -> CodeGenFunction r (Filt2Core.Parameter (Value a))
getStageParameterGEP ptr k =
   Filt2.decomposeParameter
    =<< LLVM.load
    =<< LLVM.getElementPtr0 ptr (k, ())


{-# DEPRECATED causalP          "use 'causal' instead" #-}
{-# DEPRECATED causalPackedP    "use 'causalPacked' instead" #-}