{-# LANGUAGE CPP, TypeFamilies #-}
#if __GLASGOW_HASKELL__ >= 800
{-# OPTIONS_GHC -Wno-orphans #-}
#else
{-# OPTIONS_GHC -fno-warn-orphans #-}
#endif
module Bio.Streaming
    ( MonadIO(..)
    , MonadMask
    , ByteStream

    , streamFile
    , streamHandle
    , streamInput
    , streamInputs
    , withOutputFile

    , UnwantedTerminal(..)
    , protectTerm
    , psequence
    , progressGen
    , progressNum
    , progressPos

    , mergeStreams
    , mergeStreamsBy
    , mergeStreamsOn

    , module Streaming
    , module Streaming.Prelude )
  where

import Bio.Bam.Header
import Bio.Prelude
import Bio.Streaming.Bytes
import Bio.Util.Numeric                     ( showNum )
import Streaming                     hiding ( (<>) )
import Streaming.Internal                   ( Stream(..) )
import Streaming.Prelude                    ( each )
import System.IO                            ( hIsTerminalDevice )

import qualified Streaming.Prelude      as Q

instance (Functor f, PrimMonad m) => PrimMonad (Stream f m) where
    type PrimState (Stream f m) = PrimState m
    primitive = lift . primitive

{- | Default buffer size in elements.

Since we often want to merge many files, a read should take more time
than a seek.  Assuming a rotating hard drive, this sets the sensible
buffer size to somewhat more than one MB.  A smaller buffer size would
surely work on SSDs, but the large buffer doesn't hurt either.
-}
defaultBufSize :: Int
defaultBufSize = 2*1024*1024

streamFile :: (MonadIO m, MonadMask m) => FilePath -> (ByteStream m () -> m r) -> m r
streamFile f k = bracket (liftIO $ openBinaryFile f ReadMode) (liftIO . hClose) (k . streamHandle)
{-# INLINE streamFile #-}

streamHandle :: MonadIO m => Handle -> ByteStream m ()
streamHandle = hGetContentsN defaultBufSize
{-# INLINE streamHandle #-}

-- | Reads 'stdin' if the filename is \"-\", else reads the named file.
streamInput :: (MonadIO m, MonadMask m) => FilePath -> (ByteStream m () -> m r) -> m r
streamInput "-" k = k (streamHandle stdin)
streamInput  f  k = streamFile f k
{-# INLINE streamInput #-}

{- | Reads multiple inputs in sequence.

Only one file is opened at a time, so they must also be consumed in
sequence.  The filename \"-\" refers to stdin, if no filenames are
given, stdin is read.
-}
streamInputs :: MonadIO m => [FilePath] -> (Stream (ByteStream m) m () -> r) -> r
streamInputs [] k = k $ yields (streamHandle stdin)
streamInputs fs k = k $ mapM_ go fs
  where
    go "-" = yields (streamHandle stdin)
    go  f  = yields $ do h <- liftIO $ openBinaryFile f ReadMode
                         streamHandle h
                         liftIO $ hClose h
{-# INLINE streamInputs #-}

data UnwantedTerminal = UnwantedTerminal deriving (Typeable, Show)
instance Exception UnwantedTerminal where
    displayException _ = "cowardly refusing to write binary data to terminal"

{- | Protects the terminal from binary junk.

If @s@ is a 'Stream', then @protectTerm s@ throws an error if 'stdout'
is a terminal device, followed by the same 'Stream'.  This is most
usefully composed with functions that might otherwise write binary data
to an interactive terminal.
-}
protectTerm :: (Functor f, MonadIO m) => Stream f m r -> Stream f m r
protectTerm str = do
    t <- liftIO $ hIsTerminalDevice stdout
    when t . liftIO . throwM $ UnwantedTerminal
    str
{-# INLINE protectTerm #-}

{- Like 'Streaming.sequence', but parallel.

This runs each element of a stream of actions.  A configurable number of
actions are buffered and run asynchronously.
-}
psequence :: MonadIO m => Int -> Stream (Of (IO a)) m b -> Stream (Of a) m b
psequence np = go emptyQ
  where
    -- if the queue is full, wait for the head element to complete
    go !qq s = case popQ qq of
        Just (a,qq') | lengthQ qq == np -> reap a >>= wrap . (:> go qq' s)
        _                               -> lift (inspect s) >>= go' qq

    -- if we have room for input, we get input
    go' !qq (Right (k :> s)) = liftIO (spawn k) >>= \a -> go (pushQ a qq) s
    go' !qq (Left         r) = goE r qq

    -- input ended, empty the queue
    goE r !qq = case popQ qq of
        Nothing      -> pure r
        Just (a,qq') -> reap a >>= wrap . (:> goE r qq')

    spawn :: IO a -> IO (MVar (Either SomeException a))
    spawn k = newEmptyMVar                  >>= \mv ->
              forkIO (try k >>= putMVar mv) >>
              return mv

    reap mv = liftIO (takeMVar mv) >>= either (liftIO . throwM) return


-- A very simple queue data type.
-- Invariants: q = QQ l f b --> l == length f + length b
--                          --> l == 0 ==> null f

data QQ a = QQ !Int [a] [a]

emptyQ :: QQ a
emptyQ = QQ 0 [] []

lengthQ :: QQ a -> Int
lengthQ (QQ l _ _) = l

pushQ :: a -> QQ a -> QQ a
pushQ a (QQ l [] b) = QQ (l+1) (reverse (a:b)) []
pushQ a (QQ l  f b) = QQ (l+1) f (a:b)

popQ :: QQ a -> Maybe (a, QQ a)
popQ (QQ _ [    ] _) = Nothing
popQ (QQ l [ a  ] b) = Just (a, QQ (l-1) (reverse b) [])
popQ (QQ l (a:fs) b) = Just (a, QQ (l-1) fs b)


mergeStreams :: (Monad m, Ord a)
             => Stream (Of a) m r -> Stream (Of a) m s -> Stream (Of a) m (r, s)
mergeStreams = mergeStreamsBy compare
{-# INLINE mergeStreams #-}

mergeStreamsOn :: (Monad m, Ord b)
               => (a -> b) -> Stream (Of a) m r -> Stream (Of a) m s -> Stream (Of a) m (r, s)
mergeStreamsOn f = mergeStreamsBy (comparing f)
{-# INLINE mergeStreamsOn #-}

mergeStreamsBy :: Monad m
               => (a -> a -> Ordering)
               -> Stream (Of a) m r -> Stream (Of a) m s -> Stream (Of a) m (r, s)
mergeStreamsBy cmp = go
  where
    go str0 str1 = case str0 of
      Return r0         -> (\r1 -> (r0, r1)) <$> str1
      Effect m          -> Effect $ liftM (\str -> go str str1) m
      Step (a :> rest0) -> case str1 of
        Return r1         -> (\r0 -> (r0, r1)) <$> str0
        Effect m          -> Effect $ liftM (go str0) m
        Step (b :> rest1) -> case cmp a b of
          LT -> Step (a :> go rest0 str1)
          EQ -> Step (a :> go rest0 str1) -- left-biased
          GT -> Step (b :> go str0 rest1)
{-# INLINABLE mergeStreamsBy #-}

-- | A general progress indicator that logs some message after a set
-- number of records have passed through.
progressGen :: MonadLog m => (Int -> a -> String) -> Int -> Q.Stream (Q.Of a) m r -> Q.Stream (Q.Of a) m r
progressGen msg sz = go 0
  where
    go   !n = lift . Q.next >=> either fin (step $ succ n)
    step !n (a,s) = do when (n `mod` sz == 0) . lift . logString_ $ msg n a
                       Q.cons a (go n s)
    fin r = r <$ lift (logString_ "")

-- | A simple progress indicator that logs the number of records.
progressNum :: MonadLog m => String -> Int -> Q.Stream (Q.Of a) m r -> Q.Stream (Q.Of a) m r
progressNum msg = progressGen (\n _ -> msg ++ " " ++ showNum n)

-- | A simple progress indicator that logs a position every set number
-- of passed records.
progressPos :: MonadLog m
            => (a -> (Refseq, Int)) -> String -> Refs -> Int
            -> Q.Stream (Q.Of a) m r -> Q.Stream (Q.Of a) m r
progressPos f msg refs =
    progressGen $ \_ a -> let (!rs1, !po1) = f a
                              !nm = unpack . sq_name $ getRef refs rs1
                          in msg ++ " " ++ nm ++ ":" ++ showNum po1