From 701e8277d6bce1703fd35f9c3df2a677f1cd2dd3 Mon Sep 17 00:00:00 2001 From: Karl Seguin Date: Fri, 7 Feb 2025 18:18:53 +0800 Subject: [PATCH] support continuation frames --- src/server.zig | 297 ++++++++++++++++++++++++++++++++++++------------- 1 file changed, 218 insertions(+), 79 deletions(-) diff --git a/src/server.zig b/src/server.zig index 980127c3..43a79acd 100644 --- a/src/server.zig +++ b/src/server.zig @@ -54,7 +54,8 @@ const WebSocketError = error{ NotMasked, TooLarge, InvalidMessageType, - ContinuationNotSupported, + InvalidContinuation, + NestedFragementation, }; const Error = IOError || cdp.Error || HTTPError || WebSocketError; @@ -212,7 +213,7 @@ const Server = struct { }; const more = client.processData(size) catch |err| { - std.debug.print("Client Processing Error: {}\n", .{err}); + log.err("Client Processing Error: {}\n", .{err}); return; }; @@ -460,8 +461,6 @@ fn Client(comptime S: type) type { const CLOSE_NORMAL = [_]u8{ 136, 2, 3, 232 }; // code: 1000 const CLOSE_TOO_BIG = [_]u8{ 136, 2, 3, 241 }; // 1009 const CLOSE_PROTOCOL_ERROR = [_]u8{ 136, 2, 3, 234 }; //code: 1002 - // This should be removed once we support continuation frames - const CLOSE_UNSUPPORTED_ERROR = [_]u8{ 136, 2, 3, 235 }; //code: 1003 const CLOSE_TIMEOUT = [_]u8{ 136, 2, 15, 160 }; // code: 4000 return struct { @@ -483,10 +482,10 @@ fn Client(comptime S: type) type { fn init(socket: posix.socket_t, server: S) Self { return .{ .mode = .http, - .reader = .{}, .socket = socket, .server = server, .last_active = now(), + .reader = .{ .allocator = server.allocator }, }; } @@ -497,6 +496,7 @@ fn Client(comptime S: type) type { } } self.server.queueClose(self.socket); + self.reader.deinit(); } fn readBuf(self: *Self) []u8 { @@ -686,7 +686,9 @@ fn Client(comptime S: type) type { error.NotMasked => self.send(&CLOSE_PROTOCOL_ERROR, false) catch {}, error.ReservedFlags => self.send(&CLOSE_PROTOCOL_ERROR, false) catch {}, error.InvalidMessageType => self.send(&CLOSE_PROTOCOL_ERROR, false) catch {}, - error.ContinuationNotSupported => self.send(&CLOSE_UNSUPPORTED_ERROR, false) catch {}, + error.InvalidContinuation => self.send(&CLOSE_PROTOCOL_ERROR, false) catch {}, + error.NestedFragementation => self.send(&CLOSE_PROTOCOL_ERROR, false) catch {}, + error.OutOfMemory => {}, // don't borther trying to send an error in this case } return err; } orelse break; @@ -701,6 +703,9 @@ fn Client(comptime S: type) type { }, .text, .binary => try self.server.handleCDP(msg.data), } + if (msg.cleanup_fragment) { + reader.cleanup(); + } } // We might have read part of the next message. Our reader potentially @@ -797,6 +802,8 @@ fn Client(comptime S: type) type { // can return zero or more Messages. When next returns null, any incomplete // message will remain in reader.data const Reader = struct { + allocator: Allocator, + // position in buf of the start of the next message pos: usize = 0, @@ -808,6 +815,19 @@ const Reader = struct { // fragmented into a normal message. buf: [MAX_MESSAGE_SIZE + 140]u8 = undefined, + fragments: ?Fragments = null, + + fn deinit(self: *Reader) void { + self.cleanup(); + } + + fn cleanup(self: *Reader) void { + if (self.fragments) |*f| { + f.message.deinit(self.allocator); + self.fragments = null; + } + } + fn readBuf(self: *Reader) []u8 { // We might have read a partial http or websocket message. // Subsequent reads must read from where we left off. @@ -815,65 +835,105 @@ const Reader = struct { } fn next(self: *Reader) !?Message { - var buf = self.buf[self.pos..self.len]; + LOOP: while (true) { + var buf = self.buf[self.pos..self.len]; - const length_of_len, const message_len = extractLengths(buf) orelse { - // we don't have enough bytes - return null; - }; + const length_of_len, const message_len = extractLengths(buf) orelse { + // we don't have enough bytes + return null; + }; - const byte1 = buf[0]; + const byte1 = buf[0]; - if (byte1 & 112 != 0) { - return error.ReservedFlags; + if (byte1 & 112 != 0) { + return error.ReservedFlags; + } + + if (buf[1] & 128 != 128) { + // client -> server messages _must_ be masked + return error.NotMasked; + } + + var is_continuation = false; + var message_type: Message.Type = undefined; + switch (byte1 & 15) { + 0 => is_continuation = true, + 1 => message_type = .text, + 2 => message_type = .binary, + 8 => message_type = .close, + 9 => message_type = .ping, + 10 => message_type = .pong, + else => return error.InvalidMessageType, + } + + if (message_len > MAX_MESSAGE_SIZE) { + return error.TooLarge; + } + + if (buf.len < message_len) { + return null; + } + + // prefix + length_of_len + mask + const header_len = 2 + length_of_len + 4; + + const payload = buf[header_len..message_len]; + mask(buf[header_len - 4 .. header_len], payload); + + // whatever happens after this, we know where the next message starts + self.pos += message_len; + + const fin = byte1 & 128 == 128; + + if (is_continuation) { + const fragments = &(self.fragments orelse return error.InvalidContinuation); + if (fragments.message.items.len + message_len > MAX_MESSAGE_SIZE) { + return error.TooLarge; + } + + try fragments.message.appendSlice(self.allocator, payload); + + if (fin == false) { + // maybe we have more parts of the message waiting + continue :LOOP; + } + + // this continuation is done! + return .{ + .type = fragments.type, + .data = fragments.message.items, + .cleanup_fragment = true, + }; + } + + const can_be_fragmented = message_type == .text or message_type == .binary; + if (self.fragments != null and can_be_fragmented) { + // if this isn't a continuation, then we can't have fragements + return error.NestedFragementation; + } + + if (fin == false) { + if (can_be_fragmented == false) { + return error.InvalidContinuation; + } + + // not continuation, and not fin. It has to be the first message + // in a fragemented message. + var fragments = Fragments{ .message = .{}, .type = message_type }; + try fragments.message.appendSlice(self.allocator, payload); + self.fragments = fragments; + continue :LOOP; + } + + return .{ + .data = payload, + .type = message_type, + .cleanup_fragment = false, + }; } - - const fin = byte1 & 128 == 128; - if (!fin) { - return error.ContinuationNotSupported; - } - - if (buf[1] & 128 != 128) { - // client -> server messages _must_ be masked - return error.NotMasked; - } - - // var is_continuation = false; - var message_type: Message.Type = undefined; - switch (byte1 & 15) { - // 0 => is_continuation = true, - 0 => return error.ContinuationNotSupported, - 1 => message_type = .text, - 2 => message_type = .binary, - 8 => message_type = .close, - 9 => message_type = .ping, - 10 => message_type = .pong, - else => return error.InvalidMessageType, - } - - if (message_len > MAX_MESSAGE_SIZE) { - return error.TooLarge; - } - - if (buf.len < message_len) { - return null; - } - - // prefix + length_of_len + mask - const header_len = 2 + length_of_len + 4; - - const payload = buf[header_len..message_len]; - mask(buf[header_len - 4 .. header_len], payload); - - self.pos += message_len; - - return .{ - .type = message_type, - .data = payload, - }; } - fn extractLengths(buf: []const u8) ?struct{usize, usize} { + fn extractLengths(buf: []const u8) ?struct { usize, usize } { if (buf.len < 2) { return null; } @@ -895,7 +955,7 @@ const Reader = struct { else => buf[1] & 127, } + length_of_len + 2 + 4; // +2 for header prefix, +4 for mask; - return .{length_of_len, message_len}; + return .{ length_of_len, message_len }; } // This is called after we've processed complete websocket messages (this @@ -952,9 +1012,15 @@ const Reader = struct { } }; +const Fragments = struct { + type: Message.Type, + message: std.ArrayListUnmanaged(u8), +}; + const Message = struct { type: Type, data: []const u8, + cleanup_fragment: bool, const Type = enum { text, @@ -1135,7 +1201,6 @@ test "Client: http invalid request" { "Request too large", "GET /over/9000 HTTP/1.1\r\n" ++ "Header: " ++ ("a" ** 2050) ++ "\r\n\r\n", ); - } test "Client: http invalid handshake" { @@ -1260,25 +1325,12 @@ test "Client: write websocket message" { } test "Client: read invalid websocket message" { + // 131 = 128 (fin) | 3 where 3 isn't a valid type try assertWebSocketError( error.InvalidMessageType, 1002, "", - &.{ 131, 128, 'm', 'a', 's', 'k' }, // 128 (fin) | 3 where 3 isn't a valid type - ); - - try assertWebSocketError( - error.ContinuationNotSupported, - 1003, - "", - &.{ 128, 128, 'm', 'a', 's', 'k' }, // 128 (fin) | 0 where 0 is a continuation frame - ); - - try assertWebSocketError( - error.ContinuationNotSupported, - 1003, - "", - &.{ 1, 128, 'm', 'a', 's', 'k' }, // 0 (non-fin) | 1 non-fin (contination) not supported + &.{ 131, 128, 'm', 'a', 's', 'k' }, ); for ([_]u8{ 16, 32, 64 }) |rsv| { @@ -1299,18 +1351,84 @@ test "Client: read invalid websocket message" { ); } + // client->server messages must be masked try assertWebSocketError( error.NotMasked, 1002, "", - &.{ 129, 1, 'a' }, // client->server messages must be masked + &.{ 129, 1, 'a' }, ); + // length of message is 0000 0401, i.e: 1024 * 256 + 1 try assertWebSocketError( error.TooLarge, 1009, "", - &.{ 129, 255, 0, 0, 0, 0, 0, 4, 0, 1, 'm', 'a', 's', 'k' }, // 1024 * 256 + 1 + &.{ 129, 255, 0, 0, 0, 0, 0, 4, 0, 1, 'm', 'a', 's', 'k' }, + ); + + // continuation type message must come after a normal message + // even when not a fin frame + try assertWebSocketError( + error.InvalidContinuation, + 1002, + "", + &.{ 0, 129, 'm', 'a', 's', 'k', 'd' }, + ); + + // continuation type message must come after a normal message + // even as a fin frame + try assertWebSocketError( + error.InvalidContinuation, + 1002, + "", + &.{ 128, 129, 'm', 'a', 's', 'k', 'd' }, + ); + + // text (non-fin) - text (non-fin) + try assertWebSocketError( + error.NestedFragementation, + 1002, + "", + &.{ 1, 129, 'm', 'a', 's', 'k', 'd', 1, 128, 'k', 's', 'a', 'm' }, + ); + + // text (non-fin) - text (fin) should always been continuation after non-fin + try assertWebSocketError( + error.NestedFragementation, + 1002, + "", + &.{ 1, 129, 'm', 'a', 's', 'k', 'd', 129, 128, 'k', 's', 'a', 'm' }, + ); + + // close must be fin + try assertWebSocketError( + error.InvalidContinuation, + 1002, + "", + &.{ + 8, 129, 'm', 'a', 's', 'k', 'd', + }, + ); + + // ping must be fin + try assertWebSocketError( + error.InvalidContinuation, + 1002, + "", + &.{ + 9, 129, 'm', 'a', 's', 'k', 'd', + }, + ); + + // pong must be fin + try assertWebSocketError( + error.InvalidContinuation, + 1002, + "", + &.{ + 10, 129, 'm', 'a', 's', 'k', 'd', + }, ); } @@ -1389,6 +1507,21 @@ test "Client: fuzz" { [_]u8{ 129, 254, 2, 175, 1, 2, 3, 4 } ++ "A" ** 687, ); + // non-fin text message + try websocket_messages.appendSlice(allocator, &.{ 1, 130, 0, 0, 0, 0, 1, 2 }); + + // continuation + try websocket_messages.appendSlice(allocator, &.{ 0, 131, 0, 0, 0, 0, 3, 4, 5 }); + + // pong happening in fragement + try websocket_messages.appendSlice(allocator, &.{ 138, 128, 0, 0, 0, 0 }); + + // more continuation + try websocket_messages.appendSlice(allocator, &.{ 0, 130, 0, 0, 0, 0, 6, 7 }); + + // fin + try websocket_messages.appendSlice(allocator, &.{ 128, 133, 0, 0, 0, 0, 8, 9, 10, 11, 12 }); + // close try websocket_messages.appendSlice( allocator, @@ -1455,7 +1588,7 @@ test "Client: fuzz" { ms.sent.items[4], ); - try testing.expectEqual(2, ms.cdp.items.len); + try testing.expectEqual(3, ms.cdp.items.len); try testing.expectEqualSlices( u8, &.{ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 }, @@ -1468,6 +1601,12 @@ test "Client: fuzz" { ms.cdp.items[1], ); + try testing.expectEqualSlices( + u8, + &.{ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12 }, + ms.cdp.items[2], + ); + try testing.expectEqual(true, ms.closed); } }