-- | Analysis and transformation of SQL queries. module Database.Selda.Transform where import Database.Selda.Column import Database.Selda.SQL import Database.Selda.Query.Type import Database.Selda.Types -- | Remove all dead columns recursively, assuming that the given list of -- column names contains all names present in the final result. removeDeadCols :: [ColName] -> SQL -> SQL removeDeadCols live sql = case source sql' of EmptyTable -> sql' TableName _ -> sql' Values _ _ -> sql' Product qs -> sql' {source = Product $ map noDead qs} Join jt on l r -> sql' {source = Join jt on (noDead l) (noDead r)} where noDead = removeDeadCols live' sql' = keepCols (allNonOutputColNames sql ++ live) sql live' = allColNames sql' -- | Return the names of all columns in the given top-level query. -- Subqueries are not traversed. allColNames :: SQL -> [ColName] allColNames sql = colNames (cols sql) ++ allNonOutputColNames sql -- | Return the names of all non-output (i.e. 'cols') columns in the given -- top-level query. Subqueries are not traversed. allNonOutputColNames :: SQL -> [ColName] allNonOutputColNames sql = concat [ concatMap allNamesIn (restricts sql) , colNames (groups sql) , colNames (map snd $ ordering sql) , case source sql of Join _ on _ _ -> allNamesIn on _ -> [] ] -- | Get all column names appearing in the given list of (possibly complex) -- columns. colNames :: [SomeCol SQL] -> [ColName] colNames cs = concat [ [n | Some c <- cs, n <- allNamesIn c] , [n | Named _ c <- cs, n <- allNamesIn c] , [n | Named n _ <- cs] ] -- | Remove all columns but the given, named ones and aggregates, from a query's -- list of outputs. -- If we want to refer to a column in an outer query, it must have a name. -- If it doesn't, then it's either not referred to by an outer query, or -- the outer query duplicates the expression, thereby referring directly -- to the names of its components. keepCols :: [ColName] -> SQL -> SQL keepCols live sql = sql {cols = filtered} where filtered = filter (`oneOf` live) (cols sql) oneOf (Some (AggrEx _ _)) _ = True oneOf (Named _ (AggrEx _ _)) _ = True oneOf (Some (Col n)) ns = n `elem` ns oneOf (Named n _) ns = n `elem` ns oneOf _ _ = False -- | Build the outermost query from the SQL generation state. -- Groups are ignored, as they are only used by 'aggregate'. state2sql :: GenState -> SQL state2sql (GenState [sql] srs _ _ _) = sql {restricts = restricts sql ++ srs} state2sql (GenState ss srs _ _ _) = SQL (allCols ss) (Product ss) srs [] [] Nothing False -- | Get all output columns from a list of SQL ASTs. allCols :: [SQL] -> [SomeCol SQL] allCols sqls = [outCol col | sql <- sqls, col <- cols sql] where outCol (Named n _) = Some (Col n) outCol c = c