From de160d9170a91bebea4868f4c06e1a9f8d89857a Mon Sep 17 00:00:00 2001 From: Karl Seguin Date: Wed, 19 Mar 2025 12:47:32 +0800 Subject: [PATCH] Cleanup synchronous connection for tls and non-tls. Drain response prior to redirect. --- src/http/client.zig | 274 ++++++++++++++++++++++++++++++-------------- 1 file changed, 187 insertions(+), 87 deletions(-) diff --git a/src/http/client.zig b/src/http/client.zig index 36fd178e..c165423f 100644 --- a/src/http/client.zig +++ b/src/http/client.zig @@ -21,9 +21,10 @@ const MAX_HEADER_LINE_LEN = 4096; // tls.max_ciphertext_record_len which isn't exposed const BUFFER_LEN = (1 << 14) + 256 + 5; -const TLSConnection = tls.Connection(std.net.Stream); const HeaderList = std.ArrayListUnmanaged(std.http.Header); +// Thread-safe. Holds our root certificate, connection pool and state pool +// Used to create Requests. pub const Client = struct { allocator: Allocator, state_pool: StatePool, @@ -61,17 +62,47 @@ pub const Client = struct { } }; +// Represents a request. Can be used to make either a synchronous or an +// asynchronous request. When a synchronous request is made, `request.deinit()` +// should be called once the response is no longer needed. +// When an asychronous request is made, the request is automatically cleaned up +// (but request.deinit() should still be called to discard the request +// before the `sendAsync` is called). pub const Request = struct { + // Whether or not TLS is being used. secure: bool, + // The HTTP Method to use method: Method, + + // The URI we're requested uri: std.Uri, + + // Optional body body: ?[]const u8, + + // Arena used for the lifetime of the request. Most large allocations are + // either done through the state (pre-allocated on startup + pooled) or + // by the TLS library. arena: Allocator, + + // List of request headers headers: HeaderList, + + // Used to limit the # of redirects we'll follow _redirect_count: u16, + + // The underlying socket _socket: ?posix.socket_t, + + // Pooled buffers and arena _state: *State, + + // The parent client. Used to get the root certificates, to interact + // with the connection pool, and to return _state to the state pool when done _client: *Client, + + // Whether the Host header has been set via `request.addHeader()`. If not + // we'll set it based on `uri` before issuing the request. _has_host_header: bool, pub const Method = enum { @@ -87,6 +118,7 @@ pub const Request = struct { } }; + // url can either be a `[]const u8`, in which case we'll clone + parse, or a std.Uri fn init(client: *Client, state: *State, method: Method, url: anytype) !Request { var arena = state.arena.allocator(); @@ -103,15 +135,8 @@ pub const Request = struct { return error.UriMissingHost; } - var secure: bool = false; - if (std.ascii.eqlIgnoreCase(uri.scheme, "https")) { - secure = true; - } else if (std.ascii.eqlIgnoreCase(uri.scheme, "http") == false) { - return error.UnsupportedUriScheme; - } - return .{ - .secure = secure, + .secure = true, .uri = uri, .method = method, .body = null, @@ -160,8 +185,9 @@ pub const Request = struct { // TODO timeout const SendSyncOpts = struct {}; + // Makes an synchronous request pub fn sendSync(self: *Request, _: SendSyncOpts) anyerror!Response { - try self.prepareToSend(); + try self.prepareInitialSend(); const socket, const address = try self.createSocket(true); var handler = SyncHandler{ .request = self }; return handler.send(socket, address) catch |err| { @@ -170,6 +196,7 @@ pub const Request = struct { }; } + // Called internally, follows a redirect. fn redirectSync(self: *Request, redirect: Reader.Redirect) anyerror!Response { posix.close(self._socket.?); self._socket = null; @@ -184,13 +211,15 @@ pub const Request = struct { } const SendAsyncOpts = struct {}; + // Makes an asynchronous request pub fn sendAsync(self: *Request, loop: anytype, handler: anytype, _: SendAsyncOpts) !void { - try self.prepareToSend(); + try self.prepareInitialSend(); // TODO: change this to nonblocking (false) when we have promise resolution const socket, const address = try self.createSocket(true); const AsyncHandlerT = AsyncHandler(@TypeOf(handler), @TypeOf(loop)); const async_handler = try self.arena.create(AsyncHandlerT); + async_handler.* = .{ .loop = loop, .socket = socket, @@ -210,9 +239,12 @@ pub const Request = struct { loop.connect(AsyncHandlerT, async_handler, &async_handler.read_completion, AsyncHandlerT.connected, socket, address); } - fn prepareToSend(self: *Request) !void { - const arena = self.arena; + // Does additional setup of the request for the firsts (i.e. non-redirect) + // call. + fn prepareInitialSend(self: *Request) !void { + try self.verifyUri(); + const arena = self.arena; if (self.body) |body| { const cl = try std.fmt.allocPrint(arena, "{d}", .{body.len}); try self.headers.append(arena, .{ .name = "Content-Length", .value = cl }); @@ -223,6 +255,7 @@ pub const Request = struct { } } + // Sets up the request for redirecting. fn prepareToRedirect(self: *Request, redirect: Reader.Redirect) !void { const redirect_count = self._redirect_count; if (redirect_count == 10) { @@ -231,33 +264,57 @@ pub const Request = struct { self._redirect_count = redirect_count + 1; var buf = try self.arena.alloc(u8, 1024); + + const previous_host = self.host(); self.uri = try self.uri.resolve_inplace(redirect.location, &buf); + try self.verifyUri(); if (redirect.use_get) { + // Some redirect status codes _require_ that we switch the method + // to a GET. self.method = .GET; } log.info("redirecting to: {any} {any}", .{ self.method, self.uri }); if (self.body != null and self.method == .GET) { - // Some redirects _must_ be switched to a GET. If we have a body - // we need to remove it + // If we have a body and the method is a GET, then we must be following + // a redirect which switched the method. Remove the body. + // Reset the Content-Length self.body = null; - for (self.headers.items, 0..) |hdr, i| { + for (self.headers.items) |*hdr| { if (std.mem.eql(u8, hdr.name, "Content-Length")) { - _ = self.headers.swapRemove(i); + hdr.value = "0"; break; } } } - for (self.headers.items) |*hdr| { - if (std.mem.eql(u8, hdr.name, "Host")) { - hdr.value = self.host(); - break; + const new_host = self.host(); + if (std.mem.eql(u8, previous_host, new_host) == false) { + for (self.headers.items) |*hdr| { + if (std.mem.eql(u8, hdr.name, "Host")) { + hdr.value = new_host; + break; + } } } } + // extracted because we re-verify this on redirect + fn verifyUri(self: *Request) !void { + const scheme = self.uri.scheme; + if (std.ascii.eqlIgnoreCase(scheme, "https")) { + self.secure = true; + return; + } + if (std.ascii.eqlIgnoreCase(scheme, "http")) { + self.secure = false; + return; + } + + return error.UnsupportedUriScheme; + } + fn createSocket(self: *Request, blocking: bool) !struct { posix.socket_t, std.net.Address } { const host_ = self.host(); const port: u16 = self.uri.port orelse if (self.secure) 443 else 80; @@ -307,6 +364,7 @@ pub const Request = struct { } }; +// Handles asynchronous requests fn AsyncHandler(comptime H: type, comptime L: type) type { return struct { loop: L, @@ -601,6 +659,7 @@ fn AsyncHandler(comptime H: type, comptime L: type) type { }; } +// Handles synchronous requests const SyncHandler = struct { request: *Request, @@ -609,54 +668,36 @@ const SyncHandler = struct { var request = self.request; try posix.connect(socket, &address.any, address.getOsSockLen()); - const header = try request.buildHeader(); - var stream = std.net.Stream{ .handle = socket }; - - var tls_conn: ?TLSConnection = null; + var connection: Connection = undefined; if (request.secure) { - var conn = try tls.client(stream, .{ + connection = .{ .tls = try tls.client(std.net.Stream{ .handle = socket }, .{ .host = request.host(), .root_ca = request._client.root_ca, - }); - - try conn.writeAll(header); - if (request.body) |body| { - try conn.writeAll(body); - } - tls_conn = conn; - } else if (request.body) |body| { - var vec = [2]posix.iovec_const{ - .{ .len = header.len, .base = header.ptr }, - .{ .len = body.len, .base = body.ptr }, - }; - try writeAllIOVec(socket, &vec); + }) }; } else { - try stream.writeAll(header); + connection = .{ .plain = socket }; } + const header = try request.buildHeader(); + try connection.sendRequest(header, request.body); + const state = request._state; var buf = state.buf; var reader = Reader.init(state); while (true) { - // We keep going until we have the header - var n: usize = 0; - if (tls_conn) |*conn| { - n = try conn.read(buf); - } else { - n = try stream.read(buf); - } - if (n == 0) { - return error.ConnectionResetByPeer; - } + const n = try connection.read(buf); const result = try reader.process(buf[0..n]); - if (reader.hasHeader() == false) { + if (reader.header_done == false) { continue; } - if (reader.redirect()) |redirect| { + if (try reader.redirect()) |redirect| { + if (result.done == false) { + try self.drain(&reader, &connection, result.unprocessed); + } return request.redirectSync(redirect); } @@ -669,29 +710,91 @@ const SyncHandler = struct { ._request = request, ._reader = reader, ._done = result.done, - ._tls_conn = tls_conn, + ._connection = connection, ._data = result.unprocessed, - ._socket = socket, .header = reader.response, }; } } - fn writeAllIOVec(socket: posix.socket_t, vec: []posix.iovec_const) !void { - var i: usize = 0; - while (true) { - var n = try posix.writev(socket, vec[i..]); - while (n >= vec[i].len) { - n -= vec[i].len; - i += 1; - if (i >= vec.len) { - return; - } + fn drain(self: SyncHandler, reader: *Reader, connection: *Connection, unprocessed: ?[]u8) !void { + if (unprocessed) |data| { + const result = try reader.process(data); + if (result.done) { + return; + } + } + + var buf = self.request._state.buf; + while (true) { + const n = try connection.read(buf); + const result = try reader.process(buf[0..n]); + if (result.done) { + return; } - vec[i].base += n; - vec[i].len -= n; } } + + const Connection = union(enum) { + tls: tls.Connection(std.net.Stream), + plain: posix.socket_t, + + fn sendRequest(self: *Connection, header: []const u8, body: ?[]const u8) !void { + switch (self.*) { + .tls => |*tls_conn| { + try tls_conn.writeAll(header); + if (body) |b| { + try tls_conn.writeAll(b); + } + }, + .plain => |socket| { + if (body) |b| { + var vec = [2]posix.iovec_const{ + .{ .len = header.len, .base = header.ptr }, + .{ .len = b.len, .base = b.ptr }, + }; + return writeAllIOVec(socket, &vec); + } + + return writeAll(socket, header); + }, + } + } + + fn read(self: *Connection, buf: []u8) !usize { + const n = switch (self.*) { + .tls => |*tls_conn| try tls_conn.read(buf), + .plain => |socket| try posix.read(socket, buf), + }; + if (n == 0) { + return error.ConnectionResetByPeer; + } + return n; + } + + fn writeAllIOVec(socket: posix.socket_t, vec: []posix.iovec_const) !void { + var i: usize = 0; + while (true) { + var n = try posix.writev(socket, vec[i..]); + while (n >= vec[i].len) { + n -= vec[i].len; + i += 1; + if (i >= vec.len) { + return; + } + } + vec[i].base += n; + vec[i].len -= n; + } + } + + fn writeAll(socket: posix.socket_t, data: []const u8) !void { + var i: usize = 0; + while (i < data.len) { + i += try posix.write(socket, data[i..]); + } + } + }; }; // Used for reading the response (both the header and the body) @@ -709,21 +812,21 @@ const Reader = struct { body_reader: ?BodyReader, + header_done: bool, + fn init(state: *State) Reader { return .{ .pos = 0, .response = .{}, .body_reader = null, + .header_done = false, .header_buf = state.header_buf, .arena = state.arena.allocator(), }; } - fn hasHeader(self: *const Reader) bool { - return self.response.status > 0; - } - - fn redirect(self: *const Reader) ?Redirect { + // Determines if we need to redirect + fn redirect(self: *const Reader) !?Redirect { const use_get = switch (self.response.status) { 201, 301, 302, 303 => true, 307, 308 => false, @@ -814,7 +917,6 @@ const Reader = struct { return .{ .done = false, .data = null, .unprocessed = null }; } } - var result = try self.prepareForBody(); if (unprocessed.len > 0) { if (result.done == true) { @@ -831,6 +933,8 @@ const Reader = struct { // We're done parsing the header, and we need to (maybe) setup the BodyReader fn prepareForBody(self: *Reader) !Result { + self.header_done = true; + const response = &self.response; if (response.get("transfer-encoding")) |te| { if (std.ascii.indexOfIgnoreCase(te, "chunked") != null) { @@ -863,6 +967,11 @@ const Reader = struct { return .{ .done = false, .data = null, .unprocessed = null }; } + // returns true when done + // returns any remaining unprocessed data + // When done == true, the remaining data must belong to the body + // When done == false, at least part of the remaining data must belong to + // the header. fn parseHeader(self: *Reader, data: []u8) !struct { bool, []u8 } { var pos: usize = 0; const arena = self.arena; @@ -1099,6 +1208,10 @@ const Reader = struct { const Result = struct { done: bool, data: ?[]u8, + // Any unprocessed data we have from the last call to "process". + // We can have unprocessed data when transitioning from parsing the + // header to parsing the body. When using Chunked encoding, we'll also + // have unprocessed data between chunks. unprocessed: ?[]u8 = null, }; @@ -1150,8 +1263,7 @@ pub const Response = struct { _request: *Request, _buf: []u8, - _socket: posix.socket_t, - _tls_conn: ?TLSConnection, + _connection: SyncHandler.Connection, _done: bool, @@ -1170,16 +1282,7 @@ pub const Response = struct { return null; } - var n: usize = 0; - if (self._tls_conn) |*tls_conn| { - n = try tls_conn.read(buf); - } else { - n = try posix.read(self._socket, buf); - } - if (n == 0) { - self._done = true; - return null; - } + const n = try self._connection.read(buf); self._data = buf[0..n]; } } @@ -1404,9 +1507,6 @@ test "HttpClient Reader: fuzz" { test "HttpClient: invalid url" { var client = try Client.init(testing.allocator, 1); defer client.deinit(); - - try testing.expectError(error.UnsupportedUriScheme, client.request(.GET, "://localhost")); - try testing.expectError(error.UnsupportedUriScheme, client.request(.GET, "ftp://localhost")); try testing.expectError(error.UriMissingHost, client.request(.GET, "http:///")); }