Skip to content

Commit

Permalink
Support (non-parametric) impl in dslx.
Browse files Browse the repository at this point in the history
#1277

PiperOrigin-RevId: 684055362
  • Loading branch information
erinzmoore authored and copybara-github committed Oct 9, 2024
1 parent 5947d5a commit 53f30d0
Show file tree
Hide file tree
Showing 15 changed files with 208 additions and 20 deletions.
10 changes: 10 additions & 0 deletions xls/dslx/bytecode/bytecode_emitter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -855,6 +855,16 @@ absl::StatusOr<InterpValue> BytecodeEmitter::HandleColonRefInternal(
[&](Module* module) -> absl::StatusOr<InterpValue> {
return HandleColonRefToValue(module, node);
},
[&](StructDef* struct_def) -> absl::StatusOr<InterpValue> {
std::optional<ConstantDef*> constant_def =
struct_def->GetImplConstant(node->attr());
if (!constant_def.has_value()) {
return absl::NotFoundError(absl::StrFormat(
"No impl with constant '%s' defined for struct '%s'",
node->attr(), struct_def->identifier()));
}
return type_info_->GetConstExpr(constant_def.value());
},
},
resolved_subject);
}
Expand Down
35 changes: 35 additions & 0 deletions xls/dslx/bytecode/bytecode_emitter_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -808,6 +808,41 @@ fn imported_enum_ref() -> import_0::ImportedEnum {
IsOkAndHolds(InterpValue::MakeSBits(4, 2)));
}

TEST(BytecodeEmitterTest, StructImplConstant) {
constexpr std::string_view kBaseProgram = R"(
struct Empty {}
impl Empty {
const MY_CONST = u4:7;
}
#[test]
fn struct_const_ref() -> u4 {
Empty::MY_CONST
}
)";

auto import_data = CreateImportDataForTest();
XLS_ASSERT_OK_AND_ASSIGN(
TypecheckedModule tm,
ParseAndTypecheck(kBaseProgram, "test.x", "test", &import_data));

XLS_ASSERT_OK_AND_ASSIGN(TestFunction * tf,
tm.module->GetTest("struct_const_ref"));
XLS_ASSERT_OK_AND_ASSIGN(std::unique_ptr<BytecodeFunction> bf,
BytecodeEmitter::Emit(&import_data, tm.type_info,
tf->fn(), ParametricEnv()));

const std::vector<Bytecode>& bytecodes = bf->bytecodes();
ASSERT_EQ(bytecodes.size(), 1);

const Bytecode* bc = bytecodes.data();
ASSERT_EQ(bc->op(), Bytecode::Op::kLiteral);
ASSERT_TRUE(bc->has_data());
EXPECT_THAT(bytecodes.at(0).value_data(),
IsOkAndHolds(InterpValue::MakeSBits(4, 7)));
}

TEST(BytecodeEmitterTest, ImportedConstant) {
constexpr std::string_view kImportedProgram = R"(pub const MY_CONST = u3:2;)";
constexpr std::string_view kBaseProgram = R"(
Expand Down
18 changes: 18 additions & 0 deletions xls/dslx/constexpr_evaluator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,24 @@ absl::Status ConstexprEvaluator::HandleColonRef(const ColonRef* expr) {
expr, type_info->GetConstExpr(constant_def->value()).value());
return absl::OkStatus();
},
[&](StructDef* struct_def) -> absl::Status {
std::optional<ConstantDef*> constant_def =
struct_def->GetImplConstant(expr->attr());
if (!constant_def.has_value()) {
return absl::NotFoundError(absl::StrFormat(
"No impl with constant '%s' defined for struct '%s'",
expr->attr(), struct_def->identifier()));
}
XLS_RETURN_IF_ERROR(Evaluate(import_data_, type_info_,
warning_collector_, bindings_,
constant_def.value()->value()));
XLS_RET_CHECK(
type_info_->IsKnownConstExpr(constant_def.value()->value()));
type_info_->NoteConstExpr(
expr, type_info_->GetConstExpr(constant_def.value()->value())
.value());
return absl::OkStatus();
},
},
subject);
}
Expand Down
60 changes: 60 additions & 0 deletions xls/dslx/constexpr_evaluator_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -619,5 +619,65 @@ fn main() -> MyArray {
EXPECT_THAT(value.GetLength(), IsOkAndHolds(4));
}

TEST(ConstexprEvaluatorTest, ImplWithConstantSimple) {
constexpr std::string_view kProgram = R"(
struct MyStruct {}
impl MyStruct {
const STRUCT_CONST = u32:7;
}
fn main() -> u32 {
MyStruct::STRUCT_CONST
}
)";

ImportData import_data(CreateImportDataForTest());
XLS_ASSERT_OK_AND_ASSIGN(
TypecheckedModule tm,
ParseAndTypecheck(kProgram, "test.x", "test", &import_data));

XLS_ASSERT_OK_AND_ASSIGN(Function * f,
tm.module->GetMemberOrError<Function>("main"));
WarningCollector warnings(kAllWarningsSet);
XLS_ASSERT_OK(ConstexprEvaluator::Evaluate(&import_data, tm.type_info,
&warnings, ParametricEnv(),
f->body(), nullptr));
XLS_ASSERT_OK_AND_ASSIGN(InterpValue value,
tm.type_info->GetConstExpr(f->body()));
EXPECT_EQ(value.GetBitValueViaSign().value(), 7);
}

TEST(ConstexprEvaluatorTest, ImplWithConstantRefGlobal) {
constexpr std::string_view kProgram = R"(
const SIZE = u32:4;
struct MyStruct {}
impl MyStruct {
const STRUCT_CONST = u32:2 * SIZE;
}
fn main() -> u32 {
MyStruct::STRUCT_CONST
}
)";

ImportData import_data(CreateImportDataForTest());
XLS_ASSERT_OK_AND_ASSIGN(
TypecheckedModule tm,
ParseAndTypecheck(kProgram, "test.x", "test", &import_data));

XLS_ASSERT_OK_AND_ASSIGN(Function * f,
tm.module->GetMemberOrError<Function>("main"));
WarningCollector warnings(kAllWarningsSet);
XLS_ASSERT_OK(ConstexprEvaluator::Evaluate(&import_data, tm.type_info,
&warnings, ParametricEnv(),
f->body(), nullptr));
XLS_ASSERT_OK_AND_ASSIGN(InterpValue value,
tm.type_info->GetConstExpr(f->body()));
EXPECT_EQ(value.GetBitValueViaSign().value(), 8);
}

} // namespace
} // namespace xls::dslx
8 changes: 8 additions & 0 deletions xls/dslx/frontend/ast.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1296,6 +1296,14 @@ std::vector<std::string> StructDef::GetMemberNames() const {
return names;
}

std::optional<ConstantDef*> StructDef::GetImplConstant(
std::string_view constant_name) {
if (!impl_.has_value()) {
return std::nullopt;
}
return impl_.value()->GetConstant(constant_name);
}

// -- class Impl

Impl::Impl(Module* owner, Span span, TypeAnnotation* struct_ref,
Expand Down
2 changes: 2 additions & 0 deletions xls/dslx/frontend/ast.h
Original file line number Diff line number Diff line change
Expand Up @@ -2311,6 +2311,8 @@ class StructDef : public AstNode {

std::optional<Impl*> impl() const { return impl_; }

std::optional<ConstantDef*> GetImplConstant(std::string_view constant_name);

private:
Span span_;
NameDef* name_def_;
Expand Down
2 changes: 2 additions & 0 deletions xls/dslx/frontend/module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,8 @@ std::optional<ModuleMember*> Module::FindMemberWithName(
}
} else if (std::holds_alternative<ConstAssert*>(member)) {
continue; // These have no name / binding.
} else if (std::holds_alternative<Impl*>(member)) {
continue; // These have no name / binding.
} else {
LOG(FATAL) << "Unhandled module member variant: "
<< ToAstNode(member)->GetNodeTypeName();
Expand Down
16 changes: 15 additions & 1 deletion xls/dslx/ir_convert/function_converter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2801,7 +2801,21 @@ absl::Status FunctionConverter::HandleColonRef(const ColonRef* node) {
DefConst(node, value);
return absl::OkStatus();
},
},
[&](StructDef* struct_def) -> absl::Status {
std::optional<ConstantDef*> constant_def =
struct_def->GetImplConstant(node->attr());
if (!constant_def.has_value()) {
return absl::NotFoundError(absl::StrFormat(
"No impl with constant '%s' defined for struct '%s'",
node->attr(), struct_def->identifier()));
}
XLS_ASSIGN_OR_RETURN(
InterpValue iv,
current_type_info_->GetConstExpr(constant_def.value()));
XLS_ASSIGN_OR_RETURN(Value value, InterpValueToValue(iv));
DefConst(node, value);
return absl::OkStatus();
}},
subject);
}

Expand Down
5 changes: 3 additions & 2 deletions xls/dslx/lsp/find_definition.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ namespace xls::dslx {
namespace {

const NameDef* GetNameDef(
const std::variant<Module*, EnumDef*, BuiltinNameDef*,
ArrayTypeAnnotation*>& colon_ref_subject,
const std::variant<Module*, EnumDef*, BuiltinNameDef*, ArrayTypeAnnotation*,
StructDef*>& colon_ref_subject,
std::string_view attr) {
return absl::visit(
Visitor{
Expand All @@ -46,6 +46,7 @@ const NameDef* GetNameDef(
return ModuleMemberGetNameDef(*member.value());
},
[&](EnumDef* e) -> const NameDef* { return e->GetNameDef(attr); },
[&](StructDef* s) -> const NameDef* { return s->name_def(); },
[](BuiltinNameDef*) -> const NameDef* { return nullptr; },
[](ArrayTypeAnnotation*) -> const NameDef* { return nullptr; },
},
Expand Down
2 changes: 2 additions & 0 deletions xls/dslx/tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,8 @@ dslx_lang_test(

dslx_lang_test(name = "struct_as_parametric")

dslx_lang_test(name = "impl")

dslx_lang_test(name = "subtract_to_negative")

dslx_lang_test(name = "trace")
Expand Down
30 changes: 30 additions & 0 deletions xls/dslx/tests/impl.x
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
// Copyright 2024 The XLS Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

struct Point { x: u32, y: u32 }

impl Point {
const MY_CONST = u32:5;
}

fn main() -> u32 {
let p = Point { x: u32:7, y: u32:8 };
p::MY_CONST
}

#[test]
fn use_impl_const() {
let p = Point { x: u32:7, y: u32:8 };
assert_eq(p::MY_CONST, u32:5);
}
2 changes: 1 addition & 1 deletion xls/dslx/type_system/deduce.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1044,7 +1044,7 @@ static absl::StatusOr<std::unique_ptr<Type>> DeduceColonRefToImpl(
node->attr(), struct_def->identifier()),
ctx->file_table());
}
return DeduceConstantDef(constant.value(), ctx);
return ctx->Deduce(constant.value());
}

absl::StatusOr<std::unique_ptr<Type>> DeduceColonRef(const ColonRef* node,
Expand Down
14 changes: 6 additions & 8 deletions xls/dslx/type_system/deduce_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -437,25 +437,23 @@ absl::StatusOr<ColonRefSubjectT> ResolveColonRefSubjectForTypeChecking(
td.value());
}

absl::StatusOr<
std::variant<Module*, EnumDef*, BuiltinNameDef*, ArrayTypeAnnotation*>>
absl::StatusOr<std::variant<Module*, EnumDef*, BuiltinNameDef*,
ArrayTypeAnnotation*, StructDef*>>
ResolveColonRefSubjectAfterTypeChecking(ImportData* import_data,
const TypeInfo* type_info,
const ColonRef* colon_ref) {
XLS_ASSIGN_OR_RETURN(auto result, ResolveColonRefSubjectForTypeChecking(
import_data, type_info, colon_ref));
using ReturnT = absl::StatusOr<
std::variant<Module*, EnumDef*, BuiltinNameDef*, ArrayTypeAnnotation*>>;
using ReturnT =
absl::StatusOr<std::variant<Module*, EnumDef*, BuiltinNameDef*,
ArrayTypeAnnotation*, StructDef*>>;
return absl::visit(
Visitor{
[](Module* x) -> ReturnT { return x; },
[](EnumDef* x) -> ReturnT { return x; },
[](BuiltinNameDef* x) -> ReturnT { return x; },
[](ArrayTypeAnnotation* x) -> ReturnT { return x; },
[](StructDef*) -> ReturnT {
return absl::InternalError(
"After type checking colon-ref subject cannot be a StructDef");
},
[](StructDef* x) -> ReturnT { return x; },
[](ColonRef*) -> ReturnT {
return absl::InternalError(
"After type checking colon-ref subject cannot be a StructDef");
Expand Down
8 changes: 3 additions & 5 deletions xls/dslx/type_system/deduce_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,7 @@ absl::Status ValidateNumber(const Number& number, const Type& type);
// * a module
// * an enum definition
// * a builtin type (with a constant item on it, a la `u7::MAX`)
//
// Struct definitions cannot currently have constant items on them, so this will
// have to be flagged by the type checker.
// * a constant defined via `impl` on a `StructDef`.
absl::StatusOr<std::variant<Module*, EnumDef*, BuiltinNameDef*,
ArrayTypeAnnotation*, StructDef*, ColonRef*>>
ResolveColonRefSubjectForTypeChecking(ImportData* import_data,
Expand All @@ -78,8 +76,8 @@ ResolveColonRefSubjectForTypeChecking(ImportData* import_data,
// Implementation of the above that can be called after type checking has been
// performed, in which case we can eliminate some of the (invalid) possibilities
// so they no longer need to be handled.
absl::StatusOr<
std::variant<Module*, EnumDef*, BuiltinNameDef*, ArrayTypeAnnotation*>>
absl::StatusOr<std::variant<Module*, EnumDef*, BuiltinNameDef*,
ArrayTypeAnnotation*, StructDef*>>
ResolveColonRefSubjectAfterTypeChecking(ImportData* import_data,
const TypeInfo* type_info,
const ColonRef* colon_ref);
Expand Down
16 changes: 13 additions & 3 deletions xls/dslx/type_system/impl_typecheck_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,7 @@ const GLOBAL_DIMS = NUM_DIMS;
HasSubstr("Cannot find a definition")));
}

// TODO: https://github.com/google/xls/issues/1277 - Support assignment in
// expressions.
TEST(TypecheckTest, DISABLED_ImplConstantExtracted) {
TEST(TypecheckTest, ImplConstantExtracted) {
constexpr std::string_view kProgram = R"(
struct Point { x: u32, y: u32 }
Expand All @@ -86,6 +84,18 @@ const GLOBAL_DIMS = Point::NUM_DIMS;
XLS_EXPECT_OK(Typecheck(kProgram));
}

TEST(TypecheckErrorTest, ConstantExtractionWithoutImpl) {
constexpr std::string_view kProgram = R"(
struct Point { x: u32, y: u32 }
const GLOBAL_DIMS = Point::NUM_DIMS;
)";
EXPECT_THAT(
Typecheck(kProgram),
StatusIs(absl::StatusCode::kInvalidArgument,
HasSubstr("Struct 'Point' has no impl defining 'NUM_DIMS'")));
}

TEST(TypecheckErrorTest, ConstantAccessWithoutImplDef) {
constexpr std::string_view kProgram = R"(
struct Point { x: u32, y: u32 }
Expand Down

0 comments on commit 53f30d0

Please sign in to comment.