From 825f277996c6d6c1635ea81d4976c4c8c3117b3a Mon Sep 17 00:00:00 2001 From: Mark Heffernan Date: Tue, 8 Oct 2024 12:21:49 -0700 Subject: [PATCH] Add proc-scoped channel support to channel legalization pass. Incidentally change the way in which channels are named. The data, predicate and completion channels are now numbered consistently. For example, the channels associated with the 42nd op instance of channel `foo` are: `foo__data_42`, `foo__pred_42` and `foo__condition_42`, PiperOrigin-RevId: 683721323 --- xls/ir/BUILD | 2 + xls/ir/channel.cc | 44 + xls/ir/channel.h | 10 + xls/ir/package.cc | 3 + xls/ir/proc_elaboration.cc | 1 + xls/passes/BUILD | 3 + xls/passes/channel_legalization_pass.cc | 879 +++++++++++++------ xls/passes/channel_legalization_pass_test.cc | 171 +++- 8 files changed, 826 insertions(+), 287 deletions(-) diff --git a/xls/ir/BUILD b/xls/ir/BUILD index 77600309ce..082f21152b 100644 --- a/xls/ir/BUILD +++ b/xls/ir/BUILD @@ -1583,6 +1583,7 @@ cc_library( "//xls/common/status:ret_check", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", @@ -1896,6 +1897,7 @@ cc_library( "//xls/common/status:status_macros", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", diff --git a/xls/ir/channel.cc b/xls/ir/channel.cc index 5462a1a14a..14007fd0a0 100644 --- a/xls/ir/channel.cc +++ b/xls/ir/channel.cc @@ -22,6 +22,7 @@ #include #include +#include "absl/log/check.h" #include "absl/log/log.h" #include "absl/status/status.h" #include "absl/status/statusor.h" @@ -229,6 +230,38 @@ std::ostream& operator<<(std::ostream& os, Direction direction) { return os; } +ChannelRef AsChannelRef(SendChannelRef ref) { + if (std::holds_alternative(ref)) { + return std::get(ref); + } + return std::get(ref); +} + +ChannelRef AsChannelRef(ReceiveChannelRef ref) { + if (std::holds_alternative(ref)) { + return std::get(ref); + } + return std::get(ref); +} + +SendChannelRef AsSendChannelRefOrDie(ChannelRef ref) { + if (std::holds_alternative(ref)) { + ChannelReference* cref = std::get(ref); + CHECK_EQ(cref->direction(), Direction::kSend); + return down_cast(cref); + } + return std::get(ref); +} + +ReceiveChannelRef AsReceiveChannelRefOrDie(ChannelRef ref) { + if (std::holds_alternative(ref)) { + ChannelReference* cref = std::get(ref); + CHECK_EQ(cref->direction(), Direction::kReceive); + return down_cast(cref); + } + return std::get(ref); +} + std::string_view ChannelRefName(ChannelRef ref) { return absl::visit([](const auto& ch) { return ch->name(); }, ref); } @@ -247,6 +280,17 @@ ChannelKind ChannelRefKind(ChannelRef ref) { return std::get(ref)->kind(); } +std::optional ChannelRefStrictness(ChannelRef ref) { + if (std::holds_alternative(ref)) { + return std::get(ref)->strictness(); + } + if (auto streaming_channel = + down_cast(std::get(ref))) { + return streaming_channel->GetStrictness(); + } + return std::nullopt; +} + std::string ChannelReference::ToString() const { std::vector keyword_strs; keyword_strs.push_back( diff --git a/xls/ir/channel.h b/xls/ir/channel.h index 48c1cc9ca2..be97eaea7d 100644 --- a/xls/ir/channel.h +++ b/xls/ir/channel.h @@ -420,10 +420,20 @@ using ChannelRef = std::variant; using SendChannelRef = std::variant; using ReceiveChannelRef = std::variant; +// Converts a send/receive ChannelRef into a generic ChannelRef. +ChannelRef AsChannelRef(SendChannelRef ref); +ChannelRef AsChannelRef(ReceiveChannelRef ref); + +// Converts a base ChannelRef into a send/receive form. CHECK fails if the +// ChannelRef is not of the appropriate direction. +SendChannelRef AsSendChannelRefOrDie(ChannelRef ref); +ReceiveChannelRef AsReceiveChannelRefOrDie(ChannelRef ref); + // Return the name/type/kind of a channel reference. std::string_view ChannelRefName(ChannelRef ref); Type* ChannelRefType(ChannelRef ref); ChannelKind ChannelRefKind(ChannelRef ref); +std::optional ChannelRefStrictness(ChannelRef ref); } // namespace xls diff --git a/xls/ir/package.cc b/xls/ir/package.cc index 18bca4c07b..9a686b1221 100644 --- a/xls/ir/package.cc +++ b/xls/ir/package.cc @@ -768,6 +768,7 @@ absl::Status Package::AddChannel(std::unique_ptr channel, Proc* proc) { } absl::StatusOr Package::GetChannel(int64_t id) const { + XLS_RET_CHECK(!ChannelsAreProcScoped()); for (Channel* ch : channels()) { if (ch->id() == id) { return ch; @@ -777,6 +778,7 @@ absl::StatusOr Package::GetChannel(int64_t id) const { } std::vector Package::GetChannelNames() const { + CHECK(!ChannelsAreProcScoped()); std::vector names; names.reserve(channels().size()); for (Channel* ch : channels()) { @@ -786,6 +788,7 @@ std::vector Package::GetChannelNames() const { } absl::StatusOr Package::GetChannel(std::string_view name) const { + XLS_RET_CHECK(!ChannelsAreProcScoped()); auto it = channels_.find(name); if (it != channels_.end()) { return it->second.get(); diff --git a/xls/ir/proc_elaboration.cc b/xls/ir/proc_elaboration.cc index ea6852a284..fd2727ccd8 100644 --- a/xls/ir/proc_elaboration.cc +++ b/xls/ir/proc_elaboration.cc @@ -25,6 +25,7 @@ #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" +#include "absl/log/log.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" diff --git a/xls/passes/BUILD b/xls/passes/BUILD index 4939b9ea52..ac2714ffba 100644 --- a/xls/passes/BUILD +++ b/xls/passes/BUILD @@ -2909,6 +2909,7 @@ cc_library( "//xls/ir:name_uniquer", "//xls/ir:node_util", "//xls/ir:op", + "//xls/ir:proc_elaboration", "//xls/ir:source_location", "//xls/ir:value", "@com_google_absl//absl/container:btree", @@ -2916,6 +2917,7 @@ cc_library( "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", + "@com_google_absl//absl/memory", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", @@ -2993,6 +2995,7 @@ cc_test( "//xls/ir:ir_matcher", "//xls/ir:ir_parser", "//xls/ir:ir_test_base", + "//xls/ir:proc_elaboration", "//xls/ir:value", "//xls/ir:verifier", "@com_google_absl//absl/container:flat_hash_map", diff --git a/xls/passes/channel_legalization_pass.cc b/xls/passes/channel_legalization_pass.cc index 2622c2a509..eee7885d5d 100644 --- a/xls/passes/channel_legalization_pass.cc +++ b/xls/passes/channel_legalization_pass.cc @@ -17,6 +17,7 @@ #include #include #include +#include #include #include #include @@ -30,6 +31,7 @@ #include "absl/container/flat_hash_set.h" #include "absl/log/check.h" #include "absl/log/log.h" +#include "absl/memory/memory.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" @@ -50,6 +52,7 @@ #include "xls/ir/nodes.h" #include "xls/ir/op.h" #include "xls/ir/package.h" +#include "xls/ir/proc_elaboration.h" #include "xls/ir/source_location.h" #include "xls/ir/value.h" #include "xls/passes/optimization_pass.h" @@ -60,6 +63,227 @@ namespace xls { namespace { + +// Data structure describing a channel for sending data to the adapter. +struct AdapterInputChannel { + StreamingChannel* channel; + // ChannelRef for sending data on the channel outside the adapter. + SendChannelRef parent_send_channel_ref; + // ChannelRef for receiving data on the channel inside the adapter. + ReceiveChannelRef adapter_receive_channel_ref; +}; + +// Data structure describing a channel for receiving data from the adapter. +struct AdapterOutputChannel { + StreamingChannel* channel; + // ChannelRef for receiving data on the channel outside the adapter. + ReceiveChannelRef parent_receive_channel_ref; + // ChannelRef for sending data on the channel inside the adapter. + SendChannelRef adapter_send_channel_ref; +}; + +// Helper class for building an adapter. Abstracts away proc-scoped vs global +// channel behind a uniform interface. +class AdapterBuilder { + public: + static absl::StatusOr> + CreateForProcScopedChannels(ChannelReference* adapted_channel, + std::string_view adapter_name, + Proc* parent_proc) { + XLS_RET_CHECK(parent_proc->package()->ChannelsAreProcScoped()); + std::unique_ptr proc_builder = std::make_unique( + NewStyleProc(), adapter_name, parent_proc->package()); + + // Add input/output for adapted channel. + ChannelReference* adapted_channel_ref_in_adapter; + if (adapted_channel->direction() == Direction::kSend) { + XLS_ASSIGN_OR_RETURN( + adapted_channel_ref_in_adapter, + proc_builder->AddOutputChannel( + adapted_channel->name(), adapted_channel->type(), + ChannelKind::kStreaming, adapted_channel->strictness())); + } else { + XLS_ASSIGN_OR_RETURN( + adapted_channel_ref_in_adapter, + proc_builder->AddInputChannel( + adapted_channel->name(), adapted_channel->type(), + ChannelKind::kStreaming, adapted_channel->strictness())); + } + + // Prime the name uniquer with the current channels in the proc. + auto name_uniquer = std::make_unique("__"); + for (Channel* ch : parent_proc->channels()) { + name_uniquer->GetSanitizedUniqueName(ch->name()); + } + std::unique_ptr builder = + absl::WrapUnique(new AdapterBuilder( + adapted_channel, adapted_channel_ref_in_adapter, + std::move(proc_builder), parent_proc, + /*global_channel_name_uniquer=*/nullptr, + /*proc_scoped_channel_name_uniquer=*/std::move(name_uniquer))); + + builder->adapter_instantiation_args_.push_back(adapted_channel); + return builder; + } + + static absl::StatusOr> + CreateForGlobalChannels(Channel* adapted_channel, + std::string_view adapter_name, Package* package, + NameUniquer& channel_name_uniquer) { + XLS_RET_CHECK(!package->ChannelsAreProcScoped()); + std::unique_ptr proc_builder = + std::make_unique(adapter_name, package); + return absl::WrapUnique(new AdapterBuilder( + adapted_channel, adapted_channel, std::move(proc_builder), + /*parent_proc=*/nullptr, + /*global_channel_name_uniquer=*/&channel_name_uniquer, + /*proc_scoped_channel_name_uniquer=*/nullptr)); + } + + // Create a channel for sending to the adapter. + absl::StatusOr AddAdapterInputChannel( + std::string_view name, Type* type, const FifoConfig& fifo_config, + std::optional strictness = std::nullopt) { + AdapterInputChannel result; + if (ChannelsAreProcScoped()) { + std::string uniquified_name = + proc_scoped_channel_name_uniquer_->GetSanitizedUniqueName(name); + XLS_ASSIGN_OR_RETURN(result.channel, + package()->CreateStreamingChannelInProc( + uniquified_name, ChannelOps::kSendReceive, type, + parent_proc().value(), + /*initial_values=*/{}, fifo_config)); + XLS_ASSIGN_OR_RETURN(result.parent_send_channel_ref, + parent_proc().value()->GetSendChannelReference( + result.channel->name())); + XLS_ASSIGN_OR_RETURN(ReceiveChannelReference * parent_receive_channel_ref, + parent_proc().value()->GetReceiveChannelReference( + result.channel->name())); + XLS_ASSIGN_OR_RETURN( + result.adapter_receive_channel_ref, + adapter_builder().AddInputChannel(name, type, ChannelKind::kStreaming, + strictness)); + adapter_instantiation_args_.push_back(parent_receive_channel_ref); + } else { + std::string uniquified_name = + global_channel_name_uniquer_->GetSanitizedUniqueName(name); + XLS_ASSIGN_OR_RETURN(result.channel, + package()->CreateStreamingChannel( + uniquified_name, ChannelOps::kSendReceive, type, + /*initial_values=*/{}, fifo_config)); + result.parent_send_channel_ref = result.channel; + result.adapter_receive_channel_ref = result.channel; + } + return result; + } + + // Create a channel for receiving from the adapter. + absl::StatusOr AddAdapterOutputChannel( + std::string_view name, Type* type, const FifoConfig& fifo_config, + std::optional strictness = std::nullopt) { + AdapterOutputChannel result; + if (ChannelsAreProcScoped()) { + std::string uniquified_name = + proc_scoped_channel_name_uniquer_->GetSanitizedUniqueName(name); + XLS_ASSIGN_OR_RETURN(result.channel, + package()->CreateStreamingChannelInProc( + uniquified_name, ChannelOps::kSendReceive, type, + parent_proc().value(), + /*initial_values=*/{}, fifo_config)); + XLS_ASSIGN_OR_RETURN(result.parent_receive_channel_ref, + parent_proc().value()->GetReceiveChannelReference( + result.channel->name())); + XLS_ASSIGN_OR_RETURN(SendChannelReference * parent_send_channel_ref, + parent_proc().value()->GetSendChannelReference( + result.channel->name())); + XLS_ASSIGN_OR_RETURN( + result.adapter_send_channel_ref, + adapter_builder().AddOutputChannel( + name, type, ChannelKind::kStreaming, strictness)); + adapter_instantiation_args_.push_back(parent_send_channel_ref); + } else { + std::string uniquified_name = + global_channel_name_uniquer_->GetSanitizedUniqueName(name); + XLS_ASSIGN_OR_RETURN(result.channel, + package()->CreateStreamingChannel( + uniquified_name, ChannelOps::kSendReceive, type, + /*initial_values=*/{}, fifo_config)); + result.parent_receive_channel_ref = result.channel; + result.adapter_send_channel_ref = result.channel; + } + return result; + } + + // Build and return the adapter proc. + absl::StatusOr Build(absl::Span next_state) { + if (ChannelsAreProcScoped()) { + XLS_RETURN_IF_ERROR( + parent_proc() + .value() + ->AddProcInstantiation( + absl::StrFormat("%s_adapter_inst", adapted_channel_name()), + adapter_instantiation_args_, adapter_proc()) + .status()); + } + return adapter_builder().Build(next_state); + } + + bool ChannelsAreProcScoped() { return adapter_proc()->is_new_style_proc(); } + Package* package() { return adapter_proc()->package(); } + + // For proc-scoped channels, returns the proc which instantates and + // communicates with the adapter. For global channels, returns nullopt. + std::optional parent_proc() { return parent_proc_; } + + Proc* adapter_proc() { return adapter_builder().proc(); } + ProcBuilder& adapter_builder() { return *adapter_builder_; } + + // Returns the ChannelRef of original (adapted) channel for sending/receiving + // outside the adapter. + ChannelRef adapted_channel_ref_in_parent() { + return adapted_channel_ref_in_parent_; + } + // Returns the ChannelRef of original (adapted) channel for sending/receiving + // inside the adapter. + ChannelRef adapted_channel_ref_in_adapter() { + return adapted_channel_ref_in_adapter_; + } + + // Returns various properties of the original (adapted) channel. + ChannelStrictness adapted_channel_strictness() { + return ChannelRefStrictness(adapted_channel_ref_in_parent()).value(); + } + std::string_view adapted_channel_name() { + return ChannelRefName(adapted_channel_ref_in_parent()); + } + Type* adapted_channel_type() { + return ChannelRefType(adapted_channel_ref_in_parent()); + } + + private: + AdapterBuilder(ChannelRef adapted_channel_ref_in_parent, + ChannelRef adapted_channel_ref_in_adapter, + std::unique_ptr adapter_builder, + std::optional parent_proc, + NameUniquer* global_channel_name_uniquer, + std::unique_ptr proc_scoped_channel_name_uniquer) + : adapted_channel_ref_in_parent_(adapted_channel_ref_in_parent), + adapted_channel_ref_in_adapter_(adapted_channel_ref_in_adapter), + adapter_builder_(std::move(adapter_builder)), + parent_proc_(parent_proc), + global_channel_name_uniquer_(global_channel_name_uniquer), + proc_scoped_channel_name_uniquer_( + std::move(proc_scoped_channel_name_uniquer)) {} + + ChannelRef adapted_channel_ref_in_parent_; + ChannelRef adapted_channel_ref_in_adapter_; + std::unique_ptr adapter_builder_; + std::optional parent_proc_; + NameUniquer* global_channel_name_uniquer_; + std::unique_ptr proc_scoped_channel_name_uniquer_; + std::vector adapter_instantiation_args_; +}; + // Get the token operand number for a given channel op. absl::StatusOr TokenOperandNumberForChannelOp(Node* node) { switch (node->op()) { @@ -72,45 +296,119 @@ absl::StatusOr TokenOperandNumberForChannelOp(Node* node) { absl::StrFormat("Expected channel op, got %s.", node->ToString())); } } + +struct ChannelSends { + Channel* channel; + std::vector sends; +}; + +struct ChannelReceives { + Channel* channel; + std::vector receives; +}; + struct MultipleChannelOps { - absl::flat_hash_map> multiple_sends; - absl::flat_hash_map> - multiple_receives; + std::vector sends; + std::vector receives; }; +absl::StatusOr ElaboratePackage(Package* package) { + std::optional top = package->GetTop(); + if (!top.has_value() || !(*top)->IsProc() || + !(*top)->AsProcOrDie()->is_new_style_proc()) { + return ProcElaboration::ElaborateOldStylePackage(package); + } + return ProcElaboration::Elaborate((*top)->AsProcOrDie()); +} + // Find instances of multiple sends/recvs on a channel. -MultipleChannelOps FindMultipleChannelOps(Package* p) { - MultipleChannelOps result; - for (FunctionBase* fb : p->GetFunctionBases()) { - for (Node* node : fb->nodes()) { +absl::StatusOr FindMultipleChannelOps( + const ProcElaboration& elab) { + // Create a map from channel to the set of send/receive nodes on the channel + // and vice-versa. + absl::flat_hash_map> channel_sends; + absl::flat_hash_map> channel_receives; + absl::flat_hash_map> send_channels; + absl::flat_hash_map> receive_channels; + for (ProcInstance* proc_instance : elab.proc_instances()) { + for (Node* node : proc_instance->proc()->nodes()) { if (node->Is()) { Send* send = node->As(); - VLOG(4) << "Found send " << send->ToString(); - result.multiple_sends[send->channel_name()].insert(send); + XLS_ASSIGN_OR_RETURN( + ChannelInstance * channel_instance, + proc_instance->GetChannelInstance(send->channel_name())); + channel_sends[channel_instance->channel].insert(send); + send_channels[send].insert(channel_instance->channel); } if (node->Is()) { - Receive* recv = node->As(); - result.multiple_receives[recv->channel_name()].insert(recv); - VLOG(4) << "Found recv " << recv->ToString(); + Receive* receive = node->As(); + XLS_ASSIGN_OR_RETURN( + ChannelInstance * channel_instance, + proc_instance->GetChannelInstance(receive->channel_name())); + channel_receives[channel_instance->channel].insert(receive); + receive_channels[receive].insert(channel_instance->channel); + } + } + } + + MultipleChannelOps result; + + // Identify channels which have multiple sends/receives. Return an error if + // there is a send/receive which can send on different channels *AND* is also + // part of a set multiple such ops on the same channel. For now in channel + // legalization we require sends/receives to be uniquely mapped to a single + // channel. + for (auto [channel, sends] : channel_sends) { + if (sends.size() <= 1) { + continue; + } + ChannelSends element; + element.channel = channel; + element.sends = std::vector(sends.begin(), sends.end()); + std::sort(element.sends.begin(), element.sends.end(), NodeIdLessThan); + + for (Send* send : element.sends) { + if (send_channels.at(send).size() > 1) { + return absl::UnimplementedError( + absl::StrFormat("Send `%s` can send on different channels and is " + "one of multiple sends on channel `%s`", + send->GetName(), channel->name())); + } + } + result.sends.push_back(std::move(element)); + } + std::sort(result.sends.begin(), result.sends.end(), + [](const ChannelSends& a, const ChannelSends& b) { + return a.channel->id() < b.channel->id(); + }); + + for (auto [channel, receives] : channel_receives) { + if (receives.size() <= 1) { + continue; + } + ChannelReceives element; + element.channel = channel; + element.receives = std::vector(receives.begin(), receives.end()); + std::sort(element.receives.begin(), element.receives.end(), NodeIdLessThan); + + for (Receive* receive : element.receives) { + if (receive_channels.at(receive).size() > 1) { + return absl::UnimplementedError(absl::StrFormat( + "Receive `%s` can receive on different channels and is " + "one of multiple receives on channel `%s`", + receive->GetName(), channel->name())); } } + result.receives.push_back(std::move(element)); } + std::sort(result.receives.begin(), result.receives.end(), + [](const ChannelReceives& a, const ChannelReceives& b) { + return a.channel->id() < b.channel->id(); + }); - // Erase cases where there's only one send or receive. - absl::erase_if( - result.multiple_sends, - [](const std::pair>& elt) { - return elt.second.size() < 2; - }); - absl::erase_if( - result.multiple_receives, - [](const std::pair>& elt) { - return elt.second.size() < 2; - }); - - VLOG(4) << "After erasing single accesses, found " - << result.multiple_sends.size() << " multiple send channels and " - << result.multiple_receives.size() << " multiple receive channels."; + VLOG(4) << "After erasing single accesses, found " << result.receives.size() + << " multiple send channels and " << result.receives.size() + << " multiple receive channels."; return result; } @@ -223,7 +521,7 @@ struct FunctionBaseNameLess { // Note that token params are not included in predecessor lists. template >> absl::StatusOr> GetProjectedTokenDAG( - const absl::flat_hash_set& operations, ChannelStrictness strictness) { + absl::Span operations, ChannelStrictness strictness) { // We return the result_vector, but also build a result_map to track // transitive dependencies. std::vector result_vector; @@ -260,6 +558,8 @@ absl::StatusOr> GetProjectedTokenDAG( for (const T* operation : operations) { fbs.insert(operation->function_base()); } + + absl::flat_hash_set operation_set(operations.begin(), operations.end()); for (FunctionBase* fb : fbs) { XLS_ASSIGN_OR_RETURN(std::vector fbs_result, ComputeTopoSortedTokenDAG(fb)); @@ -279,7 +579,7 @@ absl::StatusOr> GetProjectedTokenDAG( for (Node* predecessor : fb_result.predecessors) { // If a predecessor is not type T, resolve its predecessors of type T. if (!predecessor->Is() || - !operations.contains(predecessor->As())) { + !operation_set.contains(predecessor->As())) { absl::flat_hash_set& resolved = resolved_ops.at(predecessor); resolved_predecessors.insert(resolved.begin(), resolved.end()); continue; @@ -289,7 +589,7 @@ absl::StatusOr> GetProjectedTokenDAG( // If this entry in the DAG is not type T, save its resolved predecessors // for future resolution. if (!fb_result.node->Is() || - !operations.contains(fb_result.node->As())) { + !operation_set.contains(fb_result.node->As())) { resolved_ops[fb_result.node] = std::move(resolved_predecessors); continue; } @@ -299,7 +599,7 @@ absl::StatusOr> GetProjectedTokenDAG( if (strictness == ChannelStrictness::kArbitraryStaticOrder) { if (prev_node.has_value()) { if (!prev_node.value()->Is() || - !operations.contains(prev_node.value()->As())) { + !operation_set.contains(prev_node.value()->As())) { resolved_predecessors.insert(resolved_ops.at(*prev_node).begin(), resolved_ops.at(*prev_node).end()); } else { @@ -345,31 +645,24 @@ absl::Status CheckIsBlocking(Node* n) { // original channel operation. // 3) Updates the token of the original channel operation to come after the new // predicate send. -absl::StatusOr MakePredicateChannel( - Node* operation, NameUniquer& channel_name_uniquer) { +absl::StatusOr MakePredicateChannel( + Node* operation, ChannelRef channel_ref, int64_t instance_number, + AdapterBuilder& ab) { Package* package = operation->package(); - - XLS_ASSIGN_OR_RETURN(Channel * operation_channel, - GetChannelUsedByNode(operation)); - StreamingChannel* channel = down_cast(operation_channel); Proc* proc = operation->function_base()->AsProcOrDie(); XLS_ASSIGN_OR_RETURN( - StreamingChannel * pred_channel, - package->CreateStreamingChannel( - channel_name_uniquer.GetSanitizedUniqueName( - absl::StrCat(channel->name(), "__pred")), - // This is an internal channel, so override to kSendReceive - ChannelOps::kSendReceive, + AdapterInputChannel pred_input_channel, + ab.AddAdapterInputChannel( + absl::StrFormat("%s__pred_%d", ChannelRefName(channel_ref), + instance_number), // This channel is used to forward the predicate to the adapter, which // takes 1 bit. package->GetBitsType(1), - /*initial_values=*/{}, // This is an internal channel that may be inlined during proc // inlining, set FIFO depth to 1. Break cycles by registering push // outputs. // TODO: github/xls#1509 - revisit this if we have better ways of // avoiding cycles in adapters. - /*fifo_config=*/ FifoConfig(/*depth=*/1, /*bypass=*/false, /*register_push_outputs=*/true, /*register_pop_outputs=*/false))); @@ -381,24 +674,27 @@ absl::StatusOr MakePredicateChannel( // If predicate nullopt, that means it's an unconditional send/receive. Make a // true literal to send to the adapter. if (!predicate.has_value()) { - XLS_ASSIGN_OR_RETURN( - predicate, - proc->MakeNodeWithName( - SourceInfo(), Value(UBits(1, /*bit_count=*/1)), - absl::StrFormat("true_predicate_for_chan_%s", channel->name()))); + XLS_ASSIGN_OR_RETURN(predicate, + proc->MakeNodeWithName( + SourceInfo(), Value(UBits(1, /*bit_count=*/1)), + absl::StrFormat("true_predicate_for_chan_%s", + ChannelRefName(channel_ref)))); } XLS_ASSIGN_OR_RETURN(int64_t operand_number, TokenOperandNumberForChannelOp(operation)); - XLS_ASSIGN_OR_RETURN( - Send * send_pred, - proc->MakeNodeWithName( - SourceInfo(), operation->operand(operand_number), predicate.value(), - std::nullopt, pred_channel->name(), - absl::StrFormat("send_predicate_for_chan_%s", channel->name()))); + XLS_ASSIGN_OR_RETURN(Send * send_pred, + proc->MakeNodeWithName( + SourceInfo(), operation->operand(operand_number), + predicate.value(), std::nullopt, + ChannelRefName(AsChannelRef( + pred_input_channel.parent_send_channel_ref)), + absl::StrFormat("send_predicate_for_chan_%s", + ChannelRefName(channel_ref)))); // Replace the op's original input token with the predicate send's token. XLS_RETURN_IF_ERROR(operation->ReplaceOperandNumber( operand_number, send_pred, /*type_must_match=*/true)); - return pred_channel; + + return pred_input_channel; } // Add a new channel to communicate a channel operation's completion to the @@ -410,26 +706,20 @@ absl::StatusOr MakePredicateChannel( // 2) Adds a receive of the completion on this channel at the token level of the // original channel operation. // 3) Updates the token of the original channel operation to come after the new -// completion receive. -absl::StatusOr MakeCompletionChannel( - Node* operation, NameUniquer& channel_name_uniquer) { - Package* package = operation->package(); - - XLS_ASSIGN_OR_RETURN(Channel * operation_channel, - GetChannelUsedByNode(operation)); - StreamingChannel* channel = down_cast(operation_channel); +// completion_42` receive. +absl::StatusOr MakeCompletionChannel( + Node* operation, int64_t instance_number, AdapterBuilder& ab) { + Package* package = ab.package(); Proc* proc = operation->function_base()->AsProcOrDie(); + XLS_ASSIGN_OR_RETURN( - StreamingChannel * completion_channel, - package->CreateStreamingChannel( - channel_name_uniquer.GetSanitizedUniqueName( - absl::StrCat(channel->name(), "__completion")), - // This is an internal channel, so override to kSendReceive - ChannelOps::kSendReceive, + AdapterOutputChannel completion_channel, + ab.AddAdapterOutputChannel( + absl::StrFormat("%s__completion_%i", ab.adapted_channel_name(), + instance_number), // This channel is used to mark the completion of the requested // operation and doesn't carry any data. Use an empty tuple type. package->GetTupleType({}), - /*initial_values=*/{}, // This is an internal channel that may be inlined during proc // inlining, set FIFO depth to 0. Completion channels seem to not // cause cycles when they have timing paths between push and pop. @@ -448,10 +738,10 @@ absl::StatusOr MakeCompletionChannel( // true literal to send to the adapter. if (!predicate.has_value()) { XLS_ASSIGN_OR_RETURN( - predicate, - proc->MakeNodeWithName( - SourceInfo(), Value(UBits(1, /*bit_count=*/1)), - absl::StrFormat("true_predicate_for_chan_%s", channel->name()))); + predicate, proc->MakeNodeWithName( + SourceInfo(), Value(UBits(1, /*bit_count=*/1)), + absl::StrFormat("true_predicate_for_chan_%s", + completion_channel.channel->name()))); } XLS_ASSIGN_OR_RETURN(Node * free_token, proc->MakeNode(SourceInfo(), Value::Token())); @@ -459,13 +749,15 @@ absl::StatusOr MakeCompletionChannel( Receive * recv_completion, proc->MakeNodeWithName( SourceInfo(), free_token, predicate.value(), - completion_channel->name(), /*is_blocking=*/true, - absl::StrFormat("recv_completion_for_chan_%s", channel->name()))); - XLS_ASSIGN_OR_RETURN(Node * recv_completion_token, - proc->MakeNodeWithName( - SourceInfo(), recv_completion, 0, - absl::StrFormat("recv_completion_token_for_chan_%s", - channel->name()))); + completion_channel.channel->name(), /*is_blocking=*/true, + absl::StrFormat("recv_completion_for_chan_%s", + completion_channel.channel->name()))); + XLS_ASSIGN_OR_RETURN( + Node * recv_completion_token, + proc->MakeNodeWithName( + SourceInfo(), recv_completion, 0, + absl::StrFormat("recv_completion_token_for_chan_%s", + completion_channel.channel->name()))); // Replace usages of the token from the send/recv operation with the token // from the completion recv. switch (operation->op()) { @@ -499,11 +791,6 @@ absl::StatusOr MakeCompletionChannel( return completion_channel; } -struct ClonedChannelWithPredicate { - StreamingChannel* cloned_channel; - StreamingChannel* predicate_channel; -}; - // The adapter orders multiple channel operations with state that tracks if // there is an outstanding channel operation waiting to complete. struct ActivationState { @@ -742,35 +1029,40 @@ void MakeDebugTrace(BValue condition, // Makes activation network and adds asserts. absl::StatusOr MakeActivationNetwork( - ProcBuilder& pb, absl::Span token_dag, - ChannelStrictness strictness, NameUniquer& channel_name_uniquer) { + AdapterBuilder& ab, absl::Span token_dag) { ActivationNetwork activations; // First, make new predicate channels. The adapter will non-blocking receive // on each of these channels to get each operation's predicate. - absl::flat_hash_map pred_channels; + absl::flat_hash_map pred_channels; pred_channels.reserve(token_dag.size()); - for (const auto& [node, _] : token_dag) { - XLS_ASSIGN_OR_RETURN(StreamingChannel * pred_channel, - MakePredicateChannel(node, channel_name_uniquer)); - pred_channels.insert({node, pred_channel}); + + for (int64_t instance_number = 0; instance_number < token_dag.size(); + ++instance_number) { + Node* node = token_dag[instance_number].node; + XLS_ASSIGN_OR_RETURN(ChannelRef channel_ref, GetChannelRefUsedByNode(node)); + XLS_ASSIGN_OR_RETURN( + AdapterInputChannel pred_input_channel, + MakePredicateChannel(node, channel_ref, instance_number, ab)); + pred_channels.insert({node, pred_input_channel}); } // Now make state. We need to know the state in order to know the predicate on // the predicate receive. int64_t state_idx = 0; - for (const auto& [node, _] : token_dag) { - int64_t pred_channel_id = pred_channels.at(node)->id(); + for (int64_t instance_number = 0; instance_number < token_dag.size(); + ++instance_number) { + Node* node = token_dag[instance_number].node; ActivationNode activation; - activation.predicate_state.state = - pb.StateElement(absl::StrFormat("pred_%d", pred_channel_id), - Value(UBits(0, /*bit_count=*/1))); - activation.valid_state.state = - pb.StateElement(absl::StrFormat("pred_%d_valid", pred_channel_id), - Value(UBits(0, /*bit_count=*/1))); - activation.done_state.state = - pb.StateElement(absl::StrFormat("pred_%d_done", pred_channel_id), - Value(UBits(0, /*bit_count=*/1))); + activation.predicate_state.state = ab.adapter_builder().StateElement( + absl::StrFormat("pred_%d", instance_number), + Value(UBits(0, /*bit_count=*/1))); + activation.valid_state.state = ab.adapter_builder().StateElement( + absl::StrFormat("pred_%d_valid", instance_number), + Value(UBits(0, /*bit_count=*/1))); + activation.done_state.state = ab.adapter_builder().StateElement( + absl::StrFormat("pred_%d_done", instance_number), + Value(UBits(0, /*bit_count=*/1))); activation.predicate_state.state_idx = state_idx++; activation.valid_state.state_idx = state_idx++; @@ -780,8 +1072,13 @@ absl::StatusOr MakeActivationNetwork( // Make a non-blocking receive on the predicate channel for each operation. // Compute predicate values and activations for this tick. - for (const auto& [node, predecessors] : token_dag) { - StreamingChannel* pred_channel = pred_channels.at(node); + for (int64_t instance_number = 0; instance_number < token_dag.size(); + ++instance_number) { + Node* node = token_dag[instance_number].node; + const absl::flat_hash_set& predecessors = + token_dag[instance_number].predecessors; + + const AdapterInputChannel& pred_input_channel = pred_channels.at(node); ActivationNode& activation = activations.at(node); std::vector sorted_predecessors(predecessors.begin(), @@ -793,8 +1090,8 @@ absl::StatusOr MakeActivationNetwork( BValue predecessors_pred_recv_token = PredRecvTokensForActivations( activations, sorted_predecessors, absl::StrFormat("chan_%d_recv_pred_predeccesors_token", - pred_channel->id()), - pb); + instance_number), + ab.adapter_builder()); // Do a non-blocking receive on the predicate channel. // @@ -804,37 +1101,42 @@ absl::StatusOr MakeActivationNetwork( // successfully received and 'valid' indicates that it has been successfully // received. After the adapter completes all operations, 'valid' will be // reset to 0 and everything starts over again. - BValue do_pred_recv = pb.Not(activation.valid_state.state); - BValue recv = pb.ReceiveIfNonBlocking( - pred_channel, predecessors_pred_recv_token, do_pred_recv, SourceInfo(), - absl::StrFormat("recv_pred_%d", pred_channel->id())); - activation.pred_recv_token = pb.TupleIndex( + BValue do_pred_recv = + ab.adapter_builder().Not(activation.valid_state.state); + BValue recv = ab.adapter_builder().ReceiveIfNonBlocking( + pred_input_channel.adapter_receive_channel_ref, + predecessors_pred_recv_token, do_pred_recv, SourceInfo(), + absl::StrFormat("recv_pred_%d", instance_number)); + activation.pred_recv_token = ab.adapter_builder().TupleIndex( recv, 0, SourceInfo(), - absl::StrFormat("recv_pred_%d_token", pred_channel->id())); - BValue pred_recv_predicate = - pb.TupleIndex(recv, 1, SourceInfo(), - absl::StrFormat("recv_pred_%d_data", pred_channel->id())); - BValue pred_recv_valid = pb.TupleIndex( + absl::StrFormat("recv_pred_%d_token", instance_number)); + BValue pred_recv_predicate = ab.adapter_builder().TupleIndex( + recv, 1, SourceInfo(), + absl::StrFormat("recv_pred_%d_data", instance_number)); + BValue pred_recv_valid = ab.adapter_builder().TupleIndex( recv, 2, SourceInfo(), - absl::StrFormat("recv_pred_%d_valid", pred_channel->id())); - activation.predicate = pb.Or( + absl::StrFormat("recv_pred_%d_valid", instance_number)); + activation.predicate = ab.adapter_builder().Or( pred_recv_predicate, activation.predicate_state.state, SourceInfo(), - absl::StrFormat("pred_%d_updated", pred_channel->id())); - activation.valid = - pb.Or(pred_recv_valid, activation.valid_state.state, SourceInfo(), - absl::StrFormat("pred_%d_valid_updated", pred_channel->id())); - - BValue all_predecessors_done = AllPredecessorsDone( - activations, sorted_predecessors, - absl::StrFormat("%v_all_predecessors_done", *node), pb); - activation.activate = - pb.And({activation.predicate, activation.valid, - pb.Not(activation.done_state.state), all_predecessors_done}, - SourceInfo(), absl::StrFormat("%v_activate", *node)); - activation.done = - pb.Or({activation.activate, - pb.And(pb.Not(activation.predicate), activation.valid), - activation.done_state.state}); + absl::StrFormat("pred_%d_updated", instance_number)); + activation.valid = ab.adapter_builder().Or( + pred_recv_valid, activation.valid_state.state, SourceInfo(), + absl::StrFormat("pred_%d_valid_updated", instance_number)); + + BValue all_predecessors_done = + AllPredecessorsDone(activations, sorted_predecessors, + absl::StrFormat("%v_all_predecessors_done", *node), + ab.adapter_builder()); + activation.activate = ab.adapter_builder().And( + {activation.predicate, activation.valid, + ab.adapter_builder().Not(activation.done_state.state), + all_predecessors_done}, + SourceInfo(), absl::StrFormat("%v_activate", *node)); + activation.done = ab.adapter_builder().Or( + {activation.activate, + ab.adapter_builder().And( + ab.adapter_builder().Not(activation.predicate), activation.valid), + activation.done_state.state}); } // Make assertions that operations w/ no token ordering are mutually @@ -844,42 +1146,36 @@ absl::StatusOr MakeActivationNetwork( // are empty, so every operation should be mutually exclusive. In the // arbitrary_static_order case, nodes are linearized and have an added token // relationship with every other operation, and there should be no assertions. - MakeMutualExclusionAssertions(token_dag, activations, pb); + MakeMutualExclusionAssertions(token_dag, activations, ab.adapter_builder()); // Compute next-state signals. // Builds not_all_done signal which indicates if every operation is done. If // so, start the next set of operations by setting next signals for done and // valid to 0. - BValue not_all_done = - NextStateAndReturnNotAllDone(token_dag, activations, pb); + BValue not_all_done = NextStateAndReturnNotAllDone(token_dag, activations, + ab.adapter_builder()); // Make a trace to aid debugging. It only fires when all ops are done. - MakeDebugTrace(pb.Not(not_all_done), token_dag, activations, pb); + MakeDebugTrace(ab.adapter_builder().Not(not_all_done), token_dag, activations, + ab.adapter_builder()); return std::move(activations); } -absl::Status AddAdapterForMultipleReceives( - Package* p, StreamingChannel* channel, - const absl::flat_hash_set& ops, - NameUniquer& channel_name_uniquer) { +absl::Status AddAdapterForMultipleReceives(absl::Span ops, + AdapterBuilder& ab) { XLS_RET_CHECK_GT(ops.size(), 1); - XLS_ASSIGN_OR_RETURN(auto token_dags, - GetProjectedTokenDAG(ops, channel->GetStrictness())); - - std::string adapter_name = - absl::StrFormat("chan_%s_io_receive_adapter", channel->name()); + XLS_ASSIGN_OR_RETURN( + auto token_dags, + GetProjectedTokenDAG(ops, ab.adapted_channel_strictness())); - VLOG(4) << absl::StreamFormat("Channel %s has token dag %s.", channel->name(), + VLOG(4) << absl::StreamFormat("Channel %s has token dag %s.", + ab.adapted_channel_name(), absl::StrJoin(token_dags, ", ")); - ProcBuilder pb(adapter_name, p); - - XLS_ASSIGN_OR_RETURN( - ActivationNetwork activations, - MakeActivationNetwork(pb, token_dags, channel->GetStrictness(), - channel_name_uniquer)); + XLS_ASSIGN_OR_RETURN(ActivationNetwork activations, + MakeActivationNetwork(ab, token_dags)); BValue any_active; BValue external_recv_input_token; @@ -891,64 +1187,63 @@ absl::Status AddAdapterForMultipleReceives( all_activations.push_back(activations.at(node).activate); all_tokens.push_back(activations.at(node).pred_recv_token); } - any_active = pb.Or(all_activations, SourceInfo(), "any_active"); - external_recv_input_token = pb.AfterAll(all_tokens); + any_active = + ab.adapter_builder().Or(all_activations, SourceInfo(), "any_active"); + external_recv_input_token = ab.adapter_builder().AfterAll(all_tokens); } - BValue recv = pb.ReceiveIf(channel, external_recv_input_token, any_active, - SourceInfo(), "external_receive"); - BValue recv_token = - pb.TupleIndex(recv, 0, SourceInfo(), "external_receive_token"); - BValue recv_data = - pb.TupleIndex(recv, 1, SourceInfo(), "external_receive_data"); - - for (const auto& [node, _] : token_dags) { - const ActivationNode& activation = activations.at(node); + BValue recv = ab.adapter_builder().ReceiveIf( + AsReceiveChannelRefOrDie(ab.adapted_channel_ref_in_adapter()), + external_recv_input_token, any_active, SourceInfo(), "external_receive"); + BValue recv_token = ab.adapter_builder().TupleIndex(recv, 0, SourceInfo(), + "external_receive_token"); + BValue recv_data = ab.adapter_builder().TupleIndex(recv, 1, SourceInfo(), + "external_receive_data"); + + for (int64_t i = 0; i < token_dags.size(); ++i) { + Node* node = token_dags[i].node; + const ActivationNode& activation = activations.at(node); XLS_ASSIGN_OR_RETURN( - Channel * new_data_channel, - p->CloneChannel( - channel, - channel_name_uniquer.GetSanitizedUniqueName(channel->name()), - Package::CloneChannelOverrides() - .OverrideSupportedOps(ChannelOps::kSendReceive) - .OverrideFifoConfig( - // This is an internal channel that may be inlined during - // proc inlining, set FIFO depth to 1. Break cycles by - // registering push outputs. - // TODO: github/xls#1509 - revisit this if we have better - // ways of avoiding cycles in adapters. - FifoConfig(/*depth=*/1, /*bypass=*/false, - /*register_push_outputs=*/true, - /*register_pop_outputs=*/false)))); - XLS_RETURN_IF_ERROR( - ReplaceChannelUsedByNode(node, new_data_channel->name())); - BValue send_token = pb.AfterAll({activation.pred_recv_token, recv_token}); - pb.SendIf(new_data_channel, send_token, activation.activate, recv_data); + AdapterOutputChannel output_data_channel, + ab.AddAdapterOutputChannel( + absl::StrFormat("%s__data_%d", ab.adapted_channel_name(), i), + ab.adapted_channel_type(), + // This is an internal channel that may be inlined during proc + // inlining, set FIFO depth to 1. Break cycles by registering push + // outputs. + // TODO: github/xls#1509 - revisit this if we have better ways of + // avoiding cycles in adapters. + FifoConfig(/*depth=*/1, /*bypass=*/false, + /*register_push_outputs=*/true, + /*register_pop_outputs=*/false))); + XLS_RETURN_IF_ERROR(ReplaceChannelUsedByNode( + node, ChannelRefName(AsChannelRef( + output_data_channel.parent_receive_channel_ref)))); + BValue send_token = + ab.adapter_builder().AfterAll({activation.pred_recv_token, recv_token}); + ab.adapter_builder().SendIf(output_data_channel.adapter_send_channel_ref, + send_token, activation.activate, recv_data); } - return pb.Build(NextState(activations)).status(); + return ab.Build(NextState(activations)).status(); } -absl::Status AddAdapterForMultipleSends(Package* p, StreamingChannel* channel, - const absl::flat_hash_set& ops, - NameUniquer& channel_name_uniquer) { +absl::Status AddAdapterForMultipleSends(absl::Span ops, + AdapterBuilder& ab) { XLS_RET_CHECK_GT(ops.size(), 1); - XLS_ASSIGN_OR_RETURN(auto token_dags, - GetProjectedTokenDAG(ops, channel->GetStrictness())); - - std::string adapter_name = - absl::StrFormat("chan_%s_io_send_adapter", channel->name()); + XLS_ASSIGN_OR_RETURN( + auto token_dags, + GetProjectedTokenDAG(ops, ab.adapted_channel_strictness())); - VLOG(4) << absl::StreamFormat("Channel %s has token dag %s.", channel->name(), + VLOG(4) << absl::StreamFormat("Channel %s has token dag %s.", + ab.adapted_channel_name(), absl::StrJoin(token_dags, ", ")); - ProcBuilder pb(adapter_name, p); + XLS_ASSIGN_OR_RETURN(ActivationNetwork activations, + MakeActivationNetwork(ab, token_dags)); - XLS_ASSIGN_OR_RETURN( - ActivationNetwork activations, - MakeActivationNetwork(pb, token_dags, channel->GetStrictness(), - channel_name_uniquer)); + absl::flat_hash_map data_channels; BValue recv_after_all; BValue recv_data; @@ -961,52 +1256,60 @@ absl::Status AddAdapterForMultipleSends(Package* p, StreamingChannel* channel, std::vector recv_data_valids; recv_data_valids.reserve(token_dags.size()); - for (const auto& [node, _] : token_dags) { + for (int64_t i = 0; i < token_dags.size(); ++i) { + Node* node = token_dags[i].node; const ActivationNode& activation = activations.at(node); XLS_ASSIGN_OR_RETURN( - Channel * new_data_channel, - p->CloneChannel( - channel, - channel_name_uniquer.GetSanitizedUniqueName(channel->name()), - Package::CloneChannelOverrides() - .OverrideSupportedOps(ChannelOps::kSendReceive) - .OverrideFifoConfig( - // This is an internal channel that may be inlined during - // proc inlining, set FIFO depth to 1. Break cycles by - // registering push outputs. - // TODO: github/xls#1509 - revisit this if we have better - // ways of avoiding cycles in adapters. - FifoConfig(/*depth=*/1, /*bypass=*/false, - /*register_push_outputs=*/true, - /*register_pop_outputs=*/false)))); - XLS_RETURN_IF_ERROR( - ReplaceChannelUsedByNode(node, new_data_channel->name())); - BValue recv = pb.ReceiveIf(new_data_channel, activation.pred_recv_token, - activation.activate); - recv_tokens.push_back(pb.TupleIndex(recv, 0)); - recv_datas.push_back(pb.TupleIndex(recv, 1)); + AdapterInputChannel input_data_channel, + ab.AddAdapterInputChannel( + absl::StrFormat("%s__data_%d", ab.adapted_channel_name(), i), + ab.adapted_channel_type(), + // This is an internal channel that may be inlined during proc + // inlining, set FIFO depth to 1. Break cycles by registering push + // outputs. + // TODO: github/xls#1509 - revisit this if we have better ways of + // avoiding cycles in adapters. + FifoConfig(/*depth=*/1, /*bypass=*/false, + /*register_push_outputs=*/true, + /*register_pop_outputs=*/false))); + XLS_RETURN_IF_ERROR(ReplaceChannelUsedByNode( + node, ChannelRefName( + AsChannelRef(input_data_channel.parent_send_channel_ref)))); + BValue recv = ab.adapter_builder().ReceiveIf( + input_data_channel.adapter_receive_channel_ref, + activation.pred_recv_token, activation.activate); + recv_tokens.push_back(ab.adapter_builder().TupleIndex(recv, 0)); + recv_datas.push_back(ab.adapter_builder().TupleIndex(recv, 1)); recv_data_valids.push_back(activation.activate); + + data_channels[node] = input_data_channel; } - recv_after_all = pb.AfterAll(recv_tokens); + recv_after_all = ab.adapter_builder().AfterAll(recv_tokens); // Reverse for one hot select order. std::reverse(recv_data_valids.begin(), recv_data_valids.end()); - recv_data = pb.OneHotSelect(pb.Concat(recv_data_valids), recv_datas); - recv_data_valid = pb.Or(recv_data_valids); + recv_data = ab.adapter_builder().OneHotSelect( + ab.adapter_builder().Concat(recv_data_valids), recv_datas); + recv_data_valid = ab.adapter_builder().Or(recv_data_valids); } - BValue send_token = pb.SendIf(channel, recv_after_all, recv_data_valid, - recv_data, SourceInfo(), "external_send"); - BValue empty_tuple_literal = pb.Literal(Value::Tuple({})); - - for (const auto& [node, _] : token_dags) { - XLS_ASSIGN_OR_RETURN(StreamingChannel * completion_channel, - MakeCompletionChannel(node, channel_name_uniquer)); - pb.SendIf(completion_channel, send_token, activations.at(node).activate, - empty_tuple_literal); + BValue send_token = ab.adapter_builder().SendIf( + AsSendChannelRefOrDie(ab.adapted_channel_ref_in_adapter()), + recv_after_all, recv_data_valid, recv_data, SourceInfo(), + "external_send"); + BValue empty_tuple_literal = ab.adapter_builder().Literal(Value::Tuple({})); + + for (int64_t i = 0; i < token_dags.size(); ++i) { + Node* node = token_dags[i].node; + XLS_ASSIGN_OR_RETURN(AdapterOutputChannel completion_channel, + MakeCompletionChannel(node, i, ab)); + ab.adapter_builder().SendIf(completion_channel.adapter_send_channel_ref, + send_token, activations.at(node).activate, + empty_tuple_literal); } - return pb.Build(NextState(activations)).status(); + return ab.Build(NextState(activations)).status(); } + } // namespace absl::StatusOr ChannelLegalizationPass::RunInternal( @@ -1014,19 +1317,22 @@ absl::StatusOr ChannelLegalizationPass::RunInternal( PassResults* results) const { VLOG(3) << "Running channel legalization pass."; bool changed = false; - MultipleChannelOps multiple_ops = FindMultipleChannelOps(p); + XLS_ASSIGN_OR_RETURN(ProcElaboration elab, ElaboratePackage(p)); + XLS_ASSIGN_OR_RETURN(MultipleChannelOps multiple_ops, + FindMultipleChannelOps(elab)); - if (multiple_ops.multiple_receives.empty() && - multiple_ops.multiple_sends.empty()) { + if (multiple_ops.receives.empty() && multiple_ops.sends.empty()) { return false; } - NameUniquer channel_name_uniquer("__"); - for (Channel* channel : p->channels()) { - channel_name_uniquer.GetSanitizedUniqueName(channel->name()); + NameUniquer global_channel_name_uniquer("__"); + if (!p->ChannelsAreProcScoped()) { + for (Channel* channel : p->channels()) { + global_channel_name_uniquer.GetSanitizedUniqueName(channel->name()); + } } - for (const auto& [channel_name, ops] : multiple_ops.multiple_receives) { + for (const auto& [channel, ops] : multiple_ops.receives) { for (Receive* recv : ops) { if (!recv->is_blocking()) { return absl::InvalidArgumentError(absl::StrFormat( @@ -1035,9 +1341,10 @@ absl::StatusOr ChannelLegalizationPass::RunInternal( recv->GetName())); } } - XLS_ASSIGN_OR_RETURN(Channel * channel, p->GetChannel(channel_name)); if (channel->kind() != ChannelKind::kStreaming) { // Don't make adapters for non-streaming channels. + VLOG(4) << absl::StreamFormat( + "Multiple receives on non-streaming channel `%s`", channel->name()); continue; } StreamingChannel* streaming_channel = down_cast(channel); @@ -1045,19 +1352,50 @@ absl::StatusOr ChannelLegalizationPass::RunInternal( ChannelStrictness::kProvenMutuallyExclusive) { // Don't make adapters for channels that must be proven to be mutually // exclusive- they will be handled during scheduling. + VLOG(3) << absl::StreamFormat( + "Multiple receives on proven mutually exclusive channel `%s`", + channel->name()); continue; } + + std::unique_ptr adapter_builder; + std::string adapter_name = + absl::StrFormat("chan_%s_io_receive_adapter", channel->name()); + if (p->ChannelsAreProcScoped()) { + // For proc-scoped channels all receives are in the same proc. + Proc* proc = ops.front()->function_base()->AsProcOrDie(); + for (Receive* op : ops) { + if (op->function_base()->AsProcOrDie() != proc) { + return absl::UnimplementedError(absl::StrFormat( + "Channel `%s` has multiple receives in different procs: `%s` and " + "`%s`", + channel->name(), proc->name(), op->function_base()->name())); + } + } + XLS_ASSIGN_OR_RETURN(ReceiveChannelReference * channel_ref, + proc->GetReceiveChannelReference(channel->name())); + XLS_ASSIGN_OR_RETURN( + adapter_builder, + AdapterBuilder::CreateForProcScopedChannels(channel_ref, adapter_name, + /*parent_proc=*/proc)); + } else { + XLS_ASSIGN_OR_RETURN( + adapter_builder, + AdapterBuilder::CreateForGlobalChannels(channel, adapter_name, p, + global_channel_name_uniquer)); + } VLOG(3) << absl::StreamFormat( "Making receive channel adapter for channel `%s`, has receives (%s).", - channel_name, absl::StrJoin(ops, ", ")); - XLS_RETURN_IF_ERROR(AddAdapterForMultipleReceives(p, streaming_channel, ops, - channel_name_uniquer)); + channel->name(), absl::StrJoin(ops, ", ")); + XLS_RETURN_IF_ERROR(AddAdapterForMultipleReceives(ops, *adapter_builder)); changed = true; } - for (const auto& [channel_name, ops] : multiple_ops.multiple_sends) { - XLS_ASSIGN_OR_RETURN(Channel * channel, p->GetChannel(channel_name)); + + for (const auto& [channel, ops] : multiple_ops.sends) { if (channel->kind() != ChannelKind::kStreaming) { // Don't make adapters for non-streaming channels. + VLOG(4) << absl::StreamFormat( + "Multiple receives on non-streaming channel `%s`", channel->name()); continue; } StreamingChannel* streaming_channel = down_cast(channel); @@ -1065,15 +1403,46 @@ absl::StatusOr ChannelLegalizationPass::RunInternal( ChannelStrictness::kProvenMutuallyExclusive) { // Don't make adapters for channels that must be proven to be mutually // exclusive- they will be handled during scheduling. + VLOG(4) << absl::StreamFormat( + "Multiple sends on proven mutually exclusive channel `%s`", + channel->name()); continue; } - VLOG(3) << absl::StreamFormat( + + std::unique_ptr adapter_builder; + std::string adapter_name = + absl::StrFormat("chan_%s_io_send_adapter", channel->name()); + if (p->ChannelsAreProcScoped()) { + // For proc-scoped channels all sends are in the same proc. + Proc* proc = ops.front()->function_base()->AsProcOrDie(); + for (Send* op : ops) { + if (op->function_base()->AsProcOrDie() != proc) { + return absl::UnimplementedError(absl::StrFormat( + "Channel `%s` has multiple sends in different procs: `%s` and " + "`%s`", + channel->name(), proc->name(), op->function_base()->name())); + } + } + XLS_ASSIGN_OR_RETURN(SendChannelReference * channel_ref, + proc->GetSendChannelReference(channel->name())); + XLS_ASSIGN_OR_RETURN( + adapter_builder, + AdapterBuilder::CreateForProcScopedChannels(channel_ref, adapter_name, + /*parent_proc=*/proc)); + } else { + XLS_ASSIGN_OR_RETURN( + adapter_builder, + AdapterBuilder::CreateForGlobalChannels(channel, adapter_name, p, + global_channel_name_uniquer)); + } + + VLOG(4) << absl::StreamFormat( "Making send channel adapter for channel `%s`, has sends (%s).", - channel_name, absl::StrJoin(ops, ", ")); - XLS_RETURN_IF_ERROR(AddAdapterForMultipleSends(p, streaming_channel, ops, - channel_name_uniquer)); + channel->name(), absl::StrJoin(ops, ", ")); + XLS_RETURN_IF_ERROR(AddAdapterForMultipleSends(ops, *adapter_builder)); changed = true; } + return changed; } diff --git a/xls/passes/channel_legalization_pass_test.cc b/xls/passes/channel_legalization_pass_test.cc index acc2490f85..69eee799bf 100644 --- a/xls/passes/channel_legalization_pass_test.cc +++ b/xls/passes/channel_legalization_pass_test.cc @@ -43,6 +43,7 @@ #include "xls/ir/ir_parser.h" #include "xls/ir/ir_test_base.h" #include "xls/ir/package.h" +#include "xls/ir/proc_elaboration.h" #include "xls/ir/value.h" #include "xls/ir/verifier.h" #include "xls/passes/optimization_pass.h" @@ -68,6 +69,20 @@ using ::testing::UnorderedElementsAre; using ::testing::Values; using ::testing::ValuesIn; +// Create an interpreter runtime for evaluating procs. Automatically handles new +// and old style procs. +absl::StatusOr> CreateRuntime( + Package* package) { + if (!package->HasTop()) { + return CreateInterpreterSerialProcRuntime(package); + } + XLS_ASSIGN_OR_RETURN(Proc * top, package->GetTopAsProc()); + if (top->is_new_style_proc()) { + return CreateInterpreterSerialProcRuntime(top); + } + return CreateInterpreterSerialProcRuntime(package); +} + struct TestParam { using evaluation_function = std::function)>; @@ -89,7 +104,10 @@ class ChannelLegalizationPassTest protected: absl::StatusOr Run(Package* package) { PassResults results; - return ChannelLegalizationPass().Run(package, {}, &results); + XLS_ASSIGN_OR_RETURN(bool changed, + ChannelLegalizationPass().Run(package, {}, &results)); + XLS_RETURN_IF_ERROR(VerifyPackage(package)); + return changed; } }; @@ -857,6 +875,88 @@ top proc test_proc(state:(), init={()}) { EXPECT_THAT(outq->Read(), Optional(Value(UBits(1, /*bit_count=*/32)))); + return absl::OkStatus(); + }, + }, + TestParam{ + .test_name = "SingleNewStyleProc", + .ir_text = R"(package test + +top proc my_proc< + in: bits[32] in kind=streaming strictness=$0, + out: bits[32] out kind=streaming strictness=$0>() { + tok: token = literal(value=token) + recv0: (token, bits[32]) = receive(tok, channel=in) + recv0_tok: token = tuple_index(recv0, index=0) + recv0_data: bits[32] = tuple_index(recv0, index=1) + recv1: (token, bits[32]) = receive(recv0_tok, channel=in) + recv1_tok: token = tuple_index(recv1, index=0) + recv1_data: bits[32] = tuple_index(recv1, index=1) + send0: token = send(recv1_tok, recv1_data, channel=out) + send1: token = send(send0, recv0_data, channel=out) +} + )", + .builder_matcher = + { + // For mutually exclusive channels, channel legalization does + // not change the IR, but other passes do. Just check OK. + {ChannelStrictness::kProvenMutuallyExclusive, IsOk()}, + {ChannelStrictness::kTotalOrder, IsOkAndHolds(true)}, + {ChannelStrictness::kRuntimeOrdered, IsOkAndHolds(true)}, + // Build should be OK, but will fail at runtime. + {ChannelStrictness::kRuntimeMutuallyExclusive, + IsOkAndHolds(true)}, + {ChannelStrictness::kArbitraryStaticOrder, IsOkAndHolds(true)}, + }, + .evaluate = + [](SerialProcRuntime* interpreter, + std::optional strictness) -> absl::Status { + constexpr int64_t kMaxTicks = 1000; + constexpr int64_t kNumInputs = 32; + + const ProcElaboration& elab = + interpreter->queue_manager().elaboration(); + + XLS_ASSIGN_OR_RETURN(ChannelInstance * in_instance, + elab.GetChannelInstance("in", "my_proc")); + XLS_ASSIGN_OR_RETURN(ChannelInstance * out_instance, + elab.GetChannelInstance("out", "my_proc")); + ChannelQueue& inq = + interpreter->queue_manager().GetQueue(in_instance); + ChannelQueue& outq = + interpreter->queue_manager().GetQueue(out_instance); + + for (int64_t i = 0; i < kNumInputs; ++i) { + XLS_RETURN_IF_ERROR(inq.Write(Value(UBits(i, /*bit_count=*/32)))); + } + absl::flat_hash_map output_count{ + {outq.channel_instance(), kNumInputs}}; + absl::Status interpreter_status = + interpreter->TickUntilOutput(output_count, kMaxTicks).status(); + if (strictness.has_value() && + strictness.value() == + ChannelStrictness::kRuntimeMutuallyExclusive) { + EXPECT_THAT( + interpreter_status, + StatusIs(absl::StatusCode::kAborted, + HasSubstr("predicate was not mutually exclusive"))); + // Return early, we have no output to check. + return absl::OkStatus(); + } + XLS_EXPECT_OK(interpreter_status); + for (int64_t i = 0; i < kNumInputs; ++i) { + EXPECT_FALSE(outq.IsEmpty()); + int64_t flip_evens_and_odds = i; + if (i % 2 == 0) { + flip_evens_and_odds++; + } else { + flip_evens_and_odds--; + } + EXPECT_THAT(outq.Read(), + Optional(Eq(Value( + UBits(flip_evens_and_odds, /*bit_count=*/32))))); + } + return absl::OkStatus(); }, }, @@ -890,19 +990,19 @@ TEST_P(ChannelLegalizationPassTest, EvaluatesCorrectly) { std::get<0>(GetParam()).ir_text, ChannelStrictnessToString(strictness)))); XLS_ASSERT_OK_AND_ASSIGN(std::unique_ptr interpreter, - CreateInterpreterSerialProcRuntime(p.get())); + CreateRuntime(p.get())); // Don't pass in strictness because the pass hasn't been run yet. XLS_EXPECT_OK(std::get<0>(GetParam()) .evaluate(interpreter.get(), /*strictness=*/std::nullopt)); absl::StatusOr run_status = Run(p.get()); + if (!run_status.ok()) { GTEST_SKIP(); } - XLS_ASSERT_OK_AND_ASSIGN(interpreter, - CreateInterpreterSerialProcRuntime(p.get())); + XLS_ASSERT_OK_AND_ASSIGN(interpreter, CreateRuntime(p.get())); XLS_EXPECT_OK(std::get<0>(GetParam()) .evaluate(interpreter.get(), std::get<1>(GetParam()))); } @@ -928,31 +1028,31 @@ TEST_F(ChannelLegalizationPassTest, NamesAreUniquified) { // makes new channels. XLS_ASSERT_OK_AND_ASSIGN( StreamingChannel * c_1, - p.CreateStreamingChannel("c__1", ChannelOps::kSendOnly, + p.CreateStreamingChannel("c__data_1", ChannelOps::kSendOnly, p.GetBitsType(1))); XLS_ASSERT_OK_AND_ASSIGN( StreamingChannel * c_pred, - p.CreateStreamingChannel("c__pred", ChannelOps::kSendOnly, + p.CreateStreamingChannel("c__pred_0", ChannelOps::kSendOnly, p.GetBitsType(1))); XLS_ASSERT_OK_AND_ASSIGN( StreamingChannel * c_2_completion, - p.CreateStreamingChannel("c__2__completion", ChannelOps::kSendOnly, + p.CreateStreamingChannel("c__completion_1", ChannelOps::kSendOnly, p.GetBitsType(1))); XLS_ASSERT_OK_AND_ASSIGN( StreamingChannel * c_3_completion, - p.CreateStreamingChannel("c__3__completion", ChannelOps::kSendOnly, + p.CreateStreamingChannel("c__completion_2", ChannelOps::kSendOnly, p.GetBitsType(1))); XLS_ASSERT_OK_AND_ASSIGN( StreamingChannel * d_1, - p.CreateStreamingChannel("d__1", ChannelOps::kSendOnly, + p.CreateStreamingChannel("d__data_1", ChannelOps::kSendOnly, p.GetBitsType(1))); XLS_ASSERT_OK_AND_ASSIGN( StreamingChannel * d_pred, - p.CreateStreamingChannel("d__pred", ChannelOps::kSendOnly, + p.CreateStreamingChannel("d__pred_0", ChannelOps::kSendOnly, p.GetBitsType(1))); XLS_ASSERT_OK_AND_ASSIGN( StreamingChannel * d_completion, - p.CreateStreamingChannel("d__completion", ChannelOps::kSendOnly, + p.CreateStreamingChannel("d__completion_0", ChannelOps::kSendOnly, p.GetBitsType(1))); ProcBuilder b("proc0", &p); // Do multiple sends/receives on c/d to insert adapters and new channels. @@ -973,19 +1073,21 @@ TEST_F(ChannelLegalizationPassTest, NamesAreUniquified) { EXPECT_THAT(Run(&p), IsOkAndHolds(true)); - EXPECT_THAT(p.channels(), - UnorderedElementsAre( - m::Channel("c"), m::Channel("d"), - // Original "extra" names - m::Channel("c__1"), m::Channel("d__1"), m::Channel("c__pred"), - m::Channel("d__pred"), m::Channel("c__2__completion"), - m::Channel("c__3__completion"), m::Channel("d__completion"), - // New colliding names - m::Channel("c__2"), m::Channel("c__3"), m::Channel("d__2"), - m::Channel("d__3"), m::Channel("c__pred__1"), - m::Channel("c__pred__2"), m::Channel("d__pred__1"), - m::Channel("d__pred__2"), m::Channel("c__2__completion__1"), - m::Channel("c__3__completion__1"))); + EXPECT_THAT( + p.channels(), + UnorderedElementsAre( + m::Channel("c"), m::Channel("d"), + // Original "extra" names + m::Channel("c__data_1"), m::Channel("c__pred_0"), + m::Channel("c__completion_0"), m::Channel("c__completion_1"), + m::Channel("d__data_1"), m::Channel("d__pred_0"), + m::Channel("d__completion_0"), + // New colliding names + m::Channel("d__pred_0__1"), m::Channel("d__pred_1"), + m::Channel("d__data_0"), m::Channel("d__data_1__1"), + m::Channel("c__pred_0__1"), m::Channel("c__pred_1"), + m::Channel("c__data_0"), m::Channel("c__data_1__1"), + m::Channel("c__completion_1__1"), m::Channel("c__completion_2"))); } INSTANTIATE_TEST_SUITE_P( @@ -1039,14 +1141,19 @@ class SingleValueChannelLegalizationPassTest : public TestWithParam { absl::StatusOr Run() { // Replace all streaming channels with single_value channels. std::string substituted_ir_text = absl::StrReplaceAll( - GetParam().ir_text, { - {"kind=streaming, ops=send_only, " - "flow_control=ready_valid, strictness=$0, ", - "kind=single_value, ops=send_only, "}, - {"kind=streaming, ops=receive_only, " - "flow_control=ready_valid, strictness=$0, ", - "kind=single_value, ops=receive_only, "}, - }); + GetParam().ir_text, + { + // Global channel form. + {"kind=streaming, ops=send_only, " + "flow_control=ready_valid, strictness=$0, ", + "kind=single_value, ops=send_only, "}, + {"kind=streaming, ops=receive_only, " + "flow_control=ready_valid, strictness=$0, ", + "kind=single_value, ops=receive_only, "}, + // Proc-scoped channel form. + {"kind=streaming strictness=$0", "kind=single_value"}, + {"kind=streaming strictness=$0", "kind=single_value"}, + }); XLS_ASSIGN_OR_RETURN(std::unique_ptr p, Parser::ParsePackage(substituted_ir_text));