diff --git a/.gitmodules b/.gitmodules index 184dd202..5743ca29 100644 --- a/.gitmodules +++ b/.gitmodules @@ -28,7 +28,3 @@ [submodule "vendor/zig-async-io"] path = vendor/zig-async-io url = https://github.com/lightpanda-io/zig-async-io.git/ -[submodule "vendor/websocket.zig"] - path = vendor/websocket.zig - url = https://github.com/lightpanda-io/websocket.zig.git/ - branch = lightpanda diff --git a/build.zig b/build.zig index 44e99222..adf4e26f 100644 --- a/build.zig +++ b/build.zig @@ -189,11 +189,6 @@ fn common( .root_source_file = b.path("vendor/tls.zig/src/main.zig"), }); step.root_module.addImport("tls", tlsmod); - - const wsmod = b.addModule("websocket", .{ - .root_source_file = b.path("vendor/websocket.zig/src/websocket.zig"), - }); - step.root_module.addImport("websocket", wsmod); } fn moduleNetSurf(b: *std.Build, target: std.Build.ResolvedTarget) !*std.Build.Module { diff --git a/src/cdp/runtime.zig b/src/cdp/runtime.zig index 44c1a907..054d5a78 100644 --- a/src/cdp/runtime.zig +++ b/src/cdp/runtime.zig @@ -131,12 +131,12 @@ fn sendInspector( const buf = try alloc.alloc(u8, msg.json.len + 1); defer alloc.free(buf); _ = std.mem.replace(u8, msg.json, "\"awaitPromise\":true", "\"awaitPromise\":false", buf); - ctx.sendInspector(buf); + try ctx.sendInspector(buf); return ""; } } - ctx.sendInspector(msg.json); + try ctx.sendInspector(msg.json); if (msg.id == null) return ""; diff --git a/src/handler.zig b/src/handler.zig deleted file mode 100644 index 0decb3f7..00000000 --- a/src/handler.zig +++ /dev/null @@ -1,95 +0,0 @@ -// Copyright (C) 2023-2024 Lightpanda (Selecy SAS) -// -// Francis Bouvier -// Pierre Tachoire -// -// This program is free software: you can redistribute it and/or modify -// it under the terms of the GNU Affero General Public License as -// published by the Free Software Foundation, either version 3 of the -// License, or (at your option) any later version. -// -// This program is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU Affero General Public License for more details. -// -// You should have received a copy of the GNU Affero General Public License -// along with this program. If not, see . - -const std = @import("std"); - -const ws = @import("websocket"); -const Msg = @import("msg.zig").Msg; - -const log = std.log.scoped(.handler); - -pub const Stream = struct { - addr: std.net.Address, - socket: std.posix.socket_t = undefined, - - ws_host: []const u8, - ws_port: u16, - ws_conn: *ws.Conn = undefined, - - fn connectCDP(self: *Stream) !void { - const flags: u32 = std.posix.SOCK.STREAM; - const proto = blk: { - if (self.addr.any.family == std.posix.AF.UNIX) break :blk @as(u32, 0); - break :blk std.posix.IPPROTO.TCP; - }; - const socket = try std.posix.socket(self.addr.any.family, flags, proto); - - try std.posix.connect( - socket, - &self.addr.any, - self.addr.getOsSockLen(), - ); - log.debug("connected to Stream server", .{}); - self.socket = socket; - } - - fn closeCDP(self: *const Stream) void { - const close_msg: []const u8 = .{ 5, 0, 0, 0 } ++ "close"; - self.recv(close_msg) catch |err| { - log.err("stream close error: {any}", .{err}); - }; - std.posix.close(self.socket); - } - - fn start(self: *Stream, ws_conn: *ws.Conn) !void { - try self.connectCDP(); - self.ws_conn = ws_conn; - } - - pub fn recv(self: *const Stream, data: []const u8) !void { - var pos: usize = 0; - while (pos < data.len) { - const len = try std.posix.write(self.socket, data[pos..]); - pos += len; - } - } - - pub fn send(self: *const Stream, data: []const u8) !void { - return self.ws_conn.write(data); - } -}; - -pub const Handler = struct { - stream: *Stream, - - pub fn init(_: ws.Handshake, ws_conn: *ws.Conn, stream: *Stream) !Handler { - try stream.start(ws_conn); - return .{ .stream = stream }; - } - - pub fn close(self: *Handler) void { - self.stream.closeCDP(); - } - - pub fn clientMessage(self: *Handler, data: []const u8) !void { - var header: [4]u8 = undefined; - Msg.setSize(data.len, &header); - try self.stream.recv(&header); - try self.stream.recv(data); - } -}; diff --git a/src/main.zig b/src/main.zig index e4da1df2..c5c04996 100644 --- a/src/main.zig +++ b/src/main.zig @@ -20,12 +20,9 @@ const std = @import("std"); const builtin = @import("builtin"); const jsruntime = @import("jsruntime"); -const websocket = @import("websocket"); const Browser = @import("browser/browser.zig").Browser; const server = @import("server.zig"); -const handler = @import("handler.zig"); -const MaxSize = @import("msg.zig").MaxSize; const parser = @import("netsurf"); const apiweb = @import("apiweb.zig"); @@ -86,11 +83,9 @@ const CliMode = union(CliModeTag) { const Server = struct { execname: []const u8 = undefined, args: *std.process.ArgIterator = undefined, - addr: std.net.Address = undefined, host: []const u8 = Host, port: u16 = Port, timeout: u8 = Timeout, - tcp: bool = false, // undocumented TCP mode // default options const Host = "127.0.0.1"; @@ -160,10 +155,6 @@ const CliMode = union(CliModeTag) { return printUsageExit(execname, 1); } } - if (std.mem.eql(u8, "--tcp", opt)) { - _server.tcp = true; - continue; - } // unknown option if (std.mem.startsWith(u8, opt, "--")) { @@ -186,10 +177,6 @@ const CliMode = union(CliModeTag) { if (default_mode == .server) { // server mode - _server.addr = std.net.Address.parseIp4(_server.host, _server.port) catch |err| { - log.err("address (host:port) {any}\n", .{err}); - return printUsageExit(execname, 1); - }; _server.execname = execname; _server.args = args; return CliMode{ .server = _server }; @@ -247,65 +234,19 @@ pub fn main() !void { switch (cli_mode) { .server => |opts| { - - // Stream server - const addr = blk: { - if (opts.tcp) { - break :blk opts.addr; - } else { - const unix_path = "/tmp/lightpanda"; - std.fs.deleteFileAbsolute(unix_path) catch {}; // file could not exists - break :blk try std.net.Address.initUnix(unix_path); - } - }; - const socket = server.listen(addr) catch |err| { - log.err("Server listen error: {any}\n", .{err}); + const address = std.net.Address.parseIp4(opts.host, opts.port) catch |err| { + log.err("address (host:port) {any}\n", .{err}); return printUsageExit(opts.execname, 1); }; - defer std.posix.close(socket); - log.debug("Server opts: listening internally on {any}...", .{addr}); - const timeout = std.time.ns_per_s * @as(u64, opts.timeout); - - // loop var loop = try jsruntime.Loop.init(alloc); defer loop.deinit(); - // TCP server mode - if (opts.tcp) { - return server.handle(alloc, &loop, socket, null, timeout); - } - - // start stream server in separate thread - var stream = handler.Stream{ - .ws_host = opts.host, - .ws_port = opts.port, - .addr = addr, + const timeout = std.time.ns_per_s * @as(u64, opts.timeout); + server.run(alloc, address, timeout, &loop) catch |err| { + log.err("Server error", .{}); + return err; }; - const cdp_thread = try std.Thread.spawn( - .{ .allocator = alloc }, - server.handle, - .{ alloc, &loop, socket, &stream, timeout }, - ); - - // Websocket server - var ws = try websocket.Server(handler.Handler).init(alloc, .{ - .port = opts.port, - .address = opts.host, - .max_message_size = MaxSize + 14, // overhead websocket - .max_conn = 1, - .handshake = .{ - .timeout = 3, - .max_size = 1024, - // since we aren't using hanshake.headers - // we can set this to 0 to save a few bytes. - .max_headers = 0, - }, - }); - defer ws.deinit(); - - try ws.listen(&stream); - cdp_thread.join(); }, .fetch => |opts| { diff --git a/src/main_tests.zig b/src/main_tests.zig index 3544cc6a..88acd1a5 100644 --- a/src/main_tests.zig +++ b/src/main_tests.zig @@ -314,9 +314,6 @@ const kb = 1024; const ms = std.time.ns_per_ms; test { - const msgTest = @import("msg.zig"); - std.testing.refAllDecls(msgTest); - const dumpTest = @import("browser/dump.zig"); std.testing.refAllDecls(dumpTest); @@ -340,6 +337,7 @@ test { std.testing.refAllDecls(@import("generate.zig")); std.testing.refAllDecls(@import("cdp/msg.zig")); + std.testing.refAllDecls(@import("server.zig")); } fn testJSRuntime(alloc: std.mem.Allocator) !void { diff --git a/src/msg.zig b/src/msg.zig deleted file mode 100644 index 13b7a62e..00000000 --- a/src/msg.zig +++ /dev/null @@ -1,166 +0,0 @@ -// Copyright (C) 2023-2024 Lightpanda (Selecy SAS) -// -// Francis Bouvier -// Pierre Tachoire -// -// This program is free software: you can redistribute it and/or modify -// it under the terms of the GNU Affero General Public License as -// published by the Free Software Foundation, either version 3 of the -// License, or (at your option) any later version. -// -// This program is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU Affero General Public License for more details. -// -// You should have received a copy of the GNU Affero General Public License -// along with this program. If not, see . - -const std = @import("std"); - -pub const HeaderSize = 4; -pub const MsgSize = 256 * 1204; // 256KB -// NOTE: Theorically we could go up to 4GB with a 4 bytes binary encoding -// but we prefer to put a lower hard limit for obvious memory size reasons. - -pub const MaxSize = HeaderSize + MsgSize; - -pub const Msg = struct { - pub fn getSize(data: []const u8) usize { - return std.mem.readInt(u32, data[0..HeaderSize], .little); - } - - pub fn setSize(len: usize, header: *[4]u8) void { - std.mem.writeInt(u32, header, @intCast(len), .little); - } -}; - -/// Buffer returns messages from a raw text read stream, -/// with the message size being encoded on the 2 first bytes (little endian) -/// It handles both: -/// - combined messages in one read -/// - single message in several reads (multipart) -/// It's safe (and a good practice) to reuse the same Buffer -/// on several reads of the same stream. -pub const Buffer = struct { - buf: []u8, - size: usize = 0, - pos: usize = 0, - - fn isFinished(self: *const Buffer) bool { - return self.pos >= self.size; - } - - fn isEmpty(self: *const Buffer) bool { - return self.size == 0 and self.pos == 0; - } - - fn reset(self: *Buffer) void { - self.size = 0; - self.pos = 0; - } - - // read input - pub fn read(self: *Buffer, input: []const u8) !struct { - msg: []const u8, - left: []const u8, - } { - var _input = input; // make input writable - - // msg size - var msg_size: usize = undefined; - if (self.isEmpty()) { - // decode msg size header - msg_size = Msg.getSize(_input); - _input = _input[HeaderSize..]; - } else { - msg_size = self.size; - } - - // multipart - const is_multipart = !self.isEmpty() or _input.len < msg_size; - if (is_multipart) { - - // set msg size on empty Buffer - if (self.isEmpty()) { - self.size = msg_size; - } - - // get the new position of the cursor - const new_pos = self.pos + _input.len; - - // check max limit size - if (new_pos > MaxSize) { - return error.MsgTooBig; - } - - // copy the current input into Buffer - // NOTE: we could use @memcpy but it's not Thread-safe (alias problem) - // see https://www.openmymind.net/Zigs-memcpy-copyForwards-and-copyBackwards/ - // Intead we just use std.mem.copyForwards - std.mem.copyForwards(u8, self.buf[self.pos..new_pos], _input[0..]); - - // set the new cursor position - self.pos = new_pos; - - // if multipart is not finished, go fetch the next input - if (!self.isFinished()) return error.MsgMultipart; - - // otherwhise multipart is finished, use its buffer as input - _input = self.buf[0..self.pos]; - self.reset(); - } - - // handle several JSON msg in 1 read - return .{ .msg = _input[0..msg_size], .left = _input[msg_size..] }; - } -}; - -test "Buffer" { - const Case = struct { - input: []const u8, - nb: u8, - }; - - const cases = [_]Case{ - // simple - .{ .input = .{ 2, 0, 0, 0 } ++ "ok", .nb = 1 }, - // combined - .{ .input = .{ 2, 0, 0, 0 } ++ "ok" ++ .{ 3, 0, 0, 0 } ++ "foo", .nb = 2 }, - // multipart - .{ .input = .{ 9, 0, 0, 0 } ++ "multi", .nb = 0 }, - .{ .input = "part", .nb = 1 }, - // multipart & combined - .{ .input = .{ 9, 0, 0, 0 } ++ "multi", .nb = 0 }, - .{ .input = "part" ++ .{ 2, 0, 0, 0 } ++ "ok", .nb = 2 }, - // multipart & combined with other multipart - .{ .input = .{ 9, 0, 0, 0 } ++ "multi", .nb = 0 }, - .{ .input = "part" ++ .{ 8, 0, 0, 0 } ++ "co", .nb = 1 }, - .{ .input = "mbined", .nb = 1 }, - // several multipart - .{ .input = .{ 23, 0, 0, 0 } ++ "multi", .nb = 0 }, - .{ .input = "several", .nb = 0 }, - .{ .input = "complex", .nb = 0 }, - .{ .input = "part", .nb = 1 }, - // combined & multipart - .{ .input = .{ 2, 0, 0, 0 } ++ "ok" ++ .{ 9, 0, 0, 0 } ++ "multi", .nb = 1 }, - .{ .input = "part", .nb = 1 }, - }; - - var b: [MaxSize]u8 = undefined; - var buf = Buffer{ .buf = &b }; - - for (cases) |case| { - var nb: u8 = 0; - var input = case.input; - while (input.len > 0) { - const parts = buf.read(input) catch |err| { - if (err == error.MsgMultipart) break; // go to the next case input - return err; - }; - nb += 1; - input = parts.left; - } - try std.testing.expect(nb == case.nb); - } -} diff --git a/src/server.zig b/src/server.zig index eb7c0a18..9fbb26a4 100644 --- a/src/server.zig +++ b/src/server.zig @@ -19,7 +19,10 @@ const std = @import("std"); const builtin = @import("builtin"); -const Stream = @import("handler.zig").Stream; +const net = std.net; +const posix = std.posix; + +const Allocator = std.mem.Allocator; const jsruntime = @import("jsruntime"); const Completion = jsruntime.IO.Completion; @@ -30,241 +33,233 @@ const CloseError = jsruntime.IO.CloseError; const CancelError = jsruntime.IO.CancelOneError; const TimeoutError = jsruntime.IO.TimeoutError; -const MsgBuffer = @import("msg.zig").Buffer; -const MaxSize = @import("msg.zig").MaxSize; const Browser = @import("browser/browser.zig").Browser; const cdp = @import("cdp/cdp.zig"); -const NoError = error{NoError}; const IOError = AcceptError || RecvError || SendError || CloseError || TimeoutError || CancelError; -const Error = IOError || std.fmt.ParseIntError || cdp.Error || NoError; +const HTTPError = error{ + OutOfMemory, + RequestTooLarge, + NotFound, + InvalidRequest, + MissingHeaders, + InvalidProtocol, + InvalidUpgradeHeader, + InvalidVersionHeader, + InvalidConnectionHeader, +}; +const WebSocketError = error{ + OutOfMemory, + ReservedFlags, + NotMasked, + TooLarge, + InvalidMessageType, + ContinuationNotSupported, +}; +const Error = IOError || cdp.Error || HTTPError || WebSocketError; const TimeoutCheck = std.time.ns_per_ms * 100; const log = std.log.scoped(.server); -const isLinux = builtin.target.os.tag == .linux; -// I/O Main -// -------- +const MAX_HTTP_REQUEST_SIZE = 2048; -const BufReadSize = 1024; // 1KB -const MaxStdOutSize = 512; // ensure debug msg are not too long +// max message size, +14 for max websocket payload overhead +const MAX_MESSAGE_SIZE = 256 * 1024 + 14; -pub const Ctx = struct { +// For now, cdp does @import("server.zig").Ctx. Could change cdp to use "Server" +// but I rather try to decouple the CDP code from the server, so a quick +// stopgap is fine. TODO: Decouple cdp from the server +pub const Ctx = Server; + +const Server = struct { + allocator: Allocator, loop: *jsruntime.Loop, - stream: ?*Stream, // internal fields - accept_socket: std.posix.socket_t, - conn_socket: std.posix.socket_t = undefined, - read_buf: []u8, // only for read operations - msg_buf: *MsgBuffer, - err: ?Error = null, + listener: posix.socket_t, + client: ?Client(*Server) = null, + timeout: u64, + + // a memory poor for our Send objects + send_pool: std.heap.MemoryPool(Send), // I/O fields - accept_completion: *Completion, - conn_completion: *Completion, - timeout_completion: *Completion, - timeout: u64, - last_active: ?std.time.Instant = null, + conn_completion: Completion, + close_completion: Completion, + accept_completion: Completion, + timeout_completion: Completion, + + // used when gluing the session id to the inspector message + scrap: std.ArrayListUnmanaged(u8) = .{}, + + // The response to send on a GET /json/version request + json_version_response: []const u8, // CDP state: cdp.State = undefined, // JS fields browser: *Browser, // TODO: is pointer mandatory here? - sessionNew: bool, - // try_catch: jsruntime.TryCatch, // TODO pub fn deinit(self: *Ctx) void { self.state.deinit(); + self.send_pool.deinit(); + self.allocator.free(self.json_version_response); } - // callbacks - // --------- - - fn acceptCbk( - self: *Ctx, - completion: *Completion, - result: AcceptError!std.posix.socket_t, - ) void { - std.debug.assert(completion == self.acceptCompletion()); - - self.conn_socket = result catch |err| { - log.err("accept error: {any}", .{err}); - self.err = err; - return; - }; - log.info("client connected", .{}); - - // set connection timestamp and timeout - self.last_active = std.time.Instant.now() catch |err| { - log.err("accept timestamp error: {any}", .{err}); - return; - }; - self.loop.io.timeout( - *Ctx, + fn queueAccept(self: *Server) void { + log.info("accepting new conn...", .{}); + self.loop.io.accept( + *Server, self, - Ctx.timeoutCbk, - self.timeout_completion, + callbackAccept, + &self.accept_completion, + self.listener, + ); + } + + fn callbackAccept( + self: *Server, + completion: *Completion, + result: AcceptError!posix.socket_t, + ) void { + std.debug.assert(completion == &self.accept_completion); + + const socket = result catch |err| { + log.err("accept error: {any}", .{err}); + self.queueAccept(); + return; + }; + + self.newSession() catch |err| { + log.err("new session error: {any}", .{err}); + self.queueClose(socket); + return; + }; + + log.info("client connected", .{}); + self.client = Client(*Server).init(socket, self); + self.queueRead(); + self.queueTimeout(); + } + + fn queueTimeout(self: *Server) void { + self.loop.io.timeout( + *Server, + self, + callbackTimeout, + &self.timeout_completion, TimeoutCheck, ); - - // receving incomming messages asynchronously - self.loop.io.recv( - *Ctx, - self, - Ctx.readCbk, - self.conn_completion, - self.conn_socket, - self.read_buf, - ); } - fn readCbk(self: *Ctx, completion: *Completion, result: RecvError!usize) void { - std.debug.assert(completion == self.conn_completion); + fn callbackTimeout( + self: *Server, + completion: *Completion, + result: TimeoutError!void, + ) void { + std.debug.assert(completion == &self.timeout_completion); - const size = result catch |err| { - if (self.isClosed() and err == error.FileDescriptorInvalid) { - log.debug("read has been canceled", .{}); + const client = &(self.client orelse return); + + if (result) |_| { + if (now().since(client.last_active) > self.timeout) { + // close current connection + log.debug("conn timeout, closing...", .{}); + client.close(.timeout); return; } - log.err("read error: {any}", .{err}); - self.err = err; - return; - }; - - if (size == 0) { - // continue receving incomming messages asynchronously - self.loop.io.recv( - *Ctx, - self, - Ctx.readCbk, - self.conn_completion, - self.conn_socket, - self.read_buf, - ); - return; - } - - // set connection timestamp - self.last_active = std.time.Instant.now() catch |err| { - log.err("read timestamp error: {any}", .{err}); - return; - }; - - // continue receving incomming messages asynchronously - self.loop.io.recv( - *Ctx, - self, - Ctx.readCbk, - self.conn_completion, - self.conn_socket, - self.read_buf, - ); - - // read and execute input - var input: []const u8 = self.read_buf[0..size]; - while (input.len > 0) { - const parts = self.msg_buf.read(input) catch |err| { - if (err == error.MsgMultipart) { - return; - } else { - log.err("msg read error: {any}", .{err}); - return; - } - }; - input = parts.left; - // execute - self.do(parts.msg) catch |err| { - if (err != error.Closed) { - log.err("do error: {any}", .{err}); - log.debug("last msg: {s}", .{parts.msg}); - } - }; - } - } - - fn timeoutCbk(self: *Ctx, completion: *Completion, result: TimeoutError!void) void { - std.debug.assert(completion == self.timeout_completion); - - _ = result catch |err| { + } else |err| { log.err("timeout error: {any}", .{err}); - self.err = err; + } + + // We re-queue this if the timeout hasn't been exceeded or on some + // very unlikely IO timeout error. + // AKA: we don't requeue this if the connection timed out and we + // closed the connection.s + self.queueTimeout(); + } + + fn queueRead(self: *Server) void { + if (self.client) |*client| { + self.loop.io.recv( + *Server, + self, + callbackRead, + &self.conn_completion, + client.socket, + client.readBuf(), + ); + } + } + + fn callbackRead( + self: *Server, + completion: *Completion, + result: RecvError!usize, + ) void { + std.debug.assert(completion == &self.conn_completion); + + var client = &(self.client orelse return); + + const size = result catch |err| { + log.err("read error: {any}", .{err}); + self.queueClose(client.socket); return; }; - if (self.isClosed()) { - // conn is already closed, ignore timeout - return; - } - - // check time since last read - const now = std.time.Instant.now() catch |err| { - log.err("timeout timestamp error: {any}", .{err}); + const more = client.processData(size) catch |err| { + std.debug.print("Client Processing Error: {}\n", .{err}); return; }; - if (now.since(self.last_active.?) > self.timeout) { - // close current connection - log.debug("conn timeout, closing...", .{}); - self.close(); - return; + // if more == false, the client is disconnecting + if (more) { + self.queueRead(); } + } - // continue checking timeout - self.loop.io.timeout( - *Ctx, + fn queueSend( + self: *Server, + socket: posix.socket_t, + data: []const u8, + free_when_done: bool, + ) !void { + const sd = try self.send_pool.create(); + errdefer self.send_pool.destroy(sd); + + sd.* = .{ + .data = data, + .unsent = data, + .server = self, + .socket = socket, + .completion = undefined, + .free_when_done = free_when_done, + }; + sd.queueSend(); + } + + fn queueClose(self: *Server, socket: posix.socket_t) void { + self.loop.io.close( + *Server, self, - Ctx.timeoutCbk, - self.timeout_completion, - TimeoutCheck, + callbackClose, + &self.close_completion, + socket, ); } - // shortcuts - // --------- - - inline fn isClosed(self: *Ctx) bool { - // last_active is first saved on acceptCbk - return self.last_active == null; - } - - // allocator of the current session - inline fn alloc(self: *Ctx) std.mem.Allocator { - return self.browser.session.alloc; - } - - // JS env of the current session - inline fn env(self: Ctx) jsruntime.Env { - return self.browser.session.env; - } - - inline fn acceptCompletion(self: *Ctx) *Completion { - // NOTE: the logical completion to use here is the accept_completion - // as the pipe_connection can be used simulteanously by a recv I/O operation. - // But on MacOS (kqueue) the recv I/O operation on a closed socket leads to a panic - // so we use the pipe_connection to avoid this problem - if (isLinux) return self.accept_completion; - return self.conn_completion; - } - - // actions - // ------- - - fn do(self: *Ctx, cmd: []const u8) anyerror!void { - - // close cmd - if (std.mem.eql(u8, cmd, "close")) { - // close connection - log.info("close cmd, closing conn...", .{}); - self.close(); - return error.Closed; + fn callbackClose(self: *Server, completion: *Completion, _: CloseError!void) void { + std.debug.assert(completion == &self.close_completion); + if (self.client != null) { + self.client = null; } + self.queueAccept(); + } - if (self.sessionNew) self.sessionNew = false; - - const res = cdp.do(self.alloc(), cmd, self) catch |err| { + fn handleCDP(self: *Server, cmd: []const u8) !void { + const res = cdp.do(self.allocator, cmd, self) catch |err| { // cdp end cmd if (err == error.DisposeBrowserContext) { @@ -278,106 +273,106 @@ pub const Ctx = struct { }; // send result - if (!std.mem.eql(u8, res, "")) { + if (res.len != 0) { return self.send(res); } } - pub fn send(self: *Ctx, msg: []const u8) !void { - if (self.stream) |stream| { - // if we have a stream connection, just write on it - defer self.alloc().free(msg); - try stream.send(msg); - } else { - // otherwise write asynchronously on the socket connection - return sendAsync(self, msg); + // called from CDP + pub fn send(self: *Server, data: []const u8) !void { + if (self.client) |*client| { + try client.sendWS(data); } } - fn close(self: *Ctx) void { - - // conn is closed - self.last_active = null; - std.posix.close(self.conn_socket); - log.debug("connection closed", .{}); - - // restart a new browser session in case of re-connect - if (!self.sessionNew) { - self.newSession() catch |err| { - log.err("new session error: {any}", .{err}); - return; - }; - } - - log.info("accepting new conn...", .{}); - - // continue accepting incoming requests - self.loop.io.accept( - *Ctx, - self, - Ctx.acceptCbk, - self.acceptCompletion(), - self.accept_socket, - ); - } - - fn newSession(self: *Ctx) !void { - try self.browser.newSession(self.alloc(), self.loop); + fn newSession(self: *Server) !void { + try self.browser.newSession(self.allocator, self.loop); try self.browser.session.initInspector( self, - Ctx.onInspectorResp, - Ctx.onInspectorNotif, + inspectorResponse, + inspectorEvent, ); - self.sessionNew = true; } - // inspector - // --------- + // // inspector + // // --------- - pub fn sendInspector(self: *Ctx, msg: []const u8) void { - if (self.env().getInspector()) |inspector| { - inspector.send(self.env(), msg); - } else @panic("Inspector has not been set"); + // called by cdp + pub fn sendInspector(self: *Server, msg: []const u8) !void { + const env = self.browser.session.env; + if (env.getInspector()) |inspector| { + inspector.send(env, msg); + return; + } + return error.InspectNotSet; } - inline fn inspectorCtx(ctx_opaque: *anyopaque) *Ctx { - const aligned = @as(*align(@alignOf(Ctx)) anyopaque, @alignCast(ctx_opaque)); - return @as(*Ctx, @ptrCast(aligned)); - } - - fn inspectorMsg(allocator: std.mem.Allocator, ctx: *Ctx, msg: []const u8) !void { - // inject sessionID in cdp msg - const tpl = "{s},\"sessionId\":\"{s}\"}}"; - const msg_open = msg[0 .. msg.len - 1]; // remove closing bracket - const s = try std.fmt.allocPrint( - allocator, - tpl, - .{ msg_open, @tagName(ctx.state.sessionID) }, - ); - - try ctx.send(s); - } - - pub fn onInspectorResp(ctx_opaque: *anyopaque, _: u32, msg: []const u8) void { + fn inspectorResponse(ctx: *anyopaque, _: u32, msg: []const u8) void { if (std.log.defaultLogEnabled(.debug)) { // msg should be {"id":,... - const id_end = std.mem.indexOfScalar(u8, msg, ',') orelse unreachable; + std.debug.assert(std.mem.startsWith(u8, msg, "{\"id\":")); + + const id_end = std.mem.indexOfScalar(u8, msg, ',') orelse { + log.warn("invalid inspector response message: {s}", .{msg}); + return; + }; + const id = msg[6..id_end]; std.log.scoped(.cdp).debug("Res (inspector) > id {s}", .{id}); } - const ctx = inspectorCtx(ctx_opaque); - inspectorMsg(ctx.alloc(), ctx, msg) catch unreachable; + sendInspectorMessage(@alignCast(@ptrCast(ctx)), msg); } - pub fn onInspectorNotif(ctx_opaque: *anyopaque, msg: []const u8) void { + fn inspectorEvent(ctx: *anyopaque, msg: []const u8) void { if (std.log.defaultLogEnabled(.debug)) { // msg should be {"method":,... - const method_end = std.mem.indexOfScalar(u8, msg, ',') orelse unreachable; + std.debug.assert(std.mem.startsWith(u8, msg, "{\"method\":")); + const method_end = std.mem.indexOfScalar(u8, msg, ',') orelse { + log.warn("invalid inspector event message: {s}", .{msg}); + return; + }; const method = msg[10..method_end]; std.log.scoped(.cdp).debug("Event (inspector) > method {s}", .{method}); } - const ctx = inspectorCtx(ctx_opaque); - inspectorMsg(ctx.alloc(), ctx, msg) catch unreachable; + + sendInspectorMessage(@alignCast(@ptrCast(ctx)), msg); + } + + fn sendInspectorMessage(self: *Server, msg: []const u8) void { + var client = &(self.client orelse return); + + var scrap = &self.scrap; + scrap.clearRetainingCapacity(); + + const field = ",\"sessionId\":"; + const sessionID = @tagName(self.state.sessionID); + + // + 2 for the quotes around the session + const message_len = msg.len + sessionID.len + 2 + field.len; + + scrap.ensureTotalCapacity(self.allocator, message_len) catch |err| { + log.err("Failed to expand inspector buffer: {}", .{err}); + return; + }; + + // -1 because we dont' want the closing brace '}' + scrap.appendSliceAssumeCapacity(msg[0 .. msg.len - 1]); + scrap.appendSliceAssumeCapacity(field); + scrap.appendAssumeCapacity('"'); + scrap.appendSliceAssumeCapacity(sessionID); + scrap.appendSliceAssumeCapacity("\"}"); + std.debug.assert(scrap.items.len == message_len); + + // TODO: Remove when we clean up ownership of messages between + // CDD and sending. + const owned = self.allocator.dupe(u8, scrap.items) catch return; + + client.sendWS(owned) catch |err| { + log.debug("Failed to write inspector message to client: {}", .{err}); + // don't bother trying to cleanly close the client, if sendWS fails + // we're almost certainly in a non-recoverable state (i.e. OOM) + self.queueClose(client.socket); + }; } }; @@ -387,47 +382,568 @@ pub const Ctx = struct { // NOTE: to allow concurrent send we create each time a dedicated context // (with its own completion), allocated on the heap. // After the send (on the sendCbk) the dedicated context will be destroy -// and the msg slice will be free. +// and the data slice will be free. const Send = struct { - ctx: *Ctx, - msg: []const u8, - completion: Completion = undefined, + // The full data to be sent + data: []const u8, - fn init(ctx: *Ctx, msg: []const u8) !*Send { - const sd = try ctx.alloc().create(Send); - sd.* = .{ .ctx = ctx, .msg = msg }; - return sd; - } + // Whether or not to free the data once the message is sent (or fails to) + // send. This is false in cases where the message is comptime known + free_when_done: bool, + + // Any unsent data we have. Initially unsent == data, but as part of the + // message is succesfully sent, unsent becomes a smaller and smaller slice + // of data + unsent: []const u8, + + server: *Server, + completion: Completion, + socket: posix.socket_t, fn deinit(self: *Send) void { - self.ctx.alloc().free(self.msg); - self.ctx.alloc().destroy(self); + var server = self.server; + if (self.free_when_done) { + server.allocator.free(self.data); + } + server.send_pool.destroy(self); } - fn asyncCbk(self: *Send, _: *Completion, result: SendError!usize) void { - _ = result catch |err| { + fn queueSend(self: *Send) void { + self.server.loop.io.send( + *Send, + self, + sendCallback, + &self.completion, + self.socket, + self.unsent, + ); + } + + fn sendCallback( + self: *Send, + _: *Completion, + result: SendError!usize, + ) void { + const sent = result catch |err| { log.err("send error: {any}", .{err}); - self.ctx.err = err; + if (self.server.client) |*client| { + self.server.queueClose(client.socket); + } + self.deinit(); + return; }; - self.deinit(); + + if (sent == self.unsent.len) { + self.deinit(); + return; + } + + // partial send, re-queue a send for whatever we have left + self.unsent = self.unsent[sent..]; + self.queueSend(); } }; -pub fn sendAsync(ctx: *Ctx, msg: []const u8) !void { - const sd = try Send.init(ctx, msg); - ctx.loop.io.send(*Send, sd, Send.asyncCbk, &sd.completion, ctx.conn_socket, sd.msg); +// Client +// -------- + +// This is a generic only so that it can be unit tested. Normally, S == Server +// and when we send a message, we'll use server.send(...) to send via the server's +// IO loop. During tests, we can inject a simple mock to record (and then verify) +// the send message +fn Client(comptime S: type) type { + const EMPTY_PONG = [_]u8{ 138, 0 }; + + // CLOSE, 2 length, code + const CLOSE_NORMAL = [_]u8{ 136, 2, 3, 232 }; // code: 1000 + const CLOSE_TOO_BIG = [_]u8{ 136, 2, 3, 241 }; // 1009 + const CLOSE_PROTOCOL_ERROR = [_]u8{ 136, 2, 3, 234 }; //code: 1002 + // This should be removed once we support continuation frames + const CLOSE_UNSUPPORTED_ERROR = [_]u8{ 136, 2, 3, 235 }; //code: 1003 + const CLOSE_TIMEOUT = [_]u8{ 136, 2, 15, 160 }; // code: 4000 + + return struct { + // The client is initially serving HTTP requests but, under normal circumstances + // should eventually be upgraded to a websocket connections + mode: Mode, + server: S, + socket: posix.socket_t, + last_active: std.time.Instant, + + // the start of the message in our read_buf + read_pos: usize = 0, + // up to where do we have data in our read_buf + read_len: usize = 0, + read_buf: [MAX_MESSAGE_SIZE]u8 = undefined, + + const Mode = enum { + http, + websocket, + }; + + const Self = @This(); + + fn init(socket: posix.socket_t, server: S) Self { + return .{ + .mode = .http, + .socket = socket, + .server = server, + .last_active = now(), + }; + } + + fn close(self: *Self, close_code: CloseCode) void { + if (self.mode == .websocket) { + switch (close_code) { + .timeout => self.send(&CLOSE_TIMEOUT, false) catch {}, + } + } + self.server.queueClose(self.socket); + } + + fn readBuf(self: *Self) []u8 { + // We might have read a partial http or websocket message. + // Subsequent reads must read from where we left off. + std.debug.assert(self.read_pos < self.read_buf.len); + return self.read_buf[self.read_len..]; + } + + fn processData(self: *Self, len: usize) !bool { + const end = self.read_len + len; + std.debug.assert(end >= self.read_pos); + + self.last_active = now(); + const data = self.read_buf[self.read_pos..end]; + + switch (self.mode) { + .http => { + try self.processHTTPRequest(data); + return true; + }, + .websocket => return self.processWebsocketMessage(data), + } + } + + fn processHTTPRequest(self: *Self, request: []u8) HTTPError!void { + // We should never get pipelined HTTP requests + std.debug.assert(self.read_pos == 0); + + errdefer self.server.queueClose(self.socket); + + // we're only expecting [body-less] GET requests. + if (std.mem.endsWith(u8, request, "\r\n\r\n") == false) { + if (request.len > MAX_HTTP_REQUEST_SIZE) { + self.writeHTTPErrorResponse(413, "Request too large"); + return error.RequestTooLarge; + } + // we need more data, put any more data here + self.read_len = request.len; + return; + } + + self.handleHTTPRequest(request) catch |err| { + switch (err) { + error.NotFound => self.writeHTTPErrorResponse(404, "Not found"), + error.InvalidRequest => self.writeHTTPErrorResponse(400, "Invalid request"), + error.InvalidProtocol => self.writeHTTPErrorResponse(400, "Invalid HTTP protocol"), + error.MissingHeaders => self.writeHTTPErrorResponse(400, "Missing required header"), + error.InvalidUpgradeHeader => self.writeHTTPErrorResponse(400, "Unsupported upgrade type"), + error.InvalidVersionHeader => self.writeHTTPErrorResponse(400, "Invalid websocket version"), + error.InvalidConnectionHeader => self.writeHTTPErrorResponse(400, "Invalid connection header"), + else => { + log.err("error processing HTTP request: {}", .{err}); + self.writeHTTPErrorResponse(500, "Internal Server Error"); + }, + } + return err; + }; + + // the next incoming data can go to the front of our buffer + self.read_len = 0; + } + + fn handleHTTPRequest(self: *Self, request: []u8) !void { + if (request.len < 18) { + // 18 is [generously] the smallest acceptable HTTP request + return error.InvalidRequest; + } + + if (std.mem.eql(u8, request[0..4], "GET ") == false) { + return error.NotFound; + } + + const url_end = std.mem.indexOfScalarPos(u8, request, 4, ' ') orelse { + return error.InvalidRequest; + }; + + const url = request[4..url_end]; + + if (std.mem.eql(u8, url, "/")) { + return self.upgradeConnection(request); + } + + if (std.mem.eql(u8, url, "/json/version")) { + return self.send(self.server.json_version_response, false); + } + + return error.NotFound; + } + + fn upgradeConnection(self: *Self, request: []u8) !void { + // our caller already confirmed that we have a trailing \r\n\r\n + const request_line_end = std.mem.indexOfScalar(u8, request, '\r') orelse unreachable; + const request_line = request[0..request_line_end]; + + if (!std.ascii.endsWithIgnoreCase(request_line, "http/1.1")) { + return error.InvalidProtocol; + } + + // we need to extract the sec-websocket-key value + var key: []const u8 = ""; + + // we need to make sure that we got all the necessary headers + values + var required_headers: u8 = 0; + + // can't std.mem.split because it forces the iterated value to be const + // (we could @constCast...) + + var buf = request[request_line_end + 2 ..]; + + while (buf.len > 4) { + const index = std.mem.indexOfScalar(u8, buf, '\r') orelse unreachable; + const separator = std.mem.indexOfScalar(u8, buf[0..index], ':') orelse return error.InvalidRequest; + + const name = std.mem.trim(u8, toLower(buf[0..separator]), &std.ascii.whitespace); + const value = std.mem.trim(u8, buf[(separator + 1)..index], &std.ascii.whitespace); + + if (std.mem.eql(u8, name, "upgrade")) { + if (!std.ascii.eqlIgnoreCase("websocket", value)) { + return error.InvalidUpgradeHeader; + } + required_headers |= 1; + } else if (std.mem.eql(u8, name, "sec-websocket-version")) { + if (value.len != 2 or value[0] != '1' or value[1] != '3') { + return error.InvalidVersionHeader; + } + required_headers |= 2; + } else if (std.mem.eql(u8, name, "connection")) { + // find if connection header has upgrade in it, example header: + // Connection: keep-alive, Upgrade + if (std.ascii.indexOfIgnoreCase(value, "upgrade") == null) { + return error.InvalidConnectionHeader; + } + required_headers |= 4; + } else if (std.mem.eql(u8, name, "sec-websocket-key")) { + key = value; + required_headers |= 8; + } + + const next = index + 2; + buf = buf[next..]; + } + + if (required_headers != 15) { + return error.MissingHeaders; + } + + // our caller has already made sure this request ended in \r\n\r\n + // so it isn't something we need to check again + + const response = blk: { + // Response to an ugprade request is always this, with + // the Sec-Websocket-Accept value a spacial sha1 hash of the + // request "sec-websocket-version" and a magic value. + + const template = + "HTTP/1.1 101 Switching Protocols\r\n" ++ + "Upgrade: websocket\r\n" ++ + "Connection: upgrade\r\n" ++ + "Sec-Websocket-Accept: 0000000000000000000000000000\r\n\r\n"; + + // The response will be sent via the IO Loop and thus has to have its + // own lifetime. + const res = try self.server.allocator.dupe(u8, template); + errdefer self.server.allocator.free(res); + + // magic response + const key_pos = res.len - 32; + var h: [20]u8 = undefined; + var hasher = std.crypto.hash.Sha1.init(.{}); + hasher.update(key); + // websocket spec always used this value + hasher.update("258EAFA5-E914-47DA-95CA-C5AB0DC85B11"); + hasher.final(&h); + + _ = std.base64.standard.Encoder.encode(res[key_pos .. key_pos + 28], h[0..]); + + break :blk res; + }; + + self.mode = .websocket; + return self.send(response, true); + } + + fn processWebsocketMessage(self: *Self, data: []u8) !bool { + errdefer self.server.queueClose(self.socket); + + var reader = Reader{ .data = data }; + while (true) { + const msg = reader.next() catch |err| { + switch (err) { + error.TooLarge => self.send(&CLOSE_TOO_BIG, false) catch {}, + error.NotMasked => self.send(&CLOSE_PROTOCOL_ERROR, false) catch {}, + error.ReservedFlags => self.send(&CLOSE_PROTOCOL_ERROR, false) catch {}, + error.InvalidMessageType => self.send(&CLOSE_PROTOCOL_ERROR, false) catch {}, + error.ContinuationNotSupported => self.send(&CLOSE_UNSUPPORTED_ERROR, false) catch {}, + } + return err; + } orelse break; + + switch (msg.type) { + .pong => {}, + .ping => try self.sendPong(msg.data), + .close => { + self.send(&CLOSE_NORMAL, false) catch {}; + self.server.queueClose(self.socket); + return false; + }, + .text, .binary => try self.server.handleCDP(msg.data), + } + } + + const incomplete = reader.data; + self.read_len = incomplete.len; + if (incomplete.len > 0) { + // we have part of the data for the next message + + // can't use @memset because incomplete is a slice of read_buf, + // so they could overlap + + // TODO: this can be skipped if we know that the next message will + // fit into whatever reamining space we have. + std.mem.copyForwards(u8, self.read_buf[0..incomplete.len], incomplete); + } + return true; + } + + fn sendPong(self: *Self, data: []const u8) !void { + if (data.len == 0) { + return self.send(&EMPTY_PONG, false); + } + + return self.sendFrame(data, .pong); + } + + fn sendWS(self: *Self, data: []const u8) !void { + std.debug.assert(data.len < 4294967296); + + // for now, we're going to dupe this before we send it, so we don't need + // to keep this around. + defer self.server.allocator.free(data); + return self.sendFrame(data, .text); + } + + // We need to append the websocket header to data. If our IO loop supported + // a writev call, this would be simple. + // For now, we'll just have to dupe data into a larger message. + // TODO: Remove this awful allocation (probably by passing a websocket-aware + // Writer into CDP) + fn sendFrame(self: *Self, data: []const u8, op_code: OpCode) !void { + if (comptime builtin.is_test == false) { + std.debug.assert(self.mode == .websocket); + } + + // 10 is the max possible length of our header + // server->client has no mask, so it's 4 fewer bytes than the reader overhead + var header_buf: [10]u8 = undefined; + + const header: []const u8 = blk: { + const len = data.len; + header_buf[0] = 128 | @intFromEnum(op_code); // fin | opcode + + if (len <= 125) { + header_buf[1] = @intCast(len); + break :blk header_buf[0..2]; + } + + if (len < 65536) { + header_buf[1] = 126; + header_buf[2] = @intCast((len >> 8) & 0xFF); + header_buf[3] = @intCast(len & 0xFF); + break :blk header_buf[0..4]; + } + + header_buf[1] = 127; + header_buf[2] = 0; + header_buf[3] = 0; + header_buf[4] = 0; + header_buf[5] = 0; + header_buf[6] = @intCast((len >> 24) & 0xFF); + header_buf[7] = @intCast((len >> 16) & 0xFF); + header_buf[8] = @intCast((len >> 8) & 0xFF); + header_buf[9] = @intCast(len & 0xFF); + break :blk header_buf[0..10]; + }; + + const allocator = self.server.allocator; + const full = try allocator.alloc(u8, header.len + data.len); + errdefer allocator.free(full); + @memcpy(full[0..header.len], header); + @memcpy(full[header.len..], data); + try self.send(full, true); + } + + fn writeHTTPErrorResponse(self: *Self, comptime status: u16, comptime body: []const u8) void { + const response = std.fmt.comptimePrint( + "HTTP/1.1 {d} \r\nConnection: Close\r\nContent-Length: {d}\r\n\r\n{s}", + .{ status, body.len, body }, + ); + + // we're going to close this connection anyways, swallowing any + // error seems safe + self.send(response, false) catch {}; + } + + fn send(self: *Self, data: []const u8, free_when_done: bool) !void { + return self.server.queueSend(self.socket, data, free_when_done); + } + }; } -// Listener and handler -// -------------------- +// WebSocket message reader. Given websocket message, acts as an iterator that +// can return zero or more Messages. When next returns null, any incomplete +// message will remain in reader.data +const Reader = struct { + data: []u8, -pub fn handle( - alloc: std.mem.Allocator, - loop: *jsruntime.Loop, - server_socket: std.posix.socket_t, - stream: ?*Stream, + fn next(self: *Reader) !?Message { + var data = self.data; + if (data.len < 2) { + return null; + } + + const byte1 = data[0]; + + if (byte1 & 112 != 0) { + return error.ReservedFlags; + } + + var message_type: Message.Type = undefined; + switch (byte1 & 15) { + 0 => return error.ContinuationNotSupported, // TODO?? + 1 => message_type = .text, + 2 => message_type = .binary, + 8 => message_type = .close, + 9 => message_type = .ping, + 10 => message_type = .pong, + else => return error.InvalidMessageType, + } + + if (byte1 & 128 != 128) { + // TODO?? + return error.ContinuationNotSupported; + } + + const byte2 = data[1]; + if (byte2 & 128 != 128) { + // client -> server messages _must_ be masked + return error.NotMasked; + } + + const length_of_len: usize = switch (byte2 & 127) { + 126 => 2, + 127 => 8, + else => 0, + }; + + if (data.len < length_of_len + 2) { + // we definitely don't have enough data yet + return null; + } + + const message_len = switch (length_of_len) { + 2 => @as(u16, @intCast(data[3])) | @as(u16, @intCast(data[2])) << 8, + 8 => @as(u64, @intCast(data[9])) | @as(u64, @intCast(data[8])) << 8 | @as(u64, @intCast(data[7])) << 16 | @as(u64, @intCast(data[6])) << 24 | @as(u64, @intCast(data[5])) << 32 | @as(u64, @intCast(data[4])) << 40 | @as(u64, @intCast(data[3])) << 48 | @as(u64, @intCast(data[2])) << 56, + else => data[1] & 127, + } + length_of_len + 2 + 4; // +2 for header prefix, +4 for mask + + if (message_len > MAX_MESSAGE_SIZE) { + return error.TooLarge; + } + + if (data.len < message_len) { + return null; + } + + // prefix + length_of_len + mask + const header_len = 2 + length_of_len + 4; + + const payload = data[header_len..message_len]; + mask(data[header_len - 4 .. header_len], payload); + + self.data = data[message_len..]; + return .{ + .type = message_type, + .data = payload, + }; + } +}; + +const Message = struct { + type: Type, + data: []const u8, + + const Type = enum { + text, + binary, + close, + ping, + pong, + }; +}; + +// These are the only websocket types that we're currently sending +const OpCode = enum(u8) { + text = 128 | 1, + close = 128 | 8, + pong = 128 | 10, +}; + +// "private-use" close codes must be from 4000-49999 +const CloseCode = enum { + timeout, +}; + +pub fn run( + allocator: Allocator, + address: net.Address, timeout: u64, -) anyerror!void { + loop: *jsruntime.Loop, +) !void { + if (comptime builtin.is_test) { + // There's bunch of code that won't compiler in a test build (because + // it relies on a global root.Types). So we fight the compiler and make + // sure it doesn't include any of that code. Hopefully one day we can + // remove all this. + return; + } + + // create socket + const flags = posix.SOCK.STREAM | posix.SOCK.CLOEXEC | posix.SOCK.NONBLOCK; + const listener = try posix.socket(address.any.family, flags, posix.IPPROTO.TCP); + defer posix.close(listener); + + try posix.setsockopt(listener, posix.SOL.SOCKET, posix.SO.REUSEADDR, &std.mem.toBytes(@as(c_int, 1))); + // TODO: Broken on darwin + // https://github.com/ziglang/zig/issues/17260 (fixed in Zig 0.14) + // if (@hasDecl(os.TCP, "NODELAY")) { + // try os.setsockopt(socket.sockfd.?, os.IPPROTO.TCP, os.TCP.NODELAY, &std.mem.toBytes(@as(c_int, 1))); + // } + try posix.setsockopt(listener, posix.IPPROTO.TCP, 1, &std.mem.toBytes(@as(c_int, 1))); + + // bind & listen + try posix.bind(listener, &address.any, address.getOsSockLen()); + try posix.listen(listener, 1); // create v8 vm const vm = jsruntime.VM.init(); @@ -435,46 +951,31 @@ pub fn handle( // browser var browser: Browser = undefined; - try Browser.init(&browser, alloc, loop, vm); + try Browser.init(&browser, allocator, loop, vm); defer browser.deinit(); - // create buffers - var read_buf: [BufReadSize]u8 = undefined; - var buf: [MaxSize]u8 = undefined; - var msg_buf = MsgBuffer{ .buf = &buf }; + const json_version_response = try buildJSONVersionResponse(allocator, address); - // create I/O completions - var accept_completion: Completion = undefined; - var conn_completion: Completion = undefined; - var timeout_completion: Completion = undefined; - - // create I/O contexts and callbacks - // for accepting connections and receving messages - var ctx = Ctx{ + var server = Server{ .loop = loop, - .stream = stream, - .browser = &browser, - .sessionNew = true, - .read_buf = &read_buf, - .msg_buf = &msg_buf, - .accept_socket = server_socket, .timeout = timeout, - .accept_completion = &accept_completion, - .conn_completion = &conn_completion, - .timeout_completion = &timeout_completion, + .browser = &browser, + .listener = listener, + .allocator = allocator, + .conn_completion = undefined, + .close_completion = undefined, + .accept_completion = undefined, + .timeout_completion = undefined, .state = cdp.State.init(browser.session.alloc), + .json_version_response = json_version_response, + .send_pool = std.heap.MemoryPool(Send).init(allocator), }; - defer ctx.deinit(); + defer server.deinit(); - try browser.session.initInspector( - &ctx, - Ctx.onInspectorResp, - Ctx.onInspectorNotif, - ); + try browser.session.initInspector(&server, Server.inspectorResponse, Server.inspectorEvent); - // accepting connection asynchronously on internal server - log.info("accepting new conn...", .{}); - loop.io.accept(*Ctx, &ctx, Ctx.acceptCbk, ctx.acceptCompletion(), ctx.accept_socket); + // accept an connection + server.queueAccept(); // infinite loop on I/O events, either: // - cmd from incoming connection on server socket @@ -483,58 +984,565 @@ pub fn handle( try loop.io.run_for_ns(10 * std.time.ns_per_ms); if (loop.cbk_error) { log.err("JS error", .{}); - // if (try try_catch.exception(alloc, js_env.*)) |msg| { - // std.debug.print("\n\rUncaught {s}\n\r", .{msg}); - // alloc.free(msg); - // } - // loop.cbk_error = false; - } - if (ctx.err) |err| { - if (err != error.NoError) log.err("Server error: {any}", .{err}); - break; } } } -fn setSockOpt(fd: std.posix.socket_t, level: i32, option: u32, value: c_int) !void { - try std.posix.setsockopt(fd, level, option, &std.mem.toBytes(value)); +// Utils +// -------- + +fn buildJSONVersionResponse( + allocator: Allocator, + address: net.Address, +) ![]const u8 { + const body_format = "{{\"webSocketDebuggerUrl\": \"ws://{}/\"}}"; + const body_len = std.fmt.count(body_format, .{address}); + + const response_format = + "HTTP/1.1 200 OK\r\n" ++ + "Content-Length: {d}\r\n" ++ + "Content-Type: application/json; charset=UTF-8\r\n\r\n" ++ + body_format; + return try std.fmt.allocPrint(allocator, response_format, .{ body_len, address }); } -fn isUnixSocket(addr: std.net.Address) bool { - return addr.any.family == std.posix.AF.UNIX; +fn now() std.time.Instant { + // can only fail on platforms we don't support + return std.time.Instant.now() catch unreachable; } -pub fn listen(address: std.net.Address) !std.posix.socket_t { - const isunixsock = isUnixSocket(address); - - // create socket - const flags = std.posix.SOCK.STREAM | std.posix.SOCK.CLOEXEC | std.posix.SOCK.NONBLOCK; - const proto = if (isunixsock) @as(u32, 0) else std.posix.IPPROTO.TCP; - const sockfd = try std.posix.socket(address.any.family, flags, proto); - errdefer std.posix.close(sockfd); - - // socket options - // - // REUSEPORT can't be set on unix socket anymore. - // see https://github.com/torvalds/linux/commit/5b0af621c3f6ef9261cf6067812f2fd9943acb4b - if (@hasDecl(std.posix.SO, "REUSEPORT") and !isunixsock) { - try setSockOpt(sockfd, std.posix.SOL.SOCKET, std.posix.SO.REUSEPORT, 1); +// In-place string lowercase +fn toLower(str: []u8) []u8 { + for (str, 0..) |c, i| { + str[i] = std.ascii.toLower(c); } - try setSockOpt(sockfd, std.posix.SOL.SOCKET, std.posix.SO.REUSEADDR, 1); - if (!isUnixSocket(address)) { - if (builtin.target.os.tag == .linux) { // posix.TCP not available on MacOS - // WARNING: disable Nagle's alogrithm to avoid latency issues - try setSockOpt(sockfd, std.posix.IPPROTO.TCP, std.posix.TCP.NODELAY, 1); + return str; +} + +// Zig is in a weird backend transition right now. Need to determine if +// SIMD is even available. +const backend_supports_vectors = switch (builtin.zig_backend) { + .stage2_llvm, .stage2_c => true, + else => false, +}; + +// Websocket messages from client->server are masked using a 4 byte XOR mask +fn mask(m: []const u8, payload: []u8) void { + var data = payload; + + if (!comptime backend_supports_vectors) return simpleMask(m, data); + + const vector_size = std.simd.suggestVectorLength(u8) orelse @sizeOf(usize); + if (data.len >= vector_size) { + const mask_vector = std.simd.repeat(vector_size, @as(@Vector(4, u8), m[0..4].*)); + while (data.len >= vector_size) { + const slice = data[0..vector_size]; + const masked_data_slice: @Vector(vector_size, u8) = slice.*; + slice.* = masked_data_slice ^ mask_vector; + data = data[vector_size..]; } } - - // bind & listen - var socklen = address.getOsSockLen(); - try std.posix.bind(sockfd, &address.any, socklen); - const kernel_backlog = 1; // default value is 128. Here we just want 1 connection - try std.posix.listen(sockfd, kernel_backlog); - var listen_address: std.net.Address = undefined; - try std.posix.getsockname(sockfd, &listen_address.any, &socklen); - - return sockfd; + simpleMask(m, data); } + +// Used when SIMD isn't available, or for any remaining part of the message +// which is too small to effectively use SIMD. +fn simpleMask(m: []const u8, payload: []u8) void { + for (payload, 0..) |b, i| { + payload[i] = b ^ m[i & 3]; + } +} + +const testing = std.testing; +test "server: buildJSONVersionResponse" { + const address = try net.Address.parseIp4("127.0.0.1", 9001); + const res = try buildJSONVersionResponse(testing.allocator, address); + defer testing.allocator.free(res); + + try testing.expectEqualStrings("HTTP/1.1 200 OK\r\n" ++ + "Content-Length: 48\r\n" ++ + "Content-Type: application/json; charset=UTF-8\r\n\r\n" ++ + "{\"webSocketDebuggerUrl\": \"ws://127.0.0.1:9001/\"}", res); +} + +test "Client: http invalid handshake" { + try assertHTTPError( + error.InvalidRequest, + 400, + "Invalid request", + "\r\n\r\n", + ); + + try assertHTTPError( + error.NotFound, + 404, + "Not found", + "GET /over/9000 HTTP/1.1\r\n\r\n", + ); + + try assertHTTPError( + error.NotFound, + 404, + "Not found", + "POST / HTTP/1.1\r\n\r\n", + ); + + try assertHTTPError( + error.InvalidProtocol, + 400, + "Invalid HTTP protocol", + "GET / HTTP/1.0\r\n\r\n", + ); + + try assertHTTPError( + error.MissingHeaders, + 400, + "Missing required header", + "GET / HTTP/1.1\r\n\r\n", + ); + + try assertHTTPError( + error.MissingHeaders, + 400, + "Missing required header", + "GET / HTTP/1.1\r\nConnection: upgrade\r\n\r\n", + ); + + try assertHTTPError( + error.MissingHeaders, + 400, + "Missing required header", + "GET / HTTP/1.1\r\nConnection: upgrade\r\nUpgrade: websocket\r\n\r\n", + ); + + try assertHTTPError( + error.MissingHeaders, + 400, + "Missing required header", + "GET / HTTP/1.1\r\nConnection: upgrade\r\nUpgrade: websocket\r\nsec-websocket-version:13\r\n\r\n", + ); +} + +test "Client: http valid handshake" { + var ms = MockServer{}; + defer ms.deinit(); + + var client = Client(*MockServer).init(0, &ms); + + const request = + "GET / HTTP/1.1\r\n" ++ + "Connection: upgrade\r\n" ++ + "Upgrade: websocket\r\n" ++ + "sec-websocket-version:13\r\n" ++ + "sec-websocket-key: this is my key\r\n" ++ + "Custom: Header-Value\r\n\r\n"; + + @memcpy(client.read_buf[0..request.len], request); + try testing.expectEqual(true, try client.processData(request.len)); + + try testing.expectEqual(.websocket, client.mode); + try testing.expectEqualStrings( + "HTTP/1.1 101 Switching Protocols\r\n" ++ + "Upgrade: websocket\r\n" ++ + "Connection: upgrade\r\n" ++ + "Sec-Websocket-Accept: flzHu2DevQ2dSCSVqKSii5e9C2o=\r\n\r\n", + ms.sent.items[0], + ); +} + +test "Client: http get json version" { + var ms = MockServer{}; + defer ms.deinit(); + + var client = Client(*MockServer).init(0, &ms); + + const request = "GET /json/version HTTP/1.1\r\n\r\n"; + + @memcpy(client.read_buf[0..request.len], request); + try testing.expectEqual(true, try client.processData(request.len)); + + try testing.expectEqual(.http, client.mode); + + // this is the hardcoded string in our MockServer + try testing.expectEqualStrings("the json version response", ms.sent.items[0]); +} + +test "Client: write websocket message" { + const cases = [_]struct { expected: []const u8, message: []const u8 }{ + .{ .expected = &.{ 129, 0 }, .message = "" }, + .{ .expected = [_]u8{ 129, 12 } ++ "hello world!", .message = "hello world!" }, + .{ .expected = [_]u8{ 129, 126, 0, 130 } ++ ("A" ** 130), .message = "A" ** 130 }, + }; + + for (cases) |c| { + var ms = MockServer{}; + defer ms.deinit(); + + var client = Client(*MockServer).init(0, &ms); + + try client.sendWS(try testing.allocator.dupe(u8, c.message)); + try testing.expectEqual(1, ms.sent.items.len); + try testing.expectEqualSlices(u8, c.expected, ms.sent.items[0]); + } +} + +test "Client: read invalid websocket message" { + try assertWebSocketError( + error.InvalidMessageType, + 1002, + "", + &.{ 131, 1 }, // 128 (fin) | 3 where 3 isn't a valid type + ); + + try assertWebSocketError( + error.ContinuationNotSupported, + 1003, + "", + &.{ 128, 1 }, // 128 (fin) | 0 where 0 is a continuation frame + ); + + try assertWebSocketError( + error.ContinuationNotSupported, + 1003, + "", + &.{ 1, 1 }, // 0 (non-fin) | 1 non-fin (contination) not supported + ); + + for ([_]u8{ 16, 32, 64 }) |rsv| { + // none of the reserve flags should be set + try assertWebSocketError( + error.ReservedFlags, + 1002, + "", + &.{ rsv, 0 }, + ); + + // as a bitmask + try assertWebSocketError( + error.ReservedFlags, + 1002, + "", + &.{ rsv + 4, 0 }, + ); + } + + try assertWebSocketError( + error.NotMasked, + 1002, + "", + &.{ 129, 127 }, // client->server messages must be masked + ); + + try assertWebSocketError( + error.TooLarge, + 1009, + "", + &.{ 129, 255, 0, 0, 0, 0, 0, 4, 0, 1 }, // 1024 * 256 + 1 + ); +} + +test "Client: ping reply" { + try assertWebSocketMessage( + // fin | pong, len + &.{ 138, 0 }, + + // fin | ping, masked | len, 4-byte mask + &.{ 137, 128, 0, 0, 0, 0 }, + ); + + try assertWebSocketMessage( + // fin | pong, len, payload + &.{ 138, 5, 100, 96, 97, 109, 104 }, + + // fin | ping, masked | len, 4-byte mask, 5 byte payload + &.{ 137, 133, 0, 5, 7, 10, 100, 101, 102, 103, 104 }, + ); +} + +test "Client: close message" { + try assertWebSocketMessage( + // fin | close, len, close code (normal) + &.{ 136, 2, 3, 232 }, + + // fin | close, masked | len, 4-byte mask + &.{ 136, 128, 0, 0, 0, 0 }, + ); +} + +// Testing both HTTP and websocket messages broken up across multiple reads. +// We need to fuzz HTTP messages differently than websocket. HTTP are strictly +// req -> res with no pipelining. So there should only be 1 message at a time. +// So we can only "fuzz" on a per-message basis. +// But for websocket, we can fuzz _all_ the messages together. +test "Client: fuzz" { + var prng = std.rand.DefaultPrng.init(blk: { + var seed: u64 = undefined; + try std.posix.getrandom(std.mem.asBytes(&seed)); + break :blk seed; + }); + const random = prng.random(); + + const allocator = testing.allocator; + var websocket_messages: std.ArrayListUnmanaged(u8) = .{}; + defer websocket_messages.deinit(allocator); + + // ping with no payload + try websocket_messages.appendSlice( + allocator, + &.{ 137, 128, 0, 0, 0, 0 }, + ); + + // // 10 byte text message with a 0,0,0,0 mask + try websocket_messages.appendSlice( + allocator, + &.{ 129, 138, 0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 }, + ); + + // ping with a payload + try websocket_messages.appendSlice( + allocator, + &.{ 137, 133, 0, 5, 7, 10, 100, 101, 102, 103, 104 }, + ); + + // pong with no payload (noop in the server) + try websocket_messages.appendSlice( + allocator, + &.{ 138, 128, 10, 10, 10, 10 }, + ); + + // 687 long message, with a mask + try websocket_messages.appendSlice( + allocator, + [_]u8{ 129, 254, 2, 175, 1, 2, 3, 4 } ++ "A" ** 687, + ); + + // close + try websocket_messages.appendSlice( + allocator, + &.{ 136, 130, 200, 103, 34, 22, 0, 1 }, + ); + + const SendRandom = struct { + fn send(c: anytype, r: std.Random, data: []const u8) !void { + var buf = data; + while (buf.len > 0) { + const to_send = r.intRangeAtMost(usize, 1, buf.len); + @memcpy(c.readBuf()[0..to_send], buf[0..to_send]); + if (try c.processData(to_send) == false) { + return; + } + buf = buf[to_send..]; + } + } + }; + + for (0..1) |_| { + var ms = MockServer{}; + defer ms.deinit(); + + var client = Client(*MockServer).init(0, &ms); + + try SendRandom.send(&client, random, "GET /json/version HTTP/1.1\r\nContent-Length: 0\r\n\r\n"); + try SendRandom.send(&client, random, "GET / HTTP/1.1\r\n" ++ + "Connection: upgrade\r\n" ++ + "Upgrade: websocket\r\n" ++ + "sec-websocket-version:13\r\n" ++ + "sec-websocket-key: 1234aa93\r\n" ++ + "Custom: Header-Value\r\n\r\n"); + + // fuzz over all websocket messages + try SendRandom.send(&client, random, websocket_messages.items); + + try testing.expectEqual(5, ms.sent.items.len); + + try testing.expectEqualStrings( + "the json version response", + ms.sent.items[0], + ); + + try testing.expectEqualStrings( + "HTTP/1.1 101 Switching Protocols\r\n" ++ + "Upgrade: websocket\r\n" ++ + "Connection: upgrade\r\n" ++ + "Sec-Websocket-Accept: KnOKWrrjHS0nGFmtfmYFQoPIGKQ=\r\n\r\n", + ms.sent.items[1], + ); + + try testing.expectEqualSlices(u8, &.{ 138, 0 }, ms.sent.items[2]); + + try testing.expectEqualSlices( + u8, + &.{ 138, 5, 100, 96, 97, 109, 104 }, + ms.sent.items[3], + ); + + try testing.expectEqualSlices( + u8, + &.{ 136, 2, 3, 232 }, + ms.sent.items[4], + ); + + try testing.expectEqual(2, ms.cdp.items.len); + try testing.expectEqualSlices( + u8, + &.{ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 }, + ms.cdp.items[0], + ); + + try testing.expectEqualSlices( + u8, + &([_]u8{ 64, 67, 66, 69 } ** 171 ++ [_]u8{ 64, 67, 66 }), + ms.cdp.items[1], + ); + + try testing.expectEqual(true, ms.closed); + } +} + +test "server: mask" { + var buf: [4000]u8 = undefined; + const messages = [_][]const u8{ "1234", "1234" ** 99, "1234" ** 999 }; + for (messages) |message| { + // we need the message to be mutable since mask operates in-place + const payload = buf[0..message.len]; + @memcpy(payload, message); + + mask(&.{ 1, 2, 200, 240 }, payload); + try testing.expectEqual(false, std.mem.eql(u8, payload, message)); + + mask(&.{ 1, 2, 200, 240 }, payload); + try testing.expectEqual(true, std.mem.eql(u8, payload, message)); + } +} + +fn assertHTTPError( + expected_error: HTTPError, + comptime expected_status: u16, + comptime expected_body: []const u8, + input: []const u8, +) !void { + var ms = MockServer{}; + defer ms.deinit(); + + var client = Client(*MockServer).init(0, &ms); + @memcpy(client.read_buf[0..input.len], input); + try testing.expectError(expected_error, client.processData(input.len)); + + const expected_response = std.fmt.comptimePrint( + "HTTP/1.1 {d} \r\nConnection: Close\r\nContent-Length: {d}\r\n\r\n{s}", + .{ expected_status, expected_body.len, expected_body }, + ); + + try testing.expectEqual(1, ms.sent.items.len); + try testing.expectEqualStrings(expected_response, ms.sent.items[0]); +} + +fn assertWebSocketError( + expected_error: WebSocketError, + close_code: u16, + close_payload: []const u8, + input: []const u8, +) !void { + var ms = MockServer{}; + defer ms.deinit(); + + var client = Client(*MockServer).init(0, &ms); + client.mode = .websocket; // force websocket message processing + + @memcpy(client.read_buf[0..input.len], input); + try testing.expectError(expected_error, client.processData(input.len)); + + try testing.expectEqual(1, ms.sent.items.len); + + const actual = ms.sent.items[0]; + + // fin | close opcode + try testing.expectEqual(136, actual[0]); + + // message length (code + payload) + try testing.expectEqual(2 + close_payload.len, actual[1]); + + // close code + try testing.expectEqual(close_code, std.mem.readInt(u16, actual[2..4], .big)); + + // close payload (if any) + try testing.expectEqualStrings(close_payload, actual[4..]); +} + +fn assertWebSocketMessage( + expected: []const u8, + input: []const u8, +) !void { + var ms = MockServer{}; + defer ms.deinit(); + + var client = Client(*MockServer).init(0, &ms); + client.mode = .websocket; // force websocket message processing + + @memcpy(client.read_buf[0..input.len], input); + const more = try client.processData(input.len); + + try testing.expectEqual(1, ms.sent.items.len); + try testing.expectEqualSlices(u8, expected, ms.sent.items[0]); + + // if we sent a close message, then the serve should have been told + // to close the connection + if (expected[0] == 136) { + try testing.expectEqual(true, ms.closed); + try testing.expectEqual(false, more); + } else { + try testing.expectEqual(false, ms.closed); + try testing.expectEqual(true, more); + } +} + +const MockServer = struct { + closed: bool = false, + + // record the messages we sent to the client + sent: std.ArrayListUnmanaged([]const u8) = .{}, + + // record the CDP messages we need to process + cdp: std.ArrayListUnmanaged([]const u8) = .{}, + + allocator: Allocator = testing.allocator, + + json_version_response: []const u8 = "the json version response", + + fn deinit(self: *MockServer) void { + const allocator = self.allocator; + + for (self.sent.items) |msg| { + allocator.free(msg); + } + self.sent.deinit(allocator); + + for (self.cdp.items) |msg| { + allocator.free(msg); + } + self.cdp.deinit(allocator); + } + + fn queueClose(self: *MockServer, _: anytype) void { + self.closed = true; + } + + fn handleCDP(self: *MockServer, message: []const u8) !void { + const owned = try self.allocator.dupe(u8, message); + try self.cdp.append(self.allocator, owned); + } + + fn queueSend( + self: *MockServer, + socket: posix.socket_t, + data: []const u8, + free_when_done: bool, + ) !void { + _ = socket; + const owned = try self.allocator.dupe(u8, data); + try self.sent.append(self.allocator, owned); + if (free_when_done) { + testing.allocator.free(data); + } + } +}; diff --git a/src/unit_tests.zig b/src/unit_tests.zig index 2ab87f9a..7508821e 100644 --- a/src/unit_tests.zig +++ b/src/unit_tests.zig @@ -341,7 +341,7 @@ test { std.testing.refAllDecls(@import("css/parser.zig")); std.testing.refAllDecls(@import("generate.zig")); std.testing.refAllDecls(@import("http/Client.zig")); - std.testing.refAllDecls(@import("msg.zig")); std.testing.refAllDecls(@import("storage/storage.zig")); std.testing.refAllDecls(@import("iterator/iterator.zig")); + std.testing.refAllDecls(@import("server.zig")); } diff --git a/vendor/zig-js-runtime b/vendor/zig-js-runtime index 61c71e5e..f40f4914 160000 --- a/vendor/zig-js-runtime +++ b/vendor/zig-js-runtime @@ -1 +1 @@ -Subproject commit 61c71e5e390316786a0c780d9135a45890bda846 +Subproject commit f40f4914667f4fc7cd14ee0df0e76a2fd8d835b4