{-# LANGUAGE NoImplicitPrelude #-}
{-# LANGUAGE ExistentialQuantification #-}
{-# LANGUAGE Rank2Types #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE UndecidableInstances #-}
module Synthesizer.LLVM.Plug.Output where

import qualified Synthesizer.Zip as Zip

import qualified LLVM.Extra.Memory as Memory
import qualified LLVM.Extra.Class as Class
import qualified LLVM.Extra.Arithmetic as A

import qualified LLVM.Core as LLVM

import Control.Monad (liftM2, )

import qualified Synthesizer.LLVM.Storable.Vector as SVU
import qualified Data.StorableVector as SV
import qualified Data.StorableVector.Base as SVB

import qualified Foreign.ForeignPtr as FPtr
import Foreign.Storable.Tuple ()
import Foreign.Storable (Storable, )

import NumericPrelude.Numeric
import NumericPrelude.Base hiding (and, iterate, map, zip, zipWith, take, takeWhile, )


data T a b =
   forall state ioContext paramTuple.
      (Storable paramTuple,
       Class.MakeValueTuple paramTuple,
       Memory.C (Class.ValueTuple paramTuple),
       Memory.C state) =>
   Cons
      (forall r.
       Class.ValueTuple paramTuple ->
       a -> state -> LLVM.CodeGenFunction r state)
          -- compute next value
      (forall r.
       Class.ValueTuple paramTuple ->
       LLVM.CodeGenFunction r state)
          -- initial state
      (Int -> IO (ioContext, paramTuple))
          {- initialization from IO monad
          This is called once per output chunk
          with the number of input samples.
          This number is also the maximum possible number of output samples.
          This will be run within Unsafe.performIO,
          so no observable In/Out actions please!
          -}
      (Int -> ioContext -> IO b)
          {-
          finalization from IO monad, also run within Unsafe.performIO
          The integer argument is the actually produced size of data.
          We must clip the allocated output vectors accordingly.
          -}


class Default b where
   type Element b :: *
   deflt :: T (Element b) b

instance (Default c, Default d) => Default (Zip.T c d) where
   type Element (Zip.T c d) = (Element c, Element d)
   deflt = split deflt deflt

instance
   (Storable a, Class.MakeValueTuple a, Memory.C (Class.ValueTuple a)) =>
      Default (SV.Vector a) where
   type Element (SV.Vector a) = Class.ValueTuple a
   deflt = storableVector


split :: T a c -> T b d -> T (a,b) (Zip.T c d)
split (Cons nextA startA createA deleteA)
      (Cons nextB startB createB deleteB) = Cons
   (\(parameterA, parameterB) (a,b) (sa0,sb0) -> do
      sa1 <- nextA parameterA a sa0
      sb1 <- nextB parameterB b sb0
      return (sa1,sb1))
   (\(parameterA, parameterB) ->
      liftM2 (,)
         (startA parameterA)
         (startB parameterB))
   (\len -> do
      (ca,paramA) <- createA len
      (cb,paramB) <- createB len
      return ((ca,cb), (paramA, paramB)))
   (\len (ca,cb) ->
      liftM2 Zip.Cons
         (deleteA len ca)
         (deleteB len cb))


storableVector ::
   (Class.MakeValueTuple a, value ~ Class.ValueTuple a,
    Memory.C value, Storable a) =>
   T value (SV.Vector a)
storableVector = Cons
   (\ _ a p ->
      Memory.store a p >> A.advanceArrayElementPtr p)
   return
   (\len -> do
      vec <- SVB.create len (const $ return ())
      -- offset should be always zero, but we must not rely on that
      let (_fp,ptr,_l) = SVU.unsafeToPointers vec
      return (vec, ptr))
   (\len vec -> do
      let (fp,_s,_l) = SVB.toForeignPtr vec
      -- keep the foreign ptr alive
      FPtr.touchForeignPtr fp
      return $ SV.take len vec)