From 2f362f2aa25695556578933fde05b8bd0fe7690d Mon Sep 17 00:00:00 2001 From: Karl Seguin Date: Wed, 19 Mar 2025 17:33:09 +0800 Subject: [PATCH] handle redirects on asynchronous calls --- build.zig.zon | 2 +- src/http/client.zig | 398 +++++++++++++++++++++++++++++++------------- src/xhr/xhr.zig | 2 +- 3 files changed, 280 insertions(+), 122 deletions(-) diff --git a/build.zig.zon b/build.zig.zon index 400c304e..7540c493 100644 --- a/build.zig.zon +++ b/build.zig.zon @@ -4,6 +4,6 @@ .version = "0.0.0", .fingerprint = 0xda130f3af836cea0, .dependencies = .{ - .tls = .{ .url = "https://github.com/ianic/tls.zig/archive/6c36f8c39aeefa9e469b7eaf55a40b39a04d18c3.tar.gz", .hash = "122039cd3abe387b69d23930bf12154c2c84fc894874e10129a1fc5e8ac75ca0ddc0" }, + .tls = .{ .url = "https://github.com/ianic/tls.zig/archive/96b923fcdaa6371617154857cef7b8337778cbe2.tar.gz", .hash = "122031f94565d7420a155b6eaec65aaa02acc80e75e6f0947899be2106bc3055b1ec" }, }, } diff --git a/src/http/client.zig b/src/http/client.zig index c165423f..0a80142d 100644 --- a/src/http/client.zig +++ b/src/http/client.zig @@ -188,24 +188,20 @@ pub const Request = struct { // Makes an synchronous request pub fn sendSync(self: *Request, _: SendSyncOpts) anyerror!Response { try self.prepareInitialSend(); - const socket, const address = try self.createSocket(true); - var handler = SyncHandler{ .request = self }; - return handler.send(socket, address) catch |err| { - log.warn("HTTP error: {any} ({any} {any})", .{ err, self.method, self.uri }); - return err; - }; + return self.doSendSync(); } // Called internally, follows a redirect. fn redirectSync(self: *Request, redirect: Reader.Redirect) anyerror!Response { - posix.close(self._socket.?); - self._socket = null; - try self.prepareToRedirect(redirect); + return self.doSendSync(); + } + + fn doSendSync(self: *Request) anyerror!Response { const socket, const address = try self.createSocket(true); var handler = SyncHandler{ .request = self }; return handler.send(socket, address) catch |err| { - log.warn("HTTP error: {any} ({any} {any} redirect)", .{ err, self.method, self.uri }); + log.warn("HTTP error: {any} ({any} {any} {d})", .{ err, self.method, self.uri, self._redirect_count }); return err; }; } @@ -214,9 +210,16 @@ pub const Request = struct { // Makes an asynchronous request pub fn sendAsync(self: *Request, loop: anytype, handler: anytype, _: SendAsyncOpts) !void { try self.prepareInitialSend(); + return self.doSendAsync(loop, handler); + } + pub fn redirectAsync(self: *Request, redirect: Reader.Redirect, loop: anytype, handler: anytype) !void { + try self.prepareToRedirect(redirect); + return self.doSendAsync(loop, handler); + } + + fn doSendAsync(self: *Request, loop: anytype, handler: anytype) !void { // TODO: change this to nonblocking (false) when we have promise resolution const socket, const address = try self.createSocket(true); - const AsyncHandlerT = AsyncHandler(@TypeOf(handler), @TypeOf(loop)); const async_handler = try self.arena.create(AsyncHandlerT); @@ -225,22 +228,22 @@ pub const Request = struct { .socket = socket, .request = self, .handler = handler, - .tls_conn = null, .read_buf = self._state.buf, .reader = Reader.init(self._state), + .connection = .{ .handler = async_handler, .protocol = .{ .plain = {} } }, }; + if (self.secure) { - async_handler.tls_conn = try tls.asyn.Client(AsyncHandlerT.TLSHandler).init(self.arena, .{ .handler = async_handler }, .{ + async_handler.connection.protocol = .{ .tls_client = try tls.asyn.Client(AsyncHandlerT.TLSHandler).init(self.arena, .{ .handler = async_handler }, .{ .host = self.host(), .root_ca = self._client.root_ca, - }); + }) }; } loop.connect(AsyncHandlerT, async_handler, &async_handler.read_completion, AsyncHandlerT.connected, socket, address); } - // Does additional setup of the request for the firsts (i.e. non-redirect) - // call. + // Does additional setup of the request for the firsts (i.e. non-redirect) call. fn prepareInitialSend(self: *Request) !void { try self.verifyUri(); @@ -257,6 +260,13 @@ pub const Request = struct { // Sets up the request for redirecting. fn prepareToRedirect(self: *Request, redirect: Reader.Redirect) !void { + posix.close(self._socket.?); + self._socket = null; + + // CANNOT reset the arena (╥﹏╥) + // We need it for self.uri (which we're about to use to resolve + // redirect.location, and it might own some/all headers) + const redirect_count = self._redirect_count; if (redirect_count == 10) { return error.TooManyRedirects; @@ -394,7 +404,14 @@ fn AsyncHandler(comptime H: type, comptime L: type) type { // Used to help us know if we're writing the header or the body; state: SendState = .handshake, - tls_conn: ?tls.asyn.Client(TLSHandler) = null, + // Abstraction over TLS and plain text socket + connection: Connection, + + // This will be != null when we're supposed to redirect AND we've + // drained the response body. We need this as a field, because we'll + // detect this inside our TLS onRecv callback (which is executed + // inside the TLS client, and so we can't deinitialize the tls_client) + redirect: ?Reader.Redirect = null, const Self = @This(); const SendQueue = std.DoublyLinkedList([]const u8); @@ -405,30 +422,22 @@ fn AsyncHandler(comptime H: type, comptime L: type) type { body, }; + const ProcessStatus = enum { + done, + need_more, + }; + fn deinit(self: *Self) void { - if (self.tls_conn) |*tls_conn| { - tls_conn.deinit(); - } + self.connection.deinit(); self.request.deinit(); } fn connected(self: *Self, _: *IO.Completion, result: IO.ConnectError!void) void { self.loop.onConnect(result); result catch |err| return self.handleError("Connection failed", err); - - if (self.tls_conn) |*tls_conn| { - tls_conn.onConnect() catch |err| { - self.handleError("TLS handshake error", err); - }; - self.receive(); - return; - } - - self.state = .header; - const header = self.request.buildHeader() catch |err| { - return self.handleError("out of memory", err); + self.connection.connected() catch |err| { + self.handleError("connected handler error", err); }; - self.send(header); } fn send(self: *Self, data: []const u8) void { @@ -483,43 +492,12 @@ fn AsyncHandler(comptime H: type, comptime L: type) type { return; } - if (self.state == .handshake) {} - - switch (self.state) { - .handshake => { - // We're still doing our handshake. We need to wait until - // that's finished before sending the header. We might have - // more to send until then, but it'll be triggered by the - // TLS layer. - std.debug.assert(self.tls_conn != null); - }, - .body => { - // We've finished sending the body. - if (self.tls_conn == null) { - // if we aren't using TLS, then we need to start the recive loop - self.receive(); - } - }, - .header => { - // We've sent the header, we should send the body. - self.state = .body; - if (self.request.body) |body| { - if (self.tls_conn) |*tls_conn| { - tls_conn.send(body) catch |err| { - self.handleError("TLS send", err); - }; - } else { - self.send(body); - } - } else if (self.tls_conn == null) { - // start receiving the reply - self.receive(); - } - }, - } + self.connection.sent(self.state) catch |err| { + self.handleError("Processing sent data", err); + }; } - // Normally, you'd thin of HTTP as being a straight up request-response + // Normally, you'd think of HTTP as being a straight up request-response // and that we can send, and then receive. But with TLS, we need to receive // while handshaking and potentially while sending data. So we're always // receiving. @@ -544,51 +522,65 @@ fn AsyncHandler(comptime H: type, comptime L: type) type { return self.handleError("Connection closed", error.ConnectionResetByPeer); } - if (self.tls_conn) |*tls_conn| { - const pos = self.read_pos; - const end = pos + n; - const used = tls_conn.onRecv(self.read_buf[0..end]) catch |err| switch (err) { - error.Done => return self.deinit(), - else => { - self.handleError("TLS decrypt", err); - return; - }, - }; - if (used == end) { - self.read_pos = 0; - } else if (used == 0) { - self.read_pos = end; - } else { - const extra = end - used; - std.mem.copyForwards(u8, self.read_buf, self.read_buf[extra..end]); - self.read_pos = extra; - } - self.receive(); + const status = self.connection.received(n) catch |err| { + self.handleError("data processing", err); return; - } + }; - if (self.processData(self.read_buf[0..n]) == false) { - // we're done - self.deinit(); - } else { - // we're not done, need more data - self.receive(); + switch (status) { + .need_more => self.receive(), + .done => { + const redirect = self.redirect orelse { + self.deinit(); + return; + }; + self.request.redirectAsync(redirect, self.loop, self.handler) catch |err| { + self.handleError("Setup async redirect", err); + return; + }; + // redirectAsync has given up any claim to the request, + // including the socket. We just need to clean up our + // tls_client. + self.connection.deinit(); + }, } } - fn processData(self: *Self, d: []u8) bool { + fn processData(self: *Self, d: []u8) ProcessStatus { const reader = &self.reader; var data = d; while (true) { - const would_be_first = reader.response.status == 0; + const would_be_first = reader.header_done == false; const result = reader.process(data) catch |err| { self.handleError("Invalid server response", err); - return false; + return .done; }; - const done = result.done; - if (result.data != null or done or (would_be_first and reader.response.status > 0)) { + if (reader.header_done == false) { + // need more data + return .need_more; + } + + // at this point, If `would_be_first == true`, then + // `would_be_first` should be thought of as `is_first` because + + if (reader.redirect()) |redirect| { + // We don't redirect until we've drained the body (because, + // if we ever add keepalive, we'll re-use the connection). + // Calling `reader.redirect()` over and over again might not + // be the most efficient (it's a very simple function though), + // but for a redirect resposne, chances are we slurped up + // the header and body in a single go. + if (result.done == false) { + return .need_more; + } + self.redirect = redirect; + return .done; + } + + const done = result.done; + if (result.data != null or done or would_be_first) { // If we have data. Or if the request is done. Or if this is the // first time we have a complete header. Emit the chunk. self.handler.onHttpResponse(.{ @@ -596,28 +588,164 @@ fn AsyncHandler(comptime H: type, comptime L: type) type { .data = result.data, .first = would_be_first, .header = reader.response, - }) catch return false; + }) catch return .done; } if (done == true) { - return false; + return .need_more; } // With chunked-encoding, it's possible that we we've only // partially processed the data. So we need to keep processing // any unprocessed data. It would be nice if we could just glue // this all together, but that would require copying bytes around - data = result.unprocessed orelse break; + data = result.unprocessed orelse return .need_more; } - return true; } fn handleError(self: *Self, comptime msg: []const u8, err: anyerror) void { log.warn(msg ++ ": {any} ({any} {any})", .{ err, self.request.method, self.request.uri }); - self.handler.onHttpResponse(error.Failed) catch {}; + self.handler.onHttpResponse(err) catch {}; self.deinit(); } + const Connection = struct { + handler: *Self, + protocol: Protocol, + + const Protocol = union(enum) { + plain: void, + tls_client: tls.asyn.Client(TLSHandler), + }; + + fn deinit(self: *Connection) void { + switch (self.protocol) { + .tls_client => |*tls_client| tls_client.deinit(), + .plain => {}, + } + } + + fn connected(self: *Connection) !void { + const handler = self.handler; + + switch (self.protocol) { + .tls_client => |*tls_client| { + try tls_client.onConnect(); + // when TLS is active, from a network point of view + // it's no longer a strict REQ->RES. We pretty much + // have to constantly receive data (e.g. to process + // the handshake) + handler.receive(); + }, + .plain => { + handler.state = .header; + const header = try handler.request.buildHeader(); + return handler.send(header); + }, + } + } + + fn sent(self: *Connection, state: SendState) !void { + const handler = self.handler; + std.debug.assert(handler.state == state); + + switch (self.protocol) { + .tls_client => |*tls_client| { + switch (state) { + .handshake => { + // Our send is complete, but it was part of the + // TLS handshake. This isn't data we need to + // worry about. + }, + .header => { + // we WERE sending the header, but that's done + handler.state = .body; + if (handler.request.body) |body| { + try tls_client.send(body); + } + }, + .body => { + // We've finished sending the body. For non TLS + // we'll start receiving. But here, for TLS, + // we started a receive loop as soon as the c + // connection was established. + }, + } + }, + .plain => { + switch (state) { + .handshake => unreachable, + .header => { + // we WERE sending the header, but that's done + handler.state = .body; + if (handler.request.body) |body| { + handler.send(body); + } else { + // No body? time to start reading the response + handler.receive(); + } + }, + .body => { + // we're done sending the body, time to start + // reading the response + handler.receive(); + }, + } + }, + } + } + + fn received(self: *Connection, n: usize) !ProcessStatus { + const handler = self.handler; + const read_buf = handler.read_buf; + switch (self.protocol) { + .tls_client => |*tls_client| { + // The read on TLS is stateful, since we need a full + // TLS record to get cleartext data. + const pos = handler.read_pos; + const end = pos + n; + + const used = tls_client.onRecv(read_buf[0..end]) catch |err| switch (err) { + // https://github.com/ianic/tls.zig/pull/9 + // we currently have no way to break out of the TLS handling + // loop, except for returning an error. + error.TLSHandlerDone => return .done, + error.EndOfFile => return .done, // TLS close + else => return err, + }; + + // When we tell our TLS client that we've received data + // there are three possibilities: + + if (used == end) { + // 1 - It used up all the data that we gave it + handler.read_pos = 0; + } else if (used == 0) { + // 2 - It didn't use any of the data (i.e there + // wasn't a full record) + handler.read_pos = end; + } else { + // 3 - It used some of the data, but had leftover + // (i.e. there was 1+ full records AND an incomplete + // record). We need to maintain the "leftover" data + // for subsequent reads. + const extra = end - used; + std.mem.copyForwards(u8, read_buf, read_buf[extra..end]); + handler.read_pos = extra; + } + + // Remember that our read_buf is the MAX possible TLS + // record size. So as long as we make sure that the start + // of a record is at read_buf[0], we know that we'll + // always have enough space for 1 record. + + return .need_more; + }, + .plain => return handler.processData(read_buf[0..n]), + } + } + }; + // Separate struct just to keep it a bit cleaner. tls.zig requires // callbacks like "onConnect" and "send" which is a bit generic and // is confusing with the AsyncHandler which has similar concepts. @@ -632,7 +760,7 @@ fn AsyncHandler(comptime H: type, comptime L: type) type { return handler.handleError("out of memory", err); }; handler.state = .header; - handler.tls_conn.?.send(header) catch |err| { + handler.connection.protocol.tls_client.send(header) catch |err| { return handler.handleError("TLS send", err); }; } @@ -644,15 +772,17 @@ fn AsyncHandler(comptime H: type, comptime L: type) type { // tls.zig received data, it's giving it to us in plaintext pub fn onRecv(self: TLSHandler, data: []u8) !void { - if (self.handler.state != .body) { + const handler = self.handler; + if (handler.state != .body) { // We should not receive application-level data (which is the // only data tls.zig will give us), if our handler hasn't sent // the body. - self.handler.handleError("Premature server response", error.InvalidServerResonse); + handler.handleError("Premature server response", error.InvalidServerResonse); return error.InvalidServerResonse; } - if (self.handler.processData(data) == false) { - return error.Done; + switch (handler.processData(data)) { + .need_more => {}, + .done => return error.TLSHandlerDone, // https://github.com/ianic/tls.zig/pull/9 } } }; @@ -694,7 +824,7 @@ const SyncHandler = struct { continue; } - if (try reader.redirect()) |redirect| { + if (reader.redirect()) |redirect| { if (result.done == false) { try self.drain(&reader, &connection, result.unprocessed); } @@ -826,7 +956,7 @@ const Reader = struct { } // Determines if we need to redirect - fn redirect(self: *const Reader) !?Redirect { + fn redirect(self: *const Reader) ?Redirect { const use_get = switch (self.response.status) { 201, 301, 302, 303 => true, 307, 308 => false, @@ -1336,10 +1466,6 @@ const State = struct { } }; -pub const AsyncError = error{ - Failed, -}; - const StatePool = struct { states: []*State, available: usize, @@ -1570,9 +1696,9 @@ test "HttpClient: async connect error" { const Handler = struct { reset: *Thread.ResetEvent, - fn onHttpResponse(self: *@This(), res: AsyncError!Progress) !void { + fn onHttpResponse(self: *@This(), res: anyerror!Progress) !void { _ = res catch |err| { - if (err == error.Failed) { + if (err == error.ConnectionRefused) { self.reset.set(); return; } @@ -1637,6 +1763,38 @@ test "HttpClient: async with body" { }); } +test "HttpClient: async redirect" { + var client = try Client.init(testing.allocator, 2); + defer client.deinit(); + + var handler = try CaptureHandler.init(); + defer handler.deinit(); + + var loop = try jsruntime.Loop.init(testing.allocator); + defer loop.deinit(); + + var req = try client.request(.GET, "HTTP://127.0.0.1:9582/http_client/redirect"); + try req.sendAsync(&handler.loop, &handler, .{}); + + // Called twice on purpose. The initial GET resutls in the # of pending + // events to reach 0. This causes our `run_for_ns` to return. But we then + // start to requeue events (from the redirected request), so we need the + //loop to process those also. + try handler.loop.io.run_for_ns(std.time.ns_per_ms); + try handler.loop.io.run_for_ns(std.time.ns_per_ms); + try handler.reset.timedWait(std.time.ns_per_s); + + const res = handler.response; + try testing.expectEqual("over 9000!", res.body.items); + try testing.expectEqual(201, res.status); + try res.assertHeaders(&.{ + "connection", "close", + "content-length", "10", + "_host", "127.0.0.1", + "_connection", "Close", + }); +} + const TestResponse = struct { status: u16, keepalive: ?bool, @@ -1704,13 +1862,13 @@ const CaptureHandler = struct { self.loop.deinit(); } - fn onHttpResponse(self: *CaptureHandler, progress_: AsyncError!Progress) !void { + fn onHttpResponse(self: *CaptureHandler, progress_: anyerror!Progress) !void { self.process(progress_) catch |err| { std.debug.print("error: {}\n", .{err}); }; } - fn process(self: *CaptureHandler, progress_: AsyncError!Progress) !void { + fn process(self: *CaptureHandler, progress_: anyerror!Progress) !void { const progress = try progress_; const allocator = self.response.arena.allocator(); try self.response.body.appendSlice(allocator, progress.data orelse ""); diff --git a/src/xhr/xhr.zig b/src/xhr/xhr.zig index 53e62bd0..f4e01030 100644 --- a/src/xhr/xhr.zig +++ b/src/xhr/xhr.zig @@ -510,7 +510,7 @@ pub const XMLHttpRequest = struct { try request.sendAsync(loop, self, .{}); } - pub fn onHttpResponse(self: *XMLHttpRequest, progress_: http.AsyncError!http.Progress) !void { + pub fn onHttpResponse(self: *XMLHttpRequest, progress_: anyerror!http.Progress) !void { const progress = progress_ catch |err| { self.onErr(err); return err;