Skip to content

Commit

Permalink
Merge pull request #48
Browse files Browse the repository at this point in the history
Thread-safe Postgres connections
  • Loading branch information
Uditha Atukorala authored Jul 31, 2023
2 parents 672b57e + 80e69c1 commit 094ac8a
Show file tree
Hide file tree
Showing 10 changed files with 128 additions and 21 deletions.
1 change: 1 addition & 0 deletions src/datastore/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ if (GATEKEEPER_ENABLE_TESTING)
access-policies_test.cpp
collections_test.cpp
identities_test.cpp
pg_test.cpp
rbac-policies_test.cpp
redis_test.cpp
roles_test.cpp
Expand Down
1 change: 1 addition & 0 deletions src/datastore/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ struct config {

struct pg_t {
std::string opts;
duration_t timeout = 1000ms;
};

struct redis_t {
Expand Down
39 changes: 30 additions & 9 deletions src/datastore/pg.cpp
Original file line number Diff line number Diff line change
@@ -1,22 +1,43 @@
#include "pg.h"

static std::shared_ptr<datastore::pg::conn_t> _conn = nullptr;
#include "err/errors.h"

static datastore::config::pg_t _conf;
static datastore::pg::conn_t _conn = nullptr;

namespace datastore {
namespace pg {
std::shared_ptr<conn_t> conn() {
// FIXME: check if initialised
return _conn;
conn_t::element_type &connection::conn() const {
return *_conn;
}

result_t exec(std::string_view qry) {
nontxn_t tx(*conn());
return tx.exec(qry);
conn_t::element_type &connection::reconnect() {
_conn = connect();
return *_conn;
}

void init(const config::pg_t &c) {
connection conn() {
if (!_conn) {
throw err::DatastorePgConnectionUnavailable();
}

static std::timed_mutex mutex;
if (!mutex.try_lock_for(_conf.timeout)) {
throw err::DatastorePgTimeout();
}

return connection(_conn, connection::lock_t(mutex, std::adopt_lock));
}

conn_t connect() {
// Ref: https://www.postgresql.org/docs/current/libpq-envars.html
_conn = std::make_shared<conn_t>(c.opts);
_conn = std::make_shared<conn_t::element_type>(_conf.opts);
return _conn;
}

void init(const config::pg_t &c) {
_conf = c;
connect();
}
} // namespace pg
} // namespace datastore
43 changes: 37 additions & 6 deletions src/datastore/pg.h
Original file line number Diff line number Diff line change
@@ -1,26 +1,57 @@
#pragma once

#include <memory>
#include <mutex>

#include <pqxx/pqxx>

#include "config.h"

namespace datastore {
namespace pg {
using conn_t = pqxx::connection;
using conn_t = std::shared_ptr<pqxx::connection>;
using row_t = pqxx::row;
using result_t = pqxx::result;
using nontxn_t = pqxx::nontransaction;

using fkey_violation_t = pqxx::foreign_key_violation;
using unique_violation_t = pqxx::unique_violation;

std::shared_ptr<conn_t> conn();
class connection {
public:
using lock_t = std::unique_lock<std::timed_mutex>;

connection(const conn_t &conn, lock_t &&lock) noexcept : _conn(conn), _lock(std::move(lock)) {}

auto exec(std::string_view qry, auto &&...args) {
try {
return nontxn_exec(qry, std::forward<decltype(args)>(args)...);
} catch (const pqxx::broken_connection &e) {
// Try to reconnect, if it fails will throw an error
reconnect();
}

return nontxn_exec(qry, std::forward<decltype(args)>(args)...);
}

private:
result_t nontxn_exec(std::string_view qry, auto &&...args) const {
nontxn_t tx(conn());
return tx.exec_params(pqxx::zview(qry), std::forward<decltype(args)>(args)...);
}

conn_t::element_type &conn() const;
conn_t::element_type &reconnect();

conn_t _conn;
lock_t _lock;
};

result_t exec(std::string_view qry);
connection conn();
conn_t connect();

template <typename... Args> inline result_t exec(std::string_view qry, Args &&...args) {
nontxn_t tx(*conn());
return tx.exec_params(pqxx::zview(qry), args...);
inline auto exec(std::string_view qry, auto &&...args) {
return conn().exec(qry, std::forward<decltype(args)>(args)...);
}

void init(const config::pg_t &c);
Expand Down
50 changes: 50 additions & 0 deletions src/datastore/pg_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
#include <thread>

#include <gtest/gtest.h>

#include "pg.h"
#include "testing.h"

TEST(pg, concurrency) {
if (std::thread::hardware_concurrency() < 2) {
GTEST_SKIP() << "Not enough hardware support to run concurrency tests";
}

auto conf = datastore::testing::conf();
conf.pg.timeout = 50ms;
ASSERT_NO_THROW(datastore::pg::init(conf.pg));

// Success: timeout while waiting for connection lock
{
std::thread t1([conf]() {
auto conn = datastore::pg::conn();
std::this_thread::sleep_for(conf.pg.timeout * 5);
});

std::thread t2([]() {
// Connection is locked into t1 scope, expect a timeout
EXPECT_THROW(datastore::pg::conn(), err::DatastorePgTimeout);
});

t1.join();
t2.join();
}
}

TEST(pg, conn) {
// Error: connection unavailable
{ EXPECT_THROW(datastore::pg::conn(), err::DatastorePgConnectionUnavailable); }
}

TEST(pg, reconnect) {
auto conf = datastore::testing::conf();
ASSERT_NO_THROW(datastore::pg::init(conf.pg));

// Success: reconnect
{
auto c = datastore::pg::connect();
c->close();

EXPECT_NO_THROW(datastore::pg::exec("select 'ping';"));
}
}
2 changes: 1 addition & 1 deletion src/datastore/redis.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class connection {

connection(const context_t &ptr, lock_t &&lock) noexcept : _ctx(ptr), _lock(std::move(lock)) {}

template <typename... Args> inline reply_t cmd(const std::string_view str, Args &&...args) {
template <typename... Args> reply_t cmd(const std::string_view str, Args &&...args) {
reply_t reply(
static_cast<reply_t::element_type *>(redisCommand(ctx(), str.data(), args...)),
freeReplyObject);
Expand Down
2 changes: 1 addition & 1 deletion src/err/basic_error.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ template <fixed_string C, fixed_string M> struct basic_error : public std::runti
std::strcat(_err, M.c_str());
}

inline std::string_view str() const noexcept { return _err; }
std::string_view str() const noexcept { return _err; }

friend std::ostream &operator<<(std::ostream &os, const basic_error &err) {
return os << err.str();
Expand Down
3 changes: 3 additions & 0 deletions src/err/errors.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
#include "basic_error.h"

namespace err {
using DatastorePgConnectionUnavailable = basic_error<"gk:1.0.5.503", "Unavailable">;
using DatastorePgTimeout = basic_error<"gk:1.0.6.503", "Operation timed out">;

using DatastoreRedisCommandError = basic_error<"gk:1.0.4.503", "Unavailable">;
using DatastoreRedisConnectionFailure = basic_error<"gk:1.0.1.503", "Unavailable">;
using DatastoreRedisConnectionUnavailable = basic_error<"gk:1.0.3.503", "Unavailable">;
Expand Down
6 changes: 3 additions & 3 deletions src/err/fixed_string.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@ namespace err {
template <std::size_t N> struct fixed_string {
constexpr fixed_string(const char (&str)[N]) { std::copy_n(str, N, value); }

inline const char *c_str() const noexcept { return value; }
const char *c_str() const noexcept { return value; }

constexpr inline std::size_t size() const noexcept { return N; }
constexpr std::size_t size() const noexcept { return N; }

inline std::string_view str() const noexcept { return value; }
std::string_view str() const noexcept { return value; }

char value[N];
};
Expand Down
2 changes: 1 addition & 1 deletion src/logger/logger.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ inline void log(std::string_view severity, std::string_view source, Args &&...ar
std::cout << glz::write_json(obj) << std::endl;
}

void critical(std::string_view source, auto &&...args) {
inline void critical(std::string_view source, auto &&...args) {
log("critical", source, std::forward<decltype(args)>(args)...);
std::exit(EXIT_FAILURE);
}
Expand Down

0 comments on commit 094ac8a

Please sign in to comment.