{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE EmptyDataDecls #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE StandaloneDeriving #-}
{- |
A special vector type that represents a time-sequence of samples.
This way we can distinguish safely between LLVM vectors
used for parallel signals and pipelines and
those used for chunky processing of scalar signals.
For the chunky processing this data type allows us
to derive the factor from the type
that time constants have to be multiplied with.
-}
module Synthesizer.LLVM.Frame.SerialVector (
   T(Cons),
   Plain, Value,
   plain, value, constant,

   Read, Element, ReadIt, extract, readStart, readNext,
   C, WriteIt, insert, writeStart, writeNext, writeStop,
   Zero, writeZero,
   Iterator(Iterator), ReadIterator, WriteIterator, ReadMode, WriteMode,

   Sized, Size, size, sizeOfIterator, withSize,

   insertTraversable, extractTraversable,
   readStartTraversable, readNextTraversable,
   writeStartTraversable, writeNextTraversable, writeStopTraversable,
   writeZeroTraversable,

   extractAll, assemble, modify,
   upsample, subsample,
   cumulate, iterate, iteratePlain, reverse,
   shiftUp, shiftUpMultiZero, shiftDownMultiZero,
   replicate, replicateOf, fromList, fromFixedList,
   mapPlain, mapV, zipV,
   ) where

import qualified Synthesizer.LLVM.Frame.Stereo as Stereo

import qualified LLVM.Extra.Vector as Vector
import qualified LLVM.Extra.Arithmetic as A
import qualified LLVM.Extra.Class as Class
import qualified LLVM.Extra.Memory as Memory
import LLVM.Extra.Class (MakeValueTuple, valueTupleOf, )

import qualified LLVM.Util.Loop as Loop
import qualified LLVM.Core as LLVM
import LLVM.Util.Loop (Phi, )

import qualified Type.Data.Num.Decimal as TypeNum

import qualified Foreign.Storable as St
import Data.Word (Word32, )

import Control.Monad (liftM2, liftM3, foldM, replicateM, (<=<), )
import Control.Applicative (liftA2, )
import qualified Control.Monad.Trans.State as MS
import qualified Control.Applicative as App
import qualified Data.Traversable as Trav

import qualified Data.NonEmpty.Class as NonEmptyC
import qualified Data.NonEmpty as NonEmpty
import qualified Data.List.HT as ListHT
import qualified Data.List as List
import Data.Tuple.HT (mapSnd, fst3, snd3, thd3, )

import Prelude hiding (Read, replicate, reverse, iterate, )


{-
This datatype can be used for both Haskell vector and LLVM.Value Vector.
It should not contain tuples of vectors,
since the interpretation is:
"Everything inside Cons will be virtually concatenated."

We tried to use distinct types (T n a) and (Value n a)
for Haskell and LLVM objects, respectively,
but then GHC-6.12.3 to GHC-7.4.1 could not perform the GeneralizedNewtypeDeriving,
because it was not able to add a (IsPositive n ~ True) constraint
to the instances.

The disadvantage of this approach is,
that we cannot have a type that contains both parallel and serial data.
-}
newtype T v = Cons v
   deriving (
      Eq, St.Storable,
      Class.Zero, Class.Undefined,
      A.IntegerConstant, A.RationalConstant, Num)
--      SoV.IntegerConstant, SoV.RationalConstant, SoV.TranscendentalConstant)

instance (Phi v) => Phi (T v) where
   phis bb (Cons v) = fmap Cons $ Loop.phis bb v
   addPhis bb (Cons x) (Cons y) = Loop.addPhis bb x y

instance (A.Additive v) => A.Additive (T v) where
   add = lift2 A.add
   sub = lift2 A.sub
   neg = lift1 A.neg
   zero = Cons A.zero

instance (A.PseudoRing v) => A.PseudoRing (T v) where
   mul = lift2 A.mul

instance (A.Real v) => A.Real (T v) where
   min = lift2 A.min
   max = lift2 A.max
   abs = lift1 A.abs
   signum = lift1 A.signum

instance (A.Fraction v) => A.Fraction (T v) where
   truncate = lift1 A.truncate
   fraction = lift1 A.fraction

instance (A.Field v) => A.Field (T v) where
   fdiv = lift2 A.fdiv

instance (A.Algebraic v) => A.Algebraic (T v) where
   sqrt = lift1 A.sqrt

instance (A.Transcendental v) => A.Transcendental (T v) where
   pi  = fmap Cons A.pi
   sin = lift1 A.sin
   log = lift1 A.log
   exp = lift1 A.exp
   cos = lift1 A.cos
   pow = lift2 A.pow


lift1 :: Functor f => (a -> f b) -> T a -> f (T b)
lift1 f (Cons x) = fmap Cons $ f x

lift2 :: Functor f => (a -> b -> f c) -> T a -> T b -> f (T c)
lift2 f (Cons x) (Cons y) = fmap Cons $ f x y


type instance A.Scalar (T v) = A.Scalar v
instance (A.PseudoModule v) => A.PseudoModule (T v) where
   scale a (Cons v) = fmap Cons $ A.scale a v


type Plain n a = T (LLVM.Vector n a)
type Value n a = T (LLVM.Value (LLVM.Vector n a))


plain :: LLVM.Vector n a -> Plain n a
plain = Cons

value :: LLVM.Value (LLVM.Vector n a) -> Value n a
value = Cons


replicate :: (TypeNum.Positive n) => a -> Plain n a
replicate x = Cons $ App.pure x

replicateOf :: (TypeNum.Positive n, LLVM.IsPrimitive a, LLVM.IsConst a) => a -> Value n a
replicateOf x = Cons $ LLVM.valueOf $ App.pure x

fromList :: (TypeNum.Positive n) => NonEmpty.T [] a -> Plain n a
fromList = Cons . LLVM.cyclicVector

fromFixedList ::
   (TypeNum.Positive n) =>
   LLVM.FixedList (TypeNum.ToUnary n) a -> Plain n a
fromFixedList = Cons . LLVM.vector

constant :: (TypeNum.Positive n) => a -> T (Vector.Constant n a)
constant = Cons . Vector.constant


newtype Iterator mode it v = Iterator {unIterator :: it}
   deriving (Class.Undefined)

instance Phi it => Phi (Iterator mode it v) where
   phis bb (Iterator x) = fmap Iterator $ Loop.phis bb x
   addPhis bb (Iterator x) (Iterator y) = Loop.addPhis bb x y


type ReadIterator = Iterator ReadMode
type WriteIterator = Iterator WriteMode

data ReadMode
data WriteMode


instance (Memory.C it) => Memory.C (Iterator mode it v) where
   type Struct (Iterator mode it v) = Memory.Struct it
   load = Memory.loadNewtype Iterator
   store = Memory.storeNewtype (\(Iterator v) -> v)
   decompose = Memory.decomposeNewtype Iterator
   compose = Memory.composeNewtype (\(Iterator v) -> v)


fmapIt ::
   (ita -> itb) -> (va -> vb) ->
   Iterator mode ita va -> Iterator mode itb vb
fmapIt f _ (Iterator a) = Iterator (f a)


combineIt2 :: Iterator mode xa va -> Iterator mode xb vb -> Iterator mode (xa,xb) (va,vb)
combineIt2 (Iterator va) (Iterator vb) = Iterator (va,vb)

combineIt3 :: Iterator mode xa va -> Iterator mode xb vb -> Iterator mode xc vc -> Iterator mode (xa,xb,xc) (va,vb,vc)
combineIt3 (Iterator va) (Iterator vb) (Iterator vc) = Iterator (va,vb,vc)

combineItFunctor ::
   (Functor f) =>
   f (Iterator mode x v) -> Iterator mode (f x) (f v)
combineItFunctor =
   Iterator . fmap unIterator

sequenceItFunctor ::
   (Functor f) =>
   Iterator mode (f it) (f v) ->
   f (Iterator mode it v)
sequenceItFunctor =
   fmap Iterator . unIterator


class
   (TypeNum.Positive (Size v), Sized v,
    Phi (ReadIt v), Class.Undefined (ReadIt v),
    Phi v, Class.Undefined v) =>
      Read v where

   type Element v :: *
   type ReadIt v :: *

   extract :: LLVM.Value Word32 -> v -> LLVM.CodeGenFunction r (Element v)

   extractAll :: v -> LLVM.CodeGenFunction r [Element v]
   extractAll x =
      mapM
         (flip extract x . LLVM.valueOf)
         (take (size x) [0..])

   readStart :: v -> LLVM.CodeGenFunction r (ReadIterator (ReadIt v) v)
   readNext ::
      ReadIterator (ReadIt v) v ->
      LLVM.CodeGenFunction r (Element v, ReadIterator (ReadIt v) v)

class (Read v, Phi (WriteIt v), Class.Undefined (WriteIt v)) => C v where
   type WriteIt v :: *

   insert :: LLVM.Value Word32 -> Element v -> v -> LLVM.CodeGenFunction r v

   assemble :: [Element v] -> LLVM.CodeGenFunction r v
   assemble =
      foldM (\v (k,x) -> insert (LLVM.valueOf k) x v) Class.undefTuple .
      zip [0..]

   writeStart :: LLVM.CodeGenFunction r (WriteIterator (WriteIt v) v)
   writeNext ::
      Element v -> WriteIterator (WriteIt v) v ->
      LLVM.CodeGenFunction r (WriteIterator (WriteIt v) v)
   writeStop :: WriteIterator (WriteIt v) v -> LLVM.CodeGenFunction r v

class (C v, Phi (WriteIt v), Class.Zero (WriteIt v)) => Zero v where
   -- initializes the target with zeros
   -- you may only call 'writeStop' on the result of 'writeZero'
   writeZero :: LLVM.CodeGenFunction r (WriteIterator (WriteIt v) v)

instance (Vector.Simple v) => Sized (T v) where
   type Size (T v) = Vector.Size v

{- |
This instance also allows to wrap tuples of vectors,
but you cannot reasonably use them,
because it would mean to serialize vectors with different element types.
-}
instance (Vector.Simple v) => Read (T v) where

   type Element (T v) = Vector.Element v
   type ReadIt (T v) = v

   extract k (Cons v) = Vector.extract k v

   readStart (Cons v) = return $ Iterator v
   readNext (Iterator v0) = do
      x <- Vector.extract (LLVM.valueOf 0) v0
      v1 <- Vector.rotateDown v0
      return (x, Iterator v1)


instance (Vector.C v) => C (T v) where
   type WriteIt (T v) = v

   insert k a (Cons v) = fmap Cons $ Vector.insert k a v

   writeStart = return (Iterator Class.undefTuple)
   writeNext x (Iterator v0) = do
      v1 <- fmap snd $ Vector.shiftDown x v0
      return (Iterator v1)
   writeStop (Iterator v0) = return (Cons v0)

instance (Vector.C v, Class.Zero v) => Zero (T v) where
   writeZero = return (Iterator Class.zeroTuple)


instance
   (Read va, Read vb, Size va ~ Size vb) =>
      Read (va, vb) where

   type Element (va, vb) = (Element va, Element vb)
   type ReadIt (va, vb) = (ReadIt va, ReadIt vb)

   extract k (va,vb) =
      liftM2 (,)
         (extract k va)
         (extract k vb)

   readStart (va,vb) =
      liftM2 combineIt2 (readStart va) (readStart vb)
   readNext it = do
      (a, ita) <- readNext $ fmapIt fst fst it
      (b, itb) <- readNext $ fmapIt snd snd it
      return ((a,b), combineIt2 ita itb)

instance
   (C va, C vb, Size va ~ Size vb) =>
      C (va, vb) where

   type WriteIt (va, vb) = (WriteIt va, WriteIt vb)

   insert k (a,b) (va,vb) =
      liftM2 (,)
         (insert k a va)
         (insert k b vb)

   writeStart =
      liftM2 combineIt2 writeStart writeStart
   writeNext (a,b) it =
      liftM2 combineIt2
         (writeNext a $ fmapIt fst fst it)
         (writeNext b $ fmapIt snd snd it)
   writeStop it =
      liftM2 (,)
         (writeStop (fmapIt fst fst it))
         (writeStop (fmapIt snd snd it))

instance
   (Zero va, Zero vb, Size va ~ Size vb) =>
      Zero (va, vb) where

   writeZero =
      liftM2 combineIt2 writeZero writeZero


instance
   (Read va, Read vb, Read vc,
    Size va ~ Size vb,
    Size vb ~ Size vc) =>
      Read (va, vb, vc) where

   type Element (va, vb, vc) = (Element va, Element vb, Element vc)
   type ReadIt (va, vb, vc) = (ReadIt va, ReadIt vb, ReadIt vc)

   extract k (va,vb,vc) =
      liftM3 (,,)
         (extract k va)
         (extract k vb)
         (extract k vc)

   readStart (va,vb,vc) =
      liftM3 combineIt3 (readStart va) (readStart vb) (readStart vc)
   readNext it = do
      (a, ita) <- readNext $ fmapIt fst3 fst3 it
      (b, itb) <- readNext $ fmapIt snd3 snd3 it
      (c, itc) <- readNext $ fmapIt thd3 thd3 it
      return ((a,b,c), combineIt3 ita itb itc)


instance
   (C va, C vb, C vc,
    Size va ~ Size vb,
    Size vb ~ Size vc) =>
      C (va, vb, vc) where

   type WriteIt (va, vb, vc) = (WriteIt va, WriteIt vb, WriteIt vc)

   insert k (a,b,c) (va,vb,vc) =
      liftM3 (,,)
         (insert k a va)
         (insert k b vb)
         (insert k c vc)

   writeStart =
      liftM3 combineIt3 writeStart writeStart writeStart
   writeNext (a,b,c) it =
      liftM3 combineIt3
         (writeNext a $ fmapIt fst3 fst3 it)
         (writeNext b $ fmapIt snd3 snd3 it)
         (writeNext c $ fmapIt thd3 thd3 it)
   writeStop it =
      liftM3 (,,)
         (writeStop (fmapIt fst3 fst3 it))
         (writeStop (fmapIt snd3 snd3 it))
         (writeStop (fmapIt thd3 thd3 it))

instance
   (Zero va, Zero vb, Zero vc,
    Size va ~ Size vb,
    Size vb ~ Size vc) =>
      Zero (va, vb, vc) where

   writeZero =
      liftM3 combineIt3 writeZero writeZero writeZero


instance (Read v) => Read (Stereo.T v) where

   type Element (Stereo.T v) = Stereo.T (Element v)
   type ReadIt (Stereo.T v) = Stereo.T (ReadIt v)

   extract = extractTraversable

   readStart = readStartTraversable
   readNext = readNextTraversable

instance (C v) => C (Stereo.T v) where

   type WriteIt (Stereo.T v) = Stereo.T (WriteIt v)

   insert = insertTraversable

   writeStart = writeStartTraversable
   writeNext = writeNextTraversable
   writeStop = writeStopTraversable

instance (Zero v) => Zero (Stereo.T v) where

   writeZero = writeZeroTraversable


modify ::
   (C v) =>
   LLVM.Value Word32 ->
   (Element v -> LLVM.CodeGenFunction r (Element v)) ->
   v -> LLVM.CodeGenFunction r v
modify k f v = do
   flip (insert k) v =<< f =<< extract k v


subsample ::
   (Read v) =>
   v -> LLVM.CodeGenFunction r (Element v)
subsample v =
   extract (A.zero :: LLVM.Value Word32) v

-- this will be translated to an efficient pshufd
upsample ::
   (C v) =>
   Element v -> LLVM.CodeGenFunction r v
upsample x =
   withSize $ \n -> assemble $ List.replicate n x


cumulate ::
   (Vector.Arithmetic a, TypeNum.Positive n) =>
   LLVM.Value a -> Value n a ->
   LLVM.CodeGenFunction r (LLVM.Value a, Value n a)
cumulate x (Cons v) =
   fmap (mapSnd Cons) $ Vector.cumulate x v


mapPlain ::
   (TypeNum.Positive n) => (a -> b) -> Plain n a -> Plain n b
mapPlain f (Cons v) = Cons $ fmap f v

iteratePlain ::
   (TypeNum.Positive n) => (a -> a) -> a -> Plain n a
iteratePlain f x = fromList $ NonEmptyC.iterate f x

iterate ::
   (C v) =>
   (Element v -> LLVM.CodeGenFunction r (Element v)) ->
   Element v -> LLVM.CodeGenFunction r v
iterate f x =
   withSize $ \n ->
      assemble =<<
      (flip MS.evalStateT x $
       replicateM n $
       MS.StateT $ \x0 -> do x1 <- f x0; return (x0,x1))

reverse ::
   (C v) =>
   v -> LLVM.CodeGenFunction r v
reverse =
   assemble . List.reverse <=< extractAll

shiftUp ::
   (C v) =>
   Element v -> v -> LLVM.CodeGenFunction r (Element v, v)
shiftUp x v =
   ListHT.switchR
      (return (x,v))
      (\ys0 y -> fmap ((,) y) $ assemble (x:ys0))
   =<<
   extractAll v


shiftUpMultiZero ::
   (C v, A.Additive (Element v)) =>
   Int -> v -> LLVM.CodeGenFunction r v
shiftUpMultiZero n v =
   assemble . take (size v) . (List.replicate n A.zero ++) =<< extractAll v

shiftDownMultiZero ::
   (C v, A.Additive (Element v)) =>
   Int -> v -> LLVM.CodeGenFunction r v
shiftDownMultiZero n v =
   assemble . take (size v) . (++ List.repeat A.zero) . List.drop n
      =<< extractAll v


insertTraversable ::
   (C v, Trav.Traversable f, App.Applicative f) =>
   LLVM.Value Word32 -> f (Element v) -> f v -> LLVM.CodeGenFunction r (f v)
insertTraversable n a v =
   Trav.sequence (liftA2 (insert n) a v)

extractTraversable ::
   (Read v, Trav.Traversable f) =>
   LLVM.Value Word32 -> f v -> LLVM.CodeGenFunction r (f (Element v))
extractTraversable n v =
   Trav.mapM (extract n) v


readStartTraversable ::
   (Trav.Traversable f, App.Applicative f, Read v) =>
   f v -> LLVM.CodeGenFunction r (ReadIterator (f (ReadIt v)) (f v))
readNextTraversable ::
   (Trav.Traversable f, App.Applicative f, Read v) =>
   ReadIterator (f (ReadIt v)) (f v) ->
   LLVM.CodeGenFunction r (f (Element v), ReadIterator (f (ReadIt v)) (f v))

readStartTraversable v =
   fmap combineItFunctor $ Trav.mapM readStart v

readNextTraversable it = do
   st <- Trav.mapM readNext $ sequenceItFunctor it
   return (fmap fst st, combineItFunctor $ fmap snd st)


writeStartTraversable ::
   (Trav.Traversable f, App.Applicative f, C v) =>
   LLVM.CodeGenFunction r (WriteIterator (f (WriteIt v)) (f v))
writeNextTraversable ::
   (Trav.Traversable f, App.Applicative f, C v) =>
   f (Element v) -> WriteIterator (f (WriteIt v)) (f v) ->
   LLVM.CodeGenFunction r (WriteIterator (f (WriteIt v)) (f v))
writeStopTraversable ::
   (Trav.Traversable f, App.Applicative f, C v) =>
   WriteIterator (f (WriteIt v)) (f v) -> LLVM.CodeGenFunction r (f v)
writeZeroTraversable ::
   (Trav.Traversable f, App.Applicative f, Zero v) =>
   LLVM.CodeGenFunction r (WriteIterator (f (WriteIt v)) (f v))

writeStartTraversable =
   fmap combineItFunctor $ Trav.sequence $ App.pure writeStart

writeNextTraversable x it =
   fmap combineItFunctor $ Trav.sequence $
   liftA2 writeNext x $ sequenceItFunctor it

writeStopTraversable = Trav.mapM writeStop . sequenceItFunctor

writeZeroTraversable =
   fmap combineItFunctor $ Trav.sequence $ App.pure writeZero


instance (MakeValueTuple v) => MakeValueTuple (T v) where
   type ValueTuple (T v) = T (Class.ValueTuple v)
   valueTupleOf (Cons v) = Cons (Class.valueTupleOf v)

instance (Memory.C v) => Memory.C (T v) where
   type Struct (T v) = Memory.Struct v
   load = Memory.loadNewtype Cons
   store = Memory.storeNewtype (\(Cons v) -> v)
   decompose = Memory.decomposeNewtype Cons
   compose = Memory.composeNewtype (\(Cons v) -> v)


mapV :: (Functor m) =>
   (LLVM.Value (LLVM.Vector n a) -> m (LLVM.Value (LLVM.Vector n b))) ->
   Value n a -> m (Value n b)
mapV f (Cons x) = fmap Cons (f x)

zipV :: (Functor m) =>
   (c -> d) ->
   (LLVM.Value (LLVM.Vector n a) ->
    LLVM.Value (LLVM.Vector n b) ->
    m c) ->
   Value n a ->
   Value n b ->
   m d
zipV g f (Cons x) (Cons y) =
   fmap g (f x y)



withSize :: Sized v => (Int -> m v) -> m v
withSize =
   let sz :: (Sized v) => TypeNum.Singleton (Size v) -> (Int -> m v) -> m v
       sz n f = f (TypeNum.integralFromSingleton n)
   in  sz TypeNum.singleton

size :: Sized v => v -> Int
size =
   let sz :: (Sized v) => TypeNum.Singleton (Size v) -> v -> Int
       sz n _ = TypeNum.integralFromSingleton n
   in  sz TypeNum.singleton

sizeOfIterator :: Sized v => Iterator mode it v -> Int
sizeOfIterator =
   let sz :: Sized v => TypeNum.Singleton (Size v) -> Iterator mode it v -> Int
       sz n _ = TypeNum.integralFromSingleton n
   in  sz TypeNum.singleton


{- |
The type parameter @value@ shall be a virtual LLVM register
or a wrapper around one or more virtual LLVM registers.
-}
class (TypeNum.Positive (Size valueTuple)) => Sized valueTuple where
   type Size valueTuple :: *
   serialSize :: valueTuple -> Size valueTuple
   serialSize _ = error "serial size is a type number and has no value"

{- |
Basic LLVM types are all counted as scalar values, even LLVM Vectors.
This means that an LLVM Vector can be used for parallel handling of data.
-}
instance Sized (LLVM.Value a) where
   type Size (LLVM.Value a) = TypeNum.D1

instance (Sized value) => Sized (Stereo.T value) where
   type Size (Stereo.T value) = Size value

instance
   (Sized value0, Sized value1,
    Size value0 ~ Size value1) =>
      Sized (value0, value1) where
   type Size (value0, value1) = Size value0

instance
   (Sized value0, Sized value1, Sized value2,
    Size value0 ~ Size value1,
    Size value1 ~ Size value2) =>
      Sized (value0, value1, value2) where
   type Size (value0, value1, value2) = Size value0