{-# OPTIONS_GHC -Wno-orphans -Wno-deprecations #-}
{-# LANGUAGE ExistentialQuantification, RankNTypes, FlexibleContexts #-}
module Internal.Data.Basic.Sql.Types where

import Internal.Interlude hiding (Sum)

import Data.String (IsString(..))
import Database.PostgreSQL.Simple hiding (In, Only)
import Database.PostgreSQL.Simple.Types (Query(..))
import Database.PostgreSQL.Simple.ToField (Action(Escape))

data QuerySegment = QuerySegment Query [Action]
    deriving (Show)

instance Semigroup Query where
    (<>) = mappend

instance Monoid QuerySegment where
    mempty = QuerySegment mempty mempty
    QuerySegment q1 as1 `mappend` QuerySegment q2 as2
        = QuerySegment (q1 <> q2) (as1 <> as2)
instance Semigroup QuerySegment

instance IsString QuerySegment where
    fromString s = QuerySegment (fromString s) []

data Comparison = LessThan | LessOrEqual | GreaterThan | GreaterOrEqual | Equal | NotEqual
                  deriving (Eq, Ord, Read, Show)

data SortDirection = Ascending | Descending deriving (Eq, Ord, Read, Show)

newtype SqlFunctionName = SqlFunctionName Text deriving (Eq, Ord, Read, Show)

data QualifiedField = QualifiedField Int Text deriving (Eq, Ord, Read, Show)
newtype TableValue = TableValue Int deriving (Eq, Ord, Read, Show)
data QualifiedTable = QualifiedTable Text Int deriving (Eq, Ord, Read, Show)

data Condition = SqlAnd Condition Condition
               | SqlOr Condition Condition
               | SqlOperator Comparison SqlValueExp SqlValueExp
               | IsNull SqlValueExp
               | IsNotNull SqlValueExp
               | In SqlValueExp [SqlValueExp]
               | BoolLit Bool
               | Like Bool SqlValueExp Text
               deriving (Show)

data AggregateFunction = Avg | Max | Min | Count | Sum | Only | ArrayAgg deriving (Show)

data SqlValueExp = SimpleName QualifiedField
                 | TableRecord TableValue
                 | SqlFunctionApplication SqlFunctionName SqlValueExp
                 | SqlLiteral Action
                 | AggregateFunction AggregateFunction SqlValueExp
                 deriving (Show)

newtype Limit = Limit (Maybe Int) deriving (Eq, Ord, Read, Show)

data Selection = SelectEverything | SelectExpressions [SqlValueExp] deriving (Show)

newtype Grouping = Grouping [SqlValueExp] deriving (Show)

data SqlExp =
    Select
        Selection
        (Maybe Condition)
        [QualifiedTable]
        [(SqlValueExp, SortDirection)]
        Limit
        Grouping
  | Insert Text [Text] [Action]
  | RawQuery Text [Action]
  | Delete QualifiedTable (Maybe Condition)
  | Update [Text] [SqlValueExp] (Maybe Condition) QualifiedTable
  deriving (Show)

sToQuery :: StringConv a ByteString => a -> QuerySegment
sToQuery bs = QuerySegment (Query (toS bs)) []

actionToQuery :: Action -> QuerySegment
actionToQuery a = QuerySegment "? " [a]

tableToQuery :: QualifiedTable -> QuerySegment
tableToQuery (QualifiedTable name index) = "\"" <> sToQuery name <> "\" as t" <> show index <> " "

comparisonToQuery :: Comparison -> QuerySegment
comparisonToQuery Equal          = "= "
comparisonToQuery NotEqual       = "!= "
comparisonToQuery LessThan       = "< "
comparisonToQuery LessOrEqual    = "<= "
comparisonToQuery GreaterThan    = "> "
comparisonToQuery GreaterOrEqual = ">= "

fieldToQuery :: QualifiedField -> QuerySegment
fieldToQuery (QualifiedField index name) = "t" <> show index <> ".\"" <> sToQuery name <> "\" "

tableValueToQuery :: TableValue -> QuerySegment
tableValueToQuery (TableValue index) = "t" <> show index <> " "

aggregateFunctionToQuery :: AggregateFunction -> QuerySegment
aggregateFunctionToQuery Avg = "avg "
aggregateFunctionToQuery Min = "min "
aggregateFunctionToQuery Max = "max "
aggregateFunctionToQuery Sum = "sum "
aggregateFunctionToQuery Count = "count "
aggregateFunctionToQuery Only = ""
aggregateFunctionToQuery ArrayAgg = "array_agg "

valueToQuery :: SqlValueExp -> QuerySegment
valueToQuery (SimpleName field) = fieldToQuery field
valueToQuery (TableRecord tv) = tableValueToQuery tv
valueToQuery (SqlFunctionApplication (SqlFunctionName name) val) =
    sToQuery name <> "( " <> valueToQuery val <> " ) "
valueToQuery (SqlLiteral l) = actionToQuery l
valueToQuery (AggregateFunction af v) =
    aggregateFunctionToQuery af <> "( " <> valueToQuery v <> " ) "

conditionToQuery :: Condition -> QuerySegment
conditionToQuery (SqlOperator comp v1 v2) = 
    "( " 
    <> valueToQuery v1
    <> comparisonToQuery comp
    <> valueToQuery v2
    <> ") "
conditionToQuery (SqlAnd cond1 cond2) = 
    "( " <> conditionToQuery cond1 <> "and " <> conditionToQuery cond2 <> ") "
conditionToQuery (SqlOr cond1 cond2) = 
    "( " <> conditionToQuery cond1 <> "or " <> conditionToQuery cond2 <> ") "
conditionToQuery (IsNull v) = "( " <> valueToQuery v <> ") is null "
conditionToQuery (IsNotNull v) = "( " <> valueToQuery v <> ") IS NOT NULL "
conditionToQuery (In a b)
  | null b = "1!=1 " -- SELECT * FROM bla where field in () is invalid
  | otherwise = valueToQuery a <> " in " <> toSqlList b
  where toSqlList xs = "( " <> foldl' (<>) mempty (intersperse ", " (valueToQuery <$> xs)) <> " )"
conditionToQuery (BoolLit b) = if b then "true " else "false "
conditionToQuery (Like False e v) = valueToQuery e <> "like ?" <> QuerySegment "" [Escape (toS v)]
conditionToQuery (Like True e v) = valueToQuery e <> "ilike ?" <> QuerySegment "" [Escape (toS v)]

orderingToQuery :: (SqlValueExp, SortDirection) -> QuerySegment
orderingToQuery (e, Ascending) = valueToQuery e <> "asc "
orderingToQuery (e, Descending) = valueToQuery e <> "desc "

limitToQuery :: Limit -> QuerySegment
limitToQuery (Limit Nothing) = ""
limitToQuery (Limit (Just lim)) = "limit " <> sToQuery (show lim :: Text) <> " "

selectionToQuery :: Selection -> QuerySegment
selectionToQuery SelectEverything = "* "
selectionToQuery (SelectExpressions exps) = separateBy ", " (fmap valueToQuery exps)

groupToQuery :: Grouping -> QuerySegment
groupToQuery (Grouping []) = ""
groupToQuery (Grouping exps) = "group by " <> separateBy ", " (fmap valueToQuery exps)

listToTuple :: [QuerySegment] -> QuerySegment
listToTuple xs = "(" <> foldl1Def (\x y -> x <> ", " <> y) "" xs <> ") "

separateBy :: (Monoid a, Semigroup a, IsString a) => a -> [a] -> a
separateBy sep l = foldl1Def (\a b -> a <> sep <> b) mempty l <> " "

sqlExpToQuery :: SqlExp -> QuerySegment
sqlExpToQuery (Select selection cond tables ordering limit grouping) =
       "select "
    <> selectionToQuery selection
    <> "from "
    <> tableAliases
    <> maybe "" (("where " <>) . conditionToQuery) cond
    <> (if null ordering then ""
       else "order by " <> separateBy ", " (map orderingToQuery ordering))
    <> limitToQuery limit
    <> groupToQuery grouping
    where tableAliases = foldl1Def (\x y -> x <> ", " <> y) "" (map tableToQuery tables)

sqlExpToQuery (Insert table fields values) =
      "insert into \"" <> sToQuery table <> "\" "
    <> listToTuple (fmap sToQuery fields)
    <> "values " <> listToTuple (fmap actionToQuery values)
    <> "returning * "
sqlExpToQuery (RawQuery q as) = QuerySegment (Query (toS q)) as
sqlExpToQuery (Delete table cond) =
       "delete from " <> tableAlias
    <> maybe " " (("where " <>) . conditionToQuery) cond
    <> "returning * "
    where tableAlias = foldl1Def (\x y -> x <> ", " <> y) "" (map tableToQuery [table])

sqlExpToQuery (Update fields values cond table) =
       "update " <> tableToQuery table
    <> " set " <> listToTuple (fmap sToQuery fields) <> " = "
               <> listToTuple (fmap valueToQuery values)
    <> maybe "" (("where " <>) . conditionToQuery) cond
    <> "returning * "

instance Semigroup Condition where
    (<>) = SqlAnd

data SqlResult = forall a. FromRow a => SqlResult [a]

data SomeFromRowProxy = forall a. FromRow a => SomeFromRowProxy (Proxy a)