Skip to content

Commit

Permalink
[XLS] Simplify array indexing into an array operation
Browse files Browse the repository at this point in the history
Even if the index isn't literal, we can convert this into a Select, creating more opportunities for conditional specialization. On the other hand, if the index is literal and is out-of-bounds, then we can always select the last entry due to the clamping behavior of ArrayIndex.

PiperOrigin-RevId: 674644265
  • Loading branch information
ericastor authored and copybara-github committed Sep 14, 2024
1 parent d1b26f2 commit e9aa42c
Show file tree
Hide file tree
Showing 2 changed files with 144 additions and 15 deletions.
90 changes: 75 additions & 15 deletions xls/passes/array_simplification_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -202,22 +202,82 @@ absl::StatusOr<SimplifyResult> SimplifyArrayIndex(
if (array_index->array()->Is<Array>() && !array_index->indices().empty() &&
query_engine.IsFullyKnown(array_index->indices().front())) {
Array* array = array_index->array()->As<Array>();
Node* first_index = array_index->indices().front();
if (IndexIsDefinitelyInBounds(first_index, array->GetType()->AsArrayOrDie(),
query_engine)) {
// Indices are always interpreted as unsigned numbers.
XLS_ASSIGN_OR_RETURN(
uint64_t operand_no,
query_engine.KnownValueAsBits(first_index)->ToUint64());
VLOG(2) << absl::StrFormat(
"Array-index of array operation with constant index: %s",
array_index->ToString());
XLS_ASSIGN_OR_RETURN(
ArrayIndex * new_array_index,
array_index->ReplaceUsesWithNew<ArrayIndex>(
array->operand(operand_no), array_index->indices().subspan(1)));
return SimplifyResult::Changed({new_array_index});
XLS_RET_CHECK(!array->operands().empty());

VLOG(2) << absl::StrFormat(
"Array-index of array operation with constant index: %s",
array_index->ToString());

Bits first_index =
*query_engine.KnownValueAsBits(array_index->indices().front());

// Indices are always interpreted as unsigned numbers, and if past the end
// of the array, are clamped to the last value.
uint64_t operand_no;
if (bits_ops::UGreaterThan(first_index,
UBits(array->operand_count() - 1, 64))) {
operand_no = array->operand_count() - 1;
} else {
XLS_ASSIGN_OR_RETURN(operand_no, first_index.ToUint64());
}
XLS_ASSIGN_OR_RETURN(
ArrayIndex * new_array_index,
array_index->ReplaceUsesWithNew<ArrayIndex>(
array->operand(operand_no), array_index->indices().subspan(1)));
return SimplifyResult::Changed({new_array_index});
}

// An array index which indexes into a kArray operation can be replaced with a
// select between the values, using the last entry as the default value to
// reproduce the array index's clamping behavior if necessary:
//
// array_index(array(a, b, c), {i, j, k, ...}
// => select(i, cases=[array_index(a, {j, k, ...}),
// array_index(b, {j, k, ...})],
// default_value=array_index(c, {j, k, ...}))
//
if (array_index->array()->Is<Array>() && !array_index->indices().empty()) {
Array* array = array_index->array()->As<Array>();
XLS_RET_CHECK(!array->operands().empty());

VLOG(2) << absl::StrFormat("Array-index of array operation: %s",
array_index->ToString());

absl::Span<Node* const> indices = array_index->indices();
Node* selector = array_index->indices().front();
indices.remove_prefix(1);

absl::Span<Node* const> cases;
std::vector<Node*> new_array_indexes;
uint64_t reachable_size =
uint64_t{1} << std::min(int64_t{63}, selector->BitCountOrDie());
absl::Span<Node* const> reachable_operands =
array->operands().subspan(0, reachable_size);
if (indices.empty()) {
cases = reachable_operands;
} else {
new_array_indexes.reserve(reachable_operands.size());
for (Node* entry : reachable_operands) {
XLS_ASSIGN_OR_RETURN(ArrayIndex * subindex,
array_index->function_base()->MakeNode<ArrayIndex>(
array_index->loc(), entry, indices));
new_array_indexes.push_back(subindex);
}
cases = absl::MakeConstSpan(new_array_indexes);
}

std::optional<Node*> default_value;
if (selector->BitCountOrDie() >= Bits::MinBitCountUnsigned(cases.size())) {
// The selector can represent values that are past the end of the array,
// so move the last case to a default value to provide clamping.
default_value = cases.back();
cases.remove_suffix(1);
}

XLS_RETURN_IF_ERROR(
array_index->ReplaceUsesWithNew<Select>(selector, cases, default_value)
.status());
return SimplifyResult::Changed(new_array_indexes);
}

// An array index which indexes into a kArrayConcat operation and whose first
Expand Down
69 changes: 69 additions & 0 deletions xls/passes/array_simplification_pass_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,60 @@ TEST_F(ArraySimplificationPassTest, LiteralArrayWithNonLiteralIndex) {
}

TEST_F(ArraySimplificationPassTest, IndexingArrayOperation) {
auto p = CreatePackage();
FunctionBuilder fb(TestName(), p.get());
Type* u16 = p->GetBitsType(16);
Type* u32 = p->GetBitsType(32);
BValue a = fb.Array(
{fb.Param("x", u32), fb.Param("y", u32), fb.Param("z", u32)}, u32);
BValue index = fb.Param("i", u16);
fb.ArrayIndex(a, {index});
XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.Build());
solvers::z3::ScopedVerifyEquivalence stays_equivalent(f);
ASSERT_THAT(Run(f), IsOkAndHolds(true));
EXPECT_THAT(
f->return_value(),
m::Select(m::Param("i"), {m::Param("x"), m::Param("y")}, m::Param("z")));
}

TEST_F(ArraySimplificationPassTest, IndexingArrayOperationExactFit) {
auto p = CreatePackage();
FunctionBuilder fb(TestName(), p.get());
Type* u2 = p->GetBitsType(2);
Type* u32 = p->GetBitsType(32);
BValue a = fb.Array({fb.Param("x", u32), fb.Param("y", u32),
fb.Param("z", u32), fb.Param("w", u32)},
u32);
BValue index = fb.Param("i", u2);
fb.ArrayIndex(a, {index});
XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.Build());
solvers::z3::ScopedVerifyEquivalence stays_equivalent(f);
ASSERT_THAT(Run(f), IsOkAndHolds(true));
EXPECT_THAT(f->return_value(),
m::Select(m::Param("i"), {m::Param("x"), m::Param("y"),
m::Param("z"), m::Param("w")}));
}

TEST_F(ArraySimplificationPassTest, IndexingArrayOperationUndersizedIndex) {
auto p = CreatePackage();
FunctionBuilder fb(TestName(), p.get());
Type* u2 = p->GetBitsType(2);
Type* u32 = p->GetBitsType(32);
BValue a =
fb.Array({fb.Param("x", u32), fb.Param("y", u32), fb.Param("z", u32),
fb.Param("w", u32), fb.Param("q", u32)},
u32);
BValue index = fb.Param("i", u2);
fb.ArrayIndex(a, {index});
XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.Build());
solvers::z3::ScopedVerifyEquivalence stays_equivalent(f);
ASSERT_THAT(Run(f), IsOkAndHolds(true));
EXPECT_THAT(f->return_value(),
m::Select(m::Param("i"), {m::Param("x"), m::Param("y"),
m::Param("z"), m::Param("w")}));
}

TEST_F(ArraySimplificationPassTest, IndexingArrayOperationWithLiteralIndex) {
auto p = CreatePackage();
FunctionBuilder fb(TestName(), p.get());
Type* u32 = p->GetBitsType(32);
Expand All @@ -124,6 +178,21 @@ TEST_F(ArraySimplificationPassTest, IndexingArrayOperation) {
BValue index = fb.Literal(Value(UBits(2, 16)));
fb.ArrayIndex(a, {index});
XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.Build());
solvers::z3::ScopedVerifyEquivalence stays_equivalent(f);
ASSERT_THAT(Run(f), IsOkAndHolds(true));
EXPECT_THAT(f->return_value(), m::Param("z"));
}

TEST_F(ArraySimplificationPassTest, IndexingArrayOperationWithOobLiteralIndex) {
auto p = CreatePackage();
FunctionBuilder fb(TestName(), p.get());
Type* u32 = p->GetBitsType(32);
BValue a = fb.Array(
{fb.Param("x", u32), fb.Param("y", u32), fb.Param("z", u32)}, u32);
BValue index = fb.Literal(Value(UBits(5, 16)));
fb.ArrayIndex(a, {index});
XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.Build());
solvers::z3::ScopedVerifyEquivalence stays_equivalent(f);
ASSERT_THAT(Run(f), IsOkAndHolds(true));
EXPECT_THAT(f->return_value(), m::Param("z"));
}
Expand Down

0 comments on commit e9aa42c

Please sign in to comment.