{-|
Module: Squeal.PostgreSQL.Render
Description: render functions
Copyright: (c) Eitan Chatav, 2019
Maintainer: eitan@morphism.tech
Stability: experimental

render functions
-}

{-# LANGUAGE
    AllowAmbiguousTypes
  , ConstraintKinds
  , FlexibleContexts
  , LambdaCase
  , MagicHash
  , OverloadedStrings
  , PolyKinds
  , RankNTypes
  , ScopedTypeVariables
  , TypeApplications
#-}

module Squeal.PostgreSQL.Render
  ( -- * Render
    RenderSQL (..)
  , printSQL
  , escape
  , parenthesized
  , bracketed
  , (<+>)
  , commaSeparated
  , doubleQuoted
  , singleQuotedText
  , singleQuotedUtf8
  , escapeQuotedString
  , escapeQuotedText
  , renderCommaSeparated
  , renderCommaSeparatedConstraint
  , renderCommaSeparatedMaybe
  , renderNat
  , renderSymbol
  ) where

import Control.Monad.IO.Class (MonadIO (..))
import Data.ByteString (ByteString)
import Data.Maybe (catMaybes)
import Data.Text (Text)
import Generics.SOP
import GHC.Exts
import GHC.TypeLits hiding (Text)

import qualified Data.Text as Text
import qualified Data.Text.Encoding as Text
import qualified Data.ByteString as ByteString
import qualified Data.ByteString.Char8 as Char8

-- | Parenthesize a `ByteString`.
parenthesized :: ByteString -> ByteString
parenthesized :: ByteString -> ByteString
parenthesized ByteString
str = ByteString
"(" ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
str ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
")"

-- | Square bracket a `ByteString`
bracketed :: ByteString -> ByteString
bracketed :: ByteString -> ByteString
bracketed ByteString
str = ByteString
"[" ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
str ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
"]"

-- | Concatenate two `ByteString`s with a space between.
(<+>) :: ByteString -> ByteString -> ByteString
infixr 7 <+>
ByteString
str1 <+> :: ByteString -> ByteString -> ByteString
<+> ByteString
str2 = ByteString
str1 ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
" " ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
str2

-- | Comma separate a list of `ByteString`s.
commaSeparated :: [ByteString] -> ByteString
commaSeparated :: [ByteString] -> ByteString
commaSeparated = ByteString -> [ByteString] -> ByteString
ByteString.intercalate ByteString
", "

-- | Add double quotes around a `ByteString`.
doubleQuoted :: ByteString -> ByteString
doubleQuoted :: ByteString -> ByteString
doubleQuoted ByteString
str = ByteString
"\"" ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
str ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
"\""

-- | Add single quotes around a `Text` and escape single quotes within it.
singleQuotedText :: Text -> ByteString
singleQuotedText :: Text -> ByteString
singleQuotedText Text
str =
  ByteString
"'" ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> Text -> ByteString
Text.encodeUtf8 (Text -> Text -> Text -> Text
Text.replace Text
"'" Text
"''" Text
str) ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
"'"

-- | Add single quotes around a `ByteString` and escape single quotes within it.
singleQuotedUtf8 :: ByteString -> ByteString
singleQuotedUtf8 :: ByteString -> ByteString
singleQuotedUtf8 = Text -> ByteString
singleQuotedText (Text -> ByteString)
-> (ByteString -> Text) -> ByteString -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> Text
Text.decodeUtf8

-- | Escape quote a string.
escapeQuotedString :: String -> ByteString
escapeQuotedString :: String -> ByteString
escapeQuotedString String
x = ByteString
"E\'" ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> Text -> ByteString
Text.encodeUtf8 (String -> Text
forall a. IsString a => String -> a
fromString (Char -> String
escape (Char -> String) -> String -> String
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< String
x)) ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
"\'"

-- | Escape quote a string.
escapeQuotedText :: Text -> ByteString
escapeQuotedText :: Text -> ByteString
escapeQuotedText Text
x =
  ByteString
"E\'" ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> Text -> ByteString
Text.encodeUtf8 ((Char -> Text) -> Text -> Text
Text.concatMap (String -> Text
forall a. IsString a => String -> a
fromString (String -> Text) -> (Char -> String) -> Char -> Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Char -> String
escape) Text
x) ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
"\'"

-- | Comma separate the renderings of a heterogeneous list.
renderCommaSeparated
  :: SListI xs
  => (forall x. expression x -> ByteString)
  -> NP expression xs -> ByteString
renderCommaSeparated :: (forall (x :: k). expression x -> ByteString)
-> NP expression xs -> ByteString
renderCommaSeparated forall (x :: k). expression x -> ByteString
render
  = [ByteString] -> ByteString
commaSeparated
  ([ByteString] -> ByteString)
-> (NP expression xs -> [ByteString])
-> NP expression xs
-> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. NP (K ByteString) xs -> [ByteString]
forall k l (h :: (k -> *) -> l -> *) (xs :: l) a.
(HCollapse h, SListIN h xs) =>
h (K a) xs -> CollapseTo h a
hcollapse
  (NP (K ByteString) xs -> [ByteString])
-> (NP expression xs -> NP (K ByteString) xs)
-> NP expression xs
-> [ByteString]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall (a :: k). expression a -> K ByteString a)
-> NP expression xs -> NP (K ByteString) xs
forall k l (h :: (k -> *) -> l -> *) (xs :: l) (f :: k -> *)
       (f' :: k -> *).
(SListIN (Prod h) xs, HAp h) =>
(forall (a :: k). f a -> f' a) -> h f xs -> h f' xs
hmap (ByteString -> K ByteString a
forall k a (b :: k). a -> K a b
K (ByteString -> K ByteString a)
-> (expression a -> ByteString) -> expression a -> K ByteString a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. expression a -> ByteString
forall (x :: k). expression x -> ByteString
render)

-- | Comma separate the renderings of a heterogeneous list.
renderCommaSeparatedConstraint
  :: forall c xs expression. (All c xs, SListI xs)
  => (forall x. c x => expression x -> ByteString)
  -> NP expression xs -> ByteString
renderCommaSeparatedConstraint :: (forall (x :: k). c x => expression x -> ByteString)
-> NP expression xs -> ByteString
renderCommaSeparatedConstraint forall (x :: k). c x => expression x -> ByteString
render
  = [ByteString] -> ByteString
commaSeparated
  ([ByteString] -> ByteString)
-> (NP expression xs -> [ByteString])
-> NP expression xs
-> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. NP (K ByteString) xs -> [ByteString]
forall k l (h :: (k -> *) -> l -> *) (xs :: l) a.
(HCollapse h, SListIN h xs) =>
h (K a) xs -> CollapseTo h a
hcollapse
  (NP (K ByteString) xs -> [ByteString])
-> (NP expression xs -> NP (K ByteString) xs)
-> NP expression xs
-> [ByteString]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Proxy c
-> (forall (a :: k). c a => expression a -> K ByteString a)
-> NP expression xs
-> NP (K ByteString) xs
forall k l (h :: (k -> *) -> l -> *) (c :: k -> Constraint)
       (xs :: l) (proxy :: (k -> Constraint) -> *) (f :: k -> *)
       (f' :: k -> *).
(AllN (Prod h) c xs, HAp h) =>
proxy c
-> (forall (a :: k). c a => f a -> f' a) -> h f xs -> h f' xs
hcmap (Proxy c
forall k (t :: k). Proxy t
Proxy @c) (ByteString -> K ByteString a
forall k a (b :: k). a -> K a b
K (ByteString -> K ByteString a)
-> (expression a -> ByteString) -> expression a -> K ByteString a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. expression a -> ByteString
forall (x :: k). c x => expression x -> ByteString
render)

-- | Comma separate the `Maybe` renderings of a heterogeneous list, dropping
-- `Nothing`s.
renderCommaSeparatedMaybe
  :: SListI xs
  => (forall x. expression x -> Maybe ByteString)
  -> NP expression xs -> ByteString
renderCommaSeparatedMaybe :: (forall (x :: k). expression x -> Maybe ByteString)
-> NP expression xs -> ByteString
renderCommaSeparatedMaybe forall (x :: k). expression x -> Maybe ByteString
render
  = [ByteString] -> ByteString
commaSeparated
  ([ByteString] -> ByteString)
-> (NP expression xs -> [ByteString])
-> NP expression xs
-> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Maybe ByteString] -> [ByteString]
forall a. [Maybe a] -> [a]
catMaybes
  ([Maybe ByteString] -> [ByteString])
-> (NP expression xs -> [Maybe ByteString])
-> NP expression xs
-> [ByteString]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. NP (K (Maybe ByteString)) xs -> [Maybe ByteString]
forall k l (h :: (k -> *) -> l -> *) (xs :: l) a.
(HCollapse h, SListIN h xs) =>
h (K a) xs -> CollapseTo h a
hcollapse
  (NP (K (Maybe ByteString)) xs -> [Maybe ByteString])
-> (NP expression xs -> NP (K (Maybe ByteString)) xs)
-> NP expression xs
-> [Maybe ByteString]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall (a :: k). expression a -> K (Maybe ByteString) a)
-> NP expression xs -> NP (K (Maybe ByteString)) xs
forall k l (h :: (k -> *) -> l -> *) (xs :: l) (f :: k -> *)
       (f' :: k -> *).
(SListIN (Prod h) xs, HAp h) =>
(forall (a :: k). f a -> f' a) -> h f xs -> h f' xs
hmap (Maybe ByteString -> K (Maybe ByteString) a
forall k a (b :: k). a -> K a b
K (Maybe ByteString -> K (Maybe ByteString) a)
-> (expression a -> Maybe ByteString)
-> expression a
-> K (Maybe ByteString) a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. expression a -> Maybe ByteString
forall (x :: k). expression x -> Maybe ByteString
render)

-- | Render a promoted `Nat`.
renderNat :: forall n. KnownNat n => ByteString
renderNat :: ByteString
renderNat = String -> ByteString
forall a. IsString a => String -> a
fromString (Integer -> String
forall a. Show a => a -> String
show (Proxy# n -> Integer
forall (n :: Nat). KnownNat n => Proxy# n -> Integer
natVal' (Proxy# n
forall k (a :: k). Proxy# a
proxy# :: Proxy# n)))

-- | Render a promoted `Symbol`.
renderSymbol :: forall s. KnownSymbol s => ByteString
renderSymbol :: ByteString
renderSymbol = String -> ByteString
forall a. IsString a => String -> a
fromString (Proxy# s -> String
forall (n :: Symbol). KnownSymbol n => Proxy# n -> String
symbolVal' (Proxy# s
forall k (a :: k). Proxy# a
proxy# :: Proxy# s))

-- | A class for rendering SQL
class RenderSQL sql where renderSQL :: sql -> ByteString

-- | Print SQL.
printSQL :: (RenderSQL sql, MonadIO io) => sql -> io ()
printSQL :: sql -> io ()
printSQL = IO () -> io ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> io ()) -> (sql -> IO ()) -> sql -> io ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> IO ()
Char8.putStrLn (ByteString -> IO ()) -> (sql -> ByteString) -> sql -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. sql -> ByteString
forall sql. RenderSQL sql => sql -> ByteString
renderSQL

-- | `escape` a character to prevent injection
escape :: Char -> String
escape :: Char -> String
escape = \case
  Char
'\NUL' -> String
""
  Char
'\'' -> String
"''"
  Char
'"' -> String
"\\\""
  Char
'\b' -> String
"\\b"
  Char
'\n' -> String
"\\n"
  Char
'\r' -> String
"\\r"
  Char
'\t' -> String
"\\t"
  Char
'\\' -> String
"\\\\"
  Char
c -> [Char
c]