{-# LANGUAGE FlexibleContexts #-}
module Preql.Wire.Query where

import Preql.FromSql
import Preql.Wire.Errors
import Preql.Wire.Decode
import Preql.Wire.Internal
import Preql.Wire.ToSql

import Control.Monad
import GHC.TypeNats
import Preql.Imports

import qualified Data.Text as T
import qualified Data.Vector as V
import qualified Database.PostgreSQL.LibPQ as PQ

queryWith :: KnownNat (Width r) =>
  RowEncoder p -> RowDecoder (Width r) r -> PQ.Connection ->
  Query (Width r) -> p -> IO (Either QueryError (Vector r))
queryWith :: RowEncoder p
-> RowDecoder (Width r) r
-> Connection
-> Query (Width r)
-> p
-> IO (Either QueryError (Vector r))
queryWith RowEncoder p
enc RowDecoder (Width r) r
dec Connection
conn (Query ByteString
q) p
params = do
    -- TODO safer Connection type
    -- withMVar (connectionHandle conn) $ \connRaw -> do
        Either QueryError Result
e_result <- RowEncoder p
-> Connection -> ByteString -> p -> IO (Either QueryError Result)
forall p.
RowEncoder p
-> Connection -> ByteString -> p -> IO (Either QueryError Result)
execParams RowEncoder p
enc Connection
conn ByteString
q p
params
        case Either QueryError Result
e_result of
            Left QueryError
err -> Either QueryError (Vector r) -> IO (Either QueryError (Vector r))
forall (m :: * -> *) a. Monad m => a -> m a
return (QueryError -> Either QueryError (Vector r)
forall a b. a -> Either a b
Left QueryError
err)
            Right Result
rows -> (PgType -> IO (Either QueryError Oid))
-> RowDecoder (Width r) r
-> Result
-> IO (Either QueryError (Vector r))
forall (n :: Nat) a.
KnownNat n =>
(PgType -> IO (Either QueryError Oid))
-> RowDecoder n a -> Result -> IO (Either QueryError (Vector a))
decodeVector (Connection -> PgType -> IO (Either QueryError Oid)
lookupType Connection
conn) RowDecoder (Width r) r
dec Result
rows

-- If there is no result, we don't need a Decoder
queryWith_ :: RowEncoder p -> PQ.Connection -> Query n -> p -> IO (Either QueryError ())
queryWith_ :: RowEncoder p
-> Connection -> Query n -> p -> IO (Either QueryError ())
queryWith_ RowEncoder p
enc Connection
conn (Query ByteString
q) p
params = do
    Either QueryError Result
e_result <- RowEncoder p
-> Connection -> ByteString -> p -> IO (Either QueryError Result)
forall p.
RowEncoder p
-> Connection -> ByteString -> p -> IO (Either QueryError Result)
execParams RowEncoder p
enc Connection
conn ByteString
q p
params
    Either QueryError () -> IO (Either QueryError ())
forall (m :: * -> *) a. Monad m => a -> m a
return (Either QueryError Result -> Either QueryError ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void Either QueryError Result
e_result)

query :: (ToSql p, FromSql r, KnownNat (Width r)) =>
    PQ.Connection -> Query (Width r) -> p -> IO (Either QueryError (Vector r))
query :: Connection
-> Query (Width r) -> p -> IO (Either QueryError (Vector r))
query = RowEncoder p
-> RowDecoder (Width r) r
-> Connection
-> Query (Width r)
-> p
-> IO (Either QueryError (Vector r))
forall r p.
KnownNat (Width r) =>
RowEncoder p
-> RowDecoder (Width r) r
-> Connection
-> Query (Width r)
-> p
-> IO (Either QueryError (Vector r))
queryWith RowEncoder p
forall a. ToSql a => RowEncoder a
toSql RowDecoder (Width r) r
forall a. FromSql a => RowDecoder (Width a) a
fromSql

query_ :: ToSql p => PQ.Connection -> Query n -> p -> IO (Either QueryError ())
query_ :: Connection -> Query n -> p -> IO (Either QueryError ())
query_ = RowEncoder p
-> Connection -> Query n -> p -> IO (Either QueryError ())
forall p (n :: Nat).
RowEncoder p
-> Connection -> Query n -> p -> IO (Either QueryError ())
queryWith_ RowEncoder p
forall a. ToSql a => RowEncoder a
toSql

execParams :: RowEncoder p -> PQ.Connection -> ByteString -> p -> IO (Either QueryError PQ.Result)
execParams :: RowEncoder p
-> Connection -> ByteString -> p -> IO (Either QueryError Result)
execParams RowEncoder p
enc Connection
conn ByteString
q p
params = do
    Either Text Result
e_result <- Connection -> Maybe Result -> IO (Either Text Result)
forall a. Connection -> Maybe a -> IO (Either Text a)
connectionError Connection
conn (Maybe Result -> IO (Either Text Result))
-> IO (Maybe Result) -> IO (Either Text Result)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Connection
-> ByteString
-> [Maybe (Oid, ByteString, Format)]
-> Format
-> IO (Maybe Result)
PQ.execParams Connection
conn ByteString
q (RowEncoder p -> p -> [Maybe (Oid, ByteString, Format)]
forall p. RowEncoder p -> p -> [Maybe (Oid, ByteString, Format)]
runEncoder RowEncoder p
enc p
params) Format
PQ.Binary
    case Either Text Result
e_result of
        Left Text
err -> Either QueryError Result -> IO (Either QueryError Result)
forall (m :: * -> *) a. Monad m => a -> m a
return (QueryError -> Either QueryError Result
forall a b. a -> Either a b
Left (Text -> QueryError
ConnectionError Text
err))
        Right Result
res -> do
            ExecStatus
status <- Result -> IO ExecStatus
PQ.resultStatus Result
res
            if ExecStatus
status ExecStatus -> ExecStatus -> Bool
forall a. Eq a => a -> a -> Bool
== ExecStatus
PQ.CommandOk Bool -> Bool -> Bool
|| ExecStatus
status ExecStatus -> ExecStatus -> Bool
forall a. Eq a => a -> a -> Bool
== ExecStatus
PQ.TuplesOk
                then Either QueryError Result -> IO (Either QueryError Result)
forall (m :: * -> *) a. Monad m => a -> m a
return (Result -> Either QueryError Result
forall a b. b -> Either a b
Right Result
res)
                else do
                    Text
msg <- Result -> IO (Maybe ByteString)
PQ.resultErrorMessage Result
res
                        IO (Maybe ByteString) -> (Maybe ByteString -> Text) -> IO Text
forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> Text -> (ByteString -> Text) -> Maybe ByteString -> Text
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (String -> Text
T.pack (ExecStatus -> String
forall a. Show a => a -> String
show ExecStatus
status)) (OnDecodeError -> ByteString -> Text
decodeUtf8With OnDecodeError
lenientDecode)
                    Either QueryError Result -> IO (Either QueryError Result)
forall (m :: * -> *) a. Monad m => a -> m a
return (QueryError -> Either QueryError Result
forall a b. a -> Either a b
Left (Text -> QueryError
ConnectionError Text
msg))

connectionError :: PQ.Connection -> Maybe a -> IO (Either Text a)
connectionError :: Connection -> Maybe a -> IO (Either Text a)
connectionError Connection
_conn (Just a
a) = Either Text a -> IO (Either Text a)
forall (m :: * -> *) a. Monad m => a -> m a
return (a -> Either Text a
forall a b. b -> Either a b
Right a
a)
connectionError Connection
conn Maybe a
Nothing = do
    Maybe ByteString
m_msg <- IO (Maybe ByteString) -> IO (Maybe ByteString)
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (Maybe ByteString) -> IO (Maybe ByteString))
-> IO (Maybe ByteString) -> IO (Maybe ByteString)
forall a b. (a -> b) -> a -> b
$ Connection -> IO (Maybe ByteString)
PQ.errorMessage Connection
conn
    case Maybe ByteString
m_msg of
        Just ByteString
msg -> Either Text a -> IO (Either Text a)
forall (m :: * -> *) a. Monad m => a -> m a
return (Text -> Either Text a
forall a b. a -> Either a b
Left (OnDecodeError -> ByteString -> Text
decodeUtf8With OnDecodeError
lenientDecode ByteString
msg))
        Maybe ByteString
Nothing -> Either Text a -> IO (Either Text a)
forall (m :: * -> *) a. Monad m => a -> m a
return (Text -> Either Text a
forall a b. a -> Either a b
Left Text
"No error message available")

lookupType :: PQ.Connection -> PgType -> IO (Either QueryError PQ.Oid)
lookupType :: Connection -> PgType -> IO (Either QueryError Oid)
lookupType Connection
_ (Oid Oid
oid) = Either QueryError Oid -> IO (Either QueryError Oid)
forall (m :: * -> *) a. Monad m => a -> m a
return (Oid -> Either QueryError Oid
forall a b. b -> Either a b
Right Oid
oid)
lookupType Connection
conn (TypeName Text
name) = do
    Either QueryError (Vector Oid)
e_rows <- Connection
-> Query (Width Oid) -> Text -> IO (Either QueryError (Vector Oid))
forall p r.
(ToSql p, FromSql r, KnownNat (Width r)) =>
Connection
-> Query (Width r) -> p -> IO (Either QueryError (Vector r))
query Connection
conn Query (Width Oid)
"SELECT oid FROM pg_type WHERE typname = $1" Text
name
    case (Vector Oid -> Maybe Oid)
-> Either QueryError (Vector Oid) -> Either QueryError (Maybe Oid)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Vector Oid -> Int -> Maybe Oid
forall a. Vector a -> Int -> Maybe a
V.!? Int
0) Either QueryError (Vector Oid)
e_rows of
        Left QueryError
e -> Either QueryError Oid -> IO (Either QueryError Oid)
forall (m :: * -> *) a. Monad m => a -> m a
return (QueryError -> Either QueryError Oid
forall a b. a -> Either a b
Left QueryError
e)
        Right (Just Oid
oid) -> Either QueryError Oid -> IO (Either QueryError Oid)
forall (m :: * -> *) a. Monad m => a -> m a
return (Oid -> Either QueryError Oid
forall a b. b -> Either a b
Right Oid
oid)
        Right Maybe Oid
Nothing -> Either QueryError Oid -> IO (Either QueryError Oid)
forall (m :: * -> *) a. Monad m => a -> m a
return (QueryError -> Either QueryError Oid
forall a b. a -> Either a b
Left (Text -> QueryError
ConnectionError (Text
"No oid for: " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
name)))

data IsolationLevel
    = ReadCommitted
    | RepeatableRead
    | Serializable
    deriving (Int -> IsolationLevel -> ShowS
[IsolationLevel] -> ShowS
IsolationLevel -> String
(Int -> IsolationLevel -> ShowS)
-> (IsolationLevel -> String)
-> ([IsolationLevel] -> ShowS)
-> Show IsolationLevel
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [IsolationLevel] -> ShowS
$cshowList :: [IsolationLevel] -> ShowS
show :: IsolationLevel -> String
$cshow :: IsolationLevel -> String
showsPrec :: Int -> IsolationLevel -> ShowS
$cshowsPrec :: Int -> IsolationLevel -> ShowS
Show, ReadPrec [IsolationLevel]
ReadPrec IsolationLevel
Int -> ReadS IsolationLevel
ReadS [IsolationLevel]
(Int -> ReadS IsolationLevel)
-> ReadS [IsolationLevel]
-> ReadPrec IsolationLevel
-> ReadPrec [IsolationLevel]
-> Read IsolationLevel
forall a.
(Int -> ReadS a)
-> ReadS [a] -> ReadPrec a -> ReadPrec [a] -> Read a
readListPrec :: ReadPrec [IsolationLevel]
$creadListPrec :: ReadPrec [IsolationLevel]
readPrec :: ReadPrec IsolationLevel
$creadPrec :: ReadPrec IsolationLevel
readList :: ReadS [IsolationLevel]
$creadList :: ReadS [IsolationLevel]
readsPrec :: Int -> ReadS IsolationLevel
$creadsPrec :: Int -> ReadS IsolationLevel
Read, IsolationLevel -> IsolationLevel -> Bool
(IsolationLevel -> IsolationLevel -> Bool)
-> (IsolationLevel -> IsolationLevel -> Bool) -> Eq IsolationLevel
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: IsolationLevel -> IsolationLevel -> Bool
$c/= :: IsolationLevel -> IsolationLevel -> Bool
== :: IsolationLevel -> IsolationLevel -> Bool
$c== :: IsolationLevel -> IsolationLevel -> Bool
Eq, Eq IsolationLevel
Eq IsolationLevel
-> (IsolationLevel -> IsolationLevel -> Ordering)
-> (IsolationLevel -> IsolationLevel -> Bool)
-> (IsolationLevel -> IsolationLevel -> Bool)
-> (IsolationLevel -> IsolationLevel -> Bool)
-> (IsolationLevel -> IsolationLevel -> Bool)
-> (IsolationLevel -> IsolationLevel -> IsolationLevel)
-> (IsolationLevel -> IsolationLevel -> IsolationLevel)
-> Ord IsolationLevel
IsolationLevel -> IsolationLevel -> Bool
IsolationLevel -> IsolationLevel -> Ordering
IsolationLevel -> IsolationLevel -> IsolationLevel
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: IsolationLevel -> IsolationLevel -> IsolationLevel
$cmin :: IsolationLevel -> IsolationLevel -> IsolationLevel
max :: IsolationLevel -> IsolationLevel -> IsolationLevel
$cmax :: IsolationLevel -> IsolationLevel -> IsolationLevel
>= :: IsolationLevel -> IsolationLevel -> Bool
$c>= :: IsolationLevel -> IsolationLevel -> Bool
> :: IsolationLevel -> IsolationLevel -> Bool
$c> :: IsolationLevel -> IsolationLevel -> Bool
<= :: IsolationLevel -> IsolationLevel -> Bool
$c<= :: IsolationLevel -> IsolationLevel -> Bool
< :: IsolationLevel -> IsolationLevel -> Bool
$c< :: IsolationLevel -> IsolationLevel -> Bool
compare :: IsolationLevel -> IsolationLevel -> Ordering
$ccompare :: IsolationLevel -> IsolationLevel -> Ordering
$cp1Ord :: Eq IsolationLevel
Ord, Int -> IsolationLevel
IsolationLevel -> Int
IsolationLevel -> [IsolationLevel]
IsolationLevel -> IsolationLevel
IsolationLevel -> IsolationLevel -> [IsolationLevel]
IsolationLevel
-> IsolationLevel -> IsolationLevel -> [IsolationLevel]
(IsolationLevel -> IsolationLevel)
-> (IsolationLevel -> IsolationLevel)
-> (Int -> IsolationLevel)
-> (IsolationLevel -> Int)
-> (IsolationLevel -> [IsolationLevel])
-> (IsolationLevel -> IsolationLevel -> [IsolationLevel])
-> (IsolationLevel -> IsolationLevel -> [IsolationLevel])
-> (IsolationLevel
    -> IsolationLevel -> IsolationLevel -> [IsolationLevel])
-> Enum IsolationLevel
forall a.
(a -> a)
-> (a -> a)
-> (Int -> a)
-> (a -> Int)
-> (a -> [a])
-> (a -> a -> [a])
-> (a -> a -> [a])
-> (a -> a -> a -> [a])
-> Enum a
enumFromThenTo :: IsolationLevel
-> IsolationLevel -> IsolationLevel -> [IsolationLevel]
$cenumFromThenTo :: IsolationLevel
-> IsolationLevel -> IsolationLevel -> [IsolationLevel]
enumFromTo :: IsolationLevel -> IsolationLevel -> [IsolationLevel]
$cenumFromTo :: IsolationLevel -> IsolationLevel -> [IsolationLevel]
enumFromThen :: IsolationLevel -> IsolationLevel -> [IsolationLevel]
$cenumFromThen :: IsolationLevel -> IsolationLevel -> [IsolationLevel]
enumFrom :: IsolationLevel -> [IsolationLevel]
$cenumFrom :: IsolationLevel -> [IsolationLevel]
fromEnum :: IsolationLevel -> Int
$cfromEnum :: IsolationLevel -> Int
toEnum :: Int -> IsolationLevel
$ctoEnum :: Int -> IsolationLevel
pred :: IsolationLevel -> IsolationLevel
$cpred :: IsolationLevel -> IsolationLevel
succ :: IsolationLevel -> IsolationLevel
$csucc :: IsolationLevel -> IsolationLevel
Enum, IsolationLevel
IsolationLevel -> IsolationLevel -> Bounded IsolationLevel
forall a. a -> a -> Bounded a
maxBound :: IsolationLevel
$cmaxBound :: IsolationLevel
minBound :: IsolationLevel
$cminBound :: IsolationLevel
Bounded)

begin :: PQ.Connection -> IsolationLevel -> IO (Either QueryError ())
begin :: Connection -> IsolationLevel -> IO (Either QueryError ())
begin Connection
conn IsolationLevel
level = Connection -> Query Any -> () -> IO (Either QueryError ())
forall p (n :: Nat).
ToSql p =>
Connection -> Query n -> p -> IO (Either QueryError ())
query_ Connection
conn Query Any
q () where
  q :: Query Any
q = case IsolationLevel
level of
    IsolationLevel
ReadCommitted  -> Query Any
"BEGIN ISOLATION LEVEL READ COMMITTED"
    IsolationLevel
RepeatableRead -> Query Any
"BEGIN ISOLATION LEVEL REPEATABLE READ"
    IsolationLevel
Serializable   -> Query Any
"BEGIN ISOLATION LEVEL SERIALIZABLE"

commit :: PQ.Connection -> IO (Either QueryError ())
commit :: Connection -> IO (Either QueryError ())
commit Connection
conn = Connection -> Query Any -> () -> IO (Either QueryError ())
forall p (n :: Nat).
ToSql p =>
Connection -> Query n -> p -> IO (Either QueryError ())
query_ Connection
conn Query Any
"COMMIT" ()

rollback :: PQ.Connection -> IO (Either QueryError ())
rollback :: Connection -> IO (Either QueryError ())
rollback Connection
conn = Connection -> Query Any -> () -> IO (Either QueryError ())
forall p (n :: Nat).
ToSql p =>
Connection -> Query n -> p -> IO (Either QueryError ())
query_ Connection
conn Query Any
"ROLLBACK" ()