Fix server hang on client disconnect

https://github.com/lightpanda-io/browser/issues/425

Add a few integration tests for the TCP server which are fast enough to be run
as part of the unit tests (one of the new tests covers the above issue).
This commit is contained in:
Karl Seguin
2025-02-19 15:01:12 +08:00
parent 73a2fa3f9c
commit 39a9efb73b
4 changed files with 163 additions and 27 deletions

View File

@@ -242,7 +242,12 @@ pub const Page = struct {
// add global objects // add global objects
log.debug("setup global env", .{}); log.debug("setup global env", .{});
try self.session.env.bindGlobal(&self.session.window);
if (comptime builtin.is_test == false) {
// By not loading this during tests, we aren't required to load
// all of the interfaces into zig-js-runtime.
try self.session.env.bindGlobal(&self.session.window);
}
// load polyfills // load polyfills
try polyfill.load(self.arena.allocator(), self.session.env); try polyfill.load(self.arena.allocator(), self.session.env);

View File

@@ -337,12 +337,6 @@ test {
std.testing.refAllDecls(@import("generate.zig")); std.testing.refAllDecls(@import("generate.zig"));
std.testing.refAllDecls(@import("cdp/msg.zig")); std.testing.refAllDecls(@import("cdp/msg.zig"));
// Don't use refAllDecls, as this will pull in the entire project
// and break the test build.
// We should fix this. See this branch & the commit message for details:
// https://github.com/karlseguin/browser/commit/193ab5ceab3d3758ea06db04f7690460d79eb79e
_ = @import("server.zig");
} }
fn testJSRuntime(alloc: std.mem.Allocator) !void { fn testJSRuntime(alloc: std.mem.Allocator) !void {

View File

@@ -211,6 +211,13 @@ const Server = struct {
self.queueClose(client.socket); self.queueClose(client.socket);
return; return;
}; };
if (size == 0) {
if (self.client != null) {
self.client = null;
}
self.queueAccept();
return;
}
const more = client.processData(size) catch |err| { const more = client.processData(size) catch |err| {
log.err("Client Processing Error: {}\n", .{err}); log.err("Client Processing Error: {}\n", .{err});
@@ -1053,14 +1060,6 @@ pub fn run(
timeout: u64, timeout: u64,
loop: *jsruntime.Loop, loop: *jsruntime.Loop,
) !void { ) !void {
if (comptime builtin.is_test) {
// There's bunch of code that won't compiler in a test build (because
// it relies on a global root.Types). So we fight the compiler and make
// sure it doesn't include any of that code. Hopefully one day we can
// remove all this.
return;
}
// create socket // create socket
const flags = posix.SOCK.STREAM | posix.SOCK.CLOEXEC | posix.SOCK.NONBLOCK; const flags = posix.SOCK.STREAM | posix.SOCK.CLOEXEC | posix.SOCK.NONBLOCK;
const listener = try posix.socket(address.any.family, flags, posix.IPPROTO.TCP); const listener = try posix.socket(address.any.family, flags, posix.IPPROTO.TCP);
@@ -1631,6 +1630,49 @@ test "server: mask" {
} }
} }
test "server: 404" {
var c = try createTestClient();
defer c.deinit();
const res = try c.httpRequest("GET /unknown HTTP/1.1\r\n\r\n");
try testing.expectEqualStrings("HTTP/1.1 404 \r\n" ++
"Connection: Close\r\n" ++
"Content-Length: 9\r\n\r\n" ++
"Not found", res);
}
test "server: get /json/version" {
const expected_response =
"HTTP/1.1 200 OK\r\n" ++
"Content-Length: 48\r\n" ++
"Content-Type: application/json; charset=UTF-8\r\n\r\n" ++
"{\"webSocketDebuggerUrl\": \"ws://127.0.0.1:9583/\"}";
{
// twice on the same connection
var c = try createTestClient();
defer c.deinit();
const res1 = try c.httpRequest("GET /json/version HTTP/1.1\r\n\r\n");
try testing.expectEqualStrings(expected_response, res1);
const res2 = try c.httpRequest("GET /json/version HTTP/1.1\r\n\r\n");
try testing.expectEqualStrings(expected_response, res2);
}
{
// again on a new connection
var c = try createTestClient();
defer c.deinit();
const res1 = try c.httpRequest("GET /json/version HTTP/1.1\r\n\r\n");
try testing.expectEqualStrings(expected_response, res1);
const res2 = try c.httpRequest("GET /json/version HTTP/1.1\r\n\r\n");
try testing.expectEqualStrings(expected_response, res2);
}
}
fn assertHTTPError( fn assertHTTPError(
expected_error: HTTPError, expected_error: HTTPError,
comptime expected_status: u16, comptime expected_status: u16,
@@ -1762,3 +1804,63 @@ const MockServer = struct {
} }
} }
}; };
fn createTestClient() !TestClient {
const address = std.net.Address.initIp4([_]u8{ 127, 0, 0, 1 }, 9583);
const stream = try std.net.tcpConnectToAddress(address);
const timeout = std.mem.toBytes(posix.timeval{
.tv_sec = 2,
.tv_usec = 0,
});
try posix.setsockopt(stream.handle, posix.SOL.SOCKET, posix.SO.RCVTIMEO, &timeout);
try posix.setsockopt(stream.handle, posix.SOL.SOCKET, posix.SO.SNDTIMEO, &timeout);
return .{ .stream = stream };
}
const TestClient = struct {
stream: std.net.Stream,
buf: [1024]u8 = undefined,
fn deinit(self: *TestClient) void {
self.stream.close();
}
fn httpRequest(self: *TestClient, req: []const u8) ![]const u8 {
try self.stream.writeAll(req);
var pos: usize = 0;
var total_length: ?usize = null;
while (true) {
pos += try self.stream.read(self.buf[pos..]);
const response = self.buf[0..pos];
if (total_length == null) {
const header_end = std.mem.indexOf(u8, response, "\r\n\r\n") orelse continue;
const header = response[0 .. header_end + 4];
const cl_header = "Content-Length: ";
const start = (std.mem.indexOf(u8, header, cl_header) orelse {
return error.MissingContentLength;
}) + cl_header.len;
const end = std.mem.indexOfScalarPos(u8, header, start, '\r') orelse {
return error.InvalidContentLength;
};
const cl = std.fmt.parseInt(usize, header[start..end], 10) catch {
return error.InvalidContentLength;
};
total_length = cl + header.len;
}
if (total_length) |tl| {
if (pos == tl) {
return response;
}
if (pos > tl) {
return error.DataExceedsContentLength;
}
}
}
}
};

View File

@@ -18,10 +18,17 @@
const std = @import("std"); const std = @import("std");
const builtin = @import("builtin"); const builtin = @import("builtin");
const parser = @import("netsurf");
const Allocator = std.mem.Allocator; const Allocator = std.mem.Allocator;
const jsruntime = @import("jsruntime");
pub const Types = jsruntime.reflect(@import("generate.zig").Tuple(.{}){});
pub const UserContext = @import("user_context.zig").UserContext;
// pub const IO = @import("asyncio").Wrapper(jsruntime.Loop);
pub const std_options = std.Options{ pub const std_options = std.Options{
.log_level = .err,
.http_disable_tls = true, .http_disable_tls = true,
}; };
@@ -31,11 +38,16 @@ const BORDER = "=" ** 80;
var current_test: ?[]const u8 = null; var current_test: ?[]const u8 = null;
pub fn main() !void { pub fn main() !void {
try parser.init();
defer parser.deinit();
var mem: [8192]u8 = undefined; var mem: [8192]u8 = undefined;
var fba = std.heap.FixedBufferAllocator.init(&mem); var fba = std.heap.FixedBufferAllocator.init(&mem);
const allocator = fba.allocator(); const allocator = fba.allocator();
var loop = try jsruntime.Loop.init(allocator);
defer loop.deinit();
const env = Env.init(allocator); const env = Env.init(allocator);
defer env.deinit(allocator); defer env.deinit(allocator);
@@ -47,12 +59,20 @@ pub fn main() !void {
var skip: usize = 0; var skip: usize = 0;
var leak: usize = 0; var leak: usize = 0;
const address = try std.net.Address.parseIp("127.0.0.1", 9582); const http_thread = blk: {
var listener = try address.listen(.{ .reuse_address = true }); const address = try std.net.Address.parseIp("127.0.0.1", 9582);
defer listener.deinit(); const thread = try std.Thread.spawn(.{}, serveHTTP, .{address});
const http_thread = try std.Thread.spawn(.{}, serverHTTP, .{&listener}); break :blk thread;
};
defer http_thread.join(); defer http_thread.join();
const cdp_thread = blk: {
const address = try std.net.Address.parseIp("127.0.0.1", 9583);
const thread = try std.Thread.spawn(.{}, serveCDP, .{ allocator, address, &loop });
break :blk thread;
};
defer cdp_thread.join();
const printer = Printer.init(); const printer = Printer.init();
printer.fmt("\r\x1b[0K", .{}); // beginning of line and clear to end of line printer.fmt("\r\x1b[0K", .{}); // beginning of line and clear to end of line
@@ -98,7 +118,9 @@ pub fn main() !void {
} }
if (result) |_| { if (result) |_| {
pass += 1; if (is_unnamed_test == false) {
pass += 1;
}
} else |err| switch (err) { } else |err| switch (err) {
error.SkipZigTest => { error.SkipZigTest => {
skip += 1; skip += 1;
@@ -117,11 +139,13 @@ pub fn main() !void {
}, },
} }
if (env.verbose) { if (is_unnamed_test == false) {
const ms = @as(f64, @floatFromInt(ns_taken)) / 1_000_000.0; if (env.verbose) {
printer.status(status, "{s} ({d:.2}ms)\n", .{ friendly_name, ms }); const ms = @as(f64, @floatFromInt(ns_taken)) / 1_000_000.0;
} else { printer.status(status, "{s} ({d:.2}ms)\n", .{ friendly_name, ms });
printer.status(status, ".", .{}); } else {
printer.status(status, ".", .{});
}
} }
} }
@@ -294,7 +318,10 @@ fn isUnnamed(t: std.builtin.TestFn) bool {
return true; return true;
} }
fn serverHTTP(listener: *std.net.Server) !void { fn serveHTTP(address: std.net.Address) !void {
var listener = try address.listen(.{ .reuse_address = true });
defer listener.deinit();
var read_buffer: [1024]u8 = undefined; var read_buffer: [1024]u8 = undefined;
ACCEPT: while (true) { ACCEPT: while (true) {
var conn = try listener.accept(); var conn = try listener.accept();
@@ -320,6 +347,14 @@ fn serverHTTP(listener: *std.net.Server) !void {
} }
} }
fn serveCDP(allocator: Allocator, address: std.net.Address, loop: *jsruntime.Loop) !void {
const server = @import("server.zig");
server.run(allocator, address, std.time.ns_per_s * 2, loop) catch |err| {
std.debug.print("CDP server error: {}", .{err});
return err;
};
}
const Response = struct { const Response = struct {
body: []const u8 = "", body: []const u8 = "",
status: std.http.Status = .ok, status: std.http.Status = .ok,