module HESQL.Translator (generateCode) where

import Language.Haskell.Syntax
import Database.HDBC.PostgreSQL

import HESQL.Syntax 
import HESQL.Verifier

import System.FilePath

import Database.HsSqlPpp.PrettyPrinter.PrettyPrinter
import Database.HsSqlPpp.Ast.Ast

import Data.List (elemIndex)

stmtRec = HsIdent "Stmts"
funarg = HsPVar . HsIdent
funargs = map funarg

srcloc = SrcLoc "<dummy>" 1 1

hsSimpleFun name vars body wheres =
   HsFunBind [HsMatch srcloc (HsIdent name) (funargs vars) body wheres]
    

generateCode :: FilePath -> String -> HesqlModule -> IO HsModule
generateCode fn db (HesqlModule modName decls) = do
    conn <- connectPostgreSQL db
    mapM_ (verifySql conn . declSQL) decls'
    return $ HsModule srcloc (Module modName) Nothing imports hsDecls
  where imports   = [HsImportDecl srcloc (Module "Database.HDBC") 
                       False Nothing Nothing]
        hsDecls     = [dataType, initFun decls'] ++ map declFun decls'
        dataType  = HsDataDecl srcloc [] stmtRec [] 
                    [HsRecDecl srcloc stmtRec recFields] []
        mkRecField decl = ([HsIdent $ stmtRecName decl], HsBangedTy (HsTyCon $ UnQual $ HsIdent"Statement"))
        recFields = map mkRecField  decls
        decls'    = map updateDeclSql decls


updateDeclSql decl = decl { declSQL = substituteVarsSql (declVars decl) (declSQL decl) }

hsapp fun = HsApp (HsVar $ UnQual $ HsSymbol fun)
hsapps fun exps = foldl HsApp (HsVar $ UnQual $ HsSymbol fun) exps
hsvar v = HsVar $ UnQual $ HsSymbol v

stmtRecName decl = "stmt_" ++ declName decl

initFun decls = 
    hsSimpleFun "init" ["conn"] body []
  where body    = HsUnGuardedRhs $ HsDo (mkStmts ++ [returnRec])
        mkStmts = map mkStmt decls
        mkStmt decl = HsGenerator srcloc (HsPVar $ HsIdent  $ declName decl)
                      (hsapps "prepare" [hsvar "conn", HsLit $  HsString sqlText])
             where sqlText = printSql [declSQL decl]
        returnRec = HsQualifier $ hsapp "return" $ HsRecConstr (UnQual $ stmtRec) fields
        fields    = map mkField decls
        mkField decl = HsFieldUpdate 
                         (UnQual $ HsIdent $ stmtRecName decl)
                         (HsVar $ UnQual $ HsIdent $ declName decl)


declFun decl =  -- TODO declLOC
          hsSimpleFun (declName decl) ("h":declVars decl) body (stmtDef : maybeTupleFun)
     where body      = HsUnGuardedRhs $ HsDo (bindAndExecute:result)
           bindAndExecute = HsQualifier $ hsapps "execute" [stmt, valList]
           selectFlag     = isSelect $ declSQL decl
           result 
               | selectFlag = [resultStmt decl stmt]
               | otherwise               = []
           maybeTupleFun 
               | selectFlag = [tupleFun $ selectColumnLength $ declSQL decl]
               | otherwise  = []
           stmt           = HsVar $ UnQual $ HsSymbol "stmt"
           valList        = HsList $ map mkSqlVal $ declVars decl
           mkSqlVal  v    = hsapp "toSql" (hsvar v)
           stmtDef        = HsPatBind srcloc (HsPVar $ HsIdent "stmt") 
                               (HsUnGuardedRhs $ HsApp (HsVar $ UnQual $ HsIdent $ stmtRecName decl) (hsvar "h"))
                               []

resultStmt decl stmt = HsQualifier $ hsapps "fmap" [HsParen (hsapp "fmap" (hsvar "toTuple")), HsParen (hsapp (fetchFun opt) stmt)]
  where opt = declOpt decl
        fetchFun MaybeQuery  = "fetchRow"
        fetchFun StrictQuery = "fetchAllRows'"
        fetchFun LazyQuery   = "fetchAllRows"


tupleFun n = HsFunBind [
              HsMatch srcloc (HsIdent "toTuple") [HsPList $ map HsPVar lvars]
                      (HsUnGuardedRhs (HsTuple (map convVar lvars))) []
             , HsMatch srcloc (HsIdent "toTuple") [HsPWildCard]
                        (HsUnGuardedRhs (hsapp "error" $ HsLit $ HsString "hesql internal error"))
                        []
            ]

    where lvars  = map lvar [1..n]
          lvar i = HsIdent $ "v" ++ show i
          convVar v = hsapp "fromSql" (HsVar (UnQual  v))


substituteVarsSql :: [String] -> Statement -> Statement
substituteVarsSql vars sql = 
  case sql of
    SelectStatement ann se -> SelectStatement ann $ substituteSelectExpression vars se
    Update ann table sets wh sl ->
        Update ann table 
               (substituteSetClauselList vars sets)
               (fmap substE wh )
               (substMSL sl)
    Insert ann tab cols e sl ->
        Insert ann tab cols (substituteSelectExpression vars e)
               sl
    Delete ann s wh msl ->
        Delete ann s (fmap substE wh) (substMSL msl)
    _ -> error $ "substituteVarsSql: " ++ show sql
 where substE = substituteExpression vars
       substMSL = fmap $ substituteSelectList vars


substituteSetClauselList vars = map (substituteSetClausel vars)
substituteSetClausel vars (SetClause ann a e) = SetClause ann a (substituteExpression vars e)
substituteSetClausel vars (RowSetClause ann as es) = RowSetClause ann as $ map (substituteExpression vars) es



substituteSelectExpression vars sql = 
    case sql of 
      Select ann dist sl tabref wh groupby having order a b ->
    -- TODO names for a? and b?
          Select ann dist 
            (substituteSelectList vars sl)
            tabref   -- TODO
            (se' wh)
            (map se groupby)
            (se' having)
            (map sde order)
            (se' a)
            (se' b)
      Values ann el ->
          Values ann (map (map  (substituteExpression vars)) el)
   where se = substituteExpression vars
         se' = fmap se
         sde (e, dir) = (se e, dir)

           
substituteExpression :: [String] -> Expression -> Expression
substituteExpression _ l@(BooleanLit _ _) = l
substituteExpression _ l@(FloatLit _ _) = l
substituteExpression vars l@(FunCall ann s el) =
    FunCall ann s (fmap (substituteExpression vars) el)
substituteExpression _ l@(StringLit _ _ _) = l
substituteExpression _ l@(IntegerLit _ _) = l
substituteExpression vars id@(Identifier ann v) =
    case v `elemIndex` vars of
      Nothing -> id
      Just i  -> PositionalArg ann (fromIntegral i+1)
substituteExpression vars (Case ann l e) =
    Case ann l' e'
  where e' = fmap substE e
        l' = map substOneCase l
        substE = substituteExpression vars
        substOneCase (cs, e) = (map substE cs, substE e)

substituteExpression vars (CaseSimple ann e1 l e2) =
    CaseSimple ann e1' l' e2'
  where e1' = substE e1
        e2' = fmap substE e2
        l' = map substOneCase l
        substE = substituteExpression vars
        substOneCase (cs, e) = (map substE cs, substE e)

substituteExpression vars (Cast ann e tn) =
    Cast ann (substituteExpression vars e) tn

substituteExpression vars (Exists ann e) =
    Exists ann (substituteSelectExpression vars e)

substituteExpression vars l@(NullLit _) = l

substituteExpression vars (PositionalArg _ _) 
    = error "don't use PositionalArg"

substituteExpression vars (ScalarSubQuery ann se) = 
    ScalarSubQuery ann (substituteSelectExpression vars se)


substituteExpression vars  (WindowFn ann e el1 el2 dir) =
    WindowFn ann (substE e) (substEL el1) (substEL el2) dir
  where substE = substituteExpression vars
        substEL = map substE


substituteSelectList vars (SelectList ann sis sl) = SelectList ann sis' sl
  where sis' = map (substituteSelectItem vars) sis

substituteSelectItem vars item =
    case item of
      SelExp ann e -> SelExp ann (substituteExpression vars e)
      SelectItem ann e s -> SelectItem ann (substituteExpression vars e) s


--substituteExpression _ _ = todo "substituteExpression" -- TODO print Exp
