diff --git a/xls/ir/BUILD b/xls/ir/BUILD index 77600309ce..082f21152b 100644 --- a/xls/ir/BUILD +++ b/xls/ir/BUILD @@ -1583,6 +1583,7 @@ cc_library( "//xls/common/status:ret_check", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", @@ -1896,6 +1897,7 @@ cc_library( "//xls/common/status:status_macros", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", diff --git a/xls/ir/channel.cc b/xls/ir/channel.cc index 5462a1a14a..14007fd0a0 100644 --- a/xls/ir/channel.cc +++ b/xls/ir/channel.cc @@ -22,6 +22,7 @@ #include #include +#include "absl/log/check.h" #include "absl/log/log.h" #include "absl/status/status.h" #include "absl/status/statusor.h" @@ -229,6 +230,38 @@ std::ostream& operator<<(std::ostream& os, Direction direction) { return os; } +ChannelRef AsChannelRef(SendChannelRef ref) { + if (std::holds_alternative(ref)) { + return std::get(ref); + } + return std::get(ref); +} + +ChannelRef AsChannelRef(ReceiveChannelRef ref) { + if (std::holds_alternative(ref)) { + return std::get(ref); + } + return std::get(ref); +} + +SendChannelRef AsSendChannelRefOrDie(ChannelRef ref) { + if (std::holds_alternative(ref)) { + ChannelReference* cref = std::get(ref); + CHECK_EQ(cref->direction(), Direction::kSend); + return down_cast(cref); + } + return std::get(ref); +} + +ReceiveChannelRef AsReceiveChannelRefOrDie(ChannelRef ref) { + if (std::holds_alternative(ref)) { + ChannelReference* cref = std::get(ref); + CHECK_EQ(cref->direction(), Direction::kReceive); + return down_cast(cref); + } + return std::get(ref); +} + std::string_view ChannelRefName(ChannelRef ref) { return absl::visit([](const auto& ch) { return ch->name(); }, ref); } @@ -247,6 +280,17 @@ ChannelKind ChannelRefKind(ChannelRef ref) { return std::get(ref)->kind(); } +std::optional ChannelRefStrictness(ChannelRef ref) { + if (std::holds_alternative(ref)) { + return std::get(ref)->strictness(); + } + if (auto streaming_channel = + down_cast(std::get(ref))) { + return streaming_channel->GetStrictness(); + } + return std::nullopt; +} + std::string ChannelReference::ToString() const { std::vector keyword_strs; keyword_strs.push_back( diff --git a/xls/ir/channel.h b/xls/ir/channel.h index 48c1cc9ca2..be97eaea7d 100644 --- a/xls/ir/channel.h +++ b/xls/ir/channel.h @@ -420,10 +420,20 @@ using ChannelRef = std::variant; using SendChannelRef = std::variant; using ReceiveChannelRef = std::variant; +// Converts a send/receive ChannelRef into a generic ChannelRef. +ChannelRef AsChannelRef(SendChannelRef ref); +ChannelRef AsChannelRef(ReceiveChannelRef ref); + +// Converts a base ChannelRef into a send/receive form. CHECK fails if the +// ChannelRef is not of the appropriate direction. +SendChannelRef AsSendChannelRefOrDie(ChannelRef ref); +ReceiveChannelRef AsReceiveChannelRefOrDie(ChannelRef ref); + // Return the name/type/kind of a channel reference. std::string_view ChannelRefName(ChannelRef ref); Type* ChannelRefType(ChannelRef ref); ChannelKind ChannelRefKind(ChannelRef ref); +std::optional ChannelRefStrictness(ChannelRef ref); } // namespace xls diff --git a/xls/ir/package.cc b/xls/ir/package.cc index 18bca4c07b..9a686b1221 100644 --- a/xls/ir/package.cc +++ b/xls/ir/package.cc @@ -768,6 +768,7 @@ absl::Status Package::AddChannel(std::unique_ptr channel, Proc* proc) { } absl::StatusOr Package::GetChannel(int64_t id) const { + XLS_RET_CHECK(!ChannelsAreProcScoped()); for (Channel* ch : channels()) { if (ch->id() == id) { return ch; @@ -777,6 +778,7 @@ absl::StatusOr Package::GetChannel(int64_t id) const { } std::vector Package::GetChannelNames() const { + CHECK(!ChannelsAreProcScoped()); std::vector names; names.reserve(channels().size()); for (Channel* ch : channels()) { @@ -786,6 +788,7 @@ std::vector Package::GetChannelNames() const { } absl::StatusOr Package::GetChannel(std::string_view name) const { + XLS_RET_CHECK(!ChannelsAreProcScoped()); auto it = channels_.find(name); if (it != channels_.end()) { return it->second.get(); diff --git a/xls/ir/proc_elaboration.cc b/xls/ir/proc_elaboration.cc index ea6852a284..fd2727ccd8 100644 --- a/xls/ir/proc_elaboration.cc +++ b/xls/ir/proc_elaboration.cc @@ -25,6 +25,7 @@ #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" +#include "absl/log/log.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" diff --git a/xls/passes/BUILD b/xls/passes/BUILD index 4939b9ea52..ac2714ffba 100644 --- a/xls/passes/BUILD +++ b/xls/passes/BUILD @@ -2909,6 +2909,7 @@ cc_library( "//xls/ir:name_uniquer", "//xls/ir:node_util", "//xls/ir:op", + "//xls/ir:proc_elaboration", "//xls/ir:source_location", "//xls/ir:value", "@com_google_absl//absl/container:btree", @@ -2916,6 +2917,7 @@ cc_library( "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", + "@com_google_absl//absl/memory", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", @@ -2993,6 +2995,7 @@ cc_test( "//xls/ir:ir_matcher", "//xls/ir:ir_parser", "//xls/ir:ir_test_base", + "//xls/ir:proc_elaboration", "//xls/ir:value", "//xls/ir:verifier", "@com_google_absl//absl/container:flat_hash_map", diff --git a/xls/passes/channel_legalization_pass.cc b/xls/passes/channel_legalization_pass.cc index 2622c2a509..eee7885d5d 100644 --- a/xls/passes/channel_legalization_pass.cc +++ b/xls/passes/channel_legalization_pass.cc @@ -17,6 +17,7 @@ #include #include #include +#include #include #include #include @@ -30,6 +31,7 @@ #include "absl/container/flat_hash_set.h" #include "absl/log/check.h" #include "absl/log/log.h" +#include "absl/memory/memory.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" @@ -50,6 +52,7 @@ #include "xls/ir/nodes.h" #include "xls/ir/op.h" #include "xls/ir/package.h" +#include "xls/ir/proc_elaboration.h" #include "xls/ir/source_location.h" #include "xls/ir/value.h" #include "xls/passes/optimization_pass.h" @@ -60,6 +63,227 @@ namespace xls { namespace { + +// Data structure describing a channel for sending data to the adapter. +struct AdapterInputChannel { + StreamingChannel* channel; + // ChannelRef for sending data on the channel outside the adapter. + SendChannelRef parent_send_channel_ref; + // ChannelRef for receiving data on the channel inside the adapter. + ReceiveChannelRef adapter_receive_channel_ref; +}; + +// Data structure describing a channel for receiving data from the adapter. +struct AdapterOutputChannel { + StreamingChannel* channel; + // ChannelRef for receiving data on the channel outside the adapter. + ReceiveChannelRef parent_receive_channel_ref; + // ChannelRef for sending data on the channel inside the adapter. + SendChannelRef adapter_send_channel_ref; +}; + +// Helper class for building an adapter. Abstracts away proc-scoped vs global +// channel behind a uniform interface. +class AdapterBuilder { + public: + static absl::StatusOr> + CreateForProcScopedChannels(ChannelReference* adapted_channel, + std::string_view adapter_name, + Proc* parent_proc) { + XLS_RET_CHECK(parent_proc->package()->ChannelsAreProcScoped()); + std::unique_ptr proc_builder = std::make_unique( + NewStyleProc(), adapter_name, parent_proc->package()); + + // Add input/output for adapted channel. + ChannelReference* adapted_channel_ref_in_adapter; + if (adapted_channel->direction() == Direction::kSend) { + XLS_ASSIGN_OR_RETURN( + adapted_channel_ref_in_adapter, + proc_builder->AddOutputChannel( + adapted_channel->name(), adapted_channel->type(), + ChannelKind::kStreaming, adapted_channel->strictness())); + } else { + XLS_ASSIGN_OR_RETURN( + adapted_channel_ref_in_adapter, + proc_builder->AddInputChannel( + adapted_channel->name(), adapted_channel->type(), + ChannelKind::kStreaming, adapted_channel->strictness())); + } + + // Prime the name uniquer with the current channels in the proc. + auto name_uniquer = std::make_unique("__"); + for (Channel* ch : parent_proc->channels()) { + name_uniquer->GetSanitizedUniqueName(ch->name()); + } + std::unique_ptr builder = + absl::WrapUnique(new AdapterBuilder( + adapted_channel, adapted_channel_ref_in_adapter, + std::move(proc_builder), parent_proc, + /*global_channel_name_uniquer=*/nullptr, + /*proc_scoped_channel_name_uniquer=*/std::move(name_uniquer))); + + builder->adapter_instantiation_args_.push_back(adapted_channel); + return builder; + } + + static absl::StatusOr> + CreateForGlobalChannels(Channel* adapted_channel, + std::string_view adapter_name, Package* package, + NameUniquer& channel_name_uniquer) { + XLS_RET_CHECK(!package->ChannelsAreProcScoped()); + std::unique_ptr proc_builder = + std::make_unique(adapter_name, package); + return absl::WrapUnique(new AdapterBuilder( + adapted_channel, adapted_channel, std::move(proc_builder), + /*parent_proc=*/nullptr, + /*global_channel_name_uniquer=*/&channel_name_uniquer, + /*proc_scoped_channel_name_uniquer=*/nullptr)); + } + + // Create a channel for sending to the adapter. + absl::StatusOr AddAdapterInputChannel( + std::string_view name, Type* type, const FifoConfig& fifo_config, + std::optional strictness = std::nullopt) { + AdapterInputChannel result; + if (ChannelsAreProcScoped()) { + std::string uniquified_name = + proc_scoped_channel_name_uniquer_->GetSanitizedUniqueName(name); + XLS_ASSIGN_OR_RETURN(result.channel, + package()->CreateStreamingChannelInProc( + uniquified_name, ChannelOps::kSendReceive, type, + parent_proc().value(), + /*initial_values=*/{}, fifo_config)); + XLS_ASSIGN_OR_RETURN(result.parent_send_channel_ref, + parent_proc().value()->GetSendChannelReference( + result.channel->name())); + XLS_ASSIGN_OR_RETURN(ReceiveChannelReference * parent_receive_channel_ref, + parent_proc().value()->GetReceiveChannelReference( + result.channel->name())); + XLS_ASSIGN_OR_RETURN( + result.adapter_receive_channel_ref, + adapter_builder().AddInputChannel(name, type, ChannelKind::kStreaming, + strictness)); + adapter_instantiation_args_.push_back(parent_receive_channel_ref); + } else { + std::string uniquified_name = + global_channel_name_uniquer_->GetSanitizedUniqueName(name); + XLS_ASSIGN_OR_RETURN(result.channel, + package()->CreateStreamingChannel( + uniquified_name, ChannelOps::kSendReceive, type, + /*initial_values=*/{}, fifo_config)); + result.parent_send_channel_ref = result.channel; + result.adapter_receive_channel_ref = result.channel; + } + return result; + } + + // Create a channel for receiving from the adapter. + absl::StatusOr AddAdapterOutputChannel( + std::string_view name, Type* type, const FifoConfig& fifo_config, + std::optional strictness = std::nullopt) { + AdapterOutputChannel result; + if (ChannelsAreProcScoped()) { + std::string uniquified_name = + proc_scoped_channel_name_uniquer_->GetSanitizedUniqueName(name); + XLS_ASSIGN_OR_RETURN(result.channel, + package()->CreateStreamingChannelInProc( + uniquified_name, ChannelOps::kSendReceive, type, + parent_proc().value(), + /*initial_values=*/{}, fifo_config)); + XLS_ASSIGN_OR_RETURN(result.parent_receive_channel_ref, + parent_proc().value()->GetReceiveChannelReference( + result.channel->name())); + XLS_ASSIGN_OR_RETURN(SendChannelReference * parent_send_channel_ref, + parent_proc().value()->GetSendChannelReference( + result.channel->name())); + XLS_ASSIGN_OR_RETURN( + result.adapter_send_channel_ref, + adapter_builder().AddOutputChannel( + name, type, ChannelKind::kStreaming, strictness)); + adapter_instantiation_args_.push_back(parent_send_channel_ref); + } else { + std::string uniquified_name = + global_channel_name_uniquer_->GetSanitizedUniqueName(name); + XLS_ASSIGN_OR_RETURN(result.channel, + package()->CreateStreamingChannel( + uniquified_name, ChannelOps::kSendReceive, type, + /*initial_values=*/{}, fifo_config)); + result.parent_receive_channel_ref = result.channel; + result.adapter_send_channel_ref = result.channel; + } + return result; + } + + // Build and return the adapter proc. + absl::StatusOr Build(absl::Span next_state) { + if (ChannelsAreProcScoped()) { + XLS_RETURN_IF_ERROR( + parent_proc() + .value() + ->AddProcInstantiation( + absl::StrFormat("%s_adapter_inst", adapted_channel_name()), + adapter_instantiation_args_, adapter_proc()) + .status()); + } + return adapter_builder().Build(next_state); + } + + bool ChannelsAreProcScoped() { return adapter_proc()->is_new_style_proc(); } + Package* package() { return adapter_proc()->package(); } + + // For proc-scoped channels, returns the proc which instantates and + // communicates with the adapter. For global channels, returns nullopt. + std::optional parent_proc() { return parent_proc_; } + + Proc* adapter_proc() { return adapter_builder().proc(); } + ProcBuilder& adapter_builder() { return *adapter_builder_; } + + // Returns the ChannelRef of original (adapted) channel for sending/receiving + // outside the adapter. + ChannelRef adapted_channel_ref_in_parent() { + return adapted_channel_ref_in_parent_; + } + // Returns the ChannelRef of original (adapted) channel for sending/receiving + // inside the adapter. + ChannelRef adapted_channel_ref_in_adapter() { + return adapted_channel_ref_in_adapter_; + } + + // Returns various properties of the original (adapted) channel. + ChannelStrictness adapted_channel_strictness() { + return ChannelRefStrictness(adapted_channel_ref_in_parent()).value(); + } + std::string_view adapted_channel_name() { + return ChannelRefName(adapted_channel_ref_in_parent()); + } + Type* adapted_channel_type() { + return ChannelRefType(adapted_channel_ref_in_parent()); + } + + private: + AdapterBuilder(ChannelRef adapted_channel_ref_in_parent, + ChannelRef adapted_channel_ref_in_adapter, + std::unique_ptr adapter_builder, + std::optional parent_proc, + NameUniquer* global_channel_name_uniquer, + std::unique_ptr proc_scoped_channel_name_uniquer) + : adapted_channel_ref_in_parent_(adapted_channel_ref_in_parent), + adapted_channel_ref_in_adapter_(adapted_channel_ref_in_adapter), + adapter_builder_(std::move(adapter_builder)), + parent_proc_(parent_proc), + global_channel_name_uniquer_(global_channel_name_uniquer), + proc_scoped_channel_name_uniquer_( + std::move(proc_scoped_channel_name_uniquer)) {} + + ChannelRef adapted_channel_ref_in_parent_; + ChannelRef adapted_channel_ref_in_adapter_; + std::unique_ptr adapter_builder_; + std::optional parent_proc_; + NameUniquer* global_channel_name_uniquer_; + std::unique_ptr proc_scoped_channel_name_uniquer_; + std::vector adapter_instantiation_args_; +}; + // Get the token operand number for a given channel op. absl::StatusOr TokenOperandNumberForChannelOp(Node* node) { switch (node->op()) { @@ -72,45 +296,119 @@ absl::StatusOr TokenOperandNumberForChannelOp(Node* node) { absl::StrFormat("Expected channel op, got %s.", node->ToString())); } } + +struct ChannelSends { + Channel* channel; + std::vector sends; +}; + +struct ChannelReceives { + Channel* channel; + std::vector receives; +}; + struct MultipleChannelOps { - absl::flat_hash_map> multiple_sends; - absl::flat_hash_map> - multiple_receives; + std::vector sends; + std::vector receives; }; +absl::StatusOr ElaboratePackage(Package* package) { + std::optional top = package->GetTop(); + if (!top.has_value() || !(*top)->IsProc() || + !(*top)->AsProcOrDie()->is_new_style_proc()) { + return ProcElaboration::ElaborateOldStylePackage(package); + } + return ProcElaboration::Elaborate((*top)->AsProcOrDie()); +} + // Find instances of multiple sends/recvs on a channel. -MultipleChannelOps FindMultipleChannelOps(Package* p) { - MultipleChannelOps result; - for (FunctionBase* fb : p->GetFunctionBases()) { - for (Node* node : fb->nodes()) { +absl::StatusOr FindMultipleChannelOps( + const ProcElaboration& elab) { + // Create a map from channel to the set of send/receive nodes on the channel + // and vice-versa. + absl::flat_hash_map> channel_sends; + absl::flat_hash_map> channel_receives; + absl::flat_hash_map> send_channels; + absl::flat_hash_map> receive_channels; + for (ProcInstance* proc_instance : elab.proc_instances()) { + for (Node* node : proc_instance->proc()->nodes()) { if (node->Is()) { Send* send = node->As(); - VLOG(4) << "Found send " << send->ToString(); - result.multiple_sends[send->channel_name()].insert(send); + XLS_ASSIGN_OR_RETURN( + ChannelInstance * channel_instance, + proc_instance->GetChannelInstance(send->channel_name())); + channel_sends[channel_instance->channel].insert(send); + send_channels[send].insert(channel_instance->channel); } if (node->Is()) { - Receive* recv = node->As(); - result.multiple_receives[recv->channel_name()].insert(recv); - VLOG(4) << "Found recv " << recv->ToString(); + Receive* receive = node->As(); + XLS_ASSIGN_OR_RETURN( + ChannelInstance * channel_instance, + proc_instance->GetChannelInstance(receive->channel_name())); + channel_receives[channel_instance->channel].insert(receive); + receive_channels[receive].insert(channel_instance->channel); + } + } + } + + MultipleChannelOps result; + + // Identify channels which have multiple sends/receives. Return an error if + // there is a send/receive which can send on different channels *AND* is also + // part of a set multiple such ops on the same channel. For now in channel + // legalization we require sends/receives to be uniquely mapped to a single + // channel. + for (auto [channel, sends] : channel_sends) { + if (sends.size() <= 1) { + continue; + } + ChannelSends element; + element.channel = channel; + element.sends = std::vector(sends.begin(), sends.end()); + std::sort(element.sends.begin(), element.sends.end(), NodeIdLessThan); + + for (Send* send : element.sends) { + if (send_channels.at(send).size() > 1) { + return absl::UnimplementedError( + absl::StrFormat("Send `%s` can send on different channels and is " + "one of multiple sends on channel `%s`", + send->GetName(), channel->name())); + } + } + result.sends.push_back(std::move(element)); + } + std::sort(result.sends.begin(), result.sends.end(), + [](const ChannelSends& a, const ChannelSends& b) { + return a.channel->id() < b.channel->id(); + }); + + for (auto [channel, receives] : channel_receives) { + if (receives.size() <= 1) { + continue; + } + ChannelReceives element; + element.channel = channel; + element.receives = std::vector(receives.begin(), receives.end()); + std::sort(element.receives.begin(), element.receives.end(), NodeIdLessThan); + + for (Receive* receive : element.receives) { + if (receive_channels.at(receive).size() > 1) { + return absl::UnimplementedError(absl::StrFormat( + "Receive `%s` can receive on different channels and is " + "one of multiple receives on channel `%s`", + receive->GetName(), channel->name())); } } + result.receives.push_back(std::move(element)); } + std::sort(result.receives.begin(), result.receives.end(), + [](const ChannelReceives& a, const ChannelReceives& b) { + return a.channel->id() < b.channel->id(); + }); - // Erase cases where there's only one send or receive. - absl::erase_if( - result.multiple_sends, - [](const std::pair>& elt) { - return elt.second.size() < 2; - }); - absl::erase_if( - result.multiple_receives, - [](const std::pair>& elt) { - return elt.second.size() < 2; - }); - - VLOG(4) << "After erasing single accesses, found " - << result.multiple_sends.size() << " multiple send channels and " - << result.multiple_receives.size() << " multiple receive channels."; + VLOG(4) << "After erasing single accesses, found " << result.receives.size() + << " multiple send channels and " << result.receives.size() + << " multiple receive channels."; return result; } @@ -223,7 +521,7 @@ struct FunctionBaseNameLess { // Note that token params are not included in predecessor lists. template >> absl::StatusOr> GetProjectedTokenDAG( - const absl::flat_hash_set& operations, ChannelStrictness strictness) { + absl::Span operations, ChannelStrictness strictness) { // We return the result_vector, but also build a result_map to track // transitive dependencies. std::vector result_vector; @@ -260,6 +558,8 @@ absl::StatusOr> GetProjectedTokenDAG( for (const T* operation : operations) { fbs.insert(operation->function_base()); } + + absl::flat_hash_set operation_set(operations.begin(), operations.end()); for (FunctionBase* fb : fbs) { XLS_ASSIGN_OR_RETURN(std::vector fbs_result, ComputeTopoSortedTokenDAG(fb)); @@ -279,7 +579,7 @@ absl::StatusOr> GetProjectedTokenDAG( for (Node* predecessor : fb_result.predecessors) { // If a predecessor is not type T, resolve its predecessors of type T. if (!predecessor->Is() || - !operations.contains(predecessor->As())) { + !operation_set.contains(predecessor->As())) { absl::flat_hash_set& resolved = resolved_ops.at(predecessor); resolved_predecessors.insert(resolved.begin(), resolved.end()); continue; @@ -289,7 +589,7 @@ absl::StatusOr> GetProjectedTokenDAG( // If this entry in the DAG is not type T, save its resolved predecessors // for future resolution. if (!fb_result.node->Is() || - !operations.contains(fb_result.node->As())) { + !operation_set.contains(fb_result.node->As())) { resolved_ops[fb_result.node] = std::move(resolved_predecessors); continue; } @@ -299,7 +599,7 @@ absl::StatusOr> GetProjectedTokenDAG( if (strictness == ChannelStrictness::kArbitraryStaticOrder) { if (prev_node.has_value()) { if (!prev_node.value()->Is() || - !operations.contains(prev_node.value()->As())) { + !operation_set.contains(prev_node.value()->As())) { resolved_predecessors.insert(resolved_ops.at(*prev_node).begin(), resolved_ops.at(*prev_node).end()); } else { @@ -345,31 +645,24 @@ absl::Status CheckIsBlocking(Node* n) { // original channel operation. // 3) Updates the token of the original channel operation to come after the new // predicate send. -absl::StatusOr MakePredicateChannel( - Node* operation, NameUniquer& channel_name_uniquer) { +absl::StatusOr MakePredicateChannel( + Node* operation, ChannelRef channel_ref, int64_t instance_number, + AdapterBuilder& ab) { Package* package = operation->package(); - - XLS_ASSIGN_OR_RETURN(Channel * operation_channel, - GetChannelUsedByNode(operation)); - StreamingChannel* channel = down_cast(operation_channel); Proc* proc = operation->function_base()->AsProcOrDie(); XLS_ASSIGN_OR_RETURN( - StreamingChannel * pred_channel, - package->CreateStreamingChannel( - channel_name_uniquer.GetSanitizedUniqueName( - absl::StrCat(channel->name(), "__pred")), - // This is an internal channel, so override to kSendReceive - ChannelOps::kSendReceive, + AdapterInputChannel pred_input_channel, + ab.AddAdapterInputChannel( + absl::StrFormat("%s__pred_%d", ChannelRefName(channel_ref), + instance_number), // This channel is used to forward the predicate to the adapter, which // takes 1 bit. package->GetBitsType(1), - /*initial_values=*/{}, // This is an internal channel that may be inlined during proc // inlining, set FIFO depth to 1. Break cycles by registering push // outputs. // TODO: github/xls#1509 - revisit this if we have better ways of // avoiding cycles in adapters. - /*fifo_config=*/ FifoConfig(/*depth=*/1, /*bypass=*/false, /*register_push_outputs=*/true, /*register_pop_outputs=*/false))); @@ -381,24 +674,27 @@ absl::StatusOr MakePredicateChannel( // If predicate nullopt, that means it's an unconditional send/receive. Make a // true literal to send to the adapter. if (!predicate.has_value()) { - XLS_ASSIGN_OR_RETURN( - predicate, - proc->MakeNodeWithName( - SourceInfo(), Value(UBits(1, /*bit_count=*/1)), - absl::StrFormat("true_predicate_for_chan_%s", channel->name()))); + XLS_ASSIGN_OR_RETURN(predicate, + proc->MakeNodeWithName( + SourceInfo(), Value(UBits(1, /*bit_count=*/1)), + absl::StrFormat("true_predicate_for_chan_%s", + ChannelRefName(channel_ref)))); } XLS_ASSIGN_OR_RETURN(int64_t operand_number, TokenOperandNumberForChannelOp(operation)); - XLS_ASSIGN_OR_RETURN( - Send * send_pred, - proc->MakeNodeWithName( - SourceInfo(), operation->operand(operand_number), predicate.value(), - std::nullopt, pred_channel->name(), - absl::StrFormat("send_predicate_for_chan_%s", channel->name()))); + XLS_ASSIGN_OR_RETURN(Send * send_pred, + proc->MakeNodeWithName( + SourceInfo(), operation->operand(operand_number), + predicate.value(), std::nullopt, + ChannelRefName(AsChannelRef( + pred_input_channel.parent_send_channel_ref)), + absl::StrFormat("send_predicate_for_chan_%s", + ChannelRefName(channel_ref)))); // Replace the op's original input token with the predicate send's token. XLS_RETURN_IF_ERROR(operation->ReplaceOperandNumber( operand_number, send_pred, /*type_must_match=*/true)); - return pred_channel; + + return pred_input_channel; } // Add a new channel to communicate a channel operation's completion to the @@ -410,26 +706,20 @@ absl::StatusOr MakePredicateChannel( // 2) Adds a receive of the completion on this channel at the token level of the // original channel operation. // 3) Updates the token of the original channel operation to come after the new -// completion receive. -absl::StatusOr MakeCompletionChannel( - Node* operation, NameUniquer& channel_name_uniquer) { - Package* package = operation->package(); - - XLS_ASSIGN_OR_RETURN(Channel * operation_channel, - GetChannelUsedByNode(operation)); - StreamingChannel* channel = down_cast(operation_channel); +// completion_42` receive. +absl::StatusOr MakeCompletionChannel( + Node* operation, int64_t instance_number, AdapterBuilder& ab) { + Package* package = ab.package(); Proc* proc = operation->function_base()->AsProcOrDie(); + XLS_ASSIGN_OR_RETURN( - StreamingChannel * completion_channel, - package->CreateStreamingChannel( - channel_name_uniquer.GetSanitizedUniqueName( - absl::StrCat(channel->name(), "__completion")), - // This is an internal channel, so override to kSendReceive - ChannelOps::kSendReceive, + AdapterOutputChannel completion_channel, + ab.AddAdapterOutputChannel( + absl::StrFormat("%s__completion_%i", ab.adapted_channel_name(), + instance_number), // This channel is used to mark the completion of the requested // operation and doesn't carry any data. Use an empty tuple type. package->GetTupleType({}), - /*initial_values=*/{}, // This is an internal channel that may be inlined during proc // inlining, set FIFO depth to 0. Completion channels seem to not // cause cycles when they have timing paths between push and pop. @@ -448,10 +738,10 @@ absl::StatusOr MakeCompletionChannel( // true literal to send to the adapter. if (!predicate.has_value()) { XLS_ASSIGN_OR_RETURN( - predicate, - proc->MakeNodeWithName( - SourceInfo(), Value(UBits(1, /*bit_count=*/1)), - absl::StrFormat("true_predicate_for_chan_%s", channel->name()))); + predicate, proc->MakeNodeWithName( + SourceInfo(), Value(UBits(1, /*bit_count=*/1)), + absl::StrFormat("true_predicate_for_chan_%s", + completion_channel.channel->name()))); } XLS_ASSIGN_OR_RETURN(Node * free_token, proc->MakeNode(SourceInfo(), Value::Token())); @@ -459,13 +749,15 @@ absl::StatusOr MakeCompletionChannel( Receive * recv_completion, proc->MakeNodeWithName( SourceInfo(), free_token, predicate.value(), - completion_channel->name(), /*is_blocking=*/true, - absl::StrFormat("recv_completion_for_chan_%s", channel->name()))); - XLS_ASSIGN_OR_RETURN(Node * recv_completion_token, - proc->MakeNodeWithName( - SourceInfo(), recv_completion, 0, - absl::StrFormat("recv_completion_token_for_chan_%s", - channel->name()))); + completion_channel.channel->name(), /*is_blocking=*/true, + absl::StrFormat("recv_completion_for_chan_%s", + completion_channel.channel->name()))); + XLS_ASSIGN_OR_RETURN( + Node * recv_completion_token, + proc->MakeNodeWithName( + SourceInfo(), recv_completion, 0, + absl::StrFormat("recv_completion_token_for_chan_%s", + completion_channel.channel->name()))); // Replace usages of the token from the send/recv operation with the token // from the completion recv. switch (operation->op()) { @@ -499,11 +791,6 @@ absl::StatusOr MakeCompletionChannel( return completion_channel; } -struct ClonedChannelWithPredicate { - StreamingChannel* cloned_channel; - StreamingChannel* predicate_channel; -}; - // The adapter orders multiple channel operations with state that tracks if // there is an outstanding channel operation waiting to complete. struct ActivationState { @@ -742,35 +1029,40 @@ void MakeDebugTrace(BValue condition, // Makes activation network and adds asserts. absl::StatusOr MakeActivationNetwork( - ProcBuilder& pb, absl::Span token_dag, - ChannelStrictness strictness, NameUniquer& channel_name_uniquer) { + AdapterBuilder& ab, absl::Span token_dag) { ActivationNetwork activations; // First, make new predicate channels. The adapter will non-blocking receive // on each of these channels to get each operation's predicate. - absl::flat_hash_map pred_channels; + absl::flat_hash_map pred_channels; pred_channels.reserve(token_dag.size()); - for (const auto& [node, _] : token_dag) { - XLS_ASSIGN_OR_RETURN(StreamingChannel * pred_channel, - MakePredicateChannel(node, channel_name_uniquer)); - pred_channels.insert({node, pred_channel}); + + for (int64_t instance_number = 0; instance_number < token_dag.size(); + ++instance_number) { + Node* node = token_dag[instance_number].node; + XLS_ASSIGN_OR_RETURN(ChannelRef channel_ref, GetChannelRefUsedByNode(node)); + XLS_ASSIGN_OR_RETURN( + AdapterInputChannel pred_input_channel, + MakePredicateChannel(node, channel_ref, instance_number, ab)); + pred_channels.insert({node, pred_input_channel}); } // Now make state. We need to know the state in order to know the predicate on // the predicate receive. int64_t state_idx = 0; - for (const auto& [node, _] : token_dag) { - int64_t pred_channel_id = pred_channels.at(node)->id(); + for (int64_t instance_number = 0; instance_number < token_dag.size(); + ++instance_number) { + Node* node = token_dag[instance_number].node; ActivationNode activation; - activation.predicate_state.state = - pb.StateElement(absl::StrFormat("pred_%d", pred_channel_id), - Value(UBits(0, /*bit_count=*/1))); - activation.valid_state.state = - pb.StateElement(absl::StrFormat("pred_%d_valid", pred_channel_id), - Value(UBits(0, /*bit_count=*/1))); - activation.done_state.state = - pb.StateElement(absl::StrFormat("pred_%d_done", pred_channel_id), - Value(UBits(0, /*bit_count=*/1))); + activation.predicate_state.state = ab.adapter_builder().StateElement( + absl::StrFormat("pred_%d", instance_number), + Value(UBits(0, /*bit_count=*/1))); + activation.valid_state.state = ab.adapter_builder().StateElement( + absl::StrFormat("pred_%d_valid", instance_number), + Value(UBits(0, /*bit_count=*/1))); + activation.done_state.state = ab.adapter_builder().StateElement( + absl::StrFormat("pred_%d_done", instance_number), + Value(UBits(0, /*bit_count=*/1))); activation.predicate_state.state_idx = state_idx++; activation.valid_state.state_idx = state_idx++; @@ -780,8 +1072,13 @@ absl::StatusOr MakeActivationNetwork( // Make a non-blocking receive on the predicate channel for each operation. // Compute predicate values and activations for this tick. - for (const auto& [node, predecessors] : token_dag) { - StreamingChannel* pred_channel = pred_channels.at(node); + for (int64_t instance_number = 0; instance_number < token_dag.size(); + ++instance_number) { + Node* node = token_dag[instance_number].node; + const absl::flat_hash_set& predecessors = + token_dag[instance_number].predecessors; + + const AdapterInputChannel& pred_input_channel = pred_channels.at(node); ActivationNode& activation = activations.at(node); std::vector sorted_predecessors(predecessors.begin(), @@ -793,8 +1090,8 @@ absl::StatusOr MakeActivationNetwork( BValue predecessors_pred_recv_token = PredRecvTokensForActivations( activations, sorted_predecessors, absl::StrFormat("chan_%d_recv_pred_predeccesors_token", - pred_channel->id()), - pb); + instance_number), + ab.adapter_builder()); // Do a non-blocking receive on the predicate channel. // @@ -804,37 +1101,42 @@ absl::StatusOr MakeActivationNetwork( // successfully received and 'valid' indicates that it has been successfully // received. After the adapter completes all operations, 'valid' will be // reset to 0 and everything starts over again. - BValue do_pred_recv = pb.Not(activation.valid_state.state); - BValue recv = pb.ReceiveIfNonBlocking( - pred_channel, predecessors_pred_recv_token, do_pred_recv, SourceInfo(), - absl::StrFormat("recv_pred_%d", pred_channel->id())); - activation.pred_recv_token = pb.TupleIndex( + BValue do_pred_recv = + ab.adapter_builder().Not(activation.valid_state.state); + BValue recv = ab.adapter_builder().ReceiveIfNonBlocking( + pred_input_channel.adapter_receive_channel_ref, + predecessors_pred_recv_token, do_pred_recv, SourceInfo(), + absl::StrFormat("recv_pred_%d", instance_number)); + activation.pred_recv_token = ab.adapter_builder().TupleIndex( recv, 0, SourceInfo(), - absl::StrFormat("recv_pred_%d_token", pred_channel->id())); - BValue pred_recv_predicate = - pb.TupleIndex(recv, 1, SourceInfo(), - absl::StrFormat("recv_pred_%d_data", pred_channel->id())); - BValue pred_recv_valid = pb.TupleIndex( + absl::StrFormat("recv_pred_%d_token", instance_number)); + BValue pred_recv_predicate = ab.adapter_builder().TupleIndex( + recv, 1, SourceInfo(), + absl::StrFormat("recv_pred_%d_data", instance_number)); + BValue pred_recv_valid = ab.adapter_builder().TupleIndex( recv, 2, SourceInfo(), - absl::StrFormat("recv_pred_%d_valid", pred_channel->id())); - activation.predicate = pb.Or( + absl::StrFormat("recv_pred_%d_valid", instance_number)); + activation.predicate = ab.adapter_builder().Or( pred_recv_predicate, activation.predicate_state.state, SourceInfo(), - absl::StrFormat("pred_%d_updated", pred_channel->id())); - activation.valid = - pb.Or(pred_recv_valid, activation.valid_state.state, SourceInfo(), - absl::StrFormat("pred_%d_valid_updated", pred_channel->id())); - - BValue all_predecessors_done = AllPredecessorsDone( - activations, sorted_predecessors, - absl::StrFormat("%v_all_predecessors_done", *node), pb); - activation.activate = - pb.And({activation.predicate, activation.valid, - pb.Not(activation.done_state.state), all_predecessors_done}, - SourceInfo(), absl::StrFormat("%v_activate", *node)); - activation.done = - pb.Or({activation.activate, - pb.And(pb.Not(activation.predicate), activation.valid), - activation.done_state.state}); + absl::StrFormat("pred_%d_updated", instance_number)); + activation.valid = ab.adapter_builder().Or( + pred_recv_valid, activation.valid_state.state, SourceInfo(), + absl::StrFormat("pred_%d_valid_updated", instance_number)); + + BValue all_predecessors_done = + AllPredecessorsDone(activations, sorted_predecessors, + absl::StrFormat("%v_all_predecessors_done", *node), + ab.adapter_builder()); + activation.activate = ab.adapter_builder().And( + {activation.predicate, activation.valid, + ab.adapter_builder().Not(activation.done_state.state), + all_predecessors_done}, + SourceInfo(), absl::StrFormat("%v_activate", *node)); + activation.done = ab.adapter_builder().Or( + {activation.activate, + ab.adapter_builder().And( + ab.adapter_builder().Not(activation.predicate), activation.valid), + activation.done_state.state}); } // Make assertions that operations w/ no token ordering are mutually @@ -844,42 +1146,36 @@ absl::StatusOr MakeActivationNetwork( // are empty, so every operation should be mutually exclusive. In the // arbitrary_static_order case, nodes are linearized and have an added token // relationship with every other operation, and there should be no assertions. - MakeMutualExclusionAssertions(token_dag, activations, pb); + MakeMutualExclusionAssertions(token_dag, activations, ab.adapter_builder()); // Compute next-state signals. // Builds not_all_done signal which indicates if every operation is done. If // so, start the next set of operations by setting next signals for done and // valid to 0. - BValue not_all_done = - NextStateAndReturnNotAllDone(token_dag, activations, pb); + BValue not_all_done = NextStateAndReturnNotAllDone(token_dag, activations, + ab.adapter_builder()); // Make a trace to aid debugging. It only fires when all ops are done. - MakeDebugTrace(pb.Not(not_all_done), token_dag, activations, pb); + MakeDebugTrace(ab.adapter_builder().Not(not_all_done), token_dag, activations, + ab.adapter_builder()); return std::move(activations); } -absl::Status AddAdapterForMultipleReceives( - Package* p, StreamingChannel* channel, - const absl::flat_hash_set& ops, - NameUniquer& channel_name_uniquer) { +absl::Status AddAdapterForMultipleReceives(absl::Span ops, + AdapterBuilder& ab) { XLS_RET_CHECK_GT(ops.size(), 1); - XLS_ASSIGN_OR_RETURN(auto token_dags, - GetProjectedTokenDAG(ops, channel->GetStrictness())); - - std::string adapter_name = - absl::StrFormat("chan_%s_io_receive_adapter", channel->name()); + XLS_ASSIGN_OR_RETURN( + auto token_dags, + GetProjectedTokenDAG(ops, ab.adapted_channel_strictness())); - VLOG(4) << absl::StreamFormat("Channel %s has token dag %s.", channel->name(), + VLOG(4) << absl::StreamFormat("Channel %s has token dag %s.", + ab.adapted_channel_name(), absl::StrJoin(token_dags, ", ")); - ProcBuilder pb(adapter_name, p); - - XLS_ASSIGN_OR_RETURN( - ActivationNetwork activations, - MakeActivationNetwork(pb, token_dags, channel->GetStrictness(), - channel_name_uniquer)); + XLS_ASSIGN_OR_RETURN(ActivationNetwork activations, + MakeActivationNetwork(ab, token_dags)); BValue any_active; BValue external_recv_input_token; @@ -891,64 +1187,63 @@ absl::Status AddAdapterForMultipleReceives( all_activations.push_back(activations.at(node).activate); all_tokens.push_back(activations.at(node).pred_recv_token); } - any_active = pb.Or(all_activations, SourceInfo(), "any_active"); - external_recv_input_token = pb.AfterAll(all_tokens); + any_active = + ab.adapter_builder().Or(all_activations, SourceInfo(), "any_active"); + external_recv_input_token = ab.adapter_builder().AfterAll(all_tokens); } - BValue recv = pb.ReceiveIf(channel, external_recv_input_token, any_active, - SourceInfo(), "external_receive"); - BValue recv_token = - pb.TupleIndex(recv, 0, SourceInfo(), "external_receive_token"); - BValue recv_data = - pb.TupleIndex(recv, 1, SourceInfo(), "external_receive_data"); - - for (const auto& [node, _] : token_dags) { - const ActivationNode& activation = activations.at(node); + BValue recv = ab.adapter_builder().ReceiveIf( + AsReceiveChannelRefOrDie(ab.adapted_channel_ref_in_adapter()), + external_recv_input_token, any_active, SourceInfo(), "external_receive"); + BValue recv_token = ab.adapter_builder().TupleIndex(recv, 0, SourceInfo(), + "external_receive_token"); + BValue recv_data = ab.adapter_builder().TupleIndex(recv, 1, SourceInfo(), + "external_receive_data"); + + for (int64_t i = 0; i < token_dags.size(); ++i) { + Node* node = token_dags[i].node; + const ActivationNode& activation = activations.at(node); XLS_ASSIGN_OR_RETURN( - Channel * new_data_channel, - p->CloneChannel( - channel, - channel_name_uniquer.GetSanitizedUniqueName(channel->name()), - Package::CloneChannelOverrides() - .OverrideSupportedOps(ChannelOps::kSendReceive) - .OverrideFifoConfig( - // This is an internal channel that may be inlined during - // proc inlining, set FIFO depth to 1. Break cycles by - // registering push outputs. - // TODO: github/xls#1509 - revisit this if we have better - // ways of avoiding cycles in adapters. - FifoConfig(/*depth=*/1, /*bypass=*/false, - /*register_push_outputs=*/true, - /*register_pop_outputs=*/false)))); - XLS_RETURN_IF_ERROR( - ReplaceChannelUsedByNode(node, new_data_channel->name())); - BValue send_token = pb.AfterAll({activation.pred_recv_token, recv_token}); - pb.SendIf(new_data_channel, send_token, activation.activate, recv_data); + AdapterOutputChannel output_data_channel, + ab.AddAdapterOutputChannel( + absl::StrFormat("%s__data_%d", ab.adapted_channel_name(), i), + ab.adapted_channel_type(), + // This is an internal channel that may be inlined during proc + // inlining, set FIFO depth to 1. Break cycles by registering push + // outputs. + // TODO: github/xls#1509 - revisit this if we have better ways of + // avoiding cycles in adapters. + FifoConfig(/*depth=*/1, /*bypass=*/false, + /*register_push_outputs=*/true, + /*register_pop_outputs=*/false))); + XLS_RETURN_IF_ERROR(ReplaceChannelUsedByNode( + node, ChannelRefName(AsChannelRef( + output_data_channel.parent_receive_channel_ref)))); + BValue send_token = + ab.adapter_builder().AfterAll({activation.pred_recv_token, recv_token}); + ab.adapter_builder().SendIf(output_data_channel.adapter_send_channel_ref, + send_token, activation.activate, recv_data); } - return pb.Build(NextState(activations)).status(); + return ab.Build(NextState(activations)).status(); } -absl::Status AddAdapterForMultipleSends(Package* p, StreamingChannel* channel, - const absl::flat_hash_set& ops, - NameUniquer& channel_name_uniquer) { +absl::Status AddAdapterForMultipleSends(absl::Span ops, + AdapterBuilder& ab) { XLS_RET_CHECK_GT(ops.size(), 1); - XLS_ASSIGN_OR_RETURN(auto token_dags, - GetProjectedTokenDAG(ops, channel->GetStrictness())); - - std::string adapter_name = - absl::StrFormat("chan_%s_io_send_adapter", channel->name()); + XLS_ASSIGN_OR_RETURN( + auto token_dags, + GetProjectedTokenDAG(ops, ab.adapted_channel_strictness())); - VLOG(4) << absl::StreamFormat("Channel %s has token dag %s.", channel->name(), + VLOG(4) << absl::StreamFormat("Channel %s has token dag %s.", + ab.adapted_channel_name(), absl::StrJoin(token_dags, ", ")); - ProcBuilder pb(adapter_name, p); + XLS_ASSIGN_OR_RETURN(ActivationNetwork activations, + MakeActivationNetwork(ab, token_dags)); - XLS_ASSIGN_OR_RETURN( - ActivationNetwork activations, - MakeActivationNetwork(pb, token_dags, channel->GetStrictness(), - channel_name_uniquer)); + absl::flat_hash_map data_channels; BValue recv_after_all; BValue recv_data; @@ -961,52 +1256,60 @@ absl::Status AddAdapterForMultipleSends(Package* p, StreamingChannel* channel, std::vector recv_data_valids; recv_data_valids.reserve(token_dags.size()); - for (const auto& [node, _] : token_dags) { + for (int64_t i = 0; i < token_dags.size(); ++i) { + Node* node = token_dags[i].node; const ActivationNode& activation = activations.at(node); XLS_ASSIGN_OR_RETURN( - Channel * new_data_channel, - p->CloneChannel( - channel, - channel_name_uniquer.GetSanitizedUniqueName(channel->name()), - Package::CloneChannelOverrides() - .OverrideSupportedOps(ChannelOps::kSendReceive) - .OverrideFifoConfig( - // This is an internal channel that may be inlined during - // proc inlining, set FIFO depth to 1. Break cycles by - // registering push outputs. - // TODO: github/xls#1509 - revisit this if we have better - // ways of avoiding cycles in adapters. - FifoConfig(/*depth=*/1, /*bypass=*/false, - /*register_push_outputs=*/true, - /*register_pop_outputs=*/false)))); - XLS_RETURN_IF_ERROR( - ReplaceChannelUsedByNode(node, new_data_channel->name())); - BValue recv = pb.ReceiveIf(new_data_channel, activation.pred_recv_token, - activation.activate); - recv_tokens.push_back(pb.TupleIndex(recv, 0)); - recv_datas.push_back(pb.TupleIndex(recv, 1)); + AdapterInputChannel input_data_channel, + ab.AddAdapterInputChannel( + absl::StrFormat("%s__data_%d", ab.adapted_channel_name(), i), + ab.adapted_channel_type(), + // This is an internal channel that may be inlined during proc + // inlining, set FIFO depth to 1. Break cycles by registering push + // outputs. + // TODO: github/xls#1509 - revisit this if we have better ways of + // avoiding cycles in adapters. + FifoConfig(/*depth=*/1, /*bypass=*/false, + /*register_push_outputs=*/true, + /*register_pop_outputs=*/false))); + XLS_RETURN_IF_ERROR(ReplaceChannelUsedByNode( + node, ChannelRefName( + AsChannelRef(input_data_channel.parent_send_channel_ref)))); + BValue recv = ab.adapter_builder().ReceiveIf( + input_data_channel.adapter_receive_channel_ref, + activation.pred_recv_token, activation.activate); + recv_tokens.push_back(ab.adapter_builder().TupleIndex(recv, 0)); + recv_datas.push_back(ab.adapter_builder().TupleIndex(recv, 1)); recv_data_valids.push_back(activation.activate); + + data_channels[node] = input_data_channel; } - recv_after_all = pb.AfterAll(recv_tokens); + recv_after_all = ab.adapter_builder().AfterAll(recv_tokens); // Reverse for one hot select order. std::reverse(recv_data_valids.begin(), recv_data_valids.end()); - recv_data = pb.OneHotSelect(pb.Concat(recv_data_valids), recv_datas); - recv_data_valid = pb.Or(recv_data_valids); + recv_data = ab.adapter_builder().OneHotSelect( + ab.adapter_builder().Concat(recv_data_valids), recv_datas); + recv_data_valid = ab.adapter_builder().Or(recv_data_valids); } - BValue send_token = pb.SendIf(channel, recv_after_all, recv_data_valid, - recv_data, SourceInfo(), "external_send"); - BValue empty_tuple_literal = pb.Literal(Value::Tuple({})); - - for (const auto& [node, _] : token_dags) { - XLS_ASSIGN_OR_RETURN(StreamingChannel * completion_channel, - MakeCompletionChannel(node, channel_name_uniquer)); - pb.SendIf(completion_channel, send_token, activations.at(node).activate, - empty_tuple_literal); + BValue send_token = ab.adapter_builder().SendIf( + AsSendChannelRefOrDie(ab.adapted_channel_ref_in_adapter()), + recv_after_all, recv_data_valid, recv_data, SourceInfo(), + "external_send"); + BValue empty_tuple_literal = ab.adapter_builder().Literal(Value::Tuple({})); + + for (int64_t i = 0; i < token_dags.size(); ++i) { + Node* node = token_dags[i].node; + XLS_ASSIGN_OR_RETURN(AdapterOutputChannel completion_channel, + MakeCompletionChannel(node, i, ab)); + ab.adapter_builder().SendIf(completion_channel.adapter_send_channel_ref, + send_token, activations.at(node).activate, + empty_tuple_literal); } - return pb.Build(NextState(activations)).status(); + return ab.Build(NextState(activations)).status(); } + } // namespace absl::StatusOr ChannelLegalizationPass::RunInternal( @@ -1014,19 +1317,22 @@ absl::StatusOr ChannelLegalizationPass::RunInternal( PassResults* results) const { VLOG(3) << "Running channel legalization pass."; bool changed = false; - MultipleChannelOps multiple_ops = FindMultipleChannelOps(p); + XLS_ASSIGN_OR_RETURN(ProcElaboration elab, ElaboratePackage(p)); + XLS_ASSIGN_OR_RETURN(MultipleChannelOps multiple_ops, + FindMultipleChannelOps(elab)); - if (multiple_ops.multiple_receives.empty() && - multiple_ops.multiple_sends.empty()) { + if (multiple_ops.receives.empty() && multiple_ops.sends.empty()) { return false; } - NameUniquer channel_name_uniquer("__"); - for (Channel* channel : p->channels()) { - channel_name_uniquer.GetSanitizedUniqueName(channel->name()); + NameUniquer global_channel_name_uniquer("__"); + if (!p->ChannelsAreProcScoped()) { + for (Channel* channel : p->channels()) { + global_channel_name_uniquer.GetSanitizedUniqueName(channel->name()); + } } - for (const auto& [channel_name, ops] : multiple_ops.multiple_receives) { + for (const auto& [channel, ops] : multiple_ops.receives) { for (Receive* recv : ops) { if (!recv->is_blocking()) { return absl::InvalidArgumentError(absl::StrFormat( @@ -1035,9 +1341,10 @@ absl::StatusOr ChannelLegalizationPass::RunInternal( recv->GetName())); } } - XLS_ASSIGN_OR_RETURN(Channel * channel, p->GetChannel(channel_name)); if (channel->kind() != ChannelKind::kStreaming) { // Don't make adapters for non-streaming channels. + VLOG(4) << absl::StreamFormat( + "Multiple receives on non-streaming channel `%s`", channel->name()); continue; } StreamingChannel* streaming_channel = down_cast(channel); @@ -1045,19 +1352,50 @@ absl::StatusOr ChannelLegalizationPass::RunInternal( ChannelStrictness::kProvenMutuallyExclusive) { // Don't make adapters for channels that must be proven to be mutually // exclusive- they will be handled during scheduling. + VLOG(3) << absl::StreamFormat( + "Multiple receives on proven mutually exclusive channel `%s`", + channel->name()); continue; } + + std::unique_ptr adapter_builder; + std::string adapter_name = + absl::StrFormat("chan_%s_io_receive_adapter", channel->name()); + if (p->ChannelsAreProcScoped()) { + // For proc-scoped channels all receives are in the same proc. + Proc* proc = ops.front()->function_base()->AsProcOrDie(); + for (Receive* op : ops) { + if (op->function_base()->AsProcOrDie() != proc) { + return absl::UnimplementedError(absl::StrFormat( + "Channel `%s` has multiple receives in different procs: `%s` and " + "`%s`", + channel->name(), proc->name(), op->function_base()->name())); + } + } + XLS_ASSIGN_OR_RETURN(ReceiveChannelReference * channel_ref, + proc->GetReceiveChannelReference(channel->name())); + XLS_ASSIGN_OR_RETURN( + adapter_builder, + AdapterBuilder::CreateForProcScopedChannels(channel_ref, adapter_name, + /*parent_proc=*/proc)); + } else { + XLS_ASSIGN_OR_RETURN( + adapter_builder, + AdapterBuilder::CreateForGlobalChannels(channel, adapter_name, p, + global_channel_name_uniquer)); + } VLOG(3) << absl::StreamFormat( "Making receive channel adapter for channel `%s`, has receives (%s).", - channel_name, absl::StrJoin(ops, ", ")); - XLS_RETURN_IF_ERROR(AddAdapterForMultipleReceives(p, streaming_channel, ops, - channel_name_uniquer)); + channel->name(), absl::StrJoin(ops, ", ")); + XLS_RETURN_IF_ERROR(AddAdapterForMultipleReceives(ops, *adapter_builder)); changed = true; } - for (const auto& [channel_name, ops] : multiple_ops.multiple_sends) { - XLS_ASSIGN_OR_RETURN(Channel * channel, p->GetChannel(channel_name)); + + for (const auto& [channel, ops] : multiple_ops.sends) { if (channel->kind() != ChannelKind::kStreaming) { // Don't make adapters for non-streaming channels. + VLOG(4) << absl::StreamFormat( + "Multiple receives on non-streaming channel `%s`", channel->name()); continue; } StreamingChannel* streaming_channel = down_cast(channel); @@ -1065,15 +1403,46 @@ absl::StatusOr ChannelLegalizationPass::RunInternal( ChannelStrictness::kProvenMutuallyExclusive) { // Don't make adapters for channels that must be proven to be mutually // exclusive- they will be handled during scheduling. + VLOG(4) << absl::StreamFormat( + "Multiple sends on proven mutually exclusive channel `%s`", + channel->name()); continue; } - VLOG(3) << absl::StreamFormat( + + std::unique_ptr adapter_builder; + std::string adapter_name = + absl::StrFormat("chan_%s_io_send_adapter", channel->name()); + if (p->ChannelsAreProcScoped()) { + // For proc-scoped channels all sends are in the same proc. + Proc* proc = ops.front()->function_base()->AsProcOrDie(); + for (Send* op : ops) { + if (op->function_base()->AsProcOrDie() != proc) { + return absl::UnimplementedError(absl::StrFormat( + "Channel `%s` has multiple sends in different procs: `%s` and " + "`%s`", + channel->name(), proc->name(), op->function_base()->name())); + } + } + XLS_ASSIGN_OR_RETURN(SendChannelReference * channel_ref, + proc->GetSendChannelReference(channel->name())); + XLS_ASSIGN_OR_RETURN( + adapter_builder, + AdapterBuilder::CreateForProcScopedChannels(channel_ref, adapter_name, + /*parent_proc=*/proc)); + } else { + XLS_ASSIGN_OR_RETURN( + adapter_builder, + AdapterBuilder::CreateForGlobalChannels(channel, adapter_name, p, + global_channel_name_uniquer)); + } + + VLOG(4) << absl::StreamFormat( "Making send channel adapter for channel `%s`, has sends (%s).", - channel_name, absl::StrJoin(ops, ", ")); - XLS_RETURN_IF_ERROR(AddAdapterForMultipleSends(p, streaming_channel, ops, - channel_name_uniquer)); + channel->name(), absl::StrJoin(ops, ", ")); + XLS_RETURN_IF_ERROR(AddAdapterForMultipleSends(ops, *adapter_builder)); changed = true; } + return changed; } diff --git a/xls/passes/channel_legalization_pass_test.cc b/xls/passes/channel_legalization_pass_test.cc index acc2490f85..69eee799bf 100644 --- a/xls/passes/channel_legalization_pass_test.cc +++ b/xls/passes/channel_legalization_pass_test.cc @@ -43,6 +43,7 @@ #include "xls/ir/ir_parser.h" #include "xls/ir/ir_test_base.h" #include "xls/ir/package.h" +#include "xls/ir/proc_elaboration.h" #include "xls/ir/value.h" #include "xls/ir/verifier.h" #include "xls/passes/optimization_pass.h" @@ -68,6 +69,20 @@ using ::testing::UnorderedElementsAre; using ::testing::Values; using ::testing::ValuesIn; +// Create an interpreter runtime for evaluating procs. Automatically handles new +// and old style procs. +absl::StatusOr> CreateRuntime( + Package* package) { + if (!package->HasTop()) { + return CreateInterpreterSerialProcRuntime(package); + } + XLS_ASSIGN_OR_RETURN(Proc * top, package->GetTopAsProc()); + if (top->is_new_style_proc()) { + return CreateInterpreterSerialProcRuntime(top); + } + return CreateInterpreterSerialProcRuntime(package); +} + struct TestParam { using evaluation_function = std::function)>; @@ -89,7 +104,10 @@ class ChannelLegalizationPassTest protected: absl::StatusOr Run(Package* package) { PassResults results; - return ChannelLegalizationPass().Run(package, {}, &results); + XLS_ASSIGN_OR_RETURN(bool changed, + ChannelLegalizationPass().Run(package, {}, &results)); + XLS_RETURN_IF_ERROR(VerifyPackage(package)); + return changed; } }; @@ -857,6 +875,88 @@ top proc test_proc(state:(), init={()}) { EXPECT_THAT(outq->Read(), Optional(Value(UBits(1, /*bit_count=*/32)))); + return absl::OkStatus(); + }, + }, + TestParam{ + .test_name = "SingleNewStyleProc", + .ir_text = R"(package test + +top proc my_proc< + in: bits[32] in kind=streaming strictness=$0, + out: bits[32] out kind=streaming strictness=$0>() { + tok: token = literal(value=token) + recv0: (token, bits[32]) = receive(tok, channel=in) + recv0_tok: token = tuple_index(recv0, index=0) + recv0_data: bits[32] = tuple_index(recv0, index=1) + recv1: (token, bits[32]) = receive(recv0_tok, channel=in) + recv1_tok: token = tuple_index(recv1, index=0) + recv1_data: bits[32] = tuple_index(recv1, index=1) + send0: token = send(recv1_tok, recv1_data, channel=out) + send1: token = send(send0, recv0_data, channel=out) +} + )", + .builder_matcher = + { + // For mutually exclusive channels, channel legalization does + // not change the IR, but other passes do. Just check OK. + {ChannelStrictness::kProvenMutuallyExclusive, IsOk()}, + {ChannelStrictness::kTotalOrder, IsOkAndHolds(true)}, + {ChannelStrictness::kRuntimeOrdered, IsOkAndHolds(true)}, + // Build should be OK, but will fail at runtime. + {ChannelStrictness::kRuntimeMutuallyExclusive, + IsOkAndHolds(true)}, + {ChannelStrictness::kArbitraryStaticOrder, IsOkAndHolds(true)}, + }, + .evaluate = + [](SerialProcRuntime* interpreter, + std::optional strictness) -> absl::Status { + constexpr int64_t kMaxTicks = 1000; + constexpr int64_t kNumInputs = 32; + + const ProcElaboration& elab = + interpreter->queue_manager().elaboration(); + + XLS_ASSIGN_OR_RETURN(ChannelInstance * in_instance, + elab.GetChannelInstance("in", "my_proc")); + XLS_ASSIGN_OR_RETURN(ChannelInstance * out_instance, + elab.GetChannelInstance("out", "my_proc")); + ChannelQueue& inq = + interpreter->queue_manager().GetQueue(in_instance); + ChannelQueue& outq = + interpreter->queue_manager().GetQueue(out_instance); + + for (int64_t i = 0; i < kNumInputs; ++i) { + XLS_RETURN_IF_ERROR(inq.Write(Value(UBits(i, /*bit_count=*/32)))); + } + absl::flat_hash_map output_count{ + {outq.channel_instance(), kNumInputs}}; + absl::Status interpreter_status = + interpreter->TickUntilOutput(output_count, kMaxTicks).status(); + if (strictness.has_value() && + strictness.value() == + ChannelStrictness::kRuntimeMutuallyExclusive) { + EXPECT_THAT( + interpreter_status, + StatusIs(absl::StatusCode::kAborted, + HasSubstr("predicate was not mutually exclusive"))); + // Return early, we have no output to check. + return absl::OkStatus(); + } + XLS_EXPECT_OK(interpreter_status); + for (int64_t i = 0; i < kNumInputs; ++i) { + EXPECT_FALSE(outq.IsEmpty()); + int64_t flip_evens_and_odds = i; + if (i % 2 == 0) { + flip_evens_and_odds++; + } else { + flip_evens_and_odds--; + } + EXPECT_THAT(outq.Read(), + Optional(Eq(Value( + UBits(flip_evens_and_odds, /*bit_count=*/32))))); + } + return absl::OkStatus(); }, }, @@ -890,19 +990,19 @@ TEST_P(ChannelLegalizationPassTest, EvaluatesCorrectly) { std::get<0>(GetParam()).ir_text, ChannelStrictnessToString(strictness)))); XLS_ASSERT_OK_AND_ASSIGN(std::unique_ptr interpreter, - CreateInterpreterSerialProcRuntime(p.get())); + CreateRuntime(p.get())); // Don't pass in strictness because the pass hasn't been run yet. XLS_EXPECT_OK(std::get<0>(GetParam()) .evaluate(interpreter.get(), /*strictness=*/std::nullopt)); absl::StatusOr run_status = Run(p.get()); + if (!run_status.ok()) { GTEST_SKIP(); } - XLS_ASSERT_OK_AND_ASSIGN(interpreter, - CreateInterpreterSerialProcRuntime(p.get())); + XLS_ASSERT_OK_AND_ASSIGN(interpreter, CreateRuntime(p.get())); XLS_EXPECT_OK(std::get<0>(GetParam()) .evaluate(interpreter.get(), std::get<1>(GetParam()))); } @@ -928,31 +1028,31 @@ TEST_F(ChannelLegalizationPassTest, NamesAreUniquified) { // makes new channels. XLS_ASSERT_OK_AND_ASSIGN( StreamingChannel * c_1, - p.CreateStreamingChannel("c__1", ChannelOps::kSendOnly, + p.CreateStreamingChannel("c__data_1", ChannelOps::kSendOnly, p.GetBitsType(1))); XLS_ASSERT_OK_AND_ASSIGN( StreamingChannel * c_pred, - p.CreateStreamingChannel("c__pred", ChannelOps::kSendOnly, + p.CreateStreamingChannel("c__pred_0", ChannelOps::kSendOnly, p.GetBitsType(1))); XLS_ASSERT_OK_AND_ASSIGN( StreamingChannel * c_2_completion, - p.CreateStreamingChannel("c__2__completion", ChannelOps::kSendOnly, + p.CreateStreamingChannel("c__completion_1", ChannelOps::kSendOnly, p.GetBitsType(1))); XLS_ASSERT_OK_AND_ASSIGN( StreamingChannel * c_3_completion, - p.CreateStreamingChannel("c__3__completion", ChannelOps::kSendOnly, + p.CreateStreamingChannel("c__completion_2", ChannelOps::kSendOnly, p.GetBitsType(1))); XLS_ASSERT_OK_AND_ASSIGN( StreamingChannel * d_1, - p.CreateStreamingChannel("d__1", ChannelOps::kSendOnly, + p.CreateStreamingChannel("d__data_1", ChannelOps::kSendOnly, p.GetBitsType(1))); XLS_ASSERT_OK_AND_ASSIGN( StreamingChannel * d_pred, - p.CreateStreamingChannel("d__pred", ChannelOps::kSendOnly, + p.CreateStreamingChannel("d__pred_0", ChannelOps::kSendOnly, p.GetBitsType(1))); XLS_ASSERT_OK_AND_ASSIGN( StreamingChannel * d_completion, - p.CreateStreamingChannel("d__completion", ChannelOps::kSendOnly, + p.CreateStreamingChannel("d__completion_0", ChannelOps::kSendOnly, p.GetBitsType(1))); ProcBuilder b("proc0", &p); // Do multiple sends/receives on c/d to insert adapters and new channels. @@ -973,19 +1073,21 @@ TEST_F(ChannelLegalizationPassTest, NamesAreUniquified) { EXPECT_THAT(Run(&p), IsOkAndHolds(true)); - EXPECT_THAT(p.channels(), - UnorderedElementsAre( - m::Channel("c"), m::Channel("d"), - // Original "extra" names - m::Channel("c__1"), m::Channel("d__1"), m::Channel("c__pred"), - m::Channel("d__pred"), m::Channel("c__2__completion"), - m::Channel("c__3__completion"), m::Channel("d__completion"), - // New colliding names - m::Channel("c__2"), m::Channel("c__3"), m::Channel("d__2"), - m::Channel("d__3"), m::Channel("c__pred__1"), - m::Channel("c__pred__2"), m::Channel("d__pred__1"), - m::Channel("d__pred__2"), m::Channel("c__2__completion__1"), - m::Channel("c__3__completion__1"))); + EXPECT_THAT( + p.channels(), + UnorderedElementsAre( + m::Channel("c"), m::Channel("d"), + // Original "extra" names + m::Channel("c__data_1"), m::Channel("c__pred_0"), + m::Channel("c__completion_0"), m::Channel("c__completion_1"), + m::Channel("d__data_1"), m::Channel("d__pred_0"), + m::Channel("d__completion_0"), + // New colliding names + m::Channel("d__pred_0__1"), m::Channel("d__pred_1"), + m::Channel("d__data_0"), m::Channel("d__data_1__1"), + m::Channel("c__pred_0__1"), m::Channel("c__pred_1"), + m::Channel("c__data_0"), m::Channel("c__data_1__1"), + m::Channel("c__completion_1__1"), m::Channel("c__completion_2"))); } INSTANTIATE_TEST_SUITE_P( @@ -1039,14 +1141,19 @@ class SingleValueChannelLegalizationPassTest : public TestWithParam { absl::StatusOr Run() { // Replace all streaming channels with single_value channels. std::string substituted_ir_text = absl::StrReplaceAll( - GetParam().ir_text, { - {"kind=streaming, ops=send_only, " - "flow_control=ready_valid, strictness=$0, ", - "kind=single_value, ops=send_only, "}, - {"kind=streaming, ops=receive_only, " - "flow_control=ready_valid, strictness=$0, ", - "kind=single_value, ops=receive_only, "}, - }); + GetParam().ir_text, + { + // Global channel form. + {"kind=streaming, ops=send_only, " + "flow_control=ready_valid, strictness=$0, ", + "kind=single_value, ops=send_only, "}, + {"kind=streaming, ops=receive_only, " + "flow_control=ready_valid, strictness=$0, ", + "kind=single_value, ops=receive_only, "}, + // Proc-scoped channel form. + {"kind=streaming strictness=$0", "kind=single_value"}, + {"kind=streaming strictness=$0", "kind=single_value"}, + }); XLS_ASSIGN_OR_RETURN(std::unique_ptr p, Parser::ParsePackage(substituted_ir_text));