{-# LANGUAGE TypeFamilies #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}
{- |
Re-export functions from "Sound.Frame.Stereo"
and add (orphan) instances for various LLVM type classes.
If you want to use the Stereo datatype with synthesizer-llvm
we recommend to import this module instead of
"Sound.Frame.Stereo" or "Sound.Frame.NumericPrelude.Stereo".
-}
module Synthesizer.LLVM.Frame.Stereo (
   Stereo.T, Stereo.cons, Stereo.left, Stereo.right,
   Stereo.Channel(Left, Right), Stereo.select,
   Stereo.arrowFromMono,
   Stereo.arrowFromMonoControlled,
   Stereo.arrowFromChannels,
   Stereo.interleave,
   Stereo.sequence,
   Stereo.liftApplicative,
   ) where

import qualified Synthesizer.Frame.Stereo as Stereo

import qualified LLVM.Extra.Arithmetic as A
import qualified LLVM.Extra.Class as Class
import qualified LLVM.Extra.Memory as Memory
import qualified LLVM.Extra.Control as C
import qualified LLVM.Extra.Vector as Vector
import qualified LLVM.Core as LLVM
import LLVM.Extra.Class
   (Undefined, undefTuple,
    MakeValueTuple, ValueTuple, valueTupleOf, )
import LLVM.Util.Loop (Phi, phis, addPhis, )

import Type.Data.Num.Decimal (d0, d1, )

import Control.Monad (liftM2, )
import Control.Applicative (liftA2, )
import qualified Data.Traversable as Trav

import Prelude hiding (Either(Left, Right), sequence, )


instance (Class.Zero a) => Class.Zero (Stereo.T a) where
   zeroTuple = Stereo.cons Class.zeroTuple Class.zeroTuple

instance (Undefined a) => Undefined (Stereo.T a) where
   undefTuple = Stereo.cons undefTuple undefTuple

instance (C.Select a) => C.Select (Stereo.T a) where
   select = C.selectTraversable

{-
instance LLVM.CmpRet a, LLVM.CmpResult a ~ b => LLVM.CmpRet (Stereo.T a) (Stereo.T b) where
-}

instance (MakeValueTuple h) => MakeValueTuple (Stereo.T h) where
   type ValueTuple (Stereo.T h) = Stereo.T (ValueTuple h)
   valueTupleOf s =
      Stereo.cons
         (Class.valueTupleOf $ Stereo.left s)
         (Class.valueTupleOf $ Stereo.right s)

{-
instance ValueTuple a => ValueTuple (Stereo.T a) where
   buildTuple f =
      liftM2 Stereo.cons (buildTuple f) (buildTuple f)

instance IsTuple a => IsTuple (Stereo.T a) where
   tupleDesc s =
      tupleDesc (Stereo.left s) ++
      tupleDesc (Stereo.right s)
-}

instance (Phi a) => Phi (Stereo.T a) where
   phis bb v =
      liftM2 Stereo.cons
         (phis bb (Stereo.left v))
         (phis bb (Stereo.right v))
   addPhis bb x y = do
      addPhis bb (Stereo.left  x) (Stereo.left  y)
      addPhis bb (Stereo.right x) (Stereo.right y)


instance (Vector.Simple v) => Vector.Simple (Stereo.T v) where
   type Element (Stereo.T v) = Stereo.T (Vector.Element v)
   type Size (Stereo.T v) = Vector.Size v
   shuffleMatch = Vector.shuffleMatchTraversable
   extract = Vector.extractTraversable

instance (Vector.C v) => Vector.C (Stereo.T v) where
   insert = Vector.insertTraversable


type Struct a = LLVM.Struct (a, (a, ()))

memory ::
   (Memory.C l) =>
   Memory.Record r (Struct (Memory.Struct l)) (Stereo.T l)
memory =
   liftA2 Stereo.cons
      (Memory.element Stereo.left  d0)
      (Memory.element Stereo.right d1)

instance (Memory.C l) => Memory.C (Stereo.T l) where
   type Struct (Stereo.T l) = Struct (Memory.Struct l)
   load = Memory.loadRecord memory
   store = Memory.storeRecord memory
   decompose = Memory.decomposeRecord memory
   compose = Memory.composeRecord memory


{-
instance
      (Memory l s, LLVM.IsSized s ss) =>
      Memory (Stereo.T l) (LLVM.Struct (s, (s, ()))) where
   load ptr =
      liftM2 Stereo.cons
         (load =<< getElementPtr0 ptr (d0, ()))
         (load =<< getElementPtr0 ptr (d1, ()))
   store y ptr = do
      store (Stereo.left  y) =<< getElementPtr0 ptr (d0, ())
      store (Stereo.right y) =<< getElementPtr0 ptr (d1, ())
-}

instance (A.Additive a) => A.Additive (Stereo.T a) where
   zero = Stereo.cons A.zero A.zero
   add x y = Trav.sequence $ liftA2 A.add x y
   sub x y = Trav.sequence $ liftA2 A.sub x y
   neg x   = Trav.sequence $ fmap A.neg x

type instance A.Scalar (Stereo.T a) = A.Scalar a

instance (A.PseudoModule a) => A.PseudoModule (Stereo.T a) where
   scale a = Trav.sequence . fmap (A.scale a)