module Hasql.Postgres.Session.ResultProcessing where import Hasql.Postgres.Prelude import qualified Database.PostgreSQL.LibPQ as PQ import qualified Hasql.Postgres.ErrorCode as ErrorCode import qualified Data.Text as T import qualified Data.Text.Encoding as TE import qualified Data.Attoparsec.ByteString.Char8 as Atto import qualified Data.Vector as Vector import qualified Data.Vector.Mutable as MVector import qualified ListT newtype M r = M (EitherT Error (ReaderT PQ.Connection IO) r) deriving (Functor, Applicative, Monad, MonadIO) data Error = UnexpectedResult Text | ErroneousResult Text | TransactionConflict run :: PQ.Connection -> M r -> IO (Either Error r) run e (M m) = flip runReaderT e $ runEitherT m just :: Maybe PQ.Result -> M PQ.Result just = ($ return) $ maybe $ M $ do m <- lift $ ask >>= liftIO . PQ.errorMessage left $ ErroneousResult $ case m of Nothing -> "Sending a command to the server failed" Just m -> "Sending a command to the server failed due to: " <> TE.decodeLatin1 m checkStatus :: (PQ.ExecStatus -> Bool) -> PQ.Result -> M () checkStatus g r = do s <- liftIO $ PQ.resultStatus r unless (g s) $ do case s of PQ.BadResponse -> failWithErroneousResult "Bad response" PQ.NonfatalError -> failWithErroneousResult "Non-fatal error" PQ.FatalError -> failWithErroneousResult "Fatal error" _ -> M $ left $ UnexpectedResult $ "Unexpected result status: " <> (fromString $ show s) where failWithErroneousResult status = do code <- liftIO $ PQ.resultErrorField r PQ.DiagSqlstate let transactionConflict = case code of Just x -> elem x $ [ ErrorCode.transaction_rollback, ErrorCode.transaction_integrity_constraint_violation, ErrorCode.serialization_failure, ErrorCode.statement_completion_unknown, ErrorCode.deadlock_detected ] Nothing -> False in when transactionConflict $ M $ left $ TransactionConflict message <- liftIO $ PQ.resultErrorField r PQ.DiagMessagePrimary detail <- liftIO $ PQ.resultErrorField r PQ.DiagMessageDetail hint <- liftIO $ PQ.resultErrorField r PQ.DiagMessageHint M $ left $ ErroneousResult $ erroneousResultMessage status code message detail hint erroneousResultMessage status code message details hint = formatFields fields where formatFields = formatList . map formatField . catMaybes where formatList items = T.intercalate "; " items <> "." formatField (n, v) = n <> ": \"" <> v <> "\"" fields = [ Just ("Status", fromString $ show status), fmap (("Code",) . TE.decodeLatin1) $ code, fmap (("Message",) . TE.decodeLatin1) $ message, fmap (("Details",) . TE.decodeLatin1) $ details, fmap (("Hint",) . TE.decodeLatin1) $ hint ] unit :: PQ.Result -> M () unit r = checkStatus (\case PQ.CommandOk -> True; PQ.TuplesOk -> True; _ -> False) r count :: PQ.Result -> M Word64 count r = do checkStatus (\case PQ.CommandOk -> True; _ -> False) r (liftIO $ PQ.cmdTuples r) >>= maybe (M $ left $ UnexpectedResult $ "No number of affected rows") (parseWord64) parseWord64 :: ByteString -> M Word64 parseWord64 b = either (\m -> M $ left $ UnexpectedResult $ "Couldn't parse Word64: " <> fromString m) (return) (Atto.parseOnly (Atto.decimal <* Atto.endOfInput) b) vector :: PQ.Result -> M (Vector (Vector (Maybe ByteString))) vector r = do checkStatus (\case PQ.TuplesOk -> True; _ -> False) r liftIO $ do nr <- PQ.ntuples r nc <- PQ.nfields r mvx <- MVector.new (rowInt nr) forM_ [0..pred nr] $ \ir -> do mvy <- MVector.new (colInt nc) forM_ [0..pred nc] $ \ic -> do MVector.write mvy (colInt ic) =<< PQ.getvalue r ir ic vy <- Vector.unsafeFreeze mvy MVector.write mvx (rowInt ir) vy Vector.unsafeFreeze mvx {-# INLINE colInt #-} colInt :: PQ.Column -> Int colInt (PQ.Col n) = fromIntegral n {-# INLINE rowInt #-} rowInt :: PQ.Row -> Int rowInt (PQ.Row n) = fromIntegral n