Skip to content

Commit

Permalink
Allow RangeQueryEngine to define Max/MinUnsignedValue
Browse files Browse the repository at this point in the history
These functions were only defined in the base query-engine and were implemented assuming each bit is independent of one another. This is only really true for the ternary query engine however. By making the function virtual and providing an override in range-query engine we can get better bounds.

PiperOrigin-RevId: 687029040
  • Loading branch information
allight authored and copybara-github committed Oct 17, 2024
1 parent acf176f commit af19b00
Show file tree
Hide file tree
Showing 8 changed files with 137 additions and 2 deletions.
1 change: 1 addition & 0 deletions xls/passes/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -836,6 +836,7 @@ cc_library(
"//xls/data_structures:leaf_type_tree",
"//xls/ir",
"//xls/ir:bits",
"//xls/ir:bits_ops",
"//xls/ir:interval_set",
"//xls/ir:ternary",
"//xls/ir:type",
Expand Down
4 changes: 2 additions & 2 deletions xls/passes/query_engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -214,10 +214,10 @@ class QueryEngine {
}

// Returns the maximum unsigned value that the node can be.
Bits MaxUnsignedValue(Node* node) const;
virtual Bits MaxUnsignedValue(Node* node) const;

// Returns the minimum unsigned value that the node can be.
Bits MinUnsignedValue(Node* node) const;
virtual Bits MinUnsignedValue(Node* node) const;

// Returns true if the values of the two nodes are known to be equal when
// interpreted as unsigned numbers. The nodes can be of different widths.
Expand Down
24 changes: 24 additions & 0 deletions xls/passes/range_query_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1420,6 +1420,30 @@ static void IntervalSetTreeToStream(const IntervalSetTree& tree, Type* type,
}
}

Bits RangeQueryEngine::MaxUnsignedValue(Node* n) const {
CHECK(n->GetType()->IsBits()) << n;
// If modifications to the function mean we don't have a tree we just fall
// back to the bit-based version.
if (!HasExplicitIntervals(n)) {
return QueryEngine::MaxUnsignedValue(n);
}
std::optional<Interval> hull =
GetIntervalSetTreeView(n)->Get({}).ConvexHull();
return hull ? hull->UpperBound() : Bits::AllOnes(n->BitCountOrDie());
}

Bits RangeQueryEngine::MinUnsignedValue(Node* n) const {
CHECK(n->GetType()->IsBits()) << n;
// If modifications to the function mean we don't have a tree we just fall
// back to the bit-based version.
if (!HasExplicitIntervals(n)) {
return QueryEngine::MinUnsignedValue(n);
}
std::optional<Interval> hull =
GetIntervalSetTreeView(n)->Get({}).ConvexHull();
return hull ? hull->LowerBound() : Bits(n->BitCountOrDie());
}

std::string IntervalSetTreeToString(const IntervalSetTree& tree) {
std::stringstream ss;
IntervalSetTreeToStream(tree, tree.type(), {}, ss);
Expand Down
3 changes: 3 additions & 0 deletions xls/passes/range_query_engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,9 @@ class RangeQueryEngine : public QueryEngine {
// This must be called before `SetIntervalSetTree`.
void InitializeNode(Node* node);

Bits MaxUnsignedValue(Node* n) const override;
Bits MinUnsignedValue(Node* n) const override;

private:
friend class RangeQueryVisitor;

Expand Down
31 changes: 31 additions & 0 deletions xls/passes/range_query_engine_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2546,5 +2546,36 @@ TEST_F(RangeQueryEngineTest, MultipleRangeGivenValue) {
BitsLTT(ltxyz.node(), {Interval::Precise(UBits(1, 1))}));
}

TEST_F(RangeQueryEngineTest, MaxMinUnsignedValue) {
auto p = CreatePackage();
FunctionBuilder fb(TestName(), p.get());
// We will have a known range of [3, 12] for this
BValue x = fb.Param("x", p->GetBitsType(8));
// We will have a known range of [0, 3] for this
BValue y = fb.Param("y", p->GetBitsType(8));
// We will have a known range of [1, 4] for this.
BValue z = fb.Param("z", p->GetBitsType(8));

// [3, 12] + [0, 3] == [3, 15]
BValue xy = fb.Add(x, y);
// [3, 15] + [1,4] == [4,19]
BValue xyz = fb.Add(xy, z);

XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.Build());
RangeQueryEngine engine;
engine.SetIntervalSetTree(
x.node(), BitsLTT(x.node(), {Interval(UBits(3, 8), UBits(12, 8))}));
engine.SetIntervalSetTree(
y.node(), BitsLTT(y.node(), {Interval(UBits(0, 8), UBits(3, 8))}));
engine.SetIntervalSetTree(
z.node(), BitsLTT(z.node(), {Interval(UBits(1, 8), UBits(4, 8))}));
XLS_ASSERT_OK(engine.Populate(f));

EXPECT_EQ(engine.MaxUnsignedValue(xyz.node()), UBits(19, 8));
EXPECT_EQ(engine.MaxUnsignedValue(xy.node()), UBits(15, 8));
EXPECT_EQ(engine.MinUnsignedValue(xyz.node()), UBits(4, 8));
EXPECT_EQ(engine.MinUnsignedValue(xy.node()), UBits(3, 8));
}

} // namespace
} // namespace xls
26 changes: 26 additions & 0 deletions xls/passes/union_query_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include "xls/common/status/status_macros.h"
#include "xls/data_structures/leaf_type_tree.h"
#include "xls/ir/bits.h"
#include "xls/ir/bits_ops.h"
#include "xls/ir/interval_set.h"
#include "xls/ir/node.h"
#include "xls/ir/ternary.h"
Expand Down Expand Up @@ -239,4 +240,29 @@ std::optional<TernaryVector> UnownedUnionQueryEngine::ImpliedNodeTernary(
return result;
}

Bits UnownedUnionQueryEngine::MaxUnsignedValue(Node* node) const {
CHECK(node->GetType()->IsBits()) << node;
Bits result = engines_.front()->MaxUnsignedValue(node);
for (const auto& engine : absl::MakeConstSpan(engines_).subspan(1)) {
Bits eng_res = engine->MaxUnsignedValue(node);
if (bits_ops::ULessThan(eng_res, result)) {
result = std::move(eng_res);
}
}
return result;
}

Bits UnownedUnionQueryEngine::MinUnsignedValue(Node* node) const {
CHECK(node->GetType()->IsBits()) << node;
CHECK(node->GetType()->IsBits()) << node;
Bits result = engines_.front()->MinUnsignedValue(node);
for (const auto& engine : absl::MakeConstSpan(engines_).subspan(1)) {
Bits eng_res = engine->MinUnsignedValue(node);
if (bits_ops::UGreaterThan(eng_res, result)) {
result = std::move(eng_res);
}
}
return result;
}

} // namespace xls
3 changes: 3 additions & 0 deletions xls/passes/union_query_engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,9 @@ class UnownedUnionQueryEngine : public QueryEngine {
bool IsAllZeros(Node* n) const override;
bool IsAllOnes(Node* n) const override;

Bits MaxUnsignedValue(Node* node) const override;
Bits MinUnsignedValue(Node* node) const override;

private:
absl::flat_hash_map<Node*, Bits> known_bits_;
absl::flat_hash_map<Node*, Bits> known_bit_values_;
Expand Down
47 changes: 47 additions & 0 deletions xls/passes/union_query_engine_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,23 @@ class FakeQueryEngine : public QueryEngine {
intervals_[node] = intervals;
}

Bits MaxUnsignedValue(Node* n) const override {
CHECK(n->GetType()->IsBits()) << n;
if (intervals_.contains(n)) {
return intervals_.at(n).Get({}).ConvexHull()->UpperBound();
} else {
return Bits::AllOnes(n->BitCountOrDie());
}
}
Bits MinUnsignedValue(Node* n) const override {
CHECK(n->GetType()->IsBits()) << n;
if (intervals_.contains(n)) {
return intervals_.at(n).Get({}).ConvexHull()->LowerBound();
} else {
return Bits(n->BitCountOrDie());
}
}

void AddKnownBit(const TreeBitLocation& location, bool value) {
Node* node = location.node();
AddTracked(node);
Expand Down Expand Up @@ -339,5 +356,35 @@ TEST_F(UnionQueryEngineTest, Intervals) {
EXPECT_EQ(tuple_intervals.elements()[1], y_b);
}

TEST_F(UnionQueryEngineTest, MaxMinUnsignedValue) {
auto p = CreatePackage();
FunctionBuilder fb(TestName(), p.get());

BValue x = fb.Param("x", fb.package()->GetBitsType(8));
BValue y = fb.Param("y", fb.package()->GetBitsType(8));
fb.Tuple({x, y});
XLS_ASSERT_OK(fb.Build());

FakeQueryEngine query_engine_a;
IntervalSet x_a = IntervalSet::Of({
Interval(UBits(20, 8), UBits(40, 8)),
});
query_engine_a.AddIntervals(
x.node(), LeafTypeTree<IntervalSet>(x.node()->GetType(), {x_a}));
FakeQueryEngine query_engine_b;
IntervalSet y_b = IntervalSet::Of({Interval(UBits(200, 8), UBits(240, 8))});
query_engine_b.AddIntervals(
y.node(), LeafTypeTree<IntervalSet>(y.node()->GetType(), {y_b}));
std::vector<std::unique_ptr<QueryEngine>> engines;
engines.push_back(std::make_unique<FakeQueryEngine>(query_engine_a));
engines.push_back(std::make_unique<FakeQueryEngine>(query_engine_b));
UnionQueryEngine union_query_engine(std::move(engines));
// No need to Populate, since FakeQueryEngine doesn't use that interface
EXPECT_EQ(union_query_engine.MinUnsignedValue(x.node()), UBits(20, 8));
EXPECT_EQ(union_query_engine.MinUnsignedValue(y.node()), UBits(200, 8));
EXPECT_EQ(union_query_engine.MaxUnsignedValue(x.node()), UBits(40, 8));
EXPECT_EQ(union_query_engine.MaxUnsignedValue(y.node()), UBits(240, 8));
}

} // namespace
} // namespace xls

0 comments on commit af19b00

Please sign in to comment.