{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE FlexibleContexts #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}
{- |
Represent a vector of Stereo values in two vectors
that store the values in an interleaved way.
That is:

> vector0[0] = left[0]
> vector0[1] = right[0]
> vector0[2] = left[1]
> vector0[3] = right[1]
> vector1[0] = left[2]
> vector1[1] = right[2]
> vector1[2] = left[3]
> vector1[3] = right[3]

This representation is not very useful for computation,
but necessary as intermediate representation for interfacing with memory.
SSE/SSE2 have the instructions UNPACK(L|H)P(S|D) that interleave efficiently.
-}
module Synthesizer.LLVM.Frame.StereoInterleaved (
   T,
   Value(Value),
   interleave,
   deinterleave,
   fromMono,
   assemble, extractAll,
   zero,
   amplify,
   envelope,
   ) where

import qualified Synthesizer.LLVM.Frame.Stereo as Stereo
import qualified Synthesizer.LLVM.Frame.SerialVector as Serial
import qualified Synthesizer.LLVM.CausalParameterized.Functional as F

import qualified LLVM.Extra.Arithmetic as A
import qualified LLVM.Extra.Control as C
import qualified LLVM.Extra.Class as Class
import qualified LLVM.Extra.Memory as Memory
import qualified LLVM.Extra.ScalarOrVector as SoV
import qualified LLVM.Extra.Vector as Vector
import qualified LLVM.Core as LLVM
import LLVM.Extra.Class
   (Undefined, undefTuple,
    MakeValueTuple, valueTupleOf, )
import LLVM.Core
   (Vector, IsSized, SizeOf, )
import LLVM.Util.Loop (Phi, phis, addPhis, )

import qualified Type.Data.Num.Decimal as TypeNum

import Foreign.Ptr (castPtr, )
import qualified Foreign.Storable as St
-- import Data.Word (Word32, )

import qualified Data.Foldable as Fold
import Control.Monad (liftM2, )
import Control.Applicative (liftA2, pure, )

import Data.Tuple.HT (mapPair, )

import qualified Algebra.Additive as Additive


data T n a = Cons (Vector n a) (Vector n a)

data Value n a = Value (LLVM.Value (Vector n a)) (LLVM.Value (Vector n a))


type instance F.Arguments f (Value n a) = f (Value n a)
instance F.MakeArguments (Value n a) where
   makeArgs = id


withSize :: (TypeNum.Natural n) => (Int -> m (Value n a)) -> m (Value n a)
withSize =
   let sz ::
          (TypeNum.Natural n) =>
          TypeNum.Singleton n -> (Int -> m (Value n a)) -> m (Value n a)
       sz n f = f (TypeNum.integralFromSingleton n)
   in  sz TypeNum.singleton


interleave ::
   (LLVM.IsPrimitive a, TypeNum.Positive n) =>
   Stereo.T (Serial.Value n a) ->
   LLVM.CodeGenFunction r (Value n a)
interleave x =
   assemble =<< Serial.extractAll x

deinterleave ::
   (LLVM.IsPrimitive a, TypeNum.Positive n) =>
   Value n a ->
   LLVM.CodeGenFunction r (Stereo.T (Serial.Value n a))
deinterleave v =
   Serial.assemble =<< extractAll v

fromMono ::
   (LLVM.IsPrimitive a, TypeNum.Positive n) =>
   Serial.Value n a ->
   LLVM.CodeGenFunction r (Value n a)
fromMono x =
   assemble . map pure =<< Serial.extractAll x

assemble ::
   (LLVM.IsPrimitive a, TypeNum.Positive n) =>
   [Stereo.T (LLVM.Value a)] -> LLVM.CodeGenFunction r (Value n a)
assemble x =
   withSize $ \n ->
      uncurry (liftM2 Value) .
      mapPair (Vector.assemble, Vector.assemble) .
      splitAt n .
      concatMap Fold.toList $ x

extractAll ::
   (LLVM.IsPrimitive a, TypeNum.Positive n) =>
   Value n a -> LLVM.CodeGenFunction r [Stereo.T (LLVM.Value a)]
extractAll (Value v0 v1) =
   fmap
      (let aux (l:r:xs) = Stereo.cons l r : aux xs
           aux [] = []
           aux _ = error "odd number of stereo elements"
       in  aux) $
   liftM2 (++)
      (Vector.extractAll v0)
      (Vector.extractAll v1)


instance
   (TypeNum.Positive n, LLVM.IsPrimitive a, St.Storable a) =>
      St.Storable (T n a) where
   sizeOf ~(Cons v0 v1) = St.sizeOf v0 + St.sizeOf v1
   alignment ~(Cons v _) = St.alignment v
   peek ptr =
      let p = castPtr ptr
      in  liftM2 Cons
             (St.peekElemOff p 0)
             (St.peekElemOff p 1)
   poke ptr (Cons v0 v1) =
      let p = castPtr ptr
      in  St.pokeElemOff p 0 v0 >>
          St.pokeElemOff p 1 v1

instance (TypeNum.Positive n, LLVM.IsPrimitive a) => Class.Zero (Value n a) where
   zeroTuple = Value Class.zeroTuple Class.zeroTuple

instance (TypeNum.Positive n, LLVM.IsPrimitive a) => Undefined (Value n a) where
   undefTuple = Value (LLVM.value LLVM.undef) (LLVM.value LLVM.undef)

{-
Can only be implemented by ifThenElse
since the atomic 'select' command wants a bool vector.

instance (TypeNum.Positive n, LLVM.IsPrimitive a, Phi a) => C.Select (Value n a) where
   select b (Value x0 x1) (Value y0 y1) =
      liftM2 Value
         (C.select b x0 y0)
         (C.select b x1 y1)

instance LLVM.CmpRet a, LLVM.CmpResult a ~ b => LLVM.CmpRet (Stereo.T a) (Stereo.T b) where
-}

instance (TypeNum.Positive n, LLVM.IsPrimitive a, LLVM.IsConst a) =>
      MakeValueTuple (T n a) where
   type ValueTuple (T n a) = Value n a
   valueTupleOf (Cons v0 v1) =
      Value
         (LLVM.valueOf v0)
         (LLVM.valueOf v1)

instance (TypeNum.Positive n, LLVM.IsPrimitive a) => Phi (Value n a) where
   phis bb = mapV (phis bb)
   addPhis bb = zipV (\_ _ -> ()) (addPhis bb)


instance (TypeNum.Positive n) => Serial.Sized (Value n a) where
   type Size (Value n a) = n

{- |
The implementation of 'extract' may need to perform
arithmetics at run-time and is thus a bit inefficient.
-}
instance (TypeNum.Positive n, LLVM.IsPrimitive a, LLVM.IsFirstClass a) => Serial.Read (Value n a) where
   type Element (Value n a) = Stereo.T (LLVM.Value a)
   type ReadIt (Value n a) = Value n a

   extract k (Value v0 v1) =
      let size = LLVM.valueOf $ fromIntegral $ Vector.sizeInTuple v0
          ext j = do
             b <- A.cmp LLVM.CmpLT j size
             C.ifThenElse b
                (Vector.extract j v0)
                (do j1 <- A.sub j size
                    Vector.extract j1 v1)
      in  do
             k20 <- A.add k k
             k21 <- A.inc k20
             liftM2 Stereo.cons (ext k20) (ext k21)

   extractAll = extractAll

   readStart = return . Serial.Iterator
   readNext (Serial.Iterator v) = do
      xt <- extractAll v
      case xt of
         x:xs -> fmap ((,) x . Serial.Iterator) $ assemble xs
         [] -> error "StereoInterleaved.readNext: size zero"


{- |
The implementation of 'insert' may need to perform
arithmetics at run-time and is thus a bit inefficient.
-}
instance (TypeNum.Positive n, LLVM.IsPrimitive a, LLVM.IsFirstClass a) => Serial.C (Value n a) where
   type WriteIt (Value n a) = Value n a

   insert k x v =
      let size = LLVM.valueOf $ fromIntegral $ Serial.size v
          ins j c (Value v0 v1) = do
             b <- A.cmp LLVM.CmpLT j size
             C.ifThenElse b
                (do w0 <- Vector.insert j c v0
                    return $ Value w0 v1)
                (do j1 <- A.sub j size
                    w1 <- Vector.insert j1 c v1
                    return $ Value v0 w1)
      in  do
             k20 <- A.add k k
             k21 <- A.inc k20
             ins k21 (Stereo.right x) =<< ins k20 (Stereo.left x) v

   assemble = assemble

   writeStart = return (Serial.Iterator Class.undefTuple)
   writeNext x (Serial.Iterator v) = do
      xs <- extractAll v
      fmap Serial.Iterator $ assemble $ tail xs ++ [x]
   writeStop (Serial.Iterator v) = return v


type Struct n a = LLVM.Struct (Vector n a, (Vector n a, ()))

memory ::
   (TypeNum.Positive n, LLVM.IsPrimitive a, LLVM.IsPrimitive am,
    Memory.FirstClass a, Memory.Stored a ~ am,
    IsSized am, TypeNum.Positive (n TypeNum.:*: SizeOf am)) =>
   Memory.Record r (Struct n am) (Value n a)
memory =
   liftA2 Value
      (Memory.element (\(Value v _) -> v) TypeNum.d0)
      (Memory.element (\(Value _ v) -> v) TypeNum.d1)

instance
      (TypeNum.Positive n,
       Memory.FirstClass a, Memory.Stored a ~ am,
       LLVM.IsPrimitive a,  IsSized a,
       TypeNum.Positive (n TypeNum.:*: SizeOf a),
       LLVM.IsPrimitive am, IsSized am,
       TypeNum.Positive (n TypeNum.:*: SizeOf am)) =>
      Memory.C (Value n a) where
   type Struct (Value n a) = Struct n (Memory.Stored a)
   load = Memory.loadRecord memory
   store = Memory.storeRecord memory
   decompose = Memory.decomposeRecord memory
   compose = Memory.composeRecord memory

{- |
This instance allows to run @arrange@ on interleaved stereo vectors.
-}
instance
   (TypeNum.Positive n, LLVM.IsPrimitive a, LLVM.IsArithmetic a) =>
      A.Additive (Value n a) where
   zero = Value A.zero A.zero
   add = zipV Value A.add
   sub = zipV Value A.sub
   neg = mapV A.neg


zero :: (TypeNum.Positive n, Additive.C a) => (T n a)
zero = Cons (pure Additive.zero) (pure Additive.zero)


scale ::
   (TypeNum.Positive n, LLVM.IsPrimitive a, LLVM.IsArithmetic a) =>
   LLVM.Value a -> Value n a -> LLVM.CodeGenFunction r (Value n a)
scale a v = do
   av <- SoV.replicate a
   mapV (A.mul av) v

amplify ::
   (TypeNum.Positive n, LLVM.IsPrimitive a, LLVM.IsArithmetic a, LLVM.IsConst a) =>
   a -> Value n a -> LLVM.CodeGenFunction r (Value n a)
amplify a = scale (LLVM.valueOf a)

envelope ::
   (TypeNum.Positive n, LLVM.IsPrimitive a, LLVM.IsArithmetic a, LLVM.IsConst a) =>
   Serial.Value n a -> Value n a -> LLVM.CodeGenFunction r (Value n a)
envelope e a =
   zipV Value (flip A.mul) a =<< fromMono e


mapV :: (Monad m) =>
   (LLVM.Value (Vector n a) -> m (LLVM.Value (Vector n a))) ->
   Value n a -> m (Value n a)
mapV f (Value x0 x1) =
   liftM2 Value (f x0) (f x1)

zipV :: (Monad m) =>
   (c -> c -> d) ->
   (LLVM.Value (Vector n a) ->
    LLVM.Value (Vector n b) ->
    m c) ->
   Value n a ->
   Value n b ->
   m d
zipV g f (Value x0 x1) (Value y0 y1) =
   liftM2 g (f x0 y0) (f x1 y1)