Make websocket client reader stateful

Move more logic into the reader. Avoid copying partial messages in
cases where we know that the buffer is large enough.

This is mostly groundwork for trying to add support for continuation
frames.
This commit is contained in:
Karl Seguin
2025-02-07 15:57:02 +08:00
parent 4d9cc55a87
commit bdb70444d6

View File

@@ -64,7 +64,9 @@ const log = std.log.scoped(.server);
const MAX_HTTP_REQUEST_SIZE = 2048;
// max message size, +14 for max websocket payload overhead
// max message size
// +14 for max websocket payload overhead
// +140 for the max control packet that might be interleaved in a message
const MAX_MESSAGE_SIZE = 256 * 1024 + 14;
// For now, cdp does @import("server.zig").Ctx. Could change cdp to use "Server"
@@ -466,15 +468,10 @@ fn Client(comptime S: type) type {
// should eventually be upgraded to a websocket connections
mode: Mode,
server: S,
reader: Reader,
socket: posix.socket_t,
last_active: std.time.Instant,
// the start of the message in our read_buf
read_pos: usize = 0,
// up to where do we have data in our read_buf
read_len: usize = 0,
read_buf: [MAX_MESSAGE_SIZE]u8 = undefined,
const Mode = enum {
http,
websocket,
@@ -485,6 +482,7 @@ fn Client(comptime S: type) type {
fn init(socket: posix.socket_t, server: S) Self {
return .{
.mode = .http,
.reader = .{},
.socket = socket,
.server = server,
.last_active = now(),
@@ -501,31 +499,25 @@ fn Client(comptime S: type) type {
}
fn readBuf(self: *Self) []u8 {
// We might have read a partial http or websocket message.
// Subsequent reads must read from where we left off.
std.debug.assert(self.read_pos < self.read_buf.len);
return self.read_buf[self.read_len..];
return self.reader.readBuf();
}
fn processData(self: *Self, len: usize) !bool {
const end = self.read_len + len;
std.debug.assert(end >= self.read_pos);
self.last_active = now();
const data = self.read_buf[self.read_pos..end];
self.reader.len += len;
switch (self.mode) {
.http => {
try self.processHTTPRequest(data);
try self.processHTTPRequest();
return true;
},
.websocket => return self.processWebsocketMessage(data),
.websocket => return self.processWebsocketMessage(),
}
}
fn processHTTPRequest(self: *Self, request: []u8) HTTPError!void {
// We should never get pipelined HTTP requests
std.debug.assert(self.read_pos == 0);
fn processHTTPRequest(self: *Self) HTTPError!void {
std.debug.assert(self.reader.pos == 0);
const request = self.reader.buf[0..self.reader.len];
errdefer self.server.queueClose(self.socket);
@@ -537,7 +529,6 @@ fn Client(comptime S: type) type {
// we're only expecting [body-less] GET requests.
if (std.mem.endsWith(u8, request, "\r\n\r\n") == false) {
// we need more data, put any more data here
self.read_len = request.len;
return;
}
@@ -559,7 +550,7 @@ fn Client(comptime S: type) type {
};
// the next incoming data can go to the front of our buffer
self.read_len = 0;
self.reader.len = 0;
}
fn handleHTTPRequest(self: *Self, request: []u8) !void {
@@ -683,10 +674,10 @@ fn Client(comptime S: type) type {
return self.send(response, true);
}
fn processWebsocketMessage(self: *Self, data: []u8) !bool {
fn processWebsocketMessage(self: *Self) !bool {
errdefer self.server.queueClose(self.socket);
var reader = Reader{ .data = data };
var reader = &self.reader;
while (true) {
const msg = reader.next() catch |err| {
switch (err) {
@@ -711,18 +702,9 @@ fn Client(comptime S: type) type {
}
}
const incomplete = reader.data;
self.read_len = incomplete.len;
if (incomplete.len > 0) {
// we have part of the data for the next message
// can't use @memset because incomplete is a slice of read_buf,
// so they could overlap
// TODO: this can be skipped if we know that the next message will
// fit into whatever reamining space we have.
std.mem.copyForwards(u8, self.read_buf[0..incomplete.len], incomplete);
}
// We might have read part of the next message. Our reader potentially
// has to move data around in its buffer to make space.
reader.compact();
return true;
}
@@ -814,23 +796,52 @@ fn Client(comptime S: type) type {
// can return zero or more Messages. When next returns null, any incomplete
// message will remain in reader.data
const Reader = struct {
data: []u8,
// position in buf of the start of the next message
pos: usize = 0,
fn next(self: *Reader) !?Message {
var data = self.data;
if (data.len < 2) {
return null;
// position in buf up until where we have valid data
// (any new reads must be placed after this)
len: usize = 0,
// we add 140 to allow 1 control message (ping/pong/close) to be
// fragmented into a normal message.
buf: [MAX_MESSAGE_SIZE + 140]u8 = undefined,
fn readBuf(self: *Reader) []u8 {
// We might have read a partial http or websocket message.
// Subsequent reads must read from where we left off.
return self.buf[self.len..];
}
const byte1 = data[0];
fn next(self: *Reader) !?Message {
var buf = self.buf[self.pos..self.len];
const length_of_len, const message_len = extractLengths(buf) orelse {
// we don't have enough bytes
return null;
};
const byte1 = buf[0];
if (byte1 & 112 != 0) {
return error.ReservedFlags;
}
const fin = byte1 & 128 == 128;
if (!fin) {
return error.ContinuationNotSupported;
}
if (buf[1] & 128 != 128) {
// client -> server messages _must_ be masked
return error.NotMasked;
}
// var is_continuation = false;
var message_type: Message.Type = undefined;
switch (byte1 & 15) {
0 => return error.ContinuationNotSupported, // TODO??
// 0 => is_continuation = true,
0 => return error.ContinuationNotSupported,
1 => message_type = .text,
2 => message_type = .binary,
8 => message_type = .close,
@@ -839,54 +850,105 @@ const Reader = struct {
else => return error.InvalidMessageType,
}
if (byte1 & 128 != 128) {
// TODO??
return error.ContinuationNotSupported;
}
const byte2 = data[1];
if (byte2 & 128 != 128) {
// client -> server messages _must_ be masked
return error.NotMasked;
}
const length_of_len: usize = switch (byte2 & 127) {
126 => 2,
127 => 8,
else => 0,
};
if (data.len < length_of_len + 2) {
// we definitely don't have enough data yet
return null;
}
const message_len = switch (length_of_len) {
2 => @as(u16, @intCast(data[3])) | @as(u16, @intCast(data[2])) << 8,
8 => @as(u64, @intCast(data[9])) | @as(u64, @intCast(data[8])) << 8 | @as(u64, @intCast(data[7])) << 16 | @as(u64, @intCast(data[6])) << 24 | @as(u64, @intCast(data[5])) << 32 | @as(u64, @intCast(data[4])) << 40 | @as(u64, @intCast(data[3])) << 48 | @as(u64, @intCast(data[2])) << 56,
else => data[1] & 127,
} + length_of_len + 2 + 4; // +2 for header prefix, +4 for mask
if (message_len > MAX_MESSAGE_SIZE) {
return error.TooLarge;
}
if (data.len < message_len) {
if (buf.len < message_len) {
return null;
}
// prefix + length_of_len + mask
const header_len = 2 + length_of_len + 4;
const payload = data[header_len..message_len];
mask(data[header_len - 4 .. header_len], payload);
const payload = buf[header_len..message_len];
mask(buf[header_len - 4 .. header_len], payload);
self.pos += message_len;
self.data = data[message_len..];
return .{
.type = message_type,
.data = payload,
};
}
fn extractLengths(buf: []const u8) ?struct{usize, usize} {
if (buf.len < 2) {
return null;
}
const length_of_len: usize = switch (buf[1] & 127) {
126 => 2,
127 => 8,
else => 0,
};
if (buf.len < length_of_len + 2) {
// we definitely don't have enough buf yet
return null;
}
const message_len = switch (length_of_len) {
2 => @as(u16, @intCast(buf[3])) | @as(u16, @intCast(buf[2])) << 8,
8 => @as(u64, @intCast(buf[9])) | @as(u64, @intCast(buf[8])) << 8 | @as(u64, @intCast(buf[7])) << 16 | @as(u64, @intCast(buf[6])) << 24 | @as(u64, @intCast(buf[5])) << 32 | @as(u64, @intCast(buf[4])) << 40 | @as(u64, @intCast(buf[3])) << 48 | @as(u64, @intCast(buf[2])) << 56,
else => buf[1] & 127,
} + length_of_len + 2 + 4; // +2 for header prefix, +4 for mask;
return .{length_of_len, message_len};
}
// This is called after we've processed complete websocket messages (this
// only applies to websocket messages).
// There are three cases:
// 1 - We don't have any incomplete data (for a subsequent message) in buf.
// This is the easier to handle, we can set pos & len to 0.
// 2 - We have part of the next message, but we know it'll fit in the
// remaining buf. We don't need to do anything
// 3 - We have part of the next message, but either it won't fight into the
// remaining buffer, or we don't know (because we don't have enough
// of the header to tell the length). We need to "compact" the buffer
fn compact(self: *Reader) void {
const pos = self.pos;
const len = self.len;
std.debug.assert(pos <= len);
// how many (if any) partial bytes do we have
const partial_bytes = len - pos;
if (partial_bytes == 0) {
// We have no partial bytes. Setting these to 0 ensures that we
// get the best utilization of our buffer
self.pos = 0;
self.len = 0;
return;
}
const partial = self.buf[pos..len];
// If we have enough bytes of the next message to tell its length
// we'll be able to figure out whether we need to do anything or not.
if (extractLengths(partial)) |length_meta| {
const next_message_len = length_meta.@"1";
// if this isn't true, then we have a full message and it
// should have been processed.
std.debug.assert(next_message_len > partial_bytes);
const missing_bytes = next_message_len - partial_bytes;
const free_space = self.buf.len - len;
if (missing_bytes < free_space) {
// we have enough space in our buffer, as is,
return;
}
}
// We're here because we either don't have enough bytes of the next
// message, or we know that it won't fit in our buffer as-is.
std.mem.copyForwards(u8, &self.buf, partial);
self.pos = 0;
self.len = partial_bytes;
}
};
const Message = struct {
@@ -1138,7 +1200,7 @@ test "Client: http valid handshake" {
"sec-websocket-key: this is my key\r\n" ++
"Custom: Header-Value\r\n\r\n";
@memcpy(client.read_buf[0..request.len], request);
@memcpy(client.reader.buf[0..request.len], request);
try testing.expectEqual(true, try client.processData(request.len));
try testing.expectEqual(.websocket, client.mode);
@@ -1159,7 +1221,7 @@ test "Client: http get json version" {
const request = "GET /json/version HTTP/1.1\r\n\r\n";
@memcpy(client.read_buf[0..request.len], request);
@memcpy(client.reader.buf[0..request.len], request);
try testing.expectEqual(true, try client.processData(request.len));
try testing.expectEqual(.http, client.mode);
@@ -1193,21 +1255,21 @@ test "Client: read invalid websocket message" {
error.InvalidMessageType,
1002,
"",
&.{ 131, 1 }, // 128 (fin) | 3 where 3 isn't a valid type
&.{ 131, 128, 'm', 'a', 's', 'k' }, // 128 (fin) | 3 where 3 isn't a valid type
);
try assertWebSocketError(
error.ContinuationNotSupported,
1003,
"",
&.{ 128, 1 }, // 128 (fin) | 0 where 0 is a continuation frame
&.{ 128, 128, 'm', 'a', 's', 'k' }, // 128 (fin) | 0 where 0 is a continuation frame
);
try assertWebSocketError(
error.ContinuationNotSupported,
1003,
"",
&.{ 1, 1 }, // 0 (non-fin) | 1 non-fin (contination) not supported
&.{ 1, 128, 'm', 'a', 's', 'k' }, // 0 (non-fin) | 1 non-fin (contination) not supported
);
for ([_]u8{ 16, 32, 64 }) |rsv| {
@@ -1216,7 +1278,7 @@ test "Client: read invalid websocket message" {
error.ReservedFlags,
1002,
"",
&.{ rsv, 0 },
&.{ rsv, 128, 'm', 'a', 's', 'k' },
);
// as a bitmask
@@ -1224,7 +1286,7 @@ test "Client: read invalid websocket message" {
error.ReservedFlags,
1002,
"",
&.{ rsv + 4, 0 },
&.{ rsv + 4, 128, 'm', 'a', 's', 'k' },
);
}
@@ -1232,14 +1294,14 @@ test "Client: read invalid websocket message" {
error.NotMasked,
1002,
"",
&.{ 129, 127 }, // client->server messages must be masked
&.{ 129, 1, 'a' }, // client->server messages must be masked
);
try assertWebSocketError(
error.TooLarge,
1009,
"",
&.{ 129, 255, 0, 0, 0, 0, 0, 4, 0, 1 }, // 1024 * 256 + 1
&.{ 129, 255, 0, 0, 0, 0, 0, 4, 0, 1, 'm', 'a', 's', 'k' }, // 1024 * 256 + 1
);
}
@@ -1427,7 +1489,7 @@ fn assertHTTPError(
defer ms.deinit();
var client = Client(*MockServer).init(0, &ms);
@memcpy(client.read_buf[0..input.len], input);
@memcpy(client.reader.buf[0..input.len], input);
try testing.expectError(expected_error, client.processData(input.len));
const expected_response = std.fmt.comptimePrint(
@@ -1451,7 +1513,7 @@ fn assertWebSocketError(
var client = Client(*MockServer).init(0, &ms);
client.mode = .websocket; // force websocket message processing
@memcpy(client.read_buf[0..input.len], input);
@memcpy(client.reader.buf[0..input.len], input);
try testing.expectError(expected_error, client.processData(input.len));
try testing.expectEqual(1, ms.sent.items.len);
@@ -1481,7 +1543,7 @@ fn assertWebSocketMessage(
var client = Client(*MockServer).init(0, &ms);
client.mode = .websocket; // force websocket message processing
@memcpy(client.read_buf[0..input.len], input);
@memcpy(client.reader.buf[0..input.len], input);
const more = try client.processData(input.len);
try testing.expectEqual(1, ms.sent.items.len);