Skip to content

Commit

Permalink
src: Use a proper namespace for this project
Browse files Browse the repository at this point in the history
Additionally:
 - Remove the fs namespace alias from headers
 - Wrap most global functions in the namespace
  • Loading branch information
felipecrv committed Jul 18, 2024
1 parent c88ac35 commit bfe0471
Show file tree
Hide file tree
Showing 6 changed files with 134 additions and 131 deletions.
68 changes: 31 additions & 37 deletions src/flight_sql_client.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,21 +33,11 @@
#include "arrow/status.h"
#include "arrow/table.h"

using arrow::Result;
using arrow::Schema;
#include "library/include/flight_sql_fwd.h"

using arrow::Status;
using arrow::flight::ClientAuthHandler;
using arrow::flight::FlightCallOptions;
using arrow::flight::FlightClient;
using arrow::flight::FlightDescriptor;
using arrow::flight::FlightEndpoint;
using arrow::flight::FlightInfo;
using arrow::flight::FlightStreamChunk;
using arrow::flight::FlightStreamReader;
using arrow::flight::Location;
using arrow::flight::Ticket;
using arrow::flight::sql::FlightSqlClient;
using arrow::flight::sql::TableRef;

namespace sqlflite {

DEFINE_string(host, "localhost", "Host to connect to");
DEFINE_int32(port, 31337, "Port to connect to");
Expand All @@ -69,12 +59,12 @@ DEFINE_string(catalog, "", "Catalog");
DEFINE_string(schema, "", "Schema");
DEFINE_string(table, "", "Table");

Status PrintResultsForEndpoint(FlightSqlClient &client,
const FlightCallOptions &call_options,
const FlightEndpoint &endpoint) {
Status PrintResultsForEndpoint(flight::sql::FlightSqlClient &client,
const flight::FlightCallOptions &call_options,
const flight::FlightEndpoint &endpoint) {
ARROW_ASSIGN_OR_RAISE(auto stream, client.DoGet(call_options, endpoint.ticket));

const arrow::Result<std::shared_ptr<Schema>> &schema = stream->GetSchema();
const arrow::Result<std::shared_ptr<arrow::Schema>> &schema = stream->GetSchema();
ARROW_RETURN_NOT_OK(schema);

std::cout << "Schema:" << std::endl;
Expand All @@ -85,7 +75,7 @@ Status PrintResultsForEndpoint(FlightSqlClient &client,
int64_t num_rows = 0;

while (true) {
ARROW_ASSIGN_OR_RAISE(FlightStreamChunk chunk, stream->Next());
ARROW_ASSIGN_OR_RAISE(flight::FlightStreamChunk chunk, stream->Next());
if (chunk.data == nullptr) {
break;
}
Expand All @@ -98,9 +88,10 @@ Status PrintResultsForEndpoint(FlightSqlClient &client,
return Status::OK();
}

Status PrintResults(FlightSqlClient &client, const FlightCallOptions &call_options,
const std::unique_ptr<FlightInfo> &info) {
const std::vector<FlightEndpoint> &endpoints = info->endpoints();
Status PrintResults(flight::sql::FlightSqlClient &client,
const flight::FlightCallOptions &call_options,
const std::unique_ptr<flight::FlightInfo> &info) {
const std::vector<flight::FlightEndpoint> &endpoints = info->endpoints();

for (size_t i = 0; i < endpoints.size(); i++) {
std::cout << "Results from endpoint " << i + 1 << " of " << endpoints.size()
Expand All @@ -127,11 +118,12 @@ Status getPEMCertFileContents(const std::string &cert_file_path,

Status RunMain() {
ARROW_ASSIGN_OR_RAISE(auto location,
(FLAGS_use_tls) ? Location::ForGrpcTls(FLAGS_host, FLAGS_port)
: Location::ForGrpcTcp(FLAGS_host, FLAGS_port));
(FLAGS_use_tls)
? flight::Location::ForGrpcTls(FLAGS_host, FLAGS_port)
: flight::Location::ForGrpcTcp(FLAGS_host, FLAGS_port));

// Setup our options
arrow::flight::FlightClientOptions options;
flight::FlightClientOptions options;

if (!FLAGS_tls_roots.empty()) {
ARROW_RETURN_NOT_OK(getPEMCertFileContents(FLAGS_tls_roots, options.tls_root_certs));
Expand All @@ -152,19 +144,19 @@ Status RunMain() {
}
}

ARROW_ASSIGN_OR_RAISE(auto client, FlightClient::Connect(location, options));
ARROW_ASSIGN_OR_RAISE(auto client, flight::FlightClient::Connect(location, options));

FlightCallOptions call_options;
flight::FlightCallOptions call_options;

if (!FLAGS_username.empty() || !FLAGS_password.empty()) {
Result<std::pair<std::string, std::string>> bearer_result =
arrow::Result<std::pair<std::string, std::string>> bearer_result =
client->AuthenticateBasicToken({}, FLAGS_username, FLAGS_password);
ARROW_RETURN_NOT_OK(bearer_result);

call_options.headers.push_back(bearer_result.ValueOrDie());
}

FlightSqlClient sql_client(std::move(client));
flight::sql::FlightSqlClient sql_client(std::move(client));

if (FLAGS_command == "ExecuteUpdate") {
ARROW_ASSIGN_OR_RAISE(auto rows, sql_client.ExecuteUpdate(call_options, FLAGS_query));
Expand All @@ -174,7 +166,7 @@ Status RunMain() {
return Status::OK();
}

std::unique_ptr<FlightInfo> info;
std::unique_ptr<flight::FlightInfo> info;

if (FLAGS_command == "Execute") {
ARROW_ASSIGN_OR_RAISE(info, sql_client.Execute(call_options, FLAGS_query));
Expand Down Expand Up @@ -211,16 +203,16 @@ Status RunMain() {
info, sql_client.GetTables(call_options, &FLAGS_catalog, &FLAGS_schema,
&FLAGS_table, false, nullptr));
} else if (FLAGS_command == "GetExportedKeys") {
TableRef table_ref = {std::make_optional(FLAGS_catalog),
std::make_optional(FLAGS_schema), FLAGS_table};
flight::sql::TableRef table_ref = {std::make_optional(FLAGS_catalog),
std::make_optional(FLAGS_schema), FLAGS_table};
ARROW_ASSIGN_OR_RAISE(info, sql_client.GetExportedKeys(call_options, table_ref));
} else if (FLAGS_command == "GetImportedKeys") {
TableRef table_ref = {std::make_optional(FLAGS_catalog),
std::make_optional(FLAGS_schema), FLAGS_table};
flight::sql::TableRef table_ref = {std::make_optional(FLAGS_catalog),
std::make_optional(FLAGS_schema), FLAGS_table};
ARROW_ASSIGN_OR_RAISE(info, sql_client.GetImportedKeys(call_options, table_ref));
} else if (FLAGS_command == "GetPrimaryKeys") {
TableRef table_ref = {std::make_optional(FLAGS_catalog),
std::make_optional(FLAGS_schema), FLAGS_table};
flight::sql::TableRef table_ref = {std::make_optional(FLAGS_catalog),
std::make_optional(FLAGS_schema), FLAGS_table};
ARROW_ASSIGN_OR_RAISE(info, sql_client.GetPrimaryKeys(call_options, table_ref));
} else if (FLAGS_command == "GetSqlInfo") {
ARROW_ASSIGN_OR_RAISE(info, sql_client.GetSqlInfo(call_options, {}));
Expand All @@ -234,10 +226,12 @@ Status RunMain() {
return Status::OK();
}

} // namespace sqlflite

int main(int argc, char **argv) {
gflags::ParseCommandLineFlags(&argc, &argv, true);

Status st = RunMain();
Status st = sqlflite::RunMain();
if (!st.ok()) {
std::cerr << st << std::endl;
return 1;
Expand Down
1 change: 1 addition & 0 deletions src/flight_sql_server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include <boost/program_options.hpp>

namespace po = boost::program_options;
namespace fs = std::filesystem;

int main(int argc, char **argv) {
std::vector<std::string> tls_token_values;
Expand Down
63 changes: 34 additions & 29 deletions src/library/flight_sql_library.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,11 @@
#include "sqlite_server.h"
#include "duckdb_server.h"
#include "include/flight_sql_security.h"
#include "include/flight_sql_fwd.h"

namespace flight = arrow::flight;
namespace flightsql = arrow::flight::sql;
namespace fs = std::filesystem;

namespace sqlflite {

const int port = 31337;

Expand All @@ -58,49 +60,48 @@ const int port = 31337;
} \
} while (false)

arrow::Result<std::shared_ptr<arrow::flight::sql::FlightSqlServerBase>>
FlightSQLServerBuilder(const BackendType backend, const fs::path &database_filename,
const std::string &hostname, const int &port,
const std::string &username, const std::string &password,
const std::string &secret_key, const fs::path &tls_cert_path,
const fs::path &tls_key_path, const fs::path &mtls_ca_cert_path,
const std::string &init_sql_commands, const bool &print_queries) {
arrow::Result<std::shared_ptr<flight::sql::FlightSqlServerBase>> FlightSQLServerBuilder(
const BackendType backend, const fs::path &database_filename,
const std::string &hostname, const int &port, const std::string &username,
const std::string &password, const std::string &secret_key,
const fs::path &tls_cert_path, const fs::path &tls_key_path,
const fs::path &mtls_ca_cert_path, const std::string &init_sql_commands,
const bool &print_queries) {
ARROW_ASSIGN_OR_RAISE(auto location,
(!tls_cert_path.empty())
? arrow::flight::Location::ForGrpcTls(hostname, port)
: arrow::flight::Location::ForGrpcTcp(hostname, port));
? flight::Location::ForGrpcTls(hostname, port)
: flight::Location::ForGrpcTcp(hostname, port));

std::cout << "Apache Arrow version: " << ARROW_VERSION_STRING << std::endl;

arrow::flight::FlightServerOptions options(location);
flight::FlightServerOptions options(location);

if (!tls_cert_path.empty() && !tls_key_path.empty()) {
ARROW_CHECK_OK(arrow::flight::SecurityUtilities::FlightServerTlsCertificates(
ARROW_CHECK_OK(sqlflite::SecurityUtilities::FlightServerTlsCertificates(
tls_cert_path, tls_key_path, &options.tls_certificates));
} else {
std::cout << "WARNING - TLS is disabled for the Flight SQL server - this is insecure."
<< std::endl;
}

// Setup authentication middleware (using the same TLS certificate keypair)
auto header_middleware =
std::make_shared<arrow::flight::HeaderAuthServerMiddlewareFactory>(
username, password, secret_key);
auto header_middleware = std::make_shared<sqlflite::HeaderAuthServerMiddlewareFactory>(
username, password, secret_key);
auto bearer_middleware =
std::make_shared<arrow::flight::BearerAuthServerMiddlewareFactory>(secret_key);
std::make_shared<sqlflite::BearerAuthServerMiddlewareFactory>(secret_key);

options.auth_handler = std::make_unique<arrow::flight::NoOpAuthHandler>();
options.auth_handler = std::make_unique<flight::NoOpAuthHandler>();
options.middleware.push_back({"header-auth-server", header_middleware});
options.middleware.push_back({"bearer-auth-server", bearer_middleware});

if (!mtls_ca_cert_path.empty()) {
std::cout << "Using mTLS CA certificate: " << mtls_ca_cert_path << std::endl;
ARROW_CHECK_OK(arrow::flight::SecurityUtilities::FlightServerMtlsCACertificate(
ARROW_CHECK_OK(sqlflite::SecurityUtilities::FlightServerMtlsCACertificate(
mtls_ca_cert_path, &options.root_certificates));
options.verify_client = true;
}

std::shared_ptr<arrow::flight::sql::FlightSqlServerBase> server = nullptr;
std::shared_ptr<flight::sql::FlightSqlServerBase> server = nullptr;

std::string db_type = "";
if (backend == BackendType::sqlite) {
Expand Down Expand Up @@ -155,13 +156,12 @@ std::string SafeGetEnvVarValue(const std::string &env_var_name) {
}
}

arrow::Result<std::shared_ptr<arrow::flight::sql::FlightSqlServerBase>>
CreateFlightSQLServer(const BackendType backend, fs::path &database_filename,
std::string hostname, const int &port, std::string username,
std::string password, std::string secret_key,
fs::path tls_cert_path, fs::path tls_key_path,
fs::path mtls_ca_cert_path, std::string init_sql_commands,
fs::path init_sql_commands_file, const bool &print_queries) {
arrow::Result<std::shared_ptr<flight::sql::FlightSqlServerBase>> CreateFlightSQLServer(
const BackendType backend, fs::path &database_filename, std::string hostname,
const int &port, std::string username, std::string password, std::string secret_key,
fs::path tls_cert_path, fs::path tls_key_path, fs::path mtls_ca_cert_path,
std::string init_sql_commands, fs::path init_sql_commands_file,
const bool &print_queries) {
// Validate and default the arguments to env var values where applicable
if (database_filename.empty()) {
return arrow::Status::Invalid("The database filename was not provided!");
Expand Down Expand Up @@ -255,17 +255,21 @@ CreateFlightSQLServer(const BackendType backend, fs::path &database_filename,
}

arrow::Status StartFlightSQLServer(
std::shared_ptr<arrow::flight::sql::FlightSqlServerBase> server) {
std::shared_ptr<flight::sql::FlightSqlServerBase> server) {
return arrow::Status::OK();
}

} // namespace sqlflite

extern "C" {

int RunFlightSQLServer(const BackendType backend, fs::path &database_filename,
std::string hostname, const int &port, std::string username,
std::string password, std::string secret_key,
fs::path tls_cert_path, fs::path tls_key_path,
fs::path mtls_ca_cert_path, std::string init_sql_commands,
fs::path init_sql_commands_file, const bool &print_queries) {
auto create_server_result = CreateFlightSQLServer(
auto create_server_result = sqlflite::CreateFlightSQLServer(
backend, database_filename, hostname, port, username, password, secret_key,
tls_cert_path, tls_key_path, mtls_ca_cert_path, init_sql_commands,
init_sql_commands_file, print_queries);
Expand All @@ -281,3 +285,4 @@ int RunFlightSQLServer(const BackendType backend, fs::path &database_filename,
return EXIT_FAILURE;
}
}
}
Loading

0 comments on commit bfe0471

Please sign in to comment.