{-# LANGUAGE GADTs, ScopedTypeVariables, RankNTypes, DataKinds, AllowAmbiguousTypes #-}
-- {-# OPTIONS_GHC -fwarn-incomplete-uni-patterns #-}
module Internal.Data.Basic.Compiler where

import Internal.Interlude

import Internal.Data.Basic.Types as Basic
import Internal.Data.Basic.Sql.Types as Sql
import GHC.TypeLits (Symbol, KnownSymbol)
import Database.PostgreSQL.Simple.ToField (ToField(..))
import Database.PostgreSQL.Simple.ToRow (ToRow(..))

expToSql :: DbExp k a -> SqlValueExp
expToSql (Field (_ :: proxy1 table) (_ :: proxy2 name) (Var tab)) =
    SimpleName (QualifiedField tab (nameText @(CapsName table name)))
expToSql (Literal a) = SqlLiteral (toField a)

varToSql :: Var k a -> SqlValueExp
varToSql (Var idx) = TableRecord (TableValue idx)

someDbExpToSqlExp :: SomeDbExp -> SqlValueExp
someDbExpToSqlExp (SomeDbExp e) = expToSql e
someDbExpToSqlExp (SomeVar v) = varToSql v

literalCollectionToSql :: LiteralCollection collection a => collection -> [SqlValueExp]
literalCollectionToSql = fmap someDbExpToSqlExp . getLiteralCollection

boolExpToSql :: ConditionExp -> Condition
boolExpToSql (Compare c f1 f2) =
    SqlOperator c (expToSql f1) (expToSql f2)
boolExpToSql (BoolOp And exp1 exp2) = SqlAnd (boolExpToSql exp1) (boolExpToSql exp2)
boolExpToSql (BoolOp Or exp1 exp2) = SqlOr (boolExpToSql exp1) (boolExpToSql exp2)
boolExpToSql (Basic.IsNull f) = Sql.IsNull (expToSql f)
boolExpToSql (Basic.IsNotNull f) = Sql.IsNotNull (expToSql f)
boolExpToSql (Basic.In val vals) = Sql.In (expToSql val) (literalCollectionToSql vals)
boolExpToSql (Basic.BoolLit b) = Sql.BoolLit b
boolExpToSql (Basic.Like b e v) = Sql.Like b (expToSql e) v

conditionToSql :: forall tables. TableSetVars 'Filtering tables
               => (Variables 'Filtering tables -> ConditionExp) -> Condition
conditionToSql f = boolExpToSql (f (makeVars @'Filtering @tables))

uniqueNames :: [QualifiedTable] -> [QualifiedTable]
uniqueNames = flip evalState 0 . mapM (\(QualifiedTable t _) -> do
    n <- get
    modify' (+ 1)
    return (QualifiedTable t n))

compileTable :: forall name proxy. KnownSymbol name => proxy (name :: Symbol) -> SqlExp
compileTable _ =
    Select SelectEverything Nothing [QualifiedTable (nameText @name) 0] [] (Limit Nothing) (Sql.Grouping [])

updatedExpToSql :: forall fields table. UpdateExp fields table -> ([Text], [SqlValueExp])
updatedExpToSql = \upd -> updatedExpToSql' (varFromUpdateExp upd) upd
    where updatedExpToSql' :: Var 'Updating t -> UpdateExp fields' table -> ([Text], [SqlValueExp])
          updatedExpToSql' _ (NoUpdate _) = ([], [])
          updatedExpToSql' v (SetField (_ :: proxy field) upd val) =
              (nameText @(CapsName table field) : fs, expToSql val : vs)
              where (fs, vs) = updatedExpToSql' v upd

updateToSql :: forall table fields. (Var 'Updating table -> UpdateExp fields table)
             -> ([Text], [SqlValueExp])
updateToSql f = updatedExpToSql (f (makeVars @'Updating @'[table]))

orderingToSql :: forall tables ord. (Sortable ord, TableSetVars 'Sorting tables)
              => (Variables 'Sorting tables -> ord) -> [(SqlValueExp, SortDirection)]
orderingToSql f = fmap (first someDbExpToSqlExp)
                       (getOrdering (f (makeVars @'Sorting @tables)))

mappingToSql ::
    forall tables map.
    ( Mappable map
    , TableSetVars 'Mapping tables )
    => (Variables 'Mapping tables -> map) -> [SqlValueExp]
mappingToSql f = mapToSql (f (makeVars @'Mapping @tables))

mapToSql :: Mappable map => map -> [SqlValueExp]
mapToSql = fmap someDbExpToSqlExp . getMapping

groupMapToSql :: GroupMappable map => map -> [SqlValueExp]
groupMapToSql = fmap (\(af, someExp) -> AggregateFunction af (someDbExpToSqlExp someExp)) . getGroupMapping


grouppingToSql ::
    forall tables group.
    ( Groupable group
    , TableSetVars 'Basic.Grouping tables )
    => (Variables 'Basic.Grouping tables -> group) -> [SqlValueExp]
grouppingToSql f =
    fmap someDbExpToSqlExp (getGrouping (f (makeVars @'Basic.Grouping @tables)))

groupStatementToSql :: forall tables group. GroupStatement group tables -> SqlExp
groupStatementToSql (GroupOn f t) =
    Select SelectEverything conditions tables ordering lim (Sql.Grouping (grouppingToSql @tables f))
    where Select SelectEverything conditions tables ordering lim (Sql.Grouping []) = compileToSql t

foldingToSql ::
    forall tables aggr.
    ( Aggregatable aggr
    , TableSetVars 'Folding tables )
    => (Variables 'Folding tables -> aggr) -> [SqlValueExp]
foldingToSql f =
    fmap (\(af, someExp) -> AggregateFunction af (someDbExpToSqlExp someExp))
         (getAggregating (f (makeVars @'Folding @tables)))


aggregateStatementToSql :: AggregateStatement aggr 'AM -> SqlExp
aggregateStatementToSql (Aggregate f (t :: DbStatement f tables)) =
    Select (SelectExpressions (foldingToSql @tables f)) conditions tables ordering lim (Sql.Grouping [])
    where Select SelectEverything conditions tables ordering lim (Sql.Grouping []) = compileToSql t

compileToSql :: DbStatement f ts -> SqlExp
compileToSql (Table p) = compileTable p
compileToSql (Filter cond (t :: DbStatement f tables)) =
    Select sel (conditions <> Just newConditions) tables [] (Limit Nothing) (Sql.Grouping [])
    where Select sel conditions tables [] (Limit Nothing) (Sql.Grouping []) = compileToSql t
          newConditions = conditionToSql @tables cond
compileToSql (Join t1 t2) =
    Select SelectEverything Nothing (uniqueNames $ tab1 ++ tab2) [] (Limit Nothing) (Sql.Grouping [])
    where Select SelectEverything Nothing tab1 [] (Limit Nothing) (Sql.Grouping []) = compileToSql t1
          Select SelectEverything Nothing tab2 [] (Limit Nothing) (Sql.Grouping []) = compileToSql t2
compileToSql (Raw q pars) = RawQuery q (toRow pars)
compileToSql (Execute q pars) = RawQuery q (toRow pars)
compileToSql (Basic.Insert (a :: Entity entKind table)) =
    Sql.Insert (nameText @(TableName table))
               (mapTypeList (Proxy @(HasCapsFieldName table)) (capsFieldName @table)
                            (Proxy @(SetFields (MissingFields entKind) table)))
               (mapFields @(TypeSatisfies ToField) @table @(SetFields (MissingFields entKind) table) (const toField) a)
compileToSql (Basic.Delete (t :: DbStatement f '[table])) =
    Sql.Delete table conditions
    where Select SelectEverything conditions [table] [] (Limit Nothing) (Sql.Grouping []) = compileToSql t
compileToSql (Basic.Update update t) = Sql.Update updateFields updateVals conditions table
    where Select SelectEverything conditions [table] [] (Limit Nothing) (Sql.Grouping []) = compileToSql t
          (updateFields, updateVals) = updateToSql update
compileToSql (SortOn selector (t :: DbStatement f tables)) =
    Select sel conditions tables orderings (Limit Nothing) (Sql.Grouping [])
    where Select sel conditions tables [] (Limit Nothing) (Sql.Grouping []) = compileToSql t
          orderings = orderingToSql @tables selector
compileToSql (Take n t) =
    Select sel conditions tables ordering (Limit (Just n)) (Sql.Grouping [])
    where Select sel conditions tables ordering (Limit Nothing) (Sql.Grouping []) = compileToSql t
compileToSql (Map f (t :: DbStatement f tables)) =
    Select (SelectExpressions (mappingToSql @tables f)) conditions tables ordering lim (Sql.Grouping [])
    where Select SelectEverything conditions tables ordering lim (Sql.Grouping []) = compileToSql t
compileToSql (AsGroup t) = compileToSql t
compileToSql (GroupMap f t@(GroupOn gf (gt :: DbStatement f tables))) =
    Select what conditions tables ordering lim grouping
    where Select SelectEverything conditions tables ordering lim grouping = groupStatementToSql t
          g = asAggregate (gf (makeVars @'Basic.Grouping @tables))
          what = SelectExpressions (groupMapToSql (f (g, AsGroup gt)))