{-# LANGUAGE BangPatterns #-}

-- ------------------------------------------------------------

{- |
   Module     : Text.XML.HXT.RelaxNG.Validation
   Copyright  : Copyright (C) 2008 Torben Kuseler, Uwe Schmidt
   License    : MIT

   Maintainer : Uwe Schmidt (uwe@fh-wedel.de)
   Stability  : stable
   Portability: portable

   Validation of a XML document with respect to a valid Relax NG schema in simple form.
   Copied and modified from \"An algorithm for RELAX NG validation\" by James Clark
   (<http://www.thaiopensource.com/relaxng/derivative.html>).

-}

-- ------------------------------------------------------------

module Text.XML.HXT.RelaxNG.Validation
    ( validateWithRelax
    , validateDocWithRelax
    , validateRelax
    , validateRelax'
    , readForRelax
    , normalizeForRelaxValidation
    , contains
    )
where

import           Control.Arrow.ListArrows

import           Data.Char.Properties.XMLCharProps      (isXmlSpaceChar)
import           Data.Maybe                             (fromJust)

import           Text.XML.HXT.DOM.Interface
import qualified Text.XML.HXT.DOM.XmlNode               as XN

import           Text.XML.HXT.Arrow.Edit                (canonicalizeAllNodes,
                                                         collapseAllXText)
import           Text.XML.HXT.Arrow.XmlArrow

import           Text.XML.HXT.Arrow.ProcessDocument     (getDocumentContents,
                                                         parseXmlDocument, propagateAndValidateNamespaces)
import           Text.XML.HXT.Arrow.XmlState
import           Text.XML.HXT.Arrow.XmlState.TypeDefs

import           Text.XML.HXT.RelaxNG.CreatePattern
import           Text.XML.HXT.RelaxNG.DataTypeLibraries
import           Text.XML.HXT.RelaxNG.DataTypes
import           Text.XML.HXT.RelaxNG.PatternToString
import           Text.XML.HXT.RelaxNG.Utils             (compareURI,
                                                         formatStringListQuot)

{-
import qualified Debug.Trace                            as T
-- -}

-- ------------------------------------------------------------

validateWithRelax       :: IOSArrow XmlTree XmlTree -> IOSArrow XmlTree XmlTree
validateWithRelax theSchema
    = traceMsg 2 "normalize document for validation"
      >>>
      normalizeForRelaxValidation             -- prepare the document for validation
      >>>
      traceMsg 2 "start validation"
      >>>
      ( validateRelax $< theSchema )          -- compute and issue validation errors

{- |
   normalize a document for validation with Relax NG: remove all namespace declaration attributes,
   remove all processing instructions and merge all sequences of text nodes into a single text node
-}

normalizeForRelaxValidation :: ArrowXml a => a XmlTree XmlTree
normalizeForRelaxValidation
  = processTopDownWithAttrl
    (
     ( none `when`                      -- remove all namespace attributes
       ( isAttr
         >>>
         getNamespaceUri
         >>>
         isA (compareURI xmlnsNamespace)
       )
     )
     >>>
     (none `when` isPi)                 -- processing instructions
    )
    >>>
    collapseAllXText                    -- all text node sequences are merged into a single text node

-- ------------------------------------------------------------

{- | Validates a xml document with respect to a Relax NG schema

   * 1.parameter  :  the arrow for computing the Relax NG schema

   - 2.parameter  :  list of configuration options for reading and validating

   - 3.parameter  :  XML document URI

   - arrow-input  :  ignored

   - arrow-output :  list of errors or 'none'
-}

validateDocWithRelax :: IOSArrow XmlTree XmlTree -> SysConfigList -> String -> IOSArrow XmlTree XmlTree
validateDocWithRelax theSchema config doc
    = localSysEnv
      ( configSysVars config
        >>>
        readForRelax doc
        >>>
        validateWithRelax theSchema
      )

{- | Validates an XML document with respect to a Relax NG schema
   and issues error messages.

   See also: `validateRelax'`

   * 1.parameter  :  Relax NG schema

   - arrow-input  :  XML document

   - arrow-output :  the document or in case of errors none
-}

-- ------------------------------------------------------------

validateRelax :: XmlTree -> IOSArrow XmlTree XmlTree
validateRelax rngSchema
    = fromLA (validateRelax' rngSchema)
      >>>
      filterErrorMsg

{- | Validates an XML document with respect to a Relax NG schema
   This arrow is pure. It does not need IO or any configuration parameters.

   * 1.parameter  :  Relax NG schema

   - arrow-input  :  XML document

   - arrow-output :  the unchanged document or an error message
-}

validateRelax' :: XmlTree -> LA XmlTree XmlTree
validateRelax' rngSchema
    = ( ( ( constA rngSchema
            >>>
            createPatternFromXmlTree
          )
          &&&
          ( getChildren                       -- remove the root node
            >>>
            isElem                            -- and select the root element
          )
        )
        >>>
        arr2 (\ !pattern !xmlDoc -> childDeriv ("", []) pattern xmlDoc)
        >>>
        isA (not . nullable)
        >>>
        arr ( take 1024                      -- pattern may be recursive, so the string representation
                                             -- is truncated to 1024 chars to assure termination
              . ("when validating with Relax NG schema: " ++)
              . show
            )
        >>>
        mkError c_err
      )
      `orElse`
      this

-- ------------------------------------------------------------

readForRelax    :: String -> IOSArrow b XmlTree
readForRelax schema
    = getDocumentContents schema
      >>>
      parseXmlDocument False True False True
      >>>
      canonicalizeAllNodes
      >>>
      propagateAndValidateNamespaces

-- ------------------------------------------------------------
--
-- | tests whether a 'NameClass' contains a particular 'QName'

contains :: NameClass -> QName -> Bool
contains AnyName _                      = True
contains (AnyNameExcept nc)    n        = not (contains nc n)
contains (NsName ns1)          qn       = ns1 == namespaceUri qn
contains (NsNameExcept ns1 nc) qn       = ns1 == namespaceUri qn && not (contains nc qn)
contains (Name ns1 ln1)        qn       = (ns1 == namespaceUri qn) && (ln1 == localPart qn)
contains (NameClassChoice nc1 nc2) n    = (contains nc1 n) || (contains nc2 n)
contains (NCError _) _                  = False


-- ------------------------------------------------------------
--
-- | tests whether a pattern matches the empty sequence
nullable:: Pattern -> Bool
nullable (Group p1 p2)          = nullable p1 && nullable p2
nullable (Interleave p1 p2)     = nullable p1 && nullable p2
nullable (Choice p1 p2)         = nullable p1 || nullable p2
nullable (OneOrMore p)          = nullable p
nullable (Element _ _)          = False
nullable (Attribute _ _)        = False
nullable (List _)               = False
nullable (Value _ _ _)          = False
nullable (Data _ _)             = False
nullable (DataExcept _ _ _)     = False
nullable (NotAllowed _)         = False
nullable Empty                  = True
nullable Text                   = True
nullable (After _ _)            = False


-- ------------------------------------------------------------
--
-- | computes the derivative of a pattern with respect to a XML-Child and a 'Context'

childDeriv :: Context -> Pattern -> XmlTree -> Pattern

childDeriv cx p t
    | XN.isText t       = textDeriv{- ' -}cx p . fromJust . XN.getText $ t
    | XN.isElem t       = endTagDeriv{- ' -} p4
    | otherwise         = notAllowed "Call to childDeriv with wrong arguments"
    where
    children    =            XN.getChildren $ t
    qn          = fromJust . XN.getElemName $ t
    atts        = fromJust . XN.getAttrl    $ t
    cx1         = ("",[])
    p1          = startTagOpenDeriv p qn
    p2          = attsDeriv{- ' -} cx1 p1 atts
    p3          = startTagCloseDeriv p2
    p4          = childrenDeriv cx1 p3 children

-- ------------------------------------------------------------
--
-- | computes the derivative of a pattern with respect to a text node

{-
textDeriv' cx p t
    = T.trace ("textDeriv: p=\n" ++ (take 10000 . show) p ++ ", t=\n" ++ t) $
      T.trace ("res=\n" ++ (take 10000 . show) res) res
    where
    res = textDeriv cx p t
-- -}


textDeriv :: Context -> Pattern -> String -> Pattern

textDeriv cx (Choice p1 p2) s
    = choice (textDeriv cx p1 s) (textDeriv cx p2 s)

textDeriv cx (Interleave p1 p2) s
    = choice
      (interleave (textDeriv cx p1 s) p2)
      (interleave p1 (textDeriv cx p2 s))

textDeriv cx (Group p1 p2) s
    = let
      p = group (textDeriv cx p1 s) p2
      in
      if nullable p1
      then choice p (textDeriv cx p2 s)
      else p

textDeriv cx (After p1 p2) s
    = after (textDeriv cx p1 s) p2

textDeriv cx (OneOrMore p) s
    = group (textDeriv cx p s) (choice (OneOrMore p) Empty)

textDeriv _ Text _
    = Text

textDeriv cx1 (Value (uri, s) value cx2) s1
    = case datatypeEqual uri s value cx2 s1 cx1
      of
      Nothing     -> Empty
      Just errStr -> notAllowed errStr

textDeriv cx (Data (uri, s) params) s1
    = case datatypeAllows uri s params s1 cx
      of
      Nothing     -> Empty
      Just errStr -> notAllowed2 errStr

textDeriv cx (DataExcept (uri, s) params p) s1
    = case (datatypeAllows uri s params s1 cx)
      of
      Nothing     -> if not $ nullable $ textDeriv cx p s1
                     then Empty
                     else notAllowed
                              ( "Any value except " ++
                                show (show p) ++
                                " expected, but value " ++
                                show (show s1) ++
                                " found"
                              )
      Just errStr -> notAllowed errStr

textDeriv cx (List p) s
    = if nullable (listDeriv cx p (words s))
      then Empty
      else notAllowed
               ( "List with value(s) " ++
                 show p ++
                 " expected, but value(s) " ++
                 formatStringListQuot (words s) ++
                 " found"
               )

textDeriv _ n@(NotAllowed _) _
    = n

textDeriv _ p s
    = notAllowed
      ( "Pattern " ++ show (getPatternName p) ++
        " expected, but text " ++ show s ++ " found"
      )


-- ------------------------------------------------------------
--
-- | To compute the derivative of a pattern with respect to a list of strings,
-- simply compute the derivative with respect to each member of the list in turn.

listDeriv :: Context -> Pattern -> [String] -> Pattern

listDeriv _ !p []
    = p

listDeriv cx !p (x:xs)
    = listDeriv cx (textDeriv cx p x) xs


-- ------------------------------------------------------------
--
-- | computes the derivative of a pattern with respect to a start tag open

startTagOpenDeriv :: Pattern -> QName -> Pattern

startTagOpenDeriv (Choice p1 p2) qn
    = choice (startTagOpenDeriv p1 qn) (startTagOpenDeriv p2 qn)

startTagOpenDeriv (Element nc p) qn
    | contains nc qn
        = after p Empty
    | otherwise
        = notAllowed $
          "Element with name " ++ nameClassToString nc ++
            " expected, but " ++ universalName qn ++ " found"

startTagOpenDeriv (Interleave p1 p2) qn
    = choice
      (applyAfter (flip interleave p2) (startTagOpenDeriv p1 qn))
      (applyAfter (interleave p1) (startTagOpenDeriv p2 qn))

startTagOpenDeriv (OneOrMore p) qn
    = applyAfter
      (flip group (choice (OneOrMore p) Empty))
      (startTagOpenDeriv p qn)

startTagOpenDeriv (Group p1 p2) qn
    = let
      x = applyAfter (flip group p2) (startTagOpenDeriv p1 qn)
      in
      if nullable p1
      then choice x (startTagOpenDeriv p2 qn)
      else x

startTagOpenDeriv (After p1 p2) qn
    = applyAfter (flip after p2) (startTagOpenDeriv p1 qn)

startTagOpenDeriv n@(NotAllowed _) _
    = n

startTagOpenDeriv p qn
    = notAllowed ( show p ++ " expected, but Element " ++ universalName qn ++ " found" )

-- ------------------------------------------------------------

-- auxiliary functions for tracing

{-
attsDeriv' cx p ts
    = T.trace ("attsDeriv: p=" ++ (take 1000 . show) p ++ ", t=" ++ showXts ts) $
      T.trace ("res= " ++ (take 1000 . show) res) res
    where
    res = attsDeriv cx p ts
-- -}

{-
attDeriv' cx p t
    = T.trace ("attDeriv: p=\n" ++ (take 10000 . show) p ++ ", t=\n" ++ showXts [t]) $
      T.trace ("res=\n" ++ (take 1000 . show) res) res
    where
    res = attDeriv cx p t
-- -}

-- | To compute the derivative of a pattern with respect to a sequence of attributes,
-- simply compute the derivative with respect to each attribute in turn.

attsDeriv :: Context -> Pattern -> XmlTrees -> Pattern

attsDeriv _ !p []
    = p
attsDeriv cx !p (t : ts)
    | XN.isAttr t
        = attsDeriv cx (attDeriv{- ' -} cx p t) ts
    | otherwise
        = notAllowed "Call to attsDeriv with wrong arguments"

attDeriv :: Context -> Pattern -> XmlTree -> Pattern

attDeriv cx (After p1 p2) att
    = after (attDeriv cx p1 att) p2

attDeriv cx (Choice p1 p2) att
    = choice (attDeriv cx p1 att) (attDeriv cx p2 att)

attDeriv cx (Group p1 p2) att
    = choice
      (group (attDeriv cx p1 att) p2)
      (group p1 (attDeriv cx p2 att))

attDeriv cx (Interleave p1 p2) att
    = choice
      (interleave (attDeriv cx p1 att) p2)
      (interleave p1 (attDeriv cx p2 att))

attDeriv cx (OneOrMore p) att
    = group
      (attDeriv cx p att)
      (choice (OneOrMore p) Empty)

attDeriv cx (Attribute nc p) att
    | isa
      &&
      not (contains nc qn)
        = notAllowed1 $
          "Attribute with name " ++ nameClassToString nc
          ++ " expected, but " ++ universalName qn ++ " found"
    | isa
      &&
      ( ( nullable p
          &&
          whitespace val
        )
        || nullable p'
      )
        = Empty
    | isa
        = err' p'
    where
    isa =            XN.isAttr      $ att
    qn  = fromJust . XN.getAttrName $ att
    av  =            XN.getChildren $ att
    val = showXts av
    p'  = textDeriv cx p val

    err' (NotAllowed (ErrMsg _l es))
        = err'' (": " ++ head es)
    err' _
        = err'' ""
    err'' e
        = notAllowed2 $
          "Attribute value \"" ++ val ++
          "\" does not match datatype spec " ++ show p ++ e

attDeriv _ n@(NotAllowed _) _
    = n

attDeriv _ _p att
    = notAllowed $
      "No matching pattern for attribute '" ++  showXts [att] ++ "' found"

-- ------------------------------------------------------------
--
-- | computes the derivative of a pattern with respect to a start tag close

startTagCloseDeriv :: Pattern -> Pattern

startTagCloseDeriv (After p1 p2)
    = after (startTagCloseDeriv p1) p2

startTagCloseDeriv (Choice p1 p2)
    = choice
      (startTagCloseDeriv p1)
      (startTagCloseDeriv p2)

startTagCloseDeriv (Group p1 p2)
    = group
      (startTagCloseDeriv p1)
      (startTagCloseDeriv p2)

startTagCloseDeriv (Interleave p1 p2)
    = interleave
      (startTagCloseDeriv p1)
      (startTagCloseDeriv p2)

startTagCloseDeriv (OneOrMore p)
    = oneOrMore (startTagCloseDeriv p)

startTagCloseDeriv (Attribute nc _)
    = notAllowed1 $
      "Attribut with name, " ++ show nc ++
      " expected, but no more attributes found"

startTagCloseDeriv p
    = p


-- ------------------------------------------------------------
--
-- | Computing the derivative of a pattern with respect to a list of children involves
-- computing the derivative with respect to each pattern in turn, except
-- that whitespace requires special treatment.

childrenDeriv :: Context -> Pattern -> XmlTrees -> Pattern
childrenDeriv _cx p@(NotAllowed _) _
    = p

childrenDeriv cx p []
    = childrenDeriv cx p [XN.mkText ""]

childrenDeriv cx p [tt]
    | ist
      &&
      whitespace s
        = choice p p1
    | ist
        = p1
    where
    ist =            XN.isText    tt
    s   = fromJust . XN.getText $ tt
    p1  = childDeriv cx p tt

childrenDeriv cx p children
    = stripChildrenDeriv cx p children

stripChildrenDeriv :: Context -> Pattern -> XmlTrees -> Pattern
stripChildrenDeriv _ !p []
    = p

stripChildrenDeriv cx !p (h:t)
    = stripChildrenDeriv cx
      ( if strip h
        then p
        else (childDeriv cx p h)
      ) t


-- ------------------------------------------------------------
--
-- | computes the derivative of a pattern with respect to a end tag

{-
endTagDeriv' p
    = T.trace ("endTagDeriv: p=\n" ++ (take 10000 . show) p) $
      T.trace ("res=\n" ++ (take 10000 . show) res) res
    where
    res = endTagDeriv p
-- -}

endTagDeriv :: Pattern -> Pattern
endTagDeriv (Choice p1 p2)
    = choice (endTagDeriv p1) (endTagDeriv p2)

endTagDeriv (After p1 p2)
    | nullable p1
        = p2
    | otherwise
        = notAllowed $
          show p1 ++ " expected"

endTagDeriv n@(NotAllowed _)
    = n

endTagDeriv _
    = notAllowed "Call to endTagDeriv with wrong arguments"

-- ------------------------------------------------------------
--
-- | applies a function (first parameter) to the second part of a after pattern

applyAfter :: (Pattern -> Pattern) -> Pattern -> Pattern

applyAfter f (After p1 p2)      = after p1 (f p2)
applyAfter f (Choice p1 p2)     = choice (applyAfter f p1) (applyAfter f p2)
applyAfter _ n@(NotAllowed _)   = n
applyAfter _ _                  = notAllowed "Call to applyAfter with wrong arguments"

-- --------------------

-- mothers little helpers

strip           :: XmlTree -> Bool
strip           = maybe False whitespace . XN.getText

whitespace      :: String -> Bool
whitespace      = all isXmlSpaceChar

showXts         :: XmlTrees -> String
showXts         = concat . runLA (xshow $ arrL id)

-- ------------------------------------------------------------