-----------------------------------------------------------------------------
-- |
-- Module: Text.XML.LibXML.Enumerator
-- Copyright: 2010 John Millikin
-- License: MIT
--
-- Maintainer: jmillikin@gmail.com
-- Portability: portable
--
-----------------------------------------------------------------------------
module Text.XML.LibXML.Enumerator
	( Event (..)
	, parseBytesIO
	, parseTextIO
	, parseBytesST
	, parseTextST
	) where
import qualified Data.ByteString as B
import qualified Data.Enumerator as E
import Data.Enumerator ((>>==))
import qualified Data.Text as T
import qualified Data.XML.Types as X
import qualified Text.XML.LibXML.SAX as SAX

import Control.Exception (ErrorCall(..))
import Control.Monad.IO.Class (MonadIO, liftIO)
import Control.Monad.Trans.Class (lift)
import Control.Monad.ST (ST)
import qualified Data.STRef as ST
import qualified Data.IORef as IO

data Event
	= EventBeginDocument
	| EventEndDocument
	| EventBeginElement X.Name [X.Attribute]
	| EventEndElement X.Name
	| EventCharacters T.Text
	| EventComment T.Text
	| EventInstruction X.Instruction
	deriving (Show, Eq)

newParserIO :: Maybe T.Text -> IO (SAX.Parser IO, IO.IORef [Event], IO.IORef (Maybe T.Text))
newParserIO name = do
	errRef <- IO.newIORef Nothing
	p <- SAX.newParserIO (\msg -> IO.writeIORef errRef (Just msg)) name
	eventRef <- IO.newIORef []
	let addEvent e = IO.modifyIORef eventRef (e:) >> return True
	setCallbacks p addEvent
	return (p, eventRef, errRef)

newParserST :: Maybe T.Text -> ST s (SAX.Parser (ST s), ST.STRef s [Event], ST.STRef s (Maybe T.Text))
newParserST name = do
	errRef <- ST.newSTRef Nothing
	p <- SAX.newParserST (\msg -> ST.writeSTRef errRef (Just msg)) name
	eventRef <- ST.newSTRef []
	let addEvent e = ST.modifySTRef eventRef (e:) >> return True
	setCallbacks p addEvent
	return (p, eventRef, errRef)

setCallbacks :: Monad m => SAX.Parser m -> (Event -> m Bool) -> m ()
setCallbacks p addEvent = do
	let set cb st = SAX.setCallback p cb st
	
	set SAX.parsedBeginDocument (addEvent EventBeginDocument)
	set SAX.parsedEndDocument (addEvent EventEndDocument)
	set SAX.parsedBeginElement ((addEvent .) . EventBeginElement)
	set SAX.parsedEndElement (addEvent . EventEndElement)
	set SAX.parsedCharacters (addEvent . EventCharacters)
	set SAX.parsedComment (addEvent . EventComment)
	set SAX.parsedInstruction (addEvent . EventInstruction)

parseBytesIO :: MonadIO m => Maybe T.Text -> E.Enumeratee B.ByteString Event m b
parseBytesIO = parseIO SAX.parseBytes

parseTextIO :: MonadIO m => Maybe T.Text -> E.Enumeratee T.Text Event m b
parseTextIO = parseIO SAX.parseText

parseIO :: MonadIO m
        => (SAX.Parser IO -> a -> IO ())
        -> Maybe T.Text
        -> E.Enumeratee a Event m b
parseIO parseFn name s = E.Iteratee $ do
	(p, eventRef, errRef) <- liftIO $ newParserIO name
	
	let withEvents io = liftIO $ do
		IO.writeIORef eventRef []
		IO.writeIORef errRef Nothing
		io
		events <- IO.readIORef eventRef
		err <- IO.readIORef errRef
		return (reverse events, err)
	
	let parseChunk bytes = withEvents (parseFn p bytes)
	let complete = withEvents (SAX.parseComplete p)
	E.runIteratee $ eneeParser parseChunk complete s

parseBytesST :: Maybe T.Text -> E.Enumeratee B.ByteString Event (ST s) b
parseBytesST = parseST SAX.parseBytes

parseTextST :: Maybe T.Text -> E.Enumeratee T.Text Event (ST s) b
parseTextST = parseST SAX.parseText

parseST :: (SAX.Parser (ST s) -> a -> ST s ())
        -> Maybe T.Text
        -> E.Enumeratee a Event (ST s) b
parseST parseFn name s = E.Iteratee $ do
	(p, eventRef, errRef) <- newParserST name
	
	let withEvents st = do
		ST.writeSTRef eventRef []
		ST.writeSTRef errRef Nothing
		st
		events <- ST.readSTRef eventRef
		err <- ST.readSTRef errRef
		return (reverse events, err)
	
	let parseChunk bytes = withEvents (parseFn p bytes)
	let complete = withEvents (SAX.parseComplete p)
	E.runIteratee $ eneeParser parseChunk complete s

eneeParser :: Monad m
           => (a -> m ([Event], Maybe T.Text))
           -> m ([Event], Maybe T.Text)
           -> E.Enumeratee a Event m b
eneeParser parseChunk parseComplete = E.checkDone (E.continue . step) where
	step k E.EOF = checkEvents k parseComplete (\k' -> E.yield (E.Continue k') E.EOF)
	step k (E.Chunks xs) = parseLoop k xs
	
	parseLoop k [] = E.continue (step k)
	parseLoop k (x:xs) = checkEvents k (parseChunk x) (\k' -> parseLoop k' xs)
	
	checkEvents k getEvents next = do
		(events, maybeErr) <- lift getEvents
		let checkError k' = case maybeErr of
			Nothing -> next k'
			Just err -> E.throwError (ErrorCall (T.unpack err))
		if null events
			then checkError k
			else k (E.Chunks events) >>== E.checkDone checkError