{-# LANGUAGE ScopedTypeVariables, TemplateHaskell, TypeSynonymInstances, FlexibleInstances #-} module Database.Persist.HsSqlPpp (Query, persistSql, persistSqlFile, parseEntityFromFile, parseEntity, selectFromQuery, selectFromQuery', checkSQL ) where import Control.Applicative import Control.Monad.Reader import Control.Monad.Trans import Control.Monad.Error import Control.Monad.Instances import Control.Monad.IO.Control import Data.Char import Data.List (intercalate) import qualified Data.ByteString as B import qualified Data.Text as Text import Language.Haskell.TH import qualified Language.Haskell.TH.Lift as THL import qualified Language.Haskell.TH.Syntax as TH import Language.Haskell.TH.Quote import Database.Persist import Database.Persist.Base import Database.Persist.GenericSql.Raw import qualified Database.Persist.GenericSql.Internal as I import Database.Persist.TH import Database.HsSqlPpp.Types as Types import Database.HsSqlPpp.Ast as AST import Database.HsSqlPpp.Annotation import Database.HsSqlPpp.Catalog import Database.HsSqlPpp.Parser import Database.HsSqlPpp.Pretty -- | Pairs (column name, column type) type Types = [(String, String)] -- | (SQL query text, entity definition) type Query = (QueryExpr, EntityDef) class IsName x where name2string :: x -> String instance IsName String where name2string = id instance IsName AST.Name where name2string (AST.Name _ list) = name2string list instance IsName NameComponent where name2string n = ncStr n instance IsName [NameComponent] where name2string list = ncStr (last list) toString :: B.ByteString -> String toString bs = map (chr . fromIntegral) (B.unpack bs) THL.deriveLift ''SelectItem THL.deriveLift ''SelectList THL.deriveLift ''Distinct THL.deriveLift ''FrameClause THL.deriveLift ''TypeName THL.deriveLift ''LiftFlavour THL.deriveLift ''IntervalField THL.deriveLift ''InList THL.deriveLift ''ExtractField THL.deriveLift ''ScalarExpr THL.deriveLift ''Direction THL.deriveLift ''TableAlias THL.deriveLift ''NameComponent THL.deriveLift ''AST.Name THL.deriveLift ''JoinExpr THL.deriveLift ''JoinType THL.deriveLift ''Natural THL.deriveLift ''TableRef THL.deriveLift ''CombineType THL.deriveLift ''WithQuery THL.deriveLift ''TypeError THL.deriveLift ''FunFlav THL.deriveLift ''CastContext THL.deriveLift ''PseudoType THL.deriveLift ''Types.Type THL.deriveLift ''CatalogUpdate THL.deriveLift ''Annotation THL.deriveLift ''QueryExpr allE :: (a -> Either e b) -> b -> [a] -> Either e b allE _ b [] = Right b allE fn b (x:xs) = case fn x of Left e -> Left e Right _ -> allE fn b xs (<&>) :: Either String a -> Either String a -> (a -> Either String a) (Right _) <&> (Right _) = Right (Right _) <&> (Left e) = \_ -> Left e (Left e) <&> (Right _) = \_ -> Left e (Left e) <&> (Left e') = \_ -> Left (e ++ "; " ++ e') emptyAnn :: Annotation emptyAnn = Annotation Nothing Nothing [] Nothing [] Nothing Nothing makeQuery :: (PersistEntity t) => I.Connection -> QueryExpr -> [Filter t] -> Either String QueryExpr makeQuery conn query filts = foldM (addWhere conn) query filts addWhere :: (PersistEntity t) => I.Connection -> QueryExpr -> Filter t -> Either String QueryExpr addWhere conn (Select ann dist fields trefs old a b c d e) filt = Right $ Select ann dist fields trefs (Just $ convertCond conn filt old) a b c d e addWhere _ x _ = Left $ "Unsupported query expr in addWhere: " ++ show x name :: String -> AST.Name name s = Name emptyAnn [Nmc s] convertCond :: (PersistEntity t) => I.Connection -> Filter t -> Maybe ScalarExpr -> ScalarExpr convertCond conn f Nothing = convertFilter conn f convertCond conn f (Just old) = FunCall emptyAnn (name "!and") [convertFilter conn f, old] convertFilter conn (Filter field value op) = FunCall emptyAnn (name $ showSqlFilter op) (Identifier emptyAnn (Nmc fname): allVals) where allVals = map persist2expr $ filterValueToPersistValues value persist2expr (PersistText t) = StringLit emptyAnn (Text.unpack t) persist2expr (PersistByteString bstr) = StringLit emptyAnn (toString bstr) persist2expr (PersistInt64 i) = NumberLit emptyAnn (show i) persist2expr (PersistDouble x) = NumberLit emptyAnn (show x) persist2expr (PersistBool b) = BooleanLit emptyAnn b persist2expr (PersistDay d) = StringLit emptyAnn (show d) persist2expr (PersistTimeOfDay tod) = StringLit emptyAnn (show tod) persist2expr (PersistUTCTime t) = StringLit emptyAnn (show t) persist2expr (PersistNull) = NullLit emptyAnn persist2expr x = error $ "persist2expr: unsupported Persistent value in filter: " ++ show x filterValueToPersistValues :: forall a. PersistField a => Either a [a] -> [PersistValue] filterValueToPersistValues v = map toPersistValue $ either return id v fname = I.escapeName conn $ getFieldName t $ columnName $ persistColumnDef field t = entityDef $ dummyFromFilts [Filter field value op] getFieldName :: EntityDef -> String -> I.RawName getFieldName t s = I.rawFieldName $ tableColumn t s -- This is copy-paste from Database.Persist.GenericSql.Internal showSqlFilter Eq = "=" showSqlFilter Ne = "<>" showSqlFilter Gt = ">" showSqlFilter Lt = "<" showSqlFilter Ge = ">=" showSqlFilter Le = "<=" showSqlFilter In = " IN " showSqlFilter NotIn = " NOT IN " showSqlFilter (BackendSpecificFilter s) = s dummyFromFilts :: [Filter v] -> v dummyFromFilts _ = error "dummyFromFilts" tableColumn :: EntityDef -> String -> ColumnDef tableColumn t s | s == id_ = ColumnDef id_ "Int64" [] where id_ = I.unRawName $ I.rawTableIdName t tableColumn t s = go $ entityColumns t where go [] = error $ "Unknown table column: " ++ s go (ColumnDef x y z:rest) | x == s = ColumnDef x y z | otherwise = go rest -- End of copy-paste -- | Convert SELECT query convertSelect :: Types -> QueryExpr -> Either String EntityDef convertSelect types (Select _ _ (SelectList _ items) _ _ _ _ _ _ _) = EntityDef <$> return "Undefined" <*> return [] <*> mapM (convertColumn types) items <*> return [] <*> return ["Eq", "Show"] convertSelect _ expr = Left $ "Unsupported SQL query syntax in " ++ printQueryExpr expr getColumnNames :: QueryExpr -> Either String [ColumnName] getColumnNames (Select _ _ (SelectList _ items) _ _ _ _ _ _ _) = mapM getColumnName items getColumnNames expr = Left $ "Unsupported SQL query syntax in " ++ printQueryExpr expr checkSelect :: [String] -> QueryExpr -> Either String QueryExpr checkSelect tables s@(Select _ _ (SelectList _ items) trefs _ _ _ _ _ _) = (allE goodTref s trefs <&> allE goodCol s items) s where goodTref (FunTref _ _ a) = Left $ "Function call as table reference is not allowed: " ++ show a goodTref (JoinTref _ t1 _ _ t2 _ _) = (goodTref t1 <&> goodTref t2) s goodTref (SubTref _ q _) = checkSelect tables q goodTref (Tref _ n@(Name _ list) _) = let tname = lowerName n in if tname `elem` tables then Right s else Left $ "Table reference is not allowed: " ++ tname goodCol (SelExp _ expr) = goodExpr expr goodCol (SelectItem _ expr _) = goodExpr expr goodExpr (FunCall _ fn args) = let name = map toLower (name2string fn) in if name `elem` ["count", "max", "min", "avg", "sum"] then allE goodExpr s args else Left $ "Function call not allowed: " ++ name2string fn goodExpr (ScalarSubQuery _ q) = checkSelect tables q goodExpr (Case {}) = Left "CASE expressions are not allowed" goodExpr (CaseSimple {}) = Left "CASE expressions are not allowed" goodExpr (Cast {}) = Left "CAST expressions are not allowed" goodExpr (Exists _ q) = checkSelect tables q goodExpr (Extract {}) = Left "EXTRACT expressions are not allowed" goodExpr (InPredicate _ e _ (InList _ es)) = (goodExpr e <&> allE goodExpr s es) s goodExpr (InPredicate _ e _ (InQueryExpr _ q)) = (goodExpr e <&> checkSelect tables q) s goodExpr (LiftOperator _ _ _ es) = allE goodExpr s es goodExpr (QIdentifier _ _) = Right s --- !? goodExpr (WindowFn {}) = Left "WINDOW functions are not allowed" goodExpr _ = Right s checkSelect _ q = Left $ "Non-SELECT query forbidden: " ++ printQueryExpr q getConnection :: (MonadControlIO m) => SqlPersist m I.Connection getConnection = SqlPersist ask -- | Select list of records from DB using given SQL query selectFromQuery :: forall m a. (MonadControlIO m, PersistEntity a) => QueryExpr -- ^ SQL query -> [Filter a] -- ^ Filters -> SqlPersist m (Either String [a]) selectFromQuery query filts = do conn <- getConnection case makeQuery conn query filts of Left err -> fail err Right expr -> withStmt (Text.pack $ printQueryExpr expr) [] worker where worker :: I.RowPopper (SqlPersist m) -> SqlPersist m (Either String [a]) worker popper = runErrorT (go popper) go popper = do row <- lift popper case row of Nothing -> return [] Just list -> case fromPersistValues list of Left err -> fail err Right rec -> do next <- go popper return (rec: next) -- | Check if given SQL query is safe SELECT query checkSQL :: [String] -- ^ Allowed table names -> String -- ^ SQL query -> Either String QueryExpr checkSQL tables query = do (expr,_) <- parseSelect "" query checkSelect tables expr -- | Select list of records from DB using given SQL SELECT query. -- Query is checked for safety (arbitrary function calls, -- complex expressions, etc are not permitted). -- Each row will be represented as [PersistValue]. selectFromQuery' :: (MonadControlIO m) => [String] -- ^ Names of allowed tables -> String -- ^ SQL query -> SqlPersist m (Either String ([ColumnName], [[PersistValue]])) selectFromQuery' tables query = do case parseSelect "" query of Left err -> return (Left err) Right (expr, names) -> case checkSelect tables expr of Right _ -> do rows <- withStmt (Text.pack query) [] worker return $ Right (names, rows) Left err -> return $ Left $ "Forbidden query: " ++ err where worker popper = do row <- popper case row of Nothing -> return [] Just rec -> do next <- worker popper return (rec: next) -- | Quasi-quoter which parses SQL SELECT queries. -- Example input: -- -- @ -- SELECT family, salary FROM employee; -- -------------------- -- family String -- salary Int -- @ -- -- NB: entity name will be \"Undefined\", so -- you'll need to use record update syntax to -- set name your want, e.g. entity {entityName = \"Query\"}. persistSql :: QuasiQuoter persistSql = QuasiQuoter { quoteExp = \str -> case parseEntity "" str of Right expr -> TH.lift expr Left err -> fail err, quotePat = undefined, quoteType = undefined, quoteDec = undefined } -- | Load entity declaration from file containing SQL query. -- TH version. persistSqlFile :: FilePath -> Q [Exp] persistSqlFile path = do entity <- runIO $ parseEntityFromFile path x <- TH.lift entity return [x] -- | Load entity declaration from file containing SQL query parseEntityFromFile :: FilePath -> IO (Either String Query) parseEntityFromFile path = do text <- readFile path return (parseEntity path text) -- | Parse SQL entity declaration parseEntity :: FilePath -- ^ File name to use in error messages -> String -- ^ Declaration text -> Either String Query parseEntity path text = let delimiter str = not (null str) && all (`elem` "-= \t") str && any (`elem` "-=") str (queryLines, _:typeLines) = break delimiter (lines text) types = parseTypes typeLines bad line = all (`elem` " \t") line query = unlines $ filter (not . bad) queryLines in parseSelectT path types $ query parseSelectT :: String -> Types -> String -> Either String Query parseSelectT path types query = case parseQueryExpr path query of Left err -> Left $ "Cannot parse query expression: " ++ show err Right expr -> do res <- convertSelect types expr return (expr, res) parseSelect :: String -> String -> Either String (QueryExpr, [ColumnName]) parseSelect path query = do case parseQueryExpr path query of Left err -> Left (show err) Right expr -> do cols <- getColumnNames expr return (expr, cols) parseTypes :: [String] -> Types parseTypes = concatMap parseLine where parseLine line = case words line of [] -> [] (name:other) -> [(name, unwords other)] lowerName :: IsName name => name -> String lowerName n = map toLower (name2string n) lookupType :: (IsName name) => name -> Types -> Either String ColumnType lookupType name ts = case lookup (name2string name) ts of Nothing -> Left $ "Do not know type of " ++ name2string name Just typ -> return typ getColumnType :: Types -> SelectItem -> Either String ColumnType getColumnType t (SelExp _ (Identifier _ name)) = lookupType name t getColumnType t (SelExp _ (QIdentifier _ name)) = lookupType name t getColumnType t (SelExp _ (FunCall _ fn args)) = getFuncType t (lowerName fn) args getColumnType _ (SelExp _ expr) = Left $ "Cannot guess type of " ++ printScalarExpr expr getColumnType t (SelectItem _ expr name) = case lookup (name2string name) t of Just typ -> return typ Nothing -> getType t expr getColumnName :: SelectItem -> Either String ColumnName getColumnName (SelExp _ (Identifier _ name)) = return (name2string name) getColumnName (SelExp _ (QIdentifier _ name)) = return (name2string name) getColumnName (SelExp _ (FunCall _ fn _)) = return (lowerName fn) getColumnName (SelExp _ (QStar _ name)) = return (name2string name ++ ".*") getColumnName (SelExp _ (Star _)) = return "*" getColumnName (SelExp _ expr) = Left $ "Do not want invent the name for " ++ printScalarExpr expr getColumnName (SelectItem _ _ name) = return (name2string name) convertColumn :: Types -> SelectItem -> Either String ColumnDef convertColumn t s = ColumnDef <$> getColumnName s <*> getColumnType t s <*> return [] getType :: Types -> ScalarExpr -> Either String ColumnType getType _ (BooleanLit _ _) = return "Bool" getType _ (Exists _ _) = return "Bool" getType t (FunCall _ fn args) = getFuncType t (lowerName fn) args getType _ (InPredicate _ _ _ _) = return "Bool" getType _ (NumberLit _ _) = return "Int" getType t (Identifier _ name) = lookupType name t getType t (QIdentifier _ name) = lookupType name t getType _ (StringLit _ _) = return "String" getType _ x = Left $ "Unsupported scalar expression type: " ++ show x getFuncType :: Types -> String -> [ScalarExpr] -> Either String ColumnType getFuncType t "max" [e] = getType t e getFuncType t "min" [e] = getType t e getFuncType t "sum" [e] = getType t e getFuncType t "avg" [e] = getType t e getFuncType _ "count" _ = return "Int" getFuncType t fn args = case lookup fn t of Just typ -> return typ Nothing -> Left $ "Cannot invent type for " ++ fn ++ show args