From 226c18cb566280d2d16aff9fef0e50dd77688cd3 Mon Sep 17 00:00:00 2001 From: Karl Seguin Date: Tue, 18 Mar 2025 19:39:22 +0800 Subject: [PATCH] handle redirects on synchronous calls --- src/browser/browser.zig | 1 - src/http/client.zig | 288 ++++++++++++++++++++++++++++------------ src/unit_tests.zig | 7 +- src/xhr/xhr.zig | 2 +- 4 files changed, 213 insertions(+), 85 deletions(-) diff --git a/src/browser/browser.zig b/src/browser/browser.zig index 2f751991..ed0d9af0 100644 --- a/src/browser/browser.zig +++ b/src/browser/browser.zig @@ -606,7 +606,6 @@ pub const Page = struct { res_src = try std.fs.path.resolve(arena, &.{ _dir, src }); } } - const u = try std.Uri.resolve_inplace(self.uri, res_src, &b); var request = try self.session.http_client.request(.GET, u); diff --git a/src/http/client.zig b/src/http/client.zig index c2de5317..36fd178e 100644 --- a/src/http/client.zig +++ b/src/http/client.zig @@ -68,7 +68,7 @@ pub const Request = struct { body: ?[]const u8, arena: Allocator, headers: HeaderList, - _buf: []u8, + _redirect_count: u16, _socket: ?posix.socket_t, _state: *State, _client: *Client, @@ -117,10 +117,10 @@ pub const Request = struct { .body = null, .headers = .{}, .arena = arena, - ._buf = state.buf, ._socket = null, ._state = state, ._client = client, + ._redirect_count = 0, ._has_host_header = false, }; } @@ -160,67 +160,27 @@ pub const Request = struct { // TODO timeout const SendSyncOpts = struct {}; - pub fn sendSync(self: *Request, _: SendSyncOpts) !Response { + pub fn sendSync(self: *Request, _: SendSyncOpts) anyerror!Response { try self.prepareToSend(); const socket, const address = try self.createSocket(true); - try posix.connect(socket, &address.any, address.getOsSockLen()); + var handler = SyncHandler{ .request = self }; + return handler.send(socket, address) catch |err| { + log.warn("HTTP error: {any} ({any} {any})", .{ err, self.method, self.uri }); + return err; + }; + } - const header = try self.buildHeader(); - var stream = std.net.Stream{ .handle = socket }; + fn redirectSync(self: *Request, redirect: Reader.Redirect) anyerror!Response { + posix.close(self._socket.?); + self._socket = null; - var tls_conn: ?TLSConnection = null; - if (self.secure) { - var conn = try tls.client(stream, .{ - .host = self.host(), - .root_ca = self._client.root_ca, - }); - - try conn.writeAll(header); - if (self.body) |body| { - try conn.writeAll(body); - } - tls_conn = conn; - } else if (self.body) |body| { - var vec = [2]posix.iovec_const{ - .{ .len = header.len, .base = header.ptr }, - .{ .len = body.len, .base = body.ptr }, - }; - try writeAllIOVec(socket, &vec); - } else { - try stream.writeAll(header); - } - - var buf = self._state.buf; - var reader = Reader.init(self._state); - - while (true) { - var n: usize = 0; - if (tls_conn) |*conn| { - n = try conn.read(buf); - } else { - n = try stream.read(buf); - } - - if (n == 0) { - return error.ConnectionResetByPeer; - } - const result = try reader.process(buf[0..n]); - const response = reader.response; - if (response.status > 0) { - std.debug.assert(result.done or reader.body_reader != null); - std.debug.assert(result.data == null); - return .{ - ._buf = buf, - ._request = self, - ._reader = reader, - ._done = result.done, - ._tls_conn = tls_conn, - ._data = result.unprocessed, - ._socket = self._socket.?, - .header = response, - }; - } - } + try self.prepareToRedirect(redirect); + const socket, const address = try self.createSocket(true); + var handler = SyncHandler{ .request = self }; + return handler.send(socket, address) catch |err| { + log.warn("HTTP error: {any} ({any} {any} redirect)", .{ err, self.method, self.uri }); + return err; + }; } const SendAsyncOpts = struct {}; @@ -259,7 +219,42 @@ pub const Request = struct { } if (!self._has_host_header) { - try self.headers.append(arena, .{ .name = "Host", .value = self.uri.host.?.percent_encoded }); + try self.headers.append(arena, .{ .name = "Host", .value = self.host() }); + } + } + + fn prepareToRedirect(self: *Request, redirect: Reader.Redirect) !void { + const redirect_count = self._redirect_count; + if (redirect_count == 10) { + return error.TooManyRedirects; + } + self._redirect_count = redirect_count + 1; + + var buf = try self.arena.alloc(u8, 1024); + self.uri = try self.uri.resolve_inplace(redirect.location, &buf); + + if (redirect.use_get) { + self.method = .GET; + } + log.info("redirecting to: {any} {any}", .{ self.method, self.uri }); + + if (self.body != null and self.method == .GET) { + // Some redirects _must_ be switched to a GET. If we have a body + // we need to remove it + self.body = null; + for (self.headers.items, 0..) |hdr, i| { + if (std.mem.eql(u8, hdr.name, "Content-Length")) { + _ = self.headers.swapRemove(i); + break; + } + } + } + + for (self.headers.items) |*hdr| { + if (std.mem.eql(u8, hdr.name, "Host")) { + hdr.value = self.host(); + break; + } } } @@ -560,7 +555,7 @@ fn AsyncHandler(comptime H: type, comptime L: type) type { } fn handleError(self: *Self, comptime msg: []const u8, err: anyerror) void { - log.warn(msg ++ ": {any} ({any})", .{ err, self.request.uri }); + log.warn(msg ++ ": {any} ({any} {any})", .{ err, self.request.method, self.request.uri }); self.handler.onHttpResponse(error.Failed) catch {}; self.deinit(); } @@ -606,6 +601,99 @@ fn AsyncHandler(comptime H: type, comptime L: type) type { }; } +const SyncHandler = struct { + request: *Request, + + // The Request owns the socket, we shouldn't close it in here. + fn send(self: *SyncHandler, socket: posix.socket_t, address: std.net.Address) !Response { + var request = self.request; + try posix.connect(socket, &address.any, address.getOsSockLen()); + + const header = try request.buildHeader(); + var stream = std.net.Stream{ .handle = socket }; + + var tls_conn: ?TLSConnection = null; + if (request.secure) { + var conn = try tls.client(stream, .{ + .host = request.host(), + .root_ca = request._client.root_ca, + }); + + try conn.writeAll(header); + if (request.body) |body| { + try conn.writeAll(body); + } + tls_conn = conn; + } else if (request.body) |body| { + var vec = [2]posix.iovec_const{ + .{ .len = header.len, .base = header.ptr }, + .{ .len = body.len, .base = body.ptr }, + }; + try writeAllIOVec(socket, &vec); + } else { + try stream.writeAll(header); + } + + const state = request._state; + + var buf = state.buf; + var reader = Reader.init(state); + + while (true) { + // We keep going until we have the header + var n: usize = 0; + if (tls_conn) |*conn| { + n = try conn.read(buf); + } else { + n = try stream.read(buf); + } + if (n == 0) { + return error.ConnectionResetByPeer; + } + const result = try reader.process(buf[0..n]); + + if (reader.hasHeader() == false) { + continue; + } + + if (reader.redirect()) |redirect| { + return request.redirectSync(redirect); + } + + // we have a header, and it isn't a redirect, we return our Response + // object which can be iterated to get the body. + std.debug.assert(result.done or reader.body_reader != null); + std.debug.assert(result.data == null); + return .{ + ._buf = buf, + ._request = request, + ._reader = reader, + ._done = result.done, + ._tls_conn = tls_conn, + ._data = result.unprocessed, + ._socket = socket, + .header = reader.response, + }; + } + } + + fn writeAllIOVec(socket: posix.socket_t, vec: []posix.iovec_const) !void { + var i: usize = 0; + while (true) { + var n = try posix.writev(socket, vec[i..]); + while (n >= vec[i].len) { + n -= vec[i].len; + i += 1; + if (i >= vec.len) { + return; + } + } + vec[i].base += n; + vec[i].len -= n; + } + } +}; + // Used for reading the response (both the header and the body) const Reader = struct { // always references state.header_buf @@ -631,6 +719,21 @@ const Reader = struct { }; } + fn hasHeader(self: *const Reader) bool { + return self.response.status > 0; + } + + fn redirect(self: *const Reader) ?Redirect { + const use_get = switch (self.response.status) { + 201, 301, 302, 303 => true, + 307, 308 => false, + else => return null, + }; + + const location = self.response.get("location") orelse return null; + return .{ .use_get = use_get, .location = location }; + } + fn process(self: *Reader, data: []u8) ProcessError!Result { if (self.body_reader) |*br| { const ok, const result = try br.process(data); @@ -988,6 +1091,11 @@ const Reader = struct { }; }; + const Redirect = struct { + use_get: bool, + location: []const u8, + }; + const Result = struct { done: bool, data: ?[]u8, @@ -1125,7 +1233,7 @@ const State = struct { } }; -pub const Error = error{ +pub const AsyncError = error{ Failed, }; @@ -1197,22 +1305,6 @@ const StatePool = struct { } }; -pub fn writeAllIOVec(socket: posix.socket_t, vec: []posix.iovec_const) !void { - var i: usize = 0; - while (true) { - var n = try posix.writev(socket, vec[i..]); - while (n >= vec[i].len) { - n -= vec[i].len; - i += 1; - if (i >= vec.len) { - return; - } - } - vec[i].base += n; - vec[i].len -= n; - } -} - const testing = @import("../testing.zig"); test "HttpClient Reader: fuzz" { var state = try State.init(testing.allocator, 1024, 1024); @@ -1340,13 +1432,45 @@ test "HttpClient: sync no body" { try testing.expectEqual("0", res.header.get("content-length")); } +test "HttpClient: sync with body" { + var client = try Client.init(testing.allocator, 2); + defer client.deinit(); + + var req = try client.request(.GET, "http://127.0.0.1:9582/http_client/echo"); + var res = try req.sendSync(.{}); + + try testing.expectEqual("over 9000!", try res.next()); + try testing.expectEqual(201, res.header.status); + try testing.expectEqual(4, res.header.count()); + try testing.expectEqual("close", res.header.get("connection")); + try testing.expectEqual("10", res.header.get("content-length")); + try testing.expectEqual("127.0.0.1", res.header.get("_host")); + try testing.expectEqual("Close", res.header.get("_connection")); +} + +test "HttpClient: sync GET redirect" { + var client = try Client.init(testing.allocator, 2); + defer client.deinit(); + + var req = try client.request(.GET, "http://127.0.0.1:9582/http_client/redirect"); + var res = try req.sendSync(.{}); + + try testing.expectEqual("over 9000!", try res.next()); + try testing.expectEqual(201, res.header.status); + try testing.expectEqual(4, res.header.count()); + try testing.expectEqual("close", res.header.get("connection")); + try testing.expectEqual("10", res.header.get("content-length")); + try testing.expectEqual("127.0.0.1", res.header.get("_host")); + try testing.expectEqual("Close", res.header.get("_connection")); +} + test "HttpClient: async connect error" { var loop = try jsruntime.Loop.init(testing.allocator); defer loop.deinit(); const Handler = struct { reset: *Thread.ResetEvent, - fn onHttpResponse(self: *@This(), res: Error!Progress) !void { + fn onHttpResponse(self: *@This(), res: AsyncError!Progress) !void { _ = res catch |err| { if (err == error.Failed) { self.reset.set(); @@ -1397,7 +1521,7 @@ test "HttpClient: async with body" { var handler = try CaptureHandler.init(); defer handler.deinit(); - var req = try client.request(.GET, "HTTP://127.0.0.1:9582/http_client/body"); + var req = try client.request(.GET, "HTTP://127.0.0.1:9582/http_client/echo"); try req.sendAsync(&handler.loop, &handler, .{}); try handler.loop.io.run_for_ns(std.time.ns_per_ms); try handler.reset.timedWait(std.time.ns_per_s); @@ -1480,13 +1604,13 @@ const CaptureHandler = struct { self.loop.deinit(); } - fn onHttpResponse(self: *CaptureHandler, progress_: Error!Progress) !void { + fn onHttpResponse(self: *CaptureHandler, progress_: AsyncError!Progress) !void { self.process(progress_) catch |err| { std.debug.print("error: {}\n", .{err}); }; } - fn process(self: *CaptureHandler, progress_: Error!Progress) !void { + fn process(self: *CaptureHandler, progress_: AsyncError!Progress) !void { const progress = try progress_; const allocator = self.response.arena.allocator(); try self.response.body.appendSlice(allocator, progress.data orelse ""); diff --git a/src/unit_tests.zig b/src/unit_tests.zig index 88b3af5a..f3c12ca1 100644 --- a/src/unit_tests.zig +++ b/src/unit_tests.zig @@ -353,7 +353,12 @@ fn serveHTTP(allocator: Allocator, address: std.net.Address) !void { try request.respond("Hello!", .{}); } else if (std.mem.eql(u8, path, "/http_client/simple")) { try request.respond("", .{}); - } else if (std.mem.eql(u8, path, "/http_client/body")) { + } else if (std.mem.eql(u8, path, "/http_client/redirect")) { + try request.respond("", .{ + .status = .moved_permanently, + .extra_headers = &.{.{ .name = "LOCATION", .value = "../http_client/echo" }}, + }); + } else if (std.mem.eql(u8, path, "/http_client/echo")) { var headers: std.ArrayListUnmanaged(std.http.Header) = .{}; var it = request.iterateHeaders(); diff --git a/src/xhr/xhr.zig b/src/xhr/xhr.zig index 9b6f88d5..53e62bd0 100644 --- a/src/xhr/xhr.zig +++ b/src/xhr/xhr.zig @@ -510,7 +510,7 @@ pub const XMLHttpRequest = struct { try request.sendAsync(loop, self, .{}); } - pub fn onHttpResponse(self: *XMLHttpRequest, progress_: http.Error!http.Progress) !void { + pub fn onHttpResponse(self: *XMLHttpRequest, progress_: http.AsyncError!http.Progress) !void { const progress = progress_ catch |err| { self.onErr(err); return err;