Skip to content

Commit

Permalink
Improve performance of node-coverage
Browse files Browse the repository at this point in the history
This brings performance with node-coverage to only 50% worse than no coverage.

PiperOrigin-RevId: 680784331
  • Loading branch information
allight authored and copybara-github committed Oct 1, 2024
1 parent 6321ea6 commit dcc56d7
Show file tree
Hide file tree
Showing 10 changed files with 124 additions and 15 deletions.
7 changes: 7 additions & 0 deletions xls/interpreter/observer.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
#ifndef XLS_INTERPRETER_OBSERVER_H_
#define XLS_INTERPRETER_OBSERVER_H_

#include <cstdint>
#include <optional>
#include <vector>

#include "absl/container/flat_hash_map.h"
Expand All @@ -23,11 +25,16 @@

namespace xls {

class RuntimeObserver;
// An observer which can be called for each node evaluated.
class EvaluationObserver {
public:
virtual ~EvaluationObserver() = default;
virtual void NodeEvaluated(Node* n, const Value& v) = 0;
// Convert this to an observer capable of accepting jit values if possible.
virtual std::optional<RuntimeObserver*> AsRawObserver() {
return std::nullopt;
}
};

// Test observer that just collects every node value.
Expand Down
1 change: 1 addition & 0 deletions xls/jit/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -713,6 +713,7 @@ cc_library(
"//xls/codegen:codegen_options",
"//xls/codegen:codegen_pass",
"//xls/codegen:materialize_fifos_pass",
"//xls/common:casts",
"//xls/common/status:ret_check",
"//xls/common/status:status_macros",
"//xls/interpreter:block_evaluator",
Expand Down
16 changes: 16 additions & 0 deletions xls/jit/block_jit.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
#include "xls/codegen/codegen_options.h"
#include "xls/codegen/codegen_pass.h"
#include "xls/codegen/materialize_fifos_pass.h"
#include "xls/common/casts.h"
#include "xls/common/status/ret_check.h"
#include "xls/common/status/status_macros.h"
#include "xls/interpreter/block_evaluator.h"
Expand Down Expand Up @@ -572,6 +573,7 @@ class BlockContinuationJitWrapper final : public BlockContinuation {
BlockContinuationJitWrapper(std::unique_ptr<BlockJitContinuation>&& cont,
std::unique_ptr<BlockJit>&& jit)
: continuation_(std::move(cont)), jit_(std::move(jit)) {}
JitRuntime* runtime() const { return jit_->runtime(); }
const absl::flat_hash_map<std::string, Value>& output_ports() final {
if (!temporary_outputs_) {
temporary_outputs_.emplace(continuation_->GetOutputPortsMap());
Expand Down Expand Up @@ -604,6 +606,10 @@ class BlockContinuationJitWrapper final : public BlockContinuation {
}
absl::Status SetObserver(EvaluationObserver* obs) override {
ClearObserver();
std::optional<RuntimeObserver*> run = obs->AsRawObserver();
if (run) {
return continuation_->SetObserver(*run);
}
eval_observer_.emplace(
obs,
[](int64_t ptr) -> Node* {
Expand Down Expand Up @@ -642,4 +648,14 @@ JitBlockEvaluator::MakeNewContinuation(
std::move(jit));
}

absl::StatusOr<JitRuntime*> JitBlockEvaluator::GetRuntime(
BlockContinuation* cont) const {
BlockContinuationJitWrapper* cont_wrap =
dynamic_cast<BlockContinuationJitWrapper*>(cont);
if (cont_wrap == nullptr) {
return absl::InvalidArgumentError("Not a jit continuation");
}
return cont_wrap->runtime();
}

} // namespace xls
1 change: 1 addition & 0 deletions xls/jit/block_jit.h
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,7 @@ class JitBlockEvaluator : public BlockEvaluator {
explicit constexpr JitBlockEvaluator(bool supports_observer = false)
: BlockEvaluator(supports_observer ? "ObservableJit" : "Jit"),
supports_observer_(supports_observer) {}
absl::StatusOr<JitRuntime*> GetRuntime(BlockContinuation* cont) const;

protected:
absl::StatusOr<std::unique_ptr<BlockContinuation>> MakeNewContinuation(
Expand Down
7 changes: 6 additions & 1 deletion xls/jit/proc_jit.cc
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,12 @@ absl::Status ProcJitContinuation::SetObserver(EvaluationObserver* obs) {
"Observers are not supported on this compilation.");
}
XLS_RETURN_IF_ERROR(ProcContinuation::SetObserver(obs));
instance_context_.observer = &observer_shim_;
auto runtime_obs = obs->AsRawObserver();
if (runtime_obs) {
instance_context_.observer = *runtime_obs;
} else {
instance_context_.observer = &observer_shim_;
}
return absl::OkStatus();
}

Expand Down
4 changes: 4 additions & 0 deletions xls/tools/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,7 @@ cc_binary(
"//xls/ir:value_utils",
"//xls/jit:block_jit",
"//xls/jit:jit_proc_runtime",
"//xls/jit:jit_runtime",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:btree",
"@com_google_absl//absl/container:flat_hash_map",
Expand Down Expand Up @@ -947,9 +948,12 @@ cc_library(
"//xls/ir:type",
"//xls/ir:value",
"//xls/ir:value_utils",
"//xls/jit:jit_runtime",
"//xls/jit:observer",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/types:span",
Expand Down
1 change: 1 addition & 0 deletions xls/tools/eval_ir_main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -469,6 +469,7 @@ class EvalInvariantChecker : public OptimizationInvariantChecker {
// after optimizations.
absl::Status Run(Package* package, absl::Span<const ArgSet> arg_sets_in) {
XLS_ASSIGN_OR_RETURN(Function * f, package->GetTopAsFunction());
// TODO(allight): Use the specialized jit-abi coverage observer.
ScopedRecordNodeCoverage cov(
absl::GetFlag(FLAGS_output_node_coverage_stats_proto),
absl::GetFlag(FLAGS_output_node_coverage_stats_textproto));
Expand Down
35 changes: 24 additions & 11 deletions xls/tools/eval_proc_main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@
#include "xls/ir/value_utils.h"
#include "xls/jit/block_jit.h"
#include "xls/jit/jit_proc_runtime.h"
#include "xls/jit/jit_runtime.h"
#include "xls/tools/eval_utils.h"
#include "xls/tools/node_coverage_utils.h"

Expand Down Expand Up @@ -234,19 +235,25 @@ static absl::Status EvaluateProcs(
expected_outputs_for_channels,
const EvaluateProcsOptions& options = {}) {
std::unique_ptr<SerialProcRuntime> runtime;
ScopedRecordNodeCoverage cov(
absl::GetFlag(FLAGS_output_node_coverage_stats_proto),
absl::GetFlag(FLAGS_output_node_coverage_stats_textproto));
std::optional<JitRuntime*> jit;
EvaluatorOptions evaluator_options;
evaluator_options.set_trace_channels(absl::GetFlag(FLAGS_trace_channels))
.set_support_observers(cov.observer().has_value());
evaluator_options.set_trace_channels(absl::GetFlag(FLAGS_trace_channels));
bool uses_observers =
absl::GetFlag(FLAGS_output_node_coverage_stats_proto).has_value() ||
absl::GetFlag(FLAGS_output_node_coverage_stats_textproto).has_value();
evaluator_options.set_support_observers(uses_observers);
if (options.use_jit) {
XLS_ASSIGN_OR_RETURN(
runtime, CreateJitSerialProcRuntime(package, evaluator_options));
XLS_ASSIGN_OR_RETURN(auto jit_queue, runtime->GetJitChannelQueueManager());
jit = &jit_queue->runtime();
} else {
XLS_ASSIGN_OR_RETURN(runtime, CreateInterpreterSerialProcRuntime(
package, evaluator_options));
}
ScopedRecordNodeCoverage cov(
absl::GetFlag(FLAGS_output_node_coverage_stats_proto),
absl::GetFlag(FLAGS_output_node_coverage_stats_textproto), jit);
if (cov.observer()) {
XLS_RETURN_IF_ERROR(runtime->SetObserver(*cov.observer()));
LOG(ERROR) << "Set observer!";
Expand Down Expand Up @@ -730,10 +737,6 @@ static absl::Status RunBlock(
"Input IR should contain exactly one block");
}

ScopedRecordNodeCoverage cov(
absl::GetFlag(FLAGS_output_node_coverage_stats_proto),
absl::GetFlag(FLAGS_output_node_coverage_stats_textproto));

std::mt19937_64 bit_gen(options.random_seed);

Block* block = package->blocks()[0].get();
Expand Down Expand Up @@ -785,15 +788,25 @@ static absl::Status RunBlock(
reg_state[reg->name()] = XsOfType(reg->type());
}

bool needs_observer =
absl::GetFlag(FLAGS_output_node_coverage_stats_proto).has_value() ||
absl::GetFlag(FLAGS_output_node_coverage_stats_textproto).has_value();
const BlockEvaluator& continuation_factory =
options.use_jit
? reinterpret_cast<const BlockEvaluator&>(
cov.observer() ? kObservableJitBlockEvaluator
needs_observer ? kObservableJitBlockEvaluator
: kJitBlockEvaluator)
: reinterpret_cast<const BlockEvaluator&>(kInterpreterBlockEvaluator);

XLS_ASSIGN_OR_RETURN(auto continuation,
continuation_factory.NewContinuation(block, reg_state));
std::optional<JitRuntime*> jit;
if (options.use_jit) {
XLS_ASSIGN_OR_RETURN(jit,
kJitBlockEvaluator.GetRuntime(continuation.get()));
}
ScopedRecordNodeCoverage cov(
absl::GetFlag(FLAGS_output_node_coverage_stats_proto),
absl::GetFlag(FLAGS_output_node_coverage_stats_textproto), jit);

if (cov.observer()) {
XLS_RETURN_IF_ERROR(continuation->SetObserver(*cov.observer()));
Expand Down
37 changes: 37 additions & 0 deletions xls/tools/node_coverage_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include <vector>

#include "absl/algorithm/container.h"
#include "absl/log/check.h"
#include "absl/log/log.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
Expand All @@ -36,6 +37,7 @@
#include "xls/ir/type.h"
#include "xls/ir/value.h"
#include "xls/ir/value_utils.h"
#include "xls/jit/jit_runtime.h"
#include "xls/tools/node_coverage_stats.pb.h"

namespace xls {
Expand Down Expand Up @@ -101,7 +103,21 @@ void CoverageEvalObserver::NodeEvaluated(Node* n, const Value& v) {
coverage_[n] = bitmap.value();
}

absl::Status CoverageEvalObserver::Finalize() {
if (!jit_) {
XLS_RET_CHECK(raw_coverage_.empty()) << "no jit but raw data was present.";
return absl::OkStatus();
}
for (const auto& [node, data] : raw_coverage_) {
NodeEvaluated(node,
jit_.value()->UnpackBuffer(data.data(), node->GetType()));
}
raw_coverage_.clear();
return absl::OkStatus();
}

absl::StatusOr<NodeCoverageStatsProto> CoverageEvalObserver::proto() const {
XLS_RET_CHECK(raw_coverage_.empty()) << "Need to call finalize first";
NodeCoverageStatsProto res;
if (coverage_.empty()) {
LOG(WARNING) << "No coverage information collected.";
Expand Down Expand Up @@ -145,11 +161,32 @@ absl::StatusOr<NodeCoverageStatsProto> CoverageEvalObserver::proto() const {

return res;
}
void CoverageEvalObserver::RecordNodeValue(int64_t node_ptr,
const uint8_t* data) {
CHECK(jit_);
if (paused_) {
VLOG(2) << "Ignoring " << node_ptr << " due to pause.";
return;
}
Node* node = reinterpret_cast<Node*>(static_cast<intptr_t>(node_ptr));
if (!raw_coverage_.contains(node)) {
raw_coverage_[node] =
std::vector<uint8_t>(jit_.value()->GetTypeByteSize(node->GetType()), 0);
}
if (node->GetType()->GetFlatBitCount() == 0) {
return;
}
std::vector<uint8_t>& bits = raw_coverage_[node];
for (int64_t i = 0; i < bits.size(); ++i) {
bits[i] = bits[i] | data[i];
}
}

ScopedRecordNodeCoverage::~ScopedRecordNodeCoverage() {
if (!txtproto_ && !binproto_) {
return;
}
CHECK_OK(obs_.Finalize());
absl::StatusOr<NodeCoverageStatsProto> proto = obs_.proto();
if (!proto.ok()) {
LOG(ERROR) << "Unable to turn coverage stats to proto: " << proto.status();
Expand Down
30 changes: 27 additions & 3 deletions xls/tools/node_coverage_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,37 +15,61 @@
#ifndef XLS_TOOLS_NODE_COVERAGE_UTILS_H_
#define XLS_TOOLS_NODE_COVERAGE_UTILS_H_

#include <cstdint>
#include <optional>
#include <string>
#include <utility>
#include <vector>

#include "absl/container/flat_hash_map.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "xls/data_structures/inline_bitmap.h"
#include "xls/data_structures/leaf_type_tree.h"
#include "xls/interpreter/observer.h"
#include "xls/ir/node.h"
#include "xls/ir/value.h"
#include "xls/jit/jit_runtime.h"
#include "xls/jit/observer.h"
#include "xls/tools/node_coverage_stats.pb.h"

namespace xls {

class CoverageEvalObserver final : public EvaluationObserver {
class CoverageEvalObserver final : public EvaluationObserver,
public RuntimeObserver {
public:
explicit CoverageEvalObserver(std::optional<JitRuntime*> jit = std::nullopt)
: jit_(jit) {}
void NodeEvaluated(Node* n, const Value& v) override;
std::optional<RuntimeObserver*> AsRawObserver() override {
if (jit_) {
return this;
}
return std::nullopt;
}
void RecordNodeValue(int64_t node_ptr, const uint8_t* data) override;

// Prepare for proto conversion.
absl::Status Finalize();

absl::StatusOr<NodeCoverageStatsProto> proto() const;
void SetPaused(bool v) { paused_ = v; }

private:
absl::flat_hash_map<Node*, LeafTypeTree<InlineBitmap>> coverage_;
absl::flat_hash_map<Node*, std::vector<uint8_t>> raw_coverage_;
std::optional<JitRuntime*> jit_;
bool paused_ = false;
};

class ScopedRecordNodeCoverage {
public:
ScopedRecordNodeCoverage(std::optional<std::string> binproto,
std::optional<std::string> txtproto)
: binproto_(std::move(binproto)), txtproto_(std::move(txtproto)) {}
std::optional<std::string> txtproto,
std::optional<JitRuntime*> jit = std::nullopt)
: binproto_(std::move(binproto)),
txtproto_(std::move(txtproto)),
obs_(jit) {}
~ScopedRecordNodeCoverage();
std::optional<EvaluationObserver*> observer() {
if (binproto_ || txtproto_) {
Expand Down

0 comments on commit dcc56d7

Please sign in to comment.