{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE ImplicitParams #-}
{-# LANGUAGE RankNTypes #-}
module Z.IO.Buffered
(
Input(..), Output(..)
, BufferedInput
, newBufferedInput
, newBufferedInput'
, readBuffer
, unReadBuffer
, readParser
, readExactly, readExactly'
, readToMagic, readToMagic'
, readLine, readLine'
, readAll, readAll'
, BufferedOutput
, newBufferedOutput
, newBufferedOutput'
, writeBuffer
, writeBuilder
, flushBuffer
, Source, Sink
, sourceBuffer
, sinkBuffer
, sourceFromList
, (>+>)
, parseSource
, collectSource
, concatSource
, zipSource
, (>>>>=)
, BufferedException(..)
, V.defaultChunkSize
, V.smallChunkSize
) where
import Control.Monad
import Control.Monad.Primitive (ioToPrim, primToIO)
import Control.Monad.ST
import Data.IORef
import Data.Primitive.PrimArray
import Data.Typeable
import Data.Word
import Data.Bits (unsafeShiftR)
import Foreign.Ptr
import Z.Data.Array
import qualified Z.Data.Builder.Base as B
import qualified Z.Data.Parser as P
import qualified Z.Data.Vector as V
import qualified Z.Data.Vector.Base as V
import Z.Data.PrimRef.PrimIORef
import Z.Foreign
import Z.IO.Exception
class Input i where
readInput :: HasCallStack => i -> Ptr Word8 -> Int -> IO Int
class Output o where
writeOutput :: HasCallStack => o -> Ptr Word8 -> Int -> IO ()
data BufferedInput i = BufferedInput
{ bufInput :: i
, bufPushBack :: {-# UNPACK #-} !(IORef V.Bytes)
, inputBuffer :: {-# UNPACK #-} !(IORef (MutablePrimArray RealWorld Word8))
}
data BufferedOutput o = BufferedOutput
{ bufOutput :: o
, bufIndex :: {-# UNPACK #-} !Counter
, outputBuffer :: {-# UNPACK #-} !(MutablePrimArray RealWorld Word8)
}
newBufferedInput :: i -> IO (BufferedInput i)
newBufferedInput = newBufferedInput' V.defaultChunkSize
newBufferedOutput :: o -> IO (BufferedOutput o)
newBufferedOutput = newBufferedOutput' V.defaultChunkSize
newBufferedOutput' :: Int
-> o
-> IO (BufferedOutput o)
newBufferedOutput' bufSiz o = do
index <- newPrimIORef 0
buf <- newPinnedPrimArray (max bufSiz 0)
return (BufferedOutput o index buf)
newBufferedInput' :: Int
-> i
-> IO (BufferedInput i)
newBufferedInput' bufSiz i = do
pb <- newIORef V.empty
buf <- newPinnedPrimArray (max bufSiz 0)
inputBuffer <- newIORef buf
return (BufferedInput i pb inputBuffer)
readBuffer :: (Input i, HasCallStack) => BufferedInput i -> IO V.Bytes
readBuffer BufferedInput{..} = do
pb <- readIORef bufPushBack
if V.null pb
then do
rbuf <- readIORef inputBuffer
bufSiz <- getSizeofMutablePrimArray rbuf
l <- readInput bufInput (mutablePrimArrayContents rbuf) bufSiz
if l < bufSiz `quot` 2
then do
mba <- newPrimArray l
copyMutablePrimArray mba 0 rbuf 0 l
ba <- unsafeFreezePrimArray mba
return $! V.fromArr ba 0 l
else do
when (bufSiz /= 0) $ do
buf' <- newPinnedPrimArray bufSiz
writeIORef inputBuffer buf'
ba <- unsafeFreezePrimArray rbuf
return $! V.fromArr ba 0 l
else do
writeIORef bufPushBack V.empty
return pb
readExactly :: (HasCallStack, Input i) => Int -> BufferedInput i -> IO V.Bytes
readExactly n0 h0 = V.concat `fmap` (go h0 n0)
where
go h n = do
chunk <- readBuffer h
let l = V.length chunk
if l > n
then do
let (lastChunk, rest) = V.splitAt n chunk
unReadBuffer rest h
return [lastChunk]
else if l == n
then return [chunk]
else if l == 0
then return [chunk]
else do
chunks <- go h (n - l)
return (chunk : chunks)
readExactly' :: (HasCallStack, Input i) => Int -> BufferedInput i -> IO V.Bytes
readExactly' n h = do
v <- readExactly n h
if (V.length v /= n)
then throwIO (ShortReadException callStack)
else return v
readAll :: (HasCallStack, Input i) => BufferedInput i -> IO [V.Bytes]
readAll h = loop []
where
loop acc = do
chunk <- readBuffer h
if V.null chunk
then return $! reverse (chunk:acc)
else loop (chunk:acc)
readAll' :: (HasCallStack, Input i) => BufferedInput i -> IO V.Bytes
readAll' i = V.concat <$> readAll i
data BufferedException = ParseException P.ParseError CallStack
| ShortReadException CallStack deriving (Show, Typeable)
instance Exception BufferedException where
toException = ioExceptionToException
fromException = ioExceptionFromException
unReadBuffer :: (HasCallStack, Input i) => V.Bytes -> BufferedInput i -> IO ()
unReadBuffer pb' BufferedInput{..} = unless (V.null pb') $ do
modifyIORef' bufPushBack $ \ pb -> pb' `V.append` pb
readParser :: (HasCallStack, Input i) => P.Parser a -> BufferedInput i -> IO (Either P.ParseError a)
readParser p i = do
bs <- readBuffer i
(rest, r) <- P.parseChunks p (readBuffer i) bs
unReadBuffer rest i
return r
readToMagic :: (HasCallStack, Input i) => Word8 -> BufferedInput i -> IO V.Bytes
readToMagic magic0 h0 = V.concat `fmap` (go h0 magic0)
where
go h magic = do
chunk <- readBuffer h
if V.null chunk
then return []
else case V.elemIndex magic chunk of
Just i -> do
let (lastChunk, rest) = V.splitAt (i+1) chunk
unReadBuffer rest h
return [lastChunk]
Nothing -> do
chunks <- go h magic
return (chunk : chunks)
readToMagic' :: (HasCallStack, Input i) => Word8 -> BufferedInput i -> IO V.Bytes
readToMagic' magic0 h0 = V.concat `fmap` (go h0 magic0)
where
go h magic = do
chunk <- readBuffer h
if V.null chunk
then throwIO (ShortReadException callStack)
else case V.elemIndex magic chunk of
Just i -> do
let (lastChunk, rest) = V.splitAt (i+1) chunk
unReadBuffer rest h
return [lastChunk]
Nothing -> do
chunks <- go h magic
return (chunk : chunks)
readLine :: (HasCallStack, Input i) => BufferedInput i -> Source V.Bytes
readLine i = do
bs@(V.PrimVector arr s l) <- readToMagic 10 i
if l == 0
then return Nothing
else return $ case bs `V.indexMaybe` (l-2) of
Nothing -> Just (V.PrimVector arr s (l-1))
Just r | r == 13 -> Just (V.PrimVector arr s (l-2))
| otherwise -> Just (V.PrimVector arr s (l-1))
readLine' :: (HasCallStack, Input i) => BufferedInput i -> Source V.Bytes
readLine' i = do
bs@(V.PrimVector arr s l) <- readToMagic' 10 i
if l == 0
then return Nothing
else return $ case bs `V.indexMaybe` (l-2) of
Nothing -> Just (V.PrimVector arr s (l-1))
Just r | r == 13 -> Just (V.PrimVector arr s (l-2))
| otherwise -> Just (V.PrimVector arr s (l-1))
writeBuffer :: (HasCallStack, Output o) => BufferedOutput o -> V.Bytes -> IO ()
writeBuffer o@BufferedOutput{..} v@(V.PrimVector ba s l) = do
i <- readPrimIORef bufIndex
bufSiz <- getSizeofMutablePrimArray outputBuffer
if i /= 0
then if i + l <= bufSiz
then do
copyPrimArray outputBuffer i ba s l
writePrimIORef bufIndex (i+l)
else do
withMutablePrimArrayContents outputBuffer $ \ ptr -> (writeOutput bufOutput) ptr i
writePrimIORef bufIndex 0
writeBuffer o v
else
if l > bufSiz `unsafeShiftR` 2
then withPrimVectorSafe v (writeOutput bufOutput)
else do
copyPrimArray outputBuffer i ba s l
writePrimIORef bufIndex l
writeBuilder :: (HasCallStack, Output o) => BufferedOutput o -> B.Builder a -> IO ()
writeBuilder BufferedOutput{..} (B.Builder b) = do
i <- readPrimIORef bufIndex
originBufSiz <- getSizeofMutablePrimArray outputBuffer
_ <- primToIO (b (B.OneShotAction action) (lastStep originBufSiz) (B.Buffer outputBuffer i))
return ()
where
action :: V.Bytes -> ST RealWorld ()
action bytes = ioToPrim (withPrimVectorSafe bytes (writeOutput bufOutput))
lastStep :: Int -> a -> B.BuildStep RealWorld
lastStep originBufSiz _ (B.Buffer buf offset)
| sameMutablePrimArray buf outputBuffer = ioToPrim $ do
writePrimIORef bufIndex offset
return []
| offset >= originBufSiz = ioToPrim $ do
withMutablePrimArrayContents buf $ \ ptr -> (writeOutput bufOutput) ptr offset
writePrimIORef bufIndex 0
return []
| otherwise = ioToPrim $ do
copyMutablePrimArray outputBuffer 0 buf 0 offset
writePrimIORef bufIndex offset
return []
flushBuffer :: (HasCallStack, Output o) => BufferedOutput o -> IO ()
flushBuffer BufferedOutput{..} = do
i <- readPrimIORef bufIndex
when (i /= 0) $ do
withMutablePrimArrayContents outputBuffer $ \ ptr -> (writeOutput bufOutput) ptr i
writePrimIORef bufIndex 0
type Source a = IO (Maybe a)
type Sink a = (a -> IO (), IO ())
sourceBuffer :: (HasCallStack, Input i) => BufferedInput i -> Source V.Bytes
{-# INLINABLE sourceBuffer #-}
sourceBuffer i = readBuffer i >>= \ x -> if V.null x then return Nothing
else return (Just x)
sinkBuffer :: (HasCallStack, Output o) => BufferedOutput o -> Sink V.Bytes
{-# INLINABLE sinkBuffer #-}
sinkBuffer o = (writeBuffer o, flushBuffer o)
sourceFromList :: [a] -> IO (Source a)
{-# INLINABLE sourceFromList #-}
sourceFromList xs0 = do
xsRef <- newIORef xs0
return (popper xsRef)
where
popper xsRef = do
xs <- readIORef xsRef
case xs of
(x:xs') -> do
writeIORef xsRef xs'
return (Just x)
_ -> return Nothing
(>+>) :: Source a -> Source a -> IO (Source a)
{-# INLINABLE (>+>) #-}
input1 >+> input2 = concatSource [input1, input2]
collectSource :: Source a -> IO [a]
{-# INLINABLE collectSource #-}
collectSource input = loop []
where
loop acc = do
r <- input
case r of
Just r' -> loop (r':acc)
_ -> return $! reverse acc
parseSource :: HasCallStack => P.Parser a -> Source V.Bytes -> IO (Source a)
{-# INLINABLE parseSource #-}
parseSource p source = do
trailingRef <- newIORef V.empty
return (go trailingRef)
where
go trailingRef = do
trailing <- readIORef trailingRef
if V.null trailing
then do
bs <- source
case bs of
Just bs' -> do
(rest, r) <- P.parseChunks p source' bs'
writeIORef trailingRef rest
case r of Right v -> return (Just v)
Left e -> throwIO (ParseException e callStack)
_ -> return Nothing
else do
(rest, r) <- P.parseChunks p source' trailing
writeIORef trailingRef rest
case r of Right v -> return (Just v)
Left e -> throwIO (ParseException e callStack)
source' = source >>= \ r -> case r of Just r' -> return r'
_ -> return V.empty
concatSource :: [Source a] -> IO (Source a)
{-# INLINABLE concatSource #-}
concatSource ss0 = newIORef ss0 >>= return . loop
where
loop ref = do
ss <- readIORef ref
case ss of
[] -> return Nothing
(input:rest) -> do
chunk <- input
case chunk of
Just _ -> return chunk
_ -> writeIORef ref rest >> loop ref
zipSource :: Source a -> Source b -> Source (a,b)
{-# INLINABLE zipSource #-}
zipSource inputA inputB = do
mA <- inputA
mB <- inputB
case mA of Just a -> case mB of Just b -> return (Just (a, b))
_ -> return Nothing
_ -> return Nothing
(>>>>=) :: Source a
-> Sink a
-> IO ()
{-# INLINABLE (>>>>=) #-}
(>>>>=) input (write, flush) = loop
where
loop = do
m <- input
case m of
Just x' -> write x' >> loop
_ -> flush