    ,ScopedTypeVariables #-}

module Data.MutableIter (
  MIteratee (..)
  ,IOBuffer (..)


import Prelude hiding (head, null, drop, dropWhile, catch)
import qualified Prelude as P

import qualified Data.MutableIter.IOBuffer as IB
import Data.MutableIter.IOBuffer (IOBuffer, null, empty, pop)

import qualified Data.Iteratee as I
import Data.Maybe
import qualified Data.Vector.Unboxed as U

import Control.Exception (SomeException, IOException, toException)
import Control.Monad
import Control.Monad.CatchIO as CIO
import Control.Monad.IO.Class
import Control.Monad.Trans.Class
import Foreign.ForeignPtr
import Foreign.Storable (Storable, sizeOf, poke)
import Foreign.Marshal.Array

import System.IO

isNullChunk :: (MonadIO s) => I.Stream (IOBuffer r el) -> s Bool
isNullChunk (I.EOF _) = return False
isNullChunk (I.Chunk s) = liftIO $ null s

newtype MIteratee s m a = MIteratee {unwrap :: I.Iteratee s m a} deriving (Functor)

instance (MonadIO m, Storable el) => Monad (MIteratee (IOBuffer r el) m) where
  {-# INLINE return #-}
  return x = MIteratee (I.Iteratee $ \onDone _ -> onDone x (I.Chunk empty))
  -- {-# INLINE (>>=) #-} -- this inline makes things a bit slower with 6.12.  Seems fixed in ghc-7
  m >>= f = MIteratee (I.Iteratee $ \onDone onCont ->
    let mDone a str = do
          isNull <- isNullChunk str
          if isNull
            then I.runIter (unwrap $ f a) onDone onCont
            else I.runIter (unwrap $ f a) (const . flip onDone str) (fCont str)
        fCont str k Nothing = I.runIter (k str) onDone onCont
        fCont _   k e       = onCont k e
    in I.runIter (unwrap m) mDone (\k -> onCont (unwrap . (>>= f) . MIteratee . k)))

instance (Storable el) => MonadTrans (MIteratee (IOBuffer s el)) where
  lift m = MIteratee $ I.Iteratee $ \od _ -> m >>= flip od (I.Chunk empty)

instance (MonadIO m, Storable el)
  => MonadIO (MIteratee (IOBuffer r el) m) where
    liftIO = lift . liftIO

instance (MonadCatchIO m, Storable el) =>
  MonadCatchIO (MIteratee (IOBuffer r el) m) where
    catch m f = MIteratee $ I.Iteratee $ \od oc ->
      I.runIter (unwrap m) od oc `catch` (\e -> I.runIter (unwrap $ f e) od oc)
    block   = mapIteratee block
    unblock = mapIteratee unblock

mapIteratee :: (Monad m, Storable el) =>
  (m a -> m b)
  -> MIteratee (IOBuffer r el) m a
  -> MIteratee (IOBuffer r el) m b
mapIteratee f = lift . f . I.run . unwrap

joinIob :: (MonadCatchIO m, Storable el) =>
  MIteratee (IOBuffer r el) m (MIteratee s m a)
  -> MIteratee (IOBuffer r el) m a
joinIob outer = outer >>= \inner -> MIteratee $ I.Iteratee $ \od oc ->
  let onDone  x _        = od x (I.Chunk empty)
      onCont  k Nothing  = I.runIter (k (I.EOF Nothing)) onDone onCont'
      onCont  _ (Just e) = I.runIter (I.throwErr e) od oc
      onCont' _ e        = I.runIter (I.throwErr (fromMaybe excDiv e)) od oc
  in I.runIter (unwrap inner) onDone onCont

joinIM :: (Monad m) =>
  m (MIteratee (IOBuffer r el) m a)
  -> MIteratee (IOBuffer r el) m a
joinIM mIter = MIteratee (I.Iteratee (\od oc -> mIter >>= (\iter ->
                 I.runIter (unwrap iter) od oc)))

excDiv :: SomeException
excDiv = toException I.DivergentException

type MEnumerator s m a = MIteratee s m a -> m (MIteratee s m a)

wrapEnum :: (Monad m) => I.Enumerator s m a -> MEnumerator s m a
wrapEnum enum = liftM MIteratee . enum . unwrap

liftI :: (Monad m) => (I.Stream s -> MIteratee s m a) -> MIteratee s m a
liftI = MIteratee . I.liftI . (unwrap .)

idone x str = MIteratee $ I.idone x str

icont k mErr = MIteratee $ I.icont (unwrap . k) mErr

guardNull :: (MonadCatchIO m, Storable el) =>
  IOBuffer r el
  -> MIteratee (IOBuffer r el) m a
  -> MIteratee (IOBuffer r el) m a
  -> MIteratee (IOBuffer r el) m a
guardNull buf onEmpty onFull = do
  isNull <- liftIO $ null buf
  if isNull then onEmpty else onFull
{-# INLINE guardNull #-}

isStreamFinished ::
 (MonadCatchIO m, Storable el) =>
  MIteratee (IOBuffer r el) m (Maybe SomeException)
isStreamFinished = liftI step
    step s@(I.Chunk buf) = guardNull buf (liftI step) (idone Nothing s)
    step s@(I.EOF e) =
      idone (Just $ fromMaybe (toException I.EofException) e) s
{-# INLINE isStreamFinished #-}

head :: (MonadCatchIO m, Storable el) => MIteratee (IOBuffer r el) m el
head = liftI step
    step (I.Chunk buf) = guardNull buf (liftI step) $ do
      x <- liftIO $ pop buf
      idone x (I.Chunk buf)
    step str = MIteratee . I.throwErr . toException $ I.EofException
{-# INLINE head #-}

peek :: (MonadCatchIO m, Storable el) => MIteratee (IOBuffer r el) m (Maybe el)
peek = liftI step
    step c@(I.Chunk buf) = guardNull buf (liftI step) $ do
      x <- liftIO $ IB.lookAtHead buf
      idone x c
    step str = idone Nothing str
{-# INLINE peek #-}

heads :: (MonadCatchIO m, Storable el, Eq el) =>
  -> MIteratee (IOBuffer r el) m Int
heads st = loop 0 st
    loop cnt [] = return cnt
    loop cnt str = liftI (step cnt str)
    step cnt s@(x:xs) (I.Chunk buf) = do
      mx <- liftIO $ IB.lookAtHead buf
      maybe (liftI (step cnt s)) (\h -> if h == x
        then liftIO (pop buf) >> step (succ cnt) xs (I.Chunk buf)
        else idone cnt (I.Chunk buf)) mx
    step cnt s str = idone cnt str
{-# INLINE heads #-}

chunk ::
  (MonadCatchIO m, Storable el)
  => MIteratee (IOBuffer r el) m (IOBuffer r el)
chunk = liftI step
  step (I.Chunk buf) = guardNull buf (liftI step) $ idone buf (I.Chunk empty)
  step str = MIteratee . I.throwErr . toException $ I.EofException

drop :: (MonadCatchIO m, Storable el) => Int -> MIteratee (IOBuffer r el) m ()
drop n = liftI (step n)
    step 0  str           = idone () str
    step n' (I.Chunk buf) = guardNull buf (liftI (step n')) $ do
      l <- liftIO $ IB.length buf
      if n' < l
        then liftIO (IB.drop n' buf) >> idone () (I.Chunk buf)
        else liftIO (IB.drop l buf) >> liftI (step (n' - l))
    step _ str = idone () str
{-# INLINE drop #-}

dropWhile :: (MonadCatchIO m, Storable el) =>
  (el -> Bool)
  -> MIteratee (IOBuffer r el) m ()
dropWhile pred = liftI step
    step (I.Chunk buf) = guardNull buf (liftI step) $ do
      liftIO $ IB.dropWhile pred buf
      l <- liftIO $ IB.length buf
      if l == 0 then liftI step else idone () (I.Chunk buf)
    step str = idone () str
{-# INLINE dropWhile #-}

foldl' :: (MonadCatchIO m, Storable el, Show a) =>
  (a -> el -> a)
  -> a
  -> MIteratee (IOBuffer r el) m a
foldl' f acc' = liftI (step acc')
    step acc (I.Chunk buf) = guardNull buf (liftI (step acc)) $ do
      newacc <- liftIO $ IB.foldl' f acc buf
      liftI (step newacc)
    step acc str = idone acc str
{-# INLINE foldl' #-}

hopfoldl' :: (MonadCatchIO m, Storable el, Show a) =>
  -> (a -> el -> a)
  -> a
  -> MIteratee (IOBuffer r el) m a
hopfoldl' hop f acc' = liftI (step 0 acc')
    step ihop acc (I.Chunk buf) = guardNull buf (liftI (step ihop acc)) $ do
      (newacc, numExtra) <- liftIO $ do
        IB.drop ihop buf
        buflen <- IB.length buf
        let numExtra = buflen `rem` hop
        res <- IB.hopfoldl' hop f acc buf
        IB.drop numExtra buf
        return (res,numExtra)
      liftI (step (hop-numExtra) newacc)
    step _ acc str = idone acc str
{-# INLINE hopfoldl' #-}

-- -----------------------------------------------------------
-- Enumeratees

type MEnumeratee sFrom sTo m a = MIteratee sTo m a -> MIteratee sFrom m (MIteratee sTo m a)

eneeCheckIfDone ::
  (Monad m, Storable elo) =>
  ((I.Stream eli -> MIteratee eli m a)
    -> MIteratee (IOBuffer r elo) m (MIteratee eli m a))
  -> MEnumeratee (IOBuffer r elo) eli m a
eneeCheckIfDone f inner = MIteratee $ I.Iteratee $ \od oc ->
  let onDone x s = od (idone x s) (I.Chunk empty)
      onCont k Nothing  = I.runIter (unwrap $ f (MIteratee . k)) od oc
      onCont _ (Just e) = I.runIter (I.throwErr e) od oc
  in I.runIter (unwrap inner) onDone (onCont)
{-# INLINE eneeCheckIfDone #-}

mapStream :: (MonadCatchIO pr, Storable elo, Storable eli) =>
  -> (eli -> elo)
  -> MEnumeratee (IOBuffer r eli) (IOBuffer r elo) pr a
mapStream maxlen f i = do
  offp <- liftIO $ newFp 0
  bufp <- liftIO $ mallocForeignPtrArray maxlen
  goMap offp bufp i
    goMap offp bufp = eneeCheckIfDone (liftI . step offp bufp)
    step offp bufp k (I.Chunk buf) =
      guardNull buf (liftI (step offp bufp k)) $ do
        newIOBuf <- liftIO $ IB.mapBuffer f offp bufp buf
        goMap offp bufp $ k (I.Chunk newIOBuf)
    step _    _    k s             = idone (liftI k) s

mapChunk :: (Storable el, MonadCatchIO m) =>
  (IOBuffer r el -> m s2)
   -> MEnumeratee (IOBuffer r el) s2 m a
mapChunk f i = go i
  go = eneeCheckIfDone (liftI . step)
  step k (I.Chunk buf) = lift (f buf) >>= go . k . I.Chunk
  step k s             = idone (liftI k) s
{-# INLINE mapChunk #-}

mapAccum :: (MonadCatchIO pr, Storable eli, Storable elo) =>
  -> (b -> eli -> (b,elo))
  -> b
  -> MEnumeratee (IOBuffer r eli) (IOBuffer r elo) pr a
mapAccum maxlen f acc i = do
  offp <- liftIO $ newFp 0
  bufp <- liftIO $ mallocForeignPtrArray maxlen
  goMap offp bufp acc i
    goMap offp bufp acc = eneeCheckIfDone (liftI . step offp bufp acc)
    step offp bufp acc k (I.Chunk buf) =
      guardNull buf (liftI (step offp bufp acc k)) $ do
        (newAcc, newBuf) <- liftIO $ IB.mapAccumBuffer f offp bufp acc buf
        goMap offp bufp newAcc $ k (I.Chunk newBuf)
    step _    _    _   k s             = idone (liftI k) s

convStream :: (MonadCatchIO pr, Storable elo, Storable eli) =>
  MIteratee (IOBuffer r eli) pr (IOBuffer r elo)
  -> MEnumeratee (IOBuffer r eli) (IOBuffer r elo) pr a
convStream fi = eneeCheckIfDone check
    check k = isStreamFinished >>=
              maybe (step k) (idone (liftI k) . I.EOF . Just)
    step k = fi >>= convStream fi . k . I.Chunk

-- |The most general stream converter.  Given a function to produce iteratee
-- transformers and an initial state, convert the stream using iteratees
-- generated by the function while continually updating the internal state.
unfoldConvStream ::
  (MonadCatchIO m, Storable eli, Storable elo)
  => (acc -> MIteratee (IOBuffer r eli) m (acc, IOBuffer r elo))
  -> acc
  -> MEnumeratee (IOBuffer r eli) (IOBuffer r elo) m a
unfoldConvStream f acc0 = eneeCheckIfDone (check acc0)
    check acc k = isStreamFinished >>=
                    maybe (step acc k) (idone (liftI k) . I.EOF . Just)
    step acc k = f acc >>= \(acc', s') ->
                    eneeCheckIfDone (check acc') . k . I.Chunk $ s'

-- | Decimate a stream by taking every n'th element, starting at element "m".
getChannel ::
  (MonadCatchIO m, Storable el)
  => Int
  -> Int
  -> MEnumeratee (IOBuffer r el) (IOBuffer r el) m a
getChannel 1        _   = convStream chunk
getChannel numChans chn = unfoldConvStream mkIter chn
  mkIter drp = do
    drop drp
    buf <- chunk
    tlen <- liftIO $ IB.length buf
    newbuf <- liftIO $ IB.decimate numChans buf
    return (tlen `rem` numChans, newbuf)

takeUpTo :: (MonadCatchIO pr, Storable el) =>
  -> MEnumeratee (IOBuffer r el) (IOBuffer r el) pr a
takeUpTo 0 iter = return iter
takeUpTo i iter = MIteratee $ I.Iteratee $ \od oc ->
  I.runIter (unwrap iter) (onDone od oc) (onCont od oc)
    onDone od oc x _ = od (return x) (I.Chunk empty)
    onCont od oc k Nothing =
      if i == 0 then od (liftI (MIteratee . k)) (I.Chunk empty)
         else I.runIter (unwrap . liftI $ step i (MIteratee . k)) od oc
    onCont od oc _ (Just e) = I.runIter (I.throwErr e) od oc
    step n k (I.Chunk buf) = guardNull buf (liftI (step n (k))) $ do
      blen <- liftIO $ IB.length buf
      if blen <= n then takeUpTo (n - blen) $ k (I.Chunk buf)
         else do
           (s1, s2) <- liftIO $ IB.splitAt buf n
           idone (k (I.Chunk s1)) (I.Chunk s2)
    step _ k str = idone (k str) str

-- ---------------------------------------------
-- drivers and enumerators

-- | Convert a Vector iteratee to an MIteratee.  Slower but convenient.
fromUVector :: (U.Unbox el, Storable el, MonadCatchIO m) =>
  I.Iteratee (U.Vector el) m a
  -> MIteratee (IOBuffer r el) m a
fromUVector = joinIob . mapChunk (liftIO . IB.freeze) . MIteratee
{-# INLINE fromUVector #-}

makeHandleCallback :: (MonadCatchIO m, Storable el) =>
  ForeignPtr Int
  -> ForeignPtr el
  -> Int
  -> Handle
  -> ()
  -> m (Either SomeException ((Bool, ()), (IOBuffer r el)))
makeHandleCallback offp fp bsize h () = liftIO $ do
  n' <- withForeignPtr fp $ \p -> (CIO.try $ (hGetBuf h p bsize) ::
            IO (Either SomeException Int))
  case n' of
    Left e -> return $ Left e
    Right 0 -> return $ Right ((False, ()), empty)
    Right n -> do
      withForeignPtr offp $ flip poke 0
      return . (\s -> Right ((True, ()), s)) $
                 IB.createIOBuffer (fromIntegral n) offp fp

 forall e r el m a.(I.IException e, MonadCatchIO m, Storable el) =>
  -> Handle
  -> (e -> m (Maybe I.EnumException))
  -> MIteratee (IOBuffer r el) m a
  -> m (MIteratee (IOBuffer r el) m a)
enumHandleCatch bs h handler i = do
  let numbytes = bs * sizeOf (undefined :: el)
  bufp <- liftIO $ mallocForeignPtrArray bs
  offp <- liftIO $ newFp 0
  liftM MIteratee $ I.enumFromCallbackCatch
                      (makeHandleCallback offp bufp numbytes h)
                      (unwrap i)

enumHandleRandom :: forall r el m a.(MonadCatchIO m, Storable el) =>
  Int -- ^ Buffer size (number of elements per read)
  -> Handle
  -> MIteratee (IOBuffer r el) m a
  -> m (MIteratee (IOBuffer r el) m a)
enumHandleRandom bs h i = enumHandleCatch bs h handler i
    handler (I.SeekException off) =
       liftM (either
              (Just . I.EnumException :: IOException -> Maybe I.EnumException)
              (const Nothing))
             . liftIO . CIO.try $ hSeek h AbsoluteSeek $ fromIntegral off

fileDriverRandom :: (MonadCatchIO m, Storable el) =>
  -> (forall r. MIteratee (IOBuffer r el) m a)
  -> FilePath
  -> m a
fileDriverRandom bufsize iter filepath = CIO.bracket
  (liftIO $ openBinaryFile filepath ReadMode)
  (liftIO . hClose)
  ((I.run . unwrap) <=< flip (enumHandleRandom bufsize) iter)

newFp :: Storable a => a -> IO (ForeignPtr a)
newFp a = mallocForeignPtr >>= \fp ->
            withForeignPtr fp (flip poke a) >> return fp