From 9f036905319d0628746ee47b556a1a65b0eca109 Mon Sep 17 00:00:00 2001 From: Jari Vetoniemi Date: Wed, 3 Jul 2024 14:09:46 +0900 Subject: [PATCH] coro/aio: get rid of data races --- build.zig | 2 ++ examples/coro_wttr.zig | 16 +++++++-------- src/coro/ThreadPool.zig | 8 ++++---- src/minilib/DynamicThreadPool.zig | 33 +++++++++++-------------------- 4 files changed, 26 insertions(+), 33 deletions(-) diff --git a/build.zig b/build.zig index c8aefb8..55e15af 100644 --- a/build.zig +++ b/build.zig @@ -59,6 +59,7 @@ pub fn build(b: *std.Build) void { .root_source_file = b.path("examples/" ++ @tagName(example) ++ ".zig"), .target = target, .optimize = optimize, + .sanitize_thread = true, }); exe.root_module.addImport("aio", aio); exe.root_module.addImport("coro", coro); @@ -80,6 +81,7 @@ pub fn build(b: *std.Build) void { .filters = &.{test_filter}, .link_libc = aio.link_libc, .single_threaded = aio.single_threaded, + .sanitize_thread = true, }); if (mod != .minilib) tst.root_module.addImport("minilib", minilib); if (mod == .aio) tst.root_module.addImport("build_options", opts.createModule()); diff --git a/examples/coro_wttr.zig b/examples/coro_wttr.zig index 38002c2..3759b1e 100644 --- a/examples/coro_wttr.zig +++ b/examples/coro_wttr.zig @@ -6,8 +6,8 @@ const log = std.log.scoped(.coro_wttr); // Just for fun, try returning a error from one of these tasks -fn getWeather(completed: *u32, allocator: std.mem.Allocator, city: []const u8, lang: []const u8) anyerror![]const u8 { - defer completed.* += 1; +fn getWeather(completed: *std.atomic.Value(u32), allocator: std.mem.Allocator, city: []const u8, lang: []const u8) anyerror![]const u8 { + defer _ = completed.fetchAdd(1, .monotonic); var url: std.BoundedArray(u8, 256) = .{}; if (builtin.target.os.tag == .windows) { try url.writer().print("https://wttr.in/{s}?AFT&lang={s}", .{ city, lang }); @@ -24,8 +24,8 @@ fn getWeather(completed: *u32, allocator: std.mem.Allocator, city: []const u8, l return body.toOwnedSlice(); } -fn getLatestZig(completed: *u32, allocator: std.mem.Allocator) anyerror![]const u8 { - defer completed.* += 1; +fn getLatestZig(completed: *std.atomic.Value(u32), allocator: std.mem.Allocator) anyerror![]const u8 { + defer _ = completed.fetchAdd(1, .monotonic); var body = std.ArrayList(u8).init(allocator); defer body.deinit(); var client: std.http.Client = .{ .allocator = allocator }; @@ -42,7 +42,7 @@ fn getLatestZig(completed: *u32, allocator: std.mem.Allocator) anyerror![]const return allocator.dupe(u8, parsed.value.master.version); } -fn loader(completed: *u32, max: *const u32) !void { +fn loader(completed: *std.atomic.Value(u32), max: *const u32) !void { const frames: []const []const u8 = &.{ "▰▱▱▱▱▱▱", "▰▰▱▱▱▱▱", @@ -58,7 +58,7 @@ fn loader(completed: *u32, max: *const u32) !void { var idx: usize = 0; while (true) : (idx +%= 1) { try coro.io.single(aio.Timeout{ .ns = 80 * std.time.ns_per_ms }); - std.debug.print(" {s} {}/{} loading that juicy info\r", .{ frames[idx % frames.len], completed.*, max.* }); + std.debug.print(" {s} {}/{} loading that juicy info\r", .{ frames[idx % frames.len], completed.load(.acquire), max.* }); } } @@ -76,7 +76,7 @@ pub fn main() !void { defer scheduler.deinit(); var max: u32 = 0; - var completed: u32 = 0; + var completed = std.atomic.Value(u32).init(0); const ltask = try scheduler.spawn(loader, .{ &completed, &max }, .{}); var tpool: coro.ThreadPool = try coro.ThreadPool.init(gpa.allocator(), .{}); @@ -91,7 +91,7 @@ pub fn main() !void { try tasks.append(try tpool.spawnForCompletition(&scheduler, getLatestZig, .{ &completed, allocator }, .{})); max = @intCast(tasks.items.len); - while (completed < tasks.items.len) { + while (completed.load(.acquire) < tasks.items.len) { _ = try scheduler.tick(.blocking); } diff --git a/src/coro/ThreadPool.zig b/src/coro/ThreadPool.zig index 9f29a47..3daeade 100644 --- a/src/coro/ThreadPool.zig +++ b/src/coro/ThreadPool.zig @@ -33,14 +33,14 @@ pub const CancellationToken = struct { canceled: bool = false, }; -inline fn entrypoint(self: *@This(), completed: *bool, token: *CancellationToken, comptime func: anytype, res: anytype, args: anytype) void { +inline fn entrypoint(self: *@This(), completed: *std.atomic.Value(bool), token: *CancellationToken, comptime func: anytype, res: anytype, args: anytype) void { const fun_info = @typeInfo(@TypeOf(func)).Fn; if (fun_info.params.len > 0 and fun_info.params[0].type.? == *const CancellationToken) { res.* = @call(.auto, func, .{token} ++ args); } else { res.* = @call(.auto, func, args); } - completed.* = true; + completed.store(true, .release); const n = self.num_tasks.load(.acquire); for (0..n) |_| self.source.notify(); } @@ -49,13 +49,13 @@ pub const YieldError = DynamicThreadPool.SpawnError; /// Yield until `func` finishes on another thread pub fn yieldForCompletition(self: *@This(), func: anytype, args: anytype) ReturnTypeMixedWithErrorSet(func, YieldError) { - var completed: bool = false; + var completed = std.atomic.Value(bool).init(false); var res: ReturnType(func) = undefined; _ = self.num_tasks.fetchAdd(1, .monotonic); defer _ = self.num_tasks.fetchSub(1, .release); var token: CancellationToken = .{}; try self.pool.spawn(entrypoint, .{ self, &completed, &token, func, &res, args }); - while (!completed) { + while (!completed.load(.acquire)) { const nerr = io.do(.{ aio.WaitEventSource{ .source = &self.source, .link = .soft }, }, if (token.canceled) .io_cancel else .io) catch 1; diff --git a/src/minilib/DynamicThreadPool.zig b/src/minilib/DynamicThreadPool.zig index 2508216..600eda4 100644 --- a/src/minilib/DynamicThreadPool.zig +++ b/src/minilib/DynamicThreadPool.zig @@ -20,7 +20,7 @@ idling_threads: u32 = 0, active_threads: u32 = 0, timeout: u64, // used to serialize the acquisition order -serial: std.DynamicBitSetUnmanaged align(std.atomic.cache_line), +serial: std.DynamicBitSetUnmanaged, const RunQueue = std.SinglyLinkedList(Runnable); const Runnable = struct { runFn: RunProto }; @@ -133,6 +133,9 @@ pub fn spawn(self: *@This(), comptime func: anytype, args: anytype) SpawnError!v } fn worker(self: *@This(), thread: *DynamicThread, id: u32, timeout: u64) void { + self.mutex.lock(); + defer self.mutex.unlock(); + var timer = std.time.Timer.start() catch unreachable; main: while (thread.active) { // Serialize the acquisition order here so that threads will always pop the run queue in order @@ -141,16 +144,14 @@ fn worker(self: *@This(), thread: *DynamicThread, id: u32, timeout: u64) void { // If a thread keeps getting out done by the earlier threads, it will time out const can_work: bool = blk: { outer: while (id > 0 and thread.active) { - { - self.mutex.lock(); - defer self.mutex.unlock(); - if (self.run_queue.first == null) { - // We were outraced, go back to sleep - break :blk false; - } + if (self.run_queue.first == null) { + // We were outraced, go back to sleep + break :blk false; } if (timer.read() >= timeout) break :main; for (0..id) |idx| if (!self.serial.isSet(idx)) { + self.mutex.unlock(); + defer self.mutex.lock(); std.Thread.yield() catch {}; continue :outer; }; @@ -163,15 +164,9 @@ fn worker(self: *@This(), thread: *DynamicThread, id: u32, timeout: u64) void { self.serial.set(id); defer self.serial.unset(id); while (thread.active) { - // Get the node - const node = blk: { - self.mutex.lock(); - defer self.mutex.unlock(); - break :blk self.run_queue.popFirst(); - }; - - // Do the work - if (node) |run_node| { + if (self.run_queue.popFirst()) |run_node| { + self.mutex.unlock(); + defer self.mutex.lock(); const runFn = run_node.data.runFn; runFn(self, &run_node.data); timer.reset(); @@ -182,8 +177,6 @@ fn worker(self: *@This(), thread: *DynamicThread, id: u32, timeout: u64) void { if (thread.active) { const now = timer.read(); if (now >= timeout) break :main; - self.mutex.lock(); - defer self.mutex.unlock(); if (self.run_queue.first == null) { self.idling_threads += 1; defer self.idling_threads -= 1; @@ -192,8 +185,6 @@ fn worker(self: *@This(), thread: *DynamicThread, id: u32, timeout: u64) void { } } - self.mutex.lock(); - defer self.mutex.unlock(); self.active_threads -= 1; // This thread won't partipicate in the acquisition order anymore