Skip to content

Commit

Permalink
coro/aio: get rid of data races
Browse files Browse the repository at this point in the history
  • Loading branch information
Cloudef committed Jul 3, 2024
1 parent 10dbec0 commit 9f03690
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 33 deletions.
2 changes: 2 additions & 0 deletions build.zig
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ pub fn build(b: *std.Build) void {
.root_source_file = b.path("examples/" ++ @tagName(example) ++ ".zig"),
.target = target,
.optimize = optimize,
.sanitize_thread = true,
});
exe.root_module.addImport("aio", aio);
exe.root_module.addImport("coro", coro);
Expand All @@ -80,6 +81,7 @@ pub fn build(b: *std.Build) void {
.filters = &.{test_filter},
.link_libc = aio.link_libc,
.single_threaded = aio.single_threaded,
.sanitize_thread = true,
});
if (mod != .minilib) tst.root_module.addImport("minilib", minilib);
if (mod == .aio) tst.root_module.addImport("build_options", opts.createModule());
Expand Down
16 changes: 8 additions & 8 deletions examples/coro_wttr.zig
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ const log = std.log.scoped(.coro_wttr);

// Just for fun, try returning a error from one of these tasks

fn getWeather(completed: *u32, allocator: std.mem.Allocator, city: []const u8, lang: []const u8) anyerror![]const u8 {
defer completed.* += 1;
fn getWeather(completed: *std.atomic.Value(u32), allocator: std.mem.Allocator, city: []const u8, lang: []const u8) anyerror![]const u8 {
defer _ = completed.fetchAdd(1, .monotonic);
var url: std.BoundedArray(u8, 256) = .{};
if (builtin.target.os.tag == .windows) {
try url.writer().print("https://wttr.in/{s}?AFT&lang={s}", .{ city, lang });
Expand All @@ -24,8 +24,8 @@ fn getWeather(completed: *u32, allocator: std.mem.Allocator, city: []const u8, l
return body.toOwnedSlice();
}

fn getLatestZig(completed: *u32, allocator: std.mem.Allocator) anyerror![]const u8 {
defer completed.* += 1;
fn getLatestZig(completed: *std.atomic.Value(u32), allocator: std.mem.Allocator) anyerror![]const u8 {
defer _ = completed.fetchAdd(1, .monotonic);
var body = std.ArrayList(u8).init(allocator);
defer body.deinit();
var client: std.http.Client = .{ .allocator = allocator };
Expand All @@ -42,7 +42,7 @@ fn getLatestZig(completed: *u32, allocator: std.mem.Allocator) anyerror![]const
return allocator.dupe(u8, parsed.value.master.version);
}

fn loader(completed: *u32, max: *const u32) !void {
fn loader(completed: *std.atomic.Value(u32), max: *const u32) !void {
const frames: []const []const u8 = &.{
"▰▱▱▱▱▱▱",
"▰▰▱▱▱▱▱",
Expand All @@ -58,7 +58,7 @@ fn loader(completed: *u32, max: *const u32) !void {
var idx: usize = 0;
while (true) : (idx +%= 1) {
try coro.io.single(aio.Timeout{ .ns = 80 * std.time.ns_per_ms });
std.debug.print(" {s} {}/{} loading that juicy info\r", .{ frames[idx % frames.len], completed.*, max.* });
std.debug.print(" {s} {}/{} loading that juicy info\r", .{ frames[idx % frames.len], completed.load(.acquire), max.* });
}
}

Expand All @@ -76,7 +76,7 @@ pub fn main() !void {
defer scheduler.deinit();

var max: u32 = 0;
var completed: u32 = 0;
var completed = std.atomic.Value(u32).init(0);
const ltask = try scheduler.spawn(loader, .{ &completed, &max }, .{});

var tpool: coro.ThreadPool = try coro.ThreadPool.init(gpa.allocator(), .{});
Expand All @@ -91,7 +91,7 @@ pub fn main() !void {
try tasks.append(try tpool.spawnForCompletition(&scheduler, getLatestZig, .{ &completed, allocator }, .{}));

max = @intCast(tasks.items.len);
while (completed < tasks.items.len) {
while (completed.load(.acquire) < tasks.items.len) {
_ = try scheduler.tick(.blocking);
}

Expand Down
8 changes: 4 additions & 4 deletions src/coro/ThreadPool.zig
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,14 @@ pub const CancellationToken = struct {
canceled: bool = false,
};

inline fn entrypoint(self: *@This(), completed: *bool, token: *CancellationToken, comptime func: anytype, res: anytype, args: anytype) void {
inline fn entrypoint(self: *@This(), completed: *std.atomic.Value(bool), token: *CancellationToken, comptime func: anytype, res: anytype, args: anytype) void {
const fun_info = @typeInfo(@TypeOf(func)).Fn;
if (fun_info.params.len > 0 and fun_info.params[0].type.? == *const CancellationToken) {
res.* = @call(.auto, func, .{token} ++ args);
} else {
res.* = @call(.auto, func, args);
}
completed.* = true;
completed.store(true, .release);
const n = self.num_tasks.load(.acquire);
for (0..n) |_| self.source.notify();
}
Expand All @@ -49,13 +49,13 @@ pub const YieldError = DynamicThreadPool.SpawnError;

/// Yield until `func` finishes on another thread
pub fn yieldForCompletition(self: *@This(), func: anytype, args: anytype) ReturnTypeMixedWithErrorSet(func, YieldError) {
var completed: bool = false;
var completed = std.atomic.Value(bool).init(false);
var res: ReturnType(func) = undefined;
_ = self.num_tasks.fetchAdd(1, .monotonic);
defer _ = self.num_tasks.fetchSub(1, .release);
var token: CancellationToken = .{};
try self.pool.spawn(entrypoint, .{ self, &completed, &token, func, &res, args });
while (!completed) {
while (!completed.load(.acquire)) {
const nerr = io.do(.{
aio.WaitEventSource{ .source = &self.source, .link = .soft },
}, if (token.canceled) .io_cancel else .io) catch 1;
Expand Down
33 changes: 12 additions & 21 deletions src/minilib/DynamicThreadPool.zig
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ idling_threads: u32 = 0,
active_threads: u32 = 0,
timeout: u64,
// used to serialize the acquisition order
serial: std.DynamicBitSetUnmanaged align(std.atomic.cache_line),
serial: std.DynamicBitSetUnmanaged,

const RunQueue = std.SinglyLinkedList(Runnable);
const Runnable = struct { runFn: RunProto };
Expand Down Expand Up @@ -133,6 +133,9 @@ pub fn spawn(self: *@This(), comptime func: anytype, args: anytype) SpawnError!v
}

fn worker(self: *@This(), thread: *DynamicThread, id: u32, timeout: u64) void {
self.mutex.lock();
defer self.mutex.unlock();

var timer = std.time.Timer.start() catch unreachable;
main: while (thread.active) {
// Serialize the acquisition order here so that threads will always pop the run queue in order
Expand All @@ -141,16 +144,14 @@ fn worker(self: *@This(), thread: *DynamicThread, id: u32, timeout: u64) void {
// If a thread keeps getting out done by the earlier threads, it will time out
const can_work: bool = blk: {
outer: while (id > 0 and thread.active) {
{
self.mutex.lock();
defer self.mutex.unlock();
if (self.run_queue.first == null) {
// We were outraced, go back to sleep
break :blk false;
}
if (self.run_queue.first == null) {
// We were outraced, go back to sleep
break :blk false;
}
if (timer.read() >= timeout) break :main;
for (0..id) |idx| if (!self.serial.isSet(idx)) {
self.mutex.unlock();
defer self.mutex.lock();
std.Thread.yield() catch {};
continue :outer;
};
Expand All @@ -163,15 +164,9 @@ fn worker(self: *@This(), thread: *DynamicThread, id: u32, timeout: u64) void {
self.serial.set(id);
defer self.serial.unset(id);
while (thread.active) {
// Get the node
const node = blk: {
self.mutex.lock();
defer self.mutex.unlock();
break :blk self.run_queue.popFirst();
};

// Do the work
if (node) |run_node| {
if (self.run_queue.popFirst()) |run_node| {
self.mutex.unlock();
defer self.mutex.lock();
const runFn = run_node.data.runFn;
runFn(self, &run_node.data);
timer.reset();
Expand All @@ -182,8 +177,6 @@ fn worker(self: *@This(), thread: *DynamicThread, id: u32, timeout: u64) void {
if (thread.active) {
const now = timer.read();
if (now >= timeout) break :main;
self.mutex.lock();
defer self.mutex.unlock();
if (self.run_queue.first == null) {
self.idling_threads += 1;
defer self.idling_threads -= 1;
Expand All @@ -192,8 +185,6 @@ fn worker(self: *@This(), thread: *DynamicThread, id: u32, timeout: u64) void {
}
}

self.mutex.lock();
defer self.mutex.unlock();
self.active_threads -= 1;

// This thread won't partipicate in the acquisition order anymore
Expand Down

0 comments on commit 9f03690

Please sign in to comment.