handle redirects on asynchronous calls

This commit is contained in:
Karl Seguin
2025-03-19 17:33:09 +08:00
parent de160d9170
commit 2f362f2aa2
3 changed files with 280 additions and 122 deletions

View File

@@ -4,6 +4,6 @@
.version = "0.0.0", .version = "0.0.0",
.fingerprint = 0xda130f3af836cea0, .fingerprint = 0xda130f3af836cea0,
.dependencies = .{ .dependencies = .{
.tls = .{ .url = "https://github.com/ianic/tls.zig/archive/6c36f8c39aeefa9e469b7eaf55a40b39a04d18c3.tar.gz", .hash = "122039cd3abe387b69d23930bf12154c2c84fc894874e10129a1fc5e8ac75ca0ddc0" }, .tls = .{ .url = "https://github.com/ianic/tls.zig/archive/96b923fcdaa6371617154857cef7b8337778cbe2.tar.gz", .hash = "122031f94565d7420a155b6eaec65aaa02acc80e75e6f0947899be2106bc3055b1ec" },
}, },
} }

View File

@@ -188,24 +188,20 @@ pub const Request = struct {
// Makes an synchronous request // Makes an synchronous request
pub fn sendSync(self: *Request, _: SendSyncOpts) anyerror!Response { pub fn sendSync(self: *Request, _: SendSyncOpts) anyerror!Response {
try self.prepareInitialSend(); try self.prepareInitialSend();
const socket, const address = try self.createSocket(true); return self.doSendSync();
var handler = SyncHandler{ .request = self };
return handler.send(socket, address) catch |err| {
log.warn("HTTP error: {any} ({any} {any})", .{ err, self.method, self.uri });
return err;
};
} }
// Called internally, follows a redirect. // Called internally, follows a redirect.
fn redirectSync(self: *Request, redirect: Reader.Redirect) anyerror!Response { fn redirectSync(self: *Request, redirect: Reader.Redirect) anyerror!Response {
posix.close(self._socket.?);
self._socket = null;
try self.prepareToRedirect(redirect); try self.prepareToRedirect(redirect);
return self.doSendSync();
}
fn doSendSync(self: *Request) anyerror!Response {
const socket, const address = try self.createSocket(true); const socket, const address = try self.createSocket(true);
var handler = SyncHandler{ .request = self }; var handler = SyncHandler{ .request = self };
return handler.send(socket, address) catch |err| { return handler.send(socket, address) catch |err| {
log.warn("HTTP error: {any} ({any} {any} redirect)", .{ err, self.method, self.uri }); log.warn("HTTP error: {any} ({any} {any} {d})", .{ err, self.method, self.uri, self._redirect_count });
return err; return err;
}; };
} }
@@ -214,9 +210,16 @@ pub const Request = struct {
// Makes an asynchronous request // Makes an asynchronous request
pub fn sendAsync(self: *Request, loop: anytype, handler: anytype, _: SendAsyncOpts) !void { pub fn sendAsync(self: *Request, loop: anytype, handler: anytype, _: SendAsyncOpts) !void {
try self.prepareInitialSend(); try self.prepareInitialSend();
return self.doSendAsync(loop, handler);
}
pub fn redirectAsync(self: *Request, redirect: Reader.Redirect, loop: anytype, handler: anytype) !void {
try self.prepareToRedirect(redirect);
return self.doSendAsync(loop, handler);
}
fn doSendAsync(self: *Request, loop: anytype, handler: anytype) !void {
// TODO: change this to nonblocking (false) when we have promise resolution // TODO: change this to nonblocking (false) when we have promise resolution
const socket, const address = try self.createSocket(true); const socket, const address = try self.createSocket(true);
const AsyncHandlerT = AsyncHandler(@TypeOf(handler), @TypeOf(loop)); const AsyncHandlerT = AsyncHandler(@TypeOf(handler), @TypeOf(loop));
const async_handler = try self.arena.create(AsyncHandlerT); const async_handler = try self.arena.create(AsyncHandlerT);
@@ -225,22 +228,22 @@ pub const Request = struct {
.socket = socket, .socket = socket,
.request = self, .request = self,
.handler = handler, .handler = handler,
.tls_conn = null,
.read_buf = self._state.buf, .read_buf = self._state.buf,
.reader = Reader.init(self._state), .reader = Reader.init(self._state),
.connection = .{ .handler = async_handler, .protocol = .{ .plain = {} } },
}; };
if (self.secure) { if (self.secure) {
async_handler.tls_conn = try tls.asyn.Client(AsyncHandlerT.TLSHandler).init(self.arena, .{ .handler = async_handler }, .{ async_handler.connection.protocol = .{ .tls_client = try tls.asyn.Client(AsyncHandlerT.TLSHandler).init(self.arena, .{ .handler = async_handler }, .{
.host = self.host(), .host = self.host(),
.root_ca = self._client.root_ca, .root_ca = self._client.root_ca,
}); }) };
} }
loop.connect(AsyncHandlerT, async_handler, &async_handler.read_completion, AsyncHandlerT.connected, socket, address); loop.connect(AsyncHandlerT, async_handler, &async_handler.read_completion, AsyncHandlerT.connected, socket, address);
} }
// Does additional setup of the request for the firsts (i.e. non-redirect) // Does additional setup of the request for the firsts (i.e. non-redirect) call.
// call.
fn prepareInitialSend(self: *Request) !void { fn prepareInitialSend(self: *Request) !void {
try self.verifyUri(); try self.verifyUri();
@@ -257,6 +260,13 @@ pub const Request = struct {
// Sets up the request for redirecting. // Sets up the request for redirecting.
fn prepareToRedirect(self: *Request, redirect: Reader.Redirect) !void { fn prepareToRedirect(self: *Request, redirect: Reader.Redirect) !void {
posix.close(self._socket.?);
self._socket = null;
// CANNOT reset the arena (╥﹏╥)
// We need it for self.uri (which we're about to use to resolve
// redirect.location, and it might own some/all headers)
const redirect_count = self._redirect_count; const redirect_count = self._redirect_count;
if (redirect_count == 10) { if (redirect_count == 10) {
return error.TooManyRedirects; return error.TooManyRedirects;
@@ -394,7 +404,14 @@ fn AsyncHandler(comptime H: type, comptime L: type) type {
// Used to help us know if we're writing the header or the body; // Used to help us know if we're writing the header or the body;
state: SendState = .handshake, state: SendState = .handshake,
tls_conn: ?tls.asyn.Client(TLSHandler) = null, // Abstraction over TLS and plain text socket
connection: Connection,
// This will be != null when we're supposed to redirect AND we've
// drained the response body. We need this as a field, because we'll
// detect this inside our TLS onRecv callback (which is executed
// inside the TLS client, and so we can't deinitialize the tls_client)
redirect: ?Reader.Redirect = null,
const Self = @This(); const Self = @This();
const SendQueue = std.DoublyLinkedList([]const u8); const SendQueue = std.DoublyLinkedList([]const u8);
@@ -405,30 +422,22 @@ fn AsyncHandler(comptime H: type, comptime L: type) type {
body, body,
}; };
const ProcessStatus = enum {
done,
need_more,
};
fn deinit(self: *Self) void { fn deinit(self: *Self) void {
if (self.tls_conn) |*tls_conn| { self.connection.deinit();
tls_conn.deinit();
}
self.request.deinit(); self.request.deinit();
} }
fn connected(self: *Self, _: *IO.Completion, result: IO.ConnectError!void) void { fn connected(self: *Self, _: *IO.Completion, result: IO.ConnectError!void) void {
self.loop.onConnect(result); self.loop.onConnect(result);
result catch |err| return self.handleError("Connection failed", err); result catch |err| return self.handleError("Connection failed", err);
self.connection.connected() catch |err| {
if (self.tls_conn) |*tls_conn| { self.handleError("connected handler error", err);
tls_conn.onConnect() catch |err| {
self.handleError("TLS handshake error", err);
}; };
self.receive();
return;
}
self.state = .header;
const header = self.request.buildHeader() catch |err| {
return self.handleError("out of memory", err);
};
self.send(header);
} }
fn send(self: *Self, data: []const u8) void { fn send(self: *Self, data: []const u8) void {
@@ -483,43 +492,12 @@ fn AsyncHandler(comptime H: type, comptime L: type) type {
return; return;
} }
if (self.state == .handshake) {} self.connection.sent(self.state) catch |err| {
self.handleError("Processing sent data", err);
switch (self.state) {
.handshake => {
// We're still doing our handshake. We need to wait until
// that's finished before sending the header. We might have
// more to send until then, but it'll be triggered by the
// TLS layer.
std.debug.assert(self.tls_conn != null);
},
.body => {
// We've finished sending the body.
if (self.tls_conn == null) {
// if we aren't using TLS, then we need to start the recive loop
self.receive();
}
},
.header => {
// We've sent the header, we should send the body.
self.state = .body;
if (self.request.body) |body| {
if (self.tls_conn) |*tls_conn| {
tls_conn.send(body) catch |err| {
self.handleError("TLS send", err);
}; };
} else {
self.send(body);
}
} else if (self.tls_conn == null) {
// start receiving the reply
self.receive();
}
},
}
} }
// Normally, you'd thin of HTTP as being a straight up request-response // Normally, you'd think of HTTP as being a straight up request-response
// and that we can send, and then receive. But with TLS, we need to receive // and that we can send, and then receive. But with TLS, we need to receive
// while handshaking and potentially while sending data. So we're always // while handshaking and potentially while sending data. So we're always
// receiving. // receiving.
@@ -544,51 +522,65 @@ fn AsyncHandler(comptime H: type, comptime L: type) type {
return self.handleError("Connection closed", error.ConnectionResetByPeer); return self.handleError("Connection closed", error.ConnectionResetByPeer);
} }
if (self.tls_conn) |*tls_conn| { const status = self.connection.received(n) catch |err| {
const pos = self.read_pos; self.handleError("data processing", err);
const end = pos + n;
const used = tls_conn.onRecv(self.read_buf[0..end]) catch |err| switch (err) {
error.Done => return self.deinit(),
else => {
self.handleError("TLS decrypt", err);
return; return;
},
}; };
if (used == end) {
self.read_pos = 0;
} else if (used == 0) {
self.read_pos = end;
} else {
const extra = end - used;
std.mem.copyForwards(u8, self.read_buf, self.read_buf[extra..end]);
self.read_pos = extra;
}
self.receive();
return;
}
if (self.processData(self.read_buf[0..n]) == false) { switch (status) {
// we're done .need_more => self.receive(),
.done => {
const redirect = self.redirect orelse {
self.deinit(); self.deinit();
} else { return;
// we're not done, need more data };
self.receive(); self.request.redirectAsync(redirect, self.loop, self.handler) catch |err| {
self.handleError("Setup async redirect", err);
return;
};
// redirectAsync has given up any claim to the request,
// including the socket. We just need to clean up our
// tls_client.
self.connection.deinit();
},
} }
} }
fn processData(self: *Self, d: []u8) bool { fn processData(self: *Self, d: []u8) ProcessStatus {
const reader = &self.reader; const reader = &self.reader;
var data = d; var data = d;
while (true) { while (true) {
const would_be_first = reader.response.status == 0; const would_be_first = reader.header_done == false;
const result = reader.process(data) catch |err| { const result = reader.process(data) catch |err| {
self.handleError("Invalid server response", err); self.handleError("Invalid server response", err);
return false; return .done;
}; };
const done = result.done;
if (result.data != null or done or (would_be_first and reader.response.status > 0)) { if (reader.header_done == false) {
// need more data
return .need_more;
}
// at this point, If `would_be_first == true`, then
// `would_be_first` should be thought of as `is_first` because
if (reader.redirect()) |redirect| {
// We don't redirect until we've drained the body (because,
// if we ever add keepalive, we'll re-use the connection).
// Calling `reader.redirect()` over and over again might not
// be the most efficient (it's a very simple function though),
// but for a redirect resposne, chances are we slurped up
// the header and body in a single go.
if (result.done == false) {
return .need_more;
}
self.redirect = redirect;
return .done;
}
const done = result.done;
if (result.data != null or done or would_be_first) {
// If we have data. Or if the request is done. Or if this is the // If we have data. Or if the request is done. Or if this is the
// first time we have a complete header. Emit the chunk. // first time we have a complete header. Emit the chunk.
self.handler.onHttpResponse(.{ self.handler.onHttpResponse(.{
@@ -596,28 +588,164 @@ fn AsyncHandler(comptime H: type, comptime L: type) type {
.data = result.data, .data = result.data,
.first = would_be_first, .first = would_be_first,
.header = reader.response, .header = reader.response,
}) catch return false; }) catch return .done;
} }
if (done == true) { if (done == true) {
return false; return .need_more;
} }
// With chunked-encoding, it's possible that we we've only // With chunked-encoding, it's possible that we we've only
// partially processed the data. So we need to keep processing // partially processed the data. So we need to keep processing
// any unprocessed data. It would be nice if we could just glue // any unprocessed data. It would be nice if we could just glue
// this all together, but that would require copying bytes around // this all together, but that would require copying bytes around
data = result.unprocessed orelse break; data = result.unprocessed orelse return .need_more;
} }
return true;
} }
fn handleError(self: *Self, comptime msg: []const u8, err: anyerror) void { fn handleError(self: *Self, comptime msg: []const u8, err: anyerror) void {
log.warn(msg ++ ": {any} ({any} {any})", .{ err, self.request.method, self.request.uri }); log.warn(msg ++ ": {any} ({any} {any})", .{ err, self.request.method, self.request.uri });
self.handler.onHttpResponse(error.Failed) catch {}; self.handler.onHttpResponse(err) catch {};
self.deinit(); self.deinit();
} }
const Connection = struct {
handler: *Self,
protocol: Protocol,
const Protocol = union(enum) {
plain: void,
tls_client: tls.asyn.Client(TLSHandler),
};
fn deinit(self: *Connection) void {
switch (self.protocol) {
.tls_client => |*tls_client| tls_client.deinit(),
.plain => {},
}
}
fn connected(self: *Connection) !void {
const handler = self.handler;
switch (self.protocol) {
.tls_client => |*tls_client| {
try tls_client.onConnect();
// when TLS is active, from a network point of view
// it's no longer a strict REQ->RES. We pretty much
// have to constantly receive data (e.g. to process
// the handshake)
handler.receive();
},
.plain => {
handler.state = .header;
const header = try handler.request.buildHeader();
return handler.send(header);
},
}
}
fn sent(self: *Connection, state: SendState) !void {
const handler = self.handler;
std.debug.assert(handler.state == state);
switch (self.protocol) {
.tls_client => |*tls_client| {
switch (state) {
.handshake => {
// Our send is complete, but it was part of the
// TLS handshake. This isn't data we need to
// worry about.
},
.header => {
// we WERE sending the header, but that's done
handler.state = .body;
if (handler.request.body) |body| {
try tls_client.send(body);
}
},
.body => {
// We've finished sending the body. For non TLS
// we'll start receiving. But here, for TLS,
// we started a receive loop as soon as the c
// connection was established.
},
}
},
.plain => {
switch (state) {
.handshake => unreachable,
.header => {
// we WERE sending the header, but that's done
handler.state = .body;
if (handler.request.body) |body| {
handler.send(body);
} else {
// No body? time to start reading the response
handler.receive();
}
},
.body => {
// we're done sending the body, time to start
// reading the response
handler.receive();
},
}
},
}
}
fn received(self: *Connection, n: usize) !ProcessStatus {
const handler = self.handler;
const read_buf = handler.read_buf;
switch (self.protocol) {
.tls_client => |*tls_client| {
// The read on TLS is stateful, since we need a full
// TLS record to get cleartext data.
const pos = handler.read_pos;
const end = pos + n;
const used = tls_client.onRecv(read_buf[0..end]) catch |err| switch (err) {
// https://github.com/ianic/tls.zig/pull/9
// we currently have no way to break out of the TLS handling
// loop, except for returning an error.
error.TLSHandlerDone => return .done,
error.EndOfFile => return .done, // TLS close
else => return err,
};
// When we tell our TLS client that we've received data
// there are three possibilities:
if (used == end) {
// 1 - It used up all the data that we gave it
handler.read_pos = 0;
} else if (used == 0) {
// 2 - It didn't use any of the data (i.e there
// wasn't a full record)
handler.read_pos = end;
} else {
// 3 - It used some of the data, but had leftover
// (i.e. there was 1+ full records AND an incomplete
// record). We need to maintain the "leftover" data
// for subsequent reads.
const extra = end - used;
std.mem.copyForwards(u8, read_buf, read_buf[extra..end]);
handler.read_pos = extra;
}
// Remember that our read_buf is the MAX possible TLS
// record size. So as long as we make sure that the start
// of a record is at read_buf[0], we know that we'll
// always have enough space for 1 record.
return .need_more;
},
.plain => return handler.processData(read_buf[0..n]),
}
}
};
// Separate struct just to keep it a bit cleaner. tls.zig requires // Separate struct just to keep it a bit cleaner. tls.zig requires
// callbacks like "onConnect" and "send" which is a bit generic and // callbacks like "onConnect" and "send" which is a bit generic and
// is confusing with the AsyncHandler which has similar concepts. // is confusing with the AsyncHandler which has similar concepts.
@@ -632,7 +760,7 @@ fn AsyncHandler(comptime H: type, comptime L: type) type {
return handler.handleError("out of memory", err); return handler.handleError("out of memory", err);
}; };
handler.state = .header; handler.state = .header;
handler.tls_conn.?.send(header) catch |err| { handler.connection.protocol.tls_client.send(header) catch |err| {
return handler.handleError("TLS send", err); return handler.handleError("TLS send", err);
}; };
} }
@@ -644,15 +772,17 @@ fn AsyncHandler(comptime H: type, comptime L: type) type {
// tls.zig received data, it's giving it to us in plaintext // tls.zig received data, it's giving it to us in plaintext
pub fn onRecv(self: TLSHandler, data: []u8) !void { pub fn onRecv(self: TLSHandler, data: []u8) !void {
if (self.handler.state != .body) { const handler = self.handler;
if (handler.state != .body) {
// We should not receive application-level data (which is the // We should not receive application-level data (which is the
// only data tls.zig will give us), if our handler hasn't sent // only data tls.zig will give us), if our handler hasn't sent
// the body. // the body.
self.handler.handleError("Premature server response", error.InvalidServerResonse); handler.handleError("Premature server response", error.InvalidServerResonse);
return error.InvalidServerResonse; return error.InvalidServerResonse;
} }
if (self.handler.processData(data) == false) { switch (handler.processData(data)) {
return error.Done; .need_more => {},
.done => return error.TLSHandlerDone, // https://github.com/ianic/tls.zig/pull/9
} }
} }
}; };
@@ -694,7 +824,7 @@ const SyncHandler = struct {
continue; continue;
} }
if (try reader.redirect()) |redirect| { if (reader.redirect()) |redirect| {
if (result.done == false) { if (result.done == false) {
try self.drain(&reader, &connection, result.unprocessed); try self.drain(&reader, &connection, result.unprocessed);
} }
@@ -826,7 +956,7 @@ const Reader = struct {
} }
// Determines if we need to redirect // Determines if we need to redirect
fn redirect(self: *const Reader) !?Redirect { fn redirect(self: *const Reader) ?Redirect {
const use_get = switch (self.response.status) { const use_get = switch (self.response.status) {
201, 301, 302, 303 => true, 201, 301, 302, 303 => true,
307, 308 => false, 307, 308 => false,
@@ -1336,10 +1466,6 @@ const State = struct {
} }
}; };
pub const AsyncError = error{
Failed,
};
const StatePool = struct { const StatePool = struct {
states: []*State, states: []*State,
available: usize, available: usize,
@@ -1570,9 +1696,9 @@ test "HttpClient: async connect error" {
const Handler = struct { const Handler = struct {
reset: *Thread.ResetEvent, reset: *Thread.ResetEvent,
fn onHttpResponse(self: *@This(), res: AsyncError!Progress) !void { fn onHttpResponse(self: *@This(), res: anyerror!Progress) !void {
_ = res catch |err| { _ = res catch |err| {
if (err == error.Failed) { if (err == error.ConnectionRefused) {
self.reset.set(); self.reset.set();
return; return;
} }
@@ -1637,6 +1763,38 @@ test "HttpClient: async with body" {
}); });
} }
test "HttpClient: async redirect" {
var client = try Client.init(testing.allocator, 2);
defer client.deinit();
var handler = try CaptureHandler.init();
defer handler.deinit();
var loop = try jsruntime.Loop.init(testing.allocator);
defer loop.deinit();
var req = try client.request(.GET, "HTTP://127.0.0.1:9582/http_client/redirect");
try req.sendAsync(&handler.loop, &handler, .{});
// Called twice on purpose. The initial GET resutls in the # of pending
// events to reach 0. This causes our `run_for_ns` to return. But we then
// start to requeue events (from the redirected request), so we need the
//loop to process those also.
try handler.loop.io.run_for_ns(std.time.ns_per_ms);
try handler.loop.io.run_for_ns(std.time.ns_per_ms);
try handler.reset.timedWait(std.time.ns_per_s);
const res = handler.response;
try testing.expectEqual("over 9000!", res.body.items);
try testing.expectEqual(201, res.status);
try res.assertHeaders(&.{
"connection", "close",
"content-length", "10",
"_host", "127.0.0.1",
"_connection", "Close",
});
}
const TestResponse = struct { const TestResponse = struct {
status: u16, status: u16,
keepalive: ?bool, keepalive: ?bool,
@@ -1704,13 +1862,13 @@ const CaptureHandler = struct {
self.loop.deinit(); self.loop.deinit();
} }
fn onHttpResponse(self: *CaptureHandler, progress_: AsyncError!Progress) !void { fn onHttpResponse(self: *CaptureHandler, progress_: anyerror!Progress) !void {
self.process(progress_) catch |err| { self.process(progress_) catch |err| {
std.debug.print("error: {}\n", .{err}); std.debug.print("error: {}\n", .{err});
}; };
} }
fn process(self: *CaptureHandler, progress_: AsyncError!Progress) !void { fn process(self: *CaptureHandler, progress_: anyerror!Progress) !void {
const progress = try progress_; const progress = try progress_;
const allocator = self.response.arena.allocator(); const allocator = self.response.arena.allocator();
try self.response.body.appendSlice(allocator, progress.data orelse ""); try self.response.body.appendSlice(allocator, progress.data orelse "");

View File

@@ -510,7 +510,7 @@ pub const XMLHttpRequest = struct {
try request.sendAsync(loop, self, .{}); try request.sendAsync(loop, self, .{});
} }
pub fn onHttpResponse(self: *XMLHttpRequest, progress_: http.AsyncError!http.Progress) !void { pub fn onHttpResponse(self: *XMLHttpRequest, progress_: anyerror!http.Progress) !void {
const progress = progress_ catch |err| { const progress = progress_ catch |err| {
self.onErr(err); self.onErr(err);
return err; return err;