{-# LANGUAGE NoImplicitPrelude #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE ExistentialQuantification #-}
{-# LANGUAGE Rank2Types #-}
{-# LANGUAGE ForeignFunctionInterface #-}
module Synthesizer.LLVM.Simple.Signal where

import qualified LLVM.Extra.Representation as Rep
import qualified Synthesizer.LLVM.Wave as Wave
import qualified Synthesizer.LLVM.Sample as Sample
import qualified Synthesizer.LLVM.Execution as Exec
import qualified LLVM.Extra.ScalarOrVector as SoV
import qualified LLVM.Extra.MaybeContinuation as Maybe

import qualified Synthesizer.LLVM.Storable.ChunkIterator as ChunkIt
import qualified Data.StorableVector.Lazy as SVL
import qualified Data.StorableVector as SV
import qualified Data.StorableVector.Base as SVB

import qualified Synthesizer.LLVM.Frame.Stereo as Stereo

import qualified LLVM.Extra.Arithmetic as A
import LLVM.Extra.Arithmetic (advanceArrayElementPtr, )
import LLVM.Extra.Control (whileLoop, ifThen, )

import LLVM.Core
import LLVM.Util.Loop (Phi, )

import Control.Monad (liftM2, liftM3, )

import qualified Algebra.Transcendental as Trans
import qualified Algebra.Field as Field
import qualified Algebra.Ring as Ring

import Data.Word (Word32, )
import Foreign.Storable.Tuple ()
import Foreign.Storable (Storable, )
import Foreign.Marshal.Array (advancePtr, )
import qualified Foreign.Marshal.Array as Array
import qualified Foreign.Marshal.Alloc as Alloc
import Foreign.ForeignPtr
          (unsafeForeignPtrToPtr, touchForeignPtr, withForeignPtr, )
import Foreign.Ptr (FunPtr, nullPtr, )
import Control.Exception (bracket, )
import System.IO.Unsafe (unsafePerformIO, unsafeInterleaveIO, )

import NumericPrelude.Numeric
import NumericPrelude.Base hiding (and, iterate, map, zip, zipWith, )


{-
We need the forall quantification for 'CodeGenFunction's @r@ parameter.
This type parameter will be unified with the result type of the final function.
Since one piece of code can be used in multiple functions
we cannot yet fix the type @r@ here.

We might avoid code duplication by defining

> newtype T a = Cons (Causal.T () a)
-}
data T a =
   forall state packed size ioContext.
      (Rep.Memory state packed, IsSized packed size) =>
      Cons (forall r c.
            (Phi c) =>
            ioContext ->
            state -> Maybe.T r c (a, state))
               -- compute next value
           (forall r.
            ioContext ->
            CodeGenFunction r state)
               -- initial state
           (IO ioContext)
               {- initialization from IO monad
               This will be run within unsafePerformIO,
               so no observable In/Out actions please!
               -}
           (ioContext -> IO ())
               -- finalization from IO monad, also run within unsafePerformIO

simple ::
   (Rep.Memory state packed, IsSized packed size) =>
   (forall r c.
    state -> Maybe.T r c (a, state)) ->
   (forall r. CodeGenFunction r state) ->
   T a
simple next start =
   Cons
      (const next)
      (const start)
      (return ())
      (const $ return ())


map ::
   (forall r. a -> CodeGenFunction r b) -> T a -> T b
map f (Cons next start createIOContext deleteIOContext) =
   Cons
      (\ioContext sa0 -> do
         (a,sa1) <- next ioContext sa0
         b <- Maybe.lift $ f a
         return (b, sa1))
      start
      createIOContext deleteIOContext

mapAccum ::
   (Rep.Memory s struct, IsSized struct sa) =>
   (forall r. a -> s -> CodeGenFunction r (b,s)) ->
   (forall r. CodeGenFunction r s) ->
   T a -> T b
mapAccum f startS
      (Cons next start createIOContext deleteIOContext) =
   Cons
      (\ioContext (sa0,ss0) -> do
         (a,sa1) <- next ioContext sa0
         (b,ss1) <- Maybe.lift $ f a ss0
         return (b, (sa1,ss1)))
      (\ioContext ->
         liftM2 (,) (start ioContext) startS)
      createIOContext deleteIOContext


zipWith ::
   (forall r. a -> b -> CodeGenFunction r c) -> T a -> T b -> T c
zipWith f
      (Cons nextA startA createIOContextA deleteIOContextA)
      (Cons nextB startB createIOContextB deleteIOContextB) =
   Cons
      (\(ioContextA, ioContextB) (sa0,sb0) -> do
         (a,sa1) <- nextA ioContextA sa0
         (b,sb1) <- nextB ioContextB sb0
         c <- Maybe.lift $ f a b
         return (c, (sa1,sb1)))
      (\(ioContextA, ioContextB) ->
         liftM2 (,)
            (startA ioContextA)
            (startB ioContextB))
      (liftM2 (,)
         createIOContextA
         createIOContextB)
      (\(ca,cb) ->
         deleteIOContextA ca >>
         deleteIOContextB cb)

zip ::
   T a -> T b -> T (a,b)
zip = zipWith (\a b -> return (a,b))


{- |
Stretch signal in time by a certain factor.
-}
interpolateConstant ::
   (Rep.Memory a struct, IsSized struct size,
    Ring.C b,
    IsFloating b, CmpRet b Bool,
    IsConst b, IsFirstClass b, IsSized b sb) =>
   b -> T a -> T a
interpolateConstant k
      (Cons next start createIOContext deleteIOContext) =
   Cons
      (\ioContext ((y0,state0),ss0) ->
         do ((y1,state1), ss1) <-
               Maybe.fromBool $
               whileLoop
                  (valueOf True, ((y0,state0), ss0))
                  (\(cont1, (_, ss1)) ->
                     and cont1 =<< A.fcmp FPOLE ss1 (valueOf 0))
                  (\(_, ((_,state01), ss1)) ->
                     Maybe.toBool $ liftM2 (,)
                        (next ioContext state01)
                        (Maybe.lift $ A.add ss1 (valueOf k)))

            ss2 <- Maybe.lift $ A.sub ss1 (valueOf Ring.one)
            return (y1, ((y1,state1),ss2)))

{- using this initialization code we would not need undefined values
      (do sa <- start
          (a,_) <- next sa
          return (sa, a, valueOf 0))
-}
      (fmap (\sa -> ((undefTuple, sa), valueOf 0)) . start)
      createIOContext deleteIOContext


mix ::
   (IsArithmetic a) =>
   T (Value a) -> T (Value a) -> T (Value a)
mix = zipWith Sample.mixMono

mixStereo ::
   (IsArithmetic a) =>
   T (Stereo.T (Value a)) -> T (Stereo.T (Value a)) -> T (Stereo.T (Value a))
mixStereo = zipWith Sample.mixStereo


envelope ::
   (IsArithmetic a) =>
   T (Value a) -> T (Value a) -> T (Value a)
envelope = zipWith Sample.amplifyMono

envelopeStereo ::
   (IsArithmetic a) =>
   T (Value a) -> T (Stereo.T (Value a)) -> T (Stereo.T (Value a))
envelopeStereo = zipWith Sample.amplifyStereo

amplify ::
   (IsArithmetic a, IsConst a) =>
   a -> T (Value a) -> T (Value a)
amplify x =
   map (Sample.amplifyMono (valueOf x))

amplifyStereo ::
   (IsArithmetic a, IsConst a) =>
   a -> T (Stereo.T (Value a)) -> T (Stereo.T (Value a))
amplifyStereo x =
   map (Sample.amplifyStereo (valueOf x))



iterate ::
   (IsFirstClass a, IsSized a s, IsConst a) =>
   (forall r. Value a -> CodeGenFunction r (Value a)) ->
   Value a -> T (Value a)
iterate f initial =
   simple
      (\y -> Maybe.lift $ fmap (\y1 -> (y,y1)) (f y))
      (return initial)

exponential2 ::
   (Trans.C a,
    IsFirstClass a, IsSized a s, IsArithmetic a, IsConst a) =>
   a -> a -> T (Value a)
exponential2 halfLife =
   iterate (\y -> A.mul y (valueOf (0.5 ** recip halfLife))) . valueOf


osciPlain ::
   (IsFirstClass t, IsSized t size,
    SoV.Fraction t, IsConst t) =>
   (forall r. Value t -> CodeGenFunction r y) ->
   Value t -> Value t -> T y
osciPlain wave phase freq =
   map wave $
   iterate (SoV.incPhase freq) $
   phase

osci ::
   (IsFirstClass t, IsSized t size,
    SoV.Fraction t, IsConst t) =>
   (forall r. Value t -> CodeGenFunction r y) ->
   t -> t -> T y
osci wave phase freq =
   osciPlain wave (valueOf phase) (valueOf freq)

osciSaw ::
   (Ring.C a0, IsConst a0, SoV.Replicate a0 a,
    IsFirstClass a, IsSized a size,
    SoV.Fraction a, IsConst a) =>
   a -> a -> T (Value a)
osciSaw = osci Wave.saw



fromStorableVector ::
   (Storable a, MakeValueTuple a value, Rep.Memory value struct) =>
   SV.Vector a ->
   T value
fromStorableVector xs =
   let (fp,s,l) = SVB.toForeignPtr xs
   in  Cons
          (\_ (p0,l0) -> do
             cont <- Maybe.lift $ A.icmp IntUGT l0 (valueOf 0)
             Maybe.withBool cont $ do
                y1 <- Rep.load p0
                p1 <- advanceArrayElementPtr p0
                l1 <- A.dec l0
                return (y1,(p1,l1)))
          (const $ return
             (valueOf (Rep.castStorablePtr $ unsafeForeignPtrToPtr fp `advancePtr` s),
              valueOf (fromIntegral l :: Word32)))
          -- keep the foreign ptr alive
          (return fp)
          touchForeignPtr

{-
This function calls back into the Haskell function 'nextChunk'
that returns a pointer to the data of the next chunk
and advances to the next chunk in the sequence.
-}
fromStorableVectorLazy ::
   (Storable a, MakeValueTuple a value, Rep.Memory value struct) =>
   SVL.Vector a ->
   T value
fromStorableVectorLazy sig =
   Cons
      (\(stable, lenPtr) (buffer0,length0) -> do
         (buffer1,length1) <- Maybe.lift $ do
            nextChunkFn <- staticFunction ChunkIt.nextCallBack
            needNext <- A.icmp IntEQ length0 (valueOf 0)
            ifThen needNext (buffer0,length0)
               (liftM2 (,)
                   (call nextChunkFn (valueOf stable) (valueOf lenPtr))
                   (load (valueOf lenPtr)))
         valid <- Maybe.lift $ A.icmp IntNE buffer1 (valueOf nullPtr)
         Maybe.withBool valid $ do
            x <- Rep.load buffer1
            buffer2 <- advanceArrayElementPtr buffer1
            length2 <- A.dec length1
            return (x, (buffer2,length2)))
      (const $ return (valueOf nullPtr, valueOf 0))
      (liftM2 (,) (ChunkIt.new sig) Alloc.malloc)
      (\(stable,lenPtr) -> do
          ChunkIt.dispose stable
          Alloc.free lenPtr)


{-
compile ::
   (Rep.Memory value struct) =>
   T value ->
   CodeGenModule (Function (Word32 -> Ptr struct -> IO Word32))
-}

{-
We could also implement that in terms of getPointerToFunction
as done in Parameterized.Signal.
However, since the 'fill' function will be called only once,
it does not matter whether we use the Just-In-Time compiler
or compile once.
-}
render ::
   (Storable a, MakeValueTuple a value, Rep.Memory value struct) =>
   Int -> T value -> SV.Vector a
render len (Cons next start createIOContext deleteIOContext) =
   unsafePerformIO $
   bracket createIOContext deleteIOContext $ \ ioContext ->
   SVB.createAndTrim len $ \ ptr ->
      do fill <-
            Exec.runFunction $
            createFunction ExternalLinkage $ \ size bPtr -> do
               s <- start ioContext
               (pos,_) <- Maybe.arrayLoop size bPtr s $ \ ptri s0 -> do
                  (y,s1) <- next ioContext s0
                  Maybe.lift $ Rep.store y ptri
                  return s1
               ret (pos :: Value Word32)
         fmap (fromIntegral :: Word32 -> Int) $
            fill (fromIntegral len) (Rep.castStorablePtr ptr)


foreign import ccall safe "dynamic" derefChunkPtr ::
   Exec.Importer (Ptr stateStruct -> Word32 -> Ptr struct -> IO Word32)


compileChunky ::
   (Rep.Memory value struct,
    Rep.Memory state stateStruct,
    IsSized stateStruct stateSize) =>
   (forall r.
    state -> Maybe.T r (Value Bool, state) (value, state)) ->
   (forall r.
    CodeGenFunction r state) ->
   IO (FunPtr (IO (Ptr stateStruct)),
       FunPtr (Ptr stateStruct -> IO ()),
       FunPtr (Ptr stateStruct -> Word32 -> Ptr struct -> IO Word32))
compileChunky next start =
   Exec.compileModule $
      liftM3 (,,)
         (createFunction ExternalLinkage $
          do
             -- FIXME: size computation in LLVM currently does not work for structs!
             pptr <- Rep.malloc
             flip Rep.store pptr =<< start
             ret pptr)
{- for debugging: allocation with initialization makes type inference difficult
         (createFunction ExternalLinkage $
          do
             pptr <- malloc
             let retn :: CodeGenFunction r state -> Value (Ptr state) -> CodeGenFunction (Ptr state) ()
                 retn _ ptr = ret ptr
             retn undefined pptr)
-}
         (createFunction ExternalLinkage $
          \ pptr -> Rep.free pptr >> ret ())
         (createFunction ExternalLinkage $
          \ sptr loopLen ptr -> do
             sInit <- Rep.load sptr
             (pos,sExit) <- Maybe.arrayLoop loopLen ptr sInit $
              \ ptri s0 -> do
                (y,s1) <- next s0
                Maybe.lift $ Rep.store y ptri
                return s1
             Rep.store sExit sptr
             ret (pos :: Value Word32))


runChunky ::
   (Storable a, MakeValueTuple a value, Rep.Memory value struct) =>
   SVL.ChunkSize -> T value -> IO (SVL.Vector a)
runChunky (SVL.ChunkSize size)
     (Cons next start createIOContext deleteIOContext) = do
   ioContext <- createIOContext
   (startFunc, stopFunc, fill) <-
      compileChunky (next ioContext) (start ioContext)

   statePtr <- Rep.newForeignPtrInit stopFunc startFunc
   -- for explanation see Causal.Process
   ioContextPtr <- Rep.newForeignPtr (deleteIOContext ioContext) False

   let go =
         unsafeInterleaveIO $ do
            v <-
               withForeignPtr statePtr $ \sptr ->
               SVB.createAndTrim size $
               fmap (fromIntegral :: Word32 -> Int) .
               derefChunkPtr fill sptr (fromIntegral size) .
               Rep.castStorablePtr
            touchForeignPtr ioContextPtr
            (if SV.length v > 0
               then fmap (v:)
               else id) $
               (if SV.length v < size
                  then return []
                  else go)
   fmap SVL.fromChunks go

renderChunky ::
   (Storable a, MakeValueTuple a value, Rep.Memory value struct) =>
   SVL.ChunkSize -> T value -> SVL.Vector a
renderChunky size sig =
   unsafePerformIO (runChunky size sig)