From 511e9b969a254073a4197363cb2139926e59132d Mon Sep 17 00:00:00 2001 From: Pierre Tachoire Date: Tue, 30 Jan 2024 08:57:42 +0100 Subject: [PATCH] async: use std http client with loop --- src/async/Client.zig | 55 ++++---------- src/async/stream.zig | 168 +++++++++++++++++++++++++++++++++++++++++++ src/async/tcp.zig | 62 ++++++++++++++++ src/async/test.zig | 60 ++++++++++++++++ src/run_tests.zig | 5 ++ 5 files changed, 309 insertions(+), 41 deletions(-) create mode 100644 src/async/stream.zig create mode 100644 src/async/tcp.zig create mode 100644 src/async/test.zig diff --git a/src/async/Client.zig b/src/async/Client.zig index 8438a7bb..3af59134 100644 --- a/src/async/Client.zig +++ b/src/async/Client.zig @@ -3,9 +3,13 @@ //! Connections are opened in a thread-safe manner, but individual Requests are not. //! //! TLS support may be disabled via `std.options.http_disable_tls`. +//! +//! This file is a copy of the original std.http.Client with little changes to +//! handle non-blocking I/O with the jsruntime.Loop. -const std = @import("../std.zig"); +const std = @import("std"); const builtin = @import("builtin"); +const Stream = @import("stream.zig").Stream; const testing = std.testing; const http = std.http; const mem = std.mem; @@ -16,7 +20,10 @@ const assert = std.debug.assert; const use_vectors = builtin.zig_backend != .stage2_x86_64; const Client = @This(); -const proto = @import("protocol.zig"); +const proto = http.protocol; + +const Loop = @import("jsruntime").Loop; +const tcp = @import("tcp.zig"); pub const disable_tls = std.options.http_disable_tls; @@ -25,6 +32,9 @@ pub const disable_tls = std.options.http_disable_tls; /// This allocator must be thread-safe. allocator: Allocator, +// std.net.Stream implementation using jsruntime Loop +loop: *Loop, + ca_bundle: if (disable_tls) void else std.crypto.Certificate.Bundle = if (disable_tls) {} else .{}, ca_bundle_mutex: std.Thread.Mutex = .{}, @@ -194,7 +204,7 @@ pub const Connection = struct { pub const Protocol = enum { plain, tls }; - stream: net.Stream, + stream: Stream, /// undefined unless protocol is tls. tls_client: if (!disable_tls) *std.crypto.tls.Client else void, @@ -1210,7 +1220,7 @@ pub fn connectTcp(client: *Client, host: []const u8, port: u16, protocol: Connec errdefer client.allocator.destroy(conn); conn.* = .{ .data = undefined }; - const stream = net.tcpConnectToHost(client.allocator, host, port) catch |err| switch (err) { + const stream = tcp.tcpConnectToHost(client.allocator, client.loop, host, port) catch |err| switch (err) { error.ConnectionRefused => return error.ConnectionRefused, error.NetworkUnreachable => return error.NetworkUnreachable, error.ConnectionTimedOut => return error.ConnectionTimedOut, @@ -1250,43 +1260,6 @@ pub fn connectTcp(client: *Client, host: []const u8, port: u16, protocol: Connec return &conn.data; } -pub const ConnectUnixError = Allocator.Error || std.os.SocketError || error{ NameTooLong, Unsupported } || std.os.ConnectError; - -/// Connect to `path` as a unix domain socket. This will reuse a connection if one is already open. -/// -/// This function is threadsafe. -pub fn connectUnix(client: *Client, path: []const u8) ConnectUnixError!*Connection { - if (!net.has_unix_sockets) return error.Unsupported; - - if (client.connection_pool.findConnection(.{ - .host = path, - .port = 0, - .protocol = .plain, - })) |node| - return node; - - const conn = try client.allocator.create(ConnectionPool.Node); - errdefer client.allocator.destroy(conn); - conn.* = .{ .data = undefined }; - - const stream = try std.net.connectUnixSocket(path); - errdefer stream.close(); - - conn.data = .{ - .stream = stream, - .tls_client = undefined, - .protocol = .plain, - - .host = try client.allocator.dupe(u8, path), - .port = 0, - }; - errdefer client.allocator.free(conn.data.host); - - client.connection_pool.addUsed(conn); - - return &conn.data; -} - /// Connect to `tunnel_host:tunnel_port` using the specified proxy with HTTP CONNECT. This will reuse a connection if one is already open. /// /// This function is threadsafe. diff --git a/src/async/stream.zig b/src/async/stream.zig new file mode 100644 index 00000000..e1f2537f --- /dev/null +++ b/src/async/stream.zig @@ -0,0 +1,168 @@ +const std = @import("std"); +const builtin = @import("builtin"); +const os = std.os; +const io = std.io; +const assert = std.debug.assert; + +const Loop = @import("jsruntime").Loop; + +const WriteCmd = struct { + const Self = @This(); + + stream: Stream, + done: bool = false, + res: usize = undefined, + err: ?anyerror = null, + + fn run(self: *Self, buffer: []const u8) void { + self.stream.loop.send(*Self, self, callback, self.stream.handle, buffer); + } + + fn callback(self: *Self, err: ?anyerror, res: usize) void { + self.res = res; + self.err = err; + self.done = true; + } + + fn wait(self: *Self) !usize { + while (!self.done) try self.stream.loop.tick(); + if (self.err) |err| return err; + return self.res; + } +}; + +const ReadCmd = struct { + const Self = @This(); + + stream: Stream, + done: bool = false, + res: usize = undefined, + err: ?anyerror = null, + + fn run(self: *Self, buffer: []u8) void { + self.stream.loop.receive(*Self, self, callback, self.stream.handle, buffer); + } + + fn callback(self: *Self, _: []const u8, err: ?anyerror, res: usize) void { + self.res = res; + self.err = err; + self.done = true; + } + + fn wait(self: *Self) !usize { + while (!self.done) try self.stream.loop.tick(); + if (self.err) |err| return err; + return self.res; + } +}; + +pub const Stream = struct { + loop: *Loop, + + handle: std.os.socket_t, + + pub fn close(self: Stream) void { + os.closeSocket(self.handle); + } + + pub const ReadError = os.ReadError; + pub const WriteError = os.WriteError; + + pub const Reader = io.Reader(Stream, ReadError, read); + pub const Writer = io.Writer(Stream, WriteError, write); + + pub fn reader(self: Stream) Reader { + return .{ .context = self }; + } + + pub fn writer(self: Stream) Writer { + return .{ .context = self }; + } + + pub fn read(self: Stream, buffer: []u8) ReadError!usize { + var cmd = ReadCmd{ .stream = self }; + cmd.run(buffer); + return cmd.wait() catch |err| switch (err) { + else => return error.Unexpected, + }; + } + + pub fn readv(s: Stream, iovecs: []const os.iovec) ReadError!usize { + return os.readv(s.handle, iovecs); + } + + /// Returns the number of bytes read. If the number read is smaller than + /// `buffer.len`, it means the stream reached the end. Reaching the end of + /// a stream is not an error condition. + pub fn readAll(s: Stream, buffer: []u8) ReadError!usize { + return readAtLeast(s, buffer, buffer.len); + } + + /// Returns the number of bytes read, calling the underlying read function + /// the minimal number of times until the buffer has at least `len` bytes + /// filled. If the number read is less than `len` it means the stream + /// reached the end. Reaching the end of the stream is not an error + /// condition. + pub fn readAtLeast(s: Stream, buffer: []u8, len: usize) ReadError!usize { + assert(len <= buffer.len); + var index: usize = 0; + while (index < len) { + const amt = try s.read(buffer[index..]); + if (amt == 0) break; + index += amt; + } + return index; + } + + /// TODO in evented I/O mode, this implementation incorrectly uses the event loop's + /// file system thread instead of non-blocking. It needs to be reworked to properly + /// use non-blocking I/O. + pub fn write(self: Stream, buffer: []const u8) WriteError!usize { + var cmd = WriteCmd{ .stream = self }; + cmd.run(buffer); + + return cmd.wait() catch |err| switch (err) { + error.AccessDenied => error.AccessDenied, + error.WouldBlock => error.WouldBlock, + error.ConnectionResetByPeer => error.ConnectionResetByPeer, + error.MessageTooBig => error.FileTooBig, + error.BrokenPipe => error.BrokenPipe, + else => return error.Unexpected, + }; + } + + pub fn writeAll(self: Stream, bytes: []const u8) WriteError!void { + var index: usize = 0; + while (index < bytes.len) { + index += try self.write(bytes[index..]); + } + } + + /// See https://github.com/ziglang/zig/issues/7699 + /// See equivalent function: `std.fs.File.writev`. + pub fn writev(self: Stream, iovecs: []const os.iovec_const) WriteError!usize { + if (iovecs.len == 0) return 0; + const first_buffer = iovecs[0].iov_base[0..iovecs[0].iov_len]; + return try self.write(first_buffer); + } + + /// The `iovecs` parameter is mutable because this function needs to mutate the fields in + /// order to handle partial writes from the underlying OS layer. + /// See https://github.com/ziglang/zig/issues/7699 + /// See equivalent function: `std.fs.File.writevAll`. + pub fn writevAll(self: Stream, iovecs: []os.iovec_const) WriteError!void { + if (iovecs.len == 0) return; + + var i: usize = 0; + while (true) { + var amt = try self.writev(iovecs[i..]); + while (amt >= iovecs[i].iov_len) { + amt -= iovecs[i].iov_len; + i += 1; + if (i >= iovecs.len) return; + } + iovecs[i].iov_base += amt; + iovecs[i].iov_len -= amt; + } + } +}; diff --git a/src/async/tcp.zig b/src/async/tcp.zig new file mode 100644 index 00000000..a1a7f5b6 --- /dev/null +++ b/src/async/tcp.zig @@ -0,0 +1,62 @@ +const std = @import("std"); +const net = std.net; +const Stream = @import("stream.zig").Stream; +const Loop = @import("jsruntime").Loop; + +const ConnectCmd = struct { + const Self = @This(); + + loop: *Loop, + socket: std.os.socket_t, + err: ?anyerror = null, + done: bool = false, + + fn run(self: *Self, addr: std.net.Address) !void { + self.loop.connect(*Self, self, callback, self.socket, addr); + } + + fn callback(self: *Self, _: std.os.socket_t, err: ?anyerror) void { + self.err = err; + self.done = true; + } + + fn wait(self: *Self) !void { + while (!self.done) try self.loop.tick(); + if (self.err) |err| return err; + } +}; + +pub fn tcpConnectToHost(alloc: std.mem.Allocator, loop: *Loop, name: []const u8, port: u16) !Stream { + // TODO async resolve + const list = try net.getAddressList(alloc, name, port); + defer list.deinit(); + + if (list.addrs.len == 0) return error.UnknownHostName; + + for (list.addrs) |addr| { + return tcpConnectToAddress(loop, addr) catch |err| switch (err) { + error.ConnectionRefused => { + continue; + }, + else => return err, + }; + } + return std.os.ConnectError.ConnectionRefused; +} + +pub fn tcpConnectToAddress(loop: *Loop, addr: net.Address) !Stream { + const sockfd = try loop.open(addr.any.family, std.os.SOCK.STREAM, std.os.IPPROTO.TCP); + errdefer std.os.closeSocket(sockfd); + + var cmd = ConnectCmd{ + .loop = loop, + .socket = sockfd, + }; + try cmd.run(addr); + try cmd.wait(); + + return Stream{ + .loop = loop, + .handle = sockfd, + }; +} diff --git a/src/async/test.zig b/src/async/test.zig new file mode 100644 index 00000000..10081e9e --- /dev/null +++ b/src/async/test.zig @@ -0,0 +1,60 @@ +const std = @import("std"); +const http = std.http; +const StdClient = @import("Client.zig"); +// const hasync = @import("http.zig"); + +pub const Loop = @import("jsruntime").Loop; + +const url = "https://www.w3.org/"; + +test "blocking mode fetch API" { + const alloc = std.testing.allocator; + + var loop = try Loop.init(alloc); + defer loop.deinit(); + + var client: StdClient = .{ + .allocator = alloc, + .loop = &loop, + }; + defer client.deinit(); + + // force client's CA cert scan from system. + try client.ca_bundle.rescan(client.allocator); + + var res = try client.fetch(alloc, .{ + .location = .{ .uri = try std.Uri.parse(url) }, + .payload = .none, + }); + defer res.deinit(); + + try std.testing.expect(res.status == .ok); +} + +test "blocking mode open/send/wait API" { + const alloc = std.testing.allocator; + + var loop = try Loop.init(alloc); + defer loop.deinit(); + + var client: StdClient = .{ + .allocator = alloc, + .loop = &loop, + }; + defer client.deinit(); + + // force client's CA cert scan from system. + try client.ca_bundle.rescan(client.allocator); + + var headers = try std.http.Headers.initList(alloc, &[_]std.http.Field{}); + defer headers.deinit(); + + var req = try client.open(.GET, try std.Uri.parse(url), headers, .{}); + defer req.deinit(); + + try req.send(.{}); + try req.finish(); + try req.wait(); + + try std.testing.expect(req.response.status == .ok); +} diff --git a/src/run_tests.zig b/src/run_tests.zig index dea89d6f..592bc08e 100644 --- a/src/run_tests.zig +++ b/src/run_tests.zig @@ -94,6 +94,11 @@ pub fn main() !void { } } +test { + const TestAsync = @import("async/test.zig"); + std.testing.refAllDecls(TestAsync); +} + test "jsruntime" { // generate tests try generate.tests();