diff --git a/src/server.zig b/src/server.zig index 617be02b..39fd6907 100644 --- a/src/server.zig +++ b/src/server.zig @@ -36,7 +36,7 @@ const MAX_HTTP_REQUEST_SIZE = 4096; // max message size // +14 for max websocket payload overhead // +140 for the max control packet that might be interleaved in a message -const MAX_MESSAGE_SIZE = 512 * 1024 + 14; +const MAX_MESSAGE_SIZE = 512 * 1024 + 14 + 140; pub const Server = struct { app: *App, @@ -188,12 +188,15 @@ pub const Client = struct { // we expect the socket to come to us as nonblocking std.debug.assert(socket_flags & nonblocking == nonblocking); + var reader = try Reader(true).init(server.allocator); + errdefer reader.deinit(); + return .{ .socket = socket, .server = server, + .reader = reader, .mode = .{ .http = {} }, .socket_flags = socket_flags, - .reader = .{ .allocator = server.allocator }, .send_arena = ArenaAllocator.init(server.allocator), }; } @@ -537,14 +540,23 @@ fn Reader(comptime EXPECT_MASK: bool) type { // we add 140 to allow 1 control message (ping/pong/close) to be // fragmented into a normal message. - buf: [MAX_MESSAGE_SIZE + 140]u8 = undefined, + buf: []u8, fragments: ?Fragments = null, const Self = @This(); + fn init(allocator: Allocator) !Self { + const buf = try allocator.alloc(u8, 16 * 1024); + return .{ + .buf = buf, + .allocator = allocator, + }; + } + fn deinit(self: *Self) void { self.cleanup(); + self.allocator.free(self.buf); } fn cleanup(self: *Self) void { @@ -613,9 +625,14 @@ fn Reader(comptime EXPECT_MASK: bool) type { } } else if (message_len > MAX_MESSAGE_SIZE) { return error.TooLarge; - } - - if (buf.len < message_len) { + } else if (message_len > self.buf.len) { + const len = self.buf.len; + self.buf = try growBuffer(self.allocator, self.buf, message_len); + buf = self.buf[0..len]; + // we need more data + return null; + } else if (buf.len < message_len) { + // we need more data return null; } @@ -753,13 +770,32 @@ fn Reader(comptime EXPECT_MASK: bool) type { // We're here because we either don't have enough bytes of the next // message, or we know that it won't fit in our buffer as-is. - std.mem.copyForwards(u8, &self.buf, partial); + std.mem.copyForwards(u8, self.buf, partial); self.pos = 0; self.len = partial_bytes; } }; } +fn growBuffer(allocator: Allocator, buf: []u8, required_capacity: usize) ![]u8 { + // from std.ArrayList + var new_capacity = buf.len; + while (true) { + new_capacity +|= new_capacity / 2 + 8; + if (new_capacity >= required_capacity) break; + } + + log.debug(.app, "CDP buffer growth", .{ .from = buf.len, .to = new_capacity }); + + if (allocator.resize(buf, new_capacity)) { + return buf.ptr[0..new_capacity]; + } + const new_buffer = try allocator.alloc(u8, new_capacity); + @memcpy(new_buffer[0..buf.len], buf); + allocator.free(buf); + return new_buffer; +} + const Fragments = struct { type: Message.Type, message: std.ArrayListUnmanaged(u8), @@ -1037,8 +1073,8 @@ test "Client: read invalid websocket message" { ); } - // length of message is 0000 0401, i.e: 1024 * 512 + 1 - try assertWebSocketError(1009, &.{ 129, 255, 0, 0, 0, 0, 0, 8, 0, 1, 'm', 'a', 's', 'k' }); + // length of message is 0000 0810, i.e: 1024 * 512 + 265 + try assertWebSocketError(1009, &.{ 129, 255, 0, 0, 0, 0, 0, 8, 1, 0, 'm', 'a', 's', 'k' }); // continuation type message must come after a normal message // even when not a fin frame @@ -1260,7 +1296,10 @@ fn createTestClient() !TestClient { try posix.setsockopt(stream.handle, posix.SOL.SOCKET, posix.SO.SNDTIMEO, &timeout); return .{ .stream = stream, - .reader = .{ .allocator = testing.allocator }, + .reader = .{ + .allocator = testing.allocator, + .buf = try testing.allocator.alloc(u8, 1024 * 16), + }, }; }