Skip to content

Commit

Permalink
Use AOT compiled JIT code for several proc tests
Browse files Browse the repository at this point in the history
This speeds up proc-testing significantly.

Fixes: #1403
PiperOrigin-RevId: 638080318
  • Loading branch information
allight authored and copybara-github committed May 29, 2024
1 parent dabefa3 commit 70de428
Show file tree
Hide file tree
Showing 5 changed files with 86 additions and 69 deletions.
5 changes: 5 additions & 0 deletions xls/jit/jit_proc_wrapper_h.tmpl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@ namespace {{ wrapped.namespace }} {

class {{ wrapped.class_name }} final : public xls::BaseProcJitWrapper {
public:
static std::tuple<std::unique_ptr<Package>, std::unique_ptr<ProcRuntime>>
TakeRuntime(std::unique_ptr<{{ wrapped.class_name }}> r) {
return BaseProcJitWrapper::TakeRuntimeBase(std::move(r));
}

static absl::StatusOr<std::unique_ptr<{{ wrapped.class_name }}>> Create();

{% for chan in wrapped.incoming_channels %}
Expand Down
14 changes: 13 additions & 1 deletion xls/jit/proc_base_jit_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include <optional>
#include <string>
#include <string_view>
#include <tuple>
#include <type_traits>
#include <utility>
#include <vector>
Expand Down Expand Up @@ -77,7 +78,6 @@ class BaseProcJitWrapper {

XLS_ASSIGN_OR_RETURN(auto* man, aot->GetJitChannelQueueManager());
JitRuntime& runtime = man->runtime();
aot->ResetState();

return std::unique_ptr<RealType>(
new RealType(std::move(package), proc, std::move(aot), runtime));
Expand Down Expand Up @@ -146,6 +146,12 @@ class BaseProcJitWrapper {
runtime_(std::move(runtime)),
jit_runtime_(jit_runtime) {}

static std::tuple<std::unique_ptr<Package>, std::unique_ptr<ProcRuntime>>
TakeRuntimeBase(std::unique_ptr<BaseProcJitWrapper> w) {
return w->DoTakeRuntime();
}


template <typename PackedView>
absl::Status SendToChannelPacked(std::string_view chan_name,
PackedView view) {
Expand All @@ -171,6 +177,12 @@ class BaseProcJitWrapper {
Proc* proc_;
std::unique_ptr<ProcRuntime> runtime_;
JitRuntime& jit_runtime_;

private:
std::tuple<std::unique_ptr<Package>, std::unique_ptr<ProcRuntime>>
DoTakeRuntime() {
return {std::move(package_), std::move(runtime_)};
}
};

} // namespace xls
Expand Down
46 changes: 30 additions & 16 deletions xls/modules/aes/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

load(
"//xls/build_rules:xls_build_defs.bzl",
"PROC_WRAPPER_TYPE",
"cc_xls_ir_jit_wrapper",
"xls_benchmark_ir",
"xls_dslx_ir",
"xls_dslx_library",
Expand Down Expand Up @@ -67,23 +69,30 @@ xls_dslx_ir(
library = ":aes_ctr_dslx",
)

cc_xls_ir_jit_wrapper(
name = "aes_ctr_wrapper",
src = ":aes_ctr.ir",
jit_wrapper_args = {
"class_name": "AesCtr",
"namespace": "xls::aes::wrapped",
},
wrapper_type = PROC_WRAPPER_TYPE,
)

cc_test(
name = "aes_ctr_test",
srcs = ["aes_ctr_test.cc"],
data = [
":aes_ctr.ir",
],
deps = [
":aes_ctr_wrapper",
":aes_test_common",
"//xls/common:exit_status",
"//xls/common:init_xls",
"//xls/common/file:filesystem",
"//xls/common/file:get_runfile_path",
"//xls/common/status:status_macros",
"//xls/interpreter:proc_runtime",
"//xls/ir",
"//xls/ir:channel",
"//xls/ir:events",
"//xls/ir:ir_parser",
"//xls/jit:jit_channel_queue",
"//xls/jit:jit_proc_runtime",
"@boringssl//:crypto",
"@com_google_absl//absl/flags:flag",
"@com_google_absl//absl/log",
Expand All @@ -92,7 +101,6 @@ cc_test(
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/time",
],
)
Expand Down Expand Up @@ -156,25 +164,31 @@ xls_dslx_ir(
library = ":aes_gcm_dslx",
)

cc_xls_ir_jit_wrapper(
name = "aes_gcm_wrapper",
src = ":aes_gcm.ir",
jit_wrapper_args = {
"class_name": "AesGcm",
"namespace": "xls::aes::wrapped",
},
wrapper_type = PROC_WRAPPER_TYPE,
)

cc_test(
name = "aes_gcm_test",
srcs = ["aes_gcm_test.cc"],
data = [
":aes_gcm.ir",
],
tags = ["optonly"],
deps = [
":aes_gcm_wrapper",
":aes_test_common",
"//xls/common:exit_status",
"//xls/common:init_xls",
"//xls/common/file:filesystem",
"//xls/common/file:get_runfile_path",
"//xls/common/status:status_macros",
"//xls/interpreter:serial_proc_runtime",
"//xls/interpreter:proc_runtime",
"//xls/ir",
"//xls/ir:bits",
"//xls/ir:ir_parser",
"//xls/ir:channel",
"//xls/ir:value",
"//xls/jit:jit_proc_runtime",
"@boringssl//:crypto",
"@com_google_absl//absl/cleanup",
"@com_google_absl//absl/flags:flag",
Expand Down
48 changes: 20 additions & 28 deletions xls/modules/aes/aes_ctr_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,12 @@

#include <array>
#include <cstdint>
#include <filesystem> // NOLINT
#include <cstring>
#include <iostream>
#include <memory>
#include <string>
#include <string_view>
#include <tuple>
#include <utility>
#include <vector>

#include "absl/flags/flag.h"
Expand All @@ -31,21 +32,21 @@
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_cat.h"
#include "absl/time/clock.h"
#include "absl/time/time.h"
#include "openssl/aes.h"
#include "xls/common/exit_status.h"
#include "xls/common/file/filesystem.h"
#include "xls/common/file/get_runfile_path.h"
#include "xls/common/init_xls.h"
#include "xls/common/status/status_macros.h"
#include "xls/interpreter/proc_runtime.h"
#include "xls/ir/channel.h"
#include "xls/ir/events.h"
#include "xls/ir/ir_parser.h"
#include "xls/ir/package.h"
#include "xls/ir/proc.h"
#include "xls/jit/jit_channel_queue.h"
#include "xls/jit/jit_proc_runtime.h"
#include "xls/modules/aes/aes_ctr_wrapper.h"
#include "xls/modules/aes/aes_test_common.h"

constexpr std::string_view kEncrypterIrPath = "xls/modules/aes/aes_ctr.ir";

ABSL_FLAG(int32_t, num_samples, 1000,
"The number of (randomly-generated) blocks to test.");
ABSL_FLAG(bool, print_traces, false,
Expand All @@ -65,7 +66,7 @@ struct SampleData {
struct JitData {
std::unique_ptr<Package> package;
Proc* proc;
std::unique_ptr<SerialProcRuntime> proc_runtime;
std::unique_ptr<ProcRuntime> proc_runtime;
};

// In the DSLX, the IV is treated as a uN[96], so we potentially need to swap
Expand Down Expand Up @@ -267,23 +268,15 @@ static absl::StatusOr<bool> RunSample(JitData* jit_data,
return true;
}

static absl::StatusOr<JitData> CreateProcJit(std::string_view ir_path) {
JitData jit_data;

XLS_ASSIGN_OR_RETURN(std::filesystem::path full_ir_path,
GetXlsRunfilePath(ir_path));
XLS_ASSIGN_OR_RETURN(std::string ir_text, GetFileContents(full_ir_path));
VLOG(1) << "Parsing IR.";
XLS_ASSIGN_OR_RETURN(jit_data.package, Parser::ParsePackage(ir_text));
XLS_ASSIGN_OR_RETURN(jit_data.proc,
jit_data.package->GetProc("__aes_ctr__aes_ctr_0_next"));

VLOG(1) << "JIT compiling.";
XLS_ASSIGN_OR_RETURN(jit_data.proc_runtime,
CreateJitSerialProcRuntime(jit_data.package.get()));
VLOG(1) << "Created JIT!";

return jit_data;
static absl::StatusOr<JitData> CreateProcJit() {
XLS_ASSIGN_OR_RETURN(std::unique_ptr<wrapped::AesCtr> ctr,
wrapped::AesCtr::Create());
auto [package, runtime] = wrapped::AesCtr::TakeRuntime(std::move(ctr));
XLS_ASSIGN_OR_RETURN(Proc * proc,
package->GetProc("__aes_ctr__aes_ctr_0_next"));
return JitData{.package = std::move(package),
.proc = proc,
.proc_runtime = std::move(runtime)};
}

static absl::Status RunTest(int32_t num_samples, int32_t key_bits) {
Expand All @@ -292,8 +285,7 @@ static absl::Status RunTest(int32_t num_samples, int32_t key_bits) {
sample_data.key_bytes = key_bytes;
memset(sample_data.iv.data(), 0, sizeof(sample_data.iv));

XLS_ASSIGN_OR_RETURN(JitData encrypt_jit_data,
CreateProcJit(kEncrypterIrPath));
XLS_ASSIGN_OR_RETURN(JitData encrypt_jit_data, CreateProcJit());

absl::BitGen bitgen;
absl::Duration xls_encrypt_dur;
Expand Down
42 changes: 18 additions & 24 deletions xls/modules/aes/aes_gcm_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@

// Test of the XLS GCM mode implementation against a reference (in this
// case, BoringSSL's implementation).
#include <cstddef>
#include <cstdint>
#include <filesystem>
#include <cstring>
#include <iostream>
#include <memory>
#include <optional>
#include <string>
#include <string_view>
#include <utility>
#include <vector>
Expand All @@ -32,16 +32,16 @@
#include "absl/status/statusor.h"
#include "absl/strings/str_cat.h"
#include "openssl/aead.h"
#include "openssl/base.h"
#include "xls/common/exit_status.h"
#include "xls/common/file/filesystem.h"
#include "xls/common/file/get_runfile_path.h"
#include "xls/common/init_xls.h"
#include "xls/common/status/status_macros.h"
#include "xls/interpreter/serial_proc_runtime.h"
#include "xls/interpreter/proc_runtime.h"
#include "xls/ir/bits.h"
#include "xls/ir/ir_parser.h"
#include "xls/ir/channel.h"
#include "xls/ir/package.h"
#include "xls/ir/value.h"
#include "xls/jit/jit_proc_runtime.h"
#include "xls/modules/aes/aes_gcm_wrapper.h"
#include "xls/modules/aes/aes_test_common.h"

// TODO(rspringer): This is a bit slow. Seems like we should be able to compute
Expand All @@ -58,14 +58,13 @@ constexpr int kMaxPtxtBlocks = 128;
constexpr int kTagBits = 128;
constexpr int kTagBytes = kTagBits / 8;

constexpr std::string_view kIrPath = "xls/modules/aes/aes_gcm.ir";
constexpr std::string_view kCmdChannelName = "aes_gcm__command_in";
constexpr std::string_view kDataInChannelName = "aes_gcm__data_r";
constexpr std::string_view kDataOutChannelName = "aes_gcm__data_s";

struct JitData {
std::unique_ptr<Package> package;
std::unique_ptr<SerialProcRuntime> runtime;
std::unique_ptr<ProcRuntime> runtime;
};

struct SampleData {
Expand Down Expand Up @@ -107,7 +106,7 @@ static absl::StatusOr<Result> XlsEncrypt(JitData* jit_data,
bool encrypt) {
// Create (and send) the initial command.
Package* package = jit_data->package.get();
SerialProcRuntime* runtime = jit_data->runtime.get();
ProcRuntime* runtime = jit_data->runtime.get();
XLS_ASSIGN_OR_RETURN(Channel * cmd_channel,
package->GetChannel(kCmdChannelName));
XLS_ASSIGN_OR_RETURN(Value command, CreateCommandValue(sample_data, encrypt));
Expand Down Expand Up @@ -161,11 +160,11 @@ static absl::StatusOr<Result> ReferenceEncrypt(const SampleData& sample) {
int num_aad_blocks = sample.aad.size();
size_t max_result_size;
if (sample.key_bits == 128) {
max_result_size = num_ptxt_blocks * kBlockBytes +
EVP_AEAD_max_overhead(EVP_aead_aes_128_gcm());
max_result_size = num_ptxt_blocks * kBlockBytes +
EVP_AEAD_max_overhead(EVP_aead_aes_128_gcm());
} else {
max_result_size = num_ptxt_blocks * kBlockBytes +
EVP_AEAD_max_overhead(EVP_aead_aes_256_gcm());
max_result_size = num_ptxt_blocks * kBlockBytes +
EVP_AEAD_max_overhead(EVP_aead_aes_256_gcm());
}

auto ptxt_buffer = std::make_unique<uint8_t[]>(num_ptxt_blocks * kBlockBytes);
Expand Down Expand Up @@ -302,18 +301,13 @@ static absl::StatusOr<bool> RunSample(JitData* jit_data,
}

static absl::StatusOr<JitData> CreateJitData() {
XLS_ASSIGN_OR_RETURN(std::filesystem::path full_ir_path,
GetXlsRunfilePath(kIrPath));
XLS_ASSIGN_OR_RETURN(std::string ir_text, GetFileContents(full_ir_path));
XLS_ASSIGN_OR_RETURN(std::unique_ptr<Package> package,
Parser::ParsePackage(ir_text));
XLS_ASSIGN_OR_RETURN((std::unique_ptr<wrapped::AesGcm> aes_gcm),
wrapped::AesGcm::Create());
auto [package, runtime] = wrapped::AesGcm::TakeRuntime(std::move(aes_gcm));
return JitData{.package = std::move(package), .runtime = std::move(runtime)};
}

XLS_ASSIGN_OR_RETURN(std::unique_ptr<SerialProcRuntime> runtime,
CreateJitSerialProcRuntime(package.get()));

JitData jit_data{std::move(package), std::move(runtime)};
return jit_data;
}

static absl::Status RunTest(int num_samples, int key_bits) {
int key_bytes = key_bits / 8;
Expand Down

0 comments on commit 70de428

Please sign in to comment.