module Codec.Netstring.Enumerator (decode) where
import Control.Exception (ErrorCall(..))
import qualified Data.ByteString as B
import qualified Data.ByteString.Char8 as B8
import qualified Data.Enumerator as E
import Data.Enumerator ((>>==))
decode :: Monad m => E.Enumeratee B.ByteString B.ByteString m b
decode s = do
count <- iterCount
dropChar ':'
b <- decodeImpl count s
dropChar ','
return b
iterCount :: Monad m => E.Iteratee B.ByteString m Integer
iterCount = loopBytes parseFirst where
isDigit w = w >= 0x30 && w <= 0x39
parseFirst x xs = if B8.head x == '0'
then E.yield 0 (E.Chunks ((B.drop 1 x):xs))
else parseLoop [] (x:xs)
parseLoop acc [] = loopBytesL (parseLoop acc)
parseLoop acc (x:xs) = case B.span isDigit x of
(pre, post) | B.null post -> parseLoop (pre:acc) xs
(pre, post) -> finishParse (reverse (pre:acc)) (post:xs)
finishParse chunks extra = E.yield count (E.Chunks extra) where
bytes = B.concat chunks
Just (count, _) = B8.readInteger bytes
decodeImpl :: Monad m => Integer -> E.Enumeratee B.ByteString B.ByteString m b
decodeImpl count = loop 0 where
loop n = E.checkDone (\k -> loopBytesL (parseLoop n k []))
feed k acc next = k (E.Chunks (reverse acc)) >>== next
len = toInteger . B.length
parseLoop n k acc [] = feed k acc (loop n)
parseLoop n k acc (x:xs) = parse where
n' = n + len x
parse = if n' >= count then finish else keepLooping
keepLooping = parseLoop n' k (x:acc) xs
(inData, extra) = B.splitAt (fromInteger (count n)) x
finish = feed k (inData:acc) $ E.checkDoneEx (E.Chunks (extra:xs))
(\k' -> E.yield (E.Continue k') (E.Chunks (extra:xs)))
dropChar :: Monad m => Char -> E.Iteratee B.ByteString m ()
dropChar c = loopBytes step where
step x xs = if B8.head x == c
then E.yield () (E.Chunks ((B.drop 1 x):xs))
else err $ concat ["unexpected ", show (B8.head x)
,"; expecting ", show c
]
loopBytes :: Monad m
=> (B.ByteString -> [B.ByteString] -> E.Iteratee B.ByteString m b)
-> E.Iteratee B.ByteString m b
loopBytes k = E.continue step where
step E.EOF = err "unexpected EOF"
step (E.Chunks chunks) = case filter (not . B.null) chunks of
[] -> E.continue step
(x:xs) -> k x xs
loopBytesL :: Monad m
=> ([B.ByteString] -> E.Iteratee B.ByteString m b)
-> E.Iteratee B.ByteString m b
loopBytesL k = loopBytes (\x xs -> k (x:xs))
err :: Monad m => String -> E.Iteratee a m b
err msg = E.throwError (ErrorCall ("Codec.Netstring.Enumerator.decode: " ++ msg))