{-# OPTIONS_GHC -O -fglasgow-exts -fno-implicit-prelude #-}
{- glasgow-exts are for higher rank types -}
module Synthesizer.State.Signal where

import qualified Synthesizer.Generic.Signal as SigG
import qualified Synthesizer.Generic.SampledValue as Sample

-- import qualified Synthesizer.Plain.Signal   as Sig
import qualified Synthesizer.Plain.Modifier as Modifier
import qualified Data.List as List

import qualified Algebra.Module   as Module
import qualified Algebra.Additive as Additive
import Algebra.Additive (zero)

import Algebra.Module ((*>))

import qualified Synthesizer.Format as Format

import Control.Monad.State
          (State, runState, StateT(StateT), runStateT, liftM2, )
import Control.Monad (Monad, mplus, msum,
           (>>), (>>=), fail, return, (=<<),
           Functor, fmap, )

import qualified Synthesizer.Storable.Signal as SigSt
import Foreign.Storable (Storable)

import Synthesizer.Utility
   (viewListL, mapFst, mapSnd, mapPair, fst3, snd3, thd3, nest, )

import NumericPrelude.Condition (toMaybe)
import NumericPrelude (fromInteger, )

import Text.Show (Show(showsPrec), showParen, showString, )
import Data.Maybe (Maybe(Just, Nothing), maybe, fromMaybe, )
import Prelude
   ((.), ($), ($!), id, const, flip, curry, uncurry, fst, snd, error,
    (>), (>=), max, Ord,
    succ, pred, Bool(True,False), not, Int,
--    fromInteger,
    )


-- | Cf. StreamFusion  Data.Stream
data T a =
   forall s. -- Seq s =>
      Cons !(StateT s Maybe a)  -- compute next value
           !s                   -- initial state


instance (Show y) => Show (T y) where
   showsPrec p x =
      showParen (p >= 10)
         (showString "StateSignal.fromList " . showsPrec 11 (toList x))

instance Format.C T where
   format = showsPrec

instance Functor T where
   fmap = map


instance SigG.C T where
   empty = empty
   null = null
   cons = cons
   fromList = fromList
   toList = toList
   repeat = repeat
   cycle = cycle
   replicate = replicate
   iterate = iterate
   iterateAssoc op x = iterate (op x) x -- should be optimized
   unfoldR = generate
   map = map
   mix = mix
   zipWith = zipWith
   scanL = scanL
   viewL = viewL
   viewR = viewR
   foldL = foldL
   length = length
   take = take
   drop = drop
   splitAt = splitAt
   dropMarginRem = dropMarginRem
   takeWhile = takeWhile
   dropWhile = dropWhile
   span = span
   append = append
   concat = concat
   reverse = reverse
{-
   mapAccumL = mapAccumL
   mapAccumR = mapAccumR
-}
   crochetL = crochetL




{-# INLINE generate #-}
generate :: (acc -> Maybe (y, acc)) -> acc -> T y
generate f = Cons (StateT f)

{-# INLINE unfoldR #-}
unfoldR :: (acc -> Maybe (y, acc)) -> acc -> T y
unfoldR = generate

{-# INLINE generateInfinite #-}
generateInfinite :: (acc -> (y, acc)) -> acc -> T y
generateInfinite f = generate (Just . f)

{-# INLINE fromList #-}
fromList :: [y] -> T y
fromList = generate viewListL

{-# INLINE toList #-}
toList :: T y -> [y]
toList (Cons f x0) =
   List.unfoldr (runStateT f) x0


{-# INLINE fromGenericSignal #-}
fromGenericSignal ::
   (Sample.C a, SigG.C sig) =>
   sig a -> T a
fromGenericSignal =
   generate SigG.viewL

{-# INLINE toGenericSignal #-}
toGenericSignal ::
   (Sample.C a, SigG.C sig) =>
   T a -> sig a
toGenericSignal (Cons f a) =
   SigG.unfoldR (runStateT f) a


{-# INLINE fromStorableSignal #-}
fromStorableSignal ::
   (Storable a) =>
   SigSt.T a -> T a
fromStorableSignal =
   generate SigSt.viewL

{-# INLINE toStorableSignal #-}
toStorableSignal ::
   (Storable a) =>
   SigSt.ChunkSize -> T a -> SigSt.T a
toStorableSignal size (Cons f a) =
   SigSt.unfoldr size (runStateT f) a



{-# INLINE iterate #-}
iterate :: (a -> a) -> a -> T a
iterate f = generateInfinite (\x -> (x, f x))

{-# INLINE iterateAssoc #-}
iterateAssoc :: (a -> a -> a) -> a -> T a
iterateAssoc op x = iterate (op x) x -- should be optimized

{-# INLINE repeat #-}
repeat :: a -> T a
repeat = iterate id




{-# INLINE crochetL #-}
crochetL :: (x -> acc -> Maybe (y, acc)) -> acc -> T x -> T y
crochetL g b (Cons f a) =
   Cons
      (StateT (\(a0,b0) ->
          do (x0,a1) <- runStateT f a0
             (y0,b1) <- g x0 b0
             Just (y0, (a1,b1))))
      (a,b)


{-# INLINE scanL #-}
scanL :: (acc -> x -> acc) -> acc -> T x -> T acc
scanL f start =
   cons start .
   crochetL (\x acc -> let y = f acc x in Just (y, y)) start


{-# INLINE scanLClip #-}
-- | input and output have equal length, that's better for fusion
scanLClip :: (acc -> x -> acc) -> acc -> T x -> T acc
scanLClip f start =
   crochetL (\x acc -> Just (acc, f acc x)) start

{-# INLINE map #-}
map :: (a -> b) -> (T a -> T b)
map f = crochetL (\x _ -> Just (f x, ())) ()


{-# INLINE unzip #-}
unzip :: T (a,b) -> (T a, T b)
unzip x = (map fst x, map snd x)

{-# INLINE unzip3 #-}
unzip3 :: T (a,b,c) -> (T a, T b, T c)
unzip3 xs = (map fst3 xs, map snd3 xs, map thd3 xs)


{-# INLINE delay1 #-}
{- |
This is a fusion friendly implementation of delay.
However, in order to be a 'crochetL'
the output has the same length as the input,
that is, the last element is removed - at least for finite input.
-}
delay1 :: a -> T a -> T a
delay1 = crochetL (flip (curry Just))

{-# INLINE delay #-}
delay :: y -> Int -> T y -> T y
delay z n = append (replicate n z)

{-# INLINE take #-}
take :: Int -> T a -> T a
take = crochetL (\x n -> toMaybe (n>zero) (x, pred n))

{-# INLINE takeWhile #-}
takeWhile :: (a -> Bool) -> T a -> T a
takeWhile p = crochetL (\x _ -> toMaybe (p x) (x, ())) ()

{-# INLINE replicate #-}
replicate :: Int -> a -> T a
replicate n = take n . repeat


{- * functions consuming multiple lists -}

{-# INLINE zipWith #-}
zipWith :: (a -> b -> c) -> (T a -> T b -> T c)
zipWith h (Cons f a) =
   crochetL
      (\x0 a0 ->
          do (y0,a1) <- runStateT f a0
             Just (h y0 x0, a1))
      a

{-# INLINE zipWith3 #-}
zipWith3 :: (a -> b -> c -> d) -> (T a -> T b -> T c -> T d)
zipWith3 f s0 s1 =
   zipWith (uncurry f) (zip s0 s1)

{-# INLINE zipWith4 #-}
zipWith4 :: (a -> b -> c -> d -> e) -> (T a -> T b -> T c -> T d -> T e)
zipWith4 f s0 s1 =
   zipWith3 (uncurry f) (zip s0 s1)


{-# INLINE zip #-}
zip :: T a -> T b -> T (a,b)
zip = zipWith (,)

{-# INLINE zip3 #-}
zip3 :: T a -> T b -> T c -> T (a,b,c)
zip3 = zipWith3 (,,)

{-# INLINE zip4 #-}
zip4 :: T a -> T b -> T c -> T d -> T (a,b,c,d)
zip4 = zipWith4 (,,,)


{- * functions based on 'foldL' -}

{-# INLINE foldL' #-}
foldL' :: (x -> acc -> acc) -> acc -> T x -> acc
foldL' g b =
   switchL b (\ x xs -> foldL' g (g x $! b) xs)

{-# INLINE foldL #-}
foldL :: (acc -> x -> acc) -> acc -> T x -> acc
foldL f = foldL' (flip f)

{-# INLINE length #-}
length :: T a -> Int
length = foldL' (const succ) zero


{- * functions based on 'foldR' -}

foldR :: (x -> acc -> acc) -> acc -> T x -> acc
foldR g b =
   switchL b (\ x xs -> g x (foldR g b xs))


{- * Other functions -}

{-# INLINE null #-}
null :: T a -> Bool
null =
   switchL True (const (const False))
   -- foldR (const (const False)) True

{-# INLINE empty #-}
empty :: T a
empty = generate (const Nothing) ()

{-# INLINE singleton #-}
singleton :: a -> T a
singleton =
   generate (fmap (\x -> (x, Nothing))) . Just

{-# INLINE cons #-}
{- |
This is expensive and should not be used to construct lists iteratively!
-}
cons :: a -> T a -> T a
cons x xs =
   generate
      (\(mx0,xs0) ->
          fmap (mapSnd ((,) Nothing)) $
          maybe
             (viewL xs0)
             (\x0 -> Just (x0, xs0))
             mx0) $
   (Just x, xs)

{-# INLINE viewL #-}
viewL :: T a -> Maybe (a, T a)
viewL (Cons f a0) =
   fmap
      (mapSnd (Cons f))
      (runStateT f a0)

{- iterated 'cons' is very inefficient
viewR :: T a -> Maybe (T a, a)
viewR =
   foldR (\x mxs -> Just (maybe (empty,x) (mapFst (cons x)) mxs)) Nothing
-}

{-# INLINE viewR #-}
viewR :: Storable a => T a -> Maybe (T a, a)
viewR = viewRSize SigSt.defaultChunkSize

{-# INLINE viewRSize #-}
viewRSize :: Storable a => SigSt.ChunkSize -> T a -> Maybe (T a, a)
viewRSize size =
   fmap (mapFst fromStorableSignal) .
   SigSt.viewR .
   toStorableSignal size


{-# INLINE switchL #-}
switchL :: b -> (a -> T a -> b) -> T a -> b
switchL n j =
   maybe n (uncurry j) . viewL

{-# INLINE switchR #-}
switchR :: Storable a => b -> (T a -> a -> b) -> T a -> b
switchR n j =
   maybe n (uncurry j) . viewR


{- |
This implementation requires
that the input generator has to check repeatedly whether it is finished.
-}
{-# INLINE extendConstant #-}
extendConstant :: T a -> T a
extendConstant xt0 =
   switchL
      empty
      (\ x0 _ ->
          generate
             (\xt1@(x1,xs1) ->
                 Just $ switchL
                    (x1,xt1)
                    (\x xs -> (x, (x,xs)))
                    xs1)
             (x0,xt0)) $
      xt0


{-
{-# INLINE tail #-}
tail :: T a -> T a
tail = Cons . List.tail . decons

{-# INLINE head #-}
head :: T a -> a
head = List.head . decons
-}

{-# INLINE drop #-}
drop :: Int -> T a -> T a
drop n =
   fromMaybe empty .
   nest n (fmap snd . viewL =<<) .
   Just

{-# INLINE dropMarginRem #-}
{- |
This implementation expects that looking ahead is cheap.
-}
dropMarginRem :: Int -> Int -> T a -> (Int, T a)
dropMarginRem n m =
   switchL (error "StateSignal.dropMaringRem: length xs < n") const .
   dropMargin n m .
   zipWithTails (,) (iterate pred m)

{-# INLINE dropMargin #-}
dropMargin :: Int -> Int -> T a -> T a
dropMargin n m xs =
   dropMatch (take m (drop n xs)) xs


dropMatch :: T b -> T a -> T a
dropMatch xs ys =
   fromMaybe ys $
   liftM2 dropMatch
      (fmap snd $ viewL xs)
      (fmap snd $ viewL ys)


index :: Int -> T a -> a
index n =
   switchL (error "State.Signal: index too large") const . drop n


{-
splitAt :: Int -> T a -> (T a, T a)
splitAt n = mapPair (Cons, Cons) . List.splitAt n . decons
-}

{-# INLINE splitAt #-}
splitAt :: Storable a =>
   Int -> T a -> (T a, T a)
splitAt = splitAtSize SigSt.defaultChunkSize

{-# INLINE splitAtSize #-}
splitAtSize :: Storable a =>
   SigSt.ChunkSize -> Int -> T a -> (T a, T a)
splitAtSize size n =
   mapPair (fromStorableSignal, fromStorableSignal) .
   SigSt.splitAt n .
   toStorableSignal size


{-# INLINE dropWhile #-}
dropWhile :: (a -> Bool) -> T a -> T a
dropWhile p xt =
   switchL empty (\ x xs -> if p x then dropWhile p xs else xt) xt

{-
span :: (a -> Bool) -> T a -> (T a, T a)
span p = mapPair (Cons, Cons) . List.span p . decons
-}

{-# INLINE span #-}
span :: Storable a =>
   (a -> Bool) -> T a -> (T a, T a)
span = spanSize SigSt.defaultChunkSize

{-# INLINE spanSize #-}
spanSize :: Storable a =>
   SigSt.ChunkSize -> (a -> Bool) -> T a -> (T a, T a)
spanSize size p =
   mapPair (fromStorableSignal, fromStorableSignal) .
   SigSt.span p .
   toStorableSignal size


{-# INLINE cycle #-}
cycle :: T a -> T a
cycle xs =
   switchL
      (error "StateSignal.cycle: empty input")
      (curry $ \yt -> generate (Just . fromMaybe yt . viewL) xs)
      xs

{-# INLINE mix #-}
mix :: Additive.C a => T a -> T a -> T a
mix =
   curry (unfoldR mixStep)


mixStep :: (Additive.C a) =>
   (T a, T a) -> Maybe (a, (T a, T a))
mixStep (xt,yt) =
   case (viewL xt, viewL yt) of
      (Just (x,xs), Just (y,ys)) -> Just (x Additive.+ y, (xs,ys))
      (Nothing,     Just (y,ys)) -> Just (y,   (xt,ys))
      (Just (x,xs), Nothing)     -> Just (x,   (xs,yt))
      (Nothing,     Nothing)     -> Nothing


{-# INLINE sub #-}
sub :: Additive.C a => T a -> T a -> T a
sub xs ys =  mix xs (neg ys)

{-# INLINE neg #-}
neg :: Additive.C a => T a -> T a
neg = map Additive.negate

instance Additive.C y => Additive.C (T y) where
   zero = empty
   (+) = mix
   (-) = sub
   negate = neg

instance Module.C y yv => Module.C y (T yv) where
   (*>) x y = map (x*>) y


infixr 5 `append`

{-# INLINE append #-}
append :: T a -> T a -> T a
append xs ys =
   generate
      (\(b,xs0) ->
          mplus
             (fmap (mapSnd ((,) b)) $ viewL xs0)
             (if b
                then Nothing
                else fmap (mapSnd ((,) True)) $ viewL ys))
      (False,xs)

{-# INLINE appendStored #-}
appendStored :: Storable a =>
   T a -> T a -> T a
appendStored = appendStoredSize SigSt.defaultChunkSize

{-# INLINE appendStoredSize #-}
appendStoredSize :: Storable a =>
   SigSt.ChunkSize -> T a -> T a -> T a
appendStoredSize size xs ys =
   fromStorableSignal $
   SigSt.append
      (toStorableSignal size xs)
      (toStorableSignal size ys)

{-# INLINE concat #-}
-- | certainly inefficient because of frequent list deconstruction
concat :: [T a] -> T a
concat =
   generate
      (msum .
       List.map
          (\ x -> viewListL x >>=
           \(y,ys) -> viewL y >>=
           \(z,zs) -> Just (z,zs:ys)) .
       List.init . List.tails)


{-# INLINE concatStored #-}
concatStored :: Storable a =>
   [T a] -> T a
concatStored = concatStoredSize SigSt.defaultChunkSize

{-# INLINE concatStoredSize #-}
concatStoredSize :: Storable a =>
   SigSt.ChunkSize -> [T a] -> T a
concatStoredSize size =
   fromStorableSignal .
   SigSt.concat .
   List.map (toStorableSignal size)

{-# INLINE reverse #-}
reverse ::
   T a -> T a
reverse =
   fromList . List.reverse . toList

{-# INLINE reverseStored #-}
reverseStored :: Storable a =>
   T a -> T a
reverseStored = reverseStoredSize SigSt.defaultChunkSize

{-# INLINE reverseStoredSize #-}
reverseStoredSize :: Storable a =>
   SigSt.ChunkSize -> T a -> T a
reverseStoredSize size =
   fromStorableSignal .
   SigSt.reverse .
   toStorableSignal size


{-# INLINE sum #-}
sum :: (Additive.C a) => T a -> a
sum = foldL' (Additive.+) Additive.zero

{-# INLINE maximum #-}
maximum :: (Ord a) => T a -> a
maximum =
   switchL
      (error "FusionList.maximum: empty list")
      (foldL' max)

{-
{-# INLINE tails #-}
tails :: T y -> [T y]
tails = List.map Cons . List.tails . decons
-}

{-# INLINE init #-}
init :: T y -> T y
init =
   switchL
      (error "FusionList.init: empty list")
      (crochetL (\x acc -> Just (acc,x)))

{-# INLINE sliceVert #-}
-- inefficient since it computes some things twice
sliceVert :: Int -> T y -> [T y]
sliceVert n =
--   map fromList . Sig.sliceVert n . toList
   List.map (take n) . List.takeWhile (not . null) . List.iterate (drop n)

{-# INLINE zapWith #-}
zapWith :: (a -> a -> b) -> T a -> T b
zapWith f =
   switchL empty
      (crochetL (\y x -> Just (f x y, y)))

zapWithAlt :: (a -> a -> b) -> T a -> T b
zapWithAlt f xs =
   zipWith f xs (switchL empty (curry snd) xs)

{-# INLINE modifyStatic #-}
modifyStatic :: Modifier.Simple s ctrl a b -> ctrl -> T a -> T b
modifyStatic modif control x =
   crochetL
      (\a acc ->
         Just (runState (Modifier.step modif control a) acc))
      (Modifier.init modif) x

{-| Here the control may vary over the time. -}
{-# INLINE modifyModulated #-}
modifyModulated :: Modifier.Simple s ctrl a b -> T ctrl -> T a -> T b
modifyModulated modif control x =
   crochetL
      (\ca acc ->
         Just (runState (uncurry (Modifier.step modif) ca) acc))
      (Modifier.init modif)
      (zip control x)


-- cf. Module.linearComb
{-# INLINE linearComb #-}
linearComb ::
   (Module.C t y) =>
   T t -> T y -> y
linearComb ts ys =
   sum $ zipWith (*>) ts ys


-- comonadic 'bind'
-- only non-empty suffixes are processed
{-# INLINE mapTails #-}
mapTails ::
   (T y0 -> y1) -> T y0 -> T y1
mapTails f =
   generate (\xs ->
      do (_,ys) <- viewL xs
         return (f xs, ys))

-- only non-empty suffixes are processed
{-# INLINE zipWithTails #-}
zipWithTails ::
   (y0 -> T y1 -> y2) -> T y0 -> T y1 -> T y2
zipWithTails f =
   curry $ generate (\(xs0,ys0) ->
      do (x,xs) <- viewL xs0
         (_,ys) <- viewL ys0
         return (f x ys0, (xs,ys)))

delayLoop ::
      (T y -> T y)
            -- ^ processor that shall be run in a feedback loop
   -> T y   -- ^ prefix of the output, its length determines the delay
   -> T y
delayLoop proc prefix =
   -- the temporary list is need for sharing the output
   let ys = fromList (toList prefix List.++ toList (proc ys))
   in  ys

delayLoopOverlap ::
   (Additive.C y) =>
      Int
   -> (T y -> T y)
            -- ^ processor that shall be run in a feedback loop
   -> T y   -- ^ input
   -> T y   -- ^ output has the same length as the input
delayLoopOverlap time proc xs =
   -- the temporary list is need for sharing the output
   let ys = zipWith (Additive.+) xs (delay zero time (proc (fromList (toList ys))))
   in  ys


{-
A traversable instance is hardly useful,
because 'cons' is so expensive.

instance Traversable T where
-}
{-# INLINE sequence_ #-}
sequence_ :: Monad m => T (m a) -> m ()
sequence_ =
   switchL (return ()) (\x xs -> x >> sequence_ xs)

{-# INLINE mapM_ #-}
mapM_ :: Monad m => (a -> m ()) -> T a -> m ()
mapM_ f = sequence_ . map f