{-# LANGUAGE NoImplicitPrelude #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE ExistentialQuantification #-}
{-# LANGUAGE Rank2Types #-}
{-# LANGUAGE ForeignFunctionInterface #-}
module Synthesizer.LLVM.Simple.Signal (
   C(simple),
   T,
   amplify,
   amplifyStereo,
   constant,
   envelope,
   envelopeStereo,
   exponential2,
   iterate,
   map,
   mapAccum,
   mix,
   mixExt,
   takeWhile,
   empty,
   append,
   osci,
   osciPlain,
   osciSaw,
   zip,
   zipWith,

   fromStorableVector,
   fromStorableVectorLazy,

   render,
   renderChunky,
   runChunky,
   ) where

import Synthesizer.LLVM.Simple.SignalPrivate hiding (alloca)

import qualified Synthesizer.LLVM.Frame.Stereo as Stereo
import qualified Synthesizer.LLVM.Frame as Frame
import qualified Synthesizer.LLVM.Wave as Wave
import qualified Synthesizer.LLVM.Execution as Exec
import qualified Synthesizer.LLVM.ForeignPtr as ForeignPtr

import qualified Synthesizer.LLVM.Storable.ChunkIterator as ChunkIt
import qualified Synthesizer.LLVM.Storable.Vector as SVU
import qualified Data.StorableVector.Lazy as SVL
import qualified Data.StorableVector as SV
import qualified Data.StorableVector.Base as SVB

import qualified LLVM.Extra.Memory as Memory
import qualified LLVM.Extra.ScalarOrVector as SoV
import qualified LLVM.Extra.MaybeContinuation as MaybeCont
import qualified LLVM.Extra.Maybe as Maybe
import qualified LLVM.Extra.Arithmetic as A
import LLVM.Extra.Arithmetic (advanceArrayElementPtr, )
import LLVM.Extra.Class (Undefined, MakeValueTuple, ValueTuple, valueTupleOf, )

import qualified LLVM.Core as LLVM
import LLVM.Util.Loop (Phi, )
import LLVM.Core
          (CodeGenFunction, ret, Value, valueOf,
           IsSized, IsConst, IsArithmetic)

import Control.Monad (liftM2, )
import Control.Applicative (pure, liftA2, liftA3, (<$>), )

import Data.Monoid (Monoid, mappend, )

import qualified Algebra.Transcendental as Trans

import qualified System.Unsafe as Unsafe
import Foreign.Storable.Tuple ()
import Foreign.Storable (Storable, )
import Foreign.ForeignPtr (touchForeignPtr, withForeignPtr, )
import Foreign.Ptr (Ptr, )
import Data.Word (Word32, )
import Control.Exception (bracket, )

import NumericPrelude.Numeric
import NumericPrelude.Base hiding (and, iterate, map, zip, zipWith, takeWhile, )


constant :: (C signal, IsConst a) => a -> signal (Value a)
constant x = pure (valueOf x)

mapAccum ::
   (C signal, Memory.C s) =>
   (forall r. a -> s -> CodeGenFunction r (b,s)) ->
   (forall r. CodeGenFunction r s) ->
   signal a -> signal b
mapAccum f startS = alter (\(Core next start stop) ->
   Core
      (\ioContext (sa0,ss0) -> do
         (a,sa1) <- next ioContext sa0
         (b,ss1) <- MaybeCont.lift $ f a ss0
         return (b, (sa1,ss1)))
      (\ioContext ->
         liftM2 (,) (start ioContext) startS)
      (stop . fst))


{- |
Warning:
This shortens the result to the shorter input signal.
This is consistent with @Causal.mix@ but it may not be what you expect.
Consider using 'mixExt' instead.
-}
mix ::
   (C signal, A.Additive a) =>
   signal a -> signal a -> signal a
mix = zipWith Frame.mix

{- |
The result of mixing is as long as the longer of the two input signals.
-}
mixExt ::
   (C signal, Monoid (signal (Value Bool, a)),
    A.Additive a, Phi a, Undefined a) =>
   signal a -> signal a -> signal a
mixExt xs ys =
   let ext zs =
         mappend
            ((,) (valueOf True) <$> zs)
            (pure (valueOf False, A.zero))
   in  fmap snd $ takeWhile (return . fst) $
       zipWith
         (\(cx,x) (cy,y) -> liftA2 (,) (A.or cx cy) (A.add x y))
         (ext xs) (ext ys)


{-
You can apply Causal.takeWhile instead,
but this requires a pretty complex type signature
including a 'process' variable that is not of interest for the user.
-}
takeWhile ::
   (C signal) =>
   (forall r. a -> CodeGenFunction r (Value Bool)) ->
   signal a -> signal a
takeWhile p =
   alter
      (\(Core next start stop) ->
         Core
            (\context sa0 -> do
               (a,sa1) <- next context sa0
               MaybeCont.guard =<< MaybeCont.lift (p a)
               return (a,sa1))
            start
            stop)


envelope ::
   (C signal, A.PseudoRing a) =>
   signal a -> signal a -> signal a
envelope = zipWith Frame.amplifyMono

envelopeStereo ::
   (C signal, A.PseudoRing a) =>
   signal a -> signal (Stereo.T a) -> signal (Stereo.T a)
envelopeStereo = zipWith Frame.amplifyStereo

amplify ::
   (C signal, IsArithmetic a, IsConst a) =>
   a -> signal (Value a) -> signal (Value a)
amplify x =
   map (Frame.amplifyMono (valueOf x))

amplifyStereo ::
   (C signal, IsArithmetic a, IsConst a) =>
   a -> signal (Stereo.T (Value a)) -> signal (Stereo.T (Value a))
amplifyStereo x =
   map (Frame.amplifyStereo (valueOf x))



iterate ::
   (C signal,
    Memory.FirstClass a, Memory.Stored a ~ am, IsSized am, IsConst a) =>
   (forall r. Value a -> CodeGenFunction r (Value a)) ->
   Value a -> signal (Value a)
iterate f initial =
   simple
      (\y -> MaybeCont.lift $ fmap (\y1 -> (y,y1)) (f y))
      (return initial)

exponential2 ::
   (C signal, Trans.C a, IsArithmetic a,
    Memory.FirstClass a, Memory.Stored a ~ am, IsSized am, IsConst a) =>
   a -> a -> signal (Value a)
exponential2 halfLife =
   iterate (\y -> A.mul y (valueOf (0.5 ** recip halfLife))) . valueOf


osciPlain ::
   (C signal,
    Memory.FirstClass t, Memory.Stored t ~ tm, IsSized tm,
    SoV.Fraction t, IsConst t) =>
   (forall r. Value t -> CodeGenFunction r y) ->
   Value t -> Value t -> signal y
osciPlain wave phase freq =
   map wave $
   iterate (SoV.incPhase freq) $
   phase

osci ::
   (C signal,
    Memory.FirstClass t, Memory.Stored t ~ tm, IsSized tm,
    SoV.Fraction t, IsConst t) =>
   (forall r. Value t -> CodeGenFunction r y) ->
   t -> t -> signal y
osci wave phase freq =
   osciPlain wave (valueOf phase) (valueOf freq)

osciSaw ::
   (C signal,
    SoV.IntegerConstant a,
    Memory.FirstClass a, Memory.Stored a ~ am, IsSized am,
    SoV.Fraction a, IsConst a) =>
   a -> a -> signal (Value a)
osciSaw = osci Wave.saw



fromStorableVector ::
   (Storable a, MakeValueTuple a, ValueTuple a ~ value, Memory.C value) =>
   SV.Vector a ->
   T value
fromStorableVector xs =
   let (fp,ptr,l) = SVU.unsafeToPointers xs
   in  Cons
          (\_ () (p0,l0) -> do
             cont <- MaybeCont.lift $ A.cmp LLVM.CmpGT l0 A.zero
             MaybeCont.withBool cont $ do
                y1 <- Memory.load p0
                p1 <- advanceArrayElementPtr p0
                l1 <- A.dec l0
                return (y1,(p1,l1)))
          (return ())
          (const $ return
             (valueOf ptr,
              valueOf (fromIntegral l :: Word32)))
          -- keep the foreign ptr alive
          (return (fp, ()))
          touchForeignPtr

{-
This function calls back into the Haskell function 'nextChunk'
that returns a pointer to the data of the next chunk
and advances to the next chunk in the sequence.
-}
fromStorableVectorLazy ::
   (Storable a, MakeValueTuple a, ValueTuple a ~ value, Memory.C value) =>
   SVL.Vector a ->
   T value
fromStorableVectorLazy = flattenChunks . storableVectorChunks

storableVectorChunks ::
   (Storable a, MakeValueTuple a, ValueTuple a ~ value, Memory.C value,
    Memory.Struct value ~ struct) =>
   SVL.Vector a ->
   T (Value (Ptr struct), Value Word32)
storableVectorChunks sig =
   Cons
      (storableVectorNextChunk "Simple.Signal.fromStorableVectorLazy.nextChunk")
      LLVM.alloca
      (const $ return ())
      ((\stable -> (stable,stable)) <$> ChunkIt.new sig)
      ChunkIt.dispose


foreign import ccall safe "dynamic" derefFillPtr ::
   Exec.Importer (Word32 -> Ptr struct -> IO Word32)


compile ::
   (Memory.C value, Memory.Struct value ~ struct,
    Memory.C state, Memory.Struct state ~ stateStruct) =>
   (forall r z. (Phi z) => local -> state -> MaybeCont.T r z (value, state)) ->
   (forall r. CodeGenFunction r local) ->
   (forall r. CodeGenFunction r state) ->
   IO (Word32 -> Ptr struct -> IO Word32)
compile next alloca start =
   Exec.compileModule $
      Exec.createFunction derefFillPtr "fillsignalblock" $ \ size bPtr -> do
         s <- start
         local <- alloca
         (pos,_) <- MaybeCont.arrayLoop size bPtr s $ \ ptri s0 -> do
            (y,s1) <- next local s0
            MaybeCont.lift $ Memory.store y ptri
            return s1
         ret pos

{-
This parameter order would allows us to compile the code once
and apply it to different signal lengths.
However, we do not make use of this and instead bake
parts of the IO context into the code to allow constant folding.
The parameter order is consistent with that of @Parameterized.Signal.render@.
-}
render ::
   (Storable a, MakeValueTuple a, ValueTuple a ~ value, Memory.C value) =>
   T value -> Int -> SV.Vector a
render (Cons next alloca start createIOContext deleteIOContext) len =
   Unsafe.performIO $
   bracket createIOContext (deleteIOContext . fst) $ \ (_ioContext, params) ->
   SVB.createAndTrim len $ \ ptr ->
      do fill <-
            compile
               (next $ valueTupleOf params) alloca (start $ valueTupleOf params)
         fmap (fromIntegral :: Word32 -> Int) $
            fill (fromIntegral len) (Memory.castTuplePtr ptr)


foreign import ccall safe "dynamic" derefStartPtr ::
   Exec.Importer (IO (Ptr a))

foreign import ccall safe "dynamic" derefStopPtr ::
   Exec.Importer (Ptr a -> IO ())

foreign import ccall safe "dynamic" derefChunkPtr ::
   Exec.Importer (Ptr stateStruct -> Word32 -> Ptr struct -> IO Word32)


compileChunky ::
   (Memory.C value, Memory.Struct value ~ struct,
    Memory.C state, Memory.Struct state ~ stateStruct) =>
   (forall r z.
    (Phi z) =>
    local -> state -> MaybeCont.T r z (value, state)) ->
   (forall r. CodeGenFunction r local) ->
   (forall r.
    CodeGenFunction r state) ->
   IO (IO (Ptr stateStruct),
       Exec.Finalizer stateStruct,
       Ptr stateStruct -> Word32 -> Ptr struct -> IO Word32)
compileChunky next alloca start =
   Exec.compileModule $
      liftA3 (,,)
         (Exec.createFunction derefStartPtr "startsignal" $
          do
             pptr <- LLVM.malloc
             flip Memory.store pptr =<< start
             ret pptr)
{- for debugging: allocation with initialization makes type inference difficult
         (Exec.createFunPtr "startsignal" $
          do
             pptr <- malloc
             let retn :: CodeGenFunction r state -> Value (Ptr state) -> CodeGenFunction (Ptr state) ()
                 retn _ ptr = ret ptr
             retn undefined pptr)
-}
         (Exec.createFinalizer derefStopPtr "stopsignal" $
          \ pptr -> LLVM.free pptr >> ret ())
         (Exec.createFunction derefChunkPtr "fillsignal" $
          \ sptr loopLen ptr -> do
             sInit <- Memory.load sptr
             local <- alloca
             (pos,sExit) <- MaybeCont.arrayLoop loopLen ptr sInit $
              \ ptri s0 -> do
                (y,s1) <- next local s0
                MaybeCont.lift $ Memory.store y ptri
                return s1
             Memory.store (Maybe.fromJust sExit) sptr
             ret pos)


runChunky ::
   (Storable a, MakeValueTuple a, ValueTuple a ~ value, Memory.C value) =>
   T value -> SVL.ChunkSize -> IO (SVL.Vector a)
runChunky (Cons next alloca start createIOContext deleteIOContext)
      (SVL.ChunkSize size) = do
   (ioContext, params) <- createIOContext
   (startFunc, stopFunc, fill) <-
      compileChunky
         (next $ valueTupleOf params) alloca (start $ valueTupleOf params)

   statePtr <- ForeignPtr.newInit stopFunc startFunc
   ioContextPtr <- ForeignPtr.newAux (deleteIOContext ioContext)

   let go =
         Unsafe.interleaveIO $ do
            v <-
               withForeignPtr statePtr $ \sptr ->
               SVB.createAndTrim size $
               fmap (fromIntegral :: Word32 -> Int) .
               fill sptr (fromIntegral size) .
               Memory.castTuplePtr
            touchForeignPtr ioContextPtr
            (if SV.length v > 0
               then fmap (v:)
               else id) $
               (if SV.length v < size
                  then return []
                  else go)
   fmap SVL.fromChunks go

renderChunky ::
   (Storable a, MakeValueTuple a, ValueTuple a ~ value, Memory.C value) =>
   SVL.ChunkSize -> T value -> SVL.Vector a
renderChunky size sig =
   Unsafe.performIO (runChunky sig size)