{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE EmptyDataDecls #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE StandaloneDeriving #-}
{- |
A special vector type that represents a time-sequence of samples.
This way we can distinguish safely between LLVM vectors
used for parallel signals and pipelines and
those used for chunky processing of scalar signals.
For the chunky processing this data type allows us
to derive the factor from the type
that time constants have to be multiplied with.
module Synthesizer.LLVM.Frame.SerialVector (
   Plain, Value,
   plain, value, constant,

   Read, Element, ReadIt, extract, readStart, readNext,
   C, WriteIt, insert, writeStart, writeNext, writeStop,
   Zero, writeZero,
   Iterator(Iterator), ReadIterator, WriteIterator, ReadMode, WriteMode,

   Sized, Size, size, sizeOfIterator, withSize,

   insertTraversable, extractTraversable,
   readStartTraversable, readNextTraversable,
   writeStartTraversable, writeNextTraversable, writeStopTraversable,

   extractAll, assemble, modify,
   upsample, subsample,
   cumulate, iterate, iteratePlain, reverse,
   shiftUp, shiftUpMultiZero, shiftDownMultiZero,
   replicate, replicate_, replicateOf, fromList, fromFixedList,
   mapPlain, mapV, zipV,
   ) where

import qualified Synthesizer.LLVM.Frame.Stereo as Stereo

import qualified LLVM.Extra.Vector as Vector
import qualified LLVM.Extra.Arithmetic as A
import qualified LLVM.Extra.Tuple as Tuple
import qualified LLVM.Extra.Storable as Storable
import qualified LLVM.Extra.Marshal as Marshal
import qualified LLVM.Extra.Memory as Memory

import qualified LLVM.Core as LLVM

import qualified Type.Data.Num.Decimal as TypeNum

import qualified Foreign.Storable as St
import Data.Word (Word32)

import Control.Monad (liftM2, liftM3, foldM, replicateM, (<=<))
import Control.Applicative (liftA2)
import qualified Control.Monad.Trans.State as MS
import qualified Control.Applicative as App
import qualified Data.Traversable as Trav

import qualified Data.NonEmpty.Class as NonEmptyC
import qualified Data.NonEmpty as NonEmpty
import qualified Data.List.HT as ListHT
import qualified Data.List as List
import Data.Tuple.HT (mapSnd, fst3, snd3, thd3)

import Prelude hiding (Read, replicate, reverse, iterate)

This datatype can be used for both Haskell vector and LLVM.Value Vector.
It should not contain tuples of vectors,
since the interpretation is:
"Everything inside Cons will be virtually concatenated."

We tried to use distinct types (T n a) and (Value n a)
for Haskell and LLVM objects, respectively,
but then GHC-6.12.3 to GHC-7.4.1 could not perform the GeneralizedNewtypeDeriving,
because it was not able to add a (IsPositive n ~ True) constraint
to the instances.

The disadvantage of this approach is,
that we cannot have a type that contains both parallel and serial data.
newtype T v = Cons v
   deriving (
      Eq, St.Storable,
      Tuple.Zero, Tuple.Undefined,
      A.IntegerConstant, A.RationalConstant, Num)
--      SoV.IntegerConstant, SoV.RationalConstant, SoV.TranscendentalConstant)

instance (Tuple.Phi v) => Tuple.Phi (T v) where
   phi bb (Cons v) = fmap Cons $ Tuple.phi bb v
   addPhi bb (Cons x) (Cons y) = Tuple.addPhi bb x y

instance (A.Additive v) => A.Additive (T v) where
   add = lift2 A.add
   sub = lift2 A.sub
   neg = lift1 A.neg
   zero = Cons A.zero

instance (A.PseudoRing v) => A.PseudoRing (T v) where
   mul = lift2 A.mul

instance (A.Real v) => A.Real (T v) where
   min = lift2 A.min
   max = lift2 A.max
   abs = lift1 A.abs
   signum = lift1 A.signum

instance (A.Fraction v) => A.Fraction (T v) where
   truncate = lift1 A.truncate
   fraction = lift1 A.fraction

instance (A.Field v) => A.Field (T v) where
   fdiv = lift2 A.fdiv

instance (A.Algebraic v) => A.Algebraic (T v) where
   sqrt = lift1 A.sqrt

instance (A.Transcendental v) => A.Transcendental (T v) where
   pi  = fmap Cons A.pi
   sin = lift1 A.sin
   log = lift1 A.log
   exp = lift1 A.exp
   cos = lift1 A.cos
   pow = lift2 A.pow

lift1 :: Functor f => (a -> f b) -> T a -> f (T b)
lift1 f (Cons x) = fmap Cons $ f x

lift2 :: Functor f => (a -> b -> f c) -> T a -> T b -> f (T c)
lift2 f (Cons x) (Cons y) = fmap Cons $ f x y

type instance A.Scalar (T v) = A.Scalar v
instance (A.PseudoModule v) => A.PseudoModule (T v) where
   scale a (Cons v) = fmap Cons $ A.scale a v

type Plain n a = T (LLVM.Vector n a)
type Value n a = T (LLVM.Value (LLVM.Vector n a))

plain :: LLVM.Vector n a -> Plain n a
plain = Cons

value :: LLVM.Value (LLVM.Vector n a) -> Value n a
value = Cons

replicate :: (TypeNum.Positive n) => a -> Plain n a
replicate x = Cons $ App.pure x

replicate_ :: (TypeNum.Positive n) => TypeNum.Singleton n -> a -> Plain n a
replicate_ _ = replicate

replicateOf :: (TypeNum.Positive n, LLVM.IsPrimitive a, LLVM.IsConst a) => a -> Value n a
replicateOf x = Cons $ LLVM.valueOf $ App.pure x

fromList :: (TypeNum.Positive n) => NonEmpty.T [] a -> Plain n a
fromList = Cons . LLVM.cyclicVector

fromFixedList ::
   (TypeNum.Positive n) =>
   LLVM.FixedList (TypeNum.ToUnary n) a -> Plain n a
fromFixedList = Cons . LLVM.vector

constant :: (TypeNum.Positive n) => a -> T (Vector.Constant n a)
constant = Cons . Vector.constant

newtype Iterator mode it v = Iterator {unIterator :: it}
   deriving (Tuple.Undefined)

instance Tuple.Phi it => Tuple.Phi (Iterator mode it v) where
   phi bb (Iterator x) = fmap Iterator $ Tuple.phi bb x
   addPhi bb (Iterator x) (Iterator y) = Tuple.addPhi bb x y

type ReadIterator = Iterator ReadMode
type WriteIterator = Iterator WriteMode

data ReadMode
data WriteMode

instance (Memory.C it) => Memory.C (Iterator mode it v) where
   type Struct (Iterator mode it v) = Memory.Struct it
   load = Memory.loadNewtype Iterator
   store = Memory.storeNewtype (\(Iterator v) -> v)
   decompose = Memory.decomposeNewtype Iterator
   compose = Memory.composeNewtype (\(Iterator v) -> v)

fmapIt ::
   (ita -> itb) -> (va -> vb) ->
   Iterator mode ita va -> Iterator mode itb vb
fmapIt f _ (Iterator a) = Iterator (f a)

combineIt2 :: Iterator mode xa va -> Iterator mode xb vb -> Iterator mode (xa,xb) (va,vb)
combineIt2 (Iterator va) (Iterator vb) = Iterator (va,vb)

combineIt3 :: Iterator mode xa va -> Iterator mode xb vb -> Iterator mode xc vc -> Iterator mode (xa,xb,xc) (va,vb,vc)
combineIt3 (Iterator va) (Iterator vb) (Iterator vc) = Iterator (va,vb,vc)

combineItFunctor ::
   (Functor f) =>
   f (Iterator mode x v) -> Iterator mode (f x) (f v)
combineItFunctor =
   Iterator . fmap unIterator

sequenceItFunctor ::
   (Functor f) =>
   Iterator mode (f it) (f v) ->
   f (Iterator mode it v)
sequenceItFunctor =
   fmap Iterator . unIterator

   (TypeNum.Positive (Size v), Sized v,
    Tuple.Phi (ReadIt v), Tuple.Undefined (ReadIt v),
    Tuple.Phi v, Tuple.Undefined v) =>
      Read v where

   type Element v :: *
   type ReadIt v :: *

   extract :: LLVM.Value Word32 -> v -> LLVM.CodeGenFunction r (Element v)

   extractAll :: v -> LLVM.CodeGenFunction r [Element v]
   extractAll x =
         (flip extract x . LLVM.valueOf)
         (take (size x) [0..])

   readStart :: v -> LLVM.CodeGenFunction r (ReadIterator (ReadIt v) v)
   readNext ::
      ReadIterator (ReadIt v) v ->
      LLVM.CodeGenFunction r (Element v, ReadIterator (ReadIt v) v)

class (Read v, Tuple.Phi (WriteIt v), Tuple.Undefined (WriteIt v)) => C v where
   type WriteIt v :: *

   insert :: LLVM.Value Word32 -> Element v -> v -> LLVM.CodeGenFunction r v

   assemble :: [Element v] -> LLVM.CodeGenFunction r v
   assemble =
      foldM (\v (k,x) -> insert (LLVM.valueOf k) x v) Tuple.undef .
      zip [0..]

   writeStart :: LLVM.CodeGenFunction r (WriteIterator (WriteIt v) v)
   writeNext ::
      Element v -> WriteIterator (WriteIt v) v ->
      LLVM.CodeGenFunction r (WriteIterator (WriteIt v) v)
   writeStop :: WriteIterator (WriteIt v) v -> LLVM.CodeGenFunction r v

class (C v, Tuple.Phi (WriteIt v), Tuple.Zero (WriteIt v)) => Zero v where
   -- initializes the target with zeros
   -- you may only call 'writeStop' on the result of 'writeZero'
   writeZero :: LLVM.CodeGenFunction r (WriteIterator (WriteIt v) v)

instance (Vector.Simple v) => Sized (T v) where
   type Size (T v) = Vector.Size v

{- |
This instance also allows to wrap tuples of vectors,
but you cannot reasonably use them,
because it would mean to serialize vectors with different element types.
instance (Vector.Simple v) => Read (T v) where

   type Element (T v) = Vector.Element v
   type ReadIt (T v) = v

   extract k (Cons v) = Vector.extract k v

   readStart (Cons v) = return $ Iterator v
   readNext (Iterator v0) = do
      x <- Vector.extract (LLVM.valueOf 0) v0
      v1 <- Vector.rotateDown v0
      return (x, Iterator v1)

instance (Vector.C v) => C (T v) where
   type WriteIt (T v) = v

   insert k a (Cons v) = fmap Cons $ Vector.insert k a v

   writeStart = return (Iterator Tuple.undef)
   writeNext x (Iterator v0) = do
      v1 <- fmap snd $ Vector.shiftDown x v0
      return (Iterator v1)
   writeStop (Iterator v0) = return (Cons v0)

instance (Vector.C v, Tuple.Zero v) => Zero (T v) where
   writeZero = return (Iterator Tuple.zero)

   (Read va, Read vb, Size va ~ Size vb) =>
      Read (va, vb) where

   type Element (va, vb) = (Element va, Element vb)
   type ReadIt (va, vb) = (ReadIt va, ReadIt vb)

   extract k (va,vb) =
      liftM2 (,)
         (extract k va)
         (extract k vb)

   readStart (va,vb) =
      liftM2 combineIt2 (readStart va) (readStart vb)
   readNext it = do
      (a, ita) <- readNext $ fmapIt fst fst it
      (b, itb) <- readNext $ fmapIt snd snd it
      return ((a,b), combineIt2 ita itb)

   (C va, C vb, Size va ~ Size vb) =>
      C (va, vb) where

   type WriteIt (va, vb) = (WriteIt va, WriteIt vb)

   insert k (a,b) (va,vb) =
      liftM2 (,)
         (insert k a va)
         (insert k b vb)

   writeStart =
      liftM2 combineIt2 writeStart writeStart
   writeNext (a,b) it =
      liftM2 combineIt2
         (writeNext a $ fmapIt fst fst it)
         (writeNext b $ fmapIt snd snd it)
   writeStop it =
      liftM2 (,)
         (writeStop (fmapIt fst fst it))
         (writeStop (fmapIt snd snd it))

   (Zero va, Zero vb, Size va ~ Size vb) =>
      Zero (va, vb) where

   writeZero =
      liftM2 combineIt2 writeZero writeZero

   (Read va, Read vb, Read vc,
    Size va ~ Size vb,
    Size vb ~ Size vc) =>
      Read (va, vb, vc) where

   type Element (va, vb, vc) = (Element va, Element vb, Element vc)
   type ReadIt (va, vb, vc) = (ReadIt va, ReadIt vb, ReadIt vc)

   extract k (va,vb,vc) =
      liftM3 (,,)
         (extract k va)
         (extract k vb)
         (extract k vc)

   readStart (va,vb,vc) =
      liftM3 combineIt3 (readStart va) (readStart vb) (readStart vc)
   readNext it = do
      (a, ita) <- readNext $ fmapIt fst3 fst3 it
      (b, itb) <- readNext $ fmapIt snd3 snd3 it
      (c, itc) <- readNext $ fmapIt thd3 thd3 it
      return ((a,b,c), combineIt3 ita itb itc)

   (C va, C vb, C vc,
    Size va ~ Size vb,
    Size vb ~ Size vc) =>
      C (va, vb, vc) where

   type WriteIt (va, vb, vc) = (WriteIt va, WriteIt vb, WriteIt vc)

   insert k (a,b,c) (va,vb,vc) =
      liftM3 (,,)
         (insert k a va)
         (insert k b vb)
         (insert k c vc)

   writeStart =
      liftM3 combineIt3 writeStart writeStart writeStart
   writeNext (a,b,c) it =
      liftM3 combineIt3
         (writeNext a $ fmapIt fst3 fst3 it)
         (writeNext b $ fmapIt snd3 snd3 it)
         (writeNext c $ fmapIt thd3 thd3 it)
   writeStop it =
      liftM3 (,,)
         (writeStop (fmapIt fst3 fst3 it))
         (writeStop (fmapIt snd3 snd3 it))
         (writeStop (fmapIt thd3 thd3 it))

   (Zero va, Zero vb, Zero vc,
    Size va ~ Size vb,
    Size vb ~ Size vc) =>
      Zero (va, vb, vc) where

   writeZero =
      liftM3 combineIt3 writeZero writeZero writeZero

instance (Read v) => Read (Stereo.T v) where

   type Element (Stereo.T v) = Stereo.T (Element v)
   type ReadIt (Stereo.T v) = Stereo.T (ReadIt v)

   extract = extractTraversable

   readStart = readStartTraversable
   readNext = readNextTraversable

instance (C v) => C (Stereo.T v) where

   type WriteIt (Stereo.T v) = Stereo.T (WriteIt v)

   insert = insertTraversable

   writeStart = writeStartTraversable
   writeNext = writeNextTraversable
   writeStop = writeStopTraversable

instance (Zero v) => Zero (Stereo.T v) where

   writeZero = writeZeroTraversable

modify ::
   (C v) =>
   LLVM.Value Word32 ->
   (Element v -> LLVM.CodeGenFunction r (Element v)) ->
   v -> LLVM.CodeGenFunction r v
modify k f v = do
   flip (insert k) v =<< f =<< extract k v

subsample ::
   (Read v) =>
   v -> LLVM.CodeGenFunction r (Element v)
subsample v =
   extract (A.zero :: LLVM.Value Word32) v

-- this will be translated to an efficient pshufd
upsample ::
   (C v) =>
   Element v -> LLVM.CodeGenFunction r v
upsample x =
   withSize $ \n -> assemble $ List.replicate n x

cumulate ::
   (Vector.Arithmetic a, TypeNum.Positive n) =>
   LLVM.Value a -> Value n a ->
   LLVM.CodeGenFunction r (LLVM.Value a, Value n a)
cumulate x (Cons v) =
   fmap (mapSnd Cons) $ Vector.cumulate x v

mapPlain ::
   (TypeNum.Positive n) => (a -> b) -> Plain n a -> Plain n b
mapPlain f (Cons v) = Cons $ fmap f v

iteratePlain ::
   (TypeNum.Positive n) => (a -> a) -> a -> Plain n a
iteratePlain f x = fromList $ NonEmptyC.iterate f x

iterate ::
   (C v) =>
   (Element v -> LLVM.CodeGenFunction r (Element v)) ->
   Element v -> LLVM.CodeGenFunction r v
iterate f x =
   withSize $ \n ->
      assemble =<<
      (flip MS.evalStateT x $
       replicateM n $
       MS.StateT $ \x0 -> do x1 <- f x0; return (x0,x1))

reverse ::
   (C v) =>
   v -> LLVM.CodeGenFunction r v
reverse =
   assemble . List.reverse <=< extractAll

shiftUp ::
   (C v) =>
   Element v -> v -> LLVM.CodeGenFunction r (Element v, v)
shiftUp x v =
      (return (x,v))
      (\ys0 y -> fmap ((,) y) $ assemble (x:ys0))
   extractAll v

shiftUpMultiZero ::
   (C v, A.Additive (Element v)) =>
   Int -> v -> LLVM.CodeGenFunction r v
shiftUpMultiZero n v =
   assemble . take (size v) . (List.replicate n A.zero ++) =<< extractAll v

shiftDownMultiZero ::
   (C v, A.Additive (Element v)) =>
   Int -> v -> LLVM.CodeGenFunction r v
shiftDownMultiZero n v =
   assemble . take (size v) . (++ List.repeat A.zero) . List.drop n
      =<< extractAll v

insertTraversable ::
   (C v, Trav.Traversable f, App.Applicative f) =>
   LLVM.Value Word32 -> f (Element v) -> f v -> LLVM.CodeGenFunction r (f v)
insertTraversable n a v =
   Trav.sequence (liftA2 (insert n) a v)

extractTraversable ::
   (Read v, Trav.Traversable f) =>
   LLVM.Value Word32 -> f v -> LLVM.CodeGenFunction r (f (Element v))
extractTraversable n v =
   Trav.mapM (extract n) v

readStartTraversable ::
   (Trav.Traversable f, App.Applicative f, Read v) =>
   f v -> LLVM.CodeGenFunction r (ReadIterator (f (ReadIt v)) (f v))
readNextTraversable ::
   (Trav.Traversable f, App.Applicative f, Read v) =>
   ReadIterator (f (ReadIt v)) (f v) ->
   LLVM.CodeGenFunction r (f (Element v), ReadIterator (f (ReadIt v)) (f v))

readStartTraversable v =
   fmap combineItFunctor $ Trav.mapM readStart v

readNextTraversable it = do
   st <- Trav.mapM readNext $ sequenceItFunctor it
   return (fmap fst st, combineItFunctor $ fmap snd st)

writeStartTraversable ::
   (Trav.Traversable f, App.Applicative f, C v) =>
   LLVM.CodeGenFunction r (WriteIterator (f (WriteIt v)) (f v))
writeNextTraversable ::
   (Trav.Traversable f, App.Applicative f, C v) =>
   f (Element v) -> WriteIterator (f (WriteIt v)) (f v) ->
   LLVM.CodeGenFunction r (WriteIterator (f (WriteIt v)) (f v))
writeStopTraversable ::
   (Trav.Traversable f, App.Applicative f, C v) =>
   WriteIterator (f (WriteIt v)) (f v) -> LLVM.CodeGenFunction r (f v)
writeZeroTraversable ::
   (Trav.Traversable f, App.Applicative f, Zero v) =>
   LLVM.CodeGenFunction r (WriteIterator (f (WriteIt v)) (f v))

writeStartTraversable =
   fmap combineItFunctor $ Trav.sequence $ App.pure writeStart

writeNextTraversable x it =
   fmap combineItFunctor $ Trav.sequence $
   liftA2 writeNext x $ sequenceItFunctor it

writeStopTraversable = Trav.mapM writeStop . sequenceItFunctor

writeZeroTraversable =
   fmap combineItFunctor $ Trav.sequence $ App.pure writeZero

instance (Tuple.Value v) => Tuple.Value (T v) where
   type ValueOf (T v) = T (Tuple.ValueOf v)
   valueOf (Cons v) = Cons (Tuple.valueOf v)

instance (Memory.C v) => Memory.C (T v) where
   type Struct (T v) = Memory.Struct v
   load = Memory.loadNewtype Cons
   store = Memory.storeNewtype (\(Cons v) -> v)
   decompose = Memory.decomposeNewtype Cons
   compose = Memory.composeNewtype (\(Cons v) -> v)

instance (Marshal.C v) => Marshal.C (T v) where
   pack (Cons v) = Marshal.pack v
   unpack v = Cons $ Marshal.unpack v

instance (Storable.C v) => Storable.C (T v) where
   load = Storable.loadNewtype Cons Cons
   store = Storable.storeNewtype Cons (\(Cons v) -> v)

mapV :: (Functor m) =>
   (LLVM.Value (LLVM.Vector n a) -> m (LLVM.Value (LLVM.Vector n b))) ->
   Value n a -> m (Value n b)
mapV f (Cons x) = fmap Cons (f x)

zipV :: (Functor m) =>
   (c -> d) ->
   (LLVM.Value (LLVM.Vector n a) ->
    LLVM.Value (LLVM.Vector n b) ->
    m c) ->
   Value n a ->
   Value n b ->
   m d
zipV g f (Cons x) (Cons y) =
   fmap g (f x y)

withSize :: Sized v => (Int -> m v) -> m v
withSize =
   let sz :: (Sized v) => TypeNum.Singleton (Size v) -> (Int -> m v) -> m v
       sz n f = f (TypeNum.integralFromSingleton n)
   in  sz TypeNum.singleton

size :: Sized v => v -> Int
size =
   let sz :: (Sized v) => TypeNum.Singleton (Size v) -> v -> Int
       sz n _ = TypeNum.integralFromSingleton n
   in  sz TypeNum.singleton

sizeOfIterator :: Sized v => Iterator mode it v -> Int
sizeOfIterator =
   let sz :: Sized v => TypeNum.Singleton (Size v) -> Iterator mode it v -> Int
       sz n _ = TypeNum.integralFromSingleton n
   in  sz TypeNum.singleton

{- |
The type parameter @value@ shall be a virtual LLVM register
or a wrapper around one or more virtual LLVM registers.
class (TypeNum.Positive (Size valueTuple)) => Sized valueTuple where
   type Size valueTuple :: *

{- |
Basic LLVM types are all counted as scalar values, even LLVM Vectors.
This means that an LLVM Vector can be used for parallel handling of data.
instance Sized (LLVM.Value a) where
   type Size (LLVM.Value a) = TypeNum.D1

instance (Sized value) => Sized (Stereo.T value) where
   type Size (Stereo.T value) = Size value

   (Sized value0, Sized value1,
    Size value0 ~ Size value1) =>
      Sized (value0, value1) where
   type Size (value0, value1) = Size value0

   (Sized value0, Sized value1, Sized value2,
    Size value0 ~ Size value1,
    Size value1 ~ Size value2) =>
      Sized (value0, value1, value2) where
   type Size (value0, value1, value2) = Size value0