module Feldspar.Stream
    (Stream
    ,head
    ,tail
    ,map,mapNth
    ,maps
    ,intersperse
    ,interleave
    ,downsample
    ,duplicate
    ,scan
    ,mapAccum
    ,iterate
    ,repeat
    ,unfold
    ,drop
    ,zip,zipWith
    ,unzip
    ,take
    ,splitAt
    ,cycle
    ,streamAsVector, streamAsVectorSize
    ,recurrenceO,recurrenceI,recurrenceIO
    ,iir,fir
    )
    where

import Feldspar.Core
import qualified Prelude
import Feldspar.Range
import Feldspar.Prelude hiding (filter,repeat,iterate,cycle)
import Control.Arrow

import Feldspar.Vector (Vector, DVector
                       ,vector,freezeVector,unfreezeVector,indexed
                       ,sum,length,replicate,reverse,scalarProd)

-- | Infinite streams.
data Stream a where
  Stream :: Syntactic state => (state -> (a,state)) -> state -> Stream a

-- | Take the first element of a stream
head :: Syntactic a => Stream a -> a
head (Stream next init) = fst $ next init

-- | Drop the first element of a stream
tail :: Syntactic a => Stream a -> Stream a
tail (Stream next init) = Stream next (snd $ next init)

-- | 'map f str' transforms every element of the stream 'str' using the
--   function 'f'
map :: (Syntactic a, Syntactic b) =>
       (a -> b) -> Stream a -> Stream b
map f (Stream next init) = Stream newNext init
  where newNext st = let (a,st') = next st in (f a, st')

-- | 'mapNth f n k str' transforms every 'n'th element with offset 'k'
--    of the stream 'str' using the function 'f'
mapNth :: (Syntactic a) => 
          (a -> a) -> Data Index -> Data Index -> Stream a -> Stream a
mapNth f n k (Stream next init) = Stream newNext (init,0)
  where newNext (st,i) = let (a,st') = next st in (i==k?(f a,a),(st',(i+1) `mod` n))

-- | 'maps fs str' uses one of the functions from 'fs' successively to modify
--   the elements of 'str'
maps :: (Syntactic a) =>
        [(a -> a)] -> Stream a -> Stream a
maps fs (Stream next init) = Stream newNext (init,0 :: Data Index)
  where newNext (st,i) = 
            let (a,st') = next st in
            (Prelude.foldr (\ (k,f) r -> 
                                i==(fromIntegral k)?(f a,r)) 
                           a (Prelude.zip [1..] fs)
            ,(st',(i+1) `mod` fromIntegral (Prelude.length fs))
            )

-- | 'intersperse a str' inserts an 'a' between each element of the stream
--    'str'.
intersperse :: Syntactic a => a -> Stream a -> Stream a
intersperse a (Stream next init) =
    Stream newNext (true,init)
  where newNext (b,st) = b ? (let (e,st') = next st
                              in (e,(false,st'))
                             ,(a,(true,st))
                             )

-- | Create a new stream by alternating between the elements from
--   the two input streams
interleave :: Syntactic a => Stream a -> Stream a -> Stream a
interleave (Stream next1 init1) (Stream next2 init2)
    = Stream next (true,init1,init2)
  where next (b,st1,st2) = b ? (let (a,st1') = next1 st1
                                in (a,(false,st1',st2))
                               ,let (a,st2') = next2 st2
                                in (a,(true,st1,st2'))
                               )

-- | 'downsample n str' takes every 'n'th element of the input stream
downsample :: Syntactic a => Data Index -> Stream a -> Stream a
downsample n (Stream next init) = Stream newNext init
  where newNext st = forLoop (n-1) (next st) (\_ (_,st) -> next st)

-- | 'duplicate n str' stretches the stream by duplicating the elements 'n' times
duplicate :: Syntactic a => Data Index -> Stream a -> Stream a
duplicate n (Stream next init) = Stream newNext (next init,1)
  where newNext (p@(a,st),i) = i==0 ? (let (b,st') = next st in (b,((b,st'),1))
                                      ,(a,(p,(i+1)`mod`n))
                                      )

-- | 'scan f a str' produces a stream by successively applying 'f' to
--   each element of the input stream 'str' and the previous element of
--   the output stream.
scan :: Syntactic a => (a -> b -> a) -> a -> Stream b -> Stream a
scan f a (Stream next init)
    = Stream newNext (a,init)
  where newNext (acc,st) = let (a,st') = next st
                           in (acc,  (f acc a,st') )

{- This function is problematic to define for the same reason the index
   function is problematic, plus that it has the same quirk as correctScan.
-}

-- | A scan but without an initial element.
scan1 :: Syntactic a => (a -> a -> a) -> Stream a -> Stream a
scan1 f (Stream next init)
    = Stream newNext (a,true,newInit)
  where (a,newInit) = next init
        newNext (a,isFirst,st)
            = isFirst ? ( (a, (a,false,st))
                        , let (b,st') = next st
                          in let elem = f a b
                             in (elem, (elem,false,st'))
                        )

-- | Maps a function over a stream using an accumulator.
mapAccum :: (Syntactic acc, Syntactic b) =>
            (acc -> a -> (acc,b)) -> acc -> Stream a -> Stream b
mapAccum f acc (Stream next init)
    = Stream newNext (init,acc)
  where newNext (st,acc)
            = let (a,st')  = next st
                  (acc',b) = f acc a
              in (b, (st',acc'))

-- | Iteratively applies a function to a starting element. All the successive
--   results are used to create a stream.
--
-- @iterate f a == [a, f a, f (f a), f (f (f a)) ...]@
iterate :: Syntactic a => (a -> a) -> a -> Stream a
iterate f init = Stream next init
  where next a = (a, f a)

-- | Repeat an element indefinitely.
--
-- @repeat a = [a, a, a, ...]@
repeat :: Syntactic a => a -> Stream a
repeat a = Stream next (value ())
  where next _ = (a,value ())

-- | @unfold f acc@ creates a new stream by successively applying 'f' to
--   to the accumulator 'acc'.
unfold :: (Syntactic a, Syntactic c) => (c -> (a,c)) -> c -> Stream a
unfold next init = Stream next init

-- | Drop a number of elements from the front of a stream
drop :: Data Length -> Stream a -> Stream a
drop i (Stream next init) = Stream next newState
  where newState  = forLoop i init body
        body _    = snd . next

-- | Pairs together two streams into one.
zip :: Stream a -> Stream b -> Stream (a,b)
zip (Stream next1 init1) (Stream next2 init2)
    = Stream next (init1,init2)
  where next (st1,st2) = ( (a,b), (st1',st2') )
            where (a,st1') = next1 st1
                  (b,st2') = next2 st2

-- | Pairs together two streams using a function to combine the
--   corresponding elements.
zipWith :: Syntactic c => (a -> b -> c) -> Stream a -> Stream b -> Stream c
zipWith f (Stream next1 init1) (Stream next2 init2)
    = Stream next (init1,init2)
  where next (st1,st2) = ( f a b, (st1',st2'))
            where (a,st1') = next1 st1
                  (b,st2') = next2 st2

-- | Given a stream of pairs, split it into two stream.
unzip :: (Syntactic a, Syntactic b) => Stream (a,b) -> (Stream a, Stream b)
unzip stream = (map fst stream, map snd stream)

instance Syntactic a => RandomAccess (Stream a) where
  type Element (Stream a) = a
  (Stream next init) ! n = fst $ forLoop n (next init) body
    where body _ (_,st) = next st

-- | 'take n str' allocates 'n' elements from the stream 'str' into a
--   core array.
take :: (Type a) => Data Length -> Stream (Data a) -> Data [a]
take n (Stream next init)
    = sequential n init step (const $ value [])
  where step i st = next st

-- | 'splitAt n str' allocates 'n' elements from the stream 'str' into a
--   core array and returns the rest of the stream continuing from
--   element 'n+1'.
splitAt :: (Type a) =>
           Data Length -> Stream (Data a) -> (Data [a], Stream (Data a))
splitAt n stream = (take n stream,drop n stream)

-- | Loops through a vector indefinitely to produce a stream.
cycle :: Syntactic a => Vector a -> Stream a
cycle vec = Stream next 0
  where next i = (vec ! i, (i + 1) `rem` length vec)

unsafeVectorToStream :: Syntactic a => Vector a -> Stream a
unsafeVectorToStream vec = Stream next 0
  where next i = (vec ! i, i + 1)

-- | A convenience function for translating an algorithm on streams to an algorithm on vectors.
--   The result vector will have the same length as the input vector.
--   It is important that the stream function doesn't drop any elements of
--   the input stream.
-- 
--   This function allocates memory for the output vector.
streamAsVector :: (Type a, Type b) => 
                  (Stream (Data a) -> Stream (Data b)) 
               -> (Vector (Data a) -> Vector (Data b))
streamAsVector f v 
    = unfreezeVector $ take (length v) $ f $ unsafeVectorToStream v

-- | Similar to 'streamAsVector' except the size of the output array is computed by the second argument
--   which is given the size of the input vector as a result.
streamAsVectorSize :: (Type a, Type b) => 
                      (Stream (Data a) -> Stream (Data b)) -> (Data Length -> Data Length) 
                   -> (Vector (Data a) -> Vector (Data b))
streamAsVectorSize f s v = unfreezeVector $ take (s $ length v) $ f $ cycle v

-- | A combinator for descibing recurrence equations, or feedback loops.
--   The recurrence equation may refer to previous outputs of the stream,
--   but only as many as the length of the input stream
--   It uses memory proportional to the input vector.
--
-- For exaple one can define the fibonacci sequence as follows:
--
-- > fib = recurrenceO (vector [0,1]) (\fib -> fib!0 + fib!1)
--
-- The expressions @fib 1@ and @fib 2@ refer to previous elements in the
-- stream defined one step back and two steps back respectively.
recurrenceO :: Type a =>
               DVector a -> 
               (DVector a -> Data a) -> 
               Stream (Data a)
recurrenceO init mkExpr = Stream next (buf,0)
  where buf            = freezeVector init
        len            = getLength buf
        next (buf,ix)  =
            let a = mkExpr (indexed len (\i -> getIx buf ((i + ix) `rem` len)))
            in (getIx buf (ix `rem` len), (setIx buf (ix `rem` len) a, ix + 1))


-- | A recurrence combinator with input. The function 'recurrenceI' is 
--   similar to 'recurrenceO'. The difference is that that it has an input
--   stream, and that the recurrence equation may only refer to previous
--   inputs, it may not refer to previous outputs.
--
-- The sliding average of a stream can easily be implemented using
-- 'recurrenceI'.
--
-- > slidingAvg :: Data DefaultWord -> Stream (Data DefaultWord) -> Stream (Data DefaultWord)
-- > slidingAvg n str = recurrenceI (replicate n 0) str
-- >                    (\input _ -> sum input `quot` n)
recurrenceI :: (Type a, Type b) =>
               DVector a -> Stream (Data a) ->
               (DVector a -> Data b) ->
               Stream (Data b)
recurrenceI ii stream mkExpr 
    = recurrenceIO ii stream (vector []) (\i o -> mkExpr i)

-- | 'recurrenceIO' is a combination of 'recurrenceO' and 'recurrenceI'. It
--   has an input stream and the recurrence equation may refer both to
--   previous inputs and outputs.
--
--   'recurrenceIO' is used when defining the 'iir' filter.
recurrenceIO :: (Type a, Type b) =>
                DVector a -> Stream (Data a) -> DVector b ->
                (DVector a -> DVector b -> Data b) ->
                Stream (Data b)
recurrenceIO ii (Stream st s) io mkExpr
    = Stream step (ibuf,obuf,s,0)
  where ibuf = freezeVector ii
        obuf = freezeVector io
        p    = getLength ibuf
        q    = getLength obuf
        step (ibuf,obuf,s,ix) =
            let (a,s') = st s
                ibuf'  = p /= 0 ? (setIx ibuf (ix `rem` p) a, ibuf)
                b = mkExpr 
                    (indexed p (\i -> getIx ibuf' ((i + ix)     `rem` p)))
                    (indexed q (\i -> getIx obuf  ((i + ix - 1) `rem` q)))
            in (q /= 0 ? (getIx obuf (ix `rem` q),b),
                          (ibuf'
                          ,q /= 0 ? (setIx obuf (ix `rem` q) b,obuf)
                          ,s'
                          ,ix + 1))

recurrenceIIO :: (Type a, Type b, Type c) =>
                 DVector a -> Stream (Data a) -> DVector b -> Stream (Data b) ->
                 DVector c ->
                 (DVector a -> DVector b -> DVector c -> Data c) ->
                 Stream (Data c)
recurrenceIIO i1 (Stream next1 init1) i2 (Stream next2 init2) io mkExpr
    = Stream next ((ibuf1,init1),(ibuf2,init2),obuf,0)
  where ibuf1 = freezeVector i1
        ibuf2 = freezeVector i2
        obuf  = freezeVector io
        l1    = getLength ibuf1
        l2    = getLength ibuf2
        lo    = getLength obuf
        next ((ibuf1,st1),(ibuf2,st2),obuf,ix) =             
            let (a,st1') = next1 st1
                (b,st2') = next2 st2
                ibuf1'  = l1 /= 0 ? (setIx ibuf1 (ix `rem` l1) a, ibuf1)
                ibuf2'  = l2 /= 0 ? (setIx ibuf2 (ix `rem` l2) b, ibuf2)
                c = mkExpr (indexed l1 (\i -> getIx ibuf1' ((i + ix)     `rem` l1)))
                           (indexed l2 (\i -> getIx ibuf2' ((i + ix)     `rem` l2)))
                           (indexed lo (\i -> getIx obuf   ((i + ix - 1) `rem` lo)))
            in (lo /= 0 ? (getIx obuf (ix `rem` lo),c),
                          ((ibuf1',st1')
                          ,(ibuf2',st2')
                          ,lo /= 0 ? (setIx obuf (ix `rem` lo) c,obuf)
                          ,ix + 1))

slidingAvg :: Data DefaultWord -> Stream (Data DefaultWord) -> Stream (Data DefaultWord)
slidingAvg n str = recurrenceI (replicate n 0) str
                   (\input -> sum input `quot` n)

-- | A fir filter on streams
fir :: DVector Float ->
       Stream (Data Float) -> Stream (Data Float)
fir b input =
    recurrenceI (replicate (length b) 0) input
                (\input -> scalarProd b input)

-- | An iir filter on streams
iir :: Data Float -> DVector Float -> DVector Float ->
       Stream (Data Float) -> Stream (Data Float)
iir a0 a b input =
    recurrenceIO (replicate (length b) 0) input
                 (replicate (length a) 0)
      (\input output -> 1 / a0 *
                        ( scalarProd b input
                        - scalarProd a output)
      )