diff --git a/src/Config.zig b/src/Config.zig index 6788db1d..624fc63b 100644 --- a/src/Config.zig +++ b/src/Config.zig @@ -128,6 +128,13 @@ pub fn httpMaxResponseSize(self: *const Config) ?usize { }; } +pub fn wsMaxConcurrent(self: *const Config) u8 { + return switch (self.mode) { + inline .serve, .fetch, .mcp => |opts| opts.common.ws_max_concurrent orelse 8, + else => unreachable, + }; +} + pub fn logLevel(self: *const Config) ?log.Level { return switch (self.mode) { inline .serve, .fetch, .mcp => |opts| opts.common.log_level, @@ -275,6 +282,7 @@ pub const Common = struct { http_timeout: ?u31 = null, http_connect_timeout: ?u31 = null, http_max_response_size: ?usize = null, + ws_max_concurrent: ?u8 = null, tls_verify_host: bool = true, log_level: ?log.Level = null, log_format: ?log.Format = null, @@ -375,6 +383,10 @@ pub fn printUsageAndExit(self: *const Config, success: bool) void { \\ (e.g. XHR, fetch, script loading, ...). \\ Defaults to no limit. \\ + \\--ws-max-concurrent + \\ The maximum number of concurrent WebSocket connections. + \\ Defaults to 8. + \\ \\--log-level The log level: debug, info, warn, error or fatal. \\ Defaults to ++ (if (builtin.mode == .Debug) " info." else "warn.") ++ @@ -983,6 +995,19 @@ fn parseCommonArg( return true; } + if (std.mem.eql(u8, "--ws-max-concurrent", opt) or std.mem.eql(u8, "--ws_max_concurrent", opt)) { + const str = args.next() orelse { + log.fatal(.app, "missing argument value", .{ .arg = opt }); + return error.InvalidArgument; + }; + + common.ws_max_concurrent = std.fmt.parseInt(u8, str, 10) catch |err| { + log.fatal(.app, "invalid argument value", .{ .arg = opt, .err = err }); + return error.InvalidArgument; + }; + return true; + } + if (std.mem.eql(u8, "--log-level", opt) or std.mem.eql(u8, "--log_level", opt)) { const str = args.next() orelse { log.fatal(.app, "missing argument value", .{ .arg = opt }); diff --git a/src/browser/tests/net/websocket.html b/src/browser/tests/net/websocket.html index 8ad03a70..257c2136 100644 --- a/src/browser/tests/net/websocket.html +++ b/src/browser/tests/net/websocket.html @@ -238,3 +238,309 @@ }); } + + + + + + + + + + + + + + + + + + + + + + diff --git a/src/browser/webapi/net/WebSocket.zig b/src/browser/webapi/net/WebSocket.zig index 1a07dcf6..dc19728e 100644 --- a/src/browser/webapi/net/WebSocket.zig +++ b/src/browser/webapi/net/WebSocket.zig @@ -107,10 +107,7 @@ pub fn init(url: []const u8, protocols_: ?[]const u8, page: *Page) !*WebSocket { const resolved_url = try URL.resolve(arena, page.base(), url, .{ .always_dupe = true, .encode = true }); const http_client = page._session.browser.http_client; - const conn = http_client.network.getConnection() orelse { - // TODO: figure out how/where we actually want to get WebSocket connections - // from. I feel like sharing this with the HTTP Connection Pool is a - // mistake. + const conn = http_client.network.newConnection() orelse { return error.NoFreeConnection; }; @@ -135,7 +132,7 @@ pub fn init(url: []const u8, protocols_: ?[]const u8, page: *Page) !*WebSocket { try http_client.trackConn(conn); if (comptime IS_DEBUG) { - log.info(.http, "WS connecting", .{ .url = url }); + log.info(.websocket, "connecting", .{ .url = url }); } // Unlike an XHR object where we only selectively reference the instance @@ -179,9 +176,9 @@ pub fn disconnected(self: *WebSocket, err_: ?anyerror) void { self._ready_state = .closed; if (err_) |err| { - log.warn(.http, "WS disconnected", .{ .err = err, .url = self._url }); + log.warn(.websocket, "disconnected", .{ .err = err, .url = self._url }); } else { - log.info(.http, "WS disconnected", .{ .url = self._url, .reason = "closed" }); + log.info(.websocket, "disconnected", .{ .url = self._url, .reason = "closed" }); } self.cleanup(); @@ -191,7 +188,7 @@ pub fn disconnected(self: *WebSocket, err_: ?anyerror) void { const reason = if (was_clean) self._close_reason else ""; self.dispatchCloseEvent(code, reason, was_clean) catch |err| { - log.err(.http, "WS close event dispatch failed", .{ .err = err }); + log.err(.websocket, "close event dispatch failed", .{ .err = err }); }; } @@ -413,6 +410,7 @@ fn dispatchOpenEvent(self: *WebSocket) !void { } fn dispatchMessageEvent(self: *WebSocket, data: []const u8, frame_type: http.WsFrameType) !void { + std.debug.print("{any} {s}\n", .{ frame_type, data }); const page = self._page; const target = self.asEventTarget(); @@ -450,7 +448,7 @@ fn sendDataCallback(buffer: [*]u8, buf_count: usize, buf_len: usize, data: *anyo } const conn: *http.Connection = @ptrCast(@alignCast(data)); return _sendDataCallback(conn, buffer[0..buf_len]) catch |err| { - log.warn(.http, "WS send callback", .{ .err = err }); + log.warn(.websocket, "send callback", .{ .err = err }); return http.readfunc_pause; }; } @@ -499,6 +497,9 @@ fn _sendDataCallback(conn: *http.Connection, buf: []u8) !usize { fn writeContent(self: *WebSocket, conn: *http.Connection, buf: []u8, byte_msg: Message.Content, frame_type: http.WsFrameType) !usize { if (self._send_offset == 0) { // start of the message + if (comptime IS_DEBUG) { + log.debug(.websocket, "send start", .{ .url = self._url, .len = byte_msg.data.len }); + } try conn.wsStartFrame(frame_type, byte_msg.data.len); } @@ -511,6 +512,9 @@ fn writeContent(self: *WebSocket, conn: *http.Connection, buf: []u8, byte_msg: M if (self._send_offset >= byte_msg.data.len) { const removed = self._send_queue.orderedRemove(0); removed.deinit(self._page._session); + if (comptime IS_DEBUG) { + log.debug(.websocket, "send complete", .{ .url = self._url, .len = byte_msg.data.len, .queue = self._send_queue.items.len }); + } self._send_offset = 0; } @@ -523,7 +527,7 @@ fn receivedDataCallback(buffer: [*]const u8, buf_count: usize, buf_len: usize, d } const conn: *http.Connection = @ptrCast(@alignCast(data)); _receivedDataCallback(conn, buffer[0..buf_len]) catch |err| { - log.warn(.http, "WS receive callback", .{ .err = err }); + log.warn(.websocket, "receive callback", .{ .err = err }); // TODO: are there errors, like an invalid frame, that we shouldn't treat // as an error? return http.writefunc_error; @@ -535,11 +539,14 @@ fn receivedDataCallback(buffer: [*]const u8, buf_count: usize, buf_len: usize, d fn _receivedDataCallback(conn: *http.Connection, data: []const u8) !void { const self = conn.transport.websocket; const meta = conn.wsMeta() orelse { - log.err(.http, "WS missing meta", .{ .url = self._url }); + log.err(.websocket, "missing meta", .{ .url = self._url }); return error.NoFrameMeta; }; if (meta.offset == 0) { + if (comptime IS_DEBUG) { + log.debug(.websocket, "incoming message", .{ .url = self._url, .len = meta.len, .bytes_left = meta.bytes_left, .type = meta.frame_type }); + } // Start of new frame. Pre-allocate buffer self._recv_buffer.clearRetainingCapacity(); if (meta.len > self._http_client.max_response_size) { @@ -598,10 +605,10 @@ fn receivedHeaderCalllback(buffer: [*]const u8, header_count: usize, buf_len: us } self._ready_state = .open; - log.info(.http, "WS connected", .{ .url = self._url }); + log.info(.websocket, "connected", .{ .url = self._url }); self.dispatchOpenEvent() catch |err| { - log.err(.http, "WS open event fail", .{ .err = err }); + log.err(.websocket, "open event fail", .{ .err = err }); }; return buf_len; } diff --git a/src/log.zig b/src/log.zig index 3e1016c5..84ff1049 100644 --- a/src/log.zig +++ b/src/log.zig @@ -40,6 +40,7 @@ pub const Scope = enum { unknown_prop, mcp, cache, + websocket, }; const Opts = struct { diff --git a/src/network/Network.zig b/src/network/Network.zig index ab11e5ce..1fb8c8fb 100644 --- a/src/network/Network.zig +++ b/src/network/Network.zig @@ -61,6 +61,11 @@ connections: []http.Connection, available: std.DoublyLinkedList = .{}, conn_mutex: std.Thread.Mutex = .{}, +ws_pool: std.heap.MemoryPool(http.Connection), +ws_count: usize = 0, +ws_max: u8, +ws_mutex: std.Thread.Mutex = .{}, + pollfds: []posix.pollfd, listener: ?Listener = null, @@ -268,9 +273,13 @@ pub fn init(allocator: Allocator, app: *App, config: *const Config) !Network { .connections = connections, .app = app, + .robot_store = RobotStore.init(allocator), .web_bot_auth = web_bot_auth, .cache = cache, + + .ws_pool = .init(allocator), + .ws_max = config.wsMaxConcurrent(), }; } @@ -298,6 +307,8 @@ pub fn deinit(self: *Network) void { } self.allocator.free(self.connections); + self.ws_pool.deinit(); + self.robot_store.deinit(); if (self.web_bot_auth) |wba| { wba.deinit(self.allocator); @@ -592,18 +603,50 @@ pub fn getConnection(self: *Network) ?*http.Connection { } pub fn releaseConnection(self: *Network, conn: *http.Connection) void { - conn.reset(self.config, self.ca_blob) catch |err| { - lp.assert(false, "couldn't reset curl easy", .{ .err = err }); - }; - - self.conn_mutex.lock(); - defer self.conn_mutex.unlock(); - - self.available.append(&conn.node); + switch (conn.transport) { + .websocket => { + conn.deinit(); + self.ws_mutex.lock(); + defer self.ws_mutex.unlock(); + self.ws_pool.destroy(conn); + self.ws_count -= 1; + }, + else => { + conn.reset(self.config, self.ca_blob) catch |err| { + lp.assert(false, "couldn't reset curl easy", .{ .err = err }); + }; + self.conn_mutex.lock(); + defer self.conn_mutex.unlock(); + self.available.append(&conn.node); + }, + } } -pub fn newConnection(self: *Network) !http.Connection { - return http.Connection.init(self.ca_blob, self.config); +pub fn newConnection(self: *Network) ?*http.Connection { + const conn = blk: { + self.ws_mutex.lock(); + defer self.ws_mutex.unlock(); + + if (self.ws_count >= self.ws_max) { + return null; + } + + const c = self.ws_pool.create() catch return null; + self.ws_count += 1; + break :blk c; + }; + + // don't do this under lock + conn.* = http.Connection.init(self.ca_blob, self.config) catch { + self.ws_mutex.lock(); + defer self.ws_mutex.unlock(); + self.ws_pool.destroy(conn); + self.ws_count -= 1; + + return null; + }; + + return conn; } // Wraps lines @ 64 columns. A PEM is basically a base64 encoded DER (which is diff --git a/src/testing.zig b/src/testing.zig index 8dd2eb88..44abfea8 100644 --- a/src/testing.zig +++ b/src/testing.zig @@ -496,6 +496,7 @@ test "tests:beforeAll" { .common = .{ .tls_verify_host = false, .user_agent_suffix = "internal-tester", + .ws_max_concurrent = 50, }, } });