mirror of
https://github.com/lightpanda-io/browser.git
synced 2025-10-29 15:13:28 +00:00
Make websocket client reader stateful
Move more logic into the reader. Avoid copying partial messages in cases where we know that the buffer is large enough. This is mostly groundwork for trying to add support for continuation frames.
This commit is contained in:
236
src/server.zig
236
src/server.zig
@@ -64,7 +64,9 @@ const log = std.log.scoped(.server);
|
||||
|
||||
const MAX_HTTP_REQUEST_SIZE = 2048;
|
||||
|
||||
// max message size, +14 for max websocket payload overhead
|
||||
// max message size
|
||||
// +14 for max websocket payload overhead
|
||||
// +140 for the max control packet that might be interleaved in a message
|
||||
const MAX_MESSAGE_SIZE = 256 * 1024 + 14;
|
||||
|
||||
// For now, cdp does @import("server.zig").Ctx. Could change cdp to use "Server"
|
||||
@@ -466,15 +468,10 @@ fn Client(comptime S: type) type {
|
||||
// should eventually be upgraded to a websocket connections
|
||||
mode: Mode,
|
||||
server: S,
|
||||
reader: Reader,
|
||||
socket: posix.socket_t,
|
||||
last_active: std.time.Instant,
|
||||
|
||||
// the start of the message in our read_buf
|
||||
read_pos: usize = 0,
|
||||
// up to where do we have data in our read_buf
|
||||
read_len: usize = 0,
|
||||
read_buf: [MAX_MESSAGE_SIZE]u8 = undefined,
|
||||
|
||||
const Mode = enum {
|
||||
http,
|
||||
websocket,
|
||||
@@ -485,6 +482,7 @@ fn Client(comptime S: type) type {
|
||||
fn init(socket: posix.socket_t, server: S) Self {
|
||||
return .{
|
||||
.mode = .http,
|
||||
.reader = .{},
|
||||
.socket = socket,
|
||||
.server = server,
|
||||
.last_active = now(),
|
||||
@@ -501,31 +499,25 @@ fn Client(comptime S: type) type {
|
||||
}
|
||||
|
||||
fn readBuf(self: *Self) []u8 {
|
||||
// We might have read a partial http or websocket message.
|
||||
// Subsequent reads must read from where we left off.
|
||||
std.debug.assert(self.read_pos < self.read_buf.len);
|
||||
return self.read_buf[self.read_len..];
|
||||
return self.reader.readBuf();
|
||||
}
|
||||
|
||||
fn processData(self: *Self, len: usize) !bool {
|
||||
const end = self.read_len + len;
|
||||
std.debug.assert(end >= self.read_pos);
|
||||
|
||||
self.last_active = now();
|
||||
const data = self.read_buf[self.read_pos..end];
|
||||
self.reader.len += len;
|
||||
|
||||
switch (self.mode) {
|
||||
.http => {
|
||||
try self.processHTTPRequest(data);
|
||||
try self.processHTTPRequest();
|
||||
return true;
|
||||
},
|
||||
.websocket => return self.processWebsocketMessage(data),
|
||||
.websocket => return self.processWebsocketMessage(),
|
||||
}
|
||||
}
|
||||
|
||||
fn processHTTPRequest(self: *Self, request: []u8) HTTPError!void {
|
||||
// We should never get pipelined HTTP requests
|
||||
std.debug.assert(self.read_pos == 0);
|
||||
fn processHTTPRequest(self: *Self) HTTPError!void {
|
||||
std.debug.assert(self.reader.pos == 0);
|
||||
const request = self.reader.buf[0..self.reader.len];
|
||||
|
||||
errdefer self.server.queueClose(self.socket);
|
||||
|
||||
@@ -537,7 +529,6 @@ fn Client(comptime S: type) type {
|
||||
// we're only expecting [body-less] GET requests.
|
||||
if (std.mem.endsWith(u8, request, "\r\n\r\n") == false) {
|
||||
// we need more data, put any more data here
|
||||
self.read_len = request.len;
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -559,7 +550,7 @@ fn Client(comptime S: type) type {
|
||||
};
|
||||
|
||||
// the next incoming data can go to the front of our buffer
|
||||
self.read_len = 0;
|
||||
self.reader.len = 0;
|
||||
}
|
||||
|
||||
fn handleHTTPRequest(self: *Self, request: []u8) !void {
|
||||
@@ -683,10 +674,10 @@ fn Client(comptime S: type) type {
|
||||
return self.send(response, true);
|
||||
}
|
||||
|
||||
fn processWebsocketMessage(self: *Self, data: []u8) !bool {
|
||||
fn processWebsocketMessage(self: *Self) !bool {
|
||||
errdefer self.server.queueClose(self.socket);
|
||||
|
||||
var reader = Reader{ .data = data };
|
||||
var reader = &self.reader;
|
||||
while (true) {
|
||||
const msg = reader.next() catch |err| {
|
||||
switch (err) {
|
||||
@@ -711,18 +702,9 @@ fn Client(comptime S: type) type {
|
||||
}
|
||||
}
|
||||
|
||||
const incomplete = reader.data;
|
||||
self.read_len = incomplete.len;
|
||||
if (incomplete.len > 0) {
|
||||
// we have part of the data for the next message
|
||||
|
||||
// can't use @memset because incomplete is a slice of read_buf,
|
||||
// so they could overlap
|
||||
|
||||
// TODO: this can be skipped if we know that the next message will
|
||||
// fit into whatever reamining space we have.
|
||||
std.mem.copyForwards(u8, self.read_buf[0..incomplete.len], incomplete);
|
||||
}
|
||||
// We might have read part of the next message. Our reader potentially
|
||||
// has to move data around in its buffer to make space.
|
||||
reader.compact();
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -814,23 +796,52 @@ fn Client(comptime S: type) type {
|
||||
// can return zero or more Messages. When next returns null, any incomplete
|
||||
// message will remain in reader.data
|
||||
const Reader = struct {
|
||||
data: []u8,
|
||||
// position in buf of the start of the next message
|
||||
pos: usize = 0,
|
||||
|
||||
fn next(self: *Reader) !?Message {
|
||||
var data = self.data;
|
||||
if (data.len < 2) {
|
||||
return null;
|
||||
// position in buf up until where we have valid data
|
||||
// (any new reads must be placed after this)
|
||||
len: usize = 0,
|
||||
|
||||
// we add 140 to allow 1 control message (ping/pong/close) to be
|
||||
// fragmented into a normal message.
|
||||
buf: [MAX_MESSAGE_SIZE + 140]u8 = undefined,
|
||||
|
||||
fn readBuf(self: *Reader) []u8 {
|
||||
// We might have read a partial http or websocket message.
|
||||
// Subsequent reads must read from where we left off.
|
||||
return self.buf[self.len..];
|
||||
}
|
||||
|
||||
const byte1 = data[0];
|
||||
fn next(self: *Reader) !?Message {
|
||||
var buf = self.buf[self.pos..self.len];
|
||||
|
||||
const length_of_len, const message_len = extractLengths(buf) orelse {
|
||||
// we don't have enough bytes
|
||||
return null;
|
||||
};
|
||||
|
||||
const byte1 = buf[0];
|
||||
|
||||
if (byte1 & 112 != 0) {
|
||||
return error.ReservedFlags;
|
||||
}
|
||||
|
||||
const fin = byte1 & 128 == 128;
|
||||
if (!fin) {
|
||||
return error.ContinuationNotSupported;
|
||||
}
|
||||
|
||||
if (buf[1] & 128 != 128) {
|
||||
// client -> server messages _must_ be masked
|
||||
return error.NotMasked;
|
||||
}
|
||||
|
||||
// var is_continuation = false;
|
||||
var message_type: Message.Type = undefined;
|
||||
switch (byte1 & 15) {
|
||||
0 => return error.ContinuationNotSupported, // TODO??
|
||||
// 0 => is_continuation = true,
|
||||
0 => return error.ContinuationNotSupported,
|
||||
1 => message_type = .text,
|
||||
2 => message_type = .binary,
|
||||
8 => message_type = .close,
|
||||
@@ -839,54 +850,105 @@ const Reader = struct {
|
||||
else => return error.InvalidMessageType,
|
||||
}
|
||||
|
||||
if (byte1 & 128 != 128) {
|
||||
// TODO??
|
||||
return error.ContinuationNotSupported;
|
||||
}
|
||||
|
||||
const byte2 = data[1];
|
||||
if (byte2 & 128 != 128) {
|
||||
// client -> server messages _must_ be masked
|
||||
return error.NotMasked;
|
||||
}
|
||||
|
||||
const length_of_len: usize = switch (byte2 & 127) {
|
||||
126 => 2,
|
||||
127 => 8,
|
||||
else => 0,
|
||||
};
|
||||
|
||||
if (data.len < length_of_len + 2) {
|
||||
// we definitely don't have enough data yet
|
||||
return null;
|
||||
}
|
||||
|
||||
const message_len = switch (length_of_len) {
|
||||
2 => @as(u16, @intCast(data[3])) | @as(u16, @intCast(data[2])) << 8,
|
||||
8 => @as(u64, @intCast(data[9])) | @as(u64, @intCast(data[8])) << 8 | @as(u64, @intCast(data[7])) << 16 | @as(u64, @intCast(data[6])) << 24 | @as(u64, @intCast(data[5])) << 32 | @as(u64, @intCast(data[4])) << 40 | @as(u64, @intCast(data[3])) << 48 | @as(u64, @intCast(data[2])) << 56,
|
||||
else => data[1] & 127,
|
||||
} + length_of_len + 2 + 4; // +2 for header prefix, +4 for mask
|
||||
|
||||
if (message_len > MAX_MESSAGE_SIZE) {
|
||||
return error.TooLarge;
|
||||
}
|
||||
|
||||
if (data.len < message_len) {
|
||||
if (buf.len < message_len) {
|
||||
return null;
|
||||
}
|
||||
|
||||
// prefix + length_of_len + mask
|
||||
const header_len = 2 + length_of_len + 4;
|
||||
|
||||
const payload = data[header_len..message_len];
|
||||
mask(data[header_len - 4 .. header_len], payload);
|
||||
const payload = buf[header_len..message_len];
|
||||
mask(buf[header_len - 4 .. header_len], payload);
|
||||
|
||||
self.pos += message_len;
|
||||
|
||||
self.data = data[message_len..];
|
||||
return .{
|
||||
.type = message_type,
|
||||
.data = payload,
|
||||
};
|
||||
}
|
||||
|
||||
fn extractLengths(buf: []const u8) ?struct{usize, usize} {
|
||||
if (buf.len < 2) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const length_of_len: usize = switch (buf[1] & 127) {
|
||||
126 => 2,
|
||||
127 => 8,
|
||||
else => 0,
|
||||
};
|
||||
|
||||
if (buf.len < length_of_len + 2) {
|
||||
// we definitely don't have enough buf yet
|
||||
return null;
|
||||
}
|
||||
|
||||
const message_len = switch (length_of_len) {
|
||||
2 => @as(u16, @intCast(buf[3])) | @as(u16, @intCast(buf[2])) << 8,
|
||||
8 => @as(u64, @intCast(buf[9])) | @as(u64, @intCast(buf[8])) << 8 | @as(u64, @intCast(buf[7])) << 16 | @as(u64, @intCast(buf[6])) << 24 | @as(u64, @intCast(buf[5])) << 32 | @as(u64, @intCast(buf[4])) << 40 | @as(u64, @intCast(buf[3])) << 48 | @as(u64, @intCast(buf[2])) << 56,
|
||||
else => buf[1] & 127,
|
||||
} + length_of_len + 2 + 4; // +2 for header prefix, +4 for mask;
|
||||
|
||||
return .{length_of_len, message_len};
|
||||
}
|
||||
|
||||
// This is called after we've processed complete websocket messages (this
|
||||
// only applies to websocket messages).
|
||||
// There are three cases:
|
||||
// 1 - We don't have any incomplete data (for a subsequent message) in buf.
|
||||
// This is the easier to handle, we can set pos & len to 0.
|
||||
// 2 - We have part of the next message, but we know it'll fit in the
|
||||
// remaining buf. We don't need to do anything
|
||||
// 3 - We have part of the next message, but either it won't fight into the
|
||||
// remaining buffer, or we don't know (because we don't have enough
|
||||
// of the header to tell the length). We need to "compact" the buffer
|
||||
fn compact(self: *Reader) void {
|
||||
const pos = self.pos;
|
||||
const len = self.len;
|
||||
|
||||
std.debug.assert(pos <= len);
|
||||
|
||||
// how many (if any) partial bytes do we have
|
||||
const partial_bytes = len - pos;
|
||||
|
||||
if (partial_bytes == 0) {
|
||||
// We have no partial bytes. Setting these to 0 ensures that we
|
||||
// get the best utilization of our buffer
|
||||
self.pos = 0;
|
||||
self.len = 0;
|
||||
return;
|
||||
}
|
||||
|
||||
const partial = self.buf[pos..len];
|
||||
|
||||
// If we have enough bytes of the next message to tell its length
|
||||
// we'll be able to figure out whether we need to do anything or not.
|
||||
if (extractLengths(partial)) |length_meta| {
|
||||
const next_message_len = length_meta.@"1";
|
||||
// if this isn't true, then we have a full message and it
|
||||
// should have been processed.
|
||||
std.debug.assert(next_message_len > partial_bytes);
|
||||
|
||||
const missing_bytes = next_message_len - partial_bytes;
|
||||
|
||||
const free_space = self.buf.len - len;
|
||||
if (missing_bytes < free_space) {
|
||||
// we have enough space in our buffer, as is,
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// We're here because we either don't have enough bytes of the next
|
||||
// message, or we know that it won't fit in our buffer as-is.
|
||||
std.mem.copyForwards(u8, &self.buf, partial);
|
||||
self.pos = 0;
|
||||
self.len = partial_bytes;
|
||||
}
|
||||
};
|
||||
|
||||
const Message = struct {
|
||||
@@ -1138,7 +1200,7 @@ test "Client: http valid handshake" {
|
||||
"sec-websocket-key: this is my key\r\n" ++
|
||||
"Custom: Header-Value\r\n\r\n";
|
||||
|
||||
@memcpy(client.read_buf[0..request.len], request);
|
||||
@memcpy(client.reader.buf[0..request.len], request);
|
||||
try testing.expectEqual(true, try client.processData(request.len));
|
||||
|
||||
try testing.expectEqual(.websocket, client.mode);
|
||||
@@ -1159,7 +1221,7 @@ test "Client: http get json version" {
|
||||
|
||||
const request = "GET /json/version HTTP/1.1\r\n\r\n";
|
||||
|
||||
@memcpy(client.read_buf[0..request.len], request);
|
||||
@memcpy(client.reader.buf[0..request.len], request);
|
||||
try testing.expectEqual(true, try client.processData(request.len));
|
||||
|
||||
try testing.expectEqual(.http, client.mode);
|
||||
@@ -1193,21 +1255,21 @@ test "Client: read invalid websocket message" {
|
||||
error.InvalidMessageType,
|
||||
1002,
|
||||
"",
|
||||
&.{ 131, 1 }, // 128 (fin) | 3 where 3 isn't a valid type
|
||||
&.{ 131, 128, 'm', 'a', 's', 'k' }, // 128 (fin) | 3 where 3 isn't a valid type
|
||||
);
|
||||
|
||||
try assertWebSocketError(
|
||||
error.ContinuationNotSupported,
|
||||
1003,
|
||||
"",
|
||||
&.{ 128, 1 }, // 128 (fin) | 0 where 0 is a continuation frame
|
||||
&.{ 128, 128, 'm', 'a', 's', 'k' }, // 128 (fin) | 0 where 0 is a continuation frame
|
||||
);
|
||||
|
||||
try assertWebSocketError(
|
||||
error.ContinuationNotSupported,
|
||||
1003,
|
||||
"",
|
||||
&.{ 1, 1 }, // 0 (non-fin) | 1 non-fin (contination) not supported
|
||||
&.{ 1, 128, 'm', 'a', 's', 'k' }, // 0 (non-fin) | 1 non-fin (contination) not supported
|
||||
);
|
||||
|
||||
for ([_]u8{ 16, 32, 64 }) |rsv| {
|
||||
@@ -1216,7 +1278,7 @@ test "Client: read invalid websocket message" {
|
||||
error.ReservedFlags,
|
||||
1002,
|
||||
"",
|
||||
&.{ rsv, 0 },
|
||||
&.{ rsv, 128, 'm', 'a', 's', 'k' },
|
||||
);
|
||||
|
||||
// as a bitmask
|
||||
@@ -1224,7 +1286,7 @@ test "Client: read invalid websocket message" {
|
||||
error.ReservedFlags,
|
||||
1002,
|
||||
"",
|
||||
&.{ rsv + 4, 0 },
|
||||
&.{ rsv + 4, 128, 'm', 'a', 's', 'k' },
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1232,14 +1294,14 @@ test "Client: read invalid websocket message" {
|
||||
error.NotMasked,
|
||||
1002,
|
||||
"",
|
||||
&.{ 129, 127 }, // client->server messages must be masked
|
||||
&.{ 129, 1, 'a' }, // client->server messages must be masked
|
||||
);
|
||||
|
||||
try assertWebSocketError(
|
||||
error.TooLarge,
|
||||
1009,
|
||||
"",
|
||||
&.{ 129, 255, 0, 0, 0, 0, 0, 4, 0, 1 }, // 1024 * 256 + 1
|
||||
&.{ 129, 255, 0, 0, 0, 0, 0, 4, 0, 1, 'm', 'a', 's', 'k' }, // 1024 * 256 + 1
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1427,7 +1489,7 @@ fn assertHTTPError(
|
||||
defer ms.deinit();
|
||||
|
||||
var client = Client(*MockServer).init(0, &ms);
|
||||
@memcpy(client.read_buf[0..input.len], input);
|
||||
@memcpy(client.reader.buf[0..input.len], input);
|
||||
try testing.expectError(expected_error, client.processData(input.len));
|
||||
|
||||
const expected_response = std.fmt.comptimePrint(
|
||||
@@ -1451,7 +1513,7 @@ fn assertWebSocketError(
|
||||
var client = Client(*MockServer).init(0, &ms);
|
||||
client.mode = .websocket; // force websocket message processing
|
||||
|
||||
@memcpy(client.read_buf[0..input.len], input);
|
||||
@memcpy(client.reader.buf[0..input.len], input);
|
||||
try testing.expectError(expected_error, client.processData(input.len));
|
||||
|
||||
try testing.expectEqual(1, ms.sent.items.len);
|
||||
@@ -1481,7 +1543,7 @@ fn assertWebSocketMessage(
|
||||
var client = Client(*MockServer).init(0, &ms);
|
||||
client.mode = .websocket; // force websocket message processing
|
||||
|
||||
@memcpy(client.read_buf[0..input.len], input);
|
||||
@memcpy(client.reader.buf[0..input.len], input);
|
||||
const more = try client.processData(input.len);
|
||||
|
||||
try testing.expectEqual(1, ms.sent.items.len);
|
||||
|
||||
Reference in New Issue
Block a user