module HESQL.Translator (generateCode) where import Language.Haskell.Syntax import Database.HDBC.PostgreSQL import HESQL.Syntax import HESQL.Verifier import System.FilePath stmtRec = HsIdent "Stmts" funarg = HsPVar . HsIdent funargs = map funarg srcloc = SrcLoc "" 1 1 hsSimpleFun name vars body wheres = HsFunBind [HsMatch srcloc (HsIdent name) (funargs vars) body wheres] generateCode :: FilePath -> String -> HesqlModule -> IO HsModule generateCode fn db (HesqlModule modName decls) = do conn <- connectPostgreSQL db mapM_ (verifySql conn . declSQL) decls' return $ HsModule srcloc (Module modName) Nothing imports hsDecls where imports = [HsImportDecl srcloc (Module "Database.HDBC") False Nothing Nothing] hsDecls = [dataType, initFun decls'] ++ map declFun decls' dataType = HsDataDecl srcloc [] stmtRec [] [HsRecDecl srcloc stmtRec recFields] [] mkRecField decl = ([HsIdent $ stmtRecName decl], HsBangedTy (HsTyCon $ UnQual $ HsIdent"Statement")) recFields = map mkRecField decls decls' = map updateDeclSql decls updateDeclSql decl = decl { declSQL = substituteVarsSql (declVars decl) (declSQL decl) } hsapp fun = HsApp (HsVar $ UnQual $ HsSymbol fun) hsapps fun exps = foldl HsApp (HsVar $ UnQual $ HsSymbol fun) exps hsvar v = HsVar $ UnQual $ HsSymbol v stmtRecName decl = "stmt_" ++ declName decl initFun decls = hsSimpleFun "init" ["conn"] body [] where body = HsUnGuardedRhs $ HsDo (mkStmts ++ [returnRec]) mkStmts = map mkStmt decls mkStmt decl = HsGenerator srcloc (HsPVar $ HsIdent $ declName decl) (hsapps "prepare" [hsvar "conn", HsLit $ HsString sqlText]) where sqlText = show $ declSQL decl returnRec = HsQualifier $ hsapp "return" $ HsRecConstr (UnQual $ stmtRec) fields fields = map mkField decls mkField decl = HsFieldUpdate (UnQual $ HsIdent $ stmtRecName decl) (HsVar $ UnQual $ HsIdent $ declName decl) declFun decl = -- TODO declLOC hsSimpleFun (declName decl) ("h":declVars decl) body (stmtDef : maybeTupleFun) where body = HsUnGuardedRhs $ HsDo (bindAndExecute:result) bindAndExecute = HsQualifier $ hsapps "execute" [stmt, valList] selectFlag = isSelect $ declSQL decl result | selectFlag = [resultStmt (selectOpts $ declSQL decl) stmt] | otherwise = [] maybeTupleFun | selectFlag = [tupleFun $ selectColumnLength $ selectColumns $ declSQL decl] | otherwise = [] stmt = HsVar $ UnQual $ HsSymbol "stmt" valList = HsList $ map mkSqlVal $ placeHoldersSql $ declSQL decl mkSqlVal v = hsapp "toSql" (hsvar v) stmtDef = HsPatBind srcloc (HsPVar $ HsIdent "stmt") (HsUnGuardedRhs $ HsApp (HsVar $ UnQual $ HsIdent $ stmtRecName decl) (hsvar "h")) [] resultStmt opts stmt = HsQualifier $ hsapps "fmap" [HsParen (hsapp "fmap" (hsvar "toTuple")), HsParen (hsapp (prefix++suffix) stmt)] where suffix | opts == [Strict] = "'" | otherwise = "" prefix | ReturnMaybe `elem` opts = "fetchRow" | otherwise = "fetchAllRows" tupleFun n = HsFunBind [ HsMatch srcloc (HsIdent "toTuple") [HsPList $ map HsPVar lvars] (HsUnGuardedRhs (HsTuple (map convVar lvars))) [] , HsMatch srcloc (HsIdent "toTuple") [HsPWildCard] (HsUnGuardedRhs (hsapp "error" $ HsLit $ HsString "hesql internal error")) [] ] where lvars = map lvar [1..n] lvar i = HsIdent $ "v" ++ show i convVar v = hsapp "fromSql" (HsVar (UnQual v)) placeHolderCols AllColumns = [] placeHolderCols (ExplicitColumns cols) = concatMap placeHoldersExp $ map fst cols placeHoldersSql (SELECT _ cols from wh ord grp) = placeHolderCols cols ++ maybe [] placeHoldersExp wh placeHoldersSql (INSERT _ _ vals) = concatMap placeHoldersExp vals placeHoldersSql (UPDATE _ updates wh) = concatMap (placeHoldersExp . snd) updates ++ maybe [] placeHoldersExp wh placeHoldersExp (SqlPlaceHolder ph) = [ph] placeHoldersExp (SqlInfixApp e1 _ e2) = placeHoldersExp e1 ++ placeHoldersExp e2 placeHoldersExp (SqlColumn _) = [] placeHoldersExp (SqlFunApp n args) = concatMap placeHoldersExp args placeHoldersExp (SqlLiteral _) = [] placeHoldersExp (SqlNot e) = placeHoldersExp e substituteVarsExp vars (SqlInfixApp e1 op e2) = SqlInfixApp e1' op e2' where e1' = substituteVarsExp vars e1 e2' = substituteVarsExp vars e2 substituteVarsExp _ sql@(SqlPlaceHolder _) = sql substituteVarsExp _ sql@(SqlLiteral _) = sql substituteVarsExp vars sql@(SqlColumn col) | col `elem` vars = SqlPlaceHolder col | otherwise = sql substituteVarsExp vars (SqlNot e) = SqlNot (substituteVarsExp vars e) substituteVarsExp vars (SqlFunApp n args) = SqlFunApp n (map (substituteVarsExp vars) args) substituteVarsSql vars (SELECT opt cols from wh ord grp) = SELECT opt cols from wh' ord grp where wh' :: Maybe SqlExp wh' = fmap (substituteVarsExp vars) wh substituteVarsSql vars (INSERT tab spec vals) = INSERT tab spec $ map (substituteVarsExp vars) vals substituteVarsSql vars (UPDATE tab updates wh) = UPDATE tab updates' $ fmap (substituteVarsExp vars) wh where updates' = map (mapSnd $ substituteVarsExp vars) updates mapSnd f (a, b) = (a, f b) selectColumnLength (ExplicitColumns cols) = length cols