{-------------------------------------------------------------------------------------
-
- Database connectivity using HDBC
- Programmer: Leonidas Fegaras
- Email: fegaras@cse.uta.edu
- Web: http://lambda.uta.edu/
- Creation: 05/12/08, last update: 06/12/08
- 
- Copyright (c) 2008 by Leonidas Fegaras, the University of Texas at Arlington. All rights reserved.
- This material is provided as is, with absolutely no warranty expressed or implied.
- Any use is at your own risk. Permission is hereby granted to use or copy this program
- for any purpose, provided the above notices are retained on all copies.
-
--------------------------------------------------------------------------------------}


module XML.HXQ.DB where

import Char(isSpace,toLower)
import Control.Monad.State
import Database.HDBC
import XML.HXQ.XTree
import XMLParse(XMLEvent(..),parseDocument)
import HXML(AttList)
import XML.HXQ.Parser


sql2xml :: SqlValue -> XTree
sql2xml value =
    case value of
      SqlString s -> XText s
      SqlByteString bs -> XText (show bs)
      SqlWord32 n -> XInt (fromEnum n)
      SqlWord64 n -> XInt (fromEnum n)
      SqlInt32 n -> XText (show n)
      SqlInt64 n -> XText (show n)
      SqlInteger n -> XInt (fromEnum n)
      SqlChar c -> XText [c]
      SqlBool b -> XBool b
      SqlDouble n -> XText (show n)
      SqlRational n -> XText (show n)
      SqlEpochTime n -> XText (show n)
      SqlTimeDiff n -> XText (show n)
      SqlNull -> XText ""


xml2sql :: XTree -> SqlValue
xml2sql e =
    case e of
      XText s -> SqlString s
      XInt n -> SqlInteger (toInteger n)
      XFloat n -> SqlString (show n)
      XBool n -> SqlBool n
      XElem n _ _ _ [x] -> xml2sql x
      _ -> error ("Cannot convert "++show e++" into sql")


perror = error "constructed elements have no parent"


executeSQL :: Statement -> XSeq -> IO XSeq
executeSQL stmt args
    = do n <- handleSqlError (execute stmt (map xml2sql args))
         result <- handleSqlError (fetchAllRowsAL stmt)
         return (map (\x -> XElem "row" [] 0 perror (map (\(s,v) -> XElem s [] 0 perror [sql2xml v]) x)) result)


prepareSQL :: (IConnection conn) => conn -> String -> IO Statement
prepareSQL db sql = handleSqlError (prepare db sql)


{---------------------------------------------------------------------------------------
-- extract the structural summary of an XML file that contains statistics
----------------------------------------------------------------------------------------}


-- structural summary: tag   id  max#      hasText children
data SSnode = SSnode String !Int !Int !Int !Bool   [SSnode]
            deriving (Eq,Show)


insertSS :: String -> [SSnode] -> State Int (Int,SSnode,[SSnode])
insertSS tag ((SSnode n i j l b ts):s)
    | n == tag
    = return (i,SSnode n i j (l+1) b ts,s)
insertSS tag (x:xs)
    = do (i,t,ts) <- insertSS tag xs
         return (i,t,x:ts)
insertSS tag []
    = do count <- get
         put (count+1)
         return (count+1,SSnode tag (count+1) 1 1 False [],[])


insSS :: String -> [SSnode] -> State Int [SSnode]
insSS tag ns = do (k,t,s) <- insertSS tag ns
                  return (t:s)


getSS :: [XMLEvent] -> [SSnode] -> State Int [SSnode]
getSS ((EmptyEvent n atts):xs) rs
    = getSS ((StartEvent n atts):(EndEvent n):xs) rs
getSS ((StartEvent n atts):xs) ((SSnode m i j l b ns):rs)
    = do (k,SSnode m' i' j' l' b' ks,ts) <- insertSS n ns
         as <- foldM (\r (a,_) -> insSS ('@':a) r) ks atts
         getSS xs (reset(SSnode m' i' j' l' b' as):(SSnode m i j l b ts):rs)
    where r (SSnode m i j _ b ts) = SSnode m i j 0 b ts
          reset (SSnode m i j l b ts) = SSnode m i j l b (map r ts)
getSS ((EndEvent n):xs) (t:(SSnode m i j l b ns):rs)
    = getSS xs ((SSnode m i j l b (set t:ns):rs))
    where s (SSnode m i j l b ts) = SSnode m i (max j l) 0 b ts
          set (SSnode m i j l b ts) = SSnode m i j l b (map s ts)
getSS ((TextEvent t):xs) ((SSnode m i j l False ns):rs)
    | any (not . isSpace) t
    = getSS xs ((SSnode m i j l True ns):rs)
getSS (_:xs) rs = getSS xs rs
getSS [] rs = return rs


{---------------------------------------------------------------------------------------
-- Derive a good relational schema based on the structural summary (using hybrid inlining)
----------------------------------------------------------------------------------------}


type Path = [Tag]


data Table = Table String Path Bool [Table]
           | Column String Path
           deriving (Show,Read)


printPath :: Path -> String
printPath [] = ""
printPath [p] = p
printPath (p:ps) = printPath ps++"/"++p


pathCons p ps = if p=="root" then ps else p:ps


schema :: SSnode -> String -> [String] -> [Table]
schema (SSnode n i _ (-1) _ ts) prefix path
    = [ Table (prefix++show i) (pathCons n path) True
              ((reverse (concatMap (\t -> schema t prefix []) ts))
               ++[ Column "value" [] ]) ]
schema (SSnode n i j _ _ []) prefix path
    | j == 1 || head n == '@'
    = [ Column (prefix++show i) (pathCons n path) ]
schema (SSnode n i 1 _ _ ts) prefix path
    = concatMap (\t -> schema t prefix (pathCons n path)) ts
schema (SSnode n i _ _ b ts) prefix path
    = [ Table (prefix++show i) (pathCons n path) False
              ((reverse (concatMap (\t -> schema t prefix []) ts))
              ++(if b && all (\(SSnode x _ _ _ _ _)-> head x == '@') ts
                 then [ Column "value" [] ] else [])) ]


fixSS :: SSnode -> SSnode
fixSS (SSnode n i j l True ts)
    | any (\(SSnode x _ _ _ _ _)-> head x /= '@') ts
    = SSnode n i j (-1) True (filter (\(SSnode x _ _ _ _ _)-> head x == '@') ts)
fixSS (SSnode n i j l b ts)
    = SSnode n i j l b (map fixSS ts)


deriveSchema :: String -> String -> IO Table
deriveSchema file prefix
    = do doc <- readFile file
         let ts = parseDocument doc
             d = getSS ts [SSnode "root" 1 1 1 False []]
             [SSnode _ _ _ _ _ [t]] = evalState d 1
             nt@(SSnode m i j l b s) = fixSS t
         return (Table prefix [] False (reverse (schema (SSnode m i 2 l b s) prefix [])))


relationalSchema :: Table -> String -> [String]
relationalSchema (Table n path b ts) parent
    = ("create table "++n++" (      /* "++printPath path
       ++(if b then " (mixed content)" else "")++" */\n"
       ++n++"_id int,\n"
       ++(if parent /= "" then (n++"_parent int references "++parent++"("++parent++"_id),\n") else "")
       ++(concat [ m++" varchar,    /* "++printPath p++" */\n" | Column m p <- ts ])
       ++"primary key ("++n++"_id))\n")
      :[ s | t@(Table _ _ _ _) <- ts, s <- relationalSchema t n ]


getTableNames :: Table -> [String]
getTableNames (Table n _ _ ts) = n:(concatMap getTableNames ts)
getTableNames _ = []


initializeDB :: (IConnection conn) => conn -> IO ()
initializeDB db
    = do tables <- getTables db
         if elem "HXQCatalog" tables
            then return ()
            else do let s = "create table HXQCatalog ( name varchar primary key, path varchar, summary varchar )"
                    handleSqlError (run db s [])
                    commit db


createSchema :: (IConnection conn) => conn -> String -> String -> IO Table
createSchema db file name
    = do initializeDB db
         stmt <- handleSqlError (prepare db "select summary from HXQCatalog where name = ?")
         _ <- handleSqlError (execute stmt  [SqlString name])
         result <- handleSqlError (fetchAllRowsAL stmt)
         if length result > 0
            then do let [[(_,SqlString s)]] = result
                        summary = (read s)::Table
                        tables = getTableNames summary
                    _ <- mapM (\t -> handleSqlError (run db ("drop table if exists "++t) [])) tables
                    _ <- handleSqlError (run db "delete from HXQCatalog where name = ?" [SqlString name])
                    commit db
            else return ()
         t <- deriveSchema file name
         let schema = relationalSchema t ""
         -- mapM putStrLn schema
         _ <- handleSqlError (run db "insert into HXQCatalog values (?,?,?)"
                                      [SqlString name, SqlString file, SqlString (show t)])
         _ <- mapM (\s -> handleSqlError (run db s [])) schema
         commit db
         return t


findSchema :: (IConnection conn) => conn -> String -> IO Table
findSchema db name
    = do initializeDB db
         stmt <- handleSqlError (prepare db "select summary from HXQCatalog where name = ?")
         _ <- handleSqlError (execute stmt  [SqlString name])
         result <- handleSqlError (fetchAllRowsAL stmt)
         if length result == 1
            then let [[(_,SqlString s)]] = result
                 in return ((read s)::Table)
            else error ("Schema "++name++" doesn't exist")


{---------------------------------------------------------------------------------------
-- Populate the database from the XML file and its derived structural summary
----------------------------------------------------------------------------------------}


findPath :: [Table] -> [String] -> Int -> Maybe (Int,Table)
findPath (t@(Table _ p _ s):ts) path _ | p == path = Just ((length s)-1,t)
findPath (t@(Column _ p):ts) path n | p == path = Just (n,t)
findPath ((Table _ _ _ _):ts) path n = findPath ts path n
findPath (_:ts) path n = findPath ts path (n+1)
findPath [] _ _ = Nothing


populate :: [XMLEvent] -> [Table] -> Int -> [[String]] -> [(Int,String)]
populate ((EmptyEvent tag atts):xs) ts n ps
    = populate ((StartEvent tag atts):(EndEvent tag):xs) ts n ps
populate (x@(StartEvent tag atts):xs) ((t@(Table n path _ s)):ts) _ (p:ps)
    = case findPath s (tag:p) 0 of
        Just (n,nt@(Table m _ True as))
            -> (-1,m):(popAtts atts as ++ showXTree xs 1 "")
               where showXTree ((EmptyEvent tag atts):xs) i s
                         = showXTree xs i (s++"<"++tag++showAL atts++"/>")
                     showXTree ((StartEvent tag atts):xs) i s
                         = showXTree xs (i+1) (s++"<"++tag++showAL atts++">")
                     showXTree ((EndEvent tag):xs) i s
                         = if i==1 then (n,s):(-2,m):(populate xs (t:ts) n (p:ps))
                           else showXTree xs (i-1) (s++"</"++tag++">")
                     showXTree ((TextEvent text):xs) i s = showXTree xs i (s++text)
                     showXTree (_:xs) i s = showXTree xs i s
        Just (n,nt@(Table m _ _ as))
            -> (-1,m):((popAtts atts as)++(populate xs (nt:t:ts) n ([]:p:ps)))
        Just (n,nt)
            -> populate xs (nt:t:ts) n ((tag:p):ps)
        Nothing -> populate xs (t:ts) 0 ((tag:p):ps)
      where popAtts ((a,v):as) ks
                = let Just(m,_) = findPath ks ['@':a] 0
                  in (m,v):(popAtts as ks)
            popAtts [] _ = []
populate ((EndEvent tag):xs) ((t@(Table n path _ s)):ts) _ ([]:ps)
    = (-2,n):populate xs ts 0 ps
populate ((EndEvent tag):xs) ((Column m path):ts) n (p:ps)
    = populate xs ts 0 (tail p:ps)
populate ((EndEvent text):xs) ts _ (p:ps)
    = populate xs ts 0 (tail p:ps)
populate ((TextEvent text):xs) ts n ps
    | any (not . isSpace) text
    = (n,text):populate xs ts n ps
populate (x:xs) ts n ps
    = populate xs ts n ps
populate [] ts n ps = []


insert :: (IConnection conn) => conn -> [(Int,String)] -> [(String,Int,Statement)] -> IO ()
insert db xs stmts = let (s,_,_,_) = m xs 0 0 in s
    where m ((-1,m):xs) i p = let (s,el,xs',i') = ml xs (i+1) i
                              in (s >> insertTuple m el i p,[],xs',i')
          m ((k,m):xs) i p = (return (),[(k,m)],xs,i)
          ml [] i p = (return (),[],[],i)
          ml ((-2,m):xs) i p = (return (),[],xs,i)
          ml xs i p = let (s,el,xs',i') = m xs i p
                          (s',el',xs'',i'') = ml xs' i' p
                      in (s >> s',el++el',xs'',i'')
          find x xs = foldr (\(a,v) r -> if x==a then v else r) "\NUL" xs
          insertTuple m e i p
              = let (len,stmt) = foldr (\(a,l,s) r -> if m==a then (l,s) else r) (error "") stmts
                    tuple = map (\c -> find c e) [0..len]
                    lift x = if x=="\NUL" then SqlNull else SqlString x
                in do _ <- handleSqlError (execute stmt
                                           (if i==0
                                            then SqlInteger i:(map lift tuple)
                                            else SqlInteger i:SqlInteger p:(map lift tuple)))
                      if mod i 100 == 99 then commit db else return ()
                      return ()


-- | Store an XML document into the database under the given name.
shred :: (IConnection conn) => conn -> String -> String -> IO ()
shred db file name
    = do let prefix = map toLower name
         let tableStmt (Table n _ _ ts)
                 = do let len = length[ 1 | Column _ _ <- ts]-1
                      stmt <- handleSqlError (prepare db ("insert into "++n++" values ("
                                                          ++(if n==prefix then "" else "?,")++"?"
                                                          ++(concatMap (\_ -> ",?") [0..len])++")"))
                      l <- mapM tableStmt ts
                      return ((n,len,stmt):(concat l))
             tableStmt _ = return []
         t <- createSchema db file prefix
         stmts <- tableStmt t
         doc <- readFile file
         let ts = parseDocument doc
         let ic = (-1,prefix):(populate ts [t] 0 [[]] ++ [(-2,prefix)])
         insert db ic stmts
         commit db
         return ()


-- | Create a secondary index on tagname for the shredded document under the given name..
createIndex :: (IConnection conn) => conn -> String -> String -> IO ()
createIndex db name tagname
    = do let prefix = map toLower name
         table <- findSchema db name
         let indexes = getIndexes "" table
         _ <- if null indexes
              then error ("there is no tagname: "++tagname)
              else mapM (\(t,c) -> do stmt <- handleSqlError (prepare db ("create index "++t++"_"++c++" on "++t++" ("++c++")"))
                                      handleSqlError (execute stmt [])) indexes
         commit db
         return ()
    where getIndexes _ (Table n _ _ ts) = concatMap (getIndexes n) ts
          getIndexes table (Column n path) | (head path)==tagname = [(table,n)]
          getIndexes _ _ = []


{----------------------------------------------------------------------------------------------------
--  Convert XQuery to SQL
----------------------------------------------------------------------------------------------------}


publishES :: [String] -> [String] -> String
publishES (p:ps) xs
    | head p == '@'
    = "attribute "++(tail p)++" {"++publishES ps xs++"}"
publishES (p:ps) xs
    = "<"++p++">{"++publishES ps xs++"}</"++p++">"
publishES [] [x] = x
publishES [] (x:xs) = x++","++publishES [] xs


publishS :: Table -> String -> String
publishS (Table n path b ts) "error"
    = "for $"++n++" in SQL(select(),from($"++n++"),true()) return "
      ++publishES (reverse path) (map (\t -> publishS t n) ts)
publishS (Table n path b ts) parent
    = "for $"++n++" in SQL(select(),from($"++n++"),$"++n++"/"++n++"_parent eq $"
      ++parent++"/"++parent++"_id) return "
      ++publishES (reverse path) (map (\t -> publishS t n) ts)
publishS (Column n path) parent
    = publishES (reverse path) ["$"++parent++"/"++n++"/text()"]


publishTable :: Table -> String
publishTable table = "<root>{" ++ publishS table "error" ++ "}</root>"


sqlComparisson = [("=","="),("eq","="),("<=","<="),(">=",">="),("!=","!="),(">",">"),
                  ("<","<"),("ne","!="),("gt",">"),("lt","<"),("ge",">="),("le","<=")]

sqlBoolean = [("and","and"),("or","or")]


-- Is this an SQL predicate?
sqlPredicate :: [String] -> Ast -> Bool
sqlPredicate tables e
    = case e of
        Ast "child_step" [Astring c,Avar v]
            -> elem v tables
        Ast "construction" [_,_,Ast "append" [x]]
            -> sqlPredicate tables x
        Ast "call" [Avar "text",x]
            -> sqlPredicate tables x
        Ast "call" [Avar cmp,x,y]
            | any (\(f,_) -> f==cmp) sqlComparisson
            -> (sqlExpr x) && (sqlExpr y)
        Ast "call" [Avar cmp,x,y]
            | any (\(f,_) -> f==cmp) sqlBoolean
            -> (sqlPredicate tables x) && (sqlPredicate tables y)
        _ -> False
      where sqlExpr e
                = case e of
                    Astring s -> True
                    Aint n -> True
                    Ast "child_step" [Astring c,Avar v]
                        -> elem v tables
                    Ast "construction" [_,_,Ast "append" [x]]
                        -> sqlExpr x
                    Ast "call" [Avar "text",x]
                        -> sqlExpr x
                    _ -> False


-- Convert a predicate AST to an SQL predicate that uses the tables
predToSQL :: [String] -> Ast -> (String,[Ast])
predToSQL tables e
    = case e of
        Ast "child_step" [Astring c,Avar v]
            -> if elem v tables
               then ("",[])
               else error ("Cannot convert to an SQL predicate: "++show e)
        Ast "construction" [_,_,Ast "append" [x]]
            -> predToSQL tables x
        Ast "call" [Avar "text",x]
            -> predToSQL tables x
        Ast "call" [Avar cmp,x,y]
            | any (\(f,_) -> f==cmp) sqlComparisson
            -> let (nx,vx) = expToSQL tables x
                   (ny,vy) = expToSQL tables y
               in if nx == ""
                  then (ny,vx)
                  else if ny == ""
                       then (nx,vy)
                       else (nx ++ " " ++ snd (head (filter (\(f,_) -> f==cmp) sqlComparisson)) ++ " " ++ ny,vx++vy)
        Ast "call" [Avar cmp,x,y]
            | any (\(f,_) -> f==cmp) sqlBoolean
            -> let (nx,vx) = predToSQL tables x
                   (ny,vy) = predToSQL tables y
               in if nx == ""
                  then (ny,vy)
                  else if ny == ""
                       then (nx,vx)
                       else (nx ++ " " ++ snd (head (filter (\(f,_) -> f==cmp) sqlBoolean)) ++ " " ++ ny,vx++vy)
        _ -> error ("Cannot convert to an SQL predicate: "++show e)
      where expToSQL tables e
                = case e of
                    Astring s -> ("\'"++s++"\'",[])
                    Aint n -> (show n,[])
                    Ast "child_step" [Astring c,Avar v]
                        -> if elem v tables
                           then (v++"."++c,[])
                           else ("?",[e])
                    Ast "construction" [_,_,Ast "append" [x]]
                        -> expToSQL tables x
                    Ast "call" [Avar "text",x]
                        -> expToSQL tables x
                    _ -> ("?",[e])


-- Convert an AST to an SQL query
makeSQL :: [Ast] -> Ast -> [Ast] -> (String,[Ast])
makeSQL tables pred cols
    = let tnames = [ x | Avar x <- tables ]
          ts = combine tnames
          cs = combine [ x | Avar x <- cols ]
          vars (Ast n args) = concatMap vars args
          vars (Avar v) | not (elem v tnames) = [v]
          vars _ = []
          combine [] = ""
          combine [x] = x
          combine (x:xs) = x++", "++combine xs
      in if pred == Ast "call" [Avar "true"]
         then (if null cs
               then "select * from "++ts
               else "select "++cs++" from "++ts,[])
         else let (p,args) = predToSQL tnames pred
              in (if null cs
                  then "select * from "++ts++" where "++p
                  else "select "++cs++" from "++ts++" where "++p,args)