{-# LANGUAGE NoImplicitPrelude #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE TypeSynonymInstances #-}
module Synthesizer.LLVM.Filter.SecondOrderCascade where

import qualified Synthesizer.LLVM.Filter.SecondOrder as Filt2
import qualified Synthesizer.Plain.Filter.Recursive.SecondOrder as Filt2Core

import qualified Synthesizer.LLVM.CausalParameterized.Process as CausalP
import qualified LLVM.Extra.Representation as Rep
import qualified Synthesizer.LLVM.Simple.Value as Value

import qualified LLVM.Extra.Class as Class

import qualified LLVM.Core as LLVM
import LLVM.Util.Loop (Phi, phis, addPhis, )
import LLVM.Core
   (Value, valueOf, Vector,
    IsPowerOf2, IsConst, IsArithmetic, IsPrimitive, IsFirstClass, IsSized,
    CodeGenFunction, )

import qualified Data.TypeLevel.Num      as TypeNum
import qualified Data.TypeLevel.Num.Sets as TypeSet

import Data.Word (Word32, )

import qualified Control.Arrow as Arrow
import Control.Arrow ((>>>), (<<<), (&&&), arr, )

-- import qualified Algebra.Transcendental as Trans
-- import qualified Algebra.Field as Field
import qualified Algebra.Module as Module
import qualified Algebra.Ring as Ring

import NumericPrelude.Numeric
import NumericPrelude.Base


type Parameter n a = LLVM.Array n (Filt2.ParameterStruct a)

newtype ParameterValue n a =
   ParameterValue {parameterValue :: Value (Parameter n a)}
{-
Automatic deriving is not allowed even with GeneralizedNewtypeDeriving
because of IsSized constraint
and it would also be wrong for Functor and friends.
      deriving
         (Phi, LLVM.Undefined, Class.Zero,
          Functor, App.Applicative, Fold.Foldable, Trav.Traversable)
-}

instance (TypeNum.Nat n, IsSized a s) =>
      Phi (ParameterValue n a) where
   phis bb (ParameterValue r) =
      fmap ParameterValue $ phis bb r
   addPhis bb
        (ParameterValue r)
        (ParameterValue r') =
      addPhis bb r r'

instance (TypeNum.Nat n, IsSized a s) =>
      LLVM.Undefined (ParameterValue n a) where
   undefTuple = ParameterValue LLVM.undefTuple

instance (TypeNum.Nat n, IsSized a s) =>
      Class.Zero (ParameterValue n a) where
   zeroTuple = ParameterValue Class.zeroTuple

instance
      (TypeNum.Nat n, IsSized a s) =>
      Rep.Memory (ParameterValue n a) (Parameter n a) where
   load = Rep.loadNewtype ParameterValue
   store = Rep.storeNewtype (\(ParameterValue k) -> k)
   decompose = Rep.decomposeNewtype ParameterValue
   compose = Rep.composeNewtype (\(ParameterValue k) -> k)



withSize ::
   (n -> CausalP.T p (ParameterValue n a, x) y) ->
   CausalP.T p (ParameterValue n a, x) y
withSize f = f undefined

fixSize ::
   n ->
   CausalP.T p (ParameterValue n a, x) y ->
   CausalP.T p (ParameterValue n a, x) y
fixSize _n = id

causalP ::
   (Ring.C a, Module.C (Value.T a) (Value.T v),
    IsFirstClass a, IsSized a as, IsConst a,
    IsFirstClass v, IsSized v vs, IsConst v,
    IsArithmetic a, TypeSet.Nat n,
    TypeNum.Mul n LLVM.UnknownSize paramSize, TypeSet.Pos paramSize) =>
   CausalP.T p (ParameterValue n a, Value v) (Value v)
causalP =
   withSize $ \n ->
   foldl (\x y -> (arr fst &&& x) >>> y) (arr snd) $
   map
      (\k ->
         Filt2.causalP <<<
         Arrow.first (CausalP.mapSimple
            (\ps -> getStageParameter ps k)))
      (take (TypeNum.toInt n) [0..])

causalPackedP ::
   (Ring.C a,
    IsPrimitive a, IsSized a as, IsConst a,
    IsArithmetic a, TypeSet.Nat n,
    TypeNum.Mul n LLVM.UnknownSize paramSize, TypeSet.Pos paramSize,
    IsPowerOf2 d, TypeNum.Mul d as vas, TypeSet.Pos vas) =>
   CausalP.T p
      (ParameterValue n a, Value (Vector d a)) (Value (Vector d a))
causalPackedP =
   withSize $ \n ->
   foldl (\x y -> (arr fst &&& x) >>> y) (arr snd) $
   map
      (\k ->
         Filt2.causalPackedP <<<
         Arrow.first (CausalP.mapSimple
            (\ps -> getStageParameter ps k)))
      (take (TypeNum.toInt n) [0..])

getStageParameter, getStageParameterMalloc, getStageParameterAlloca ::
   (IsFirstClass a, TypeSet.Nat n, IsSized a sa,
    TypeNum.Mul n LLVM.UnknownSize s, TypeSet.Pos s) =>
   ParameterValue n a ->
   Word32 ->
   CodeGenFunction r (Filt2Core.Parameter (Value a))
getStageParameter ps k =
   Rep.decompose =<<
   LLVM.extractvalue (parameterValue ps) k

{-
Expensive because we need a heap allocation for every sample.
However, we could allocate the memory once in the Causal initialization routine.
-}
getStageParameterMalloc ps k = do
   ptr <- LLVM.malloc
   LLVM.store (parameterValue ps) ptr
   p <- Rep.load =<< LLVM.getElementPtr0 ptr (valueOf k, ())
   LLVM.free ptr
   return p

{-
With this implementation, LLVM-2.6 generates a stack variable layout
that requires non-aligned access to vector values.
The result is a crash at runtime.
-}
getStageParameterAlloca ps k = do
   ptr <- LLVM.alloca
   LLVM.store (parameterValue ps) ptr
   Rep.load =<< LLVM.getElementPtr0 ptr (valueOf k, ())