Cleanup synchronous connection for tls and non-tls.

Drain response prior to redirect.
This commit is contained in:
Karl Seguin
2025-03-19 12:47:32 +08:00
parent 226c18cb56
commit de160d9170

View File

@@ -21,9 +21,10 @@ const MAX_HEADER_LINE_LEN = 4096;
// tls.max_ciphertext_record_len which isn't exposed // tls.max_ciphertext_record_len which isn't exposed
const BUFFER_LEN = (1 << 14) + 256 + 5; const BUFFER_LEN = (1 << 14) + 256 + 5;
const TLSConnection = tls.Connection(std.net.Stream);
const HeaderList = std.ArrayListUnmanaged(std.http.Header); const HeaderList = std.ArrayListUnmanaged(std.http.Header);
// Thread-safe. Holds our root certificate, connection pool and state pool
// Used to create Requests.
pub const Client = struct { pub const Client = struct {
allocator: Allocator, allocator: Allocator,
state_pool: StatePool, state_pool: StatePool,
@@ -61,17 +62,47 @@ pub const Client = struct {
} }
}; };
// Represents a request. Can be used to make either a synchronous or an
// asynchronous request. When a synchronous request is made, `request.deinit()`
// should be called once the response is no longer needed.
// When an asychronous request is made, the request is automatically cleaned up
// (but request.deinit() should still be called to discard the request
// before the `sendAsync` is called).
pub const Request = struct { pub const Request = struct {
// Whether or not TLS is being used.
secure: bool, secure: bool,
// The HTTP Method to use
method: Method, method: Method,
// The URI we're requested
uri: std.Uri, uri: std.Uri,
// Optional body
body: ?[]const u8, body: ?[]const u8,
// Arena used for the lifetime of the request. Most large allocations are
// either done through the state (pre-allocated on startup + pooled) or
// by the TLS library.
arena: Allocator, arena: Allocator,
// List of request headers
headers: HeaderList, headers: HeaderList,
// Used to limit the # of redirects we'll follow
_redirect_count: u16, _redirect_count: u16,
// The underlying socket
_socket: ?posix.socket_t, _socket: ?posix.socket_t,
// Pooled buffers and arena
_state: *State, _state: *State,
// The parent client. Used to get the root certificates, to interact
// with the connection pool, and to return _state to the state pool when done
_client: *Client, _client: *Client,
// Whether the Host header has been set via `request.addHeader()`. If not
// we'll set it based on `uri` before issuing the request.
_has_host_header: bool, _has_host_header: bool,
pub const Method = enum { pub const Method = enum {
@@ -87,6 +118,7 @@ pub const Request = struct {
} }
}; };
// url can either be a `[]const u8`, in which case we'll clone + parse, or a std.Uri
fn init(client: *Client, state: *State, method: Method, url: anytype) !Request { fn init(client: *Client, state: *State, method: Method, url: anytype) !Request {
var arena = state.arena.allocator(); var arena = state.arena.allocator();
@@ -103,15 +135,8 @@ pub const Request = struct {
return error.UriMissingHost; return error.UriMissingHost;
} }
var secure: bool = false;
if (std.ascii.eqlIgnoreCase(uri.scheme, "https")) {
secure = true;
} else if (std.ascii.eqlIgnoreCase(uri.scheme, "http") == false) {
return error.UnsupportedUriScheme;
}
return .{ return .{
.secure = secure, .secure = true,
.uri = uri, .uri = uri,
.method = method, .method = method,
.body = null, .body = null,
@@ -160,8 +185,9 @@ pub const Request = struct {
// TODO timeout // TODO timeout
const SendSyncOpts = struct {}; const SendSyncOpts = struct {};
// Makes an synchronous request
pub fn sendSync(self: *Request, _: SendSyncOpts) anyerror!Response { pub fn sendSync(self: *Request, _: SendSyncOpts) anyerror!Response {
try self.prepareToSend(); try self.prepareInitialSend();
const socket, const address = try self.createSocket(true); const socket, const address = try self.createSocket(true);
var handler = SyncHandler{ .request = self }; var handler = SyncHandler{ .request = self };
return handler.send(socket, address) catch |err| { return handler.send(socket, address) catch |err| {
@@ -170,6 +196,7 @@ pub const Request = struct {
}; };
} }
// Called internally, follows a redirect.
fn redirectSync(self: *Request, redirect: Reader.Redirect) anyerror!Response { fn redirectSync(self: *Request, redirect: Reader.Redirect) anyerror!Response {
posix.close(self._socket.?); posix.close(self._socket.?);
self._socket = null; self._socket = null;
@@ -184,13 +211,15 @@ pub const Request = struct {
} }
const SendAsyncOpts = struct {}; const SendAsyncOpts = struct {};
// Makes an asynchronous request
pub fn sendAsync(self: *Request, loop: anytype, handler: anytype, _: SendAsyncOpts) !void { pub fn sendAsync(self: *Request, loop: anytype, handler: anytype, _: SendAsyncOpts) !void {
try self.prepareToSend(); try self.prepareInitialSend();
// TODO: change this to nonblocking (false) when we have promise resolution // TODO: change this to nonblocking (false) when we have promise resolution
const socket, const address = try self.createSocket(true); const socket, const address = try self.createSocket(true);
const AsyncHandlerT = AsyncHandler(@TypeOf(handler), @TypeOf(loop)); const AsyncHandlerT = AsyncHandler(@TypeOf(handler), @TypeOf(loop));
const async_handler = try self.arena.create(AsyncHandlerT); const async_handler = try self.arena.create(AsyncHandlerT);
async_handler.* = .{ async_handler.* = .{
.loop = loop, .loop = loop,
.socket = socket, .socket = socket,
@@ -210,9 +239,12 @@ pub const Request = struct {
loop.connect(AsyncHandlerT, async_handler, &async_handler.read_completion, AsyncHandlerT.connected, socket, address); loop.connect(AsyncHandlerT, async_handler, &async_handler.read_completion, AsyncHandlerT.connected, socket, address);
} }
fn prepareToSend(self: *Request) !void { // Does additional setup of the request for the firsts (i.e. non-redirect)
const arena = self.arena; // call.
fn prepareInitialSend(self: *Request) !void {
try self.verifyUri();
const arena = self.arena;
if (self.body) |body| { if (self.body) |body| {
const cl = try std.fmt.allocPrint(arena, "{d}", .{body.len}); const cl = try std.fmt.allocPrint(arena, "{d}", .{body.len});
try self.headers.append(arena, .{ .name = "Content-Length", .value = cl }); try self.headers.append(arena, .{ .name = "Content-Length", .value = cl });
@@ -223,6 +255,7 @@ pub const Request = struct {
} }
} }
// Sets up the request for redirecting.
fn prepareToRedirect(self: *Request, redirect: Reader.Redirect) !void { fn prepareToRedirect(self: *Request, redirect: Reader.Redirect) !void {
const redirect_count = self._redirect_count; const redirect_count = self._redirect_count;
if (redirect_count == 10) { if (redirect_count == 10) {
@@ -231,33 +264,57 @@ pub const Request = struct {
self._redirect_count = redirect_count + 1; self._redirect_count = redirect_count + 1;
var buf = try self.arena.alloc(u8, 1024); var buf = try self.arena.alloc(u8, 1024);
const previous_host = self.host();
self.uri = try self.uri.resolve_inplace(redirect.location, &buf); self.uri = try self.uri.resolve_inplace(redirect.location, &buf);
try self.verifyUri();
if (redirect.use_get) { if (redirect.use_get) {
// Some redirect status codes _require_ that we switch the method
// to a GET.
self.method = .GET; self.method = .GET;
} }
log.info("redirecting to: {any} {any}", .{ self.method, self.uri }); log.info("redirecting to: {any} {any}", .{ self.method, self.uri });
if (self.body != null and self.method == .GET) { if (self.body != null and self.method == .GET) {
// Some redirects _must_ be switched to a GET. If we have a body // If we have a body and the method is a GET, then we must be following
// we need to remove it // a redirect which switched the method. Remove the body.
// Reset the Content-Length
self.body = null; self.body = null;
for (self.headers.items, 0..) |hdr, i| { for (self.headers.items) |*hdr| {
if (std.mem.eql(u8, hdr.name, "Content-Length")) { if (std.mem.eql(u8, hdr.name, "Content-Length")) {
_ = self.headers.swapRemove(i); hdr.value = "0";
break; break;
} }
} }
} }
for (self.headers.items) |*hdr| { const new_host = self.host();
if (std.mem.eql(u8, hdr.name, "Host")) { if (std.mem.eql(u8, previous_host, new_host) == false) {
hdr.value = self.host(); for (self.headers.items) |*hdr| {
break; if (std.mem.eql(u8, hdr.name, "Host")) {
hdr.value = new_host;
break;
}
} }
} }
} }
// extracted because we re-verify this on redirect
fn verifyUri(self: *Request) !void {
const scheme = self.uri.scheme;
if (std.ascii.eqlIgnoreCase(scheme, "https")) {
self.secure = true;
return;
}
if (std.ascii.eqlIgnoreCase(scheme, "http")) {
self.secure = false;
return;
}
return error.UnsupportedUriScheme;
}
fn createSocket(self: *Request, blocking: bool) !struct { posix.socket_t, std.net.Address } { fn createSocket(self: *Request, blocking: bool) !struct { posix.socket_t, std.net.Address } {
const host_ = self.host(); const host_ = self.host();
const port: u16 = self.uri.port orelse if (self.secure) 443 else 80; const port: u16 = self.uri.port orelse if (self.secure) 443 else 80;
@@ -307,6 +364,7 @@ pub const Request = struct {
} }
}; };
// Handles asynchronous requests
fn AsyncHandler(comptime H: type, comptime L: type) type { fn AsyncHandler(comptime H: type, comptime L: type) type {
return struct { return struct {
loop: L, loop: L,
@@ -601,6 +659,7 @@ fn AsyncHandler(comptime H: type, comptime L: type) type {
}; };
} }
// Handles synchronous requests
const SyncHandler = struct { const SyncHandler = struct {
request: *Request, request: *Request,
@@ -609,54 +668,36 @@ const SyncHandler = struct {
var request = self.request; var request = self.request;
try posix.connect(socket, &address.any, address.getOsSockLen()); try posix.connect(socket, &address.any, address.getOsSockLen());
const header = try request.buildHeader(); var connection: Connection = undefined;
var stream = std.net.Stream{ .handle = socket };
var tls_conn: ?TLSConnection = null;
if (request.secure) { if (request.secure) {
var conn = try tls.client(stream, .{ connection = .{ .tls = try tls.client(std.net.Stream{ .handle = socket }, .{
.host = request.host(), .host = request.host(),
.root_ca = request._client.root_ca, .root_ca = request._client.root_ca,
}); }) };
try conn.writeAll(header);
if (request.body) |body| {
try conn.writeAll(body);
}
tls_conn = conn;
} else if (request.body) |body| {
var vec = [2]posix.iovec_const{
.{ .len = header.len, .base = header.ptr },
.{ .len = body.len, .base = body.ptr },
};
try writeAllIOVec(socket, &vec);
} else { } else {
try stream.writeAll(header); connection = .{ .plain = socket };
} }
const header = try request.buildHeader();
try connection.sendRequest(header, request.body);
const state = request._state; const state = request._state;
var buf = state.buf; var buf = state.buf;
var reader = Reader.init(state); var reader = Reader.init(state);
while (true) { while (true) {
// We keep going until we have the header const n = try connection.read(buf);
var n: usize = 0;
if (tls_conn) |*conn| {
n = try conn.read(buf);
} else {
n = try stream.read(buf);
}
if (n == 0) {
return error.ConnectionResetByPeer;
}
const result = try reader.process(buf[0..n]); const result = try reader.process(buf[0..n]);
if (reader.hasHeader() == false) { if (reader.header_done == false) {
continue; continue;
} }
if (reader.redirect()) |redirect| { if (try reader.redirect()) |redirect| {
if (result.done == false) {
try self.drain(&reader, &connection, result.unprocessed);
}
return request.redirectSync(redirect); return request.redirectSync(redirect);
} }
@@ -669,29 +710,91 @@ const SyncHandler = struct {
._request = request, ._request = request,
._reader = reader, ._reader = reader,
._done = result.done, ._done = result.done,
._tls_conn = tls_conn, ._connection = connection,
._data = result.unprocessed, ._data = result.unprocessed,
._socket = socket,
.header = reader.response, .header = reader.response,
}; };
} }
} }
fn writeAllIOVec(socket: posix.socket_t, vec: []posix.iovec_const) !void { fn drain(self: SyncHandler, reader: *Reader, connection: *Connection, unprocessed: ?[]u8) !void {
var i: usize = 0; if (unprocessed) |data| {
while (true) { const result = try reader.process(data);
var n = try posix.writev(socket, vec[i..]); if (result.done) {
while (n >= vec[i].len) { return;
n -= vec[i].len; }
i += 1; }
if (i >= vec.len) {
return; var buf = self.request._state.buf;
} while (true) {
const n = try connection.read(buf);
const result = try reader.process(buf[0..n]);
if (result.done) {
return;
} }
vec[i].base += n;
vec[i].len -= n;
} }
} }
const Connection = union(enum) {
tls: tls.Connection(std.net.Stream),
plain: posix.socket_t,
fn sendRequest(self: *Connection, header: []const u8, body: ?[]const u8) !void {
switch (self.*) {
.tls => |*tls_conn| {
try tls_conn.writeAll(header);
if (body) |b| {
try tls_conn.writeAll(b);
}
},
.plain => |socket| {
if (body) |b| {
var vec = [2]posix.iovec_const{
.{ .len = header.len, .base = header.ptr },
.{ .len = b.len, .base = b.ptr },
};
return writeAllIOVec(socket, &vec);
}
return writeAll(socket, header);
},
}
}
fn read(self: *Connection, buf: []u8) !usize {
const n = switch (self.*) {
.tls => |*tls_conn| try tls_conn.read(buf),
.plain => |socket| try posix.read(socket, buf),
};
if (n == 0) {
return error.ConnectionResetByPeer;
}
return n;
}
fn writeAllIOVec(socket: posix.socket_t, vec: []posix.iovec_const) !void {
var i: usize = 0;
while (true) {
var n = try posix.writev(socket, vec[i..]);
while (n >= vec[i].len) {
n -= vec[i].len;
i += 1;
if (i >= vec.len) {
return;
}
}
vec[i].base += n;
vec[i].len -= n;
}
}
fn writeAll(socket: posix.socket_t, data: []const u8) !void {
var i: usize = 0;
while (i < data.len) {
i += try posix.write(socket, data[i..]);
}
}
};
}; };
// Used for reading the response (both the header and the body) // Used for reading the response (both the header and the body)
@@ -709,21 +812,21 @@ const Reader = struct {
body_reader: ?BodyReader, body_reader: ?BodyReader,
header_done: bool,
fn init(state: *State) Reader { fn init(state: *State) Reader {
return .{ return .{
.pos = 0, .pos = 0,
.response = .{}, .response = .{},
.body_reader = null, .body_reader = null,
.header_done = false,
.header_buf = state.header_buf, .header_buf = state.header_buf,
.arena = state.arena.allocator(), .arena = state.arena.allocator(),
}; };
} }
fn hasHeader(self: *const Reader) bool { // Determines if we need to redirect
return self.response.status > 0; fn redirect(self: *const Reader) !?Redirect {
}
fn redirect(self: *const Reader) ?Redirect {
const use_get = switch (self.response.status) { const use_get = switch (self.response.status) {
201, 301, 302, 303 => true, 201, 301, 302, 303 => true,
307, 308 => false, 307, 308 => false,
@@ -814,7 +917,6 @@ const Reader = struct {
return .{ .done = false, .data = null, .unprocessed = null }; return .{ .done = false, .data = null, .unprocessed = null };
} }
} }
var result = try self.prepareForBody(); var result = try self.prepareForBody();
if (unprocessed.len > 0) { if (unprocessed.len > 0) {
if (result.done == true) { if (result.done == true) {
@@ -831,6 +933,8 @@ const Reader = struct {
// We're done parsing the header, and we need to (maybe) setup the BodyReader // We're done parsing the header, and we need to (maybe) setup the BodyReader
fn prepareForBody(self: *Reader) !Result { fn prepareForBody(self: *Reader) !Result {
self.header_done = true;
const response = &self.response; const response = &self.response;
if (response.get("transfer-encoding")) |te| { if (response.get("transfer-encoding")) |te| {
if (std.ascii.indexOfIgnoreCase(te, "chunked") != null) { if (std.ascii.indexOfIgnoreCase(te, "chunked") != null) {
@@ -863,6 +967,11 @@ const Reader = struct {
return .{ .done = false, .data = null, .unprocessed = null }; return .{ .done = false, .data = null, .unprocessed = null };
} }
// returns true when done
// returns any remaining unprocessed data
// When done == true, the remaining data must belong to the body
// When done == false, at least part of the remaining data must belong to
// the header.
fn parseHeader(self: *Reader, data: []u8) !struct { bool, []u8 } { fn parseHeader(self: *Reader, data: []u8) !struct { bool, []u8 } {
var pos: usize = 0; var pos: usize = 0;
const arena = self.arena; const arena = self.arena;
@@ -1099,6 +1208,10 @@ const Reader = struct {
const Result = struct { const Result = struct {
done: bool, done: bool,
data: ?[]u8, data: ?[]u8,
// Any unprocessed data we have from the last call to "process".
// We can have unprocessed data when transitioning from parsing the
// header to parsing the body. When using Chunked encoding, we'll also
// have unprocessed data between chunks.
unprocessed: ?[]u8 = null, unprocessed: ?[]u8 = null,
}; };
@@ -1150,8 +1263,7 @@ pub const Response = struct {
_request: *Request, _request: *Request,
_buf: []u8, _buf: []u8,
_socket: posix.socket_t, _connection: SyncHandler.Connection,
_tls_conn: ?TLSConnection,
_done: bool, _done: bool,
@@ -1170,16 +1282,7 @@ pub const Response = struct {
return null; return null;
} }
var n: usize = 0; const n = try self._connection.read(buf);
if (self._tls_conn) |*tls_conn| {
n = try tls_conn.read(buf);
} else {
n = try posix.read(self._socket, buf);
}
if (n == 0) {
self._done = true;
return null;
}
self._data = buf[0..n]; self._data = buf[0..n];
} }
} }
@@ -1404,9 +1507,6 @@ test "HttpClient Reader: fuzz" {
test "HttpClient: invalid url" { test "HttpClient: invalid url" {
var client = try Client.init(testing.allocator, 1); var client = try Client.init(testing.allocator, 1);
defer client.deinit(); defer client.deinit();
try testing.expectError(error.UnsupportedUriScheme, client.request(.GET, "://localhost"));
try testing.expectError(error.UnsupportedUriScheme, client.request(.GET, "ftp://localhost"));
try testing.expectError(error.UriMissingHost, client.request(.GET, "http:///")); try testing.expectError(error.UriMissingHost, client.request(.GET, "http:///"));
} }