{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE EmptyDataDecls #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE StandaloneDeriving #-}
{- |
A special vector type that represents a time-sequence of samples.
This way we can distinguish safely between LLVM vectors
used for parallel signals and pipelines and
those used for chunky processing of scalar signals.
For the chunky processing this data type allows us
to derive the factor from the type
that time constants have to be multiplied with.
-}
module Synthesizer.LLVM.Frame.SerialVector.Class (
   Constant(Constant), constant,

   Read, Element, ReadIt, extract, readStart, readNext,
   Write, WriteIt, insert, writeStart, writeNext, writeStop,
   Zero, writeZero,
   Iterator(Iterator), ReadIterator, WriteIterator, ReadMode, WriteMode,

   Sized, Size, size, sizeOfIterator, withSize,

   insertTraversable, extractTraversable,
   readStartTraversable, readNextTraversable,
   writeStartTraversable, writeNextTraversable, writeStopTraversable,
   writeZeroTraversable,

   dissect, assemble, modify,
   upsample, subsample, last,
   iterate, reverse,
   shiftUp, shiftUpMultiZero, shiftDownMultiZero,
   ) where

import qualified Synthesizer.LLVM.Frame.SerialVector.Code as SerialCode
import qualified Synthesizer.LLVM.Frame.Stereo as Stereo

import qualified LLVM.Extra.Multi.Vector as MultiVector
import qualified LLVM.Extra.Multi.Value as MultiValue
import qualified LLVM.Extra.Memory as Memory
import qualified LLVM.Extra.Arithmetic as A
import qualified LLVM.Extra.Tuple as Tuple

import qualified LLVM.Core as LLVM

import qualified Type.Data.Num.Decimal as TypeNum

import Data.Word (Word32)

import qualified Control.Monad.Trans.State as MS
import qualified Control.Applicative as App
import Control.Monad (foldM, replicateM, (<=<))
import Control.Applicative (liftA2, liftA3, (<$>))

import qualified Data.Traversable as Trav
import qualified Data.List.HT as ListHT
import qualified Data.List as List
import Data.Tuple.HT (mapSnd, fst3, snd3, thd3)

import Prelude hiding (Read, replicate, reverse, iterate, last)



newtype Constant n a = Constant a

constant :: (TypeNum.Positive n) => a -> Constant n a
constant = Constant

instance Functor (Constant n) where
   fmap f (Constant a) = Constant (f a)

instance App.Applicative (Constant n) where
   pure = Constant
   Constant f <*> Constant a = Constant (f a)

instance (Tuple.Phi a) => Tuple.Phi (Constant n a) where
   phi bb (Constant a) = Constant <$> Tuple.phi bb a
   addPhi bb (Constant a) (Constant b) = Tuple.addPhi bb a b

instance (Tuple.Undefined a) => Tuple.Undefined (Constant n a) where
   undef = Tuple.undefPointed



instance (TypeNum.Positive n) => Sized (Constant n a) where
   type Size (Constant n a) = n

instance
   (TypeNum.Positive n, Tuple.Phi a, Tuple.Undefined a) =>
      Read (Constant n a) where

   type Element (Constant n a) = a
   type ReadIt (Constant n a) = a

   extract _k (Constant a) = return a

   readStart (Constant a) = return $ Iterator a
   readNext it@(Iterator a) = return (a, it)



newtype Iterator mode it v = Iterator {unIterator :: it}
   deriving (Tuple.Undefined)

instance Tuple.Phi it => Tuple.Phi (Iterator mode it v) where
   phi bb (Iterator x) = fmap Iterator $ Tuple.phi bb x
   addPhi bb (Iterator x) (Iterator y) = Tuple.addPhi bb x y


type ReadIterator = Iterator ReadMode
type WriteIterator = Iterator WriteMode

data ReadMode
data WriteMode


instance (Memory.C it) => Memory.C (Iterator mode it v) where
   type Struct (Iterator mode it v) = Memory.Struct it
   load = Memory.loadNewtype Iterator
   store = Memory.storeNewtype (\(Iterator v) -> v)
   decompose = Memory.decomposeNewtype Iterator
   compose = Memory.composeNewtype (\(Iterator v) -> v)


fmapIt ::
   (ita -> itb) -> (va -> vb) ->
   Iterator mode ita va -> Iterator mode itb vb
fmapIt f _ (Iterator a) = Iterator (f a)


combineIt2 ::
   Iterator mode xa va -> Iterator mode xb vb ->
   Iterator mode (xa,xb) (va,vb)
combineIt2 (Iterator va) (Iterator vb) = Iterator (va,vb)

combineIt3 ::
   Iterator mode xa va -> Iterator mode xb vb -> Iterator mode xc vc ->
   Iterator mode (xa,xb,xc) (va,vb,vc)
combineIt3 (Iterator va) (Iterator vb) (Iterator vc) = Iterator (va,vb,vc)

combineItFunctor ::
   (Functor f) => f (Iterator mode x v) -> Iterator mode (f x) (f v)
combineItFunctor = Iterator . fmap unIterator

sequenceItFunctor ::
   (Functor f) => Iterator mode (f it) (f v) -> f (Iterator mode it v)
sequenceItFunctor = fmap Iterator . unIterator


withSize :: Sized v => (Int -> m v) -> m v
withSize =
   let sz :: (Sized v) => TypeNum.Singleton (Size v) -> (Int -> m v) -> m v
       sz n f = f (TypeNum.integralFromSingleton n)
   in  sz TypeNum.singleton

size :: (Sized v, Integral i) => v -> i
size =
   let sz :: (Sized v, Integral i) => TypeNum.Singleton (Size v) -> v -> i
       sz n _ = TypeNum.integralFromSingleton n
   in  sz TypeNum.singleton

sizeOfIterator :: (Sized v, Integral i) => Iterator mode it v -> i
sizeOfIterator =
   let sz :: (Sized v, Integral i) =>
               TypeNum.Singleton (Size v) -> Iterator mode it v -> i
       sz n _ = TypeNum.integralFromSingleton n
   in  sz TypeNum.singleton


{- |
The type parameter @v@ shall be a @MultiVector@ or @MultiValue Serial@
or a wrapper around one or more such things sharing the same size.
-}
class (TypeNum.Positive (Size v)) => Sized v where
   type Size v

class
   (Sized v,
    Tuple.Phi (ReadIt v), Tuple.Undefined (ReadIt v),
    Tuple.Phi v, Tuple.Undefined v) =>
      Read v where

   type Element v
   type ReadIt v

   extract :: LLVM.Value Word32 -> v -> LLVM.CodeGenFunction r (Element v)

   dissect :: v -> LLVM.CodeGenFunction r [Element v]
   dissect x = mapM (flip extract x . LLVM.valueOf) (take (size x) [0..])

   readStart :: v -> LLVM.CodeGenFunction r (ReadIterator (ReadIt v) v)
   readNext ::
      ReadIterator (ReadIt v) v ->
      LLVM.CodeGenFunction r (Element v, ReadIterator (ReadIt v) v)

class
   (Read v, Tuple.Phi (WriteIt v), Tuple.Undefined (WriteIt v)) =>
      Write v where
   type WriteIt v

   insert :: LLVM.Value Word32 -> Element v -> v -> LLVM.CodeGenFunction r v

   assemble :: [Element v] -> LLVM.CodeGenFunction r v
   assemble =
      foldM (\v (k,x) -> insert (LLVM.valueOf k) x v) Tuple.undef . zip [0..]

   writeStart :: LLVM.CodeGenFunction r (WriteIterator (WriteIt v) v)
   writeNext ::
      Element v -> WriteIterator (WriteIt v) v ->
      LLVM.CodeGenFunction r (WriteIterator (WriteIt v) v)
   writeStop :: WriteIterator (WriteIt v) v -> LLVM.CodeGenFunction r v

class (Write v, Tuple.Phi (WriteIt v), Tuple.Zero (WriteIt v)) => Zero v where
   -- initializes the target with zeros
   -- you may only call 'writeStop' on the result of 'writeZero'
   writeZero :: LLVM.CodeGenFunction r (WriteIterator (WriteIt v) v)



instance (TypeNum.Positive n) => Sized (MultiVector.T n a) where
   type Size (MultiVector.T n a) = n

instance (TypeNum.Positive n, MultiVector.C a) => Read (MultiVector.T n a) where

   type Element (MultiVector.T n a) = MultiValue.T a
   type ReadIt (MultiVector.T n a) = MultiVector.T n a

   extract = MultiVector.extract

   readStart v = return $ Iterator v
   readNext (Iterator v) =
      mapSnd Iterator <$> MultiVector.shiftDown MultiValue.undef v

instance
      (TypeNum.Positive n, MultiVector.C a) => Write (MultiVector.T n a) where

   type WriteIt (MultiVector.T n a) = MultiVector.T n a

   insert = MultiVector.insert

   writeStart = return (Iterator MultiVector.undef)
   writeNext x (Iterator v) = Iterator . snd <$> MultiVector.shiftDown x v
   writeStop (Iterator v) = return v

instance (TypeNum.Positive n, MultiVector.C a) => Zero (MultiVector.T n a) where
   writeZero = return (Iterator Tuple.zero)



type Serial n a = SerialCode.Value n a

instance (TypeNum.Positive n) => Sized (Serial n a) where
   type Size (Serial n a) = n

instance (TypeNum.Positive n, MultiVector.C a) => Read (Serial n a) where

   type Element (Serial n a) = MultiValue.T a
   type ReadIt (Serial n a) = Serial n a

   extract = SerialCode.extract

   readStart v = return $ Iterator v
   readNext (Iterator v) =
      mapSnd Iterator <$> SerialCode.shiftDown MultiValue.undef v

instance (TypeNum.Positive n, MultiVector.C a) => Write (Serial n a) where

   type WriteIt (Serial n a) = Serial n a

   insert = SerialCode.insert

   writeStart = return (Iterator Tuple.undef)
   writeNext x (Iterator v) = Iterator . snd <$> SerialCode.shiftDown x v
   writeStop (Iterator v) = return v

instance (TypeNum.Positive n, MultiVector.C a) => Zero (Serial n a) where
   writeZero = return (Iterator Tuple.zero)



instance (Sized va, Sized vb, Size va ~ Size vb) => Sized (va, vb) where
   type Size (va, vb) = Size va

instance (Read va, Read vb, Size va ~ Size vb) => Read (va, vb) where

   type Element (va, vb) = (Element va, Element vb)
   type ReadIt (va, vb) = (ReadIt va, ReadIt vb)

   extract k (va,vb) = liftA2 (,) (extract k va) (extract k vb)

   readStart (va,vb) = liftA2 combineIt2 (readStart va) (readStart vb)
   readNext it = do
      (a, ita) <- readNext $ fmapIt fst fst it
      (b, itb) <- readNext $ fmapIt snd snd it
      return ((a,b), combineIt2 ita itb)

instance (Write va, Write vb, Size va ~ Size vb) => Write (va, vb) where

   type WriteIt (va, vb) = (WriteIt va, WriteIt vb)

   insert k (a,b) (va,vb) =
      liftA2 (,)
         (insert k a va)
         (insert k b vb)

   writeStart = liftA2 combineIt2 writeStart writeStart
   writeNext (a,b) it =
      liftA2 combineIt2
         (writeNext a $ fmapIt fst fst it)
         (writeNext b $ fmapIt snd snd it)
   writeStop it =
      liftA2 (,)
         (writeStop (fmapIt fst fst it))
         (writeStop (fmapIt snd snd it))

instance (Zero va, Zero vb, Size va ~ Size vb) => Zero (va, vb) where
   writeZero = liftA2 combineIt2 writeZero writeZero


instance
   (Sized va, Sized vb, Sized vc, Size va ~ Size vb, Size vb ~ Size vc) =>
      Sized (va, vb, vc) where
   type Size (va, vb, vc) = Size va

instance
   (Read va, Read vb, Read vc, Size va ~ Size vb, Size vb ~ Size vc) =>
      Read (va, vb, vc) where

   type Element (va, vb, vc) = (Element va, Element vb, Element vc)
   type ReadIt (va, vb, vc) = (ReadIt va, ReadIt vb, ReadIt vc)

   extract k (va,vb,vc) =
      liftA3 (,,)
         (extract k va)
         (extract k vb)
         (extract k vc)

   readStart (va,vb,vc) =
      liftA3 combineIt3 (readStart va) (readStart vb) (readStart vc)
   readNext it = do
      (a, ita) <- readNext $ fmapIt fst3 fst3 it
      (b, itb) <- readNext $ fmapIt snd3 snd3 it
      (c, itc) <- readNext $ fmapIt thd3 thd3 it
      return ((a,b,c), combineIt3 ita itb itc)


instance
   (Write va, Write vb, Write vc, Size va ~ Size vb, Size vb ~ Size vc) =>
      Write (va, vb, vc) where

   type WriteIt (va, vb, vc) = (WriteIt va, WriteIt vb, WriteIt vc)

   insert k (a,b,c) (va,vb,vc) =
      liftA3 (,,)
         (insert k a va)
         (insert k b vb)
         (insert k c vc)

   writeStart = liftA3 combineIt3 writeStart writeStart writeStart
   writeNext (a,b,c) it =
      liftA3 combineIt3
         (writeNext a $ fmapIt fst3 fst3 it)
         (writeNext b $ fmapIt snd3 snd3 it)
         (writeNext c $ fmapIt thd3 thd3 it)
   writeStop it =
      liftA3 (,,)
         (writeStop (fmapIt fst3 fst3 it))
         (writeStop (fmapIt snd3 snd3 it))
         (writeStop (fmapIt thd3 thd3 it))

instance
   (Zero va, Zero vb, Zero vc, Size va ~ Size vb, Size vb ~ Size vc) =>
      Zero (va, vb, vc) where

   writeZero = liftA3 combineIt3 writeZero writeZero writeZero


instance (Sized value) => Sized (Stereo.T value) where
   type Size (Stereo.T value) = Size value

instance (Read v) => Read (Stereo.T v) where

   type Element (Stereo.T v) = Stereo.T (Element v)
   type ReadIt (Stereo.T v) = Stereo.T (ReadIt v)

   extract = extractTraversable

   readStart = readStartTraversable
   readNext = readNextTraversable

instance (Write v) => Write (Stereo.T v) where

   type WriteIt (Stereo.T v) = Stereo.T (WriteIt v)

   insert = insertTraversable

   writeStart = writeStartTraversable
   writeNext = writeNextTraversable
   writeStop = writeStopTraversable

instance (Zero v) => Zero (Stereo.T v) where

   writeZero = writeZeroTraversable


insertTraversable ::
   (Write v, Trav.Traversable f, App.Applicative f) =>
   LLVM.Value Word32 -> f (Element v) -> f v -> LLVM.CodeGenFunction r (f v)
insertTraversable n a v =
   Trav.sequence (liftA2 (insert n) a v)

extractTraversable ::
   (Read v, Trav.Traversable f) =>
   LLVM.Value Word32 -> f v -> LLVM.CodeGenFunction r (f (Element v))
extractTraversable n v =
   Trav.mapM (extract n) v


readStartTraversable ::
   (Trav.Traversable f, App.Applicative f, Read v) =>
   f v -> LLVM.CodeGenFunction r (ReadIterator (f (ReadIt v)) (f v))
readNextTraversable ::
   (Trav.Traversable f, App.Applicative f, Read v) =>
   ReadIterator (f (ReadIt v)) (f v) ->
   LLVM.CodeGenFunction r (f (Element v), ReadIterator (f (ReadIt v)) (f v))

readStartTraversable v =
   fmap combineItFunctor $ Trav.mapM readStart v

readNextTraversable it = do
   st <- Trav.mapM readNext $ sequenceItFunctor it
   return (fmap fst st, combineItFunctor $ fmap snd st)


writeStartTraversable ::
   (Trav.Traversable f, App.Applicative f, Write v) =>
   LLVM.CodeGenFunction r (WriteIterator (f (WriteIt v)) (f v))
writeNextTraversable ::
   (Trav.Traversable f, App.Applicative f, Write v) =>
   f (Element v) -> WriteIterator (f (WriteIt v)) (f v) ->
   LLVM.CodeGenFunction r (WriteIterator (f (WriteIt v)) (f v))
writeStopTraversable ::
   (Trav.Traversable f, App.Applicative f, Write v) =>
   WriteIterator (f (WriteIt v)) (f v) -> LLVM.CodeGenFunction r (f v)
writeZeroTraversable ::
   (Trav.Traversable f, App.Applicative f, Zero v) =>
   LLVM.CodeGenFunction r (WriteIterator (f (WriteIt v)) (f v))

writeStartTraversable =
   fmap combineItFunctor $ Trav.sequence $ App.pure writeStart

writeNextTraversable x it =
   fmap combineItFunctor $ Trav.sequence $
   liftA2 writeNext x $ sequenceItFunctor it

writeStopTraversable = Trav.mapM writeStop . sequenceItFunctor

writeZeroTraversable =
   fmap combineItFunctor $ Trav.sequence $ App.pure writeZero


modify ::
   (Write v, Element v ~ a) =>
   LLVM.Value Word32 ->
   (a -> LLVM.CodeGenFunction r a) ->
   v -> LLVM.CodeGenFunction r v
modify k f v = flip (insert k) v =<< f =<< extract k v


last :: (Read v) => v -> LLVM.CodeGenFunction r (Element v)
last v = extract (LLVM.valueOf (size v - 1 :: Word32)) v

subsample :: (Read v) => v -> LLVM.CodeGenFunction r (Element v)
subsample v = extract (A.zero :: LLVM.Value Word32) v

-- this will be translated to an efficient pshufd
upsample :: (Write v) => Element v -> LLVM.CodeGenFunction r v
upsample x = withSize $ \n -> assemble $ List.replicate n x


iterate ::
   (Write v) =>
   (Element v -> LLVM.CodeGenFunction r (Element v)) ->
   Element v -> LLVM.CodeGenFunction r v
iterate f x =
   withSize $ \n ->
      assemble =<<
      (flip MS.evalStateT x $
       replicateM n $
       MS.StateT $ \x0 -> do x1 <- f x0; return (x0,x1))

reverse ::
   (Write v) =>
   v -> LLVM.CodeGenFunction r v
reverse =
   assemble . List.reverse <=< dissect

shiftUp ::
   (Write v) =>
   Element v -> v -> LLVM.CodeGenFunction r (Element v, v)
shiftUp x v =
   ListHT.switchR
      (return (x,v))
      (\ys0 y -> fmap ((,) y) $ assemble (x:ys0))
   =<<
   dissect v


shiftUpMultiZero ::
   (Write v, A.Additive (Element v)) =>
   Int -> v -> LLVM.CodeGenFunction r v
shiftUpMultiZero n v =
   assemble . take (size v) . (List.replicate n A.zero ++) =<< dissect v

shiftDownMultiZero ::
   (Write v, A.Additive (Element v)) =>
   Int -> v -> LLVM.CodeGenFunction r v
shiftDownMultiZero n v =
   assemble . take (size v) . (++ List.repeat A.zero) . List.drop n
      =<< dissect v