module Database.Persist.HsSqlPpp
(Query,
persistSql,
persistSqlFile,
parseEntityFromFile,
parseEntity,
selectFromQuery,
selectFromQuery',
checkSQL
) where
import Control.Applicative
import Control.Monad.Reader
import Control.Monad.Trans
import Control.Monad.Error
import Control.Monad.Instances
import Control.Monad.IO.Control
import Data.Char
import Data.List (intercalate)
import qualified Data.ByteString as B
import qualified Data.Text as Text
import Language.Haskell.TH
import qualified Language.Haskell.TH.Lift as THL
import qualified Language.Haskell.TH.Syntax as TH
import Language.Haskell.TH.Quote
import Database.Persist
import Database.Persist.Base
import Database.Persist.GenericSql.Raw
import qualified Database.Persist.GenericSql.Internal as I
import Database.Persist.TH
import Database.HsSqlPpp.Types as Types
import Database.HsSqlPpp.Ast as AST
import Database.HsSqlPpp.Annotation
import Database.HsSqlPpp.Catalog
import Database.HsSqlPpp.Parser
import Database.HsSqlPpp.Pretty
type Types = [(String, String)]
type Query = (QueryExpr, EntityDef)
class IsName x where
name2string :: x -> String
instance IsName String where
name2string = id
instance IsName AST.Name where
name2string (AST.Name _ list) = name2string list
instance IsName NameComponent where
name2string n = ncStr n
instance IsName [NameComponent] where
name2string list = ncStr (last list)
toString :: B.ByteString -> String
toString bs = map (chr . fromIntegral) (B.unpack bs)
THL.deriveLift ''SelectItem
THL.deriveLift ''SelectList
THL.deriveLift ''Distinct
THL.deriveLift ''FrameClause
THL.deriveLift ''TypeName
THL.deriveLift ''LiftFlavour
THL.deriveLift ''IntervalField
THL.deriveLift ''InList
THL.deriveLift ''ExtractField
THL.deriveLift ''ScalarExpr
THL.deriveLift ''Direction
THL.deriveLift ''TableAlias
THL.deriveLift ''NameComponent
THL.deriveLift ''AST.Name
THL.deriveLift ''JoinExpr
THL.deriveLift ''JoinType
THL.deriveLift ''Natural
THL.deriveLift ''TableRef
THL.deriveLift ''CombineType
THL.deriveLift ''WithQuery
THL.deriveLift ''TypeError
THL.deriveLift ''FunFlav
THL.deriveLift ''CastContext
THL.deriveLift ''PseudoType
THL.deriveLift ''Types.Type
THL.deriveLift ''CatalogUpdate
THL.deriveLift ''Annotation
THL.deriveLift ''QueryExpr
allE :: (a -> Either e b) -> b -> [a] -> Either e b
allE _ b [] = Right b
allE fn b (x:xs) =
case fn x of
Left e -> Left e
Right _ -> allE fn b xs
(<&>) :: Either String a -> Either String a -> (a -> Either String a)
(Right _) <&> (Right _) = Right
(Right _) <&> (Left e) = \_ -> Left e
(Left e) <&> (Right _) = \_ -> Left e
(Left e) <&> (Left e') = \_ -> Left (e ++ "; " ++ e')
emptyAnn :: Annotation
emptyAnn = Annotation Nothing Nothing [] Nothing [] Nothing Nothing
makeQuery :: (PersistEntity t) => I.Connection -> QueryExpr -> [Filter t] -> Either String QueryExpr
makeQuery conn query filts = foldM (addWhere conn) query filts
addWhere :: (PersistEntity t) => I.Connection -> QueryExpr -> Filter t -> Either String QueryExpr
addWhere conn (Select ann dist fields trefs old a b c d e) filt = Right $ Select ann dist fields trefs (Just $ convertCond conn filt old) a b c d e
addWhere _ x _ = Left $ "Unsupported query expr in addWhere: " ++ show x
name :: String -> AST.Name
name s = Name emptyAnn [Nmc s]
convertCond :: (PersistEntity t) => I.Connection -> Filter t -> Maybe ScalarExpr -> ScalarExpr
convertCond conn f Nothing = convertFilter conn f
convertCond conn f (Just old) = FunCall emptyAnn (name "!and") [convertFilter conn f, old]
convertFilter conn (Filter field value op) =
FunCall emptyAnn (name $ showSqlFilter op) (Identifier emptyAnn (Nmc fname): allVals)
where
allVals = map persist2expr $ filterValueToPersistValues value
persist2expr (PersistText t) = StringLit emptyAnn (Text.unpack t)
persist2expr (PersistByteString bstr) = StringLit emptyAnn (toString bstr)
persist2expr (PersistInt64 i) = NumberLit emptyAnn (show i)
persist2expr (PersistDouble x) = NumberLit emptyAnn (show x)
persist2expr (PersistBool b) = BooleanLit emptyAnn b
persist2expr (PersistDay d) = StringLit emptyAnn (show d)
persist2expr (PersistTimeOfDay tod) = StringLit emptyAnn (show tod)
persist2expr (PersistUTCTime t) = StringLit emptyAnn (show t)
persist2expr (PersistNull) = NullLit emptyAnn
persist2expr x = error $ "persist2expr: unsupported Persistent value in filter: " ++ show x
filterValueToPersistValues :: forall a. PersistField a => Either a [a] -> [PersistValue]
filterValueToPersistValues v = map toPersistValue $ either return id v
fname = I.escapeName conn $ getFieldName t $ columnName $ persistColumnDef field
t = entityDef $ dummyFromFilts [Filter field value op]
getFieldName :: EntityDef -> String -> I.RawName
getFieldName t s = I.rawFieldName $ tableColumn t s
showSqlFilter Eq = "="
showSqlFilter Ne = "<>"
showSqlFilter Gt = ">"
showSqlFilter Lt = "<"
showSqlFilter Ge = ">="
showSqlFilter Le = "<="
showSqlFilter In = " IN "
showSqlFilter NotIn = " NOT IN "
showSqlFilter (BackendSpecificFilter s) = s
dummyFromFilts :: [Filter v] -> v
dummyFromFilts _ = error "dummyFromFilts"
tableColumn :: EntityDef -> String -> ColumnDef
tableColumn t s | s == id_ = ColumnDef id_ "Int64" []
where id_ = I.unRawName $ I.rawTableIdName t
tableColumn t s = go $ entityColumns t
where
go [] = error $ "Unknown table column: " ++ s
go (ColumnDef x y z:rest)
| x == s = ColumnDef x y z
| otherwise = go rest
convertSelect :: Types -> QueryExpr -> Either String EntityDef
convertSelect types (Select _ _ (SelectList _ items) _ _ _ _ _ _ _) =
EntityDef <$> return "Undefined"
<*> return []
<*> mapM (convertColumn types) items
<*> return []
<*> return ["Eq", "Show"]
convertSelect _ expr = Left $ "Unsupported SQL query syntax in " ++ printQueryExpr expr
getColumnNames :: QueryExpr -> Either String [ColumnName]
getColumnNames (Select _ _ (SelectList _ items) _ _ _ _ _ _ _) =
mapM getColumnName items
getColumnNames expr = Left $ "Unsupported SQL query syntax in " ++ printQueryExpr expr
checkSelect :: [String] -> QueryExpr -> Either String QueryExpr
checkSelect tables s@(Select _ _ (SelectList _ items) trefs _ _ _ _ _ _) =
(allE goodTref s trefs <&> allE goodCol s items) s
where
goodTref (FunTref _ _ a) = Left $ "Function call as table reference is not allowed: " ++ show a
goodTref (JoinTref _ t1 _ _ t2 _ _) = (goodTref t1 <&> goodTref t2) s
goodTref (SubTref _ q _) = checkSelect tables q
goodTref (Tref _ n@(Name _ list) _) =
let tname = lowerName n
in if tname `elem` tables
then Right s
else Left $ "Table reference is not allowed: " ++ tname
goodCol (SelExp _ expr) = goodExpr expr
goodCol (SelectItem _ expr _) = goodExpr expr
goodExpr (FunCall _ fn args) =
let name = map toLower (name2string fn)
in if name `elem` ["count", "max", "min", "avg", "sum"]
then allE goodExpr s args
else Left $ "Function call not allowed: " ++ name2string fn
goodExpr (ScalarSubQuery _ q) = checkSelect tables q
goodExpr (Case {}) = Left "CASE expressions are not allowed"
goodExpr (CaseSimple {}) = Left "CASE expressions are not allowed"
goodExpr (Cast {}) = Left "CAST expressions are not allowed"
goodExpr (Exists _ q) = checkSelect tables q
goodExpr (Extract {}) = Left "EXTRACT expressions are not allowed"
goodExpr (InPredicate _ e _ (InList _ es)) = (goodExpr e <&> allE goodExpr s es) s
goodExpr (InPredicate _ e _ (InQueryExpr _ q)) = (goodExpr e <&> checkSelect tables q) s
goodExpr (LiftOperator _ _ _ es) = allE goodExpr s es
goodExpr (QIdentifier _ _) = Right s
goodExpr (WindowFn {}) = Left "WINDOW functions are not allowed"
goodExpr _ = Right s
checkSelect _ q = Left $ "Non-SELECT query forbidden: " ++ printQueryExpr q
getConnection :: (MonadControlIO m) => SqlPersist m I.Connection
getConnection = SqlPersist ask
selectFromQuery :: forall m a. (MonadControlIO m, PersistEntity a)
=> QueryExpr
-> [Filter a]
-> SqlPersist m (Either String [a])
selectFromQuery query filts = do
conn <- getConnection
case makeQuery conn query filts of
Left err -> fail err
Right expr -> withStmt (Text.pack $ printQueryExpr expr) [] worker
where
worker :: I.RowPopper (SqlPersist m) -> SqlPersist m (Either String [a])
worker popper = runErrorT (go popper)
go popper = do
row <- lift popper
case row of
Nothing -> return []
Just list -> case fromPersistValues list of
Left err -> fail err
Right rec -> do
next <- go popper
return (rec: next)
checkSQL :: [String]
-> String
-> Either String QueryExpr
checkSQL tables query = do
(expr,_) <- parseSelect "<no file>" query
checkSelect tables expr
selectFromQuery' :: (MonadControlIO m)
=> [String]
-> String
-> SqlPersist m (Either String ([ColumnName], [[PersistValue]]))
selectFromQuery' tables query = do
case parseSelect "<inline>" query of
Left err -> return (Left err)
Right (expr, names) ->
case checkSelect tables expr of
Right _ -> do
rows <- withStmt (Text.pack query) [] worker
return $ Right (names, rows)
Left err -> return $ Left $ "Forbidden query: " ++ err
where
worker popper = do
row <- popper
case row of
Nothing -> return []
Just rec -> do
next <- worker popper
return (rec: next)
persistSql :: QuasiQuoter
persistSql = QuasiQuoter {
quoteExp = \str -> case parseEntity "<inline>" str of
Right expr -> TH.lift expr
Left err -> fail err,
quotePat = undefined,
quoteType = undefined,
quoteDec = undefined }
persistSqlFile :: FilePath -> Q [Exp]
persistSqlFile path = do
entity <- runIO $ parseEntityFromFile path
x <- TH.lift entity
return [x]
parseEntityFromFile :: FilePath -> IO (Either String Query)
parseEntityFromFile path = do
text <- readFile path
return (parseEntity path text)
parseEntity :: FilePath
-> String
-> Either String Query
parseEntity path text =
let delimiter str = not (null str) && all (`elem` "-= \t") str && any (`elem` "-=") str
(queryLines, _:typeLines) = break delimiter (lines text)
types = parseTypes typeLines
bad line = all (`elem` " \t") line
query = unlines $ filter (not . bad) queryLines
in parseSelectT path types $ query
parseSelectT :: String -> Types -> String -> Either String Query
parseSelectT path types query =
case parseQueryExpr path query of
Left err -> Left $ "Cannot parse query expression: " ++ show err
Right expr -> do
res <- convertSelect types expr
return (expr, res)
parseSelect :: String -> String -> Either String (QueryExpr, [ColumnName])
parseSelect path query = do
case parseQueryExpr path query of
Left err -> Left (show err)
Right expr -> do
cols <- getColumnNames expr
return (expr, cols)
parseTypes :: [String] -> Types
parseTypes = concatMap parseLine
where
parseLine line =
case words line of
[] -> []
(name:other) -> [(name, unwords other)]
lowerName :: IsName name => name -> String
lowerName n = map toLower (name2string n)
lookupType :: (IsName name) => name -> Types -> Either String ColumnType
lookupType name ts =
case lookup (name2string name) ts of
Nothing -> Left $ "Do not know type of " ++ name2string name
Just typ -> return typ
getColumnType :: Types -> SelectItem -> Either String ColumnType
getColumnType t (SelExp _ (Identifier _ name)) = lookupType name t
getColumnType t (SelExp _ (QIdentifier _ name)) = lookupType name t
getColumnType t (SelExp _ (FunCall _ fn args)) = getFuncType t (lowerName fn) args
getColumnType _ (SelExp _ expr) = Left $ "Cannot guess type of " ++ printScalarExpr expr
getColumnType t (SelectItem _ expr name) =
case lookup (name2string name) t of
Just typ -> return typ
Nothing -> getType t expr
getColumnName :: SelectItem -> Either String ColumnName
getColumnName (SelExp _ (Identifier _ name)) = return (name2string name)
getColumnName (SelExp _ (QIdentifier _ name)) = return (name2string name)
getColumnName (SelExp _ (FunCall _ fn _)) = return (lowerName fn)
getColumnName (SelExp _ (QStar _ name)) = return (name2string name ++ ".*")
getColumnName (SelExp _ (Star _)) = return "*"
getColumnName (SelExp _ expr) = Left $ "Do not want invent the name for " ++ printScalarExpr expr
getColumnName (SelectItem _ _ name) = return (name2string name)
convertColumn :: Types -> SelectItem -> Either String ColumnDef
convertColumn t s =
ColumnDef <$> getColumnName s <*> getColumnType t s <*> return []
getType :: Types -> ScalarExpr -> Either String ColumnType
getType _ (BooleanLit _ _) = return "Bool"
getType _ (Exists _ _) = return "Bool"
getType t (FunCall _ fn args) = getFuncType t (lowerName fn) args
getType _ (InPredicate _ _ _ _) = return "Bool"
getType _ (NumberLit _ _) = return "Int"
getType t (Identifier _ name) = lookupType name t
getType t (QIdentifier _ name) = lookupType name t
getType _ (StringLit _ _) = return "String"
getType _ x = Left $ "Unsupported scalar expression type: " ++ show x
getFuncType :: Types -> String -> [ScalarExpr] -> Either String ColumnType
getFuncType t "max" [e] = getType t e
getFuncType t "min" [e] = getType t e
getFuncType t "sum" [e] = getType t e
getFuncType t "avg" [e] = getType t e
getFuncType _ "count" _ = return "Int"
getFuncType t fn args =
case lookup fn t of
Just typ -> return typ
Nothing -> Left $ "Cannot invent type for " ++ fn ++ show args