{-------------------------------------------------------------------------------------
-
- The XQuery polymorphic type inference system
- Programmer: Leonidas Fegaras
- Based on P. Wadler's work on XQuery type checking
-   and on "Typing Haskell in Haskell" by Mark P. Jones
- Email: fegaras@cse.uta.edu
- Web: http://lambda.uta.edu/
- Creation: 08/10/09, last update: 09/29/09
- 
- Copyright (c) 2009 by Leonidas Fegaras, the University of Texas at Arlington. All rights reserved.
- This material is provided as is, with absolutely no warranty expressed or implied.
- Any use is at your own risk. Permission is hereby granted to use or copy this program
- for any purpose, provided the above notices are retained on all copies.
-
--------------------------------------------------------------------------------------}


module Text.XML.HXQ.TypeInference where

import Control.Monad
import Text.XML.HXQ.Parser
import Text.XML.HXQ.XTree
import Text.XML.HXQ.Functions
import Text.XML.HXQ.Types
import Debug.Trace
import Data.List

type Subst = [(TVar,Type)]


apply :: Subst -> Type -> Type
apply s t@(TVariable v)
      = case lookup v s of
          Just bt -> bt
          _ -> t
apply s (TElement n t)
    = TElement n (apply s t)
apply s (TAttribute n t)
    = TAttribute n (apply s t)
apply s (TSequence t1 t2)
    = TSequence (apply s t1) (apply s t2)
apply s (TInterleaving t1 t2)
    = TInterleaving (apply s t1) (apply s t2)
apply s (TChoice t1 t2)
    = TChoice (apply s t1) (apply s t2)
apply s (TQualified t c)
    = TQualified (apply s t) c
apply s t = t


tv :: Type -> [TVar]
tv (TVariable v) = [v]
tv (TElement n t) = tv t
tv (TAttribute n t) = tv t
tv (TSequence t1 t2) = union (tv t1) (tv t2)
tv (TInterleaving t1 t2) = union (tv t1) (tv t2)
tv (TChoice t1 t2) = union (tv t1) (tv t2)
tv (TQualified t c) = tv t
tv t = []


compose :: Subst -> Subst -> Subst
compose s1 s2 = [ (v,apply s1 t) | (v,t) <- s2 ] ++ s1


merge :: Subst -> Subst -> Maybe Subst
merge s1 s2
    = if all (\v -> apply s1 (TVariable v) == apply s2 (TVariable v))
             (intersect (map fst s1) (map fst s2))
      then Just $ s1++s2
      else Nothing


varBind :: TVar -> Type -> Maybe Subst
varBind v (TVariable w) | v == w = Just []
varBind v t
    = if elem v (tv t) then Nothing else Just [(v,t)]


makeSequence ts
    = simplifyType $ ms ts
      where ms [] = TEmpty
            ms [t] = t
            ms (t:ts) = TSequence t (ms ts)


simplifyType t
    = st t
      where st t
                = case t of
                    TSequence TEmpty t -> st t
                    TSequence t TEmpty -> st t
                    TInterleaving TEmpty t -> st t
                    TInterleaving t TEmpty -> st t
                    TQualified TEmpty _ -> TEmpty
                    TQualified (TQualified t q1) q2
                        -> st (TQualified t (if q1=='*' || q2=='*' then '*' else if q1=='+' || q2=='+' then '+' else '?'))
                    TChoice TEmpty t -> st (TQualified t '?')
                    TChoice t TEmpty -> st (TQualified t '?')
                    TChoice t1 t2 | t1 == t2 -> st t1
                    TSequence t1 t2 -> TSequence (st t1) (st t2)
                    TInterleaving t1 t2 -> TInterleaving (st t1) (st t2)
                    TChoice t1 t2 -> TChoice (st t1) (st t2)
                    TQualified t q -> TQualified (st t) q
                    _ -> t


xpathStep :: Type -> String -> String -> NS -> Type
xpathStep t step tag ns
    = xps t
      where xps t@(TElement n t')
                = case step of
                    "self" -> if n==tag || tag=="*" then t else TEmpty
                    "child" -> xpathStep t' "self" tag ns
                    "descendant" -> xpathStep t' "descendant-or-self" tag ns
                    "descendant-or-self"
                        -> if n==tag || tag=="*"
                           then makeSequence [t,xpathStep t' "descendant-or-self" tag ns]
                           else xpathStep t' "descendant-or-self" tag ns
                    "attribute-self" -> xpathStep t' "" tag ns
                    "attribute-descendant" -> xps t'
                    _ -> tNode
            xps t@(TAttribute n t')
                = case step of
                    "attribute-self" -> if n==tag || tag=="*" then t else TEmpty
                    "attribute-descendant"
                        -> if n==tag || tag=="*" then makeSequence [t,xps t'] else xps t'
                    _ -> TEmpty
            xps (TSequence t1 t2)
                = makeSequence [xps t1,xps t2]
            xps (TInterleaving t1 t2)
                = simplifyType $ TInterleaving (xps t1) (xps t2)
            xps (TChoice t1 t2)
                = simplifyType $ TChoice (xps t1) (xps t2)
            xps (TQualified t c)
                = simplifyType $ TQualified (xps t) c
            xps t
                | t == tItem
                = case expandQName (QName "" (defaultElementNS ns) tag) ns of
                    t'@(TElement _ _) -> t'
                    t' ->  TElement tag t'
            xps t = TEmpty


collectAttributes :: Type -> (Type,[Type])
collectAttributes = ca
    where ca (TSequence t1 t2)
              = let (t1',a1) = ca t1
                    (t2',a2) = ca t2
                in (TSequence t1' t2',a1++a2)
          ca (TChoice t1 t2)
              = let (t1',a1) = ca t1
                    (t2',a2) = ca t2
                in (TChoice t1' t2',a1++a2)
          ca (TInterleaving t1 t2)
              = let (t1',a1) = ca t1
                    (t2',a2) = ca t2
                in (TInterleaving t1' t2',a1++a2)
          ca (TQualified t c)
              = let (t',a) = ca t
                in (TQualified t' c,a)
          ca t@(TAttribute _ _) = (TEmpty,[t])
          ca t = (t,[])


normalizeType :: Type -> [[Type]]
normalizeType = nt
    where nt (TSequence t1 t2)
              = [ p1++p2 | p1 <- nt t1, p2 <- nt t2 ]
          nt (TChoice t1 t2)
              = union (nt t1) (nt t2)
          nt (TInterleaving t1 t2)
              = let ps1 = nt t1
                    ps2 = nt t2
                in union [ p1++p2 | p1 <- ps1, p2 <- ps2 ]
                         [ p2++p1 | p1 <- ps1, p2 <- ps2 ]
          nt (TElement n t)
              = [ [TElement n (makeSequence p)] | p <- nt t ]
          nt (TQualified t '?')
              = union (nt t) [[]]
          nt TEmpty = []
          nt t = [[t]]


strip :: Type -> Type
strip t
    = case p t of
        [] -> TEmpty
        [t] -> t
        ts -> foldr1 TChoice ts
      where p TEmpty = []
            p (TSequence t1 t2) = union (p t1) (p t2)
            p (TInterleaving t1 t2) = union (p t1) (p t2)
            p (TChoice t1 t2) = union (p t1) (p t2)
            p (TQualified t c) = p t
            p t = [t]


dotQual :: Char -> Char -> Char
dotQual q1 q2
    = if q2=='0' then '0'
      else case q1 of
             '0' -> '0'
             '-' -> q2
             '?' -> if elem q2 "-?" then '?' else '*'
             '+' -> if elem q2 "-+" then '+' else '*'
             '*' -> '*'


makeDot :: Type -> Char -> Type
makeDot t q
    = case dotQual (qualifier t) q of
        '0' -> TEmpty
        '-' -> strip t
        c -> TQualified (strip t) c


leqQual :: Char -> Char -> Bool
leqQual q1 q2
    = case q1 of
        '0' -> elem q2 "0?*"
        '-' -> elem q2 "-?+*"
        '?' -> elem q2 "?*"
        '+' -> elem q2 "+*"
        '*' -> q2=='*'


qualifier :: Type -> Char
qualifier t
    = q t
      where q TEmpty = '0'
            q (TSequence t1 t2)
                = let q1 = q t1
                      q2 = q t2
                  in if q2=='0' then q1
                     else case q1 of
                            '0' -> q2
                            '-' -> '+'
                            '?' -> if elem q2 "-+" then '+' else '*'
                            '+' -> '+'
                            '*' -> if elem q2 "-+" then '+' else '*'
            q (TInterleaving t1 t2) = q (TSequence t1 t2)
            q (TChoice t1 t2)
                = let q1 = q t1
                      q2 = q t2
                  in case q1 of
                       '0' -> if q2=='-' then '?' else if q2=='+' then '*' else q2
                       '-' -> if q2=='0' then '?' else q2
                       '?' -> if elem q2 "+*" then '*' else '?'
                       '+' -> if elem q2 "-+" then '+' else '*'
                       '*' -> '*'
            q (TQualified t c)
                = dotQual (q t) c
            q t = '-'


subtype :: String -> String -> Bool
subtype n m
    = n == m
      || m == "string"
      || n == "string"
      || (m == "numeric" && elem n ["integer", "decimal", "float", "double"])
      || (n == "numeric" && elem m ["integer", "decimal", "float", "double"])
      || (n /= "anyAtomicType" && n /= "numeric" && isBuildInType n && subtype (findA n buildInTypes) m)


mguL :: [Type] -> [Type] -> Maybe Subst
mguL [] []
    = Just []
mguL (t1:ts1) (t2:ts2)
    = liftM2 compose (mgu t1 t2) (mguL ts1 ts2)
mguL _ _ = Nothing


-- most general unifier with subtyping t1<=t2
mgu :: Type -> Type -> Maybe Subst
mgu (TVariable v) t = varBind v t
mgu t (TVariable v) = varBind v t
mgu (TBase n) (TBase m)
    = if uri n==xsNamespace && uri m==xsNamespace
         && (subtype (localName n) (localName m))
      then Just []
      else Nothing
mgu (TElement n1 t1@(TBase n)) t2@(TBase m)
    = mgu t1 t2
mgu _ TAny = Just []
mgu TAny _ = Just []
mgu _ (TItem "item") = Just []
mgu (TItem n) (TItem "node") = Just []
mgu (TElement n1 t1) (TElement n2 t2)
    | n1 == n2 || n2 == "*"
    = mgu t1 t2
mgu (TAttribute n1 t1) (TAttribute n2 t2)
    | n1 == n2 || n2 == "*"
    = mgu t1 t2
mgu (TQualified t1 c1) (TQualified t2 c2)
    = if leqQual c1 c2
      then mgu t1 t2
      else Nothing
mgu (TSequence t1 t2) t3@(TQualified _ '*')
    = liftM2 compose (mgu t1 t3) (mgu t2 t3)
mgu t1 (TQualified t2 c)
    = mgu t1 t2
mgu t1 t2
    | length nt1 + length nt2 > 2
    = msum [ foldr (liftM2 compose) (Just [])
                   [ mguL p1 p2 | p2 <- nt2 ]
           | p1 <- nt1 ]
    where nt1 = normalizeType t1
          nt2 = normalizeType t2
mgu t1 t2 = Nothing


newtype TI a = TI (Subst -> Int -> (Subst,Int,a))

instance Monad TI where
    return x = TI(\s n -> (s,n,x))
    TI c >>= f = TI(\s n -> let (s',m,x) = c s n; TI fx = f x in fx s' m)

instance Functor TI where
  fmap = liftM

instance Applicative TI where
  pure  = return
  (<*>) = ap


runTI :: TI a -> a
runTI (TI c) = let (s,n,result) = c [] 0 in result


getSubst :: TI Subst
getSubst = TI(\s n -> (s,n,s))


extSubst :: Subst -> TI()
extSubst s' = TI(\s n -> (compose s' s,n,()))


unifyT :: Ast -> Type -> Type -> TI() -> TI()
unifyT e t1 t2 c
    = do s <- getSubst
         case mgu (apply s t1) (apply s t2) of
           Just s' -> extSubst s'
           Nothing -> c


unify :: Ast -> Type -> Type -> TI()
unify e t1 t2
    = unifyT e t1 t2 (error $ "Incompatible type in "++show e++" (expected: "
                              ++show t1++", found: "++show t2++")")


unifyL :: Ast -> [Type] -> [Type] -> TI()
unifyL e ts1 ts2
    = do s <- getSubst
         case mguL (map (apply s) ts1) (map (apply s) ts2) of
           Just s' -> extSubst s'
           Nothing -> error $ "Incompatible types in "++show e++" (expected: "
                              ++show ts1++", found: "++show ts2++")"


newTVar :: TI Type
newTVar = TI(\s n -> (s,n+1,TVariable n))


-- type assumptions for the XQuery variables
type Assumptions = [(String,Type)]


-- function signatures
type Signatures = [(String,([Type],Type))]


functionSignatures :: NS -> Signatures
functionSignatures ns
    = map (\(fn,len,otp:tps,_,_) -> (fn,(tps,otp))) systemFunctions


-- XQuery Type Inference
typeInf :: Ast -> Assumptions -> Type -> Signatures -> NS -> TI Type
typeInf e assumptions context fncs ns
  = do t <- ti e assumptions context
       s <- getSubst
       --trace (show e++" "++show assumptions++" "++show s++" -> "++show t) (return t)
       return t
  where
    ti e assumptions context =
     case e of
      Avar "." -> return context
      Avar v -> return (findV v assumptions)
      Ast "global" [Avar v] -> return (findV v assumptions)
      Aint n -> return tInt
      Afloat n -> return tFloat
      Astring s -> return tString
      Ast "nonIO" [u] -> typeInf u assumptions context fncs ns
      Ast "context" [v,Astring dp,body]
          -> do vt <- typeInf v assumptions context fncs ns
                typeInf body assumptions vt fncs ns
      Ast "call" [Avar "position"] -> return tInt
      Ast "call" [Avar "last"] -> return tInt
      Ast "call" [Avar f,Astring file]
          | elem f ["doc","fn:doc"]
          -> return (TItem "item")
      Ast "call" [Avar "debug",c]
          -> typeInf c assumptions context fncs ns
      Ast "call" [Avar "eval",x]
          -> do tx <- typeInf x assumptions context fncs ns
                () <- unify x tx tString
                newTVar
      Ast "step" (Avar step:Astring tag:e:preds)
          -> do te <- typeInf e assumptions context fncs ns
                let te' = xpathStep te step tag ns
                _ <- mapM (\p -> do tp <- typeInf p assumptions te' fncs ns
                                    unifyT p tp tBool (unify p tp tNumeric)) preds
                return te'
      Ast "filter" (e:preds)
          -> do te <- typeInf e assumptions context fncs ns
                _ <- mapM (\p -> do tp <- typeInf p assumptions te fncs ns
                                    unifyT p tp tBool (unify p tp tNumeric)) preds
                return te
      Ast "predicate" [condition,body]
          -> do tb <- typeInf body assumptions context fncs ns
                tc <- typeInf condition assumptions tb fncs ns
                () <- unify condition tc tBool
                return tb
      Ast "append" args
          -> do as <- mapM (\x -> typeInf x assumptions context fncs ns) args
                return $ makeSequence as
      Ast "if" [c,x,y]
          -> do tc <- typeInf c assumptions context fncs ns
                () <- unify c tc tBool
                tt <- typeInf x assumptions context fncs ns
                te <- typeInf y assumptions context fncs ns
                () <- unify e tt te
                return tt
      Ast "validate" [e]
          -> typeInf e assumptions context fncs ns
      Ast "insert" [e1,e2]
          -> do t1 <- typeInf e1 assumptions context fncs ns
                t2 <- typeInf e2 assumptions context fncs ns
                return TEmpty
      Ast "delete" [e]
          -> do te <- typeInf e assumptions context fncs ns
                return TEmpty
      Ast "replace" [e1,e2]
          -> do t1 <- typeInf e1 assumptions context fncs ns
                t2 <- typeInf e2 assumptions context fncs ns
                return TEmpty
      Ast "call" ((Avar "concatenate"):args)
          -> do ts <- mapM (\a -> typeInf a assumptions context fncs ns) args
                return $ makeSequence ts
      Ast "call" ((Avar "concat"):args)
          -> do ts <- mapM (\a -> typeInf a assumptions context fncs ns) args
                return $ tString
      Ast "call" (v@(Avar fname):args)
          -> do ts <- mapM (\a -> typeInf a assumptions context fncs ns) args
                let fn = functionTag fname ns
                    t = tag fname ns
                case filter (\(n,_) -> n == localName fn) (functionSignatures ns) of
                  [] -> if uri t == xsNamespace && isBuildInType (localName t) && length ts == 1
                        then return $ TBase t
                        else case filter (\(n,_) -> n == localName fn) fncs of
                               (_,(params,out)):_
                                   -> if (length params) == (length args)
                                      then do () <- unifyL e ts params
                                              return out
                                      else error ("Wrong number of arguments in function call: "++fname)
                               _ -> error ("Undefined function: "++fname)
                  fs -> case filter (\(_,(pts,_)) -> length args == length pts) fs of
                          [] -> error ("wrong number of arguments in function call: " ++ fname)
                          fs' -> case filter (\(_,(pts,ot)) -> mguL ts pts /= Nothing) fs' of
                                   [(_,(pts,ot))] -> do () <- unifyL e ts pts
                                                        return ot
                                   _ -> error ("Incompatible arguments in function call: "++fname++show ts
                                               ++"\n(expected "++concatMap (\(_,(pts,_)) -> show pts++" ") fs'++")")
      Ast "construction" [tag,id,parent,Ast "attributes" al,body]
             -> do tt <- typeInf tag assumptions context fncs ns
                   () <- unify e tt tString
                   tb <- typeInf body assumptions context fncs ns
                   let (tb',ats) = collectAttributes tb
                   alc <- mapM (\(Ast "pair" [a,v])
                                    -> do ta <- typeInf a assumptions context fncs ns
                                          () <- unify a ta tString
                                          tv <- typeInf v assumptions context fncs ns
                                          return $ TAttribute (case a of Astring n -> n; _ -> "*") tv) al
                   return $ TElement (case tag of Astring n -> n; _ -> "*")
                                     (makeSequence (alc++ats++[tb']))
      Ast "attribute_construction" [name,value]
          -> do tn <- typeInf name assumptions context fncs ns
                () <- unify e tn tString
                tv <- typeInf value assumptions context fncs ns
                return $ TAttribute (case name of Astring n -> n; _ -> "*") tv
      Ast "let" [Avar var,source,body]
          -> do ts <- typeInf source assumptions context fncs ns
                typeInf body ((var,ts):assumptions) context fncs ns
      Ast "for" [Avar var,Avar "$",source,body]      -- a for-loop without an index
          -> do ts <- typeInf source assumptions context fncs ns
                let te = strip ts
                ot <- typeInf body ((var,te):assumptions) context fncs ns
                return $ makeDot ot (qualifier ts)
      Ast "for" [Avar var,Avar ivar,source,body]     -- a for-loop with an index
          -> do ts <- typeInf source assumptions context fncs ns
                let te = strip ts
                ot <- typeInf body ((var,te):(ivar,tInt):assumptions) context fncs ns
                return $ makeDot ot (qualifier ts)
      Ast "sortTuple" (exp:orderBys)             -- prepare each FLWOR tuple for sorting
          -> do te <- typeInf exp assumptions context fncs ns
                _ <- mapM (\a -> typeInf a assumptions context fncs ns) orderBys
                return te
      Ast "sort" (exp:ordList)
          -> typeInf exp assumptions context fncs ns
      Ast "type" [e]
          -> return $ toType e ns
      _ -> error ("Illegal XQuery: "++(show e))


typeInference :: Ast -> Assumptions -> Signatures -> NS -> Type
typeInference e as fncs ns
    = runTI $ typeInf e as (error "Unspecified context") fncs ns


typeCheck :: Ast -> Type -> Assumptions -> Signatures -> NS -> Bool
typeCheck e t as fncs ns
    = runTI $ do t' <- typeInf e as (error "Unspecified context") fncs ns
                 () <- unify e t t'
                 return True