{-# LANGUAGE NoImplicitPrelude #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE ForeignFunctionInterface #-}
{- |
Functions on lazy storable vectors that are implemented using LLVM.
-}
module Synthesizer.LLVM.Storable.Signal (
   unpackStrict, unpack,
   makeUnpackGenericStrict, makeUnpackGeneric,
   makeReversePackedStrict, makeReversePacked,
   continue, continuePacked, continuePackedGeneric,
   makeMixer,
   makeArranger, arrange,
   ) where

import qualified Synthesizer.LLVM.Parameterized.Signal as SigP
import qualified Synthesizer.LLVM.Parameterized.SignalPacked as SigPS

import qualified Synthesizer.LLVM.Execution as Exec
import qualified Synthesizer.LLVM.Sample as Sample
import qualified LLVM.Extra.Representation as Rep
import qualified LLVM.Extra.Vector as Vector
import LLVM.Extra.Control (arrayLoop, )

import qualified Data.StorableVector.Lazy as SVL
import qualified Data.StorableVector as SV
import qualified Data.StorableVector.Base as SVB

import qualified Data.EventList.Relative.TimeBody  as EventList
import qualified Data.EventList.Relative.TimeMixed as EventListTM
import qualified Data.EventList.Absolute.TimeBody  as AbsEventList
import qualified Number.NonNegative as NonNeg

import qualified Algebra.Additive as Additive

import LLVM.Extra.Arithmetic (advanceArrayElementPtr, )

import LLVM.Core
   (Linkage(ExternalLinkage), createFunction, ret,
    MakeValueTuple, IsSized, IsPrimitive, getElementPtr,
    Vector, IsPowerOf2, )
import qualified Data.TypeLevel.Num as TypeNum

import qualified Control.Category as Cat

import qualified Data.List.HT as ListHT
import Data.Word (Word32, )
import Data.Int (Int32, )
import Foreign.Ptr (Ptr, )
import Foreign.ForeignPtr (castForeignPtr, )
import Foreign.Storable (Storable, )
import Foreign.Marshal.Array (advancePtr, )
import qualified Foreign.Marshal.Array as Array

import System.IO.Unsafe (unsafePerformIO, )

import NumericPrelude.Numeric
import NumericPrelude.Base


{- |
This function needs only constant time
in contrast to 'Synthesizer.LLVM.Parameterized.SignalPacked.unpack'.

We cannot provide a 'pack' function
since the array size may not line up.
It would also need copying since the source data may not be aligned properly.
-}
unpackStrict ::
   (Storable a, IsPrimitive a, IsPowerOf2 n) =>
   SV.Vector (Vector n a) -> SV.Vector a
unpackStrict v =
   let getDim :: (TypeNum.Nat n) => SV.Vector (Vector n a) -> n -> Int
       getDim _ = TypeNum.toInt
       d = getDim v undefined
       (fptr,s,l) = SVB.toForeignPtr v
   in  SVB.SV (castForeignPtr fptr) (s*d) (l*d)

unpack ::
   (Storable a, IsPrimitive a, IsPowerOf2 n) =>
   SVL.Vector (Vector n a) -> SVL.Vector a
unpack =
   SVL.fromChunks . map unpackStrict . SVL.chunks

{- |
This is similar to 'unpackStrict' but performs rearrangement of data.
This is for instance necessary for stereo signals
where the data layout of packed and unpacked data is different,
thus simple casting of the data is not possible.
-}
makeUnpackGenericStrict ::
   (Vector.Access n va vv,
    Storable a, MakeValueTuple a va, Rep.Memory va as, IsSized as asize,
    Storable v, MakeValueTuple v vv, Rep.Memory vv vs, IsSized vs vsize) =>
   IO (SV.Vector v -> SV.Vector a)
makeUnpackGenericStrict =
   let vectorSize ::
          (Vector.Access n al vl, Storable v, MakeValueTuple v vl) =>
          SV.Vector v -> n
       vectorSize _ = undefined
   in  fmap (\f v -> f (TypeNum.toInt (vectorSize v) * SV.length v) v) $
       SigP.run (SigPS.unpack $ SigP.fromStorableVector Cat.id)

makeUnpackGeneric ::
   (Vector.Access n va vv,
    Storable a, MakeValueTuple a va, Rep.Memory va as, IsSized as asize,
    Storable v, MakeValueTuple v vv, Rep.Memory vv vs, IsSized vs vsize) =>
   IO (SVL.Vector v -> SVL.Vector a)
makeUnpackGeneric =
   fmap (\f -> SVL.fromChunks . map f . SVL.chunks) $
   makeUnpackGenericStrict


makeReverser ::
   (Storable a, Vector.ShuffleMatch n value,
    MakeValueTuple a value, Rep.Memory value struct) =>
   value -> IO (Word32 -> Ptr a -> Ptr a -> IO ())
--   (Rep.Memory a struct, Vector.ShuffleMatch n a) =>
--   IO (Word32 -> Ptr struct -> Ptr struct -> IO ())
makeReverser dummy =
   fmap (\f len srcPtr dstPtr ->
      f len (Rep.castStorablePtr srcPtr) (Rep.castStorablePtr dstPtr)) $
   fmap derefMixPtr $
   Exec.compileModule $
   createFunction ExternalLinkage $ \ size ptrA ptrB -> do
      ptrAEnd <- getElementPtr ptrA (size, ())
      arrayLoop size ptrB ptrAEnd $ \ ptrBi ptrAj0 -> do
         ptrAj1 <- getElementPtr ptrAj0 (-1 :: Int32, ())
         flip Rep.store ptrBi
            =<< Vector.reverse
            . flip asTypeOf dummy
            =<< Rep.load ptrAj1
         return ptrAj1
      ret ()

makeReversePackedStrict ::
   (Storable v, Vector.Access n va vv,
    MakeValueTuple v vv, Rep.Memory vv vs, IsSized vs vsize) =>
   IO (SV.Vector v -> SV.Vector v)
makeReversePackedStrict = do
   rev <- makeReverser undefined
   return $ \v ->
      unsafePerformIO $
      SVB.withStartPtr v $ \ptrA len ->
      SVB.create len $ \ptrB ->
      rev (fromIntegral len) ptrA ptrB

makeReversePacked ::
   (Storable v, Vector.Access n va vv,
    MakeValueTuple v vv, Rep.Memory vv vs, IsSized vs vsize) =>
   IO (SVL.Vector v -> SVL.Vector v)
makeReversePacked =
   fmap (\f -> SVL.fromChunks . reverse . map f . SVL.chunks) $
   makeReversePackedStrict


{- |
Append two signals where the second signal
gets the last value of the first signal as parameter.
If the first signal is empty
then there is no parameter for the second signal
and thus we simply return an empty signal in that case.
-}
continue ::
   (Storable a) =>
   SVL.Vector a -> (a -> SVL.Vector a) -> SVL.Vector a
continue x y =
   SVL.fromChunks $
   withLast SV.empty
      (SVL.chunks x)
      (SV.switchR [] $ \_ -> SVL.chunks . y)

_continueNeglectLast ::
   (Storable a) =>
   SVL.Vector a -> (a -> SVL.Vector a) -> SVL.Vector a
_continueNeglectLast x y =
   SVL.switchR SVL.empty
      (\body l -> SVL.append body (y l)) x

continuePacked ::
   (IsPowerOf2 n, Storable a, IsPrimitive a) =>
   SVL.Vector (Vector n a) ->
   (a -> SVL.Vector (Vector n a)) ->
   SVL.Vector (Vector n a)
continuePacked x y =
   SVL.fromChunks $
   withLast SV.empty
      (SVL.chunks x)
      (SV.switchR [] (\_ -> SVL.chunks . y) .
       unpackStrict)

{-
This function reduces the last chunk to size one, repacks that
and takes the last value.
It would be certainly more efficient to use
a single @Rep.load@, @extractelement@ and @store@
instead of a loop of count 1.
However, this implementation is the simplest one, so far.
-}
{- |
Use this like

> do unpackGeneric <- makeUnpackGenericStrict
>    return (continuePackedGeneric unpackGeneric x y)
-}
continuePackedGeneric ::
{-
   (Storable v, Vector.Access n a v,
    MakeValueTuple v vv, Rep.Memory vv vs, IsSized vs vsize) =>
-}
   (Storable v, Storable a) =>
   (SV.Vector v -> SV.Vector a) ->
   SVL.Vector v -> (a -> SVL.Vector v) -> SVL.Vector v
continuePackedGeneric unpackGeneric x y =
   SVL.fromChunks $
   withLast SV.empty
      (SVL.chunks x)
      (\lastChunk ->
         SV.switchR [] (\_ -> SVL.chunks . y) $ unpackGeneric $
         SV.drop (SV.length lastChunk - 1) $ lastChunk)


-- candidate for utility-ht
withLast :: a -> [a] -> (a -> [a]) -> [a]
withLast deflt x y =
   foldr
      (\a cont _ -> a : cont a)
      y x deflt

{-
This version is too strict, since it looks one element ahead.
-}
_withLast :: [a] -> (a -> [a]) -> [a]
_withLast x y =
   ListHT.switchR []
      (\body end -> body ++ end : y end)
      x



foreign import ccall safe "dynamic" derefFillPtr ::
   Exec.Importer (Word32 -> Ptr a -> IO ())

{- |
'fillBuffer' is not only more general than filling with zeros,
it also simplifies type inference.
-}
fillBuffer ::
   (MakeValueTuple a value, Rep.Memory value struct) =>
   value -> IO (Word32 -> Ptr a -> IO ())
fillBuffer x =
   fmap (\f len ptr -> f len (Rep.castStorablePtr ptr)) $
   fmap derefFillPtr $
   Exec.compileModule $
   createFunction ExternalLinkage $ \ size ptr -> do
      arrayLoop size ptr () $ \ ptri () -> do
         Rep.store x ptri
         return ()
      ret ()


foreign import ccall safe "dynamic" derefMixPtr ::
   Exec.Importer (Word32 -> Ptr a -> Ptr a -> IO ())

makeMixer ::
   (Storable a, Sample.Additive value,
    MakeValueTuple a value, Rep.Memory value struct) =>
   value -> IO (Word32 -> Ptr a -> Ptr a -> IO ())
makeMixer dummy =
   fmap (\f len srcPtr dstPtr ->
      f len (Rep.castStorablePtr srcPtr) (Rep.castStorablePtr dstPtr)) $
   fmap derefMixPtr $
   Exec.compileModule $
   createFunction ExternalLinkage $ \ size srcPtr dstPtr -> do
      arrayLoop size srcPtr dstPtr $ \ srcPtri dstPtri -> do
         y <- Rep.load srcPtri
         Rep.modify (Sample.add (y `asTypeOf` dummy)) dstPtri
         advanceArrayElementPtr dstPtri
      ret ()


addToBuffer ::
   (Storable a) =>
   (Word32 -> Ptr a -> Ptr a -> IO ()) ->
   Int -> Ptr a -> Int -> SVL.Vector a -> IO (Int, SVL.Vector a)
addToBuffer addChunkToBuffer len v start xs =
   let (now,future) = SVL.splitAt (len - start) xs
       go i [] = return i
       go i (c:cs) =
          SVB.withStartPtr c (\ptr l ->
             addChunkToBuffer (fromIntegral l) ptr (advancePtr v i)) >>
          go (i + SV.length c) cs
   in  fmap (flip (,) future) . go start . SVL.chunks $ now


{-
Same algorithm as in Synthesizer.Storable.Cut.arrangeEquidist
-}
makeArranger ::
   (Storable a, Sample.Additive value,
    MakeValueTuple a value, Rep.Memory value struct) =>
   IO (SVL.ChunkSize ->
       EventList.T NonNeg.Int (SVL.Vector a) ->
       SVL.Vector a)
makeArranger = do
   mixer <- makeMixer undefined
   fill <- fillBuffer Sample.zero
   return $ \ (SVL.ChunkSize sz) ->
      let sznn = NonNeg.fromNumberMsg "arrange" sz
          go acc evs =
             let (now,future) = EventListTM.splitAtTime sznn evs
                 xs =
                    AbsEventList.toPairList $
                    EventList.toAbsoluteEventList 0 $
                    EventListTM.switchTimeR const now
                 (chunk,newAcc) =
                    unsafePerformIO $
                    SVB.createAndTrim' sz $ \ptr -> do
                       fill (fromIntegral sz) ptr
                       newAcc0 <- flip mapM acc $ addToBuffer mixer sz ptr 0
                       newAcc1 <- flip mapM xs $ \(i,s) ->
                          addToBuffer mixer sz ptr (NonNeg.toNumber i) s
                       let (ends, suffixes) = unzip $ newAcc0++newAcc1
                           {- if there are more events to come,
                              we must pad with zeros -}
                           len =
                              if EventList.null future
                                then foldl max 0 ends
                                else sz
                       return (0, len,
                               filter (not . SVL.null) suffixes)
             in  if SV.null chunk
                   then []
                   else chunk : go newAcc future
      in  SVL.fromChunks . go []

{- |
This is unsafe since it relies on the prior initialization of the LLVM JIT.
Better use 'makeArranger'.
-}
arrange ::
   (Storable a, Sample.Additive value,
    MakeValueTuple a value, Rep.Memory value struct) =>
      SVL.ChunkSize
   -> EventList.T NonNeg.Int (SVL.Vector a)
         {-^ A list of pairs: (relative start time, signal part),
             The start time is relative to the start time
             of the previous event. -}
   -> SVL.Vector a
         {-^ The mixed signal. -}
arrange =
   unsafePerformIO makeArranger