diff --git a/src/TestHTTPServer.zig b/src/TestHTTPServer.zig index 9750a396..19c5bbc9 100644 --- a/src/TestHTTPServer.zig +++ b/src/TestHTTPServer.zig @@ -20,7 +20,7 @@ const std = @import("std"); const TestHTTPServer = @This(); -shutdown: bool, +shutdown: std.atomic.Value(bool), listener: ?std.net.Server, handler: Handler, @@ -28,16 +28,22 @@ const Handler = *const fn (req: *std.http.Server.Request) anyerror!void; pub fn init(handler: Handler) TestHTTPServer { return .{ - .shutdown = true, + .shutdown = .init(true), .listener = null, .handler = handler, }; } pub fn deinit(self: *TestHTTPServer) void { - self.shutdown = true; + self.listener = null; +} + +pub fn stop(self: *TestHTTPServer) void { + self.shutdown.store(true, .release); if (self.listener) |*listener| { - listener.deinit(); + // Use shutdown to unblock accept(). On Linux this causes accept to + // return error.SocketNotListening. close() alone doesn't interrupt accept(). + std.posix.shutdown(listener.stream.handle, .recv) catch {}; } } @@ -46,12 +52,13 @@ pub fn run(self: *TestHTTPServer, wg: *std.Thread.WaitGroup) !void { self.listener = try address.listen(.{ .reuse_address = true }); var listener = &self.listener.?; + self.shutdown.store(false, .release); wg.finish(); while (true) { const conn = listener.accept() catch |err| { - if (self.shutdown) { + if (self.shutdown.load(.acquire) or err == error.SocketNotListening) { return; } return err; diff --git a/src/test_runner.zig b/src/test_runner.zig index 2461c82f..1b9a2324 100644 --- a/src/test_runner.zig +++ b/src/test_runner.zig @@ -442,6 +442,7 @@ pub const TrackingAllocator = struct { allocated_bytes: usize = 0, allocation_count: usize = 0, reallocation_count: usize = 0, + mutex: std.Thread.Mutex = .{}, const Stats = struct { allocated_bytes: usize, @@ -479,6 +480,9 @@ pub const TrackingAllocator = struct { return_address: usize, ) ?[*]u8 { const self: *TrackingAllocator = @ptrCast(@alignCast(ctx)); + self.mutex.lock(); + defer self.mutex.unlock(); + const result = self.parent_allocator.rawAlloc(len, alignment, return_address); self.allocation_count += 1; self.allocated_bytes += len; @@ -493,6 +497,9 @@ pub const TrackingAllocator = struct { ra: usize, ) bool { const self: *TrackingAllocator = @ptrCast(@alignCast(ctx)); + self.mutex.lock(); + defer self.mutex.unlock(); + const result = self.parent_allocator.rawResize(old_mem, alignment, new_len, ra); self.reallocation_count += 1; // TODO: only if result is not null? return result; @@ -505,6 +512,9 @@ pub const TrackingAllocator = struct { ra: usize, ) void { const self: *TrackingAllocator = @ptrCast(@alignCast(ctx)); + self.mutex.lock(); + defer self.mutex.unlock(); + self.parent_allocator.rawFree(old_mem, alignment, ra); self.free_count += 1; } @@ -517,6 +527,9 @@ pub const TrackingAllocator = struct { ret_addr: usize, ) ?[*]u8 { const self: *TrackingAllocator = @ptrCast(@alignCast(ctx)); + self.mutex.lock(); + defer self.mutex.unlock(); + const result = self.parent_allocator.rawRemap(memory, alignment, new_len, ret_addr); self.reallocation_count += 1; // TODO: only if result is not null? return result; diff --git a/src/testing.zig b/src/testing.zig index 05c7efbc..41f67e96 100644 --- a/src/testing.zig +++ b/src/testing.zig @@ -450,7 +450,9 @@ const TestHTTPServer = @import("TestHTTPServer.zig"); const Server = @import("Server.zig"); var test_cdp_server: ?Server = null; +var test_cdp_server_thread: ?std.Thread = null; var test_http_server: ?TestHTTPServer = null; +var test_http_server_thread: ?std.Thread = null; var test_config: Config = undefined; @@ -482,16 +484,10 @@ test "tests:beforeAll" { var wg: std.Thread.WaitGroup = .{}; wg.startMany(2); - { - const thread = try std.Thread.spawn(.{}, serveCDP, .{&wg}); - thread.detach(); - } + test_cdp_server_thread = try std.Thread.spawn(.{}, serveCDP, .{&wg}); test_http_server = TestHTTPServer.init(testHTTPHandler); - { - const thread = try std.Thread.spawn(.{}, TestHTTPServer.run, .{ &test_http_server.?, &wg }); - thread.detach(); - } + test_http_server_thread = try std.Thread.spawn(.{}, TestHTTPServer.run, .{ &test_http_server.?, &wg }); // need to wait for the servers to be listening, else tests will fail because // they aren't able to connect. @@ -499,9 +495,22 @@ test "tests:beforeAll" { } test "tests:afterAll" { + if (test_cdp_server) |*server| { + server.stop(); + } + if (test_cdp_server_thread) |thread| { + thread.join(); + } if (test_cdp_server) |*server| { server.deinit(); } + + if (test_http_server) |*server| { + server.stop(); + } + if (test_http_server_thread) |thread| { + thread.join(); + } if (test_http_server) |*server| { server.deinit(); } @@ -518,8 +527,6 @@ fn serveCDP(wg: *std.Thread.WaitGroup) !void { const address = try std.net.Address.parseIp("127.0.0.1", 9583); test_cdp_server = try Server.init(test_app, address); - var server = try Server.init(test_app, address); - defer server.deinit(); wg.finish(); test_cdp_server.?.run(address, 5) catch |err| {