Skip to content

Commit

Permalink
Extract a single proc in ir_minimizer
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 637665333
  • Loading branch information
allight authored and copybara-github committed May 27, 2024
1 parent 66262cd commit ec4df61
Show file tree
Hide file tree
Showing 2 changed files with 189 additions and 2 deletions.
77 changes: 77 additions & 0 deletions xls/tools/ir_minimizer_main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
#include "xls/ir/channel.h"
#include "xls/ir/channel_ops.h"
#include "xls/ir/events.h"
#include "xls/ir/function.h"
#include "xls/ir/function_base.h"
#include "xls/ir/function_builder.h"
#include "xls/ir/ir_parser.h"
Expand Down Expand Up @@ -128,6 +129,10 @@ ABSL_FLAG(int64_t, failed_attempt_limit, 256,
"done reducing.");
ABSL_FLAG(int64_t, total_attempt_limit, 16384,
"Limit on total number of attempts to try before bailing.");
ABSL_FLAG(
bool, can_extract_single_proc, false,
"Whether to extract a single proc from a network. Selected proc might not "
"be top. All internal channels are changed to be recv/send only.");
ABSL_FLAG(bool, can_extract_segments, false,
"Whether to allow the minimizer to extract segments of the IR. This "
"transform entirely removes some segment of logic and makes a new "
Expand Down Expand Up @@ -949,6 +954,71 @@ absl::StatusOr<SimplificationResult> SimplifyNode(
return SimplificationResult::kDidChange;
}

absl::StatusOr<SimplifiedIr> ExtractSingleProc(FunctionBase* f,
std::string* which_transform) {
XLS_RET_CHECK(f->IsProc());
Package* p = f->package();
Proc* proc = f->AsProcOrDie();
VLOG(3) << "Extracting proc " << proc->name() << " from network;";
if (proc->is_new_style_proc()) {
VLOG(3) << "new style procs not yet supported.";
return SimplifiedIr{.result = SimplificationResult::kDidNotChange,
.ir_data = f->package(),
.node_count = f->node_count()};
}
*which_transform = absl::StrFormat(
"Extracting proc %s from network into singleton proc.", proc->name());
absl::flat_hash_set<Channel*> send_chans;
absl::flat_hash_set<Channel*> recv_chans;
for (Node* n : proc->nodes()) {
if (n->Is<Send>()) {
XLS_ASSIGN_OR_RETURN(Channel * c,
p->GetChannel(n->As<Send>()->channel_name()));
send_chans.insert(c);
} else if (n->Is<Receive>()) {
XLS_ASSIGN_OR_RETURN(Channel * c,
p->GetChannel(n->As<Receive>()->channel_name()));
recv_chans.insert(c);
}
}
Package new_pkg = Package(f->package()->name());
absl::flat_hash_map<const Function*, Function*> function_remap;
absl::flat_hash_map<const FunctionBase*, FunctionBase*> function_base_remap;
// Keep all the subroutines.
for (FunctionBase* f : FunctionsInPostOrder(p)) {
if (f->IsFunction()) {
XLS_ASSIGN_OR_RETURN(
function_remap[f->AsFunctionOrDie()],
f->AsFunctionOrDie()->Clone(f->name(), &new_pkg, function_remap));
function_base_remap[f] = function_remap[f->AsFunctionOrDie()];
}
}
for (Channel* c : send_chans) {
XLS_RETURN_IF_ERROR(
new_pkg
.CloneChannel(c, c->name(),
Package::CloneChannelOverrides().OverrideSupportedOps(
ChannelOps::kSendOnly))
.status());
}
for (Channel* c : recv_chans) {
XLS_RETURN_IF_ERROR(
new_pkg
.CloneChannel(c, c->name(),
Package::CloneChannelOverrides().OverrideSupportedOps(
ChannelOps::kReceiveOnly))
.status());
}
XLS_ASSIGN_OR_RETURN(
Proc * new_proc,
proc->Clone(proc->name(), &new_pkg, /*channel_remapping=*/{},
/*call_remapping=*/function_base_remap));
XLS_RETURN_IF_ERROR(new_pkg.SetTop(new_proc));
return SimplifiedIr{.result = SimplificationResult::kDidChange,
.ir_data = new_pkg.DumpIr(),
.node_count = new_proc->node_count()};
}

// Picks a random node in the function 'f' and (if possible) generates a new
// package&function that only contains that node and its antecedents.
absl::StatusOr<SimplifiedIr> ExtractRandomNodeSubset(
Expand Down Expand Up @@ -1128,6 +1198,13 @@ absl::StatusOr<SimplifiedIr> Simplify(FunctionBase* f,
}
}

if (absl::GetFlag(FLAGS_can_extract_single_proc) && f->IsProc() &&
absl::Bernoulli(rng, 0.1) && f->package()->procs().size() > 1) {
XLS_ASSIGN_OR_RETURN(auto result, ExtractSingleProc(f, which_transform));
if (result.result != SimplificationResult::kDidNotChange) {
return result;
}
}
if (absl::Bernoulli(rng, 0.2)) {
XLS_ASSIGN_OR_RETURN(SimplificationResult result,
SimplifyReturnValue(f, rng, which_transform));
Expand Down
114 changes: 112 additions & 2 deletions xls/tools/ir_minimizer_main_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,84 @@
}
'''

PROC_2 = '''package multi_proc
file_number 0 "xls/jit/multi_proc.x"
chan multi_proc__bytes_src(bits[32], id=0, kind=streaming, ops=receive_only, flow_control=ready_valid, strictness=proven_mutually_exclusive, metadata="""""")
chan multi_proc__bytes_result(bits[32], id=1, kind=streaming, ops=send_only, flow_control=ready_valid, strictness=proven_mutually_exclusive, metadata="""""")
chan multi_proc__send_double_pipe(bits[32], id=2, kind=streaming, ops=send_receive, flow_control=ready_valid, strictness=proven_mutually_exclusive, metadata="""""")
chan multi_proc__send_quad_pipe(bits[32], id=3, kind=streaming, ops=send_receive, flow_control=ready_valid, strictness=proven_mutually_exclusive, metadata="""""")
chan multi_proc__recv_double_pipe(bits[32], id=4, kind=streaming, ops=send_receive, flow_control=ready_valid, strictness=proven_mutually_exclusive, metadata="""""")
chan multi_proc__recv_quad_pipe(bits[32], id=5, kind=streaming, ops=send_receive, flow_control=ready_valid, strictness=proven_mutually_exclusive, metadata="""""")
fn __multi_proc__double_it(n: bits[32]) -> bits[32] {
ret add.2: bits[32] = add(n, n, id=2, pos=[(0,17,32)])
}
fn __multi_proc__proc_double.init() -> () {
ret tuple.3: () = tuple(id=3, pos=[(0,25,11)])
}
fn __multi_proc__proc_quad.init() -> () {
ret tuple.4: () = tuple(id=4, pos=[(0,40,11)])
}
top proc __multi_proc__proc_ten_0_next(__state: (), init={()}) {
tok: token = after_all(id=8)
receive.9: (token, bits[32]) = receive(tok, channel=multi_proc__bytes_src, id=9)
tok__1: token = tuple_index(receive.9, index=0, id=11, pos=[(0,74,13)])
v: bits[32] = tuple_index(receive.9, index=1, id=12, pos=[(0,74,18)])
tok__2: token = send(tok__1, v, channel=multi_proc__send_quad_pipe, id=13)
receive.14: (token, bits[32]) = receive(tok__2, channel=multi_proc__recv_quad_pipe, id=14)
tok__3: token = tuple_index(receive.14, index=0, id=16, pos=[(0,76,13)])
tok__4: token = send(tok__3, v, channel=multi_proc__send_double_pipe, id=19)
qv: bits[32] = tuple_index(receive.14, index=1, id=17, pos=[(0,76,18)])
receive.20: (token, bits[32]) = receive(tok__4, channel=multi_proc__recv_double_pipe, id=20)
ev: bits[32] = invoke(qv, to_apply=__multi_proc__double_it, id=18, pos=[(0,77,26)])
dv: bits[32] = tuple_index(receive.20, index=1, id=23, pos=[(0,79,18)])
tok__5: token = tuple_index(receive.20, index=0, id=22, pos=[(0,79,13)])
add.24: bits[32] = add(ev, dv, id=24, pos=[(0,81,35)])
__token: token = literal(value=token, id=5)
literal.7: bits[1] = literal(value=1, id=7)
tuple_index.10: token = tuple_index(receive.9, index=0, id=10)
tuple_index.15: token = tuple_index(receive.14, index=0, id=15)
tuple_index.21: token = tuple_index(receive.20, index=0, id=21)
send.25: token = send(tok__5, add.24, channel=multi_proc__bytes_result, id=25)
tuple.26: () = tuple(id=26, pos=[(0,72,15)])
next (tuple.26)
}
proc __multi_proc__proc_ten__proc_double_0_next(__state: (), init={()}) {
tok: token = after_all(id=30)
receive.31: (token, bits[32]) = receive(tok, channel=multi_proc__send_double_pipe, id=31)
v: bits[32] = tuple_index(receive.31, index=1, id=34, pos=[(0,29,18)])
tok__1: token = tuple_index(receive.31, index=0, id=33, pos=[(0,29,13)])
invoke.35: bits[32] = invoke(v, to_apply=__multi_proc__double_it, id=35, pos=[(0,30,41)])
__token: token = literal(value=token, id=27)
literal.29: bits[1] = literal(value=1, id=29)
tuple_index.32: token = tuple_index(receive.31, index=0, id=32)
send.36: token = send(tok__1, invoke.35, channel=multi_proc__recv_double_pipe, id=36)
tuple.37: () = tuple(id=37, pos=[(0,27,15)])
next (tuple.37)
}
proc __multi_proc__proc_ten__proc_quad_0_next(__state: (), init={()}) {
tok: token = after_all(id=41)
receive.42: (token, bits[32]) = receive(tok, channel=multi_proc__send_quad_pipe, id=42)
v: bits[32] = tuple_index(receive.42, index=1, id=45, pos=[(0,44,18)])
invoke.46: bits[32] = invoke(v, to_apply=__multi_proc__double_it, id=46, pos=[(0,45,51)])
tok__1: token = tuple_index(receive.42, index=0, id=44, pos=[(0,44,13)])
invoke.47: bits[32] = invoke(invoke.46, to_apply=__multi_proc__double_it, id=47, pos=[(0,45,41)])
__token: token = literal(value=token, id=38)
literal.40: bits[1] = literal(value=1, id=40)
tuple_index.43: token = tuple_index(receive.42, index=0, id=43)
send.48: token = send(tok__1, invoke.47, channel=multi_proc__recv_quad_pipe, id=48)
tuple.49: () = tuple(id=49, pos=[(0,42,15)])
next (tuple.49)
}
'''


class IrMinimizerMainTest(absltest.TestCase):

Expand Down Expand Up @@ -168,6 +246,35 @@ def test_minimize_extract_things(self):
""",
)

def test_minimize_extract_single_proc(self):
ir_file = self.create_tempfile(content=PROC_2)
test_sh_file = self.create_tempfile()
self._write_sh_script(
test_sh_file.full_path,
["/usr/bin/env grep 'multi_proc__bytes_src' $1"],
)
minimized_ir = subprocess.check_output([
IR_MINIMIZER_MAIN_PATH,
f'--test_executable={test_sh_file.full_path}',
'--can_remove_params=false',
'--can_remove_sends=true',
'--can_remove_receives=true',
'--can_extract_single_proc=true',
ir_file.full_path,
])
self._maybe_record_property('output', minimized_ir.decode('utf-8'))
self.assertRegex(
minimized_ir.decode('utf-8'),
r'''package multi_proc
chan multi_proc__bytes_src\(bits\[32\], id=[0-9]+, kind=streaming, ops=receive_only, flow_control=ready_valid, strictness=proven_mutually_exclusive, metadata=""""""\)
top proc __multi_proc__proc_ten_0_next\(\) \{
tok: token = after_all\(id=1\)
receive_9: \(token, bits\[32\]\) = receive\(tok, channel=multi_proc__bytes_src, id=2\)
\}
''')

def test_minimize_inline_one_can_inline_other_invokes(self):
ir_file = self.create_tempfile(content=INVOKE_TWO_DEEP)
test_sh_file = self.create_tempfile()
Expand Down Expand Up @@ -209,7 +316,9 @@ def test_minimize_no_change_subroutine_type(self):
ir_file.full_path,
])
self._maybe_record_property('output', minimized_ir.decode('utf-8'))
self.assertEqual(minimized_ir.decode('utf-8'), """package foo
self.assertEqual(
minimized_ir.decode('utf-8'),
"""package foo
fn bar(x: bits[8][8]) -> bits[8][4] {
ret literal.57: bits[8][4] = literal(value=[0, 0, 0, 0], id=57)
Expand All @@ -221,7 +330,8 @@ def test_minimize_no_change_subroutine_type(self):
literal.64: bits[1] = literal(value=0, id=64)
ret array_slice.65: bits[8][2] = array_slice(invoke.26, literal.64, width=2, id=65)
}
""")
""",
)

def test_minimize_add_no_remove_params(self):
ir_file = self.create_tempfile(content=ADD_IR)
Expand Down

0 comments on commit ec4df61

Please sign in to comment.