{-|
Module: Squeal.PostgreSQL.Query.From.Set
Description: set returning functions
Copyright: (c) Eitan Chatav, 2019
Maintainer: eitan@morphism.tech
Stability: experimental

set returning functions
-}

{-# LANGUAGE
    ConstraintKinds
  , DeriveGeneric
  , DerivingStrategies
  , FlexibleContexts
  , FlexibleInstances
  , GADTs
  , GeneralizedNewtypeDeriving
  , LambdaCase
  , MultiParamTypeClasses
  , OverloadedLabels
  , OverloadedStrings
  , QuantifiedConstraints
  , ScopedTypeVariables
  , StandaloneDeriving
  , TypeApplications
  , TypeFamilies
  , TypeInType
  , TypeOperators
  , RankNTypes
  , UndecidableInstances
  #-}

module Squeal.PostgreSQL.Query.From.Set
  ( -- * Set Functions
    type (-|->)
  , type (--|->)
  , SetFun
  , SetFunN
  , generateSeries
  , generateSeriesStep
  , generateSeriesTimestamp
  , unsafeSetFunction
  , setFunction
  , unsafeSetFunctionN
  , setFunctionN
  ) where

import Data.ByteString (ByteString)
import Generics.SOP hiding (from)
import GHC.TypeLits

import qualified Generics.SOP as SOP

import Squeal.PostgreSQL.Type.Alias
import Squeal.PostgreSQL.Expression
import Squeal.PostgreSQL.Query.From
import Squeal.PostgreSQL.Render
import Squeal.PostgreSQL.Type.List
import Squeal.PostgreSQL.Type.Schema

{- |
A @RankNType@ for set returning functions with 1 argument.
-}
type (-|->) arg set = forall db. SetFun db arg set

{- |
A @RankNType@ for set returning functions with multiple argument.
-}
type (--|->) arg set = forall db. SetFunN db arg set
     -- ^ output

{- |
Like `-|->` but depends on the schemas of the database
-}
type SetFun db arg row
  =  forall lat with params
  .  Expression 'Ungrouped lat with db params '[] arg
     -- ^ input
  -> FromClause lat with db params '[row]
     -- ^ output

{- |
Like `--|->` but depends on the schemas of the database
-}
type SetFunN db args set
  =  forall lat with params
  .  NP (Expression 'Ungrouped lat with db params '[]) args
     -- ^ input
  -> FromClause lat with db params '[set]
     -- ^ output

-- $setup
-- >>> import Squeal.PostgreSQL

-- | Escape hatch for a set returning function of a single variable
unsafeSetFunction
  :: forall fun ty row. KnownSymbol fun
  => ByteString
  -> ty -|-> (fun ::: row) -- ^ set returning function
unsafeSetFunction :: ByteString -> ty -|-> (fun ::: row)
unsafeSetFunction ByteString
fun Expression 'Ungrouped lat with db params '[] ty
x = ByteString -> FromClause lat with db params '[fun ::: row]
forall (lat :: FromType) (with :: FromType) (db :: SchemasType)
       (params :: [NullType]) (from :: FromType).
ByteString -> FromClause lat with db params from
UnsafeFromClause (ByteString -> FromClause lat with db params '[fun ::: row])
-> ByteString -> FromClause lat with db params '[fun ::: row]
forall a b. (a -> b) -> a -> b
$
  ByteString
fun ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString -> ByteString
parenthesized (Expression 'Ungrouped lat with db params '[] ty -> ByteString
forall sql. RenderSQL sql => sql -> ByteString
renderSQL Expression 'Ungrouped lat with db params '[] ty
x)

{- | Call a user defined set returning function of a single variable

>>> type Fn = '[ 'Null 'PGbool] :=> 'ReturnsTable '["ret" ::: 'NotNull 'PGnumeric]
>>> type Schema = '["fn" ::: 'Function Fn]
>>> :{
let
  fn :: SetFun (Public Schema) ('Null 'PGbool) ("fn" ::: '["ret" ::: 'NotNull 'PGnumeric])
  fn = setFunction #fn
in
  printSQL (fn true)
:}
"fn"(TRUE)
-}
setFunction
  :: ( Has sch db schema
     , Has fun schema ('Function ('[ty] :=> 'ReturnsTable row)) )
  => QualifiedAlias sch fun -- ^ function alias
  -> SetFun db ty (fun ::: row)
setFunction :: QualifiedAlias sch fun -> SetFun db ty (fun ::: row)
setFunction QualifiedAlias sch fun
fun = ByteString -> ty -|-> (fun ::: row)
forall (fun :: Symbol) (ty :: NullType) (row :: RowType).
KnownSymbol fun =>
ByteString -> ty -|-> (fun ::: row)
unsafeSetFunction (QualifiedAlias sch fun -> ByteString
forall sql. RenderSQL sql => sql -> ByteString
renderSQL QualifiedAlias sch fun
fun)

{- | Escape hatch for a multivariable set returning function-}
unsafeSetFunctionN
  :: forall fun tys row. (SOP.SListI tys, KnownSymbol fun)
  => ByteString
  -> tys --|-> (fun ::: row) -- ^ set returning function
unsafeSetFunctionN :: ByteString -> tys --|-> (fun ::: row)
unsafeSetFunctionN ByteString
fun NP (Expression 'Ungrouped lat with db params '[]) tys
xs = ByteString -> FromClause lat with db params '[fun ::: row]
forall (lat :: FromType) (with :: FromType) (db :: SchemasType)
       (params :: [NullType]) (from :: FromType).
ByteString -> FromClause lat with db params from
UnsafeFromClause (ByteString -> FromClause lat with db params '[fun ::: row])
-> ByteString -> FromClause lat with db params '[fun ::: row]
forall a b. (a -> b) -> a -> b
$
  ByteString
fun ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString -> ByteString
parenthesized ((forall (x :: NullType).
 Expression 'Ungrouped lat with db params '[] x -> ByteString)
-> NP (Expression 'Ungrouped lat with db params '[]) tys
-> ByteString
forall k (xs :: [k]) (expression :: k -> *).
SListI xs =>
(forall (x :: k). expression x -> ByteString)
-> NP expression xs -> ByteString
renderCommaSeparated forall sql. RenderSQL sql => sql -> ByteString
forall (x :: NullType).
Expression 'Ungrouped lat with db params '[] x -> ByteString
renderSQL NP (Expression 'Ungrouped lat with db params '[]) tys
xs)

{- | Call a user defined multivariable set returning function

>>> type Fn = '[ 'Null 'PGbool, 'Null 'PGtext] :=> 'ReturnsTable '["ret" ::: 'NotNull 'PGnumeric]
>>> type Schema = '["fn" ::: 'Function Fn]
>>> :{
let
  fn :: SetFunN (Public Schema)
    '[ 'Null 'PGbool, 'Null 'PGtext]
    ("fn" ::: '["ret" ::: 'NotNull 'PGnumeric])
  fn = setFunctionN #fn
in
  printSQL (fn (true *: "hi"))
:}
"fn"(TRUE, (E'hi' :: text))
-}
setFunctionN
  :: ( Has sch db schema
     , Has fun schema ('Function (tys :=> 'ReturnsTable row))
     , SOP.SListI tys )
  => QualifiedAlias sch fun -- ^ function alias
  -> SetFunN db tys (fun ::: row)
setFunctionN :: QualifiedAlias sch fun -> SetFunN db tys (fun ::: row)
setFunctionN QualifiedAlias sch fun
fun = ByteString -> tys --|-> (fun ::: row)
forall (fun :: Symbol) (tys :: [NullType]) (row :: RowType).
(SListI tys, KnownSymbol fun) =>
ByteString -> tys --|-> (fun ::: row)
unsafeSetFunctionN (QualifiedAlias sch fun -> ByteString
forall sql. RenderSQL sql => sql -> ByteString
renderSQL QualifiedAlias sch fun
fun)

{- | @generateSeries (start :* stop)@

Generate a series of values,
from @start@ to @stop@ with a step size of one

>>> printSQL (generateSeries @'PGint4 (1 *: 10))
generate_series((1 :: int4), (10 :: int4))
-}
generateSeries
  :: ty `In` '[ 'PGint4, 'PGint8, 'PGnumeric]
  => '[ null ty, null ty] --|->
    ("generate_series" ::: '["generate_series" ::: null ty])
    -- ^ set returning function
generateSeries :: '[null ty, null ty]
--|-> ("generate_series" ::: '["generate_series" ::: null ty])
generateSeries = ByteString
-> '[null ty, null ty]
   --|-> ("generate_series" ::: '["generate_series" ::: null ty])
forall (fun :: Symbol) (tys :: [NullType]) (row :: RowType).
(SListI tys, KnownSymbol fun) =>
ByteString -> tys --|-> (fun ::: row)
unsafeSetFunctionN ByteString
"generate_series"

{- | @generateSeriesStep (start :* stop *: step)@

Generate a series of values,
from @start@ to @stop@ with a step size of @step@

>>> printSQL (generateSeriesStep @'PGint8 (2 :* 100 *: 2))
generate_series((2 :: int8), (100 :: int8), (2 :: int8))
-}
generateSeriesStep
  :: ty `In` '[ 'PGint4, 'PGint8, 'PGnumeric]
  => '[null ty, null ty, null ty] --|->
    ("generate_series" ::: '["generate_series" ::: null ty])
    -- ^ set returning function
generateSeriesStep :: '[null ty, null ty, null ty]
--|-> ("generate_series" ::: '["generate_series" ::: null ty])
generateSeriesStep = ByteString
-> '[null ty, null ty, null ty]
   --|-> ("generate_series" ::: '["generate_series" ::: null ty])
forall (fun :: Symbol) (tys :: [NullType]) (row :: RowType).
(SListI tys, KnownSymbol fun) =>
ByteString -> tys --|-> (fun ::: row)
unsafeSetFunctionN ByteString
"generate_series"

{- | @generateSeriesTimestamp (start :* stop *: step)@

Generate a series of timestamps,
from @start@ to @stop@ with a step size of @step@

>>> :{
let
  start = now
  stop = now !+ interval_ 10 Years
  step = interval_ 1 Months
in printSQL (generateSeriesTimestamp (start :* stop *: step))
:}
generate_series(now(), (now() + (INTERVAL '10.000 years')), (INTERVAL '1.000 months'))
-}
generateSeriesTimestamp
  :: ty `In` '[ 'PGtimestamp, 'PGtimestamptz]
  => '[null ty, null ty, null 'PGinterval] --|->
    ("generate_series"  ::: '["generate_series" ::: null ty])
    -- ^ set returning function
generateSeriesTimestamp :: '[null ty, null ty, null 'PGinterval]
--|-> ("generate_series" ::: '["generate_series" ::: null ty])
generateSeriesTimestamp = ByteString
-> '[null ty, null ty, null 'PGinterval]
   --|-> ("generate_series" ::: '["generate_series" ::: null ty])
forall (fun :: Symbol) (tys :: [NullType]) (row :: RowType).
(SListI tys, KnownSymbol fun) =>
ByteString -> tys --|-> (fun ::: row)
unsafeSetFunctionN ByteString
"generate_series"