Cleaner merge

Switch to non-blocking sockets.

Fix TLS handshake/receive/send ordering
This commit is contained in:
Karl Seguin
2025-03-20 23:09:20 +08:00
parent feb2046549
commit 22aa126b29
12 changed files with 271 additions and 203 deletions

View File

@@ -252,10 +252,13 @@ pub const Request = struct {
};
if (self.secure) {
async_handler.connection.protocol = .{ .tls_client = try tls.asyn.Client(AsyncHandlerT.TLSHandler).init(self.arena, .{ .handler = async_handler }, .{
.host = self.host(),
.root_ca = self._client.root_ca,
}) };
async_handler.connection.protocol = .{
.tls_client = try tls.asyn.Client(AsyncHandlerT.TLSHandler).init(self.arena, .{ .handler = async_handler }, .{
.host = self.host(),
.root_ca = self._client.root_ca,
// .key_log_callback = tls.config.key_log.callback
}),
};
}
loop.connect(AsyncHandlerT, async_handler, &async_handler.read_completion, AsyncHandlerT.connected, socket, address);
@@ -274,6 +277,8 @@ pub const Request = struct {
if (!self._has_host_header) {
try self.headers.append(arena, .{ .name = "Host", .value = self.host() });
}
try self.headers.append(arena, .{ .name = "User-Agent", .value = "Lightpanda/1.0" });
}
// Sets up the request for redirecting.
@@ -442,6 +447,7 @@ fn AsyncHandler(comptime H: type, comptime L: type) type {
const ProcessStatus = enum {
done,
wait,
need_more,
};
@@ -466,7 +472,6 @@ fn AsyncHandler(comptime H: type, comptime L: type) type {
node.data = data;
self.send_queue.append(node);
if (self.send_queue.len > 1) {
// if we already had a message in the queue, then our send loop
// is already setup.
@@ -488,16 +493,20 @@ fn AsyncHandler(comptime H: type, comptime L: type) type {
const n = n_ catch |err| {
return self.handleError("Write error", err);
};
const node = self.send_queue.popFirst().?;
const node = self.send_queue.first.?;
const data = node.data;
if (n < data.len) {
var next: ?*SendQueue.Node = node;
if (n == data.len) {
_ = self.send_queue.popFirst();
next = node.next;
} else {
// didn't send all the data, we prematurely popped this off
// (because, in most cases, it _will_ send all the data)
node.data = data[n..];
self.send_queue.prepend(node);
}
if (self.send_queue.first) |next| {
if (next) |next_| {
// we still have data to send
self.loop.send(
Self,
@@ -505,12 +514,12 @@ fn AsyncHandler(comptime H: type, comptime L: type) type {
&self.send_completion,
sent,
self.socket,
next.data,
next_.data,
);
return;
}
self.connection.sent(self.state) catch |err| {
self.connection.sent() catch |err| {
self.handleError("Processing sent data", err);
};
}
@@ -546,6 +555,11 @@ fn AsyncHandler(comptime H: type, comptime L: type) type {
};
switch (status) {
.wait => {
// Happens when we're transitioning from handshaking to
// sending the request. Don't continue the read loop. Let
// the request get sent before we try to read again.
},
.need_more => self.receive(),
.done => {
const redirect = self.redirect orelse {
@@ -610,7 +624,7 @@ fn AsyncHandler(comptime H: type, comptime L: type) type {
}
if (done == true) {
return .need_more;
return .done;
}
// With chunked-encoding, it's possible that we we've only
@@ -638,8 +652,8 @@ fn AsyncHandler(comptime H: type, comptime L: type) type {
fn deinit(self: *Connection) void {
switch (self.protocol) {
.tls_client => |*tls_client| tls_client.deinit(),
.plain => {},
.tls_client => |*tls_client| tls_client.deinit(),
}
}
@@ -656,59 +670,14 @@ fn AsyncHandler(comptime H: type, comptime L: type) type {
handler.receive();
},
.plain => {
handler.state = .header;
// queue everything up
handler.state = .body;
const header = try handler.request.buildHeader();
return handler.send(header);
},
}
}
fn sent(self: *Connection, state: SendState) !void {
const handler = self.handler;
std.debug.assert(handler.state == state);
switch (self.protocol) {
.tls_client => |*tls_client| {
switch (state) {
.handshake => {
// Our send is complete, but it was part of the
// TLS handshake. This isn't data we need to
// worry about.
},
.header => {
// we WERE sending the header, but that's done
handler.state = .body;
if (handler.request.body) |body| {
try tls_client.send(body);
}
},
.body => {
// We've finished sending the body. For non TLS
// we'll start receiving. But here, for TLS,
// we started a receive loop as soon as the c
// connection was established.
},
}
},
.plain => {
switch (state) {
.handshake => unreachable,
.header => {
// we WERE sending the header, but that's done
handler.state = .body;
if (handler.request.body) |body| {
handler.send(body);
} else {
// No body? time to start reading the response
handler.receive();
}
},
.body => {
// we're done sending the body, time to start
// reading the response
handler.receive();
},
handler.send(header);
if (handler.request.body) |body| {
handler.send(body);
}
handler.receive();
},
}
}
@@ -723,6 +692,8 @@ fn AsyncHandler(comptime H: type, comptime L: type) type {
const pos = handler.read_pos;
const end = pos + n;
const is_handshaking = handler.state == .handshake;
const used = tls_client.onRecv(read_buf[0..end]) catch |err| switch (err) {
// https://github.com/ianic/tls.zig/pull/9
// we currently have no way to break out of the TLS handling
@@ -738,30 +709,79 @@ fn AsyncHandler(comptime H: type, comptime L: type) type {
if (used == end) {
// 1 - It used up all the data that we gave it
handler.read_pos = 0;
} else if (used == 0) {
if (is_handshaking and handler.state == .header) {
// we're transitioning from handshaking to
// sending the request. We should not be
// receiving data right now. This is particularly
// important becuase our socket is currently in
// blocking mode (until promise resolution is
// complete). If we try to receive now, we'll
// block the loop
return .wait;
}
// If we're here, we're either still handshaking
// (in which case we need more data), or we
// we're reading the response and we need more data
// (else we would have gotten TLSHandlerDone)
return .need_more;
}
if (used == 0) {
// 2 - It didn't use any of the data (i.e there
// wasn't a full record)
handler.read_pos = end;
} else {
// 3 - It used some of the data, but had leftover
// (i.e. there was 1+ full records AND an incomplete
// record). We need to maintain the "leftover" data
// for subsequent reads.
const extra = end - used;
std.mem.copyForwards(u8, read_buf, read_buf[extra..end]);
handler.read_pos = extra;
return .need_more;
}
// 3 - It used some of the data, but had leftover
// (i.e. there was 1+ full records AND an incomplete
// record). We need to maintain the "leftover" data
// for subsequent reads.
// Remember that our read_buf is the MAX possible TLS
// record size. So as long as we make sure that the start
// of a record is at read_buf[0], we know that we'll
// always have enough space for 1 record.
const unused = end - used;
std.mem.copyForwards(u8, read_buf, read_buf[unused..end]);
handler.read_pos = unused;
// an incomplete record means there must be more data
return .need_more;
},
.plain => return handler.processData(read_buf[0..n]),
}
}
fn sent(self: *Connection) !void {
switch (self.protocol) {
.tls_client => |*tls_client| {
const handler = self.handler;
switch (handler.state) {
.handshake => {
// Our send is complete, but it was part of the
// TLS handshake. This isn't data we need to
// worry about.
},
.header => {
// we WERE sending the header, but that's done
handler.state = .body;
if (handler.request.body) |body| {
try tls_client.send(body);
} else {
// no body to send, start receiving the response
handler.receive();
}
},
.body => handler.receive(),
}
},
.plain => {
// For plain, we already queued the header, the body
// and the reader!
},
}
}
};
// Separate struct just to keep it a bit cleaner. tls.zig requires
@@ -774,12 +794,15 @@ fn AsyncHandler(comptime H: type, comptime L: type) type {
// Callback from tls.zig indicating that the handshake is complete
pub fn onConnect(self: TLSHandler) void {
var handler = self.handler;
handler.state = .header;
const header = handler.request.buildHeader() catch |err| {
return handler.handleError("out of memory", err);
};
handler.state = .header;
handler.connection.protocol.tls_client.send(header) catch |err| {
return handler.handleError("TLS send", err);
const tls_client = &handler.connection.protocol.tls_client;
tls_client.send(header) catch |err| {
return handler.handleError("TLS send header", err);
};
}
@@ -798,7 +821,9 @@ fn AsyncHandler(comptime H: type, comptime L: type) type {
handler.handleError("Premature server response", error.InvalidServerResonse);
return error.InvalidServerResonse;
}
switch (handler.processData(data)) {
.wait => unreachable, // processData never returns this
.need_more => {},
.done => return error.TLSHandlerDone, // https://github.com/ianic/tls.zig/pull/9
}
@@ -818,10 +843,13 @@ const SyncHandler = struct {
var connection: Connection = undefined;
if (request.secure) {
connection = .{ .tls = try tls.client(std.net.Stream{ .handle = socket }, .{
.host = request.host(),
.root_ca = request._client.root_ca,
}) };
connection = .{
.tls = try tls.client(std.net.Stream{ .handle = socket }, .{
.host = request.host(),
.root_ca = request._client.root_ca,
// .key_log_callback = tls.config.key_log.callback,
}),
};
} else {
connection = .{ .plain = socket };
}
@@ -1700,11 +1728,12 @@ test "HttpClient: sync with body" {
try testing.expectEqual("over 9000!", try res.next());
try testing.expectEqual(201, res.header.status);
try testing.expectEqual(4, res.header.count());
try testing.expectEqual(5, res.header.count());
try testing.expectEqual("close", res.header.get("connection"));
try testing.expectEqual("10", res.header.get("content-length"));
try testing.expectEqual("127.0.0.1", res.header.get("_host"));
try testing.expectEqual("Close", res.header.get("_connection"));
try testing.expectEqual("Lightpanda/1.0", res.header.get("_user-agent"));
}
test "HttpClient: sync tls with body" {
@@ -1750,11 +1779,12 @@ test "HttpClient: sync redirect from TLS to Plaintext" {
arr.appendSliceAssumeCapacity(data);
}
try testing.expectEqual(201, res.header.status);
try testing.expectEqual(4, res.header.count());
try testing.expectEqual(5, res.header.count());
try testing.expectEqual("close", res.header.get("connection"));
try testing.expectEqual("10", res.header.get("content-length"));
try testing.expectEqual("127.0.0.1", res.header.get("_host"));
try testing.expectEqual("Close", res.header.get("_connection"));
try testing.expectEqual("Lightpanda/1.0", res.header.get("_user-agent"));
}
}
@@ -1792,11 +1822,12 @@ test "HttpClient: sync GET redirect" {
try testing.expectEqual("over 9000!", try res.next());
try testing.expectEqual(201, res.header.status);
try testing.expectEqual(4, res.header.count());
try testing.expectEqual(5, res.header.count());
try testing.expectEqual("close", res.header.get("connection"));
try testing.expectEqual("10", res.header.get("content-length"));
try testing.expectEqual("127.0.0.1", res.header.get("_host"));
try testing.expectEqual("Close", res.header.get("_connection"));
try testing.expectEqual("Lightpanda/1.0", res.header.get("_user-agent"));
}
test "HttpClient: async connect error" {
@@ -1886,6 +1917,7 @@ test "HttpClient: async with body" {
"connection", "close",
"content-length", "10",
"_host", "127.0.0.1",
"_user-agent", "Lightpanda/1.0",
"_connection", "Close",
});
}
@@ -1917,6 +1949,7 @@ test "HttpClient: async redirect" {
"connection", "close",
"content-length", "10",
"_host", "127.0.0.1",
"_user-agent", "Lightpanda/1.0",
"_connection", "Close",
});
}