-----------------------------------------------------------------------------
-- |
-- Module: Text.JSON.YAJL.Enumerator
-- Copyright: 2010 John Millikin
-- License: GPL-3
--
-- Maintainer: jmillikin@gmail.com
-- Portability: portable
--
-----------------------------------------------------------------------------
module Text.JSON.YAJL.Enumerator
	(
	-- * Parsing
	  parseBytesIO
	, parseBytesST
	
	-- * Generating
	, Y.GeneratorConfig (..)
	, Y.GeneratorError (..)
	, generateBytesIO
	, generateTextIO
	, generateBytesST
	, generateTextST
	) where
import Prelude hiding (null)
import qualified Prelude as Prelude
import qualified Data.ByteString as B
import qualified Data.Text as T
import qualified Data.Text.Encoding as TE
import qualified Data.Text.Lazy as TL
import qualified Data.Enumerator as E
import Data.Enumerator ((>>==))
import qualified Text.JSON.YAJL as Y
import qualified Data.JSON.Types as J

import Control.Exception as Exc
import Control.Monad.IO.Class (MonadIO, liftIO)
import Control.Monad.Trans.Class (lift)
import Control.Monad.ST (ST, stToIO, unsafeIOToST, unsafeSTToIO, RealWorld)
import qualified Data.STRef as ST
import qualified Data.IORef as IO

-- Parser {{{

-- | Requires input to be in UTF-8
parseBytesIO :: MonadIO m => E.Enumeratee B.ByteString J.Event m b
parseBytesIO s = do
	(p, eventRef) <- liftIO newParserIO
	
	let withEvents io = liftIO $ do
		IO.writeIORef eventRef []
		status <- io
		events <- IO.readIORef eventRef
		return (reverse events, status)
	
	let parseChunk bytes = withEvents (Y.parseBytes p bytes)
	let complete = withEvents (Y.parseComplete p)
	eneeParser (liftIO (Y.getBytesConsumed p)) parseChunk complete s

-- | Requires input to be in UTF-8
parseBytesST :: E.Enumeratee B.ByteString J.Event (ST s) b
parseBytesST s = do
	(p, eventRef) <- lift newParserST
	
	let withEvents st = do
		ST.writeSTRef eventRef []
		status <- st
		events <- ST.readSTRef eventRef
		return (reverse events, status)
	
	let parseChunk bytes = withEvents (Y.parseBytes p bytes)
	let complete = withEvents (Y.parseComplete p)
	eneeParser (Y.getBytesConsumed p) parseChunk complete s

newParserIO :: IO (Y.Parser IO, IO.IORef [J.Event])
newParserIO = do
	p <- Y.newParserIO
	eventRef <- IO.newIORef []
	let addEvent e = IO.modifyIORef eventRef (e:) >> return True
	setCallbacks p addEvent
	return (p, eventRef)

newParserST :: ST s (Y.Parser (ST s), ST.STRef s [J.Event])
newParserST = do
	p <- Y.newParserST
	eventRef <- ST.newSTRef []
	let addEvent e = ST.modifySTRef eventRef (e:) >> return True
	setCallbacks p addEvent
	return (p, eventRef)

setCallbacks :: Monad m => Y.Parser m -> (J.Event -> m Bool) -> m ()
setCallbacks p addEvent = do
	let set = Y.setCallback p
	
	set Y.parsedBeginArray (addEvent J.EventBeginArray)
	set Y.parsedEndArray (addEvent J.EventEndArray)
	set Y.parsedBeginObject (addEvent J.EventBeginObject)
	set Y.parsedEndObject (addEvent J.EventEndObject)
	set Y.parsedNull (addEvent (J.EventAtom J.AtomNull))
	set Y.parsedBoolean (addEvent . J.EventAtom . J.AtomBoolean)
	set Y.parsedAttributeText (addEvent . J.EventAttributeName . TL.fromStrict)
	set Y.parsedText (addEvent . J.EventAtom . J.AtomText . TL.fromStrict)
	
	-- this is going to bite me in the ass
	set Y.parsedInteger (addEvent . J.EventAtom . J.AtomNumber . fromInteger)
	set Y.parsedDouble (addEvent . J.EventAtom . J.AtomNumber . toRational)

eneeParser :: Monad m
           => m Integer
           -> (B.ByteString -> m ([J.Event], Y.ParseStatus))
           -> m ([J.Event], Y.ParseStatus)
           -> E.Enumeratee B.ByteString J.Event m b
eneeParser getBytesConsumed parseChunk parseComplete = E.checkDone (E.continue . step) where
	step k (E.Chunks xs) = parseLoop k xs
	step k E.EOF = do
		(events, status) <- lift parseComplete
		checkStatus status events E.EOF k
			(\_ -> throwError "yajl-enumerator: Unexpected EOF while parsing")
	
	parseLoop k [] = E.continue (step k)
	parseLoop k (x:xs) = do
		(events, status) <- lift (parseChunk x)
		extra <- getExtra x xs status
		checkStatus status events extra k
			(\k' -> parseLoop k' xs)
	
	getExtra x xs Y.ParseFinished = do
		consumed <- lift getBytesConsumed
		let extraX = B.drop (fromInteger consumed) x
		return . E.Chunks $ if null extraX
			then xs
			else extraX:xs
	getExtra _ _ _ = return (E.Chunks [])
	
	checkStatus status events extra k onContinue = iter where
		checkError k' = case status of
			Y.ParseError err -> throwError (T.unpack err)
			Y.ParseFinished -> E.yield (E.Continue k') extra
			Y.ParseContinue -> onContinue k'
			Y.ParseCancelled -> throwError "Parse cancelled"
		
		iter = if null events
			then checkError k
			else k (E.Chunks events) >>== E.checkDoneEx extra checkError
	
	throwError = E.throwError . Exc.ErrorCall

-- }}}

-- Generator {{{

class Nullable a where
	null :: a -> Bool

instance Nullable [a] where
	null = Prelude.null

instance Nullable B.ByteString where
	null = B.null

instance Nullable T.Text where
	null = T.null

generateTextIO :: MonadIO m
               => Y.GeneratorConfig
               -> E.Enumeratee J.Event T.Text m b
generateTextIO = generateIO (fmap TE.decodeUtf8 . Y.getBuffer)

generateBytesIO :: MonadIO m
                => Y.GeneratorConfig
                -> E.Enumeratee J.Event B.ByteString m b
generateBytesIO = generateIO Y.getBuffer

generateTextST :: Y.GeneratorConfig
               -> E.Enumeratee J.Event T.Text (ST s) b
generateTextST = generateST (fmap TE.decodeUtf8 . Y.getBuffer)

generateBytesST :: Y.GeneratorConfig
                -> E.Enumeratee J.Event B.ByteString (ST s) b
generateBytesST = generateST Y.getBuffer

generateIO :: (Nullable a, MonadIO m)
           => (Y.Generator RealWorld -> ST RealWorld a)
           -> Y.GeneratorConfig
           -> E.Enumeratee J.Event a m b
generateIO getBuf config s = do
	g <- liftIO $ stToIO $ Y.newGenerator config
	let takeBuf = liftIO $ stToIO $ do
		buf <- getBuf g
		Y.clearBuffer g
		return buf
	let genEvent e = liftIO $ Exc.handle (return . Just) $ do
		stToIO $ genEventImpl g e
		return Nothing
	eneeGenerator genEvent takeBuf s

generateST :: Nullable a
           => (Y.Generator s -> ST s a)
           -> Y.GeneratorConfig
           -> E.Enumeratee J.Event a (ST s) b
generateST getBuf config s = do
	g <- lift $ Y.newGenerator config
	let takeBuf = do
		buf <- getBuf g
		Y.clearBuffer g
		return buf
	let genEvent e = unsafeIOToST $ Exc.handle (return . Just) $ do
		unsafeSTToIO $ genEventImpl g e
		return Nothing
	eneeGenerator genEvent takeBuf s

eneeGenerator :: (Nullable a, Monad m)
              => (J.Event -> m (Maybe Y.GeneratorError))
              -> m a
              -> E.Enumeratee J.Event a m b
eneeGenerator genEvent takeBuf = E.checkDone (E.continue . step) where
	step k (E.Chunks []) = E.continue (step k)
	step k (E.Chunks xs) = parseLoop k xs
	step k E.EOF = do
		maybeError <- lift (genEvent (J.EventAtom J.AtomNull))
		case maybeError of
			Just Y.GenerationComplete -> E.yield (E.Continue k) E.EOF
			_ -> E.throwError (Exc.ErrorCall "yajl-enumerator: Unexpected EOF while generating")
	
	parseLoop k [] = checkBuf k [] (E.continue . step)
	parseLoop k (x:xs) = do
		maybeError <- lift (genEvent x)
		case maybeError of
			Just Y.GenerationComplete -> checkBuf k xs
				(\k' -> E.yield (E.Continue k') (E.Chunks xs))
			Just err -> checkBuf k xs (\_ -> E.throwError err)
			Nothing -> parseLoop k xs
	
	checkBuf k extra next = do
		buf <- lift takeBuf
		if null buf
			then next k
			else k (E.Chunks [buf]) >>== E.checkDoneEx (E.Chunks extra) next

genEventImpl :: Y.Generator s -> J.Event -> ST s ()
genEventImpl g e = case e of
	J.EventBeginObject -> Y.generateBeginObject g
	J.EventEndObject -> Y.generateEndObject g
	J.EventBeginArray -> Y.generateBeginArray g
	J.EventEndArray -> Y.generateEndArray g
	J.EventAttributeName name -> Y.generateText g (TL.toStrict name)
	
	J.EventAtom atom -> case atom of
		J.AtomNull -> Y.generateNull g
		J.AtomBoolean x -> Y.generateBoolean g x
		J.AtomNumber num -> Y.generateDouble g (fromRational num)
		J.AtomText text -> Y.generateText g (TL.toStrict text)

-- }}}