support continuation frames

This commit is contained in:
Karl Seguin
2025-02-07 18:18:53 +08:00
parent bdb70444d6
commit 14fe4f65e1

View File

@@ -54,7 +54,8 @@ const WebSocketError = error{
NotMasked,
TooLarge,
InvalidMessageType,
ContinuationNotSupported,
InvalidContinuation,
NestedFragementation,
};
const Error = IOError || cdp.Error || HTTPError || WebSocketError;
@@ -211,7 +212,7 @@ const Server = struct {
};
const more = client.processData(size) catch |err| {
std.debug.print("Client Processing Error: {}\n", .{err});
log.err("Client Processing Error: {}\n", .{err});
return;
};
@@ -459,8 +460,6 @@ fn Client(comptime S: type) type {
const CLOSE_NORMAL = [_]u8{ 136, 2, 3, 232 }; // code: 1000
const CLOSE_TOO_BIG = [_]u8{ 136, 2, 3, 241 }; // 1009
const CLOSE_PROTOCOL_ERROR = [_]u8{ 136, 2, 3, 234 }; //code: 1002
// This should be removed once we support continuation frames
const CLOSE_UNSUPPORTED_ERROR = [_]u8{ 136, 2, 3, 235 }; //code: 1003
const CLOSE_TIMEOUT = [_]u8{ 136, 2, 15, 160 }; // code: 4000
return struct {
@@ -482,10 +481,10 @@ fn Client(comptime S: type) type {
fn init(socket: posix.socket_t, server: S) Self {
return .{
.mode = .http,
.reader = .{},
.socket = socket,
.server = server,
.last_active = now(),
.reader = .{ .allocator = server.allocator },
};
}
@@ -496,6 +495,7 @@ fn Client(comptime S: type) type {
}
}
self.server.queueClose(self.socket);
self.reader.deinit();
}
fn readBuf(self: *Self) []u8 {
@@ -685,7 +685,9 @@ fn Client(comptime S: type) type {
error.NotMasked => self.send(&CLOSE_PROTOCOL_ERROR, false) catch {},
error.ReservedFlags => self.send(&CLOSE_PROTOCOL_ERROR, false) catch {},
error.InvalidMessageType => self.send(&CLOSE_PROTOCOL_ERROR, false) catch {},
error.ContinuationNotSupported => self.send(&CLOSE_UNSUPPORTED_ERROR, false) catch {},
error.InvalidContinuation => self.send(&CLOSE_PROTOCOL_ERROR, false) catch {},
error.NestedFragementation => self.send(&CLOSE_PROTOCOL_ERROR, false) catch {},
error.OutOfMemory => {}, // don't borther trying to send an error in this case
}
return err;
} orelse break;
@@ -700,6 +702,9 @@ fn Client(comptime S: type) type {
},
.text, .binary => try self.server.handleCDP(msg.data),
}
if (msg.cleanup_fragment) {
reader.cleanup();
}
}
// We might have read part of the next message. Our reader potentially
@@ -796,6 +801,8 @@ fn Client(comptime S: type) type {
// can return zero or more Messages. When next returns null, any incomplete
// message will remain in reader.data
const Reader = struct {
allocator: Allocator,
// position in buf of the start of the next message
pos: usize = 0,
@@ -807,6 +814,19 @@ const Reader = struct {
// fragmented into a normal message.
buf: [MAX_MESSAGE_SIZE + 140]u8 = undefined,
fragments: ?Fragments = null,
fn deinit(self: *Reader) void {
self.cleanup();
}
fn cleanup(self: *Reader) void {
if (self.fragments) |*f| {
f.message.deinit(self.allocator);
self.fragments = null;
}
}
fn readBuf(self: *Reader) []u8 {
// We might have read a partial http or websocket message.
// Subsequent reads must read from where we left off.
@@ -814,65 +834,105 @@ const Reader = struct {
}
fn next(self: *Reader) !?Message {
var buf = self.buf[self.pos..self.len];
LOOP: while (true) {
var buf = self.buf[self.pos..self.len];
const length_of_len, const message_len = extractLengths(buf) orelse {
// we don't have enough bytes
return null;
};
const length_of_len, const message_len = extractLengths(buf) orelse {
// we don't have enough bytes
return null;
};
const byte1 = buf[0];
const byte1 = buf[0];
if (byte1 & 112 != 0) {
return error.ReservedFlags;
if (byte1 & 112 != 0) {
return error.ReservedFlags;
}
if (buf[1] & 128 != 128) {
// client -> server messages _must_ be masked
return error.NotMasked;
}
var is_continuation = false;
var message_type: Message.Type = undefined;
switch (byte1 & 15) {
0 => is_continuation = true,
1 => message_type = .text,
2 => message_type = .binary,
8 => message_type = .close,
9 => message_type = .ping,
10 => message_type = .pong,
else => return error.InvalidMessageType,
}
if (message_len > MAX_MESSAGE_SIZE) {
return error.TooLarge;
}
if (buf.len < message_len) {
return null;
}
// prefix + length_of_len + mask
const header_len = 2 + length_of_len + 4;
const payload = buf[header_len..message_len];
mask(buf[header_len - 4 .. header_len], payload);
// whatever happens after this, we know where the next message starts
self.pos += message_len;
const fin = byte1 & 128 == 128;
if (is_continuation) {
const fragments = &(self.fragments orelse return error.InvalidContinuation);
if (fragments.message.items.len + message_len > MAX_MESSAGE_SIZE) {
return error.TooLarge;
}
try fragments.message.appendSlice(self.allocator, payload);
if (fin == false) {
// maybe we have more parts of the message waiting
continue :LOOP;
}
// this continuation is done!
return .{
.type = fragments.type,
.data = fragments.message.items,
.cleanup_fragment = true,
};
}
const can_be_fragmented = message_type == .text or message_type == .binary;
if (self.fragments != null and can_be_fragmented) {
// if this isn't a continuation, then we can't have fragements
return error.NestedFragementation;
}
if (fin == false) {
if (can_be_fragmented == false) {
return error.InvalidContinuation;
}
// not continuation, and not fin. It has to be the first message
// in a fragemented message.
var fragments = Fragments{ .message = .{}, .type = message_type };
try fragments.message.appendSlice(self.allocator, payload);
self.fragments = fragments;
continue :LOOP;
}
return .{
.data = payload,
.type = message_type,
.cleanup_fragment = false,
};
}
const fin = byte1 & 128 == 128;
if (!fin) {
return error.ContinuationNotSupported;
}
if (buf[1] & 128 != 128) {
// client -> server messages _must_ be masked
return error.NotMasked;
}
// var is_continuation = false;
var message_type: Message.Type = undefined;
switch (byte1 & 15) {
// 0 => is_continuation = true,
0 => return error.ContinuationNotSupported,
1 => message_type = .text,
2 => message_type = .binary,
8 => message_type = .close,
9 => message_type = .ping,
10 => message_type = .pong,
else => return error.InvalidMessageType,
}
if (message_len > MAX_MESSAGE_SIZE) {
return error.TooLarge;
}
if (buf.len < message_len) {
return null;
}
// prefix + length_of_len + mask
const header_len = 2 + length_of_len + 4;
const payload = buf[header_len..message_len];
mask(buf[header_len - 4 .. header_len], payload);
self.pos += message_len;
return .{
.type = message_type,
.data = payload,
};
}
fn extractLengths(buf: []const u8) ?struct{usize, usize} {
fn extractLengths(buf: []const u8) ?struct { usize, usize } {
if (buf.len < 2) {
return null;
}
@@ -894,7 +954,7 @@ const Reader = struct {
else => buf[1] & 127,
} + length_of_len + 2 + 4; // +2 for header prefix, +4 for mask;
return .{length_of_len, message_len};
return .{ length_of_len, message_len };
}
// This is called after we've processed complete websocket messages (this
@@ -951,9 +1011,15 @@ const Reader = struct {
}
};
const Fragments = struct {
type: Message.Type,
message: std.ArrayListUnmanaged(u8),
};
const Message = struct {
type: Type,
data: []const u8,
cleanup_fragment: bool,
const Type = enum {
text,
@@ -1125,7 +1191,6 @@ test "Client: http invalid request" {
"Request too large",
"GET /over/9000 HTTP/1.1\r\n" ++ "Header: " ++ ("a" ** 2050) ++ "\r\n\r\n",
);
}
test "Client: http invalid handshake" {
@@ -1251,25 +1316,12 @@ test "Client: write websocket message" {
}
test "Client: read invalid websocket message" {
// 131 = 128 (fin) | 3 where 3 isn't a valid type
try assertWebSocketError(
error.InvalidMessageType,
1002,
"",
&.{ 131, 128, 'm', 'a', 's', 'k' }, // 128 (fin) | 3 where 3 isn't a valid type
);
try assertWebSocketError(
error.ContinuationNotSupported,
1003,
"",
&.{ 128, 128, 'm', 'a', 's', 'k' }, // 128 (fin) | 0 where 0 is a continuation frame
);
try assertWebSocketError(
error.ContinuationNotSupported,
1003,
"",
&.{ 1, 128, 'm', 'a', 's', 'k' }, // 0 (non-fin) | 1 non-fin (contination) not supported
&.{ 131, 128, 'm', 'a', 's', 'k' },
);
for ([_]u8{ 16, 32, 64 }) |rsv| {
@@ -1290,18 +1342,84 @@ test "Client: read invalid websocket message" {
);
}
// client->server messages must be masked
try assertWebSocketError(
error.NotMasked,
1002,
"",
&.{ 129, 1, 'a' }, // client->server messages must be masked
&.{ 129, 1, 'a' },
);
// length of message is 0000 0401, i.e: 1024 * 256 + 1
try assertWebSocketError(
error.TooLarge,
1009,
"",
&.{ 129, 255, 0, 0, 0, 0, 0, 4, 0, 1, 'm', 'a', 's', 'k' }, // 1024 * 256 + 1
&.{ 129, 255, 0, 0, 0, 0, 0, 4, 0, 1, 'm', 'a', 's', 'k' },
);
// continuation type message must come after a normal message
// even when not a fin frame
try assertWebSocketError(
error.InvalidContinuation,
1002,
"",
&.{ 0, 129, 'm', 'a', 's', 'k', 'd' },
);
// continuation type message must come after a normal message
// even as a fin frame
try assertWebSocketError(
error.InvalidContinuation,
1002,
"",
&.{ 128, 129, 'm', 'a', 's', 'k', 'd' },
);
// text (non-fin) - text (non-fin)
try assertWebSocketError(
error.NestedFragementation,
1002,
"",
&.{ 1, 129, 'm', 'a', 's', 'k', 'd', 1, 128, 'k', 's', 'a', 'm' },
);
// text (non-fin) - text (fin) should always been continuation after non-fin
try assertWebSocketError(
error.NestedFragementation,
1002,
"",
&.{ 1, 129, 'm', 'a', 's', 'k', 'd', 129, 128, 'k', 's', 'a', 'm' },
);
// close must be fin
try assertWebSocketError(
error.InvalidContinuation,
1002,
"",
&.{
8, 129, 'm', 'a', 's', 'k', 'd',
},
);
// ping must be fin
try assertWebSocketError(
error.InvalidContinuation,
1002,
"",
&.{
9, 129, 'm', 'a', 's', 'k', 'd',
},
);
// pong must be fin
try assertWebSocketError(
error.InvalidContinuation,
1002,
"",
&.{
10, 129, 'm', 'a', 's', 'k', 'd',
},
);
}
@@ -1380,6 +1498,21 @@ test "Client: fuzz" {
[_]u8{ 129, 254, 2, 175, 1, 2, 3, 4 } ++ "A" ** 687,
);
// non-fin text message
try websocket_messages.appendSlice(allocator, &.{ 1, 130, 0, 0, 0, 0, 1, 2 });
// continuation
try websocket_messages.appendSlice(allocator, &.{ 0, 131, 0, 0, 0, 0, 3, 4, 5 });
// pong happening in fragement
try websocket_messages.appendSlice(allocator, &.{ 138, 128, 0, 0, 0, 0 });
// more continuation
try websocket_messages.appendSlice(allocator, &.{ 0, 130, 0, 0, 0, 0, 6, 7 });
// fin
try websocket_messages.appendSlice(allocator, &.{ 128, 133, 0, 0, 0, 0, 8, 9, 10, 11, 12 });
// close
try websocket_messages.appendSlice(
allocator,
@@ -1446,7 +1579,7 @@ test "Client: fuzz" {
ms.sent.items[4],
);
try testing.expectEqual(2, ms.cdp.items.len);
try testing.expectEqual(3, ms.cdp.items.len);
try testing.expectEqualSlices(
u8,
&.{ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 },
@@ -1459,6 +1592,12 @@ test "Client: fuzz" {
ms.cdp.items[1],
);
try testing.expectEqualSlices(
u8,
&.{ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12 },
ms.cdp.items[2],
);
try testing.expectEqual(true, ms.closed);
}
}