{-# LANGUAGE DeriveDataTypeable #-}

-----------------------------------------------------------------------------
-- |
-- Module: Text.XML.Expat.Enumerator
-- Copyright: 2010 John Millikin
-- License: MIT
--
-- Maintainer: jmillikin@gmail.com
-- Portability: portable
--
-----------------------------------------------------------------------------
module Text.XML.Expat.Enumerator
	( Expat.Encoding (..)
	, ParseError (..)
	, parseBytesIO
	, parseTextIO
	) where
import qualified Data.ByteString as B
import qualified Data.Enumerator as E
import qualified Data.Enumerator.Text as ET
import Data.Enumerator ((>>==), ($$))
import qualified Data.Text as T
import qualified Data.Text.Encoding as TE
import qualified Data.Text.Lazy as TL
import qualified Data.XML.Types as X
import qualified Text.XML.Expat.Internal.IO as Expat

import qualified Control.Exception as Exc
import Control.Monad (forM)
import Control.Monad.IO.Class (MonadIO, liftIO)
import Data.Typeable (Typeable)
import qualified Data.IORef as IO
import Foreign.C (CString, CStringLen)

-- | A replacement for 'Expat.XMLParseError', defined so it can be given
-- a non-orphan 'Exc.Exception' instance.
data ParseError = ParseError
	{ parseErrorMessage :: T.Text
	, parseErrorLocation :: Expat.XMLParseLocation
	}
	deriving (Show, Typeable)

instance Exc.Exception ParseError

parseTextIO :: MonadIO m => E.Enumeratee T.Text X.Event m b
parseTextIO s = E.joinI (ET.encode ET.utf8 $$ parseBytesIO (Just Expat.UTF8) s)

parseBytesIO :: MonadIO m
             => Maybe Expat.Encoding -- ^ If present, will override Expat's
                                     -- encoding detection.
             -> E.Enumeratee B.ByteString X.Event m b
parseBytesIO enc s = do
	p <- liftIO $ Expat.newParser enc
	ref <- liftIO $ setCallbacks p
	eneeParser p ref s

setCallbacks :: Expat.Parser -> IO (IO.IORef [X.Event])
setCallbacks p = do
	eventRef <- IO.newIORef []
	let addEvent e = IO.modifyIORef eventRef (e:) >> return True
	let toName local = X.Name local Nothing Nothing
	
	Expat.setStartElementHandler p $ \_ cName cAttrs -> do
		local <- peekUTF8 cName
		attrs <- forM cAttrs $ \(cAttrName, cAttrText) -> do
			attrLocal <- peekUTF8 cAttrName
			attrText <- peekUTF8 cAttrText
			let content = X.ContentText attrText
			return $ X.Attribute (toName attrLocal) [content]
		addEvent (X.EventBeginElement (toName local) attrs)
	
	Expat.setEndElementHandler p $ \_ cName -> do
		local <- peekUTF8 cName
		addEvent (X.EventEndElement (toName local))
	
	Expat.setCharacterDataHandler p $ \_ cstr -> do
		text <- peekUTF8Len cstr
		addEvent (X.EventContent (X.ContentText text))
	
	Expat.setProcessingInstructionHandler p $ \_ ctgt cdst -> do
		tgt <- peekUTF8 ctgt
		dst <- peekUTF8 cdst
		addEvent (X.EventInstruction (X.Instruction tgt dst))
	
	Expat.setCommentHandler p $ \_ cstr -> do
		text <- peekUTF8 cstr
		addEvent (X.EventComment text)
	
	return eventRef

eneeParser :: MonadIO m
           => Expat.Parser
           -> IO.IORef [X.Event]
           -> E.Enumeratee B.ByteString X.Event m b
eneeParser p eventRef = E.checkDone (E.continue . step) where
	step k E.EOF = checkEvents k
		(\ptr -> Expat.parseChunk ptr B.empty True)
		(\k' -> E.yield (E.Continue k') E.EOF)
	
	step k (E.Chunks []) = E.continue (step k)
	step k (E.Chunks xs) = checkEvents k
		(\ptr -> parseChunks ptr xs)
		(\k' -> E.continue (step k'))
	
	parseChunks _ [] = error "Text.XML.Expat.Enumerator: parseChunks []"
	parseChunks ptr (x:xs) = do
		maybeErr <- Expat.parseChunk ptr x False
		case maybeErr of
			Just err -> return (Just err)
			Nothing -> if null xs
				then return Nothing
				else parseChunks ptr xs
	
	checkEvents k runParse next = do
		(events, maybeErr) <- liftIO (getEvents runParse)
		let checkError k' = case maybeErr of
			Nothing -> next k'
			Just err -> throwError err
		if null events
			then checkError k
			else k (E.Chunks events) >>== E.checkDone checkError
	
	getEvents runParse = liftIO $ do
		IO.writeIORef eventRef []
		err <- Expat.withParser p runParse
		events <- IO.readIORef eventRef
		return (reverse events, err)
	
	throwError err = E.throwError (ParseError (T.pack msg) loc) where
		Expat.XMLParseError msg loc = err

peekUTF8 :: CString -> IO TL.Text
peekUTF8 cstr = do
	bytes <- B.packCString cstr
	return (TL.fromChunks [TE.decodeUtf8 bytes])

peekUTF8Len :: CStringLen -> IO TL.Text
peekUTF8Len cstr = do
	bytes <- B.packCStringLen cstr
	return (TL.fromChunks [TE.decodeUtf8 bytes])