async: use std http client with loop

This commit is contained in:
Pierre Tachoire
2024-01-30 08:57:42 +01:00
parent 9d26a43aa8
commit 511e9b969a
5 changed files with 309 additions and 41 deletions

View File

@@ -3,9 +3,13 @@
//! Connections are opened in a thread-safe manner, but individual Requests are not.
//!
//! TLS support may be disabled via `std.options.http_disable_tls`.
//!
//! This file is a copy of the original std.http.Client with little changes to
//! handle non-blocking I/O with the jsruntime.Loop.
const std = @import("../std.zig");
const std = @import("std");
const builtin = @import("builtin");
const Stream = @import("stream.zig").Stream;
const testing = std.testing;
const http = std.http;
const mem = std.mem;
@@ -16,7 +20,10 @@ const assert = std.debug.assert;
const use_vectors = builtin.zig_backend != .stage2_x86_64;
const Client = @This();
const proto = @import("protocol.zig");
const proto = http.protocol;
const Loop = @import("jsruntime").Loop;
const tcp = @import("tcp.zig");
pub const disable_tls = std.options.http_disable_tls;
@@ -25,6 +32,9 @@ pub const disable_tls = std.options.http_disable_tls;
/// This allocator must be thread-safe.
allocator: Allocator,
// std.net.Stream implementation using jsruntime Loop
loop: *Loop,
ca_bundle: if (disable_tls) void else std.crypto.Certificate.Bundle = if (disable_tls) {} else .{},
ca_bundle_mutex: std.Thread.Mutex = .{},
@@ -194,7 +204,7 @@ pub const Connection = struct {
pub const Protocol = enum { plain, tls };
stream: net.Stream,
stream: Stream,
/// undefined unless protocol is tls.
tls_client: if (!disable_tls) *std.crypto.tls.Client else void,
@@ -1210,7 +1220,7 @@ pub fn connectTcp(client: *Client, host: []const u8, port: u16, protocol: Connec
errdefer client.allocator.destroy(conn);
conn.* = .{ .data = undefined };
const stream = net.tcpConnectToHost(client.allocator, host, port) catch |err| switch (err) {
const stream = tcp.tcpConnectToHost(client.allocator, client.loop, host, port) catch |err| switch (err) {
error.ConnectionRefused => return error.ConnectionRefused,
error.NetworkUnreachable => return error.NetworkUnreachable,
error.ConnectionTimedOut => return error.ConnectionTimedOut,
@@ -1250,43 +1260,6 @@ pub fn connectTcp(client: *Client, host: []const u8, port: u16, protocol: Connec
return &conn.data;
}
pub const ConnectUnixError = Allocator.Error || std.os.SocketError || error{ NameTooLong, Unsupported } || std.os.ConnectError;
/// Connect to `path` as a unix domain socket. This will reuse a connection if one is already open.
///
/// This function is threadsafe.
pub fn connectUnix(client: *Client, path: []const u8) ConnectUnixError!*Connection {
if (!net.has_unix_sockets) return error.Unsupported;
if (client.connection_pool.findConnection(.{
.host = path,
.port = 0,
.protocol = .plain,
})) |node|
return node;
const conn = try client.allocator.create(ConnectionPool.Node);
errdefer client.allocator.destroy(conn);
conn.* = .{ .data = undefined };
const stream = try std.net.connectUnixSocket(path);
errdefer stream.close();
conn.data = .{
.stream = stream,
.tls_client = undefined,
.protocol = .plain,
.host = try client.allocator.dupe(u8, path),
.port = 0,
};
errdefer client.allocator.free(conn.data.host);
client.connection_pool.addUsed(conn);
return &conn.data;
}
/// Connect to `tunnel_host:tunnel_port` using the specified proxy with HTTP CONNECT. This will reuse a connection if one is already open.
///
/// This function is threadsafe.

168
src/async/stream.zig Normal file
View File

@@ -0,0 +1,168 @@
const std = @import("std");
const builtin = @import("builtin");
const os = std.os;
const io = std.io;
const assert = std.debug.assert;
const Loop = @import("jsruntime").Loop;
const WriteCmd = struct {
const Self = @This();
stream: Stream,
done: bool = false,
res: usize = undefined,
err: ?anyerror = null,
fn run(self: *Self, buffer: []const u8) void {
self.stream.loop.send(*Self, self, callback, self.stream.handle, buffer);
}
fn callback(self: *Self, err: ?anyerror, res: usize) void {
self.res = res;
self.err = err;
self.done = true;
}
fn wait(self: *Self) !usize {
while (!self.done) try self.stream.loop.tick();
if (self.err) |err| return err;
return self.res;
}
};
const ReadCmd = struct {
const Self = @This();
stream: Stream,
done: bool = false,
res: usize = undefined,
err: ?anyerror = null,
fn run(self: *Self, buffer: []u8) void {
self.stream.loop.receive(*Self, self, callback, self.stream.handle, buffer);
}
fn callback(self: *Self, _: []const u8, err: ?anyerror, res: usize) void {
self.res = res;
self.err = err;
self.done = true;
}
fn wait(self: *Self) !usize {
while (!self.done) try self.stream.loop.tick();
if (self.err) |err| return err;
return self.res;
}
};
pub const Stream = struct {
loop: *Loop,
handle: std.os.socket_t,
pub fn close(self: Stream) void {
os.closeSocket(self.handle);
}
pub const ReadError = os.ReadError;
pub const WriteError = os.WriteError;
pub const Reader = io.Reader(Stream, ReadError, read);
pub const Writer = io.Writer(Stream, WriteError, write);
pub fn reader(self: Stream) Reader {
return .{ .context = self };
}
pub fn writer(self: Stream) Writer {
return .{ .context = self };
}
pub fn read(self: Stream, buffer: []u8) ReadError!usize {
var cmd = ReadCmd{ .stream = self };
cmd.run(buffer);
return cmd.wait() catch |err| switch (err) {
else => return error.Unexpected,
};
}
pub fn readv(s: Stream, iovecs: []const os.iovec) ReadError!usize {
return os.readv(s.handle, iovecs);
}
/// Returns the number of bytes read. If the number read is smaller than
/// `buffer.len`, it means the stream reached the end. Reaching the end of
/// a stream is not an error condition.
pub fn readAll(s: Stream, buffer: []u8) ReadError!usize {
return readAtLeast(s, buffer, buffer.len);
}
/// Returns the number of bytes read, calling the underlying read function
/// the minimal number of times until the buffer has at least `len` bytes
/// filled. If the number read is less than `len` it means the stream
/// reached the end. Reaching the end of the stream is not an error
/// condition.
pub fn readAtLeast(s: Stream, buffer: []u8, len: usize) ReadError!usize {
assert(len <= buffer.len);
var index: usize = 0;
while (index < len) {
const amt = try s.read(buffer[index..]);
if (amt == 0) break;
index += amt;
}
return index;
}
/// TODO in evented I/O mode, this implementation incorrectly uses the event loop's
/// file system thread instead of non-blocking. It needs to be reworked to properly
/// use non-blocking I/O.
pub fn write(self: Stream, buffer: []const u8) WriteError!usize {
var cmd = WriteCmd{ .stream = self };
cmd.run(buffer);
return cmd.wait() catch |err| switch (err) {
error.AccessDenied => error.AccessDenied,
error.WouldBlock => error.WouldBlock,
error.ConnectionResetByPeer => error.ConnectionResetByPeer,
error.MessageTooBig => error.FileTooBig,
error.BrokenPipe => error.BrokenPipe,
else => return error.Unexpected,
};
}
pub fn writeAll(self: Stream, bytes: []const u8) WriteError!void {
var index: usize = 0;
while (index < bytes.len) {
index += try self.write(bytes[index..]);
}
}
/// See https://github.com/ziglang/zig/issues/7699
/// See equivalent function: `std.fs.File.writev`.
pub fn writev(self: Stream, iovecs: []const os.iovec_const) WriteError!usize {
if (iovecs.len == 0) return 0;
const first_buffer = iovecs[0].iov_base[0..iovecs[0].iov_len];
return try self.write(first_buffer);
}
/// The `iovecs` parameter is mutable because this function needs to mutate the fields in
/// order to handle partial writes from the underlying OS layer.
/// See https://github.com/ziglang/zig/issues/7699
/// See equivalent function: `std.fs.File.writevAll`.
pub fn writevAll(self: Stream, iovecs: []os.iovec_const) WriteError!void {
if (iovecs.len == 0) return;
var i: usize = 0;
while (true) {
var amt = try self.writev(iovecs[i..]);
while (amt >= iovecs[i].iov_len) {
amt -= iovecs[i].iov_len;
i += 1;
if (i >= iovecs.len) return;
}
iovecs[i].iov_base += amt;
iovecs[i].iov_len -= amt;
}
}
};

62
src/async/tcp.zig Normal file
View File

@@ -0,0 +1,62 @@
const std = @import("std");
const net = std.net;
const Stream = @import("stream.zig").Stream;
const Loop = @import("jsruntime").Loop;
const ConnectCmd = struct {
const Self = @This();
loop: *Loop,
socket: std.os.socket_t,
err: ?anyerror = null,
done: bool = false,
fn run(self: *Self, addr: std.net.Address) !void {
self.loop.connect(*Self, self, callback, self.socket, addr);
}
fn callback(self: *Self, _: std.os.socket_t, err: ?anyerror) void {
self.err = err;
self.done = true;
}
fn wait(self: *Self) !void {
while (!self.done) try self.loop.tick();
if (self.err) |err| return err;
}
};
pub fn tcpConnectToHost(alloc: std.mem.Allocator, loop: *Loop, name: []const u8, port: u16) !Stream {
// TODO async resolve
const list = try net.getAddressList(alloc, name, port);
defer list.deinit();
if (list.addrs.len == 0) return error.UnknownHostName;
for (list.addrs) |addr| {
return tcpConnectToAddress(loop, addr) catch |err| switch (err) {
error.ConnectionRefused => {
continue;
},
else => return err,
};
}
return std.os.ConnectError.ConnectionRefused;
}
pub fn tcpConnectToAddress(loop: *Loop, addr: net.Address) !Stream {
const sockfd = try loop.open(addr.any.family, std.os.SOCK.STREAM, std.os.IPPROTO.TCP);
errdefer std.os.closeSocket(sockfd);
var cmd = ConnectCmd{
.loop = loop,
.socket = sockfd,
};
try cmd.run(addr);
try cmd.wait();
return Stream{
.loop = loop,
.handle = sockfd,
};
}

60
src/async/test.zig Normal file
View File

@@ -0,0 +1,60 @@
const std = @import("std");
const http = std.http;
const StdClient = @import("Client.zig");
// const hasync = @import("http.zig");
pub const Loop = @import("jsruntime").Loop;
const url = "https://www.w3.org/";
test "blocking mode fetch API" {
const alloc = std.testing.allocator;
var loop = try Loop.init(alloc);
defer loop.deinit();
var client: StdClient = .{
.allocator = alloc,
.loop = &loop,
};
defer client.deinit();
// force client's CA cert scan from system.
try client.ca_bundle.rescan(client.allocator);
var res = try client.fetch(alloc, .{
.location = .{ .uri = try std.Uri.parse(url) },
.payload = .none,
});
defer res.deinit();
try std.testing.expect(res.status == .ok);
}
test "blocking mode open/send/wait API" {
const alloc = std.testing.allocator;
var loop = try Loop.init(alloc);
defer loop.deinit();
var client: StdClient = .{
.allocator = alloc,
.loop = &loop,
};
defer client.deinit();
// force client's CA cert scan from system.
try client.ca_bundle.rescan(client.allocator);
var headers = try std.http.Headers.initList(alloc, &[_]std.http.Field{});
defer headers.deinit();
var req = try client.open(.GET, try std.Uri.parse(url), headers, .{});
defer req.deinit();
try req.send(.{});
try req.finish();
try req.wait();
try std.testing.expect(req.response.status == .ok);
}

View File

@@ -94,6 +94,11 @@ pub fn main() !void {
}
}
test {
const TestAsync = @import("async/test.zig");
std.testing.refAllDecls(TestAsync);
}
test "jsruntime" {
// generate tests
try generate.tests();