{-# language DataKinds #-}
{-# language FlexibleContexts #-}
{-# language GADTs #-}
{-# language LambdaCase #-}
{-# language NamedFieldPuns #-}
{-# language RankNTypes #-}
{-# language ScopedTypeVariables #-}
{-# language StandaloneKindSignatures #-}
{-# language StrictData #-}
{-# language TypeApplications #-}

module Rel8.Statement.Returning
  ( Returning( NoReturning, Returning )
  , runReturning
  , ppReturning
  )
where

-- base
import Data.Foldable ( toList )
import Data.Kind ( Type )
import Data.List.NonEmpty ( NonEmpty )
import Prelude

-- opaleye
import qualified Opaleye.Internal.HaskellDB.PrimQuery as Opaleye
import qualified Opaleye.Internal.HaskellDB.Sql.Print as Opaleye
import qualified Opaleye.Internal.Sql as Opaleye
import qualified Opaleye.Internal.Tag as Opaleye

-- pretty
import Text.PrettyPrint ( Doc, (<+>), text )

-- rel8
import Rel8.Expr (Expr)
import Rel8.Query (Query)
import Rel8.Schema.Name ( Selects )
import Rel8.Schema.Table ( TableSchema(..) )
import Rel8.Statement (Statement, statementNoReturning, statementReturning)
import Rel8.Table (Table)
import Rel8.Table.Opaleye ( castTable, exprs, view )

-- transformers
import Control.Monad.Trans.State.Strict (State)


-- | 'Rel8.Insert', 'Rel8.Update' and 'Rel8.Delete' all support an optional
-- @RETURNING@ clause.
type Returning :: Type -> Type -> Type
data Returning names a where
  -- | No @RETURNING@ clause
  NoReturning :: Returning names ()

  -- | 'Returning' allows you to project out of the affected rows, which can
  -- be useful if you want to log exactly which rows were deleted, or to view
  -- a generated id (for example, if using a column with an autoincrementing
  -- counter via 'Rel8.nextval').
  Returning :: (Selects names exprs, Table Expr a) => (exprs -> a) -> Returning names (Query a)


projections :: ()
  => TableSchema names -> Returning names a -> Maybe (NonEmpty Opaleye.PrimExpr)
projections :: forall names a.
TableSchema names -> Returning names a -> Maybe (NonEmpty PrimExpr)
projections TableSchema {names
columns :: names
$sel:columns:TableSchema :: forall names. TableSchema names -> names
columns} = \case
  Returning names a
NoReturning -> Maybe (NonEmpty PrimExpr)
forall a. Maybe a
Nothing
  Returning exprs -> a
f -> NonEmpty PrimExpr -> Maybe (NonEmpty PrimExpr)
forall a. a -> Maybe a
Just (a -> NonEmpty PrimExpr
forall a. Table Expr a => a -> NonEmpty PrimExpr
exprs (a -> a
forall a. Table Expr a => a -> a
castTable (exprs -> a
f (names -> exprs
forall names exprs. Selects names exprs => names -> exprs
view names
columns))))


runReturning ::
  State Opaleye.Tag Doc ->
  Returning names a ->
  Statement a
runReturning :: forall names a. State Tag Doc -> Returning names a -> Statement a
runReturning State Tag Doc
pp = \case
  Returning names a
NoReturning -> State Tag Doc -> Statement ()
statementNoReturning State Tag Doc
pp
  Returning exprs -> a
_ -> State Tag Doc -> Statement (Query a)
forall a. Table Expr a => State Tag Doc -> Statement (Query a)
statementReturning State Tag Doc
pp


ppReturning :: TableSchema names -> Returning names a -> Doc
ppReturning :: forall names a. TableSchema names -> Returning names a -> Doc
ppReturning TableSchema names
schema Returning names a
returning = case TableSchema names -> Returning names a -> Maybe (NonEmpty PrimExpr)
forall names a.
TableSchema names -> Returning names a -> Maybe (NonEmpty PrimExpr)
projections TableSchema names
schema Returning names a
returning of
  Maybe (NonEmpty PrimExpr)
Nothing -> Doc
forall a. Monoid a => a
mempty
  Just NonEmpty PrimExpr
columns ->
    String -> Doc
text String
"RETURNING" Doc -> Doc -> Doc
<+> (SqlExpr -> Doc) -> [SqlExpr] -> Doc
forall a. (a -> Doc) -> [a] -> Doc
Opaleye.commaV SqlExpr -> Doc
Opaleye.ppSqlExpr (NonEmpty SqlExpr -> [SqlExpr]
forall a. NonEmpty a -> [a]
forall (t :: * -> *) a. Foldable t => t a -> [a]
toList NonEmpty SqlExpr
sqlExprs)
    where
      sqlExprs :: NonEmpty SqlExpr
sqlExprs = PrimExpr -> SqlExpr
Opaleye.sqlExpr (PrimExpr -> SqlExpr) -> NonEmpty PrimExpr -> NonEmpty SqlExpr
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> NonEmpty PrimExpr
columns