From dabefa3250731fbd577eb0700b7a793844c12040 Mon Sep 17 00:00:00 2001 From: Alex Light Date: Tue, 28 May 2024 17:17:54 -0700 Subject: [PATCH] Use the AOT data layout on AOT code We were always using the data-layout for the ORC-jit when running AOT code. These are currently the same but we really should be sure to use the layout the code was actually compiled with PiperOrigin-RevId: 638079658 --- xls/interpreter/BUILD | 2 ++ xls/interpreter/serial_proc_runtime_test.cc | 7 +++- xls/jit/BUILD | 18 +++++++--- xls/jit/aot_compiler_main.cc | 32 +++++++++-------- xls/jit/aot_entrypoint.proto | 3 ++ xls/jit/block_jit.cc | 9 +++-- xls/jit/function_base_jit.h | 2 ++ xls/jit/function_base_jit_wrapper.h | 17 +++++---- xls/jit/function_jit.cc | 17 ++++++--- xls/jit/function_jit.h | 2 +- xls/jit/function_jit_aot_test.cc | 23 +++++++------ xls/jit/jit_channel_queue.cc | 20 +++++------ xls/jit/jit_channel_queue.h | 14 ++++---- xls/jit/jit_channel_queue_benchmark.cc | 5 ++- xls/jit/jit_function_wrapper_cc.tmpl | 6 ++-- xls/jit/jit_proc_runtime.cc | 38 ++++++++++++++++----- xls/jit/jit_runtime.cc | 13 +++---- xls/jit/jit_runtime.h | 5 +-- xls/jit/proc_jit_test.cc | 9 +++-- xls/tools/BUILD | 1 + xls/tools/benchmark_main.cc | 18 ++++++---- 21 files changed, 163 insertions(+), 98 deletions(-) diff --git a/xls/interpreter/BUILD b/xls/interpreter/BUILD index fe23a3630d..73d3a13e20 100644 --- a/xls/interpreter/BUILD +++ b/xls/interpreter/BUILD @@ -448,6 +448,8 @@ cc_test( "//xls/ir:value", "//xls/jit:jit_channel_queue", "//xls/jit:jit_proc_runtime", + "//xls/jit:jit_runtime", + "//xls/jit:orc_jit", "//xls/jit:proc_jit", "@com_google_googletest//:gtest", ], diff --git a/xls/interpreter/serial_proc_runtime_test.cc b/xls/interpreter/serial_proc_runtime_test.cc index e372b19df0..c75d96464b 100644 --- a/xls/interpreter/serial_proc_runtime_test.cc +++ b/xls/interpreter/serial_proc_runtime_test.cc @@ -40,6 +40,8 @@ #include "xls/ir/value.h" #include "xls/jit/jit_channel_queue.h" #include "xls/jit/jit_proc_runtime.h" +#include "xls/jit/jit_runtime.h" +#include "xls/jit/orc_jit.h" #include "xls/jit/proc_jit.h" namespace xls { @@ -51,11 +53,14 @@ constexpr const char kIrAssertPath[] = "xls/interpreter/force_assert.ir"; // ProcJits. absl::StatusOr> CreateMixedSerialProcRuntime( ProcElaboration elaboration) { + XLS_ASSIGN_OR_RETURN(std::unique_ptr orc, OrcJit::Create()); + XLS_ASSIGN_OR_RETURN(auto data_layout, orc->CreateDataLayout()); // Create a queue manager for the queues. This factory verifies that there an // receive only queue for every receive only channel. XLS_ASSIGN_OR_RETURN( std::unique_ptr queue_manager, - JitChannelQueueManager::CreateThreadSafe(std::move(elaboration))); + JitChannelQueueManager::CreateThreadSafe( + std::move(elaboration), std::make_unique(data_layout))); // Create a ProcJit or a ProcInterpreter for each Proc. Alternate between the // two options. diff --git a/xls/jit/BUILD b/xls/jit/BUILD index e75cc0acde..e260608493 100644 --- a/xls/jit/BUILD +++ b/xls/jit/BUILD @@ -84,7 +84,6 @@ cc_binary( name = "aot_compiler_main", srcs = ["aot_compiler_main.cc"], deps = [ - ":aot_compiler", ":aot_entrypoint_cc_proto", ":function_base_jit", ":function_jit", @@ -450,6 +449,7 @@ cc_library( "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", + "@llvm-project//llvm:Support", "@llvm-project//llvm:ir_headers", ], ) @@ -590,7 +590,6 @@ cc_library( "//xls/interpreter:channel_queue", "//xls/ir", "//xls/ir:channel", - "//xls/ir:elaboration", "//xls/ir:proc_elaboration", "//xls/ir:value", "@com_google_absl//absl/base:core_headers", @@ -630,17 +629,15 @@ cc_library( hdrs = ["jit_runtime.h"], deps = [ ":llvm_type_converter", - ":orc_jit", "//xls/common:bits_util", "//xls/common:math_util", - "//xls/common/status:status_macros", "//xls/ir:bits", "//xls/ir:type", "//xls/ir:value", + "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", @@ -765,6 +762,7 @@ cc_library( "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", + "@llvm-project//llvm:ir_headers", ], ) @@ -774,8 +772,11 @@ cc_test( deps = [ ":block_jit", ":jit_runtime", + ":orc_jit", "//xls/common:xls_gunit_main", "//xls/common/status:matchers", + "//xls/common/status:status_macros", + "//xls/interpreter:block_evaluator", "//xls/interpreter:block_evaluator_test_base", "//xls/ir", "//xls/ir:bits", @@ -783,8 +784,12 @@ cc_test( "//xls/ir:ir_test_base", "//xls/ir:value", "//xls/ir:value_view", + "@com_google_absl//absl/base", + "@com_google_absl//absl/base:no_destructor", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/types:span", "@com_google_googletest//:gtest", + "@llvm-project//llvm:ir_headers", ], ) @@ -883,6 +888,7 @@ cc_library( ":aot_entrypoint_cc_proto", ":function_base_jit", ":jit_channel_queue", + ":jit_runtime", ":llvm_compiler", ":proc_jit", "//xls/common/status:ret_check", @@ -900,6 +906,7 @@ cc_library( "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", + "@llvm-project//llvm:Support", "@llvm-project//llvm:Target", "@llvm-project//llvm:ir_headers", ], @@ -926,6 +933,7 @@ cc_binary( deps = [ ":jit_channel_queue", ":jit_runtime", + ":orc_jit", "//xls/ir", "//xls/ir:channel", "//xls/ir:channel_ops", diff --git a/xls/jit/aot_compiler_main.cc b/xls/jit/aot_compiler_main.cc index 205799e859..0070a99444 100644 --- a/xls/jit/aot_compiler_main.cc +++ b/xls/jit/aot_compiler_main.cc @@ -21,6 +21,7 @@ #include // NOLINT #include #include +#include #include #include #include @@ -32,6 +33,7 @@ #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "llvm/include/llvm/IR/DataLayout.h" +#include "llvm/include/llvm/IR/LLVMContext.h" #include "google/protobuf/text_format.h" #include "xls/common/file/filesystem.h" #include "xls/common/init_xls.h" @@ -42,7 +44,6 @@ #include "xls/ir/nodes.h" #include "xls/ir/package.h" #include "xls/ir/type.h" -#include "xls/jit/aot_compiler.h" #include "xls/jit/aot_entrypoint.pb.h" #include "xls/jit/function_base_jit.h" #include "xls/jit/function_jit.h" @@ -76,14 +77,8 @@ namespace { absl::StatusOr GenerateEntrypointProto( Package* package, FunctionBase* func, const JittedFunctionBase& object_code, - bool include_msan) { + bool include_msan, LlvmTypeConverter& type_converter) { AotEntrypointProto proto; - XLS_ASSIGN_OR_RETURN( - std::unique_ptr aot_compiler, - AotCompiler::Create(/*emit_msan=*/include_msan, /*opt_level=*/3)); - XLS_ASSIGN_OR_RETURN(llvm::DataLayout data_layout, - aot_compiler->CreateDataLayout()); - LlvmTypeConverter type_converter(aot_compiler->GetContext(), data_layout); proto.set_has_msan(include_msan); if (func->IsFunction()) { proto.set_type(AotEntrypointProto::FUNCTION); @@ -171,7 +166,7 @@ absl::Status RealMain(const std::string& input_ir_path, const std::string& top, } } - JitObjectCode object_code; + std::optional object_code; if (f->IsFunction()) { XLS_ASSIGN_OR_RETURN(object_code, FunctionJit::CreateObjectCode( f->AsFunctionOrDie(), @@ -191,12 +186,19 @@ absl::Status RealMain(const std::string& input_ir_path, const std::string& top, } AotPackageEntrypointsProto all_entrypoints; XLS_RETURN_IF_ERROR(SetFileContents( - output_object_path, std::string(object_code.object_code.begin(), - object_code.object_code.end()))); - for (const FunctionEntrypoint& oc : object_code.entrypoints) { - XLS_ASSIGN_OR_RETURN(*all_entrypoints.add_entrypoint(), - GenerateEntrypointProto(package.get(), oc.function, - oc.jit_info, include_msan)); + output_object_path, std::string(object_code->object_code.begin(), + object_code->object_code.end()))); + + *all_entrypoints.mutable_data_layout() = + object_code->data_layout.getStringRepresentation(); + + auto context = std::make_unique(); + LlvmTypeConverter type_converter(context.get(), object_code->data_layout); + for (const FunctionEntrypoint& oc : object_code->entrypoints) { + XLS_ASSIGN_OR_RETURN( + *all_entrypoints.add_entrypoint(), + GenerateEntrypointProto(package.get(), oc.function, oc.jit_info, + include_msan, type_converter)); } if (generate_textproto) { std::string text; diff --git a/xls/jit/aot_entrypoint.proto b/xls/jit/aot_entrypoint.proto index 8330260aed..cbc613ccfc 100644 --- a/xls/jit/aot_entrypoint.proto +++ b/xls/jit/aot_entrypoint.proto @@ -80,4 +80,7 @@ message AotEntrypointProto { // a list of all of the targets contained. message AotPackageEntrypointsProto { repeated AotEntrypointProto entrypoint = 1; + + // The LLVM DataLayout used in this compile. + optional string data_layout = 2; } diff --git a/xls/jit/block_jit.cc b/xls/jit/block_jit.cc index e5515e0781..b79b7882a8 100644 --- a/xls/jit/block_jit.cc +++ b/xls/jit/block_jit.cc @@ -58,10 +58,10 @@ absl::StatusOr> BlockJit::Create(Block* block) { return absl::UnimplementedError( "Jitting of blocks with instantiations is not yet supported."); } - XLS_ASSIGN_OR_RETURN(std::unique_ptr runtime, - JitRuntime::Create()); - return std::unique_ptr(new BlockJit( - block, std::move(runtime), std::move(orc_jit), std::move(function))); + XLS_ASSIGN_OR_RETURN(auto data_layout, orc_jit->CreateDataLayout()); + return std::unique_ptr( + new BlockJit(block, std::make_unique(data_layout), + std::move(orc_jit), std::move(function))); } std::unique_ptr BlockJit::NewContinuation() { @@ -351,7 +351,6 @@ absl::StatusOr JitBlockEvaluator::EvaluateBlock( XLS_RET_CHECK_EQ(elaboration.instances().size(), 1) << "StreamingJitBlockEvaluator does not support instantiations"; - XLS_ASSIGN_OR_RETURN(auto runtime, JitRuntime::Create()); XLS_ASSIGN_OR_RETURN(auto jit, BlockJit::Create(top_block)); auto continuation = jit->NewContinuation(); XLS_RETURN_IF_ERROR(continuation->SetInputPorts(inputs)); diff --git a/xls/jit/function_base_jit.h b/xls/jit/function_base_jit.h index a795b4a55c..91f0ac840a 100644 --- a/xls/jit/function_base_jit.h +++ b/xls/jit/function_base_jit.h @@ -26,6 +26,7 @@ #include "absl/container/flat_hash_map.h" #include "absl/status/statusor.h" #include "absl/types/span.h" +#include "llvm/include/llvm/IR/DataLayout.h" #include "xls/ir/events.h" #include "xls/ir/function.h" #include "xls/ir/function_base.h" @@ -301,6 +302,7 @@ struct FunctionEntrypoint { struct JitObjectCode { std::vector object_code; std::vector entrypoints; + llvm::DataLayout data_layout; }; } // namespace xls diff --git a/xls/jit/function_base_jit_wrapper.h b/xls/jit/function_base_jit_wrapper.h index 074a229ced..4f86270a36 100644 --- a/xls/jit/function_base_jit_wrapper.h +++ b/xls/jit/function_base_jit_wrapper.h @@ -62,21 +62,24 @@ class BaseFunctionJitWrapper { template static absl::StatusOr> Create( std::string_view ir_text, std::string_view function_name, - absl::Span aot_entrypoint_proto_bin, + absl::Span aot_entrypoints_proto_bin, JitFunctionType unpacked_entrypoint, JitFunctionType packed_entrypoint) requires(std::is_base_of_v) { XLS_ASSIGN_OR_RETURN(auto package, ParsePackage(ir_text, /*filename=*/std::nullopt)); XLS_ASSIGN_OR_RETURN(auto function, package->GetFunction(function_name)); - AotEntrypointProto proto; + AotPackageEntrypointsProto proto; // NB We could fallback to real jit here maybe? - XLS_RET_CHECK(proto.ParseFromArray(aot_entrypoint_proto_bin.data(), - aot_entrypoint_proto_bin.size())) + XLS_RET_CHECK(proto.ParseFromArray(aot_entrypoints_proto_bin.data(), + aot_entrypoints_proto_bin.size())) << "Unable to parse aot information."; - XLS_ASSIGN_OR_RETURN( - auto jit, FunctionJit::CreateFromAot( - function, proto, unpacked_entrypoint, packed_entrypoint)); + XLS_RET_CHECK_EQ(proto.entrypoint_size(), 1) + << "FunctionWrapper should only have a single XLS function compiled."; + XLS_ASSIGN_OR_RETURN(auto jit, + FunctionJit::CreateFromAot( + function, proto.entrypoint(0), proto.data_layout(), + unpacked_entrypoint, packed_entrypoint)); return std::unique_ptr( new RealType(std::move(package), std::move(jit), MatchesImplicitToken(function->GetType()->parameters()))); diff --git a/xls/jit/function_jit.cc b/xls/jit/function_jit.cc index 9a76efd154..326fb1c61b 100644 --- a/xls/jit/function_jit.cc +++ b/xls/jit/function_jit.cc @@ -16,7 +16,9 @@ #include #include +#include #include +#include #include #include @@ -28,6 +30,8 @@ #include "absl/strings/str_format.h" #include "absl/types/span.h" #include "llvm/include/llvm/IR/DataLayout.h" +#include "llvm/include/llvm/Support/Error.h" +#include "xls/common/status/ret_check.h" #include "xls/common/status/status_macros.h" #include "xls/ir/events.h" #include "xls/ir/function.h" @@ -38,7 +42,6 @@ #include "xls/ir/value_utils.h" #include "xls/jit/aot_compiler.h" #include "xls/jit/function_base_jit.h" -#include "xls/jit/jit_callbacks.h" #include "xls/jit/jit_runtime.h" #include "xls/jit/observer.h" #include "xls/jit/orc_jit.h" @@ -55,13 +58,17 @@ absl::StatusOr> FunctionJit::Create( /* static */ absl::StatusOr> FunctionJit::CreateFromAot(Function* xls_function, const AotEntrypointProto& entrypoint, + std::string_view data_layout, JitFunctionType function_unpacked, std::optional function_packed) { XLS_ASSIGN_OR_RETURN( JittedFunctionBase jfb, JittedFunctionBase::BuildFromAot(xls_function, entrypoint, function_unpacked, function_packed)); - XLS_ASSIGN_OR_RETURN(auto runtime, JitRuntime::Create()); + llvm::Expected layout = + llvm::DataLayout::parse(data_layout); + XLS_RET_CHECK(layout) << "Unable to parse '" << data_layout + << "' to an llvm data-layout."; // OrcJit is simply the arena that holds the JITed code. Since we are already // compiled theres no need to create and initialize it. // TODO(allight): Ideally we wouldn't even need to link in the llvm stuff if @@ -69,13 +76,14 @@ FunctionJit::CreateFromAot(Function* xls_function, // around some extra .so's isn't a huge deal. return std::unique_ptr( new FunctionJit(xls_function, std::unique_ptr(nullptr), - std::move(jfb), std::move(runtime))); + std::move(jfb), std::make_unique(*layout))); } absl::StatusOr FunctionJit::CreateObjectCode( Function* xls_function, int64_t opt_level, bool include_msan) { XLS_ASSIGN_OR_RETURN(std::unique_ptr comp, AotCompiler::Create(include_msan, opt_level)); + XLS_ASSIGN_OR_RETURN(llvm::DataLayout data_layout, comp->CreateDataLayout()); XLS_ASSIGN_OR_RETURN(JittedFunctionBase jfb, JittedFunctionBase::Build(xls_function, *comp)); XLS_ASSIGN_OR_RETURN(auto obj_code, std::move(comp)->GetObjectCode()); @@ -85,7 +93,8 @@ absl::StatusOr FunctionJit::CreateObjectCode( .function = xls_function, .jit_info = std::move(jfb), }, - }}; + }, + .data_layout = data_layout}; } absl::StatusOr> FunctionJit::CreateInternal( diff --git a/xls/jit/function_jit.h b/xls/jit/function_jit.h index 076d5f06c8..555f3022a6 100644 --- a/xls/jit/function_jit.h +++ b/xls/jit/function_jit.h @@ -21,7 +21,6 @@ #include #include #include -#include #include "absl/container/flat_hash_map.h" #include "absl/status/status.h" @@ -58,6 +57,7 @@ class FunctionJit { // function. static absl::StatusOr> CreateFromAot( Function* xls_function, const AotEntrypointProto& entrypoint, + std::string_view data_layout, JitFunctionType function_unpacked, std::optional function_packed = std::nullopt); diff --git a/xls/jit/function_jit_aot_test.cc b/xls/jit/function_jit_aot_test.cc index dfe81a1337..f119c93e47 100644 --- a/xls/jit/function_jit_aot_test.cc +++ b/xls/jit/function_jit_aot_test.cc @@ -79,24 +79,24 @@ static constexpr std::string_view kGoldTopName = static constexpr std::string_view kTestAotEntrypointsProto = "xls/jit/multi_function_aot.pb"; -absl::StatusOr GetEntrypointsProto() { +absl::StatusOr GetEntrypointsProto() { AotPackageEntrypointsProto proto; XLS_ASSIGN_OR_RETURN(std::filesystem::path path, GetXlsRunfilePath(kTestAotEntrypointsProto)); XLS_ASSIGN_OR_RETURN(std::string bin, GetFileContents(path)); XLS_RET_CHECK(proto.ParseFromString(bin)); XLS_RET_CHECK_EQ(proto.entrypoint_size(), 1); - return proto.entrypoint()[0]; + return proto; } bool AreSymbolsAsExpected() { auto v = GetEntrypointsProto(); if (!v.ok()) { return false; } - return v->has_function_symbol() && - v->function_symbol() == kExpectedSymbolNameUnpacked && - v->has_packed_function_symbol() && - v->packed_function_symbol() == kExpectedSymbolNamePacked; + return v->entrypoint(0).has_function_symbol() && + v->entrypoint(0).function_symbol() == kExpectedSymbolNameUnpacked && + v->entrypoint(0).has_packed_function_symbol() && + v->entrypoint(0).packed_function_symbol() == kExpectedSymbolNamePacked; } // Not really a test just to make sure that if all other tests are disabled due @@ -118,7 +118,8 @@ class FunctionJitAotTest : public testing::Test { }; TEST_F(FunctionJitAotTest, CallAot) { - XLS_ASSERT_OK_AND_ASSIGN(AotEntrypointProto proto, GetEntrypointsProto()); + XLS_ASSERT_OK_AND_ASSIGN(AotPackageEntrypointsProto proto, + GetEntrypointsProto()); XLS_ASSERT_OK_AND_ASSIGN(auto gold_file, GetXlsRunfilePath(kGoldIr)); XLS_ASSERT_OK_AND_ASSIGN(std::string pkg_text, GetFileContents(gold_file)); XLS_ASSERT_OK_AND_ASSIGN(auto p, ParsePackage(pkg_text, kGoldIr)); @@ -126,7 +127,8 @@ TEST_F(FunctionJitAotTest, CallAot) { XLS_ASSERT_OK_AND_ASSIGN( auto test_aot, FunctionJit::CreateFromAot( - f, proto, __multi_func_with_trace__multi_function_one, + f, proto.entrypoint(0), proto.data_layout(), + __multi_func_with_trace__multi_function_one, __multi_func_with_trace__multi_function_one_packed)); // Value { @@ -159,7 +161,8 @@ TEST_F(FunctionJitAotTest, CallAot) { } TEST_F(FunctionJitAotTest, InterceptCallAot) { - XLS_ASSERT_OK_AND_ASSIGN(AotEntrypointProto proto, GetEntrypointsProto()); + XLS_ASSERT_OK_AND_ASSIGN(AotPackageEntrypointsProto proto, + GetEntrypointsProto()); XLS_ASSERT_OK_AND_ASSIGN(auto gold_file, GetXlsRunfilePath(kGoldIr)); XLS_ASSERT_OK_AND_ASSIGN(std::string pkg_text, GetFileContents(gold_file)); XLS_ASSERT_OK_AND_ASSIGN(auto p, ParsePackage(pkg_text, kGoldIr)); @@ -174,7 +177,7 @@ TEST_F(FunctionJitAotTest, InterceptCallAot) { XLS_ASSERT_OK_AND_ASSIGN( auto test_aot, FunctionJit::CreateFromAot( - f, proto, + f, proto.entrypoint(0), proto.data_layout(), [](const uint8_t* const* inputs, uint8_t* const* outputs, void* temp_buffer, xls::InterpreterEvents* events, xls::InstanceContext* instance_context, diff --git a/xls/jit/jit_channel_queue.cc b/xls/jit/jit_channel_queue.cc index 8026e1196d..b690ef7f51 100644 --- a/xls/jit/jit_channel_queue.cc +++ b/xls/jit/jit_channel_queue.cc @@ -126,16 +126,16 @@ std::optional ThreadUnsafeJitChannelQueue::ReadInternal() { } /* static */ absl::StatusOr> -JitChannelQueueManager::CreateThreadSafe(Package* package) { +JitChannelQueueManager::CreateThreadSafe(Package* package, + std::unique_ptr runtime) { XLS_ASSIGN_OR_RETURN(ProcElaboration elaboration, ProcElaboration::ElaborateOldStylePackage(package)); - return CreateThreadSafe(std::move(elaboration)); + return CreateThreadSafe(std::move(elaboration), std::move(runtime)); } /* static */ absl::StatusOr> -JitChannelQueueManager::CreateThreadSafe(ProcElaboration&& elaboration) { - XLS_ASSIGN_OR_RETURN(std::unique_ptr runtime, - JitRuntime::Create()); +JitChannelQueueManager::CreateThreadSafe(ProcElaboration&& elaboration, + std::unique_ptr runtime) { std::vector> queues; for (ChannelInstance* channel_instance : elaboration.channel_instances()) { queues.push_back(std::make_unique( @@ -146,16 +146,16 @@ JitChannelQueueManager::CreateThreadSafe(ProcElaboration&& elaboration) { } /* static */ absl::StatusOr> -JitChannelQueueManager::CreateThreadUnsafe(Package* package) { +JitChannelQueueManager::CreateThreadUnsafe( + Package* package, std::unique_ptr runtime) { XLS_ASSIGN_OR_RETURN(ProcElaboration elaboration, ProcElaboration::ElaborateOldStylePackage(package)); - return CreateThreadUnsafe(std::move(elaboration)); + return CreateThreadUnsafe(std::move(elaboration), std::move(runtime)); } /* static */ absl::StatusOr> -JitChannelQueueManager::CreateThreadUnsafe(ProcElaboration&& elaboration) { - XLS_ASSIGN_OR_RETURN(std::unique_ptr runtime, - JitRuntime::Create()); +JitChannelQueueManager::CreateThreadUnsafe( + ProcElaboration&& elaboration, std::unique_ptr runtime) { std::vector> queues; for (ChannelInstance* channel_instance : elaboration.channel_instances()) { queues.push_back(std::make_unique( diff --git a/xls/jit/jit_channel_queue.h b/xls/jit/jit_channel_queue.h index 0832ccb4c1..699f498da2 100644 --- a/xls/jit/jit_channel_queue.h +++ b/xls/jit/jit_channel_queue.h @@ -27,8 +27,8 @@ #include "absl/synchronization/mutex.h" #include "xls/interpreter/channel_queue.h" #include "xls/ir/channel.h" -#include "xls/ir/elaboration.h" #include "xls/ir/package.h" +#include "xls/ir/proc_elaboration.h" #include "xls/ir/value.h" #include "xls/jit/jit_runtime.h" @@ -37,7 +37,7 @@ namespace xls { // A queue from which raw bytes may be written or read. class ByteQueue { public: - // `channel_element_size` is the granuality of the queue access. Each read or + // `channel_element_size` is the granularity of the queue access. Each read or // write to the queue handles this many bytes at a time. `is_single_value` // indicates whether this queue follows single-value channel semantics where // the queue only holds a single value; writes overwrite the value in the @@ -207,13 +207,15 @@ class JitChannelQueueManager : public ChannelQueueManager { // Factories which create a queue manager with exclusively ThreadSafe/Unsafe // queues. static absl::StatusOr> - CreateThreadSafe(Package* package); + CreateThreadSafe(Package* package, std::unique_ptr runtime); static absl::StatusOr> - CreateThreadSafe(ProcElaboration&& elaboration); + CreateThreadSafe(ProcElaboration&& elaboration, + std::unique_ptr runtime); static absl::StatusOr> - CreateThreadUnsafe(Package* package); + CreateThreadUnsafe(Package* package, std::unique_ptr runtime); static absl::StatusOr> - CreateThreadUnsafe(ProcElaboration&& elaboration); + CreateThreadUnsafe(ProcElaboration&& elaboration, + std::unique_ptr runtime); JitChannelQueue& GetJitQueue(Channel* channel); JitChannelQueue& GetJitQueue(ChannelInstance* channel_instance); diff --git a/xls/jit/jit_channel_queue_benchmark.cc b/xls/jit/jit_channel_queue_benchmark.cc index 37c047f59e..2ec26b266d 100644 --- a/xls/jit/jit_channel_queue_benchmark.cc +++ b/xls/jit/jit_channel_queue_benchmark.cc @@ -26,6 +26,7 @@ #include "xls/ir/proc_elaboration.h" #include "xls/jit/jit_channel_queue.h" #include "xls/jit/jit_runtime.h" +#include "xls/jit/orc_jit.h" namespace xls { namespace { @@ -40,7 +41,9 @@ static void BM_QueueWriteThenRead(benchmark::State& state) { int64_t element_size_bytes = state.range(0); Package package("benchmark"); - std::unique_ptr jit_runtime = JitRuntime::Create().value(); + auto orc_jit = OrcJit::Create().value(); + auto jit_runtime = + std::make_unique(orc_jit->CreateDataLayout().value()); Channel* channel = package .CreateStreamingChannel("my_channel", ChannelOps::kSendReceive, diff --git a/xls/jit/jit_function_wrapper_cc.tmpl b/xls/jit/jit_function_wrapper_cc.tmpl index 6c6c709f3c..553af86709 100644 --- a/xls/jit/jit_function_wrapper_cc.tmpl +++ b/xls/jit/jit_function_wrapper_cc.tmpl @@ -45,8 +45,8 @@ static constexpr char kIrText[] = // Bytes of the AOT entrypoint message: {{ str(wrapped.aot_entrypoint.entrypoint[0]).split("\n") | prefix_each("// ") | join("\n") }} -static constexpr std::array kAotEntrypointProtoBin = { - {{wrapped.aot_entrypoint.entrypoint[0].SerializeToString() | list | join(", ")}} +static constexpr std::array kAotEntrypointsProtoBin = { + {{wrapped.aot_entrypoint.SerializeToString() | list | join(", ")}} }; } // namespace @@ -55,7 +55,7 @@ absl::StatusOr> return xls::BaseFunctionJitWrapper::Create<{{wrapped.class_name}}>( kIrText, kFunctionName, - kAotEntrypointProtoBin, + kAotEntrypointsProtoBin, {{wrapped.aot_entrypoint.entrypoint[0].function_symbol}}, {{wrapped.aot_entrypoint.entrypoint[0].packed_function_symbol}}); } diff --git a/xls/jit/jit_proc_runtime.cc b/xls/jit/jit_proc_runtime.cc index df1a5353b3..2ab4268682 100644 --- a/xls/jit/jit_proc_runtime.cc +++ b/xls/jit/jit_proc_runtime.cc @@ -14,6 +14,7 @@ #include "xls/jit/jit_proc_runtime.h" +#include #include #include #include @@ -31,6 +32,7 @@ #include "llvm/include/llvm/IR/DataLayout.h" #include "llvm/include/llvm/IR/LLVMContext.h" #include "llvm/include/llvm/IR/Module.h" +#include "llvm/include/llvm/Support/Error.h" #include "llvm/include/llvm/Target/TargetMachine.h" #include "xls/common/status/ret_check.h" #include "xls/common/status/status_macros.h" @@ -45,6 +47,7 @@ #include "xls/jit/aot_entrypoint.pb.h" #include "xls/jit/function_base_jit.h" #include "xls/jit/jit_channel_queue.h" +#include "xls/jit/jit_runtime.h" #include "xls/jit/llvm_compiler.h" #include "xls/jit/proc_jit.h" @@ -70,7 +73,7 @@ class SharedCompiler final : public LlvmCompiler { public: explicit SharedCompiler(std::string_view name, AotCompiler* underlying, std::unique_ptr target, - llvm::DataLayout&& data_layout) + llvm::DataLayout data_layout) : LlvmCompiler(std::move(target), std::move(data_layout), underlying->opt_level(), underlying->include_msan()), underlying_(underlying), @@ -127,16 +130,22 @@ absl::StatusOr GetAotObjectCode(ProcElaboration elaboration, SharedCompiler sc(elaboration.top() ? elaboration.top()->GetName() : elaboration.procs().front()->package()->name(), - compiler.get(), std::move(target), std::move(layout)); - JitObjectCode joc; + compiler.get(), std::move(target), layout); + std::vector entrypoints; + entrypoints.reserve(elaboration.procs().size()); for (Proc* p : elaboration.procs()) { - joc.entrypoints.push_back({.function = p}); - XLS_ASSIGN_OR_RETURN(joc.entrypoints.back().jit_info, + entrypoints.push_back({.function = p}); + XLS_ASSIGN_OR_RETURN(entrypoints.back().jit_info, JittedFunctionBase::Build(p, sc)); } XLS_RETURN_IF_ERROR(compiler->CompileModule(std::move(sc).TakeModule())); - XLS_ASSIGN_OR_RETURN(joc.object_code, std::move(compiler)->GetObjectCode()); - return joc; + XLS_ASSIGN_OR_RETURN(std::vector object_code, + std::move(compiler)->GetObjectCode()); + return JitObjectCode{ + .object_code = std::move(object_code), + .entrypoints = std::move(entrypoints), + .data_layout = layout, + }; } namespace { @@ -177,11 +186,18 @@ absl::StatusOr> CreateAotRuntime( return procs_by_name.contains(p->name()) && procs_by_name[p->name()].proc == p; })) << "Elaboration has unknown procs"; + XLS_RET_CHECK(entrypoints.has_data_layout()) + << "Data layout required to create an aot runtime"; + llvm::Expected layout = + llvm::DataLayout::parse(entrypoints.data_layout()); + XLS_RET_CHECK(layout) << "Unable to parse '" << entrypoints.data_layout() + << "' to an llvm data-layout."; // Create a queue manager for the queues. This factory verifies that there an // receive only queue for every receive only channel. XLS_ASSIGN_OR_RETURN( std::unique_ptr queue_manager, - JitChannelQueueManager::CreateThreadSafe(std::move(elaboration))); + JitChannelQueueManager::CreateThreadSafe( + std::move(elaboration), std::make_unique(*layout))); // Create a ProcJit for each Proc. std::vector> proc_jits; for (const auto& [_, jit_args] : procs_by_name) { @@ -205,11 +221,15 @@ absl::StatusOr> CreateAotRuntime( absl::StatusOr> CreateRuntime( ProcElaboration elaboration) { + // We use the compiler to know the data layout. + XLS_ASSIGN_OR_RETURN(std::unique_ptr comp, OrcJit::Create()); + XLS_ASSIGN_OR_RETURN(llvm::DataLayout layout, comp->CreateDataLayout()); // Create a queue manager for the queues. This factory verifies that there an // receive only queue for every receive only channel. XLS_ASSIGN_OR_RETURN( std::unique_ptr queue_manager, - JitChannelQueueManager::CreateThreadSafe(std::move(elaboration))); + JitChannelQueueManager::CreateThreadSafe( + std::move(elaboration), std::make_unique(layout))); // Create a ProcJit for each Proc. std::vector> proc_jits; diff --git a/xls/jit/jit_runtime.cc b/xls/jit/jit_runtime.cc index 7d0b775cd2..5a57bd5ae6 100644 --- a/xls/jit/jit_runtime.cc +++ b/xls/jit/jit_runtime.cc @@ -26,16 +26,18 @@ #include "absl/strings/str_format.h" #include "absl/synchronization/mutex.h" #include "absl/types/span.h" +#include "llvm/include/llvm/IR/Constant.h" #include "llvm/include/llvm/IR/DataLayout.h" +#include "llvm/include/llvm/IR/DerivedTypes.h" +#include "llvm/include/llvm/IR/Type.h" #include "llvm/include/llvm/Support/Alignment.h" +#include "llvm/include/llvm/Support/Casting.h" #include "xls/common/bits_util.h" #include "xls/common/math_util.h" -#include "xls/common/status/status_macros.h" #include "xls/ir/bits.h" #include "xls/ir/type.h" #include "xls/ir/value.h" #include "xls/jit/llvm_type_converter.h" -#include "xls/jit/orc_jit.h" namespace xls { @@ -45,13 +47,6 @@ JitRuntime::JitRuntime(llvm::DataLayout data_layout) type_converter_( std::make_unique(context_.get(), data_layout_)) {} -/* static */ absl::StatusOr> JitRuntime::Create() { - XLS_ASSIGN_OR_RETURN(auto orc_jit, OrcJit::Create()); - XLS_ASSIGN_OR_RETURN(llvm::DataLayout data_layout, - orc_jit->CreateDataLayout()); - return std::make_unique(data_layout); -} - absl::Status JitRuntime::PackArgs(absl::Span args, absl::Span arg_types, absl::Span arg_buffers) { diff --git a/xls/jit/jit_runtime.h b/xls/jit/jit_runtime.h index e533d6b281..a938627103 100644 --- a/xls/jit/jit_runtime.h +++ b/xls/jit/jit_runtime.h @@ -15,11 +15,13 @@ #ifndef XLS_JIT_JIT_RUNTIME_H_ #define XLS_JIT_JIT_RUNTIME_H_ +#include #include #include +#include "absl/base/thread_annotations.h" #include "absl/status/status.h" -#include "absl/status/statusor.h" +#include "absl/synchronization/mutex.h" #include "absl/types/span.h" #include "llvm/include/llvm/IR/DataLayout.h" #include "xls/ir/type.h" @@ -34,7 +36,6 @@ namespace xls { class JitRuntime { public: explicit JitRuntime(llvm::DataLayout data_layout); - static absl::StatusOr> Create(); // Packs the specified values into a flat buffer with the data layout // expected by LLVM. diff --git a/xls/jit/proc_jit_test.cc b/xls/jit/proc_jit_test.cc index c685fd0b7c..1b338320c2 100644 --- a/xls/jit/proc_jit_test.cc +++ b/xls/jit/proc_jit_test.cc @@ -32,8 +32,8 @@ namespace { JitRuntime* GetJitRuntime() { static auto orc_jit = OrcJit::Create().value(); - static auto jit_runtime = std::make_unique( - orc_jit->CreateDataLayout().value()); + static auto jit_runtime = + std::make_unique(orc_jit->CreateDataLayout().value()); return jit_runtime.get(); } @@ -50,7 +50,10 @@ INSTANTIATE_TEST_SUITE_P( .value(); }, [](Package* package) -> std::unique_ptr { - return JitChannelQueueManager::CreateThreadSafe(package).value(); + return JitChannelQueueManager::CreateThreadSafe( + package, std::make_unique( + GetJitRuntime()->data_layout())) + .value(); }))); } // namespace diff --git a/xls/tools/BUILD b/xls/tools/BUILD index c2fea81714..7077bdfc40 100644 --- a/xls/tools/BUILD +++ b/xls/tools/BUILD @@ -1229,6 +1229,7 @@ cc_binary( "//xls/jit:function_jit", "//xls/jit:jit_channel_queue", "//xls/jit:jit_runtime", + "//xls/jit:orc_jit", "//xls/jit:proc_jit", "//xls/passes:bdd_function", "//xls/passes:bdd_query_engine", diff --git a/xls/tools/benchmark_main.cc b/xls/tools/benchmark_main.cc index 4f1927e204..0e06b2bc33 100644 --- a/xls/tools/benchmark_main.cc +++ b/xls/tools/benchmark_main.cc @@ -72,6 +72,7 @@ #include "xls/jit/function_jit.h" #include "xls/jit/jit_channel_queue.h" #include "xls/jit/jit_runtime.h" +#include "xls/jit/orc_jit.h" #include "xls/jit/proc_jit.h" #include "xls/passes/bdd_function.h" #include "xls/passes/bdd_query_engine.h" @@ -605,10 +606,7 @@ absl::Status RunBlockInterpreterAndJit(Block* block, std::string_view description, Rng& rng_engine) { absl::Time start_jit_compile = absl::Now(); - XLS_ASSIGN_OR_RETURN(std::unique_ptr runtime, - JitRuntime::Create()); - XLS_ASSIGN_OR_RETURN(std::unique_ptr jit, - BlockJit::Create(block)); + XLS_ASSIGN_OR_RETURN(std::unique_ptr jit, BlockJit::Create(block)); std::cout << absl::StreamFormat( "JIT compile time (%s): %dms\n", description, DurationToMs(absl::Now() - start_jit_compile)); @@ -634,8 +632,9 @@ absl::Status RunBlockInterpreterAndJit(Block* block, input_types.reserve(block->GetInputPorts().size()); absl::c_transform(block->GetInputPorts(), std::back_inserter(input_types), [](InputPort* v) { return v->GetType(); }); - XLS_ASSIGN_OR_RETURN(auto jit_args, ConvertToJitArguments( - arg_set, input_types, runtime.get())); + XLS_ASSIGN_OR_RETURN( + auto jit_args, + ConvertToJitArguments(arg_set, input_types, jit->runtime())); auto [jit_arg_buffers, jit_arg_pointers] = std::move(jit_args); // The JIT is much faster so run many times. @@ -680,9 +679,14 @@ template absl::Status RunProcInterpreterAndJit(Proc* proc, std::string_view description, Rng& rng_engine) { absl::Time start_jit_compile = absl::Now(); + // Technically the creation cost to get the data layout is amortized over all + // the procs in the elaboration. + XLS_ASSIGN_OR_RETURN(std::unique_ptr layout_source, OrcJit::Create()); + XLS_ASSIGN_OR_RETURN(auto data_layout, layout_source->CreateDataLayout()); XLS_ASSIGN_OR_RETURN( std::unique_ptr queue_manager, - JitChannelQueueManager::CreateThreadSafe(proc->package())); + JitChannelQueueManager::CreateThreadSafe( + proc->package(), std::make_unique(data_layout))); XLS_ASSIGN_OR_RETURN( std::unique_ptr jit, ProcJit::Create(proc, &queue_manager->runtime(), queue_manager.get()));