diff --git a/.gitmodules b/.gitmodules index 1dcd0195..2ceea970 100644 --- a/.gitmodules +++ b/.gitmodules @@ -22,3 +22,6 @@ [submodule "vendor/mimalloc"] path = vendor/mimalloc url = git@github.com:microsoft/mimalloc.git +[submodule "vendor/tls.zig"] + path = vendor/tls.zig + url = git@github.com:ianic/tls.zig.git diff --git a/build.zig b/build.zig index c6b65029..75c6b396 100644 --- a/build.zig +++ b/build.zig @@ -179,6 +179,11 @@ fn common( const netsurf = moduleNetSurf(b); netsurf.addImport("jsruntime", jsruntimemod); step.root_module.addImport("netsurf", netsurf); + + const tlsmod = b.addModule("tls", .{ + .root_source_file = b.path("vendor/tls.zig/src/main.zig"), + }); + step.root_module.addImport("tls", tlsmod); } fn moduleNetSurf(b: *std.Build) *std.Build.Module { diff --git a/src/async/Client.zig b/src/async/Client.zig index b6fe6f04..91748b9a 100644 --- a/src/async/Client.zig +++ b/src/async/Client.zig @@ -35,7 +35,9 @@ const assert = std.debug.assert; const use_vectors = builtin.zig_backend != .stage2_x86_64; const Client = @This(); -const proto = http.protocol; +const proto = std.http.protocol; + +const tls23 = @import("tls"); const Loop = @import("jsruntime").Loop; const tcp = @import("tcp.zig"); @@ -217,7 +219,7 @@ pub const ConnectionPool = struct { pub const Connection = struct { stream: Stream, /// undefined unless protocol is tls. - tls_client: if (!disable_tls) *std.crypto.tls.Client else void, + tls_client: if (!disable_tls) *tls23.Connection(Stream) else void, /// The protocol that this connection is using. protocol: Protocol, @@ -246,12 +248,12 @@ pub const Connection = struct { pub const Protocol = enum { plain, tls }; pub fn readvDirectTls(conn: *Connection, buffers: []std.posix.iovec) ReadError!usize { - return conn.tls_client.readv(conn.stream, buffers) catch |err| { + return conn.tls_client.readv(buffers) catch |err| { // https://github.com/ziglang/zig/issues/2473 if (mem.startsWith(u8, @errorName(err), "TlsAlert")) return error.TlsAlert; switch (err) { - error.TlsConnectionTruncated, error.TlsRecordOverflow, error.TlsDecodeError, error.TlsBadRecordMac, error.TlsBadLength, error.TlsIllegalParameter, error.TlsUnexpectedMessage => return error.TlsFailure, + error.TlsRecordOverflow, error.TlsBadRecordMac, error.TlsUnexpectedMessage => return error.TlsFailure, error.ConnectionTimedOut => return error.ConnectionTimedOut, error.ConnectionResetByPeer, error.BrokenPipe => return error.ConnectionResetByPeer, else => return error.UnexpectedReadFailure, @@ -344,7 +346,7 @@ pub const Connection = struct { } pub fn writeAllDirectTls(conn: *Connection, buffer: []const u8) WriteError!void { - return conn.tls_client.writeAll(conn.stream, buffer) catch |err| switch (err) { + return conn.tls_client.writeAll(buffer) catch |err| switch (err) { error.BrokenPipe, error.ConnectionResetByPeer => return error.ConnectionResetByPeer, else => return error.UnexpectedWriteFailure, }; @@ -412,7 +414,7 @@ pub const Connection = struct { if (disable_tls) unreachable; // try to cleanly close the TLS connection, for any server that cares. - _ = conn.tls_client.writeEnd(conn.stream, "", true) catch {}; + conn.tls_client.close() catch {}; allocator.destroy(conn.tls_client); } @@ -1376,13 +1378,13 @@ pub fn connectTcp(client: *Client, host: []const u8, port: u16, protocol: Connec if (protocol == .tls) { if (disable_tls) unreachable; - conn.data.tls_client = try client.allocator.create(std.crypto.tls.Client); + conn.data.tls_client = try client.allocator.create(tls23.Connection(Stream)); errdefer client.allocator.destroy(conn.data.tls_client); - conn.data.tls_client.* = std.crypto.tls.Client.init(stream, client.ca_bundle, host) catch return error.TlsInitializationFailed; - // This is appropriate for HTTPS because the HTTP headers contain - // the content length which is used to detect truncation attacks. - conn.data.tls_client.allow_truncation_attacks = true; + conn.data.tls_client.* = tls23.client(stream, .{ + .host = host, + .root_ca = client.ca_bundle, + }) catch return error.TlsInitializationFailed; } client.connection_pool.addUsed(conn); diff --git a/src/browser/browser.zig b/src/browser/browser.zig index 9e16a480..cc545ff1 100644 --- a/src/browser/browser.zig +++ b/src/browser/browser.zig @@ -37,7 +37,7 @@ const Walker = @import("../dom/walker.zig").WalkerDepthFirst; const storage = @import("../storage/storage.zig"); -const FetchResult = std.http.Client.FetchResult; +const FetchResult = @import("../http/Client.zig").Client.FetchResult; const UserContext = @import("../user_context.zig").UserContext; const HttpClient = @import("../async/Client.zig"); diff --git a/src/browser/loader.zig b/src/browser/loader.zig index c7dc1ea9..535e9c87 100644 --- a/src/browser/loader.zig +++ b/src/browser/loader.zig @@ -17,17 +17,18 @@ // along with this program. If not, see . const std = @import("std"); +const Client = @import("../http/Client.zig"); const user_agent = "Lightpanda.io/1.0"; pub const Loader = struct { - client: std.http.Client, + client: Client, // use 16KB for headers buffer size. server_header_buffer: [1024 * 16]u8 = undefined, pub const Response = struct { alloc: std.mem.Allocator, - req: *std.http.Client.Request, + req: *Client.Request, pub fn deinit(self: *Response) void { self.req.deinit(); @@ -37,7 +38,7 @@ pub const Loader = struct { pub fn init(alloc: std.mem.Allocator) Loader { return Loader{ - .client = std.http.Client{ + .client = Client{ .allocator = alloc, }, }; @@ -54,7 +55,7 @@ pub const Loader = struct { pub fn get(self: *Loader, alloc: std.mem.Allocator, uri: std.Uri) !Response { var resp = Response{ .alloc = alloc, - .req = try alloc.create(std.http.Client.Request), + .req = try alloc.create(Client.Request), }; errdefer alloc.destroy(resp.req); diff --git a/src/http/Client.zig b/src/http/Client.zig new file mode 100644 index 00000000..eeac1663 --- /dev/null +++ b/src/http/Client.zig @@ -0,0 +1,1794 @@ +// Copyright (C) 2023-2024 Lightpanda (Selecy SAS) +// +// Francis Bouvier +// Pierre Tachoire +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as +// published by the Free Software Foundation, either version 3 of the +// License, or (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +//! HTTP(S) Client implementation. +//! +//! Connections are opened in a thread-safe manner, but individual Requests are not. +//! +//! TLS support may be disabled via `std.options.http_disable_tls`. + +const std = @import("std"); +const builtin = @import("builtin"); +const testing = std.testing; +const http = std.http; +const mem = std.mem; +const net = std.net; +const Uri = std.Uri; +const Allocator = mem.Allocator; +const assert = std.debug.assert; +const use_vectors = builtin.zig_backend != .stage2_x86_64; + +const Client = @This(); +const proto = std.http.protocol; + +const tls23 = @import("tls"); + +pub const disable_tls = std.options.http_disable_tls; + +/// Used for all client allocations. Must be thread-safe. +allocator: Allocator, + +ca_bundle: if (disable_tls) void else std.crypto.Certificate.Bundle = if (disable_tls) {} else .{}, +ca_bundle_mutex: std.Thread.Mutex = .{}, + +/// When this is `true`, the next time this client performs an HTTPS request, +/// it will first rescan the system for root certificates. +next_https_rescan_certs: bool = true, + +/// The pool of connections that can be reused (and currently in use). +connection_pool: ConnectionPool = .{}, + +/// If populated, all http traffic travels through this third party. +/// This field cannot be modified while the client has active connections. +/// Pointer to externally-owned memory. +http_proxy: ?*Proxy = null, +/// If populated, all https traffic travels through this third party. +/// This field cannot be modified while the client has active connections. +/// Pointer to externally-owned memory. +https_proxy: ?*Proxy = null, + +/// A set of linked lists of connections that can be reused. +pub const ConnectionPool = struct { + mutex: std.Thread.Mutex = .{}, + /// Open connections that are currently in use. + used: Queue = .{}, + /// Open connections that are not currently in use. + free: Queue = .{}, + free_len: usize = 0, + free_size: usize = 32, + + /// The criteria for a connection to be considered a match. + pub const Criteria = struct { + host: []const u8, + port: u16, + protocol: Connection.Protocol, + }; + + const Queue = std.DoublyLinkedList(Connection); + pub const Node = Queue.Node; + + /// Finds and acquires a connection from the connection pool matching the criteria. This function is threadsafe. + /// If no connection is found, null is returned. + pub fn findConnection(pool: *ConnectionPool, criteria: Criteria) ?*Connection { + pool.mutex.lock(); + defer pool.mutex.unlock(); + + var next = pool.free.last; + while (next) |node| : (next = node.prev) { + if (node.data.protocol != criteria.protocol) continue; + if (node.data.port != criteria.port) continue; + + // Domain names are case-insensitive (RFC 5890, Section 2.3.2.4) + if (!std.ascii.eqlIgnoreCase(node.data.host, criteria.host)) continue; + + pool.acquireUnsafe(node); + return &node.data; + } + + return null; + } + + /// Acquires an existing connection from the connection pool. This function is not threadsafe. + pub fn acquireUnsafe(pool: *ConnectionPool, node: *Node) void { + pool.free.remove(node); + pool.free_len -= 1; + + pool.used.append(node); + } + + /// Acquires an existing connection from the connection pool. This function is threadsafe. + pub fn acquire(pool: *ConnectionPool, node: *Node) void { + pool.mutex.lock(); + defer pool.mutex.unlock(); + + return pool.acquireUnsafe(node); + } + + /// Tries to release a connection back to the connection pool. This function is threadsafe. + /// If the connection is marked as closing, it will be closed instead. + /// + /// The allocator must be the owner of all nodes in this pool. + /// The allocator must be the owner of all resources associated with the connection. + pub fn release(pool: *ConnectionPool, allocator: Allocator, connection: *Connection) void { + pool.mutex.lock(); + defer pool.mutex.unlock(); + + const node: *Node = @fieldParentPtr("data", connection); + + pool.used.remove(node); + + if (node.data.closing or pool.free_size == 0) { + node.data.close(allocator); + return allocator.destroy(node); + } + + if (pool.free_len >= pool.free_size) { + const popped = pool.free.popFirst() orelse unreachable; + pool.free_len -= 1; + + popped.data.close(allocator); + allocator.destroy(popped); + } + + if (node.data.proxied) { + pool.free.prepend(node); // proxied connections go to the end of the queue, always try direct connections first + } else { + pool.free.append(node); + } + + pool.free_len += 1; + } + + /// Adds a newly created node to the pool of used connections. This function is threadsafe. + pub fn addUsed(pool: *ConnectionPool, node: *Node) void { + pool.mutex.lock(); + defer pool.mutex.unlock(); + + pool.used.append(node); + } + + /// Resizes the connection pool. This function is threadsafe. + /// + /// If the new size is smaller than the current size, then idle connections will be closed until the pool is the new size. + pub fn resize(pool: *ConnectionPool, allocator: Allocator, new_size: usize) void { + pool.mutex.lock(); + defer pool.mutex.unlock(); + + const next = pool.free.first; + _ = next; + while (pool.free_len > new_size) { + const popped = pool.free.popFirst() orelse unreachable; + pool.free_len -= 1; + + popped.data.close(allocator); + allocator.destroy(popped); + } + + pool.free_size = new_size; + } + + /// Frees the connection pool and closes all connections within. This function is threadsafe. + /// + /// All future operations on the connection pool will deadlock. + pub fn deinit(pool: *ConnectionPool, allocator: Allocator) void { + pool.mutex.lock(); + + var next = pool.free.first; + while (next) |node| { + defer allocator.destroy(node); + next = node.next; + + node.data.close(allocator); + } + + next = pool.used.first; + while (next) |node| { + defer allocator.destroy(node); + next = node.next; + + node.data.close(allocator); + } + + pool.* = undefined; + } +}; + +/// An interface to either a plain or TLS connection. +pub const Connection = struct { + stream: net.Stream, + /// undefined unless protocol is tls. + tls_client: if (!disable_tls) *tls23.Connection(net.Stream) else void, + + /// The protocol that this connection is using. + protocol: Protocol, + + /// The host that this connection is connected to. + host: []u8, + + /// The port that this connection is connected to. + port: u16, + + /// Whether this connection is proxied and is not directly connected. + proxied: bool = false, + + /// Whether this connection is closing when we're done with it. + closing: bool = false, + + read_start: BufferSize = 0, + read_end: BufferSize = 0, + write_end: BufferSize = 0, + read_buf: [buffer_size]u8 = undefined, + write_buf: [buffer_size]u8 = undefined, + + pub const buffer_size = std.crypto.tls.max_ciphertext_record_len; + const BufferSize = std.math.IntFittingRange(0, buffer_size); + + pub const Protocol = enum { plain, tls }; + + pub fn readvDirectTls(conn: *Connection, buffers: []std.posix.iovec) ReadError!usize { + return conn.tls_client.readv(buffers) catch |err| { + // https://github.com/ziglang/zig/issues/2473 + if (mem.startsWith(u8, @errorName(err), "TlsAlert")) return error.TlsAlert; + + switch (err) { + error.TlsRecordOverflow, error.TlsBadRecordMac, error.TlsUnexpectedMessage => return error.TlsFailure, + error.ConnectionTimedOut => return error.ConnectionTimedOut, + error.ConnectionResetByPeer, error.BrokenPipe => return error.ConnectionResetByPeer, + else => return error.UnexpectedReadFailure, + } + }; + } + + pub fn readvDirect(conn: *Connection, buffers: []std.posix.iovec) ReadError!usize { + if (conn.protocol == .tls) { + if (disable_tls) unreachable; + + return conn.readvDirectTls(buffers); + } + + return conn.stream.readv(buffers) catch |err| switch (err) { + error.ConnectionTimedOut => return error.ConnectionTimedOut, + error.ConnectionResetByPeer, error.BrokenPipe => return error.ConnectionResetByPeer, + else => return error.UnexpectedReadFailure, + }; + } + + /// Refills the read buffer with data from the connection. + pub fn fill(conn: *Connection) ReadError!void { + if (conn.read_end != conn.read_start) return; + + var iovecs = [1]std.posix.iovec{ + .{ .base = &conn.read_buf, .len = conn.read_buf.len }, + }; + const nread = try conn.readvDirect(&iovecs); + if (nread == 0) return error.EndOfStream; + conn.read_start = 0; + conn.read_end = @intCast(nread); + } + + /// Returns the current slice of buffered data. + pub fn peek(conn: *Connection) []const u8 { + return conn.read_buf[conn.read_start..conn.read_end]; + } + + /// Discards the given number of bytes from the read buffer. + pub fn drop(conn: *Connection, num: BufferSize) void { + conn.read_start += num; + } + + /// Reads data from the connection into the given buffer. + pub fn read(conn: *Connection, buffer: []u8) ReadError!usize { + const available_read = conn.read_end - conn.read_start; + const available_buffer = buffer.len; + + if (available_read > available_buffer) { // partially read buffered data + @memcpy(buffer[0..available_buffer], conn.read_buf[conn.read_start..conn.read_end][0..available_buffer]); + conn.read_start += @intCast(available_buffer); + + return available_buffer; + } else if (available_read > 0) { // fully read buffered data + @memcpy(buffer[0..available_read], conn.read_buf[conn.read_start..conn.read_end]); + conn.read_start += available_read; + + return available_read; + } + + var iovecs = [2]std.posix.iovec{ + .{ .base = buffer.ptr, .len = buffer.len }, + .{ .base = &conn.read_buf, .len = conn.read_buf.len }, + }; + const nread = try conn.readvDirect(&iovecs); + + if (nread > buffer.len) { + conn.read_start = 0; + conn.read_end = @intCast(nread - buffer.len); + return buffer.len; + } + + return nread; + } + + pub const ReadError = error{ + TlsFailure, + TlsAlert, + ConnectionTimedOut, + ConnectionResetByPeer, + UnexpectedReadFailure, + EndOfStream, + }; + + pub const Reader = std.io.Reader(*Connection, ReadError, read); + + pub fn reader(conn: *Connection) Reader { + return Reader{ .context = conn }; + } + + pub fn writeAllDirectTls(conn: *Connection, buffer: []const u8) WriteError!void { + return conn.tls_client.writeAll(buffer) catch |err| switch (err) { + error.BrokenPipe, error.ConnectionResetByPeer => return error.ConnectionResetByPeer, + else => return error.UnexpectedWriteFailure, + }; + } + + pub fn writeAllDirect(conn: *Connection, buffer: []const u8) WriteError!void { + if (conn.protocol == .tls) { + if (disable_tls) unreachable; + + return conn.writeAllDirectTls(buffer); + } + + return conn.stream.writeAll(buffer) catch |err| switch (err) { + error.BrokenPipe, error.ConnectionResetByPeer => return error.ConnectionResetByPeer, + else => return error.UnexpectedWriteFailure, + }; + } + + /// Writes the given buffer to the connection. + pub fn write(conn: *Connection, buffer: []const u8) WriteError!usize { + if (conn.write_buf.len - conn.write_end < buffer.len) { + try conn.flush(); + + if (buffer.len > conn.write_buf.len) { + try conn.writeAllDirect(buffer); + return buffer.len; + } + } + + @memcpy(conn.write_buf[conn.write_end..][0..buffer.len], buffer); + conn.write_end += @intCast(buffer.len); + + return buffer.len; + } + + /// Returns a buffer to be filled with exactly len bytes to write to the connection. + pub fn allocWriteBuffer(conn: *Connection, len: BufferSize) WriteError![]u8 { + if (conn.write_buf.len - conn.write_end < len) try conn.flush(); + defer conn.write_end += len; + return conn.write_buf[conn.write_end..][0..len]; + } + + /// Flushes the write buffer to the connection. + pub fn flush(conn: *Connection) WriteError!void { + if (conn.write_end == 0) return; + + try conn.writeAllDirect(conn.write_buf[0..conn.write_end]); + conn.write_end = 0; + } + + pub const WriteError = error{ + ConnectionResetByPeer, + UnexpectedWriteFailure, + }; + + pub const Writer = std.io.Writer(*Connection, WriteError, write); + + pub fn writer(conn: *Connection) Writer { + return Writer{ .context = conn }; + } + + /// Closes the connection. + pub fn close(conn: *Connection, allocator: Allocator) void { + if (conn.protocol == .tls) { + if (disable_tls) unreachable; + + // try to cleanly close the TLS connection, for any server that cares. + conn.tls_client.close() catch {}; + allocator.destroy(conn.tls_client); + } + + conn.stream.close(); + allocator.free(conn.host); + } +}; + +/// The mode of transport for requests. +pub const RequestTransfer = union(enum) { + content_length: u64, + chunked: void, + none: void, +}; + +/// The decompressor for response messages. +pub const Compression = union(enum) { + pub const DeflateDecompressor = std.compress.zlib.Decompressor(Request.TransferReader); + pub const GzipDecompressor = std.compress.gzip.Decompressor(Request.TransferReader); + // https://github.com/ziglang/zig/issues/18937 + //pub const ZstdDecompressor = std.compress.zstd.DecompressStream(Request.TransferReader, .{}); + + deflate: DeflateDecompressor, + gzip: GzipDecompressor, + // https://github.com/ziglang/zig/issues/18937 + //zstd: ZstdDecompressor, + none: void, +}; + +/// A HTTP response originating from a server. +pub const Response = struct { + version: http.Version, + status: http.Status, + reason: []const u8, + + /// Points into the user-provided `server_header_buffer`. + location: ?[]const u8 = null, + /// Points into the user-provided `server_header_buffer`. + content_type: ?[]const u8 = null, + /// Points into the user-provided `server_header_buffer`. + content_disposition: ?[]const u8 = null, + + keep_alive: bool, + + /// If present, the number of bytes in the response body. + content_length: ?u64 = null, + + /// If present, the transfer encoding of the response body, otherwise none. + transfer_encoding: http.TransferEncoding = .none, + + /// If present, the compression of the response body, otherwise identity (no compression). + transfer_compression: http.ContentEncoding = .identity, + + parser: proto.HeadersParser, + compression: Compression = .none, + + /// Whether the response body should be skipped. Any data read from the + /// response body will be discarded. + skip: bool = false, + + pub const ParseError = error{ + HttpHeadersInvalid, + HttpHeaderContinuationsUnsupported, + HttpTransferEncodingUnsupported, + HttpConnectionHeaderUnsupported, + InvalidContentLength, + CompressionUnsupported, + }; + + pub fn parse(res: *Response, bytes: []const u8) ParseError!void { + var it = mem.splitSequence(u8, bytes, "\r\n"); + + const first_line = it.next().?; + if (first_line.len < 12) { + return error.HttpHeadersInvalid; + } + + const version: http.Version = switch (int64(first_line[0..8])) { + int64("HTTP/1.0") => .@"HTTP/1.0", + int64("HTTP/1.1") => .@"HTTP/1.1", + else => return error.HttpHeadersInvalid, + }; + if (first_line[8] != ' ') return error.HttpHeadersInvalid; + const status: http.Status = @enumFromInt(parseInt3(first_line[9..12])); + const reason = mem.trimLeft(u8, first_line[12..], " "); + + res.version = version; + res.status = status; + res.reason = reason; + res.keep_alive = switch (version) { + .@"HTTP/1.0" => false, + .@"HTTP/1.1" => true, + }; + + while (it.next()) |line| { + if (line.len == 0) return; + switch (line[0]) { + ' ', '\t' => return error.HttpHeaderContinuationsUnsupported, + else => {}, + } + + var line_it = mem.splitScalar(u8, line, ':'); + const header_name = line_it.next().?; + const header_value = mem.trim(u8, line_it.rest(), " \t"); + if (header_name.len == 0) return error.HttpHeadersInvalid; + + if (std.ascii.eqlIgnoreCase(header_name, "connection")) { + res.keep_alive = !std.ascii.eqlIgnoreCase(header_value, "close"); + } else if (std.ascii.eqlIgnoreCase(header_name, "content-type")) { + res.content_type = header_value; + } else if (std.ascii.eqlIgnoreCase(header_name, "location")) { + res.location = header_value; + } else if (std.ascii.eqlIgnoreCase(header_name, "content-disposition")) { + res.content_disposition = header_value; + } else if (std.ascii.eqlIgnoreCase(header_name, "transfer-encoding")) { + // Transfer-Encoding: second, first + // Transfer-Encoding: deflate, chunked + var iter = mem.splitBackwardsScalar(u8, header_value, ','); + + const first = iter.first(); + const trimmed_first = mem.trim(u8, first, " "); + + var next: ?[]const u8 = first; + if (std.meta.stringToEnum(http.TransferEncoding, trimmed_first)) |transfer| { + if (res.transfer_encoding != .none) return error.HttpHeadersInvalid; // we already have a transfer encoding + res.transfer_encoding = transfer; + + next = iter.next(); + } + + if (next) |second| { + const trimmed_second = mem.trim(u8, second, " "); + + if (std.meta.stringToEnum(http.ContentEncoding, trimmed_second)) |transfer| { + if (res.transfer_compression != .identity) return error.HttpHeadersInvalid; // double compression is not supported + res.transfer_compression = transfer; + } else { + return error.HttpTransferEncodingUnsupported; + } + } + + if (iter.next()) |_| return error.HttpTransferEncodingUnsupported; + } else if (std.ascii.eqlIgnoreCase(header_name, "content-length")) { + const content_length = std.fmt.parseInt(u64, header_value, 10) catch return error.InvalidContentLength; + + if (res.content_length != null and res.content_length != content_length) return error.HttpHeadersInvalid; + + res.content_length = content_length; + } else if (std.ascii.eqlIgnoreCase(header_name, "content-encoding")) { + if (res.transfer_compression != .identity) return error.HttpHeadersInvalid; + + const trimmed = mem.trim(u8, header_value, " "); + + if (std.meta.stringToEnum(http.ContentEncoding, trimmed)) |ce| { + res.transfer_compression = ce; + } else { + return error.HttpTransferEncodingUnsupported; + } + } + } + return error.HttpHeadersInvalid; // missing empty line + } + + test parse { + const response_bytes = "HTTP/1.1 200 OK\r\n" ++ + "LOcation:url\r\n" ++ + "content-tYpe: text/plain\r\n" ++ + "content-disposition:attachment; filename=example.txt \r\n" ++ + "content-Length:10\r\n" ++ + "TRansfer-encoding:\tdeflate, chunked \r\n" ++ + "connectioN:\t keep-alive \r\n\r\n"; + + var header_buffer: [1024]u8 = undefined; + var res = Response{ + .status = undefined, + .reason = undefined, + .version = undefined, + .keep_alive = false, + .parser = proto.HeadersParser.init(&header_buffer), + }; + + @memcpy(header_buffer[0..response_bytes.len], response_bytes); + res.parser.header_bytes_len = response_bytes.len; + + try res.parse(response_bytes); + + try testing.expectEqual(.@"HTTP/1.1", res.version); + try testing.expectEqualStrings("OK", res.reason); + try testing.expectEqual(.ok, res.status); + + try testing.expectEqualStrings("url", res.location.?); + try testing.expectEqualStrings("text/plain", res.content_type.?); + try testing.expectEqualStrings("attachment; filename=example.txt", res.content_disposition.?); + + try testing.expectEqual(true, res.keep_alive); + try testing.expectEqual(10, res.content_length.?); + try testing.expectEqual(.chunked, res.transfer_encoding); + try testing.expectEqual(.deflate, res.transfer_compression); + } + + inline fn int64(array: *const [8]u8) u64 { + return @bitCast(array.*); + } + + fn parseInt3(text: *const [3]u8) u10 { + if (use_vectors) { + const nnn: @Vector(3, u8) = text.*; + const zero: @Vector(3, u8) = .{ '0', '0', '0' }; + const mmm: @Vector(3, u10) = .{ 100, 10, 1 }; + return @reduce(.Add, @as(@Vector(3, u10), nnn -% zero) *% mmm); + } + return std.fmt.parseInt(u10, text, 10) catch unreachable; + } + + test parseInt3 { + const expectEqual = testing.expectEqual; + try expectEqual(@as(u10, 0), parseInt3("000")); + try expectEqual(@as(u10, 418), parseInt3("418")); + try expectEqual(@as(u10, 999), parseInt3("999")); + } + + pub fn iterateHeaders(r: Response) http.HeaderIterator { + return http.HeaderIterator.init(r.parser.get()); + } + + test iterateHeaders { + const response_bytes = "HTTP/1.1 200 OK\r\n" ++ + "LOcation:url\r\n" ++ + "content-tYpe: text/plain\r\n" ++ + "content-disposition:attachment; filename=example.txt \r\n" ++ + "content-Length:10\r\n" ++ + "TRansfer-encoding:\tdeflate, chunked \r\n" ++ + "connectioN:\t keep-alive \r\n\r\n"; + + var header_buffer: [1024]u8 = undefined; + var res = Response{ + .status = undefined, + .reason = undefined, + .version = undefined, + .keep_alive = false, + .parser = proto.HeadersParser.init(&header_buffer), + }; + + @memcpy(header_buffer[0..response_bytes.len], response_bytes); + res.parser.header_bytes_len = response_bytes.len; + + var it = res.iterateHeaders(); + { + const header = it.next().?; + try testing.expectEqualStrings("LOcation", header.name); + try testing.expectEqualStrings("url", header.value); + try testing.expect(!it.is_trailer); + } + { + const header = it.next().?; + try testing.expectEqualStrings("content-tYpe", header.name); + try testing.expectEqualStrings("text/plain", header.value); + try testing.expect(!it.is_trailer); + } + { + const header = it.next().?; + try testing.expectEqualStrings("content-disposition", header.name); + try testing.expectEqualStrings("attachment; filename=example.txt", header.value); + try testing.expect(!it.is_trailer); + } + { + const header = it.next().?; + try testing.expectEqualStrings("content-Length", header.name); + try testing.expectEqualStrings("10", header.value); + try testing.expect(!it.is_trailer); + } + { + const header = it.next().?; + try testing.expectEqualStrings("TRansfer-encoding", header.name); + try testing.expectEqualStrings("deflate, chunked", header.value); + try testing.expect(!it.is_trailer); + } + { + const header = it.next().?; + try testing.expectEqualStrings("connectioN", header.name); + try testing.expectEqualStrings("keep-alive", header.value); + try testing.expect(!it.is_trailer); + } + try testing.expectEqual(null, it.next()); + } +}; + +/// A HTTP request that has been sent. +/// +/// Order of operations: open -> send[ -> write -> finish] -> wait -> read +pub const Request = struct { + uri: Uri, + client: *Client, + /// This is null when the connection is released. + connection: ?*Connection, + keep_alive: bool, + + method: http.Method, + version: http.Version = .@"HTTP/1.1", + transfer_encoding: RequestTransfer, + redirect_behavior: RedirectBehavior, + + /// Whether the request should handle a 100-continue response before sending the request body. + handle_continue: bool, + + /// The response associated with this request. + /// + /// This field is undefined until `wait` is called. + response: Response, + + /// Standard headers that have default, but overridable, behavior. + headers: Headers, + + /// These headers are kept including when following a redirect to a + /// different domain. + /// Externally-owned; must outlive the Request. + extra_headers: []const http.Header, + + /// These headers are stripped when following a redirect to a different + /// domain. + /// Externally-owned; must outlive the Request. + privileged_headers: []const http.Header, + + pub const Headers = struct { + host: Value = .default, + authorization: Value = .default, + user_agent: Value = .default, + connection: Value = .default, + accept_encoding: Value = .default, + content_type: Value = .default, + + pub const Value = union(enum) { + default, + omit, + override: []const u8, + }; + }; + + /// Any value other than `not_allowed` or `unhandled` means that integer represents + /// how many remaining redirects are allowed. + pub const RedirectBehavior = enum(u16) { + /// The next redirect will cause an error. + not_allowed = 0, + /// Redirects are passed to the client to analyze the redirect response + /// directly. + unhandled = std.math.maxInt(u16), + _, + + pub fn subtractOne(rb: *RedirectBehavior) void { + switch (rb.*) { + .not_allowed => unreachable, + .unhandled => unreachable, + _ => rb.* = @enumFromInt(@intFromEnum(rb.*) - 1), + } + } + + pub fn remaining(rb: RedirectBehavior) u16 { + assert(rb != .unhandled); + return @intFromEnum(rb); + } + }; + + /// Frees all resources associated with the request. + pub fn deinit(req: *Request) void { + if (req.connection) |connection| { + if (!req.response.parser.done) { + // If the response wasn't fully read, then we need to close the connection. + connection.closing = true; + } + req.client.connection_pool.release(req.client.allocator, connection); + } + req.* = undefined; + } + + // This function must deallocate all resources associated with the request, + // or keep those which will be used. + // This needs to be kept in sync with deinit and request. + fn redirect(req: *Request, uri: Uri) !void { + assert(req.response.parser.done); + + req.client.connection_pool.release(req.client.allocator, req.connection.?); + req.connection = null; + + var server_header = std.heap.FixedBufferAllocator.init(req.response.parser.header_bytes_buffer); + defer req.response.parser.header_bytes_buffer = server_header.buffer[server_header.end_index..]; + const protocol, const valid_uri = try validateUri(uri, server_header.allocator()); + + const new_host = valid_uri.host.?.raw; + const prev_host = req.uri.host.?.raw; + const keep_privileged_headers = + std.ascii.eqlIgnoreCase(valid_uri.scheme, req.uri.scheme) and + std.ascii.endsWithIgnoreCase(new_host, prev_host) and + (new_host.len == prev_host.len or new_host[new_host.len - prev_host.len - 1] == '.'); + if (!keep_privileged_headers) { + // When redirecting to a different domain, strip privileged headers. + req.privileged_headers = &.{}; + } + + if (switch (req.response.status) { + .see_other => true, + .moved_permanently, .found => req.method == .POST, + else => false, + }) { + // A redirect to a GET must change the method and remove the body. + req.method = .GET; + req.transfer_encoding = .none; + req.headers.content_type = .omit; + } + + if (req.transfer_encoding != .none) { + // The request body has already been sent. The request is + // still in a valid state, but the redirect must be handled + // manually. + return error.RedirectRequiresResend; + } + + req.uri = valid_uri; + req.connection = try req.client.connect(new_host, uriPort(valid_uri, protocol), protocol); + req.redirect_behavior.subtractOne(); + req.response.parser.reset(); + + req.response = .{ + .version = undefined, + .status = undefined, + .reason = undefined, + .keep_alive = undefined, + .parser = req.response.parser, + }; + } + + pub const SendError = Connection.WriteError || error{ InvalidContentLength, UnsupportedTransferEncoding }; + + /// Send the HTTP request headers to the server. + pub fn send(req: *Request) SendError!void { + if (!req.method.requestHasBody() and req.transfer_encoding != .none) + return error.UnsupportedTransferEncoding; + + const connection = req.connection.?; + const w = connection.writer(); + + try req.method.write(w); + try w.writeByte(' '); + + if (req.method == .CONNECT) { + try req.uri.writeToStream(.{ .authority = true }, w); + } else { + try req.uri.writeToStream(.{ + .scheme = connection.proxied, + .authentication = connection.proxied, + .authority = connection.proxied, + .path = true, + .query = true, + }, w); + } + try w.writeByte(' '); + try w.writeAll(@tagName(req.version)); + try w.writeAll("\r\n"); + + if (try emitOverridableHeader("host: ", req.headers.host, w)) { + try w.writeAll("host: "); + try req.uri.writeToStream(.{ .authority = true }, w); + try w.writeAll("\r\n"); + } + + if (try emitOverridableHeader("authorization: ", req.headers.authorization, w)) { + if (req.uri.user != null or req.uri.password != null) { + try w.writeAll("authorization: "); + const authorization = try connection.allocWriteBuffer( + @intCast(basic_authorization.valueLengthFromUri(req.uri)), + ); + assert(basic_authorization.value(req.uri, authorization).len == authorization.len); + try w.writeAll("\r\n"); + } + } + + if (try emitOverridableHeader("user-agent: ", req.headers.user_agent, w)) { + try w.writeAll("user-agent: zig/"); + try w.writeAll(builtin.zig_version_string); + try w.writeAll(" (std.http)\r\n"); + } + + if (try emitOverridableHeader("connection: ", req.headers.connection, w)) { + if (req.keep_alive) { + try w.writeAll("connection: keep-alive\r\n"); + } else { + try w.writeAll("connection: close\r\n"); + } + } + + if (try emitOverridableHeader("accept-encoding: ", req.headers.accept_encoding, w)) { + // https://github.com/ziglang/zig/issues/18937 + //try w.writeAll("accept-encoding: gzip, deflate, zstd\r\n"); + try w.writeAll("accept-encoding: gzip, deflate\r\n"); + } + + switch (req.transfer_encoding) { + .chunked => try w.writeAll("transfer-encoding: chunked\r\n"), + .content_length => |len| try w.print("content-length: {d}\r\n", .{len}), + .none => {}, + } + + if (try emitOverridableHeader("content-type: ", req.headers.content_type, w)) { + // The default is to omit content-type if not provided because + // "application/octet-stream" is redundant. + } + + for (req.extra_headers) |header| { + assert(header.name.len != 0); + + try w.writeAll(header.name); + try w.writeAll(": "); + try w.writeAll(header.value); + try w.writeAll("\r\n"); + } + + if (connection.proxied) proxy: { + const proxy = switch (connection.protocol) { + .plain => req.client.http_proxy, + .tls => req.client.https_proxy, + } orelse break :proxy; + + const authorization = proxy.authorization orelse break :proxy; + try w.writeAll("proxy-authorization: "); + try w.writeAll(authorization); + try w.writeAll("\r\n"); + } + + try w.writeAll("\r\n"); + + try connection.flush(); + } + + /// Returns true if the default behavior is required, otherwise handles + /// writing (or not writing) the header. + fn emitOverridableHeader(prefix: []const u8, v: Headers.Value, w: anytype) !bool { + switch (v) { + .default => return true, + .omit => return false, + .override => |x| { + try w.writeAll(prefix); + try w.writeAll(x); + try w.writeAll("\r\n"); + return false; + }, + } + } + + const TransferReadError = Connection.ReadError || proto.HeadersParser.ReadError; + + const TransferReader = std.io.Reader(*Request, TransferReadError, transferRead); + + fn transferReader(req: *Request) TransferReader { + return .{ .context = req }; + } + + fn transferRead(req: *Request, buf: []u8) TransferReadError!usize { + if (req.response.parser.done) return 0; + + var index: usize = 0; + while (index == 0) { + const amt = try req.response.parser.read(req.connection.?, buf[index..], req.response.skip); + if (amt == 0 and req.response.parser.done) break; + index += amt; + } + + return index; + } + + pub const WaitError = RequestError || SendError || TransferReadError || + proto.HeadersParser.CheckCompleteHeadError || Response.ParseError || + error{ // TODO: file zig fmt issue for this bad indentation + TooManyHttpRedirects, + RedirectRequiresResend, + HttpRedirectLocationMissing, + HttpRedirectLocationInvalid, + CompressionInitializationFailed, + CompressionUnsupported, + }; + + /// Waits for a response from the server and parses any headers that are sent. + /// This function will block until the final response is received. + /// + /// If handling redirects and the request has no payload, then this + /// function will automatically follow redirects. If a request payload is + /// present, then this function will error with + /// error.RedirectRequiresResend. + /// + /// Must be called after `send` and, if any data was written to the request + /// body, then also after `finish`. + pub fn wait(req: *Request) WaitError!void { + while (true) { + // This while loop is for handling redirects, which means the request's + // connection may be different than the previous iteration. However, it + // is still guaranteed to be non-null with each iteration of this loop. + const connection = req.connection.?; + + while (true) { // read headers + try connection.fill(); + + const nchecked = try req.response.parser.checkCompleteHead(connection.peek()); + connection.drop(@intCast(nchecked)); + + if (req.response.parser.state.isContent()) break; + } + + try req.response.parse(req.response.parser.get()); + + if (req.response.status == .@"continue") { + // We're done parsing the continue response; reset to prepare + // for the real response. + req.response.parser.done = true; + req.response.parser.reset(); + + if (req.handle_continue) + continue; + + return; // we're not handling the 100-continue + } + + // we're switching protocols, so this connection is no longer doing http + if (req.method == .CONNECT and req.response.status.class() == .success) { + connection.closing = false; + req.response.parser.done = true; + return; // the connection is not HTTP past this point + } + + connection.closing = !req.response.keep_alive or !req.keep_alive; + + // Any response to a HEAD request and any response with a 1xx + // (Informational), 204 (No Content), or 304 (Not Modified) status + // code is always terminated by the first empty line after the + // header fields, regardless of the header fields present in the + // message. + if (req.method == .HEAD or req.response.status.class() == .informational or + req.response.status == .no_content or req.response.status == .not_modified) + { + req.response.parser.done = true; + return; // The response is empty; no further setup or redirection is necessary. + } + + switch (req.response.transfer_encoding) { + .none => { + if (req.response.content_length) |cl| { + req.response.parser.next_chunk_length = cl; + + if (cl == 0) req.response.parser.done = true; + } else { + // read until the connection is closed + req.response.parser.next_chunk_length = std.math.maxInt(u64); + } + }, + .chunked => { + req.response.parser.next_chunk_length = 0; + req.response.parser.state = .chunk_head_size; + }, + } + + if (req.response.status.class() == .redirect and req.redirect_behavior != .unhandled) { + // skip the body of the redirect response, this will at least + // leave the connection in a known good state. + req.response.skip = true; + assert(try req.transferRead(&.{}) == 0); // we're skipping, no buffer is necessary + + if (req.redirect_behavior == .not_allowed) return error.TooManyHttpRedirects; + + const location = req.response.location orelse + return error.HttpRedirectLocationMissing; + + // This mutates the beginning of header_bytes_buffer and uses that + // for the backing memory of the returned Uri. + try req.redirect(req.uri.resolve_inplace( + location, + &req.response.parser.header_bytes_buffer, + ) catch |err| switch (err) { + error.UnexpectedCharacter, + error.InvalidFormat, + error.InvalidPort, + => return error.HttpRedirectLocationInvalid, + error.NoSpaceLeft => return error.HttpHeadersOversize, + }); + try req.send(); + } else { + req.response.skip = false; + if (!req.response.parser.done) { + switch (req.response.transfer_compression) { + .identity => req.response.compression = .none, + .compress, .@"x-compress" => return error.CompressionUnsupported, + .deflate => req.response.compression = .{ + .deflate = std.compress.zlib.decompressor(req.transferReader()), + }, + .gzip, .@"x-gzip" => req.response.compression = .{ + .gzip = std.compress.gzip.decompressor(req.transferReader()), + }, + // https://github.com/ziglang/zig/issues/18937 + //.zstd => req.response.compression = .{ + // .zstd = std.compress.zstd.decompressStream(req.client.allocator, req.transferReader()), + //}, + .zstd => return error.CompressionUnsupported, + } + } + + break; + } + } + } + + pub const ReadError = TransferReadError || proto.HeadersParser.CheckCompleteHeadError || + error{ DecompressionFailure, InvalidTrailers }; + + pub const Reader = std.io.Reader(*Request, ReadError, read); + + pub fn reader(req: *Request) Reader { + return .{ .context = req }; + } + + /// Reads data from the response body. Must be called after `wait`. + pub fn read(req: *Request, buffer: []u8) ReadError!usize { + const out_index = switch (req.response.compression) { + .deflate => |*deflate| deflate.read(buffer) catch return error.DecompressionFailure, + .gzip => |*gzip| gzip.read(buffer) catch return error.DecompressionFailure, + // https://github.com/ziglang/zig/issues/18937 + //.zstd => |*zstd| zstd.read(buffer) catch return error.DecompressionFailure, + else => try req.transferRead(buffer), + }; + if (out_index > 0) return out_index; + + while (!req.response.parser.state.isContent()) { // read trailing headers + try req.connection.?.fill(); + + const nchecked = try req.response.parser.checkCompleteHead(req.connection.?.peek()); + req.connection.?.drop(@intCast(nchecked)); + } + + return 0; + } + + /// Reads data from the response body. Must be called after `wait`. + pub fn readAll(req: *Request, buffer: []u8) !usize { + var index: usize = 0; + while (index < buffer.len) { + const amt = try read(req, buffer[index..]); + if (amt == 0) break; + index += amt; + } + return index; + } + + pub const WriteError = Connection.WriteError || error{ NotWriteable, MessageTooLong }; + + pub const Writer = std.io.Writer(*Request, WriteError, write); + + pub fn writer(req: *Request) Writer { + return .{ .context = req }; + } + + /// Write `bytes` to the server. The `transfer_encoding` field determines how data will be sent. + /// Must be called after `send` and before `finish`. + pub fn write(req: *Request, bytes: []const u8) WriteError!usize { + switch (req.transfer_encoding) { + .chunked => { + if (bytes.len > 0) { + try req.connection.?.writer().print("{x}\r\n", .{bytes.len}); + try req.connection.?.writer().writeAll(bytes); + try req.connection.?.writer().writeAll("\r\n"); + } + + return bytes.len; + }, + .content_length => |*len| { + if (len.* < bytes.len) return error.MessageTooLong; + + const amt = try req.connection.?.write(bytes); + len.* -= amt; + return amt; + }, + .none => return error.NotWriteable, + } + } + + /// Write `bytes` to the server. The `transfer_encoding` field determines how data will be sent. + /// Must be called after `send` and before `finish`. + pub fn writeAll(req: *Request, bytes: []const u8) WriteError!void { + var index: usize = 0; + while (index < bytes.len) { + index += try write(req, bytes[index..]); + } + } + + pub const FinishError = WriteError || error{MessageNotCompleted}; + + /// Finish the body of a request. This notifies the server that you have no more data to send. + /// Must be called after `send`. + pub fn finish(req: *Request) FinishError!void { + switch (req.transfer_encoding) { + .chunked => try req.connection.?.writer().writeAll("0\r\n\r\n"), + .content_length => |len| if (len != 0) return error.MessageNotCompleted, + .none => {}, + } + + try req.connection.?.flush(); + } +}; + +pub const Proxy = struct { + protocol: Connection.Protocol, + host: []const u8, + authorization: ?[]const u8, + port: u16, + supports_connect: bool, +}; + +/// Release all associated resources with the client. +/// +/// All pending requests must be de-initialized and all active connections released +/// before calling this function. +pub fn deinit(client: *Client) void { + assert(client.connection_pool.used.first == null); // There are still active requests. + + client.connection_pool.deinit(client.allocator); + + if (!disable_tls) + client.ca_bundle.deinit(client.allocator); + + client.* = undefined; +} + +/// Populates `http_proxy` and `https_proxy` via standard proxy environment variables. +/// Asserts the client has no active connections. +/// Uses `arena` for a few small allocations that must outlive the client, or +/// at least until those fields are set to different values. +pub fn initDefaultProxies(client: *Client, arena: Allocator) !void { + // Prevent any new connections from being created. + client.connection_pool.mutex.lock(); + defer client.connection_pool.mutex.unlock(); + + assert(client.connection_pool.used.first == null); // There are active requests. + + if (client.http_proxy == null) { + client.http_proxy = try createProxyFromEnvVar(arena, &.{ + "http_proxy", "HTTP_PROXY", "all_proxy", "ALL_PROXY", + }); + } + + if (client.https_proxy == null) { + client.https_proxy = try createProxyFromEnvVar(arena, &.{ + "https_proxy", "HTTPS_PROXY", "all_proxy", "ALL_PROXY", + }); + } +} + +fn createProxyFromEnvVar(arena: Allocator, env_var_names: []const []const u8) !?*Proxy { + const content = for (env_var_names) |name| { + break std.process.getEnvVarOwned(arena, name) catch |err| switch (err) { + error.EnvironmentVariableNotFound => continue, + else => |e| return e, + }; + } else return null; + + const uri = Uri.parse(content) catch try Uri.parseAfterScheme("http", content); + const protocol, const valid_uri = validateUri(uri, arena) catch |err| switch (err) { + error.UnsupportedUriScheme => return null, + error.UriMissingHost => return error.HttpProxyMissingHost, + error.OutOfMemory => |e| return e, + }; + + const authorization: ?[]const u8 = if (valid_uri.user != null or valid_uri.password != null) a: { + const authorization = try arena.alloc(u8, basic_authorization.valueLengthFromUri(valid_uri)); + assert(basic_authorization.value(valid_uri, authorization).len == authorization.len); + break :a authorization; + } else null; + + const proxy = try arena.create(Proxy); + proxy.* = .{ + .protocol = protocol, + .host = valid_uri.host.?.raw, + .authorization = authorization, + .port = uriPort(valid_uri, protocol), + .supports_connect = true, + }; + return proxy; +} + +pub const basic_authorization = struct { + pub const max_user_len = 255; + pub const max_password_len = 255; + pub const max_value_len = valueLength(max_user_len, max_password_len); + + const prefix = "Basic "; + + pub fn valueLength(user_len: usize, password_len: usize) usize { + return prefix.len + std.base64.standard.Encoder.calcSize(user_len + 1 + password_len); + } + + pub fn valueLengthFromUri(uri: Uri) usize { + var stream = std.io.countingWriter(std.io.null_writer); + try stream.writer().print("{user}", .{uri.user orelse Uri.Component.empty}); + const user_len = stream.bytes_written; + stream.bytes_written = 0; + try stream.writer().print("{password}", .{uri.password orelse Uri.Component.empty}); + const password_len = stream.bytes_written; + return valueLength(@intCast(user_len), @intCast(password_len)); + } + + pub fn value(uri: Uri, out: []u8) []u8 { + var buf: [max_user_len + ":".len + max_password_len]u8 = undefined; + var stream = std.io.fixedBufferStream(&buf); + stream.writer().print("{user}", .{uri.user orelse Uri.Component.empty}) catch + unreachable; + assert(stream.pos <= max_user_len); + stream.writer().print(":{password}", .{uri.password orelse Uri.Component.empty}) catch + unreachable; + + @memcpy(out[0..prefix.len], prefix); + const base64 = std.base64.standard.Encoder.encode(out[prefix.len..], stream.getWritten()); + return out[0 .. prefix.len + base64.len]; + } +}; + +pub const ConnectTcpError = Allocator.Error || error{ ConnectionRefused, NetworkUnreachable, ConnectionTimedOut, ConnectionResetByPeer, TemporaryNameServerFailure, NameServerFailure, UnknownHostName, HostLacksNetworkAddresses, UnexpectedConnectFailure, TlsInitializationFailed }; + +/// Connect to `host:port` using the specified protocol. This will reuse a connection if one is already open. +/// +/// This function is threadsafe. +pub fn connectTcp(client: *Client, host: []const u8, port: u16, protocol: Connection.Protocol) ConnectTcpError!*Connection { + if (client.connection_pool.findConnection(.{ + .host = host, + .port = port, + .protocol = protocol, + })) |node| return node; + + if (disable_tls and protocol == .tls) + return error.TlsInitializationFailed; + + const conn = try client.allocator.create(ConnectionPool.Node); + errdefer client.allocator.destroy(conn); + conn.* = .{ .data = undefined }; + + const stream = net.tcpConnectToHost(client.allocator, host, port) catch |err| switch (err) { + error.ConnectionRefused => return error.ConnectionRefused, + error.NetworkUnreachable => return error.NetworkUnreachable, + error.ConnectionTimedOut => return error.ConnectionTimedOut, + error.ConnectionResetByPeer => return error.ConnectionResetByPeer, + error.TemporaryNameServerFailure => return error.TemporaryNameServerFailure, + error.NameServerFailure => return error.NameServerFailure, + error.UnknownHostName => return error.UnknownHostName, + error.HostLacksNetworkAddresses => return error.HostLacksNetworkAddresses, + else => return error.UnexpectedConnectFailure, + }; + errdefer stream.close(); + + conn.data = .{ + .stream = stream, + .tls_client = undefined, + + .protocol = protocol, + .host = try client.allocator.dupe(u8, host), + .port = port, + }; + errdefer client.allocator.free(conn.data.host); + + if (protocol == .tls) { + if (disable_tls) unreachable; + + conn.data.tls_client = try client.allocator.create(tls23.Connection(net.Stream)); + errdefer client.allocator.destroy(conn.data.tls_client); + + conn.data.tls_client.* = tls23.client(stream, .{ + .host = host, + .root_ca = client.ca_bundle, + }) catch return error.TlsInitializationFailed; + } + + client.connection_pool.addUsed(conn); + + return &conn.data; +} + +pub const ConnectUnixError = Allocator.Error || std.posix.SocketError || error{NameTooLong} || std.posix.ConnectError; + +/// Connect to `path` as a unix domain socket. This will reuse a connection if one is already open. +/// +/// This function is threadsafe. +pub fn connectUnix(client: *Client, path: []const u8) ConnectUnixError!*Connection { + if (client.connection_pool.findConnection(.{ + .host = path, + .port = 0, + .protocol = .plain, + })) |node| + return node; + + const conn = try client.allocator.create(ConnectionPool.Node); + errdefer client.allocator.destroy(conn); + conn.* = .{ .data = undefined }; + + const stream = try std.net.connectUnixSocket(path); + errdefer stream.close(); + + conn.data = .{ + .stream = stream, + .tls_client = undefined, + .protocol = .plain, + + .host = try client.allocator.dupe(u8, path), + .port = 0, + }; + errdefer client.allocator.free(conn.data.host); + + client.connection_pool.addUsed(conn); + + return &conn.data; +} + +/// Connect to `tunnel_host:tunnel_port` using the specified proxy with HTTP +/// CONNECT. This will reuse a connection if one is already open. +/// +/// This function is threadsafe. +pub fn connectTunnel( + client: *Client, + proxy: *Proxy, + tunnel_host: []const u8, + tunnel_port: u16, +) !*Connection { + if (!proxy.supports_connect) return error.TunnelNotSupported; + + if (client.connection_pool.findConnection(.{ + .host = tunnel_host, + .port = tunnel_port, + .protocol = proxy.protocol, + })) |node| + return node; + + var maybe_valid = false; + (tunnel: { + const conn = try client.connectTcp(proxy.host, proxy.port, proxy.protocol); + errdefer { + conn.closing = true; + client.connection_pool.release(client.allocator, conn); + } + + var buffer: [8096]u8 = undefined; + var req = client.open(.CONNECT, .{ + .scheme = "http", + .host = .{ .raw = tunnel_host }, + .port = tunnel_port, + }, .{ + .redirect_behavior = .unhandled, + .connection = conn, + .server_header_buffer = &buffer, + }) catch |err| { + std.log.debug("err {}", .{err}); + break :tunnel err; + }; + defer req.deinit(); + + req.send() catch |err| break :tunnel err; + req.wait() catch |err| break :tunnel err; + + if (req.response.status.class() == .server_error) { + maybe_valid = true; + break :tunnel error.ServerError; + } + + if (req.response.status != .ok) break :tunnel error.ConnectionRefused; + + // this connection is now a tunnel, so we can't use it for anything else, it will only be released when the client is de-initialized. + req.connection = null; + + client.allocator.free(conn.host); + conn.host = try client.allocator.dupe(u8, tunnel_host); + errdefer client.allocator.free(conn.host); + + conn.port = tunnel_port; + conn.closing = false; + + return conn; + }) catch { + // something went wrong with the tunnel + proxy.supports_connect = maybe_valid; + return error.TunnelNotSupported; + }; +} + +// Prevents a dependency loop in open() +const ConnectErrorPartial = ConnectTcpError || error{ UnsupportedUriScheme, ConnectionRefused }; +pub const ConnectError = ConnectErrorPartial || RequestError; + +/// Connect to `host:port` using the specified protocol. This will reuse a +/// connection if one is already open. +/// If a proxy is configured for the client, then the proxy will be used to +/// connect to the host. +/// +/// This function is threadsafe. +pub fn connect( + client: *Client, + host: []const u8, + port: u16, + protocol: Connection.Protocol, +) ConnectError!*Connection { + const proxy = switch (protocol) { + .plain => client.http_proxy, + .tls => client.https_proxy, + } orelse return client.connectTcp(host, port, protocol); + + // Prevent proxying through itself. + if (std.ascii.eqlIgnoreCase(proxy.host, host) and + proxy.port == port and proxy.protocol == protocol) + { + return client.connectTcp(host, port, protocol); + } + + if (proxy.supports_connect) tunnel: { + return connectTunnel(client, proxy, host, port) catch |err| switch (err) { + error.TunnelNotSupported => break :tunnel, + else => |e| return e, + }; + } + + // fall back to using the proxy as a normal http proxy + const conn = try client.connectTcp(proxy.host, proxy.port, proxy.protocol); + errdefer { + conn.closing = true; + client.connection_pool.release(conn); + } + + conn.proxied = true; + return conn; +} + +pub const RequestError = ConnectTcpError || ConnectErrorPartial || Request.SendError || + std.fmt.ParseIntError || Connection.WriteError || + error{ // TODO: file a zig fmt issue for this bad indentation + UnsupportedUriScheme, + UriMissingHost, + + CertificateBundleLoadFailure, + UnsupportedTransferEncoding, +}; + +pub const RequestOptions = struct { + version: http.Version = .@"HTTP/1.1", + + /// Automatically ignore 100 Continue responses. This assumes you don't + /// care, and will have sent the body before you wait for the response. + /// + /// If this is not the case AND you know the server will send a 100 + /// Continue, set this to false and wait for a response before sending the + /// body. If you wait AND the server does not send a 100 Continue before + /// you finish the request, then the request *will* deadlock. + handle_continue: bool = true, + + /// If false, close the connection after the one request. If true, + /// participate in the client connection pool. + keep_alive: bool = true, + + /// This field specifies whether to automatically follow redirects, and if + /// so, how many redirects to follow before returning an error. + /// + /// This will only follow redirects for repeatable requests (ie. with no + /// payload or the server has acknowledged the payload). + redirect_behavior: Request.RedirectBehavior = @enumFromInt(3), + + /// Externally-owned memory used to store the server's entire HTTP header. + /// `error.HttpHeadersOversize` is returned from read() when a + /// client sends too many bytes of HTTP headers. + server_header_buffer: []u8, + + /// Must be an already acquired connection. + connection: ?*Connection = null, + + /// Standard headers that have default, but overridable, behavior. + headers: Request.Headers = .{}, + /// These headers are kept including when following a redirect to a + /// different domain. + /// Externally-owned; must outlive the Request. + extra_headers: []const http.Header = &.{}, + /// These headers are stripped when following a redirect to a different + /// domain. + /// Externally-owned; must outlive the Request. + privileged_headers: []const http.Header = &.{}, +}; + +fn validateUri(uri: Uri, arena: Allocator) !struct { Connection.Protocol, Uri } { + const protocol_map = std.StaticStringMap(Connection.Protocol).initComptime(.{ + .{ "http", .plain }, + .{ "ws", .plain }, + .{ "https", .tls }, + .{ "wss", .tls }, + }); + const protocol = protocol_map.get(uri.scheme) orelse return error.UnsupportedUriScheme; + var valid_uri = uri; + // The host is always going to be needed as a raw string for hostname resolution anyway. + valid_uri.host = .{ + .raw = try (uri.host orelse return error.UriMissingHost).toRawMaybeAlloc(arena), + }; + return .{ protocol, valid_uri }; +} + +fn uriPort(uri: Uri, protocol: Connection.Protocol) u16 { + return uri.port orelse switch (protocol) { + .plain => 80, + .tls => 443, + }; +} + +/// Open a connection to the host specified by `uri` and prepare to send a HTTP request. +/// +/// `uri` must remain alive during the entire request. +/// +/// The caller is responsible for calling `deinit()` on the `Request`. +/// This function is threadsafe. +/// +/// Asserts that "\r\n" does not occur in any header name or value. +pub fn open( + client: *Client, + method: http.Method, + uri: Uri, + options: RequestOptions, +) RequestError!Request { + if (std.debug.runtime_safety) { + for (options.extra_headers) |header| { + assert(header.name.len != 0); + assert(std.mem.indexOfScalar(u8, header.name, ':') == null); + assert(std.mem.indexOfPosLinear(u8, header.name, 0, "\r\n") == null); + assert(std.mem.indexOfPosLinear(u8, header.value, 0, "\r\n") == null); + } + for (options.privileged_headers) |header| { + assert(header.name.len != 0); + assert(std.mem.indexOfPosLinear(u8, header.name, 0, "\r\n") == null); + assert(std.mem.indexOfPosLinear(u8, header.value, 0, "\r\n") == null); + } + } + + var server_header = std.heap.FixedBufferAllocator.init(options.server_header_buffer); + const protocol, const valid_uri = try validateUri(uri, server_header.allocator()); + + if (protocol == .tls and @atomicLoad(bool, &client.next_https_rescan_certs, .acquire)) { + if (disable_tls) unreachable; + + client.ca_bundle_mutex.lock(); + defer client.ca_bundle_mutex.unlock(); + + if (client.next_https_rescan_certs) { + client.ca_bundle.rescan(client.allocator) catch + return error.CertificateBundleLoadFailure; + @atomicStore(bool, &client.next_https_rescan_certs, false, .release); + } + } + + const conn = options.connection orelse + try client.connect(valid_uri.host.?.raw, uriPort(valid_uri, protocol), protocol); + + var req: Request = .{ + .uri = valid_uri, + .client = client, + .connection = conn, + .keep_alive = options.keep_alive, + .method = method, + .version = options.version, + .transfer_encoding = .none, + .redirect_behavior = options.redirect_behavior, + .handle_continue = options.handle_continue, + .response = .{ + .version = undefined, + .status = undefined, + .reason = undefined, + .keep_alive = undefined, + .parser = proto.HeadersParser.init(server_header.buffer[server_header.end_index..]), + }, + .headers = options.headers, + .extra_headers = options.extra_headers, + .privileged_headers = options.privileged_headers, + }; + errdefer req.deinit(); + + return req; +} + +pub const FetchOptions = struct { + server_header_buffer: ?[]u8 = null, + redirect_behavior: ?Request.RedirectBehavior = null, + + /// If the server sends a body, it will be appended to this ArrayList. + /// `max_append_size` provides an upper limit for how much they can grow. + response_storage: ResponseStorage = .ignore, + max_append_size: ?usize = null, + + location: Location, + method: ?http.Method = null, + payload: ?[]const u8 = null, + raw_uri: bool = false, + keep_alive: bool = true, + + /// Standard headers that have default, but overridable, behavior. + headers: Request.Headers = .{}, + /// These headers are kept including when following a redirect to a + /// different domain. + /// Externally-owned; must outlive the Request. + extra_headers: []const http.Header = &.{}, + /// These headers are stripped when following a redirect to a different + /// domain. + /// Externally-owned; must outlive the Request. + privileged_headers: []const http.Header = &.{}, + + pub const Location = union(enum) { + url: []const u8, + uri: Uri, + }; + + pub const ResponseStorage = union(enum) { + ignore, + /// Only the existing capacity will be used. + static: *std.ArrayListUnmanaged(u8), + dynamic: *std.ArrayList(u8), + }; +}; + +pub const FetchResult = struct { + status: http.Status, +}; + +/// Perform a one-shot HTTP request with the provided options. +/// +/// This function is threadsafe. +pub fn fetch(client: *Client, options: FetchOptions) !FetchResult { + const uri = switch (options.location) { + .url => |u| try Uri.parse(u), + .uri => |u| u, + }; + var server_header_buffer: [16 * 1024]u8 = undefined; + + const method: http.Method = options.method orelse + if (options.payload != null) .POST else .GET; + + var req = try open(client, method, uri, .{ + .server_header_buffer = options.server_header_buffer orelse &server_header_buffer, + .redirect_behavior = options.redirect_behavior orelse + if (options.payload == null) @enumFromInt(3) else .unhandled, + .headers = options.headers, + .extra_headers = options.extra_headers, + .privileged_headers = options.privileged_headers, + .keep_alive = options.keep_alive, + }); + defer req.deinit(); + + if (options.payload) |payload| req.transfer_encoding = .{ .content_length = payload.len }; + + try req.send(); + + if (options.payload) |payload| try req.writeAll(payload); + + try req.finish(); + try req.wait(); + + switch (options.response_storage) { + .ignore => { + // Take advantage of request internals to discard the response body + // and make the connection available for another request. + req.response.skip = true; + assert(try req.transferRead(&.{}) == 0); // No buffer is necessary when skipping. + }, + .dynamic => |list| { + const max_append_size = options.max_append_size orelse 2 * 1024 * 1024; + try req.reader().readAllArrayList(list, max_append_size); + }, + .static => |list| { + const buf = b: { + const buf = list.unusedCapacitySlice(); + if (options.max_append_size) |len| { + if (len < buf.len) break :b buf[0..len]; + } + break :b buf; + }; + list.items.len += try req.reader().readAll(buf); + }, + } + + return .{ + .status = req.response.status, + }; +} + +test { + _ = &initDefaultProxies; +} diff --git a/vendor/tls.zig b/vendor/tls.zig new file mode 160000 index 00000000..0ea9e6d7 --- /dev/null +++ b/vendor/tls.zig @@ -0,0 +1 @@ +Subproject commit 0ea9e6d769a74946d6554edef4f05850734a48d2