{-# LANGUAGE CPP, ExistentialQuantification, Rank2Types #-}

#ifdef USE_PRAGMAS
#define INPRAG(f)  {-# INLINE f #-}
#define INPRAG0(f) {-# INLINE [0] f #-}
#define INPRAG1(f) {-# INLINE [1] f #-}
#else
#define INPRAG(f)
#define INPRAG0(f)
#define INPRAG1(f)
#endif

module Data.Vector.Fusion.Stream.Monadic where

import Control.Monad  (liftM)
import Prelude        (Monad(..), Int, Ord(..), ($), Maybe(..))

data Step s a = Yield a s
              | Skip    s
              | Done

data Stream m a = forall s. Stream (s -> m (Step s a)) s Size

zipWithM :: Monad m => (a -> b -> m c) -> Stream m a -> Stream m b -> Stream m c
INPRAG1(zipWithM)
zipWithM f (Stream stepa sa0 na) (Stream stepb sb0 nb)
  = Stream step (sa0, sb0, Nothing) (smaller na nb)
  where
    INPRAG0(step)
    step (sa, sb, Nothing) = liftM (\r ->
                               case r of
                                 Yield x sa' -> Skip (sa', sb, Just x)
                                 Skip    sa' -> Skip (sa', sb, Nothing)
                                 Done        -> Done
                             ) (stepa sa)

    step (sa, sb, Just x)  = do
                               r <- stepb sb
                               case r of
                                 Yield y sb' ->
                                   do
                                     z <- f x y
                                     return $ Yield z (sa, sb', Nothing)
                                 Skip    sb' -> return $ Skip (sa, sb', Just x)
                                 Done        -> return $ Done

zipWith3M :: Monad m => (a -> b -> c -> m d) -> Stream m a -> Stream m b -> Stream m c -> Stream m d
INPRAG1(zipWith3M)
zipWith3M f (Stream stepa sa0 na) (Stream stepb sb0 nb) (Stream stepc sc0 nc)
  = Stream step (sa0, sb0, sc0, Nothing) (smaller na (smaller nb nc))
  where
    INPRAG0(step)
    step (sa, sb, sc, Nothing) = do
        r <- stepa sa
        return $ case r of
            Yield x sa' -> Skip (sa', sb, sc, Just (x, Nothing))
            Skip    sa' -> Skip (sa', sb, sc, Nothing)
            Done        -> Done

    step (sa, sb, sc, Just (x, Nothing)) = do
        r <- stepb sb
        return $ case r of
            Yield y sb' -> Skip (sa, sb', sc, Just (x, Just y))
            Skip    sb' -> Skip (sa, sb', sc, Just (x, Nothing))
            Done        -> Done

    step (sa, sb, sc, Just (x, Just y)) = do
        r <- stepc sc
        case r of
            Yield z sc' -> f x y z >>= (\res -> return $ Yield res (sa, sb, sc', Nothing))
            Skip    sc' -> return $ Skip (sa, sb, sc', Just (x, Just y))
            Done        -> return $ Done

zipWith4M :: Monad m => (a -> b -> c -> d -> m e)
                     -> Stream m a -> Stream m b -> Stream m c -> Stream m d
                     -> Stream m e
INPRAG(zipWith4M)
zipWith4M f sa sb sc sd
  = zipWithM (\(a,b) (c,d) -> f a b c d) (zip sa sb) (zip sc sd)

zipWith5M :: Monad m => (a -> b -> c -> d -> e -> m f)
                     -> Stream m a -> Stream m b -> Stream m c -> Stream m d
                     -> Stream m e -> Stream m f
INPRAG(zipWith5M)
zipWith5M f sa sb sc sd se
  = zipWithM (\(a,b,c) (d,e) -> f a b c d e) (zip3 sa sb sc) (zip sd se)

zipWith6M :: Monad m => (a -> b -> c -> d -> e -> f -> m g)
                     -> Stream m a -> Stream m b -> Stream m c -> Stream m d
                     -> Stream m e -> Stream m f -> Stream m g
INPRAG(zipWith6M)
zipWith6M fn sa sb sc sd se sf
  = zipWithM (\(a,b,c) (d,e,f) -> fn a b c d e f) (zip3 sa sb sc)
                                                  (zip3 sd se sf)

zipWith :: Monad m => (a -> b -> c) -> Stream m a -> Stream m b -> Stream m c
INPRAG(zipWith)
zipWith f = zipWithM (\a b -> return (f a b))

zipWith3 :: Monad m => (a -> b -> c -> d)
                    -> Stream m a -> Stream m b -> Stream m c -> Stream m d
INPRAG(zipWith3)
zipWith3 f = zipWith3M (\a b c -> return (f a b c))

zipWith4 :: Monad m => (a -> b -> c -> d -> e)
                    -> Stream m a -> Stream m b -> Stream m c -> Stream m d
                    -> Stream m e
INPRAG(zipWith4)
zipWith4 f = zipWith4M (\a b c d -> return (f a b c d))

zipWith5 :: Monad m => (a -> b -> c -> d -> e -> f)
                    -> Stream m a -> Stream m b -> Stream m c -> Stream m d
                    -> Stream m e -> Stream m f
INPRAG(zipWith5)
zipWith5 f = zipWith5M (\a b c d e -> return (f a b c d e))

zipWith6 :: Monad m => (a -> b -> c -> d -> e -> f -> g)
                    -> Stream m a -> Stream m b -> Stream m c -> Stream m d
                    -> Stream m e -> Stream m f -> Stream m g
INPRAG(zipWith6)
zipWith6 fn = zipWith6M (\a b c d e f -> return (fn a b c d e f))

zip :: Monad m => Stream m a -> Stream m b -> Stream m (a,b)
INPRAG(zip)
zip = zipWith (,)

zip3 :: Monad m => Stream m a -> Stream m b -> Stream m c -> Stream m (a,b,c)
INPRAG(zip3)
zip3 = zipWith3 (,,)

zip4 :: Monad m => Stream m a -> Stream m b -> Stream m c -> Stream m d
                -> Stream m (a,b,c,d)
INPRAG(zip4)
zip4 = zipWith4 (,,,)

zip5 :: Monad m => Stream m a -> Stream m b -> Stream m c -> Stream m d
                -> Stream m e -> Stream m (a,b,c,d,e)
INPRAG(zip5)
zip5 = zipWith5 (,,,,)

zip6 :: Monad m => Stream m a -> Stream m b -> Stream m c -> Stream m d
                -> Stream m e -> Stream m f -> Stream m (a,b,c,d,e,f)
INPRAG(zip6)
zip6 = zipWith6 (,,,,,)

delay_inline :: (a -> b) -> a -> b
delay_inline f = f

data Size = Exact Int
          | Max   Int
          | Unknown

smaller :: Size -> Size -> Size
INPRAG(smaller)
smaller (Exact m) (Exact n) = Exact (min m n)
smaller (Exact m) (Max   n) = Max   (min m n)
smaller (Exact m) Unknown   = Max   m
smaller (Max   m) (Exact n) = Max   (min m n)
smaller (Max   m) (Max   n) = Max   (min m n)
smaller (Max   m) Unknown   = Max   m
smaller Unknown   (Exact n) = Max   n
smaller Unknown   (Max   n) = Max   n
smaller Unknown   Unknown   = Unknown

