{-# LANGUAGE NoImplicitPrelude #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE ExistentialQuantification #-}
{-# LANGUAGE Rank2Types #-}
{-# LANGUAGE ForeignFunctionInterface #-}
{- |
Signal generators that generate the signal in chunks
that can be processed natively by the processor.
Some of the functions for plain signals can be re-used without modification.
E.g. rendering a signal and reading from and to signals work
because the vector type as element type warrents correct alignment.
We can convert between atomic and chunked signals.

The article
<http://perilsofparallel.blogspot.com/2008/09/larrabee-vs-nvidia-mimd-vs-simd.html>
explains the difference between Vector and SIMD computing.
According to that the SSE extensions in Intel processors
must be called Vector computing.
But since we use the term Vector already in the mathematical sense,
I like to use the term "packed" that is used in Intel mnemonics like mulps.
-}
module Synthesizer.LLVM.Parameterized.SignalPacked where

import Synthesizer.LLVM.Parameterized.Signal (T(Cons), )
import qualified Synthesizer.LLVM.Parameterized.Signal as Sig
import qualified Synthesizer.LLVM.Parameter as Param

import qualified Synthesizer.LLVM.Random as Rnd
import qualified LLVM.Extra.Representation as Rep
import qualified LLVM.Extra.ScalarOrVector as SoV
import qualified LLVM.Extra.Vector as Vector
import qualified LLVM.Extra.MaybeContinuation as Maybe
import qualified LLVM.Extra.Control as U
import LLVM.Extra.Control (whileLoop, )

import qualified Data.TypeLevel.Num as TypeNum

import qualified LLVM.Extra.Class as Class
import qualified LLVM.Extra.Arithmetic as A

import LLVM.Core as LLVM

-- we can also use <$> for parameters
import Control.Arrow ((^<<), )
import Control.Applicative (liftA2, )

import qualified Algebra.Transcendental as Trans
import qualified Algebra.Algebraic as Algebraic
import qualified Algebra.RealField as RealField
import qualified Algebra.Field as Field
import qualified Algebra.Ring as Ring

import Data.Word (Word32, )
import Foreign.Storable (Storable, )

import qualified Data.List as List

import NumericPrelude.Numeric as NP
import NumericPrelude.Base hiding (and, iterate, map, zip, zipWith, )



{- |
Convert a signal of scalar values into one using processor vectors.
If the signal length is not divisible by the chunk size,
then the last chunk is dropped.
-}
pack, packRotate, packIndex ::
   (Vector.Access n a v) =>
   T p a -> T p v
pack = packRotate

packRotate (Cons next start createIOContext deleteIOContext) = Cons
   (\param s -> do
      (v2,_,s2) <-
         Maybe.fromBool $
         U.whileLoop
            (valueOf True,
             let v = undefTuple
             in  (v, valueOf $ (fromIntegral $ Vector.sizeInTuple v :: Word32), s))
            (\(cont,(_v0,i0,_s0)) ->
               A.and cont =<<
                  A.icmp IntUGT i0 (value LLVM.zero))
            (\(_,(v0,i0,s0)) -> Maybe.toBool $ do
               (a,s1) <- next param s0
               Maybe.lift $ do
                  v1 <- fmap snd $ Vector.shiftDown a v0
                  i1 <- A.dec i0
                  return (v1,i1,s1))
      return (v2, s2))
   start
   createIOContext
   deleteIOContext

packIndex (Cons next start createIOContext deleteIOContext) = Cons
   (\param s -> do
      (v2,_,s2) <-
         Maybe.fromBool $
         U.whileLoop
            (valueOf True, (undefTuple, value LLVM.zero, s))
            (\(cont,(v0,i0,_s0)) ->
               A.and cont =<<
                  A.icmp IntULT i0
                     (valueOf $ fromIntegral $ Vector.sizeInTuple v0))
            (\(_,(v0,i0,s0)) -> Maybe.toBool $ do
               (a,s1) <- next param s0
               Maybe.lift $ do
                  v1 <- Vector.insert i0 a v0
                  i1 <- A.inc i0
                  return (v1,i1,s1))
      return (v2, s2))
   start
   createIOContext
   deleteIOContext


{- |
Like 'pack' but duplicates the code for creating elements.
That is, for vectors of size n, the code of the input signal
will be emitted n times.
This is efficient only for simple input generators.
-}
packSmall ::
   (Vector.Access n a v, Class.Zero v) =>
   T p a -> T p v
packSmall (Cons next start createIOContext deleteIOContext) = Cons
   (\param s ->
      let vundef = undefTuple
      in  foldr
             (\i rest (v0,s0) -> do
                (a,s1) <- next param s0
                v1 <- Maybe.lift $ Vector.insert (valueOf i) a v0
                rest (v1,s1))
             return
             (take (Vector.sizeInTuple vundef) [0..])
             (vundef, s))
   start
   createIOContext
   deleteIOContext


unpack, unpackRotate, unpackIndex ::
   (Vector.Access n a v, Rep.Memory v vp, IsSized vp vs) =>
   T p v -> T p a
unpack = unpackRotate

unpackRotate (Cons next start createIOContext deleteIOContext) = Cons
   (\param (i0,v0,s0) -> do
      endOfVector <-
         Maybe.lift $ A.icmp IntEQ i0 (valueOf 0)
      (i2,v2,s2) <-
         Maybe.fromBool $
         U.ifThen endOfVector (valueOf True, (i0,v0,s0)) $ do
            (cont1, (v1,s1)) <- Maybe.toBool $ next param s0
            return (cont1, (valueOf $ fromIntegral $ Vector.sizeInTuple v0, v1, s1))
      Maybe.lift $ do
         a <- Vector.extract (valueOf 0 `asTypeOf` i0) v2
         v3 <- Vector.rotateDown v2
         i3 <- A.dec i2
         return (a, (i3,v3,s2)))
   (\p -> do
      s <- start p
      return (valueOf 0, undefTuple, s))
   createIOContext
   deleteIOContext

unpackIndex (Cons next start createIOContext deleteIOContext) = Cons
   (\param (i0,v0,s0) -> do
      endOfVector <-
         Maybe.lift $ A.icmp IntUGE i0
            (valueOf $ fromIntegral $ Vector.sizeInTuple v0)
      (i2,v2,s2) <-
         Maybe.fromBool $
         U.ifThen endOfVector (valueOf True, (i0,v0,s0)) $ do
            (cont1, (v1,s1)) <- Maybe.toBool $ next param s0
            return (cont1, (value LLVM.zero, v1, s1))
      Maybe.lift $ do
         a <- Vector.extract i2 v2
         i3 <- A.inc i2
         return (a, (i3,v2,s2)))
   (\p -> do
      s <- start p
      let v = undefTuple
      return (valueOf $ fromIntegral $ Vector.sizeInTuple v, v, s))
   createIOContext
   deleteIOContext


withSize ::
   (n -> T p (Value (Vector n a))) ->
   T p (Value (Vector n a))
withSize f = f undefined


constant ::
   (Storable a,  MakeValueTuple a (Value a),
    IsConst a, IsPrimitive a,
    IsPowerOf2 n, IsSized (Vector n a) s) =>
--    IsPowerOf2 n, IsSized a s, TypeNum.Pos vs, TypeNum.Mul n s vs) =>
   Param.T p a -> T p (Value (Vector n a))
constant x =
   Sig.constant (LLVM.vector . (:[]) ^<< x)


exponential2 ::
   (Trans.C a, Storable a, MakeValueTuple a (Value a),
    IsFirstClass a, IsSized a s, IsSized (Vector n a) vs,
    IsPrimitive a, IsArithmetic a, IsConst a,
    IsPowerOf2 n) =>
   Param.T p a -> Param.T p a -> T p (Value (Vector n a))
exponential2 halfLife start = withSize $ \n ->
   Sig.exponentialCore
      (LLVM.vector . (:[]) ^<<
         0.5 ** (fromIntegral (TypeNum.toInt n) / halfLife))
      (liftA2
         (\h -> LLVM.vector . List.iterate (0.5 ** recip h *))
         halfLife start)

exponentialBounded2 ::
   (Trans.C a, Storable a, MakeValueTuple a (Value a),
    IsFirstClass a, IsSized a s, IsSized (Vector n a) vs,
    IsPrimitive a, Vector.Real a, IsConst a,
    IsPowerOf2 n) =>
   Param.T p a -> Param.T p a -> Param.T p a ->
   T p (Value (Vector n a))
exponentialBounded2 bound halfLife start = withSize $ \n ->
   Sig.exponentialBoundedCore
      (fmap (LLVM.vector . (:[])) bound)
      (LLVM.vector . (:[]) ^<<
         0.5 ** (fromIntegral (TypeNum.toInt n) / halfLife))
      (liftA2
         (\h -> LLVM.vector . List.iterate (0.5 ** recip h *))
         halfLife start)


osciCore ::
   (Storable t, MakeValueTuple t (Value t),
    IsFirstClass t, IsSized t size, IsSized (Vector n t) vsize,
    Vector.Real t, IsFloating t, RealField.C t, IsConst t,
    IsPowerOf2 n) =>
   Param.T p t -> Param.T p t -> T p (Value (Vector n t))
osciCore phase freq = withSize $ \n ->
   Sig.osciCore
      (liftA2
         (\f -> LLVM.vector . List.iterate (fraction . (f +)))
         freq phase)
      (fmap
         (\f -> LLVM.vector [fraction (fromIntegral (TypeNum.toInt n) * f)])
         freq)

osci ::
   (Storable t, MakeValueTuple t (Value t),
    Storable c, MakeValueTuple c cl,
    IsFirstClass t, IsSized t size, IsSized (Vector n t) vsize,
    Rep.Memory cl cp, IsSized cp cs,
    Vector.Real t, IsFloating t, RealField.C t, IsConst t,
    IsPowerOf2 n) =>
   (forall r. cl -> Value (Vector n t) -> CodeGenFunction r y) ->
   Param.T p c ->
   Param.T p t -> Param.T p t -> T p y
osci wave waveParam phase freq =
   Sig.map wave waveParam $
   osciCore phase freq

osciSimple ::
   (Storable t, MakeValueTuple t (Value t),
    IsFirstClass t, IsSized t size, IsSized (Vector n t) vsize,
    Vector.Real t, IsFloating t, RealField.C t, IsConst t,
    IsPowerOf2 n) =>
   (forall r. Value (Vector n t) -> CodeGenFunction r y) ->
   Param.T p t -> Param.T p t -> T p y
osciSimple wave =
   osci (const wave) (return ())


rampInf, rampSlope,
 parabolaFadeInInf, parabolaFadeOutInf ::
   (RealField.C a, Storable a, MakeValueTuple a (Value a),
    IsPrimitive a, IsArithmetic a, IsConst a,
    IsPowerOf2 n, IsSized (Vector n a) s) =>
   Param.T p a -> T p (Value (Vector n a))
rampSlope slope = withSize $ \n ->
   Sig.rampCore
      (fmap (\s -> LLVM.vector [fromIntegral (TypeNum.toInt n) * s]) slope)
      (fmap (\s -> LLVM.vector (List.iterate (s +) 0)) slope)
rampInf dur = rampSlope (recip dur)

parabolaFadeInInf dur = withSize $ \ni ->
   let n = fromIntegral (TypeNum.toInt ni)
   in  Sig.parabolaCore
          (fmap
             (\dr ->
                let d = n / dr
                in  LLVM.vector [-2*d*d]) dur)
          (fmap
             (\dr ->
                let d = n / dr
                in  LLVM.vector $ List.iterate (subtract $ 2 / dr ^ 2) (d*(2-d)))
             dur)
          (fmap
             (\dr ->
                LLVM.vector $ List.map (\t -> t*(2-t)) $ List.iterate (recip dr +) 0)
             dur)

parabolaFadeOutInf dur = withSize $ \ni ->
   let n = fromIntegral (TypeNum.toInt ni)
   in  Sig.parabolaCore
          (fmap
             (\dr ->
                let d = n / dr
                in  LLVM.vector [-2*d*d]) dur)
          (fmap
             (\dr ->
                let d = n / dr
                in  LLVM.vector $ List.iterate (subtract $ 2 / dr ^ 2) (-d*d))
             dur)
          (fmap
             (\dr ->
                LLVM.vector $ List.map (\t -> 1-t*t) $ List.iterate (recip dr +) 0)
             dur)


{- |
For the mysterious rate parameter see 'Sig.noise'.
-}
noise ::
   (Algebraic.C a, IsFloating a, IsConst a, IsPrimitive a,
    IsPowerOf2 n, IsSized (Vector n Word32) s,
    IsSized a as, TypeNum.Mul n as vas, TypeNum.Pos vas,
    MakeValueTuple a (Value a), Storable a) =>
   Param.T p Word32 ->
   Param.T p a ->
   T p (Value (Vector n a))
noise seed rate =
   let m2 = fromInteger $ div Rnd.modulus 2
   in  Sig.map (\r y ->
          A.mul r
           =<< flip A.sub (SoV.replicateOf $ m2+1)
           {-
           In principle it must be uitofp,
           but sitofp is a single instruction on x86
           and our numbers are below 2^31.
           -}
           =<< sitofp y)
          (LLVM.vector . (:[]) ^<< sqrt (3 * rate) / return m2) $
       noiseCore seed

noiseCore, noiseCoreAlt ::
   (IsPowerOf2 n, IsSized (Vector n Word32) s) =>
   Param.T p Word32 ->
   T p (Value (Vector n Word32))
noiseCore seed =
   Sig.iterate (const Rnd.nextVector)
      (return ())
      (Rnd.vectorSeed . (+1) . flip mod (Rnd.modulus-1) ^<< seed)

noiseCoreAlt seed =
   Sig.iterate (const Rnd.nextVector64)
      (return ())
      (Rnd.vectorSeed . (+1) . flip mod (Rnd.modulus-1) ^<< seed)