{-# LANGUAGE 
     FlexibleInstances
    ,DeriveFunctor
    ,RankNTypes
    ,ScopedTypeVariables #-}

module Data.MutableIter (
  MIteratee (..)
  ,IOBuffer (..)
  ,IB.createIOBuffer
  ,MEnumerator
  ,MEnumeratee
  ,joinIob
  ,joinIM
  ,wrapEnum
  ,liftI
  ,idone
  ,icont
  ,guardNull
  ,isStreamFinished
  ,head
  ,heads
  ,chunk
  ,peek
  ,drop
  ,dropWhile
  ,foldl'
  ,hopfoldl'
  ,mapStream
  ,mapChunk
  ,mapAccum
  ,convStream
  ,unfoldConvStream
  ,getChannel
  ,takeUpTo
  ,fromUVector
  ,enumHandleRandom
  ,fileDriverRandom
  ,newFp
  )

where

import Prelude hiding (head, null, drop, dropWhile, catch)
import qualified Prelude as P

import qualified Data.MutableIter.IOBuffer as IB
import Data.MutableIter.IOBuffer (IOBuffer, null, empty, pop)

import qualified Data.Iteratee as I
import Data.Maybe
import qualified Data.Vector.Unboxed as U

import Control.Exception (SomeException, IOException, toException)
import Control.Monad
import Control.Monad.CatchIO as CIO
import Control.Monad.IO.Class
import Control.Monad.Trans.Class
import Foreign.ForeignPtr
import Foreign.Storable (Storable, sizeOf, poke)
import Foreign.Marshal.Array

import System.IO

isNullChunk :: (MonadIO s) => I.Stream (IOBuffer r el) -> s Bool
isNullChunk (I.EOF _) = return False
isNullChunk (I.Chunk s) = liftIO $ null s

newtype MIteratee s m a = MIteratee {unwrap :: I.Iteratee s m a} deriving (Functor)

instance (MonadIO m, Storable el) => Monad (MIteratee (IOBuffer r el) m) where
  {-# INLINE return #-}
  return x = MIteratee (I.Iteratee $ \onDone _ -> onDone x (I.Chunk empty))
  -- {-# INLINE (>>=) #-} -- this inline makes things a bit slower with 6.12.  Seems fixed in ghc-7
  m >>= f = MIteratee (I.Iteratee $ \onDone onCont ->
    let mDone a str = do
          isNull <- isNullChunk str
          if isNull
            then I.runIter (unwrap $ f a) onDone onCont
            else I.runIter (unwrap $ f a) (const . flip onDone str) (fCont str)
        fCont str k Nothing = I.runIter (k str) onDone onCont
        fCont _   k e       = onCont k e
    in I.runIter (unwrap m) mDone (\k -> onCont (unwrap . (>>= f) . MIteratee . k)))

instance (Storable el) => MonadTrans (MIteratee (IOBuffer s el)) where
  lift m = MIteratee $ I.Iteratee $ \od _ -> m >>= flip od (I.Chunk empty)

instance (MonadIO m, Storable el)
  => MonadIO (MIteratee (IOBuffer r el) m) where
    liftIO = lift . liftIO

instance (MonadCatchIO m, Storable el) =>
  MonadCatchIO (MIteratee (IOBuffer r el) m) where
    catch m f = MIteratee $ I.Iteratee $ \od oc ->
      I.runIter (unwrap m) od oc `catch` (\e -> I.runIter (unwrap $ f e) od oc)
    block   = mapIteratee block
    unblock = mapIteratee unblock

mapIteratee :: (Monad m, Storable el) =>
  (m a -> m b)
  -> MIteratee (IOBuffer r el) m a
  -> MIteratee (IOBuffer r el) m b
mapIteratee f = lift . f . I.run . unwrap

joinIob :: (MonadCatchIO m, Storable el) =>
  MIteratee (IOBuffer r el) m (MIteratee s m a)
  -> MIteratee (IOBuffer r el) m a
joinIob outer = outer >>= \inner -> MIteratee $ I.Iteratee $ \od oc ->
  let onDone  x _        = od x (I.Chunk empty)
      onCont  k Nothing  = I.runIter (k (I.EOF Nothing)) onDone onCont'
      onCont  _ (Just e) = I.runIter (I.throwErr e) od oc
      onCont' _ e        = I.runIter (I.throwErr (fromMaybe excDiv e)) od oc
  in I.runIter (unwrap inner) onDone onCont

joinIM :: (Monad m) =>
  m (MIteratee (IOBuffer r el) m a)
  -> MIteratee (IOBuffer r el) m a
joinIM mIter = MIteratee (I.Iteratee (\od oc -> mIter >>= (\iter ->
                 I.runIter (unwrap iter) od oc)))

excDiv :: SomeException
excDiv = toException I.DivergentException

type MEnumerator s m a = MIteratee s m a -> m (MIteratee s m a)

wrapEnum :: (Monad m) => I.Enumerator s m a -> MEnumerator s m a
wrapEnum enum = liftM MIteratee . enum . unwrap

liftI :: (Monad m) => (I.Stream s -> MIteratee s m a) -> MIteratee s m a
liftI = MIteratee . I.liftI . (unwrap .)

idone x str = MIteratee $ I.idone x str

icont k mErr = MIteratee $ I.icont (unwrap . k) mErr

guardNull :: (MonadCatchIO m, Storable el) =>
  IOBuffer r el
  -> MIteratee (IOBuffer r el) m a
  -> MIteratee (IOBuffer r el) m a
  -> MIteratee (IOBuffer r el) m a
guardNull buf onEmpty onFull = do
  isNull <- liftIO $ null buf
  if isNull then onEmpty else onFull
{-# INLINE guardNull #-}

isStreamFinished ::
 (MonadCatchIO m, Storable el) =>
  MIteratee (IOBuffer r el) m (Maybe SomeException)
isStreamFinished = liftI step
  where
    step s@(I.Chunk buf) = guardNull buf (liftI step) (idone Nothing s)
    step s@(I.EOF e) =
      idone (Just $ fromMaybe (toException I.EofException) e) s
{-# INLINE isStreamFinished #-}

head :: (MonadCatchIO m, Storable el) => MIteratee (IOBuffer r el) m el
head = liftI step
  where
    step (I.Chunk buf) = guardNull buf (liftI step) $ do
      x <- liftIO $ pop buf
      idone x (I.Chunk buf)
    step str = MIteratee . I.throwErr . toException $ I.EofException
{-# INLINE head #-}

peek :: (MonadCatchIO m, Storable el) => MIteratee (IOBuffer r el) m (Maybe el)
peek = liftI step
  where
    step c@(I.Chunk buf) = guardNull buf (liftI step) $ do
      x <- liftIO $ IB.lookAtHead buf
      idone x c
    step str = idone Nothing str
{-# INLINE peek #-}

heads :: (MonadCatchIO m, Storable el, Eq el) =>
  [el]
  -> MIteratee (IOBuffer r el) m Int
heads st = loop 0 st
  where
    loop cnt [] = return cnt
    loop cnt str = liftI (step cnt str)
    step cnt s@(x:xs) (I.Chunk buf) = do
      mx <- liftIO $ IB.lookAtHead buf
      maybe (liftI (step cnt s)) (\h -> if h == x
        then liftIO (pop buf) >> step (succ cnt) xs (I.Chunk buf)
        else idone cnt (I.Chunk buf)) mx
    step cnt s str = idone cnt str
{-# INLINE heads #-}

chunk ::
  (MonadCatchIO m, Storable el)
  => MIteratee (IOBuffer r el) m (IOBuffer r el)
chunk = liftI step
 where
  step (I.Chunk buf) = guardNull buf (liftI step) $ idone buf (I.Chunk empty)
  step str = MIteratee . I.throwErr . toException $ I.EofException

drop :: (MonadCatchIO m, Storable el) => Int -> MIteratee (IOBuffer r el) m ()
drop n = liftI (step n)
  where
    step 0  str           = idone () str
    step n' (I.Chunk buf) = guardNull buf (liftI (step n')) $ do
      l <- liftIO $ IB.length buf
      if n' < l
        then liftIO (IB.drop n' buf) >> idone () (I.Chunk buf)
        else liftIO (IB.drop l buf) >> liftI (step (n' - l))
    step _ str = idone () str
{-# INLINE drop #-}

dropWhile :: (MonadCatchIO m, Storable el) =>
  (el -> Bool)
  -> MIteratee (IOBuffer r el) m ()
dropWhile pred = liftI step
  where
    step (I.Chunk buf) = guardNull buf (liftI step) $ do
      liftIO $ IB.dropWhile pred buf
      l <- liftIO $ IB.length buf
      if l == 0 then liftI step else idone () (I.Chunk buf)
    step str = idone () str
{-# INLINE dropWhile #-}

foldl' :: (MonadCatchIO m, Storable el, Show a) =>
  (a -> el -> a)
  -> a
  -> MIteratee (IOBuffer r el) m a
foldl' f acc' = liftI (step acc')
  where
    step acc (I.Chunk buf) = guardNull buf (liftI (step acc)) $ do
      newacc <- liftIO $ IB.foldl' f acc buf
      liftI (step newacc)
    step acc str = idone acc str
{-# INLINE foldl' #-}

hopfoldl' :: (MonadCatchIO m, Storable el, Show a) =>
  Int
  -> (a -> el -> a)
  -> a
  -> MIteratee (IOBuffer r el) m a
hopfoldl' hop f acc' = liftI (step 0 acc')
  where
    step ihop acc (I.Chunk buf) = guardNull buf (liftI (step ihop acc)) $ do
      (newacc, numExtra) <- liftIO $ do
        IB.drop ihop buf
        buflen <- IB.length buf
        let numExtra = buflen `rem` hop
        res <- IB.hopfoldl' hop f acc buf
        IB.drop numExtra buf
        return (res,numExtra)
      liftI (step (hop-numExtra) newacc)
    step _ acc str = idone acc str
{-# INLINE hopfoldl' #-}


-- -----------------------------------------------------------
-- Enumeratees

type MEnumeratee sFrom sTo m a = MIteratee sTo m a -> MIteratee sFrom m (MIteratee sTo m a)

eneeCheckIfDone ::
  (Monad m, Storable elo) =>
  ((I.Stream eli -> MIteratee eli m a)
    -> MIteratee (IOBuffer r elo) m (MIteratee eli m a))
  -> MEnumeratee (IOBuffer r elo) eli m a
eneeCheckIfDone f inner = MIteratee $ I.Iteratee $ \od oc ->
  let onDone x s = od (idone x s) (I.Chunk empty)
      onCont k Nothing  = I.runIter (unwrap $ f (MIteratee . k)) od oc
      onCont _ (Just e) = I.runIter (I.throwErr e) od oc
  in I.runIter (unwrap inner) onDone (onCont)
{-# INLINE eneeCheckIfDone #-}

mapStream :: (MonadCatchIO pr, Storable elo, Storable eli) =>
  Int
  -> (eli -> elo)
  -> MEnumeratee (IOBuffer r eli) (IOBuffer r elo) pr a
mapStream maxlen f i = do
  offp <- liftIO $ newFp 0
  bufp <- liftIO $ mallocForeignPtrArray maxlen
  goMap offp bufp i
   where
    goMap offp bufp = eneeCheckIfDone (liftI . step offp bufp)
    step offp bufp k (I.Chunk buf) =
      guardNull buf (liftI (step offp bufp k)) $ do
        newIOBuf <- liftIO $ IB.mapBuffer f offp bufp buf
        goMap offp bufp $ k (I.Chunk newIOBuf)
    step _    _    k s             = idone (liftI k) s

mapChunk :: (Storable el, MonadCatchIO m) =>
  (IOBuffer r el -> m s2)
   -> MEnumeratee (IOBuffer r el) s2 m a
mapChunk f i = go i
 where
  go = eneeCheckIfDone (liftI . step)
  step k (I.Chunk buf) = lift (f buf) >>= go . k . I.Chunk
  step k s             = idone (liftI k) s
{-# INLINE mapChunk #-}


mapAccum :: (MonadCatchIO pr, Storable eli, Storable elo) =>
  Int
  -> (b -> eli -> (b,elo))
  -> b
  -> MEnumeratee (IOBuffer r eli) (IOBuffer r elo) pr a
mapAccum maxlen f acc i = do
  offp <- liftIO $ newFp 0
  bufp <- liftIO $ mallocForeignPtrArray maxlen
  goMap offp bufp acc i
   where
    goMap offp bufp acc = eneeCheckIfDone (liftI . step offp bufp acc)
    step offp bufp acc k (I.Chunk buf) =
      guardNull buf (liftI (step offp bufp acc k)) $ do
        (newAcc, newBuf) <- liftIO $ IB.mapAccumBuffer f offp bufp acc buf
        goMap offp bufp newAcc $ k (I.Chunk newBuf)
    step _    _    _   k s             = idone (liftI k) s

convStream :: (MonadCatchIO pr, Storable elo, Storable eli) =>
  MIteratee (IOBuffer r eli) pr (IOBuffer r elo)
  -> MEnumeratee (IOBuffer r eli) (IOBuffer r elo) pr a
convStream fi = eneeCheckIfDone check
  where
    check k = isStreamFinished >>=
              maybe (step k) (idone (liftI k) . I.EOF . Just)
    step k = fi >>= convStream fi . k . I.Chunk

-- |The most general stream converter.  Given a function to produce iteratee
-- transformers and an initial state, convert the stream using iteratees
-- generated by the function while continually updating the internal state.
unfoldConvStream ::
  (MonadCatchIO m, Storable eli, Storable elo)
  => (acc -> MIteratee (IOBuffer r eli) m (acc, IOBuffer r elo))
  -> acc
  -> MEnumeratee (IOBuffer r eli) (IOBuffer r elo) m a
unfoldConvStream f acc0 = eneeCheckIfDone (check acc0)
  where
    check acc k = isStreamFinished >>=
                    maybe (step acc k) (idone (liftI k) . I.EOF . Just)
    step acc k = f acc >>= \(acc', s') ->
                    eneeCheckIfDone (check acc') . k . I.Chunk $ s'

-- | Decimate a stream by taking every n'th element, starting at element "m".
getChannel ::
  (MonadCatchIO m, Storable el)
  => Int
  -> Int
  -> MEnumeratee (IOBuffer r el) (IOBuffer r el) m a
getChannel 1        _   = convStream chunk
getChannel numChans chn = unfoldConvStream mkIter chn
 where
  mkIter drp = do
    drop drp
    buf <- chunk
    tlen <- liftIO $ IB.length buf
    newbuf <- liftIO $ IB.decimate numChans buf
    return (tlen `rem` numChans, newbuf)

takeUpTo :: (MonadCatchIO pr, Storable el) =>
  Int
  -> MEnumeratee (IOBuffer r el) (IOBuffer r el) pr a
takeUpTo 0 iter = return iter
takeUpTo i iter = MIteratee $ I.Iteratee $ \od oc ->
  I.runIter (unwrap iter) (onDone od oc) (onCont od oc)
  where
    onDone od oc x _ = od (return x) (I.Chunk empty)
    onCont od oc k Nothing =
      if i == 0 then od (liftI (MIteratee . k)) (I.Chunk empty)
         else I.runIter (unwrap . liftI $ step i (MIteratee . k)) od oc
    onCont od oc _ (Just e) = I.runIter (I.throwErr e) od oc
    step n k (I.Chunk buf) = guardNull buf (liftI (step n (k))) $ do
      blen <- liftIO $ IB.length buf
      if blen <= n then takeUpTo (n - blen) $ k (I.Chunk buf)
         else do
           (s1, s2) <- liftIO $ IB.splitAt buf n
           idone (k (I.Chunk s1)) (I.Chunk s2)
    step _ k str = idone (k str) str

-- ---------------------------------------------
-- drivers and enumerators

-- | Convert a Vector iteratee to an MIteratee.  Slower but convenient.
fromUVector :: (U.Unbox el, Storable el, MonadCatchIO m) =>
  I.Iteratee (U.Vector el) m a
  -> MIteratee (IOBuffer r el) m a
fromUVector = joinIob . mapChunk (liftIO . IB.freeze) . MIteratee
{-# INLINE fromUVector #-}

makeHandleCallback :: (MonadCatchIO m, Storable el) =>
  ForeignPtr Int
  -> ForeignPtr el
  -> Int
  -> Handle
  -> ()
  -> m (Either SomeException ((Bool, ()), (IOBuffer r el)))
makeHandleCallback offp fp bsize h () = liftIO $ do
  n' <- withForeignPtr fp $ \p -> (CIO.try $ (hGetBuf h p bsize) ::
            IO (Either SomeException Int))
  case n' of
    Left e -> return $ Left e
    Right 0 -> return $ Right ((False, ()), empty)
    Right n -> do
      withForeignPtr offp $ flip poke 0
      return . (\s -> Right ((True, ()), s)) $
                 IB.createIOBuffer (fromIntegral n) offp fp

enumHandleCatch
 ::
 forall e r el m a.(I.IException e, MonadCatchIO m, Storable el) =>
  Int
  -> Handle
  -> (e -> m (Maybe I.EnumException))
  -> MIteratee (IOBuffer r el) m a
  -> m (MIteratee (IOBuffer r el) m a)
enumHandleCatch bs h handler i = do
  let numbytes = bs * sizeOf (undefined :: el)
  bufp <- liftIO $ mallocForeignPtrArray bs
  offp <- liftIO $ newFp 0
  liftM MIteratee $ I.enumFromCallbackCatch
                      (makeHandleCallback offp bufp numbytes h)
                      handler
                      ()
                      (unwrap i)

enumHandleRandom :: forall r el m a.(MonadCatchIO m, Storable el) =>
  Int -- ^ Buffer size (number of elements per read)
  -> Handle
  -> MIteratee (IOBuffer r el) m a
  -> m (MIteratee (IOBuffer r el) m a)
enumHandleRandom bs h i = enumHandleCatch bs h handler i
  where
    handler (I.SeekException off) =
       liftM (either
              (Just . I.EnumException :: IOException -> Maybe I.EnumException)
              (const Nothing))
             . liftIO . CIO.try $ hSeek h AbsoluteSeek $ fromIntegral off

fileDriverRandom :: (MonadCatchIO m, Storable el) =>
  Int
  -> (forall r. MIteratee (IOBuffer r el) m a)
  -> FilePath
  -> m a
fileDriverRandom bufsize iter filepath = CIO.bracket
  (liftIO $ openBinaryFile filepath ReadMode)
  (liftIO . hClose)
  ((I.run . unwrap) <=< flip (enumHandleRandom bufsize) iter)

newFp :: Storable a => a -> IO (ForeignPtr a)
newFp a = mallocForeignPtr >>= \fp ->
            withForeignPtr fp (flip poke a) >> return fp