module Database.PostgreSQL.PQTypes.Internal.Utils (
    MkConstraint
  , mread
  , safePeekCString
  , safePeekCString'
  , cStringLenToBytea
  , byteaToCStringLen
  , textToCString
  , verifyPQTRes
  , withPGparam
  , throwLibPQError
  , throwLibPQTypesError
  , rethrowWithArrayError
  , hpqTypesError
  , unexpectedNULL
  ) where

import Control.Monad
import Data.ByteString.Unsafe
import Data.Kind (Type)
import Foreign.C
import Foreign.ForeignPtr
import Foreign.Marshal.Alloc
import Foreign.Marshal.Utils
import Foreign.Ptr
import Foreign.Storable
import GHC.Exts
import qualified Control.Exception as E
import qualified Data.Text as T
import qualified Data.Text.Encoding as T

import Database.PostgreSQL.PQTypes.Internal.C.Interface
import Database.PostgreSQL.PQTypes.Internal.C.Types
import Database.PostgreSQL.PQTypes.Internal.Error

type family MkConstraint (m :: Type -> Type)
                         (cs :: [(Type -> Type) -> Constraint]) :: Constraint where
  MkConstraint m '[] = ()
  MkConstraint m (c ': cs) = (c m, MkConstraint m cs)

-- Safely read value.
mread :: Read a => String -> Maybe a
mread :: String -> Maybe a
mread String
s = do
  [(a
a, String
"")] <- [(a, String)] -> Maybe [(a, String)]
forall a. a -> Maybe a
Just (ReadS a
forall a. Read a => ReadS a
reads String
s)
  a -> Maybe a
forall a. a -> Maybe a
Just a
a

-- | Safely peek C string.
safePeekCString :: CString -> IO (Maybe String)
safePeekCString :: CString -> IO (Maybe String)
safePeekCString CString
cs
  | CString
cs CString -> CString -> Bool
forall a. Eq a => a -> a -> Bool
== CString
forall a. Ptr a
nullPtr = Maybe String -> IO (Maybe String)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe String
forall a. Maybe a
Nothing
  | Bool
otherwise     = String -> Maybe String
forall a. a -> Maybe a
Just (String -> Maybe String) -> IO String -> IO (Maybe String)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> CString -> IO String
peekCString CString
cs

-- | Safely peek C string and return "" if NULL.
safePeekCString' :: CString -> IO String
safePeekCString' :: CString -> IO String
safePeekCString' CString
cs = String -> (String -> String) -> Maybe String -> String
forall b a. b -> (a -> b) -> Maybe a -> b
maybe String
"" String -> String
forall a. a -> a
id (Maybe String -> String) -> IO (Maybe String) -> IO String
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> CString -> IO (Maybe String)
safePeekCString CString
cs

-- | Convert C string to 'PGbytea'.
cStringLenToBytea :: CStringLen -> PGbytea
cStringLenToBytea :: CStringLen -> PGbytea
cStringLenToBytea (CString
cs, Int
len) = PGbytea :: CInt -> CString -> PGbytea
PGbytea {
  pgByteaLen :: CInt
pgByteaLen = Int -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
len
, pgByteaData :: CString
pgByteaData = CString
cs
}

-- | Convert 'PGbytea' to C string.
byteaToCStringLen :: PGbytea -> CStringLen
byteaToCStringLen :: PGbytea -> CStringLen
byteaToCStringLen PGbytea{CString
CInt
pgByteaData :: CString
pgByteaLen :: CInt
pgByteaData :: PGbytea -> CString
pgByteaLen :: PGbytea -> CInt
..} = (CString
pgByteaData, CInt -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral CInt
pgByteaLen)

-- | Convert 'Text' to UTF-8 encoded C string wrapped by foreign pointer.
textToCString :: T.Text -> IO (ForeignPtr CChar)
textToCString :: Text -> IO (ForeignPtr CChar)
textToCString Text
bs = ByteString
-> (CStringLen -> IO (ForeignPtr CChar)) -> IO (ForeignPtr CChar)
forall a. ByteString -> (CStringLen -> IO a) -> IO a
unsafeUseAsCStringLen (Text -> ByteString
T.encodeUtf8 Text
bs) ((CStringLen -> IO (ForeignPtr CChar)) -> IO (ForeignPtr CChar))
-> (CStringLen -> IO (ForeignPtr CChar)) -> IO (ForeignPtr CChar)
forall a b. (a -> b) -> a -> b
$ \(CString
cs, Int
len) -> do
  ForeignPtr CChar
fptr <- Int -> IO (ForeignPtr CChar)
forall a. Int -> IO (ForeignPtr a)
mallocForeignPtrBytes (Int
len Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
  ForeignPtr CChar -> (CString -> IO ()) -> IO ()
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CChar
fptr ((CString -> IO ()) -> IO ()) -> (CString -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \CString
ptr -> do
    CString -> CString -> Int -> IO ()
forall a. Ptr a -> Ptr a -> Int -> IO ()
copyBytes CString
ptr CString
cs Int
len
    CString -> Int -> CChar -> IO ()
forall a b. Storable a => Ptr b -> Int -> a -> IO ()
pokeByteOff CString
ptr Int
len (CChar
0::CChar)
  ForeignPtr CChar -> IO (ForeignPtr CChar)
forall (m :: * -> *) a. Monad m => a -> m a
return ForeignPtr CChar
fptr

-- | Check return value of a function from libpqtypes
-- and if it indicates an error, throw appropriate exception.
verifyPQTRes :: Ptr PGerror -> String -> CInt -> IO ()
verifyPQTRes :: Ptr PGerror -> String -> CInt -> IO ()
verifyPQTRes Ptr PGerror
err String
ctx CInt
0 = Ptr PGerror -> String -> IO ()
forall a. Ptr PGerror -> String -> IO a
throwLibPQTypesError Ptr PGerror
err String
ctx
verifyPQTRes   Ptr PGerror
_   String
_ CInt
_ = () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

-- 'alloca'-like function for managing usage of 'PGparam' object.
withPGparam :: Ptr PGconn -> (Ptr PGparam -> IO r) -> IO r
withPGparam :: Ptr PGconn -> (Ptr PGparam -> IO r) -> IO r
withPGparam Ptr PGconn
conn = IO (Ptr PGparam)
-> (Ptr PGparam -> IO ()) -> (Ptr PGparam -> IO r) -> IO r
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
E.bracket IO (Ptr PGparam)
create Ptr PGparam -> IO ()
c_PQparamClear
  where
    create :: IO (Ptr PGparam)
create = (Ptr PGerror -> IO (Ptr PGparam)) -> IO (Ptr PGparam)
forall a b. Storable a => (Ptr a -> IO b) -> IO b
alloca ((Ptr PGerror -> IO (Ptr PGparam)) -> IO (Ptr PGparam))
-> (Ptr PGerror -> IO (Ptr PGparam)) -> IO (Ptr PGparam)
forall a b. (a -> b) -> a -> b
$ \Ptr PGerror
err -> do
      Ptr PGparam
param <- Ptr PGconn -> Ptr PGerror -> IO (Ptr PGparam)
c_PQparamCreate Ptr PGconn
conn Ptr PGerror
err
      Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Ptr PGparam
param Ptr PGparam -> Ptr PGparam -> Bool
forall a. Eq a => a -> a -> Bool
== Ptr PGparam
forall a. Ptr a
nullPtr) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$
        Ptr PGerror -> String -> IO ()
forall a. Ptr PGerror -> String -> IO a
throwLibPQTypesError Ptr PGerror
err String
"withPGparam.create"
      Ptr PGparam -> IO (Ptr PGparam)
forall (m :: * -> *) a. Monad m => a -> m a
return Ptr PGparam
param

----------------------------------------

-- | Throw libpq specific error.
throwLibPQError :: Ptr PGconn -> String -> IO a
throwLibPQError :: Ptr PGconn -> String -> IO a
throwLibPQError Ptr PGconn
conn String
ctx = do
  String
msg <- CString -> IO String
safePeekCString' (CString -> IO String) -> IO CString -> IO String
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Ptr PGconn -> IO CString
c_PQerrorMessage Ptr PGconn
conn
  LibPQError -> IO a
forall e a. Exception e => e -> IO a
E.throwIO (LibPQError -> IO a) -> (String -> LibPQError) -> String -> IO a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> LibPQError
LibPQError
    (String -> IO a) -> String -> IO a
forall a b. (a -> b) -> a -> b
$ if String -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null String
ctx then String
msg else String
ctx String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
": " String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
msg

-- | Throw libpqtypes specific error.
throwLibPQTypesError :: Ptr PGerror -> String -> IO a
throwLibPQTypesError :: Ptr PGerror -> String -> IO a
throwLibPQTypesError Ptr PGerror
err String
ctx = do
  String
msg <- PGerror -> String
pgErrorMsg (PGerror -> String) -> IO PGerror -> IO String
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Ptr PGerror -> IO PGerror
forall a. Storable a => Ptr a -> IO a
peek Ptr PGerror
err
  LibPQError -> IO a
forall e a. Exception e => e -> IO a
E.throwIO (LibPQError -> IO a) -> (String -> LibPQError) -> String -> IO a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> LibPQError
LibPQError
    (String -> IO a) -> String -> IO a
forall a b. (a -> b) -> a -> b
$ if String -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null String
ctx then String
msg else String
ctx String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
": " String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
msg

-- | Rethrow supplied exception enriched with array index.
rethrowWithArrayError :: CInt -> E.SomeException -> IO a
rethrowWithArrayError :: CInt -> SomeException -> IO a
rethrowWithArrayError CInt
i (E.SomeException e
e) =
  ArrayItemError -> IO a
forall e a. Exception e => e -> IO a
E.throwIO ArrayItemError :: forall e. Exception e => Int -> e -> ArrayItemError
ArrayItemError {
    arrItemIndex :: Int
arrItemIndex = CInt -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral CInt
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1
  , arrItemError :: e
arrItemError = e
e
  }

-- | Throw 'HPQTypesError exception.
hpqTypesError :: String -> IO a
hpqTypesError :: String -> IO a
hpqTypesError = HPQTypesError -> IO a
forall e a. Exception e => e -> IO a
E.throwIO (HPQTypesError -> IO a)
-> (String -> HPQTypesError) -> String -> IO a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> HPQTypesError
HPQTypesError

-- | Throw 'unexpected NULL' exception.
unexpectedNULL :: IO a
unexpectedNULL :: IO a
unexpectedNULL = String -> IO a
forall a. String -> IO a
hpqTypesError String
"unexpected NULL"