{-# LANGUAGE OverloadedStrings, BangPatterns #-}
module Text.Roundtrip.Xml.Printer (

    XmlPrinter, runXmlPrinter,

    runXmlPrinterByteString, runXmlPrinterLazyByteString,
    runXmlPrinterText, runXmlPrinterLazyText,
    runXmlPrinterString

) where

import Control.Monad (mplus, liftM2)

import Data.XML.Types

import Control.Monad.State
import Control.Exception (SomeException)

import System.IO.Unsafe (unsafePerformIO)

import qualified Data.Text as T
import qualified Data.Text.Lazy as TL

import qualified Data.Enumerator as E
import qualified Data.Enumerator.List as EL
import qualified Data.Enumerator.Binary as EB
import qualified Data.Enumerator.Text as ET
import qualified Text.XML.Enumerator.Render as EX

import qualified Data.ByteString as BS
import qualified Data.ByteString.Lazy as BSL

import Control.Isomorphism.Partial
import Text.Roundtrip
import Text.Roundtrip.Printer

data PxState = PxStateJust Name [Attribute]
             | PxStateNothing
              deriving (Show)

newtype XmlPrinter a = XmlPrinter { unXmlPrinter :: Printer (State PxState) [Event] a }

instance IsoFunctor XmlPrinter where
    iso <$> (XmlPrinter p) = XmlPrinter $ iso `printerApply` p

instance ProductFunctor XmlPrinter where
    XmlPrinter p <*> XmlPrinter q = XmlPrinter (p `printerConcat` q)

instance Alternative XmlPrinter where
    XmlPrinter p <||> XmlPrinter q = XmlPrinter (p `printerAlternative` q)
    empty = XmlPrinter printerEmpty

instance Syntax XmlPrinter where
    pure x = XmlPrinter (printerPure x)

-- Rendering a list of events into a string/text/bytestring is done via
-- enumerators. This is not optimal because the resulting list is too strict.
-- However, currently no other functions exists for such a conversion.
runXmlPrinterGen :: Monad m => XmlPrinter a -> a
                 -> (m (Either SomeException [c]) -> Either SomeException [c])
                 -> E.Enumeratee Event c m [c] -> Maybe [c]
runXmlPrinterGen p x run render =
    case runXmlPrinter p x of
      Nothing -> Nothing
      Just l ->
          case run $
               E.run $
               E.enumList 20 l E.$$
               E.joinI $ (render E.$$ EL.consume)
          of Left _ -> Nothing
             Right t -> Just t

runXmlPrinterByteString :: XmlPrinter a -> a -> Maybe BS.ByteString
runXmlPrinterByteString p x =
    do l <- runXmlPrinterGen p x unsafePerformIO EX.renderBytes
       return $ BS.concat l

runXmlPrinterLazyByteString :: XmlPrinter a -> a -> Maybe BSL.ByteString
runXmlPrinterLazyByteString p x =
    do l <- runXmlPrinterGen p x unsafePerformIO EX.renderBytes
       return $ BSL.fromChunks l

runXmlPrinterText :: XmlPrinter a -> a -> Maybe T.Text
runXmlPrinterText p x =
    do l <- runXmlPrinterGen p x unsafePerformIO EX.renderText
       return $ T.concat l

runXmlPrinterLazyText :: XmlPrinter a -> a -> Maybe TL.Text
runXmlPrinterLazyText p x =
    do l <- runXmlPrinterGen p x unsafePerformIO EX.renderText
       return $ TL.fromChunks l

runXmlPrinterString :: XmlPrinter a -> a -> Maybe String
runXmlPrinterString p x =
    do tl <- runXmlPrinterLazyText p x
       case TL.unpack tl of
         ('<':'?':'x':'m':'l':z) -> Just (eat z)
         str -> Just str
    where
      eat l =
          case dropWhile (/= '?') l of
            '>':xs -> xs
            [] -> []
            _:xs -> eat xs

runXmlPrinter :: XmlPrinter a -> a -> Maybe [Event]
runXmlPrinter (XmlPrinter (Printer p)) x =
    evalState (p x) PxStateNothing

instance XmlSyntax XmlPrinter where
    xmlBeginDoc = xmlPrinterBeginDoc
    xmlEndDoc = xmlPrinterEndDoc
    xmlBeginElem = xmlPrinterBeginElem
    xmlEndElem = xmlPrinterEndElem
    xmlAttrValue = xmlPrinterAttrValue
    xmlTextNotEmpty = xmlPrinterTextNotEmpty

mkXmlPrinter :: (a -> State PxState (Maybe [Event])) -> XmlPrinter a
mkXmlPrinter = XmlPrinter . Printer

xmlPrinterBeginDoc :: XmlPrinter ()
xmlPrinterBeginDoc = mkXmlPrinter $ \() -> return (Just [EventBeginDocument])

xmlPrinterEndDoc :: XmlPrinter ()
xmlPrinterEndDoc = mkXmlPrinter $ \() -> return (Just [EventEndDocument])

xmlPrinterBeginElem :: Name -> XmlPrinter ()
xmlPrinterBeginElem name = mkXmlPrinter $ \() ->
    do l <- possiblyCloseOpeningTag []
       state <- get
       let newState = case state of
                        PxStateNothing -> PxStateJust name []
                        _ -> error $ "expected state Nothing, but got " ++ (show state)
       put newState
       return l

xmlPrinterEndElem :: Name -> XmlPrinter ()
xmlPrinterEndElem name = mkXmlPrinter $ \() -> possiblyCloseOpeningTag [EventEndElement name]

xmlPrinterAttrValue :: Name -> XmlPrinter T.Text
xmlPrinterAttrValue aName = mkXmlPrinter $ \value ->
     do state <- get
        let newState = case state of
                         PxStateJust elName attrs ->
                             PxStateJust elName ((aName, [ContentText value]) : attrs)
                         PxStateNothing -> error "xmlAttribute: state is Nothing"
        put newState
        return $ Just []

xmlPrinterTextNotEmpty :: XmlPrinter T.Text
xmlPrinterTextNotEmpty = mkXmlPrinter $ \value ->
                         if T.null value
                            then return $ Just []
                            else possiblyCloseOpeningTag [EventContent (ContentText value)]

possiblyCloseOpeningTag :: [Event] -> State PxState (Maybe [Event])
possiblyCloseOpeningTag l =
    do state <- get
       case state of
         PxStateJust name attrs ->
             do put PxStateNothing
                return $ Just (EventBeginElement name (reverse attrs) : l)
         PxStateNothing -> return $ Just l