diff --git a/src/Config.zig b/src/Config.zig
index 6788db1d..624fc63b 100644
--- a/src/Config.zig
+++ b/src/Config.zig
@@ -128,6 +128,13 @@ pub fn httpMaxResponseSize(self: *const Config) ?usize {
};
}
+pub fn wsMaxConcurrent(self: *const Config) u8 {
+ return switch (self.mode) {
+ inline .serve, .fetch, .mcp => |opts| opts.common.ws_max_concurrent orelse 8,
+ else => unreachable,
+ };
+}
+
pub fn logLevel(self: *const Config) ?log.Level {
return switch (self.mode) {
inline .serve, .fetch, .mcp => |opts| opts.common.log_level,
@@ -275,6 +282,7 @@ pub const Common = struct {
http_timeout: ?u31 = null,
http_connect_timeout: ?u31 = null,
http_max_response_size: ?usize = null,
+ ws_max_concurrent: ?u8 = null,
tls_verify_host: bool = true,
log_level: ?log.Level = null,
log_format: ?log.Format = null,
@@ -375,6 +383,10 @@ pub fn printUsageAndExit(self: *const Config, success: bool) void {
\\ (e.g. XHR, fetch, script loading, ...).
\\ Defaults to no limit.
\\
+ \\--ws-max-concurrent
+ \\ The maximum number of concurrent WebSocket connections.
+ \\ Defaults to 8.
+ \\
\\--log-level The log level: debug, info, warn, error or fatal.
\\ Defaults to
++ (if (builtin.mode == .Debug) " info." else "warn.") ++
@@ -983,6 +995,19 @@ fn parseCommonArg(
return true;
}
+ if (std.mem.eql(u8, "--ws-max-concurrent", opt) or std.mem.eql(u8, "--ws_max_concurrent", opt)) {
+ const str = args.next() orelse {
+ log.fatal(.app, "missing argument value", .{ .arg = opt });
+ return error.InvalidArgument;
+ };
+
+ common.ws_max_concurrent = std.fmt.parseInt(u8, str, 10) catch |err| {
+ log.fatal(.app, "invalid argument value", .{ .arg = opt, .err = err });
+ return error.InvalidArgument;
+ };
+ return true;
+ }
+
if (std.mem.eql(u8, "--log-level", opt) or std.mem.eql(u8, "--log_level", opt)) {
const str = args.next() orelse {
log.fatal(.app, "missing argument value", .{ .arg = opt });
diff --git a/src/browser/tests/net/websocket.html b/src/browser/tests/net/websocket.html
index 8ad03a70..257c2136 100644
--- a/src/browser/tests/net/websocket.html
+++ b/src/browser/tests/net/websocket.html
@@ -238,3 +238,309 @@
});
}
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/src/browser/webapi/net/WebSocket.zig b/src/browser/webapi/net/WebSocket.zig
index 1a07dcf6..dc19728e 100644
--- a/src/browser/webapi/net/WebSocket.zig
+++ b/src/browser/webapi/net/WebSocket.zig
@@ -107,10 +107,7 @@ pub fn init(url: []const u8, protocols_: ?[]const u8, page: *Page) !*WebSocket {
const resolved_url = try URL.resolve(arena, page.base(), url, .{ .always_dupe = true, .encode = true });
const http_client = page._session.browser.http_client;
- const conn = http_client.network.getConnection() orelse {
- // TODO: figure out how/where we actually want to get WebSocket connections
- // from. I feel like sharing this with the HTTP Connection Pool is a
- // mistake.
+ const conn = http_client.network.newConnection() orelse {
return error.NoFreeConnection;
};
@@ -135,7 +132,7 @@ pub fn init(url: []const u8, protocols_: ?[]const u8, page: *Page) !*WebSocket {
try http_client.trackConn(conn);
if (comptime IS_DEBUG) {
- log.info(.http, "WS connecting", .{ .url = url });
+ log.info(.websocket, "connecting", .{ .url = url });
}
// Unlike an XHR object where we only selectively reference the instance
@@ -179,9 +176,9 @@ pub fn disconnected(self: *WebSocket, err_: ?anyerror) void {
self._ready_state = .closed;
if (err_) |err| {
- log.warn(.http, "WS disconnected", .{ .err = err, .url = self._url });
+ log.warn(.websocket, "disconnected", .{ .err = err, .url = self._url });
} else {
- log.info(.http, "WS disconnected", .{ .url = self._url, .reason = "closed" });
+ log.info(.websocket, "disconnected", .{ .url = self._url, .reason = "closed" });
}
self.cleanup();
@@ -191,7 +188,7 @@ pub fn disconnected(self: *WebSocket, err_: ?anyerror) void {
const reason = if (was_clean) self._close_reason else "";
self.dispatchCloseEvent(code, reason, was_clean) catch |err| {
- log.err(.http, "WS close event dispatch failed", .{ .err = err });
+ log.err(.websocket, "close event dispatch failed", .{ .err = err });
};
}
@@ -413,6 +410,7 @@ fn dispatchOpenEvent(self: *WebSocket) !void {
}
fn dispatchMessageEvent(self: *WebSocket, data: []const u8, frame_type: http.WsFrameType) !void {
+ std.debug.print("{any} {s}\n", .{ frame_type, data });
const page = self._page;
const target = self.asEventTarget();
@@ -450,7 +448,7 @@ fn sendDataCallback(buffer: [*]u8, buf_count: usize, buf_len: usize, data: *anyo
}
const conn: *http.Connection = @ptrCast(@alignCast(data));
return _sendDataCallback(conn, buffer[0..buf_len]) catch |err| {
- log.warn(.http, "WS send callback", .{ .err = err });
+ log.warn(.websocket, "send callback", .{ .err = err });
return http.readfunc_pause;
};
}
@@ -499,6 +497,9 @@ fn _sendDataCallback(conn: *http.Connection, buf: []u8) !usize {
fn writeContent(self: *WebSocket, conn: *http.Connection, buf: []u8, byte_msg: Message.Content, frame_type: http.WsFrameType) !usize {
if (self._send_offset == 0) {
// start of the message
+ if (comptime IS_DEBUG) {
+ log.debug(.websocket, "send start", .{ .url = self._url, .len = byte_msg.data.len });
+ }
try conn.wsStartFrame(frame_type, byte_msg.data.len);
}
@@ -511,6 +512,9 @@ fn writeContent(self: *WebSocket, conn: *http.Connection, buf: []u8, byte_msg: M
if (self._send_offset >= byte_msg.data.len) {
const removed = self._send_queue.orderedRemove(0);
removed.deinit(self._page._session);
+ if (comptime IS_DEBUG) {
+ log.debug(.websocket, "send complete", .{ .url = self._url, .len = byte_msg.data.len, .queue = self._send_queue.items.len });
+ }
self._send_offset = 0;
}
@@ -523,7 +527,7 @@ fn receivedDataCallback(buffer: [*]const u8, buf_count: usize, buf_len: usize, d
}
const conn: *http.Connection = @ptrCast(@alignCast(data));
_receivedDataCallback(conn, buffer[0..buf_len]) catch |err| {
- log.warn(.http, "WS receive callback", .{ .err = err });
+ log.warn(.websocket, "receive callback", .{ .err = err });
// TODO: are there errors, like an invalid frame, that we shouldn't treat
// as an error?
return http.writefunc_error;
@@ -535,11 +539,14 @@ fn receivedDataCallback(buffer: [*]const u8, buf_count: usize, buf_len: usize, d
fn _receivedDataCallback(conn: *http.Connection, data: []const u8) !void {
const self = conn.transport.websocket;
const meta = conn.wsMeta() orelse {
- log.err(.http, "WS missing meta", .{ .url = self._url });
+ log.err(.websocket, "missing meta", .{ .url = self._url });
return error.NoFrameMeta;
};
if (meta.offset == 0) {
+ if (comptime IS_DEBUG) {
+ log.debug(.websocket, "incoming message", .{ .url = self._url, .len = meta.len, .bytes_left = meta.bytes_left, .type = meta.frame_type });
+ }
// Start of new frame. Pre-allocate buffer
self._recv_buffer.clearRetainingCapacity();
if (meta.len > self._http_client.max_response_size) {
@@ -598,10 +605,10 @@ fn receivedHeaderCalllback(buffer: [*]const u8, header_count: usize, buf_len: us
}
self._ready_state = .open;
- log.info(.http, "WS connected", .{ .url = self._url });
+ log.info(.websocket, "connected", .{ .url = self._url });
self.dispatchOpenEvent() catch |err| {
- log.err(.http, "WS open event fail", .{ .err = err });
+ log.err(.websocket, "open event fail", .{ .err = err });
};
return buf_len;
}
diff --git a/src/log.zig b/src/log.zig
index 3e1016c5..84ff1049 100644
--- a/src/log.zig
+++ b/src/log.zig
@@ -40,6 +40,7 @@ pub const Scope = enum {
unknown_prop,
mcp,
cache,
+ websocket,
};
const Opts = struct {
diff --git a/src/network/Network.zig b/src/network/Network.zig
index ab11e5ce..1fb8c8fb 100644
--- a/src/network/Network.zig
+++ b/src/network/Network.zig
@@ -61,6 +61,11 @@ connections: []http.Connection,
available: std.DoublyLinkedList = .{},
conn_mutex: std.Thread.Mutex = .{},
+ws_pool: std.heap.MemoryPool(http.Connection),
+ws_count: usize = 0,
+ws_max: u8,
+ws_mutex: std.Thread.Mutex = .{},
+
pollfds: []posix.pollfd,
listener: ?Listener = null,
@@ -268,9 +273,13 @@ pub fn init(allocator: Allocator, app: *App, config: *const Config) !Network {
.connections = connections,
.app = app,
+
.robot_store = RobotStore.init(allocator),
.web_bot_auth = web_bot_auth,
.cache = cache,
+
+ .ws_pool = .init(allocator),
+ .ws_max = config.wsMaxConcurrent(),
};
}
@@ -298,6 +307,8 @@ pub fn deinit(self: *Network) void {
}
self.allocator.free(self.connections);
+ self.ws_pool.deinit();
+
self.robot_store.deinit();
if (self.web_bot_auth) |wba| {
wba.deinit(self.allocator);
@@ -592,18 +603,50 @@ pub fn getConnection(self: *Network) ?*http.Connection {
}
pub fn releaseConnection(self: *Network, conn: *http.Connection) void {
- conn.reset(self.config, self.ca_blob) catch |err| {
- lp.assert(false, "couldn't reset curl easy", .{ .err = err });
- };
-
- self.conn_mutex.lock();
- defer self.conn_mutex.unlock();
-
- self.available.append(&conn.node);
+ switch (conn.transport) {
+ .websocket => {
+ conn.deinit();
+ self.ws_mutex.lock();
+ defer self.ws_mutex.unlock();
+ self.ws_pool.destroy(conn);
+ self.ws_count -= 1;
+ },
+ else => {
+ conn.reset(self.config, self.ca_blob) catch |err| {
+ lp.assert(false, "couldn't reset curl easy", .{ .err = err });
+ };
+ self.conn_mutex.lock();
+ defer self.conn_mutex.unlock();
+ self.available.append(&conn.node);
+ },
+ }
}
-pub fn newConnection(self: *Network) !http.Connection {
- return http.Connection.init(self.ca_blob, self.config);
+pub fn newConnection(self: *Network) ?*http.Connection {
+ const conn = blk: {
+ self.ws_mutex.lock();
+ defer self.ws_mutex.unlock();
+
+ if (self.ws_count >= self.ws_max) {
+ return null;
+ }
+
+ const c = self.ws_pool.create() catch return null;
+ self.ws_count += 1;
+ break :blk c;
+ };
+
+ // don't do this under lock
+ conn.* = http.Connection.init(self.ca_blob, self.config) catch {
+ self.ws_mutex.lock();
+ defer self.ws_mutex.unlock();
+ self.ws_pool.destroy(conn);
+ self.ws_count -= 1;
+
+ return null;
+ };
+
+ return conn;
}
// Wraps lines @ 64 columns. A PEM is basically a base64 encoded DER (which is
diff --git a/src/testing.zig b/src/testing.zig
index 8dd2eb88..44abfea8 100644
--- a/src/testing.zig
+++ b/src/testing.zig
@@ -496,6 +496,7 @@ test "tests:beforeAll" {
.common = .{
.tls_verify_host = false,
.user_agent_suffix = "internal-tester",
+ .ws_max_concurrent = 50,
},
} });