--
-- Copyright (c) 2009-2010, ERICSSON AB All rights reserved.
-- 
-- Redistribution and use in source and binary forms, with or without
-- modification, are permitted provided that the following conditions are met:
-- 
--     * Redistributions of source code must retain the above copyright notice,
--       this list of conditions and the following disclaimer.
--     * Redistributions in binary form must reproduce the above copyright
--       notice, this list of conditions and the following disclaimer in the
--       documentation and/or other materials provided with the distribution.
--     * Neither the name of the ERICSSON AB nor the names of its contributors
--       may be used to endorse or promote products derived from this software
--       without specific prior written permission.
-- 
-- THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
-- AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
-- IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
-- ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS
-- BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY,
-- OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
-- SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
-- INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
-- CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
-- ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
-- THE POSSIBILITY OF SUCH DAMAGE.
--

module Feldspar.Stream 
    (Stream
    ,head
    ,tail
    ,map
    ,intersperse
    ,interleave
    ,scan
    ,mapAccum
    ,iterate
    ,repeat
    ,unfold
    ,drop
    ,dropWhile
    ,filter
    ,partition
    ,zip
    ,zipWith
    ,unzip
    ,take
    ,splitAt
    ,cycle
    ,recurrence
    ,recurrenceI
    ,iir
    ,fir
    )
    where

import Feldspar.Core
import qualified Prelude
import Feldspar.Range
import Feldspar.Prelude hiding (filter,repeat,iterate,cycle)
import Control.Arrow

import Feldspar.Vector (Vector, DVector
                       ,vector
                       ,freezeVector,indexed
                       ,sum,length,replicate)

-- | Infinite streams.
data Stream a = forall state . (Computable a, Computable state) =>
                Stream (StepFunction state a) state

data StepFunction state a 
    = Continuous (state -> (a,state))
    | Stuttering (state -> (a,Data Bool, state))

-- When we want to treat a step function as if it was continuous.
-- Use with care! It introduces an extra while loop if the 
-- argument is stuttering
step :: (Computable state, Computable a) =>
        StepFunction state a -> (state -> (a,state))
step (Continuous next) init = next init
step (Stuttering next) init = (a,st)
    where (a,_,st) = while (not . snd3) (next . thd3) (next init)

-- When we cannot optimize for the continuous case we can use this function
-- to consider all step functions as stuttering and reduce the amount of
-- code we have to write.
stuttering :: StepFunction state a -> (state -> (a, Data Bool, state))
stuttering (Stuttering next) = next
stuttering (Continuous next) = \state -> let (a,st) = next state
                                         in (a,true,st)

-- This helper function enables us to write function using the stuttering 
-- case only while still propagating the continuous information.
-- Helps writing less code.
mapStep :: ((stateA -> (a,Data Bool, stateA)) -> 
                (stateB -> (b,Data Bool, stateB))) 
        -> StepFunction stateA a -> StepFunction stateB b
mapStep mkStep (Stuttering next) = Stuttering (mkStep next)
mapStep mkStep (Continuous next) = Continuous newStep
  where newStep a = let (b,_,st) = mkStep (\a -> let (b,st) = next a
                                                 in (b,true,st)) a
                    in (b,st)

-- Helper functions for working on triplets
fst3 (a,_,_) = a
snd3 (_,b,_) = b
thd3 (_,_,c) = c
first3  f (a,b,c) = (f a,b,c)
second3 f (a,b,c) = (a,f b,c)
third3  f (a,b,c) = (a,b,f c)

-- | Take the first element of a stream
head :: Computable a => Stream a -> a
head (Stream next init) = fst $ step next init

-- | Drop the first element of a stream
tail :: Computable a => Stream a -> Stream a
tail (Stream next init) = Stream next (snd $ step next init)

-- | 'map f str' transforms every element of the stream 'str' using the
--   function 'f'
map :: (Computable a, Computable b) =>
       (a -> b) -> Stream a -> Stream b
map f (Stream next init) = Stream (mapStep (first3 f .) next) init

-- | 'intersperse a str' inserts an 'a' between each element of the stream
--    'str'.
intersperse :: a -> Stream a -> Stream a
intersperse a (Stream next init) = 
    Stream (mapStep newNext next) (true,init)
  where newNext next (b,st) = b ? (let (e,isValid,st') = next st
                                   in isValid ? ( (e,true,(false,st'))
                                                , (e,false,(true,st'))
                                                )
                                  ,(a,true,(true,st))
                                  )

-- | Create a new stream by alternating between the elements from 
--   the two input streams
interleave :: Stream a -> Stream a -> Stream a
interleave (Stream (Continuous next1) init1) (Stream (Continuous next2) init2)
    = Stream (Continuous next) (true,init1,init2)
  where next (b,st1,st2) = b ? (let (a,st1') = next1 st1
                                in (a,(false,st1',st2))
                               ,let (a,st2') = next2 st2
                                in (a,(true,st1,st2'))
                               )
interleave (Stream next1 init1) (Stream next2 init2)
    = Stream (Stuttering next) (true,init1,init2)
  where next (b,st1,st2) = b ? (let (a,isValid,st1') = stuttering next1 st1
                                in isValid ? ( (a,true,(false,st1',st2))
                                             , (a,false,(true,st1',st2))
                                             )
                               ,let (a,isValid,st2') = stuttering next2 st2
                                in isValid ? ( (a,true,(true,st1,st2'))
                                             , (a,false,(false,st1,st2'))
                                             )
                               )

-- | 'scan f a str' produces a stream by successively applying 'f' to
--   each element of the input stream 'str' and the previous element of 
--   the output stream.
scan :: Computable a => (a -> b -> a) -> a -> Stream b -> Stream a
scan f a (Stream next init)
    = Stream (mapStep newNext next) (a,init)
  where newNext next (acc,st) = let (a,isValid,st') = next st
                                in isValid ? ( (acc,true,  (f acc a,st') )
                                             , (acc,false, (acc,st')     )
                                             )

{- This function is problematic to define for the same reason the index
   function is problematic, plus that it has the same quirk as correctScan.
-}

-- | A scan but without an initial element.
scan1 :: Computable a => (a -> a -> a) -> Stream a -> Stream a
scan1 f (Stream next init)
    = Stream (mapStep newNext next) (a,true,newInit)
  where (a,newInit) = step next init
        newNext next (a,isFirst,st)
            = isFirst ? ( (a, true, (a,false,st))
                        , let (b,isValid,st') = next st
                          in isValid ? ( let elem = f a b
                                         in (elem, true, (elem,false,st'))
                                       , (a,false, (a,false,st'))
                                       )
                        )

-- mapAccum creates a nested loop. It's either that or recomputing the 
-- function even for non-valid elements in the input stream.

-- | Maps a function over a stream using an accumulator.
mapAccum :: (Computable acc, Computable b) => 
            (acc -> a -> (acc,b)) -> acc -> Stream a -> Stream b
mapAccum f acc (Stream next init)
    = Stream (Continuous newNext) (init,acc)
  where newNext (st,acc)
            = let (a,st')  = step next st
                  (acc',b) = f acc a
              in (b, (st',acc'))

-- | Iteratively applies a function to a starting element. All the successive
--   results are used to create a stream.
--
-- @iterate f a == [a, f a, f (f a), f (f (f a)) ...]@
iterate :: Computable a => (a -> a) -> a -> Stream a
iterate f init = Stream (Continuous next) init
  where next a = (a, f a)

-- | Repeat an element indefinitely.
--
-- @repeat a = [a, a, a, ...]@
repeat :: Computable a => a -> Stream a
repeat a = Stream (Continuous next) unit
  where next _ = (a,unit)

-- | @unfold f acc@ creates a new stream by successively applying 'f' to
--   to the accumulator 'acc'.
unfold :: (Computable a, Computable c) => (c -> (a,c)) -> c -> Stream a
unfold next init = Stream (Continuous next) init

-- | Drop a number of elements from the front of a stream
drop :: Data Unsigned32 -> Stream a -> Stream a
{- This version creates a conditional inside the loop
   The output stream is always stuttering
drop i (Stream next init) = Stream (Stuttering newNext) (i,init)
  where newNext (i,st) = i == 0 ? (let (a,isValid,st') = stuttering next st
                                   in isValid ? ( (a,true,  (0,st'))
                                                , (a,false, (0,st')) 
                                                )
                                  ,let (a,isValid,st') = stuttering next st
                                   in isValid ? ( (a,false, (i-1,st'))
                                                , (a,false, (i,  st'))
                                                )
                                  )
-}
-- This version generates a while loop to compute the initial state
-- The output stream is continuous if the input stream is
drop i (Stream next init) = Stream next newState
  where (newState,_) = while cond body (init,i)
        cond (st,i)  = i > 0
        body (st,i)  = let (_,b,st') = stuttering next st
                       in b ? ( (st',i-1)
                              , (st',i))

-- | @dropWhile p str@ drops element from the stream @str@ as long as the
-- elements fulfill the predicate @p@.
dropWhile p (Stream next init) = Stream next newState
  where (_,newState) = while cond body (step next init)
        cond (a,st)  = p a
        body (_,st)  = step next st

-- | 'filter p str' removes elements from the stream 'str' if they are false
--   according to the predicate 'p'
filter :: (a -> Data Bool) -> Stream a -> Stream a
filter p (Stream next init) = Stream (Stuttering newNext) init
  where newNext st = let (a,isValid,st') = stuttering next st
                     in isValid && p a ? ( (a,true, st')
                                         , (a,false,st')
                                         )

-- | Splits a stream in two according to the predicate function. All 
--   elements which return true go in the first stream, the rest go in the
--   second.
partition :: (a -> Data Bool) -> Stream a -> (Stream a, Stream a)
partition p stream = (filter p stream, filter (not . p) stream)

-- In the case that the input streams are stuttering this function
-- will introduce nested loops

-- | Pairs together two streams into one.
zip :: Stream a -> Stream b -> Stream (a,b)
zip (Stream next1 init1) (Stream next2 init2)
    = Stream (Continuous next) (init1,init2)
  where next (st1,st2) = ( (a,b), (st1',st2') )
            where (a,st1') = step next1 st1
                  (b,st2') = step next2 st2

-- This function can also potentially introduce nested loops, just like zip

-- | Pairs together two streams using a function to combine the 
--   corresponding elements.
zipWith :: Computable c => (a -> b -> c) -> Stream a -> Stream b -> Stream c
zipWith f (Stream next1 init1) (Stream next2 init2)
    = Stream (Continuous next) (init1,init2)
  where next (st1,st2) = ( f a b, (st1',st2'))
            where (a,st1') = step next1 st1
                  (b,st2') = step next2 st2

-- | Given a stream of pairs, split it into two stream. 
unzip :: (Computable a, Computable b) => Stream (a,b) -> (Stream a, Stream b)
unzip stream = (map fst stream, map snd stream)

instance RandomAccess (Stream a) where
  type Element (Stream a) = a
  (Stream next init) ! n = fst3 $ while ((/= 0) . thd3) body (a,st,n)
      where body (a,st,i) = let (a,isValid,st') = stuttering next st
                            in isValid ? ( (a,st',i-1)
                                         , (a,st',i)
                                         )
            (a,st) = step next init -- I would like to get rid of this one

-- | 'take n str' allocates 'n' elements from the stream 'str' into a
--   core array.
take :: Storable a => Data Int -> Stream (Data a) -> Data [a]
take n (Stream next init) 
    = snd3 $ while cond body 
      (0,array (mapMonotonic fromIntegral (dataSize n) :> universal) [],init)
  where cond (i,_  ,_ ) = i < n
        body (i,arr,st) = let (a,isValid,st') = stuttering next st
                          in isValid ? ( (i+1,setIx arr i a,st')
                                       , (i,  arr,          st')
                                       )

-- | 'splitAt n str' allocates 'n' elements from the stream 'str' into a 
--   core array and returns the rest of the stream continuing from 
--   element 'n+1'.
splitAt :: Storable a => 
           Data Int -> Stream (Data a) -> (Data [a], Stream (Data a))
splitAt n (Stream next init) = (arr,Stream next st)
  where 
    (_,arr,st) = 
        while cond body 
        (0,array (mapMonotonic fromIntegral (dataSize n) :> universal) [],init)
    cond (i,_  ,_ ) = i < n
    body (i,arr,st) = let (a,isValid,st') = stuttering next st
                      in isValid ? ( (i+1,setIx arr i a,st')
                                   , (i,  arr,          st')
                                   )

-- | Loops through a vector indefinitely to produce a stream.
cycle :: Computable a => Vector a -> Stream a
cycle vec = Stream (Continuous next) 0
  where next i = (vec ! i, (i + 1) `rem` length vec)


-- | A combinator for descibing recurrence equations, or feedback loops.
--   It uses memory proportional to the input vector
--
-- For exaple one can define the fibonacci sequence as follows:
--
-- > fib = recurrence (vector [0,1]) (\fib -> fib 1 + fib 2)
--
-- The expressions @fib 1@ and @fib 2@ refer to previous elements in the 
-- stream defined one step back and two steps back respectively.
recurrence :: Storable a => 
              DVector a -> ((Int -> Data a) -> Data a) -> Stream (Data a)
recurrence init mkExpr = Stream (Continuous next) (buf,0)
  where buf            = freezeVector init
        len            = length init
        next (buf,ix)  = 
            let a = mkExpr (\i -> getIx buf ((value i + ix) `rem` len))
            in (getIx buf (ix `rem` len), (setIx buf (ix `rem` len) a, ix + 1))

-- | A recurrence combinator with input
--
-- The sliding average of a stream can easily be implemented using
-- 'recurrenceI'.
--
-- > slidingAvg :: Data Int -> Stream (Data Int) -> Stream (Data Int)
-- > slidingAvg n str = recurrenceI (replicate n 0) str (vector [])
-- >                    (\input _ -> sum (indexed n input) `quot` n)
recurrenceI :: (Storable a, Storable b) => 
               DVector a -> Stream (Data a) -> DVector b ->
               ((Data Int -> Data a) -> (Data Int -> Data b) -> Data b) ->
               Stream (Data b)
recurrenceI ii (Stream (Continuous st) s) io mkExpr 
    = Stream (Continuous step) (ibuf,obuf,s,0)
  where ibuf = freezeVector ii
        obuf = freezeVector io
        p    = length ii
        q    = length io
        step (ibuf,obuf,s,ix) = 
            let (a,s') = st s
                ibuf'  = p /= 0 ? (setIx ibuf (ix `rem` p) a, ibuf)
                b = mkExpr (\i -> getIx ibuf' ((i + ix)     `rem` p))
                           (\i -> getIx obuf  ((i + ix - 1) `rem` q))
            in (q /= 0 ? (getIx obuf (ix `rem` q),b), 
                          (ibuf'
                          ,q /= 0 ? (setIx obuf (ix `rem` q) b,obuf)
                          ,s'
                          ,ix + 1))
recurrenceI ii (Stream (Stuttering st) s) io mkExpr
    = Stream (Stuttering step) (ibuf,obuf,s,0)
  where ibuf = freezeVector ii
        obuf = freezeVector io
        p    = length ii
        q    = length io
        step (ibuf,obuf,s,ix) = 
            let (a,isValid,s') = st s
                ibuf'  = p /= 0 ? (setIx ibuf (ix `rem` p) a,ibuf)
                b = mkExpr (\i -> getIx ibuf' ((i + ix)     `rem` p))
                           (\i -> getIx obuf  ((i + ix - 1) `rem` q))
            in isValid ?( (q /= 0 ? (getIx obuf (ix `rem` q), b), true,
                                     (ibuf'
                                     ,q /= 0 ? (setIx obuf (ix `rem` q) b,obuf)
                                     ,s'
                                     ,ix + 1))
                        , (q /= 0 ? (getIx obuf (ix `rem` q),b), false,
                                     (ibuf
                                     ,obuf
                                     ,s'
                                     ,ix))
                        )

slidingAvg :: Data Int -> Stream (Data Int) -> Stream (Data Int)
slidingAvg n str = recurrenceI (replicate n 0) str (vector [])
                   (\input _ -> sum (indexed n input) `quot` n)

-- | A fir filter on streams
fir :: DVector Float -> 
       Stream (Data Float) -> Stream (Data Float)
fir b input = 
    recurrenceI (replicate n 0) input
                (vector [])
                (\input _ -> sum (indexed n (\i -> b!i * input!(n-i))))
  where n = length b

-- | An iir filter on streams
iir :: Data Float -> DVector Float -> DVector Float -> 
       Stream (Data Float) -> Stream (Data Float)
iir a0 a b input = 
    recurrenceI (replicate q 0) input 
                (replicate p 0)
      (\input output -> 1 / a0 * 
                        ( sum (indexed p (\i -> b!i *  input!(p-i)))
                        - sum (indexed q (\j -> a!j * output!(q-j))))
      )
  where p = length b
        q = length a

-- A nice instance to have when using the recurrence functions.
instance RandomAccess (Data Int -> Data a) where
  type Element (Data Int -> Data a) = Data a
  (!) = ($)

-- Function to be used with filter for debuggin purposes
even n = n `rem` 2 == 0