{-------------------------------------------------------------------------------------
-
- A Compiler from XQuery to Haskell
- Programmer: Leonidas Fegaras
- Email: fegaras@cse.uta.edu
- Web: http://lambda.uta.edu/
- Creation: 02/15/08, last update: 05/02/08
- 
- Copyright (c) 2008 by Leonidas Fegaras, the University of Texas at Arlington. All rights reserved.
- This material is provided as is, with absolutely no warranty expressed or implied.
- Any use is at your own risk. Permission is hereby granted to use or copy this program
- for any purpose, provided the above notices are retained on all copies.
-
--------------------------------------------------------------------------------------}


{-# OPTIONS_GHC -fth #-}

module XQueryCompiler where

import Char(isDigit)
import List(sortBy)
import Language.Haskell.TH
import XMLParse(parseDocument)
import HXML(AttList)
import XQueryParser
import XTree
import XQueryOptimizer


{--------------- XPath Steps ---------------------------------------------------------}


-- XPath step /tag or /*
child_step :: Tag -> XTree -> XSeq
child_step m x
    = case x of
        (XElem _ _ _ bs)
            -> foldr (\b s -> case b of
                                (XElem k _ _ _) | (k==m || m=="*") -> b:s
                                _ -> s) [] bs
        _ -> []


-- XPath step //tag or //*
descendant_step :: Tag -> XTree -> XSeq
descendant_step m (x@(XElem t _ _ cs))
    | m==t || m=="*"
    = x:(concatMap (descendant_step m) cs)
descendant_step m (XElem t _ _ cs) = concatMap (descendant_step m) cs
descendant_step m _ = []


-- It's like //* but has tagged children, which are derived statically
-- After examing 100 children it gives up: this avoids space leaks
descendant_any_with_tagged_children :: [Tag] -> XTree -> XSeq
descendant_any_with_tagged_children tags (x@(XElem t _ _ cs))
    | all (\tag -> foldr (\b s -> case b of
                                    (XElem k _ _ _) -> s || k == tag
                                    _ -> s) False (take 100 cs)) tags
    = x:(concatMap (descendant_any_with_tagged_children tags) cs)
descendant_any_with_tagged_children tags (XElem t _ _ cs)
    = concatMap (descendant_any_with_tagged_children tags) cs
descendant_any_with_tagged_children tags _ = []


-- XPath step /@attr or /@*
attribute_step :: Tag -> XTree -> XSeq
attribute_step m x
    = case x of
        (XElem _ al _ _) -> foldr (\(k,v) s -> if k==m || m=="*"
                                               then (XText v):s
                                               else s) [] al
        _ -> []


-- XPath step //@attr or //@*
attribute_descendant_step :: Tag -> XTree -> XSeq
attribute_descendant_step m (x@(XElem _ al _ cs))
    = foldr (\(k,v) s -> if k==m || m=="*"
                         then (XText v):s
                         else s)
            (concatMap (attribute_descendant_step m) cs) al
attribute_descendant_step m _ = []


{------------ Functions --------------------------------------------------------------}


-- find the value of a variable in an association list
findV var env
  = case filter (\(n,_) -> n==var) env of
      (_,b):_ -> b
      _ -> error ("Undefined variable: "++var)

-- is the variable defined in the association list?
memV var env
  = case filter (\(n,_) -> n==var) env of
      (_,b):_ -> True
      _ -> False


-- like foldr but with an index
foldir :: (a -> Int -> b -> b) -> b -> [a] -> Int -> b
foldir c n [] i = n
foldir c n (x:xs) i = c x i (foldir c n xs (i+1))

trueXT = XBool True

readNum :: String -> Maybe XTree
readNum cs = case span isDigit cs of
               (n,[]) -> Just (XInt (read n))
               (n,'.':rest) -> case span isDigit rest of
                                 (k,[]) -> Just (XFloat (read (n++('.':k))))
                                 _ -> Nothing
               _ -> Nothing

text :: XSeq -> XSeq
text xs = foldr (\x r -> case x of
                           XElem _ _ _ [z@(XText _)] -> z:r
                           XElem _ _ _ [z@(XInt _)] -> z:r
                           XElem _ _ _ [z@(XFloat _)] -> z:r
                           XElem _ _ _ [z@(XBool _)] -> z:r
                           XText _ -> x:r
                           XInt _ -> x:r
                           XFloat _ -> x:r
                           XBool _ -> x:r
                           _ -> r) [] xs

toString :: XSeq -> [String]
toString xs = map (\x -> case x of 
                           XText t -> t
                           XInt n -> show n
                           XFloat n -> show n
                           XBool n -> show n)
                  (text xs)

toNum :: XSeq -> XSeq
toNum xs = foldr (\x r -> case x of
                            XInt n -> x:r
                            XFloat n -> x:r
                            XText s -> case readNum s of
                                         Just t -> t:r
                                         _ -> r
                            _ -> r) [] (text xs)

toFloat :: XTree -> Float
toFloat (XText s) = case readNum s of
                      Just (XInt n) -> fromIntegral n
                      Just (XFloat n) -> n
                      _ -> error("Cannot convert to a float: "++s)
toFloat (XInt n) = fromIntegral n
toFloat (XFloat n) = n
toFloat x = error("Cannot convert to a float: "++(show x))

contains :: String -> String -> Bool
contains xs ys | ((take (length ys) xs) == ys) = True
contains (_:xs) ys = contains xs ys
contains [] ys = False

distinct :: Eq a => [a] -> [a]
distinct = foldl (\r a -> if elem a r then r else r++[a]) []

arithmetic :: (Float -> Float -> Float) -> XTree -> XTree -> XTree
arithmetic op (XInt n) (XInt m) = XInt (round (op (fromIntegral n) (fromIntegral m)))
arithmetic op (XFloat n) (XFloat m) = XFloat (op n m)
arithmetic op (XFloat n) (XInt m) = XFloat (op n (fromIntegral m))
arithmetic op (XInt n) (XFloat m) = XFloat (op (fromIntegral n) m)

compareXTrees :: XTree -> XTree -> Ordering
compareXTrees (XElem _ _ _ _) _ = EQ
compareXTrees _ (XElem _ _ _ _) = EQ
compareXTrees (XInt n) (XInt m) = compare n m
compareXTrees (XFloat n) (XInt m) = compare n (fromIntegral m)
compareXTrees (XInt n) (XFloat m) = compare (fromIntegral n) m
compareXTrees (XFloat n) (XFloat m) = compare n m
compareXTrees (XText n) (XText m) = compare n m
compareXTrees x y = compare (toFloat x) (toFloat y)

strictCompareOne [XInt n] [XInt m] = compare n m
strictCompareOne [XFloat n] [XFloat m] = compare n m
strictCompareOne [XFloat n] [XInt m] = compare n (fromIntegral m)
strictCompareOne [XInt n] [XFloat m] = compare (fromIntegral n) m
strictCompareOne [XText n] [XText m] = compare n m
strictCompareOne x y = error ("Illegal operands in strict comparison: "++(show x)++" "++(show y))

strictCompare :: XSeq -> XSeq -> Ordering
strictCompare [XElem _ _ _ x] [XElem _ _ _ y] = strictCompareOne x y
strictCompare x [XElem _ _ _ y] = strictCompareOne x y
strictCompare [XElem _ _ _ x] y = strictCompareOne x y
strictCompare x y = strictCompareOne x y

compareXSeqs :: Bool -> XSeq -> XSeq -> Ordering
compareXSeqs ord xs ys
    = let comps = [ compareXTrees x y | x <- xs, y <- ys ]
      in if ord
            then if all (\x -> x == LT) comps
                    then LT
                 else if all (\x -> x == GT) comps
                    then GT
                 else EQ
         else if all (\x -> x == LT) comps
                 then GT
              else if all (\x -> x == GT) comps
                 then LT
              else EQ

conditionTest :: XSeq -> Bool
conditionTest [] = False
conditionTest [XText ""] = False
conditionTest [XInt 0] = False
conditionTest [XBool False] = False
conditionTest _ = True


-- XPath steps
paths :: [(Tag,Q Exp)]
paths = [ ( "child_step", [| child_step |] ),
          ( "descendant_step", [| descendant_step |] ),
          ( "attribute_step", [| attribute_step |] ),
          ( "attribute_descendant_step", [| attribute_descendant_step |] )
        ]


type Function = [Q Exp] -> Q Exp

-- System functions: they can also be defined as Haskell functions of type (XSeq,...,XSeq) -> XSeq
-- but here we make sure they are unfolded and fused with the rest of the query
functions :: [(Tag,Int,Function)]
functions = [ ( "=", 2, \[xs,ys] -> [| [ trueXT | x <- text $xs, y <- text $ys, compareXTrees x y == EQ ] |] ),
              ( "!=", 2, \[xs,ys] -> [| if null [ trueXT | x <- text $xs, y <- text $ys, compareXTrees x y == EQ ] then [trueXT] else [] |] ),
              ( ">", 2, \[xs,ys] -> [| [ trueXT | x <- text $xs, y <- text $ys, compareXTrees x y == GT ] |] ),
              ( "<", 2, \[xs,ys] -> [| [ trueXT | x <- text $xs, y <- text $ys, compareXTrees x y == LT ] |] ),
              ( ">=", 2, \[xs,ys] -> [| [ trueXT | x <- text $xs, y <- text $ys, compareXTrees x y `elem` [GT,EQ] ] |] ),
              ( "<=", 2, \[xs,ys] -> [| [ trueXT | x <- text $xs, y <- text $ys, compareXTrees x y `elem` [LT,EQ] ] |] ),
              ( "eq", 2, \[xs,ys] -> [| if strictCompare $xs $ys == EQ then [trueXT] else [] |] ),
              ( "neq", 2, \[xs,ys] -> [| if strictCompare $xs $ys /= EQ then [trueXT] else [] |] ),
              ( "lt", 2, \[xs,ys] -> [| if strictCompare $xs $ys == LT then [trueXT] else [] |] ),
              ( "gt", 2, \[xs,ys] -> [| if strictCompare $xs $ys == GT then [trueXT] else [] |] ),
              ( "le", 2, \[xs,ys] -> [| if strictCompare $xs $ys `elem` [LT,EQ] then [trueXT] else [] |] ),
              ( "ge", 2, \[xs,ys] -> [| if strictCompare $xs $ys `elem` [GT,EQ] then [trueXT] else [] |] ),
              ( "<<", 2, \[xs,ys] -> [| [ trueXT | XElem _ _ ox _ <- $xs, XElem _ _ oy _  <- $ys, ox < oy ] |] ),
              ( ">>", 2, \[xs,ys] -> [| [ trueXT | XElem _ _ ox _ <- $xs, XElem _ _ oy _  <- $ys, ox > oy ] |] ),
              ( "is", 2, \[xs,ys] -> [| [ trueXT | XElem _ _ ox _ <- $xs, XElem _ _ oy _  <- $ys, ox == oy ] |] ),
              ( "+", 2, \[xs,ys] -> [| [ arithmetic (+) x y | x <- toNum $xs, y <- toNum $ys ] |] ),
              ( "-", 2, \[xs,ys] -> [| [ arithmetic (-) x y | x <- toNum $xs, y <- toNum $ys ] |] ),
              ( "*", 2, \[xs,ys] -> [| [ arithmetic (*) x y | x <- toNum $xs, y <- toNum $ys ] |] ),
              ( "div", 2, \[xs,ys] -> [| [ arithmetic (/) x y | x <- toNum $xs, y <- toNum $ys ] |] ),
              ( "idiv", 2, \[xs,ys] -> [| [ XInt (div x y) | (XInt x) <- toNum $xs, (XInt y) <- toNum $ys ] |] ),
              ( "mod", 2, \[xs,ys] -> [| [ XInt (mod x y) | (XInt x) <- toNum $xs, (XInt y) <- toNum $ys ] |] ),
              ( "uplus", 1, \[xs] -> [| [ x | x <- toNum $xs ] |] ),
              ( "uminus", 1, \[xs] -> [| [ case x of XInt n -> XInt (-n); XFloat n -> XFloat (-n) | x <- toNum $xs ] |] ),
              ( "and", 2, \[xs,ys] -> [| if (conditionTest $xs) && (conditionTest $ys) then [trueXT] else [] |] ),
              ( "or", 2, \[xs,ys] -> [| if (conditionTest $xs) || (conditionTest $ys) then [trueXT] else [] |] ),
              ( "not", 1, \[xs] -> [| if (conditionTest $xs) then [] else [trueXT] |] ),
              ( "some", 1, \[xs] -> [| if (conditionTest $xs) then [trueXT] else [] |] ),
              ( "count", 1, \[xs] -> [| [ XInt (length $xs) ] |] ),
              ( "sum", 1, \[xs] -> [| [ XFloat (sum [ toFloat x | x <- toNum $xs ]) ] |] ),
              ( "avg", 1, \[xs] -> [| let ys = $xs in [ XFloat ((sum [ toFloat x | x <- toNum ys ])
                                                                / (fromIntegral (length ys))) ] |] ),
              ( "min", 1, \[xs] -> [| [ XFloat (minimum [ toFloat x | x <- toNum $xs ]) ] |] ),
              ( "max", 1, \[xs] -> [| [ XFloat (maximum [ toFloat x | x <- toNum $xs ]) ] |] ),
              ( "to", 2, \[xs,ys] -> [| [ XInt i | XInt n <- toNum $xs, XInt m <- toNum $ys, i <- [n..m] ] |] ),
              ( "text", 1, \[xs] -> [| text $xs |] ),
              ( "string", 1, \[xs] -> [| text $xs |] ),
              ( "data", 1, \[xs] -> [| text $xs |] ),
              ( "node", 1, \[xs] -> [| $xs |] ),
              ( "empty", 0, \[] -> [| [] |] ),
              ( "true", 0, \[] -> [| [trueXT] |] ),
              ( "false", 0, \[] -> [| [] |] ),
              ( "if", 3, \[cs,ts,es] -> [| if conditionTest $cs then $ts else $es |] ),
              ( "element", 2, \[tags,xs] -> [| [ x | tag <- toString $tags, x@(XElem t _ _ _) <- $xs, (t==tag || tag=="*") ] |] ),
              ( "attribute", 2, \[tags,xs] -> [| [ z | tag <- toString $tags, x <- $xs, z <- attribute_step tag x ] |] ),
              ( "name", 1, \[xs] -> [| [ XText tag | XElem tag _ _ _ <- $xs ] |] ),
              ( "contains", 2, \[xs,text] -> [| [ trueXT | x <- toString $xs, t <- toString $text, contains x t ] |] ),
              ( "concatenate", 2, \[xs,ys] -> [| $xs ++ $ys |] ),
              ( "concat", 2, \[xs,ys] -> [| [ XText (showXS ($xs ++ $ys)) ] |] ),
              ( "distinct-values", 1, \[xs] -> [| distinct $xs |] ),
              ( "union", 2, \[xs,ys] -> [| distinct ($xs ++ $ys) |] ),
              ( "intersect", 2, \[xs,ys] -> [| filter (\x -> elem x $ys) $xs |] ),
              ( "except", 2, \[xs,ys] -> [| filter (\x -> not (elem x $ys)) $xs  |] ),
              ( "reverse", 1, \[xs] -> [| reverse $xs |] )
            ]


-- functions to be used by the interpreter
-- when evaluated, it gives [(String,Int,[XSeq]->XSeq)]
iFunctions :: Q Exp
iFunctions = foldr (\(fname,len,f) r -> let vars = map (\i -> mkName ("v_"++(show i))) [1..len]
                                            entry = tupE [litE (StringL fname),litE (IntegerL (toInteger len)),
                                                          lamE [listP (map varP vars)] (f (map varE vars))]
                                        in [| $entry : $r |]) [| [] |] functions


-- XPath steps to be used by the interpreter
-- when evaluated, it gives [(String,Tag->XTree->XSeq)]
pFunctions = foldr (\(pname,p) r -> let pn = litE (StringL pname) in [| ($pn,$p) : $r |]) [| [] |] paths


-- make a function call
callF :: Tag -> Function
callF fname args = case filter (\(n,_,_) -> n == fname || ("fn:"++n)==fname) functions of
                     (_,len,f):_ -> if (length args) == len
                                       then f args
                                    else error ("wrong number of arguments in function call: " ++ fname)
                     _ ->     -- otherwise, it must be a Haskell function of type (XSeq,...,XSeq) -> XSeq
                          let itp = case args of
                                      [] -> [t| () |]
                                      [_] -> [t| XSeq |]
                                      _ -> foldr (\_ r -> appT r [t| XSeq |]) (appT (tupleT (length args)) [t| XSeq |])
                                                 (tail args)
                              fn = sigE (varE (mkName fname))
                                        (appT (appT arrowT itp) [t| XSeq |])
                          in appE fn (tupE args)


{------------ Compiler ---------------------------------------------------------------}


undef1 = [| error "Undefined XQuery context (.)" |]
undef2 = [| error "Undefined position()" |]
undef3 = [| error "Undefined last()" |]


-- does the expression contain a last()?
containsLast :: Ast -> Bool
containsLast (Ast "step" [Ast "call" [Avar "last"]]) = True
containsLast (Ast f _) | elem f ["let","for","predicate"] = False
containsLast (Ast "step" _) = False
containsLast (Ast _ args) = or (map containsLast args)
containsLast _ = False


-- calculate the maximum position value used in a predicate, if there is one
maxPosition :: Ast -> Ast -> Int
maxPosition position e
    = case e of
        Ast "call" [Avar f,Ast "step" [p],Aint n]
            | f `elem` ["=","<","<=","eq","lt","le"] && p == position
            -> n
        Ast "call" [Avar f,Aint n,Ast "step" [p]]
            | f `elem` ["=",">",">=","eq","gt","ge"] && p == position
            -> n
        Ast "let" [Avar x,source,body]
            -> if position == Avar x
               then 0 else minp (maxPosition position source) (maxPosition position body)
        Ast "for" [Avar x,Avar i,source,body]
            -> if position == Avar x || position == Avar i
               then 0 else minp (maxPosition position source) (maxPosition position body)
        Ast "predicate" [pred,body]
            -> minp (maxPosition position pred) (maxPosition position body)
        Ast "call" [Avar "and",x,y]
            -> minp (maxPosition position x) (maxPosition position y)
        Ast "call" [Avar "or",x,y]
            -> max (maxPosition position x) (maxPosition position y)
        _ -> 0
    where minp x y = if x == 0 then y else if y == 0 then x else min x y


pathPosition = Ast "call" [Avar "position"]


-- Each XPath predicate must calculate position() and last() from its input XSeq
-- if last() is used, then the evaluation is blocking (need to store the whole input XSeq)
compilePredicates :: [Ast] -> Q Exp -> Bool -> Q Exp
compilePredicates [] xs _ = xs
compilePredicates ((Aint n):preds) xs _   -- shortcut that improves laziness
    = compilePredicates preds
            [| [ $xs !! $(litE (IntegerL (toInteger (n-1)))) ] |] True
compilePredicates (pred:preds) xs True    -- top-k like
    | maxPosition pathPosition pred > 0
    = compilePredicates (pred:preds)
           [| take $(litE (IntegerL (toInteger (maxPosition pathPosition pred)))) $xs |] False
compilePredicates (pred:preds) xs _
    | containsLast pred         -- blocking: use only when last() is used in the predicate
    = compilePredicates preds
            [| let bl = $xs
                   len = length bl
               in foldir (\x i r -> if case $(compile pred [| x |] [| [XInt i] |] [| [XInt len] |] "") of
                                         [XInt k] -> k == i               -- indexing
                                         b -> conditionTest b
                                    then x:r else r) [] bl 1 |] True
compilePredicates (pred:preds) xs _
    = compilePredicates preds
            [| foldir (\x i r -> if case $(compile pred [| x |] [| [XInt i] |] undef3 "") of
                                      [XInt k] -> k == i               -- indexing
                                      b -> conditionTest b
                                 then x:r else r) [] $xs 1 |] True


-- extract the QName
qName :: XSeq -> Tag
qName [XText s] = s
qName e = error ("Invalid QName: "++(show e))


-- Compile the AST e into Haskell code
-- context: context node (XPath .)
-- position: the element position in the parent sequence (XPath position())
-- last: the length of the parent sequence (XPath last())
-- effective_axis: the XPath axis in /axis::tag(exp)
--        (eg, the effective axis of //(A | B) is "descendant_step"
compile :: Ast -> Q Exp -> Q Exp -> Q Exp -> String -> Q Exp
compile e context position last effective_axis
  = case e of
      Avar "." -> [| [ $context :: XTree ] |]
      Avar v -> let x = varE (mkName v)
                in [| $x :: XSeq |]
      Aint n -> let x = litE (IntegerL (toInteger n))
                in [| [ XInt $x ] |]
      Afloat n -> let x = litE (RationalL (toRational n))
                  in [| [ XFloat $x ] |]
      Astring s -> let x = litE (StringL s)
                   in [| [ XText $x ] |]
      Ast "context" [v,Astring dp,body]
          -> [| foldr (\x r -> $(compile body [| x |] position last dp)++r)
                      [] $(compile v context position last effective_axis) |]
      Ast "doc" [Aint n] -> let d = varE (mkName ("_doc"++(show n))) in [| [ $d ] |]
      Ast "call" [Avar "position"]
          -> position
      Ast "call" [Avar "last"]
          -> last
      Ast "step" [Ast "child_step" [tag, Avar "."]]
          | effective_axis /= ""
          -> compile (Ast "step" [Ast effective_axis [tag, Avar "."]]) context position last ""
      Ast "step" ((Ast "descendant_any" (body:tags)):predicates)
          -> let bc = compile body context position last effective_axis
                 ts = listE (map (\(Avar tag) -> litE (stringL tag)) tags)
             in [| foldr (\x r -> $(compilePredicates predicates [| descendant_any_with_tagged_children $ts x |] True)++r)
                         [] $bc |]
      Ast "step" ((Ast path_step [Astring tag,body]):predicates)
          |  memV path_step paths
          -> let bc = compile body context position last effective_axis
                 tc = litE (stringL tag)
             in [| foldr (\x r -> $(compilePredicates predicates [| $(findV path_step paths) $tc x |] True)++r)
                         [] $bc |]
      Ast "step" [exp]
          -> compile exp context position last effective_axis
      Ast "step" (exp:predicates)
          -> compilePredicates predicates (compile exp context position last effective_axis) True
      Ast "predicate" [condition,body]
          -> compilePredicates [condition] (compile body context position last effective_axis) True
      Ast "call" ((Avar f):args)
          -> callF f (map (\x -> compile x context position last effective_axis) args)
      Ast "construction" [Astring tag,Ast "attributes" [],body]
          -> let ct = litE (StringL tag)
                 bc = compile body context position last effective_axis
             in [| [ XElem $ct [] 0 $bc ] |]
      Ast "construction" [tag,Ast "attributes" al,body]
          -> let alc = foldr (\(Ast "pair" [a,v]) r
                                  -> let ac = compile a context position last effective_axis
                                         vc = compile v context position last effective_axis
                                     in [| (qName $ac,showXS (text $vc)) : $r |]) [| [] |] al
                 ct = compile tag context position last effective_axis
                 bc = compile body context position last effective_axis
             in [| [ XElem (qName $ct) $alc 0 $bc ] |]
      Ast "let" [Avar var,source,body]
          -> do s <- compile source context position last effective_axis
                b <- compile body context position last effective_axis
                return (AppE (LamE [VarP (mkName var)] b) s)
      Ast "for" [Avar var,Avar "$",source,body]      -- a for-loop without an index
          -> let b = compile body [| head $(varE (mkName var)) |] undef2 undef3 ""
                 f = lamE [varP (mkName var)] [| \r -> $b ++ r |]
                 s = compile source context position last effective_axis
             in [| foldr (\x -> $f [x]) [] $s |]
      Ast "for" [Avar var,Avar ivar,source,body]     -- a for-loop with an index
          -> let b = compile body [| head $(varE (mkName var)) |]
                             [| $(varE (mkName ivar)) |] undef3 ""
                 f = lamE [varP (mkName var)] (lamE [varP (mkName ivar)] [| \r -> $b ++ r |])
                 p = maxPosition (Avar ivar) body
                 ns = if p > 0              -- there is a top-k like restriction
                      then Ast "step" [source,Ast "call" [Avar "<=",Ast "step" [pathPosition],Aint p]]
                      else source
                 s = compile ns context position last effective_axis
             in [| foldir (\x i -> $f [x] [XInt i]) [] $s 1 |]
      Ast "sortTuple" (exp:orderBys)             -- prepare each FLWOR tuple for sorting
          -> let res = foldl (\r a -> let ac = compile a context position last effective_axis
                                      in [| $r++[text $ac] |] )
                             [| [ $(compile exp context position last effective_axis) ] |] orderBys
             in [| [ $res ] |]
      Ast "sort" (exp:ordList)                   -- blocking
          -> let ce = compile exp context position last effective_axis
                 ordering = foldr (\(Avar ord) r
                                       -> let asc = if ord == "ascending"
                                                    then [| True |]
                                                    else [| False |]
                                          in [| \(x:xs) (y:ys) -> case compareXSeqs $asc x y of
                                                                    EQ -> $r xs ys
                                                                    o -> o |])
                                  [| \xs ys -> EQ |] ordList
             in [| concatMap head (sortBy (\(_:xs) (_:ys) -> $ordering xs ys) ($ce::[[XSeq]])) |]
      _ -> error ("Illegal XQuery: "++(show e))


-- collect all input documents and assign them a unique number
getDocs :: Ast -> Int -> (Ast, Int, [(Int, Ast)])
getDocs query count =
    case query of
      Ast "call" [Avar "doc",file]
             -> (Ast "doc" [Aint count], count+1, [(count,file)])
      Ast "call" [Avar "fn:doc",file]
             -> (Ast "doc" [Aint count], count+1, [(count,file)])
      Ast n args -> let (s,c,ns) = foldr (\a r c -> let (e,c1,n1) = getDocs a c
                                                        (s,c2,n2) = r c1
                                                    in (e:s,c2,docUnion n1 n2))
                                         (\c -> ([],c,[])) args count
                    in (Ast n s,c,ns)
      _ -> (query,count,[])
    where docUnion xs ((n,s):ys) = (n,foldr(\(m,d) r -> if s==d then Aint m else r) s xs):(docUnion xs ys)
          docUnion xs [] = xs


-- optimize and compile an AST
compileAst :: Ast -> Q Exp
compileAst ast = compile (optimize ast) undef1 undef2 undef3 ""


-- compile an XQuery AST that reads XML documents
compileQuery :: [Ast] -> Q Exp
compileQuery ((Ast "function" ((Avar f):b:args)):xs)
    = let lvars = case args of
                    [Astring a] -> [varP (mkName a)]
                    _ -> [tupP (map (\(Avar a) -> varP (mkName a)) args)]
      in letE [valD (varP (mkName f)) (normalB (lamE lvars (compileAst b))) []]
              (compileQuery xs)
compileQuery ((Ast "variable" [Avar v,u]):xs)
    = letE [valD (varP (mkName v)) (normalB (compileAst u)) []]
           (compileQuery xs)
compileQuery [query]
    = let (ast,_,ns) = getDocs query 0
          code = compileAst ast
      in foldl (\r (n,file) -> let d = lamE [varP (mkName ("_doc"++(show n)))] r
                                       in case file of
                                            Aint m -> [| $d $(varE (mkName ("_doc"++(show m)))) |]
                                            _ -> [| do let [XText f] = $(compileAst file)
                                                       doc <- readFile f
                                                       $d (materialize (parseDocument doc)) |])
               [| return $code |] ns


-- Debugging: display the AST and the Haskell code of an input XQuery
cq :: String -> IO ()
cq query = do putStrLn "Abstract Syntax Tree:"
              let ast = parse (scan query)
              putStrLn (show ast)
              let opt = optimize (last ast)
              putStrLn "Optimized AST:"
              putStrLn (show opt)
              putStrLn "Haskell Code:"
              let code = compileQuery ast
              runQ code >>= putStrLn.pprint


-- Run an XQuery expression that does not read XML documents
-- When evaluated, it returns XSeq
xe :: String -> Q Exp
xe query = compileAst (last (parse (scan query)))


-- Run an XQuery that reads XML documents
-- When evaluated, it returns IO XSeq
xq :: String -> Q Exp
xq query = compileQuery (parse (scan query))