diff --git a/src/server.zig b/src/server.zig index 7ed8dc53..13a585cf 100644 --- a/src/server.zig +++ b/src/server.zig @@ -52,6 +52,7 @@ pub const Client = ClientT(*Server, CDP); const Server = struct { allocator: Allocator, loop: *jsruntime.Loop, + current_client_id: usize = 0, // internal fields listener: posix.socket_t, @@ -64,10 +65,9 @@ const Server = struct { // a memory poor for our Clietns client_pool: std.heap.MemoryPool(Client), - timeout_completion_pool: std.heap.MemoryPool(Completion), + completion_state_pool: std.heap.MemoryPool(CompletionState), // I/O fields - conn_completion: Completion, close_completion: Completion, accept_completion: Completion, @@ -77,7 +77,7 @@ const Server = struct { fn deinit(self: *Server) void { self.send_pool.deinit(); self.client_pool.deinit(); - self.timeout_completion_pool.deinit(); + self.completion_state_pool.deinit(); self.allocator.free(self.json_version_response); } @@ -99,40 +99,37 @@ const Server = struct { ) void { std.debug.assert(self.client == null); std.debug.assert(completion == &self.accept_completion); - - const socket = result catch |err| { + self.doCallbackAccept(result) catch |err| { log.err("accept error: {any}", .{err}); self.queueAccept(); - return; }; + } - const client = self.client_pool.create() catch |err| { - log.err("failed to create client: {any}", .{err}); - posix.close(socket); - return; - }; + fn doCallbackAccept( + self: *Server, + result: AcceptError!posix.socket_t, + ) !void { + const socket = try result; + const client = try self.client_pool.create(); errdefer self.client_pool.destroy(client); + self.current_client_id += 1; client.* = Client.init(socket, self); self.client = client; log.info("client connected", .{}); - self.queueRead(); - self.queueTimeout(); + try self.queueRead(); + try self.queueTimeout(); } - fn queueTimeout(self: *Server) void { - const completion = self.timeout_completion_pool.create() catch |err| { - log.err("failed to create timeout completion: {any}", .{err}); - return; - }; - + fn queueTimeout(self: *Server) !void { + const cs = try self.createCompletionState(); self.loop.io.timeout( *Server, self, callbackTimeout, - completion, + &cs.completion, TimeoutCheck, ); } @@ -142,7 +139,16 @@ const Server = struct { completion: *Completion, result: TimeoutError!void, ) void { - self.timeout_completion_pool.destroy(completion); + const cs: *CompletionState = @alignCast( + @fieldParentPtr("completion", completion), + ); + defer self.completion_state_pool.destroy(cs); + + if (cs.client_id != self.current_client_id) { + // completion for a previously-connected client + return; + } + const client = self.client orelse return; if (result) |_| { @@ -160,20 +166,23 @@ const Server = struct { // very unlikely IO timeout error. // AKA: we don't requeue this if the connection timed out and we // closed the connection.s - self.queueTimeout(); + self.queueTimeout() catch |err| { + log.err("queueTimeout error: {any}", .{err}); + }; } - fn queueRead(self: *Server) void { - if (self.client) |client| { - self.loop.io.recv( - *Server, - self, - callbackRead, - &self.conn_completion, - client.socket, - client.readBuf(), - ); - } + fn queueRead(self: *Server) !void { + var client = self.client orelse return; + + const cs = try self.createCompletionState(); + self.loop.io.recv( + *Server, + self, + callbackRead, + &cs.completion, + client.socket, + client.readBuf(), + ); } fn callbackRead( @@ -181,7 +190,15 @@ const Server = struct { completion: *Completion, result: RecvError!usize, ) void { - std.debug.assert(completion == &self.conn_completion); + const cs: *CompletionState = @alignCast( + @fieldParentPtr("completion", completion), + ); + defer self.completion_state_pool.destroy(cs); + + if (cs.client_id != self.current_client_id) { + // completion for a previously-connected client + return; + } var client = self.client orelse return; @@ -205,7 +222,10 @@ const Server = struct { // if more == false, the client is disconnecting if (more) { - self.queueRead(); + self.queueRead() catch |err| { + log.err("queueRead error: {any}", .{err}); + client.close(null); + }; } } @@ -218,12 +238,15 @@ const Server = struct { const sd = try self.send_pool.create(); errdefer self.send_pool.destroy(sd); + const cs = try self.createCompletionState(); + errdefer self.completion_state_pool.destroy(cs); + sd.* = .{ .unsent = data, .server = self, .socket = socket, - .completion = undefined, .arena = arena, + .completion_state = cs, }; sd.queueSend(); } @@ -246,6 +269,18 @@ const Server = struct { std.debug.assert(completion == &self.close_completion); self.queueAccept(); } + + fn createCompletionState(self: *Server) !*CompletionState { + var cs = try self.completion_state_pool.create(); + cs.client_id = self.current_client_id; + cs.completion = undefined; + return cs; + } +}; + +const CompletionState = struct { + client_id: usize, + completion: Completion, }; // I/O Send @@ -259,17 +294,19 @@ const Send = struct { // Any unsent data we have. unsent: []const u8, server: *Server, - completion: Completion, socket: posix.socket_t, + completion_state: *CompletionState, // If we need to free anything when we're done arena: ?ArenaAllocator, fn deinit(self: *Send) void { - var server = self.server; if (self.arena) |arena| { arena.deinit(); } + + var server = self.server; + server.completion_state_pool.destroy(self.completion_state); server.send_pool.destroy(self); } @@ -278,16 +315,25 @@ const Send = struct { // Any unsent data we have. *Send, self, sendCallback, - &self.completion, + &self.completion_state.completion, self.socket, self.unsent, ); } fn sendCallback(self: *Send, _: *Completion, result: SendError!usize) void { + const server = self.server; + const cs = self.completion_state; + + if (cs.client_id != server.current_client_id) { + // completion for a previously-connected client + self.deinit(); + return; + } + const sent = result catch |err| { log.info("send error: {any}", .{err}); - if (self.server.client) |client| { + if (server.client) |client| { client.close(null); } self.deinit(); @@ -1011,13 +1057,12 @@ pub fn run( .timeout = timeout, .listener = listener, .allocator = allocator, - .conn_completion = undefined, .close_completion = undefined, .accept_completion = undefined, .json_version_response = json_version_response, .send_pool = std.heap.MemoryPool(Send).init(allocator), .client_pool = std.heap.MemoryPool(Client).init(allocator), - .timeout_completion_pool = std.heap.MemoryPool(Completion).init(allocator), + .completion_state_pool = std.heap.MemoryPool(CompletionState).init(allocator), }; defer server.deinit();