-----------------------------------------------------------------------------
-- |
-- Module: Codec.Netstring.Enumerator
-- Copyright: 2010 John Millikin
-- License: GPL-3
--
-- Maintainer: jmillikin@gmail.com
-- Portability: portable
--
-----------------------------------------------------------------------------
module Codec.Netstring.Enumerator (decode) where
import Control.Exception (ErrorCall(..))
import qualified Data.ByteString as B
import qualified Data.ByteString.Char8 as B8
import qualified Data.Enumerator as E
import Data.Enumerator ((>>==))

decode :: Monad m => E.Enumeratee B.ByteString B.ByteString m b
decode s = do
	count <- iterCount
	dropChar ':'
	b <- decodeImpl count s
	dropChar ','
	return b

iterCount :: Monad m => E.Iteratee B.ByteString m Integer
iterCount = loopBytes parseFirst where
	
	-- Digit checking is only performed for ASCII digits
	--
	-- isDigit = (`B.elem` B8.pack "0123456789")
	isDigit w = w >= 0x30 && w <= 0x39
	
	-- A netstring count, if it begins with '0', must /be/ zero.
	parseFirst x xs = if B8.head x == '0'
		then E.yield 0 (E.Chunks ((B.drop 1 x):xs))
		else parseLoop [] (x:xs)
	
	-- Read bytes until a non-digit character is encountered
	parseLoop acc [] = loopBytesL (parseLoop acc)
	parseLoop acc (x:xs) = case B.span isDigit x of
		(pre, post) | B.null post -> parseLoop (pre:acc) xs
		(pre, post) -> finishParse (reverse (pre:acc)) (post:xs)
	
	-- With all bytes read, combine them and parse the count
	finishParse chunks extra = E.yield count (E.Chunks extra) where
		bytes = B.concat chunks
		
		-- 'bytes' is composed solely of ASCII digits, so this will
		-- never fail, or yield leftover data
		Just (count, _) = B8.readInteger bytes

decodeImpl :: Monad m => Integer -> E.Enumeratee B.ByteString B.ByteString m b
decodeImpl count = loop 0 where
	loop n = E.checkDone (\k -> loopBytesL (parseLoop n k []))
	feed k acc next = k (E.Chunks (reverse acc)) >>== next
	len = toInteger . B.length
	
	-- Read bytes from the parent enumerator until the full 'count'
	-- is reached. If the end of this chunk is reached before the
	-- count, feed all bytes to the child iteratee and continue
	-- looping.
	parseLoop n k acc [] = feed k acc (loop n)
	parseLoop n k acc (x:xs) = parse where
		n' = n + len x
		parse = if n' >= count then finish else keepLooping
		keepLooping = parseLoop n' k (x:acc) xs
		
		-- If this chunk of bytes finishes the netstring, feed the
		-- accumulator to the child iteratee and yield back to the
		-- parent enumerator.
		(inData, extra) = B.splitAt (fromInteger (count - n)) x
		finish = feed k (inData:acc) $ E.checkDoneEx (E.Chunks (extra:xs))
			(\k' -> E.yield (E.Continue k') (E.Chunks (extra:xs)))

-- 'dropChar' reads a single ASCII character from the enumerator, and checks
-- that it matches the expected value. If so, it's discarded (but the rest
-- of the bytes are forwarded on)
dropChar :: Monad m => Char -> E.Iteratee B.ByteString m ()
dropChar c = loopBytes step where
	step x xs = if B8.head x == c
		then E.yield () (E.Chunks ((B.drop 1 x):xs))
		else err $ concat ["unexpected ", show (B8.head x)
		                  ,"; expecting ", show c
		                  ]

loopBytes :: Monad m
          => (B.ByteString -> [B.ByteString] -> E.Iteratee B.ByteString m b)
          -> E.Iteratee B.ByteString m b
loopBytes k = E.continue step where
	step E.EOF = err "unexpected EOF"
	step (E.Chunks chunks) = case filter (not . B.null) chunks of
		[] -> E.continue step
		(x:xs) -> k x xs

loopBytesL :: Monad m
          => ([B.ByteString] -> E.Iteratee B.ByteString m b)
          -> E.Iteratee B.ByteString m b
loopBytesL k = loopBytes (\x xs -> k (x:xs))

err :: Monad m => String -> E.Iteratee a m b
err msg = E.throwError (ErrorCall ("Codec.Netstring.Enumerator.decode: " ++ msg))