{-# LANGUAGE 
     FlexibleInstances
    ,DeriveFunctor
    ,RankNTypes
    ,ScopedTypeVariables #-}

module Data.MutableIter (
  MIteratee (..)
  ,IOBuffer (..)
  ,IB.createIOBuffer
  ,MEnumerator
  ,MEnumeratee
  ,joinIob
  ,joinIM
  ,wrapEnum
  ,liftI
  ,idone
  ,icont
  ,guardNull
  ,head
  ,heads
  ,peek
  ,drop
  ,dropWhile
  ,foldl'
  ,mapStream
  ,mapAccum
  ,convStream
  ,takeUpTo
  ,enumHandleRandom
  ,fileDriverRandom
  ,newFp
  )

where

import Prelude hiding (head, null, drop, dropWhile, catch)
import qualified Prelude as P

import qualified Data.MutableIter.IOBuffer as IB
import Data.MutableIter.IOBuffer (IOBuffer, null, empty, pop)

import qualified Data.Iteratee as I
import Data.Maybe

import Control.Exception (SomeException, IOException)
import Control.Monad
import Control.Monad.CatchIO as CIO
import Control.Monad.IO.Class
import Control.Monad.Trans.Class
import Foreign.ForeignPtr
import Foreign.Storable (Storable, sizeOf, poke)
import Foreign.Marshal.Array

import System.IO

isNullChunk :: (MonadIO s) => I.Stream (IOBuffer r el) -> s Bool
isNullChunk (I.EOF _) = return False
isNullChunk (I.Chunk s) = liftIO $ null s

newtype MIteratee s m a = MIteratee {unwrap :: I.Iteratee s m a} deriving (Functor)

instance (MonadIO m, Storable el) => Monad (MIteratee (IOBuffer r el) m) where
  {-# INLINE return #-}
  return x = MIteratee (I.Iteratee $ \onDone _ -> onDone x (I.Chunk empty))
  -- {-# INLINE (>>=) #-} -- this inline makes things a bit slower with 6.12.  Seems fixed in ghc-7
  m >>= f = MIteratee (I.Iteratee $ \onDone onCont ->
    let mDone a str = do
          isNull <- isNullChunk str
          if isNull
            then I.runIter (unwrap $ f a) onDone onCont
            else I.runIter (unwrap $ f a) (const . flip onDone str) (fCont str)
        fCont str k Nothing = I.runIter (k str) onDone onCont
        fCont _   k e       = onCont k e
    in I.runIter (unwrap m) mDone (\k -> onCont (unwrap . (>>= f) . MIteratee . k)))

instance (Storable el) => MonadTrans (MIteratee (IOBuffer s el)) where
  lift m = MIteratee $ I.Iteratee $ \od _ -> m >>= flip od (I.Chunk empty)

instance (MonadIO m, Storable el)
  => MonadIO (MIteratee (IOBuffer r el) m) where
    liftIO = lift . liftIO

instance (MonadCatchIO m, Storable el) =>
  MonadCatchIO (MIteratee (IOBuffer r el) m) where
    catch m f = MIteratee $ I.Iteratee $ \od oc ->
      I.runIter (unwrap m) od oc `catch` (\e -> I.runIter (unwrap $ f e) od oc)
    block   = mapIteratee block
    unblock = mapIteratee unblock

mapIteratee :: (Monad m, Storable el) =>
  (m a -> m b)
  -> MIteratee (IOBuffer r el) m a
  -> MIteratee (IOBuffer r el) m b
mapIteratee f = lift . f . I.run . unwrap

joinIob :: (MonadCatchIO m, Storable el) =>
  MIteratee (IOBuffer r el) m (MIteratee (IOBuffer r el') m a)
  -> MIteratee (IOBuffer r el) m a
joinIob outer = outer >>= \inner -> MIteratee $ I.Iteratee $ \od oc ->
  let onDone  x _        = od x (I.Chunk empty)
      onCont  k Nothing  = I.runIter (k (I.EOF Nothing)) onDone onCont'
      onCont  _ (Just e) = I.runIter (I.throwErr e) od oc
      onCont' _ e        = I.runIter (I.throwErr (fromMaybe excDiv e)) od oc
  in I.runIter (unwrap inner) onDone onCont

joinIM :: (Monad m) =>
  m (MIteratee (IOBuffer r el) m a)
  -> MIteratee (IOBuffer r el) m a
joinIM mIter = MIteratee (I.Iteratee (\od oc -> mIter >>= (\iter ->
                 I.runIter (unwrap iter) od oc)))

excDiv :: SomeException
excDiv = toException I.DivergentException

type MEnumerator s m a = MIteratee s m a -> m (MIteratee s m a)

wrapEnum :: (Monad m) => I.Enumerator s m a -> MEnumerator s m a
wrapEnum enum = liftM MIteratee . enum . unwrap

liftI :: (Monad m) => (I.Stream s -> MIteratee s m a) -> MIteratee s m a
liftI = MIteratee . I.liftI . (unwrap .)

idone x str = MIteratee $ I.idone x str

icont k mErr = MIteratee $ I.icont (unwrap . k) mErr

guardNull :: (MonadCatchIO m, Storable el) =>
  IOBuffer r el
  -> MIteratee (IOBuffer r el) m a
  -> MIteratee (IOBuffer r el) m a
  -> MIteratee (IOBuffer r el) m a
guardNull buf onEmpty onFull = do
  isNull <- liftIO $ null buf
  if isNull then onEmpty else onFull

head :: (MonadCatchIO m, Storable el) => MIteratee (IOBuffer r el) m el
head = liftI step
  where
    step (I.Chunk buf) = guardNull buf (icont step Nothing) $ do
      x <- liftIO $ pop buf
      idone x (I.Chunk buf)
    step str = MIteratee . I.throwErr . toException $ I.EofException

peek :: (MonadCatchIO m, Storable el) => MIteratee (IOBuffer r el) m (Maybe el)
peek = liftI step
  where
    step c@(I.Chunk buf) = guardNull buf (icont step Nothing) $ do
      x <- liftIO $ IB.lookAtHead buf
      idone x c
    step str = idone Nothing str

heads :: (MonadCatchIO m, Storable el, Eq el) =>
  [el]
  -> MIteratee (IOBuffer r el) m Int
heads st = loop 0 st
  where
    loop cnt [] = return cnt
    loop cnt str = liftI (step cnt str)
    step cnt s@(x:xs) (I.Chunk buf) = do
      mx <- liftIO $ IB.lookAtHead buf
      maybe (icont (step cnt s) Nothing) (\h -> if h == x
        then liftIO (pop buf) >> step (succ cnt) xs (I.Chunk buf)
        else idone cnt (I.Chunk buf)) mx
    step cnt s str = idone cnt str

drop :: (MonadCatchIO m, Storable el) => Int -> MIteratee (IOBuffer r el) m ()
drop n = liftI (step n)
  where
    step 0  str           = idone () str
    step n' (I.Chunk buf) = guardNull buf (icont (step n') Nothing) $ do
      l <- liftIO $ IB.length buf
      if n' < l
        then liftIO (IB.drop n' buf) >> idone () (I.Chunk buf)
        else liftIO (IB.drop l buf) >> liftI (step (n' - l))
    step _ str = idone () str

dropWhile :: (MonadCatchIO m, Storable el) =>
  (el -> Bool)
  -> MIteratee (IOBuffer r el) m ()
dropWhile pred = liftI step
  where
    step (I.Chunk buf) = guardNull buf (icont step Nothing) $ do
      liftIO $ IB.dropWhile pred buf
      l <- liftIO $ IB.length buf
      if l == 0 then icont step Nothing else idone () (I.Chunk buf)
    step str = idone () str

foldl' :: (MonadCatchIO m, Storable el, Show a) =>
  (a -> el -> a)
  -> a
  -> MIteratee (IOBuffer r el) m a
foldl' f acc' = liftI (step acc')
  where
    step acc (I.Chunk buf) = guardNull buf (liftI (step acc)) $ do
      newacc <- liftIO $ IB.foldl' f acc buf
      icont (step newacc) Nothing
    step acc str = idone acc str


-- -----------------------------------------------------------
-- Enumeratees

type MEnumeratee sFrom sTo m a = MIteratee sTo m a -> MIteratee sFrom m (MIteratee sTo m a)

eneeCheckIfDone ::
  (Monad m, Storable elo) =>
  ((I.Stream eli -> MIteratee eli m a)
    -> MIteratee (IOBuffer r elo) m (MIteratee eli m a))
  -> MEnumeratee (IOBuffer r elo) eli m a
eneeCheckIfDone f inner = MIteratee $ I.Iteratee $ \od oc ->
  let onDone x s = od (idone x s) (I.Chunk empty)
      onCont k Nothing  = I.runIter (unwrap $ f (MIteratee . k)) od oc
      onCont _ (Just e) = I.runIter (I.throwErr e) od oc
  in I.runIter (unwrap inner) onDone (onCont)

mapStream :: (MonadCatchIO pr, Storable elo, Storable eli) =>
  Int
  -> (eli -> elo)
  -> MEnumeratee (IOBuffer r eli)
                 (IOBuffer r elo)
                 pr
                 a
mapStream maxlen f i = do
  offp <- liftIO $ newFp 0
  bufp <- liftIO $ mallocForeignPtrArray maxlen
  goMap offp bufp i
   where
    goMap offp bufp = eneeCheckIfDone (liftI . step offp bufp)
    step offp bufp k (I.Chunk buf) =
      guardNull buf (liftI (step offp bufp k)) $ do
        newIOBuf <- liftIO $ IB.mapBuffer f offp bufp buf
        goMap offp bufp $ k (I.Chunk newIOBuf)
    step _    _    k s             = idone (liftI k) s

mapAccum :: (MonadCatchIO pr, Storable eli, Storable elo) =>
  Int
  -> (b -> eli -> (b,elo))
  -> b
  -> MEnumeratee (IOBuffer r eli) (IOBuffer r elo) pr a
mapAccum maxlen f acc i = do
  offp <- liftIO $ newFp 0
  bufp <- liftIO $ mallocForeignPtrArray maxlen
  goMap offp bufp acc i
   where
    goMap offp bufp acc = eneeCheckIfDone (liftI . step offp bufp acc)
    step offp bufp acc k (I.Chunk buf) =
      guardNull buf (liftI (step offp bufp acc k)) $ do
        (newAcc, newBuf) <- liftIO $ IB.mapAccumBuffer f offp bufp acc buf
        goMap offp bufp newAcc $ k (I.Chunk newBuf)
    step _    _    _   k s             = idone (liftI k) s

convStream :: (MonadCatchIO pr, Storable elo, Storable eli) =>
  MIteratee (IOBuffer r eli) pr (IOBuffer r elo)
  -> MEnumeratee (IOBuffer r eli)
                 (IOBuffer r elo)
                 pr
                 a
convStream fi = eneeCheckIfDone check
  where
    check k = MIteratee (I.isStreamFinished) >>=
              maybe (step k) (idone (liftI k) . I.EOF . Just)
    step k = fi >>= convStream fi . k . I.Chunk

takeUpTo :: (MonadCatchIO pr, Storable el) =>
  Int
  -> MEnumeratee (IOBuffer r el) (IOBuffer r el) pr a
takeUpTo 0 iter = return iter
takeUpTo i iter = MIteratee $ I.Iteratee $ \od oc ->
  I.runIter (unwrap iter) (onDone od oc) (onCont od oc)
  where
    onDone od oc x _ = od (return x) (I.Chunk empty)
    onCont od oc k Nothing =
      if i == 0 then od (liftI (MIteratee . k)) (I.Chunk empty)
         else I.runIter (unwrap . liftI $ step i (MIteratee . k)) od oc
    onCont od oc _ (Just e) = I.runIter (I.throwErr e) od oc
    step n k (I.Chunk buf) = guardNull buf (liftI (step n (k))) $ do
      blen <- liftIO $ IB.length buf
      if blen <= n then takeUpTo (n - blen) $ k (I.Chunk buf)
         else do
           (s1, s2) <- liftIO $ IB.splitAt buf n
           idone (k (I.Chunk s1)) (I.Chunk s2)
    step _ k str = idone (k str) str

-- ---------------------------------------------
-- drivers and enumerators

makeHandleCallback :: (MonadCatchIO m, Storable el) =>
  ForeignPtr Int
  -> ForeignPtr el
  -> Int
  -> Handle
  -> ()
  -> m (Either SomeException ((Bool, ()), (IOBuffer r el)))
makeHandleCallback offp fp bsize h () = liftIO $ do
  n' <- withForeignPtr fp $ \p -> (CIO.try $ (hGetBuf h p bsize) ::
            IO (Either SomeException Int))
  case n' of
    Left e -> return $ Left e
    Right 0 -> return $ Right ((False, ()), empty)
    Right n -> do
      withForeignPtr offp $ flip poke 0
      return . (\s -> Right ((True, ()), s)) $
                 IB.createIOBuffer (fromIntegral n) offp fp

enumHandleCatch
 ::
 forall e r el m a.(I.IException e, MonadCatchIO m, Storable el) =>
  Int
  -> Handle
  -> (e -> m (Maybe I.EnumException))
  -> MIteratee (IOBuffer r el) m a
  -> m (MIteratee (IOBuffer r el) m a)
enumHandleCatch bs h handler i = do
  let numbytes = bs * sizeOf (undefined :: el)
  bufp <- liftIO $ mallocForeignPtrArray bs
  offp <- liftIO $ newFp 0
  liftM MIteratee $ I.enumFromCallbackCatch
                      (makeHandleCallback offp bufp numbytes h)
                      handler
                      ()
                      (unwrap i)

enumHandleRandom :: forall r el m a.(MonadCatchIO m, Storable el) =>
  Int -- ^ Buffer size (number of elements per read)
  -> Handle
  -> MIteratee (IOBuffer r el) m a
  -> m (MIteratee (IOBuffer r el) m a)
enumHandleRandom bs h i = enumHandleCatch bs h handler i
  where
    handler (I.SeekException off) =
       liftM (either
              (Just . I.EnumException :: IOException -> Maybe I.EnumException)
              (const Nothing))
             . liftIO . CIO.try $ hSeek h AbsoluteSeek $ fromIntegral off

fileDriverRandom :: (MonadCatchIO m, Storable el) =>
  Int
  -> (forall r. MIteratee (IOBuffer r el) m a)
  -> FilePath
  -> m a
fileDriverRandom bufsize iter filepath = CIO.bracket
  (liftIO $ openBinaryFile filepath ReadMode)
  (liftIO . hClose)
  ((I.run . unwrap) <=< flip (enumHandleRandom bufsize) iter)

newFp :: Storable a => a -> IO (ForeignPtr a)
newFp a = mallocForeignPtr >>= \fp ->
            withForeignPtr fp (flip poke a) >> return fp