-----------------------------------------------------------------------------
-- |
-- Module      :  Database.TxtSushi.SQLParser
-- Copyright   :  (c) Keith Sheppard 2009
-- License     :  GPL3 or greater
-- Maintainer  :  keithshep@gmail.com
-- Stability   :  experimental
-- Portability :  portable
--
-- SQL Expressions
--
-----------------------------------------------------------------------------

module Database.TxtSushi.SQLExpression (
    allMaybeTableNames,
    SelectStatement(..),
    TableExpression(..),
    ColumnIdentifier(..),
    ColumnSelection(..),
    Expression(..),
    SQLFunction(..),
    OrderByItem(..),
    ColumnRange(..),
    isAggregate,
    selectStatementContainsAggregates,
    expressionToString,
    columnToString) where

import Data.Char
import Data.List

import Database.TxtSushi.EvaluatedExpression

--------------------------------------------------------------------------------
-- The data definition for select statements
--------------------------------------------------------------------------------

-- | represents a select statement
--   TODO this should be moved inside the TableExpression type
data SelectStatement = SelectStatement {
    columnSelections :: [ColumnSelection],
    maybeFromTable :: Maybe TableExpression,
    maybeWhereFilter :: Maybe Expression,
    maybeGroupByHaving :: Maybe ([Expression], Maybe Expression),
    orderByItems :: [OrderByItem]}

data TableExpression =
    TableIdentifier {
        tableName :: String,
        maybeTableAlias :: Maybe String} |
    InnerJoin {
        leftJoinTable :: TableExpression,
        rightJoinTable :: TableExpression,
        onCondition :: Expression,
        maybeTableAlias :: Maybe String} |
    CrossJoin {
        leftJoinTable :: TableExpression,
        rightJoinTable :: TableExpression,
        maybeTableAlias :: Maybe String} |
    SelectExpression {
        selectStatement :: SelectStatement,
        maybeTableAlias :: Maybe String}

-- | convenience function for extracting all of the table names used by the
--   given table expression
allMaybeTableNames :: (Maybe TableExpression) -> [String]
allMaybeTableNames Nothing = []
allMaybeTableNames (Just tblExp) = allTableNames tblExp

allTableNames :: TableExpression -> [String]
allTableNames (TableIdentifier tblName _) = [tblName]
allTableNames (InnerJoin lftTbl rtTbl _ _) =
    (allTableNames lftTbl) ++ (allTableNames rtTbl)
allTableNames (CrossJoin lftTbl rtTbl _) =
    (allTableNames lftTbl) ++ (allTableNames rtTbl)
allTableNames (SelectExpression selectStmt _) =
    allMaybeTableNames $ maybeFromTable selectStmt

data ColumnSelection =
    AllColumns |
    AllColumnsFrom {sourceTableName :: String} |
    ExpressionColumn {
        expression :: Expression,
        maybeColumnAlias :: Maybe String} |
    ExpressionColumnRange {
        binding :: ColumnIdentifier,
        range :: ColumnRange,
        expression :: Expression}

data ColumnRange = ColumnRange {
    maybeStart :: Maybe ColumnIdentifier,
    maybeEnd :: Maybe ColumnIdentifier}

data ColumnIdentifier =
    ColumnIdentifier {
        maybeTableName :: Maybe String,
        columnId :: String} deriving Eq

data Expression =
    FunctionExpression {
        sqlFunction :: SQLFunction,
        functionArguments :: [Expression],
        stringRepresentation :: String} |
    ColumnExpression {
        column :: ColumnIdentifier,
        stringRepresentation :: String} |
    StringConstantExpression {
        stringConstant :: String,
        stringRepresentation :: String} |
    IntConstantExpression {
        intConstant :: Int,
        stringRepresentation :: String} |
    RealConstantExpression {
        realConstant :: Double,
        stringRepresentation :: String} |
    BoolConstantExpression {
        boolConstant :: Bool,
        stringRepresentation :: String}

data SQLFunction = SQLFunction {
    functionName :: String,
    minArgCount :: Int,
    argCountIsFixed :: Bool,
    functionGrammar :: String,
    functionDescription :: String,
    applyFunction :: [EvaluatedExpression] -> EvaluatedExpression}

-- | an aggregate function is one whose min function count is 1 and whose
--   arg count is not fixed
isAggregate :: SQLFunction -> Bool
isAggregate = not . argCountIsFixed

containsAggregates :: Expression -> Bool
containsAggregates (FunctionExpression sqlFun args _) =
    isAggregate sqlFun || any containsAggregates args
containsAggregates _ = False

selectionContainsAggregates :: ColumnSelection -> Bool
selectionContainsAggregates (ExpressionColumn expr _) =
    containsAggregates expr
selectionContainsAggregates _ = False

orderByItemContainsAggregates :: OrderByItem -> Bool
orderByItemContainsAggregates (OrderByItem expr _) =
    containsAggregates expr

selectStatementContainsAggregates :: SelectStatement -> Bool
selectStatementContainsAggregates select =
    any selectionContainsAggregates (columnSelections select) ||
    any orderByItemContainsAggregates (orderByItems select)

expressionToString :: Expression -> String
expressionToString (FunctionExpression _ _ strRep) = strRep
expressionToString (ColumnExpression _ strRep) = strRep
expressionToString (StringConstantExpression _ strRep) = strRep
expressionToString (IntConstantExpression _ strRep) = strRep
expressionToString (RealConstantExpression _ strRep) = strRep
expressionToString (BoolConstantExpression _ strRep) = strRep

columnToString :: ColumnIdentifier -> String
columnToString (ColumnIdentifier (Just tblName) colId) = tblName ++ "." ++ colId
columnToString (ColumnIdentifier (Nothing) colId) = colId

data OrderByItem = OrderByItem {
    orderExpression :: Expression,
    orderAscending :: Bool}