diff --git a/src/server.zig b/src/server.zig index 1ef5d01e..d5bc4856 100644 --- a/src/server.zig +++ b/src/server.zig @@ -64,7 +64,9 @@ const log = std.log.scoped(.server); const MAX_HTTP_REQUEST_SIZE = 2048; -// max message size, +14 for max websocket payload overhead +// max message size +// +14 for max websocket payload overhead +// +140 for the max control packet that might be interleaved in a message const MAX_MESSAGE_SIZE = 256 * 1024 + 14; // For now, cdp does @import("server.zig").Ctx. Could change cdp to use "Server" @@ -466,15 +468,10 @@ fn Client(comptime S: type) type { // should eventually be upgraded to a websocket connections mode: Mode, server: S, + reader: Reader, socket: posix.socket_t, last_active: std.time.Instant, - // the start of the message in our read_buf - read_pos: usize = 0, - // up to where do we have data in our read_buf - read_len: usize = 0, - read_buf: [MAX_MESSAGE_SIZE]u8 = undefined, - const Mode = enum { http, websocket, @@ -485,6 +482,7 @@ fn Client(comptime S: type) type { fn init(socket: posix.socket_t, server: S) Self { return .{ .mode = .http, + .reader = .{}, .socket = socket, .server = server, .last_active = now(), @@ -501,31 +499,25 @@ fn Client(comptime S: type) type { } fn readBuf(self: *Self) []u8 { - // We might have read a partial http or websocket message. - // Subsequent reads must read from where we left off. - std.debug.assert(self.read_pos < self.read_buf.len); - return self.read_buf[self.read_len..]; + return self.reader.readBuf(); } fn processData(self: *Self, len: usize) !bool { - const end = self.read_len + len; - std.debug.assert(end >= self.read_pos); - self.last_active = now(); - const data = self.read_buf[self.read_pos..end]; + self.reader.len += len; switch (self.mode) { .http => { - try self.processHTTPRequest(data); + try self.processHTTPRequest(); return true; }, - .websocket => return self.processWebsocketMessage(data), + .websocket => return self.processWebsocketMessage(), } } - fn processHTTPRequest(self: *Self, request: []u8) HTTPError!void { - // We should never get pipelined HTTP requests - std.debug.assert(self.read_pos == 0); + fn processHTTPRequest(self: *Self) HTTPError!void { + std.debug.assert(self.reader.pos == 0); + const request = self.reader.buf[0..self.reader.len]; errdefer self.server.queueClose(self.socket); @@ -537,7 +529,6 @@ fn Client(comptime S: type) type { // we're only expecting [body-less] GET requests. if (std.mem.endsWith(u8, request, "\r\n\r\n") == false) { // we need more data, put any more data here - self.read_len = request.len; return; } @@ -559,7 +550,7 @@ fn Client(comptime S: type) type { }; // the next incoming data can go to the front of our buffer - self.read_len = 0; + self.reader.len = 0; } fn handleHTTPRequest(self: *Self, request: []u8) !void { @@ -683,10 +674,10 @@ fn Client(comptime S: type) type { return self.send(response, true); } - fn processWebsocketMessage(self: *Self, data: []u8) !bool { + fn processWebsocketMessage(self: *Self) !bool { errdefer self.server.queueClose(self.socket); - var reader = Reader{ .data = data }; + var reader = &self.reader; while (true) { const msg = reader.next() catch |err| { switch (err) { @@ -711,18 +702,9 @@ fn Client(comptime S: type) type { } } - const incomplete = reader.data; - self.read_len = incomplete.len; - if (incomplete.len > 0) { - // we have part of the data for the next message - - // can't use @memset because incomplete is a slice of read_buf, - // so they could overlap - - // TODO: this can be skipped if we know that the next message will - // fit into whatever reamining space we have. - std.mem.copyForwards(u8, self.read_buf[0..incomplete.len], incomplete); - } + // We might have read part of the next message. Our reader potentially + // has to move data around in its buffer to make space. + reader.compact(); return true; } @@ -814,23 +796,52 @@ fn Client(comptime S: type) type { // can return zero or more Messages. When next returns null, any incomplete // message will remain in reader.data const Reader = struct { - data: []u8, + // position in buf of the start of the next message + pos: usize = 0, + + // position in buf up until where we have valid data + // (any new reads must be placed after this) + len: usize = 0, + + // we add 140 to allow 1 control message (ping/pong/close) to be + // fragmented into a normal message. + buf: [MAX_MESSAGE_SIZE + 140]u8 = undefined, + + fn readBuf(self: *Reader) []u8 { + // We might have read a partial http or websocket message. + // Subsequent reads must read from where we left off. + return self.buf[self.len..]; + } fn next(self: *Reader) !?Message { - var data = self.data; - if (data.len < 2) { - return null; - } + var buf = self.buf[self.pos..self.len]; - const byte1 = data[0]; + const length_of_len, const message_len = extractLengths(buf) orelse { + // we don't have enough bytes + return null; + }; + + const byte1 = buf[0]; if (byte1 & 112 != 0) { return error.ReservedFlags; } + const fin = byte1 & 128 == 128; + if (!fin) { + return error.ContinuationNotSupported; + } + + if (buf[1] & 128 != 128) { + // client -> server messages _must_ be masked + return error.NotMasked; + } + + // var is_continuation = false; var message_type: Message.Type = undefined; switch (byte1 & 15) { - 0 => return error.ContinuationNotSupported, // TODO?? + // 0 => is_continuation = true, + 0 => return error.ContinuationNotSupported, 1 => message_type = .text, 2 => message_type = .binary, 8 => message_type = .close, @@ -839,54 +850,105 @@ const Reader = struct { else => return error.InvalidMessageType, } - if (byte1 & 128 != 128) { - // TODO?? - return error.ContinuationNotSupported; - } - - const byte2 = data[1]; - if (byte2 & 128 != 128) { - // client -> server messages _must_ be masked - return error.NotMasked; - } - - const length_of_len: usize = switch (byte2 & 127) { - 126 => 2, - 127 => 8, - else => 0, - }; - - if (data.len < length_of_len + 2) { - // we definitely don't have enough data yet - return null; - } - - const message_len = switch (length_of_len) { - 2 => @as(u16, @intCast(data[3])) | @as(u16, @intCast(data[2])) << 8, - 8 => @as(u64, @intCast(data[9])) | @as(u64, @intCast(data[8])) << 8 | @as(u64, @intCast(data[7])) << 16 | @as(u64, @intCast(data[6])) << 24 | @as(u64, @intCast(data[5])) << 32 | @as(u64, @intCast(data[4])) << 40 | @as(u64, @intCast(data[3])) << 48 | @as(u64, @intCast(data[2])) << 56, - else => data[1] & 127, - } + length_of_len + 2 + 4; // +2 for header prefix, +4 for mask - if (message_len > MAX_MESSAGE_SIZE) { return error.TooLarge; } - if (data.len < message_len) { + if (buf.len < message_len) { return null; } // prefix + length_of_len + mask const header_len = 2 + length_of_len + 4; - const payload = data[header_len..message_len]; - mask(data[header_len - 4 .. header_len], payload); + const payload = buf[header_len..message_len]; + mask(buf[header_len - 4 .. header_len], payload); + + self.pos += message_len; - self.data = data[message_len..]; return .{ .type = message_type, .data = payload, }; } + + fn extractLengths(buf: []const u8) ?struct{usize, usize} { + if (buf.len < 2) { + return null; + } + + const length_of_len: usize = switch (buf[1] & 127) { + 126 => 2, + 127 => 8, + else => 0, + }; + + if (buf.len < length_of_len + 2) { + // we definitely don't have enough buf yet + return null; + } + + const message_len = switch (length_of_len) { + 2 => @as(u16, @intCast(buf[3])) | @as(u16, @intCast(buf[2])) << 8, + 8 => @as(u64, @intCast(buf[9])) | @as(u64, @intCast(buf[8])) << 8 | @as(u64, @intCast(buf[7])) << 16 | @as(u64, @intCast(buf[6])) << 24 | @as(u64, @intCast(buf[5])) << 32 | @as(u64, @intCast(buf[4])) << 40 | @as(u64, @intCast(buf[3])) << 48 | @as(u64, @intCast(buf[2])) << 56, + else => buf[1] & 127, + } + length_of_len + 2 + 4; // +2 for header prefix, +4 for mask; + + return .{length_of_len, message_len}; + } + + // This is called after we've processed complete websocket messages (this + // only applies to websocket messages). + // There are three cases: + // 1 - We don't have any incomplete data (for a subsequent message) in buf. + // This is the easier to handle, we can set pos & len to 0. + // 2 - We have part of the next message, but we know it'll fit in the + // remaining buf. We don't need to do anything + // 3 - We have part of the next message, but either it won't fight into the + // remaining buffer, or we don't know (because we don't have enough + // of the header to tell the length). We need to "compact" the buffer + fn compact(self: *Reader) void { + const pos = self.pos; + const len = self.len; + + std.debug.assert(pos <= len); + + // how many (if any) partial bytes do we have + const partial_bytes = len - pos; + + if (partial_bytes == 0) { + // We have no partial bytes. Setting these to 0 ensures that we + // get the best utilization of our buffer + self.pos = 0; + self.len = 0; + return; + } + + const partial = self.buf[pos..len]; + + // If we have enough bytes of the next message to tell its length + // we'll be able to figure out whether we need to do anything or not. + if (extractLengths(partial)) |length_meta| { + const next_message_len = length_meta.@"1"; + // if this isn't true, then we have a full message and it + // should have been processed. + std.debug.assert(next_message_len > partial_bytes); + + const missing_bytes = next_message_len - partial_bytes; + + const free_space = self.buf.len - len; + if (missing_bytes < free_space) { + // we have enough space in our buffer, as is, + return; + } + } + + // We're here because we either don't have enough bytes of the next + // message, or we know that it won't fit in our buffer as-is. + std.mem.copyForwards(u8, &self.buf, partial); + self.pos = 0; + self.len = partial_bytes; + } }; const Message = struct { @@ -1138,7 +1200,7 @@ test "Client: http valid handshake" { "sec-websocket-key: this is my key\r\n" ++ "Custom: Header-Value\r\n\r\n"; - @memcpy(client.read_buf[0..request.len], request); + @memcpy(client.reader.buf[0..request.len], request); try testing.expectEqual(true, try client.processData(request.len)); try testing.expectEqual(.websocket, client.mode); @@ -1159,7 +1221,7 @@ test "Client: http get json version" { const request = "GET /json/version HTTP/1.1\r\n\r\n"; - @memcpy(client.read_buf[0..request.len], request); + @memcpy(client.reader.buf[0..request.len], request); try testing.expectEqual(true, try client.processData(request.len)); try testing.expectEqual(.http, client.mode); @@ -1193,21 +1255,21 @@ test "Client: read invalid websocket message" { error.InvalidMessageType, 1002, "", - &.{ 131, 1 }, // 128 (fin) | 3 where 3 isn't a valid type + &.{ 131, 128, 'm', 'a', 's', 'k' }, // 128 (fin) | 3 where 3 isn't a valid type ); try assertWebSocketError( error.ContinuationNotSupported, 1003, "", - &.{ 128, 1 }, // 128 (fin) | 0 where 0 is a continuation frame + &.{ 128, 128, 'm', 'a', 's', 'k' }, // 128 (fin) | 0 where 0 is a continuation frame ); try assertWebSocketError( error.ContinuationNotSupported, 1003, "", - &.{ 1, 1 }, // 0 (non-fin) | 1 non-fin (contination) not supported + &.{ 1, 128, 'm', 'a', 's', 'k' }, // 0 (non-fin) | 1 non-fin (contination) not supported ); for ([_]u8{ 16, 32, 64 }) |rsv| { @@ -1216,7 +1278,7 @@ test "Client: read invalid websocket message" { error.ReservedFlags, 1002, "", - &.{ rsv, 0 }, + &.{ rsv, 128, 'm', 'a', 's', 'k' }, ); // as a bitmask @@ -1224,7 +1286,7 @@ test "Client: read invalid websocket message" { error.ReservedFlags, 1002, "", - &.{ rsv + 4, 0 }, + &.{ rsv + 4, 128, 'm', 'a', 's', 'k' }, ); } @@ -1232,14 +1294,14 @@ test "Client: read invalid websocket message" { error.NotMasked, 1002, "", - &.{ 129, 127 }, // client->server messages must be masked + &.{ 129, 1, 'a' }, // client->server messages must be masked ); try assertWebSocketError( error.TooLarge, 1009, "", - &.{ 129, 255, 0, 0, 0, 0, 0, 4, 0, 1 }, // 1024 * 256 + 1 + &.{ 129, 255, 0, 0, 0, 0, 0, 4, 0, 1, 'm', 'a', 's', 'k' }, // 1024 * 256 + 1 ); } @@ -1427,7 +1489,7 @@ fn assertHTTPError( defer ms.deinit(); var client = Client(*MockServer).init(0, &ms); - @memcpy(client.read_buf[0..input.len], input); + @memcpy(client.reader.buf[0..input.len], input); try testing.expectError(expected_error, client.processData(input.len)); const expected_response = std.fmt.comptimePrint( @@ -1451,7 +1513,7 @@ fn assertWebSocketError( var client = Client(*MockServer).init(0, &ms); client.mode = .websocket; // force websocket message processing - @memcpy(client.read_buf[0..input.len], input); + @memcpy(client.reader.buf[0..input.len], input); try testing.expectError(expected_error, client.processData(input.len)); try testing.expectEqual(1, ms.sent.items.len); @@ -1481,7 +1543,7 @@ fn assertWebSocketMessage( var client = Client(*MockServer).init(0, &ms); client.mode = .websocket; // force websocket message processing - @memcpy(client.read_buf[0..input.len], input); + @memcpy(client.reader.buf[0..input.len], input); const more = try client.processData(input.len); try testing.expectEqual(1, ms.sent.items.len);