----------------------------------------------------------------------------- -- | -- Module: Codec.Netstring.Enumerator -- Copyright: 2010 John Millikin -- License: GPL-3 -- -- Maintainer: jmillikin@gmail.com -- Portability: portable -- ----------------------------------------------------------------------------- module Codec.Netstring.Enumerator (decode) where import Control.Exception (ErrorCall(..)) import qualified Data.ByteString as B import qualified Data.ByteString.Char8 as B8 import qualified Data.Enumerator as E import Data.Enumerator ((>>==)) decode :: Monad m => E.Enumeratee B.ByteString B.ByteString m b decode s = do count <- iterCount dropChar ':' b <- decodeImpl count s dropChar ',' return b iterCount :: Monad m => E.Iteratee B.ByteString m Integer iterCount = loopBytes parseFirst where -- Digit checking is only performed for ASCII digits -- -- isDigit = (`B.elem` B8.pack "0123456789") isDigit w = w >= 0x30 && w <= 0x39 -- A netstring count, if it begins with '0', must /be/ zero. parseFirst x xs = if B8.head x == '0' then E.yield 0 (E.Chunks ((B.drop 1 x):xs)) else parseLoop [] (x:xs) -- Read bytes until a non-digit character is encountered parseLoop acc [] = loopBytesL (parseLoop acc) parseLoop acc (x:xs) = case B.span isDigit x of (pre, post) | B.null post -> parseLoop (pre:acc) xs (pre, post) -> finishParse (reverse (pre:acc)) (post:xs) -- With all bytes read, combine them and parse the count finishParse chunks extra = E.yield count (E.Chunks extra) where bytes = B.concat chunks -- 'bytes' is composed solely of ASCII digits, so this will -- never fail, or yield leftover data Just (count, _) = B8.readInteger bytes decodeImpl :: Monad m => Integer -> E.Enumeratee B.ByteString B.ByteString m b decodeImpl count = loop 0 where loop n = E.checkDone (\k -> loopBytesL (parseLoop n k [])) feed k acc next = k (E.Chunks (reverse acc)) >>== next len = toInteger . B.length -- Read bytes from the parent enumerator until the full 'count' -- is reached. If the end of this chunk is reached before the -- count, feed all bytes to the child iteratee and continue -- looping. parseLoop n k acc [] = feed k acc (loop n) parseLoop n k acc (x:xs) = parse where n' = n + len x parse = if n' >= count then finish else keepLooping keepLooping = parseLoop n' k (x:acc) xs -- If this chunk of bytes finishes the netstring, feed the -- accumulator to the child iteratee and yield back to the -- parent enumerator. (inData, extra) = B.splitAt (fromInteger (count - n)) x finish = feed k (inData:acc) $ E.checkDoneEx (E.Chunks (extra:xs)) (\k' -> E.yield (E.Continue k') (E.Chunks (extra:xs))) -- 'dropChar' reads a single ASCII character from the enumerator, and checks -- that it matches the expected value. If so, it's discarded (but the rest -- of the bytes are forwarded on) dropChar :: Monad m => Char -> E.Iteratee B.ByteString m () dropChar c = loopBytes step where step x xs = if B8.head x == c then E.yield () (E.Chunks ((B.drop 1 x):xs)) else err $ concat ["unexpected ", show (B8.head x) ,"; expecting ", show c ] loopBytes :: Monad m => (B.ByteString -> [B.ByteString] -> E.Iteratee B.ByteString m b) -> E.Iteratee B.ByteString m b loopBytes k = E.continue step where step E.EOF = err "unexpected EOF" step (E.Chunks chunks) = case filter (not . B.null) chunks of [] -> E.continue step (x:xs) -> k x xs loopBytesL :: Monad m => ([B.ByteString] -> E.Iteratee B.ByteString m b) -> E.Iteratee B.ByteString m b loopBytesL k = loopBytes (\x xs -> k (x:xs)) err :: Monad m => String -> E.Iteratee a m b err msg = E.throwError (ErrorCall ("Codec.Netstring.Enumerator.decode: " ++ msg))