{-# language GADTs #-}
{-# language LambdaCase #-}
{-# language NamedFieldPuns #-}
{-# language RankNTypes #-}
{-# language ScopedTypeVariables #-}
{-# language StandaloneKindSignatures #-}
{-# language StrictData #-}
{-# language TypeApplications #-}

module Rel8.Statement.Returning
  ( Returning( NumberOfRowsAffected, Projection )
  , decodeReturning
  , ppReturning
  )
where

-- base
import Control.Applicative ( liftA2 )
import Data.Foldable ( toList )
import Data.Int ( Int64 )
import Data.Kind ( Type )
import Data.List.NonEmpty ( NonEmpty )
import Prelude

-- hasql
import qualified Hasql.Decoders as Hasql

-- opaleye
import qualified Opaleye.Internal.HaskellDB.PrimQuery as Opaleye
import qualified Opaleye.Internal.HaskellDB.Sql.Print as Opaleye
import qualified Opaleye.Internal.Sql as Opaleye

-- pretty
import Text.PrettyPrint ( Doc, (<+>), text )

-- rel8
import Rel8.Schema.Name ( Selects )
import Rel8.Schema.Table ( TableSchema(..) )
import Rel8.Table.Opaleye ( castTable, exprs, view )
import Rel8.Table.Serialize ( Serializable, parse )

-- semigropuoids
import Data.Functor.Apply ( Apply, (<.>) )


-- | 'Rel8.Insert', 'Rel8.Update' and 'Rel8.Delete' all support returning either
-- the number of rows affected, or the actual rows modified.
type Returning :: Type -> Type -> Type
data Returning names a where
  Pure :: a -> Returning names a
  Ap :: Returning names (a -> b) -> Returning names a -> Returning names b

  -- | Return the number of rows affected.
  NumberOfRowsAffected :: Returning names Int64

  -- | 'Projection' allows you to project out of the affected rows, which can
  -- be useful if you want to log exactly which rows were deleted, or to view
  -- a generated id (for example, if using a column with an autoincrementing
  -- counter via 'Rel8.nextval').
  Projection :: (Selects names exprs, Serializable returning a)
    => (exprs -> returning)
    -> Returning names [a]


instance Functor (Returning names) where
  fmap :: forall a b. (a -> b) -> Returning names a -> Returning names b
fmap a -> b
f = \case
    Pure a
a -> forall a names. a -> Returning names a
Pure (a -> b
f a
a)
    Ap Returning names (a -> a)
g Returning names a
a -> forall names a b.
Returning names (a -> b) -> Returning names a -> Returning names b
Ap (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (a -> b
f forall b c a. (b -> c) -> (a -> b) -> a -> c
.) Returning names (a -> a)
g) Returning names a
a
    Returning names a
m -> forall names a b.
Returning names (a -> b) -> Returning names a -> Returning names b
Ap (forall a names. a -> Returning names a
Pure a -> b
f) Returning names a
m


instance Apply (Returning names) where
  <.> :: forall a b.
Returning names (a -> b) -> Returning names a -> Returning names b
(<.>) = forall names a b.
Returning names (a -> b) -> Returning names a -> Returning names b
Ap


instance Applicative (Returning names) where
  pure :: forall a. a -> Returning names a
pure = forall a names. a -> Returning names a
Pure
  <*> :: forall a b.
Returning names (a -> b) -> Returning names a -> Returning names b
(<*>) = forall names a b.
Returning names (a -> b) -> Returning names a -> Returning names b
Ap


projections :: ()
  => TableSchema names -> Returning names a -> Maybe (NonEmpty Opaleye.PrimExpr)
projections :: forall names a.
TableSchema names -> Returning names a -> Maybe (NonEmpty PrimExpr)
projections schema :: TableSchema names
schema@TableSchema {names
columns :: forall names. TableSchema names -> names
columns :: names
columns} = \case
  Pure a
_ -> forall a. Maybe a
Nothing
  Ap Returning names (a -> a)
f Returning names a
a -> forall names a.
TableSchema names -> Returning names a -> Maybe (NonEmpty PrimExpr)
projections TableSchema names
schema Returning names (a -> a)
f forall a. Semigroup a => a -> a -> a
<> forall names a.
TableSchema names -> Returning names a -> Maybe (NonEmpty PrimExpr)
projections TableSchema names
schema Returning names a
a
  Returning names a
NumberOfRowsAffected -> forall a. Maybe a
Nothing
  Projection exprs -> returning
f -> forall a. a -> Maybe a
Just (forall a. Table Expr a => a -> NonEmpty PrimExpr
exprs (forall a. Table Expr a => a -> a
castTable (exprs -> returning
f (forall names exprs. Selects names exprs => names -> exprs
view names
columns))))


runReturning :: ()
  => ((Int64 -> a) -> r)
  -> (forall x. Hasql.Row x -> ([x] -> a) -> r)
  -> Returning names a
  -> r
runReturning :: forall a r names.
((Int64 -> a) -> r)
-> (forall x. Row x -> ([x] -> a) -> r) -> Returning names a -> r
runReturning (Int64 -> a) -> r
rowCount forall x. Row x -> ([x] -> a) -> r
rowList = \case
  Pure a
a -> (Int64 -> a) -> r
rowCount (forall a b. a -> b -> a
const a
a)
  Ap Returning names (a -> a)
fs Returning names a
as ->
    forall a r names.
((Int64 -> a) -> r)
-> (forall x. Row x -> ([x] -> a) -> r) -> Returning names a -> r
runReturning
      (\Int64 -> a -> a
withCount ->
         forall a r names.
((Int64 -> a) -> r)
-> (forall x. Row x -> ([x] -> a) -> r) -> Returning names a -> r
runReturning
           (\Int64 -> a
withCount' -> (Int64 -> a) -> r
rowCount (Int64 -> a -> a
withCount forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Int64 -> a
withCount'))
           (\Row x
decoder -> forall x. Row x -> ([x] -> a) -> r
rowList Row x
decoder forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) a b c.
Applicative f =>
(a -> b -> c) -> f a -> f b -> f c
liftA2 Int64 -> a -> a
withCount forall (f :: * -> *) x. Foldable f => f x -> Int64
length64)
           Returning names a
as)
      (\Row x
decoder [x] -> a -> a
withRows ->
         forall a r names.
((Int64 -> a) -> r)
-> (forall x. Row x -> ([x] -> a) -> r) -> Returning names a -> r
runReturning
           (\Int64 -> a
withCount -> forall x. Row x -> ([x] -> a) -> r
rowList Row x
decoder forall a b. (a -> b) -> a -> b
$ [x] -> a -> a
withRows forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Int64 -> a
withCount forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) x. Foldable f => f x -> Int64
length64)
           (\Row x
decoder' [x] -> a
withRows' ->
             forall x. Row x -> ([x] -> a) -> r
rowList (forall (f :: * -> *) a b c.
Applicative f =>
(a -> b -> c) -> f a -> f b -> f c
liftA2 (,) Row x
decoder Row x
decoder') forall a b. (a -> b) -> a -> b
$
               [x] -> a -> a
withRows forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a b. (a, b) -> a
fst forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [x] -> a
withRows' forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a b. (a, b) -> b
snd)
           Returning names a
as)
      Returning names (a -> a)
fs
  Returning names a
NumberOfRowsAffected -> (Int64 -> a) -> r
rowCount forall a. a -> a
id
  Projection (exprs -> returning
_ :: exprs -> returning) -> forall x. Row x -> ([x] -> a) -> r
rowList Row a
decoder' forall a. a -> a
id
    where
      decoder' :: Row a
decoder' = forall exprs a. Serializable exprs a => Row a
parse @returning
  where
    length64 :: Foldable f => f x -> Int64
    length64 :: forall (f :: * -> *) x. Foldable f => f x -> Int64
length64 = forall a b. (Integral a, Num b) => a -> b
fromIntegral forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) a. Foldable t => t a -> Int
length


decodeReturning :: Returning names a -> Hasql.Result a
decodeReturning :: forall names a. Returning names a -> Result a
decodeReturning = forall a r names.
((Int64 -> a) -> r)
-> (forall x. Row x -> ([x] -> a) -> r) -> Returning names a -> r
runReturning
  (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Result Int64
Hasql.rowsAffected)
  (\Row x
decoder [x] -> a
withRows -> [x] -> a
withRows forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. Row a -> Result [a]
Hasql.rowList Row x
decoder)


ppReturning :: TableSchema names -> Returning names a -> Doc
ppReturning :: forall names a. TableSchema names -> Returning names a -> Doc
ppReturning TableSchema names
schema Returning names a
returning = case forall names a.
TableSchema names -> Returning names a -> Maybe (NonEmpty PrimExpr)
projections TableSchema names
schema Returning names a
returning of
  Maybe (NonEmpty PrimExpr)
Nothing -> forall a. Monoid a => a
mempty
  Just NonEmpty PrimExpr
columns ->
    String -> Doc
text String
"RETURNING" Doc -> Doc -> Doc
<+> forall a. (a -> Doc) -> [a] -> Doc
Opaleye.commaV SqlExpr -> Doc
Opaleye.ppSqlExpr (forall (t :: * -> *) a. Foldable t => t a -> [a]
toList NonEmpty SqlExpr
sqlExprs)
    where
      sqlExprs :: NonEmpty SqlExpr
sqlExprs = PrimExpr -> SqlExpr
Opaleye.sqlExpr forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> NonEmpty PrimExpr
columns