Skip to content

Commit

Permalink
Use the AOT data layout on AOT code
Browse files Browse the repository at this point in the history
We were always using the data-layout for the ORC-jit when running AOT code. These are currently the same but we really should be sure to use the layout the code was actually compiled with

PiperOrigin-RevId: 638079658
  • Loading branch information
allight authored and copybara-github committed May 29, 2024
1 parent 73c5545 commit dabefa3
Show file tree
Hide file tree
Showing 21 changed files with 163 additions and 98 deletions.
2 changes: 2 additions & 0 deletions xls/interpreter/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -448,6 +448,8 @@ cc_test(
"//xls/ir:value",
"//xls/jit:jit_channel_queue",
"//xls/jit:jit_proc_runtime",
"//xls/jit:jit_runtime",
"//xls/jit:orc_jit",
"//xls/jit:proc_jit",
"@com_google_googletest//:gtest",
],
Expand Down
7 changes: 6 additions & 1 deletion xls/interpreter/serial_proc_runtime_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@
#include "xls/ir/value.h"
#include "xls/jit/jit_channel_queue.h"
#include "xls/jit/jit_proc_runtime.h"
#include "xls/jit/jit_runtime.h"
#include "xls/jit/orc_jit.h"
#include "xls/jit/proc_jit.h"

namespace xls {
Expand All @@ -51,11 +53,14 @@ constexpr const char kIrAssertPath[] = "xls/interpreter/force_assert.ir";
// ProcJits.
absl::StatusOr<std::unique_ptr<SerialProcRuntime>> CreateMixedSerialProcRuntime(
ProcElaboration elaboration) {
XLS_ASSIGN_OR_RETURN(std::unique_ptr<OrcJit> orc, OrcJit::Create());
XLS_ASSIGN_OR_RETURN(auto data_layout, orc->CreateDataLayout());
// Create a queue manager for the queues. This factory verifies that there an
// receive only queue for every receive only channel.
XLS_ASSIGN_OR_RETURN(
std::unique_ptr<JitChannelQueueManager> queue_manager,
JitChannelQueueManager::CreateThreadSafe(std::move(elaboration)));
JitChannelQueueManager::CreateThreadSafe(
std::move(elaboration), std::make_unique<JitRuntime>(data_layout)));

// Create a ProcJit or a ProcInterpreter for each Proc. Alternate between the
// two options.
Expand Down
18 changes: 13 additions & 5 deletions xls/jit/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,6 @@ cc_binary(
name = "aot_compiler_main",
srcs = ["aot_compiler_main.cc"],
deps = [
":aot_compiler",
":aot_entrypoint_cc_proto",
":function_base_jit",
":function_jit",
Expand Down Expand Up @@ -450,6 +449,7 @@ cc_library(
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/types:span",
"@llvm-project//llvm:Support",
"@llvm-project//llvm:ir_headers",
],
)
Expand Down Expand Up @@ -590,7 +590,6 @@ cc_library(
"//xls/interpreter:channel_queue",
"//xls/ir",
"//xls/ir:channel",
"//xls/ir:elaboration",
"//xls/ir:proc_elaboration",
"//xls/ir:value",
"@com_google_absl//absl/base:core_headers",
Expand Down Expand Up @@ -630,17 +629,15 @@ cc_library(
hdrs = ["jit_runtime.h"],
deps = [
":llvm_type_converter",
":orc_jit",
"//xls/common:bits_util",
"//xls/common:math_util",
"//xls/common/status:status_macros",
"//xls/ir:bits",
"//xls/ir:type",
"//xls/ir:value",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/synchronization",
"@com_google_absl//absl/types:span",
Expand Down Expand Up @@ -765,6 +762,7 @@ cc_library(
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/types:span",
"@llvm-project//llvm:ir_headers",
],
)

Expand All @@ -774,17 +772,24 @@ cc_test(
deps = [
":block_jit",
":jit_runtime",
":orc_jit",
"//xls/common:xls_gunit_main",
"//xls/common/status:matchers",
"//xls/common/status:status_macros",
"//xls/interpreter:block_evaluator",
"//xls/interpreter:block_evaluator_test_base",
"//xls/ir",
"//xls/ir:bits",
"//xls/ir:function_builder",
"//xls/ir:ir_test_base",
"//xls/ir:value",
"//xls/ir:value_view",
"@com_google_absl//absl/base",
"@com_google_absl//absl/base:no_destructor",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/types:span",
"@com_google_googletest//:gtest",
"@llvm-project//llvm:ir_headers",
],
)

Expand Down Expand Up @@ -883,6 +888,7 @@ cc_library(
":aot_entrypoint_cc_proto",
":function_base_jit",
":jit_channel_queue",
":jit_runtime",
":llvm_compiler",
":proc_jit",
"//xls/common/status:ret_check",
Expand All @@ -900,6 +906,7 @@ cc_library(
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/types:span",
"@llvm-project//llvm:Support",
"@llvm-project//llvm:Target",
"@llvm-project//llvm:ir_headers",
],
Expand All @@ -926,6 +933,7 @@ cc_binary(
deps = [
":jit_channel_queue",
":jit_runtime",
":orc_jit",
"//xls/ir",
"//xls/ir:channel",
"//xls/ir:channel_ops",
Expand Down
32 changes: 17 additions & 15 deletions xls/jit/aot_compiler_main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include <filesystem> // NOLINT
#include <iostream>
#include <memory>
#include <optional>
#include <string>
#include <string_view>
#include <vector>
Expand All @@ -32,6 +33,7 @@
#include "absl/status/statusor.h"
#include "absl/strings/str_cat.h"
#include "llvm/include/llvm/IR/DataLayout.h"
#include "llvm/include/llvm/IR/LLVMContext.h"
#include "google/protobuf/text_format.h"
#include "xls/common/file/filesystem.h"
#include "xls/common/init_xls.h"
Expand All @@ -42,7 +44,6 @@
#include "xls/ir/nodes.h"
#include "xls/ir/package.h"
#include "xls/ir/type.h"
#include "xls/jit/aot_compiler.h"
#include "xls/jit/aot_entrypoint.pb.h"
#include "xls/jit/function_base_jit.h"
#include "xls/jit/function_jit.h"
Expand Down Expand Up @@ -76,14 +77,8 @@ namespace {

absl::StatusOr<AotEntrypointProto> GenerateEntrypointProto(
Package* package, FunctionBase* func, const JittedFunctionBase& object_code,
bool include_msan) {
bool include_msan, LlvmTypeConverter& type_converter) {
AotEntrypointProto proto;
XLS_ASSIGN_OR_RETURN(
std::unique_ptr<AotCompiler> aot_compiler,
AotCompiler::Create(/*emit_msan=*/include_msan, /*opt_level=*/3));
XLS_ASSIGN_OR_RETURN(llvm::DataLayout data_layout,
aot_compiler->CreateDataLayout());
LlvmTypeConverter type_converter(aot_compiler->GetContext(), data_layout);
proto.set_has_msan(include_msan);
if (func->IsFunction()) {
proto.set_type(AotEntrypointProto::FUNCTION);
Expand Down Expand Up @@ -171,7 +166,7 @@ absl::Status RealMain(const std::string& input_ir_path, const std::string& top,
}
}

JitObjectCode object_code;
std::optional<JitObjectCode> object_code;
if (f->IsFunction()) {
XLS_ASSIGN_OR_RETURN(object_code, FunctionJit::CreateObjectCode(
f->AsFunctionOrDie(),
Expand All @@ -191,12 +186,19 @@ absl::Status RealMain(const std::string& input_ir_path, const std::string& top,
}
AotPackageEntrypointsProto all_entrypoints;
XLS_RETURN_IF_ERROR(SetFileContents(
output_object_path, std::string(object_code.object_code.begin(),
object_code.object_code.end())));
for (const FunctionEntrypoint& oc : object_code.entrypoints) {
XLS_ASSIGN_OR_RETURN(*all_entrypoints.add_entrypoint(),
GenerateEntrypointProto(package.get(), oc.function,
oc.jit_info, include_msan));
output_object_path, std::string(object_code->object_code.begin(),
object_code->object_code.end())));

*all_entrypoints.mutable_data_layout() =
object_code->data_layout.getStringRepresentation();

auto context = std::make_unique<llvm::LLVMContext>();
LlvmTypeConverter type_converter(context.get(), object_code->data_layout);
for (const FunctionEntrypoint& oc : object_code->entrypoints) {
XLS_ASSIGN_OR_RETURN(
*all_entrypoints.add_entrypoint(),
GenerateEntrypointProto(package.get(), oc.function, oc.jit_info,
include_msan, type_converter));
}
if (generate_textproto) {
std::string text;
Expand Down
3 changes: 3 additions & 0 deletions xls/jit/aot_entrypoint.proto
Original file line number Diff line number Diff line change
Expand Up @@ -80,4 +80,7 @@ message AotEntrypointProto {
// a list of all of the targets contained.
message AotPackageEntrypointsProto {
repeated AotEntrypointProto entrypoint = 1;

// The LLVM DataLayout used in this compile.
optional string data_layout = 2;
}
9 changes: 4 additions & 5 deletions xls/jit/block_jit.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,10 @@ absl::StatusOr<std::unique_ptr<BlockJit>> BlockJit::Create(Block* block) {
return absl::UnimplementedError(
"Jitting of blocks with instantiations is not yet supported.");
}
XLS_ASSIGN_OR_RETURN(std::unique_ptr<JitRuntime> runtime,
JitRuntime::Create());
return std::unique_ptr<BlockJit>(new BlockJit(
block, std::move(runtime), std::move(orc_jit), std::move(function)));
XLS_ASSIGN_OR_RETURN(auto data_layout, orc_jit->CreateDataLayout());
return std::unique_ptr<BlockJit>(
new BlockJit(block, std::make_unique<JitRuntime>(data_layout),
std::move(orc_jit), std::move(function)));
}

std::unique_ptr<BlockJitContinuation> BlockJit::NewContinuation() {
Expand Down Expand Up @@ -351,7 +351,6 @@ absl::StatusOr<BlockRunResult> JitBlockEvaluator::EvaluateBlock(
XLS_RET_CHECK_EQ(elaboration.instances().size(), 1)
<< "StreamingJitBlockEvaluator does not support instantiations";

XLS_ASSIGN_OR_RETURN(auto runtime, JitRuntime::Create());
XLS_ASSIGN_OR_RETURN(auto jit, BlockJit::Create(top_block));
auto continuation = jit->NewContinuation();
XLS_RETURN_IF_ERROR(continuation->SetInputPorts(inputs));
Expand Down
2 changes: 2 additions & 0 deletions xls/jit/function_base_jit.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include "absl/container/flat_hash_map.h"
#include "absl/status/statusor.h"
#include "absl/types/span.h"
#include "llvm/include/llvm/IR/DataLayout.h"
#include "xls/ir/events.h"
#include "xls/ir/function.h"
#include "xls/ir/function_base.h"
Expand Down Expand Up @@ -301,6 +302,7 @@ struct FunctionEntrypoint {
struct JitObjectCode {
std::vector<uint8_t> object_code;
std::vector<FunctionEntrypoint> entrypoints;
llvm::DataLayout data_layout;
};

} // namespace xls
Expand Down
17 changes: 10 additions & 7 deletions xls/jit/function_base_jit_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,21 +62,24 @@ class BaseFunctionJitWrapper {
template <typename RealType>
static absl::StatusOr<std::unique_ptr<RealType>> Create(
std::string_view ir_text, std::string_view function_name,
absl::Span<uint8_t const> aot_entrypoint_proto_bin,
absl::Span<uint8_t const> aot_entrypoints_proto_bin,
JitFunctionType unpacked_entrypoint, JitFunctionType packed_entrypoint)
requires(std::is_base_of_v<BaseFunctionJitWrapper, RealType>)
{
XLS_ASSIGN_OR_RETURN(auto package,
ParsePackage(ir_text, /*filename=*/std::nullopt));
XLS_ASSIGN_OR_RETURN(auto function, package->GetFunction(function_name));
AotEntrypointProto proto;
AotPackageEntrypointsProto proto;
// NB We could fallback to real jit here maybe?
XLS_RET_CHECK(proto.ParseFromArray(aot_entrypoint_proto_bin.data(),
aot_entrypoint_proto_bin.size()))
XLS_RET_CHECK(proto.ParseFromArray(aot_entrypoints_proto_bin.data(),
aot_entrypoints_proto_bin.size()))
<< "Unable to parse aot information.";
XLS_ASSIGN_OR_RETURN(
auto jit, FunctionJit::CreateFromAot(
function, proto, unpacked_entrypoint, packed_entrypoint));
XLS_RET_CHECK_EQ(proto.entrypoint_size(), 1)
<< "FunctionWrapper should only have a single XLS function compiled.";
XLS_ASSIGN_OR_RETURN(auto jit,
FunctionJit::CreateFromAot(
function, proto.entrypoint(0), proto.data_layout(),
unpacked_entrypoint, packed_entrypoint));
return std::unique_ptr<RealType>(
new RealType(std::move(package), std::move(jit),
MatchesImplicitToken(function->GetType()->parameters())));
Expand Down
17 changes: 13 additions & 4 deletions xls/jit/function_jit.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@

#include <cstdint>
#include <memory>
#include <optional>
#include <string>
#include <string_view>
#include <utility>
#include <vector>

Expand All @@ -28,6 +30,8 @@
#include "absl/strings/str_format.h"
#include "absl/types/span.h"
#include "llvm/include/llvm/IR/DataLayout.h"
#include "llvm/include/llvm/Support/Error.h"
#include "xls/common/status/ret_check.h"
#include "xls/common/status/status_macros.h"
#include "xls/ir/events.h"
#include "xls/ir/function.h"
Expand All @@ -38,7 +42,6 @@
#include "xls/ir/value_utils.h"
#include "xls/jit/aot_compiler.h"
#include "xls/jit/function_base_jit.h"
#include "xls/jit/jit_callbacks.h"
#include "xls/jit/jit_runtime.h"
#include "xls/jit/observer.h"
#include "xls/jit/orc_jit.h"
Expand All @@ -55,27 +58,32 @@ absl::StatusOr<std::unique_ptr<FunctionJit>> FunctionJit::Create(
/* static */ absl::StatusOr<std::unique_ptr<FunctionJit>>
FunctionJit::CreateFromAot(Function* xls_function,
const AotEntrypointProto& entrypoint,
std::string_view data_layout,
JitFunctionType function_unpacked,
std::optional<JitFunctionType> function_packed) {
XLS_ASSIGN_OR_RETURN(
JittedFunctionBase jfb,
JittedFunctionBase::BuildFromAot(xls_function, entrypoint,
function_unpacked, function_packed));
XLS_ASSIGN_OR_RETURN(auto runtime, JitRuntime::Create());
llvm::Expected<llvm::DataLayout> layout =
llvm::DataLayout::parse(data_layout);
XLS_RET_CHECK(layout) << "Unable to parse '" << data_layout
<< "' to an llvm data-layout.";
// OrcJit is simply the arena that holds the JITed code. Since we are already
// compiled theres no need to create and initialize it.
// TODO(allight): Ideally we wouldn't even need to link in the llvm stuff if
// we go down this path, that's a larger refactor however and just carrying
// around some extra .so's isn't a huge deal.
return std::unique_ptr<FunctionJit>(
new FunctionJit(xls_function, std::unique_ptr<OrcJit>(nullptr),
std::move(jfb), std::move(runtime)));
std::move(jfb), std::make_unique<JitRuntime>(*layout)));
}

absl::StatusOr<JitObjectCode> FunctionJit::CreateObjectCode(
Function* xls_function, int64_t opt_level, bool include_msan) {
XLS_ASSIGN_OR_RETURN(std::unique_ptr<AotCompiler> comp,
AotCompiler::Create(include_msan, opt_level));
XLS_ASSIGN_OR_RETURN(llvm::DataLayout data_layout, comp->CreateDataLayout());
XLS_ASSIGN_OR_RETURN(JittedFunctionBase jfb,
JittedFunctionBase::Build(xls_function, *comp));
XLS_ASSIGN_OR_RETURN(auto obj_code, std::move(comp)->GetObjectCode());
Expand All @@ -85,7 +93,8 @@ absl::StatusOr<JitObjectCode> FunctionJit::CreateObjectCode(
.function = xls_function,
.jit_info = std::move(jfb),
},
}};
},
.data_layout = data_layout};
}

absl::StatusOr<std::unique_ptr<FunctionJit>> FunctionJit::CreateInternal(
Expand Down
2 changes: 1 addition & 1 deletion xls/jit/function_jit.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
#include <string>
#include <string_view>
#include <utility>
#include <vector>

#include "absl/container/flat_hash_map.h"
#include "absl/status/status.h"
Expand Down Expand Up @@ -58,6 +57,7 @@ class FunctionJit {
// function.
static absl::StatusOr<std::unique_ptr<FunctionJit>> CreateFromAot(
Function* xls_function, const AotEntrypointProto& entrypoint,
std::string_view data_layout,
JitFunctionType function_unpacked,
std::optional<JitFunctionType> function_packed = std::nullopt);

Expand Down
Loading

0 comments on commit dabefa3

Please sign in to comment.