{-# LANGUAGE OverloadedStrings, RecordWildCards, GADTs, CPP #-}
-- | PostgreSQL backend for Selda.
module Database.Selda.PostgreSQL
  ( PG, PGConnectInfo (..)
  , withPostgreSQL, on, auth
  , pgOpen, pgOpen', seldaClose
  , pgConnString, pgPPConfig
  ) where
#if !MIN_VERSION_base(4, 11, 0)
import Data.Monoid
#endif
import Data.ByteString (ByteString)
import Data.String (IsString (..))
import qualified Data.Text as T
import Database.Selda.Backend hiding (toText)
import Database.Selda.JSON
import Database.Selda.Unsafe as Selda (cast, operator)
import Control.Monad.Catch
import Control.Monad.IO.Class

#ifndef __HASTE__
import Control.Monad (void)
import qualified Data.ByteString as BS (foldl')
import qualified Data.ByteString.Char8 as BS (pack, unpack)
import Data.Dynamic
import Data.Foldable (for_)
import Data.Int (Int64)
import Data.Text.Encoding
import Database.Selda.PostgreSQL.Encoding
import Database.PostgreSQL.LibPQ hiding (user, pass, db, host)
#endif

data PG

instance JSONBackend PG where
  ~> :: forall a.
JSONValue a =>
Col PG a -> Col PG Text -> Col PG (Maybe Value)
(~>) = forall s a b c. Text -> Col s a -> Col s b -> Col s c
operator Text
"->"
  jsonToText :: Col PG Value -> Col PG Text
jsonToText = forall s a b. SqlType b => Col s a -> Col s b
Selda.cast

-- | PostgreSQL connection information.
data PGConnectInfo = PGConnectInfo
  { -- | Host to connect to.
    PGConnectInfo -> Text
pgHost     :: T.Text
    -- | Port to connect to.
  , PGConnectInfo -> Int
pgPort     :: Int
    -- | Name of database to use.
  , PGConnectInfo -> Text
pgDatabase :: T.Text
    -- | Schema to use upon connection.
  , PGConnectInfo -> Maybe Text
pgSchema   :: Maybe T.Text
    -- | Username for authentication, if necessary.
  , PGConnectInfo -> Maybe Text
pgUsername :: Maybe T.Text
    -- | Password for authentication, if necessary.
  , PGConnectInfo -> Maybe Text
pgPassword :: Maybe T.Text
  }
  | PGConnectionString
  { -- | Custom connection PostgreSQL connection string.
    PGConnectInfo -> Text
pgConnectionString :: T.Text
  , pgSchema :: Maybe T.Text
  }

instance IsString PGConnectInfo where
  fromString :: [Char] -> PGConnectInfo
fromString [Char]
s = PGConnectionString
    { pgConnectionString :: Text
pgConnectionString = forall a. IsString a => [Char] -> a
fromString [Char]
s
    , pgSchema :: Maybe Text
pgSchema = forall a. Maybe a
Nothing
    }

-- | Connect to the given database on the given host, on the default PostgreSQL
--   port (5432):
--
-- > withPostgreSQL ("my_db" `on` "example.com") $ do
-- >   ...
on :: T.Text -> T.Text -> PGConnectInfo
on :: Text -> Text -> PGConnectInfo
on Text
db Text
host = PGConnectInfo
  { pgHost :: Text
pgHost = Text
host
  , pgPort :: Int
pgPort = Int
5432
  , pgDatabase :: Text
pgDatabase = Text
db
  , pgSchema :: Maybe Text
pgSchema   = forall a. Maybe a
Nothing
  , pgUsername :: Maybe Text
pgUsername = forall a. Maybe a
Nothing
  , pgPassword :: Maybe Text
pgPassword = forall a. Maybe a
Nothing
  }
infixl 7 `on`

-- | Add the given username and password to the given connection information:
--
-- > withPostgreSQL ("my_db" `on` "example.com" `auth` ("user", "pass")) $ do
-- >   ...
--
--   For more precise control over the connection options, you should modify
--   the 'PGConnectInfo' directly.
auth :: PGConnectInfo -> (T.Text, T.Text) -> PGConnectInfo
auth :: PGConnectInfo -> (Text, Text) -> PGConnectInfo
auth PGConnectInfo
ci (Text
user, Text
pass) = PGConnectInfo
ci
  { pgUsername :: Maybe Text
pgUsername = forall a. a -> Maybe a
Just Text
user
  , pgPassword :: Maybe Text
pgPassword = forall a. a -> Maybe a
Just Text
pass
  }
infixl 4 `auth`

-- | Convert `PGConnectInfo` into `ByteString`
pgConnString :: PGConnectInfo -> ByteString
#ifdef __HASTE__
pgConnString PGConnectInfo{..} = error "pgConnString called in JS context"
#else
pgConnString :: PGConnectInfo -> ByteString
pgConnString PGConnectInfo{Int
Maybe Text
Text
pgPassword :: Maybe Text
pgUsername :: Maybe Text
pgSchema :: Maybe Text
pgDatabase :: Text
pgPort :: Int
pgHost :: Text
pgPassword :: PGConnectInfo -> Maybe Text
pgUsername :: PGConnectInfo -> Maybe Text
pgSchema :: PGConnectInfo -> Maybe Text
pgDatabase :: PGConnectInfo -> Text
pgPort :: PGConnectInfo -> Int
pgHost :: PGConnectInfo -> Text
..} = forall a. Monoid a => [a] -> a
mconcat
  [ ByteString
"host=", Text -> ByteString
encodeUtf8 Text
pgHost, ByteString
" "
  , ByteString
"port=", [Char] -> ByteString
BS.pack (forall a. Show a => a -> [Char]
show Int
pgPort), ByteString
" "
  , ByteString
"dbname=", Text -> ByteString
encodeUtf8 Text
pgDatabase, ByteString
" "
  , case Maybe Text
pgUsername of
      Just Text
user -> ByteString
"user=" forall a. Semigroup a => a -> a -> a
<> Text -> ByteString
encodeUtf8 Text
user forall a. Semigroup a => a -> a -> a
<> ByteString
" "
      Maybe Text
_         -> ByteString
""
  , case Maybe Text
pgPassword of
      Just Text
pass -> ByteString
"password=" forall a. Semigroup a => a -> a -> a
<> Text -> ByteString
encodeUtf8 Text
pass forall a. Semigroup a => a -> a -> a
<> ByteString
" "
      Maybe Text
_         -> ByteString
""
  , ByteString
"connect_timeout=10", ByteString
" "
  , ByteString
"client_encoding=UTF8"
  ]
pgConnString PGConnectionString{Maybe Text
Text
pgSchema :: Maybe Text
pgConnectionString :: Text
pgConnectionString :: PGConnectInfo -> Text
pgSchema :: PGConnectInfo -> Maybe Text
..} = Text -> ByteString
encodeUtf8 Text
pgConnectionString
#endif

-- | Perform the given computation over a PostgreSQL database.
--   The database connection is guaranteed to be closed when the computation
--   terminates.
withPostgreSQL :: (MonadIO m, MonadMask m)
               => PGConnectInfo
               -> SeldaT PG m a
               -> m a
#ifdef __HASTE__
withPostgreSQL _ _ = return $ error "withPostgreSQL called in JS context"
#else
withPostgreSQL :: forall (m :: * -> *) a.
(MonadIO m, MonadMask m) =>
PGConnectInfo -> SeldaT PG m a -> m a
withPostgreSQL PGConnectInfo
ci SeldaT PG m a
m = forall (m :: * -> *) a c b.
MonadMask m =>
m a -> (a -> m c) -> (a -> m b) -> m b
bracket (forall (m :: * -> *).
(MonadIO m, MonadMask m) =>
PGConnectInfo -> m (SeldaConnection PG)
pgOpen PGConnectInfo
ci) forall (m :: * -> *) b. MonadIO m => SeldaConnection b -> m ()
seldaClose (forall (m :: * -> *) b a.
(MonadIO m, MonadMask m) =>
SeldaT b m a -> SeldaConnection b -> m a
runSeldaT SeldaT PG m a
m)
#endif

-- | Open a new PostgreSQL connection. The connection will persist across
--   calls to 'runSeldaT', and must be explicitly closed using 'seldaClose'
--   when no longer needed.
pgOpen :: (MonadIO m, MonadMask m) => PGConnectInfo -> m (SeldaConnection PG)
pgOpen :: forall (m :: * -> *).
(MonadIO m, MonadMask m) =>
PGConnectInfo -> m (SeldaConnection PG)
pgOpen PGConnectInfo
ci = forall (m :: * -> *).
(MonadIO m, MonadMask m) =>
Maybe Text -> ByteString -> m (SeldaConnection PG)
pgOpen' (PGConnectInfo -> Maybe Text
pgSchema PGConnectInfo
ci) (PGConnectInfo -> ByteString
pgConnString PGConnectInfo
ci)

pgPPConfig :: PPConfig
pgOpen' :: (MonadIO m, MonadMask m)
        => Maybe T.Text
        -> ByteString
        -> m (SeldaConnection PG)
#ifdef __HASTE__
pgOpen' _ _ = return $ error "pgOpen' called in JS context"
pgPPConfig = error "pgPPConfig evaluated in JS context"
#else
pgOpen' :: forall (m :: * -> *).
(MonadIO m, MonadMask m) =>
Maybe Text -> ByteString -> m (SeldaConnection PG)
pgOpen' Maybe Text
schema ByteString
connStr =
  forall (m :: * -> *) a c b.
MonadMask m =>
m a -> (a -> m c) -> (a -> m b) -> m b
bracketOnError (forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ ByteString -> IO Connection
connectdb ByteString
connStr) (forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall b c a. (b -> c) -> (a -> b) -> a -> c
. Connection -> IO ()
finish) forall a b. (a -> b) -> a -> b
$ \Connection
conn -> do
    ConnStatus
st <- forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ Connection -> IO ConnStatus
status Connection
conn
    case ConnStatus
st of
      ConnStatus
ConnectionOk -> do
        let backend :: SeldaBackend PG
backend = Connection -> SeldaBackend PG
pgBackend Connection
conn

        (Int, [[SqlValue]])
_ <- forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ forall b.
SeldaBackend b -> Text -> [Param] -> IO (Int, [[SqlValue]])
runStmt SeldaBackend PG
backend Text
"SET client_min_messages TO WARNING;" []

        forall (t :: * -> *) (f :: * -> *) a b.
(Foldable t, Applicative f) =>
t a -> (a -> f b) -> f ()
for_ Maybe Text
schema forall a b. (a -> b) -> a -> b
$ \Text
schema' ->
          forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ forall b.
SeldaBackend b -> Text -> [Param] -> IO (Int, [[SqlValue]])
runStmt SeldaBackend PG
backend (Text
"SET search_path TO '" forall a. Semigroup a => a -> a -> a
<> Text
schema' forall a. Semigroup a => a -> a -> a
<> Text
"';") []

        forall (m :: * -> *) b.
MonadIO m =>
SeldaBackend b -> Text -> m (SeldaConnection b)
newConnection SeldaBackend PG
backend (ByteString -> Text
decodeUtf8 ByteString
connStr)
      ConnStatus
nope -> do
        forall {m :: * -> *} {a} {a}. (MonadThrow m, Show a) => a -> m a
connFailed ConnStatus
nope
    where
      connFailed :: a -> m a
connFailed a
f = forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM forall a b. (a -> b) -> a -> b
$ [Char] -> SeldaError
DbError forall a b. (a -> b) -> a -> b
$ [[Char]] -> [Char]
unwords
        [ [Char]
"unable to connect to postgres server: " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> [Char]
show a
f
        ]

pgPPConfig :: PPConfig
pgPPConfig = PPConfig
defPPConfig
    { ppType :: SqlTypeRep -> Text
ppType = PPConfig -> SqlTypeRep -> Text
pgColType PPConfig
defPPConfig
    , ppTypeHook :: SqlTypeRep -> [ColAttr] -> (SqlTypeRep -> Text) -> Text
ppTypeHook = SqlTypeRep -> [ColAttr] -> (SqlTypeRep -> Text) -> Text
pgTypeHook
    , ppTypePK :: SqlTypeRep -> Text
ppTypePK = PPConfig -> SqlTypeRep -> Text
pgColTypePK PPConfig
defPPConfig
    , ppAutoIncInsert :: Text
ppAutoIncInsert = Text
"DEFAULT"
    , ppColAttrs :: [ColAttr] -> Text
ppColAttrs = [Text] -> Text
T.unwords forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a -> b) -> [a] -> [b]
map ColAttr -> Text
pgColAttr
    , ppColAttrsHook :: SqlTypeRep -> [ColAttr] -> ([ColAttr] -> Text) -> Text
ppColAttrsHook = SqlTypeRep -> [ColAttr] -> ([ColAttr] -> Text) -> Text
pgColAttrsHook
    , ppIndexMethodHook :: IndexMethod -> Text
ppIndexMethodHook = (Text
" USING " forall a. Semigroup a => a -> a -> a
<>) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {a}. IsString a => IndexMethod -> a
compileIndexMethod
    }
  where
    compileIndexMethod :: IndexMethod -> a
compileIndexMethod IndexMethod
BTreeIndex = a
"btree"
    compileIndexMethod IndexMethod
HashIndex  = a
"hash"

    pgTypeHook :: SqlTypeRep -> [ColAttr] -> (SqlTypeRep -> T.Text) -> T.Text
    pgTypeHook :: SqlTypeRep -> [ColAttr] -> (SqlTypeRep -> Text) -> Text
pgTypeHook SqlTypeRep
ty [ColAttr]
attrs SqlTypeRep -> Text
fun
      | SqlTypeRep -> [ColAttr] -> Bool
isGenericIntPrimaryKey SqlTypeRep
ty [ColAttr]
attrs = PPConfig -> SqlTypeRep -> Text
pgColTypePK PPConfig
pgPPConfig SqlTypeRep
TRowID
      | Bool
otherwise                       = forall {a}. IsString a => (SqlTypeRep -> a) -> SqlTypeRep -> a
pgTypeRenameHook SqlTypeRep -> Text
fun SqlTypeRep
ty

    pgTypeRenameHook :: (SqlTypeRep -> a) -> SqlTypeRep -> a
pgTypeRenameHook SqlTypeRep -> a
_ SqlTypeRep
TDateTime = a
"timestamp with time zone"
    pgTypeRenameHook SqlTypeRep -> a
_ SqlTypeRep
TTime     = a
"time with time zone"
    pgTypeRenameHook SqlTypeRep -> a
f SqlTypeRep
ty        = SqlTypeRep -> a
f SqlTypeRep
ty

    pgColAttrsHook :: SqlTypeRep -> [ColAttr] -> ([ColAttr] -> T.Text) -> T.Text
    pgColAttrsHook :: SqlTypeRep -> [ColAttr] -> ([ColAttr] -> Text) -> Text
pgColAttrsHook SqlTypeRep
ty [ColAttr]
attrs [ColAttr] -> Text
fun
      | SqlTypeRep -> [ColAttr] -> Bool
isGenericIntPrimaryKey SqlTypeRep
ty [ColAttr]
attrs = [ColAttr] -> Text
fun [AutoIncType -> ColAttr
AutoPrimary AutoIncType
Strong]
      | Bool
otherwise = [ColAttr] -> Text
fun [ColAttr]
attrs

    bigserialQue :: [ColAttr]
    bigserialQue :: [ColAttr]
bigserialQue = [AutoIncType -> ColAttr
AutoPrimary AutoIncType
Strong, ColAttr
Required]

    -- For when we use 'autoPrimaryGen' on 'Int' field
    isGenericIntPrimaryKey :: SqlTypeRep -> [ColAttr] -> Bool
    isGenericIntPrimaryKey :: SqlTypeRep -> [ColAttr] -> Bool
isGenericIntPrimaryKey SqlTypeRep
ty [ColAttr]
attrs = SqlTypeRep
ty forall a. Eq a => a -> a -> Bool
== SqlTypeRep
TInt64 Bool -> Bool -> Bool
&& forall (t :: * -> *). Foldable t => t Bool -> Bool
and ((forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [ColAttr]
attrs) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [ColAttr]
bigserialQue)

-- | Create a `SeldaBackend` for PostgreSQL `Connection`
pgBackend :: Connection   -- ^ PostgreSQL connection object.
          -> SeldaBackend PG
pgBackend :: Connection -> SeldaBackend PG
pgBackend Connection
c = SeldaBackend
  { runStmt :: Text -> [Param] -> IO (Int, [[SqlValue]])
runStmt         = \Text
q [Param]
ps -> forall {a} {b}. Either a b -> b
right forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Connection
-> Bool -> Text -> [Param] -> IO (Either Int64 (Int, [[SqlValue]]))
pgQueryRunner Connection
c Bool
False Text
q [Param]
ps
  , runStmtWithPK :: Text -> [Param] -> IO Int64
runStmtWithPK   = \Text
q [Param]
ps -> forall {a} {b}. Either a b -> a
left forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Connection
-> Bool -> Text -> [Param] -> IO (Either Int64 (Int, [[SqlValue]]))
pgQueryRunner Connection
c Bool
True Text
q [Param]
ps
  , prepareStmt :: StmtID -> [SqlTypeRep] -> Text -> IO Dynamic
prepareStmt     = Connection -> StmtID -> [SqlTypeRep] -> Text -> IO Dynamic
pgPrepare Connection
c
  , runPrepared :: Dynamic -> [Param] -> IO (Int, [[SqlValue]])
runPrepared     = Connection -> Dynamic -> [Param] -> IO (Int, [[SqlValue]])
pgRun Connection
c
  , getTableInfo :: TableName -> IO TableInfo
getTableInfo    = Connection -> Text -> IO TableInfo
pgGetTableInfo Connection
c forall b c a. (b -> c) -> (a -> b) -> a -> c
. TableName -> Text
rawTableName
  , backendId :: BackendID
backendId       = BackendID
PostgreSQL
  , ppConfig :: PPConfig
ppConfig        = PPConfig
pgPPConfig
  , closeConnection :: SeldaConnection PG -> IO ()
closeConnection = \SeldaConnection PG
_ -> Connection -> IO ()
finish Connection
c
  , disableForeignKeys :: Bool -> IO ()
disableForeignKeys = Connection -> Bool -> IO ()
disableFKs Connection
c
  }
  where
    left :: Either a b -> a
left (Left a
x) = a
x
    left Either a b
_        = forall a. HasCallStack => [Char] -> a
error [Char]
"impossible"
    right :: Either a b -> b
right (Right b
x) = b
x
    right Either a b
_         = forall a. HasCallStack => [Char] -> a
error [Char]
"impossible"

-- Solution to disable FKs from
-- <https://dba.stackexchange.com/questions/96961/how-to-temporarily-disable-foreign-keys-in-amazon-rds-postgresql>
disableFKs :: Connection -> Bool -> IO ()
disableFKs :: Connection -> Bool -> IO ()
disableFKs Connection
c Bool
True = do
    forall (f :: * -> *) a. Functor f => f a -> f ()
void forall a b. (a -> b) -> a -> b
$ Connection
-> Bool -> Text -> [Param] -> IO (Either Int64 (Int, [[SqlValue]]))
pgQueryRunner Connection
c Bool
False Text
"BEGIN TRANSACTION;" []
    forall (f :: * -> *) a. Functor f => f a -> f ()
void forall a b. (a -> b) -> a -> b
$ Connection
-> Bool -> Text -> [Param] -> IO (Either Int64 (Int, [[SqlValue]]))
pgQueryRunner Connection
c Bool
False Text
create []
    forall (f :: * -> *) a. Functor f => f a -> f ()
void forall a b. (a -> b) -> a -> b
$ Connection
-> Bool -> Text -> [Param] -> IO (Either Int64 (Int, [[SqlValue]]))
pgQueryRunner Connection
c Bool
False Text
dropTbl []
  where
    create :: Text
create = forall a. Monoid a => [a] -> a
mconcat
      [ Text
"create table if not exists __selda_dropped_fks ("
      , Text
"        seq bigserial primary key,"
      , Text
"        sql text"
      , Text
");"
      ]
    dropTbl :: Text
dropTbl = forall a. Monoid a => [a] -> a
mconcat
      [ Text
"do $$ declare t record;"
      , Text
"begin"
      , Text
"    for t in select conrelid::regclass::varchar table_name, conname constraint_name,"
      , Text
"            pg_catalog.pg_get_constraintdef(r.oid, true) constraint_definition"
      , Text
"            from pg_catalog.pg_constraint r"
      , Text
"            where r.contype = 'f'"
      , Text
"            and r.connamespace = (select n.oid from pg_namespace n where n.nspname = current_schema())"
      , Text
"        loop"
      , Text
"        insert into __selda_dropped_fks (sql) values ("
      , Text
"            format('alter table if exists %s add constraint %s %s',"
      , Text
"                quote_ident(t.table_name), quote_ident(t.constraint_name), t.constraint_definition));"
      , Text
"        execute format('alter table %s drop constraint %s', quote_ident(t.table_name), quote_ident(t.constraint_name));"
      , Text
"    end loop;"
      , Text
"end $$;"
      ]
disableFKs Connection
c Bool
False = do
    forall (f :: * -> *) a. Functor f => f a -> f ()
void forall a b. (a -> b) -> a -> b
$ Connection
-> Bool -> Text -> [Param] -> IO (Either Int64 (Int, [[SqlValue]]))
pgQueryRunner Connection
c Bool
False Text
restore []
    forall (f :: * -> *) a. Functor f => f a -> f ()
void forall a b. (a -> b) -> a -> b
$ Connection
-> Bool -> Text -> [Param] -> IO (Either Int64 (Int, [[SqlValue]]))
pgQueryRunner Connection
c Bool
False Text
"DROP TABLE __selda_dropped_fks;" []
    forall (f :: * -> *) a. Functor f => f a -> f ()
void forall a b. (a -> b) -> a -> b
$ Connection
-> Bool -> Text -> [Param] -> IO (Either Int64 (Int, [[SqlValue]]))
pgQueryRunner Connection
c Bool
False Text
"COMMIT;" []
  where
    restore :: Text
restore = forall a. Monoid a => [a] -> a
mconcat
      [ Text
"do $$ declare t record;"
      , Text
"begin"
      , Text
"    for t in select * from __selda_dropped_fks order by seq loop"
      , Text
"        execute t.sql;"
      , Text
"        delete from __selda_dropped_fks where seq = t.seq;"
      , Text
"    end loop;"
      , Text
"end $$;"
      ]

pgGetTableInfo :: Connection -> T.Text -> IO TableInfo
pgGetTableInfo :: Connection -> Text -> IO TableInfo
pgGetTableInfo Connection
c Text
tbl = do
    Right (Int
_, [[SqlValue]]
vals) <- Connection
-> Bool -> Text -> [Param] -> IO (Either Int64 (Int, [[SqlValue]]))
pgQueryRunner Connection
c Bool
False Text
tableinfo []
    if forall (t :: * -> *) a. Foldable t => t a -> Bool
null [[SqlValue]]
vals
      then do
        forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ TableName -> [ColumnInfo] -> [[ColName]] -> [ColName] -> TableInfo
TableInfo (Text -> TableName
mkTableName Text
tbl) [] [] []
      else do
        Right (Int
_, [[SqlValue]]
pkInfo) <- Connection
-> Bool -> Text -> [Param] -> IO (Either Int64 (Int, [[SqlValue]]))
pgQueryRunner Connection
c Bool
False Text
pkquery []
        Right (Int
_, [[SqlValue]]
us) <- Connection
-> Bool -> Text -> [Param] -> IO (Either Int64 (Int, [[SqlValue]]))
pgQueryRunner Connection
c Bool
False Text
uniquequery []
        let uniques :: [[Text]]
uniques = forall a b. (a -> b) -> [a] -> [b]
map [SqlValue] -> [Text]
splitNames [[SqlValue]]
us
        Right (Int
_, [[SqlValue]]
fks) <- Connection
-> Bool -> Text -> [Param] -> IO (Either Int64 (Int, [[SqlValue]]))
pgQueryRunner Connection
c Bool
False Text
fkquery []
        Right (Int
_, [[SqlValue]]
ixs) <- Connection
-> Bool -> Text -> [Param] -> IO (Either Int64 (Int, [[SqlValue]]))
pgQueryRunner Connection
c Bool
False Text
ixquery []
        [ColumnInfo]
colInfos <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall {m :: * -> *} {t :: * -> *}.
(Foldable t, MonadThrow m) =>
[[SqlValue]] -> t Text -> [SqlValue] -> m ColumnInfo
describe [[SqlValue]]
fks (forall a b. (a -> b) -> [a] -> [b]
map [SqlValue] -> Text
toText [[SqlValue]]
ixs)) [[SqlValue]]
vals
        TableInfo
x <- forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ TableInfo
          { tableInfoName :: TableName
tableInfoName = Text -> TableName
mkTableName Text
tbl
          , tableColumnInfos :: [ColumnInfo]
tableColumnInfos = [ColumnInfo]
colInfos
          , tableUniqueGroups :: [[ColName]]
tableUniqueGroups = forall a b. (a -> b) -> [a] -> [b]
map (forall a b. (a -> b) -> [a] -> [b]
map Text -> ColName
mkColName) [[Text]]
uniques
          , tablePrimaryKey :: [ColName]
tablePrimaryKey = [Text -> ColName
mkColName Text
pk | [SqlString Text
pk] <- [[SqlValue]]
pkInfo]
          }
        forall (f :: * -> *) a. Applicative f => a -> f a
pure TableInfo
x
  where
    splitNames :: [SqlValue] -> [Text]
splitNames = Text -> [Text]
breakNames forall b c a. (b -> c) -> (a -> b) -> a -> c
. [SqlValue] -> Text
toText
    -- TODO: this is super ugly; should really be fixed
    breakNames :: Text -> [Text]
breakNames Text
s =
      case (Char -> Bool) -> Text -> (Text, Text)
T.break (forall a. Eq a => a -> a -> Bool
== Char
'"') Text
s of
        (Text
n, Text
ns) | Text -> Bool
T.null Text
n  -> []
                | Text -> Bool
T.null Text
ns -> [Text
n]
                | Bool
otherwise -> Text
n forall a. a -> [a] -> [a]
: Text -> [Text]
breakNames (Text -> Text
T.tail Text
ns)
    toText :: [SqlValue] -> Text
toText [SqlString Text
s] = Text
s
    toText [SqlValue]
_             = forall a. HasCallStack => [Char] -> a
error [Char]
"toText: unreachable"
    tableinfo :: Text
tableinfo = forall a. Monoid a => [a] -> a
mconcat
      [ Text
"SELECT column_name, data_type, is_nullable, column_default LIKE 'nextval(%' "
      , Text
"FROM information_schema.columns "
      , Text
"WHERE table_name = '", Text
tbl, Text
"';"
      ]
    pkquery :: Text
pkquery = forall a. Monoid a => [a] -> a
mconcat
      [ Text
"SELECT a.attname "
      , Text
"FROM pg_index i "
      , Text
"JOIN pg_attribute a ON a.attrelid = i.indrelid "
      , Text
"  AND a.attnum = ANY(i.indkey) "
      , Text
"WHERE i.indrelid = '\"", Text
tbl, Text
"\"'::regclass "
      , Text
"  AND i.indisprimary;"
      ]
    uniquequery :: Text
uniquequery = forall a. Monoid a => [a] -> a
mconcat
      [ Text
"SELECT string_agg(a.attname, '\"') "
      , Text
"FROM pg_index i "
      , Text
"JOIN pg_attribute a ON a.attrelid = i.indrelid "
      , Text
"  AND a.attnum = ANY(i.indkey) "
      , Text
"WHERE i.indrelid = '\"", Text
tbl, Text
"\"'::regclass "
      , Text
"  AND i.indisunique "
      , Text
"  AND NOT i.indisprimary "
      , Text
"GROUP BY i.indkey;"
      ]
    fkquery :: Text
fkquery = forall a. Monoid a => [a] -> a
mconcat
      [ Text
"SELECT kcu.column_name, ccu.table_name, ccu.column_name "
      , Text
"FROM information_schema.table_constraints AS tc "
      , Text
"JOIN information_schema.key_column_usage AS kcu "
      , Text
"  ON tc.constraint_name = kcu.constraint_name "
      , Text
"JOIN information_schema.constraint_column_usage AS ccu "
      , Text
"  ON ccu.constraint_name = tc.constraint_name "
      , Text
"WHERE constraint_type = 'FOREIGN KEY' AND tc.table_name='", Text
tbl, Text
"';"
      ]
    ixquery :: Text
ixquery = forall a. Monoid a => [a] -> a
mconcat
      [ Text
"select a.attname as column_name "
      , Text
"from pg_class t, pg_class i, pg_index ix, pg_attribute a "
      , Text
"where "
      , Text
"t.oid = ix.indrelid "
      , Text
"and i.oid = ix.indexrelid "
      , Text
"and a.attrelid = t.oid "
      , Text
"and a.attnum = ANY(ix.indkey) "
      , Text
"and t.relkind = 'r' "
      , Text
"and not ix.indisunique "
      , Text
"and not ix.indisprimary "
      , Text
"and t.relkind = 'r' "
      , Text
"and t.relname = '", Text
tbl , Text
"';"
      ]
    describe :: [[SqlValue]] -> t Text -> [SqlValue] -> m ColumnInfo
describe [[SqlValue]]
fks t Text
ixs [SqlString Text
name, SqlString Text
ty, SqlString Text
nullable, SqlValue
auto] =
      forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ ColumnInfo
        { colName :: ColName
colName = Text -> ColName
mkColName Text
name
        , colType :: Either Text SqlTypeRep
colType = Text -> Either Text SqlTypeRep
mkTypeRep Text
ty'
        , colIsAutoPrimary :: Bool
colIsAutoPrimary = SqlValue -> Bool
isAuto SqlValue
auto
        , colIsNullable :: Bool
colIsNullable = Text -> Bool
readBool Text
nullable
        , colHasIndex :: Bool
colHasIndex = Text
name forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` t Text
ixs
        , colFKs :: [(TableName, ColName)]
colFKs =
            [ (Text -> TableName
mkTableName Text
tblname, Text -> ColName
mkColName Text
col)
            | [SqlString Text
cname, SqlString Text
tblname, SqlString Text
col] <- [[SqlValue]]
fks
            , Text
name forall a. Eq a => a -> a -> Bool
== Text
cname
            ]
        }
      where
        ty' :: Text
ty' = Text -> Text
T.toLower Text
ty
        isAuto :: SqlValue -> Bool
isAuto (SqlBool Bool
x) = Bool
x
        isAuto SqlValue
_           = Bool
False
    describe [[SqlValue]]
_ t Text
_ [SqlValue]
results =
      forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM forall a b. (a -> b) -> a -> b
$ [Char] -> SeldaError
SqlError forall a b. (a -> b) -> a -> b
$ [Char]
"bad result from table info query: " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> [Char]
show [SqlValue]
results

pgQueryRunner :: Connection -> Bool -> T.Text -> [Param] -> IO (Either Int64 (Int, [[SqlValue]]))
pgQueryRunner :: Connection
-> Bool -> Text -> [Param] -> IO (Either Int64 (Int, [[SqlValue]]))
pgQueryRunner Connection
c Bool
return_lastid Text
q [Param]
ps = do
    Maybe Result
mres <- Connection
-> ByteString
-> [Maybe (Oid, ByteString, Format)]
-> Format
-> IO (Maybe Result)
execParams Connection
c (Text -> ByteString
encodeUtf8 Text
q') [forall a. Lit a -> Maybe (Oid, ByteString, Format)
fromSqlValue Lit a
p | Param Lit a
p <- [Param]
ps] Format
Binary
    forall a.
Connection -> [Char] -> Maybe Result -> (Result -> IO a) -> IO a
unlessError Connection
c [Char]
errmsg Maybe Result
mres forall a b. (a -> b) -> a -> b
$ \Result
res -> do
      if Bool
return_lastid
        then forall a b. a -> Either a b
Left forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Result -> IO Int64
getLastId Result
res
        else forall a b. b -> Either a b
Right forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Result -> IO (Int, [[SqlValue]])
getRows Result
res
  where
    errmsg :: [Char]
errmsg = [Char]
"error executing query `" forall a. [a] -> [a] -> [a]
++ Text -> [Char]
T.unpack Text
q' forall a. [a] -> [a] -> [a]
++ [Char]
"'"
    q' :: Text
q' | Bool
return_lastid = Text
q forall a. Semigroup a => a -> a -> a
<> Text
" RETURNING LASTVAL();"
       | Bool
otherwise     = Text
q

    getLastId :: Result -> IO Int64
getLastId Result
res = (forall b a. b -> (a -> b) -> Maybe a -> b
maybe Int64
0 forall a. a -> a
id forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ByteString -> Int64
readInt64) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Result -> Row -> Column -> IO (Maybe ByteString)
getvalue Result
res Row
0 Column
0

pgRun :: Connection -> Dynamic -> [Param] -> IO (Int, [[SqlValue]])
pgRun :: Connection -> Dynamic -> [Param] -> IO (Int, [[SqlValue]])
pgRun Connection
c Dynamic
hdl [Param]
ps = do
    let Just StmtID
sid = forall a. Typeable a => Dynamic -> Maybe a
fromDynamic Dynamic
hdl :: Maybe StmtID
    Maybe Result
mres <- Connection
-> ByteString
-> [Maybe (ByteString, Format)]
-> Format
-> IO (Maybe Result)
execPrepared Connection
c ([Char] -> ByteString
BS.pack forall a b. (a -> b) -> a -> b
$ forall a. Show a => a -> [Char]
show StmtID
sid) (forall a b. (a -> b) -> [a] -> [b]
map Param -> Maybe (ByteString, Format)
mkParam [Param]
ps) Format
Binary
    forall a.
Connection -> [Char] -> Maybe Result -> (Result -> IO a) -> IO a
unlessError Connection
c [Char]
errmsg Maybe Result
mres forall a b. (a -> b) -> a -> b
$ Result -> IO (Int, [[SqlValue]])
getRows
  where
    errmsg :: [Char]
errmsg = [Char]
"error executing prepared statement"
    mkParam :: Param -> Maybe (ByteString, Format)
mkParam (Param Lit a
p) = case forall a. Lit a -> Maybe (Oid, ByteString, Format)
fromSqlValue Lit a
p of
      Just (Oid
_, ByteString
val, Format
fmt) -> forall a. a -> Maybe a
Just (ByteString
val, Format
fmt)
      Maybe (Oid, ByteString, Format)
Nothing            -> forall a. Maybe a
Nothing

-- | Get all rows from a result.
getRows :: Result -> IO (Int, [[SqlValue]])
getRows :: Result -> IO (Int, [[SqlValue]])
getRows Result
res = do
    Row
rows <- Result -> IO Row
ntuples Result
res
    Column
cols <- Result -> IO Column
nfields Result
res
    [Oid]
types <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Result -> Column -> IO Oid
ftype Result
res) [Column
0..Column
colsforall a. Num a => a -> a -> a
-Column
1]
    Maybe ByteString
affected <- Result -> IO (Maybe ByteString)
cmdTuples Result
res
    [[SqlValue]]
result <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Result -> [Oid] -> Column -> Row -> IO [SqlValue]
getRow Result
res [Oid]
types Column
cols) [Row
0..Row
rowsforall a. Num a => a -> a -> a
-Row
1]
    forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ case Maybe ByteString
affected of
      Just ByteString
"" -> (Int
0, [[SqlValue]]
result)
      Just ByteString
s  -> (ByteString -> Int
bsToPositiveInt ByteString
s, [[SqlValue]]
result)
      Maybe ByteString
_       -> (Int
0, [[SqlValue]]
result)
  where
    bsToPositiveInt :: ByteString -> Int
bsToPositiveInt = forall a. (a -> Word8 -> a) -> a -> ByteString -> a
BS.foldl' (\Int
a Word8
x -> Int
aforall a. Num a => a -> a -> a
*Int
10forall a. Num a => a -> a -> a
+forall a b. (Integral a, Num b) => a -> b
fromIntegral Word8
xforall a. Num a => a -> a -> a
-Int
48) Int
0


-- | Get all columns for the given row.
getRow :: Result -> [Oid] -> Column -> Row -> IO [SqlValue]
getRow :: Result -> [Oid] -> Column -> Row -> IO [SqlValue]
getRow Result
res [Oid]
types Column
cols Row
row = do
  forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
sequence forall a b. (a -> b) -> a -> b
$ forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (Result -> Row -> Column -> Oid -> IO SqlValue
getCol Result
res Row
row) [Column
0..Column
colsforall a. Num a => a -> a -> a
-Column
1] [Oid]
types

-- | Get the given column.
getCol :: Result -> Row -> Column -> Oid -> IO SqlValue
getCol :: Result -> Row -> Column -> Oid -> IO SqlValue
getCol Result
res Row
row Column
col Oid
t = do
  Maybe ByteString
mval <- Result -> Row -> Column -> IO (Maybe ByteString)
getvalue Result
res Row
row Column
col
  case Maybe ByteString
mval of
    Just ByteString
val -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Oid -> ByteString -> SqlValue
toSqlValue Oid
t ByteString
val
    Maybe ByteString
_        -> forall (f :: * -> *) a. Applicative f => a -> f a
pure SqlValue
SqlNull

pgPrepare :: Connection -> StmtID -> [SqlTypeRep] -> T.Text -> IO Dynamic
pgPrepare :: Connection -> StmtID -> [SqlTypeRep] -> Text -> IO Dynamic
pgPrepare Connection
c StmtID
sid [SqlTypeRep]
types Text
q = do
    Maybe Result
mres <- Connection
-> ByteString -> ByteString -> Maybe [Oid] -> IO (Maybe Result)
prepare Connection
c ([Char] -> ByteString
BS.pack forall a b. (a -> b) -> a -> b
$ forall a. Show a => a -> [Char]
show StmtID
sid) (Text -> ByteString
encodeUtf8 Text
q) (forall a. a -> Maybe a
Just [Oid]
types')
    forall a.
Connection -> [Char] -> Maybe Result -> (Result -> IO a) -> IO a
unlessError Connection
c [Char]
errmsg Maybe Result
mres forall a b. (a -> b) -> a -> b
$ \Result
_ -> forall (m :: * -> *) a. Monad m => a -> m a
return (forall a. Typeable a => a -> Dynamic
toDyn StmtID
sid)
  where
    types' :: [Oid]
types' = forall a b. (a -> b) -> [a] -> [b]
map SqlTypeRep -> Oid
fromSqlType [SqlTypeRep]
types
    errmsg :: [Char]
errmsg = [Char]
"error preparing query `" forall a. [a] -> [a] -> [a]
++ Text -> [Char]
T.unpack Text
q forall a. [a] -> [a] -> [a]
++ [Char]
"'"

-- | Perform the given computation unless an error occurred previously.
unlessError :: Connection -> String -> Maybe Result -> (Result -> IO a) -> IO a
unlessError :: forall a.
Connection -> [Char] -> Maybe Result -> (Result -> IO a) -> IO a
unlessError Connection
c [Char]
msg Maybe Result
mres Result -> IO a
m = do
  case Maybe Result
mres of
    Just Result
res -> do
      ExecStatus
st <- Result -> IO ExecStatus
resultStatus Result
res
      case ExecStatus
st of
        ExecStatus
BadResponse   -> forall a. Connection -> [Char] -> IO a
doError Connection
c [Char]
msg
        ExecStatus
FatalError    -> forall a. Connection -> [Char] -> IO a
doError Connection
c [Char]
msg
        ExecStatus
NonfatalError -> forall a. Connection -> [Char] -> IO a
doError Connection
c [Char]
msg
        ExecStatus
_             -> Result -> IO a
m Result
res
    Maybe Result
Nothing -> forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM forall a b. (a -> b) -> a -> b
$ [Char] -> SeldaError
DbError [Char]
"unable to submit query to server"

doError :: Connection -> String -> IO a
doError :: forall a. Connection -> [Char] -> IO a
doError Connection
c [Char]
msg = do
  Maybe ByteString
me <- Connection -> IO (Maybe ByteString)
errorMessage Connection
c
  forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM forall a b. (a -> b) -> a -> b
$ [Char] -> SeldaError
SqlError forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat
    [ [Char]
msg
    , forall b a. b -> (a -> b) -> Maybe a -> b
maybe [Char]
"" (([Char]
": " forall a. [a] -> [a] -> [a]
++) forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> [Char]
BS.unpack) Maybe ByteString
me
    ]

mkTypeRep :: T.Text ->  Either T.Text SqlTypeRep
mkTypeRep :: Text -> Either Text SqlTypeRep
mkTypeRep Text
"bigserial"                = forall a b. b -> Either a b
Right SqlTypeRep
TRowID
mkTypeRep Text
"int4"                     = forall a b. b -> Either a b
Right SqlTypeRep
TInt32
mkTypeRep Text
"int8"                     = forall a b. b -> Either a b
Right SqlTypeRep
TInt64
mkTypeRep Text
"bigint"                   = forall a b. b -> Either a b
Right SqlTypeRep
TInt64
mkTypeRep Text
"float8"                   = forall a b. b -> Either a b
Right SqlTypeRep
TFloat
mkTypeRep Text
"double precision"         = forall a b. b -> Either a b
Right SqlTypeRep
TFloat
mkTypeRep Text
"timestamp with time zone" = forall a b. b -> Either a b
Right SqlTypeRep
TDateTime
mkTypeRep Text
"bytea"                    = forall a b. b -> Either a b
Right SqlTypeRep
TBlob
mkTypeRep Text
"text"                     = forall a b. b -> Either a b
Right SqlTypeRep
TText
mkTypeRep Text
"boolean"                  = forall a b. b -> Either a b
Right SqlTypeRep
TBool
mkTypeRep Text
"date"                     = forall a b. b -> Either a b
Right SqlTypeRep
TDate
mkTypeRep Text
"time with time zone"      = forall a b. b -> Either a b
Right SqlTypeRep
TTime
mkTypeRep Text
"uuid"                     = forall a b. b -> Either a b
Right SqlTypeRep
TUUID
mkTypeRep Text
"jsonb"                    = forall a b. b -> Either a b
Right SqlTypeRep
TJSON
mkTypeRep Text
typ                        = forall a b. a -> Either a b
Left Text
typ

-- | Custom column types for postgres.
pgColType :: PPConfig -> SqlTypeRep -> T.Text
pgColType :: PPConfig -> SqlTypeRep -> Text
pgColType PPConfig
_ SqlTypeRep
TRowID    = Text
"BIGINT"
pgColType PPConfig
_ SqlTypeRep
TInt64    = Text
"INT8"
pgColType PPConfig
_ SqlTypeRep
TInt32    = Text
"INT4"
pgColType PPConfig
_ SqlTypeRep
TFloat    = Text
"FLOAT8"
pgColType PPConfig
_ SqlTypeRep
TDateTime = Text
"TIMESTAMP"
pgColType PPConfig
_ SqlTypeRep
TBlob     = Text
"BYTEA"
pgColType PPConfig
_ SqlTypeRep
TUUID     = Text
"UUID"
pgColType PPConfig
_ SqlTypeRep
TJSON     = Text
"JSONB"
pgColType PPConfig
cfg SqlTypeRep
t       = PPConfig -> SqlTypeRep -> Text
ppType PPConfig
cfg SqlTypeRep
t

-- | Custom attribute types for postgres.
pgColAttr :: ColAttr -> T.Text
pgColAttr :: ColAttr -> Text
pgColAttr ColAttr
Primary         = Text
""
pgColAttr (AutoPrimary AutoIncType
_) = Text
"PRIMARY KEY"
pgColAttr ColAttr
Required        = Text
"NOT NULL"
pgColAttr ColAttr
Optional        = Text
"NULL"
pgColAttr ColAttr
Unique          = Text
"UNIQUE"
pgColAttr (Indexed Maybe IndexMethod
_)     = Text
""

-- | Custom column types (primary key position) for postgres.
pgColTypePK :: PPConfig -> SqlTypeRep -> T.Text
pgColTypePK :: PPConfig -> SqlTypeRep -> Text
pgColTypePK PPConfig
_ SqlTypeRep
TRowID    = Text
"BIGSERIAL"
pgColTypePK PPConfig
cfg SqlTypeRep
t       = PPConfig -> SqlTypeRep -> Text
pgColType PPConfig
cfg SqlTypeRep
t
#endif