From c0c0694fcc9d56101347d7ba9b8f4b66d1d1d176 Mon Sep 17 00:00:00 2001 From: Karl Seguin Date: Thu, 6 Feb 2025 22:05:01 +0800 Subject: [PATCH] Make TCP server websocket-aware Adding HTTP & websocket awareness to the TCP server. HTTP server handles `GET /json/version` and websocket upgrade requests. Conceptually, websocket handling is the same code as before, but receiving data will parse the websocket frames and writing data will wrap it in a websocket frame. The previous `Ctx` was split into a `Server` and a `Client`. This was largely done to make it easy to write unit tests, since the `Client` is a generic, all its dependencies (i.e. the server) can be mocked out. This also makes it a bit nicer to know if there is or isn't a client (via the server's client optional). Added a MemoryPool for the Send object (I thought that was a nice touch!) Removed MacOS hack on accept/conn completion usage. Known issues: - When framing an outgoing message, the entire message has to be duped. This is no worse than how it was before, but it should be possible to eliminate this in the future. Probably not part of this PR. - Websocket parsing will reject continuation frames. I don't know of a single client that will send a fragmented message (websocket has its own message fragmentation), but we should probably still support this just in case. - I don't think the receive, timeout and close completions can safely be re-used like we're doing. I believe they need to be associated with a specific client socket. - A new connection creates a new browser session. I think this is right (??), but for the very first, we're throwing out a perfectly usable session. I'm thinking this might be a change to how Browser/Sessions work. - zig build test won't compile. This branch reproduces the issue with none of these changes: https://github.com/karlseguin/browser/tree/broken_test_build (or, as a diff to main): https://github.com/lightpanda-io/browser/compare/main...karlseguin:broken_test_build --- .gitmodules | 4 - src/cdp/runtime.zig | 4 +- src/handler.zig | 95 --- src/main.zig | 71 +- src/main_tests.zig | 4 +- src/msg.zig | 166 ----- src/server.zig | 1722 ++++++++++++++++++++++++++++++++++--------- 7 files changed, 1373 insertions(+), 693 deletions(-) delete mode 100644 src/handler.zig delete mode 100644 src/msg.zig diff --git a/.gitmodules b/.gitmodules index 184dd202..5743ca29 100644 --- a/.gitmodules +++ b/.gitmodules @@ -28,7 +28,3 @@ [submodule "vendor/zig-async-io"] path = vendor/zig-async-io url = https://github.com/lightpanda-io/zig-async-io.git/ -[submodule "vendor/websocket.zig"] - path = vendor/websocket.zig - url = https://github.com/lightpanda-io/websocket.zig.git/ - branch = lightpanda diff --git a/src/cdp/runtime.zig b/src/cdp/runtime.zig index 44c1a907..054d5a78 100644 --- a/src/cdp/runtime.zig +++ b/src/cdp/runtime.zig @@ -131,12 +131,12 @@ fn sendInspector( const buf = try alloc.alloc(u8, msg.json.len + 1); defer alloc.free(buf); _ = std.mem.replace(u8, msg.json, "\"awaitPromise\":true", "\"awaitPromise\":false", buf); - ctx.sendInspector(buf); + try ctx.sendInspector(buf); return ""; } } - ctx.sendInspector(msg.json); + try ctx.sendInspector(msg.json); if (msg.id == null) return ""; diff --git a/src/handler.zig b/src/handler.zig deleted file mode 100644 index 0decb3f7..00000000 --- a/src/handler.zig +++ /dev/null @@ -1,95 +0,0 @@ -// Copyright (C) 2023-2024 Lightpanda (Selecy SAS) -// -// Francis Bouvier -// Pierre Tachoire -// -// This program is free software: you can redistribute it and/or modify -// it under the terms of the GNU Affero General Public License as -// published by the Free Software Foundation, either version 3 of the -// License, or (at your option) any later version. -// -// This program is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU Affero General Public License for more details. -// -// You should have received a copy of the GNU Affero General Public License -// along with this program. If not, see . - -const std = @import("std"); - -const ws = @import("websocket"); -const Msg = @import("msg.zig").Msg; - -const log = std.log.scoped(.handler); - -pub const Stream = struct { - addr: std.net.Address, - socket: std.posix.socket_t = undefined, - - ws_host: []const u8, - ws_port: u16, - ws_conn: *ws.Conn = undefined, - - fn connectCDP(self: *Stream) !void { - const flags: u32 = std.posix.SOCK.STREAM; - const proto = blk: { - if (self.addr.any.family == std.posix.AF.UNIX) break :blk @as(u32, 0); - break :blk std.posix.IPPROTO.TCP; - }; - const socket = try std.posix.socket(self.addr.any.family, flags, proto); - - try std.posix.connect( - socket, - &self.addr.any, - self.addr.getOsSockLen(), - ); - log.debug("connected to Stream server", .{}); - self.socket = socket; - } - - fn closeCDP(self: *const Stream) void { - const close_msg: []const u8 = .{ 5, 0, 0, 0 } ++ "close"; - self.recv(close_msg) catch |err| { - log.err("stream close error: {any}", .{err}); - }; - std.posix.close(self.socket); - } - - fn start(self: *Stream, ws_conn: *ws.Conn) !void { - try self.connectCDP(); - self.ws_conn = ws_conn; - } - - pub fn recv(self: *const Stream, data: []const u8) !void { - var pos: usize = 0; - while (pos < data.len) { - const len = try std.posix.write(self.socket, data[pos..]); - pos += len; - } - } - - pub fn send(self: *const Stream, data: []const u8) !void { - return self.ws_conn.write(data); - } -}; - -pub const Handler = struct { - stream: *Stream, - - pub fn init(_: ws.Handshake, ws_conn: *ws.Conn, stream: *Stream) !Handler { - try stream.start(ws_conn); - return .{ .stream = stream }; - } - - pub fn close(self: *Handler) void { - self.stream.closeCDP(); - } - - pub fn clientMessage(self: *Handler, data: []const u8) !void { - var header: [4]u8 = undefined; - Msg.setSize(data.len, &header); - try self.stream.recv(&header); - try self.stream.recv(data); - } -}; diff --git a/src/main.zig b/src/main.zig index e4da1df2..c5c04996 100644 --- a/src/main.zig +++ b/src/main.zig @@ -20,12 +20,9 @@ const std = @import("std"); const builtin = @import("builtin"); const jsruntime = @import("jsruntime"); -const websocket = @import("websocket"); const Browser = @import("browser/browser.zig").Browser; const server = @import("server.zig"); -const handler = @import("handler.zig"); -const MaxSize = @import("msg.zig").MaxSize; const parser = @import("netsurf"); const apiweb = @import("apiweb.zig"); @@ -86,11 +83,9 @@ const CliMode = union(CliModeTag) { const Server = struct { execname: []const u8 = undefined, args: *std.process.ArgIterator = undefined, - addr: std.net.Address = undefined, host: []const u8 = Host, port: u16 = Port, timeout: u8 = Timeout, - tcp: bool = false, // undocumented TCP mode // default options const Host = "127.0.0.1"; @@ -160,10 +155,6 @@ const CliMode = union(CliModeTag) { return printUsageExit(execname, 1); } } - if (std.mem.eql(u8, "--tcp", opt)) { - _server.tcp = true; - continue; - } // unknown option if (std.mem.startsWith(u8, opt, "--")) { @@ -186,10 +177,6 @@ const CliMode = union(CliModeTag) { if (default_mode == .server) { // server mode - _server.addr = std.net.Address.parseIp4(_server.host, _server.port) catch |err| { - log.err("address (host:port) {any}\n", .{err}); - return printUsageExit(execname, 1); - }; _server.execname = execname; _server.args = args; return CliMode{ .server = _server }; @@ -247,65 +234,19 @@ pub fn main() !void { switch (cli_mode) { .server => |opts| { - - // Stream server - const addr = blk: { - if (opts.tcp) { - break :blk opts.addr; - } else { - const unix_path = "/tmp/lightpanda"; - std.fs.deleteFileAbsolute(unix_path) catch {}; // file could not exists - break :blk try std.net.Address.initUnix(unix_path); - } - }; - const socket = server.listen(addr) catch |err| { - log.err("Server listen error: {any}\n", .{err}); + const address = std.net.Address.parseIp4(opts.host, opts.port) catch |err| { + log.err("address (host:port) {any}\n", .{err}); return printUsageExit(opts.execname, 1); }; - defer std.posix.close(socket); - log.debug("Server opts: listening internally on {any}...", .{addr}); - const timeout = std.time.ns_per_s * @as(u64, opts.timeout); - - // loop var loop = try jsruntime.Loop.init(alloc); defer loop.deinit(); - // TCP server mode - if (opts.tcp) { - return server.handle(alloc, &loop, socket, null, timeout); - } - - // start stream server in separate thread - var stream = handler.Stream{ - .ws_host = opts.host, - .ws_port = opts.port, - .addr = addr, + const timeout = std.time.ns_per_s * @as(u64, opts.timeout); + server.run(alloc, address, timeout, &loop) catch |err| { + log.err("Server error", .{}); + return err; }; - const cdp_thread = try std.Thread.spawn( - .{ .allocator = alloc }, - server.handle, - .{ alloc, &loop, socket, &stream, timeout }, - ); - - // Websocket server - var ws = try websocket.Server(handler.Handler).init(alloc, .{ - .port = opts.port, - .address = opts.host, - .max_message_size = MaxSize + 14, // overhead websocket - .max_conn = 1, - .handshake = .{ - .timeout = 3, - .max_size = 1024, - // since we aren't using hanshake.headers - // we can set this to 0 to save a few bytes. - .max_headers = 0, - }, - }); - defer ws.deinit(); - - try ws.listen(&stream); - cdp_thread.join(); }, .fetch => |opts| { diff --git a/src/main_tests.zig b/src/main_tests.zig index 3544cc6a..88acd1a5 100644 --- a/src/main_tests.zig +++ b/src/main_tests.zig @@ -314,9 +314,6 @@ const kb = 1024; const ms = std.time.ns_per_ms; test { - const msgTest = @import("msg.zig"); - std.testing.refAllDecls(msgTest); - const dumpTest = @import("browser/dump.zig"); std.testing.refAllDecls(dumpTest); @@ -340,6 +337,7 @@ test { std.testing.refAllDecls(@import("generate.zig")); std.testing.refAllDecls(@import("cdp/msg.zig")); + std.testing.refAllDecls(@import("server.zig")); } fn testJSRuntime(alloc: std.mem.Allocator) !void { diff --git a/src/msg.zig b/src/msg.zig deleted file mode 100644 index 13b7a62e..00000000 --- a/src/msg.zig +++ /dev/null @@ -1,166 +0,0 @@ -// Copyright (C) 2023-2024 Lightpanda (Selecy SAS) -// -// Francis Bouvier -// Pierre Tachoire -// -// This program is free software: you can redistribute it and/or modify -// it under the terms of the GNU Affero General Public License as -// published by the Free Software Foundation, either version 3 of the -// License, or (at your option) any later version. -// -// This program is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU Affero General Public License for more details. -// -// You should have received a copy of the GNU Affero General Public License -// along with this program. If not, see . - -const std = @import("std"); - -pub const HeaderSize = 4; -pub const MsgSize = 256 * 1204; // 256KB -// NOTE: Theorically we could go up to 4GB with a 4 bytes binary encoding -// but we prefer to put a lower hard limit for obvious memory size reasons. - -pub const MaxSize = HeaderSize + MsgSize; - -pub const Msg = struct { - pub fn getSize(data: []const u8) usize { - return std.mem.readInt(u32, data[0..HeaderSize], .little); - } - - pub fn setSize(len: usize, header: *[4]u8) void { - std.mem.writeInt(u32, header, @intCast(len), .little); - } -}; - -/// Buffer returns messages from a raw text read stream, -/// with the message size being encoded on the 2 first bytes (little endian) -/// It handles both: -/// - combined messages in one read -/// - single message in several reads (multipart) -/// It's safe (and a good practice) to reuse the same Buffer -/// on several reads of the same stream. -pub const Buffer = struct { - buf: []u8, - size: usize = 0, - pos: usize = 0, - - fn isFinished(self: *const Buffer) bool { - return self.pos >= self.size; - } - - fn isEmpty(self: *const Buffer) bool { - return self.size == 0 and self.pos == 0; - } - - fn reset(self: *Buffer) void { - self.size = 0; - self.pos = 0; - } - - // read input - pub fn read(self: *Buffer, input: []const u8) !struct { - msg: []const u8, - left: []const u8, - } { - var _input = input; // make input writable - - // msg size - var msg_size: usize = undefined; - if (self.isEmpty()) { - // decode msg size header - msg_size = Msg.getSize(_input); - _input = _input[HeaderSize..]; - } else { - msg_size = self.size; - } - - // multipart - const is_multipart = !self.isEmpty() or _input.len < msg_size; - if (is_multipart) { - - // set msg size on empty Buffer - if (self.isEmpty()) { - self.size = msg_size; - } - - // get the new position of the cursor - const new_pos = self.pos + _input.len; - - // check max limit size - if (new_pos > MaxSize) { - return error.MsgTooBig; - } - - // copy the current input into Buffer - // NOTE: we could use @memcpy but it's not Thread-safe (alias problem) - // see https://www.openmymind.net/Zigs-memcpy-copyForwards-and-copyBackwards/ - // Intead we just use std.mem.copyForwards - std.mem.copyForwards(u8, self.buf[self.pos..new_pos], _input[0..]); - - // set the new cursor position - self.pos = new_pos; - - // if multipart is not finished, go fetch the next input - if (!self.isFinished()) return error.MsgMultipart; - - // otherwhise multipart is finished, use its buffer as input - _input = self.buf[0..self.pos]; - self.reset(); - } - - // handle several JSON msg in 1 read - return .{ .msg = _input[0..msg_size], .left = _input[msg_size..] }; - } -}; - -test "Buffer" { - const Case = struct { - input: []const u8, - nb: u8, - }; - - const cases = [_]Case{ - // simple - .{ .input = .{ 2, 0, 0, 0 } ++ "ok", .nb = 1 }, - // combined - .{ .input = .{ 2, 0, 0, 0 } ++ "ok" ++ .{ 3, 0, 0, 0 } ++ "foo", .nb = 2 }, - // multipart - .{ .input = .{ 9, 0, 0, 0 } ++ "multi", .nb = 0 }, - .{ .input = "part", .nb = 1 }, - // multipart & combined - .{ .input = .{ 9, 0, 0, 0 } ++ "multi", .nb = 0 }, - .{ .input = "part" ++ .{ 2, 0, 0, 0 } ++ "ok", .nb = 2 }, - // multipart & combined with other multipart - .{ .input = .{ 9, 0, 0, 0 } ++ "multi", .nb = 0 }, - .{ .input = "part" ++ .{ 8, 0, 0, 0 } ++ "co", .nb = 1 }, - .{ .input = "mbined", .nb = 1 }, - // several multipart - .{ .input = .{ 23, 0, 0, 0 } ++ "multi", .nb = 0 }, - .{ .input = "several", .nb = 0 }, - .{ .input = "complex", .nb = 0 }, - .{ .input = "part", .nb = 1 }, - // combined & multipart - .{ .input = .{ 2, 0, 0, 0 } ++ "ok" ++ .{ 9, 0, 0, 0 } ++ "multi", .nb = 1 }, - .{ .input = "part", .nb = 1 }, - }; - - var b: [MaxSize]u8 = undefined; - var buf = Buffer{ .buf = &b }; - - for (cases) |case| { - var nb: u8 = 0; - var input = case.input; - while (input.len > 0) { - const parts = buf.read(input) catch |err| { - if (err == error.MsgMultipart) break; // go to the next case input - return err; - }; - nb += 1; - input = parts.left; - } - try std.testing.expect(nb == case.nb); - } -} diff --git a/src/server.zig b/src/server.zig index cbe16e12..5f18b6e9 100644 --- a/src/server.zig +++ b/src/server.zig @@ -19,7 +19,10 @@ const std = @import("std"); const builtin = @import("builtin"); -const Stream = @import("handler.zig").Stream; +const net = std.net; +const posix = std.posix; + +const Allocator = std.mem.Allocator; const jsruntime = @import("jsruntime"); const Completion = jsruntime.IO.Completion; @@ -30,237 +33,232 @@ const CloseError = jsruntime.IO.CloseError; const CancelError = jsruntime.IO.CancelError; const TimeoutError = jsruntime.IO.TimeoutError; -const MsgBuffer = @import("msg.zig").Buffer; -const MaxSize = @import("msg.zig").MaxSize; const Browser = @import("browser/browser.zig").Browser; const cdp = @import("cdp/cdp.zig"); -const NoError = error{NoError}; const IOError = AcceptError || RecvError || SendError || CloseError || TimeoutError || CancelError; -const Error = IOError || std.fmt.ParseIntError || cdp.Error || NoError; +const HTTPError = error{ + OutOfMemory, + RequestTooLarge, + NotFound, + InvalidRequest, + MissingHeaders, + InvalidProtocol, + InvalidUpgradeHeader, + InvalidVersionHeader, + InvalidConnectionHeader, +}; +const WebSocketError = error{ + OutOfMemory, + ReservedFlags, + NotMasked, + TooLarge, + InvalidMessageType, + ContinuationNotSupported, +}; +const Error = IOError || cdp.Error || HTTPError || WebSocketError; const TimeoutCheck = std.time.ns_per_ms * 100; const log = std.log.scoped(.server); -const isLinux = builtin.target.os.tag == .linux; -// I/O Main -// -------- +const MAX_HTTP_REQUEST_SIZE = 2048; -const BufReadSize = 1024; // 1KB -const MaxStdOutSize = 512; // ensure debug msg are not too long +// max message size, +14 for max websocket payload overhead +const MAX_MESSAGE_SIZE = 256 * 1024 + 14; -pub const Ctx = struct { +// For now, cdp does @import("server.zig").Ctx. Could change cdp to use "Server" +// but I rather try to decouple the CDP code from the server, so a quick +// stopgap is fine. TODO: Decouple cdp from the server +pub const Ctx = Server; + +const Server = struct { + allocator: Allocator, loop: *jsruntime.Loop, - stream: ?*Stream, // internal fields - accept_socket: std.posix.socket_t, - conn_socket: std.posix.socket_t = undefined, - read_buf: []u8, // only for read operations - msg_buf: *MsgBuffer, - err: ?Error = null, + listener: posix.socket_t, + client: ?Client(*Server) = null, + timeout: u64, + + // a memory poor for our Send objects + send_pool: std.heap.MemoryPool(Send), // I/O fields - accept_completion: *Completion, - conn_completion: *Completion, - timeout_completion: *Completion, - timeout: u64, - last_active: ?std.time.Instant = null, + conn_completion: Completion, + close_completion: Completion, + accept_completion: Completion, + timeout_completion: Completion, + + // used when gluing the session id to the inspector message + scrap: std.ArrayListUnmanaged(u8) = .{}, + + // The response to send on a GET /json/version request + json_version_response: []const u8, // CDP state: cdp.State = .{}, // JS fields browser: *Browser, // TODO: is pointer mandatory here? - sessionNew: bool, - // try_catch: jsruntime.TryCatch, // TODO - // callbacks - // --------- + fn deinit(self: *Server) void { + self.send_pool.deinit(); + self.allocator.free(self.json_version_response); + } - fn acceptCbk( - self: *Ctx, - completion: *Completion, - result: AcceptError!std.posix.socket_t, - ) void { - std.debug.assert(completion == self.acceptCompletion()); - - self.conn_socket = result catch |err| { - log.err("accept error: {any}", .{err}); - self.err = err; - return; - }; - log.info("client connected", .{}); - - // set connection timestamp and timeout - self.last_active = std.time.Instant.now() catch |err| { - log.err("accept timestamp error: {any}", .{err}); - return; - }; - self.loop.io.timeout( - *Ctx, + fn queueAccept(self: *Server) void { + log.info("accepting new conn...", .{}); + self.loop.io.accept( + *Server, self, - Ctx.timeoutCbk, - self.timeout_completion, - TimeoutCheck, - ); - - // receving incomming messages asynchronously - self.loop.io.recv( - *Ctx, - self, - Ctx.readCbk, - self.conn_completion, - self.conn_socket, - self.read_buf, + callbackAccept, + &self.accept_completion, + self.listener, ); } - fn readCbk(self: *Ctx, completion: *Completion, result: RecvError!usize) void { - std.debug.assert(completion == self.conn_completion); + fn callbackAccept( + self: *Server, + completion: *Completion, + result: AcceptError!posix.socket_t, + ) void { + std.debug.assert(completion == &self.accept_completion); - const size = result catch |err| { - if (self.isClosed() and err == error.FileDescriptorInvalid) { - log.debug("read has been canceled", .{}); + const socket = result catch |err| { + log.err("accept error: {any}", .{err}); + self.queueAccept(); + return; + }; + + self.newSession() catch |err| { + log.err("new session error: {any}", .{err}); + self.queueClose(socket); + return; + }; + + log.info("client connected", .{}); + self.client = Client(*Server).init(socket, self); + self.queueRead(); + self.queueTimeout(); + } + + fn queueTimeout(self: *Server) void { + self.loop.io.timeout( + *Server, + self, + callbackTimeout, + &self.timeout_completion, + TimeoutCheck, + ); + } + + fn callbackTimeout( + self: *Server, + completion: *Completion, + result: TimeoutError!void, + ) void { + std.debug.assert(completion == &self.timeout_completion); + + const client = &(self.client orelse return); + + if (result) |_| { + if (now().since(client.last_active) > self.timeout) { + // close current connection + log.debug("conn timeout, closing...", .{}); + client.close(.timeout); return; } - log.err("read error: {any}", .{err}); - self.err = err; - return; - }; - - if (size == 0) { - // continue receving incomming messages asynchronously - self.loop.io.recv( - *Ctx, - self, - Ctx.readCbk, - self.conn_completion, - self.conn_socket, - self.read_buf, - ); - return; - } - - // set connection timestamp - self.last_active = std.time.Instant.now() catch |err| { - log.err("read timestamp error: {any}", .{err}); - return; - }; - - // continue receving incomming messages asynchronously - self.loop.io.recv( - *Ctx, - self, - Ctx.readCbk, - self.conn_completion, - self.conn_socket, - self.read_buf, - ); - - // read and execute input - var input: []const u8 = self.read_buf[0..size]; - while (input.len > 0) { - const parts = self.msg_buf.read(input) catch |err| { - if (err == error.MsgMultipart) { - return; - } else { - log.err("msg read error: {any}", .{err}); - return; - } - }; - input = parts.left; - // execute - self.do(parts.msg) catch |err| { - if (err != error.Closed) { - log.err("do error: {any}", .{err}); - log.debug("last msg: {s}", .{parts.msg}); - } - }; - } - } - - fn timeoutCbk(self: *Ctx, completion: *Completion, result: TimeoutError!void) void { - std.debug.assert(completion == self.timeout_completion); - - _ = result catch |err| { + } else |err| { log.err("timeout error: {any}", .{err}); - self.err = err; + } + + // We re-queue this if the timeout hasn't been exceeded or on some + // very unlikely IO timeout error. + // AKA: we don't requeue this if the connection timed out and we + // closed the connection.s + self.queueTimeout(); + } + + fn queueRead(self: *Server) void { + if (self.client) |*client| { + self.loop.io.recv( + *Server, + self, + callbackRead, + &self.conn_completion, + client.socket, + client.readBuf(), + ); + } + } + + fn callbackRead( + self: *Server, + completion: *Completion, + result: RecvError!usize, + ) void { + std.debug.assert(completion == &self.conn_completion); + + var client = &(self.client orelse return); + + const size = result catch |err| { + log.err("read error: {any}", .{err}); + self.queueClose(client.socket); return; }; - if (self.isClosed()) { - // conn is already closed, ignore timeout - return; - } - - // check time since last read - const now = std.time.Instant.now() catch |err| { - log.err("timeout timestamp error: {any}", .{err}); + const more = client.processData(size) catch |err| { + std.debug.print("Client Processing Error: {}\n", .{err}); return; }; - if (now.since(self.last_active.?) > self.timeout) { - // close current connection - log.debug("conn timeout, closing...", .{}); - self.close(); - return; + // if more == false, the client is disconnecting + if (more) { + self.queueRead(); } + } - // continue checking timeout - self.loop.io.timeout( - *Ctx, + fn queueSend( + self: *Server, + socket: posix.socket_t, + data: []const u8, + free_when_done: bool, + ) !void { + const sd = try self.send_pool.create(); + errdefer self.send_pool.destroy(sd); + + sd.* = .{ + .data = data, + .unsent = data, + .server = self, + .socket = socket, + .completion = undefined, + .free_when_done = free_when_done, + }; + sd.queueSend(); + } + + fn queueClose(self: *Server, socket: posix.socket_t) void { + self.loop.io.close( + *Server, self, - Ctx.timeoutCbk, - self.timeout_completion, - TimeoutCheck, + callbackClose, + &self.close_completion, + socket, ); } - // shortcuts - // --------- - - inline fn isClosed(self: *Ctx) bool { - // last_active is first saved on acceptCbk - return self.last_active == null; - } - - // allocator of the current session - inline fn alloc(self: *Ctx) std.mem.Allocator { - return self.browser.session.alloc; - } - - // JS env of the current session - inline fn env(self: Ctx) jsruntime.Env { - return self.browser.session.env; - } - - inline fn acceptCompletion(self: *Ctx) *Completion { - // NOTE: the logical completion to use here is the accept_completion - // as the pipe_connection can be used simulteanously by a recv I/O operation. - // But on MacOS (kqueue) the recv I/O operation on a closed socket leads to a panic - // so we use the pipe_connection to avoid this problem - if (isLinux) return self.accept_completion; - return self.conn_completion; - } - - // actions - // ------- - - fn do(self: *Ctx, cmd: []const u8) anyerror!void { - - // close cmd - if (std.mem.eql(u8, cmd, "close")) { - // close connection - log.info("close cmd, closing conn...", .{}); - self.close(); - return error.Closed; + fn callbackClose(self: *Server, completion: *Completion, _: CloseError!void) void { + std.debug.assert(completion == &self.close_completion); + if (self.client != null) { + self.client = null; } + self.queueAccept(); + } - if (self.sessionNew) self.sessionNew = false; - - const res = cdp.do(self.alloc(), cmd, self) catch |err| { + fn handleCDP(self: *Server, cmd: []const u8) !void { + const res = cdp.do(self.allocator, cmd, self) catch |err| { // cdp end cmd if (err == error.DisposeBrowserContext) { @@ -274,106 +272,106 @@ pub const Ctx = struct { }; // send result - if (!std.mem.eql(u8, res, "")) { + if (res.len != 0) { return self.send(res); } } - pub fn send(self: *Ctx, msg: []const u8) !void { - if (self.stream) |stream| { - // if we have a stream connection, just write on it - defer self.alloc().free(msg); - try stream.send(msg); - } else { - // otherwise write asynchronously on the socket connection - return sendAsync(self, msg); + // called from CDP + pub fn send(self: *Server, data: []const u8) !void { + if (self.client) |*client| { + try client.sendWS(data); } } - fn close(self: *Ctx) void { - - // conn is closed - self.last_active = null; - std.posix.close(self.conn_socket); - log.debug("connection closed", .{}); - - // restart a new browser session in case of re-connect - if (!self.sessionNew) { - self.newSession() catch |err| { - log.err("new session error: {any}", .{err}); - return; - }; - } - - log.info("accepting new conn...", .{}); - - // continue accepting incoming requests - self.loop.io.accept( - *Ctx, - self, - Ctx.acceptCbk, - self.acceptCompletion(), - self.accept_socket, - ); - } - - fn newSession(self: *Ctx) !void { - try self.browser.newSession(self.alloc(), self.loop); + fn newSession(self: *Server) !void { + try self.browser.newSession(self.allocator, self.loop); try self.browser.session.initInspector( self, - Ctx.onInspectorResp, - Ctx.onInspectorNotif, + inspectorResponse, + inspectorEvent, ); - self.sessionNew = true; } - // inspector - // --------- + // // inspector + // // --------- - pub fn sendInspector(self: *Ctx, msg: []const u8) void { - if (self.env().getInspector()) |inspector| { - inspector.send(self.env(), msg); - } else @panic("Inspector has not been set"); + // called by cdp + pub fn sendInspector(self: *Server, msg: []const u8) !void { + const env = self.browser.session.env; + if (env.getInspector()) |inspector| { + inspector.send(env, msg); + return; + } + return error.InspectNotSet; } - inline fn inspectorCtx(ctx_opaque: *anyopaque) *Ctx { - const aligned = @as(*align(@alignOf(Ctx)) anyopaque, @alignCast(ctx_opaque)); - return @as(*Ctx, @ptrCast(aligned)); - } - - fn inspectorMsg(allocator: std.mem.Allocator, ctx: *Ctx, msg: []const u8) !void { - // inject sessionID in cdp msg - const tpl = "{s},\"sessionId\":\"{s}\"}}"; - const msg_open = msg[0 .. msg.len - 1]; // remove closing bracket - const s = try std.fmt.allocPrint( - allocator, - tpl, - .{ msg_open, @tagName(ctx.state.sessionID) }, - ); - - try ctx.send(s); - } - - pub fn onInspectorResp(ctx_opaque: *anyopaque, _: u32, msg: []const u8) void { + fn inspectorResponse(ctx: *anyopaque, _: u32, msg: []const u8) void { if (std.log.defaultLogEnabled(.debug)) { // msg should be {"id":,... - const id_end = std.mem.indexOfScalar(u8, msg, ',') orelse unreachable; + std.debug.assert(std.mem.startsWith(u8, msg, "{\"id\":")); + + const id_end = std.mem.indexOfScalar(u8, msg, ',') orelse { + log.warn("invalid inspector response message: {s}", .{msg}); + return; + }; + const id = msg[6..id_end]; std.log.scoped(.cdp).debug("Res (inspector) > id {s}", .{id}); } - const ctx = inspectorCtx(ctx_opaque); - inspectorMsg(ctx.alloc(), ctx, msg) catch unreachable; + sendInspectorMessage(@alignCast(@ptrCast(ctx)), msg); } - pub fn onInspectorNotif(ctx_opaque: *anyopaque, msg: []const u8) void { + fn inspectorEvent(ctx: *anyopaque, msg: []const u8) void { if (std.log.defaultLogEnabled(.debug)) { // msg should be {"method":,... - const method_end = std.mem.indexOfScalar(u8, msg, ',') orelse unreachable; + std.debug.assert(std.mem.startsWith(u8, msg, "{\"method\":")); + const method_end = std.mem.indexOfScalar(u8, msg, ',') orelse { + log.warn("invalid inspector event message: {s}", .{msg}); + return; + }; const method = msg[10..method_end]; std.log.scoped(.cdp).debug("Event (inspector) > method {s}", .{method}); } - const ctx = inspectorCtx(ctx_opaque); - inspectorMsg(ctx.alloc(), ctx, msg) catch unreachable; + + sendInspectorMessage(@alignCast(@ptrCast(ctx)), msg); + } + + fn sendInspectorMessage(self: *Server, msg: []const u8) void { + var client = &(self.client orelse return); + + var scrap = &self.scrap; + scrap.clearRetainingCapacity(); + + const field = ",\"sessionId\":"; + const sessionID = @tagName(self.state.sessionID); + + // + 2 for the quotes around the session + const message_len = msg.len + sessionID.len + 2 + field.len; + + scrap.ensureTotalCapacity(self.allocator, message_len) catch |err| { + log.err("Failed to expand inspector buffer: {}", .{err}); + return; + }; + + // -1 because we dont' want the closing brace '}' + scrap.appendSliceAssumeCapacity(msg[0 .. msg.len - 1]); + scrap.appendSliceAssumeCapacity(field); + scrap.appendAssumeCapacity('"'); + scrap.appendSliceAssumeCapacity(sessionID); + scrap.appendSliceAssumeCapacity("\"}"); + std.debug.assert(scrap.items.len == message_len); + + // TODO: Remove when we clean up ownership of messages between + // CDD and sending. + const owned = self.allocator.dupe(u8, scrap.items) catch return; + + client.sendWS(owned) catch |err| { + log.debug("Failed to write inspector message to client: {}", .{err}); + // don't bother trying to cleanly close the client, if sendWS fails + // we're almost certainly in a non-recoverable state (i.e. OOM) + self.queueClose(client.socket); + }; } }; @@ -383,47 +381,560 @@ pub const Ctx = struct { // NOTE: to allow concurrent send we create each time a dedicated context // (with its own completion), allocated on the heap. // After the send (on the sendCbk) the dedicated context will be destroy -// and the msg slice will be free. +// and the data slice will be free. const Send = struct { - ctx: *Ctx, - msg: []const u8, - completion: Completion = undefined, + // The full data to be sent + data: []const u8, - fn init(ctx: *Ctx, msg: []const u8) !*Send { - const sd = try ctx.alloc().create(Send); - sd.* = .{ .ctx = ctx, .msg = msg }; - return sd; - } + // Whether or not to free the data once the message is sent (or fails to) + // send. This is false in cases where the message is comptime known + free_when_done: bool, + + // Any unsent data we have. Initially unsent == data, but as part of the + // message is succesfully sent, unsent becomes a smaller and smaller slice + // of data + unsent: []const u8, + + server: *Server, + completion: Completion, + socket: posix.socket_t, fn deinit(self: *Send) void { - self.ctx.alloc().free(self.msg); - self.ctx.alloc().destroy(self); + var server = self.server; + if (self.free_when_done) { + server.allocator.free(self.data); + } + server.send_pool.destroy(self); } - fn asyncCbk(self: *Send, _: *Completion, result: SendError!usize) void { - _ = result catch |err| { + fn queueSend(self: *Send) void { + self.server.loop.io.send( + *Send, + self, + sendCallback, + &self.completion, + self.socket, + self.unsent, + ); + } + + fn sendCallback( + self: *Send, + _: *Completion, + result: SendError!usize, + ) void { + const sent = result catch |err| { log.err("send error: {any}", .{err}); - self.ctx.err = err; + if (self.server.client) |*client| { + self.server.queueClose(client.socket); + } + self.deinit(); + return; }; - self.deinit(); + + if (sent == self.unsent.len) { + self.deinit(); + return; + } + + // partial send, re-queue a send for whatever we have left + self.unsent = self.unsent[sent..]; + self.queueSend(); } }; -pub fn sendAsync(ctx: *Ctx, msg: []const u8) !void { - const sd = try Send.init(ctx, msg); - ctx.loop.io.send(*Send, sd, Send.asyncCbk, &sd.completion, ctx.conn_socket, sd.msg); +// Client +// -------- + +// This is a generic only so that it can be unit tested. Normally, S == Server +// and when we send a message, we'll use server.send(...) to send via the server's +// IO loop. During tests, we can inject a simple mock to record (and then verify) +// the send message +fn Client(comptime S: type) type { + const EMPTY_PONG = [_]u8{ 138, 0 }; + + // CLOSE, 2 length, code + const CLOSE_NORMAL = [_]u8{ 136, 2, 3, 232 }; // code: 1000 + const CLOSE_TOO_BIG = [_]u8{ 136, 2, 3, 241 }; // 1009 + const CLOSE_PROTOCOL_ERROR = [_]u8{ 136, 2, 3, 234 }; //code: 1002 + // This should be removed once we support continuation frames + const CLOSE_UNSUPPORTED_ERROR = [_]u8{ 136, 2, 3, 235 }; //code: 1003 + const CLOSE_TIMEOUT = [_]u8{ 136, 2, 15, 160 }; // code: 4000 + + return struct { + // The client is initially serving HTTP requests but, under normal circumstances + // should eventually be upgraded to a websocket connections + mode: Mode, + server: S, + socket: posix.socket_t, + last_active: std.time.Instant, + + // the start of the message in our read_buf + read_pos: usize = 0, + // up to where do we have data in our read_buf + read_len: usize = 0, + read_buf: [MAX_MESSAGE_SIZE]u8 = undefined, + + const Mode = enum { + http, + websocket, + }; + + const Self = @This(); + + fn init(socket: posix.socket_t, server: S) Self { + return .{ + .mode = .http, + .socket = socket, + .server = server, + .last_active = now(), + }; + } + + fn close(self: *Self, close_code: CloseCode) void { + if (self.mode == .websocket) { + switch (close_code) { + .timeout => self.send(&CLOSE_TIMEOUT, false) catch {}, + } + } + self.server.queueClose(self.socket); + } + + fn readBuf(self: *Self) []u8 { + // We might have read a partial http or websocket message. + // Subsequent reads must read from where we left off. + std.debug.assert(self.read_pos < self.read_buf.len); + return self.read_buf[self.read_len..]; + } + + fn processData(self: *Self, len: usize) !bool { + const end = self.read_len + len; + std.debug.assert(end >= self.read_pos); + + self.last_active = now(); + const data = self.read_buf[self.read_pos..end]; + + switch (self.mode) { + .http => { + try self.processHTTPRequest(data); + return true; + }, + .websocket => return self.processWebsocketMessage(data), + } + } + + fn processHTTPRequest(self: *Self, request: []u8) HTTPError!void { + // We should never get pipelined HTTP requests + std.debug.assert(self.read_pos == 0); + + errdefer self.server.queueClose(self.socket); + + // we're only expecting [body-less] GET requests. + if (std.mem.endsWith(u8, request, "\r\n\r\n") == false) { + if (request.len > MAX_HTTP_REQUEST_SIZE) { + self.writeHTTPErrorResponse(413, "Request too large"); + return error.RequestTooLarge; + } + // we need more data, put any more data here + self.read_len = request.len; + return; + } + + self.handleHTTPRequest(request) catch |err| { + switch (err) { + error.NotFound => self.writeHTTPErrorResponse(404, "Not found"), + error.InvalidRequest => self.writeHTTPErrorResponse(400, "Invalid request"), + error.InvalidProtocol => self.writeHTTPErrorResponse(400, "Invalid HTTP protocol"), + error.MissingHeaders => self.writeHTTPErrorResponse(400, "Missing required header"), + error.InvalidUpgradeHeader => self.writeHTTPErrorResponse(400, "Unsupported upgrade type"), + error.InvalidVersionHeader => self.writeHTTPErrorResponse(400, "Invalid websocket version"), + error.InvalidConnectionHeader => self.writeHTTPErrorResponse(400, "Invalid connection header"), + else => { + log.err("error processing HTTP request: {}", .{err}); + self.writeHTTPErrorResponse(500, "Internal Server Error"); + }, + } + return err; + }; + + // the next incoming data can go to the front of our buffer + self.read_len = 0; + } + + fn handleHTTPRequest(self: *Self, request: []u8) !void { + if (request.len < 18) { + // 18 is [generously] the smallest acceptable HTTP request + return error.InvalidRequest; + } + + if (std.mem.eql(u8, request[0..4], "GET ") == false) { + return error.NotFound; + } + + const url_end = std.mem.indexOfScalarPos(u8, request, 4, ' ') orelse { + return error.InvalidRequest; + }; + + const url = request[4..url_end]; + + if (std.mem.eql(u8, url, "/")) { + return self.upgradeConnection(request); + } + + if (std.mem.eql(u8, url, "/json/version")) { + return self.send(self.server.json_version_response, false); + } + + return error.NotFound; + } + + fn upgradeConnection(self: *Self, request: []u8) !void { + // our caller already confirmed that we have a trailing \r\n\r\n + const request_line_end = std.mem.indexOfScalar(u8, request, '\r') orelse unreachable; + const request_line = request[0..request_line_end]; + + if (!std.ascii.endsWithIgnoreCase(request_line, "http/1.1")) { + return error.InvalidProtocol; + } + + // we need to extract the sec-websocket-key value + var key: []const u8 = ""; + + // we need to make sure that we got all the necessary headers + values + var required_headers: u8 = 0; + + // can't std.mem.split because it forces the iterated value to be const + // (we could @constCast...) + + var buf = request[request_line_end + 2 ..]; + + while (buf.len > 4) { + const index = std.mem.indexOfScalar(u8, buf, '\r') orelse unreachable; + const separator = std.mem.indexOfScalar(u8, buf[0..index], ':') orelse return error.InvalidRequest; + + const name = std.mem.trim(u8, toLower(buf[0..separator]), &std.ascii.whitespace); + const value = std.mem.trim(u8, buf[(separator + 1)..index], &std.ascii.whitespace); + + if (std.mem.eql(u8, name, "upgrade")) { + if (!std.ascii.eqlIgnoreCase("websocket", value)) { + return error.InvalidUpgradeHeader; + } + required_headers |= 1; + } else if (std.mem.eql(u8, name, "sec-websocket-version")) { + if (value.len != 2 or value[0] != '1' or value[1] != '3') { + return error.InvalidVersionHeader; + } + required_headers |= 2; + } else if (std.mem.eql(u8, name, "connection")) { + // find if connection header has upgrade in it, example header: + // Connection: keep-alive, Upgrade + if (std.ascii.indexOfIgnoreCase(value, "upgrade") == null) { + return error.InvalidConnectionHeader; + } + required_headers |= 4; + } else if (std.mem.eql(u8, name, "sec-websocket-key")) { + key = value; + required_headers |= 8; + } + + const next = index + 2; + buf = buf[next..]; + } + + if (required_headers != 15) { + return error.MissingHeaders; + } + + // our caller has already made sure this request ended in \r\n\r\n + // so it isn't something we need to check again + + const response = blk: { + // Response to an ugprade request is always this, with + // the Sec-Websocket-Accept value a spacial sha1 hash of the + // request "sec-websocket-version" and a magic value. + + const template = + "HTTP/1.1 101 Switching Protocols\r\n" ++ + "Upgrade: websocket\r\n" ++ + "Connection: upgrade\r\n" ++ + "Sec-Websocket-Accept: 0000000000000000000000000000\r\n\r\n"; + + // The response will be sent via the IO Loop and thus has to have its + // own lifetime. + const res = try self.server.allocator.dupe(u8, template); + errdefer self.server.allocator.free(res); + + // magic response + const key_pos = res.len - 32; + var h: [20]u8 = undefined; + var hasher = std.crypto.hash.Sha1.init(.{}); + hasher.update(key); + // websocket spec always used this value + hasher.update("258EAFA5-E914-47DA-95CA-C5AB0DC85B11"); + hasher.final(&h); + + _ = std.base64.standard.Encoder.encode(res[key_pos .. key_pos + 28], h[0..]); + + break :blk res; + }; + + self.mode = .websocket; + return self.send(response, true); + } + + fn processWebsocketMessage(self: *Self, data: []u8) !bool { + errdefer self.server.queueClose(self.socket); + + var reader = Reader{ .data = data }; + while (true) { + const msg = reader.next() catch |err| { + switch (err) { + error.TooLarge => self.send(&CLOSE_TOO_BIG, false) catch {}, + error.NotMasked => self.send(&CLOSE_PROTOCOL_ERROR, false) catch {}, + error.ReservedFlags => self.send(&CLOSE_PROTOCOL_ERROR, false) catch {}, + error.InvalidMessageType => self.send(&CLOSE_PROTOCOL_ERROR, false) catch {}, + error.ContinuationNotSupported => self.send(&CLOSE_UNSUPPORTED_ERROR, false) catch {}, + } + return err; + } orelse break; + + switch (msg.type) { + .pong => {}, + .ping => try self.sendPong(msg.data), + .close => { + self.send(&CLOSE_NORMAL, false) catch {}; + self.server.queueClose(self.socket); + return false; + }, + .text, .binary => try self.server.handleCDP(msg.data), + } + } + + const incomplete = reader.data; + self.read_len = incomplete.len; + if (incomplete.len > 0) { + // we have part of the data for the next message + + // can't use @memset because incomplete is a slice of read_buf, + // so they could overlap + + // TODO: this can be skipped if we know that the next message will + // fit into whatever reamining space we have. + std.mem.copyForwards(u8, self.read_buf[0..incomplete.len], incomplete); + } + return true; + } + + fn sendPong(self: *Self, data: []const u8) !void { + if (data.len == 0) { + return self.send(&EMPTY_PONG, false); + } + + return self.sendFrame(data, .pong); + } + + fn sendWS(self: *Self, data: []const u8) !void { + std.debug.assert(data.len < 4294967296); + + // for now, we're going to dupe this before we send it, so we don't need + // to keep this around. + defer self.server.allocator.free(data); + return self.sendFrame(data, .text); + } + + // We need to append the websocket header to data. If our IO loop supported + // a writev call, this would be simple. + // For now, we'll just have to dupe data into a larger message. + // TODO: Remove this awful allocation (probably by passing a websocket-aware + // Writer into CDP) + fn sendFrame(self: *Self, data: []const u8, op_code: OpCode) !void { + if (comptime builtin.is_test == false) { + std.debug.assert(self.mode == .websocket); + } + + // 10 is the max possible length of our header + // server->client has no mask, so it's 4 fewer bytes than the reader overhead + var header_buf: [10]u8 = undefined; + + const header: []const u8 = blk: { + const len = data.len; + header_buf[0] = 128 | @intFromEnum(op_code); // fin | opcode + + if (len <= 125) { + header_buf[1] = @intCast(len); + break :blk header_buf[0..2]; + } + + if (len < 65536) { + header_buf[1] = 126; + header_buf[2] = @intCast((len >> 8) & 0xFF); + header_buf[3] = @intCast(len & 0xFF); + break :blk header_buf[0..4]; + } + + header_buf[1] = 127; + header_buf[2] = 0; + header_buf[3] = 0; + header_buf[4] = 0; + header_buf[5] = 0; + header_buf[6] = @intCast((len >> 24) & 0xFF); + header_buf[7] = @intCast((len >> 16) & 0xFF); + header_buf[8] = @intCast((len >> 8) & 0xFF); + header_buf[9] = @intCast(len & 0xFF); + break :blk header_buf[0..10]; + }; + + const allocator = self.server.allocator; + const full = try allocator.alloc(u8, header.len + data.len); + errdefer allocator.free(full); + @memcpy(full[0..header.len], header); + @memcpy(full[header.len..], data); + try self.send(full, true); + } + + fn writeHTTPErrorResponse(self: *Self, comptime status: u16, comptime body: []const u8) void { + const response = std.fmt.comptimePrint( + "HTTP/1.1 {d} \r\nConnection: Close\r\nContent-Length: {d}\r\n\r\n{s}", + .{ status, body.len, body }, + ); + + // we're going to close this connection anyways, swallowing any + // error seems safe + self.send(response, false) catch {}; + } + + fn send(self: *Self, data: []const u8, free_when_done: bool) !void { + return self.server.queueSend(self.socket, data, free_when_done); + } + }; } -// Listener and handler -// -------------------- +// WebSocket message reader. Given websocket message, acts as an iterator that +// can return zero or more Messages. When next returns null, any incomplete +// message will remain in reader.data +const Reader = struct { + data: []u8, -pub fn handle( - alloc: std.mem.Allocator, - loop: *jsruntime.Loop, - server_socket: std.posix.socket_t, - stream: ?*Stream, + fn next(self: *Reader) !?Message { + var data = self.data; + if (data.len < 2) { + return null; + } + + const byte1 = data[0]; + + if (byte1 & 112 != 0) { + return error.ReservedFlags; + } + + var message_type: Message.Type = undefined; + switch (byte1 & 15) { + 0 => return error.ContinuationNotSupported, // TODO?? + 1 => message_type = .text, + 2 => message_type = .binary, + 8 => message_type = .close, + 9 => message_type = .ping, + 10 => message_type = .pong, + else => return error.InvalidMessageType, + } + + if (byte1 & 128 != 128) { + // TODO?? + return error.ContinuationNotSupported; + } + + const byte2 = data[1]; + if (byte2 & 128 != 128) { + // client -> server messages _must_ be masked + return error.NotMasked; + } + + const length_of_len: usize = switch (byte2 & 127) { + 126 => 2, + 127 => 8, + else => 0, + }; + + if (data.len < length_of_len + 2) { + // we definitely don't have enough data yet + return null; + } + + const message_len = switch (length_of_len) { + 2 => @as(u16, @intCast(data[3])) | @as(u16, @intCast(data[2])) << 8, + 8 => @as(u64, @intCast(data[9])) | @as(u64, @intCast(data[8])) << 8 | @as(u64, @intCast(data[7])) << 16 | @as(u64, @intCast(data[6])) << 24 | @as(u64, @intCast(data[5])) << 32 | @as(u64, @intCast(data[4])) << 40 | @as(u64, @intCast(data[3])) << 48 | @as(u64, @intCast(data[2])) << 56, + else => data[1] & 127, + } + length_of_len + 2 + 4; // +2 for header prefix, +4 for mask + + if (message_len > MAX_MESSAGE_SIZE) { + return error.TooLarge; + } + + if (data.len < message_len) { + return null; + } + + // prefix + length_of_len + mask + const header_len = 2 + length_of_len + 4; + + const payload = data[header_len..message_len]; + mask(data[header_len - 4 .. header_len], payload); + + self.data = data[message_len..]; + return .{ + .type = message_type, + .data = payload, + }; + } +}; + +const Message = struct { + type: Type, + data: []const u8, + + const Type = enum { + text, + binary, + close, + ping, + pong, + }; +}; + +// These are the only websocket types that we're currently sending +const OpCode = enum(u8) { + text = 128 | 1, + close = 128 | 8, + pong = 128 | 10, +}; + +// "private-use" close codes must be from 4000-49999 +const CloseCode = enum { + timeout, +}; + +pub fn run( + allocator: Allocator, + address: net.Address, timeout: u64, -) anyerror!void { + loop: *jsruntime.Loop, +) !void { + // create socket + const flags = posix.SOCK.STREAM | posix.SOCK.CLOEXEC | posix.SOCK.NONBLOCK; + const listener = try posix.socket(address.any.family, flags, posix.IPPROTO.TCP); + defer posix.close(listener); + + try posix.setsockopt(listener, posix.SOL.SOCKET, posix.SO.REUSEADDR, &std.mem.toBytes(@as(c_int, 1))); + // TODO: Broken on darwin + // https://github.com/ziglang/zig/issues/17260 (fixed in Zig 0.14) + // if (@hasDecl(os.TCP, "NODELAY")) { + // try os.setsockopt(socket.sockfd.?, os.IPPROTO.TCP, os.TCP.NODELAY, &std.mem.toBytes(@as(c_int, 1))); + // } + try posix.setsockopt(listener, posix.IPPROTO.TCP, 1, &std.mem.toBytes(@as(c_int, 1))); + + // bind & listen + try posix.bind(listener, &address.any, address.getOsSockLen()); + try posix.listen(listener, 1); // create v8 vm const vm = jsruntime.VM.init(); @@ -431,43 +942,30 @@ pub fn handle( // browser var browser: Browser = undefined; - try Browser.init(&browser, alloc, loop, vm); + try Browser.init(&browser, allocator, loop, vm); defer browser.deinit(); - // create buffers - var read_buf: [BufReadSize]u8 = undefined; - var buf: [MaxSize]u8 = undefined; - var msg_buf = MsgBuffer{ .buf = &buf }; + const json_version_response = try buildJSONVersionResponse(allocator, address); - // create I/O completions - var accept_completion: Completion = undefined; - var conn_completion: Completion = undefined; - var timeout_completion: Completion = undefined; - - // create I/O contexts and callbacks - // for accepting connections and receving messages - var ctx = Ctx{ + var server = Server{ .loop = loop, - .stream = stream, - .browser = &browser, - .sessionNew = true, - .read_buf = &read_buf, - .msg_buf = &msg_buf, - .accept_socket = server_socket, .timeout = timeout, - .accept_completion = &accept_completion, - .conn_completion = &conn_completion, - .timeout_completion = &timeout_completion, + .browser = &browser, + .listener = listener, + .allocator = allocator, + .conn_completion = undefined, + .close_completion = undefined, + .accept_completion = undefined, + .timeout_completion = undefined, + .json_version_response = json_version_response, + .send_pool = std.heap.MemoryPool(Send).init(allocator), }; - try browser.session.initInspector( - &ctx, - Ctx.onInspectorResp, - Ctx.onInspectorNotif, - ); + defer server.deinit(); - // accepting connection asynchronously on internal server - log.info("accepting new conn...", .{}); - loop.io.accept(*Ctx, &ctx, Ctx.acceptCbk, ctx.acceptCompletion(), ctx.accept_socket); + try browser.session.initInspector(&server, Server.inspectorResponse, Server.inspectorEvent); + + // accept an connection + server.queueAccept(); // infinite loop on I/O events, either: // - cmd from incoming connection on server socket @@ -476,58 +974,566 @@ pub fn handle( try loop.io.run_for_ns(10 * std.time.ns_per_ms); if (loop.cbk_error) { log.err("JS error", .{}); - // if (try try_catch.exception(alloc, js_env.*)) |msg| { - // std.debug.print("\n\rUncaught {s}\n\r", .{msg}); - // alloc.free(msg); - // } - // loop.cbk_error = false; - } - if (ctx.err) |err| { - if (err != error.NoError) log.err("Server error: {any}", .{err}); - break; } } } -fn setSockOpt(fd: std.posix.socket_t, level: i32, option: u32, value: c_int) !void { - try std.posix.setsockopt(fd, level, option, &std.mem.toBytes(value)); +// Utils +// -------- + +fn buildJSONVersionResponse( + allocator: Allocator, + address: net.Address, +) ![]const u8 { + const body_format = "{{\"webSocketDebuggerUrl\": \"ws://{}/\"}}"; + const body_len = std.fmt.count(body_format, .{address}); + + const response_format = + "HTTP/1.1 200 OK\r\n" ++ + "Content-Length: {d}\r\n" ++ + "Content-Type: application/json; charset=UTF-8\r\n\r\n" ++ + body_format; + return try std.fmt.allocPrint(allocator, response_format, .{ body_len, address }); } -fn isUnixSocket(addr: std.net.Address) bool { - return addr.any.family == std.posix.AF.UNIX; +fn now() std.time.Instant { + // can only fail on platforms we don't support + return std.time.Instant.now() catch unreachable; } -pub fn listen(address: std.net.Address) !std.posix.socket_t { - const isunixsock = isUnixSocket(address); - - // create socket - const flags = std.posix.SOCK.STREAM | std.posix.SOCK.CLOEXEC | std.posix.SOCK.NONBLOCK; - const proto = if (isunixsock) @as(u32, 0) else std.posix.IPPROTO.TCP; - const sockfd = try std.posix.socket(address.any.family, flags, proto); - errdefer std.posix.close(sockfd); - - // socket options - // - // REUSEPORT can't be set on unix socket anymore. - // see https://github.com/torvalds/linux/commit/5b0af621c3f6ef9261cf6067812f2fd9943acb4b - if (@hasDecl(std.posix.SO, "REUSEPORT") and !isunixsock) { - try setSockOpt(sockfd, std.posix.SOL.SOCKET, std.posix.SO.REUSEPORT, 1); +// In-place string lowercase +fn toLower(str: []u8) []u8 { + for (str, 0..) |c, i| { + str[i] = std.ascii.toLower(c); } - try setSockOpt(sockfd, std.posix.SOL.SOCKET, std.posix.SO.REUSEADDR, 1); - if (!isUnixSocket(address)) { - if (builtin.target.os.tag == .linux) { // posix.TCP not available on MacOS - // WARNING: disable Nagle's alogrithm to avoid latency issues - try setSockOpt(sockfd, std.posix.IPPROTO.TCP, std.posix.TCP.NODELAY, 1); + return str; +} + +// Zig is in a weird backend transition right now. Need to determine if +// SIMD is even available. +const backend_supports_vectors = switch (builtin.zig_backend) { + .stage2_llvm, .stage2_c => true, + else => false, +}; + +// Websocket messages from client->server are masked using a 4 byte XOR mask +fn mask(m: []const u8, payload: []u8) void { + var data = payload; + + if (!comptime backend_supports_vectors) return simpleMask(m, data); + + const vector_size = std.simd.suggestVectorLength(u8) orelse @sizeOf(usize); + if (data.len >= vector_size) { + const mask_vector = std.simd.repeat(vector_size, @as(@Vector(4, u8), m[0..4].*)); + while (data.len >= vector_size) { + const slice = data[0..vector_size]; + const masked_data_slice: @Vector(vector_size, u8) = slice.*; + slice.* = masked_data_slice ^ mask_vector; + data = data[vector_size..]; } } - - // bind & listen - var socklen = address.getOsSockLen(); - try std.posix.bind(sockfd, &address.any, socklen); - const kernel_backlog = 1; // default value is 128. Here we just want 1 connection - try std.posix.listen(sockfd, kernel_backlog); - var listen_address: std.net.Address = undefined; - try std.posix.getsockname(sockfd, &listen_address.any, &socklen); - - return sockfd; + simpleMask(m, data); } + +// Used when SIMD isn't available, or for any remaining part of the message +// which is too small to effectively use SIMD. +fn simpleMask(m: []const u8, payload: []u8) void { + for (payload, 0..) |b, i| { + payload[i] = b ^ m[i & 3]; + } +} + +const testing = std.testing; +test "server: buildJSONVersionResponse" { + const address = try net.Address.parseIp4("127.0.0.1", 9001); + const res = try buildJSONVersionResponse(testing.allocator, address); + defer testing.allocator.free(res); + + try testing.expectEqualStrings("HTTP/1.1 200 OK\r\n" ++ + "Content-Length: 48\r\n" ++ + "Content-Type: application/json; charset=UTF-8\r\n\r\n" ++ + "{\"webSocketDebuggerUrl\": \"ws://127.0.0.1:9001/\"}", res); +} + +test "Client: http invalid handshake" { + try assertHTTPError( + error.InvalidRequest, + 400, + "Invalid request", + "\r\n\r\n", + ); + + try assertHTTPError( + error.NotFound, + 404, + "Not found", + "GET /over/9000 HTTP/1.1\r\n\r\n", + ); + + try assertHTTPError( + error.NotFound, + 404, + "Not found", + "POST / HTTP/1.1\r\n\r\n", + ); + + try assertHTTPError( + error.InvalidProtocol, + 400, + "Invalid HTTP protocol", + "GET / HTTP/1.0\r\n\r\n", + ); + + try assertHTTPError( + error.MissingHeaders, + 400, + "Missing required header", + "GET / HTTP/1.1\r\n\r\n", + ); + + try assertHTTPError( + error.MissingHeaders, + 400, + "Missing required header", + "GET / HTTP/1.1\r\nConnection: upgrade\r\n\r\n", + ); + + try assertHTTPError( + error.MissingHeaders, + 400, + "Missing required header", + "GET / HTTP/1.1\r\nConnection: upgrade\r\nUpgrade: websocket\r\n\r\n", + ); + + try assertHTTPError( + error.MissingHeaders, + 400, + "Missing required header", + "GET / HTTP/1.1\r\nConnection: upgrade\r\nUpgrade: websocket\r\nsec-websocket-version:13\r\n\r\n", + ); +} + +test "Client: http valid handshake" { + var ms = MockServer{}; + defer ms.deinit(); + + var client = Client(*MockServer).init(0, &ms); + + const request = + "GET / HTTP/1.1\r\n" ++ + "Connection: upgrade\r\n" ++ + "Upgrade: websocket\r\n" ++ + "sec-websocket-version:13\r\n" ++ + "sec-websocket-key: this is my key\r\n" ++ + "Custom: Header-Value\r\n\r\n"; + + @memcpy(client.read_buf[0..request.len], request); + try testing.expectEqual(true, try client.processData(request.len)); + + try testing.expectEqual(.websocket, client.mode); + try testing.expectEqualStrings( + "HTTP/1.1 101 Switching Protocols\r\n" ++ + "Upgrade: websocket\r\n" ++ + "Connection: upgrade\r\n" ++ + "Sec-Websocket-Accept: flzHu2DevQ2dSCSVqKSii5e9C2o=\r\n\r\n", + ms.sent.items[0], + ); +} + +test "Client: http get json version" { + var ms = MockServer{}; + defer ms.deinit(); + + var client = Client(*MockServer).init(0, &ms); + + const request = "GET /json/version HTTP/1.1\r\n\r\n"; + + @memcpy(client.read_buf[0..request.len], request); + try testing.expectEqual(true, try client.processData(request.len)); + + try testing.expectEqual(.http, client.mode); + + // this is the hardcoded string in our MockServer + try testing.expectEqualStrings("the json version response", ms.sent.items[0]); +} + +test "Client: write websocket message" { + var ms = MockServer{}; + defer ms.deinit(); + + var client = Client(*MockServer).init(0, &ms); + + const cases = [_]struct { expected: []const u8, message: []const u8 }{ + .{ .expected = &.{ 129, 0 }, .message = "" }, + .{ .expected = [_]u8{ 129, 12 } ++ "hello world!", .message = "hello world!" }, + .{ .expected = [_]u8{ 129, 126, 0, 130 } ++ ("A" ** 130), .message = "A" ** 130 }, + }; + + for (cases) |c| { + ms.sent.clearRetainingCapacity(); + try client.sendWS(try testing.allocator.dupe(u8, c.message)); + try testing.expectEqual(1, ms.sent.items.len); + try testing.expectEqualSlices(u8, c.expected, ms.sent.items[0]); + } +} + +test "Client: read invalid websocket message" { + try assertWebSocketError( + error.InvalidMessageType, + 1002, + "", + &.{ 131, 1 }, // 128 (fin) | 3 where 3 isn't a valid type + ); + + try assertWebSocketError( + error.ContinuationNotSupported, + 1003, + "", + &.{ 128, 1 }, // 128 (fin) | 0 where 0 is a continuation frame + ); + + try assertWebSocketError( + error.ContinuationNotSupported, + 1003, + "", + &.{ 1, 1 }, // 0 (non-fin) | 1 non-fin (contination) not supported + ); + + for ([_]u8{ 16, 32, 64 }) |rsv| { + // none of the reserve flags should be set + try assertWebSocketError( + error.ReservedFlags, + 1002, + "", + &.{ rsv, 0 }, + ); + + // as a bitmask + try assertWebSocketError( + error.ReservedFlags, + 1002, + "", + &.{ rsv + 4, 0 }, + ); + } + + try assertWebSocketError( + error.NotMasked, + 1002, + "", + &.{ 129, 127 }, // client->server messages must be masked + ); + + try assertWebSocketError( + error.TooLarge, + 1009, + "", + &.{ 129, 255, 0, 0, 0, 0, 0, 4, 0, 1 }, // 1024 * 256 + 1 + ); +} + +test "Client: ping reply" { + try assertWebSocketMessage( + // fin | pong, len + &.{ 138, 0 }, + + // fin | ping, masked | len, 4-byte mask + &.{ 137, 128, 0, 0, 0, 0 }, + ); + + try assertWebSocketMessage( + // fin | pong, len, payload + &.{ 138, 5, 100, 96, 97, 109, 104 }, + + // fin | ping, masked | len, 4-byte mask, 5 byte payload + &.{ 137, 133, 0, 5, 7, 10, 100, 101, 102, 103, 104 }, + ); +} + +test "Client: close message" { + try assertWebSocketMessage( + // fin | close, len, close code (normal) + &.{ 136, 2, 3, 232 }, + + // fin | close, masked | len, 4-byte mask + &.{ 136, 128, 0, 0, 0, 0 }, + ); +} + +// Testing both HTTP and websocket messages broken up across multiple reads. +// We need to fuzz HTTP messages differently than websocket. HTTP are strictly +// req -> res with no pipelining. So there should only be 1 message at a time. +// So we can only "fuzz" on a per-message basis. +// But for websocket, we can fuzz _all_ the messages together. +test "Client: fuzz" { + var prng = std.rand.DefaultPrng.init(blk: { + var seed: u64 = undefined; + try std.posix.getrandom(std.mem.asBytes(&seed)); + break :blk seed; + }); + const random = prng.random(); + + const allocator = testing.allocator; + var websocket_messages: std.ArrayListUnmanaged(u8) = .{}; + defer websocket_messages.deinit(allocator); + + // ping with no payload + try websocket_messages.appendSlice( + allocator, + &.{ 137, 128, 0, 0, 0, 0 }, + ); + + // // 10 byte text message with a 0,0,0,0 mask + try websocket_messages.appendSlice( + allocator, + &.{ 129, 138, 0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 }, + ); + + // ping with a payload + try websocket_messages.appendSlice( + allocator, + &.{ 137, 133, 0, 5, 7, 10, 100, 101, 102, 103, 104 }, + ); + + // pong with no payload (noop in the server) + try websocket_messages.appendSlice( + allocator, + &.{ 138, 128, 10, 10, 10, 10 }, + ); + + // 687 long message, with a mask + try websocket_messages.appendSlice( + allocator, + [_]u8{ 129, 254, 2, 175, 1, 2, 3, 4 } ++ "A" ** 687, + ); + + // close + try websocket_messages.appendSlice( + allocator, + &.{ 136, 130, 200, 103, 34, 22, 0, 1 }, + ); + + const SendRandom = struct { + fn send(c: anytype, r: std.Random, data: []const u8) !void { + var buf = data; + while (buf.len > 0) { + const to_send = r.intRangeAtMost(usize, 1, buf.len); + @memcpy(c.readBuf()[0..to_send], buf[0..to_send]); + if (try c.processData(to_send) == false) { + return; + } + buf = buf[to_send..]; + } + } + }; + + for (0..1) |_| { + var ms = MockServer{}; + defer ms.deinit(); + + var client = Client(*MockServer).init(0, &ms); + + try SendRandom.send(&client, random, "GET /json/version HTTP/1.1\r\nContent-Length: 0\r\n\r\n"); + try SendRandom.send(&client, random, "GET / HTTP/1.1\r\n" ++ + "Connection: upgrade\r\n" ++ + "Upgrade: websocket\r\n" ++ + "sec-websocket-version:13\r\n" ++ + "sec-websocket-key: 1234aa93\r\n" ++ + "Custom: Header-Value\r\n\r\n"); + + // fuzz over all websocket messages + try SendRandom.send(&client, random, websocket_messages.items); + + try testing.expectEqual(5, ms.sent.items.len); + + try testing.expectEqualStrings( + "the json version response", + ms.sent.items[0], + ); + + try testing.expectEqualStrings( + "HTTP/1.1 101 Switching Protocols\r\n" ++ + "Upgrade: websocket\r\n" ++ + "Connection: upgrade\r\n" ++ + "Sec-Websocket-Accept: KnOKWrrjHS0nGFmtfmYFQoPIGKQ=\r\n\r\n", + ms.sent.items[1], + ); + + try testing.expectEqualSlices(u8, &.{ 138, 0 }, ms.sent.items[2]); + + try testing.expectEqualSlices( + u8, + &.{ 138, 5, 100, 96, 97, 109, 104 }, + ms.sent.items[3], + ); + + try testing.expectEqualSlices( + u8, + &.{ 136, 2, 3, 232 }, + ms.sent.items[4], + ); + + try testing.expectEqual(2, ms.cdp.items.len); + try testing.expectEqualSlices( + u8, + &.{ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 }, + ms.cdp.items[0], + ); + + try testing.expectEqualSlices( + u8, + &([_]u8{ 64, 67, 66, 69 } ** 171 ++ [_]u8{ 64, 67, 66 }), + ms.cdp.items[1], + ); + + try testing.expectEqual(true, ms.closed); + } +} + +test "server: mask" { + var buf: [4000]u8 = undefined; + const messages = [_][]const u8{ "1234", "1234" ** 99, "1234" ** 999 }; + for (messages) |message| { + // we need the message to be mutable since mask operates in-place + const payload = buf[0..message.len]; + @memcpy(payload, message); + + mask(&.{ 1, 2, 200, 240 }, payload); + try testing.expectEqual(false, std.mem.eql(u8, payload, message)); + + mask(&.{ 1, 2, 200, 240 }, payload); + try testing.expectEqual(true, std.mem.eql(u8, payload, message)); + } +} + +fn assertHTTPError( + expected_error: HTTPError, + comptime expected_status: u16, + comptime expected_body: []const u8, + input: []const u8, +) !void { + var ms = MockServer{}; + defer ms.deinit(); + + var client = Client(*MockServer).init(0, &ms); + @memcpy(client.read_buf[0..input.len], input); + try testing.expectError(expected_error, client.processData(input.len)); + + const expected_response = std.fmt.comptimePrint( + "HTTP/1.1 {d} \r\nConnection: Close\r\nContent-Length: {d}\r\n\r\n{s}", + .{ expected_status, expected_body.len, expected_body }, + ); + + try testing.expectEqual(1, ms.sent.items.len); + try testing.expectEqualStrings(expected_response, ms.sent.items[0]); +} + +fn assertWebSocketError( + expected_error: WebSocketError, + close_code: u16, + close_payload: []const u8, + input: []const u8, +) !void { + var ms = MockServer{}; + defer ms.deinit(); + + var client = Client(*MockServer).init(0, &ms); + client.mode = .websocket; // force websocket message processing + + @memcpy(client.read_buf[0..input.len], input); + try testing.expectError(expected_error, client.processData(input.len)); + + try testing.expectEqual(1, ms.sent.items.len); + + const actual = ms.sent.items[0]; + + // fin | close opcode + try testing.expectEqual(136, actual[0]); + + // message length (code + payload) + try testing.expectEqual(2 + close_payload.len, actual[1]); + + // close code + try testing.expectEqual(close_code, std.mem.readInt(u16, actual[2..4], .big)); + + // close payload (if any) + try testing.expectEqualStrings(close_payload, actual[4..]); +} + +fn assertWebSocketMessage( + expected: []const u8, + input: []const u8, +) !void { + var ms = MockServer{}; + defer ms.deinit(); + + var client = Client(*MockServer).init(0, &ms); + client.mode = .websocket; // force websocket message processing + + @memcpy(client.read_buf[0..input.len], input); + const more = try client.processData(input.len); + + try testing.expectEqual(1, ms.sent.items.len); + try testing.expectEqualSlices(u8, expected, ms.sent.items[0]); + + // if we sent a close message, then the serve should have been told + // to close the connection + if (expected[0] == 136) { + try testing.expectEqual(true, ms.closed); + try testing.expectEqual(false, more); + } else { + try testing.expectEqual(false, ms.closed); + try testing.expectEqual(true, more); + } +} + +const MockServer = struct { + closed: bool = false, + + // record the messages we sent to the client + sent: std.ArrayListUnmanaged([]const u8) = .{}, + + // record the CDP messages we need to process + cdp: std.ArrayListUnmanaged([]const u8) = .{}, + + allocator: Allocator = testing.allocator, + + json_version_response: []const u8 = "the json version response", + + fn deinit(self: *MockServer) void { + const allocator = self.allocator; + + for (self.sent.items) |msg| { + allocator.free(msg); + } + self.sent.deinit(allocator); + + for (self.cdp.items) |msg| { + allocator.free(msg); + } + self.cdp.deinit(allocator); + } + + fn queueClose(self: *MockServer, _: anytype) void { + self.closed = true; + } + + fn handleCDP(self: *MockServer, message: []const u8) !void { + const owned = try self.allocator.dupe(u8, message); + try self.cdp.append(self.allocator, owned); + } + + fn queueSend( + self: *MockServer, + socket: posix.socket_t, + data: []const u8, + free_when_done: bool, + ) !void { + _ = socket; + const owned = try self.allocator.dupe(u8, data); + try self.sent.append(self.allocator, owned); + if (free_when_done) { + testing.allocator.free(data); + } + } +};