module Data.Conduit.Async ( buffer
, ($$&)
, bufferToFile
, gatherFrom
, drainTo
) where
import Control.Applicative
import Control.Concurrent.Async.Lifted
import Control.Concurrent.STM
import Control.Concurrent.STM.TBChan
import Control.Exception.Lifted
import Control.Monad hiding (forM_)
import Control.Monad.IO.Class
import Control.Monad.Loops
import Control.Monad.Trans.Class
import Control.Monad.Trans.Control
import Control.Monad.Trans.Resource
import Data.Conduit
import qualified Data.Conduit.Binary as CB
import qualified Data.Conduit.Cereal as C
import qualified Data.Conduit.List as CL
import Data.Foldable (forM_)
import Data.Serialize as Cereal
import System.Directory (removeFile)
import System.IO
buffer :: (MonadBaseControl IO m, MonadIO m)
=> Int -> Producer m a -> Consumer a m b -> m b
buffer size input output = do
chan <- liftIO $ newTBQueueIO size
control $ \runInIO ->
withAsync (runInIO $ sender chan) $ \input' ->
withAsync (runInIO $ recv chan $$ output) $ \output' -> do
link2 input' output'
wait output'
where
send chan = liftIO . atomically . writeTBQueue chan
sender chan = do
input $$ CL.mapM_ (send chan . Just)
send chan Nothing
recv chan = do
mx <- liftIO $ atomically $ readTBQueue chan
case mx of
Nothing -> return ()
Just x -> yield x >> recv chan
($$&) :: (MonadIO m, MonadBaseControl IO m)
=> Producer m a -> Consumer a m b -> m b
($$&) = buffer 64
data BufferContext m a = BufferContext
{ chan :: TBChan a
, restore :: TChan (Source m a)
, slotsFree :: TVar (Maybe Int)
, done :: TVar Bool
}
bufferToFile :: (MonadBaseControl IO m, MonadIO m, MonadResource m, Serialize a)
=> Int
-> Maybe Int
-> FilePath
-> Producer m a
-> Consumer a m b
-> m b
bufferToFile memorySize fileMax tempDir input output = do
context <- liftIO $ BufferContext
<$> newTBChanIO memorySize
<*> newTChanIO
<*> newTVarIO fileMax
<*> newTVarIO False
control $ \runInIO ->
withAsync (runInIO $ sender context) $ \input' ->
withAsync (runInIO $ recv context $$ output) $ \output' -> do
link2 input' output'
wait output'
where
sender BufferContext {..} = do
input $$ awaitForever $ \x -> join $ liftIO $ atomically $ do
written <- tryWriteTBChan chan x
if written
then return $ return ()
else do
action <- persistChan
writeTBChan chan x
return action
liftIO $ atomically $ writeTVar done True
where
persistChan = do
xs <- exhaust chan
mslots <- readTVar slotsFree
let len = length xs
forM_ mslots $ \slots -> check (len < slots)
filePath <- newEmptyTMVar
writeTChan restore $ do
(path, key) <- liftIO $ atomically $ takeTMVar filePath
CB.sourceFile path $= do
C.conduitGet Cereal.get
liftIO $ atomically $
modifyTVar slotsFree (fmap (+ len))
release key
case xs of
[] -> return $ return ()
_ -> do
modifyTVar slotsFree (fmap (+ (len)))
return $ do
(key, (path, h)) <- allocate
(openTempFile tempDir "conduit.bin")
(\(path, h) -> hClose h >> removeFile path)
liftIO $ do
CL.sourceList xs $= C.conduitPut put
$$ CB.sinkHandle h
hClose h
atomically $ putTMVar filePath (path, key)
recv BufferContext {..} = loop where
loop = do
(src, exit) <- liftIO $ atomically $ do
maction <- tryReadTChan restore
case maction of
Just action -> return (action, False)
Nothing -> do
xs <- exhaust chan
isDone <- readTVar done
return (CL.sourceList xs, isDone)
src
unless exit loop
exhaust chan = whileM (not <$> isEmptyTBChan chan) (readTBChan chan)
gatherFrom :: (MonadIO m, MonadBaseControl IO m)
=> Int
-> (TBQueue o -> m ())
-> Producer m o
gatherFrom size scatter = do
chan <- liftIO $ newTBQueueIO size
worker <- lift $ async (scatter chan)
lift . restoreM =<< gather worker chan
where
gather worker chan = do
(xs, mres) <- liftIO $ atomically $ do
xs <- whileM (not <$> isEmptyTBQueue chan) (readTBQueue chan)
(xs,) <$> pollSTM worker
Prelude.mapM_ yield xs
case mres of
Just (Left e) -> liftIO $ throwIO (e :: SomeException)
Just (Right r) -> return r
Nothing -> gather worker chan
drainTo :: (MonadIO m, MonadBaseControl IO m)
=> Int
-> (TBQueue (Maybe i) -> m r)
-> Consumer i m r
drainTo size gather = do
chan <- liftIO $ newTBQueueIO size
worker <- lift $ async (gather chan)
lift . restoreM =<< scatter worker chan
where
scatter worker chan = do
mval <- await
(mx, action) <- liftIO $ atomically $ do
mres <- pollSTM worker
case mres of
Just (Left e) ->
return (Nothing, liftIO $ throwIO (e :: SomeException))
Just (Right r) ->
return (Just r, return ())
Nothing -> do
writeTBQueue chan mval
return (Nothing, return ())
action
case mx of
Just x -> return x
Nothing -> scatter worker chan