diff --git a/src/browser/page.zig b/src/browser/page.zig index a9a60dc0..e8f29110 100644 --- a/src/browser/page.zig +++ b/src/browser/page.zig @@ -491,7 +491,7 @@ pub const Page = struct { return arr.items; } - fn newHTTPRequest(self: *const Page, method: http.Request.Method, url: *const URL, opts: storage.cookie.LookupOpts) !http.Request { + fn newHTTPRequest(self: *const Page, method: http.Request.Method, url: *const URL, opts: storage.cookie.LookupOpts) !*http.Request { // Don't use the state's request_factory here, since requests made by the // page (i.e. to load ) should not generate notifications. var request = try self.session.browser.http_client.request(method, &url.uri); diff --git a/src/browser/xhr/xhr.zig b/src/browser/xhr/xhr.zig index 594f1a72..04cecb72 100644 --- a/src/browser/xhr/xhr.zig +++ b/src/browser/xhr/xhr.zig @@ -79,7 +79,7 @@ const XMLHttpRequestBodyInit = union(enum) { pub const XMLHttpRequest = struct { proto: XMLHttpRequestEventTarget = XMLHttpRequestEventTarget{}, arena: Allocator, - request: ?http.Request = null, + request: ?*http.Request = null, method: http.Request.Method, state: State, @@ -252,6 +252,13 @@ pub const XMLHttpRequest = struct { }; } + pub fn destructor(self: *XMLHttpRequest) void { + if (self.request) |req| { + req.abort(); + self.request = null; + } + } + pub fn reset(self: *XMLHttpRequest) void { self.url = null; @@ -417,7 +424,7 @@ pub const XMLHttpRequest = struct { self.send_flag = true; self.request = try page.request_factory.create(self.method, &self.url.?.uri); - var request = &self.request.?; + var request = self.request.?; errdefer request.deinit(); for (self.headers.list.items) |hdr| { @@ -452,6 +459,9 @@ pub const XMLHttpRequest = struct { pub fn onHttpResponse(self: *XMLHttpRequest, progress_: anyerror!http.Progress) !void { const progress = progress_ catch |err| { + // The request has been closed internally by the client, it isn't safe + // for us to keep it around. + self.request = null; self.onErr(err); return err; }; @@ -510,6 +520,10 @@ pub const XMLHttpRequest = struct { .status = progress.header.status, }); + // Not that the request is done, the http/client will free the request + // object. It isn't safe to keep it around. + self.request = null; + self.state = .done; self.send_flag = false; self.dispatchEvt("readystatechange"); @@ -532,6 +546,7 @@ pub const XMLHttpRequest = struct { pub fn _abort(self: *XMLHttpRequest) void { self.onErr(DOMError.Abort); + self.destructor(); } pub fn get_responseType(self: *XMLHttpRequest) []const u8 { diff --git a/src/http/client.zig b/src/http/client.zig index 71046c09..4aba5fb2 100644 --- a/src/http/client.zig +++ b/src/http/client.zig @@ -51,6 +51,7 @@ pub const Client = struct { root_ca: tls.config.CertBundle, tls_verify_host: bool = true, connection_manager: ConnectionManager, + request_pool: std.heap.MemoryPool(Request), const Opts = struct { tls_verify_host: bool = true, @@ -76,6 +77,7 @@ pub const Client = struct { .http_proxy = opts.http_proxy, .tls_verify_host = opts.tls_verify_host, .connection_manager = connection_manager, + .request_pool = std.heap.MemoryPool(Request).init(allocator), }; } @@ -86,9 +88,10 @@ pub const Client = struct { } self.state_pool.deinit(allocator); self.connection_manager.deinit(); + self.request_pool.deinit(); } - pub fn request(self: *Client, method: Request.Method, uri: *const Uri) !Request { + pub fn request(self: *Client, method: Request.Method, uri: *const Uri) !*Request { const state = self.state_pool.acquire(); errdefer { @@ -96,7 +99,18 @@ pub const Client = struct { self.state_pool.release(state); } - return Request.init(self, state, method, uri); + // We need the request on the heap, because it can have a longer lifetime + // than the code making the request. That sounds odd, but consider the + // case of an XHR request: it can still be inflight (e.g. waiting for + // the response) when the page gets unloaded. Once the page is unloaded + // the page arena is reset and the XHR instance becomes invalid. If the + // XHR instance owns the `Request`, we'd crash once an async callback + // executes. + const req = try self.request_pool.create(); + errdefer self.request_pool.destroy(req); + + req.* = try Request.init(self, state, method, uri); + return req; } pub fn requestFactory(self: *Client, notification: ?*Notification) RequestFactory { @@ -112,7 +126,7 @@ pub const RequestFactory = struct { client: *Client, notification: ?*Notification, - pub fn create(self: RequestFactory, method: Request.Method, uri: *const Uri) !Request { + pub fn create(self: RequestFactory, method: Request.Method, uri: *const Uri) !*Request { var req = try self.client.request(method, uri); req.notification = self.notification; return req; @@ -244,6 +258,17 @@ pub const Request = struct { // The notifier that we emit request notifications to, if any. notification: ?*Notification, + // Aborting an async request is complicated, as we need to wait until all + // in-flight IO events are completed. Our AsyncHandler is a generic type + // that we don't have the necessary type information for in the Request, + // so we need to rely on anyopaque. + _aborter: ?Aborter, + + const Aborter = struct { + ctx: *anyopaque, + func: *const fn (*anyopaque) void, + }; + pub const Method = enum { GET, PUT, @@ -282,6 +307,7 @@ pub const Request = struct { ._request_host = decomposed.request_host, ._state = state, ._client = client, + ._aborter = null, ._connection = null, ._keepalive = false, ._redirect_count = 0, @@ -297,6 +323,15 @@ pub const Request = struct { self.releaseConnection(); _ = self._state.reset(); self._client.state_pool.release(self._state); + self._client.request_pool.destroy(self); + } + + pub fn abort(self: *Request) void { + const aborter = self._aborter orelse { + self.deinit(); + return; + }; + aborter.func(aborter.ctx); } const DecomposedURL = struct { @@ -544,6 +579,11 @@ pub const Request = struct { return async_handler.conn.connected(); } + self._aborter = .{ + .ctx = async_handler, + .func = AsyncHandlerT.abort, + }; + return loop.connect( AsyncHandlerT, async_handler, @@ -732,13 +772,6 @@ fn AsyncHandler(comptime H: type, comptime L: type) type { // that we have valid, but unprocessed, data up to. read_pos: usize = 0, - // Depending on which version of TLS, there are different places during - // the handshake that we want to start receiving from. We can't have - // overlapping receives (works fine on MacOS (kqueue) but not Linux ( - // io_uring)). Using this boolean as a guard, to make sure we only have - // 1 in-flight receive is easier than trying to understand TLS. - is_receiving: bool = false, - // need a separate read and write buf because, with TLS, messages are // not strictly req->resp. write_buf: []u8, @@ -775,6 +808,13 @@ fn AsyncHandler(comptime H: type, comptime L: type) type { // gzipped responses *cough*) full_body: ?std.ArrayListUnmanaged(u8) = null, + // Shutting down an async request requires that we wait for all inflight + // IO to be completed. So we need to track what inflight requests we + // have and whether or not we're shutting down + shutdown: bool = false, + pending_write: bool = false, + pending_receive: bool = false, + const Self = @This(); const SendQueue = std.DoublyLinkedList([]const u8); @@ -794,6 +834,12 @@ fn AsyncHandler(comptime H: type, comptime L: type) type { self.request.deinit(); } + fn abort(ctx: *anyopaque) void { + var self: *Self = @alignCast(@ptrCast(ctx)); + self.shutdown = true; + self.maybeShutdown(); + } + fn connected(self: *Self, _: *IO.Completion, result: IO.ConnectError!void) void { result catch |err| return self.handleError("Connection failed", err); self.conn.connected() catch |err| { @@ -815,6 +861,7 @@ fn AsyncHandler(comptime H: type, comptime L: type) type { return; } + self.pending_write = true; self.loop.send( Self, self, @@ -828,6 +875,10 @@ fn AsyncHandler(comptime H: type, comptime L: type) type { } fn sent(self: *Self, _: *IO.Completion, n_: IO.SendError!usize) void { + self.pending_write = false; + if (self.shutdown) { + return self.maybeShutdown(); + } const n = n_ catch |err| { return self.handleError("Write error", err); }; @@ -845,6 +896,7 @@ fn AsyncHandler(comptime H: type, comptime L: type) type { } if (next) |next_| { + self.pending_write = true; // we still have data to send self.loop.send( Self, @@ -869,11 +921,11 @@ fn AsyncHandler(comptime H: type, comptime L: type) type { // while handshaking and potentially while sending data. So we're always // receiving. fn receive(self: *Self) void { - if (self.is_receiving) { + if (self.pending_receive) { return; } - self.is_receiving = true; + self.pending_receive = true; self.loop.recv( Self, self, @@ -887,7 +939,11 @@ fn AsyncHandler(comptime H: type, comptime L: type) type { } fn received(self: *Self, _: *IO.Completion, n_: IO.RecvError!usize) void { - self.is_receiving = false; + self.pending_receive = false; + if (self.shutdown) { + return self.maybeShutdown(); + } + const n = n_ catch |err| { return self.handleError("Read error", err); }; @@ -926,6 +982,17 @@ fn AsyncHandler(comptime H: type, comptime L: type) type { } } + fn maybeShutdown(self: *Self) void { + std.debug.assert(self.shutdown); + if (self.pending_write or self.pending_receive) { + return; + } + + // Who knows what state we're in, safer to not try to re-use the connection + self.request._keepalive = false; + self.request.deinit(); + } + // If our socket came from the connection pool, it's possible that we're // failing because it's since timed out. If fn maybeRetryRequest(self: *Self) bool {