From 098f26f409c22ec1a03cd969f61249f8088e10a0 Mon Sep 17 00:00:00 2001 From: Karl Seguin Date: Tue, 31 Mar 2026 20:37:28 +0800 Subject: [PATCH] WebSocket WebAPI Uses libcurl's websocket capabilities to add support for WebSocket. Depends on https://github.com/lightpanda-io/zig-v8-fork/pull/167 Issue: https://github.com/lightpanda-io/browser/issues/1952 This is a WIP because it currently uses the same connection pool used for all HTTP requests. It would be pretty easy for a page to starve the pool and block any progress. We previously stored the *Transfer inside of the easy's private data. We now store the *Connection, and a Connection now has a `transport` field which is a union for `http: *Transfer` or `websocket: *Websocket`. --- build.zig | 1 + build.zig.zon | 4 +- src/TestWSServer.zig | 371 ++++++++++++ src/browser/HttpClient.zig | 149 ++--- src/browser/js/bridge.zig | 2 + src/browser/tests/net/websocket.html | 240 ++++++++ src/browser/tests/net/websocket2.html | 233 ++++++++ src/browser/tests/net/websocket3.html | 77 +++ src/browser/webapi/Event.zig | 2 + src/browser/webapi/EventTarget.zig | 3 + src/browser/webapi/MessagePort.zig | 2 +- src/browser/webapi/Window.zig | 2 +- src/browser/webapi/event/CloseEvent.zig | 102 ++++ src/browser/webapi/event/MessageEvent.zig | 17 +- src/browser/webapi/net/WebSocket.zig | 687 ++++++++++++++++++++++ src/network/http.zig | 90 ++- src/sys/libcurl.zig | 122 ++++ src/testing.zig | 40 +- 18 files changed, 2039 insertions(+), 105 deletions(-) create mode 100644 src/TestWSServer.zig create mode 100644 src/browser/tests/net/websocket.html create mode 100644 src/browser/tests/net/websocket2.html create mode 100644 src/browser/tests/net/websocket3.html create mode 100644 src/browser/webapi/event/CloseEvent.zig create mode 100644 src/browser/webapi/net/WebSocket.zig diff --git a/build.zig b/build.zig index 4fba3dc9..0f640756 100644 --- a/build.zig +++ b/build.zig @@ -462,6 +462,7 @@ fn buildCurl( .CURL_DISABLE_SMTP = true, .CURL_DISABLE_TELNET = true, .CURL_DISABLE_TFTP = true, + .CURL_DISABLE_WEBSOCKETS = false, // Enable WebSocket support .ssize_t = null, ._FILE_OFFSET_BITS = 64, diff --git a/build.zig.zon b/build.zig.zon index f6c231bb..0c6096a4 100644 --- a/build.zig.zon +++ b/build.zig.zon @@ -5,8 +5,8 @@ .minimum_zig_version = "0.15.2", .dependencies = .{ .v8 = .{ - .url = "https://github.com/lightpanda-io/zig-v8-fork/archive/refs/tags/v0.3.7.tar.gz", - .hash = "v8-0.0.0-xddH67uBBAD95hWsPQz3Ni1PlZjdywtPXrGUAp8rSKco", + .url = "https://github.com/lightpanda-io/zig-v8-fork/archive/99c1ddf2d0b15f141e92ea09abdfc8e0e5f441e6.tar.gz", + .hash = "v8-0.0.0-xddH63-BBABP05dni8oMrs9qQwuczHhNhXHbXXlPb95s", }, // .v8 = .{ .path = "../zig-v8-fork" }, .brotli = .{ diff --git a/src/TestWSServer.zig b/src/TestWSServer.zig new file mode 100644 index 00000000..4d7ddd54 --- /dev/null +++ b/src/TestWSServer.zig @@ -0,0 +1,371 @@ +// Copyright (C) 2023-2026 Lightpanda (Selecy SAS) +// +// Francis Bouvier +// Pierre Tachoire +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as +// published by the Free Software Foundation, either version 3 of the +// License, or (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +const std = @import("std"); +const posix = std.posix; + +const TestWSServer = @This(); + +shutdown: std.atomic.Value(bool), +listener: ?posix.socket_t, + +pub fn init() TestWSServer { + return .{ + .shutdown = .init(true), + .listener = null, + }; +} + +pub fn deinit(self: *TestWSServer) void { + if (self.listener) |socket| { + posix.close(socket); + self.listener = null; + } +} + +pub fn stop(self: *TestWSServer) void { + self.shutdown.store(true, .release); + if (self.listener) |socket| { + posix.close(socket); + self.listener = null; + } +} + +pub fn run(self: *TestWSServer, wg: *std.Thread.WaitGroup) void { + self.runImpl(wg) catch |err| { + std.debug.print("WebSocket echo server error: {}\n", .{err}); + }; +} + +fn runImpl(self: *TestWSServer, wg: *std.Thread.WaitGroup) !void { + const socket = try posix.socket(posix.AF.INET, posix.SOCK.STREAM, 0); + errdefer posix.close(socket); + + const addr = std.net.Address.initIp4(.{ 127, 0, 0, 1 }, 9584); + + try posix.setsockopt(socket, posix.SOL.SOCKET, posix.SO.REUSEADDR, &std.mem.toBytes(@as(c_int, 1))); + try posix.bind(socket, &addr.any, addr.getOsSockLen()); + try posix.listen(socket, 8); + + self.listener = socket; + self.shutdown.store(false, .release); + wg.finish(); + + while (!self.shutdown.load(.acquire)) { + var client_addr: posix.sockaddr = undefined; + var addr_len: posix.socklen_t = @sizeOf(posix.sockaddr); + + const client = posix.accept(socket, &client_addr, &addr_len, 0) catch |err| { + if (self.shutdown.load(.acquire)) return; + std.debug.print("[WS Server] Accept error: {}\n", .{err}); + continue; + }; + + const thread = std.Thread.spawn(.{}, handleClient, .{client}) catch |err| { + std.debug.print("[WS Server] Thread spawn error: {}\n", .{err}); + posix.close(client); + continue; + }; + thread.detach(); + } +} + +fn handleClient(client: posix.socket_t) void { + defer posix.close(client); + + var buf: [4096]u8 = undefined; + const n = posix.read(client, &buf) catch return; + + const request = buf[0..n]; + + // Find Sec-WebSocket-Key + const key_header = "Sec-WebSocket-Key: "; + const key_start = std.mem.indexOf(u8, request, key_header) orelse return; + const key_line_start = key_start + key_header.len; + const key_end = std.mem.indexOfScalarPos(u8, request, key_line_start, '\r') orelse return; + const key = request[key_line_start..key_end]; + + // Compute accept key + var hasher = std.crypto.hash.Sha1.init(.{}); + hasher.update(key); + hasher.update("258EAFA5-E914-47DA-95CA-C5AB0DC85B11"); + var hash: [20]u8 = undefined; + hasher.final(&hash); + + var accept_key: [28]u8 = undefined; + _ = std.base64.standard.Encoder.encode(&accept_key, &hash); + + // Send upgrade response + var resp_buf: [256]u8 = undefined; + const resp = std.fmt.bufPrint(&resp_buf, "HTTP/1.1 101 Switching Protocols\r\n" ++ + "Upgrade: websocket\r\n" ++ + "Connection: Upgrade\r\n" ++ + "Sec-WebSocket-Accept: {s}\r\n\r\n", .{accept_key}) catch return; + _ = posix.write(client, resp) catch return; + + // Message loop with larger buffer for big messages + var msg_buf: [128 * 1024]u8 = undefined; + var recv_buf = RecvBuffer{ .buf = &msg_buf }; + + while (true) { + const frame = recv_buf.readFrame(client) orelse break; + + // Close frame - echo it back before closing + if (frame.opcode == 8) { + sendFrame(client, 8, "", frame.payload) catch {}; + break; + } + + // Handle commands or echo + if (frame.opcode == 1) { // Text + handleTextMessage(client, frame.payload) catch break; + } else if (frame.opcode == 2) { // Binary + handleBinaryMessage(client, frame.payload) catch break; + } + } +} + +const Frame = struct { + opcode: u8, + payload: []u8, +}; + +const RecvBuffer = struct { + buf: []u8, + start: usize = 0, + end: usize = 0, + + fn available(self: *RecvBuffer) []u8 { + return self.buf[self.start..self.end]; + } + + fn consume(self: *RecvBuffer, n: usize) void { + self.start += n; + if (self.start >= self.end) { + self.start = 0; + self.end = 0; + } + } + + fn ensureBytes(self: *RecvBuffer, client: posix.socket_t, needed: usize) bool { + while (self.end - self.start < needed) { + // Compact buffer if needed + if (self.end >= self.buf.len - 1024) { + const avail = self.end - self.start; + std.mem.copyForwards(u8, self.buf[0..avail], self.buf[self.start..self.end]); + self.start = 0; + self.end = avail; + } + + const n = posix.read(client, self.buf[self.end..]) catch return false; + if (n == 0) return false; + self.end += n; + } + return true; + } + + fn readFrame(self: *RecvBuffer, client: posix.socket_t) ?Frame { + // Need at least 2 bytes for basic header + if (!self.ensureBytes(client, 2)) return null; + + const data = self.available(); + const opcode = data[0] & 0x0F; + const masked = (data[1] & 0x80) != 0; + var payload_len: usize = data[1] & 0x7F; + var header_size: usize = 2; + + // Extended payload length + if (payload_len == 126) { + if (!self.ensureBytes(client, 4)) return null; + const d = self.available(); + payload_len = @as(usize, d[2]) << 8 | d[3]; + header_size = 4; + } else if (payload_len == 127) { + if (!self.ensureBytes(client, 10)) return null; + const d = self.available(); + payload_len = @as(usize, d[2]) << 56 | + @as(usize, d[3]) << 48 | + @as(usize, d[4]) << 40 | + @as(usize, d[5]) << 32 | + @as(usize, d[6]) << 24 | + @as(usize, d[7]) << 16 | + @as(usize, d[8]) << 8 | + d[9]; + header_size = 10; + } + + const mask_size: usize = if (masked) 4 else 0; + const total_frame_size = header_size + mask_size + payload_len; + + if (!self.ensureBytes(client, total_frame_size)) return null; + + const frame_data = self.available(); + + // Get mask key if present + var mask_key: [4]u8 = undefined; + if (masked) { + @memcpy(&mask_key, frame_data[header_size..][0..4]); + } + + // Get payload and unmask + const payload_start = header_size + mask_size; + const payload = frame_data[payload_start..][0..payload_len]; + + if (masked) { + for (payload, 0..) |*b, i| { + b.* ^= mask_key[i % 4]; + } + } + + self.consume(total_frame_size); + + return .{ .opcode = opcode, .payload = payload }; + } +}; + +fn handleTextMessage(client: posix.socket_t, payload: []const u8) !void { + // Command: force-close - close socket immediately without close frame + if (std.mem.eql(u8, payload, "force-close")) { + return error.ForceClose; + } + + // Command: send-large:N - send a message of N bytes + if (std.mem.startsWith(u8, payload, "send-large:")) { + const size_str = payload["send-large:".len..]; + const size = std.fmt.parseInt(usize, size_str, 10) catch return error.InvalidCommand; + try sendLargeMessage(client, size); + return; + } + + // Command: close:CODE:REASON - send close frame with specific code/reason + if (std.mem.startsWith(u8, payload, "close:")) { + const rest = payload["close:".len..]; + if (std.mem.indexOf(u8, rest, ":")) |sep| { + const code = std.fmt.parseInt(u16, rest[0..sep], 10) catch 1000; + const reason = rest[sep + 1 ..]; + try sendCloseFrame(client, code, reason); + } + return; + } + + // Default: echo with "echo-" prefix + const prefix = "echo-"; + try sendFrame(client, 1, prefix, payload); +} + +fn handleBinaryMessage(client: posix.socket_t, payload: []const u8) !void { + // Echo binary data back with byte 0xEE prepended as marker + const marker = [_]u8{0xEE}; + try sendFrame(client, 2, &marker, payload); +} + +fn sendFrame(client: posix.socket_t, opcode: u8, prefix: []const u8, payload: []const u8) !void { + const total_len = prefix.len + payload.len; + + // Build header + var header: [10]u8 = undefined; + var header_len: usize = 2; + + header[0] = 0x80 | opcode; // FIN + opcode + + if (total_len <= 125) { + header[1] = @intCast(total_len); + } else if (total_len <= 65535) { + header[1] = 126; + header[2] = @intCast((total_len >> 8) & 0xFF); + header[3] = @intCast(total_len & 0xFF); + header_len = 4; + } else { + header[1] = 127; + header[2] = @intCast((total_len >> 56) & 0xFF); + header[3] = @intCast((total_len >> 48) & 0xFF); + header[4] = @intCast((total_len >> 40) & 0xFF); + header[5] = @intCast((total_len >> 32) & 0xFF); + header[6] = @intCast((total_len >> 24) & 0xFF); + header[7] = @intCast((total_len >> 16) & 0xFF); + header[8] = @intCast((total_len >> 8) & 0xFF); + header[9] = @intCast(total_len & 0xFF); + header_len = 10; + } + + _ = try posix.write(client, header[0..header_len]); + if (prefix.len > 0) { + _ = try posix.write(client, prefix); + } + if (payload.len > 0) { + _ = try posix.write(client, payload); + } +} + +fn sendLargeMessage(client: posix.socket_t, size: usize) !void { + // Build header + var header: [10]u8 = undefined; + var header_len: usize = 2; + + header[0] = 0x81; // FIN + text + + if (size <= 125) { + header[1] = @intCast(size); + } else if (size <= 65535) { + header[1] = 126; + header[2] = @intCast((size >> 8) & 0xFF); + header[3] = @intCast(size & 0xFF); + header_len = 4; + } else { + header[1] = 127; + header[2] = @intCast((size >> 56) & 0xFF); + header[3] = @intCast((size >> 48) & 0xFF); + header[4] = @intCast((size >> 40) & 0xFF); + header[5] = @intCast((size >> 32) & 0xFF); + header[6] = @intCast((size >> 24) & 0xFF); + header[7] = @intCast((size >> 16) & 0xFF); + header[8] = @intCast((size >> 8) & 0xFF); + header[9] = @intCast(size & 0xFF); + header_len = 10; + } + + _ = try posix.write(client, header[0..header_len]); + + // Send payload in chunks - pattern of 'A'-'Z' repeating + var sent: usize = 0; + var chunk: [4096]u8 = undefined; + while (sent < size) { + const to_send = @min(chunk.len, size - sent); + for (chunk[0..to_send], 0..) |*b, i| { + b.* = @intCast('A' + ((sent + i) % 26)); + } + _ = try posix.write(client, chunk[0..to_send]); + sent += to_send; + } +} + +fn sendCloseFrame(client: posix.socket_t, code: u16, reason: []const u8) !void { + const reason_len = @min(reason.len, 123); // Max 123 bytes for reason + const payload_len = 2 + reason_len; + + var frame: [129]u8 = undefined; // 2 header + 2 code + 123 reason + 2 padding + frame[0] = 0x88; // FIN + close + frame[1] = @intCast(payload_len); + frame[2] = @intCast((code >> 8) & 0xFF); + frame[3] = @intCast(code & 0xFF); + if (reason_len > 0) { + @memcpy(frame[4..][0..reason_len], reason[0..reason_len]); + } + + _ = try posix.write(client, frame[0 .. 4 + reason_len]); +} diff --git a/src/browser/HttpClient.zig b/src/browser/HttpClient.zig index b845d4f8..b8b0abc9 100644 --- a/src/browser/HttpClient.zig +++ b/src/browser/HttpClient.zig @@ -28,6 +28,7 @@ const URL = @import("URL.zig"); const Config = @import("../Config.zig"); const Notification = @import("../Notification.zig"); const CookieJar = @import("webapi/storage/Cookie.zig").Jar; +const WebSocket = @import("webapi/net/WebSocket.zig"); const http = @import("../network/http.zig"); const Network = @import("../network/Network.zig"); @@ -113,6 +114,8 @@ obey_robots: bool, cdp_client: ?CDPClient = null, +max_response_size: usize, + // libcurl can monitor arbitrary sockets, this lets us use libcurl to poll // both HTTP data as well as messages from an CDP connection. // Furthermore, we have some tension between blocking scripts and request @@ -153,6 +156,7 @@ pub fn init(allocator: Allocator, network: *Network) !*Client { .http_proxy = http_proxy, .tls_verify = network.config.tlsVerifyHost(), .obey_robots = network.config.obeyRobots(), + .max_response_size = network.config.httpMaxResponseSize() orelse std.math.maxInt(u32), }; return client; @@ -221,16 +225,18 @@ fn _abort(self: *Client, comptime abort_all: bool, frame_id: u32) void { while (n) |node| { n = node.next; const conn: *http.Connection = @fieldParentPtr("node", node); - var transfer = Transfer.fromConnection(conn) catch |err| { - // Let's cleanup what we can - self.removeConn(conn); - log.err(.http, "get private info", .{ .err = err, .source = "abort" }); - continue; - }; - if (comptime abort_all) { - transfer.kill(); - } else if (transfer.req.frame_id == frame_id) { - transfer.kill(); + switch (conn.transport) { + .http => |transfer| { + if ((comptime abort_all) or transfer.req.frame_id == frame_id) { + transfer.kill(); + } + }, + .websocket => |ws| { + if ((comptime abort_all) or ws._page._frame_id == frame_id) { + ws.kill(); + } + }, + .none => unreachable, } } } @@ -636,7 +642,6 @@ fn makeTransfer(self: *Client, req: Request) !*Transfer { .req = req, .ctx = req.ctx, .client = self, - .max_response_size = self.network.config.httpMaxResponseSize(), }; return transfer; } @@ -663,15 +668,11 @@ fn makeRequest(self: *Client, conn: *http.Connection, transfer: *Transfer) anyer // fails BEFORE `curl_multi_add_handle` succeeds, the we still need to do // cleanup. But if things fail after `curl_multi_add_handle`, we expect // perfom to pickup the failure and cleanup. - self.in_use.append(&conn.node); - self.handles.add(conn) catch |err| { + self.trackConn(conn) catch |err| { transfer._conn = null; transfer.deinit(); - self.in_use.remove(&conn.node); - self.releaseConn(conn); return err; }; - self.active += 1; if (transfer.req.start_callback) |cb| { cb(transfer) catch |err| { @@ -735,7 +736,7 @@ fn processOneMessage(self: *Client, msg: http.Handles.MultiMessage, transfer: *T // Also check on RecvError: proxy may send 407 with headers before // closing the connection (CONNECT tunnel not yet established). if (msg.err == null or msg.err.? == error.RecvError) { - transfer.detectAuthChallenge(&msg.conn); + transfer.detectAuthChallenge(msg.conn); } // In case of auth challenge @@ -834,7 +835,7 @@ fn processOneMessage(self: *Client, msg: http.Handles.MultiMessage, transfer: *T if (!transfer._header_done_called) { // In case of request w/o data, we need to call the header done // callback now. - const proceed = try transfer.headerDoneCallback(&msg.conn); + const proceed = try transfer.headerDoneCallback(msg.conn); if (!proceed) { transfer.requestFailed(error.Abort, true); return true; @@ -871,30 +872,63 @@ fn processOneMessage(self: *Client, msg: http.Handles.MultiMessage, transfer: *T fn processMessages(self: *Client) !bool { var processed = false; - while (self.handles.readMessage()) |msg| { - const transfer = try Transfer.fromConnection(&msg.conn); - const done = self.processOneMessage(msg, transfer) catch |err| blk: { - log.err(.http, "process_messages", .{ .err = err, .req = transfer }); - transfer.requestFailed(err, true); - if (transfer._detached_conn) |c| { - // Conn was removed from handles during redirect reconfiguration - // but not re-added. Release it directly to avoid double-remove. - self.in_use.remove(&c.node); - self.active -= 1; - self.releaseConn(c); - transfer._detached_conn = null; - } - break :blk true; - }; - if (done) { - transfer.deinit(); - processed = true; + while (try self.handles.readMessage()) |msg| { + switch (msg.conn.transport) { + .http => |transfer| { + const done = self.processOneMessage(msg, transfer) catch |err| blk: { + log.err(.http, "process_messages", .{ .err = err, .req = transfer }); + transfer.requestFailed(err, true); + if (transfer._detached_conn) |c| { + // Conn was removed from handles during redirect reconfiguration + // but not re-added. Release it directly to avoid double-remove. + self.in_use.remove(&c.node); + self.active -= 1; + self.releaseConn(c); + transfer._detached_conn = null; + } + break :blk true; + }; + if (done) { + transfer.deinit(); + processed = true; + } + }, + .websocket => |ws| { + if (msg.err) |err| switch (err) { + error.GotNothing => ws.disconnected(null), + else => ws.disconnected(err), + } else { + // Clean close - no error + ws.disconnected(null); + } + + return true; + }, + .none => unreachable, } } return processed; } -fn removeConn(self: *Client, conn: *http.Connection) void { +pub fn trackConn(self: *Client, conn: *http.Connection) !void { + self.in_use.append(&conn.node); + // Set private pointer so readMessage can find the Connection. + // Must be done each time since curl_easy_reset clears it when + // connections are returned to pool. + conn.setPrivate(conn) catch |err| { + self.in_use.remove(&conn.node); + self.releaseConn(conn); + return err; + }; + self.handles.add(conn) catch |err| { + self.in_use.remove(&conn.node); + self.releaseConn(conn); + return err; + }; + self.active += 1; +} + +pub fn removeConn(self: *Client, conn: *http.Connection) void { self.in_use.remove(&conn.node); self.active -= 1; if (self.handles.remove(conn)) { @@ -927,7 +961,6 @@ pub const Request = struct { resource_type: ResourceType, credentials: ?[:0]const u8 = null, notification: *Notification, - max_response_size: ?usize = null, // This is only relevant for intercepted requests. If a request is flagged // as blocking AND is intercepted, then it'll be up to us to wait until @@ -980,8 +1013,6 @@ pub const Transfer = struct { aborted: bool = false, - max_response_size: ?usize = null, - // We'll store the response header here response_header: ?ResponseHead = null, @@ -1112,7 +1143,7 @@ pub const Transfer = struct { const req = &self.req; // Set callbacks and per-client settings on the pooled connection. - try conn.setCallbacks(Transfer.dataCallback); + try conn.setWriteCallback(Transfer.dataCallback); try conn.setFollowLocation(false); try conn.setProxy(client.http_proxy); try conn.setTlsVerify(client.tls_verify, client.use_proxy); @@ -1140,7 +1171,7 @@ pub const Transfer = struct { try conn.setCookies(@ptrCast(cookies.ptr)); } - try conn.setPrivate(self); + conn.transport = .{ .http = self }; // add credentials if (req.credentials) |creds| { @@ -1340,11 +1371,9 @@ pub const Transfer = struct { } } - if (transfer.max_response_size) |max_size| { - if (transfer.getContentLength()) |cl| { - if (cl > max_size) { - return error.ResponseTooLarge; - } + if (transfer.getContentLength()) |cl| { + if (cl > transfer.client.max_response_size) { + return error.ResponseTooLarge; } } @@ -1367,10 +1396,7 @@ pub const Transfer = struct { } const conn: *http.Connection = @ptrCast(@alignCast(data)); - var transfer = fromConnection(conn) catch |err| { - log.err(.http, "get private info", .{ .err = err, .source = "body callback" }); - return http.writefunc_error; - }; + var transfer = conn.transport.http; if (!transfer._first_data_received) { transfer._first_data_received = true; @@ -1387,11 +1413,9 @@ pub const Transfer = struct { // Pre-size buffer from Content-Length. if (transfer.getContentLength()) |cl| { - if (transfer.max_response_size) |max_size| { - if (cl > max_size) { - transfer._callback_error = error.ResponseTooLarge; - return http.writefunc_error; - } + if (cl > transfer.client.max_response_size) { + transfer._callback_error = error.ResponseTooLarge; + return http.writefunc_error; } transfer._stream_buffer.ensureTotalCapacity(transfer.arena.allocator(), cl) catch {}; } @@ -1400,11 +1424,9 @@ pub const Transfer = struct { if (transfer._skip_body) return @intCast(chunk_len); transfer.bytes_received += chunk_len; - if (transfer.max_response_size) |max_size| { - if (transfer.bytes_received > max_size) { - transfer._callback_error = error.ResponseTooLarge; - return http.writefunc_error; - } + if (transfer.bytes_received > transfer.client.max_response_size) { + transfer._callback_error = error.ResponseTooLarge; + return http.writefunc_error; } const chunk = buffer[0..chunk_len]; @@ -1433,11 +1455,6 @@ pub const Transfer = struct { return .{ .list = .{ .list = self.response_header.?._injected_headers } }; } - fn fromConnection(conn: *const http.Connection) !*Transfer { - const private = try conn.getPrivate(); - return @ptrCast(@alignCast(private)); - } - pub fn fulfill(transfer: *Transfer, status: u16, headers: []const http.Header, body: ?[]const u8) !void { if (transfer._conn != null) { // should never happen, should have been intercepted/paused, and then diff --git a/src/browser/js/bridge.zig b/src/browser/js/bridge.zig index 0a51327e..8fbdc315 100644 --- a/src/browser/js/bridge.zig +++ b/src/browser/js/bridge.zig @@ -829,6 +829,8 @@ pub const JsApis = flattenTypes(&.{ @import("../webapi/net/URLSearchParams.zig"), @import("../webapi/net/XMLHttpRequest.zig"), @import("../webapi/net/XMLHttpRequestEventTarget.zig"), + @import("../webapi/net/WebSocket.zig"), + @import("../webapi/event/CloseEvent.zig"), @import("../webapi/streams/ReadableStream.zig"), @import("../webapi/streams/ReadableStreamDefaultReader.zig"), @import("../webapi/streams/ReadableStreamDefaultController.zig"), diff --git a/src/browser/tests/net/websocket.html b/src/browser/tests/net/websocket.html new file mode 100644 index 00000000..8ad03a70 --- /dev/null +++ b/src/browser/tests/net/websocket.html @@ -0,0 +1,240 @@ + + + + + + + + + + + + + + + + + + diff --git a/src/browser/tests/net/websocket2.html b/src/browser/tests/net/websocket2.html new file mode 100644 index 00000000..d421867e --- /dev/null +++ b/src/browser/tests/net/websocket2.html @@ -0,0 +1,233 @@ + + + + + + + + + + + + + + + + + + diff --git a/src/browser/tests/net/websocket3.html b/src/browser/tests/net/websocket3.html new file mode 100644 index 00000000..12dc19cc --- /dev/null +++ b/src/browser/tests/net/websocket3.html @@ -0,0 +1,77 @@ + + + + + + + + diff --git a/src/browser/webapi/Event.zig b/src/browser/webapi/Event.zig index b48bc059..b573bfc7 100644 --- a/src/browser/webapi/Event.zig +++ b/src/browser/webapi/Event.zig @@ -80,6 +80,7 @@ pub const Type = union(enum) { promise_rejection_event: *@import("event/PromiseRejectionEvent.zig"), submit_event: *@import("event/SubmitEvent.zig"), form_data_event: *@import("event/FormDataEvent.zig"), + close_event: *@import("event/CloseEvent.zig"), }; pub const Options = struct { @@ -171,6 +172,7 @@ pub fn is(self: *Event, comptime T: type) ?*T { .promise_rejection_event => |e| return if (T == @import("event/PromiseRejectionEvent.zig")) e else null, .submit_event => |e| return if (T == @import("event/SubmitEvent.zig")) e else null, .form_data_event => |e| return if (T == @import("event/FormDataEvent.zig")) e else null, + .close_event => |e| return if (T == @import("event/CloseEvent.zig")) e else null, .ui_event => |e| { if (T == @import("event/UIEvent.zig")) { return e; diff --git a/src/browser/webapi/EventTarget.zig b/src/browser/webapi/EventTarget.zig index 704efeb3..60dfbf11 100644 --- a/src/browser/webapi/EventTarget.zig +++ b/src/browser/webapi/EventTarget.zig @@ -45,6 +45,7 @@ pub const Type = union(enum) { visual_viewport: *@import("VisualViewport.zig"), file_reader: *@import("FileReader.zig"), font_face_set: *@import("css/FontFaceSet.zig"), + websocket: *@import("net/WebSocket.zig"), }; pub fn init(page: *Page) !*EventTarget { @@ -141,6 +142,7 @@ pub fn format(self: *EventTarget, writer: *std.Io.Writer) !void { .visual_viewport => writer.writeAll(""), .file_reader => writer.writeAll(""), .font_face_set => writer.writeAll(""), + .websocket => writer.writeAll(""), }; } @@ -160,6 +162,7 @@ pub fn toString(self: *EventTarget) []const u8 { .visual_viewport => return "[object VisualViewport]", .file_reader => return "[object FileReader]", .font_face_set => return "[object FontFaceSet]", + .websocket => return "[object WebSocket]", }; } diff --git a/src/browser/webapi/MessagePort.zig b/src/browser/webapi/MessagePort.zig index 51d208b0..a7bb9bfc 100644 --- a/src/browser/webapi/MessagePort.zig +++ b/src/browser/webapi/MessagePort.zig @@ -125,7 +125,7 @@ const PostMessageCallback = struct { const target = self.port.asEventTarget(); if (page._event_manager.hasDirectListeners(target, "message", self.port._on_message)) { const event = (MessageEvent.initTrusted(comptime .wrap("message"), .{ - .data = self.message, + .data = .{ .value = self.message }, .origin = "", .source = null, }, page) catch |err| { diff --git a/src/browser/webapi/Window.zig b/src/browser/webapi/Window.zig index fb3ec8f8..621e1e3a 100644 --- a/src/browser/webapi/Window.zig +++ b/src/browser/webapi/Window.zig @@ -791,7 +791,7 @@ const PostMessageCallback = struct { const event_target = window.asEventTarget(); if (page._event_manager.hasDirectListeners(event_target, "message", window._on_message)) { const event = (try MessageEvent.initTrusted(comptime .wrap("message"), .{ - .data = self.message, + .data = .{ .value = self.message }, .origin = self.origin, .source = self.source, .bubbles = false, diff --git a/src/browser/webapi/event/CloseEvent.zig b/src/browser/webapi/event/CloseEvent.zig new file mode 100644 index 00000000..aa9f1d2b --- /dev/null +++ b/src/browser/webapi/event/CloseEvent.zig @@ -0,0 +1,102 @@ +// Copyright (C) 2023-2026 Lightpanda (Selecy SAS) +// +// Francis Bouvier +// Pierre Tachoire +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as +// published by the Free Software Foundation, either version 3 of the +// License, or (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +const std = @import("std"); +const String = @import("../../../string.zig").String; + +const Page = @import("../../Page.zig"); +const Session = @import("../../Session.zig"); +const Event = @import("../Event.zig"); +const Allocator = std.mem.Allocator; + +const CloseEvent = @This(); +_proto: *Event, +_code: u16 = 1000, +_reason: []const u8 = "", +_was_clean: bool = true, + +const CloseEventOptions = struct { + code: u16 = 1000, + reason: []const u8 = "", + wasClean: bool = true, +}; + +const Options = Event.inheritOptions(CloseEvent, CloseEventOptions); + +pub fn init(typ: []const u8, _opts: ?Options, page: *Page) !*CloseEvent { + const arena = try page.getArena(.{ .debug = "CloseEvent" }); + errdefer page.releaseArena(arena); + const type_string = try String.init(arena, typ, .{}); + return initWithTrusted(arena, type_string, _opts, false, page); +} + +pub fn initTrusted(typ: String, _opts: ?Options, page: *Page) !*CloseEvent { + const arena = try page.getArena(.{ .debug = "CloseEvent.trusted" }); + errdefer page.releaseArena(arena); + return initWithTrusted(arena, typ, _opts, true, page); +} + +fn initWithTrusted(arena: Allocator, typ: String, _opts: ?Options, trusted: bool, page: *Page) !*CloseEvent { + const opts = _opts orelse Options{}; + + const event = try page._factory.event( + arena, + typ, + CloseEvent{ + ._proto = undefined, + ._code = opts.code, + ._reason = if (opts.reason.len > 0) try arena.dupe(u8, opts.reason) else "", + ._was_clean = opts.wasClean, + }, + ); + + Event.populatePrototypes(event, opts, trusted); + return event; +} + +pub fn asEvent(self: *CloseEvent) *Event { + return self._proto; +} + +pub fn getCode(self: *const CloseEvent) u16 { + return self._code; +} + +pub fn getReason(self: *const CloseEvent) []const u8 { + return self._reason; +} + +pub fn getWasClean(self: *const CloseEvent) bool { + return self._was_clean; +} + +pub const JsApi = struct { + const js = @import("../../js/js.zig"); + pub const bridge = js.Bridge(CloseEvent); + + pub const Meta = struct { + pub const name = "CloseEvent"; + pub const prototype_chain = bridge.prototypeChain(); + pub var class_id: bridge.ClassId = undefined; + }; + + pub const constructor = bridge.constructor(CloseEvent.init, .{}); + pub const code = bridge.accessor(CloseEvent.getCode, null, .{}); + pub const reason = bridge.accessor(CloseEvent.getReason, null, .{}); + pub const wasClean = bridge.accessor(CloseEvent.getWasClean, null, .{}); +}; diff --git a/src/browser/webapi/event/MessageEvent.zig b/src/browser/webapi/event/MessageEvent.zig index 03530400..66ffd8c6 100644 --- a/src/browser/webapi/event/MessageEvent.zig +++ b/src/browser/webapi/event/MessageEvent.zig @@ -30,16 +30,22 @@ const Allocator = std.mem.Allocator; const MessageEvent = @This(); _proto: *Event, -_data: ?js.Value.Temp = null, +_data: ?Data = null, _origin: []const u8 = "", _source: ?*Window = null, const MessageEventOptions = struct { - data: ?js.Value.Temp = null, + data: ?Data = null, origin: ?[]const u8 = null, source: ?*Window = null, }; +pub const Data = union(enum) { + value: js.Value.Temp, + string: []const u8, + arraybuffer: js.ArrayBuffer, +}; + const Options = Event.inheritOptions(MessageEvent, MessageEventOptions); pub fn init(typ: []const u8, opts_: ?Options, page: *Page) !*MessageEvent { @@ -75,7 +81,10 @@ fn initWithTrusted(arena: Allocator, typ: String, opts_: ?Options, trusted: bool pub fn deinit(self: *MessageEvent, session: *Session) void { if (self._data) |d| { - d.release(); + switch (d) { + .value => |js_val| js_val.release(), + .string, .arraybuffer => {}, + } } self._proto.deinit(session); } @@ -92,7 +101,7 @@ pub fn asEvent(self: *MessageEvent) *Event { return self._proto; } -pub fn getData(self: *const MessageEvent) ?js.Value.Temp { +pub fn getData(self: *const MessageEvent) ?Data { return self._data; } diff --git a/src/browser/webapi/net/WebSocket.zig b/src/browser/webapi/net/WebSocket.zig new file mode 100644 index 00000000..1a07dcf6 --- /dev/null +++ b/src/browser/webapi/net/WebSocket.zig @@ -0,0 +1,687 @@ +// Copyright (C) 2023-2026 Lightpanda (Selecy SAS) +// +// Francis Bouvier +// Pierre Tachoire +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as +// published by the Free Software Foundation, either version 3 of the +// License, or (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +const std = @import("std"); +const lp = @import("lightpanda"); + +const log = @import("../../../log.zig"); +const http = @import("../../../network/http.zig"); + +const js = @import("../../js/js.zig"); +const Blob = @import("../Blob.zig"); +const URL = @import("../../URL.zig"); +const Page = @import("../../Page.zig"); +const Session = @import("../../Session.zig"); +const HttpClient = @import("../../HttpClient.zig"); + +const Event = @import("../Event.zig"); +const EventTarget = @import("../EventTarget.zig"); +const MessageEvent = @import("../event/MessageEvent.zig"); +const CloseEvent = @import("../event/CloseEvent.zig"); + +const Allocator = std.mem.Allocator; +const IS_DEBUG = @import("builtin").mode == .Debug; + +const WebSocket = @This(); +_rc: lp.RC(u8) = .{}, +_page: *Page, +_proto: *EventTarget, +_arena: Allocator, + +// Connection state +_ready_state: ReadyState = .connecting, +_url: [:0]const u8 = "", +_binary_type: BinaryType = .blob, + +// Handshake tracking +_got_101: bool = false, +_got_upgrade: bool = false, + +_conn: ?*http.Connection, +_http_client: *HttpClient, + +// buffered outgoing messages +_send_queue: std.ArrayList(Message) = .empty, +_send_offset: usize = 0, + +// buffered incoming frame +_recv_buffer: std.ArrayList(u8) = .empty, + +// close info for event dispatch +_close_code: u16 = 1000, +_close_reason: []const u8 = "", + +// Event handlers +_on_open: ?js.Function.Temp = null, +_on_message: ?js.Function.Temp = null, +_on_error: ?js.Function.Temp = null, +_on_close: ?js.Function.Temp = null, + +pub const ReadyState = enum(u8) { + connecting = 0, + open = 1, + closing = 2, + closed = 3, +}; + +pub const BinaryType = enum { + blob, + arraybuffer, +}; + +pub fn init(url: []const u8, protocols_: ?[]const u8, page: *Page) !*WebSocket { + if (protocols_) |protocols| { + if (protocols.len > 0) { + log.warn(.not_implemented, "WS protocols", .{ .protocols = protocols }); + } + } + + { + if (url.len < 6) { + return error.SyntaxError; + } + const normalized_start = std.ascii.lowerString(&page.buf, url[0..6]); + if (!std.mem.startsWith(u8, normalized_start, "ws://") and !std.mem.startsWith(u8, normalized_start, "wss://")) { + return error.SyntaxError; + } + } + + const arena = try page.getArena(.{ .debug = "WebSocket" }); + errdefer page.releaseArena(arena); + + const resolved_url = try URL.resolve(arena, page.base(), url, .{ .always_dupe = true, .encode = true }); + + const http_client = page._session.browser.http_client; + const conn = http_client.network.getConnection() orelse { + // TODO: figure out how/where we actually want to get WebSocket connections + // from. I feel like sharing this with the HTTP Connection Pool is a + // mistake. + return error.NoFreeConnection; + }; + + errdefer http_client.network.releaseConnection(conn); + + try conn.setURL(resolved_url); + try conn.setConnectOnly(false); + + try conn.setReadCallback(sendDataCallback, true); + try conn.setWriteCallback(receivedDataCallback); + try conn.setHeaderCallback(receivedHeaderCalllback); + + const self = try page._factory.eventTargetWithAllocator(arena, WebSocket{ + ._page = page, + ._conn = conn, + ._arena = arena, + ._proto = undefined, + ._url = resolved_url, + ._http_client = http_client, + }); + conn.transport = .{ .websocket = self }; + try http_client.trackConn(conn); + + if (comptime IS_DEBUG) { + log.info(.http, "WS connecting", .{ .url = url }); + } + + // Unlike an XHR object where we only selectively reference the instance + // while the request is actually inflight, WS connection is "inflight" from + // the moment it's created. + self.acquireRef(); + + return self; +} + +pub fn deinit(self: *WebSocket, session: *Session) void { + self.cleanup(); + + if (self._on_open) |func| { + func.release(); + } + if (self._on_message) |func| { + func.release(); + } + if (self._on_error) |func| { + func.release(); + } + if (self._on_close) |func| { + func.release(); + } + + for (self._send_queue.items) |msg| { + msg.deinit(session); + } + + session.releaseArena(self._arena); +} + +// we're being aborted internally (e.g. page shutting down) +pub fn kill(self: *WebSocket) void { + self.cleanup(); +} + +pub fn disconnected(self: *WebSocket, err_: ?anyerror) void { + const was_clean = self._ready_state == .closing and err_ == null; + self._ready_state = .closed; + + if (err_) |err| { + log.warn(.http, "WS disconnected", .{ .err = err, .url = self._url }); + } else { + log.info(.http, "WS disconnected", .{ .url = self._url, .reason = "closed" }); + } + + self.cleanup(); + + // Use 1006 (abnormal closure) if connection wasn't cleanly closed + const code = if (was_clean) self._close_code else 1006; + const reason = if (was_clean) self._close_reason else ""; + + self.dispatchCloseEvent(code, reason, was_clean) catch |err| { + log.err(.http, "WS close event dispatch failed", .{ .err = err }); + }; +} + +fn cleanup(self: *WebSocket) void { + if (self._conn) |conn| { + self._http_client.removeConn(conn); + self._conn = null; + self.releaseRef(self._page._session); + } +} + +pub fn releaseRef(self: *WebSocket, session: *Session) void { + self._rc.release(self, session); +} + +pub fn acquireRef(self: *WebSocket) void { + self._rc.acquire(); +} + +fn asEventTarget(self: *WebSocket) *EventTarget { + return self._proto; +} + +fn queueMessage(self: *WebSocket, msg: Message) !void { + const was_empty = self._send_queue.items.len == 0; + try self._send_queue.append(self._arena, msg); + + if (was_empty) { + // Unpause the send callback so libcurl will request data + if (self._conn) |conn| { + try conn.pause(.{ .cont = true }); + } + } +} + +/// WebSocket send() accepts string, Blob, ArrayBuffer, or TypedArray +const SendData = union(enum) { + blob: *Blob, + js_val: js.Value, +}; + +/// Union for extracting bytes from ArrayBuffer/TypedArray +const BinaryData = union(enum) { + int8: []i8, + uint8: []u8, + int16: []i16, + uint16: []u16, + int32: []i32, + uint32: []u32, + int64: []i64, + uint64: []u64, + + fn asBuffer(self: BinaryData) []u8 { + return switch (self) { + .int8 => |b| @as([*]u8, @ptrCast(b.ptr))[0..b.len], + .uint8 => |b| b, + .int16 => |b| @as([*]u8, @ptrCast(b.ptr))[0 .. b.len * 2], + .uint16 => |b| @as([*]u8, @ptrCast(b.ptr))[0 .. b.len * 2], + .int32 => |b| @as([*]u8, @ptrCast(b.ptr))[0 .. b.len * 4], + .uint32 => |b| @as([*]u8, @ptrCast(b.ptr))[0 .. b.len * 4], + .int64 => |b| @as([*]u8, @ptrCast(b.ptr))[0 .. b.len * 8], + .uint64 => |b| @as([*]u8, @ptrCast(b.ptr))[0 .. b.len * 8], + }; + } +}; + +pub fn send(self: *WebSocket, data: SendData) !void { + if (self._ready_state != .open) { + return error.InvalidStateError; + } + + // Get a dedicated arena for this message + const arena = try self._page._session.getArena(.{ .debug = "WebSocket message" }); + errdefer self._page._session.releaseArena(arena); + + switch (data) { + .blob => |blob| { + try self.queueMessage(.{ .binary = .{ + .arena = arena, + .data = try arena.dupe(u8, blob._slice), + } }); + }, + .js_val => |js_val| { + if (js_val.isString()) |str| { + try self.queueMessage(.{ .text = .{ + .arena = arena, + .data = try str.toSliceWithAlloc(arena), + } }); + } else { + const binary = try js_val.toZig(BinaryData); + try self.queueMessage(.{ .binary = .{ + .arena = arena, + .data = try arena.dupe(u8, binary.asBuffer()), + } }); + } + }, + } +} + +pub fn close(self: *WebSocket, code_: ?u16, reason_: ?[]const u8) !void { + if (self._ready_state == .closing or self._ready_state == .closed) { + return; + } + + const code = code_ orelse 1000; + const reason = reason_ orelse ""; + + if (self._ready_state == .connecting) { + // Connection not yet established - fail it + self._ready_state = .closed; + self.cleanup(); + try self.dispatchCloseEvent(code, reason, false); + return; + } + + self._ready_state = .closing; + self._close_code = code; + self._close_reason = try self._arena.dupe(u8, reason); + try self.queueMessage(.close); +} + +pub fn getUrl(self: *const WebSocket) []const u8 { + return self._url; +} + +pub fn getReadyState(self: *const WebSocket) u16 { + return @intFromEnum(self._ready_state); +} + +pub fn getBufferedAmount(self: *const WebSocket) u32 { + var buffered: u32 = 0; + for (self._send_queue.items) |msg| { + switch (msg) { + .text, .binary => |byte_msg| buffered += @intCast(byte_msg.data.len), + .close => buffered += @intCast(2 + self._close_reason.len), + } + } + return buffered; +} + +pub fn getProtocol(self: *const WebSocket) []const u8 { + return self._protocol; +} + +pub fn getExtensions(self: *const WebSocket) []const u8 { + return self._extensions; +} + +pub fn getBinaryType(self: *const WebSocket) []const u8 { + return @tagName(self._binary_type); +} + +pub fn setBinaryType(self: *WebSocket, value: []const u8) void { + if (std.meta.stringToEnum(BinaryType, value)) |bt| { + self._binary_type = bt; + } +} + +pub fn getOnOpen(self: *const WebSocket) ?js.Function.Temp { + return self._on_open; +} + +pub fn setOnOpen(self: *WebSocket, cb_: ?js.Function) !void { + if (self._on_open) |old| old.release(); + if (cb_) |cb| { + self._on_open = try cb.tempWithThis(self); + } else { + self._on_open = null; + } +} + +pub fn getOnMessage(self: *const WebSocket) ?js.Function.Temp { + return self._on_message; +} + +pub fn setOnMessage(self: *WebSocket, cb_: ?js.Function) !void { + if (self._on_message) |old| old.release(); + if (cb_) |cb| { + self._on_message = try cb.tempWithThis(self); + } else { + self._on_message = null; + } +} + +pub fn getOnError(self: *const WebSocket) ?js.Function.Temp { + return self._on_error; +} + +pub fn setOnError(self: *WebSocket, cb_: ?js.Function) !void { + if (self._on_error) |old| old.release(); + if (cb_) |cb| { + self._on_error = try cb.tempWithThis(self); + } else { + self._on_error = null; + } +} + +pub fn getOnClose(self: *const WebSocket) ?js.Function.Temp { + return self._on_close; +} + +pub fn setOnClose(self: *WebSocket, cb_: ?js.Function) !void { + if (self._on_close) |old| old.release(); + if (cb_) |cb| { + self._on_close = try cb.tempWithThis(self); + } else { + self._on_close = null; + } +} + +fn dispatchOpenEvent(self: *WebSocket) !void { + const page = self._page; + const target = self.asEventTarget(); + + if (page._event_manager.hasDirectListeners(target, "open", self._on_open)) { + const event = try Event.initTrusted(.wrap("open"), .{}, page); + try page._event_manager.dispatchDirect(target, event, self._on_open, .{ .context = "WebSocket open" }); + } +} + +fn dispatchMessageEvent(self: *WebSocket, data: []const u8, frame_type: http.WsFrameType) !void { + const page = self._page; + const target = self.asEventTarget(); + + if (page._event_manager.hasDirectListeners(target, "message", self._on_message)) { + const msg_data: MessageEvent.Data = if (frame_type == .binary and self._binary_type == .arraybuffer) + .{ .arraybuffer = .{ .values = data } } + else + .{ .string = data }; + + const event = try MessageEvent.initTrusted(.wrap("message"), .{ + .data = msg_data, + .origin = "", + }, page); + try page._event_manager.dispatchDirect(target, event.asEvent(), self._on_message, .{ .context = "WebSocket message" }); + } +} + +fn dispatchCloseEvent(self: *WebSocket, code: u16, reason: []const u8, was_clean: bool) !void { + const page = self._page; + const target = self.asEventTarget(); + + if (page._event_manager.hasDirectListeners(target, "close", self._on_close)) { + const event = try CloseEvent.initTrusted(.wrap("close"), .{ + .code = code, + .reason = reason, + .wasClean = was_clean, + }, page); + try page._event_manager.dispatchDirect(target, event.asEvent(), self._on_close, .{ .context = "WebSocket close" }); + } +} + +fn sendDataCallback(buffer: [*]u8, buf_count: usize, buf_len: usize, data: *anyopaque) usize { + if (comptime IS_DEBUG) { + std.debug.assert(buf_count == 1); + } + const conn: *http.Connection = @ptrCast(@alignCast(data)); + return _sendDataCallback(conn, buffer[0..buf_len]) catch |err| { + log.warn(.http, "WS send callback", .{ .err = err }); + return http.readfunc_pause; + }; +} + +fn _sendDataCallback(conn: *http.Connection, buf: []u8) !usize { + lp.assert(buf.len >= 2, "WS short buffer", .{ .len = buf.len }); + + const self = conn.transport.websocket; + + if (self._send_queue.items.len == 0) { + // No data to send - pause until queueMessage is called + return http.readfunc_pause; + } + + const msg = &self._send_queue.items[0]; + + switch (msg.*) { + .close => { + const code = self._close_code; + const reason = self._close_reason; + + // Close frame: 2 bytes for code (big-endian) + optional reason + // Truncate reason to fit in buf (max 123 bytes per spec) + const reason_len: usize = @min(reason.len, 123, buf.len -| 2); + const frame_len = 2 + reason_len; + const to_copy = @min(buf.len, frame_len); + + var close_payload: [125]u8 = undefined; + close_payload[0] = @intCast((code >> 8) & 0xFF); + close_payload[1] = @intCast(code & 0xFF); + if (reason_len > 0) { + @memcpy(close_payload[2..][0..reason_len], reason[0..reason_len]); + } + + try conn.wsStartFrame(.close, to_copy); + @memcpy(buf[0..to_copy], close_payload[0..to_copy]); + + _ = self._send_queue.orderedRemove(0); + return to_copy; + }, + .text => |content| return self.writeContent(conn, buf, content, .text), + .binary => |content| return self.writeContent(conn, buf, content, .binary), + } +} + +fn writeContent(self: *WebSocket, conn: *http.Connection, buf: []u8, byte_msg: Message.Content, frame_type: http.WsFrameType) !usize { + if (self._send_offset == 0) { + // start of the message + try conn.wsStartFrame(frame_type, byte_msg.data.len); + } + + const remaining = byte_msg.data[self._send_offset..]; + const to_copy = @min(remaining.len, buf.len); + @memcpy(buf[0..to_copy], remaining[0..to_copy]); + + self._send_offset += to_copy; + + if (self._send_offset >= byte_msg.data.len) { + const removed = self._send_queue.orderedRemove(0); + removed.deinit(self._page._session); + self._send_offset = 0; + } + + return to_copy; +} + +fn receivedDataCallback(buffer: [*]const u8, buf_count: usize, buf_len: usize, data: *anyopaque) usize { + if (comptime IS_DEBUG) { + std.debug.assert(buf_count == 1); + } + const conn: *http.Connection = @ptrCast(@alignCast(data)); + _receivedDataCallback(conn, buffer[0..buf_len]) catch |err| { + log.warn(.http, "WS receive callback", .{ .err = err }); + // TODO: are there errors, like an invalid frame, that we shouldn't treat + // as an error? + return http.writefunc_error; + }; + + return buf_len; +} + +fn _receivedDataCallback(conn: *http.Connection, data: []const u8) !void { + const self = conn.transport.websocket; + const meta = conn.wsMeta() orelse { + log.err(.http, "WS missing meta", .{ .url = self._url }); + return error.NoFrameMeta; + }; + + if (meta.offset == 0) { + // Start of new frame. Pre-allocate buffer + self._recv_buffer.clearRetainingCapacity(); + if (meta.len > self._http_client.max_response_size) { + return error.MessageTooLarge; + } + try self._recv_buffer.ensureTotalCapacity(self._arena, meta.len); + } + + try self._recv_buffer.appendSlice(self._arena, data); + + if (meta.bytes_left > 0) { + // still more data waiting for this frame + return; + } + + const message = self._recv_buffer.items; + switch (meta.frame_type) { + .text, .binary => try self.dispatchMessageEvent(message, meta.frame_type), + .close => { + // Parse close frame: 2-byte code (big-endian) + optional reason + self._close_code = if (message.len >= 2) + @as(u16, message[0]) << 8 | message[1] + else + 1005; // No status code received + if (message.len > 2) { + self._close_reason = try self._arena.dupe(u8, message[2..]); + } + self._ready_state = .closing; + self.disconnected(null); + }, + .ping, .pong, .cont => {}, + } +} + +// libcurl has no mechanism to signal that the connection is established. The +// best option I could come up with was looking for an upgrade header response. +fn receivedHeaderCalllback(buffer: [*]const u8, header_count: usize, buf_len: usize, data: *anyopaque) usize { + if (comptime IS_DEBUG) { + std.debug.assert(header_count == 1); + } + const conn: *http.Connection = @ptrCast(@alignCast(data)); + const self = conn.transport.websocket; + const header = buffer[0..buf_len]; + + if (self._got_101 == false and std.mem.startsWith(u8, header, "HTTP/")) { + if (std.mem.indexOf(u8, header, " 101 ")) |_| { + self._got_101 = true; + } + return buf_len; + } + + // Empty line = end of headers + if (buf_len <= 2) { + if (!self._got_101 or !self._got_upgrade) { + return 0; + } + + self._ready_state = .open; + log.info(.http, "WS connected", .{ .url = self._url }); + + self.dispatchOpenEvent() catch |err| { + log.err(.http, "WS open event fail", .{ .err = err }); + }; + return buf_len; + } + + if (self._got_upgrade) { + // dont' care about headers once we've gotten the upgrade header + return buf_len; + } + + const colon = std.mem.indexOfScalarPos(u8, header, 0, ':') orelse { + // weird, continue... + return buf_len; + }; + + if (std.ascii.eqlIgnoreCase(header[0..colon], "upgrade") == false) { + return buf_len; + } + + const value = std.mem.trim(u8, header[colon + 1 ..], " \t\r\n"); + if (std.ascii.eqlIgnoreCase(value, "websocket")) { + self._got_upgrade = true; + } + + return buf_len; +} + +const Message = union(enum) { + close, + text: Content, + binary: Content, + + const Content = struct { + arena: Allocator, + data: []const u8, + }; + fn deinit(self: Message, session: *Session) void { + switch (self) { + .text, .binary => |msg| session.releaseArena(msg.arena), + .close => {}, + } + } +}; + +pub const JsApi = struct { + pub const bridge = js.Bridge(WebSocket); + + pub const Meta = struct { + pub const name = "WebSocket"; + pub const prototype_chain = bridge.prototypeChain(); + pub var class_id: bridge.ClassId = undefined; + }; + + pub const constructor = bridge.constructor(WebSocket.init, .{ .dom_exception = true }); + + pub const CONNECTING = bridge.property(@intFromEnum(ReadyState.connecting), .{ .template = true }); + pub const OPEN = bridge.property(@intFromEnum(ReadyState.open), .{ .template = true }); + pub const CLOSING = bridge.property(@intFromEnum(ReadyState.closing), .{ .template = true }); + pub const CLOSED = bridge.property(@intFromEnum(ReadyState.closed), .{ .template = true }); + + pub const url = bridge.accessor(WebSocket.getUrl, null, .{}); + pub const readyState = bridge.accessor(WebSocket.getReadyState, null, .{}); + pub const bufferedAmount = bridge.accessor(WebSocket.getBufferedAmount, null, .{}); + pub const binaryType = bridge.accessor(WebSocket.getBinaryType, WebSocket.setBinaryType, .{}); + + pub const protocol = bridge.property("", .{ .template = false }); + pub const extensions = bridge.property("", .{ .template = false }); + + pub const onopen = bridge.accessor(WebSocket.getOnOpen, WebSocket.setOnOpen, .{}); + pub const onmessage = bridge.accessor(WebSocket.getOnMessage, WebSocket.setOnMessage, .{}); + pub const onerror = bridge.accessor(WebSocket.getOnError, WebSocket.setOnError, .{}); + pub const onclose = bridge.accessor(WebSocket.getOnClose, WebSocket.setOnClose, .{}); + + pub const send = bridge.function(WebSocket.send, .{ .dom_exception = true }); + pub const close = bridge.function(WebSocket.close, .{}); +}; + +const testing = @import("../../../testing.zig"); +test "WebApi: WebSocket" { + // TEMP since we're currently limited to 10 concurrent connections + try testing.htmlRunner("net/websocket.html", .{}); + try testing.htmlRunner("net/websocket2.html", .{}); + try testing.htmlRunner("net/websocket3.html", .{}); +} diff --git a/src/network/http.zig b/src/network/http.zig index 2bfabac0..f5adccb7 100644 --- a/src/network/http.zig +++ b/src/network/http.zig @@ -28,7 +28,9 @@ pub const ENABLE_DEBUG = false; pub const Blob = libcurl.CurlBlob; pub const WaitFd = libcurl.CurlWaitFd; +pub const readfunc_pause = libcurl.curl_readfunc_pause; pub const writefunc_error = libcurl.curl_writefunc_error; +pub const WsFrameType = libcurl.WsFrameType; const Error = libcurl.Error; @@ -211,15 +213,19 @@ pub const ResponseHead = struct { pub const Connection = struct { _easy: *libcurl.Curl, + transport: Transport, node: std.DoublyLinkedList.Node = .{}, - pub fn init( - ca_blob: ?libcurl.CurlBlob, - config: *const Config, - ) !Connection { + pub const Transport = union(enum) { + none, // used for cases that manage their own connection, e.g. telemetry + http: *@import("../browser/HttpClient.zig").Transfer, + websocket: *@import("../browser/webapi/net/WebSocket.zig"), + }; + + pub fn init(ca_blob: ?libcurl.CurlBlob, config: *const Config) !Connection { const easy = libcurl.curl_easy_init() orelse return error.FailedToInitializeEasy; - const self = Connection{ ._easy = easy }; + var self = Connection{ ._easy = easy, .transport = .none }; errdefer self.deinit(); try self.reset(config, ca_blob); @@ -299,7 +305,12 @@ pub const Connection = struct { try libcurl.curl_easy_setopt(self._easy, .user_pwd, creds.ptr); } - pub fn setCallbacks( + pub fn setConnectOnly(self: *const Connection, connect_only: bool) !void { + const value: c_long = if (connect_only) 2 else 0; + try libcurl.curl_easy_setopt(self._easy, .connect_only, value); + } + + pub fn setWriteCallback( self: *Connection, comptime data_cb: libcurl.CurlWriteFunction, ) !void { @@ -307,12 +318,49 @@ pub const Connection = struct { try libcurl.curl_easy_setopt(self._easy, .write_function, data_cb); } + pub fn setReadCallback( + self: *Connection, + comptime data_cb: libcurl.CurlReadFunction, + upload: bool, + ) !void { + try libcurl.curl_easy_setopt(self._easy, .read_data, self); + try libcurl.curl_easy_setopt(self._easy, .read_function, data_cb); + if (upload) { + try libcurl.curl_easy_setopt(self._easy, .upload, true); + } + } + + pub fn setHeaderCallback( + self: *Connection, + comptime data_cb: libcurl.CurlHeaderFunction, + ) !void { + try libcurl.curl_easy_setopt(self._easy, .header_data, self); + try libcurl.curl_easy_setopt(self._easy, .header_function, data_cb); + } + + pub const PauseFlags = packed struct { + red: bool = false, + green: bool = false, + blue: bool = false, + alpha: bool = false, + // Optional padding to match a specific size, e.g., a u32 + _padding: u28 = 0, + }; + + pub fn pause( + self: *Connection, + flags: libcurl.CurlPauseFlags, + ) !void { + try libcurl.curl_easy_pause(self._easy, flags); + } + pub fn reset( - self: *const Connection, + self: *Connection, config: *const Config, ca_blob: ?libcurl.CurlBlob, ) !void { libcurl.curl_easy_reset(self._easy); + self.transport = .none; // timeouts try libcurl.curl_easy_setopt(self._easy, .timeout_ms, config.httpTimeout()); @@ -449,12 +497,6 @@ pub const Connection = struct { }; } - pub fn getPrivate(self: *const Connection) !*anyopaque { - var private: *anyopaque = undefined; - try libcurl.curl_easy_getinfo(self._easy, .private, &private); - return private; - } - // These are headers that may not be send to the users for inteception. pub fn secretHeaders(_: *const Connection, headers: *Headers, http_headers: *const Config.HttpHeaders) !void { if (http_headers.proxy_bearer_header) |hdr| { @@ -471,6 +513,14 @@ pub const Connection = struct { try libcurl.curl_easy_perform(self._easy); return self.getResponseCode(); } + + pub fn wsStartFrame(self: *const Connection, frame_type: libcurl.WsFrameType, size: usize) !void { + try libcurl.curl_ws_start_frame(self._easy, frame_type, @intCast(size)); + } + + pub fn wsMeta(self: *const Connection) ?libcurl.WsFrameMeta { + return libcurl.curl_ws_meta(self._easy); + } }; pub const Handles = struct { @@ -508,17 +558,21 @@ pub const Handles = struct { } pub const MultiMessage = struct { - conn: Connection, + conn: *Connection, err: ?Error, }; - pub fn readMessage(self: *Handles) ?MultiMessage { + pub fn readMessage(self: *Handles) !?MultiMessage { var messages_count: c_int = 0; const msg = libcurl.curl_multi_info_read(self.multi, &messages_count) orelse return null; return switch (msg.data) { - .done => |err| .{ - .conn = .{ ._easy = msg.easy_handle }, - .err = err, + .done => |err| { + var private: *anyopaque = undefined; + try libcurl.curl_easy_getinfo(msg.easy_handle, .private, &private); + return .{ + .conn = @ptrCast(@alignCast(private)), + .err = err, + }; }, else => unreachable, }; diff --git a/src/sys/libcurl.zig b/src/sys/libcurl.zig index 0e2defe3..31587823 100644 --- a/src/sys/libcurl.zig +++ b/src/sys/libcurl.zig @@ -40,6 +40,8 @@ pub const CurlDebugFunction = fn (*Curl, CurlInfoType, [*c]u8, usize, *anyopaque pub const CurlHeaderFunction = fn ([*]const u8, usize, usize, *anyopaque) usize; pub const CurlWriteFunction = fn ([*]const u8, usize, usize, *anyopaque) usize; pub const curl_writefunc_error: usize = c.CURL_WRITEFUNC_ERROR; +pub const curl_readfunc_pause: usize = c.CURL_READFUNC_PAUSE; +pub const CurlReadFunction = fn ([*]u8, usize, usize, *anyopaque) usize; pub const FreeCallback = fn (ptr: ?*anyopaque) void; pub const StrdupCallback = fn (str: [*:0]const u8) ?[*:0]u8; @@ -98,6 +100,23 @@ pub const CurlWaitFd = extern struct { revents: CurlWaitEvents, }; +pub const CurlPauseFlags = packed struct(c_short) { + recv: bool = false, + send: bool = false, + all: bool = false, + cont: bool = false, + _reserved: u12 = 0, + + pub fn to_c(self: @This()) c_int { + var flags: c_int = 0; + if (self.recv) flags |= c.CURLPAUSE_RECV; + if (self.send) flags |= c.CURLPAUSE_SEND; + if (self.all) flags |= c.CURLPAUSE_ALL; + if (self.cont) flags |= c.CURLPAUSE_CONT; + return flags; + } +}; + comptime { const debug_cb_check: c.curl_debug_callback = struct { fn cb(handle: ?*Curl, msg_type: c.curl_infotype, raw: [*c]u8, len: usize, user: ?*anyopaque) callconv(.c) c_int { @@ -167,6 +186,10 @@ pub const CurlOption = enum(c.CURLoption) { header_function = c.CURLOPT_HEADERFUNCTION, write_data = c.CURLOPT_WRITEDATA, write_function = c.CURLOPT_WRITEFUNCTION, + read_data = c.CURLOPT_READDATA, + read_function = c.CURLOPT_READFUNCTION, + connect_only = c.CURLOPT_CONNECT_ONLY, + upload = c.CURLOPT_UPLOAD, }; pub const CurlMOption = enum(c.CURLMoption) { @@ -530,6 +553,7 @@ pub fn curl_easy_setopt(easy: *Curl, comptime option: CurlOption, value: anytype const code = switch (option) { .verbose, .post, + .upload, .http_get, .ssl_verify_host, .ssl_verify_peer, @@ -551,6 +575,7 @@ pub fn curl_easy_setopt(easy: *Curl, comptime option: CurlOption, value: anytype .max_redirs, .follow_location, .post_field_size, + .connect_only, => blk: { const n: c_long = switch (@typeInfo(@TypeOf(value))) { .comptime_int, .int => @intCast(value), @@ -593,6 +618,7 @@ pub fn curl_easy_setopt(easy: *Curl, comptime option: CurlOption, value: anytype .private, .header_data, + .read_data, .write_data, => blk: { const ptr: ?*anyopaque = switch (@typeInfo(@TypeOf(value))) { @@ -631,6 +657,22 @@ pub fn curl_easy_setopt(easy: *Curl, comptime option: CurlOption, value: anytype break :blk c.curl_easy_setopt(easy, opt, cb); }, + .read_function => blk: { + const cb: c.curl_write_callback = switch (@typeInfo(@TypeOf(value))) { + .null => null, + .@"fn" => |info| struct { + fn cb(buffer: [*c]u8, count: usize, len: usize, user: ?*anyopaque) callconv(.c) usize { + const user_arg = if (@typeInfo(info.params[3].type.?) == .optional) + user + else + user orelse unreachable; + return value(@ptrCast(buffer), count, len, user_arg); + } + }.cb, + else => @compileError("expected Zig function or null for " ++ @tagName(option) ++ ", got " ++ @typeName(@TypeOf(value))), + }; + break :blk c.curl_easy_setopt(easy, opt, cb); + }, .write_function => blk: { const cb: c.curl_write_callback = switch (@typeInfo(@TypeOf(value))) { .null => null, @@ -677,6 +719,10 @@ pub fn curl_easy_getinfo(easy: *Curl, comptime info: CurlInfo, out: anytype) Err try errorCheck(code); } +pub fn curl_easy_pause(easy: *Curl, flags: CurlPauseFlags) Error!void { + try errorCheck(c.curl_easy_pause(easy, flags.to_c())); +} + pub fn curl_easy_header( easy: *Curl, name: [*:0]const u8, @@ -804,3 +850,79 @@ pub fn curl_slist_free_all(list: ?*CurlSList) void { c.curl_slist_free_all(ptr); } } + +// WebSocket support (requires libcurl 7.86.0+) +pub const WsFrameType = enum { + text, + binary, + cont, + close, + ping, + pong, + + fn toInt(self: WsFrameType) c_uint { + return switch (self) { + .text => c.CURLWS_TEXT, + .binary => c.CURLWS_BINARY, + .cont => c.CURLWS_CONT, + .close => c.CURLWS_CLOSE, + .ping => c.CURLWS_PING, + .pong => c.CURLWS_PONG, + }; + } + + fn fromFlags(flags: c_int) WsFrameType { + const f: c_uint = @bitCast(flags); + if (f & c.CURLWS_TEXT != 0) return .text; + if (f & c.CURLWS_BINARY != 0) return .binary; + if (f & c.CURLWS_CLOSE != 0) return .close; + if (f & c.CURLWS_PING != 0) return .ping; + if (f & c.CURLWS_PONG != 0) return .pong; + if (f & c.CURLWS_CONT != 0) return .cont; + return .binary; // default fallback + } +}; + +pub const WsFrameMeta = struct { + frame_type: WsFrameType, + offset: usize, + bytes_left: usize, + len: usize, + + fn from(frame: *const c.curl_ws_frame) WsFrameMeta { + return .{ + .frame_type = WsFrameType.fromFlags(frame.flags), + .offset = @intCast(frame.offset), + .bytes_left = @intCast(frame.bytesleft), + .len = if (frame.len < 0) + std.math.maxInt(usize) + else + @intCast(frame.len), + }; + } +}; + +pub fn curl_ws_send(easy: *Curl, buffer: []const u8, sent: *usize, fragsize: CurlOffT, frame_type: WsFrameType) Error!void { + try errorCheck(c.curl_ws_send(easy, buffer.ptr, buffer.len, sent, fragsize, frame_type.toInt())); +} + +pub fn curl_ws_recv(easy: *Curl, buffer: []u8, recv: *usize, meta: *?WsFrameMeta) Error!void { + var c_meta: [*c]const c.curl_ws_frame = null; + const code = c.curl_ws_recv(easy, buffer.ptr, buffer.len, recv, &c_meta); + if (c_meta) |m| { + meta.* = WsFrameMeta.from(m); + } else { + meta.* = null; + } + try errorCheck(code); +} + +pub fn curl_ws_meta(easy: *Curl) ?WsFrameMeta { + const ptr = c.curl_ws_meta(easy); + if (ptr == null) return null; + return WsFrameMeta.from(ptr); +} + +pub fn curl_ws_start_frame(easy: *Curl, frame_type: WsFrameType, size: CurlOffT) Error!void { + try errorCheck(c.curl_ws_start_frame(easy, frame_type.toInt(), size)); +} diff --git a/src/testing.zig b/src/testing.zig index 8ff59751..8dd2eb88 100644 --- a/src/testing.zig +++ b/src/testing.zig @@ -436,19 +436,17 @@ fn runWebApiTest(test_file: [:0]const u8) !void { if (js_val.isTrue()) { return; } - switch (try runner.tick(.{ .ms = 20 })) { - .done => return error.TestNeverSignaledCompletion, - .ok => |next_ms| { - const ms_elapsed = timer.lap() / 1_000_000; - if (ms_elapsed >= wait_ms) { - return error.TestTimedOut; - } - wait_ms -= @intCast(ms_elapsed); - if (next_ms > 0) { - std.Thread.sleep(std.time.ns_per_ms * next_ms); - } - }, + const sleep_ms: usize = switch (try runner.tick(.{ .ms = 20 })) { + .done => 20, + .ok => |next_ms| @min(next_ms, 20), + }; + + const ms_elapsed = timer.lap() / 1_000_000; + if (ms_elapsed >= wait_ms) { + return error.TestTimedOut; } + wait_ms -= @intCast(ms_elapsed); + std.Thread.sleep(std.time.ns_per_ms * sleep_ms); } } @@ -476,12 +474,15 @@ pub fn pageTest(comptime test_file: []const u8, opts: PageTestOpts) !*Page { const log = @import("log.zig"); const TestHTTPServer = @import("TestHTTPServer.zig"); +const TestWSServer = @import("TestWSServer.zig"); const Server = @import("Server.zig"); var test_cdp_server: ?*Server = null; var test_cdp_server_thread: ?std.Thread = null; var test_http_server: ?TestHTTPServer = null; var test_http_server_thread: ?std.Thread = null; +var test_ws_server: ?TestWSServer = null; +var test_ws_server_thread: ?std.Thread = null; var test_config: Config = undefined; @@ -514,13 +515,16 @@ test "tests:beforeAll" { test_session = try test_browser.newSession(test_notification); var wg: std.Thread.WaitGroup = .{}; - wg.startMany(2); + wg.startMany(3); test_cdp_server_thread = try std.Thread.spawn(.{}, serveCDP, .{&wg}); test_http_server = TestHTTPServer.init(testHTTPHandler); test_http_server_thread = try std.Thread.spawn(.{}, TestHTTPServer.run, .{ &test_http_server.?, &wg }); + test_ws_server = TestWSServer.init(); + test_ws_server_thread = try std.Thread.spawn(.{}, TestWSServer.run, .{ &test_ws_server.?, &wg }); + // need to wait for the servers to be listening, else tests will fail because // they aren't able to connect. wg.wait(); @@ -545,6 +549,16 @@ test "tests:afterAll" { server.deinit(); } + if (test_ws_server) |*server| { + server.stop(); + } + if (test_ws_server_thread) |thread| { + thread.join(); + } + if (test_ws_server) |*server| { + server.deinit(); + } + @import("root").v8_peak_memory = test_browser.env.isolate.getHeapStatistics().total_physical_size; test_notification.deinit();