From e2cc40457177d9da2bc1cb3423ce9daa5d949d14 Mon Sep 17 00:00:00 2001 From: Francis Bouvier Date: Sun, 29 Jun 2025 17:58:19 -0700 Subject: [PATCH] Handle TLS proxy, both for HTTP and HTTPS (tls in tls) endpoints --- src/http/client.zig | 150 +++++++++++++++++++++++++------------------- 1 file changed, 87 insertions(+), 63 deletions(-) diff --git a/src/http/client.zig b/src/http/client.zig index 73aef555..b3935ff9 100644 --- a/src/http/client.zig +++ b/src/http/client.zig @@ -240,6 +240,11 @@ pub const Client = struct { const proxy_type = self.proxy_type orelse return false; return proxy_type == .forward; } + + fn isProxyTLS(self: *const Client) bool { + const proxy = self.http_proxy orelse return false; + return std.mem.eql(u8, proxy.scheme, "https"); + } }; const RequestOpts = struct { @@ -331,10 +336,10 @@ const Connection = struct { fn close(self: *TLSClient) void { switch (self.*) { .blocking => |*tls_client| tls_client.close() catch {}, - .blocking_tls_in_tls => {}, // |*tls_in_tls| { - // tls_in_tls.destination.close() catch {}; // Crashes - // tls_in_tls.proxy.close() catch {}; - // }, + .blocking_tls_in_tls => |*tls_in_tls| { + tls_in_tls.destination.close() catch {}; + tls_in_tls.proxy.close() catch {}; + }, .nonblocking => {}, } } @@ -535,6 +540,7 @@ pub const Request = struct { } const is_connect_proxy = client.isConnectProxy(); + const is_proxy_tls = client.isProxyTLS(); var secure: bool = undefined; const scheme = if (is_connect_proxy) uri.scheme else connect_uri.scheme; @@ -546,7 +552,11 @@ pub const Request = struct { return error.UnsupportedUriScheme; } const request_port: u16 = uri.port orelse if (secure) 443 else 80; - const connect_port: u16 = connect_uri.port orelse (if (is_connect_proxy) 80 else request_port); + const connect_port: u16 = connect_uri.port orelse blk: { + if (is_connect_proxy) { + if (is_proxy_tls) break :blk 443 else break :blk 80; + } else break :blk request_port; + }; return .{ .secure = secure, @@ -663,36 +673,58 @@ pub const Request = struct { }; self._connection = connection; + const tls_config = tls.config.Client{ + .host = self._request_host, + .root_ca = self._client.root_ca, + .insecure_skip_verify = self._tls_verify_host == false, + // .key_log_callback = tls.config.key_log.callback, + }; + + // proxy const is_connect_proxy = self._client.isConnectProxy(); + const is_proxy_tls = self._client.isProxyTLS(); + if (is_connect_proxy) { - var connect_connection = try SyncHandler.connect(self); - if (self._secure) { // TODO separate _secure for proxy and desination - const tls_in_tls = try tls.client(&connect_connection, .{ - .host = self._request_host, - .root_ca = self._client.root_ca, - .insecure_skip_verify = self._tls_verify_host == false, - // .key_log_callback = tls.config.key_log.callback, - }); - self._connection.?.tls = .{ - .blocking_tls_in_tls = .{ - .proxy = connect_connection, - .destination = tls_in_tls, - }, - }; + var proxy_conn: SyncHandler.Conn = .{ .plain = self._connection.?.socket }; + + if (is_proxy_tls) { + + // create an underlying TLS stream with the proxy + var proxy_tls_config = tls_config; + proxy_tls_config.host = self._connect_host; + var proxy_conn_tls = try tls.client(std.net.Stream{ .handle = socket }, proxy_tls_config); + proxy_conn = .{ .tls = &proxy_conn_tls }; } - } else { - if (self._secure) { - self._connection.?.tls = .{ - .blocking = try tls.client(std.net.Stream{ .handle = socket }, .{ - .host = if (is_connect_proxy) self._request_host else self._connect_host, - .root_ca = self._client.root_ca, - .insecure_skip_verify = self._tls_verify_host == false, - // .key_log_callback = tls.config.key_log.callback, - }), - }; + + // connect to the proxy + try SyncHandler.connect(self, &proxy_conn); + + if (is_proxy_tls) { + if (self._secure) { + + // if secure endpoint, create the main TLS stream + // encapsulated into the TLS stream proxy + const tls_in_tls = try tls.client(proxy_conn.tls, tls_config); + self._connection.?.tls = .{ + .blocking_tls_in_tls = .{ + .proxy = proxy_conn.tls.*, + .destination = tls_in_tls, + }, + }; + } else { + + // otherwise, just use the TLS stream proxy + self._connection.?.tls = .{ .blocking = proxy_conn.tls.* }; + } } } + if (self._secure and !is_proxy_tls) { + self._connection.?.tls = .{ + .blocking = try tls.client(std.net.Stream{ .handle = socket }, tls_config), + }; + } + self._connection_from_keepalive = false; } @@ -1836,18 +1868,9 @@ const SyncHandler = struct { // Unfortunately, this is called from the Request doSendSync since we need // to do this before setting up our TLS connection. - fn connect(request: *Request) !tls.Connection(std.net.Stream) { - const socket = request._connection.?.socket; - + fn connect(request: *Request, conn: *Conn) !void { const header = try request.buildConnectHeader(); - // try Conn.writeAll(socket, header); - var tls_client = try tls.client(std.net.Stream{ .handle = socket }, .{ - .host = request._connect_host, - .root_ca = request._client.root_ca, - .insecure_skip_verify = request._tls_verify_host == false, - .key_log_callback = tls.config.key_log.callback, - }); - try tls_client.writeAll(header); + try conn.writeAll(header); var pos: usize = 0; var reader = request.newReader(); @@ -1858,24 +1881,19 @@ const SyncHandler = struct { // we only send CONNECT requests on newly established connections // and maybeRetryOrErr is only for connections that might have been // closed while being kept-alive - // const n = try posix.read(socket, read_buf[pos..]); - // const n = switch (self.*) { - // .tls => |tls_client| try tls_client.read(buf), - // .plain => |socket| try posix.read(socket, buf), - // }; - const n = try tls_client.read(read_buf[pos..]); + const n = try conn.read(read_buf[pos..]); if (n == 0) { return error.ConnectionResetByPeer; } pos += n; if (try reader.connectResponse(read_buf[0..pos])) { // returns true if we have a successful connect response - return tls_client; + return; } // we don't have enough data yet. } - return tls_client; + return; } fn maybeRetryOrErr(self: *SyncHandler, err: anyerror) !Response { @@ -1931,16 +1949,16 @@ const SyncHandler = struct { fn sendRequest(self: *Conn, header: []const u8, body: ?[]const u8) !void { switch (self.*) { - .tls_in_tls => |tls_client| { - try tls_client.writeAll(header); + .tls => |_| { + try self.writeAll(header); if (body) |b| { - try tls_client.writeAll(b); + try self.writeAll(b); } }, - .tls => |tls_client| { - try tls_client.writeAll(header); + .tls_in_tls => |_| { + try self.writeAll(header); if (body) |b| { - try tls_client.writeAll(b); + try self.writeAll(b); } }, .plain => |socket| { @@ -1951,15 +1969,15 @@ const SyncHandler = struct { }; return writeAllIOVec(socket, &vec); } - return writeAll(socket, header); + return self.writeAll(header); }, } } fn read(self: *Conn, buf: []u8) !usize { const n = switch (self.*) { - .tls_in_tls => |tls_client| try tls_client.read(buf), .tls => |tls_client| try tls_client.read(buf), + .tls_in_tls => |tls_client| try tls_client.read(buf), .plain => |socket| try posix.read(socket, buf), }; if (n == 0) { @@ -1968,6 +1986,19 @@ const SyncHandler = struct { return n; } + fn writeAll(self: *Conn, data: []const u8) !void { + switch (self.*) { + .tls => |tls_client| try tls_client.writeAll(data), + .tls_in_tls => |tls_client| try tls_client.writeAll(data), + .plain => |socket| { + var i: usize = 0; + while (i < data.len) { + i += try posix.write(socket, data[i..]); + } + }, + } + } + fn writeAllIOVec(socket: posix.socket_t, vec: []posix.iovec_const) !void { var i: usize = 0; while (true) { @@ -1983,13 +2014,6 @@ const SyncHandler = struct { vec[i].len -= n; } } - - fn writeAll(socket: posix.socket_t, data: []const u8) !void { - var i: usize = 0; - while (i < data.len) { - i += try posix.write(socket, data[i..]); - } - } }; // We don't ask for encoding, but some providers (CloudFront!!)