{-# LANGUAGE NoImplicitPrelude #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE ExistentialQuantification #-}
{-# LANGUAGE Rank2Types #-}
{-# LANGUAGE ForeignFunctionInterface #-}
module Synthesizer.LLVM.Simple.Signal (
   C(simple),
   T,
   amplify,
   amplifyStereo,
   constant,
   envelope,
   envelopeStereo,
   exponential2,
   iterate,
   map,
   mapAccum,
   mix,
   mixExt,
   takeWhile,
   empty,
   append,
   osci,
   osciPlain,
   osciSaw,
   zip,
   zipWith,

   fromStorableVector,
   fromStorableVectorLazy,

   render,
   renderChunky,
   runChunky,
   ) where

import Synthesizer.LLVM.Simple.SignalPrivate hiding (alloca)

import qualified Synthesizer.LLVM.Frame.Stereo as Stereo
import qualified Synthesizer.LLVM.Frame as Frame
import qualified Synthesizer.LLVM.Wave as Wave
import qualified Synthesizer.LLVM.ForeignPtr as ForeignPtr

import qualified Synthesizer.LLVM.Storable.ChunkIterator as ChunkIt
import qualified Synthesizer.LLVM.Storable.Vector as SVU
import qualified Data.StorableVector.Lazy as SVL
import qualified Data.StorableVector as SV
import qualified Data.StorableVector.Base as SVB

import qualified LLVM.DSL.Execution as Exec

import qualified LLVM.Extra.Storable as Storable
import qualified LLVM.Extra.Memory as Memory
import qualified LLVM.Extra.ScalarOrVector as SoV
import qualified LLVM.Extra.MaybeContinuation as MaybeCont
import qualified LLVM.Extra.Maybe as Maybe
import qualified LLVM.Extra.Arithmetic as A
import qualified LLVM.Extra.Tuple as Tuple

import qualified LLVM.Core as LLVM
import LLVM.Core
          (CodeGenFunction, ret, Value, valueOf,
           IsFirstClass, IsSized, IsConst, IsArithmetic)

import Control.Monad (liftM2)
import Control.Applicative (pure, liftA2, liftA3, (<$>))

import Data.Monoid (Monoid, mappend)

import qualified Algebra.Transcendental as Trans

import qualified System.Unsafe as Unsafe
import Foreign.ForeignPtr (touchForeignPtr)
import Foreign.Ptr (Ptr)
import Data.Word (Word)
import Control.Exception (bracket)

import NumericPrelude.Numeric
import NumericPrelude.Base hiding (and, iterate, map, zip, zipWith, takeWhile)


constant :: (C signal, IsConst a) => a -> signal (Value a)
constant x = pure (valueOf x)

mapAccum ::
   (C signal, Memory.C s) =>
   (forall r. a -> s -> CodeGenFunction r (b,s)) ->
   (forall r. CodeGenFunction r s) ->
   signal a -> signal b
mapAccum f startS = alter (\(Core next start stop) ->
   Core
      (\ioContext (sa0,ss0) -> do
         (a,sa1) <- next ioContext sa0
         (b,ss1) <- MaybeCont.lift $ f a ss0
         return (b, (sa1,ss1)))
      (\ioContext ->
         liftM2 (,) (start ioContext) startS)
      (stop . fst))


{- |
Warning:
This shortens the result to the shorter input signal.
This is consistent with @Causal.mix@ but it may not be what you expect.
Consider using 'mixExt' instead.
-}
mix ::
   (C signal, A.Additive a) =>
   signal a -> signal a -> signal a
mix = zipWith Frame.mix

{- |
The result of mixing is as long as the longer of the two input signals.
-}
mixExt ::
   (C signal, Monoid (signal (Value Bool, a)),
    A.Additive a, Tuple.Phi a, Tuple.Undefined a) =>
   signal a -> signal a -> signal a
mixExt xs ys =
   let ext zs =
         mappend
            ((,) (valueOf True) <$> zs)
            (pure (valueOf False, A.zero))
   in  fmap snd $ takeWhile (return . fst) $
       zipWith
         (\(cx,x) (cy,y) -> liftA2 (,) (A.or cx cy) (A.add x y))
         (ext xs) (ext ys)


{-
You can apply Causal.takeWhile instead,
but this requires a pretty complex type signature
including a 'process' variable that is not of interest for the user.
-}
takeWhile ::
   (C signal) =>
   (forall r. a -> CodeGenFunction r (Value Bool)) ->
   signal a -> signal a
takeWhile p =
   alter
      (\(Core next start stop) ->
         Core
            (\context sa0 -> do
               (a,sa1) <- next context sa0
               MaybeCont.guard =<< MaybeCont.lift (p a)
               return (a,sa1))
            start
            stop)


envelope ::
   (C signal, A.PseudoRing a) =>
   signal a -> signal a -> signal a
envelope = zipWith Frame.amplifyMono

envelopeStereo ::
   (C signal, A.PseudoRing a) =>
   signal a -> signal (Stereo.T a) -> signal (Stereo.T a)
envelopeStereo = zipWith Frame.amplifyStereo

amplify ::
   (C signal, IsArithmetic a, IsConst a) =>
   a -> signal (Value a) -> signal (Value a)
amplify x =
   map (Frame.amplifyMono (valueOf x))

amplifyStereo ::
   (C signal, IsArithmetic a, IsConst a) =>
   a -> signal (Stereo.T (Value a)) -> signal (Stereo.T (Value a))
amplifyStereo x =
   map (Frame.amplifyStereo (valueOf x))



iterate ::
   (C signal, IsFirstClass a, IsSized a, IsConst a) =>
   (forall r. Value a -> CodeGenFunction r (Value a)) ->
   Value a -> signal (Value a)
iterate f initial =
   simple
      (\y -> MaybeCont.lift $ fmap (\y1 -> (y,y1)) (f y))
      (return initial)

exponential2 ::
   (C signal, Trans.C a, IsArithmetic a, IsSized a, IsConst a) =>
   a -> a -> signal (Value a)
exponential2 halfLife =
   iterate (\y -> A.mul y (valueOf (0.5 ** recip halfLife))) . valueOf


osciPlain ::
   (C signal, SoV.Fraction t, IsSized t, IsConst t) =>
   (forall r. Value t -> CodeGenFunction r y) ->
   Value t -> Value t -> signal y
osciPlain wave phase freq =
   map wave $
   iterate (SoV.incPhase freq) $
   phase

osci ::
   (C signal, SoV.Fraction t, IsSized t, IsConst t) =>
   (forall r. Value t -> CodeGenFunction r y) ->
   t -> t -> signal y
osci wave phase freq =
   osciPlain wave (valueOf phase) (valueOf freq)

osciSaw ::
   (C signal, SoV.IntegerConstant a, SoV.Fraction a, IsSized a, IsConst a) =>
   a -> a -> signal (Value a)
osciSaw = osci Wave.saw



fromStorableVector ::
   (Storable.C a, Tuple.ValueOf a ~ value) => SV.Vector a -> T value
fromStorableVector xs =
   let (fp,ptr,l) = SVU.unsafeToPointers xs
   in  Cons
          (\_ () (p0,l0) -> do
             cont <- MaybeCont.lift $ A.cmp LLVM.CmpGT l0 A.zero
             MaybeCont.withBool cont $ do
                y1 <- Storable.load p0
                p1 <- Storable.incrementPtr p0
                l1 <- A.dec l0
                return (y1,(p1,l1)))
          (return ())
          (const $ return
             (valueOf ptr,
              valueOf (fromIntegral l :: Word)))
          -- keep the foreign ptr alive
          (return (fp, ()))
          touchForeignPtr

{-
This function calls back into the Haskell function 'nextChunk'
that returns a pointer to the data of the next chunk
and advances to the next chunk in the sequence.
-}
fromStorableVectorLazy ::
   (Storable.C a, Tuple.ValueOf a ~ value) => SVL.Vector a -> T value
fromStorableVectorLazy = flattenChunks . storableVectorChunks

storableVectorChunks ::
   (Storable.C a) => SVL.Vector a -> T (Value (Ptr a), Value Word)
storableVectorChunks sig =
   Cons
      (storableVectorNextChunk "Simple.Signal.fromStorableVectorLazy.nextChunk")
      LLVM.alloca
      (const $ return ())
      ((\stable -> (stable,stable)) <$> ChunkIt.new sig)
      ChunkIt.dispose


foreign import ccall safe "dynamic" derefFillPtr ::
   Exec.Importer (Word -> Ptr struct -> IO Word)


compile ::
   (Storable.C a, Tuple.ValueOf a ~ value, Memory.C state) =>
   (forall r z.
    (Tuple.Phi z) => local -> state -> MaybeCont.T r z (value, state)) ->
   (forall r. CodeGenFunction r local) ->
   (forall r. CodeGenFunction r state) ->
   IO (Word -> Ptr a -> IO Word)
compile next alloca start =
   Exec.compile "signal" $
      Exec.createFunction derefFillPtr "fillsignalblock" $ \ size bPtr -> do
         s <- start
         local <- alloca
         (pos,_) <-
               Storable.arrayLoopMaybeCont size bPtr s $ \ ptri s0 -> do
            (y,s1) <- next local s0
            MaybeCont.lift $ Storable.store y ptri
            return s1
         ret pos

{-
This parameter order would allows us to compile the code once
and apply it to different signal lengths.
However, we do not make use of this and instead bake
parts of the IO context into the code to allow constant folding.
The parameter order is consistent with that of @Parameterized.Signal.render@.
-}
render ::
   (Storable.C a, Tuple.ValueOf a ~ value, Memory.C value) =>
   T value -> Int -> SV.Vector a
render (Cons next alloca start createIOContext deleteIOContext) len =
   Unsafe.performIO $
   bracket createIOContext (deleteIOContext . fst) $ \ (_ioContext, params) ->
   SVB.createAndTrim len $ \ ptr ->
      do fill <-
            compile
               (next $ Tuple.valueOf params) alloca (start $ Tuple.valueOf params)
         fmap (fromIntegral :: Word -> Int) $ fill (fromIntegral len) ptr


foreign import ccall safe "dynamic" derefStartPtr ::
   Exec.Importer (IO (LLVM.Ptr a))

foreign import ccall safe "dynamic" derefStopPtr ::
   Exec.Importer (LLVM.Ptr a -> IO ())

foreign import ccall safe "dynamic" derefChunkPtr ::
   Exec.Importer (LLVM.Ptr stateStruct -> Word -> Ptr struct -> IO Word)


compileChunky ::
   (Storable.C a, Tuple.ValueOf a ~ value,
    Memory.C state, Memory.Struct state ~ stateStruct) =>
   (forall r z.
    (Tuple.Phi z) =>
    local -> state -> MaybeCont.T r z (value, state)) ->
   (forall r. CodeGenFunction r local) ->
   (forall r. CodeGenFunction r state) ->
   IO (IO (LLVM.Ptr stateStruct),
       Exec.Finalizer stateStruct,
       LLVM.Ptr stateStruct -> Word -> Ptr a -> IO Word)
compileChunky next alloca start =
   Exec.compile "signal-chunky" $
      liftA3 (,,)
         (Exec.createFunction derefStartPtr "startsignal" $
          do
             pptr <- LLVM.malloc
             flip Memory.store pptr =<< start
             ret pptr)
{- for debugging: allocation with initialization makes type inference difficult
         (Exec.createFunPtr "startsignal" $
          do
             pptr <- malloc
             let retn :: CodeGenFunction r state -> Value (Ptr state) -> CodeGenFunction (Ptr state) ()
                 retn _ ptr = ret ptr
             retn undefined pptr)
-}
         (Exec.createFinalizer derefStopPtr "stopsignal" $
          \ pptr -> LLVM.free pptr >> ret ())
         (Exec.createFunction derefChunkPtr "fillsignal" $
          \ sptr loopLen ptr -> do
             sInit <- Memory.load sptr
             local <- alloca
             (pos,sExit) <-
                Storable.arrayLoopMaybeCont loopLen ptr sInit $
                   \ ptri s0 -> do
                (y,s1) <- next local s0
                MaybeCont.lift $ Storable.store y ptri
                return s1
             Memory.store (Maybe.fromJust sExit) sptr
             ret pos)


runChunky ::
   (Storable.C a, Tuple.ValueOf a ~ value) =>
   T value -> SVL.ChunkSize -> IO (SVL.Vector a)
runChunky (Cons next alloca start createIOContext deleteIOContext)
      (SVL.ChunkSize size) = do
   (ioContext, params) <- createIOContext
   (startFunc, stopFunc, fill) <-
      compileChunky
         (next $ Tuple.valueOf params) alloca (start $ Tuple.valueOf params)

   statePtr <- ForeignPtr.newInit stopFunc startFunc
   ioContextPtr <- ForeignPtr.newAux (deleteIOContext ioContext)

   let go =
         Unsafe.interleaveIO $ do
            v <-
               ForeignPtr.with statePtr $ \sptr ->
               SVB.createAndTrim size $
               fmap (fromIntegral :: Word -> Int) .
               fill sptr (fromIntegral size)
            touchForeignPtr ioContextPtr
            (if SV.length v > 0
               then fmap (v:)
               else id) $
               (if SV.length v < size
                  then return []
                  else go)
   fmap SVL.fromChunks go

renderChunky ::
   (Storable.C a, Tuple.ValueOf a ~ value) =>
   SVL.ChunkSize -> T value -> SVL.Vector a
renderChunky size sig =
   Unsafe.performIO (runChunky sig size)