Skip to content

Commit

Permalink
Simplify block-jit creation
Browse files Browse the repository at this point in the history
Make block-jit match other jits and keep the JitRuntime as an internal implementation detail specific to each jit.

PiperOrigin-RevId: 638072265
  • Loading branch information
allight authored and copybara-github committed May 28, 2024
1 parent aeb9b52 commit 73c5545
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 49 deletions.
48 changes: 20 additions & 28 deletions xls/jit/block_jit.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,30 +50,31 @@

namespace xls {

absl::StatusOr<std::unique_ptr<BlockJit>> BlockJit::Create(
Block* block, JitRuntime* runtime) {
absl::StatusOr<std::unique_ptr<BlockJit>> BlockJit::Create(Block* block) {
XLS_ASSIGN_OR_RETURN(std::unique_ptr<OrcJit> orc_jit, OrcJit::Create());
XLS_ASSIGN_OR_RETURN(auto function,
JittedFunctionBase::Build(block, *orc_jit));
if (!block->GetInstantiations().empty()) {
return absl::UnimplementedError(
"Jitting of blocks with instantiations is not yet supported.");
}
return std::unique_ptr<BlockJit>(
new BlockJit(block, runtime, std::move(orc_jit), std::move(function)));
XLS_ASSIGN_OR_RETURN(std::unique_ptr<JitRuntime> runtime,
JitRuntime::Create());
return std::unique_ptr<BlockJit>(new BlockJit(
block, std::move(runtime), std::move(orc_jit), std::move(function)));
}

std::unique_ptr<BlockJitContinuation> BlockJit::NewContinuation() {
return std::unique_ptr<BlockJitContinuation>(
new BlockJitContinuation(block_, this, runtime_, function_));
new BlockJitContinuation(block_, this, function_));
}

absl::Status BlockJit::RunOneCycle(BlockJitContinuation& continuation) {
function_.RunJittedFunction(
continuation.input_buffers_.current(),
continuation.output_buffers_.current(), continuation.temp_buffer_,
&continuation.GetEvents(), /*instance_context=*/&continuation.callbacks_,
runtime_,
runtime_.get(),
/*continuation_point=*/0);
continuation.SwapRegisters();
return absl::OkStatus();
Expand Down Expand Up @@ -139,11 +140,9 @@ BlockJitContinuation::IOSpace BlockJitContinuation::MakeCombinedBuffers(
}

BlockJitContinuation::BlockJitContinuation(Block* block, BlockJit* jit,
JitRuntime* runtime,
const JittedFunctionBase& jit_func)
: block_(block),
block_jit_(jit),
runtime_(runtime),
register_buffers_memory_{jit_func.CreateInputBuffer(),
jit_func.CreateInputBuffer()},
input_port_buffers_memory_(jit_func.CreateInputBuffer()),
Expand Down Expand Up @@ -178,7 +177,7 @@ absl::Status BlockJitContinuation::SetInputPorts(
<< ip->GetType()->ToString();
++it;
}
return runtime_->PackArgs(values, types, input_port_pointers());
return block_jit_->runtime()->PackArgs(values, types, input_port_pointers());
}

absl::Status BlockJitContinuation::SetInputPorts(
Expand Down Expand Up @@ -231,7 +230,7 @@ absl::Status BlockJitContinuation::SetRegisters(
<< reg->type()->ToString();
++it;
}
return runtime_->PackArgs(values, types, register_pointers());
return block_jit_->runtime()->PackArgs(values, types, register_pointers());
}

absl::Status BlockJitContinuation::SetRegisters(
Expand Down Expand Up @@ -275,7 +274,7 @@ std::vector<Value> BlockJitContinuation::GetOutputPorts() const {
result.reserve(output_port_pointers().size());
int i = 0;
for (auto ptr : output_port_pointers()) {
result.push_back(runtime_->UnpackBuffer(
result.push_back(block_jit_->runtime()->UnpackBuffer(
ptr, block_->GetOutputPorts()[i++]->operand(0)->GetType()));
}
return result;
Expand Down Expand Up @@ -327,8 +326,8 @@ std::vector<Value> BlockJitContinuation::GetRegisters() const {
result.reserve(register_pointers().size());
int i = 0;
for (auto ptr : register_pointers()) {
result.push_back(
runtime_->UnpackBuffer(ptr, block_->GetRegisters()[i++]->type()));
result.push_back(block_jit_->runtime()->UnpackBuffer(
ptr, block_->GetRegisters()[i++]->type()));
}
return result;
}
Expand All @@ -353,7 +352,7 @@ absl::StatusOr<BlockRunResult> JitBlockEvaluator::EvaluateBlock(
<< "StreamingJitBlockEvaluator does not support instantiations";

XLS_ASSIGN_OR_RETURN(auto runtime, JitRuntime::Create());
XLS_ASSIGN_OR_RETURN(auto jit, BlockJit::Create(top_block, runtime.get()));
XLS_ASSIGN_OR_RETURN(auto jit, BlockJit::Create(top_block));
auto continuation = jit->NewContinuation();
XLS_RETURN_IF_ERROR(continuation->SetInputPorts(inputs));
XLS_RETURN_IF_ERROR(continuation->SetRegisters(reg_state));
Expand All @@ -374,8 +373,7 @@ StreamingJitBlockEvaluator::EvaluateSequentialBlock(
for (Register* reg : block->GetRegisters()) {
reg_state[reg->name()] = ZeroOfType(reg->type());
}
XLS_ASSIGN_OR_RETURN(auto runtime, JitRuntime::Create());
XLS_ASSIGN_OR_RETURN(auto jit, BlockJit::Create(block, runtime.get()));
XLS_ASSIGN_OR_RETURN(auto jit, BlockJit::Create(block));
auto continuation = jit->NewContinuation();
XLS_RETURN_IF_ERROR(continuation->SetRegisters(reg_state));

Expand Down Expand Up @@ -403,8 +401,7 @@ StreamingJitBlockEvaluator::EvaluateChannelizedSequentialBlock(
reg_state[reg->name()] = ZeroOfType(reg->type());
}

XLS_ASSIGN_OR_RETURN(auto runtime, JitRuntime::Create());
XLS_ASSIGN_OR_RETURN(auto jit, BlockJit::Create(block, runtime.get()));
XLS_ASSIGN_OR_RETURN(auto jit, BlockJit::Create(block));
auto continuation = jit->NewContinuation();
XLS_RETURN_IF_ERROR(continuation->SetRegisters(reg_state));

Expand Down Expand Up @@ -468,11 +465,8 @@ namespace {
class BlockContinuationJitWrapper final : public BlockContinuation {
public:
BlockContinuationJitWrapper(std::unique_ptr<BlockJitContinuation>&& cont,
std::unique_ptr<BlockJit>&& jit,
std::unique_ptr<JitRuntime>&& runtime)
: continuation_(std::move(cont)),
jit_(std::move(jit)),
runtime_(std::move(runtime)) {}
std::unique_ptr<BlockJit>&& jit)
: continuation_(std::move(cont)), jit_(std::move(jit)) {}
const absl::flat_hash_map<std::string, Value>& output_ports() final {
if (!temporary_outputs_) {
temporary_outputs_.emplace(continuation_->GetOutputPortsMap());
Expand Down Expand Up @@ -502,7 +496,6 @@ class BlockContinuationJitWrapper final : public BlockContinuation {
private:
std::unique_ptr<BlockJitContinuation> continuation_;
std::unique_ptr<BlockJit> jit_;
std::unique_ptr<JitRuntime> runtime_;
// Holder for the data we return out of output_ports so that we can reduce
// copying.
std::optional<absl::flat_hash_map<std::string, Value>> temporary_outputs_;
Expand All @@ -519,12 +512,11 @@ StreamingJitBlockEvaluator::NewContinuation(
Block* top_block = *elaboration.top()->block();
XLS_RET_CHECK_EQ(elaboration.instances().size(), 1)
<< "StreamingJitBlockEvaluator does not support instantiations";
XLS_ASSIGN_OR_RETURN(auto runtime, JitRuntime::Create());
XLS_ASSIGN_OR_RETURN(auto jit, BlockJit::Create(top_block, runtime.get()));
XLS_ASSIGN_OR_RETURN(auto jit, BlockJit::Create(top_block));
auto jit_cont = jit->NewContinuation();
XLS_RETURN_IF_ERROR(jit_cont->SetRegisters(initial_registers));
return std::make_unique<BlockContinuationJitWrapper>(
std::move(jit_cont), std::move(jit), std::move(runtime));
return std::make_unique<BlockContinuationJitWrapper>(std::move(jit_cont),
std::move(jit));
}

} // namespace xls
16 changes: 8 additions & 8 deletions xls/jit/block_jit.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,7 @@ namespace xls {
class BlockJitContinuation;
class BlockJit {
public:
static absl::StatusOr<std::unique_ptr<BlockJit>> Create(Block* block,
JitRuntime* runtime);
static absl::StatusOr<std::unique_ptr<BlockJit>> Create(Block* block);

// Create a new blank block with no registers or ports set. Can be cycled
// independently of other blocks/continuations.
Expand All @@ -61,6 +60,8 @@ class BlockJit {

OrcJit& orc_jit() const { return *jit_; }

JitRuntime* runtime() const { return runtime_.get(); }

// Get how large each pointer buffer for the input ports are.
absl::Span<const int64_t> input_port_sizes() const {
return absl::MakeConstSpan(function_.input_buffer_sizes())
Expand All @@ -74,15 +75,15 @@ class BlockJit {
}

private:
BlockJit(Block* block, JitRuntime* runtime, std::unique_ptr<OrcJit> jit,
JittedFunctionBase function)
BlockJit(Block* block, std::unique_ptr<JitRuntime> runtime,
std::unique_ptr<OrcJit> jit, JittedFunctionBase function)
: block_(block),
runtime_(runtime),
runtime_(std::move(runtime)),
jit_(std::move(jit)),
function_(std::move(function)) {}

Block* block_;
JitRuntime* runtime_;
std::unique_ptr<JitRuntime> runtime_;
std::unique_ptr<OrcJit> jit_;
JittedFunctionBase function_;
};
Expand Down Expand Up @@ -194,7 +195,7 @@ class BlockJitContinuation {

private:
using BufferPair = std::array<JitArgumentSet, 2>;
BlockJitContinuation(Block* block, BlockJit* jit, JitRuntime* runtime,
BlockJitContinuation(Block* block, BlockJit* jit,
const JittedFunctionBase& jit_func);
static IOSpace MakeCombinedBuffers(const JittedFunctionBase& jit_func,
const Block* block,
Expand Down Expand Up @@ -227,7 +228,6 @@ class BlockJitContinuation {

const Block* block_;
BlockJit* block_jit_;
JitRuntime* runtime_;

// Buffers for the registers. Note this includes (unused) space for the input
// ports.
Expand Down
18 changes: 6 additions & 12 deletions xls/jit/block_jit_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,7 @@ TEST_F(BlockJitTest, ConstantToPort) {
bb.OutputPort("answer", input);

XLS_ASSERT_OK_AND_ASSIGN(Block * b, bb.Build());
XLS_ASSERT_OK_AND_ASSIGN(auto runtime, JitRuntime::Create());
XLS_ASSERT_OK_AND_ASSIGN(auto jit, BlockJit::Create(b, runtime.get()));
XLS_ASSERT_OK_AND_ASSIGN(auto jit, BlockJit::Create(b));
auto cont = jit->NewContinuation();
XLS_ASSERT_OK(jit->RunOneCycle(*cont));

Expand All @@ -58,8 +57,7 @@ TEST_F(BlockJitTest, AddTwoPort) {
bb.OutputPort("answer", bb.Add(input1, input2));

XLS_ASSERT_OK_AND_ASSIGN(Block * b, bb.Build());
XLS_ASSERT_OK_AND_ASSIGN(auto runtime, JitRuntime::Create());
XLS_ASSERT_OK_AND_ASSIGN(auto jit, BlockJit::Create(b, runtime.get()));
XLS_ASSERT_OK_AND_ASSIGN(auto jit, BlockJit::Create(b));
auto cont = jit->NewContinuation();
XLS_ASSERT_OK(
cont->SetInputPorts({Value(UBits(12, 8)), Value(UBits(30, 8))}));
Expand All @@ -78,8 +76,7 @@ TEST_F(BlockJitTest, ConstantToReg) {
bb.RegisterRead(r);

XLS_ASSERT_OK_AND_ASSIGN(Block * b, bb.Build());
XLS_ASSERT_OK_AND_ASSIGN(auto runtime, JitRuntime::Create());
XLS_ASSERT_OK_AND_ASSIGN(auto jit, BlockJit::Create(b, runtime.get()));
XLS_ASSERT_OK_AND_ASSIGN(auto jit, BlockJit::Create(b));
auto cont = jit->NewContinuation();
XLS_ASSERT_OK(cont->SetRegisters({Value(UBits(2, 8))}));
XLS_ASSERT_OK(jit->RunOneCycle(*cont));
Expand All @@ -98,8 +95,7 @@ TEST_F(BlockJitTest, DelaySlot) {
bb.OutputPort("output", read);

XLS_ASSERT_OK_AND_ASSIGN(Block * b, bb.Build());
XLS_ASSERT_OK_AND_ASSIGN(auto runtime, JitRuntime::Create());
XLS_ASSERT_OK_AND_ASSIGN(auto jit, BlockJit::Create(b, runtime.get()));
XLS_ASSERT_OK_AND_ASSIGN(auto jit, BlockJit::Create(b));
auto cont = jit->NewContinuation();

XLS_ASSERT_OK(cont->SetRegisters({Value(UBits(2, 8))}));
Expand All @@ -123,8 +119,7 @@ TEST_F(BlockJitTest, SetInputsWithViews) {
bb.OutputPort("output", bb.Add(input1, input2));

XLS_ASSERT_OK_AND_ASSIGN(Block * b, bb.Build());
XLS_ASSERT_OK_AND_ASSIGN(auto runtime, JitRuntime::Create());
XLS_ASSERT_OK_AND_ASSIGN(auto jit, BlockJit::Create(b, runtime.get()));
XLS_ASSERT_OK_AND_ASSIGN(auto jit, BlockJit::Create(b));
auto cont = jit->NewContinuation();

int16_t input_bits1 = -12;
Expand Down Expand Up @@ -154,8 +149,7 @@ TEST_F(BlockJitTest, SetRegistersWithViews) {
bb.OutputPort("output", bb.Add(read1, read2));

XLS_ASSERT_OK_AND_ASSIGN(Block * b, bb.Build());
XLS_ASSERT_OK_AND_ASSIGN(auto runtime, JitRuntime::Create());
XLS_ASSERT_OK_AND_ASSIGN(auto jit, BlockJit::Create(b, runtime.get()));
XLS_ASSERT_OK_AND_ASSIGN(auto jit, BlockJit::Create(b));
auto cont = jit->NewContinuation();

int16_t input_bits1 = -12;
Expand Down
2 changes: 1 addition & 1 deletion xls/tools/benchmark_main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -608,7 +608,7 @@ absl::Status RunBlockInterpreterAndJit(Block* block,
XLS_ASSIGN_OR_RETURN(std::unique_ptr<JitRuntime> runtime,
JitRuntime::Create());
XLS_ASSIGN_OR_RETURN(std::unique_ptr<BlockJit> jit,
BlockJit::Create(block, runtime.get()));
BlockJit::Create(block));
std::cout << absl::StreamFormat(
"JIT compile time (%s): %dms\n", description,
DurationToMs(absl::Now() - start_jit_compile));
Expand Down

0 comments on commit 73c5545

Please sign in to comment.