{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE OverloadedStrings #-}
module Symantic.XML.RelaxNG.Language where

import Data.Bool
import Data.Eq (Eq(..))
import Data.Function (($), (.))
import Data.Maybe (Maybe(..))
import Data.Semigroup (Semigroup(..))
import Data.String (String, IsString(..))
import Prelude (error)
import Text.Show (Show(..))
import qualified Data.List as List
import qualified Data.HashMap.Strict as HM

import Symantic.Base.Fixity
import Symantic.XML.Language

-- * Class 'RelaxNG'
class
 ( XML repr
 , Permutable repr
 , Definable repr
 ) => RelaxNG repr where
  default elementMatch ::
                  Transformable repr =>
                  RelaxNG (UnTrans repr) =>
                  NameClass -> repr a k -> repr (QName -> a) k
  -- | Like 'element' but with a matching pattern
  -- instead of a specific 'QName'.
  elementMatch :: NameClass -> repr a k -> repr (QName -> a) k
  elementMatch nc = noTrans . elementMatch nc . unTrans
  default attributeMatch ::
                  Transformable repr =>
                  RelaxNG (UnTrans repr) =>
                  NameClass -> repr a k -> repr (QName -> a) k
  -- | Like 'attribute' but with a matching pattern
  -- instead of a specific 'QName'.
  attributeMatch :: NameClass -> repr a k -> repr (QName -> a) k
  attributeMatch nc = noTrans . attributeMatch nc . unTrans

-- * Type 'Definable'
class Definable repr where
  -- | @(define name expr)@ declares a rule named @(name)@
  -- and matching the 'RelaxNG' schema @(expr)@.
  --
  -- Useful for rendering the 'RelaxNG' schema,
  -- and necessary to avoid infinite recursion when
  -- printing a 'RelaxNG' schema calling itself recursively.
  --
  -- WARNING: 'DefineName's must be unique inside
  -- a whole 'RelaxNG' schema.
  define :: DefineName -> repr a k -> repr a k
  default define ::
   Transformable repr => RelaxNG (UnTrans repr) =>
   DefineName -> repr f k -> repr f k
  define n = noTrans . define n . unTrans

-- ** Type 'DefineName'
type DefineName = String

-- * Type 'NameClass'
data NameClass
 =   NameClass_Any
 |   (:::) Namespace NCName
 |   (:*) Namespace
 |   (:-:) NameClass NameClass
 |   (:|:) NameClass NameClass

infix 9 :::
infixr 2 :|:
infixl 6 :-:

(*:*) :: NameClass
(*:*) = NameClass_Any

-- | @('matchNameClass' nc q)@ returns 'True' iif. the 'NameClass' @(nc)@ matches the 'QName' @(q)@.
matchNameClass :: NameClass -> QName -> Bool
matchNameClass NameClass_Any _q = True
matchNameClass (ns:::nl) q = qNameSpace q == ns && qNameLocal q == nl
matchNameClass ((:*) ns) q = qNameSpace q == ns
matchNameClass (x:|:y) q = matchNameClass x q || matchNameClass y q
matchNameClass (x:-:y) q = matchNameClass x q && not (matchNameClass y q)

-- | Return the namespaces used by the given 'NameClass'
namespacesNameClass :: NameClass -> HM.HashMap Namespace (Maybe NCName)
namespacesNameClass = \case
 NameClass_Any -> HM.empty
 ns ::: _ -> HM.singleton ns Nothing
 (:*) ns -> HM.singleton ns Nothing
 x :|: y -> namespacesNameClass x <> namespacesNameClass y
 x :-: y -> namespacesNameClass x <> namespacesNameClass y

-- | Only parses "*", "{some-namespace}*", or "{some-namespace}some-localname".
instance IsString NameClass where
  fromString = \case
   "*" -> NameClass_Any
   full@('{':rest) ->
    case List.break (== '}') rest of
     (_, "") -> error $ "Invalid XML Clark notation: "<>show full
     (ns, "*") -> (:*) (fromString ns)
     (ns, local) -> fromString ns ::: fromString (List.drop 1 local)
   s -> let QName ns nl = fromString s in ns:::nl
instance Textify (Namespaces NCName, (Infix,Side), NameClass) where
  textify (nss,po,nc) = case nc of
   NameClass_Any -> textify '*'
   ns:::nl ->
    textify (prefixifyQName nss (QName ns nl))
   (:*) ns ->
    case HM.lookup ns (namespaces_prefixes nss) of
     Nothing -> "{"<>textify ns<>"}*"
     Just p -> textify p <> ":*"
   x :|: y -> pairIfNeeded pairParen po op $
    textify (nss,(op,SideL),x) <> " | " <> textify (nss,(op,SideR),y)
    where op = infixR 2
   x :-: y ->
    pairIfNeeded pairParen po op $
    textify (nss,(op,SideL),x) <> " - " <> textify (nss,(op,SideR),y)
    where op = infixL 6