-- ------------------------------------------------------------

{- |
   Module     : Text.XML.HXT.RelaxNG.Validation
   Copyright  : Copyright (C) 2008 Torben Kuseler, Uwe Schmidt
   License    : MIT

   Maintainer : Uwe Schmidt (uwe@fh-wedel.de)
   Stability  : stable
   Portability: portable

   Validation of a XML document with respect to a valid Relax NG schema in simple form.
   Copied and modified from \"An algorithm for RELAX NG validation\" by James Clark
   (<http://www.thaiopensource.com/relaxng/derivative.html>).

-}

-- ------------------------------------------------------------

module Text.XML.HXT.RelaxNG.Validation
    ( validateWithRelaxAndHandleErrors
    , validateDocWithRelax
    , validateRelax
    , validateXMLDoc
    , readForRelax
    , normalizeForRelaxValidation
    , contains
    )
where

import Control.Arrow.ListArrows

import           Text.XML.HXT.DOM.Interface
import qualified Text.XML.HXT.DOM.XmlNode as XN
import           Text.XML.HXT.DOM.Unicode
    ( isXmlSpaceChar
    )


import Text.XML.HXT.Arrow.XmlArrow
import Text.XML.HXT.Arrow.XmlIOStateArrow

import Text.XML.HXT.Arrow.Edit
    ( canonicalizeAllNodes
    , collapseAllXText
    )

import Text.XML.HXT.Arrow.ProcessDocument
    ( propagateAndValidateNamespaces
    , getDocumentContents
    , parseXmlDocument
    )

import Text.XML.HXT.RelaxNG.DataTypes
import Text.XML.HXT.RelaxNG.CreatePattern
import Text.XML.HXT.RelaxNG.PatternToString
import Text.XML.HXT.RelaxNG.DataTypeLibraries
import Text.XML.HXT.RelaxNG.Utils
    ( formatStringListQuot
    , compareURI
    )

import Data.Maybe
    ( fromJust
    )

{-
import qualified Debug.Trace as T
-}

-- ------------------------------------------------------------

validateWithRelaxAndHandleErrors        :: IOSArrow XmlTree XmlTree -> IOSArrow XmlTree XmlTree
validateWithRelaxAndHandleErrors theSchema
    = validateWithRelax theSchema
      >>>
      handleErrors

validateWithRelax       :: IOSArrow XmlTree XmlTree -> IOSArrow XmlTree XmlTree
validateWithRelax theSchema
    = traceMsg 2 "validate with Relax NG schema"
      >>>
      ( ( normalizeForRelaxValidation           -- prepare the document for validation
          >>>
          getChildren
          >>>
          isElem                                -- and select the root element
        )
        &&&
        theSchema
      )
      >>>
      arr2A validateRelax                       -- compute vaidation errors as a document

handleErrors    :: IOSArrow XmlTree XmlTree
handleErrors
    = traceDoc "error found when validating with Relax NG schema"
      >>>
      ( getChildren                             -- prepare error format
        >>>
        getText
        >>>
        arr ("Relax NG validation: " ++)
        >>>
        mkError c_err
      )
      >>>
      filterErrorMsg                            -- issue errors and set system status


{- |
   normalize a document for validation with Relax NG: remove all namespace declaration attributes,
   remove all processing instructions and merge all sequences of text nodes into a single text node
-}

normalizeForRelaxValidation :: ArrowXml a => a XmlTree XmlTree
normalizeForRelaxValidation
  = processTopDownWithAttrl
    (
     ( none `when`                      -- remove all namespace attributes
       ( isAttr
         >>>
         getNamespaceUri
         >>>
         isA (compareURI xmlnsNamespace)
       )
     )
     >>>
     (none `when` isPi)                 -- processing instructions
    )
    >>>
    collapseAllXText                    -- all text node sequences are merged into a single text node

-- ------------------------------------------------------------

{- | Validates a xml document with respect to a Relax NG schema

   * 1.parameter  :  the arrow for computing the Relax NG schema

   - 2.parameter  :  list of options for reading and validating

   - 3.parameter  :  XML document URI

   - arrow-input  :  ignored

   - arrow-output :  list of errors or 'none'
-}

validateDocWithRelax :: IOSArrow XmlTree XmlTree -> Attributes -> String -> IOSArrow XmlTree XmlTree
validateDocWithRelax theSchema al doc
  = ( if null doc
      then root [] []
      else readForRelax al doc
    )
    >>>
    validateWithRelax theSchema
    >>>
    perform handleErrors


{- | Validates a xml document with respect to a Relax NG schema

   * 1.parameter  :  XML document

   - arrow-input  :  Relax NG schema

   - arrow-output :  list of errors or 'none'
-}

validateRelax :: XmlTree -> IOSArrow XmlTree XmlTree
validateRelax xmlDoc
  = fromLA
    ( createPatternFromXmlTree
      >>>
      arr (\p -> childDeriv ("",[]) p xmlDoc)
      >>>
      ( (not . nullable)
        `guardsP`
        root [] [ (take 1024 . show) ^>> mkText ]       -- pattern may be recursive, so the string representation
                                                        -- is truncated to 1024 chars to assure termination
      )
    )

-- ------------------------------------------------------------

readForRelax    :: Attributes -> String -> IOSArrow b XmlTree
readForRelax options schema
    = getDocumentContents options schema
      >>>
      parseXmlDocument False
      >>>
      canonicalizeAllNodes
      >>>
      propagateAndValidateNamespaces

-- ------------------------------------------------------------

{- old stuff -}

validateXMLDoc :: Attributes -> String -> IOSArrow XmlTree XmlTree
validateXMLDoc al xmlDoc
  = validateRelax
    $<
    ( readForRelax al xmlDoc
      >>>
      normalizeForRelaxValidation
      >>>
      getChildren
    )


-- ------------------------------------------------------------
--
-- | tests whether a 'NameClass' contains a particular 'QName'

contains :: NameClass -> QName -> Bool
contains AnyName _                      = True
contains (AnyNameExcept nc)    n        = not (contains nc n)
contains (NsName ns1)          qn       = ns1 == namespaceUri qn
contains (NsNameExcept ns1 nc) qn       = ns1 == namespaceUri qn && not (contains nc qn)
contains (Name ns1 ln1)        qn       = (ns1 == namespaceUri qn) && (ln1 == localPart qn)
contains (NameClassChoice nc1 nc2) n    = (contains nc1 n) || (contains nc2 n)
contains (NCError _) _                  = False


-- ------------------------------------------------------------
--
-- | tests whether a pattern matches the empty sequence
nullable:: Pattern -> Bool
nullable (Group p1 p2)          = nullable p1 && nullable p2
nullable (Interleave p1 p2)     = nullable p1 && nullable p2
nullable (Choice p1 p2)         = nullable p1 || nullable p2
nullable (OneOrMore p)          = nullable p
nullable (Element _ _)          = False
nullable (Attribute _ _)        = False
nullable (List _)               = False
nullable (Value _ _ _)          = False
nullable (Data _ _)             = False
nullable (DataExcept _ _ _)     = False
nullable (NotAllowed _)         = False
nullable Empty                  = True
nullable Text                   = True
nullable (After _ _)            = False


-- ------------------------------------------------------------
--
-- | computes the derivative of a pattern with respect to a XML-Child and a 'Context'

childDeriv :: Context -> Pattern -> XmlTree -> Pattern

childDeriv cx p t
    | XN.isText t       = textDeriv cx p . fromJust . XN.getText $ t
    | XN.isElem t       = endTagDeriv p4
    | otherwise         = notAllowed "Call to childDeriv with wrong arguments"
    where
    children    =            XN.getChildren $ t
    qn          = fromJust . XN.getElemName $ t
    atts        = fromJust . XN.getAttrl    $ t
    cx1         = ("",[])
    p1          = startTagOpenDeriv p qn
    p2          = attsDeriv cx1 p1 atts
    p3          = startTagCloseDeriv p2
    p4          = childrenDeriv cx1 p3 children

-- ------------------------------------------------------------
--
-- | computes the derivative of a pattern with respect to a text node

textDeriv :: Context -> Pattern -> String -> Pattern

textDeriv cx (Choice p1 p2) s
    = choice (textDeriv cx p1 s) (textDeriv cx p2 s)

textDeriv cx (Interleave p1 p2) s
    = choice
      (interleave (textDeriv cx p1 s) p2)
      (interleave p1 (textDeriv cx p2 s))

textDeriv cx (Group p1 p2) s
    = let
      p = group (textDeriv cx p1 s) p2
      in
      if nullable p1
      then choice p (textDeriv cx p2 s)
      else p

textDeriv cx (After p1 p2) s
    = after (textDeriv cx p1 s) p2

textDeriv cx (OneOrMore p) s
    = group (textDeriv cx p s) (choice (OneOrMore p) Empty)

textDeriv _ Text _
    = Text

textDeriv cx1 (Value (uri, s) value cx2) s1
    = case datatypeEqual uri s value cx2 s1 cx1
      of
      Nothing     -> Empty
      Just errStr -> notAllowed errStr

textDeriv cx (Data (uri, s) params) s1
    = case datatypeAllows uri s params s1 cx
      of
      Nothing     -> Empty
      Just errStr -> notAllowed2 errStr

textDeriv cx (DataExcept (uri, s) params p) s1
    = case (datatypeAllows uri s params s1 cx)
      of
      Nothing     -> if not $ nullable $ textDeriv cx p s1
                     then Empty
                     else notAllowed
                              ( "Any value except " ++
                                show (show p) ++
                                " expected, but value " ++
                                show (show s1) ++
                                " found"
                              )
      Just errStr -> notAllowed errStr

textDeriv cx (List p) s
    = if nullable (listDeriv cx p (words s))
      then Empty
      else notAllowed
               ( "List with value(s) " ++
                 show p ++
                 " expected, but value(s) " ++
                 formatStringListQuot (words s) ++
                 " found"
               )

textDeriv _ n@(NotAllowed _) _
    = n

textDeriv _ p s
    = notAllowed
      ( "Pattern " ++ show (getPatternName p) ++
        " expected, but text " ++ show s ++ " found"
      )


-- ------------------------------------------------------------
--
-- | To compute the derivative of a pattern with respect to a list of strings,
-- simply compute the derivative with respect to each member of the list in turn.

listDeriv :: Context -> Pattern -> [String] -> Pattern

listDeriv _ p []
    = p

listDeriv cx p (x:xs)
    = listDeriv cx (textDeriv cx p x) xs


-- ------------------------------------------------------------
--
-- | computes the derivative of a pattern with respect to a start tag open

startTagOpenDeriv :: Pattern -> QName -> Pattern

startTagOpenDeriv (Choice p1 p2) qn
    = choice (startTagOpenDeriv p1 qn) (startTagOpenDeriv p2 qn)

startTagOpenDeriv (Element nc p) qn
    | contains nc qn
        = after p Empty
    | otherwise
        = notAllowed $
          "Element with name " ++ nameClassToString nc ++
            " expected, but " ++ universalName qn ++ " found"

startTagOpenDeriv (Interleave p1 p2) qn
    = choice
      (applyAfter (flip interleave p2) (startTagOpenDeriv p1 qn))
      (applyAfter (interleave p1) (startTagOpenDeriv p2 qn))

startTagOpenDeriv (OneOrMore p) qn
    = applyAfter
      (flip group (choice (OneOrMore p) Empty))
      (startTagOpenDeriv p qn)

startTagOpenDeriv (Group p1 p2) qn
    = let
      x = applyAfter (flip group p2) (startTagOpenDeriv p1 qn)
      in
      if nullable p1
      then choice x (startTagOpenDeriv p2 qn)
      else x

startTagOpenDeriv (After p1 p2) qn
    = applyAfter (flip after p2) (startTagOpenDeriv p1 qn)

startTagOpenDeriv n@(NotAllowed _) _
    = n

startTagOpenDeriv p qn
    = notAllowed ( show p ++ " expected, but Element " ++ universalName qn ++ " found" )

-- ------------------------------------------------------------

-- auxiliary functions for tracing
{-
attsDeriv' cx p ts
    = T.trace ("attsDeriv: p=" ++ (take 1000 . show) p ++ ", t=" ++ showXts ts) $
      T.trace ("res= " ++ (take 1000 . show) res) res
    where
    res = attsDeriv cx p ts

attDeriv' cx p t
    = T.trace ("attDeriv: p=" ++ (take 1000 . show) p ++ ", t=" ++ showXts [t]) $
      T.trace ("res= " ++ (take 1000 . show) res) res
    where
    res = attDeriv cx p t
-}

-- | To compute the derivative of a pattern with respect to a sequence of attributes,
-- simply compute the derivative with respect to each attribute in turn.

attsDeriv :: Context -> Pattern -> XmlTrees -> Pattern

attsDeriv _ p []
    = p
attsDeriv cx p (t : ts)
    | XN.isAttr t
        = attsDeriv cx (attDeriv cx p t) ts
    | otherwise
        = notAllowed "Call to attsDeriv with wrong arguments"

attDeriv :: Context -> Pattern -> XmlTree -> Pattern

attDeriv cx (After p1 p2) att
    = after (attDeriv cx p1 att) p2

attDeriv cx (Choice p1 p2) att
    = choice (attDeriv cx p1 att) (attDeriv cx p2 att)

attDeriv cx (Group p1 p2) att
    = choice
      (group (attDeriv cx p1 att) p2)
      (group p1 (attDeriv cx p2 att))

attDeriv cx (Interleave p1 p2) att
    = choice
      (interleave (attDeriv cx p1 att) p2)
      (interleave p1 (attDeriv cx p2 att))

attDeriv cx (OneOrMore p) att
    = group
      (attDeriv cx p att)
      (choice (OneOrMore p) Empty)

attDeriv cx (Attribute nc p) att
    | isa
      &&
      not (contains nc qn)
        = notAllowed1 $
          "Attribute with name " ++ nameClassToString nc
          ++ " expected, but " ++ universalName qn ++ " found"
    | isa
      &&
      ( ( nullable p
          &&
          whitespace val
        )
        || nullable p'
      )
        = Empty
    | isa
        = err' p'
    where
    isa =            XN.isAttr      $ att
    qn  = fromJust . XN.getAttrName $ att
    av  =            XN.getChildren $ att
    val = showXts av
    p'  = textDeriv cx p val

    err' (NotAllowed (ErrMsg _l es))
        = err'' (": " ++ head es)
    err' _
        = err'' ""
    err'' e
        = notAllowed2 $
          "Attribute value \"" ++ val ++
          "\" does not match datatype spec " ++ show p ++ e

attDeriv _ n@(NotAllowed _) _
    = n

attDeriv _ _p att
    = notAllowed $
      "No matching pattern for attribute '" ++  showXts [att] ++ "' found"

-- ------------------------------------------------------------
--
-- | computes the derivative of a pattern with respect to a start tag close

startTagCloseDeriv :: Pattern -> Pattern

startTagCloseDeriv (After p1 p2)
    = after (startTagCloseDeriv p1) p2

startTagCloseDeriv (Choice p1 p2)
    = choice
      (startTagCloseDeriv p1)
      (startTagCloseDeriv p2)

startTagCloseDeriv (Group p1 p2)
    = group
      (startTagCloseDeriv p1)
      (startTagCloseDeriv p2)

startTagCloseDeriv (Interleave p1 p2)
    = interleave
      (startTagCloseDeriv p1)
      (startTagCloseDeriv p2)

startTagCloseDeriv (OneOrMore p)
    = oneOrMore (startTagCloseDeriv p)

startTagCloseDeriv (Attribute nc _)
    = notAllowed1 $
      "Attribut with name, " ++ show nc ++
      " expected, but no more attributes found"

startTagCloseDeriv p
    = p


-- ------------------------------------------------------------
--
-- | Computing the derivative of a pattern with respect to a list of children involves
-- computing the derivative with respect to each pattern in turn, except
-- that whitespace requires special treatment.

childrenDeriv :: Context -> Pattern -> XmlTrees -> Pattern
childrenDeriv _cx p@(NotAllowed _) _
    = p

childrenDeriv cx p []
    = childrenDeriv cx p [XN.mkText ""]

childrenDeriv cx p [tt]
    | ist
      &&
      whitespace s
        = choice p p1
    | ist
        = p1
    where
    ist =            XN.isText    tt
    s   = fromJust . XN.getText $ tt
    p1  = childDeriv cx p tt

childrenDeriv cx p children
    = stripChildrenDeriv cx p children

stripChildrenDeriv :: Context -> Pattern -> XmlTrees -> Pattern
stripChildrenDeriv _ p []
    = p

stripChildrenDeriv cx p (h:t)
    = stripChildrenDeriv cx
      ( if strip h
        then p
        else (childDeriv cx p h)
      ) t


-- ------------------------------------------------------------
--
-- | computes the derivative of a pattern with respect to a end tag

endTagDeriv :: Pattern -> Pattern
endTagDeriv (Choice p1 p2)
    = choice (endTagDeriv p1) (endTagDeriv p2)

endTagDeriv (After p1 p2)
    | nullable p1
        = p2
    | otherwise
        = notAllowed $
          show p1 ++ " expected"

endTagDeriv n@(NotAllowed _)
    = n

endTagDeriv _
    = notAllowed "Call to endTagDeriv with wrong arguments"

-- ------------------------------------------------------------
--
-- | applies a function (first parameter) to the second part of a after pattern

applyAfter :: (Pattern -> Pattern) -> Pattern -> Pattern

applyAfter f (After p1 p2)      = after p1 (f p2)
applyAfter f (Choice p1 p2)     = choice (applyAfter f p1) (applyAfter f p2)
applyAfter _ n@(NotAllowed _)   = n
applyAfter _ _                  = notAllowed "Call to applyAfter with wrong arguments"

-- --------------------

-- mothers little helpers

strip           :: XmlTree -> Bool
strip           = maybe False whitespace . XN.getText

whitespace      :: String -> Bool
whitespace      = all isXmlSpaceChar

showXts         :: XmlTrees -> String
showXts         = concat . runLA (xshow $ arrL id)

-- ------------------------------------------------------------