{-# language GADTs #-}
{-# language LambdaCase #-}
{-# language NamedFieldPuns #-}
{-# language RankNTypes #-}
{-# language ScopedTypeVariables #-}
{-# language StandaloneKindSignatures #-}
{-# language StrictData #-}
{-# language TypeApplications #-}

module Rel8.Statement.Returning
  ( Returning( NumberOfRowsAffected, Projection )
  , decodeReturning
  , ppReturning
  )
where

-- base
import Control.Applicative ( liftA2 )
import Data.Foldable ( toList )
import Data.Int ( Int64 )
import Data.Kind ( Type )
import Data.List.NonEmpty ( NonEmpty )
import Prelude

-- hasql
import qualified Hasql.Decoders as Hasql

-- opaleye
import qualified Opaleye.Internal.HaskellDB.PrimQuery as Opaleye
import qualified Opaleye.Internal.HaskellDB.Sql.Print as Opaleye
import qualified Opaleye.Internal.Sql as Opaleye

-- pretty
import Text.PrettyPrint ( Doc, (<+>), text )

-- rel8
import Rel8.Schema.Name ( Selects )
import Rel8.Schema.Table ( TableSchema(..) )
import Rel8.Table.Opaleye ( castTable, exprs, view )
import Rel8.Table.Serialize ( Serializable, parse )

-- semigropuoids
import Data.Functor.Apply ( Apply, (<.>) )


-- | 'Rel8.Insert', 'Rel8.Update' and 'Rel8.Delete' all support returning either
-- the number of rows affected, or the actual rows modified.
type Returning :: Type -> Type -> Type
data Returning names a where
  Pure :: a -> Returning names a
  Ap :: Returning names (a -> b) -> Returning names a -> Returning names b

  -- | Return the number of rows affected.
  NumberOfRowsAffected :: Returning names Int64

  -- | 'Projection' allows you to project out of the affected rows, which can
  -- be useful if you want to log exactly which rows were deleted, or to view
  -- a generated id (for example, if using a column with an autoincrementing
  -- counter via 'Rel8.nextval').
  Projection :: (Selects names exprs, Serializable returning a)
    => (exprs -> returning)
    -> Returning names [a]


instance Functor (Returning names) where
  fmap :: (a -> b) -> Returning names a -> Returning names b
fmap a -> b
f = \case
    Pure a
a -> b -> Returning names b
forall a names. a -> Returning names a
Pure (a -> b
f a
a)
    Ap Returning names (a -> a)
g Returning names a
a -> Returning names (a -> b) -> Returning names a -> Returning names b
forall names a b.
Returning names (a -> b) -> Returning names a -> Returning names b
Ap (((a -> a) -> a -> b)
-> Returning names (a -> a) -> Returning names (a -> b)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (a -> b
f (a -> b) -> (a -> a) -> a -> b
forall b c a. (b -> c) -> (a -> b) -> a -> c
.) Returning names (a -> a)
g) Returning names a
a
    Returning names a
m -> Returning names (a -> b) -> Returning names a -> Returning names b
forall names a b.
Returning names (a -> b) -> Returning names a -> Returning names b
Ap ((a -> b) -> Returning names (a -> b)
forall a names. a -> Returning names a
Pure a -> b
f) Returning names a
m


instance Apply (Returning names) where
  <.> :: Returning names (a -> b) -> Returning names a -> Returning names b
(<.>) = Returning names (a -> b) -> Returning names a -> Returning names b
forall names a b.
Returning names (a -> b) -> Returning names a -> Returning names b
Ap


instance Applicative (Returning names) where
  pure :: a -> Returning names a
pure = a -> Returning names a
forall a names. a -> Returning names a
Pure
  <*> :: Returning names (a -> b) -> Returning names a -> Returning names b
(<*>) = Returning names (a -> b) -> Returning names a -> Returning names b
forall names a b.
Returning names (a -> b) -> Returning names a -> Returning names b
Ap


projections :: ()
  => TableSchema names -> Returning names a -> Maybe (NonEmpty Opaleye.PrimExpr)
projections :: TableSchema names -> Returning names a -> Maybe (NonEmpty PrimExpr)
projections schema :: TableSchema names
schema@TableSchema {names
columns :: forall names. TableSchema names -> names
columns :: names
columns} = \case
  Pure a
_ -> Maybe (NonEmpty PrimExpr)
forall a. Maybe a
Nothing
  Ap Returning names (a -> a)
f Returning names a
a -> TableSchema names
-> Returning names (a -> a) -> Maybe (NonEmpty PrimExpr)
forall names a.
TableSchema names -> Returning names a -> Maybe (NonEmpty PrimExpr)
projections TableSchema names
schema Returning names (a -> a)
f Maybe (NonEmpty PrimExpr)
-> Maybe (NonEmpty PrimExpr) -> Maybe (NonEmpty PrimExpr)
forall a. Semigroup a => a -> a -> a
<> TableSchema names -> Returning names a -> Maybe (NonEmpty PrimExpr)
forall names a.
TableSchema names -> Returning names a -> Maybe (NonEmpty PrimExpr)
projections TableSchema names
schema Returning names a
a
  Returning names a
NumberOfRowsAffected -> Maybe (NonEmpty PrimExpr)
forall a. Maybe a
Nothing
  Projection exprs -> returning
f -> NonEmpty PrimExpr -> Maybe (NonEmpty PrimExpr)
forall a. a -> Maybe a
Just (returning -> NonEmpty PrimExpr
forall a. Table Expr a => a -> NonEmpty PrimExpr
exprs (returning -> returning
forall a. Table Expr a => a -> a
castTable (exprs -> returning
f (names -> exprs
forall names exprs. Selects names exprs => names -> exprs
view names
columns))))


runReturning :: ()
  => ((Int64 -> a) -> r)
  -> (forall x. Hasql.Row x -> ([x] -> a) -> r)
  -> Returning names a
  -> r
runReturning :: ((Int64 -> a) -> r)
-> (forall x. Row x -> ([x] -> a) -> r) -> Returning names a -> r
runReturning (Int64 -> a) -> r
rowCount forall x. Row x -> ([x] -> a) -> r
rowList = \case
  Pure a
a -> (Int64 -> a) -> r
rowCount (a -> Int64 -> a
forall a b. a -> b -> a
const a
a)
  Ap Returning names (a -> a)
fs Returning names a
as ->
    ((Int64 -> a -> a) -> r)
-> (forall x. Row x -> ([x] -> a -> a) -> r)
-> Returning names (a -> a)
-> r
forall a r names.
((Int64 -> a) -> r)
-> (forall x. Row x -> ([x] -> a) -> r) -> Returning names a -> r
runReturning
      (\Int64 -> a -> a
withCount ->
         ((Int64 -> a) -> r)
-> (forall x. Row x -> ([x] -> a) -> r) -> Returning names a -> r
forall a r names.
((Int64 -> a) -> r)
-> (forall x. Row x -> ([x] -> a) -> r) -> Returning names a -> r
runReturning
           (\Int64 -> a
withCount' -> (Int64 -> a) -> r
rowCount (Int64 -> a -> a
withCount (Int64 -> a -> a) -> (Int64 -> a) -> Int64 -> a
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Int64 -> a
withCount'))
           (\Row x
decoder -> Row x -> ([x] -> a) -> r
forall x. Row x -> ([x] -> a) -> r
rowList Row x
decoder (([x] -> a) -> r) -> (([x] -> a) -> [x] -> a) -> ([x] -> a) -> r
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Int64 -> a -> a) -> ([x] -> Int64) -> ([x] -> a) -> [x] -> a
forall (f :: * -> *) a b c.
Applicative f =>
(a -> b -> c) -> f a -> f b -> f c
liftA2 Int64 -> a -> a
withCount [x] -> Int64
forall (f :: * -> *) x. Foldable f => f x -> Int64
length64)
           Returning names a
as)
      (\Row x
decoder [x] -> a -> a
withRows ->
         ((Int64 -> a) -> r)
-> (forall x. Row x -> ([x] -> a) -> r) -> Returning names a -> r
forall a r names.
((Int64 -> a) -> r)
-> (forall x. Row x -> ([x] -> a) -> r) -> Returning names a -> r
runReturning
           (\Int64 -> a
withCount -> Row x -> ([x] -> a) -> r
forall x. Row x -> ([x] -> a) -> r
rowList Row x
decoder (([x] -> a) -> r) -> ([x] -> a) -> r
forall a b. (a -> b) -> a -> b
$ [x] -> a -> a
withRows ([x] -> a -> a) -> ([x] -> a) -> [x] -> a
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Int64 -> a
withCount (Int64 -> a) -> ([x] -> Int64) -> [x] -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [x] -> Int64
forall (f :: * -> *) x. Foldable f => f x -> Int64
length64)
           (\Row x
decoder' [x] -> a
withRows' ->
             Row (x, x) -> ([(x, x)] -> a) -> r
forall x. Row x -> ([x] -> a) -> r
rowList ((x -> x -> (x, x)) -> Row x -> Row x -> Row (x, x)
forall (f :: * -> *) a b c.
Applicative f =>
(a -> b -> c) -> f a -> f b -> f c
liftA2 (,) Row x
decoder Row x
decoder') (([(x, x)] -> a) -> r) -> ([(x, x)] -> a) -> r
forall a b. (a -> b) -> a -> b
$
               [x] -> a -> a
withRows ([x] -> a -> a) -> ([(x, x)] -> [x]) -> [(x, x)] -> a -> a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((x, x) -> x) -> [(x, x)] -> [x]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (x, x) -> x
forall a b. (a, b) -> a
fst ([(x, x)] -> a -> a) -> ([(x, x)] -> a) -> [(x, x)] -> a
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [x] -> a
withRows' ([x] -> a) -> ([(x, x)] -> [x]) -> [(x, x)] -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((x, x) -> x) -> [(x, x)] -> [x]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (x, x) -> x
forall a b. (a, b) -> b
snd)
           Returning names a
as)
      Returning names (a -> a)
fs
  Returning names a
NumberOfRowsAffected -> (Int64 -> a) -> r
rowCount Int64 -> a
forall a. a -> a
id
  Projection (exprs -> returning
_ :: exprs -> returning) -> Row a -> ([a] -> a) -> r
forall x. Row x -> ([x] -> a) -> r
rowList Row a
decoder' [a] -> a
forall a. a -> a
id
    where
      decoder' :: Row a
decoder' = forall a. Serializable returning a => Row a
forall exprs a. Serializable exprs a => Row a
parse @returning
  where
    length64 :: Foldable f => f x -> Int64
    length64 :: f x -> Int64
length64 = Int -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> Int64) -> (f x -> Int) -> f x -> Int64
forall b c a. (b -> c) -> (a -> b) -> a -> c
. f x -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length


decodeReturning :: Returning names a -> Hasql.Result a
decodeReturning :: Returning names a -> Result a
decodeReturning = ((Int64 -> a) -> Result a)
-> (forall x. Row x -> ([x] -> a) -> Result a)
-> Returning names a
-> Result a
forall a r names.
((Int64 -> a) -> r)
-> (forall x. Row x -> ([x] -> a) -> r) -> Returning names a -> r
runReturning
  ((Int64 -> a) -> Result Int64 -> Result a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Result Int64
Hasql.rowsAffected)
  (\Row x
decoder [x] -> a
withRows -> [x] -> a
withRows ([x] -> a) -> Result [x] -> Result a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Row x -> Result [x]
forall a. Row a -> Result [a]
Hasql.rowList Row x
decoder)


ppReturning :: TableSchema names -> Returning names a -> Doc
ppReturning :: TableSchema names -> Returning names a -> Doc
ppReturning TableSchema names
schema Returning names a
returning = case TableSchema names -> Returning names a -> Maybe (NonEmpty PrimExpr)
forall names a.
TableSchema names -> Returning names a -> Maybe (NonEmpty PrimExpr)
projections TableSchema names
schema Returning names a
returning of
  Maybe (NonEmpty PrimExpr)
Nothing -> Doc
forall a. Monoid a => a
mempty
  Just NonEmpty PrimExpr
columns ->
    String -> Doc
text String
"RETURNING" Doc -> Doc -> Doc
<+> (SqlExpr -> Doc) -> [SqlExpr] -> Doc
forall a. (a -> Doc) -> [a] -> Doc
Opaleye.commaV SqlExpr -> Doc
Opaleye.ppSqlExpr (NonEmpty SqlExpr -> [SqlExpr]
forall (t :: * -> *) a. Foldable t => t a -> [a]
toList NonEmpty SqlExpr
sqlExprs)
    where
      sqlExprs :: NonEmpty SqlExpr
sqlExprs = PrimExpr -> SqlExpr
Opaleye.sqlExpr (PrimExpr -> SqlExpr) -> NonEmpty PrimExpr -> NonEmpty SqlExpr
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> NonEmpty PrimExpr
columns