Ensure completions are executed on the currently connected client

For the time being, given that we only allow 1 client at a time, I took a
shortcut to implement this. The server has an incrementing "current_client_id"
which is part of every completion. On completion callback, we just check if
its client_id is still equal to the server's current_client_id.
This commit is contained in:
Karl Seguin
2025-02-21 09:30:45 +08:00
parent 09505dba09
commit 756d6620cc

View File

@@ -52,6 +52,7 @@ pub const Client = ClientT(*Server, CDP);
const Server = struct { const Server = struct {
allocator: Allocator, allocator: Allocator,
loop: *jsruntime.Loop, loop: *jsruntime.Loop,
current_client_id: usize = 0,
// internal fields // internal fields
listener: posix.socket_t, listener: posix.socket_t,
@@ -64,10 +65,9 @@ const Server = struct {
// a memory poor for our Clietns // a memory poor for our Clietns
client_pool: std.heap.MemoryPool(Client), client_pool: std.heap.MemoryPool(Client),
timeout_completion_pool: std.heap.MemoryPool(Completion), completion_state_pool: std.heap.MemoryPool(CompletionState),
// I/O fields // I/O fields
conn_completion: Completion,
close_completion: Completion, close_completion: Completion,
accept_completion: Completion, accept_completion: Completion,
@@ -77,7 +77,7 @@ const Server = struct {
fn deinit(self: *Server) void { fn deinit(self: *Server) void {
self.send_pool.deinit(); self.send_pool.deinit();
self.client_pool.deinit(); self.client_pool.deinit();
self.timeout_completion_pool.deinit(); self.completion_state_pool.deinit();
self.allocator.free(self.json_version_response); self.allocator.free(self.json_version_response);
} }
@@ -99,40 +99,37 @@ const Server = struct {
) void { ) void {
std.debug.assert(self.client == null); std.debug.assert(self.client == null);
std.debug.assert(completion == &self.accept_completion); std.debug.assert(completion == &self.accept_completion);
self.doCallbackAccept(result) catch |err| {
const socket = result catch |err| {
log.err("accept error: {any}", .{err}); log.err("accept error: {any}", .{err});
self.queueAccept(); self.queueAccept();
return;
}; };
}
const client = self.client_pool.create() catch |err| { fn doCallbackAccept(
log.err("failed to create client: {any}", .{err}); self: *Server,
posix.close(socket); result: AcceptError!posix.socket_t,
return; ) !void {
}; const socket = try result;
const client = try self.client_pool.create();
errdefer self.client_pool.destroy(client); errdefer self.client_pool.destroy(client);
self.current_client_id += 1;
client.* = Client.init(socket, self); client.* = Client.init(socket, self);
self.client = client; self.client = client;
log.info("client connected", .{}); log.info("client connected", .{});
self.queueRead(); try self.queueRead();
self.queueTimeout(); try self.queueTimeout();
} }
fn queueTimeout(self: *Server) void { fn queueTimeout(self: *Server) !void {
const completion = self.timeout_completion_pool.create() catch |err| { const cs = try self.createCompletionState();
log.err("failed to create timeout completion: {any}", .{err});
return;
};
self.loop.io.timeout( self.loop.io.timeout(
*Server, *Server,
self, self,
callbackTimeout, callbackTimeout,
completion, &cs.completion,
TimeoutCheck, TimeoutCheck,
); );
} }
@@ -142,7 +139,16 @@ const Server = struct {
completion: *Completion, completion: *Completion,
result: TimeoutError!void, result: TimeoutError!void,
) void { ) void {
self.timeout_completion_pool.destroy(completion); const cs: *CompletionState = @alignCast(
@fieldParentPtr("completion", completion),
);
defer self.completion_state_pool.destroy(cs);
if (cs.client_id != self.current_client_id) {
// completion for a previously-connected client
return;
}
const client = self.client orelse return; const client = self.client orelse return;
if (result) |_| { if (result) |_| {
@@ -160,20 +166,23 @@ const Server = struct {
// very unlikely IO timeout error. // very unlikely IO timeout error.
// AKA: we don't requeue this if the connection timed out and we // AKA: we don't requeue this if the connection timed out and we
// closed the connection.s // closed the connection.s
self.queueTimeout(); self.queueTimeout() catch |err| {
log.err("queueTimeout error: {any}", .{err});
};
} }
fn queueRead(self: *Server) void { fn queueRead(self: *Server) !void {
if (self.client) |client| { var client = self.client orelse return;
self.loop.io.recv(
*Server, const cs = try self.createCompletionState();
self, self.loop.io.recv(
callbackRead, *Server,
&self.conn_completion, self,
client.socket, callbackRead,
client.readBuf(), &cs.completion,
); client.socket,
} client.readBuf(),
);
} }
fn callbackRead( fn callbackRead(
@@ -181,7 +190,15 @@ const Server = struct {
completion: *Completion, completion: *Completion,
result: RecvError!usize, result: RecvError!usize,
) void { ) void {
std.debug.assert(completion == &self.conn_completion); const cs: *CompletionState = @alignCast(
@fieldParentPtr("completion", completion),
);
defer self.completion_state_pool.destroy(cs);
if (cs.client_id != self.current_client_id) {
// completion for a previously-connected client
return;
}
var client = self.client orelse return; var client = self.client orelse return;
@@ -205,7 +222,10 @@ const Server = struct {
// if more == false, the client is disconnecting // if more == false, the client is disconnecting
if (more) { if (more) {
self.queueRead(); self.queueRead() catch |err| {
log.err("queueRead error: {any}", .{err});
client.close(null);
};
} }
} }
@@ -218,12 +238,15 @@ const Server = struct {
const sd = try self.send_pool.create(); const sd = try self.send_pool.create();
errdefer self.send_pool.destroy(sd); errdefer self.send_pool.destroy(sd);
const cs = try self.createCompletionState();
errdefer self.completion_state_pool.destroy(cs);
sd.* = .{ sd.* = .{
.unsent = data, .unsent = data,
.server = self, .server = self,
.socket = socket, .socket = socket,
.completion = undefined,
.arena = arena, .arena = arena,
.completion_state = cs,
}; };
sd.queueSend(); sd.queueSend();
} }
@@ -246,6 +269,18 @@ const Server = struct {
std.debug.assert(completion == &self.close_completion); std.debug.assert(completion == &self.close_completion);
self.queueAccept(); self.queueAccept();
} }
fn createCompletionState(self: *Server) !*CompletionState {
var cs = try self.completion_state_pool.create();
cs.client_id = self.current_client_id;
cs.completion = undefined;
return cs;
}
};
const CompletionState = struct {
client_id: usize,
completion: Completion,
}; };
// I/O Send // I/O Send
@@ -259,17 +294,19 @@ const Send = struct { // Any unsent data we have.
unsent: []const u8, unsent: []const u8,
server: *Server, server: *Server,
completion: Completion,
socket: posix.socket_t, socket: posix.socket_t,
completion_state: *CompletionState,
// If we need to free anything when we're done // If we need to free anything when we're done
arena: ?ArenaAllocator, arena: ?ArenaAllocator,
fn deinit(self: *Send) void { fn deinit(self: *Send) void {
var server = self.server;
if (self.arena) |arena| { if (self.arena) |arena| {
arena.deinit(); arena.deinit();
} }
var server = self.server;
server.completion_state_pool.destroy(self.completion_state);
server.send_pool.destroy(self); server.send_pool.destroy(self);
} }
@@ -278,16 +315,25 @@ const Send = struct { // Any unsent data we have.
*Send, *Send,
self, self,
sendCallback, sendCallback,
&self.completion, &self.completion_state.completion,
self.socket, self.socket,
self.unsent, self.unsent,
); );
} }
fn sendCallback(self: *Send, _: *Completion, result: SendError!usize) void { fn sendCallback(self: *Send, _: *Completion, result: SendError!usize) void {
const server = self.server;
const cs = self.completion_state;
if (cs.client_id != server.current_client_id) {
// completion for a previously-connected client
self.deinit();
return;
}
const sent = result catch |err| { const sent = result catch |err| {
log.info("send error: {any}", .{err}); log.info("send error: {any}", .{err});
if (self.server.client) |client| { if (server.client) |client| {
client.close(null); client.close(null);
} }
self.deinit(); self.deinit();
@@ -1011,13 +1057,12 @@ pub fn run(
.timeout = timeout, .timeout = timeout,
.listener = listener, .listener = listener,
.allocator = allocator, .allocator = allocator,
.conn_completion = undefined,
.close_completion = undefined, .close_completion = undefined,
.accept_completion = undefined, .accept_completion = undefined,
.json_version_response = json_version_response, .json_version_response = json_version_response,
.send_pool = std.heap.MemoryPool(Send).init(allocator), .send_pool = std.heap.MemoryPool(Send).init(allocator),
.client_pool = std.heap.MemoryPool(Client).init(allocator), .client_pool = std.heap.MemoryPool(Client).init(allocator),
.timeout_completion_pool = std.heap.MemoryPool(Completion).init(allocator), .completion_state_pool = std.heap.MemoryPool(CompletionState).init(allocator),
}; };
defer server.deinit(); defer server.deinit();