Fix race condition

This commit is contained in:
Nikolay Govorov
2026-02-04 13:48:07 +00:00
parent 7672b42fbc
commit 6ccd3f277b
3 changed files with 42 additions and 15 deletions

View File

@@ -20,7 +20,7 @@ const std = @import("std");
const TestHTTPServer = @This(); const TestHTTPServer = @This();
shutdown: bool, shutdown: std.atomic.Value(bool),
listener: ?std.net.Server, listener: ?std.net.Server,
handler: Handler, handler: Handler,
@@ -28,16 +28,22 @@ const Handler = *const fn (req: *std.http.Server.Request) anyerror!void;
pub fn init(handler: Handler) TestHTTPServer { pub fn init(handler: Handler) TestHTTPServer {
return .{ return .{
.shutdown = true, .shutdown = .init(true),
.listener = null, .listener = null,
.handler = handler, .handler = handler,
}; };
} }
pub fn deinit(self: *TestHTTPServer) void { pub fn deinit(self: *TestHTTPServer) void {
self.shutdown = true; self.listener = null;
}
pub fn stop(self: *TestHTTPServer) void {
self.shutdown.store(true, .release);
if (self.listener) |*listener| { if (self.listener) |*listener| {
listener.deinit(); // Use shutdown to unblock accept(). On Linux this causes accept to
// return error.SocketNotListening. close() alone doesn't interrupt accept().
std.posix.shutdown(listener.stream.handle, .recv) catch {};
} }
} }
@@ -46,12 +52,13 @@ pub fn run(self: *TestHTTPServer, wg: *std.Thread.WaitGroup) !void {
self.listener = try address.listen(.{ .reuse_address = true }); self.listener = try address.listen(.{ .reuse_address = true });
var listener = &self.listener.?; var listener = &self.listener.?;
self.shutdown.store(false, .release);
wg.finish(); wg.finish();
while (true) { while (true) {
const conn = listener.accept() catch |err| { const conn = listener.accept() catch |err| {
if (self.shutdown) { if (self.shutdown.load(.acquire) or err == error.SocketNotListening) {
return; return;
} }
return err; return err;

View File

@@ -442,6 +442,7 @@ pub const TrackingAllocator = struct {
allocated_bytes: usize = 0, allocated_bytes: usize = 0,
allocation_count: usize = 0, allocation_count: usize = 0,
reallocation_count: usize = 0, reallocation_count: usize = 0,
mutex: std.Thread.Mutex = .{},
const Stats = struct { const Stats = struct {
allocated_bytes: usize, allocated_bytes: usize,
@@ -479,6 +480,9 @@ pub const TrackingAllocator = struct {
return_address: usize, return_address: usize,
) ?[*]u8 { ) ?[*]u8 {
const self: *TrackingAllocator = @ptrCast(@alignCast(ctx)); const self: *TrackingAllocator = @ptrCast(@alignCast(ctx));
self.mutex.lock();
defer self.mutex.unlock();
const result = self.parent_allocator.rawAlloc(len, alignment, return_address); const result = self.parent_allocator.rawAlloc(len, alignment, return_address);
self.allocation_count += 1; self.allocation_count += 1;
self.allocated_bytes += len; self.allocated_bytes += len;
@@ -493,6 +497,9 @@ pub const TrackingAllocator = struct {
ra: usize, ra: usize,
) bool { ) bool {
const self: *TrackingAllocator = @ptrCast(@alignCast(ctx)); const self: *TrackingAllocator = @ptrCast(@alignCast(ctx));
self.mutex.lock();
defer self.mutex.unlock();
const result = self.parent_allocator.rawResize(old_mem, alignment, new_len, ra); const result = self.parent_allocator.rawResize(old_mem, alignment, new_len, ra);
self.reallocation_count += 1; // TODO: only if result is not null? self.reallocation_count += 1; // TODO: only if result is not null?
return result; return result;
@@ -505,6 +512,9 @@ pub const TrackingAllocator = struct {
ra: usize, ra: usize,
) void { ) void {
const self: *TrackingAllocator = @ptrCast(@alignCast(ctx)); const self: *TrackingAllocator = @ptrCast(@alignCast(ctx));
self.mutex.lock();
defer self.mutex.unlock();
self.parent_allocator.rawFree(old_mem, alignment, ra); self.parent_allocator.rawFree(old_mem, alignment, ra);
self.free_count += 1; self.free_count += 1;
} }
@@ -517,6 +527,9 @@ pub const TrackingAllocator = struct {
ret_addr: usize, ret_addr: usize,
) ?[*]u8 { ) ?[*]u8 {
const self: *TrackingAllocator = @ptrCast(@alignCast(ctx)); const self: *TrackingAllocator = @ptrCast(@alignCast(ctx));
self.mutex.lock();
defer self.mutex.unlock();
const result = self.parent_allocator.rawRemap(memory, alignment, new_len, ret_addr); const result = self.parent_allocator.rawRemap(memory, alignment, new_len, ret_addr);
self.reallocation_count += 1; // TODO: only if result is not null? self.reallocation_count += 1; // TODO: only if result is not null?
return result; return result;

View File

@@ -450,7 +450,9 @@ const TestHTTPServer = @import("TestHTTPServer.zig");
const Server = @import("Server.zig"); const Server = @import("Server.zig");
var test_cdp_server: ?Server = null; var test_cdp_server: ?Server = null;
var test_cdp_server_thread: ?std.Thread = null;
var test_http_server: ?TestHTTPServer = null; var test_http_server: ?TestHTTPServer = null;
var test_http_server_thread: ?std.Thread = null;
var test_config: Config = undefined; var test_config: Config = undefined;
@@ -482,16 +484,10 @@ test "tests:beforeAll" {
var wg: std.Thread.WaitGroup = .{}; var wg: std.Thread.WaitGroup = .{};
wg.startMany(2); wg.startMany(2);
{ test_cdp_server_thread = try std.Thread.spawn(.{}, serveCDP, .{&wg});
const thread = try std.Thread.spawn(.{}, serveCDP, .{&wg});
thread.detach();
}
test_http_server = TestHTTPServer.init(testHTTPHandler); test_http_server = TestHTTPServer.init(testHTTPHandler);
{ test_http_server_thread = try std.Thread.spawn(.{}, TestHTTPServer.run, .{ &test_http_server.?, &wg });
const thread = try std.Thread.spawn(.{}, TestHTTPServer.run, .{ &test_http_server.?, &wg });
thread.detach();
}
// need to wait for the servers to be listening, else tests will fail because // need to wait for the servers to be listening, else tests will fail because
// they aren't able to connect. // they aren't able to connect.
@@ -499,9 +495,22 @@ test "tests:beforeAll" {
} }
test "tests:afterAll" { test "tests:afterAll" {
if (test_cdp_server) |*server| {
server.stop();
}
if (test_cdp_server_thread) |thread| {
thread.join();
}
if (test_cdp_server) |*server| { if (test_cdp_server) |*server| {
server.deinit(); server.deinit();
} }
if (test_http_server) |*server| {
server.stop();
}
if (test_http_server_thread) |thread| {
thread.join();
}
if (test_http_server) |*server| { if (test_http_server) |*server| {
server.deinit(); server.deinit();
} }
@@ -518,8 +527,6 @@ fn serveCDP(wg: *std.Thread.WaitGroup) !void {
const address = try std.net.Address.parseIp("127.0.0.1", 9583); const address = try std.net.Address.parseIp("127.0.0.1", 9583);
test_cdp_server = try Server.init(test_app, address); test_cdp_server = try Server.init(test_app, address);
var server = try Server.init(test_app, address);
defer server.deinit();
wg.finish(); wg.finish();
test_cdp_server.?.run(address, 5) catch |err| { test_cdp_server.?.run(address, 5) catch |err| {