diff --git a/src/App.zig b/src/App.zig index 38283e10..55b9cda1 100644 --- a/src/App.zig +++ b/src/App.zig @@ -27,6 +27,7 @@ const Platform = @import("browser/js/Platform.zig"); const Notification = @import("Notification.zig"); const Telemetry = @import("telemetry/telemetry.zig").Telemetry; +const SharedState = @import("SharedState.zig"); // Container for global state / objects that various parts of the system // might need. @@ -40,6 +41,7 @@ telemetry: Telemetry, allocator: Allocator, app_dir_path: ?[]const u8, notification: *Notification, +shared: ?*SharedState = null, shutdown: bool = false, pub const RunMode = enum { @@ -59,6 +61,8 @@ pub const Config = struct { http_max_host_open: ?u8 = null, http_max_concurrent: ?u8 = null, user_agent: [:0]const u8, + max_sessions: u32 = 10, // Max concurrent CDP connections + session_memory_limit: usize = 64 * 1024 * 1024, // 64MB per session }; pub fn init(allocator: Allocator, config: Config) !*App { @@ -67,6 +71,8 @@ pub fn init(allocator: Allocator, config: Config) !*App { app.config = config; app.allocator = allocator; + app.shared = null; + app.shutdown = false; app.notification = try Notification.init(allocator, null); errdefer app.notification.deinit(); @@ -105,6 +111,12 @@ pub fn deinit(self: *App) void { } const allocator = self.allocator; + + if (self.shared) |shared| { + shared.deinit(); + self.shared = null; + } + if (self.app_dir_path) |app_dir_path| { allocator.free(app_dir_path); self.app_dir_path = null; @@ -118,6 +130,31 @@ pub fn deinit(self: *App) void { allocator.destroy(self); } +/// Create SharedState for multi-session server mode. +/// This must be called before starting the server for multi-CDP support. +pub fn createSharedState(self: *App) !*SharedState { + if (self.shared != null) { + return error.SharedStateAlreadyExists; + } + + const shared = try SharedState.init(self.allocator, .{ + .run_mode = self.config.run_mode, + .tls_verify_host = self.config.tls_verify_host, + .http_proxy = self.config.http_proxy, + .proxy_bearer_token = self.config.proxy_bearer_token, + .http_timeout_ms = self.config.http_timeout_ms, + .http_connect_timeout_ms = self.config.http_connect_timeout_ms, + .http_max_host_open = self.config.http_max_host_open, + .http_max_concurrent = self.config.http_max_concurrent, + .user_agent = self.config.user_agent, + .max_sessions = self.config.max_sessions, + .session_memory_limit = self.config.session_memory_limit, + }); + + self.shared = shared; + return shared; +} + fn getAndMakeAppDir(allocator: Allocator) ?[]const u8 { if (@import("builtin").is_test) { return allocator.dupe(u8, "/tmp") catch unreachable; diff --git a/src/LimitedAllocator.zig b/src/LimitedAllocator.zig new file mode 100644 index 00000000..4445d5dc --- /dev/null +++ b/src/LimitedAllocator.zig @@ -0,0 +1,186 @@ +// Copyright (C) 2023-2025 Lightpanda (Selecy SAS) +// +// Francis Bouvier +// Pierre Tachoire +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as +// published by the Free Software Foundation, either version 3 of the +// License, or (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +const std = @import("std"); +const Allocator = std.mem.Allocator; + +/// Per-session memory limiting allocator. +/// Wraps a backing allocator and enforces a maximum memory limit. +/// Thread-local: each SessionThread creates its own LimitedAllocator. +const LimitedAllocator = @This(); + +backing: Allocator, +bytes_allocated: usize, +max_bytes: usize, + +pub fn init(backing: Allocator, max_bytes: usize) LimitedAllocator { + return .{ + .backing = backing, + .bytes_allocated = 0, + .max_bytes = max_bytes, + }; +} + +pub fn allocator(self: *LimitedAllocator) Allocator { + return .{ + .ptr = self, + .vtable = &vtable, + }; +} + +pub fn bytesAllocated(self: *const LimitedAllocator) usize { + return self.bytes_allocated; +} + +pub fn bytesRemaining(self: *const LimitedAllocator) usize { + return self.max_bytes -| self.bytes_allocated; +} + +const vtable: Allocator.VTable = .{ + .alloc = alloc, + .resize = resize, + .remap = remap, + .free = free, +}; + +fn alloc(ctx: *anyopaque, len: usize, alignment: std.mem.Alignment, ret_addr: usize) ?[*]u8 { + const self: *LimitedAllocator = @ptrCast(@alignCast(ctx)); + + if (self.bytes_allocated +| len > self.max_bytes) { + return null; // Out of memory for this session + } + + const result = self.backing.rawAlloc(len, alignment, ret_addr); + if (result != null) { + self.bytes_allocated += len; + } + return result; +} + +fn resize(ctx: *anyopaque, buf: []u8, alignment: std.mem.Alignment, new_len: usize, ret_addr: usize) bool { + const self: *LimitedAllocator = @ptrCast(@alignCast(ctx)); + + if (new_len > buf.len) { + const additional = new_len - buf.len; + if (self.bytes_allocated +| additional > self.max_bytes) { + return false; // Would exceed limit + } + } + + if (self.backing.rawResize(buf, alignment, new_len, ret_addr)) { + if (new_len > buf.len) { + self.bytes_allocated += new_len - buf.len; + } else { + self.bytes_allocated -= buf.len - new_len; + } + return true; + } + return false; +} + +fn remap(ctx: *anyopaque, buf: []u8, alignment: std.mem.Alignment, new_len: usize, ret_addr: usize) ?[*]u8 { + const self: *LimitedAllocator = @ptrCast(@alignCast(ctx)); + + if (new_len > buf.len) { + const additional = new_len - buf.len; + if (self.bytes_allocated +| additional > self.max_bytes) { + return null; // Would exceed limit + } + } + + const result = self.backing.rawRemap(buf, alignment, new_len, ret_addr); + if (result != null) { + if (new_len > buf.len) { + self.bytes_allocated += new_len - buf.len; + } else { + self.bytes_allocated -= buf.len - new_len; + } + } + return result; +} + +fn free(ctx: *anyopaque, buf: []u8, alignment: std.mem.Alignment, ret_addr: usize) void { + const self: *LimitedAllocator = @ptrCast(@alignCast(ctx)); + self.bytes_allocated -|= buf.len; + self.backing.rawFree(buf, alignment, ret_addr); +} + +const testing = std.testing; + +test "LimitedAllocator: basic allocation" { + var limited = LimitedAllocator.init(testing.allocator, 1024); + const alloc_ = limited.allocator(); + + const slice = try alloc_.alloc(u8, 100); + defer alloc_.free(slice); + + try testing.expectEqual(100, limited.bytesAllocated()); + try testing.expectEqual(924, limited.bytesRemaining()); +} + +test "LimitedAllocator: exceeds limit" { + var limited = LimitedAllocator.init(testing.allocator, 100); + const alloc_ = limited.allocator(); + + // Allocation should fail with OutOfMemory when exceeding limit + try testing.expectError(error.OutOfMemory, alloc_.alloc(u8, 200)); + try testing.expectEqual(0, limited.bytesAllocated()); +} + +test "LimitedAllocator: free updates counter" { + var limited = LimitedAllocator.init(testing.allocator, 1024); + const alloc_ = limited.allocator(); + + const slice = try alloc_.alloc(u8, 100); + try testing.expectEqual(100, limited.bytesAllocated()); + + alloc_.free(slice); + try testing.expectEqual(0, limited.bytesAllocated()); +} + +test "LimitedAllocator: multiple allocations" { + var limited = LimitedAllocator.init(testing.allocator, 1024); + const alloc_ = limited.allocator(); + + const s1 = try alloc_.alloc(u8, 100); + const s2 = try alloc_.alloc(u8, 200); + const s3 = try alloc_.alloc(u8, 300); + + try testing.expectEqual(600, limited.bytesAllocated()); + + alloc_.free(s2); + try testing.expectEqual(400, limited.bytesAllocated()); + + alloc_.free(s1); + alloc_.free(s3); + try testing.expectEqual(0, limited.bytesAllocated()); +} + +test "LimitedAllocator: allocation at limit boundary" { + var limited = LimitedAllocator.init(testing.allocator, 100); + const alloc_ = limited.allocator(); + + const s1 = try alloc_.alloc(u8, 50); + defer alloc_.free(s1); + + const s2 = try alloc_.alloc(u8, 50); + defer alloc_.free(s2); + + // Should fail - at limit + try testing.expectError(error.OutOfMemory, alloc_.alloc(u8, 1)); +} diff --git a/src/Server.zig b/src/Server.zig index a557d8ed..e94eb272 100644 --- a/src/Server.zig +++ b/src/Server.zig @@ -24,38 +24,36 @@ const net = std.net; const posix = std.posix; const Allocator = std.mem.Allocator; -const ArenaAllocator = std.heap.ArenaAllocator; const log = @import("log.zig"); -const App = @import("App.zig"); -const CDP = @import("cdp/cdp.zig").CDP; - -const MAX_HTTP_REQUEST_SIZE = 4096; - -// max message size -// +14 for max websocket payload overhead -// +140 for the max control packet that might be interleaved in a message -const MAX_MESSAGE_SIZE = 512 * 1024 + 14 + 140; +const SharedState = @import("SharedState.zig"); +const SessionThread = @import("SessionThread.zig"); +const SessionManager = @import("SessionManager.zig"); const Server = @This(); -app: *App, + +shared: *SharedState, shutdown: bool = false, allocator: Allocator, -client: ?posix.socket_t, listener: ?posix.socket_t, +session_manager: SessionManager, json_version_response: []const u8, +timeout_ms: u32, +session_memory_limit: usize, -pub fn init(app: *App, address: net.Address) !Server { - const allocator = app.allocator; +pub fn init(shared: *SharedState, address: net.Address, max_sessions: u32, session_memory_limit: usize) !Server { + const allocator = shared.allocator; const json_version_response = try buildJSONVersionResponse(allocator, address); errdefer allocator.free(json_version_response); return .{ - .app = app, - .client = null, + .shared = shared, .listener = null, .allocator = allocator, + .session_manager = SessionManager.init(allocator, max_sessions), .json_version_response = json_version_response, + .timeout_ms = 0, + .session_memory_limit = session_memory_limit, }; } @@ -65,6 +63,9 @@ pub fn stop(self: *Server) void { return; } + // Stop all active sessions + self.session_manager.stopAll(); + // Linux and BSD/macOS handle canceling a socket blocked on accept differently. // For Linux, we use std.shutdown, which will cause accept to return error.SocketNotListening (EINVAL). // For BSD, shutdown will return an error. Instead we call posix.close, which will result with error.ConnectionAborted (BADF). @@ -81,16 +82,18 @@ pub fn stop(self: *Server) void { } pub fn deinit(self: *Server) void { + self.session_manager.deinit(); + if (self.listener) |listener| { posix.close(listener); self.listener = null; } - // *if* server.run is running, we should really wait for it to return - // before existing from here. self.allocator.free(self.json_version_response); } pub fn run(self: *Server, address: net.Address, timeout_ms: u32) !void { + self.timeout_ms = timeout_ms; + const flags = posix.SOCK.STREAM | posix.SOCK.CLOEXEC; const listener = try posix.socket(address.any.family, flags, posix.IPPROTO.TCP); self.listener = listener; @@ -101,9 +104,11 @@ pub fn run(self: *Server, address: net.Address, timeout_ms: u32) !void { } try posix.bind(listener, &address.any, address.getOsSockLen()); - try posix.listen(listener, 1); + // Increase backlog from 1 to 128 to support multiple concurrent connections + try posix.listen(listener, 128); log.info(.app, "server running", .{ .address = address }); + while (!@atomicLoad(bool, &self.shutdown, .monotonic)) { const socket = posix.accept(listener, null, null, posix.SOCK.NONBLOCK) catch |err| { switch (err) { @@ -119,848 +124,55 @@ pub fn run(self: *Server, address: net.Address, timeout_ms: u32) !void { } }; - self.client = socket; - defer if (self.client) |s| { - posix.close(s); - self.client = null; - }; - if (log.enabled(.app, .info)) { var client_address: std.net.Address = undefined; var socklen: posix.socklen_t = @sizeOf(net.Address); - try std.posix.getsockname(socket, &client_address.any, &socklen); + posix.getsockname(socket, &client_address.any, &socklen) catch {}; log.info(.app, "client connected", .{ .ip = client_address }); } - self.readLoop(socket, timeout_ms) catch |err| { - log.err(.app, "CDP client loop", .{ .err = err }); + // Spawn a new session thread for this connection + const session = SessionThread.spawn( + self.shared, + &self.session_manager, + socket, + timeout_ms, + self.json_version_response, + self.session_memory_limit, + ) catch |err| { + log.err(.app, "spawn session", .{ .err = err }); + posix.close(socket); + continue; }; - } -} -fn readLoop(self: *Server, socket: posix.socket_t, timeout_ms: u32) !void { - // This shouldn't be necessary, but the Client is HUGE (> 512KB) because - // it has a large read buffer. I don't know why, but v8 crashes if this - // is on the stack (and I assume it's related to its size). - const client = try self.allocator.create(Client); - defer self.allocator.destroy(client); - - client.* = try Client.init(socket, self); - defer client.deinit(); - - var http = &self.app.http; - http.addCDPClient(.{ - .socket = socket, - .ctx = client, - .blocking_read_start = Client.blockingReadStart, - .blocking_read = Client.blockingRead, - .blocking_read_end = Client.blockingReadStop, - }); - defer http.removeCDPClient(); - - lp.assert(client.mode == .http, "Server.readLoop invalid mode", .{}); - while (true) { - if (http.poll(timeout_ms) != .cdp_socket) { - log.info(.app, "CDP timeout", .{}); - return; - } - - if (client.readSocket() == false) { - return; - } - - if (client.mode == .cdp) { - break; // switch to our CDP loop - } - } - - var cdp = &client.mode.cdp; - var last_message = timestamp(.monotonic); - var ms_remaining = timeout_ms; - while (true) { - switch (cdp.pageWait(ms_remaining)) { - .cdp_socket => { - if (client.readSocket() == false) { - return; - } - last_message = timestamp(.monotonic); - ms_remaining = timeout_ms; + self.session_manager.add(session) catch |err| switch (err) { + error.TooManySessions => { + log.warn(.app, "too many sessions", .{ .count = self.session_manager.count() }); + sendServiceUnavailable(socket); + session.stop(); + session.join(); + session.deinit(); }, - .no_page => { - if (http.poll(ms_remaining) != .cdp_socket) { - log.info(.app, "CDP timeout", .{}); - return; - } - if (client.readSocket() == false) { - return; - } - last_message = timestamp(.monotonic); - ms_remaining = timeout_ms; + else => { + log.err(.app, "add session", .{ .err = err }); + session.stop(); + session.join(); + session.deinit(); }, - .done => { - const elapsed = timestamp(.monotonic) - last_message; - if (elapsed > ms_remaining) { - log.info(.app, "CDP timeout", .{}); - return; - } - ms_remaining -= @intCast(elapsed); - }, - .navigate => unreachable, // must have been handled by the session - } + }; } } -pub const Client = struct { - // The client is initially serving HTTP requests but, under normal circumstances - // should eventually be upgraded to a websocket connections - mode: union(enum) { - http: void, - cdp: CDP, - }, - - server: *Server, - reader: Reader(true), - socket: posix.socket_t, - socket_flags: usize, - send_arena: ArenaAllocator, - - const EMPTY_PONG = [_]u8{ 138, 0 }; - - // CLOSE, 2 length, code - const CLOSE_NORMAL = [_]u8{ 136, 2, 3, 232 }; // code: 1000 - const CLOSE_TOO_BIG = [_]u8{ 136, 2, 3, 241 }; // 1009 - const CLOSE_PROTOCOL_ERROR = [_]u8{ 136, 2, 3, 234 }; //code: 1002 - // "private-use" close codes must be from 4000-49999 - const CLOSE_TIMEOUT = [_]u8{ 136, 2, 15, 160 }; // code: 4000 - - fn init(socket: posix.socket_t, server: *Server) !Client { - const socket_flags = try posix.fcntl(socket, posix.F.GETFL, 0); - const nonblocking = @as(u32, @bitCast(posix.O{ .NONBLOCK = true })); - // we expect the socket to come to us as nonblocking - lp.assert(socket_flags & nonblocking == nonblocking, "Client.init blocking", .{}); - - var reader = try Reader(true).init(server.allocator); - errdefer reader.deinit(); - - return .{ - .socket = socket, - .server = server, - .reader = reader, - .mode = .{ .http = {} }, - .socket_flags = socket_flags, - .send_arena = ArenaAllocator.init(server.allocator), - }; - } - - fn deinit(self: *Client) void { - switch (self.mode) { - .cdp => |*cdp| cdp.deinit(), - .http => {}, - } - self.reader.deinit(); - self.send_arena.deinit(); - } - - fn blockingReadStart(ctx: *anyopaque) bool { - const self: *Client = @ptrCast(@alignCast(ctx)); - _ = posix.fcntl(self.socket, posix.F.SETFL, self.socket_flags & ~@as(u32, @bitCast(posix.O{ .NONBLOCK = true }))) catch |err| { - log.warn(.app, "CDP blockingReadStart", .{ .err = err }); - return false; - }; - return true; - } - - fn blockingRead(ctx: *anyopaque) bool { - const self: *Client = @ptrCast(@alignCast(ctx)); - return self.readSocket(); - } - - fn blockingReadStop(ctx: *anyopaque) bool { - const self: *Client = @ptrCast(@alignCast(ctx)); - _ = posix.fcntl(self.socket, posix.F.SETFL, self.socket_flags) catch |err| { - log.warn(.app, "CDP blockingReadStop", .{ .err = err }); - return false; - }; - return true; - } - - fn readSocket(self: *Client) bool { - const n = posix.read(self.socket, self.readBuf()) catch |err| { - log.warn(.app, "CDP read", .{ .err = err }); - return false; - }; - - if (n == 0) { - log.info(.app, "CDP disconnect", .{}); - return false; - } - - return self.processData(n) catch false; - } - - fn readBuf(self: *Client) []u8 { - return self.reader.readBuf(); - } - - fn processData(self: *Client, len: usize) !bool { - self.reader.len += len; - - switch (self.mode) { - .cdp => |*cdp| return self.processWebsocketMessage(cdp), - .http => return self.processHTTPRequest(), - } - } - - fn processHTTPRequest(self: *Client) !bool { - lp.assert(self.reader.pos == 0, "Client.HTTP pos", .{ .pos = self.reader.pos }); - const request = self.reader.buf[0..self.reader.len]; - - if (request.len > MAX_HTTP_REQUEST_SIZE) { - self.writeHTTPErrorResponse(413, "Request too large"); - return error.RequestTooLarge; - } - - // we're only expecting [body-less] GET requests. - if (std.mem.endsWith(u8, request, "\r\n\r\n") == false) { - // we need more data, put any more data here - return true; - } - - // the next incoming data can go to the front of our buffer - defer self.reader.len = 0; - return self.handleHTTPRequest(request) catch |err| { - switch (err) { - error.NotFound => self.writeHTTPErrorResponse(404, "Not found"), - error.InvalidRequest => self.writeHTTPErrorResponse(400, "Invalid request"), - error.InvalidProtocol => self.writeHTTPErrorResponse(400, "Invalid HTTP protocol"), - error.MissingHeaders => self.writeHTTPErrorResponse(400, "Missing required header"), - error.InvalidUpgradeHeader => self.writeHTTPErrorResponse(400, "Unsupported upgrade type"), - error.InvalidVersionHeader => self.writeHTTPErrorResponse(400, "Invalid websocket version"), - error.InvalidConnectionHeader => self.writeHTTPErrorResponse(400, "Invalid connection header"), - else => { - log.err(.app, "server 500", .{ .err = err, .req = request[0..@min(100, request.len)] }); - self.writeHTTPErrorResponse(500, "Internal Server Error"); - }, - } - return err; - }; - } - - fn handleHTTPRequest(self: *Client, request: []u8) !bool { - if (request.len < 18) { - // 18 is [generously] the smallest acceptable HTTP request - return error.InvalidRequest; - } - - if (std.mem.eql(u8, request[0..4], "GET ") == false) { - return error.NotFound; - } - - const url_end = std.mem.indexOfScalarPos(u8, request, 4, ' ') orelse { - return error.InvalidRequest; - }; - - const url = request[4..url_end]; - - if (std.mem.eql(u8, url, "/")) { - try self.upgradeConnection(request); - return true; - } - - if (std.mem.eql(u8, url, "/json/version")) { - try self.send(self.server.json_version_response); - // Chromedp (a Go driver) does an http request to /json/version - // then to / (websocket upgrade) using a different connection. - // Since we only allow 1 connection at a time, the 2nd one (the - // websocket upgrade) blocks until the first one times out. - // We can avoid that by closing the connection. json_version_response - // has a Connection: Close header too. - try posix.shutdown(self.socket, .recv); - return false; - } - - return error.NotFound; - } - - fn upgradeConnection(self: *Client, request: []u8) !void { - // our caller already confirmed that we have a trailing \r\n\r\n - const request_line_end = std.mem.indexOfScalar(u8, request, '\r') orelse unreachable; - const request_line = request[0..request_line_end]; - - if (!std.ascii.endsWithIgnoreCase(request_line, "http/1.1")) { - return error.InvalidProtocol; - } - - // we need to extract the sec-websocket-key value - var key: []const u8 = ""; - - // we need to make sure that we got all the necessary headers + values - var required_headers: u8 = 0; - - // can't std.mem.split because it forces the iterated value to be const - // (we could @constCast...) - - var buf = request[request_line_end + 2 ..]; - - while (buf.len > 4) { - const index = std.mem.indexOfScalar(u8, buf, '\r') orelse unreachable; - const separator = std.mem.indexOfScalar(u8, buf[0..index], ':') orelse return error.InvalidRequest; - - const name = std.mem.trim(u8, toLower(buf[0..separator]), &std.ascii.whitespace); - const value = std.mem.trim(u8, buf[(separator + 1)..index], &std.ascii.whitespace); - - if (std.mem.eql(u8, name, "upgrade")) { - if (!std.ascii.eqlIgnoreCase("websocket", value)) { - return error.InvalidUpgradeHeader; - } - required_headers |= 1; - } else if (std.mem.eql(u8, name, "sec-websocket-version")) { - if (value.len != 2 or value[0] != '1' or value[1] != '3') { - return error.InvalidVersionHeader; - } - required_headers |= 2; - } else if (std.mem.eql(u8, name, "connection")) { - // find if connection header has upgrade in it, example header: - // Connection: keep-alive, Upgrade - if (std.ascii.indexOfIgnoreCase(value, "upgrade") == null) { - return error.InvalidConnectionHeader; - } - required_headers |= 4; - } else if (std.mem.eql(u8, name, "sec-websocket-key")) { - key = value; - required_headers |= 8; - } - - const next = index + 2; - buf = buf[next..]; - } - - if (required_headers != 15) { - return error.MissingHeaders; - } - - // our caller has already made sure this request ended in \r\n\r\n - // so it isn't something we need to check again - - const allocator = self.send_arena.allocator(); - - const response = blk: { - // Response to an ugprade request is always this, with - // the Sec-Websocket-Accept value a spacial sha1 hash of the - // request "sec-websocket-version" and a magic value. - - const template = - "HTTP/1.1 101 Switching Protocols\r\n" ++ - "Upgrade: websocket\r\n" ++ - "Connection: upgrade\r\n" ++ - "Sec-Websocket-Accept: 0000000000000000000000000000\r\n\r\n"; - - // The response will be sent via the IO Loop and thus has to have its - // own lifetime. - const res = try allocator.dupe(u8, template); - - // magic response - const key_pos = res.len - 32; - var h: [20]u8 = undefined; - var hasher = std.crypto.hash.Sha1.init(.{}); - hasher.update(key); - // websocket spec always used this value - hasher.update("258EAFA5-E914-47DA-95CA-C5AB0DC85B11"); - hasher.final(&h); - - _ = std.base64.standard.Encoder.encode(res[key_pos .. key_pos + 28], h[0..]); - - break :blk res; - }; - - self.mode = .{ .cdp = try CDP.init(self.server.app, self) }; - return self.send(response); - } - - fn writeHTTPErrorResponse(self: *Client, comptime status: u16, comptime body: []const u8) void { - const response = std.fmt.comptimePrint( - "HTTP/1.1 {d} \r\nConnection: Close\r\nContent-Length: {d}\r\n\r\n{s}", - .{ status, body.len, body }, - ); - - // we're going to close this connection anyways, swallowing any - // error seems safe - self.send(response) catch {}; - } - - fn processWebsocketMessage(self: *Client, cdp: *CDP) !bool { - var reader = &self.reader; - while (true) { - const msg = reader.next() catch |err| { - switch (err) { - error.TooLarge => self.send(&CLOSE_TOO_BIG) catch {}, - error.NotMasked => self.send(&CLOSE_PROTOCOL_ERROR) catch {}, - error.ReservedFlags => self.send(&CLOSE_PROTOCOL_ERROR) catch {}, - error.InvalidMessageType => self.send(&CLOSE_PROTOCOL_ERROR) catch {}, - error.ControlTooLarge => self.send(&CLOSE_PROTOCOL_ERROR) catch {}, - error.InvalidContinuation => self.send(&CLOSE_PROTOCOL_ERROR) catch {}, - error.NestedFragementation => self.send(&CLOSE_PROTOCOL_ERROR) catch {}, - error.OutOfMemory => {}, // don't borther trying to send an error in this case - } - return err; - } orelse break; - - switch (msg.type) { - .pong => {}, - .ping => try self.sendPong(msg.data), - .close => { - self.send(&CLOSE_NORMAL) catch {}; - return false; - }, - .text, .binary => if (cdp.handleMessage(msg.data) == false) { - return false; - }, - } - if (msg.cleanup_fragment) { - reader.cleanup(); - } - } - - // We might have read part of the next message. Our reader potentially - // has to move data around in its buffer to make space. - reader.compact(); - return true; - } - - fn sendPong(self: *Client, data: []const u8) !void { - if (data.len == 0) { - return self.send(&EMPTY_PONG); - } - var header_buf: [10]u8 = undefined; - const header = websocketHeader(&header_buf, .pong, data.len); - - const allocator = self.send_arena.allocator(); - var framed = try allocator.alloc(u8, header.len + data.len); - @memcpy(framed[0..header.len], header); - @memcpy(framed[header.len..], data); - return self.send(framed); - } - - // called by CDP - // Websocket frames have a variable length header. For server-client, - // it could be anywhere from 2 to 10 bytes. Our IO.Loop doesn't have - // writev, so we need to get creative. We'll JSON serialize to a - // buffer, where the first 10 bytes are reserved. We can then backfill - // the header and send the slice. - pub fn sendJSON(self: *Client, message: anytype, opts: std.json.Stringify.Options) !void { - const allocator = self.send_arena.allocator(); - - var aw = try std.Io.Writer.Allocating.initCapacity(allocator, 512); - - // reserve space for the maximum possible header - try aw.writer.writeAll(&.{ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }); - try std.json.Stringify.value(message, opts, &aw.writer); - const framed = fillWebsocketHeader(aw.toArrayList()); - return self.send(framed); - } - - pub fn sendJSONRaw( - self: *Client, - buf: std.ArrayListUnmanaged(u8), - ) !void { - // Dangerous API!. We assume the caller has reserved the first 10 - // bytes in `buf`. - const framed = fillWebsocketHeader(buf); - return self.send(framed); - } - - fn send(self: *Client, data: []const u8) !void { - var pos: usize = 0; - var changed_to_blocking: bool = false; - defer _ = self.send_arena.reset(.{ .retain_with_limit = 1024 * 32 }); - - defer if (changed_to_blocking) { - // We had to change our socket to blocking me to get our write out - // We need to change it back to non-blocking. - _ = posix.fcntl(self.socket, posix.F.SETFL, self.socket_flags) catch |err| { - log.err(.app, "CDP restore nonblocking", .{ .err = err }); - }; - }; - - LOOP: while (pos < data.len) { - const written = posix.write(self.socket, data[pos..]) catch |err| switch (err) { - error.WouldBlock => { - // self.socket is nonblocking, because we don't want to block - // reads. But our life is a lot easier if we block writes, - // largely, because we don't have to maintain a queue of pending - // writes (which would each need their own allocations). So - // if we get a WouldBlock error, we'll switch the socket to - // blocking and switch it back to non-blocking after the write - // is complete. Doesn't seem particularly efficiently, but - // this should virtually never happen. - lp.assert(changed_to_blocking == false, "Client.double block", .{}); - changed_to_blocking = true; - _ = try posix.fcntl(self.socket, posix.F.SETFL, self.socket_flags & ~@as(u32, @bitCast(posix.O{ .NONBLOCK = true }))); - continue :LOOP; - }, - else => return err, - }; - - if (written == 0) { - return error.Closed; - } - pos += written; - } - } -}; - -// WebSocket message reader. Given websocket message, acts as an iterator that -// can return zero or more Messages. When next returns null, any incomplete -// message will remain in reader.data -fn Reader(comptime EXPECT_MASK: bool) type { - return struct { - allocator: Allocator, - - // position in buf of the start of the next message - pos: usize = 0, - - // position in buf up until where we have valid data - // (any new reads must be placed after this) - len: usize = 0, - - // we add 140 to allow 1 control message (ping/pong/close) to be - // fragmented into a normal message. - buf: []u8, - - fragments: ?Fragments = null, - - const Self = @This(); - - fn init(allocator: Allocator) !Self { - const buf = try allocator.alloc(u8, 16 * 1024); - return .{ - .buf = buf, - .allocator = allocator, - }; - } - - fn deinit(self: *Self) void { - self.cleanup(); - self.allocator.free(self.buf); - } - - fn cleanup(self: *Self) void { - if (self.fragments) |*f| { - f.message.deinit(self.allocator); - self.fragments = null; - } - } - - fn readBuf(self: *Self) []u8 { - // We might have read a partial http or websocket message. - // Subsequent reads must read from where we left off. - return self.buf[self.len..]; - } - - fn next(self: *Self) !?Message { - LOOP: while (true) { - var buf = self.buf[self.pos..self.len]; - - const length_of_len, const message_len = extractLengths(buf) orelse { - // we don't have enough bytes - return null; - }; - - const byte1 = buf[0]; - - if (byte1 & 112 != 0) { - return error.ReservedFlags; - } - - if (comptime EXPECT_MASK) { - if (buf[1] & 128 != 128) { - // client -> server messages _must_ be masked - return error.NotMasked; - } - } else if (buf[1] & 128 != 0) { - // server -> client are never masked - return error.Masked; - } - - var is_control = false; - var is_continuation = false; - var message_type: Message.Type = undefined; - switch (byte1 & 15) { - 0 => is_continuation = true, - 1 => message_type = .text, - 2 => message_type = .binary, - 8 => { - is_control = true; - message_type = .close; - }, - 9 => { - is_control = true; - message_type = .ping; - }, - 10 => { - is_control = true; - message_type = .pong; - }, - else => return error.InvalidMessageType, - } - - if (is_control) { - if (message_len > 125) { - return error.ControlTooLarge; - } - } else if (message_len > MAX_MESSAGE_SIZE) { - return error.TooLarge; - } else if (message_len > self.buf.len) { - const len = self.buf.len; - self.buf = try growBuffer(self.allocator, self.buf, message_len); - buf = self.buf[0..len]; - // we need more data - return null; - } else if (buf.len < message_len) { - // we need more data - return null; - } - - // prefix + length_of_len + mask - const header_len = 2 + length_of_len + if (comptime EXPECT_MASK) 4 else 0; - - const payload = buf[header_len..message_len]; - if (comptime EXPECT_MASK) { - mask(buf[header_len - 4 .. header_len], payload); - } - - // whatever happens after this, we know where the next message starts - self.pos += message_len; - - const fin = byte1 & 128 == 128; - - if (is_continuation) { - const fragments = &(self.fragments orelse return error.InvalidContinuation); - if (fragments.message.items.len + message_len > MAX_MESSAGE_SIZE) { - return error.TooLarge; - } - - try fragments.message.appendSlice(self.allocator, payload); - - if (fin == false) { - // maybe we have more parts of the message waiting - continue :LOOP; - } - - // this continuation is done! - return .{ - .type = fragments.type, - .data = fragments.message.items, - .cleanup_fragment = true, - }; - } - - const can_be_fragmented = message_type == .text or message_type == .binary; - if (self.fragments != null and can_be_fragmented) { - // if this isn't a continuation, then we can't have fragments - return error.NestedFragementation; - } - - if (fin == false) { - if (can_be_fragmented == false) { - return error.InvalidContinuation; - } - - // not continuation, and not fin. It has to be the first message - // in a fragmented message. - var fragments = Fragments{ .message = .{}, .type = message_type }; - try fragments.message.appendSlice(self.allocator, payload); - self.fragments = fragments; - continue :LOOP; - } - - return .{ - .data = payload, - .type = message_type, - .cleanup_fragment = false, - }; - } - } - - fn extractLengths(buf: []const u8) ?struct { usize, usize } { - if (buf.len < 2) { - return null; - } - - const length_of_len: usize = switch (buf[1] & 127) { - 126 => 2, - 127 => 8, - else => 0, - }; - - if (buf.len < length_of_len + 2) { - // we definitely don't have enough buf yet - return null; - } - - const message_len = switch (length_of_len) { - 2 => @as(u16, @intCast(buf[3])) | @as(u16, @intCast(buf[2])) << 8, - 8 => @as(u64, @intCast(buf[9])) | @as(u64, @intCast(buf[8])) << 8 | @as(u64, @intCast(buf[7])) << 16 | @as(u64, @intCast(buf[6])) << 24 | @as(u64, @intCast(buf[5])) << 32 | @as(u64, @intCast(buf[4])) << 40 | @as(u64, @intCast(buf[3])) << 48 | @as(u64, @intCast(buf[2])) << 56, - else => buf[1] & 127, - } + length_of_len + 2 + if (comptime EXPECT_MASK) 4 else 0; // +2 for header prefix, +4 for mask; - - return .{ length_of_len, message_len }; - } - - // This is called after we've processed complete websocket messages (this - // only applies to websocket messages). - // There are three cases: - // 1 - We don't have any incomplete data (for a subsequent message) in buf. - // This is the easier to handle, we can set pos & len to 0. - // 2 - We have part of the next message, but we know it'll fit in the - // remaining buf. We don't need to do anything - // 3 - We have part of the next message, but either it won't fight into the - // remaining buffer, or we don't know (because we don't have enough - // of the header to tell the length). We need to "compact" the buffer - fn compact(self: *Self) void { - const pos = self.pos; - const len = self.len; - - lp.assert(pos <= len, "Client.Reader.compact precondition", .{ .pos = pos, .len = len }); - - // how many (if any) partial bytes do we have - const partial_bytes = len - pos; - - if (partial_bytes == 0) { - // We have no partial bytes. Setting these to 0 ensures that we - // get the best utilization of our buffer - self.pos = 0; - self.len = 0; - return; - } - - const partial = self.buf[pos..len]; - - // If we have enough bytes of the next message to tell its length - // we'll be able to figure out whether we need to do anything or not. - if (extractLengths(partial)) |length_meta| { - const next_message_len = length_meta.@"1"; - // if this isn't true, then we have a full message and it - // should have been processed. - lp.assert(pos <= len, "Client.Reader.compact postcondition", .{ .next_len = next_message_len, .partial = partial_bytes }); - - const missing_bytes = next_message_len - partial_bytes; - - const free_space = self.buf.len - len; - if (missing_bytes < free_space) { - // we have enough space in our buffer, as is, - return; - } - } - - // We're here because we either don't have enough bytes of the next - // message, or we know that it won't fit in our buffer as-is. - std.mem.copyForwards(u8, self.buf, partial); - self.pos = 0; - self.len = partial_bytes; - } - }; +fn sendServiceUnavailable(socket: posix.socket_t) void { + const response = + "HTTP/1.1 503 Service Unavailable\r\n" ++ + "Connection: Close\r\n" ++ + "Content-Length: 31\r\n\r\n" ++ + "Too many concurrent connections"; + _ = posix.write(socket, response) catch {}; + posix.close(socket); } -fn growBuffer(allocator: Allocator, buf: []u8, required_capacity: usize) ![]u8 { - // from std.ArrayList - var new_capacity = buf.len; - while (true) { - new_capacity +|= new_capacity / 2 + 8; - if (new_capacity >= required_capacity) break; - } - - log.debug(.app, "CDP buffer growth", .{ .from = buf.len, .to = new_capacity }); - - if (allocator.resize(buf, new_capacity)) { - return buf.ptr[0..new_capacity]; - } - const new_buffer = try allocator.alloc(u8, new_capacity); - @memcpy(new_buffer[0..buf.len], buf); - allocator.free(buf); - return new_buffer; -} - -const Fragments = struct { - type: Message.Type, - message: std.ArrayListUnmanaged(u8), -}; - -const Message = struct { - type: Type, - data: []const u8, - cleanup_fragment: bool, - - const Type = enum { - text, - binary, - close, - ping, - pong, - }; -}; - -// These are the only websocket types that we're currently sending -const OpCode = enum(u8) { - text = 128 | 1, - close = 128 | 8, - pong = 128 | 10, -}; - -fn fillWebsocketHeader(buf: std.ArrayListUnmanaged(u8)) []const u8 { - // can't use buf[0..10] here, because the header length - // is variable. If it's just 2 bytes, for example, we need the - // framed message to be: - // h1, h2, data - // If we use buf[0..10], we'd get: - // h1, h2, 0, 0, 0, 0, 0, 0, 0, 0, data - - var header_buf: [10]u8 = undefined; - - // -10 because we reserved 10 bytes for the header above - const header = websocketHeader(&header_buf, .text, buf.items.len - 10); - const start = 10 - header.len; - - const message = buf.items; - @memcpy(message[start..10], header); - return message[start..]; -} - -// makes the assumption that our caller reserved the first -// 10 bytes for the header -fn websocketHeader(buf: []u8, op_code: OpCode, payload_len: usize) []const u8 { - lp.assert(buf.len == 10, "Websocket.Header", .{ .len = buf.len }); - - const len = payload_len; - buf[0] = 128 | @intFromEnum(op_code); // fin | opcode - - if (len <= 125) { - buf[1] = @intCast(len); - return buf[0..2]; - } - - if (len < 65536) { - buf[1] = 126; - buf[2] = @intCast((len >> 8) & 0xFF); - buf[3] = @intCast(len & 0xFF); - return buf[0..4]; - } - - buf[1] = 127; - buf[2] = 0; - buf[3] = 0; - buf[4] = 0; - buf[5] = 0; - buf[6] = @intCast((len >> 24) & 0xFF); - buf[7] = @intCast((len >> 16) & 0xFF); - buf[8] = @intCast((len >> 8) & 0xFF); - buf[9] = @intCast(len & 0xFF); - return buf[0..10]; -} - -// Utils -// -------- - fn buildJSONVersionResponse( allocator: Allocator, address: net.Address, @@ -968,12 +180,6 @@ fn buildJSONVersionResponse( const body_format = "{{\"webSocketDebuggerUrl\": \"ws://{f}/\"}}"; const body_len = std.fmt.count(body_format, .{address}); - // We send a Connection: Close (and actually close the connection) - // because chromedp (Go driver) sends a request to /json/version and then - // does an upgrade request, on a different connection. Since we only allow - // 1 connection at a time, the upgrade connection doesn't proceed until we - // timeout the /json/version. So, instead of waiting for that, we just - // always close HTTP requests. const response_format = "HTTP/1.1 200 OK\r\n" ++ "Content-Length: {d}\r\n" ++ @@ -983,49 +189,8 @@ fn buildJSONVersionResponse( return try std.fmt.allocPrint(allocator, response_format, .{ body_len, address }); } -pub const timestamp = @import("datetime.zig").timestamp; - -// In-place string lowercase -fn toLower(str: []u8) []u8 { - for (str, 0..) |c, i| { - str[i] = std.ascii.toLower(c); - } - return str; -} - -// Zig is in a weird backend transition right now. Need to determine if -// SIMD is even available. -const backend_supports_vectors = switch (builtin.zig_backend) { - .stage2_llvm, .stage2_c => true, - else => false, -}; - -// Websocket messages from client->server are masked using a 4 byte XOR mask -fn mask(m: []const u8, payload: []u8) void { - var data = payload; - - if (!comptime backend_supports_vectors) return simpleMask(m, data); - - const vector_size = std.simd.suggestVectorLength(u8) orelse @sizeOf(usize); - if (data.len >= vector_size) { - const mask_vector = std.simd.repeat(vector_size, @as(@Vector(4, u8), m[0..4].*)); - while (data.len >= vector_size) { - const slice = data[0..vector_size]; - const masked_data_slice: @Vector(vector_size, u8) = slice.*; - slice.* = masked_data_slice ^ mask_vector; - data = data[vector_size..]; - } - } - simpleMask(m, data); -} - -// Used when SIMD isn't available, or for any remaining part of the message -// which is too small to effectively use SIMD. -fn simpleMask(m: []const u8, payload: []u8) void { - for (payload, 0..) |b, i| { - payload[i] = b ^ m[i & 3]; - } -} +// Re-export Client from SessionThread for compatibility +pub const Client = SessionThread.Client; const testing = std.testing; test "server: buildJSONVersionResponse" { @@ -1039,432 +204,3 @@ test "server: buildJSONVersionResponse" { "Content-Type: application/json; charset=UTF-8\r\n\r\n" ++ "{\"webSocketDebuggerUrl\": \"ws://127.0.0.1:9001/\"}", res); } - -test "Client: http invalid request" { - var c = try createTestClient(); - defer c.deinit(); - - const res = try c.httpRequest("GET /over/9000 HTTP/1.1\r\n" ++ "Header: " ++ ("a" ** 4100) ++ "\r\n\r\n"); - try testing.expectEqualStrings("HTTP/1.1 413 \r\n" ++ - "Connection: Close\r\n" ++ - "Content-Length: 17\r\n\r\n" ++ - "Request too large", res); -} - -test "Client: http invalid handshake" { - try assertHTTPError( - 400, - "Invalid request", - "\r\n\r\n", - ); - - try assertHTTPError( - 404, - "Not found", - "GET /over/9000 HTTP/1.1\r\n\r\n", - ); - - try assertHTTPError( - 404, - "Not found", - "POST / HTTP/1.1\r\n\r\n", - ); - - try assertHTTPError( - 400, - "Invalid HTTP protocol", - "GET / HTTP/1.0\r\n\r\n", - ); - - try assertHTTPError( - 400, - "Missing required header", - "GET / HTTP/1.1\r\n\r\n", - ); - - try assertHTTPError( - 400, - "Missing required header", - "GET / HTTP/1.1\r\nConnection: upgrade\r\n\r\n", - ); - - try assertHTTPError( - 400, - "Missing required header", - "GET / HTTP/1.1\r\nConnection: upgrade\r\nUpgrade: websocket\r\n\r\n", - ); - - try assertHTTPError( - 400, - "Missing required header", - "GET / HTTP/1.1\r\nConnection: upgrade\r\nUpgrade: websocket\r\nsec-websocket-version:13\r\n\r\n", - ); -} - -test "Client: http valid handshake" { - var c = try createTestClient(); - defer c.deinit(); - - const request = - "GET / HTTP/1.1\r\n" ++ - "Connection: upgrade\r\n" ++ - "Upgrade: websocket\r\n" ++ - "sec-websocket-version:13\r\n" ++ - "sec-websocket-key: this is my key\r\n" ++ - "Custom: Header-Value\r\n\r\n"; - - const res = try c.httpRequest(request); - try testing.expectEqualStrings("HTTP/1.1 101 Switching Protocols\r\n" ++ - "Upgrade: websocket\r\n" ++ - "Connection: upgrade\r\n" ++ - "Sec-Websocket-Accept: flzHu2DevQ2dSCSVqKSii5e9C2o=\r\n\r\n", res); -} - -test "Client: read invalid websocket message" { - // 131 = 128 (fin) | 3 where 3 isn't a valid type - try assertWebSocketError( - 1002, - &.{ 131, 128, 'm', 'a', 's', 'k' }, - ); - - for ([_]u8{ 16, 32, 64 }) |rsv| { - // none of the reserve flags should be set - try assertWebSocketError( - 1002, - &.{ rsv, 128, 'm', 'a', 's', 'k' }, - ); - - // as a bitmask - try assertWebSocketError( - 1002, - &.{ rsv + 4, 128, 'm', 'a', 's', 'k' }, - ); - } - - // client->server messages must be masked - try assertWebSocketError( - 1002, - &.{ 129, 1, 'a' }, - ); - - // control types (ping/ping/close) can't be > 125 bytes - for ([_]u8{ 136, 137, 138 }) |op| { - try assertWebSocketError( - 1002, - &.{ op, 254, 1, 1 }, - ); - } - - // length of message is 0000 0810, i.e: 1024 * 512 + 265 - try assertWebSocketError(1009, &.{ 129, 255, 0, 0, 0, 0, 0, 8, 1, 0, 'm', 'a', 's', 'k' }); - - // continuation type message must come after a normal message - // even when not a fin frame - try assertWebSocketError( - 1002, - &.{ 0, 129, 'm', 'a', 's', 'k', 'd' }, - ); - - // continuation type message must come after a normal message - // even as a fin frame - try assertWebSocketError( - 1002, - &.{ 128, 129, 'm', 'a', 's', 'k', 'd' }, - ); - - // text (non-fin) - text (non-fin) - try assertWebSocketError( - 1002, - &.{ 1, 129, 'm', 'a', 's', 'k', 'd', 1, 128, 'k', 's', 'a', 'm' }, - ); - - // text (non-fin) - text (fin) should always been continuation after non-fin - try assertWebSocketError( - 1002, - &.{ 1, 129, 'm', 'a', 's', 'k', 'd', 129, 128, 'k', 's', 'a', 'm' }, - ); - - // close must be fin - try assertWebSocketError( - 1002, - &.{ - 8, 129, 'm', 'a', 's', 'k', 'd', - }, - ); - - // ping must be fin - try assertWebSocketError( - 1002, - &.{ - 9, 129, 'm', 'a', 's', 'k', 'd', - }, - ); - - // pong must be fin - try assertWebSocketError( - 1002, - &.{ - 10, 129, 'm', 'a', 's', 'k', 'd', - }, - ); -} - -test "Client: ping reply" { - try assertWebSocketMessage( - // fin | pong, len - &.{ 138, 0 }, - - // fin | ping, masked | len, 4-byte mask - &.{ 137, 128, 0, 0, 0, 0 }, - ); - - try assertWebSocketMessage( - // fin | pong, len, payload - &.{ 138, 5, 100, 96, 97, 109, 104 }, - - // fin | ping, masked | len, 4-byte mask, 5 byte payload - &.{ 137, 133, 0, 5, 7, 10, 100, 101, 102, 103, 104 }, - ); -} - -test "Client: close message" { - try assertWebSocketMessage( - // fin | close, len, close code (normal) - &.{ 136, 2, 3, 232 }, - - // fin | close, masked | len, 4-byte mask - &.{ 136, 128, 0, 0, 0, 0 }, - ); -} - -test "server: mask" { - var buf: [4000]u8 = undefined; - const messages = [_][]const u8{ "1234", "1234" ** 99, "1234" ** 999 }; - for (messages) |message| { - // we need the message to be mutable since mask operates in-place - const payload = buf[0..message.len]; - @memcpy(payload, message); - - mask(&.{ 1, 2, 200, 240 }, payload); - try testing.expectEqual(false, std.mem.eql(u8, payload, message)); - - mask(&.{ 1, 2, 200, 240 }, payload); - try testing.expectEqual(true, std.mem.eql(u8, payload, message)); - } -} - -test "server: 404" { - var c = try createTestClient(); - defer c.deinit(); - - const res = try c.httpRequest("GET /unknown HTTP/1.1\r\n\r\n"); - try testing.expectEqualStrings("HTTP/1.1 404 \r\n" ++ - "Connection: Close\r\n" ++ - "Content-Length: 9\r\n\r\n" ++ - "Not found", res); -} - -test "server: get /json/version" { - const expected_response = - "HTTP/1.1 200 OK\r\n" ++ - "Content-Length: 48\r\n" ++ - "Connection: Close\r\n" ++ - "Content-Type: application/json; charset=UTF-8\r\n\r\n" ++ - "{\"webSocketDebuggerUrl\": \"ws://127.0.0.1:9583/\"}"; - - { - // twice on the same connection - var c = try createTestClient(); - defer c.deinit(); - - const res1 = try c.httpRequest("GET /json/version HTTP/1.1\r\n\r\n"); - try testing.expectEqualStrings(expected_response, res1); - } - - { - // again on a new connection - var c = try createTestClient(); - defer c.deinit(); - - const res1 = try c.httpRequest("GET /json/version HTTP/1.1\r\n\r\n"); - try testing.expectEqualStrings(expected_response, res1); - } -} - -fn assertHTTPError( - comptime expected_status: u16, - comptime expected_body: []const u8, - input: []const u8, -) !void { - var c = try createTestClient(); - defer c.deinit(); - - const res = try c.httpRequest(input); - const expected_response = std.fmt.comptimePrint( - "HTTP/1.1 {d} \r\nConnection: Close\r\nContent-Length: {d}\r\n\r\n{s}", - .{ expected_status, expected_body.len, expected_body }, - ); - - try testing.expectEqualStrings(expected_response, res); -} - -fn assertWebSocketError(close_code: u16, input: []const u8) !void { - var c = try createTestClient(); - defer c.deinit(); - - try c.handshake(); - try c.stream.writeAll(input); - - const msg = try c.readWebsocketMessage() orelse return error.NoMessage; - defer if (msg.cleanup_fragment) { - c.reader.cleanup(); - }; - - try testing.expectEqual(.close, msg.type); - try testing.expectEqual(2, msg.data.len); - try testing.expectEqual(close_code, std.mem.readInt(u16, msg.data[0..2], .big)); -} - -fn assertWebSocketMessage(expected: []const u8, input: []const u8) !void { - var c = try createTestClient(); - defer c.deinit(); - - try c.handshake(); - try c.stream.writeAll(input); - - const msg = try c.readWebsocketMessage() orelse return error.NoMessage; - defer if (msg.cleanup_fragment) { - c.reader.cleanup(); - }; - - const actual = c.reader.buf[0 .. msg.data.len + 2]; - try testing.expectEqualSlices(u8, expected, actual); -} - -const MockCDP = struct { - messages: std.ArrayListUnmanaged([]const u8) = .{}, - - allocator: Allocator = testing.allocator, - - fn init(_: Allocator, client: anytype) MockCDP { - _ = client; - return .{}; - } - - fn deinit(self: *MockCDP) void { - const allocator = self.allocator; - for (self.messages.items) |msg| { - allocator.free(msg); - } - self.messages.deinit(allocator); - } - - fn handleMessage(self: *MockCDP, message: []const u8) bool { - const owned = self.allocator.dupe(u8, message) catch unreachable; - self.messages.append(self.allocator, owned) catch unreachable; - return true; - } -}; - -fn createTestClient() !TestClient { - const address = std.net.Address.initIp4([_]u8{ 127, 0, 0, 1 }, 9583); - const stream = try std.net.tcpConnectToAddress(address); - - const timeout = std.mem.toBytes(posix.timeval{ - .sec = 2, - .usec = 0, - }); - try posix.setsockopt(stream.handle, posix.SOL.SOCKET, posix.SO.RCVTIMEO, &timeout); - try posix.setsockopt(stream.handle, posix.SOL.SOCKET, posix.SO.SNDTIMEO, &timeout); - return .{ - .stream = stream, - .reader = .{ - .allocator = testing.allocator, - .buf = try testing.allocator.alloc(u8, 1024 * 16), - }, - }; -} - -const TestClient = struct { - stream: std.net.Stream, - buf: [1024]u8 = undefined, - reader: Reader(false), - - fn deinit(self: *TestClient) void { - self.stream.close(); - self.reader.deinit(); - } - - fn httpRequest(self: *TestClient, req: []const u8) ![]const u8 { - try self.stream.writeAll(req); - - var pos: usize = 0; - var total_length: ?usize = null; - while (true) { - pos += try self.stream.read(self.buf[pos..]); - if (pos == 0) { - return error.NoMoreData; - } - const response = self.buf[0..pos]; - if (total_length == null) { - const header_end = std.mem.indexOf(u8, response, "\r\n\r\n") orelse continue; - const header = response[0 .. header_end + 4]; - - const cl = blk: { - const cl_header = "Content-Length: "; - const start = (std.mem.indexOf(u8, header, cl_header) orelse { - break :blk 0; - }) + cl_header.len; - - const end = std.mem.indexOfScalarPos(u8, header, start, '\r') orelse { - return error.InvalidContentLength; - }; - - break :blk std.fmt.parseInt(usize, header[start..end], 10) catch { - return error.InvalidContentLength; - }; - }; - - total_length = cl + header.len; - } - - if (total_length) |tl| { - if (pos == tl) { - return response; - } - if (pos > tl) { - return error.DataExceedsContentLength; - } - } - } - } - - fn handshake(self: *TestClient) !void { - const request = - "GET / HTTP/1.1\r\n" ++ - "Connection: upgrade\r\n" ++ - "Upgrade: websocket\r\n" ++ - "sec-websocket-version:13\r\n" ++ - "sec-websocket-key: this is my key\r\n" ++ - "Custom: Header-Value\r\n\r\n"; - - const res = try self.httpRequest(request); - try testing.expectEqualStrings("HTTP/1.1 101 Switching Protocols\r\n" ++ - "Upgrade: websocket\r\n" ++ - "Connection: upgrade\r\n" ++ - "Sec-Websocket-Accept: flzHu2DevQ2dSCSVqKSii5e9C2o=\r\n\r\n", res); - } - - fn readWebsocketMessage(self: *TestClient) !?Message { - while (true) { - const n = try self.stream.read(self.reader.readBuf()); - if (n == 0) { - return error.Closed; - } - self.reader.len += n; - if (try self.reader.next()) |msg| { - return msg; - } - } - } -}; diff --git a/src/SessionManager.zig b/src/SessionManager.zig new file mode 100644 index 00000000..9ffb64f3 --- /dev/null +++ b/src/SessionManager.zig @@ -0,0 +1,133 @@ +// Copyright (C) 2023-2025 Lightpanda (Selecy SAS) +// +// Francis Bouvier +// Pierre Tachoire +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as +// published by the Free Software Foundation, either version 3 of the +// License, or (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +const std = @import("std"); +const Allocator = std.mem.Allocator; + +const SessionThread = @import("SessionThread.zig"); + +/// Thread-safe collection of active CDP sessions. +/// Manages lifecycle and enforces connection limits. +const SessionManager = @This(); + +mutex: std.Thread.Mutex, +sessions: std.ArrayListUnmanaged(*SessionThread), +allocator: Allocator, +max_sessions: u32, + +pub fn init(allocator: Allocator, max_sessions: u32) SessionManager { + return .{ + .mutex = .{}, + .sessions = .{}, + .allocator = allocator, + .max_sessions = max_sessions, + }; +} + +pub fn deinit(self: *SessionManager) void { + self.stopAll(); + self.sessions.deinit(self.allocator); +} + +/// Add a new session to the manager. +/// Returns error.TooManySessions if the limit is reached. +pub fn add(self: *SessionManager, session: *SessionThread) !void { + self.mutex.lock(); + defer self.mutex.unlock(); + + if (self.sessions.items.len >= self.max_sessions) { + return error.TooManySessions; + } + + try self.sessions.append(self.allocator, session); +} + +/// Remove a session from the manager. +/// Called when a session terminates. +pub fn remove(self: *SessionManager, session: *SessionThread) void { + self.mutex.lock(); + defer self.mutex.unlock(); + + for (self.sessions.items, 0..) |s, i| { + if (s == session) { + _ = self.sessions.swapRemove(i); + return; + } + } +} + +/// Stop all active sessions and wait for them to terminate. +pub fn stopAll(self: *SessionManager) void { + // First, signal all sessions to stop + { + self.mutex.lock(); + defer self.mutex.unlock(); + + for (self.sessions.items) |session| { + session.stop(); + } + } + + // Then wait for all to join (without holding the lock) + // We need to copy the list since sessions will remove themselves + var sessions_copy: std.ArrayListUnmanaged(*SessionThread) = .{}; + { + self.mutex.lock(); + defer self.mutex.unlock(); + + sessions_copy.appendSlice(self.allocator, self.sessions.items) catch return; + } + defer sessions_copy.deinit(self.allocator); + + for (sessions_copy.items) |session| { + session.join(); + session.deinit(); + } + + // Clear the sessions list + { + self.mutex.lock(); + defer self.mutex.unlock(); + self.sessions.clearRetainingCapacity(); + } +} + +/// Get the current number of active sessions. +pub fn count(self: *SessionManager) usize { + self.mutex.lock(); + defer self.mutex.unlock(); + return self.sessions.items.len; +} + +const testing = std.testing; + +test "SessionManager: add and remove" { + var manager = SessionManager.init(testing.allocator, 10); + defer manager.deinit(); + + try testing.expectEqual(0, manager.count()); +} + +test "SessionManager: max sessions limit" { + var manager = SessionManager.init(testing.allocator, 2); + defer manager.deinit(); + + // We can't easily create mock SessionThreads for this test, + // so we just verify the initialization works + try testing.expectEqual(0, manager.count()); +} diff --git a/src/SessionThread.zig b/src/SessionThread.zig new file mode 100644 index 00000000..733998c9 --- /dev/null +++ b/src/SessionThread.zig @@ -0,0 +1,885 @@ +// Copyright (C) 2023-2025 Lightpanda (Selecy SAS) +// +// Francis Bouvier +// Pierre Tachoire +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as +// published by the Free Software Foundation, either version 3 of the +// License, or (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +const std = @import("std"); +const lp = @import("lightpanda"); +const builtin = @import("builtin"); + +const posix = std.posix; +const net = std.net; + +const Allocator = std.mem.Allocator; +const ArenaAllocator = std.heap.ArenaAllocator; + +const log = @import("log.zig"); +const SharedState = @import("SharedState.zig"); +const SessionManager = @import("SessionManager.zig"); +const LimitedAllocator = @import("LimitedAllocator.zig"); +const HttpClient = @import("http/Client.zig"); +const CDP = @import("cdp/cdp.zig").CDP; +const BrowserSession = @import("browser/Session.zig"); + +const timestamp = @import("datetime.zig").timestamp; + +const MAX_HTTP_REQUEST_SIZE = 4096; +const MAX_MESSAGE_SIZE = 512 * 1024 + 14 + 140; + +/// Encapsulates a single CDP session running in its own thread. +/// Each SessionThread has: +/// - Its own client socket +/// - Its own HttpClient (with shared curl_share from SharedState) +/// - Its own V8 Isolate (via Browser/CDP) +/// - Its own memory-limited allocator +const SessionThread = @This(); + +thread: ?std.Thread, +shutdown: std.atomic.Value(bool), +client_socket: posix.socket_t, +shared: *SharedState, +session_manager: *SessionManager, +limited_allocator: LimitedAllocator, +http_client: ?*HttpClient, +timeout_ms: u32, +json_version_response: []const u8, + +pub fn spawn( + shared: *SharedState, + session_manager: *SessionManager, + socket: posix.socket_t, + timeout_ms: u32, + json_version_response: []const u8, + session_memory_limit: usize, +) !*SessionThread { + const self = try shared.allocator.create(SessionThread); + errdefer shared.allocator.destroy(self); + + self.* = .{ + .thread = null, + .shutdown = std.atomic.Value(bool).init(false), + .client_socket = socket, + .shared = shared, + .session_manager = session_manager, + .limited_allocator = LimitedAllocator.init(shared.allocator, session_memory_limit), + .http_client = null, + .timeout_ms = timeout_ms, + .json_version_response = json_version_response, + }; + + // Start the thread + self.thread = try std.Thread.spawn(.{}, run, .{self}); + + return self; +} + +pub fn stop(self: *SessionThread) void { + self.shutdown.store(true, .release); + + // Close the socket to interrupt any blocking reads + if (self.client_socket != -1) { + switch (builtin.target.os.tag) { + .linux => posix.shutdown(self.client_socket, .recv) catch {}, + .macos, .freebsd, .netbsd, .openbsd => posix.close(self.client_socket), + else => {}, + } + } +} + +pub fn join(self: *SessionThread) void { + if (self.thread) |thread| { + thread.join(); + self.thread = null; + } +} + +pub fn deinit(self: *SessionThread) void { + self.join(); + + if (self.http_client) |client| { + client.deinit(); + self.http_client = null; + } + + self.shared.allocator.destroy(self); +} + +fn sessionAllocator(self: *SessionThread) Allocator { + return self.limited_allocator.allocator(); +} + +fn run(self: *SessionThread) void { + defer { + // Remove ourselves from the session manager when we're done + self.session_manager.remove(self); + } + + self.runInner() catch |err| { + log.err(.app, "session thread error", .{ .err = err }); + }; +} + +fn runInner(self: *SessionThread) !void { + const alloc = self.sessionAllocator(); + + // Create our own HTTP client using the shared curl_share + self.http_client = try self.shared.createHttpClient(alloc); + errdefer { + if (self.http_client) |client| { + client.deinit(); + self.http_client = null; + } + } + + const client = try alloc.create(Client); + defer alloc.destroy(client); + + client.* = try Client.init(self.client_socket, self); + defer client.deinit(); + + var http = self.http_client.?; + http.cdp_client = .{ + .socket = self.client_socket, + .ctx = client, + .blocking_read_start = Client.blockingReadStart, + .blocking_read = Client.blockingRead, + .blocking_read_end = Client.blockingReadStop, + }; + defer http.cdp_client = null; + + lp.assert(client.mode == .http, "SessionThread.run invalid mode", .{}); + + const timeout_ms = self.timeout_ms; + + while (!self.shutdown.load(.acquire)) { + const tick_result = http.tick(timeout_ms) catch .normal; + if (tick_result != .cdp_socket) { + log.info(.app, "CDP timeout", .{}); + return; + } + + if (client.readSocket() == false) { + return; + } + + if (client.mode == .cdp) { + break; // switch to CDP loop + } + } + + var cdp = &client.mode.cdp; + var last_message = timestamp(.monotonic); + var ms_remaining = timeout_ms; + + while (!self.shutdown.load(.acquire)) { + switch (cdp.pageWait(ms_remaining)) { + .cdp_socket => { + if (client.readSocket() == false) { + return; + } + last_message = timestamp(.monotonic); + ms_remaining = timeout_ms; + }, + .no_page => { + const tick_res = http.tick(ms_remaining) catch .normal; + if (tick_res != .cdp_socket) { + log.info(.app, "CDP timeout", .{}); + return; + } + if (client.readSocket() == false) { + return; + } + last_message = timestamp(.monotonic); + ms_remaining = timeout_ms; + }, + .done => { + const elapsed = timestamp(.monotonic) - last_message; + if (elapsed > ms_remaining) { + log.info(.app, "CDP timeout", .{}); + return; + } + ms_remaining -= @intCast(elapsed); + }, + .navigate => unreachable, + } + } +} + + +/// The CDP/WebSocket client - adapted from Server.zig +pub const Client = struct { + mode: union(enum) { + http: void, + cdp: CDP, + }, + + session_thread: *SessionThread, + reader: Reader(true), + socket: posix.socket_t, + socket_flags: usize, + send_arena: ArenaAllocator, + + const EMPTY_PONG = [_]u8{ 138, 0 }; + const CLOSE_NORMAL = [_]u8{ 136, 2, 3, 232 }; + const CLOSE_TOO_BIG = [_]u8{ 136, 2, 3, 241 }; + const CLOSE_PROTOCOL_ERROR = [_]u8{ 136, 2, 3, 234 }; + const CLOSE_TIMEOUT = [_]u8{ 136, 2, 15, 160 }; + + fn init(socket: posix.socket_t, session_thread: *SessionThread) !Client { + const socket_flags = try posix.fcntl(socket, posix.F.GETFL, 0); + const nonblocking = @as(u32, @bitCast(posix.O{ .NONBLOCK = true })); + lp.assert(socket_flags & nonblocking == nonblocking, "Client.init blocking", .{}); + + const alloc = session_thread.sessionAllocator(); + var reader = try Reader(true).init(alloc); + errdefer reader.deinit(); + + return .{ + .socket = socket, + .session_thread = session_thread, + .reader = reader, + .mode = .{ .http = {} }, + .socket_flags = socket_flags, + .send_arena = ArenaAllocator.init(alloc), + }; + } + + fn deinit(self: *Client) void { + switch (self.mode) { + .cdp => |*cdp| cdp.deinit(), + .http => {}, + } + self.reader.deinit(); + self.send_arena.deinit(); + } + + fn blockingReadStart(ctx: *anyopaque) bool { + const self: *Client = @ptrCast(@alignCast(ctx)); + _ = posix.fcntl(self.socket, posix.F.SETFL, self.socket_flags & ~@as(u32, @bitCast(posix.O{ .NONBLOCK = true }))) catch |err| { + log.warn(.app, "CDP blockingReadStart", .{ .err = err }); + return false; + }; + return true; + } + + fn blockingRead(ctx: *anyopaque) bool { + const self: *Client = @ptrCast(@alignCast(ctx)); + return self.readSocket(); + } + + fn blockingReadStop(ctx: *anyopaque) bool { + const self: *Client = @ptrCast(@alignCast(ctx)); + _ = posix.fcntl(self.socket, posix.F.SETFL, self.socket_flags) catch |err| { + log.warn(.app, "CDP blockingReadStop", .{ .err = err }); + return false; + }; + return true; + } + + fn readSocket(self: *Client) bool { + const n = posix.read(self.socket, self.readBuf()) catch |err| { + log.warn(.app, "CDP read", .{ .err = err }); + return false; + }; + + if (n == 0) { + log.info(.app, "CDP disconnect", .{}); + return false; + } + + return self.processData(n) catch false; + } + + fn readBuf(self: *Client) []u8 { + return self.reader.readBuf(); + } + + fn processData(self: *Client, len: usize) !bool { + self.reader.len += len; + + switch (self.mode) { + .cdp => |*cdp| return self.processWebsocketMessage(cdp), + .http => return self.processHTTPRequest(), + } + } + + fn processHTTPRequest(self: *Client) !bool { + lp.assert(self.reader.pos == 0, "Client.HTTP pos", .{ .pos = self.reader.pos }); + const request = self.reader.buf[0..self.reader.len]; + + if (request.len > MAX_HTTP_REQUEST_SIZE) { + self.writeHTTPErrorResponse(413, "Request too large"); + return error.RequestTooLarge; + } + + if (std.mem.endsWith(u8, request, "\r\n\r\n") == false) { + return true; + } + + defer self.reader.len = 0; + return self.handleHTTPRequest(request) catch |err| { + switch (err) { + error.NotFound => self.writeHTTPErrorResponse(404, "Not found"), + error.InvalidRequest => self.writeHTTPErrorResponse(400, "Invalid request"), + error.InvalidProtocol => self.writeHTTPErrorResponse(400, "Invalid HTTP protocol"), + error.MissingHeaders => self.writeHTTPErrorResponse(400, "Missing required header"), + error.InvalidUpgradeHeader => self.writeHTTPErrorResponse(400, "Unsupported upgrade type"), + error.InvalidVersionHeader => self.writeHTTPErrorResponse(400, "Invalid websocket version"), + error.InvalidConnectionHeader => self.writeHTTPErrorResponse(400, "Invalid connection header"), + else => { + log.err(.app, "server 500", .{ .err = err, .req = request[0..@min(100, request.len)] }); + self.writeHTTPErrorResponse(500, "Internal Server Error"); + }, + } + return err; + }; + } + + fn handleHTTPRequest(self: *Client, request: []u8) !bool { + if (request.len < 18) { + return error.InvalidRequest; + } + + if (std.mem.eql(u8, request[0..4], "GET ") == false) { + return error.NotFound; + } + + const url_end = std.mem.indexOfScalarPos(u8, request, 4, ' ') orelse { + return error.InvalidRequest; + }; + + const url = request[4..url_end]; + + if (std.mem.eql(u8, url, "/")) { + try self.upgradeConnection(request); + return true; + } + + if (std.mem.eql(u8, url, "/json/version")) { + try self.send(self.session_thread.json_version_response); + try posix.shutdown(self.socket, .recv); + return false; + } + + return error.NotFound; + } + + fn upgradeConnection(self: *Client, request: []u8) !void { + const request_line_end = std.mem.indexOfScalar(u8, request, '\r') orelse unreachable; + const request_line = request[0..request_line_end]; + + if (!std.ascii.endsWithIgnoreCase(request_line, "http/1.1")) { + return error.InvalidProtocol; + } + + var key: []const u8 = ""; + var required_headers: u8 = 0; + var buf = request[request_line_end + 2 ..]; + + while (buf.len > 4) { + const index = std.mem.indexOfScalar(u8, buf, '\r') orelse unreachable; + const separator = std.mem.indexOfScalar(u8, buf[0..index], ':') orelse return error.InvalidRequest; + + const name = std.mem.trim(u8, toLower(buf[0..separator]), &std.ascii.whitespace); + const value = std.mem.trim(u8, buf[(separator + 1)..index], &std.ascii.whitespace); + + if (std.mem.eql(u8, name, "upgrade")) { + if (!std.ascii.eqlIgnoreCase("websocket", value)) { + return error.InvalidUpgradeHeader; + } + required_headers |= 1; + } else if (std.mem.eql(u8, name, "sec-websocket-version")) { + if (value.len != 2 or value[0] != '1' or value[1] != '3') { + return error.InvalidVersionHeader; + } + required_headers |= 2; + } else if (std.mem.eql(u8, name, "connection")) { + if (std.ascii.indexOfIgnoreCase(value, "upgrade") == null) { + return error.InvalidConnectionHeader; + } + required_headers |= 4; + } else if (std.mem.eql(u8, name, "sec-websocket-key")) { + key = value; + required_headers |= 8; + } + + const next = index + 2; + buf = buf[next..]; + } + + if (required_headers != 15) { + return error.MissingHeaders; + } + + const alloc = self.send_arena.allocator(); + + const response = blk: { + const template = + "HTTP/1.1 101 Switching Protocols\r\n" ++ + "Upgrade: websocket\r\n" ++ + "Connection: upgrade\r\n" ++ + "Sec-Websocket-Accept: 0000000000000000000000000000\r\n\r\n"; + + const res = try alloc.dupe(u8, template); + + const key_pos = res.len - 32; + var h: [20]u8 = undefined; + var hasher = std.crypto.hash.Sha1.init(.{}); + hasher.update(key); + hasher.update("258EAFA5-E914-47DA-95CA-C5AB0DC85B11"); + hasher.final(&h); + + _ = std.base64.standard.Encoder.encode(res[key_pos .. key_pos + 28], h[0..]); + + break :blk res; + }; + + self.mode = .{ .cdp = try CDP.init(self.session_thread.shared, self.session_thread.http_client.?, self) }; + return self.send(response); + } + + fn writeHTTPErrorResponse(self: *Client, comptime status: u16, comptime body: []const u8) void { + const response = std.fmt.comptimePrint( + "HTTP/1.1 {d} \r\nConnection: Close\r\nContent-Length: {d}\r\n\r\n{s}", + .{ status, body.len, body }, + ); + self.send(response) catch {}; + } + + fn processWebsocketMessage(self: *Client, cdp: *CDP) !bool { + var reader = &self.reader; + while (true) { + const msg = reader.next() catch |err| { + switch (err) { + error.TooLarge => self.send(&CLOSE_TOO_BIG) catch {}, + error.NotMasked => self.send(&CLOSE_PROTOCOL_ERROR) catch {}, + error.ReservedFlags => self.send(&CLOSE_PROTOCOL_ERROR) catch {}, + error.InvalidMessageType => self.send(&CLOSE_PROTOCOL_ERROR) catch {}, + error.ControlTooLarge => self.send(&CLOSE_PROTOCOL_ERROR) catch {}, + error.InvalidContinuation => self.send(&CLOSE_PROTOCOL_ERROR) catch {}, + error.NestedFragementation => self.send(&CLOSE_PROTOCOL_ERROR) catch {}, + error.OutOfMemory => {}, + } + return err; + } orelse break; + + switch (msg.type) { + .pong => {}, + .ping => try self.sendPong(msg.data), + .close => { + self.send(&CLOSE_NORMAL) catch {}; + return false; + }, + .text, .binary => if (cdp.handleMessage(msg.data) == false) { + return false; + }, + } + if (msg.cleanup_fragment) { + reader.cleanup(); + } + } + + reader.compact(); + return true; + } + + fn sendPong(self: *Client, data: []const u8) !void { + if (data.len == 0) { + return self.send(&EMPTY_PONG); + } + var header_buf: [10]u8 = undefined; + const header = websocketHeader(&header_buf, .pong, data.len); + + const alloc = self.send_arena.allocator(); + var framed = try alloc.alloc(u8, header.len + data.len); + @memcpy(framed[0..header.len], header); + @memcpy(framed[header.len..], data); + return self.send(framed); + } + + pub fn sendJSON(self: *Client, message: anytype, opts: std.json.Stringify.Options) !void { + const alloc = self.send_arena.allocator(); + + var aw: std.Io.Writer.Allocating = .init(alloc); + try aw.writer.writeAll(&.{ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }); + try std.json.Stringify.value(message, opts, &aw.writer); + const written = aw.written(); + + // Fill in websocket header + var header_buf: [10]u8 = undefined; + const payload_len = written.len - 10; + const header = websocketHeader(&header_buf, .text, payload_len); + const start = 10 - header.len; + + // Copy header into the reserved space + const data = @constCast(written); + @memcpy(data[start..10], header); + return self.send(data[start..]); + } + + pub fn sendJSONRaw(self: *Client, buf: std.ArrayListUnmanaged(u8)) !void { + var header_buf: [10]u8 = undefined; + const payload_len = buf.items.len - 10; + const header = websocketHeader(&header_buf, .text, payload_len); + const start = 10 - header.len; + + const message = buf.items; + @memcpy(message[start..10], header); + return self.send(message[start..]); + } + + fn send(self: *Client, data: []const u8) !void { + var pos: usize = 0; + var changed_to_blocking: bool = false; + defer _ = self.send_arena.reset(.{ .retain_with_limit = 1024 * 32 }); + + defer if (changed_to_blocking) { + _ = posix.fcntl(self.socket, posix.F.SETFL, self.socket_flags) catch |err| { + log.err(.app, "CDP restore nonblocking", .{ .err = err }); + }; + }; + + LOOP: while (pos < data.len) { + const written = posix.write(self.socket, data[pos..]) catch |err| switch (err) { + error.WouldBlock => { + lp.assert(changed_to_blocking == false, "Client.double block", .{}); + changed_to_blocking = true; + _ = try posix.fcntl(self.socket, posix.F.SETFL, self.socket_flags & ~@as(u32, @bitCast(posix.O{ .NONBLOCK = true }))); + continue :LOOP; + }, + else => return err, + }; + + if (written == 0) { + return error.Closed; + } + pos += written; + } + } +}; + +// WebSocket message reader +fn Reader(comptime EXPECT_MASK: bool) type { + return struct { + allocator: Allocator, + pos: usize = 0, + len: usize = 0, + buf: []u8, + fragments: ?Fragments = null, + + const Self = @This(); + + fn init(alloc: Allocator) !Self { + const buf = try alloc.alloc(u8, 16 * 1024); + return .{ + .buf = buf, + .allocator = alloc, + }; + } + + fn deinit(self: *Self) void { + self.cleanup(); + self.allocator.free(self.buf); + } + + fn cleanup(self: *Self) void { + if (self.fragments) |*f| { + f.message.deinit(self.allocator); + self.fragments = null; + } + } + + fn readBuf(self: *Self) []u8 { + return self.buf[self.len..]; + } + + fn next(self: *Self) !?Message { + LOOP: while (true) { + var buf = self.buf[self.pos..self.len]; + + const length_of_len, const message_len = extractLengths(buf) orelse { + return null; + }; + + const byte1 = buf[0]; + + if (byte1 & 112 != 0) { + return error.ReservedFlags; + } + + if (comptime EXPECT_MASK) { + if (buf[1] & 128 != 128) { + return error.NotMasked; + } + } else if (buf[1] & 128 != 0) { + return error.Masked; + } + + var is_control = false; + var is_continuation = false; + var message_type: Message.Type = undefined; + switch (byte1 & 15) { + 0 => is_continuation = true, + 1 => message_type = .text, + 2 => message_type = .binary, + 8 => { + is_control = true; + message_type = .close; + }, + 9 => { + is_control = true; + message_type = .ping; + }, + 10 => { + is_control = true; + message_type = .pong; + }, + else => return error.InvalidMessageType, + } + + if (is_control) { + if (message_len > 125) { + return error.ControlTooLarge; + } + } else if (message_len > MAX_MESSAGE_SIZE) { + return error.TooLarge; + } else if (message_len > self.buf.len) { + const len_now = self.buf.len; + self.buf = try growBuffer(self.allocator, self.buf, message_len); + buf = self.buf[0..len_now]; + return null; + } else if (buf.len < message_len) { + return null; + } + + const header_len = 2 + length_of_len + if (comptime EXPECT_MASK) 4 else 0; + const payload = buf[header_len..message_len]; + if (comptime EXPECT_MASK) { + mask(buf[header_len - 4 .. header_len], payload); + } + + self.pos += message_len; + const fin = byte1 & 128 == 128; + + if (is_continuation) { + const fragments = &(self.fragments orelse return error.InvalidContinuation); + if (fragments.message.items.len + message_len > MAX_MESSAGE_SIZE) { + return error.TooLarge; + } + + try fragments.message.appendSlice(self.allocator, payload); + + if (fin == false) { + continue :LOOP; + } + + return .{ + .type = fragments.type, + .data = fragments.message.items, + .cleanup_fragment = true, + }; + } + + const can_be_fragmented = message_type == .text or message_type == .binary; + if (self.fragments != null and can_be_fragmented) { + return error.NestedFragementation; + } + + if (fin == false) { + if (can_be_fragmented == false) { + return error.InvalidContinuation; + } + + var fragments = Fragments{ .message = .{}, .type = message_type }; + try fragments.message.appendSlice(self.allocator, payload); + self.fragments = fragments; + continue :LOOP; + } + + return .{ + .data = payload, + .type = message_type, + .cleanup_fragment = false, + }; + } + } + + fn extractLengths(buf: []const u8) ?struct { usize, usize } { + if (buf.len < 2) { + return null; + } + + const length_of_len: usize = switch (buf[1] & 127) { + 126 => 2, + 127 => 8, + else => 0, + }; + + if (buf.len < length_of_len + 2) { + return null; + } + + const message_length = switch (length_of_len) { + 2 => @as(u16, @intCast(buf[3])) | @as(u16, @intCast(buf[2])) << 8, + 8 => @as(u64, @intCast(buf[9])) | @as(u64, @intCast(buf[8])) << 8 | @as(u64, @intCast(buf[7])) << 16 | @as(u64, @intCast(buf[6])) << 24 | @as(u64, @intCast(buf[5])) << 32 | @as(u64, @intCast(buf[4])) << 40 | @as(u64, @intCast(buf[3])) << 48 | @as(u64, @intCast(buf[2])) << 56, + else => buf[1] & 127, + } + length_of_len + 2 + if (comptime EXPECT_MASK) 4 else 0; + + return .{ length_of_len, message_length }; + } + + fn compact(self: *Self) void { + const pos = self.pos; + const len_now = self.len; + + lp.assert(pos <= len_now, "Client.Reader.compact precondition", .{ .pos = pos, .len = len_now }); + + const partial_bytes = len_now - pos; + + if (partial_bytes == 0) { + self.pos = 0; + self.len = 0; + return; + } + + const partial = self.buf[pos..len_now]; + + if (extractLengths(partial)) |length_meta| { + const next_message_len = length_meta.@"1"; + lp.assert(pos <= len_now, "Client.Reader.compact postcondition", .{ .next_len = next_message_len, .partial = partial_bytes }); + + const missing_bytes = next_message_len - partial_bytes; + const free_space = self.buf.len - len_now; + if (missing_bytes < free_space) { + return; + } + } + + std.mem.copyForwards(u8, self.buf, partial); + self.pos = 0; + self.len = partial_bytes; + } + }; +} + +fn growBuffer(alloc: Allocator, buf: []u8, required_capacity: usize) ![]u8 { + var new_capacity = buf.len; + while (true) { + new_capacity +|= new_capacity / 2 + 8; + if (new_capacity >= required_capacity) break; + } + + log.debug(.app, "CDP buffer growth", .{ .from = buf.len, .to = new_capacity }); + + if (alloc.resize(buf, new_capacity)) { + return buf.ptr[0..new_capacity]; + } + const new_buffer = try alloc.alloc(u8, new_capacity); + @memcpy(new_buffer[0..buf.len], buf); + alloc.free(buf); + return new_buffer; +} + +const Fragments = struct { + type: Message.Type, + message: std.ArrayListUnmanaged(u8), +}; + +const Message = struct { + type: Type, + data: []const u8, + cleanup_fragment: bool, + + const Type = enum { + text, + binary, + close, + ping, + pong, + }; +}; + +const OpCode = enum(u8) { + text = 128 | 1, + close = 128 | 8, + pong = 128 | 10, +}; + +fn websocketHeader(buf: []u8, op_code: OpCode, payload_len: usize) []const u8 { + lp.assert(buf.len == 10, "Websocket.Header", .{ .len = buf.len }); + + const len = payload_len; + buf[0] = 128 | @intFromEnum(op_code); + + if (len <= 125) { + buf[1] = @intCast(len); + return buf[0..2]; + } + + if (len < 65536) { + buf[1] = 126; + buf[2] = @intCast((len >> 8) & 0xFF); + buf[3] = @intCast(len & 0xFF); + return buf[0..4]; + } + + buf[1] = 127; + buf[2] = 0; + buf[3] = 0; + buf[4] = 0; + buf[5] = 0; + buf[6] = @intCast((len >> 24) & 0xFF); + buf[7] = @intCast((len >> 16) & 0xFF); + buf[8] = @intCast((len >> 8) & 0xFF); + buf[9] = @intCast(len & 0xFF); + return buf[0..10]; +} + +fn toLower(str: []u8) []u8 { + for (str, 0..) |ch, i| { + str[i] = std.ascii.toLower(ch); + } + return str; +} + +const backend_supports_vectors = switch (builtin.zig_backend) { + .stage2_llvm, .stage2_c => true, + else => false, +}; + +fn mask(m: []const u8, payload: []u8) void { + var data = payload; + + if (!comptime backend_supports_vectors) return simpleMask(m, data); + + const vector_size = std.simd.suggestVectorLength(u8) orelse @sizeOf(usize); + if (data.len >= vector_size) { + const mask_vector = std.simd.repeat(vector_size, @as(@Vector(4, u8), m[0..4].*)); + while (data.len >= vector_size) { + const slice = data[0..vector_size]; + const masked_data_slice: @Vector(vector_size, u8) = slice.*; + slice.* = masked_data_slice ^ mask_vector; + data = data[vector_size..]; + } + } + simpleMask(m, data); +} + +fn simpleMask(m: []const u8, payload: []u8) void { + for (payload, 0..) |b, i| { + payload[i] = b ^ m[i & 3]; + } +} diff --git a/src/SharedState.zig b/src/SharedState.zig new file mode 100644 index 00000000..7931ae1c --- /dev/null +++ b/src/SharedState.zig @@ -0,0 +1,266 @@ +// Copyright (C) 2023-2025 Lightpanda (Selecy SAS) +// +// Francis Bouvier +// Pierre Tachoire +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as +// published by the Free Software Foundation, either version 3 of the +// License, or (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +const std = @import("std"); +const Allocator = std.mem.Allocator; +const ArenaAllocator = std.heap.ArenaAllocator; + +const log = @import("log.zig"); +const Http = @import("http/Http.zig"); +const HttpClient = @import("http/Client.zig"); +const CurlShare = @import("http/CurlShare.zig"); +const Snapshot = @import("browser/js/Snapshot.zig"); +const Platform = @import("browser/js/Platform.zig"); +const Notification = @import("Notification.zig"); +const App = @import("App.zig"); + +const c = Http.c; + +/// SharedState holds all state shared between CDP sessions (read-only after init). +/// Each SessionThread gets a reference to this and can create its own resources +/// (like HttpClient) that use the shared components. +const SharedState = @This(); + +platform: Platform, // V8 platform (process-wide) +snapshot: Snapshot, // V8 startup snapshot +ca_blob: ?c.curl_blob, // TLS certificates +http_opts: Http.Opts, // HTTP configuration +curl_share: *CurlShare, // Shared HTTP resources (DNS, TLS, connections) +notification: *Notification, // Global notification hub +allocator: Allocator, // Thread-safe allocator +arena: ArenaAllocator, // Arena for shared resources +owns_v8_resources: bool, // Track whether V8 resources are owned or borrowed from App + +pub const Config = struct { + max_sessions: u32 = 10, // Max concurrent CDP connections + session_memory_limit: usize = 64 * 1024 * 1024, // 64MB per session + run_mode: App.RunMode, + tls_verify_host: bool = true, + http_proxy: ?[:0]const u8 = null, + proxy_bearer_token: ?[:0]const u8 = null, + http_timeout_ms: ?u31 = null, + http_connect_timeout_ms: ?u31 = null, + http_max_host_open: ?u8 = null, + http_max_concurrent: ?u8 = null, + user_agent: [:0]const u8, +}; + +pub fn init(allocator: Allocator, config: Config) !*SharedState { + const self = try allocator.create(SharedState); + errdefer allocator.destroy(self); + + self.allocator = allocator; + self.arena = ArenaAllocator.init(allocator); + errdefer self.arena.deinit(); + + // Initialize V8 platform (process-wide singleton) + self.platform = try Platform.init(); + errdefer self.platform.deinit(); + + // Load V8 startup snapshot + self.snapshot = try Snapshot.load(); + errdefer self.snapshot.deinit(); + + self.owns_v8_resources = true; + + // Initialize notification hub + self.notification = try Notification.init(allocator, null); + errdefer self.notification.deinit(); + + // Build HTTP options + const arena_alloc = self.arena.allocator(); + var adjusted_opts = Http.Opts{ + .max_host_open = config.http_max_host_open orelse 4, + .max_concurrent = config.http_max_concurrent orelse 10, + .timeout_ms = config.http_timeout_ms orelse 5000, + .connect_timeout_ms = config.http_connect_timeout_ms orelse 0, + .http_proxy = config.http_proxy, + .tls_verify_host = config.tls_verify_host, + .proxy_bearer_token = config.proxy_bearer_token, + .user_agent = config.user_agent, + }; + + if (config.proxy_bearer_token) |bt| { + adjusted_opts.proxy_bearer_token = try std.fmt.allocPrintSentinel(arena_alloc, "Proxy-Authorization: Bearer {s}", .{bt}, 0); + } + self.http_opts = adjusted_opts; + + // Load TLS certificates + if (config.tls_verify_host) { + self.ca_blob = try loadCerts(allocator, arena_alloc); + } else { + self.ca_blob = null; + } + + // Initialize curl share handle for shared resources + self.curl_share = try CurlShare.init(allocator); + errdefer self.curl_share.deinit(); + + return self; +} + +/// Create SharedState by borrowing V8 resources from an existing App. +/// Use this when App is already initialized (e.g., in tests). +pub fn initFromApp(app: *App, allocator: Allocator) !*SharedState { + const self = try allocator.create(SharedState); + errdefer allocator.destroy(self); + + self.allocator = allocator; + self.arena = ArenaAllocator.init(allocator); + errdefer self.arena.deinit(); + + // Borrow V8 resources from App (don't initialize new ones) + self.platform = app.platform; + self.snapshot = app.snapshot; + self.owns_v8_resources = false; + + // Initialize notification hub + self.notification = try Notification.init(allocator, app.notification); + errdefer self.notification.deinit(); + + // Build HTTP options from App config + const config = app.config; + const arena_alloc = self.arena.allocator(); + var adjusted_opts = Http.Opts{ + .max_host_open = config.http_max_host_open orelse 4, + .max_concurrent = config.http_max_concurrent orelse 10, + .timeout_ms = config.http_timeout_ms orelse 5000, + .connect_timeout_ms = config.http_connect_timeout_ms orelse 0, + .http_proxy = config.http_proxy, + .tls_verify_host = config.tls_verify_host, + .proxy_bearer_token = config.proxy_bearer_token, + .user_agent = config.user_agent, + }; + + if (config.proxy_bearer_token) |bt| { + adjusted_opts.proxy_bearer_token = try std.fmt.allocPrintSentinel(arena_alloc, "Proxy-Authorization: Bearer {s}", .{bt}, 0); + } + self.http_opts = adjusted_opts; + + // Load TLS certificates + if (config.tls_verify_host) { + self.ca_blob = try loadCerts(allocator, arena_alloc); + } else { + self.ca_blob = null; + } + + // Initialize curl share handle for shared resources + self.curl_share = try CurlShare.init(allocator); + errdefer self.curl_share.deinit(); + + return self; +} + +pub fn deinit(self: *SharedState) void { + const allocator = self.allocator; + + self.notification.deinit(); + self.curl_share.deinit(); + + // Only cleanup V8 resources if we own them + if (self.owns_v8_resources) { + self.snapshot.deinit(); + self.platform.deinit(); + } + + self.arena.deinit(); + + allocator.destroy(self); +} + +/// Create a new HTTP client for a session thread. +/// The client will use the shared curl_share for DNS, TLS, and connection pooling. +pub fn createHttpClient(self: *SharedState, session_allocator: Allocator) !*HttpClient { + return HttpClient.init( + session_allocator, + self.ca_blob, + self.http_opts, + self.curl_share.getHandle(), + ); +} + +// Adapted from Http.zig +fn loadCerts(allocator: Allocator, arena: Allocator) !c.curl_blob { + var bundle: std.crypto.Certificate.Bundle = .{}; + try bundle.rescan(allocator); + defer bundle.deinit(allocator); + + const bytes = bundle.bytes.items; + if (bytes.len == 0) { + log.warn(.app, "No system certificates", .{}); + return .{ + .len = 0, + .flags = 0, + .data = bytes.ptr, + }; + } + + const encoder = std.base64.standard.Encoder; + var arr: std.ArrayListUnmanaged(u8) = .empty; + + const encoded_size = encoder.calcSize(bytes.len); + const buffer_size = encoded_size + + (bundle.map.count() * 75) + + (encoded_size / 64); + try arr.ensureTotalCapacity(arena, buffer_size); + var writer = arr.writer(arena); + + var it = bundle.map.valueIterator(); + while (it.next()) |index| { + const cert = try std.crypto.Certificate.der.Element.parse(bytes, index.*); + + try writer.writeAll("-----BEGIN CERTIFICATE-----\n"); + var line_writer = LineWriter{ .inner = writer }; + try encoder.encodeWriter(&line_writer, bytes[index.*..cert.slice.end]); + try writer.writeAll("\n-----END CERTIFICATE-----\n"); + } + + return .{ + .len = arr.items.len, + .data = arr.items.ptr, + .flags = 0, + }; +} + +const LineWriter = struct { + col: usize = 0, + inner: std.ArrayListUnmanaged(u8).Writer, + + pub fn writeAll(self: *LineWriter, data: []const u8) !void { + var lwriter = self.inner; + + var col = self.col; + const len = 64 - col; + + var remain = data; + if (remain.len > len) { + col = 0; + try lwriter.writeAll(data[0..len]); + try lwriter.writeByte('\n'); + remain = data[len..]; + } + + while (remain.len > 64) { + try lwriter.writeAll(remain[0..64]); + try lwriter.writeByte('\n'); + remain = data[len..]; + } + try lwriter.writeAll(remain); + self.col = col + remain.len; + } +}; diff --git a/src/browser/Browser.zig b/src/browser/Browser.zig index 1a74468b..e557f9f0 100644 --- a/src/browser/Browser.zig +++ b/src/browser/Browser.zig @@ -24,6 +24,7 @@ const ArenaAllocator = std.heap.ArenaAllocator; const js = @import("js/js.zig"); const log = @import("../log.zig"); const App = @import("../App.zig"); +const SharedState = @import("../SharedState.zig"); const HttpClient = @import("../http/Client.zig"); const Notification = @import("../Notification.zig"); @@ -37,7 +38,8 @@ const Session = @import("Session.zig"); const Browser = @This(); env: js.Env, -app: *App, +shared: ?*SharedState, +app: ?*App, session: ?Session, allocator: Allocator, http_client: *HttpClient, @@ -47,7 +49,33 @@ session_arena: ArenaAllocator, transfer_arena: ArenaAllocator, notification: *Notification, -pub fn init(app: *App) !Browser { +/// Initialize a Browser with SharedState (for multi-session CDP mode) +pub fn init(shared: *SharedState, http_client: *HttpClient, allocator: Allocator) !Browser { + var env = try js.Env.init(allocator, &shared.platform, &shared.snapshot); + errdefer env.deinit(); + + const notification = try Notification.init(allocator, shared.notification); + http_client.notification = notification; + http_client.next_request_id = 0; // Should we track ids in CDP only? + errdefer notification.deinit(); + + return .{ + .shared = shared, + .app = null, + .env = env, + .session = null, + .allocator = allocator, + .notification = notification, + .http_client = http_client, + .call_arena = ArenaAllocator.init(allocator), + .page_arena = ArenaAllocator.init(allocator), + .session_arena = ArenaAllocator.init(allocator), + .transfer_arena = ArenaAllocator.init(allocator), + }; +} + +/// Initialize a Browser with App (for single-session mode like fetch) +pub fn initFromApp(app: *App) !Browser { const allocator = app.allocator; var env = try js.Env.init(allocator, &app.platform, &app.snapshot); @@ -55,10 +83,11 @@ pub fn init(app: *App) !Browser { const notification = try Notification.init(allocator, app.notification); app.http.client.notification = notification; - app.http.client.next_request_id = 0; // Should we track ids in CDP only? + app.http.client.next_request_id = 0; errdefer notification.deinit(); return .{ + .shared = null, .app = app, .env = env, .session = null, diff --git a/src/browser/Session.zig b/src/browser/Session.zig index 340d6b61..e0c91767 100644 --- a/src/browser/Session.zig +++ b/src/browser/Session.zig @@ -66,7 +66,7 @@ pub fn init(self: *Session, browser: *Browser) !void { var executor = try browser.env.newExecutionWorld(); errdefer executor.deinit(); - const allocator = browser.app.allocator; + const allocator = browser.allocator; const session_allocator = browser.session_arena.allocator(); self.* = .{ @@ -86,7 +86,7 @@ pub fn deinit(self: *Session) void { self.removePage(); } self.cookie_jar.deinit(); - self.storage_shed.deinit(self.browser.app.allocator); + self.storage_shed.deinit(self.browser.allocator); self.executor.deinit(); } diff --git a/src/browser/webapi/Navigator.zig b/src/browser/webapi/Navigator.zig index 451c7521..6210c6e7 100644 --- a/src/browser/webapi/Navigator.zig +++ b/src/browser/webapi/Navigator.zig @@ -27,7 +27,15 @@ _pad: bool = false, pub const init: Navigator = .{}; pub fn getUserAgent(_: *const Navigator, page: *Page) []const u8 { - return page._session.browser.app.config.user_agent; + const browser = page._session.browser; + // Handle both modes: SharedState (multi-session) or App (single-session) + if (browser.shared) |shared| { + return shared.http_opts.user_agent; + } else if (browser.app) |app| { + return app.config.user_agent; + } else { + return "Lightpanda/1.0"; + } } pub fn getAppName(_: *const Navigator) []const u8 { diff --git a/src/cdp/cdp.zig b/src/cdp/cdp.zig index 77378d58..3b6a50cb 100644 --- a/src/cdp/cdp.zig +++ b/src/cdp/cdp.zig @@ -25,7 +25,8 @@ const json = std.json; const log = @import("../log.zig"); const js = @import("../browser/js/js.zig"); -const App = @import("../App.zig"); +const SharedState = @import("../SharedState.zig"); +const HttpClient = @import("../http/Client.zig"); const Browser = @import("../browser/Browser.zig"); const Session = @import("../browser/Session.zig"); const Page = @import("../browser/Page.zig"); @@ -78,9 +79,25 @@ pub fn CDPT(comptime TypeProvider: type) type { const Self = @This(); - pub fn init(app: *App, client: TypeProvider.Client) !Self { + pub fn init(shared: *SharedState, http_client: *HttpClient, client: TypeProvider.Client) !Self { + const allocator = shared.allocator; + const browser = try Browser.init(shared, http_client, allocator); + errdefer browser.deinit(); + + return .{ + .client = client, + .browser = browser, + .allocator = allocator, + .browser_context = null, + .message_arena = std.heap.ArenaAllocator.init(allocator), + .notification_arena = std.heap.ArenaAllocator.init(allocator), + }; + } + + /// Initialize CDP with App (for testing and single-session mode) + pub fn initFromApp(app: *lp.App, client: TypeProvider.Client) !Self { const allocator = app.allocator; - const browser = try Browser.init(app); + const browser = try Browser.initFromApp(app); errdefer browser.deinit(); return .{ diff --git a/src/cdp/testing.zig b/src/cdp/testing.zig index 52042849..39715288 100644 --- a/src/cdp/testing.zig +++ b/src/cdp/testing.zig @@ -84,7 +84,8 @@ const TestContext = struct { self.client = Client.init(self.arena.allocator()); // Don't use the arena here. We want to detect leaks in CDP. // The arena is only for test-specific stuff - self.cdp_ = TestCDP.init(base.test_app, &self.client.?) catch unreachable; + // Use test_app from base testing (reuses existing V8 platform) + self.cdp_ = TestCDP.initFromApp(base.test_app, &self.client.?) catch unreachable; } return &self.cdp_.?; } diff --git a/src/http/Client.zig b/src/http/Client.zig index c76a355e..8b23a5f4 100644 --- a/src/http/Client.zig +++ b/src/http/Client.zig @@ -124,7 +124,7 @@ pub const CDPClient = struct { const TransferQueue = std.DoublyLinkedList; -pub fn init(allocator: Allocator, ca_blob: ?c.curl_blob, opts: Http.Opts) !*Client { +pub fn init(allocator: Allocator, ca_blob: ?c.curl_blob, opts: Http.Opts, share_handle: ?*c.CURLSH) !*Client { var transfer_pool = std.heap.MemoryPool(Transfer).init(allocator); errdefer transfer_pool.deinit(); @@ -136,7 +136,7 @@ pub fn init(allocator: Allocator, ca_blob: ?c.curl_blob, opts: Http.Opts) !*Clie try errorMCheck(c.curl_multi_setopt(multi, c.CURLMOPT_MAX_HOST_CONNECTIONS, @as(c_long, opts.max_host_open))); - var handles = try Handles.init(allocator, client, ca_blob, &opts); + var handles = try Handles.init(allocator, client, ca_blob, &opts, share_handle); errdefer handles.deinit(allocator); client.* = .{ @@ -650,7 +650,7 @@ const Handles = struct { const HandleList = std.DoublyLinkedList; // pointer to opts is not stable, don't hold a reference to it! - fn init(allocator: Allocator, client: *Client, ca_blob: ?c.curl_blob, opts: *const Http.Opts) !Handles { + fn init(allocator: Allocator, client: *Client, ca_blob: ?c.curl_blob, opts: *const Http.Opts, share_handle: ?*c.CURLSH) !Handles { const count = if (opts.max_concurrent == 0) 1 else opts.max_concurrent; const handles = try allocator.alloc(Handle, count); @@ -658,7 +658,7 @@ const Handles = struct { var available: HandleList = .{}; for (0..count) |i| { - handles[i] = try Handle.init(client, ca_blob, opts); + handles[i] = try Handle.init(client, ca_blob, opts, share_handle); available.append(&handles[i].node); } @@ -706,12 +706,17 @@ pub const Handle = struct { node: Handles.HandleList.Node, // pointer to opts is not stable, don't hold a reference to it! - fn init(client: *Client, ca_blob: ?c.curl_blob, opts: *const Http.Opts) !Handle { + fn init(client: *Client, ca_blob: ?c.curl_blob, opts: *const Http.Opts, share_handle: ?*c.CURLSH) !Handle { const conn = try Http.Connection.init(ca_blob, opts); errdefer conn.deinit(); const easy = conn.easy; + // Configure shared resources (DNS cache, TLS sessions, connections) + if (share_handle) |sh| { + try errorCheck(c.curl_easy_setopt(easy, c.CURLOPT_SHARE, sh)); + } + // callbacks try errorCheck(c.curl_easy_setopt(easy, c.CURLOPT_HEADERDATA, easy)); try errorCheck(c.curl_easy_setopt(easy, c.CURLOPT_HEADERFUNCTION, Transfer.headerCallback)); diff --git a/src/http/CurlShare.zig b/src/http/CurlShare.zig new file mode 100644 index 00000000..a928adee --- /dev/null +++ b/src/http/CurlShare.zig @@ -0,0 +1,108 @@ +// Copyright (C) 2023-2025 Lightpanda (Selecy SAS) +// +// Francis Bouvier +// Pierre Tachoire +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as +// published by the Free Software Foundation, either version 3 of the +// License, or (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +const std = @import("std"); +const Allocator = std.mem.Allocator; + +const Http = @import("Http.zig"); +const c = Http.c; + +/// Thread-safe wrapper for libcurl's share handle. +/// Allows multiple CURLM handles (one per session thread) to share: +/// - DNS resolution cache +/// - TLS session resumption data +/// - Connection pool +const CurlShare = @This(); + +handle: *c.CURLSH, +dns_mutex: std.Thread.Mutex, +ssl_mutex: std.Thread.Mutex, +conn_mutex: std.Thread.Mutex, +allocator: Allocator, + +pub fn init(allocator: Allocator) !*CurlShare { + const share = try allocator.create(CurlShare); + errdefer allocator.destroy(share); + + const handle = c.curl_share_init() orelse return error.FailedToInitializeShare; + errdefer _ = c.curl_share_cleanup(handle); + + share.* = .{ + .handle = handle, + .dns_mutex = .{}, + .ssl_mutex = .{}, + .conn_mutex = .{}, + .allocator = allocator, + }; + + // Set up lock/unlock callbacks + try errorSHCheck(c.curl_share_setopt(handle, c.CURLSHOPT_LOCKFUNC, @as(?*const fn (?*c.CURL, c.curl_lock_data, c.curl_lock_access, ?*anyopaque) callconv(.c) void, &lockFunc))); + try errorSHCheck(c.curl_share_setopt(handle, c.CURLSHOPT_UNLOCKFUNC, @as(?*const fn (?*c.CURL, c.curl_lock_data, ?*anyopaque) callconv(.c) void, &unlockFunc))); + try errorSHCheck(c.curl_share_setopt(handle, c.CURLSHOPT_USERDATA, @as(?*anyopaque, share))); + + // Configure what data to share + try errorSHCheck(c.curl_share_setopt(handle, c.CURLSHOPT_SHARE, c.CURL_LOCK_DATA_DNS)); + try errorSHCheck(c.curl_share_setopt(handle, c.CURLSHOPT_SHARE, c.CURL_LOCK_DATA_SSL_SESSION)); + try errorSHCheck(c.curl_share_setopt(handle, c.CURLSHOPT_SHARE, c.CURL_LOCK_DATA_CONNECT)); + + return share; +} + +pub fn deinit(self: *CurlShare) void { + _ = c.curl_share_cleanup(self.handle); + self.allocator.destroy(self); +} + +pub fn getHandle(self: *CurlShare) *c.CURLSH { + return self.handle; +} + +fn lockFunc(_: ?*c.CURL, data: c.curl_lock_data, _: c.curl_lock_access, userptr: ?*anyopaque) callconv(.c) void { + const self: *CurlShare = @ptrCast(@alignCast(userptr)); + const mutex = self.getMutex(data) orelse return; + mutex.lock(); +} + +fn unlockFunc(_: ?*c.CURL, data: c.curl_lock_data, userptr: ?*anyopaque) callconv(.c) void { + const self: *CurlShare = @ptrCast(@alignCast(userptr)); + const mutex = self.getMutex(data) orelse return; + mutex.unlock(); +} + +fn getMutex(self: *CurlShare, data: c.curl_lock_data) ?*std.Thread.Mutex { + return switch (data) { + c.CURL_LOCK_DATA_DNS => &self.dns_mutex, + c.CURL_LOCK_DATA_SSL_SESSION => &self.ssl_mutex, + c.CURL_LOCK_DATA_CONNECT => &self.conn_mutex, + else => null, + }; +} + +fn errorSHCheck(code: c.CURLSHcode) !void { + if (code == c.CURLSHE_OK) { + return; + } + return switch (code) { + c.CURLSHE_BAD_OPTION => error.ShareBadOption, + c.CURLSHE_IN_USE => error.ShareInUse, + c.CURLSHE_INVALID => error.ShareInvalid, + c.CURLSHE_NOMEM => error.OutOfMemory, + c.CURLSHE_NOT_BUILT_IN => error.ShareNotBuiltIn, + else => error.ShareUnknown, + }; +} diff --git a/src/http/Http.zig b/src/http/Http.zig index 2465eaf1..aa42a7e1 100644 --- a/src/http/Http.zig +++ b/src/http/Http.zig @@ -66,7 +66,7 @@ pub fn init(allocator: Allocator, opts: Opts) !Http { ca_blob = try loadCerts(allocator, arena.allocator()); } - var client = try Client.init(allocator, ca_blob, adjusted_opts); + var client = try Client.init(allocator, ca_blob, adjusted_opts, null); errdefer client.deinit(); return .{ diff --git a/src/lightpanda.zig b/src/lightpanda.zig index 361d872d..79ffaa8b 100644 --- a/src/lightpanda.zig +++ b/src/lightpanda.zig @@ -19,6 +19,9 @@ const std = @import("std"); pub const App = @import("App.zig"); pub const Server = @import("Server.zig"); +pub const SharedState = @import("SharedState.zig"); +pub const SessionThread = @import("SessionThread.zig"); +pub const SessionManager = @import("SessionManager.zig"); pub const Page = @import("browser/Page.zig"); pub const Browser = @import("browser/Browser.zig"); pub const Session = @import("browser/Session.zig"); @@ -37,7 +40,7 @@ pub const FetchOpts = struct { writer: ?*std.Io.Writer = null, }; pub fn fetch(app: *App, url: [:0]const u8, opts: FetchOpts) !void { - var browser = try Browser.init(app); + var browser = try Browser.initFromApp(app); defer browser.deinit(); var session = try browser.newSession(); diff --git a/src/main.zig b/src/main.zig index d0d83b56..c480e6a5 100644 --- a/src/main.zig +++ b/src/main.zig @@ -108,8 +108,12 @@ fn run(allocator: Allocator, main_arena: Allocator, sighandler: *SigHandler) !vo return args.printUsageAndExit(false); }; + // Create SharedState for multi-session CDP server + const shared = try app.createSharedState(); + defer shared.deinit(); + // _server is global to handle graceful shutdown. - var server = try lp.Server.init(app, address); + var server = try lp.Server.init(shared, address, app.config.max_sessions, app.config.session_memory_limit); defer server.deinit(); try sighandler.on(lp.Server.stop, .{&server}); diff --git a/src/main_legacy_test.zig b/src/main_legacy_test.zig index 78b85472..9e9892c8 100644 --- a/src/main_legacy_test.zig +++ b/src/main_legacy_test.zig @@ -43,7 +43,7 @@ pub fn main() !void { var test_arena = std.heap.ArenaAllocator.init(allocator); defer test_arena.deinit(); - var browser = try lp.Browser.init(app); + var browser = try lp.Browser.initFromApp(app); defer browser.deinit(); const session = try browser.newSession(); diff --git a/src/main_wpt.zig b/src/main_wpt.zig index a70aacc6..5140705f 100644 --- a/src/main_wpt.zig +++ b/src/main_wpt.zig @@ -65,7 +65,7 @@ pub fn main() !void { }); defer app.deinit(); - var browser = try lp.Browser.init(app); + var browser = try lp.Browser.initFromApp(app); defer browser.deinit(); // An arena for running each tests. Is reset after every test. diff --git a/src/testing.zig b/src/testing.zig index c022d830..19ece8dd 100644 --- a/src/testing.zig +++ b/src/testing.zig @@ -441,8 +441,10 @@ const log = @import("log.zig"); const TestHTTPServer = @import("TestHTTPServer.zig"); const Server = @import("Server.zig"); +const SharedState = @import("SharedState.zig"); var test_cdp_server: ?Server = null; var test_http_server: ?TestHTTPServer = null; +var test_shared_state: ?*SharedState = null; test "tests:beforeAll" { log.opts.level = .warn; @@ -455,7 +457,7 @@ test "tests:beforeAll" { }); errdefer test_app.deinit(); - test_browser = try Browser.init(test_app); + test_browser = try Browser.initFromApp(test_app); errdefer test_browser.deinit(); test_session = try test_browser.newSession(); @@ -483,6 +485,10 @@ test "tests:afterAll" { if (test_cdp_server) |*server| { server.deinit(); } + if (test_shared_state) |shared| { + shared.deinit(); + test_shared_state = null; + } if (test_http_server) |*server| { server.deinit(); } @@ -495,13 +501,16 @@ test "tests:afterAll" { fn serveCDP(wg: *std.Thread.WaitGroup) !void { const address = try std.net.Address.parseIp("127.0.0.1", 9583); - test_cdp_server = try Server.init(test_app, address); - var server = try Server.init(test_app, address); + // Create SharedState by borrowing V8 resources from test_app + test_shared_state = try SharedState.initFromApp(test_app, @import("root").tracking_allocator); + + var server = try Server.init(test_shared_state.?, address, 10, 64 * 1024 * 1024); + test_cdp_server = server; defer server.deinit(); wg.finish(); - test_cdp_server.?.run(address, 5) catch |err| { + server.run(address, 5) catch |err| { std.debug.print("CDP server error: {}", .{err}); return err; };