diff --git a/.gitmodules b/.gitmodules index 2ceea970..229d1a16 100644 --- a/.gitmodules +++ b/.gitmodules @@ -25,3 +25,6 @@ [submodule "vendor/tls.zig"] path = vendor/tls.zig url = git@github.com:ianic/tls.zig.git +[submodule "vendor/zig-async-io"] + path = vendor/zig-async-io + url = git@github.com:lightpanda-io/zig-async-io.git diff --git a/build.zig b/build.zig index 86ad4ef9..8c83d648 100644 --- a/build.zig +++ b/build.zig @@ -159,6 +159,11 @@ fn common( netsurf.addImport("jsruntime", jsruntimemod); step.root_module.addImport("netsurf", netsurf); + const asyncio = b.addModule("asyncio", .{ + .root_source_file = b.path("vendor/zig-async-io/src/lib.zig"), + }); + step.root_module.addImport("asyncio", asyncio); + const tlsmod = b.addModule("tls", .{ .root_source_file = b.path("vendor/tls.zig/src/main.zig"), }); diff --git a/src/browser/browser.zig b/src/browser/browser.zig index e6f646ef..0b58cbaa 100644 --- a/src/browser/browser.zig +++ b/src/browser/browser.zig @@ -40,7 +40,7 @@ const storage = @import("../storage/storage.zig"); const FetchResult = @import("../http/Client.zig").Client.FetchResult; const UserContext = @import("../user_context.zig").UserContext; -const HttpClient = @import("../http/async/main.zig").Client; +const HttpClient = @import("asyncio").Client; const log = std.log.scoped(.browser); diff --git a/src/http/async/io.zig b/src/http/async/io.zig deleted file mode 100644 index 4a11b5b6..00000000 --- a/src/http/async/io.zig +++ /dev/null @@ -1,132 +0,0 @@ -const std = @import("std"); - -const Ctx = @import("std/http/Client.zig").Ctx; -const Loop = @import("jsruntime").Loop; -const NetworkImpl = Loop.Network(SingleThreaded); - -pub const Blocking = struct { - pub fn connect( - _: *Blocking, - comptime CtxT: type, - ctx: *CtxT, - comptime cbk: fn (ctx: *CtxT, res: anyerror!void) anyerror!void, - socket: std.posix.socket_t, - address: std.net.Address, - ) void { - std.posix.connect(socket, &address.any, address.getOsSockLen()) catch |err| { - std.posix.close(socket); - cbk(ctx, err) catch |e| { - ctx.setErr(e); - }; - }; - cbk(ctx, {}) catch |e| ctx.setErr(e); - } - - pub fn send( - _: *Blocking, - comptime CtxT: type, - ctx: *CtxT, - comptime cbk: fn (ctx: *CtxT, res: anyerror!void) anyerror!void, - socket: std.posix.socket_t, - buf: []const u8, - ) void { - const len = std.posix.write(socket, buf) catch |err| { - cbk(ctx, err) catch |e| { - return ctx.setErr(e); - }; - return ctx.setErr(err); - }; - ctx.setLen(len); - cbk(ctx, {}) catch |e| ctx.setErr(e); - } - - pub fn recv( - _: *Blocking, - comptime CtxT: type, - ctx: *CtxT, - comptime cbk: fn (ctx: *CtxT, res: anyerror!void) anyerror!void, - socket: std.posix.socket_t, - buf: []u8, - ) void { - const len = std.posix.read(socket, buf) catch |err| { - cbk(ctx, err) catch |e| { - return ctx.setErr(e); - }; - return ctx.setErr(err); - }; - ctx.setLen(len); - cbk(ctx, {}) catch |e| ctx.setErr(e); - } -}; - -pub const SingleThreaded = struct { - impl: NetworkImpl, - cbk: Cbk, - ctx: *Ctx, - - const Cbk = *const fn (ctx: *Ctx, res: anyerror!void) anyerror!void; - - pub fn init(loop: *Loop) SingleThreaded { - return .{ - .impl = NetworkImpl.init(loop), - .cbk = undefined, - .ctx = undefined, - }; - } - - pub fn connect( - self: *SingleThreaded, - comptime _: type, - ctx: *Ctx, - comptime cbk: Cbk, - socket: std.posix.socket_t, - address: std.net.Address, - ) void { - self.cbk = cbk; - self.ctx = ctx; - self.impl.connect(self, socket, address); - } - - pub fn onConnect(self: *SingleThreaded, err: ?anyerror) void { - if (err) |e| return self.ctx.setErr(e); - self.cbk(self.ctx, {}) catch |e| self.ctx.setErr(e); - } - - pub fn send( - self: *SingleThreaded, - comptime _: type, - ctx: *Ctx, - comptime cbk: Cbk, - socket: std.posix.socket_t, - buf: []const u8, - ) void { - self.ctx = ctx; - self.cbk = cbk; - self.impl.send(self, socket, buf); - } - - pub fn onSend(self: *SingleThreaded, ln: usize, err: ?anyerror) void { - if (err) |e| return self.ctx.setErr(e); - self.ctx.setLen(ln); - self.cbk(self.ctx, {}) catch |e| self.ctx.setErr(e); - } - - pub fn recv( - self: *SingleThreaded, - comptime _: type, - ctx: *Ctx, - comptime cbk: Cbk, - socket: std.posix.socket_t, - buf: []u8, - ) void { - self.ctx = ctx; - self.cbk = cbk; - self.impl.receive(self, socket, buf); - } - - pub fn onReceive(self: *SingleThreaded, ln: usize, err: ?anyerror) void { - if (err) |e| return self.ctx.setErr(e); - self.ctx.setLen(ln); - self.cbk(self.ctx, {}) catch |e| self.ctx.setErr(e); - } -}; diff --git a/src/http/async/main.zig b/src/http/async/main.zig deleted file mode 100644 index c3ef9934..00000000 --- a/src/http/async/main.zig +++ /dev/null @@ -1,3 +0,0 @@ -const std = @import("std"); - -pub const Client = @import("std/http/Client.zig"); diff --git a/src/http/async/stack.zig b/src/http/async/stack.zig deleted file mode 100644 index d19a0c8f..00000000 --- a/src/http/async/stack.zig +++ /dev/null @@ -1,95 +0,0 @@ -const std = @import("std"); - -pub fn Stack(comptime T: type) type { - return struct { - const Self = @This(); - pub const Fn = *const T; - - next: ?*Self = null, - func: Fn, - - pub fn init(alloc: std.mem.Allocator, comptime func: Fn) !*Self { - const next = try alloc.create(Self); - next.* = .{ .func = func }; - return next; - } - - pub fn push(self: *Self, alloc: std.mem.Allocator, comptime func: Fn) !void { - if (self.next) |next| { - return next.push(alloc, func); - } - self.next = try Self.init(alloc, func); - } - - pub fn pop(self: *Self, alloc: std.mem.Allocator, prev: ?*Self) Fn { - if (self.next) |next| { - return next.pop(alloc, self); - } - defer { - if (prev) |p| { - self.deinit(alloc, p); - } - } - return self.func; - } - - pub fn deinit(self: *Self, alloc: std.mem.Allocator, prev: ?*Self) void { - if (self.next) |next| { - // recursivly deinit - next.deinit(alloc, self); - } - if (prev) |p| { - p.next = null; - } - alloc.destroy(self); - } - }; -} - -fn first() u8 { - return 1; -} - -fn second() u8 { - return 2; -} - -test "stack" { - const alloc = std.testing.allocator; - const TestStack = Stack(fn () u8); - - var stack = TestStack{ .func = first }; - try stack.push(alloc, second); - - const a = stack.pop(alloc, null); - try std.testing.expect(a() == 2); - - const b = stack.pop(alloc, null); - try std.testing.expect(b() == 1); -} - -fn first_op(arg: ?*anyopaque) u8 { - const val = @as(*u8, @ptrCast(arg)); - return val.* + @as(u8, 1); -} - -fn second_op(arg: ?*anyopaque) u8 { - const val = @as(*u8, @ptrCast(arg)); - return val.* + @as(u8, 2); -} - -test "opaque stack" { - const alloc = std.testing.allocator; - const TestStack = Stack(fn (?*anyopaque) u8); - - var stack = TestStack{ .func = first_op }; - try stack.push(alloc, second_op); - - const a = stack.pop(alloc, null); - var x: u8 = 5; - try std.testing.expect(a(@as(*anyopaque, @ptrCast(&x))) == 2 + x); - - const b = stack.pop(alloc, null); - var y: u8 = 3; - try std.testing.expect(b(@as(*anyopaque, @ptrCast(&y))) == 1 + y); -} diff --git a/src/http/async/std/http.zig b/src/http/async/std/http.zig deleted file mode 100644 index f027d440..00000000 --- a/src/http/async/std/http.zig +++ /dev/null @@ -1,318 +0,0 @@ -pub const Client = @import("http/Client.zig"); -pub const Server = @import("http/Server.zig"); -pub const protocol = @import("http/protocol.zig"); -pub const HeadParser = std.http.HeadParser; -pub const ChunkParser = std.http.ChunkParser; -pub const HeaderIterator = std.http.HeaderIterator; - -pub const Version = enum { - @"HTTP/1.0", - @"HTTP/1.1", -}; - -/// https://developer.mozilla.org/en-US/docs/Web/HTTP/Methods -/// -/// https://datatracker.ietf.org/doc/html/rfc7231#section-4 Initial definition -/// -/// https://datatracker.ietf.org/doc/html/rfc5789#section-2 PATCH -pub const Method = enum(u64) { - GET = parse("GET"), - HEAD = parse("HEAD"), - POST = parse("POST"), - PUT = parse("PUT"), - DELETE = parse("DELETE"), - CONNECT = parse("CONNECT"), - OPTIONS = parse("OPTIONS"), - TRACE = parse("TRACE"), - PATCH = parse("PATCH"), - - _, - - /// Converts `s` into a type that may be used as a `Method` field. - /// Asserts that `s` is 24 or fewer bytes. - pub fn parse(s: []const u8) u64 { - var x: u64 = 0; - const len = @min(s.len, @sizeOf(@TypeOf(x))); - @memcpy(std.mem.asBytes(&x)[0..len], s[0..len]); - return x; - } - - pub fn write(self: Method, w: anytype) !void { - const bytes = std.mem.asBytes(&@intFromEnum(self)); - const str = std.mem.sliceTo(bytes, 0); - try w.writeAll(str); - } - - /// Returns true if a request of this method is allowed to have a body - /// Actual behavior from servers may vary and should still be checked - pub fn requestHasBody(self: Method) bool { - return switch (self) { - .POST, .PUT, .PATCH => true, - .GET, .HEAD, .DELETE, .CONNECT, .OPTIONS, .TRACE => false, - else => true, - }; - } - - /// Returns true if a response to this method is allowed to have a body - /// Actual behavior from clients may vary and should still be checked - pub fn responseHasBody(self: Method) bool { - return switch (self) { - .GET, .POST, .DELETE, .CONNECT, .OPTIONS, .PATCH => true, - .HEAD, .PUT, .TRACE => false, - else => true, - }; - } - - /// An HTTP method is safe if it doesn't alter the state of the server. - /// - /// https://developer.mozilla.org/en-US/docs/Glossary/Safe/HTTP - /// - /// https://datatracker.ietf.org/doc/html/rfc7231#section-4.2.1 - pub fn safe(self: Method) bool { - return switch (self) { - .GET, .HEAD, .OPTIONS, .TRACE => true, - .POST, .PUT, .DELETE, .CONNECT, .PATCH => false, - else => false, - }; - } - - /// An HTTP method is idempotent if an identical request can be made once or several times in a row with the same effect while leaving the server in the same state. - /// - /// https://developer.mozilla.org/en-US/docs/Glossary/Idempotent - /// - /// https://datatracker.ietf.org/doc/html/rfc7231#section-4.2.2 - pub fn idempotent(self: Method) bool { - return switch (self) { - .GET, .HEAD, .PUT, .DELETE, .OPTIONS, .TRACE => true, - .CONNECT, .POST, .PATCH => false, - else => false, - }; - } - - /// A cacheable response is an HTTP response that can be cached, that is stored to be retrieved and used later, saving a new request to the server. - /// - /// https://developer.mozilla.org/en-US/docs/Glossary/cacheable - /// - /// https://datatracker.ietf.org/doc/html/rfc7231#section-4.2.3 - pub fn cacheable(self: Method) bool { - return switch (self) { - .GET, .HEAD => true, - .POST, .PUT, .DELETE, .CONNECT, .OPTIONS, .TRACE, .PATCH => false, - else => false, - }; - } -}; - -/// https://developer.mozilla.org/en-US/docs/Web/HTTP/Status -pub const Status = enum(u10) { - @"continue" = 100, // RFC7231, Section 6.2.1 - switching_protocols = 101, // RFC7231, Section 6.2.2 - processing = 102, // RFC2518 - early_hints = 103, // RFC8297 - - ok = 200, // RFC7231, Section 6.3.1 - created = 201, // RFC7231, Section 6.3.2 - accepted = 202, // RFC7231, Section 6.3.3 - non_authoritative_info = 203, // RFC7231, Section 6.3.4 - no_content = 204, // RFC7231, Section 6.3.5 - reset_content = 205, // RFC7231, Section 6.3.6 - partial_content = 206, // RFC7233, Section 4.1 - multi_status = 207, // RFC4918 - already_reported = 208, // RFC5842 - im_used = 226, // RFC3229 - - multiple_choice = 300, // RFC7231, Section 6.4.1 - moved_permanently = 301, // RFC7231, Section 6.4.2 - found = 302, // RFC7231, Section 6.4.3 - see_other = 303, // RFC7231, Section 6.4.4 - not_modified = 304, // RFC7232, Section 4.1 - use_proxy = 305, // RFC7231, Section 6.4.5 - temporary_redirect = 307, // RFC7231, Section 6.4.7 - permanent_redirect = 308, // RFC7538 - - bad_request = 400, // RFC7231, Section 6.5.1 - unauthorized = 401, // RFC7235, Section 3.1 - payment_required = 402, // RFC7231, Section 6.5.2 - forbidden = 403, // RFC7231, Section 6.5.3 - not_found = 404, // RFC7231, Section 6.5.4 - method_not_allowed = 405, // RFC7231, Section 6.5.5 - not_acceptable = 406, // RFC7231, Section 6.5.6 - proxy_auth_required = 407, // RFC7235, Section 3.2 - request_timeout = 408, // RFC7231, Section 6.5.7 - conflict = 409, // RFC7231, Section 6.5.8 - gone = 410, // RFC7231, Section 6.5.9 - length_required = 411, // RFC7231, Section 6.5.10 - precondition_failed = 412, // RFC7232, Section 4.2][RFC8144, Section 3.2 - payload_too_large = 413, // RFC7231, Section 6.5.11 - uri_too_long = 414, // RFC7231, Section 6.5.12 - unsupported_media_type = 415, // RFC7231, Section 6.5.13][RFC7694, Section 3 - range_not_satisfiable = 416, // RFC7233, Section 4.4 - expectation_failed = 417, // RFC7231, Section 6.5.14 - teapot = 418, // RFC 7168, 2.3.3 - misdirected_request = 421, // RFC7540, Section 9.1.2 - unprocessable_entity = 422, // RFC4918 - locked = 423, // RFC4918 - failed_dependency = 424, // RFC4918 - too_early = 425, // RFC8470 - upgrade_required = 426, // RFC7231, Section 6.5.15 - precondition_required = 428, // RFC6585 - too_many_requests = 429, // RFC6585 - request_header_fields_too_large = 431, // RFC6585 - unavailable_for_legal_reasons = 451, // RFC7725 - - internal_server_error = 500, // RFC7231, Section 6.6.1 - not_implemented = 501, // RFC7231, Section 6.6.2 - bad_gateway = 502, // RFC7231, Section 6.6.3 - service_unavailable = 503, // RFC7231, Section 6.6.4 - gateway_timeout = 504, // RFC7231, Section 6.6.5 - http_version_not_supported = 505, // RFC7231, Section 6.6.6 - variant_also_negotiates = 506, // RFC2295 - insufficient_storage = 507, // RFC4918 - loop_detected = 508, // RFC5842 - not_extended = 510, // RFC2774 - network_authentication_required = 511, // RFC6585 - - _, - - pub fn phrase(self: Status) ?[]const u8 { - return switch (self) { - // 1xx statuses - .@"continue" => "Continue", - .switching_protocols => "Switching Protocols", - .processing => "Processing", - .early_hints => "Early Hints", - - // 2xx statuses - .ok => "OK", - .created => "Created", - .accepted => "Accepted", - .non_authoritative_info => "Non-Authoritative Information", - .no_content => "No Content", - .reset_content => "Reset Content", - .partial_content => "Partial Content", - .multi_status => "Multi-Status", - .already_reported => "Already Reported", - .im_used => "IM Used", - - // 3xx statuses - .multiple_choice => "Multiple Choice", - .moved_permanently => "Moved Permanently", - .found => "Found", - .see_other => "See Other", - .not_modified => "Not Modified", - .use_proxy => "Use Proxy", - .temporary_redirect => "Temporary Redirect", - .permanent_redirect => "Permanent Redirect", - - // 4xx statuses - .bad_request => "Bad Request", - .unauthorized => "Unauthorized", - .payment_required => "Payment Required", - .forbidden => "Forbidden", - .not_found => "Not Found", - .method_not_allowed => "Method Not Allowed", - .not_acceptable => "Not Acceptable", - .proxy_auth_required => "Proxy Authentication Required", - .request_timeout => "Request Timeout", - .conflict => "Conflict", - .gone => "Gone", - .length_required => "Length Required", - .precondition_failed => "Precondition Failed", - .payload_too_large => "Payload Too Large", - .uri_too_long => "URI Too Long", - .unsupported_media_type => "Unsupported Media Type", - .range_not_satisfiable => "Range Not Satisfiable", - .expectation_failed => "Expectation Failed", - .teapot => "I'm a teapot", - .misdirected_request => "Misdirected Request", - .unprocessable_entity => "Unprocessable Entity", - .locked => "Locked", - .failed_dependency => "Failed Dependency", - .too_early => "Too Early", - .upgrade_required => "Upgrade Required", - .precondition_required => "Precondition Required", - .too_many_requests => "Too Many Requests", - .request_header_fields_too_large => "Request Header Fields Too Large", - .unavailable_for_legal_reasons => "Unavailable For Legal Reasons", - - // 5xx statuses - .internal_server_error => "Internal Server Error", - .not_implemented => "Not Implemented", - .bad_gateway => "Bad Gateway", - .service_unavailable => "Service Unavailable", - .gateway_timeout => "Gateway Timeout", - .http_version_not_supported => "HTTP Version Not Supported", - .variant_also_negotiates => "Variant Also Negotiates", - .insufficient_storage => "Insufficient Storage", - .loop_detected => "Loop Detected", - .not_extended => "Not Extended", - .network_authentication_required => "Network Authentication Required", - - else => return null, - }; - } - - pub const Class = enum { - informational, - success, - redirect, - client_error, - server_error, - }; - - pub fn class(self: Status) Class { - return switch (@intFromEnum(self)) { - 100...199 => .informational, - 200...299 => .success, - 300...399 => .redirect, - 400...499 => .client_error, - else => .server_error, - }; - } - - test { - try std.testing.expectEqualStrings("OK", Status.ok.phrase().?); - try std.testing.expectEqualStrings("Not Found", Status.not_found.phrase().?); - } - - test { - try std.testing.expectEqual(Status.Class.success, Status.ok.class()); - try std.testing.expectEqual(Status.Class.client_error, Status.not_found.class()); - } -}; - -pub const TransferEncoding = enum { - chunked, - none, - // compression is intentionally omitted here, as std.http.Client stores it as content-encoding -}; - -pub const ContentEncoding = enum { - identity, - compress, - @"x-compress", - deflate, - gzip, - @"x-gzip", - zstd, -}; - -pub const Connection = enum { - keep_alive, - close, -}; - -pub const Header = struct { - name: []const u8, - value: []const u8, -}; - -const builtin = @import("builtin"); -const std = @import("std"); - -test { - _ = Client; - _ = Method; - _ = Server; - _ = Status; -} diff --git a/src/http/async/std/http/Client.zig b/src/http/async/std/http/Client.zig deleted file mode 100644 index 2c866e6f..00000000 --- a/src/http/async/std/http/Client.zig +++ /dev/null @@ -1,2512 +0,0 @@ -//! HTTP(S) Client implementation. -//! -//! Connections are opened in a thread-safe manner, but individual Requests are not. -//! -//! TLS support may be disabled via `std.options.http_disable_tls`. - -const std = @import("std"); -const builtin = @import("builtin"); -const testing = std.testing; -const http = std.http; -const mem = std.mem; -const net = @import("../net.zig"); -const Uri = std.Uri; -const Allocator = mem.Allocator; -const assert = std.debug.assert; -const use_vectors = builtin.zig_backend != .stage2_x86_64; - -const Client = @This(); -const proto = @import("protocol.zig"); - -const tls23 = @import("../../tls.zig/main.zig"); -const VecPut = @import("../../tls.zig/connection.zig").VecPut; -const GenericStack = @import("../../stack.zig").Stack; -const async_io = @import("../../io.zig"); -pub const Loop = async_io.SingleThreaded; - -const cipher = @import("../../tls.zig/cipher.zig"); - -pub const disable_tls = std.options.http_disable_tls; - -/// Used for all client allocations. Must be thread-safe. -allocator: Allocator, - -ca_bundle: if (disable_tls) void else std.crypto.Certificate.Bundle = if (disable_tls) {} else .{}, -ca_bundle_mutex: std.Thread.Mutex = .{}, - -/// When this is `true`, the next time this client performs an HTTPS request, -/// it will first rescan the system for root certificates. -next_https_rescan_certs: bool = true, - -/// The pool of connections that can be reused (and currently in use). -connection_pool: ConnectionPool = .{}, - -/// If populated, all http traffic travels through this third party. -/// This field cannot be modified while the client has active connections. -/// Pointer to externally-owned memory. -http_proxy: ?*Proxy = null, -/// If populated, all https traffic travels through this third party. -/// This field cannot be modified while the client has active connections. -/// Pointer to externally-owned memory. -https_proxy: ?*Proxy = null, - -/// A set of linked lists of connections that can be reused. -pub const ConnectionPool = struct { - mutex: std.Thread.Mutex = .{}, - /// Open connections that are currently in use. - used: Queue = .{}, - /// Open connections that are not currently in use. - free: Queue = .{}, - free_len: usize = 0, - free_size: usize = 32, - - /// The criteria for a connection to be considered a match. - pub const Criteria = struct { - host: []const u8, - port: u16, - protocol: Connection.Protocol, - }; - - const Queue = std.DoublyLinkedList(Connection); - pub const Node = Queue.Node; - - /// Finds and acquires a connection from the connection pool matching the criteria. This function is threadsafe. - /// If no connection is found, null is returned. - pub fn findConnection(pool: *ConnectionPool, criteria: Criteria) ?*Connection { - pool.mutex.lock(); - defer pool.mutex.unlock(); - - var next = pool.free.last; - while (next) |node| : (next = node.prev) { - if (node.data.protocol != criteria.protocol) continue; - if (node.data.port != criteria.port) continue; - - // Domain names are case-insensitive (RFC 5890, Section 2.3.2.4) - if (!std.ascii.eqlIgnoreCase(node.data.host, criteria.host)) continue; - - pool.acquireUnsafe(node); - return &node.data; - } - - return null; - } - - /// Acquires an existing connection from the connection pool. This function is not threadsafe. - pub fn acquireUnsafe(pool: *ConnectionPool, node: *Node) void { - pool.free.remove(node); - pool.free_len -= 1; - - pool.used.append(node); - } - - /// Acquires an existing connection from the connection pool. This function is threadsafe. - pub fn acquire(pool: *ConnectionPool, node: *Node) void { - pool.mutex.lock(); - defer pool.mutex.unlock(); - - return pool.acquireUnsafe(node); - } - - /// Tries to release a connection back to the connection pool. This function is threadsafe. - /// If the connection is marked as closing, it will be closed instead. - /// - /// The allocator must be the owner of all nodes in this pool. - /// The allocator must be the owner of all resources associated with the connection. - pub fn release(pool: *ConnectionPool, allocator: Allocator, connection: *Connection) void { - pool.mutex.lock(); - defer pool.mutex.unlock(); - - const node: *Node = @fieldParentPtr("data", connection); - - pool.used.remove(node); - - if (node.data.closing or pool.free_size == 0) { - node.data.close(allocator); - return allocator.destroy(node); - } - - if (pool.free_len >= pool.free_size) { - const popped = pool.free.popFirst() orelse unreachable; - pool.free_len -= 1; - - popped.data.close(allocator); - allocator.destroy(popped); - } - - if (node.data.proxied) { - pool.free.prepend(node); // proxied connections go to the end of the queue, always try direct connections first - } else { - pool.free.append(node); - } - - pool.free_len += 1; - } - - /// Adds a newly created node to the pool of used connections. This function is threadsafe. - pub fn addUsed(pool: *ConnectionPool, node: *Node) void { - pool.mutex.lock(); - defer pool.mutex.unlock(); - - pool.used.append(node); - } - - /// Resizes the connection pool. This function is threadsafe. - /// - /// If the new size is smaller than the current size, then idle connections will be closed until the pool is the new size. - pub fn resize(pool: *ConnectionPool, allocator: Allocator, new_size: usize) void { - pool.mutex.lock(); - defer pool.mutex.unlock(); - - const next = pool.free.first; - _ = next; - while (pool.free_len > new_size) { - const popped = pool.free.popFirst() orelse unreachable; - pool.free_len -= 1; - - popped.data.close(allocator); - allocator.destroy(popped); - } - - pool.free_size = new_size; - } - - /// Frees the connection pool and closes all connections within. This function is threadsafe. - /// - /// All future operations on the connection pool will deadlock. - pub fn deinit(pool: *ConnectionPool, allocator: Allocator) void { - pool.mutex.lock(); - - var next = pool.free.first; - while (next) |node| { - defer allocator.destroy(node); - next = node.next; - - node.data.close(allocator); - } - - next = pool.used.first; - while (next) |node| { - defer allocator.destroy(node); - next = node.next; - - node.data.close(allocator); - } - - pool.* = undefined; - } -}; - -/// An interface to either a plain or TLS connection. -pub const Connection = struct { - stream: net.Stream, - /// undefined unless protocol is tls. - tls_client: if (!disable_tls) *tls23.Connection(net.Stream) else void, - - /// The protocol that this connection is using. - protocol: Protocol, - - /// The host that this connection is connected to. - host: []u8, - - /// The port that this connection is connected to. - port: u16, - - /// Whether this connection is proxied and is not directly connected. - proxied: bool = false, - - /// Whether this connection is closing when we're done with it. - closing: bool = false, - - read_start: BufferSize = 0, - read_end: BufferSize = 0, - write_end: BufferSize = 0, - read_buf: [buffer_size]u8 = undefined, - write_buf: [buffer_size]u8 = undefined, - - pub const buffer_size = std.crypto.tls.max_ciphertext_record_len; - const BufferSize = std.math.IntFittingRange(0, buffer_size); - - pub const Protocol = enum { plain, tls }; - - pub fn async_readvDirect( - conn: *Connection, - buffers: []std.posix.iovec, - ctx: *Ctx, - comptime cbk: Cbk, - ) !void { - _ = conn; - - if (ctx.conn().protocol == .tls) { - if (disable_tls) unreachable; - - return ctx.conn().tls_client.async_readv(ctx.conn().stream, buffers, ctx, cbk); - } - - return ctx.stream().async_readv(buffers, ctx, cbk); - } - - pub fn readvDirectTls(conn: *Connection, buffers: []std.posix.iovec) ReadError!usize { - return conn.tls_client.readv(buffers) catch |err| { - // https://github.com/ziglang/zig/issues/2473 - if (mem.startsWith(u8, @errorName(err), "TlsAlert")) return error.TlsAlert; - - switch (err) { - error.TlsRecordOverflow, error.TlsBadRecordMac, error.TlsUnexpectedMessage => return error.TlsFailure, - error.ConnectionTimedOut => return error.ConnectionTimedOut, - error.ConnectionResetByPeer, error.BrokenPipe => return error.ConnectionResetByPeer, - else => return error.UnexpectedReadFailure, - } - }; - } - - pub fn readvDirect(conn: *Connection, buffers: []std.posix.iovec) ReadError!usize { - if (conn.protocol == .tls) { - if (disable_tls) unreachable; - - return conn.readvDirectTls(buffers); - } - - return conn.stream.readv(buffers) catch |err| switch (err) { - error.ConnectionTimedOut => return error.ConnectionTimedOut, - error.ConnectionResetByPeer, error.BrokenPipe => return error.ConnectionResetByPeer, - else => return error.UnexpectedReadFailure, - }; - } - - fn onFill(ctx: *Ctx, res: anyerror!void) anyerror!void { - ctx.alloc().free(ctx._iovecs); - res catch |err| return ctx.pop(err); - - // EOF - const nread = ctx.len(); - if (nread == 0) return ctx.pop(error.EndOfStream); - - // finished - ctx.conn().read_start = 0; - ctx.conn().read_end = @intCast(nread); - return ctx.pop({}); - } - - pub fn async_fill(conn: *Connection, ctx: *Ctx, comptime cbk: Cbk) !void { - if (conn.read_end != conn.read_start) return; - - ctx._iovecs = try ctx.alloc().alloc(std.posix.iovec, 1); - errdefer ctx.alloc().free(ctx._iovecs); - const iovecs = [1]std.posix.iovec{ - .{ .base = &conn.read_buf, .len = conn.read_buf.len }, - }; - @memcpy(ctx._iovecs, &iovecs); - - try ctx.push(cbk); - return conn.async_readvDirect(ctx._iovecs, ctx, onFill); - } - - /// Refills the read buffer with data from the connection. - pub fn fill(conn: *Connection) ReadError!void { - if (conn.read_end != conn.read_start) return; - - var iovecs = [1]std.posix.iovec{ - .{ .base = &conn.read_buf, .len = conn.read_buf.len }, - }; - const nread = try conn.readvDirect(&iovecs); - if (nread == 0) return error.EndOfStream; - conn.read_start = 0; - conn.read_end = @intCast(nread); - } - - /// Returns the current slice of buffered data. - pub fn peek(conn: *Connection) []const u8 { - return conn.read_buf[conn.read_start..conn.read_end]; - } - - /// Discards the given number of bytes from the read buffer. - pub fn drop(conn: *Connection, num: BufferSize) void { - conn.read_start += num; - } - - /// Reads data from the connection into the given buffer. - pub fn read(conn: *Connection, buffer: []u8) ReadError!usize { - const available_read = conn.read_end - conn.read_start; - const available_buffer = buffer.len; - - if (available_read > available_buffer) { // partially read buffered data - @memcpy(buffer[0..available_buffer], conn.read_buf[conn.read_start..conn.read_end][0..available_buffer]); - conn.read_start += @intCast(available_buffer); - - return available_buffer; - } else if (available_read > 0) { // fully read buffered data - @memcpy(buffer[0..available_read], conn.read_buf[conn.read_start..conn.read_end]); - conn.read_start += available_read; - - return available_read; - } - - var iovecs = [2]std.posix.iovec{ - .{ .base = buffer.ptr, .len = buffer.len }, - .{ .base = &conn.read_buf, .len = conn.read_buf.len }, - }; - const nread = try conn.readvDirect(&iovecs); - - if (nread > buffer.len) { - conn.read_start = 0; - conn.read_end = @intCast(nread - buffer.len); - return buffer.len; - } - - return nread; - } - - pub const ReadError = error{ - TlsFailure, - TlsAlert, - ConnectionTimedOut, - ConnectionResetByPeer, - UnexpectedReadFailure, - EndOfStream, - }; - - pub const Reader = std.io.Reader(*Connection, ReadError, read); - - pub fn reader(conn: *Connection) Reader { - return Reader{ .context = conn }; - } - - pub fn writeAllDirectTls(conn: *Connection, buffer: []const u8) WriteError!void { - return conn.tls_client.writeAll(buffer) catch |err| switch (err) { - error.BrokenPipe, error.ConnectionResetByPeer => return error.ConnectionResetByPeer, - else => return error.UnexpectedWriteFailure, - }; - } - - fn onWriteAllDirect(ctx: *Ctx, res: anyerror!void) !void { - res catch |err| switch (err) { - error.BrokenPipe, - error.ConnectionResetByPeer, - => return ctx.pop(error.ConnectionResetByPeer), - else => return ctx.pop(error.UnexpectedWriteFailure), - }; - return ctx.pop({}); - } - - pub fn async_writeAllDirect( - conn: *Connection, - buffer: []const u8, - ctx: *Ctx, - comptime cbk: Cbk, - ) !void { - try ctx.push(cbk); - if (conn.protocol == .tls) { - if (disable_tls) unreachable; - - return conn.tls_client.async_writeAll(conn.stream, buffer, ctx, onWriteAllDirect); - } - - return conn.stream.async_writeAll(buffer, ctx, onWriteAllDirect); - } - - pub fn writeAllDirect(conn: *Connection, buffer: []const u8) WriteError!void { - if (conn.protocol == .tls) { - if (disable_tls) unreachable; - - return conn.writeAllDirectTls(buffer); - } - - return conn.stream.writeAll(buffer) catch |err| switch (err) { - error.BrokenPipe, error.ConnectionResetByPeer => return error.ConnectionResetByPeer, - else => return error.UnexpectedWriteFailure, - }; - } - - /// Writes the given buffer to the connection. - pub fn write(conn: *Connection, buffer: []const u8) WriteError!usize { - if (conn.write_buf.len - conn.write_end < buffer.len) { - try conn.flush(); - - if (buffer.len > conn.write_buf.len) { - try conn.writeAllDirect(buffer); - return buffer.len; - } - } - - @memcpy(conn.write_buf[conn.write_end..][0..buffer.len], buffer); - conn.write_end += @intCast(buffer.len); - - return buffer.len; - } - - /// Returns a buffer to be filled with exactly len bytes to write to the connection. - pub fn allocWriteBuffer(conn: *Connection, len: BufferSize) WriteError![]u8 { - if (conn.write_buf.len - conn.write_end < len) try conn.flush(); - defer conn.write_end += len; - return conn.write_buf[conn.write_end..][0..len]; - } - - fn onFlush(ctx: *Ctx, res: anyerror!void) !void { - res catch |err| return ctx.pop(err); - ctx.conn().write_end = 0; - return ctx.pop({}); - } - - pub fn async_flush(conn: *Connection, ctx: *Ctx, comptime cbk: Cbk) !void { - if (conn.write_end == 0) return error.WriteEmpty; - - try ctx.push(cbk); - try conn.async_writeAllDirect(conn.write_buf[0..conn.write_end], ctx, onFlush); - } - - /// Flushes the write buffer to the connection. - pub fn flush(conn: *Connection) WriteError!void { - if (conn.write_end == 0) return; - - try conn.writeAllDirect(conn.write_buf[0..conn.write_end]); - conn.write_end = 0; - } - - pub const WriteError = error{ - ConnectionResetByPeer, - UnexpectedWriteFailure, - }; - - pub const Writer = std.io.Writer(*Connection, WriteError, write); - - pub fn writer(conn: *Connection) Writer { - return Writer{ .context = conn }; - } - - /// Closes the connection. - pub fn close(conn: *Connection, allocator: Allocator) void { - if (conn.protocol == .tls) { - if (disable_tls) unreachable; - - // try to cleanly close the TLS connection, for any server that cares. - conn.tls_client.close() catch {}; - allocator.destroy(conn.tls_client); - } - - conn.stream.close(); - allocator.free(conn.host); - } -}; - -/// The mode of transport for requests. -pub const RequestTransfer = union(enum) { - content_length: u64, - chunked: void, - none: void, -}; - -/// The decompressor for response messages. -pub const Compression = union(enum) { - pub const DeflateDecompressor = std.compress.zlib.Decompressor(Request.TransferReader); - pub const GzipDecompressor = std.compress.gzip.Decompressor(Request.TransferReader); - // https://github.com/ziglang/zig/issues/18937 - //pub const ZstdDecompressor = std.compress.zstd.DecompressStream(Request.TransferReader, .{}); - - deflate: DeflateDecompressor, - gzip: GzipDecompressor, - // https://github.com/ziglang/zig/issues/18937 - //zstd: ZstdDecompressor, - none: void, -}; - -/// A HTTP response originating from a server. -pub const Response = struct { - version: http.Version, - status: http.Status, - reason: []const u8, - - /// Points into the user-provided `server_header_buffer`. - location: ?[]const u8 = null, - /// Points into the user-provided `server_header_buffer`. - content_type: ?[]const u8 = null, - /// Points into the user-provided `server_header_buffer`. - content_disposition: ?[]const u8 = null, - - keep_alive: bool, - - /// If present, the number of bytes in the response body. - content_length: ?u64 = null, - - /// If present, the transfer encoding of the response body, otherwise none. - transfer_encoding: http.TransferEncoding = .none, - - /// If present, the compression of the response body, otherwise identity (no compression). - transfer_compression: http.ContentEncoding = .identity, - - parser: proto.HeadersParser, - compression: Compression = .none, - - /// Whether the response body should be skipped. Any data read from the - /// response body will be discarded. - skip: bool = false, - - pub const ParseError = error{ - HttpHeadersInvalid, - HttpHeaderContinuationsUnsupported, - HttpTransferEncodingUnsupported, - HttpConnectionHeaderUnsupported, - InvalidContentLength, - CompressionUnsupported, - }; - - pub fn parse(res: *Response, bytes: []const u8) ParseError!void { - var it = mem.splitSequence(u8, bytes, "\r\n"); - - const first_line = it.next().?; - if (first_line.len < 12) { - return error.HttpHeadersInvalid; - } - - const version: http.Version = switch (int64(first_line[0..8])) { - int64("HTTP/1.0") => .@"HTTP/1.0", - int64("HTTP/1.1") => .@"HTTP/1.1", - else => return error.HttpHeadersInvalid, - }; - if (first_line[8] != ' ') return error.HttpHeadersInvalid; - const status: http.Status = @enumFromInt(parseInt3(first_line[9..12])); - const reason = mem.trimLeft(u8, first_line[12..], " "); - - res.version = version; - res.status = status; - res.reason = reason; - res.keep_alive = switch (version) { - .@"HTTP/1.0" => false, - .@"HTTP/1.1" => true, - }; - - while (it.next()) |line| { - if (line.len == 0) return; - switch (line[0]) { - ' ', '\t' => return error.HttpHeaderContinuationsUnsupported, - else => {}, - } - - var line_it = mem.splitScalar(u8, line, ':'); - const header_name = line_it.next().?; - const header_value = mem.trim(u8, line_it.rest(), " \t"); - if (header_name.len == 0) return error.HttpHeadersInvalid; - - if (std.ascii.eqlIgnoreCase(header_name, "connection")) { - res.keep_alive = !std.ascii.eqlIgnoreCase(header_value, "close"); - } else if (std.ascii.eqlIgnoreCase(header_name, "content-type")) { - res.content_type = header_value; - } else if (std.ascii.eqlIgnoreCase(header_name, "location")) { - res.location = header_value; - } else if (std.ascii.eqlIgnoreCase(header_name, "content-disposition")) { - res.content_disposition = header_value; - } else if (std.ascii.eqlIgnoreCase(header_name, "transfer-encoding")) { - // Transfer-Encoding: second, first - // Transfer-Encoding: deflate, chunked - var iter = mem.splitBackwardsScalar(u8, header_value, ','); - - const first = iter.first(); - const trimmed_first = mem.trim(u8, first, " "); - - var next: ?[]const u8 = first; - if (std.meta.stringToEnum(http.TransferEncoding, trimmed_first)) |transfer| { - if (res.transfer_encoding != .none) return error.HttpHeadersInvalid; // we already have a transfer encoding - res.transfer_encoding = transfer; - - next = iter.next(); - } - - if (next) |second| { - const trimmed_second = mem.trim(u8, second, " "); - - if (std.meta.stringToEnum(http.ContentEncoding, trimmed_second)) |transfer| { - if (res.transfer_compression != .identity) return error.HttpHeadersInvalid; // double compression is not supported - res.transfer_compression = transfer; - } else { - return error.HttpTransferEncodingUnsupported; - } - } - - if (iter.next()) |_| return error.HttpTransferEncodingUnsupported; - } else if (std.ascii.eqlIgnoreCase(header_name, "content-length")) { - const content_length = std.fmt.parseInt(u64, header_value, 10) catch return error.InvalidContentLength; - - if (res.content_length != null and res.content_length != content_length) return error.HttpHeadersInvalid; - - res.content_length = content_length; - } else if (std.ascii.eqlIgnoreCase(header_name, "content-encoding")) { - if (res.transfer_compression != .identity) return error.HttpHeadersInvalid; - - const trimmed = mem.trim(u8, header_value, " "); - - if (std.meta.stringToEnum(http.ContentEncoding, trimmed)) |ce| { - res.transfer_compression = ce; - } else { - return error.HttpTransferEncodingUnsupported; - } - } - } - return error.HttpHeadersInvalid; // missing empty line - } - - test parse { - const response_bytes = "HTTP/1.1 200 OK\r\n" ++ - "LOcation:url\r\n" ++ - "content-tYpe: text/plain\r\n" ++ - "content-disposition:attachment; filename=example.txt \r\n" ++ - "content-Length:10\r\n" ++ - "TRansfer-encoding:\tdeflate, chunked \r\n" ++ - "connectioN:\t keep-alive \r\n\r\n"; - - var header_buffer: [1024]u8 = undefined; - var res = Response{ - .status = undefined, - .reason = undefined, - .version = undefined, - .keep_alive = false, - .parser = proto.HeadersParser.init(&header_buffer), - }; - - @memcpy(header_buffer[0..response_bytes.len], response_bytes); - res.parser.header_bytes_len = response_bytes.len; - - try res.parse(response_bytes); - - try testing.expectEqual(.@"HTTP/1.1", res.version); - try testing.expectEqualStrings("OK", res.reason); - try testing.expectEqual(.ok, res.status); - - try testing.expectEqualStrings("url", res.location.?); - try testing.expectEqualStrings("text/plain", res.content_type.?); - try testing.expectEqualStrings("attachment; filename=example.txt", res.content_disposition.?); - - try testing.expectEqual(true, res.keep_alive); - try testing.expectEqual(10, res.content_length.?); - try testing.expectEqual(.chunked, res.transfer_encoding); - try testing.expectEqual(.deflate, res.transfer_compression); - } - - inline fn int64(array: *const [8]u8) u64 { - return @bitCast(array.*); - } - - fn parseInt3(text: *const [3]u8) u10 { - if (use_vectors) { - const nnn: @Vector(3, u8) = text.*; - const zero: @Vector(3, u8) = .{ '0', '0', '0' }; - const mmm: @Vector(3, u10) = .{ 100, 10, 1 }; - return @reduce(.Add, @as(@Vector(3, u10), nnn -% zero) *% mmm); - } - return std.fmt.parseInt(u10, text, 10) catch unreachable; - } - - test parseInt3 { - const expectEqual = testing.expectEqual; - try expectEqual(@as(u10, 0), parseInt3("000")); - try expectEqual(@as(u10, 418), parseInt3("418")); - try expectEqual(@as(u10, 999), parseInt3("999")); - } - - pub fn iterateHeaders(r: Response) http.HeaderIterator { - return http.HeaderIterator.init(r.parser.get()); - } - - test iterateHeaders { - const response_bytes = "HTTP/1.1 200 OK\r\n" ++ - "LOcation:url\r\n" ++ - "content-tYpe: text/plain\r\n" ++ - "content-disposition:attachment; filename=example.txt \r\n" ++ - "content-Length:10\r\n" ++ - "TRansfer-encoding:\tdeflate, chunked \r\n" ++ - "connectioN:\t keep-alive \r\n\r\n"; - - var header_buffer: [1024]u8 = undefined; - var res = Response{ - .status = undefined, - .reason = undefined, - .version = undefined, - .keep_alive = false, - .parser = proto.HeadersParser.init(&header_buffer), - }; - - @memcpy(header_buffer[0..response_bytes.len], response_bytes); - res.parser.header_bytes_len = response_bytes.len; - - var it = res.iterateHeaders(); - { - const header = it.next().?; - try testing.expectEqualStrings("LOcation", header.name); - try testing.expectEqualStrings("url", header.value); - try testing.expect(!it.is_trailer); - } - { - const header = it.next().?; - try testing.expectEqualStrings("content-tYpe", header.name); - try testing.expectEqualStrings("text/plain", header.value); - try testing.expect(!it.is_trailer); - } - { - const header = it.next().?; - try testing.expectEqualStrings("content-disposition", header.name); - try testing.expectEqualStrings("attachment; filename=example.txt", header.value); - try testing.expect(!it.is_trailer); - } - { - const header = it.next().?; - try testing.expectEqualStrings("content-Length", header.name); - try testing.expectEqualStrings("10", header.value); - try testing.expect(!it.is_trailer); - } - { - const header = it.next().?; - try testing.expectEqualStrings("TRansfer-encoding", header.name); - try testing.expectEqualStrings("deflate, chunked", header.value); - try testing.expect(!it.is_trailer); - } - { - const header = it.next().?; - try testing.expectEqualStrings("connectioN", header.name); - try testing.expectEqualStrings("keep-alive", header.value); - try testing.expect(!it.is_trailer); - } - try testing.expectEqual(null, it.next()); - } -}; - -/// A HTTP request that has been sent. -/// -/// Order of operations: open -> send[ -> write -> finish] -> wait -> read -pub const Request = struct { - uri: Uri = undefined, - client: *Client, - /// This is null when the connection is released. - connection: ?*Connection = null, - keep_alive: bool = undefined, - - method: http.Method = undefined, - version: http.Version = .@"HTTP/1.1", - transfer_encoding: RequestTransfer = undefined, - redirect_behavior: RedirectBehavior = undefined, - - /// Whether the request should handle a 100-continue response before sending the request body. - handle_continue: bool = undefined, - - /// The response associated with this request. - /// - /// This field is undefined until `wait` is called. - response: Response = undefined, - - /// Standard headers that have default, but overridable, behavior. - headers: Headers = undefined, - - /// These headers are kept including when following a redirect to a - /// different domain. - /// Externally-owned; must outlive the Request. - extra_headers: []const http.Header = undefined, - - /// These headers are stripped when following a redirect to a different - /// domain. - /// Externally-owned; must outlive the Request. - privileged_headers: []const http.Header = undefined, - - pub fn init(client: *Client) Request { - return .{ - .client = client, - }; - } - - pub const Headers = struct { - host: Value = .default, - authorization: Value = .default, - user_agent: Value = .default, - connection: Value = .default, - accept_encoding: Value = .default, - content_type: Value = .default, - - pub const Value = union(enum) { - default, - omit, - override: []const u8, - }; - }; - - /// Any value other than `not_allowed` or `unhandled` means that integer represents - /// how many remaining redirects are allowed. - pub const RedirectBehavior = enum(u16) { - /// The next redirect will cause an error. - not_allowed = 0, - /// Redirects are passed to the client to analyze the redirect response - /// directly. - unhandled = std.math.maxInt(u16), - _, - - pub fn subtractOne(rb: *RedirectBehavior) void { - switch (rb.*) { - .not_allowed => unreachable, - .unhandled => unreachable, - _ => rb.* = @enumFromInt(@intFromEnum(rb.*) - 1), - } - } - - pub fn remaining(rb: RedirectBehavior) u16 { - assert(rb != .unhandled); - return @intFromEnum(rb); - } - }; - - /// Frees all resources associated with the request. - pub fn deinit(req: *Request) void { - if (req.connection) |connection| { - if (!req.response.parser.done) { - // If the response wasn't fully read, then we need to close the connection. - connection.closing = true; - } - req.client.connection_pool.release(req.client.allocator, connection); - } - req.* = undefined; - } - - fn onRedirectSend(ctx: *Ctx, res: anyerror!void) !void { - res catch |err| return ctx.pop(err); - // go back on check headers - ctx.conn().async_fill(ctx, onResponseHeaders) catch |err| return ctx.pop(err); - } - - fn onRedirectConnect(ctx: *Ctx, res: anyerror!void) !void { - res catch |err| return ctx.pop(err); - // re-send request - ctx.req.prepareSend(.{}) catch |err| return ctx.pop(err); - ctx.req.connection.?.async_flush(ctx, onRedirectSend) catch |err| return ctx.pop(err); - } - - // async_redirect flow: - // connect -> setRequestConnection - // -> onRedirectConnect -> async_flush - // -> onRedirectSend -> async_fill - // -> go back on the wait workflow of the response - fn async_redirect(req: *Request, uri: Uri, ctx: *Ctx) !void { - try req.prepareRedirect(); - - var server_header = std.heap.FixedBufferAllocator.init(req.response.parser.header_bytes_buffer); - defer req.response.parser.header_bytes_buffer = server_header.buffer[server_header.end_index..]; - - const protocol, const valid_uri = try validateUri(uri, server_header.allocator()); - - const new_host = valid_uri.host.?.raw; - const prev_host = req.uri.host.?.raw; - const keep_privileged_headers = - std.ascii.eqlIgnoreCase(valid_uri.scheme, req.uri.scheme) and - std.ascii.endsWithIgnoreCase(new_host, prev_host) and - (new_host.len == prev_host.len or new_host[new_host.len - prev_host.len - 1] == '.'); - if (!keep_privileged_headers) { - // When redirecting to a different domain, strip privileged headers. - req.privileged_headers = &.{}; - } - - // create a new connection for the redirected URI - ctx.data.conn = try req.client.allocator.create(Connection); - ctx.data.conn.* = .{ - .stream = undefined, - .tls_client = undefined, - .protocol = undefined, - .host = undefined, - .port = undefined, - }; - req.uri = valid_uri; - return req.client.async_connect(new_host, uriPort(valid_uri, protocol), protocol, ctx, setRequestConnection); - } - - // This function must deallocate all resources associated with the request, - // or keep those which will be used. - // This needs to be kept in sync with deinit and request. - fn redirect(req: *Request, uri: Uri) !void { - try req.prepareRedirect(); - - var server_header = std.heap.FixedBufferAllocator.init(req.response.parser.header_bytes_buffer); - defer req.response.parser.header_bytes_buffer = server_header.buffer[server_header.end_index..]; - - const protocol, const valid_uri = try validateUri(uri, server_header.allocator()); - - const new_host = valid_uri.host.?.raw; - const prev_host = req.uri.host.?.raw; - const keep_privileged_headers = - std.ascii.eqlIgnoreCase(valid_uri.scheme, req.uri.scheme) and - std.ascii.endsWithIgnoreCase(new_host, prev_host) and - (new_host.len == prev_host.len or new_host[new_host.len - prev_host.len - 1] == '.'); - if (!keep_privileged_headers) { - // When redirecting to a different domain, strip privileged headers. - req.privileged_headers = &.{}; - } - - req.connection = try req.client.connect(new_host, uriPort(valid_uri, protocol), protocol); - req.uri = valid_uri; - } - fn prepareRedirect(req: *Request) !void { - assert(req.response.parser.done); - - req.client.connection_pool.release(req.client.allocator, req.connection.?); - req.connection = null; - - if (switch (req.response.status) { - .see_other => true, - .moved_permanently, .found => req.method == .POST, - else => false, - }) { - // A redirect to a GET must change the method and remove the body. - req.method = .GET; - req.transfer_encoding = .none; - req.headers.content_type = .omit; - } - - if (req.transfer_encoding != .none) { - // The request body has already been sent. The request is - // still in a valid state, but the redirect must be handled - // manually. - return error.RedirectRequiresResend; - } - - req.redirect_behavior.subtractOne(); - req.response.parser.reset(); - - req.response = .{ - .version = undefined, - .status = undefined, - .reason = undefined, - .keep_alive = undefined, - .parser = req.response.parser, - }; - } - - pub const SendError = Connection.WriteError || error{ InvalidContentLength, UnsupportedTransferEncoding }; - - pub fn async_send(req: *Request, ctx: *Ctx, comptime cbk: Cbk) !void { - try req.prepareSend(); - try req.connection.?.async_flush(ctx, cbk); - } - - /// Send the HTTP request headers to the server. - pub fn send(req: *Request) SendError!void { - try req.prepareSend(); - try req.connection.?.flush(); - } - - fn prepareSend(req: *Request) SendError!void { - if (!req.method.requestHasBody() and req.transfer_encoding != .none) - if (!req.method.requestHasBody() and req.transfer_encoding != .none) - return error.UnsupportedTransferEncoding; - - const connection = req.connection.?; - const w = connection.writer(); - - try req.method.write(w); - try w.writeByte(' '); - - if (req.method == .CONNECT) { - try req.uri.writeToStream(.{ .authority = true }, w); - } else { - try req.uri.writeToStream(.{ - .scheme = connection.proxied, - .authentication = connection.proxied, - .authority = connection.proxied, - .path = true, - .query = true, - }, w); - } - try w.writeByte(' '); - try w.writeAll(@tagName(req.version)); - try w.writeAll("\r\n"); - - if (try emitOverridableHeader("host: ", req.headers.host, w)) { - try w.writeAll("host: "); - try req.uri.writeToStream(.{ .authority = true }, w); - try w.writeAll("\r\n"); - } - - if (try emitOverridableHeader("authorization: ", req.headers.authorization, w)) { - if (req.uri.user != null or req.uri.password != null) { - try w.writeAll("authorization: "); - const authorization = try connection.allocWriteBuffer( - @intCast(basic_authorization.valueLengthFromUri(req.uri)), - ); - assert(basic_authorization.value(req.uri, authorization).len == authorization.len); - try w.writeAll("\r\n"); - } - } - - if (try emitOverridableHeader("user-agent: ", req.headers.user_agent, w)) { - try w.writeAll("user-agent: zig/"); - try w.writeAll(builtin.zig_version_string); - try w.writeAll(" (std.http)\r\n"); - } - - if (try emitOverridableHeader("connection: ", req.headers.connection, w)) { - if (req.keep_alive) { - try w.writeAll("connection: keep-alive\r\n"); - } else { - try w.writeAll("connection: close\r\n"); - } - } - - if (try emitOverridableHeader("accept-encoding: ", req.headers.accept_encoding, w)) { - // https://github.com/ziglang/zig/issues/18937 - //try w.writeAll("accept-encoding: gzip, deflate, zstd\r\n"); - try w.writeAll("accept-encoding: gzip, deflate\r\n"); - } - - switch (req.transfer_encoding) { - .chunked => try w.writeAll("transfer-encoding: chunked\r\n"), - .content_length => |len| try w.print("content-length: {d}\r\n", .{len}), - .none => {}, - } - - if (try emitOverridableHeader("content-type: ", req.headers.content_type, w)) { - // The default is to omit content-type if not provided because - // "application/octet-stream" is redundant. - } - - for (req.extra_headers) |header| { - assert(header.name.len != 0); - - try w.writeAll(header.name); - try w.writeAll(": "); - try w.writeAll(header.value); - try w.writeAll("\r\n"); - } - - if (connection.proxied) proxy: { - const proxy = switch (connection.protocol) { - .plain => req.client.http_proxy, - .tls => req.client.https_proxy, - } orelse break :proxy; - - const authorization = proxy.authorization orelse break :proxy; - try w.writeAll("proxy-authorization: "); - try w.writeAll(authorization); - try w.writeAll("\r\n"); - } - - try w.writeAll("\r\n"); - } - - /// Returns true if the default behavior is required, otherwise handles - /// writing (or not writing) the header. - fn emitOverridableHeader(prefix: []const u8, v: Headers.Value, w: anytype) !bool { - switch (v) { - .default => return true, - .omit => return false, - .override => |x| { - try w.writeAll(prefix); - try w.writeAll(x); - try w.writeAll("\r\n"); - return false; - }, - } - } - - const TransferReadError = Connection.ReadError || proto.HeadersParser.ReadError; - - const TransferReader = std.io.Reader(*Request, TransferReadError, transferRead); - - fn transferReader(req: *Request) TransferReader { - return .{ .context = req }; - } - - fn transferRead(req: *Request, buf: []u8) TransferReadError!usize { - if (req.response.parser.done) return 0; - - var index: usize = 0; - while (index == 0) { - const amt = try req.response.parser.read(req.connection.?, buf[index..], req.response.skip); - if (amt == 0 and req.response.parser.done) break; - index += amt; - } - - return index; - } - - pub const WaitError = RequestError || SendError || TransferReadError || - proto.HeadersParser.CheckCompleteHeadError || Response.ParseError || - error{ // TODO: file zig fmt issue for this bad indentation - TooManyHttpRedirects, - RedirectRequiresResend, - HttpRedirectLocationMissing, - HttpRedirectLocationInvalid, - CompressionInitializationFailed, - CompressionUnsupported, - }; - - pub fn async_wait(_: *Request, ctx: *Ctx, comptime cbk: Cbk) !void { - try ctx.push(cbk); - return ctx.conn().async_fill(ctx, onResponseHeaders); - } - - /// Waits for a response from the server and parses any headers that are sent. - /// This function will block until the final response is received. - /// - /// If handling redirects and the request has no payload, then this - /// function will automatically follow redirects. If a request payload is - /// present, then this function will error with - /// error.RedirectRequiresResend. - /// - /// Must be called after `send` and, if any data was written to the request - /// body, then also after `finish`. - pub fn wait(req: *Request) WaitError!void { - while (true) { - // This while loop is for handling redirects, which means the request's - // connection may be different than the previous iteration. However, it - // is still guaranteed to be non-null with each iteration of this loop. - const connection = req.connection.?; - - while (true) { // read headers - try connection.fill(); - - const nchecked = try req.response.parser.checkCompleteHead(connection.peek()); - connection.drop(@intCast(nchecked)); - - if (req.response.parser.state.isContent()) break; - } - - try req.response.parse(req.response.parser.get()); - - if (req.response.status == .@"continue") { - // We're done parsing the continue response; reset to prepare - // for the real response. - req.response.parser.done = true; - req.response.parser.reset(); - - if (req.handle_continue) - continue; - - return; // we're not handling the 100-continue - } - - // we're switching protocols, so this connection is no longer doing http - if (req.method == .CONNECT and req.response.status.class() == .success) { - connection.closing = false; - req.response.parser.done = true; - return; // the connection is not HTTP past this point - } - - connection.closing = !req.response.keep_alive or !req.keep_alive; - - // Any response to a HEAD request and any response with a 1xx - // (Informational), 204 (No Content), or 304 (Not Modified) status - // code is always terminated by the first empty line after the - // header fields, regardless of the header fields present in the - // message. - if (req.method == .HEAD or req.response.status.class() == .informational or - req.response.status == .no_content or req.response.status == .not_modified) - { - req.response.parser.done = true; - return; // The response is empty; no further setup or redirection is necessary. - } - - switch (req.response.transfer_encoding) { - .none => { - if (req.response.content_length) |cl| { - req.response.parser.next_chunk_length = cl; - - if (cl == 0) req.response.parser.done = true; - } else { - // read until the connection is closed - req.response.parser.next_chunk_length = std.math.maxInt(u64); - } - }, - .chunked => { - req.response.parser.next_chunk_length = 0; - req.response.parser.state = .chunk_head_size; - }, - } - - if (req.response.status.class() == .redirect and req.redirect_behavior != .unhandled) { - // skip the body of the redirect response, this will at least - // leave the connection in a known good state. - req.response.skip = true; - assert(try req.transferRead(&.{}) == 0); // we're skipping, no buffer is necessary - - if (req.redirect_behavior == .not_allowed) return error.TooManyHttpRedirects; - - const location = req.response.location orelse - return error.HttpRedirectLocationMissing; - - // This mutates the beginning of header_bytes_buffer and uses that - // for the backing memory of the returned Uri. - try req.redirect(req.uri.resolve_inplace( - location, - &req.response.parser.header_bytes_buffer, - ) catch |err| switch (err) { - error.UnexpectedCharacter, - error.InvalidFormat, - error.InvalidPort, - => return error.HttpRedirectLocationInvalid, - error.NoSpaceLeft => return error.HttpHeadersOversize, - }); - try req.send(); - } else { - req.response.skip = false; - if (!req.response.parser.done) { - switch (req.response.transfer_compression) { - .identity => req.response.compression = .none, - .compress, .@"x-compress" => return error.CompressionUnsupported, - .deflate => req.response.compression = .{ - .deflate = std.compress.zlib.decompressor(req.transferReader()), - }, - .gzip, .@"x-gzip" => req.response.compression = .{ - .gzip = std.compress.gzip.decompressor(req.transferReader()), - }, - // https://github.com/ziglang/zig/issues/18937 - //.zstd => req.response.compression = .{ - // .zstd = std.compress.zstd.decompressStream(req.client.allocator, req.transferReader()), - //}, - .zstd => return error.CompressionUnsupported, - } - } - - break; - } - } - } - - pub const ReadError = TransferReadError || proto.HeadersParser.CheckCompleteHeadError || - error{ DecompressionFailure, InvalidTrailers }; - - pub const Reader = std.io.Reader(*Request, ReadError, read); - - pub fn reader(req: *Request) Reader { - return .{ .context = req }; - } - - /// Reads data from the response body. Must be called after `wait`. - pub fn read(req: *Request, buffer: []u8) ReadError!usize { - const out_index = switch (req.response.compression) { - .deflate => |*deflate| deflate.read(buffer) catch return error.DecompressionFailure, - .gzip => |*gzip| gzip.read(buffer) catch return error.DecompressionFailure, - // https://github.com/ziglang/zig/issues/18937 - //.zstd => |*zstd| zstd.read(buffer) catch return error.DecompressionFailure, - else => try req.transferRead(buffer), - }; - if (out_index > 0) return out_index; - - while (!req.response.parser.state.isContent()) { // read trailing headers - try req.connection.?.fill(); - - const nchecked = try req.response.parser.checkCompleteHead(req.connection.?.peek()); - req.connection.?.drop(@intCast(nchecked)); - } - - return 0; - } - - /// Reads data from the response body. Must be called after `wait`. - pub fn readAll(req: *Request, buffer: []u8) !usize { - var index: usize = 0; - while (index < buffer.len) { - const amt = try read(req, buffer[index..]); - if (amt == 0) break; - index += amt; - } - return index; - } - - pub const WriteError = Connection.WriteError || error{ NotWriteable, MessageTooLong }; - - pub const Writer = std.io.Writer(*Request, WriteError, write); - - pub fn writer(req: *Request) Writer { - return .{ .context = req }; - } - - /// Write `bytes` to the server. The `transfer_encoding` field determines how data will be sent. - /// Must be called after `send` and before `finish`. - pub fn write(req: *Request, bytes: []const u8) WriteError!usize { - switch (req.transfer_encoding) { - .chunked => { - if (bytes.len > 0) { - try req.connection.?.writer().print("{x}\r\n", .{bytes.len}); - try req.connection.?.writer().writeAll(bytes); - try req.connection.?.writer().writeAll("\r\n"); - } - - return bytes.len; - }, - .content_length => |*len| { - if (len.* < bytes.len) return error.MessageTooLong; - - const amt = try req.connection.?.write(bytes); - len.* -= amt; - return amt; - }, - .none => return error.NotWriteable, - } - } - - /// Write `bytes` to the server. The `transfer_encoding` field determines how data will be sent. - /// Must be called after `send` and before `finish`. - pub fn writeAll(req: *Request, bytes: []const u8) WriteError!void { - var index: usize = 0; - while (index < bytes.len) { - index += try write(req, bytes[index..]); - } - } - - fn onWriteAll(ctx: *Ctx, res: anyerror!void) !void { - res catch |err| return ctx.pop(err); - switch (ctx.req.transfer_encoding) { - .chunked => unreachable, - .none => unreachable, - .content_length => |*len| { - len.* = 0; - }, - } - try ctx.pop({}); - } - - pub fn async_writeAll(req: *Request, buf: []const u8, ctx: *Ctx, comptime cbk: Cbk) !void { - switch (req.transfer_encoding) { - .chunked => return error.ChunkedNotImplemented, - .none => return error.NotWriteable, - .content_length => |len| { - try ctx.push(cbk); - if (len < buf.len) return error.MessageTooLong; - - try req.connection.?.async_writeAllDirect(buf, ctx, onWriteAll); - }, - } - } - - pub const FinishError = WriteError || error{MessageNotCompleted}; - - pub fn async_finish(req: *Request, ctx: *Ctx, comptime cbk: Cbk) !void { - try req.common_finish(); - req.connection.?.async_flush(ctx, cbk) catch |err| switch (err) { - error.WriteEmpty => return cbk(ctx, {}), - else => return cbk(ctx, err), - }; - } - - /// Finish the body of a request. This notifies the server that you have no more data to send. - /// Must be called after `send`. - pub fn finish(req: *Request) FinishError!void { - try req.common_finish(); - try req.connection.?.flush(); - } - - fn common_finish(req: *Request) FinishError!void { - switch (req.transfer_encoding) { - .chunked => try req.connection.?.writer().writeAll("0\r\n\r\n"), - .content_length => |len| if (len != 0) return error.MessageNotCompleted, - .none => {}, - } - } - - fn onResponseHeaders(ctx: *Ctx, res: anyerror!void) !void { - res catch |err| return ctx.pop(err); - const done = ctx.req.parseResponseHeaders() catch |err| return ctx.pop(err); - // if read of the headers is not done, continue - if (!done) return ctx.conn().async_fill(ctx, onResponseHeaders); - // if read of the headers is done, go read the reponse - return onResponse(ctx, {}); - } - - fn parseResponseHeaders(req: *Request) !bool { - const nchecked = try req.response.parser.checkCompleteHead(req.connection.?.peek()); - req.connection.?.drop(@intCast(nchecked)); - - if (req.response.parser.state.isContent()) return true; - return false; - } - - fn onResponse(ctx: *Ctx, res: anyerror!void) !void { - res catch |err| return ctx.pop(err); - const ret = ctx.req.parseResponse() catch |err| return ctx.pop(err); - if (ret.redirect_uri) |uri| { - ctx.req.async_redirect(uri, ctx) catch |err| return ctx.pop(err); - return; - } - // if read of the response is not done, continue - if (!ret.done) return ctx.conn().async_fill(ctx, onResponse); - // if read of the response is done, go execute the provided callback - return ctx.pop({}); - } - - const WaitRedirectsReturn = struct { - redirect_uri: ?Uri = null, - done: bool = true, - }; - - fn parseResponse(req: *Request) WaitError!WaitRedirectsReturn { - try req.response.parse(req.response.parser.get()); - - if (req.response.status == .@"continue") { - // We're done parsing the continue response; reset to prepare - // for the real response. - req.response.parser.done = true; - req.response.parser.reset(); - - if (req.handle_continue) return .{ .done = false }; - - return .{ .done = true }; - } - - // we're switching protocols, so this connection is no longer doing http - if (req.method == .CONNECT and req.response.status.class() == .success) { - req.connection.?.closing = false; - req.response.parser.done = true; - return .{ .done = true }; // the connection is not HTTP past this point - } - - req.connection.?.closing = !req.response.keep_alive or !req.keep_alive; - - // Any response to a HEAD request and any response with a 1xx - // (Informational), 204 (No Content), or 304 (Not Modified) status - // code is always terminated by the first empty line after the - // header fields, regardless of the header fields present in the - // message. - if (req.method == .HEAD or req.response.status.class() == .informational or - req.response.status == .no_content or req.response.status == .not_modified) - { - req.response.parser.done = true; - return .{ .done = true }; // The response is empty; no further setup or redirection is necessary. - } - - switch (req.response.transfer_encoding) { - .none => { - if (req.response.content_length) |cl| { - req.response.parser.next_chunk_length = cl; - - if (cl == 0) req.response.parser.done = true; - } else { - // read until the connection is closed - req.response.parser.next_chunk_length = std.math.maxInt(u64); - } - }, - .chunked => { - req.response.parser.next_chunk_length = 0; - req.response.parser.state = .chunk_head_size; - }, - } - - if (req.response.status.class() == .redirect and req.redirect_behavior != .unhandled) { - // skip the body of the redirect response, this will at least - // leave the connection in a known good state. - req.response.skip = true; - assert(try req.transferRead(&.{}) == 0); // we're skipping, no buffer is necessary - - if (req.redirect_behavior == .not_allowed) return error.TooManyHttpRedirects; - - const location = req.response.location orelse - return error.HttpRedirectLocationMissing; - - // This mutates the beginning of header_bytes_buffer and uses that - // for the backing memory of the returned Uri. - try req.redirect(req.uri.resolve_inplace( - location, - &req.response.parser.header_bytes_buffer, - ) catch |err| switch (err) { - error.UnexpectedCharacter, - error.InvalidFormat, - error.InvalidPort, - => return error.HttpRedirectLocationInvalid, - error.NoSpaceLeft => return error.HttpHeadersOversize, - }); - - return .{ .redirect_uri = req.uri }; - } else { - req.response.skip = false; - if (!req.response.parser.done) { - switch (req.response.transfer_compression) { - .identity => req.response.compression = .none, - .compress, .@"x-compress" => return error.CompressionUnsupported, - .deflate => req.response.compression = .{ - .deflate = std.compress.zlib.decompressor(req.transferReader()), - }, - .gzip, .@"x-gzip" => req.response.compression = .{ - .gzip = std.compress.gzip.decompressor(req.transferReader()), - }, - // https://github.com/ziglang/zig/issues/18937 - //.zstd => req.response.compression = .{ - // .zstd = std.compress.zstd.decompressStream(req.client.allocator, req.transferReader()), - //}, - .zstd => return error.CompressionUnsupported, - } - } - - return .{ .done = true }; - } - return .{ .done = false }; - } -}; - -pub const Proxy = struct { - protocol: Connection.Protocol, - host: []const u8, - authorization: ?[]const u8, - port: u16, - supports_connect: bool, -}; - -/// Release all associated resources with the client. -/// -/// All pending requests must be de-initialized and all active connections released -/// before calling this function. -pub fn deinit(client: *Client) void { - assert(client.connection_pool.used.first == null); // There are still active requests. - - client.connection_pool.deinit(client.allocator); - - if (!disable_tls) - client.ca_bundle.deinit(client.allocator); - - client.* = undefined; -} - -/// Populates `http_proxy` and `https_proxy` via standard proxy environment variables. -/// Asserts the client has no active connections. -/// Uses `arena` for a few small allocations that must outlive the client, or -/// at least until those fields are set to different values. -pub fn initDefaultProxies(client: *Client, arena: Allocator) !void { - // Prevent any new connections from being created. - client.connection_pool.mutex.lock(); - defer client.connection_pool.mutex.unlock(); - - assert(client.connection_pool.used.first == null); // There are active requests. - - if (client.http_proxy == null) { - client.http_proxy = try createProxyFromEnvVar(arena, &.{ - "http_proxy", "HTTP_PROXY", "all_proxy", "ALL_PROXY", - }); - } - - if (client.https_proxy == null) { - client.https_proxy = try createProxyFromEnvVar(arena, &.{ - "https_proxy", "HTTPS_PROXY", "all_proxy", "ALL_PROXY", - }); - } -} - -fn createProxyFromEnvVar(arena: Allocator, env_var_names: []const []const u8) !?*Proxy { - const content = for (env_var_names) |name| { - break std.process.getEnvVarOwned(arena, name) catch |err| switch (err) { - error.EnvironmentVariableNotFound => continue, - else => |e| return e, - }; - } else return null; - - const uri = Uri.parse(content) catch try Uri.parseAfterScheme("http", content); - const protocol, const valid_uri = validateUri(uri, arena) catch |err| switch (err) { - error.UnsupportedUriScheme => return null, - error.UriMissingHost => return error.HttpProxyMissingHost, - error.OutOfMemory => |e| return e, - }; - - const authorization: ?[]const u8 = if (valid_uri.user != null or valid_uri.password != null) a: { - const authorization = try arena.alloc(u8, basic_authorization.valueLengthFromUri(valid_uri)); - assert(basic_authorization.value(valid_uri, authorization).len == authorization.len); - break :a authorization; - } else null; - - const proxy = try arena.create(Proxy); - proxy.* = .{ - .protocol = protocol, - .host = valid_uri.host.?.raw, - .authorization = authorization, - .port = uriPort(valid_uri, protocol), - .supports_connect = true, - }; - return proxy; -} - -pub const basic_authorization = struct { - pub const max_user_len = 255; - pub const max_password_len = 255; - pub const max_value_len = valueLength(max_user_len, max_password_len); - - const prefix = "Basic "; - - pub fn valueLength(user_len: usize, password_len: usize) usize { - return prefix.len + std.base64.standard.Encoder.calcSize(user_len + 1 + password_len); - } - - pub fn valueLengthFromUri(uri: Uri) usize { - var stream = std.io.countingWriter(std.io.null_writer); - try stream.writer().print("{user}", .{uri.user orelse Uri.Component.empty}); - const user_len = stream.bytes_written; - stream.bytes_written = 0; - try stream.writer().print("{password}", .{uri.password orelse Uri.Component.empty}); - const password_len = stream.bytes_written; - return valueLength(@intCast(user_len), @intCast(password_len)); - } - - pub fn value(uri: Uri, out: []u8) []u8 { - var buf: [max_user_len + ":".len + max_password_len]u8 = undefined; - var stream = std.io.fixedBufferStream(&buf); - stream.writer().print("{user}", .{uri.user orelse Uri.Component.empty}) catch - unreachable; - assert(stream.pos <= max_user_len); - stream.writer().print(":{password}", .{uri.password orelse Uri.Component.empty}) catch - unreachable; - - @memcpy(out[0..prefix.len], prefix); - const base64 = std.base64.standard.Encoder.encode(out[prefix.len..], stream.getWritten()); - return out[0 .. prefix.len + base64.len]; - } -}; - -pub const ConnectTcpError = Allocator.Error || error{ ConnectionRefused, NetworkUnreachable, ConnectionTimedOut, ConnectionResetByPeer, TemporaryNameServerFailure, NameServerFailure, UnknownHostName, HostLacksNetworkAddresses, UnexpectedConnectFailure, TlsInitializationFailed }; - -// requires ctx.data.stream to be set -fn setConnection(ctx: *Ctx, res: anyerror!void) !void { - - // check stream - errdefer ctx.data.conn.stream.close(); - res catch |e| { - // ctx.data.conn.stream.close(); is it needed with errdefer? - switch (e) { - error.ConnectionRefused, - error.NetworkUnreachable, - error.ConnectionTimedOut, - error.ConnectionResetByPeer, - error.TemporaryNameServerFailure, - error.NameServerFailure, - error.UnknownHostName, - error.HostLacksNetworkAddresses, - => return ctx.pop(e), - else => return ctx.pop(error.UnexpectedConnectFailure), - } - }; - - if (ctx.data.conn.protocol == .tls) { - if (disable_tls) unreachable; - - ctx.data.conn.tls_client = try ctx.alloc().create(tls23.Connection(net.Stream)); - errdefer ctx.alloc().destroy(ctx.data.conn.tls_client); - - // TODO tls23.client does an handshake to pick a cipher. - ctx.data.conn.tls_client.* = tls23.client(ctx.data.conn.stream, .{ - .host = ctx.data.conn.host, - .root_ca = .{ .bundle = ctx.req.client.ca_bundle }, - }) catch return error.TlsInitializationFailed; - } - - // add connection node in pool - const node = ctx.req.client.allocator.create(ConnectionPool.Node) catch |e| return ctx.pop(e); - errdefer ctx.req.client.allocator.destroy(node); - // NOTE we can not use the ctx.data.conn pointer as a node connection data, - // we need to copy it's value and use this reference for the connection - node.* = .{ - .data = .{ - .stream = ctx.data.conn.stream, - .tls_client = ctx.data.conn.tls_client, - .protocol = ctx.data.conn.protocol, - .host = ctx.data.conn.host, - .port = ctx.data.conn.port, - }, - }; - // remove old pointer, now useless - const old_conn = ctx.data.conn; - defer ctx.req.client.allocator.destroy(old_conn); - - ctx.req.client.connection_pool.addUsed(node); - ctx.data.conn = &node.data; - - return ctx.pop({}); -} - -/// Connect to `host:port` using the specified protocol. This will reuse a connection if one is already open. -/// -/// This function is threadsafe. -pub fn connectTcp(client: *Client, host: []const u8, port: u16, protocol: Connection.Protocol) ConnectTcpError!*Connection { - if (client.connection_pool.findConnection(.{ - .host = host, - .port = port, - .protocol = protocol, - })) |node| return node; - - if (disable_tls and protocol == .tls) - return error.TlsInitializationFailed; - - const conn = try client.allocator.create(ConnectionPool.Node); - errdefer client.allocator.destroy(conn); - conn.* = .{ .data = undefined }; - - const stream = net.tcpConnectToHost(client.allocator, host, port) catch |err| switch (err) { - error.ConnectionRefused => return error.ConnectionRefused, - error.NetworkUnreachable => return error.NetworkUnreachable, - error.ConnectionTimedOut => return error.ConnectionTimedOut, - error.ConnectionResetByPeer => return error.ConnectionResetByPeer, - error.TemporaryNameServerFailure => return error.TemporaryNameServerFailure, - error.NameServerFailure => return error.NameServerFailure, - error.UnknownHostName => return error.UnknownHostName, - error.HostLacksNetworkAddresses => return error.HostLacksNetworkAddresses, - else => return error.UnexpectedConnectFailure, - }; - errdefer stream.close(); - - conn.data = .{ - .stream = stream, - .tls_client = undefined, - - .protocol = protocol, - .host = try client.allocator.dupe(u8, host), - .port = port, - }; - errdefer client.allocator.free(conn.data.host); - - if (protocol == .tls) { - if (disable_tls) unreachable; - - conn.data.tls_client = try client.allocator.create(tls23.Connection(net.Stream)); - errdefer client.allocator.destroy(conn.data.tls_client); - - // TODO tls23.client does an handshake to pick a cipher. - conn.data.tls_client.* = tls23.client(stream, .{ - .host = host, - .root_ca = .{ .bundle = client.ca_bundle }, - }) catch return error.TlsInitializationFailed; - } - - client.connection_pool.addUsed(conn); - - return &conn.data; -} - -pub fn async_connectTcp( - client: *Client, - host: []const u8, - port: u16, - protocol: Connection.Protocol, - ctx: *Ctx, - comptime cbk: Cbk, -) !void { - try ctx.push(cbk); - if (ctx.req.client.connection_pool.findConnection(.{ - .host = host, - .port = port, - .protocol = protocol, - })) |conn| { - ctx.data.conn = conn; - ctx.req.connection = conn; - return ctx.pop({}); - } - - if (disable_tls and protocol == .tls) - return error.TlsInitializationFailed; - - return net.async_tcpConnectToHost( - client.allocator, - host, - port, - ctx, - setConnection, - ); -} - -pub const ConnectUnixError = Allocator.Error || std.posix.SocketError || error{NameTooLong} || std.posix.ConnectError; - -// Connect to `path` as a unix domain socket. This will reuse a connection if one is already open. -// -// This function is threadsafe. -// pub fn connectUnix(client: *Client, path: []const u8) ConnectUnixError!*Connection { -// if (client.connection_pool.findConnection(.{ -// .host = path, -// .port = 0, -// .protocol = .plain, -// })) |node| -// return node; - -// const conn = try client.allocator.create(ConnectionPool.Node); -// errdefer client.allocator.destroy(conn); -// conn.* = .{ .data = undefined }; - -// const stream = try std.net.connectUnixSocket(path); -// errdefer stream.close(); - -// conn.data = .{ -// .stream = stream, -// .tls_client = undefined, -// .protocol = .plain, - -// .host = try client.allocator.dupe(u8, path), -// .port = 0, -// }; -// errdefer client.allocator.free(conn.data.host); - -// client.connection_pool.addUsed(conn); - -// return &conn.data; -//} - -/// Connect to `tunnel_host:tunnel_port` using the specified proxy with HTTP -/// CONNECT. This will reuse a connection if one is already open. -/// -/// This function is threadsafe. -pub fn connectTunnel( - client: *Client, - proxy: *Proxy, - tunnel_host: []const u8, - tunnel_port: u16, -) !*Connection { - if (!proxy.supports_connect) return error.TunnelNotSupported; - - if (client.connection_pool.findConnection(.{ - .host = tunnel_host, - .port = tunnel_port, - .protocol = proxy.protocol, - })) |node| - return node; - - var maybe_valid = false; - (tunnel: { - const conn = try client.connectTcp(proxy.host, proxy.port, proxy.protocol); - errdefer { - conn.closing = true; - client.connection_pool.release(client.allocator, conn); - } - - var buffer: [8096]u8 = undefined; - var req = client.open(.CONNECT, .{ - .scheme = "http", - .host = .{ .raw = tunnel_host }, - .port = tunnel_port, - }, .{ - .redirect_behavior = .unhandled, - .connection = conn, - .server_header_buffer = &buffer, - }) catch |err| { - break :tunnel err; - }; - defer req.deinit(); - - req.send() catch |err| break :tunnel err; - req.wait() catch |err| break :tunnel err; - - if (req.response.status.class() == .server_error) { - maybe_valid = true; - break :tunnel error.ServerError; - } - - if (req.response.status != .ok) break :tunnel error.ConnectionRefused; - - // this connection is now a tunnel, so we can't use it for anything else, it will only be released when the client is de-initialized. - req.connection = null; - - client.allocator.free(conn.host); - conn.host = try client.allocator.dupe(u8, tunnel_host); - errdefer client.allocator.free(conn.host); - - conn.port = tunnel_port; - conn.closing = false; - - return conn; - }) catch { - // something went wrong with the tunnel - proxy.supports_connect = maybe_valid; - return error.TunnelNotSupported; - }; -} - -// Prevents a dependency loop in open() -const ConnectErrorPartial = ConnectTcpError || error{ UnsupportedUriScheme, ConnectionRefused }; -pub const ConnectError = ConnectErrorPartial || RequestError; - -fn onConnectProxy(ctx: *Ctx, res: anyerror!void) anyerror!void { - res catch |e| { - ctx.data.conn.closing = true; - ctx.req.client.connection_pool.release(ctx.req.client.allocator, ctx.data.conn); - return ctx.pop(e); - }; - ctx.data.conn.proxied = true; - return ctx.pop({}); -} - -/// Connect to `host:port` using the specified protocol. This will reuse a -/// connection if one is already open. -/// If a proxy is configured for the client, then the proxy will be used to -/// connect to the host. -/// -/// This function is threadsafe. -pub fn connect( - client: *Client, - host: []const u8, - port: u16, - protocol: Connection.Protocol, -) ConnectError!*Connection { - const proxy = switch (protocol) { - .plain => client.http_proxy, - .tls => client.https_proxy, - } orelse return client.connectTcp(host, port, protocol); - - // Prevent proxying through itself. - if (std.ascii.eqlIgnoreCase(proxy.host, host) and - proxy.port == port and proxy.protocol == protocol) - { - return client.connectTcp(host, port, protocol); - } - - if (proxy.supports_connect) tunnel: { - return connectTunnel(client, proxy, host, port) catch |err| switch (err) { - error.TunnelNotSupported => break :tunnel, - else => |e| return e, - }; - } - - // fall back to using the proxy as a normal http proxy - const conn = try client.connectTcp(proxy.host, proxy.port, proxy.protocol); - errdefer { - conn.closing = true; - client.connection_pool.release(conn); - } - - conn.proxied = true; - return conn; -} - -pub fn async_connect( - client: *Client, - host: []const u8, - port: u16, - protocol: Connection.Protocol, - ctx: *Ctx, - comptime cbk: Cbk, -) !void { - const proxy = switch (protocol) { - .plain => client.http_proxy, - .tls => client.https_proxy, - } orelse return client.async_connectTcp(host, port, protocol, ctx, cbk); - - // Prevent proxying through itself. - if (std.ascii.eqlIgnoreCase(proxy.host, host) and - proxy.port == port and proxy.protocol == protocol) - { - return client.async_connectTcp(host, port, protocol, ctx, cbk); - } - - // TODO: enable async_connectTunnel - // if (proxy.supports_connect) tunnel: { - // return connectTunnel(client, proxy, host, port) catch |err| switch (err) { - // error.TunnelNotSupported => break :tunnel, - // else => |e| return e, - // }; - // } - - // fall back to using the proxy as a normal http proxy - try ctx.push(cbk); - return client.async_connectTcp(proxy.host, proxy.port, proxy.protocol, ctx, onConnectProxy); -} - -pub const RequestError = ConnectTcpError || ConnectErrorPartial || Request.SendError || - std.fmt.ParseIntError || Connection.WriteError || - error{ // TODO: file a zig fmt issue for this bad indentation - UnsupportedUriScheme, - UriMissingHost, - - CertificateBundleLoadFailure, - UnsupportedTransferEncoding, -}; - -pub const RequestOptions = struct { - version: http.Version = .@"HTTP/1.1", - - /// Automatically ignore 100 Continue responses. This assumes you don't - /// care, and will have sent the body before you wait for the response. - /// - /// If this is not the case AND you know the server will send a 100 - /// Continue, set this to false and wait for a response before sending the - /// body. If you wait AND the server does not send a 100 Continue before - /// you finish the request, then the request *will* deadlock. - handle_continue: bool = true, - - /// If false, close the connection after the one request. If true, - /// participate in the client connection pool. - keep_alive: bool = true, - - /// This field specifies whether to automatically follow redirects, and if - /// so, how many redirects to follow before returning an error. - /// - /// This will only follow redirects for repeatable requests (ie. with no - /// payload or the server has acknowledged the payload). - redirect_behavior: Request.RedirectBehavior = @enumFromInt(3), - - /// Externally-owned memory used to store the server's entire HTTP header. - /// `error.HttpHeadersOversize` is returned from read() when a - /// client sends too many bytes of HTTP headers. - server_header_buffer: []u8, - - /// Must be an already acquired connection. - connection: ?*Connection = null, - - /// Standard headers that have default, but overridable, behavior. - headers: Request.Headers = .{}, - /// These headers are kept including when following a redirect to a - /// different domain. - /// Externally-owned; must outlive the Request. - extra_headers: []const http.Header = &.{}, - /// These headers are stripped when following a redirect to a different - /// domain. - /// Externally-owned; must outlive the Request. - privileged_headers: []const http.Header = &.{}, -}; - -const protocol_map = std.StaticStringMap(Connection.Protocol).initComptime(.{ - .{ "http", .plain }, - .{ "ws", .plain }, - .{ "https", .tls }, - .{ "wss", .tls }, -}); - -fn validateUri(uri: Uri, arena: Allocator) !struct { Connection.Protocol, Uri } { - const protocol = protocol_map.get(uri.scheme) orelse return error.UnsupportedUriScheme; - var valid_uri = uri; - // The host is always going to be needed as a raw string for hostname resolution anyway. - valid_uri.host = .{ - .raw = try (uri.host orelse return error.UriMissingHost).toRawMaybeAlloc(arena), - }; - return .{ protocol, valid_uri }; -} - -fn uriPort(uri: Uri, protocol: Connection.Protocol) u16 { - return uri.port orelse switch (protocol) { - .plain => 80, - .tls => 443, - }; -} - -pub fn create( - client: *Client, - method: http.Method, - uri: Uri, - options: RequestOptions, -) RequestError!Request { - if (std.debug.runtime_safety) { - for (options.extra_headers) |header| { - assert(header.name.len != 0); - assert(std.mem.indexOfScalar(u8, header.name, ':') == null); - assert(std.mem.indexOfPosLinear(u8, header.name, 0, "\r\n") == null); - assert(std.mem.indexOfPosLinear(u8, header.value, 0, "\r\n") == null); - } - for (options.privileged_headers) |header| { - assert(header.name.len != 0); - assert(std.mem.indexOfPosLinear(u8, header.name, 0, "\r\n") == null); - assert(std.mem.indexOfPosLinear(u8, header.value, 0, "\r\n") == null); - } - } - - var server_header = std.heap.FixedBufferAllocator.init(options.server_header_buffer); - _, const valid_uri = try validateUri(uri, server_header.allocator()); - - var req: Request = .{ - .uri = valid_uri, - .client = client, - .keep_alive = options.keep_alive, - .method = method, - .version = options.version, - .transfer_encoding = .none, - .redirect_behavior = options.redirect_behavior, - .handle_continue = options.handle_continue, - .response = .{ - .version = undefined, - .status = undefined, - .reason = undefined, - .keep_alive = undefined, - .parser = proto.HeadersParser.init(server_header.buffer[server_header.end_index..]), - }, - .headers = options.headers, - .extra_headers = options.extra_headers, - .privileged_headers = options.privileged_headers, - }; - errdefer req.deinit(); - - return req; -} - -/// Open a connection to the host specified by `uri` and prepare to send a HTTP request. -/// -/// `uri` must remain alive during the entire request. -/// -/// The caller is responsible for calling `deinit()` on the `Request`. -/// This function is threadsafe. -/// -/// Asserts that "\r\n" does not occur in any header name or value. -pub fn open( - client: *Client, - method: http.Method, - uri: Uri, - options: RequestOptions, -) RequestError!Request { - if (std.debug.runtime_safety) { - for (options.extra_headers) |header| { - assert(header.name.len != 0); - assert(std.mem.indexOfScalar(u8, header.name, ':') == null); - assert(std.mem.indexOfPosLinear(u8, header.name, 0, "\r\n") == null); - assert(std.mem.indexOfPosLinear(u8, header.value, 0, "\r\n") == null); - } - for (options.privileged_headers) |header| { - assert(header.name.len != 0); - assert(std.mem.indexOfPosLinear(u8, header.name, 0, "\r\n") == null); - assert(std.mem.indexOfPosLinear(u8, header.value, 0, "\r\n") == null); - } - } - - var server_header = std.heap.FixedBufferAllocator.init(options.server_header_buffer); - const protocol, const valid_uri = try validateUri(uri, server_header.allocator()); - - if (protocol == .tls and @atomicLoad(bool, &client.next_https_rescan_certs, .acquire)) { - if (disable_tls) unreachable; - - client.ca_bundle_mutex.lock(); - defer client.ca_bundle_mutex.unlock(); - - if (client.next_https_rescan_certs) { - client.ca_bundle.rescan(client.allocator) catch - return error.CertificateBundleLoadFailure; - @atomicStore(bool, &client.next_https_rescan_certs, false, .release); - } - } - - const conn = options.connection orelse - try client.connect(valid_uri.host.?.raw, uriPort(valid_uri, protocol), protocol); - - var req: Request = .{ - .uri = valid_uri, - .client = client, - .connection = conn, - .keep_alive = options.keep_alive, - .method = method, - .version = options.version, - .transfer_encoding = .none, - .redirect_behavior = options.redirect_behavior, - .handle_continue = options.handle_continue, - .response = .{ - .version = undefined, - .status = undefined, - .reason = undefined, - .keep_alive = undefined, - .parser = proto.HeadersParser.init(server_header.buffer[server_header.end_index..]), - }, - .headers = options.headers, - .extra_headers = options.extra_headers, - .privileged_headers = options.privileged_headers, - }; - errdefer req.deinit(); - - return req; -} - -pub fn async_open( - client: *Client, - method: http.Method, - uri: Uri, - options: RequestOptions, - ctx: *Ctx, - comptime cbk: Cbk, -) !void { - if (std.debug.runtime_safety) { - for (options.extra_headers) |header| { - assert(header.name.len != 0); - assert(std.mem.indexOfScalar(u8, header.name, ':') == null); - assert(std.mem.indexOfPosLinear(u8, header.name, 0, "\r\n") == null); - assert(std.mem.indexOfPosLinear(u8, header.value, 0, "\r\n") == null); - } - for (options.privileged_headers) |header| { - assert(header.name.len != 0); - assert(std.mem.indexOfPosLinear(u8, header.name, 0, "\r\n") == null); - assert(std.mem.indexOfPosLinear(u8, header.value, 0, "\r\n") == null); - } - } - - var server_header = std.heap.FixedBufferAllocator.init(options.server_header_buffer); - const protocol, const valid_uri = try validateUri(uri, server_header.allocator()); - - if (protocol == .tls and @atomicLoad(bool, &client.next_https_rescan_certs, .acquire)) { - if (disable_tls) unreachable; - - client.ca_bundle_mutex.lock(); - defer client.ca_bundle_mutex.unlock(); - - if (client.next_https_rescan_certs) { - client.ca_bundle.rescan(client.allocator) catch return error.CertificateBundleLoadFailure; - @atomicStore(bool, &client.next_https_rescan_certs, false, .release); - } - } - - // add fields to request - ctx.req.uri = valid_uri; - ctx.req.keep_alive = options.keep_alive; - ctx.req.method = method; - ctx.req.transfer_encoding = .none; - ctx.req.redirect_behavior = options.redirect_behavior; - ctx.req.handle_continue = options.handle_continue; - ctx.req.headers = options.headers; - ctx.req.extra_headers = options.extra_headers; - ctx.req.privileged_headers = options.privileged_headers; - ctx.req.response = .{ - .version = undefined, - .status = undefined, - .reason = undefined, - .keep_alive = undefined, - .parser = proto.HeadersParser.init(server_header.buffer[server_header.end_index..]), - }; - - // we already have the connection, - // set it and call directly the callback - if (options.connection) |conn| { - ctx.req.connection = conn; - return cbk(ctx, {}); - } - - // push callback function - try ctx.push(cbk); - - const host = valid_uri.host orelse return error.UriMissingHost; - const port = uriPort(valid_uri, protocol); - - // add fields to connection - ctx.data.conn.protocol = protocol; - ctx.data.conn.host = try client.allocator.dupe(u8, host.raw); - ctx.data.conn.port = port; - - return client.async_connect(host.raw, port, protocol, ctx, setRequestConnection); -} - -pub const FetchOptions = struct { - server_header_buffer: ?[]u8 = null, - redirect_behavior: ?Request.RedirectBehavior = null, - - /// If the server sends a body, it will be appended to this ArrayList. - /// `max_append_size` provides an upper limit for how much they can grow. - response_storage: ResponseStorage = .ignore, - max_append_size: ?usize = null, - - location: Location, - method: ?http.Method = null, - payload: ?[]const u8 = null, - raw_uri: bool = false, - keep_alive: bool = true, - - /// Standard headers that have default, but overridable, behavior. - headers: Request.Headers = .{}, - /// These headers are kept including when following a redirect to a - /// different domain. - /// Externally-owned; must outlive the Request. - extra_headers: []const http.Header = &.{}, - /// These headers are stripped when following a redirect to a different - /// domain. - /// Externally-owned; must outlive the Request. - privileged_headers: []const http.Header = &.{}, - - pub const Location = union(enum) { - url: []const u8, - uri: Uri, - }; - - pub const ResponseStorage = union(enum) { - ignore, - /// Only the existing capacity will be used. - static: *std.ArrayListUnmanaged(u8), - dynamic: *std.ArrayList(u8), - }; -}; - -pub const FetchResult = struct { - status: http.Status, -}; - -// TODO: enable async_fetch -/// Perform a one-shot HTTP request with the provided options. -/// -/// This function is threadsafe. -pub fn fetch(client: *Client, options: FetchOptions) !FetchResult { - const uri = switch (options.location) { - .url => |u| try Uri.parse(u), - .uri => |u| u, - }; - var server_header_buffer: [16 * 1024]u8 = undefined; - - const method: http.Method = options.method orelse - if (options.payload != null) .POST else .GET; - - var req = try open(client, method, uri, .{ - .server_header_buffer = options.server_header_buffer orelse &server_header_buffer, - .redirect_behavior = options.redirect_behavior orelse - if (options.payload == null) @enumFromInt(3) else .unhandled, - .headers = options.headers, - .extra_headers = options.extra_headers, - .privileged_headers = options.privileged_headers, - .keep_alive = options.keep_alive, - }); - defer req.deinit(); - - if (options.payload) |payload| req.transfer_encoding = .{ .content_length = payload.len }; - - try req.send(); - - if (options.payload) |payload| try req.writeAll(payload); - - try req.finish(); - try req.wait(); - - switch (options.response_storage) { - .ignore => { - // Take advantage of request internals to discard the response body - // and make the connection available for another request. - req.response.skip = true; - assert(try req.transferRead(&.{}) == 0); // No buffer is necessary when skipping. - }, - .dynamic => |list| { - const max_append_size = options.max_append_size orelse 2 * 1024 * 1024; - try req.reader().readAllArrayList(list, max_append_size); - }, - .static => |list| { - const buf = b: { - const buf = list.unusedCapacitySlice(); - if (options.max_append_size) |len| { - if (len < buf.len) break :b buf[0..len]; - } - break :b buf; - }; - list.items.len += try req.reader().readAll(buf); - }, - } - - return .{ - .status = req.response.status, - }; -} - -pub const Cbk = fn (ctx: *Ctx, res: anyerror!void) anyerror!void; - -pub const Ctx = struct { - const Stack = GenericStack(Cbk); - - // temporary Data we need to store on the heap - // because of the callback execution model - const Data = struct { - list: *std.net.AddressList = undefined, - addr_current: usize = undefined, - socket: std.posix.socket_t = undefined, - - // TODO: we could remove this field as it is already set in ctx.req - // but we do not know for now what will be the impact to set those directly - // on the request, especially in case of error/cancellation - conn: *Connection, - }; - - req: *Request = undefined, - - userData: *anyopaque = undefined, - - loop: *Loop, - data: Data, - stack: ?*Stack = null, - err: ?anyerror = null, - - _buffer: ?[]const u8 = null, - _len: ?usize = null, - - _iovecs: []std.posix.iovec = undefined, - - // TLS readvAtLeast - // _off_i: usize = 0, - // _vec_i: usize = 0, - // _tls_len: usize = 0, - - // TLS readv - _vp: VecPut = undefined, - // _tls_read_buf contains the next decrypted buffer - _tls_read_buf: ?[]u8 = undefined, - _tls_read_content_type: tls23.proto.ContentType = undefined, - - // _tls_read_record contains the crypted record - _tls_read_record: ?tls23.record.Record = null, - - // TLS writeAll - _tls_write_bytes: []const u8 = undefined, - _tls_write_index: usize = 0, - _tls_write_buf: [cipher.max_ciphertext_record_len]u8 = undefined, - - pub fn init(loop: *Loop, req: *Request) !Ctx { - const connection = try req.client.allocator.create(Connection); - connection.* = .{ - .stream = undefined, - .tls_client = undefined, - .protocol = undefined, - .host = undefined, - .port = undefined, - }; - return .{ - .req = req, - .loop = loop, - .data = .{ .conn = connection }, - }; - } - - pub fn setErr(self: *Ctx, err: anyerror) void { - self.err = err; - } - - pub fn push(self: *Ctx, comptime func: Stack.Fn) !void { - if (self.stack) |stack| { - return try stack.push(self.alloc(), func); - } - self.stack = try Stack.init(self.alloc(), func); - } - - pub fn pop(self: *Ctx, res: anyerror!void) !void { - if (self.stack) |stack| { - const allocator = self.alloc(); - const func = stack.pop(allocator, null); - - defer { - if (stack.next == null) { - allocator.destroy(stack); - self.stack = null; - } - } - - return @call(.auto, func, .{ self, res }); - } - unreachable; - } - - pub fn deinit(self: Ctx) void { - if (self.stack) |stack| { - stack.deinit(self.alloc(), null); - } - } - - // not sure about those - - pub fn len(self: Ctx) usize { - if (self._len == null) unreachable; - return self._len.?; - } - - pub fn setLen(self: *Ctx, nb: ?usize) void { - self._len = nb; - } - - pub fn buf(self: Ctx) []const u8 { - if (self._buffer == null) unreachable; - return self._buffer.?; - } - - pub fn setBuf(self: *Ctx, bytes: ?[]const u8) void { - self._buffer = bytes; - } - - // ctx Request aliases - - pub fn alloc(self: Ctx) std.mem.Allocator { - return self.req.client.allocator; - } - - pub fn conn(self: Ctx) *Connection { - return self.req.connection.?; - } - - pub fn stream(self: Ctx) net.Stream { - return self.conn().stream; - } -}; - -// requires ctx.data.conn to be set -fn setRequestConnection(ctx: *Ctx, res: anyerror!void) anyerror!void { - res catch |e| return ctx.pop(e); - - ctx.req.connection = ctx.data.conn; - return ctx.pop({}); -} diff --git a/src/http/async/std/http/Server.zig b/src/http/async/std/http/Server.zig deleted file mode 100644 index 38d3f133..00000000 --- a/src/http/async/std/http/Server.zig +++ /dev/null @@ -1,1148 +0,0 @@ -//! Blocking HTTP server implementation. -//! Handles a single connection's lifecycle. - -connection: net.Server.Connection, -/// Keeps track of whether the Server is ready to accept a new request on the -/// same connection, and makes invalid API usage cause assertion failures -/// rather than HTTP protocol violations. -state: State, -/// User-provided buffer that must outlive this Server. -/// Used to store the client's entire HTTP header. -read_buffer: []u8, -/// Amount of available data inside read_buffer. -read_buffer_len: usize, -/// Index into `read_buffer` of the first byte of the next HTTP request. -next_request_start: usize, - -pub const State = enum { - /// The connection is available to be used for the first time, or reused. - ready, - /// An error occurred in `receiveHead`. - receiving_head, - /// A Request object has been obtained and from there a Response can be - /// opened. - received_head, - /// The client is uploading something to this Server. - receiving_body, - /// The connection is eligible for another HTTP request, however the client - /// and server did not negotiate a persistent connection. - closing, -}; - -/// Initialize an HTTP server that can respond to multiple requests on the same -/// connection. -/// The returned `Server` is ready for `receiveHead` to be called. -pub fn init(connection: net.Server.Connection, read_buffer: []u8) Server { - return .{ - .connection = connection, - .state = .ready, - .read_buffer = read_buffer, - .read_buffer_len = 0, - .next_request_start = 0, - }; -} - -pub const ReceiveHeadError = error{ - /// Client sent too many bytes of HTTP headers. - /// The HTTP specification suggests to respond with a 431 status code - /// before closing the connection. - HttpHeadersOversize, - /// Client sent headers that did not conform to the HTTP protocol. - HttpHeadersInvalid, - /// A low level I/O error occurred trying to read the headers. - HttpHeadersUnreadable, - /// Partial HTTP request was received but the connection was closed before - /// fully receiving the headers. - HttpRequestTruncated, - /// The client sent 0 bytes of headers before closing the stream. - /// In other words, a keep-alive connection was finally closed. - HttpConnectionClosing, -}; - -/// The header bytes reference the read buffer that Server was initialized with -/// and remain alive until the next call to receiveHead. -pub fn receiveHead(s: *Server) ReceiveHeadError!Request { - assert(s.state == .ready); - s.state = .received_head; - errdefer s.state = .receiving_head; - - // In case of a reused connection, move the next request's bytes to the - // beginning of the buffer. - if (s.next_request_start > 0) { - if (s.read_buffer_len > s.next_request_start) { - rebase(s, 0); - } else { - s.read_buffer_len = 0; - } - } - - var hp: http.HeadParser = .{}; - - if (s.read_buffer_len > 0) { - const bytes = s.read_buffer[0..s.read_buffer_len]; - const end = hp.feed(bytes); - if (hp.state == .finished) - return finishReceivingHead(s, end); - } - - while (true) { - const buf = s.read_buffer[s.read_buffer_len..]; - if (buf.len == 0) - return error.HttpHeadersOversize; - const read_n = s.connection.stream.read(buf) catch - return error.HttpHeadersUnreadable; - if (read_n == 0) { - if (s.read_buffer_len > 0) { - return error.HttpRequestTruncated; - } else { - return error.HttpConnectionClosing; - } - } - s.read_buffer_len += read_n; - const bytes = buf[0..read_n]; - const end = hp.feed(bytes); - if (hp.state == .finished) - return finishReceivingHead(s, s.read_buffer_len - bytes.len + end); - } -} - -fn finishReceivingHead(s: *Server, head_end: usize) ReceiveHeadError!Request { - return .{ - .server = s, - .head_end = head_end, - .head = Request.Head.parse(s.read_buffer[0..head_end]) catch - return error.HttpHeadersInvalid, - .reader_state = undefined, - }; -} - -pub const Request = struct { - server: *Server, - /// Index into Server's read_buffer. - head_end: usize, - head: Head, - reader_state: union { - remaining_content_length: u64, - chunk_parser: http.ChunkParser, - }, - - pub const Compression = union(enum) { - pub const DeflateDecompressor = std.compress.zlib.Decompressor(std.io.AnyReader); - pub const GzipDecompressor = std.compress.gzip.Decompressor(std.io.AnyReader); - pub const ZstdDecompressor = std.compress.zstd.Decompressor(std.io.AnyReader); - - deflate: DeflateDecompressor, - gzip: GzipDecompressor, - zstd: ZstdDecompressor, - none: void, - }; - - pub const Head = struct { - method: http.Method, - target: []const u8, - version: http.Version, - expect: ?[]const u8, - content_type: ?[]const u8, - content_length: ?u64, - transfer_encoding: http.TransferEncoding, - transfer_compression: http.ContentEncoding, - keep_alive: bool, - compression: Compression, - - pub const ParseError = error{ - UnknownHttpMethod, - HttpHeadersInvalid, - HttpHeaderContinuationsUnsupported, - HttpTransferEncodingUnsupported, - HttpConnectionHeaderUnsupported, - InvalidContentLength, - CompressionUnsupported, - MissingFinalNewline, - }; - - pub fn parse(bytes: []const u8) ParseError!Head { - var it = mem.splitSequence(u8, bytes, "\r\n"); - - const first_line = it.next().?; - if (first_line.len < 10) - return error.HttpHeadersInvalid; - - const method_end = mem.indexOfScalar(u8, first_line, ' ') orelse - return error.HttpHeadersInvalid; - if (method_end > 24) return error.HttpHeadersInvalid; - - const method_str = first_line[0..method_end]; - const method: http.Method = @enumFromInt(http.Method.parse(method_str)); - - const version_start = mem.lastIndexOfScalar(u8, first_line, ' ') orelse - return error.HttpHeadersInvalid; - if (version_start == method_end) return error.HttpHeadersInvalid; - - const version_str = first_line[version_start + 1 ..]; - if (version_str.len != 8) return error.HttpHeadersInvalid; - const version: http.Version = switch (int64(version_str[0..8])) { - int64("HTTP/1.0") => .@"HTTP/1.0", - int64("HTTP/1.1") => .@"HTTP/1.1", - else => return error.HttpHeadersInvalid, - }; - - const target = first_line[method_end + 1 .. version_start]; - - var head: Head = .{ - .method = method, - .target = target, - .version = version, - .expect = null, - .content_type = null, - .content_length = null, - .transfer_encoding = .none, - .transfer_compression = .identity, - .keep_alive = switch (version) { - .@"HTTP/1.0" => false, - .@"HTTP/1.1" => true, - }, - .compression = .none, - }; - - while (it.next()) |line| { - if (line.len == 0) return head; - switch (line[0]) { - ' ', '\t' => return error.HttpHeaderContinuationsUnsupported, - else => {}, - } - - var line_it = mem.splitScalar(u8, line, ':'); - const header_name = line_it.next().?; - const header_value = mem.trim(u8, line_it.rest(), " \t"); - if (header_name.len == 0) return error.HttpHeadersInvalid; - - if (std.ascii.eqlIgnoreCase(header_name, "connection")) { - head.keep_alive = !std.ascii.eqlIgnoreCase(header_value, "close"); - } else if (std.ascii.eqlIgnoreCase(header_name, "expect")) { - head.expect = header_value; - } else if (std.ascii.eqlIgnoreCase(header_name, "content-type")) { - head.content_type = header_value; - } else if (std.ascii.eqlIgnoreCase(header_name, "content-length")) { - if (head.content_length != null) return error.HttpHeadersInvalid; - head.content_length = std.fmt.parseInt(u64, header_value, 10) catch - return error.InvalidContentLength; - } else if (std.ascii.eqlIgnoreCase(header_name, "content-encoding")) { - if (head.transfer_compression != .identity) return error.HttpHeadersInvalid; - - const trimmed = mem.trim(u8, header_value, " "); - - if (std.meta.stringToEnum(http.ContentEncoding, trimmed)) |ce| { - head.transfer_compression = ce; - } else { - return error.HttpTransferEncodingUnsupported; - } - } else if (std.ascii.eqlIgnoreCase(header_name, "transfer-encoding")) { - // Transfer-Encoding: second, first - // Transfer-Encoding: deflate, chunked - var iter = mem.splitBackwardsScalar(u8, header_value, ','); - - const first = iter.first(); - const trimmed_first = mem.trim(u8, first, " "); - - var next: ?[]const u8 = first; - if (std.meta.stringToEnum(http.TransferEncoding, trimmed_first)) |transfer| { - if (head.transfer_encoding != .none) - return error.HttpHeadersInvalid; // we already have a transfer encoding - head.transfer_encoding = transfer; - - next = iter.next(); - } - - if (next) |second| { - const trimmed_second = mem.trim(u8, second, " "); - - if (std.meta.stringToEnum(http.ContentEncoding, trimmed_second)) |transfer| { - if (head.transfer_compression != .identity) - return error.HttpHeadersInvalid; // double compression is not supported - head.transfer_compression = transfer; - } else { - return error.HttpTransferEncodingUnsupported; - } - } - - if (iter.next()) |_| return error.HttpTransferEncodingUnsupported; - } - } - return error.MissingFinalNewline; - } - - test parse { - const request_bytes = "GET /hi HTTP/1.0\r\n" ++ - "content-tYpe: text/plain\r\n" ++ - "content-Length:10\r\n" ++ - "expeCt: 100-continue \r\n" ++ - "TRansfer-encoding:\tdeflate, chunked \r\n" ++ - "connectioN:\t keep-alive \r\n\r\n"; - - const req = try parse(request_bytes); - - try testing.expectEqual(.GET, req.method); - try testing.expectEqual(.@"HTTP/1.0", req.version); - try testing.expectEqualStrings("/hi", req.target); - - try testing.expectEqualStrings("text/plain", req.content_type.?); - try testing.expectEqualStrings("100-continue", req.expect.?); - - try testing.expectEqual(true, req.keep_alive); - try testing.expectEqual(10, req.content_length.?); - try testing.expectEqual(.chunked, req.transfer_encoding); - try testing.expectEqual(.deflate, req.transfer_compression); - } - - inline fn int64(array: *const [8]u8) u64 { - return @bitCast(array.*); - } - }; - - pub fn iterateHeaders(r: *Request) http.HeaderIterator { - return http.HeaderIterator.init(r.server.read_buffer[0..r.head_end]); - } - - test iterateHeaders { - const request_bytes = "GET /hi HTTP/1.0\r\n" ++ - "content-tYpe: text/plain\r\n" ++ - "content-Length:10\r\n" ++ - "expeCt: 100-continue \r\n" ++ - "TRansfer-encoding:\tdeflate, chunked \r\n" ++ - "connectioN:\t keep-alive \r\n\r\n"; - - var read_buffer: [500]u8 = undefined; - @memcpy(read_buffer[0..request_bytes.len], request_bytes); - - var server: Server = .{ - .connection = undefined, - .state = .ready, - .read_buffer = &read_buffer, - .read_buffer_len = request_bytes.len, - .next_request_start = 0, - }; - - var request: Request = .{ - .server = &server, - .head_end = request_bytes.len, - .head = undefined, - .reader_state = undefined, - }; - - var it = request.iterateHeaders(); - { - const header = it.next().?; - try testing.expectEqualStrings("content-tYpe", header.name); - try testing.expectEqualStrings("text/plain", header.value); - try testing.expect(!it.is_trailer); - } - { - const header = it.next().?; - try testing.expectEqualStrings("content-Length", header.name); - try testing.expectEqualStrings("10", header.value); - try testing.expect(!it.is_trailer); - } - { - const header = it.next().?; - try testing.expectEqualStrings("expeCt", header.name); - try testing.expectEqualStrings("100-continue", header.value); - try testing.expect(!it.is_trailer); - } - { - const header = it.next().?; - try testing.expectEqualStrings("TRansfer-encoding", header.name); - try testing.expectEqualStrings("deflate, chunked", header.value); - try testing.expect(!it.is_trailer); - } - { - const header = it.next().?; - try testing.expectEqualStrings("connectioN", header.name); - try testing.expectEqualStrings("keep-alive", header.value); - try testing.expect(!it.is_trailer); - } - try testing.expectEqual(null, it.next()); - } - - pub const RespondOptions = struct { - version: http.Version = .@"HTTP/1.1", - status: http.Status = .ok, - reason: ?[]const u8 = null, - keep_alive: bool = true, - extra_headers: []const http.Header = &.{}, - transfer_encoding: ?http.TransferEncoding = null, - }; - - /// Send an entire HTTP response to the client, including headers and body. - /// - /// Automatically handles HEAD requests by omitting the body. - /// - /// Unless `transfer_encoding` is specified, uses the "content-length" - /// header. - /// - /// If the request contains a body and the connection is to be reused, - /// discards the request body, leaving the Server in the `ready` state. If - /// this discarding fails, the connection is marked as not to be reused and - /// no error is surfaced. - /// - /// Asserts status is not `continue`. - /// Asserts there are at most 25 extra_headers. - /// Asserts that "\r\n" does not occur in any header name or value. - pub fn respond( - request: *Request, - content: []const u8, - options: RespondOptions, - ) Response.WriteError!void { - const max_extra_headers = 25; - assert(options.status != .@"continue"); - assert(options.extra_headers.len <= max_extra_headers); - if (std.debug.runtime_safety) { - for (options.extra_headers) |header| { - assert(header.name.len != 0); - assert(std.mem.indexOfScalar(u8, header.name, ':') == null); - assert(std.mem.indexOfPosLinear(u8, header.name, 0, "\r\n") == null); - assert(std.mem.indexOfPosLinear(u8, header.value, 0, "\r\n") == null); - } - } - - const transfer_encoding_none = (options.transfer_encoding orelse .chunked) == .none; - const server_keep_alive = !transfer_encoding_none and options.keep_alive; - const keep_alive = request.discardBody(server_keep_alive); - - const phrase = options.reason orelse options.status.phrase() orelse ""; - - var first_buffer: [500]u8 = undefined; - var h = std.ArrayListUnmanaged(u8).initBuffer(&first_buffer); - if (request.head.expect != null) { - // reader() and hence discardBody() above sets expect to null if it - // is handled. So the fact that it is not null here means unhandled. - h.appendSliceAssumeCapacity("HTTP/1.1 417 Expectation Failed\r\n"); - if (!keep_alive) h.appendSliceAssumeCapacity("connection: close\r\n"); - h.appendSliceAssumeCapacity("content-length: 0\r\n\r\n"); - try request.server.connection.stream.writeAll(h.items); - return; - } - h.fixedWriter().print("{s} {d} {s}\r\n", .{ - @tagName(options.version), @intFromEnum(options.status), phrase, - }) catch unreachable; - - switch (options.version) { - .@"HTTP/1.0" => if (keep_alive) h.appendSliceAssumeCapacity("connection: keep-alive\r\n"), - .@"HTTP/1.1" => if (!keep_alive) h.appendSliceAssumeCapacity("connection: close\r\n"), - } - - if (options.transfer_encoding) |transfer_encoding| switch (transfer_encoding) { - .none => {}, - .chunked => h.appendSliceAssumeCapacity("transfer-encoding: chunked\r\n"), - } else { - h.fixedWriter().print("content-length: {d}\r\n", .{content.len}) catch unreachable; - } - - var chunk_header_buffer: [18]u8 = undefined; - var iovecs: [max_extra_headers * 4 + 3]std.posix.iovec_const = undefined; - var iovecs_len: usize = 0; - - iovecs[iovecs_len] = .{ - .base = h.items.ptr, - .len = h.items.len, - }; - iovecs_len += 1; - - for (options.extra_headers) |header| { - iovecs[iovecs_len] = .{ - .base = header.name.ptr, - .len = header.name.len, - }; - iovecs_len += 1; - - iovecs[iovecs_len] = .{ - .base = ": ", - .len = 2, - }; - iovecs_len += 1; - - if (header.value.len != 0) { - iovecs[iovecs_len] = .{ - .base = header.value.ptr, - .len = header.value.len, - }; - iovecs_len += 1; - } - - iovecs[iovecs_len] = .{ - .base = "\r\n", - .len = 2, - }; - iovecs_len += 1; - } - - iovecs[iovecs_len] = .{ - .base = "\r\n", - .len = 2, - }; - iovecs_len += 1; - - if (request.head.method != .HEAD) { - const is_chunked = (options.transfer_encoding orelse .none) == .chunked; - if (is_chunked) { - if (content.len > 0) { - const chunk_header = std.fmt.bufPrint( - &chunk_header_buffer, - "{x}\r\n", - .{content.len}, - ) catch unreachable; - - iovecs[iovecs_len] = .{ - .base = chunk_header.ptr, - .len = chunk_header.len, - }; - iovecs_len += 1; - - iovecs[iovecs_len] = .{ - .base = content.ptr, - .len = content.len, - }; - iovecs_len += 1; - - iovecs[iovecs_len] = .{ - .base = "\r\n", - .len = 2, - }; - iovecs_len += 1; - } - - iovecs[iovecs_len] = .{ - .base = "0\r\n\r\n", - .len = 5, - }; - iovecs_len += 1; - } else if (content.len > 0) { - iovecs[iovecs_len] = .{ - .base = content.ptr, - .len = content.len, - }; - iovecs_len += 1; - } - } - - try request.server.connection.stream.writevAll(iovecs[0..iovecs_len]); - } - - pub const RespondStreamingOptions = struct { - /// An externally managed slice of memory used to batch bytes before - /// sending. `respondStreaming` asserts this is large enough to store - /// the full HTTP response head. - /// - /// Must outlive the returned Response. - send_buffer: []u8, - /// If provided, the response will use the content-length header; - /// otherwise it will use transfer-encoding: chunked. - content_length: ?u64 = null, - /// Options that are shared with the `respond` method. - respond_options: RespondOptions = .{}, - }; - - /// The header is buffered but not sent until Response.flush is called. - /// - /// If the request contains a body and the connection is to be reused, - /// discards the request body, leaving the Server in the `ready` state. If - /// this discarding fails, the connection is marked as not to be reused and - /// no error is surfaced. - /// - /// HEAD requests are handled transparently by setting a flag on the - /// returned Response to omit the body. However it may be worth noticing - /// that flag and skipping any expensive work that would otherwise need to - /// be done to satisfy the request. - /// - /// Asserts `send_buffer` is large enough to store the entire response header. - /// Asserts status is not `continue`. - pub fn respondStreaming(request: *Request, options: RespondStreamingOptions) Response { - const o = options.respond_options; - assert(o.status != .@"continue"); - const transfer_encoding_none = (o.transfer_encoding orelse .chunked) == .none; - const server_keep_alive = !transfer_encoding_none and o.keep_alive; - const keep_alive = request.discardBody(server_keep_alive); - const phrase = o.reason orelse o.status.phrase() orelse ""; - - var h = std.ArrayListUnmanaged(u8).initBuffer(options.send_buffer); - - const elide_body = if (request.head.expect != null) eb: { - // reader() and hence discardBody() above sets expect to null if it - // is handled. So the fact that it is not null here means unhandled. - h.appendSliceAssumeCapacity("HTTP/1.1 417 Expectation Failed\r\n"); - if (!keep_alive) h.appendSliceAssumeCapacity("connection: close\r\n"); - h.appendSliceAssumeCapacity("content-length: 0\r\n\r\n"); - break :eb true; - } else eb: { - h.fixedWriter().print("{s} {d} {s}\r\n", .{ - @tagName(o.version), @intFromEnum(o.status), phrase, - }) catch unreachable; - - switch (o.version) { - .@"HTTP/1.0" => if (keep_alive) h.appendSliceAssumeCapacity("connection: keep-alive\r\n"), - .@"HTTP/1.1" => if (!keep_alive) h.appendSliceAssumeCapacity("connection: close\r\n"), - } - - if (o.transfer_encoding) |transfer_encoding| switch (transfer_encoding) { - .chunked => h.appendSliceAssumeCapacity("transfer-encoding: chunked\r\n"), - .none => {}, - } else if (options.content_length) |len| { - h.fixedWriter().print("content-length: {d}\r\n", .{len}) catch unreachable; - } else { - h.appendSliceAssumeCapacity("transfer-encoding: chunked\r\n"); - } - - for (o.extra_headers) |header| { - assert(header.name.len != 0); - h.appendSliceAssumeCapacity(header.name); - h.appendSliceAssumeCapacity(": "); - h.appendSliceAssumeCapacity(header.value); - h.appendSliceAssumeCapacity("\r\n"); - } - - h.appendSliceAssumeCapacity("\r\n"); - break :eb request.head.method == .HEAD; - }; - - return .{ - .stream = request.server.connection.stream, - .send_buffer = options.send_buffer, - .send_buffer_start = 0, - .send_buffer_end = h.items.len, - .transfer_encoding = if (o.transfer_encoding) |te| switch (te) { - .chunked => .chunked, - .none => .none, - } else if (options.content_length) |len| .{ - .content_length = len, - } else .chunked, - .elide_body = elide_body, - .chunk_len = 0, - }; - } - - pub const ReadError = net.Stream.ReadError || error{ - HttpChunkInvalid, - HttpHeadersOversize, - }; - - fn read_cl(context: *const anyopaque, buffer: []u8) ReadError!usize { - const request: *Request = @constCast(@alignCast(@ptrCast(context))); - const s = request.server; - - const remaining_content_length = &request.reader_state.remaining_content_length; - if (remaining_content_length.* == 0) { - s.state = .ready; - return 0; - } - assert(s.state == .receiving_body); - const available = try fill(s, request.head_end); - const len = @min(remaining_content_length.*, available.len, buffer.len); - @memcpy(buffer[0..len], available[0..len]); - remaining_content_length.* -= len; - s.next_request_start += len; - if (remaining_content_length.* == 0) - s.state = .ready; - return len; - } - - fn fill(s: *Server, head_end: usize) ReadError![]u8 { - const available = s.read_buffer[s.next_request_start..s.read_buffer_len]; - if (available.len > 0) return available; - s.next_request_start = head_end; - s.read_buffer_len = head_end + try s.connection.stream.read(s.read_buffer[head_end..]); - return s.read_buffer[head_end..s.read_buffer_len]; - } - - fn read_chunked(context: *const anyopaque, buffer: []u8) ReadError!usize { - const request: *Request = @constCast(@alignCast(@ptrCast(context))); - const s = request.server; - - const cp = &request.reader_state.chunk_parser; - const head_end = request.head_end; - - // Protect against returning 0 before the end of stream. - var out_end: usize = 0; - while (out_end == 0) { - switch (cp.state) { - .invalid => return 0, - .data => { - assert(s.state == .receiving_body); - const available = try fill(s, head_end); - const len = @min(cp.chunk_len, available.len, buffer.len); - @memcpy(buffer[0..len], available[0..len]); - cp.chunk_len -= len; - if (cp.chunk_len == 0) - cp.state = .data_suffix; - out_end += len; - s.next_request_start += len; - continue; - }, - else => { - assert(s.state == .receiving_body); - const available = try fill(s, head_end); - const n = cp.feed(available); - switch (cp.state) { - .invalid => return error.HttpChunkInvalid, - .data => { - if (cp.chunk_len == 0) { - // The next bytes in the stream are trailers, - // or \r\n to indicate end of chunked body. - // - // This function must append the trailers at - // head_end so that headers and trailers are - // together. - // - // Since returning 0 would indicate end of - // stream, this function must read all the - // trailers before returning. - if (s.next_request_start > head_end) rebase(s, head_end); - var hp: http.HeadParser = .{}; - { - const bytes = s.read_buffer[head_end..s.read_buffer_len]; - const end = hp.feed(bytes); - if (hp.state == .finished) { - cp.state = .invalid; - s.state = .ready; - s.next_request_start = s.read_buffer_len - bytes.len + end; - return out_end; - } - } - while (true) { - const buf = s.read_buffer[s.read_buffer_len..]; - if (buf.len == 0) - return error.HttpHeadersOversize; - const read_n = try s.connection.stream.read(buf); - s.read_buffer_len += read_n; - const bytes = buf[0..read_n]; - const end = hp.feed(bytes); - if (hp.state == .finished) { - cp.state = .invalid; - s.state = .ready; - s.next_request_start = s.read_buffer_len - bytes.len + end; - return out_end; - } - } - } - const data = available[n..]; - const len = @min(cp.chunk_len, data.len, buffer.len); - @memcpy(buffer[0..len], data[0..len]); - cp.chunk_len -= len; - if (cp.chunk_len == 0) - cp.state = .data_suffix; - out_end += len; - s.next_request_start += n + len; - continue; - }, - else => continue, - } - }, - } - } - return out_end; - } - - pub const ReaderError = Response.WriteError || error{ - /// The client sent an expect HTTP header value other than - /// "100-continue". - HttpExpectationFailed, - }; - - /// In the case that the request contains "expect: 100-continue", this - /// function writes the continuation header, which means it can fail with a - /// write error. After sending the continuation header, it sets the - /// request's expect field to `null`. - /// - /// Asserts that this function is only called once. - pub fn reader(request: *Request) ReaderError!std.io.AnyReader { - const s = request.server; - assert(s.state == .received_head); - s.state = .receiving_body; - s.next_request_start = request.head_end; - - if (request.head.expect) |expect| { - if (mem.eql(u8, expect, "100-continue")) { - try request.server.connection.stream.writeAll("HTTP/1.1 100 Continue\r\n\r\n"); - request.head.expect = null; - } else { - return error.HttpExpectationFailed; - } - } - - switch (request.head.transfer_encoding) { - .chunked => { - request.reader_state = .{ .chunk_parser = http.ChunkParser.init }; - return .{ - .readFn = read_chunked, - .context = request, - }; - }, - .none => { - request.reader_state = .{ - .remaining_content_length = request.head.content_length orelse 0, - }; - return .{ - .readFn = read_cl, - .context = request, - }; - }, - } - } - - /// Returns whether the connection should remain persistent. - /// If it would fail, it instead sets the Server state to `receiving_body` - /// and returns false. - fn discardBody(request: *Request, keep_alive: bool) bool { - // Prepare to receive another request on the same connection. - // There are two factors to consider: - // * Any body the client sent must be discarded. - // * The Server's read_buffer may already have some bytes in it from - // whatever came after the head, which may be the next HTTP request - // or the request body. - // If the connection won't be kept alive, then none of this matters - // because the connection will be severed after the response is sent. - const s = request.server; - if (keep_alive and request.head.keep_alive) switch (s.state) { - .received_head => { - const r = request.reader() catch return false; - _ = r.discard() catch return false; - assert(s.state == .ready); - return true; - }, - .receiving_body, .ready => return true, - else => unreachable, - }; - - // Avoid clobbering the state in case a reading stream already exists. - switch (s.state) { - .received_head => s.state = .closing, - else => {}, - } - return false; - } -}; - -pub const Response = struct { - stream: net.Stream, - send_buffer: []u8, - /// Index of the first byte in `send_buffer`. - /// This is 0 unless a short write happens in `write`. - send_buffer_start: usize, - /// Index of the last byte + 1 in `send_buffer`. - send_buffer_end: usize, - /// `null` means transfer-encoding: chunked. - /// As a debugging utility, counts down to zero as bytes are written. - transfer_encoding: TransferEncoding, - elide_body: bool, - /// Indicates how much of the end of the `send_buffer` corresponds to a - /// chunk. This amount of data will be wrapped by an HTTP chunk header. - chunk_len: usize, - - pub const TransferEncoding = union(enum) { - /// End of connection signals the end of the stream. - none, - /// As a debugging utility, counts down to zero as bytes are written. - content_length: u64, - /// Each chunk is wrapped in a header and trailer. - chunked, - }; - - pub const WriteError = net.Stream.WriteError; - - /// When using content-length, asserts that the amount of data sent matches - /// the value sent in the header, then calls `flush`. - /// Otherwise, transfer-encoding: chunked is being used, and it writes the - /// end-of-stream message, then flushes the stream to the system. - /// Respects the value of `elide_body` to omit all data after the headers. - pub fn end(r: *Response) WriteError!void { - switch (r.transfer_encoding) { - .content_length => |len| { - assert(len == 0); // Trips when end() called before all bytes written. - try flush_cl(r); - }, - .none => { - try flush_cl(r); - }, - .chunked => { - try flush_chunked(r, &.{}); - }, - } - r.* = undefined; - } - - pub const EndChunkedOptions = struct { - trailers: []const http.Header = &.{}, - }; - - /// Asserts that the Response is using transfer-encoding: chunked. - /// Writes the end-of-stream message and any optional trailers, then - /// flushes the stream to the system. - /// Respects the value of `elide_body` to omit all data after the headers. - /// Asserts there are at most 25 trailers. - pub fn endChunked(r: *Response, options: EndChunkedOptions) WriteError!void { - assert(r.transfer_encoding == .chunked); - try flush_chunked(r, options.trailers); - r.* = undefined; - } - - /// If using content-length, asserts that writing these bytes to the client - /// would not exceed the content-length value sent in the HTTP header. - /// May return 0, which does not indicate end of stream. The caller decides - /// when the end of stream occurs by calling `end`. - pub fn write(r: *Response, bytes: []const u8) WriteError!usize { - switch (r.transfer_encoding) { - .content_length, .none => return write_cl(r, bytes), - .chunked => return write_chunked(r, bytes), - } - } - - fn write_cl(context: *const anyopaque, bytes: []const u8) WriteError!usize { - const r: *Response = @constCast(@alignCast(@ptrCast(context))); - - var trash: u64 = std.math.maxInt(u64); - const len = switch (r.transfer_encoding) { - .content_length => |*len| len, - else => &trash, - }; - - if (r.elide_body) { - len.* -= bytes.len; - return bytes.len; - } - - if (bytes.len + r.send_buffer_end > r.send_buffer.len) { - const send_buffer_len = r.send_buffer_end - r.send_buffer_start; - var iovecs: [2]std.posix.iovec_const = .{ - .{ - .base = r.send_buffer.ptr + r.send_buffer_start, - .len = send_buffer_len, - }, - .{ - .base = bytes.ptr, - .len = bytes.len, - }, - }; - const n = try r.stream.writev(&iovecs); - - if (n >= send_buffer_len) { - // It was enough to reset the buffer. - r.send_buffer_start = 0; - r.send_buffer_end = 0; - const bytes_n = n - send_buffer_len; - len.* -= bytes_n; - return bytes_n; - } - - // It didn't even make it through the existing buffer, let - // alone the new bytes provided. - r.send_buffer_start += n; - return 0; - } - - // All bytes can be stored in the remaining space of the buffer. - @memcpy(r.send_buffer[r.send_buffer_end..][0..bytes.len], bytes); - r.send_buffer_end += bytes.len; - len.* -= bytes.len; - return bytes.len; - } - - fn write_chunked(context: *const anyopaque, bytes: []const u8) WriteError!usize { - const r: *Response = @constCast(@alignCast(@ptrCast(context))); - assert(r.transfer_encoding == .chunked); - - if (r.elide_body) - return bytes.len; - - if (bytes.len + r.send_buffer_end > r.send_buffer.len) { - const send_buffer_len = r.send_buffer_end - r.send_buffer_start; - const chunk_len = r.chunk_len + bytes.len; - var header_buf: [18]u8 = undefined; - const chunk_header = std.fmt.bufPrint(&header_buf, "{x}\r\n", .{chunk_len}) catch unreachable; - - var iovecs: [5]std.posix.iovec_const = .{ - .{ - .base = r.send_buffer.ptr + r.send_buffer_start, - .len = send_buffer_len - r.chunk_len, - }, - .{ - .base = chunk_header.ptr, - .len = chunk_header.len, - }, - .{ - .base = r.send_buffer.ptr + r.send_buffer_end - r.chunk_len, - .len = r.chunk_len, - }, - .{ - .base = bytes.ptr, - .len = bytes.len, - }, - .{ - .base = "\r\n", - .len = 2, - }, - }; - // TODO make this writev instead of writevAll, which involves - // complicating the logic of this function. - try r.stream.writevAll(&iovecs); - r.send_buffer_start = 0; - r.send_buffer_end = 0; - r.chunk_len = 0; - return bytes.len; - } - - // All bytes can be stored in the remaining space of the buffer. - @memcpy(r.send_buffer[r.send_buffer_end..][0..bytes.len], bytes); - r.send_buffer_end += bytes.len; - r.chunk_len += bytes.len; - return bytes.len; - } - - /// If using content-length, asserts that writing these bytes to the client - /// would not exceed the content-length value sent in the HTTP header. - pub fn writeAll(r: *Response, bytes: []const u8) WriteError!void { - var index: usize = 0; - while (index < bytes.len) { - index += try write(r, bytes[index..]); - } - } - - /// Sends all buffered data to the client. - /// This is redundant after calling `end`. - /// Respects the value of `elide_body` to omit all data after the headers. - pub fn flush(r: *Response) WriteError!void { - switch (r.transfer_encoding) { - .none, .content_length => return flush_cl(r), - .chunked => return flush_chunked(r, null), - } - } - - fn flush_cl(r: *Response) WriteError!void { - try r.stream.writeAll(r.send_buffer[r.send_buffer_start..r.send_buffer_end]); - r.send_buffer_start = 0; - r.send_buffer_end = 0; - } - - fn flush_chunked(r: *Response, end_trailers: ?[]const http.Header) WriteError!void { - const max_trailers = 25; - if (end_trailers) |trailers| assert(trailers.len <= max_trailers); - assert(r.transfer_encoding == .chunked); - - const http_headers = r.send_buffer[r.send_buffer_start .. r.send_buffer_end - r.chunk_len]; - - if (r.elide_body) { - try r.stream.writeAll(http_headers); - r.send_buffer_start = 0; - r.send_buffer_end = 0; - r.chunk_len = 0; - return; - } - - var header_buf: [18]u8 = undefined; - const chunk_header = std.fmt.bufPrint(&header_buf, "{x}\r\n", .{r.chunk_len}) catch unreachable; - - var iovecs: [max_trailers * 4 + 5]std.posix.iovec_const = undefined; - var iovecs_len: usize = 0; - - iovecs[iovecs_len] = .{ - .base = http_headers.ptr, - .len = http_headers.len, - }; - iovecs_len += 1; - - if (r.chunk_len > 0) { - iovecs[iovecs_len] = .{ - .base = chunk_header.ptr, - .len = chunk_header.len, - }; - iovecs_len += 1; - - iovecs[iovecs_len] = .{ - .base = r.send_buffer.ptr + r.send_buffer_end - r.chunk_len, - .len = r.chunk_len, - }; - iovecs_len += 1; - - iovecs[iovecs_len] = .{ - .base = "\r\n", - .len = 2, - }; - iovecs_len += 1; - } - - if (end_trailers) |trailers| { - iovecs[iovecs_len] = .{ - .base = "0\r\n", - .len = 3, - }; - iovecs_len += 1; - - for (trailers) |trailer| { - iovecs[iovecs_len] = .{ - .base = trailer.name.ptr, - .len = trailer.name.len, - }; - iovecs_len += 1; - - iovecs[iovecs_len] = .{ - .base = ": ", - .len = 2, - }; - iovecs_len += 1; - - if (trailer.value.len != 0) { - iovecs[iovecs_len] = .{ - .base = trailer.value.ptr, - .len = trailer.value.len, - }; - iovecs_len += 1; - } - - iovecs[iovecs_len] = .{ - .base = "\r\n", - .len = 2, - }; - iovecs_len += 1; - } - - iovecs[iovecs_len] = .{ - .base = "\r\n", - .len = 2, - }; - iovecs_len += 1; - } - - try r.stream.writevAll(iovecs[0..iovecs_len]); - r.send_buffer_start = 0; - r.send_buffer_end = 0; - r.chunk_len = 0; - } - - pub fn writer(r: *Response) std.io.AnyWriter { - return .{ - .writeFn = switch (r.transfer_encoding) { - .none, .content_length => write_cl, - .chunked => write_chunked, - }, - .context = r, - }; - } -}; - -fn rebase(s: *Server, index: usize) void { - const leftover = s.read_buffer[s.next_request_start..s.read_buffer_len]; - const dest = s.read_buffer[index..][0..leftover.len]; - if (leftover.len <= s.next_request_start - index) { - @memcpy(dest, leftover); - } else { - mem.copyBackwards(u8, dest, leftover); - } - s.read_buffer_len = index + leftover.len; -} - -const std = @import("std"); -const http = std.http; -const mem = std.mem; -const net = std.net; -const Uri = std.Uri; -const assert = std.debug.assert; -const testing = std.testing; - -const Server = @This(); diff --git a/src/http/async/std/http/protocol.zig b/src/http/async/std/http/protocol.zig deleted file mode 100644 index 389e1e4f..00000000 --- a/src/http/async/std/http/protocol.zig +++ /dev/null @@ -1,447 +0,0 @@ -const std = @import("std"); -const builtin = @import("builtin"); -const testing = std.testing; -const mem = std.mem; - -const assert = std.debug.assert; -const use_vectors = builtin.zig_backend != .stage2_x86_64; - -pub const State = enum { - invalid, - - // Begin header and trailer parsing states. - - start, - seen_n, - seen_r, - seen_rn, - seen_rnr, - finished, - - // Begin transfer-encoding: chunked parsing states. - - chunk_head_size, - chunk_head_ext, - chunk_head_r, - chunk_data, - chunk_data_suffix, - chunk_data_suffix_r, - - /// Returns true if the parser is in a content state (ie. not waiting for more headers). - pub fn isContent(self: State) bool { - return switch (self) { - .invalid, .start, .seen_n, .seen_r, .seen_rn, .seen_rnr => false, - .finished, .chunk_head_size, .chunk_head_ext, .chunk_head_r, .chunk_data, .chunk_data_suffix, .chunk_data_suffix_r => true, - }; - } -}; - -pub const HeadersParser = struct { - state: State = .start, - /// A fixed buffer of len `max_header_bytes`. - /// Pointers into this buffer are not stable until after a message is complete. - header_bytes_buffer: []u8, - header_bytes_len: u32, - next_chunk_length: u64, - /// `false`: headers. `true`: trailers. - done: bool, - - /// Initializes the parser with a provided buffer `buf`. - pub fn init(buf: []u8) HeadersParser { - return .{ - .header_bytes_buffer = buf, - .header_bytes_len = 0, - .done = false, - .next_chunk_length = 0, - }; - } - - /// Reinitialize the parser. - /// Asserts the parser is in the "done" state. - pub fn reset(hp: *HeadersParser) void { - assert(hp.done); - hp.* = .{ - .state = .start, - .header_bytes_buffer = hp.header_bytes_buffer, - .header_bytes_len = 0, - .done = false, - .next_chunk_length = 0, - }; - } - - pub fn get(hp: HeadersParser) []u8 { - return hp.header_bytes_buffer[0..hp.header_bytes_len]; - } - - pub fn findHeadersEnd(r: *HeadersParser, bytes: []const u8) u32 { - var hp: std.http.HeadParser = .{ - .state = switch (r.state) { - .start => .start, - .seen_n => .seen_n, - .seen_r => .seen_r, - .seen_rn => .seen_rn, - .seen_rnr => .seen_rnr, - .finished => .finished, - else => unreachable, - }, - }; - const result = hp.feed(bytes); - r.state = switch (hp.state) { - .start => .start, - .seen_n => .seen_n, - .seen_r => .seen_r, - .seen_rn => .seen_rn, - .seen_rnr => .seen_rnr, - .finished => .finished, - }; - return @intCast(result); - } - - pub fn findChunkedLen(r: *HeadersParser, bytes: []const u8) u32 { - var cp: std.http.ChunkParser = .{ - .state = switch (r.state) { - .chunk_head_size => .head_size, - .chunk_head_ext => .head_ext, - .chunk_head_r => .head_r, - .chunk_data => .data, - .chunk_data_suffix => .data_suffix, - .chunk_data_suffix_r => .data_suffix_r, - .invalid => .invalid, - else => unreachable, - }, - .chunk_len = r.next_chunk_length, - }; - const result = cp.feed(bytes); - r.state = switch (cp.state) { - .head_size => .chunk_head_size, - .head_ext => .chunk_head_ext, - .head_r => .chunk_head_r, - .data => .chunk_data, - .data_suffix => .chunk_data_suffix, - .data_suffix_r => .chunk_data_suffix_r, - .invalid => .invalid, - }; - r.next_chunk_length = cp.chunk_len; - return @intCast(result); - } - - /// Returns whether or not the parser has finished parsing a complete - /// message. A message is only complete after the entire body has been read - /// and any trailing headers have been parsed. - pub fn isComplete(r: *HeadersParser) bool { - return r.done and r.state == .finished; - } - - pub const CheckCompleteHeadError = error{HttpHeadersOversize}; - - /// Pushes `in` into the parser. Returns the number of bytes consumed by - /// the header. Any header bytes are appended to `header_bytes_buffer`. - pub fn checkCompleteHead(hp: *HeadersParser, in: []const u8) CheckCompleteHeadError!u32 { - if (hp.state.isContent()) return 0; - - const i = hp.findHeadersEnd(in); - const data = in[0..i]; - if (hp.header_bytes_len + data.len > hp.header_bytes_buffer.len) - return error.HttpHeadersOversize; - - @memcpy(hp.header_bytes_buffer[hp.header_bytes_len..][0..data.len], data); - hp.header_bytes_len += @intCast(data.len); - - return i; - } - - pub const ReadError = error{ - HttpChunkInvalid, - }; - - /// Reads the body of the message into `buffer`. Returns the number of - /// bytes placed in the buffer. - /// - /// If `skip` is true, the buffer will be unused and the body will be skipped. - /// - /// See `std.http.Client.Connection for an example of `conn`. - pub fn read(r: *HeadersParser, conn: anytype, buffer: []u8, skip: bool) !usize { - assert(r.state.isContent()); - if (r.done) return 0; - - var out_index: usize = 0; - while (true) { - switch (r.state) { - .invalid, .start, .seen_n, .seen_r, .seen_rn, .seen_rnr => unreachable, - .finished => { - const data_avail = r.next_chunk_length; - - if (skip) { - try conn.fill(); - - const nread = @min(conn.peek().len, data_avail); - conn.drop(@intCast(nread)); - r.next_chunk_length -= nread; - - if (r.next_chunk_length == 0 or nread == 0) r.done = true; - - return out_index; - } else if (out_index < buffer.len) { - const out_avail = buffer.len - out_index; - - const can_read = @as(usize, @intCast(@min(data_avail, out_avail))); - const nread = try conn.read(buffer[0..can_read]); - r.next_chunk_length -= nread; - - if (r.next_chunk_length == 0 or nread == 0) r.done = true; - - return nread; - } else { - return out_index; - } - }, - .chunk_data_suffix, .chunk_data_suffix_r, .chunk_head_size, .chunk_head_ext, .chunk_head_r => { - try conn.fill(); - - const i = r.findChunkedLen(conn.peek()); - conn.drop(@intCast(i)); - - switch (r.state) { - .invalid => return error.HttpChunkInvalid, - .chunk_data => if (r.next_chunk_length == 0) { - if (std.mem.eql(u8, conn.peek(), "\r\n")) { - r.state = .finished; - conn.drop(2); - } else { - // The trailer section is formatted identically - // to the header section. - r.state = .seen_rn; - } - r.done = true; - - return out_index; - }, - else => return out_index, - } - - continue; - }, - .chunk_data => { - const data_avail = r.next_chunk_length; - const out_avail = buffer.len - out_index; - - if (skip) { - try conn.fill(); - - const nread = @min(conn.peek().len, data_avail); - conn.drop(@intCast(nread)); - r.next_chunk_length -= nread; - } else if (out_avail > 0) { - const can_read: usize = @intCast(@min(data_avail, out_avail)); - const nread = try conn.read(buffer[out_index..][0..can_read]); - r.next_chunk_length -= nread; - out_index += nread; - } - - if (r.next_chunk_length == 0) { - r.state = .chunk_data_suffix; - continue; - } - - return out_index; - }, - } - } - } -}; - -inline fn int16(array: *const [2]u8) u16 { - return @as(u16, @bitCast(array.*)); -} - -inline fn int24(array: *const [3]u8) u24 { - return @as(u24, @bitCast(array.*)); -} - -inline fn int32(array: *const [4]u8) u32 { - return @as(u32, @bitCast(array.*)); -} - -inline fn intShift(comptime T: type, x: anytype) T { - switch (@import("builtin").cpu.arch.endian()) { - .little => return @as(T, @truncate(x >> (@bitSizeOf(@TypeOf(x)) - @bitSizeOf(T)))), - .big => return @as(T, @truncate(x)), - } -} - -/// A buffered (and peekable) Connection. -const MockBufferedConnection = struct { - pub const buffer_size = 0x2000; - - conn: std.io.FixedBufferStream([]const u8), - buf: [buffer_size]u8 = undefined, - start: u16 = 0, - end: u16 = 0, - - pub fn fill(conn: *MockBufferedConnection) ReadError!void { - if (conn.end != conn.start) return; - - const nread = try conn.conn.read(conn.buf[0..]); - if (nread == 0) return error.EndOfStream; - conn.start = 0; - conn.end = @as(u16, @truncate(nread)); - } - - pub fn peek(conn: *MockBufferedConnection) []const u8 { - return conn.buf[conn.start..conn.end]; - } - - pub fn drop(conn: *MockBufferedConnection, num: u16) void { - conn.start += num; - } - - pub fn readAtLeast(conn: *MockBufferedConnection, buffer: []u8, len: usize) ReadError!usize { - var out_index: u16 = 0; - while (out_index < len) { - const available = conn.end - conn.start; - const left = buffer.len - out_index; - - if (available > 0) { - const can_read = @as(u16, @truncate(@min(available, left))); - - @memcpy(buffer[out_index..][0..can_read], conn.buf[conn.start..][0..can_read]); - out_index += can_read; - conn.start += can_read; - - continue; - } - - if (left > conn.buf.len) { - // skip the buffer if the output is large enough - return conn.conn.read(buffer[out_index..]); - } - - try conn.fill(); - } - - return out_index; - } - - pub fn read(conn: *MockBufferedConnection, buffer: []u8) ReadError!usize { - return conn.readAtLeast(buffer, 1); - } - - pub const ReadError = std.io.FixedBufferStream([]const u8).ReadError || error{EndOfStream}; - pub const Reader = std.io.Reader(*MockBufferedConnection, ReadError, read); - - pub fn reader(conn: *MockBufferedConnection) Reader { - return Reader{ .context = conn }; - } - - pub fn writeAll(conn: *MockBufferedConnection, buffer: []const u8) WriteError!void { - return conn.conn.writeAll(buffer); - } - - pub fn write(conn: *MockBufferedConnection, buffer: []const u8) WriteError!usize { - return conn.conn.write(buffer); - } - - pub const WriteError = std.io.FixedBufferStream([]const u8).WriteError; - pub const Writer = std.io.Writer(*MockBufferedConnection, WriteError, write); - - pub fn writer(conn: *MockBufferedConnection) Writer { - return Writer{ .context = conn }; - } -}; - -test "HeadersParser.read length" { - // mock BufferedConnection for read - var headers_buf: [256]u8 = undefined; - - var r = HeadersParser.init(&headers_buf); - const data = "GET / HTTP/1.1\r\nHost: localhost\r\nContent-Length: 5\r\n\r\nHello"; - - var conn: MockBufferedConnection = .{ - .conn = std.io.fixedBufferStream(data), - }; - - while (true) { // read headers - try conn.fill(); - - const nchecked = try r.checkCompleteHead(conn.peek()); - conn.drop(@intCast(nchecked)); - - if (r.state.isContent()) break; - } - - var buf: [8]u8 = undefined; - - r.next_chunk_length = 5; - const len = try r.read(&conn, &buf, false); - try std.testing.expectEqual(@as(usize, 5), len); - try std.testing.expectEqualStrings("Hello", buf[0..len]); - - try std.testing.expectEqualStrings("GET / HTTP/1.1\r\nHost: localhost\r\nContent-Length: 5\r\n\r\n", r.get()); -} - -test "HeadersParser.read chunked" { - // mock BufferedConnection for read - - var headers_buf: [256]u8 = undefined; - var r = HeadersParser.init(&headers_buf); - const data = "GET / HTTP/1.1\r\nHost: localhost\r\n\r\n2\r\nHe\r\n2\r\nll\r\n1\r\no\r\n0\r\n\r\n"; - - var conn: MockBufferedConnection = .{ - .conn = std.io.fixedBufferStream(data), - }; - - while (true) { // read headers - try conn.fill(); - - const nchecked = try r.checkCompleteHead(conn.peek()); - conn.drop(@intCast(nchecked)); - - if (r.state.isContent()) break; - } - var buf: [8]u8 = undefined; - - r.state = .chunk_head_size; - const len = try r.read(&conn, &buf, false); - try std.testing.expectEqual(@as(usize, 5), len); - try std.testing.expectEqualStrings("Hello", buf[0..len]); - - try std.testing.expectEqualStrings("GET / HTTP/1.1\r\nHost: localhost\r\n\r\n", r.get()); -} - -test "HeadersParser.read chunked trailer" { - // mock BufferedConnection for read - - var headers_buf: [256]u8 = undefined; - var r = HeadersParser.init(&headers_buf); - const data = "GET / HTTP/1.1\r\nHost: localhost\r\n\r\n2\r\nHe\r\n2\r\nll\r\n1\r\no\r\n0\r\nContent-Type: text/plain\r\n\r\n"; - - var conn: MockBufferedConnection = .{ - .conn = std.io.fixedBufferStream(data), - }; - - while (true) { // read headers - try conn.fill(); - - const nchecked = try r.checkCompleteHead(conn.peek()); - conn.drop(@intCast(nchecked)); - - if (r.state.isContent()) break; - } - var buf: [8]u8 = undefined; - - r.state = .chunk_head_size; - const len = try r.read(&conn, &buf, false); - try std.testing.expectEqual(@as(usize, 5), len); - try std.testing.expectEqualStrings("Hello", buf[0..len]); - - while (true) { // read headers - try conn.fill(); - - const nchecked = try r.checkCompleteHead(conn.peek()); - conn.drop(@intCast(nchecked)); - - if (r.state.isContent()) break; - } - - try std.testing.expectEqualStrings("GET / HTTP/1.1\r\nHost: localhost\r\n\r\nContent-Type: text/plain\r\n\r\n", r.get()); -} diff --git a/src/http/async/std/net.zig b/src/http/async/std/net.zig deleted file mode 100644 index 86863deb..00000000 --- a/src/http/async/std/net.zig +++ /dev/null @@ -1,2050 +0,0 @@ -//! Cross-platform networking abstractions. - -const std = @import("std"); -const builtin = @import("builtin"); -const assert = std.debug.assert; -const net = @This(); -const mem = std.mem; -const posix = std.posix; -const fs = std.fs; -const io = std.io; -const native_endian = builtin.target.cpu.arch.endian(); -const native_os = builtin.os.tag; -const windows = std.os.windows; - -const Ctx = @import("http/Client.zig").Ctx; -const Cbk = @import("http/Client.zig").Cbk; - -// Windows 10 added support for unix sockets in build 17063, redstone 4 is the -// first release to support them. -pub const has_unix_sockets = switch (native_os) { - .windows => builtin.os.version_range.windows.isAtLeast(.win10_rs4) orelse false, - else => true, -}; - -pub const IPParseError = error{ - Overflow, - InvalidEnd, - InvalidCharacter, - Incomplete, -}; - -pub const IPv4ParseError = IPParseError || error{NonCanonical}; - -pub const IPv6ParseError = IPParseError || error{InvalidIpv4Mapping}; -pub const IPv6InterfaceError = posix.SocketError || posix.IoCtl_SIOCGIFINDEX_Error || error{NameTooLong}; -pub const IPv6ResolveError = IPv6ParseError || IPv6InterfaceError; - -pub const Address = extern union { - any: posix.sockaddr, - in: Ip4Address, - in6: Ip6Address, - un: if (has_unix_sockets) posix.sockaddr.un else void, - - /// Parse the given IP address string into an Address value. - /// It is recommended to use `resolveIp` instead, to handle - /// IPv6 link-local unix addresses. - pub fn parseIp(name: []const u8, port: u16) !Address { - if (parseIp4(name, port)) |ip4| return ip4 else |err| switch (err) { - error.Overflow, - error.InvalidEnd, - error.InvalidCharacter, - error.Incomplete, - error.NonCanonical, - => {}, - } - - if (parseIp6(name, port)) |ip6| return ip6 else |err| switch (err) { - error.Overflow, - error.InvalidEnd, - error.InvalidCharacter, - error.Incomplete, - error.InvalidIpv4Mapping, - => {}, - } - - return error.InvalidIPAddressFormat; - } - - pub fn resolveIp(name: []const u8, port: u16) !Address { - if (parseIp4(name, port)) |ip4| return ip4 else |err| switch (err) { - error.Overflow, - error.InvalidEnd, - error.InvalidCharacter, - error.Incomplete, - error.NonCanonical, - => {}, - } - - if (resolveIp6(name, port)) |ip6| return ip6 else |err| switch (err) { - error.Overflow, - error.InvalidEnd, - error.InvalidCharacter, - error.Incomplete, - error.InvalidIpv4Mapping, - => {}, - else => return err, - } - - return error.InvalidIPAddressFormat; - } - - pub fn parseExpectingFamily(name: []const u8, family: posix.sa_family_t, port: u16) !Address { - switch (family) { - posix.AF.INET => return parseIp4(name, port), - posix.AF.INET6 => return parseIp6(name, port), - posix.AF.UNSPEC => return parseIp(name, port), - else => unreachable, - } - } - - pub fn parseIp6(buf: []const u8, port: u16) IPv6ParseError!Address { - return .{ .in6 = try Ip6Address.parse(buf, port) }; - } - - pub fn resolveIp6(buf: []const u8, port: u16) IPv6ResolveError!Address { - return .{ .in6 = try Ip6Address.resolve(buf, port) }; - } - - pub fn parseIp4(buf: []const u8, port: u16) IPv4ParseError!Address { - return .{ .in = try Ip4Address.parse(buf, port) }; - } - - pub fn initIp4(addr: [4]u8, port: u16) Address { - return .{ .in = Ip4Address.init(addr, port) }; - } - - pub fn initIp6(addr: [16]u8, port: u16, flowinfo: u32, scope_id: u32) Address { - return .{ .in6 = Ip6Address.init(addr, port, flowinfo, scope_id) }; - } - - pub fn initUnix(path: []const u8) !Address { - var sock_addr = posix.sockaddr.un{ - .family = posix.AF.UNIX, - .path = undefined, - }; - - // Add 1 to ensure a terminating 0 is present in the path array for maximum portability. - if (path.len + 1 > sock_addr.path.len) return error.NameTooLong; - - @memset(&sock_addr.path, 0); - @memcpy(sock_addr.path[0..path.len], path); - - return .{ .un = sock_addr }; - } - - /// Returns the port in native endian. - /// Asserts that the address is ip4 or ip6. - pub fn getPort(self: Address) u16 { - return switch (self.any.family) { - posix.AF.INET => self.in.getPort(), - posix.AF.INET6 => self.in6.getPort(), - else => unreachable, - }; - } - - /// `port` is native-endian. - /// Asserts that the address is ip4 or ip6. - pub fn setPort(self: *Address, port: u16) void { - switch (self.any.family) { - posix.AF.INET => self.in.setPort(port), - posix.AF.INET6 => self.in6.setPort(port), - else => unreachable, - } - } - - /// Asserts that `addr` is an IP address. - /// This function will read past the end of the pointer, with a size depending - /// on the address family. - pub fn initPosix(addr: *align(4) const posix.sockaddr) Address { - switch (addr.family) { - posix.AF.INET => return Address{ .in = Ip4Address{ .sa = @as(*const posix.sockaddr.in, @ptrCast(addr)).* } }, - posix.AF.INET6 => return Address{ .in6 = Ip6Address{ .sa = @as(*const posix.sockaddr.in6, @ptrCast(addr)).* } }, - else => unreachable, - } - } - - pub fn format( - self: Address, - comptime fmt: []const u8, - options: std.fmt.FormatOptions, - out_stream: anytype, - ) !void { - if (fmt.len != 0) std.fmt.invalidFmtError(fmt, self); - switch (self.any.family) { - posix.AF.INET => try self.in.format(fmt, options, out_stream), - posix.AF.INET6 => try self.in6.format(fmt, options, out_stream), - posix.AF.UNIX => { - if (!has_unix_sockets) { - unreachable; - } - - try std.fmt.format(out_stream, "{s}", .{std.mem.sliceTo(&self.un.path, 0)}); - }, - else => unreachable, - } - } - - pub fn eql(a: Address, b: Address) bool { - const a_bytes = @as([*]const u8, @ptrCast(&a.any))[0..a.getOsSockLen()]; - const b_bytes = @as([*]const u8, @ptrCast(&b.any))[0..b.getOsSockLen()]; - return mem.eql(u8, a_bytes, b_bytes); - } - - pub fn getOsSockLen(self: Address) posix.socklen_t { - switch (self.any.family) { - posix.AF.INET => return self.in.getOsSockLen(), - posix.AF.INET6 => return self.in6.getOsSockLen(), - posix.AF.UNIX => { - if (!has_unix_sockets) { - unreachable; - } - - // Using the full length of the structure here is more portable than returning - // the number of bytes actually used by the currently stored path. - // This also is correct regardless if we are passing a socket address to the kernel - // (e.g. in bind, connect, sendto) since we ensure the path is 0 terminated in - // initUnix() or if we are receiving a socket address from the kernel and must - // provide the full buffer size (e.g. getsockname, getpeername, recvfrom, accept). - // - // To access the path, std.mem.sliceTo(&address.un.path, 0) should be used. - return @as(posix.socklen_t, @intCast(@sizeOf(posix.sockaddr.un))); - }, - - else => unreachable, - } - } - - pub const ListenError = posix.SocketError || posix.BindError || posix.ListenError || - posix.SetSockOptError || posix.GetSockNameError; - - pub const ListenOptions = struct { - /// How many connections the kernel will accept on the application's behalf. - /// If more than this many connections pool in the kernel, clients will start - /// seeing "Connection refused". - kernel_backlog: u31 = 128, - /// Sets SO_REUSEADDR and SO_REUSEPORT on POSIX. - /// Sets SO_REUSEADDR on Windows, which is roughly equivalent. - reuse_address: bool = false, - /// Deprecated. Does the same thing as reuse_address. - reuse_port: bool = false, - force_nonblocking: bool = false, - }; - - /// The returned `Server` has an open `stream`. - pub fn listen(address: Address, options: ListenOptions) ListenError!Server { - const nonblock: u32 = if (options.force_nonblocking) posix.SOCK.NONBLOCK else 0; - const sock_flags = posix.SOCK.STREAM | posix.SOCK.CLOEXEC | nonblock; - const proto: u32 = if (address.any.family == posix.AF.UNIX) 0 else posix.IPPROTO.TCP; - - const sockfd = try posix.socket(address.any.family, sock_flags, proto); - var s: Server = .{ - .listen_address = undefined, - .stream = .{ .handle = sockfd }, - }; - errdefer s.stream.close(); - - if (options.reuse_address or options.reuse_port) { - try posix.setsockopt( - sockfd, - posix.SOL.SOCKET, - posix.SO.REUSEADDR, - &mem.toBytes(@as(c_int, 1)), - ); - switch (native_os) { - .windows => {}, - else => try posix.setsockopt( - sockfd, - posix.SOL.SOCKET, - posix.SO.REUSEPORT, - &mem.toBytes(@as(c_int, 1)), - ), - } - } - - var socklen = address.getOsSockLen(); - try posix.bind(sockfd, &address.any, socklen); - try posix.listen(sockfd, options.kernel_backlog); - try posix.getsockname(sockfd, &s.listen_address.any, &socklen); - return s; - } -}; - -pub const Ip4Address = extern struct { - sa: posix.sockaddr.in, - - pub fn parse(buf: []const u8, port: u16) IPv4ParseError!Ip4Address { - var result: Ip4Address = .{ - .sa = .{ - .port = mem.nativeToBig(u16, port), - .addr = undefined, - }, - }; - const out_ptr = mem.asBytes(&result.sa.addr); - - var x: u8 = 0; - var index: u8 = 0; - var saw_any_digits = false; - var has_zero_prefix = false; - for (buf) |c| { - if (c == '.') { - if (!saw_any_digits) { - return error.InvalidCharacter; - } - if (index == 3) { - return error.InvalidEnd; - } - out_ptr[index] = x; - index += 1; - x = 0; - saw_any_digits = false; - has_zero_prefix = false; - } else if (c >= '0' and c <= '9') { - if (c == '0' and !saw_any_digits) { - has_zero_prefix = true; - } else if (has_zero_prefix) { - return error.NonCanonical; - } - saw_any_digits = true; - x = try std.math.mul(u8, x, 10); - x = try std.math.add(u8, x, c - '0'); - } else { - return error.InvalidCharacter; - } - } - if (index == 3 and saw_any_digits) { - out_ptr[index] = x; - return result; - } - - return error.Incomplete; - } - - pub fn resolveIp(name: []const u8, port: u16) !Ip4Address { - if (parse(name, port)) |ip4| return ip4 else |err| switch (err) { - error.Overflow, - error.InvalidEnd, - error.InvalidCharacter, - error.Incomplete, - error.NonCanonical, - => {}, - } - return error.InvalidIPAddressFormat; - } - - pub fn init(addr: [4]u8, port: u16) Ip4Address { - return Ip4Address{ - .sa = posix.sockaddr.in{ - .port = mem.nativeToBig(u16, port), - .addr = @as(*align(1) const u32, @ptrCast(&addr)).*, - }, - }; - } - - /// Returns the port in native endian. - /// Asserts that the address is ip4 or ip6. - pub fn getPort(self: Ip4Address) u16 { - return mem.bigToNative(u16, self.sa.port); - } - - /// `port` is native-endian. - /// Asserts that the address is ip4 or ip6. - pub fn setPort(self: *Ip4Address, port: u16) void { - self.sa.port = mem.nativeToBig(u16, port); - } - - pub fn format( - self: Ip4Address, - comptime fmt: []const u8, - options: std.fmt.FormatOptions, - out_stream: anytype, - ) !void { - if (fmt.len != 0) std.fmt.invalidFmtError(fmt, self); - _ = options; - const bytes = @as(*const [4]u8, @ptrCast(&self.sa.addr)); - try std.fmt.format(out_stream, "{}.{}.{}.{}:{}", .{ - bytes[0], - bytes[1], - bytes[2], - bytes[3], - self.getPort(), - }); - } - - pub fn getOsSockLen(self: Ip4Address) posix.socklen_t { - _ = self; - return @sizeOf(posix.sockaddr.in); - } -}; - -pub const Ip6Address = extern struct { - sa: posix.sockaddr.in6, - - /// Parse a given IPv6 address string into an Address. - /// Assumes the Scope ID of the address is fully numeric. - /// For non-numeric addresses, see `resolveIp6`. - pub fn parse(buf: []const u8, port: u16) IPv6ParseError!Ip6Address { - var result = Ip6Address{ - .sa = posix.sockaddr.in6{ - .scope_id = 0, - .port = mem.nativeToBig(u16, port), - .flowinfo = 0, - .addr = undefined, - }, - }; - var ip_slice: *[16]u8 = result.sa.addr[0..]; - - var tail: [16]u8 = undefined; - - var x: u16 = 0; - var saw_any_digits = false; - var index: u8 = 0; - var scope_id = false; - var abbrv = false; - for (buf, 0..) |c, i| { - if (scope_id) { - if (c >= '0' and c <= '9') { - const digit = c - '0'; - { - const ov = @mulWithOverflow(result.sa.scope_id, 10); - if (ov[1] != 0) return error.Overflow; - result.sa.scope_id = ov[0]; - } - { - const ov = @addWithOverflow(result.sa.scope_id, digit); - if (ov[1] != 0) return error.Overflow; - result.sa.scope_id = ov[0]; - } - } else { - return error.InvalidCharacter; - } - } else if (c == ':') { - if (!saw_any_digits) { - if (abbrv) return error.InvalidCharacter; // ':::' - if (i != 0) abbrv = true; - @memset(ip_slice[index..], 0); - ip_slice = tail[0..]; - index = 0; - continue; - } - if (index == 14) { - return error.InvalidEnd; - } - ip_slice[index] = @as(u8, @truncate(x >> 8)); - index += 1; - ip_slice[index] = @as(u8, @truncate(x)); - index += 1; - - x = 0; - saw_any_digits = false; - } else if (c == '%') { - if (!saw_any_digits) { - return error.InvalidCharacter; - } - scope_id = true; - saw_any_digits = false; - } else if (c == '.') { - if (!abbrv or ip_slice[0] != 0xff or ip_slice[1] != 0xff) { - // must start with '::ffff:' - return error.InvalidIpv4Mapping; - } - const start_index = mem.lastIndexOfScalar(u8, buf[0..i], ':').? + 1; - const addr = (Ip4Address.parse(buf[start_index..], 0) catch { - return error.InvalidIpv4Mapping; - }).sa.addr; - ip_slice = result.sa.addr[0..]; - ip_slice[10] = 0xff; - ip_slice[11] = 0xff; - - const ptr = mem.sliceAsBytes(@as(*const [1]u32, &addr)[0..]); - - ip_slice[12] = ptr[0]; - ip_slice[13] = ptr[1]; - ip_slice[14] = ptr[2]; - ip_slice[15] = ptr[3]; - return result; - } else { - const digit = try std.fmt.charToDigit(c, 16); - { - const ov = @mulWithOverflow(x, 16); - if (ov[1] != 0) return error.Overflow; - x = ov[0]; - } - { - const ov = @addWithOverflow(x, digit); - if (ov[1] != 0) return error.Overflow; - x = ov[0]; - } - saw_any_digits = true; - } - } - - if (!saw_any_digits and !abbrv) { - return error.Incomplete; - } - if (!abbrv and index < 14) { - return error.Incomplete; - } - - if (index == 14) { - ip_slice[14] = @as(u8, @truncate(x >> 8)); - ip_slice[15] = @as(u8, @truncate(x)); - return result; - } else { - ip_slice[index] = @as(u8, @truncate(x >> 8)); - index += 1; - ip_slice[index] = @as(u8, @truncate(x)); - index += 1; - @memcpy(result.sa.addr[16 - index ..][0..index], ip_slice[0..index]); - return result; - } - } - - pub fn resolve(buf: []const u8, port: u16) IPv6ResolveError!Ip6Address { - // TODO: Unify the implementations of resolveIp6 and parseIp6. - var result = Ip6Address{ - .sa = posix.sockaddr.in6{ - .scope_id = 0, - .port = mem.nativeToBig(u16, port), - .flowinfo = 0, - .addr = undefined, - }, - }; - var ip_slice: *[16]u8 = result.sa.addr[0..]; - - var tail: [16]u8 = undefined; - - var x: u16 = 0; - var saw_any_digits = false; - var index: u8 = 0; - var abbrv = false; - - var scope_id = false; - var scope_id_value: [posix.IFNAMESIZE - 1]u8 = undefined; - var scope_id_index: usize = 0; - - for (buf, 0..) |c, i| { - if (scope_id) { - // Handling of percent-encoding should be for an URI library. - if ((c >= '0' and c <= '9') or - (c >= 'A' and c <= 'Z') or - (c >= 'a' and c <= 'z') or - (c == '-') or (c == '.') or (c == '_') or (c == '~')) - { - if (scope_id_index >= scope_id_value.len) { - return error.Overflow; - } - - scope_id_value[scope_id_index] = c; - scope_id_index += 1; - } else { - return error.InvalidCharacter; - } - } else if (c == ':') { - if (!saw_any_digits) { - if (abbrv) return error.InvalidCharacter; // ':::' - if (i != 0) abbrv = true; - @memset(ip_slice[index..], 0); - ip_slice = tail[0..]; - index = 0; - continue; - } - if (index == 14) { - return error.InvalidEnd; - } - ip_slice[index] = @as(u8, @truncate(x >> 8)); - index += 1; - ip_slice[index] = @as(u8, @truncate(x)); - index += 1; - - x = 0; - saw_any_digits = false; - } else if (c == '%') { - if (!saw_any_digits) { - return error.InvalidCharacter; - } - scope_id = true; - saw_any_digits = false; - } else if (c == '.') { - if (!abbrv or ip_slice[0] != 0xff or ip_slice[1] != 0xff) { - // must start with '::ffff:' - return error.InvalidIpv4Mapping; - } - const start_index = mem.lastIndexOfScalar(u8, buf[0..i], ':').? + 1; - const addr = (Ip4Address.parse(buf[start_index..], 0) catch { - return error.InvalidIpv4Mapping; - }).sa.addr; - ip_slice = result.sa.addr[0..]; - ip_slice[10] = 0xff; - ip_slice[11] = 0xff; - - const ptr = mem.sliceAsBytes(@as(*const [1]u32, &addr)[0..]); - - ip_slice[12] = ptr[0]; - ip_slice[13] = ptr[1]; - ip_slice[14] = ptr[2]; - ip_slice[15] = ptr[3]; - return result; - } else { - const digit = try std.fmt.charToDigit(c, 16); - { - const ov = @mulWithOverflow(x, 16); - if (ov[1] != 0) return error.Overflow; - x = ov[0]; - } - { - const ov = @addWithOverflow(x, digit); - if (ov[1] != 0) return error.Overflow; - x = ov[0]; - } - saw_any_digits = true; - } - } - - if (!saw_any_digits and !abbrv) { - return error.Incomplete; - } - - if (scope_id and scope_id_index == 0) { - return error.Incomplete; - } - - var resolved_scope_id: u32 = 0; - if (scope_id_index > 0) { - const scope_id_str = scope_id_value[0..scope_id_index]; - resolved_scope_id = std.fmt.parseInt(u32, scope_id_str, 10) catch |err| blk: { - if (err != error.InvalidCharacter) return err; - break :blk try if_nametoindex(scope_id_str); - }; - } - - result.sa.scope_id = resolved_scope_id; - - if (index == 14) { - ip_slice[14] = @as(u8, @truncate(x >> 8)); - ip_slice[15] = @as(u8, @truncate(x)); - return result; - } else { - ip_slice[index] = @as(u8, @truncate(x >> 8)); - index += 1; - ip_slice[index] = @as(u8, @truncate(x)); - index += 1; - @memcpy(result.sa.addr[16 - index ..][0..index], ip_slice[0..index]); - return result; - } - } - - pub fn init(addr: [16]u8, port: u16, flowinfo: u32, scope_id: u32) Ip6Address { - return Ip6Address{ - .sa = posix.sockaddr.in6{ - .addr = addr, - .port = mem.nativeToBig(u16, port), - .flowinfo = flowinfo, - .scope_id = scope_id, - }, - }; - } - - /// Returns the port in native endian. - /// Asserts that the address is ip4 or ip6. - pub fn getPort(self: Ip6Address) u16 { - return mem.bigToNative(u16, self.sa.port); - } - - /// `port` is native-endian. - /// Asserts that the address is ip4 or ip6. - pub fn setPort(self: *Ip6Address, port: u16) void { - self.sa.port = mem.nativeToBig(u16, port); - } - - pub fn format( - self: Ip6Address, - comptime fmt: []const u8, - options: std.fmt.FormatOptions, - out_stream: anytype, - ) !void { - if (fmt.len != 0) std.fmt.invalidFmtError(fmt, self); - _ = options; - const port = mem.bigToNative(u16, self.sa.port); - if (mem.eql(u8, self.sa.addr[0..12], &[_]u8{ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xff, 0xff })) { - try std.fmt.format(out_stream, "[::ffff:{}.{}.{}.{}]:{}", .{ - self.sa.addr[12], - self.sa.addr[13], - self.sa.addr[14], - self.sa.addr[15], - port, - }); - return; - } - const big_endian_parts = @as(*align(1) const [8]u16, @ptrCast(&self.sa.addr)); - const native_endian_parts = switch (native_endian) { - .big => big_endian_parts.*, - .little => blk: { - var buf: [8]u16 = undefined; - for (big_endian_parts, 0..) |part, i| { - buf[i] = mem.bigToNative(u16, part); - } - break :blk buf; - }, - }; - try out_stream.writeAll("["); - var i: usize = 0; - var abbrv = false; - while (i < native_endian_parts.len) : (i += 1) { - if (native_endian_parts[i] == 0) { - if (!abbrv) { - try out_stream.writeAll(if (i == 0) "::" else ":"); - abbrv = true; - } - continue; - } - try std.fmt.format(out_stream, "{x}", .{native_endian_parts[i]}); - if (i != native_endian_parts.len - 1) { - try out_stream.writeAll(":"); - } - } - try std.fmt.format(out_stream, "]:{}", .{port}); - } - - pub fn getOsSockLen(self: Ip6Address) posix.socklen_t { - _ = self; - return @sizeOf(posix.sockaddr.in6); - } -}; - -pub fn connectUnixSocket(path: []const u8) !Stream { - const opt_non_block = 0; - const sockfd = try posix.socket( - posix.AF.UNIX, - posix.SOCK.STREAM | posix.SOCK.CLOEXEC | opt_non_block, - 0, - ); - errdefer Stream.close(.{ .handle = sockfd }); - - var addr = try std.net.Address.initUnix(path); - try posix.connect(sockfd, &addr.any, addr.getOsSockLen()); - - return .{ .handle = sockfd }; -} - -fn if_nametoindex(name: []const u8) IPv6InterfaceError!u32 { - if (native_os == .linux) { - var ifr: posix.ifreq = undefined; - const sockfd = try posix.socket(posix.AF.UNIX, posix.SOCK.DGRAM | posix.SOCK.CLOEXEC, 0); - defer Stream.close(.{ .handle = sockfd }); - - @memcpy(ifr.ifrn.name[0..name.len], name); - ifr.ifrn.name[name.len] = 0; - - // TODO investigate if this needs to be integrated with evented I/O. - try posix.ioctl_SIOCGIFINDEX(sockfd, &ifr); - - return @bitCast(ifr.ifru.ivalue); - } - - if (native_os.isDarwin()) { - if (name.len >= posix.IFNAMESIZE) - return error.NameTooLong; - - var if_name: [posix.IFNAMESIZE:0]u8 = undefined; - @memcpy(if_name[0..name.len], name); - if_name[name.len] = 0; - const if_slice = if_name[0..name.len :0]; - const index = std.c.if_nametoindex(if_slice); - if (index == 0) - return error.InterfaceNotFound; - return @as(u32, @bitCast(index)); - } - - @compileError("std.net.if_nametoindex unimplemented for this OS"); -} - -pub const AddressList = struct { - arena: std.heap.ArenaAllocator, - addrs: []Address, - canon_name: ?[]u8, - - pub fn deinit(self: *AddressList) void { - // Here we copy the arena allocator into stack memory, because - // otherwise it would destroy itself while it was still working. - var arena = self.arena; - arena.deinit(); - // self is destroyed - } -}; - -pub const TcpConnectToHostError = GetAddressListError || TcpConnectToAddressError; - -/// All memory allocated with `allocator` will be freed before this function returns. -pub fn tcpConnectToHost(allocator: mem.Allocator, name: []const u8, port: u16) TcpConnectToHostError!Stream { - const list = try getAddressList(allocator, name, port); - defer list.deinit(); - - if (list.addrs.len == 0) return error.UnknownHostName; - - for (list.addrs) |addr| { - return tcpConnectToAddress(addr) catch |err| switch (err) { - error.ConnectionRefused => { - continue; - }, - else => return err, - }; - } - return posix.ConnectError.ConnectionRefused; -} - -pub const TcpConnectToAddressError = posix.SocketError || posix.ConnectError; - -pub fn tcpConnectToAddress(address: Address) TcpConnectToAddressError!Stream { - const nonblock = 0; - const sock_flags = posix.SOCK.STREAM | nonblock | - (if (native_os == .windows) 0 else posix.SOCK.CLOEXEC); - const sockfd = try posix.socket(address.any.family, sock_flags, posix.IPPROTO.TCP); - errdefer Stream.close(.{ .handle = sockfd }); - - try posix.connect(sockfd, &address.any, address.getOsSockLen()); - - return Stream{ .handle = sockfd }; -} - -const GetAddressListError = std.mem.Allocator.Error || std.fs.File.OpenError || std.fs.File.ReadError || posix.SocketError || posix.BindError || posix.SetSockOptError || error{ - // TODO: break this up into error sets from the various underlying functions - - TemporaryNameServerFailure, - NameServerFailure, - AddressFamilyNotSupported, - UnknownHostName, - ServiceUnavailable, - Unexpected, - - HostLacksNetworkAddresses, - - InvalidCharacter, - InvalidEnd, - NonCanonical, - Overflow, - Incomplete, - InvalidIpv4Mapping, - InvalidIPAddressFormat, - - InterfaceNotFound, - FileSystem, -}; - -/// Call `AddressList.deinit` on the result. -pub fn getAddressList(allocator: mem.Allocator, name: []const u8, port: u16) GetAddressListError!*AddressList { - const result = blk: { - var arena = std.heap.ArenaAllocator.init(allocator); - errdefer arena.deinit(); - - const result = try arena.allocator().create(AddressList); - result.* = AddressList{ - .arena = arena, - .addrs = undefined, - .canon_name = null, - }; - break :blk result; - }; - const arena = result.arena.allocator(); - errdefer result.deinit(); - - if (native_os == .windows) { - const name_c = try allocator.dupeZ(u8, name); - defer allocator.free(name_c); - - const port_c = try std.fmt.allocPrintZ(allocator, "{}", .{port}); - defer allocator.free(port_c); - - const ws2_32 = windows.ws2_32; - const hints = posix.addrinfo{ - .flags = ws2_32.AI.NUMERICSERV, - .family = posix.AF.UNSPEC, - .socktype = posix.SOCK.STREAM, - .protocol = posix.IPPROTO.TCP, - .canonname = null, - .addr = null, - .addrlen = 0, - .next = null, - }; - var res: ?*posix.addrinfo = null; - var first = true; - while (true) { - const rc = ws2_32.getaddrinfo(name_c.ptr, port_c.ptr, &hints, &res); - switch (@as(windows.ws2_32.WinsockError, @enumFromInt(@as(u16, @intCast(rc))))) { - @as(windows.ws2_32.WinsockError, @enumFromInt(0)) => break, - .WSATRY_AGAIN => return error.TemporaryNameServerFailure, - .WSANO_RECOVERY => return error.NameServerFailure, - .WSAEAFNOSUPPORT => return error.AddressFamilyNotSupported, - .WSA_NOT_ENOUGH_MEMORY => return error.OutOfMemory, - .WSAHOST_NOT_FOUND => return error.UnknownHostName, - .WSATYPE_NOT_FOUND => return error.ServiceUnavailable, - .WSAEINVAL => unreachable, - .WSAESOCKTNOSUPPORT => unreachable, - .WSANOTINITIALISED => { - if (!first) return error.Unexpected; - first = false; - try windows.callWSAStartup(); - continue; - }, - else => |err| return windows.unexpectedWSAError(err), - } - } - defer ws2_32.freeaddrinfo(res); - - const addr_count = blk: { - var count: usize = 0; - var it = res; - while (it) |info| : (it = info.next) { - if (info.addr != null) { - count += 1; - } - } - break :blk count; - }; - result.addrs = try arena.alloc(Address, addr_count); - - var it = res; - var i: usize = 0; - while (it) |info| : (it = info.next) { - const addr = info.addr orelse continue; - result.addrs[i] = Address.initPosix(@alignCast(addr)); - - if (info.canonname) |n| { - if (result.canon_name == null) { - result.canon_name = try arena.dupe(u8, mem.sliceTo(n, 0)); - } - } - i += 1; - } - - return result; - } - - if (builtin.link_libc) { - const name_c = try allocator.dupeZ(u8, name); - defer allocator.free(name_c); - - const port_c = try std.fmt.allocPrintZ(allocator, "{}", .{port}); - defer allocator.free(port_c); - - const sys = if (native_os == .windows) windows.ws2_32 else posix.system; - const hints = posix.addrinfo{ - .flags = sys.AI.NUMERICSERV, - .family = posix.AF.UNSPEC, - .socktype = posix.SOCK.STREAM, - .protocol = posix.IPPROTO.TCP, - .canonname = null, - .addr = null, - .addrlen = 0, - .next = null, - }; - var res: ?*posix.addrinfo = null; - switch (sys.getaddrinfo(name_c.ptr, port_c.ptr, &hints, &res)) { - @as(sys.EAI, @enumFromInt(0)) => {}, - .ADDRFAMILY => return error.HostLacksNetworkAddresses, - .AGAIN => return error.TemporaryNameServerFailure, - .BADFLAGS => unreachable, // Invalid hints - .FAIL => return error.NameServerFailure, - .FAMILY => return error.AddressFamilyNotSupported, - .MEMORY => return error.OutOfMemory, - .NODATA => return error.HostLacksNetworkAddresses, - .NONAME => return error.UnknownHostName, - .SERVICE => return error.ServiceUnavailable, - .SOCKTYPE => unreachable, // Invalid socket type requested in hints - .SYSTEM => switch (posix.errno(-1)) { - else => |e| return posix.unexpectedErrno(e), - }, - else => unreachable, - } - defer if (res) |some| sys.freeaddrinfo(some); - - const addr_count = blk: { - var count: usize = 0; - var it = res; - while (it) |info| : (it = info.next) { - if (info.addr != null) { - count += 1; - } - } - break :blk count; - }; - result.addrs = try arena.alloc(Address, addr_count); - - var it = res; - var i: usize = 0; - while (it) |info| : (it = info.next) { - const addr = info.addr orelse continue; - result.addrs[i] = Address.initPosix(@alignCast(addr)); - - if (info.canonname) |n| { - if (result.canon_name == null) { - result.canon_name = try arena.dupe(u8, mem.sliceTo(n, 0)); - } - } - i += 1; - } - - return result; - } - - if (native_os == .linux) { - const flags = std.c.AI.NUMERICSERV; - const family = posix.AF.UNSPEC; - var lookup_addrs = std.ArrayList(LookupAddr).init(allocator); - defer lookup_addrs.deinit(); - - var canon = std.ArrayList(u8).init(arena); - defer canon.deinit(); - - try linuxLookupName(&lookup_addrs, &canon, name, family, flags, port); - - result.addrs = try arena.alloc(Address, lookup_addrs.items.len); - if (canon.items.len != 0) { - result.canon_name = try canon.toOwnedSlice(); - } - - for (lookup_addrs.items, 0..) |lookup_addr, i| { - result.addrs[i] = lookup_addr.addr; - assert(result.addrs[i].getPort() == port); - } - - return result; - } - @compileError("std.net.getAddressList unimplemented for this OS"); -} - -const LookupAddr = struct { - addr: Address, - sortkey: i32 = 0, -}; - -const DAS_USABLE = 0x40000000; -const DAS_MATCHINGSCOPE = 0x20000000; -const DAS_MATCHINGLABEL = 0x10000000; -const DAS_PREC_SHIFT = 20; -const DAS_SCOPE_SHIFT = 16; -const DAS_PREFIX_SHIFT = 8; -const DAS_ORDER_SHIFT = 0; - -fn linuxLookupName( - addrs: *std.ArrayList(LookupAddr), - canon: *std.ArrayList(u8), - opt_name: ?[]const u8, - family: posix.sa_family_t, - flags: u32, - port: u16, -) !void { - if (opt_name) |name| { - // reject empty name and check len so it fits into temp bufs - canon.items.len = 0; - try canon.appendSlice(name); - if (Address.parseExpectingFamily(name, family, port)) |addr| { - try addrs.append(LookupAddr{ .addr = addr }); - } else |name_err| if ((flags & std.c.AI.NUMERICHOST) != 0) { - return name_err; - } else { - try linuxLookupNameFromHosts(addrs, canon, name, family, port); - if (addrs.items.len == 0) { - // RFC 6761 Section 6.3.3 - // Name resolution APIs and libraries SHOULD recognize localhost - // names as special and SHOULD always return the IP loopback address - // for address queries and negative responses for all other query - // types. - - // Check for equal to "localhost(.)" or ends in ".localhost(.)" - const localhost = if (name[name.len - 1] == '.') "localhost." else "localhost"; - if (mem.endsWith(u8, name, localhost) and (name.len == localhost.len or name[name.len - localhost.len] == '.')) { - try addrs.append(LookupAddr{ .addr = .{ .in = Ip4Address.parse("127.0.0.1", port) catch unreachable } }); - try addrs.append(LookupAddr{ .addr = .{ .in6 = Ip6Address.parse("::1", port) catch unreachable } }); - return; - } - - try linuxLookupNameFromDnsSearch(addrs, canon, name, family, port); - } - } - } else { - try canon.resize(0); - try linuxLookupNameFromNull(addrs, family, flags, port); - } - if (addrs.items.len == 0) return error.UnknownHostName; - - // No further processing is needed if there are fewer than 2 - // results or if there are only IPv4 results. - if (addrs.items.len == 1 or family == posix.AF.INET) return; - const all_ip4 = for (addrs.items) |addr| { - if (addr.addr.any.family != posix.AF.INET) break false; - } else true; - if (all_ip4) return; - - // The following implements a subset of RFC 3484/6724 destination - // address selection by generating a single 31-bit sort key for - // each address. Rules 3, 4, and 7 are omitted for having - // excessive runtime and code size cost and dubious benefit. - // So far the label/precedence table cannot be customized. - // This implementation is ported from musl libc. - // A more idiomatic "ziggy" implementation would be welcome. - for (addrs.items, 0..) |*addr, i| { - var key: i32 = 0; - var sa6: posix.sockaddr.in6 = undefined; - @memset(@as([*]u8, @ptrCast(&sa6))[0..@sizeOf(posix.sockaddr.in6)], 0); - var da6 = posix.sockaddr.in6{ - .family = posix.AF.INET6, - .scope_id = addr.addr.in6.sa.scope_id, - .port = 65535, - .flowinfo = 0, - .addr = [1]u8{0} ** 16, - }; - var sa4: posix.sockaddr.in = undefined; - @memset(@as([*]u8, @ptrCast(&sa4))[0..@sizeOf(posix.sockaddr.in)], 0); - var da4 = posix.sockaddr.in{ - .family = posix.AF.INET, - .port = 65535, - .addr = 0, - .zero = [1]u8{0} ** 8, - }; - var sa: *align(4) posix.sockaddr = undefined; - var da: *align(4) posix.sockaddr = undefined; - var salen: posix.socklen_t = undefined; - var dalen: posix.socklen_t = undefined; - if (addr.addr.any.family == posix.AF.INET6) { - da6.addr = addr.addr.in6.sa.addr; - da = @ptrCast(&da6); - dalen = @sizeOf(posix.sockaddr.in6); - sa = @ptrCast(&sa6); - salen = @sizeOf(posix.sockaddr.in6); - } else { - sa6.addr[0..12].* = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff".*; - da6.addr[0..12].* = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff".*; - mem.writeInt(u32, da6.addr[12..], addr.addr.in.sa.addr, native_endian); - da4.addr = addr.addr.in.sa.addr; - da = @ptrCast(&da4); - dalen = @sizeOf(posix.sockaddr.in); - sa = @ptrCast(&sa4); - salen = @sizeOf(posix.sockaddr.in); - } - const dpolicy = policyOf(da6.addr); - const dscope: i32 = scopeOf(da6.addr); - const dlabel = dpolicy.label; - const dprec: i32 = dpolicy.prec; - const MAXADDRS = 3; - var prefixlen: i32 = 0; - const sock_flags = posix.SOCK.DGRAM | posix.SOCK.CLOEXEC; - if (posix.socket(addr.addr.any.family, sock_flags, posix.IPPROTO.UDP)) |fd| syscalls: { - defer Stream.close(.{ .handle = fd }); - posix.connect(fd, da, dalen) catch break :syscalls; - key |= DAS_USABLE; - posix.getsockname(fd, sa, &salen) catch break :syscalls; - if (addr.addr.any.family == posix.AF.INET) { - mem.writeInt(u32, sa6.addr[12..16], sa4.addr, native_endian); - } - if (dscope == @as(i32, scopeOf(sa6.addr))) key |= DAS_MATCHINGSCOPE; - if (dlabel == labelOf(sa6.addr)) key |= DAS_MATCHINGLABEL; - prefixlen = prefixMatch(sa6.addr, da6.addr); - } else |_| {} - key |= dprec << DAS_PREC_SHIFT; - key |= (15 - dscope) << DAS_SCOPE_SHIFT; - key |= prefixlen << DAS_PREFIX_SHIFT; - key |= (MAXADDRS - @as(i32, @intCast(i))) << DAS_ORDER_SHIFT; - addr.sortkey = key; - } - mem.sort(LookupAddr, addrs.items, {}, addrCmpLessThan); -} - -const Policy = struct { - addr: [16]u8, - len: u8, - mask: u8, - prec: u8, - label: u8, -}; - -const defined_policies = [_]Policy{ - Policy{ - .addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01".*, - .len = 15, - .mask = 0xff, - .prec = 50, - .label = 0, - }, - Policy{ - .addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\x00\x00\x00\x00".*, - .len = 11, - .mask = 0xff, - .prec = 35, - .label = 4, - }, - Policy{ - .addr = "\x20\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00".*, - .len = 1, - .mask = 0xff, - .prec = 30, - .label = 2, - }, - Policy{ - .addr = "\x20\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00".*, - .len = 3, - .mask = 0xff, - .prec = 5, - .label = 5, - }, - Policy{ - .addr = "\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00".*, - .len = 0, - .mask = 0xfe, - .prec = 3, - .label = 13, - }, - // These are deprecated and/or returned to the address - // pool, so despite the RFC, treating them as special - // is probably wrong. - // { "", 11, 0xff, 1, 3 }, - // { "\xfe\xc0", 1, 0xc0, 1, 11 }, - // { "\x3f\xfe", 1, 0xff, 1, 12 }, - // Last rule must match all addresses to stop loop. - Policy{ - .addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00".*, - .len = 0, - .mask = 0, - .prec = 40, - .label = 1, - }, -}; - -fn policyOf(a: [16]u8) *const Policy { - for (&defined_policies) |*policy| { - if (!mem.eql(u8, a[0..policy.len], policy.addr[0..policy.len])) continue; - if ((a[policy.len] & policy.mask) != policy.addr[policy.len]) continue; - return policy; - } - unreachable; -} - -fn scopeOf(a: [16]u8) u8 { - if (IN6_IS_ADDR_MULTICAST(a)) return a[1] & 15; - if (IN6_IS_ADDR_LINKLOCAL(a)) return 2; - if (IN6_IS_ADDR_LOOPBACK(a)) return 2; - if (IN6_IS_ADDR_SITELOCAL(a)) return 5; - return 14; -} - -fn prefixMatch(s: [16]u8, d: [16]u8) u8 { - // TODO: This FIXME inherited from porting from musl libc. - // I don't want this to go into zig std lib 1.0.0. - - // FIXME: The common prefix length should be limited to no greater - // than the nominal length of the prefix portion of the source - // address. However the definition of the source prefix length is - // not clear and thus this limiting is not yet implemented. - var i: u8 = 0; - while (i < 128 and ((s[i / 8] ^ d[i / 8]) & (@as(u8, 128) >> @as(u3, @intCast(i % 8)))) == 0) : (i += 1) {} - return i; -} - -fn labelOf(a: [16]u8) u8 { - return policyOf(a).label; -} - -fn IN6_IS_ADDR_MULTICAST(a: [16]u8) bool { - return a[0] == 0xff; -} - -fn IN6_IS_ADDR_LINKLOCAL(a: [16]u8) bool { - return a[0] == 0xfe and (a[1] & 0xc0) == 0x80; -} - -fn IN6_IS_ADDR_LOOPBACK(a: [16]u8) bool { - return a[0] == 0 and a[1] == 0 and - a[2] == 0 and - a[12] == 0 and a[13] == 0 and - a[14] == 0 and a[15] == 1; -} - -fn IN6_IS_ADDR_SITELOCAL(a: [16]u8) bool { - return a[0] == 0xfe and (a[1] & 0xc0) == 0xc0; -} - -// Parameters `b` and `a` swapped to make this descending. -fn addrCmpLessThan(context: void, b: LookupAddr, a: LookupAddr) bool { - _ = context; - return a.sortkey < b.sortkey; -} - -fn linuxLookupNameFromNull( - addrs: *std.ArrayList(LookupAddr), - family: posix.sa_family_t, - flags: u32, - port: u16, -) !void { - if ((flags & std.c.AI.PASSIVE) != 0) { - if (family != posix.AF.INET6) { - (try addrs.addOne()).* = LookupAddr{ - .addr = Address.initIp4([1]u8{0} ** 4, port), - }; - } - if (family != posix.AF.INET) { - (try addrs.addOne()).* = LookupAddr{ - .addr = Address.initIp6([1]u8{0} ** 16, port, 0, 0), - }; - } - } else { - if (family != posix.AF.INET6) { - (try addrs.addOne()).* = LookupAddr{ - .addr = Address.initIp4([4]u8{ 127, 0, 0, 1 }, port), - }; - } - if (family != posix.AF.INET) { - (try addrs.addOne()).* = LookupAddr{ - .addr = Address.initIp6(([1]u8{0} ** 15) ++ [1]u8{1}, port, 0, 0), - }; - } - } -} - -fn linuxLookupNameFromHosts( - addrs: *std.ArrayList(LookupAddr), - canon: *std.ArrayList(u8), - name: []const u8, - family: posix.sa_family_t, - port: u16, -) !void { - const file = fs.openFileAbsoluteZ("/etc/hosts", .{}) catch |err| switch (err) { - error.FileNotFound, - error.NotDir, - error.AccessDenied, - => return, - else => |e| return e, - }; - defer file.close(); - - var buffered_reader = std.io.bufferedReader(file.reader()); - const reader = buffered_reader.reader(); - var line_buf: [512]u8 = undefined; - while (reader.readUntilDelimiterOrEof(&line_buf, '\n') catch |err| switch (err) { - error.StreamTooLong => blk: { - // Skip to the delimiter in the reader, to fix parsing - try reader.skipUntilDelimiterOrEof('\n'); - // Use the truncated line. A truncated comment or hostname will be handled correctly. - break :blk &line_buf; - }, - else => |e| return e, - }) |line| { - var split_it = mem.splitScalar(u8, line, '#'); - const no_comment_line = split_it.first(); - - var line_it = mem.tokenizeAny(u8, no_comment_line, " \t"); - const ip_text = line_it.next() orelse continue; - var first_name_text: ?[]const u8 = null; - while (line_it.next()) |name_text| { - if (first_name_text == null) first_name_text = name_text; - if (mem.eql(u8, name_text, name)) { - break; - } - } else continue; - - const addr = Address.parseExpectingFamily(ip_text, family, port) catch |err| switch (err) { - error.Overflow, - error.InvalidEnd, - error.InvalidCharacter, - error.Incomplete, - error.InvalidIPAddressFormat, - error.InvalidIpv4Mapping, - error.NonCanonical, - => continue, - }; - try addrs.append(LookupAddr{ .addr = addr }); - - // first name is canonical name - const name_text = first_name_text.?; - if (isValidHostName(name_text)) { - canon.items.len = 0; - try canon.appendSlice(name_text); - } - } -} - -pub fn isValidHostName(hostname: []const u8) bool { - if (hostname.len >= 254) return false; - if (!std.unicode.utf8ValidateSlice(hostname)) return false; - for (hostname) |byte| { - if (!std.ascii.isASCII(byte) or byte == '.' or byte == '-' or std.ascii.isAlphanumeric(byte)) { - continue; - } - return false; - } - return true; -} - -fn linuxLookupNameFromDnsSearch( - addrs: *std.ArrayList(LookupAddr), - canon: *std.ArrayList(u8), - name: []const u8, - family: posix.sa_family_t, - port: u16, -) !void { - var rc: ResolvConf = undefined; - try getResolvConf(addrs.allocator, &rc); - defer rc.deinit(); - - // Count dots, suppress search when >=ndots or name ends in - // a dot, which is an explicit request for global scope. - var dots: usize = 0; - for (name) |byte| { - if (byte == '.') dots += 1; - } - - const search = if (dots >= rc.ndots or mem.endsWith(u8, name, ".")) - "" - else - rc.search.items; - - var canon_name = name; - - // Strip final dot for canon, fail if multiple trailing dots. - if (mem.endsWith(u8, canon_name, ".")) canon_name.len -= 1; - if (mem.endsWith(u8, canon_name, ".")) return error.UnknownHostName; - - // Name with search domain appended is setup in canon[]. This both - // provides the desired default canonical name (if the requested - // name is not a CNAME record) and serves as a buffer for passing - // the full requested name to name_from_dns. - try canon.resize(canon_name.len); - @memcpy(canon.items, canon_name); - try canon.append('.'); - - var tok_it = mem.tokenizeAny(u8, search, " \t"); - while (tok_it.next()) |tok| { - canon.shrinkRetainingCapacity(canon_name.len + 1); - try canon.appendSlice(tok); - try linuxLookupNameFromDns(addrs, canon, canon.items, family, rc, port); - if (addrs.items.len != 0) return; - } - - canon.shrinkRetainingCapacity(canon_name.len); - return linuxLookupNameFromDns(addrs, canon, name, family, rc, port); -} - -const dpc_ctx = struct { - addrs: *std.ArrayList(LookupAddr), - canon: *std.ArrayList(u8), - port: u16, -}; - -fn linuxLookupNameFromDns( - addrs: *std.ArrayList(LookupAddr), - canon: *std.ArrayList(u8), - name: []const u8, - family: posix.sa_family_t, - rc: ResolvConf, - port: u16, -) !void { - const ctx = dpc_ctx{ - .addrs = addrs, - .canon = canon, - .port = port, - }; - const AfRr = struct { - af: posix.sa_family_t, - rr: u8, - }; - const afrrs = [_]AfRr{ - AfRr{ .af = posix.AF.INET6, .rr = posix.RR.A }, - AfRr{ .af = posix.AF.INET, .rr = posix.RR.AAAA }, - }; - var qbuf: [2][280]u8 = undefined; - var abuf: [2][512]u8 = undefined; - var qp: [2][]const u8 = undefined; - const apbuf = [2][]u8{ &abuf[0], &abuf[1] }; - var nq: usize = 0; - - for (afrrs) |afrr| { - if (family != afrr.af) { - const len = posix.res_mkquery(0, name, 1, afrr.rr, &[_]u8{}, null, &qbuf[nq]); - qp[nq] = qbuf[nq][0..len]; - nq += 1; - } - } - - var ap = [2][]u8{ apbuf[0], apbuf[1] }; - ap[0].len = 0; - ap[1].len = 0; - - try resMSendRc(qp[0..nq], ap[0..nq], apbuf[0..nq], rc); - - var i: usize = 0; - while (i < nq) : (i += 1) { - dnsParse(ap[i], ctx, dnsParseCallback) catch {}; - } - - if (addrs.items.len != 0) return; - if (ap[0].len < 4 or (ap[0][3] & 15) == 2) return error.TemporaryNameServerFailure; - if ((ap[0][3] & 15) == 0) return error.UnknownHostName; - if ((ap[0][3] & 15) == 3) return; - return error.NameServerFailure; -} - -const ResolvConf = struct { - attempts: u32, - ndots: u32, - timeout: u32, - search: std.ArrayList(u8), - ns: std.ArrayList(LookupAddr), - - fn deinit(rc: *ResolvConf) void { - rc.ns.deinit(); - rc.search.deinit(); - rc.* = undefined; - } -}; - -/// Ignores lines longer than 512 bytes. -/// TODO: https://github.com/ziglang/zig/issues/2765 and https://github.com/ziglang/zig/issues/2761 -fn getResolvConf(allocator: mem.Allocator, rc: *ResolvConf) !void { - rc.* = ResolvConf{ - .ns = std.ArrayList(LookupAddr).init(allocator), - .search = std.ArrayList(u8).init(allocator), - .ndots = 1, - .timeout = 5, - .attempts = 2, - }; - errdefer rc.deinit(); - - const file = fs.openFileAbsoluteZ("/etc/resolv.conf", .{}) catch |err| switch (err) { - error.FileNotFound, - error.NotDir, - error.AccessDenied, - => return linuxLookupNameFromNumericUnspec(&rc.ns, "127.0.0.1", 53), - else => |e| return e, - }; - defer file.close(); - - var buf_reader = std.io.bufferedReader(file.reader()); - const stream = buf_reader.reader(); - var line_buf: [512]u8 = undefined; - while (stream.readUntilDelimiterOrEof(&line_buf, '\n') catch |err| switch (err) { - error.StreamTooLong => blk: { - // Skip to the delimiter in the stream, to fix parsing - try stream.skipUntilDelimiterOrEof('\n'); - // Give an empty line to the while loop, which will be skipped. - break :blk line_buf[0..0]; - }, - else => |e| return e, - }) |line| { - const no_comment_line = no_comment_line: { - var split = mem.splitScalar(u8, line, '#'); - break :no_comment_line split.first(); - }; - var line_it = mem.tokenizeAny(u8, no_comment_line, " \t"); - - const token = line_it.next() orelse continue; - if (mem.eql(u8, token, "options")) { - while (line_it.next()) |sub_tok| { - var colon_it = mem.splitScalar(u8, sub_tok, ':'); - const name = colon_it.first(); - const value_txt = colon_it.next() orelse continue; - const value = std.fmt.parseInt(u8, value_txt, 10) catch |err| switch (err) { - // TODO https://github.com/ziglang/zig/issues/11812 - error.Overflow => @as(u8, 255), - error.InvalidCharacter => continue, - }; - if (mem.eql(u8, name, "ndots")) { - rc.ndots = @min(value, 15); - } else if (mem.eql(u8, name, "attempts")) { - rc.attempts = @min(value, 10); - } else if (mem.eql(u8, name, "timeout")) { - rc.timeout = @min(value, 60); - } - } - } else if (mem.eql(u8, token, "nameserver")) { - const ip_txt = line_it.next() orelse continue; - try linuxLookupNameFromNumericUnspec(&rc.ns, ip_txt, 53); - } else if (mem.eql(u8, token, "domain") or mem.eql(u8, token, "search")) { - rc.search.items.len = 0; - try rc.search.appendSlice(line_it.rest()); - } - } - - if (rc.ns.items.len == 0) { - return linuxLookupNameFromNumericUnspec(&rc.ns, "127.0.0.1", 53); - } -} - -fn linuxLookupNameFromNumericUnspec( - addrs: *std.ArrayList(LookupAddr), - name: []const u8, - port: u16, -) !void { - const addr = try Address.resolveIp(name, port); - (try addrs.addOne()).* = LookupAddr{ .addr = addr }; -} - -fn resMSendRc( - queries: []const []const u8, - answers: [][]u8, - answer_bufs: []const []u8, - rc: ResolvConf, -) !void { - const timeout = 1000 * rc.timeout; - const attempts = rc.attempts; - - var sl: posix.socklen_t = @sizeOf(posix.sockaddr.in); - var family: posix.sa_family_t = posix.AF.INET; - - var ns_list = std.ArrayList(Address).init(rc.ns.allocator); - defer ns_list.deinit(); - - try ns_list.resize(rc.ns.items.len); - const ns = ns_list.items; - - for (rc.ns.items, 0..) |iplit, i| { - ns[i] = iplit.addr; - assert(ns[i].getPort() == 53); - if (iplit.addr.any.family != posix.AF.INET) { - family = posix.AF.INET6; - } - } - - const flags = posix.SOCK.DGRAM | posix.SOCK.CLOEXEC | posix.SOCK.NONBLOCK; - const fd = posix.socket(family, flags, 0) catch |err| switch (err) { - error.AddressFamilyNotSupported => blk: { - // Handle case where system lacks IPv6 support - if (family == posix.AF.INET6) { - family = posix.AF.INET; - break :blk try posix.socket(posix.AF.INET, flags, 0); - } - return err; - }, - else => |e| return e, - }; - defer Stream.close(.{ .handle = fd }); - - // Past this point, there are no errors. Each individual query will - // yield either no reply (indicated by zero length) or an answer - // packet which is up to the caller to interpret. - - // Convert any IPv4 addresses in a mixed environment to v4-mapped - if (family == posix.AF.INET6) { - try posix.setsockopt( - fd, - posix.SOL.IPV6, - std.os.linux.IPV6.V6ONLY, - &mem.toBytes(@as(c_int, 0)), - ); - for (0..ns.len) |i| { - if (ns[i].any.family != posix.AF.INET) continue; - mem.writeInt(u32, ns[i].in6.sa.addr[12..], ns[i].in.sa.addr, native_endian); - ns[i].in6.sa.addr[0..12].* = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff".*; - ns[i].any.family = posix.AF.INET6; - ns[i].in6.sa.flowinfo = 0; - ns[i].in6.sa.scope_id = 0; - } - sl = @sizeOf(posix.sockaddr.in6); - } - - // Get local address and open/bind a socket - var sa: Address = undefined; - @memset(@as([*]u8, @ptrCast(&sa))[0..@sizeOf(Address)], 0); - sa.any.family = family; - try posix.bind(fd, &sa.any, sl); - - var pfd = [1]posix.pollfd{posix.pollfd{ - .fd = fd, - .events = posix.POLL.IN, - .revents = undefined, - }}; - const retry_interval = timeout / attempts; - var next: u32 = 0; - var t2: u64 = @bitCast(std.time.milliTimestamp()); - const t0 = t2; - var t1 = t2 - retry_interval; - - var servfail_retry: usize = undefined; - - outer: while (t2 - t0 < timeout) : (t2 = @as(u64, @bitCast(std.time.milliTimestamp()))) { - if (t2 - t1 >= retry_interval) { - // Query all configured nameservers in parallel - var i: usize = 0; - while (i < queries.len) : (i += 1) { - if (answers[i].len == 0) { - var j: usize = 0; - while (j < ns.len) : (j += 1) { - _ = posix.sendto(fd, queries[i], posix.MSG.NOSIGNAL, &ns[j].any, sl) catch undefined; - } - } - } - t1 = t2; - servfail_retry = 2 * queries.len; - } - - // Wait for a response, or until time to retry - const clamped_timeout = @min(@as(u31, std.math.maxInt(u31)), t1 + retry_interval - t2); - const nevents = posix.poll(&pfd, clamped_timeout) catch 0; - if (nevents == 0) continue; - - while (true) { - var sl_copy = sl; - const rlen = posix.recvfrom(fd, answer_bufs[next], 0, &sa.any, &sl_copy) catch break; - - // Ignore non-identifiable packets - if (rlen < 4) continue; - - // Ignore replies from addresses we didn't send to - var j: usize = 0; - while (j < ns.len and !ns[j].eql(sa)) : (j += 1) {} - if (j == ns.len) continue; - - // Find which query this answer goes with, if any - var i: usize = next; - while (i < queries.len and (answer_bufs[next][0] != queries[i][0] or - answer_bufs[next][1] != queries[i][1])) : (i += 1) - {} - - if (i == queries.len) continue; - if (answers[i].len != 0) continue; - - // Only accept positive or negative responses; - // retry immediately on server failure, and ignore - // all other codes such as refusal. - switch (answer_bufs[next][3] & 15) { - 0, 3 => {}, - 2 => if (servfail_retry != 0) { - servfail_retry -= 1; - _ = posix.sendto(fd, queries[i], posix.MSG.NOSIGNAL, &ns[j].any, sl) catch undefined; - }, - else => continue, - } - - // Store answer in the right slot, or update next - // available temp slot if it's already in place. - answers[i].len = rlen; - if (i == next) { - while (next < queries.len and answers[next].len != 0) : (next += 1) {} - } else { - @memcpy(answer_bufs[i][0..rlen], answer_bufs[next][0..rlen]); - } - - if (next == queries.len) break :outer; - } - } -} - -fn dnsParse( - r: []const u8, - ctx: anytype, - comptime callback: anytype, -) !void { - // This implementation is ported from musl libc. - // A more idiomatic "ziggy" implementation would be welcome. - if (r.len < 12) return error.InvalidDnsPacket; - if ((r[3] & 15) != 0) return; - var p = r.ptr + 12; - var qdcount = r[4] * @as(usize, 256) + r[5]; - var ancount = r[6] * @as(usize, 256) + r[7]; - if (qdcount + ancount > 64) return error.InvalidDnsPacket; - while (qdcount != 0) { - qdcount -= 1; - while (@intFromPtr(p) - @intFromPtr(r.ptr) < r.len and p[0] -% 1 < 127) p += 1; - if (p[0] > 193 or (p[0] == 193 and p[1] > 254) or @intFromPtr(p) > @intFromPtr(r.ptr) + r.len - 6) - return error.InvalidDnsPacket; - p += @as(usize, 5) + @intFromBool(p[0] != 0); - } - while (ancount != 0) { - ancount -= 1; - while (@intFromPtr(p) - @intFromPtr(r.ptr) < r.len and p[0] -% 1 < 127) p += 1; - if (p[0] > 193 or (p[0] == 193 and p[1] > 254) or @intFromPtr(p) > @intFromPtr(r.ptr) + r.len - 6) - return error.InvalidDnsPacket; - p += @as(usize, 1) + @intFromBool(p[0] != 0); - const len = p[8] * @as(usize, 256) + p[9]; - if (@intFromPtr(p) + len > @intFromPtr(r.ptr) + r.len) return error.InvalidDnsPacket; - try callback(ctx, p[1], p[10..][0..len], r); - p += 10 + len; - } -} - -fn dnsParseCallback(ctx: dpc_ctx, rr: u8, data: []const u8, packet: []const u8) !void { - switch (rr) { - posix.RR.A => { - if (data.len != 4) return error.InvalidDnsARecord; - const new_addr = try ctx.addrs.addOne(); - new_addr.* = LookupAddr{ - .addr = Address.initIp4(data[0..4].*, ctx.port), - }; - }, - posix.RR.AAAA => { - if (data.len != 16) return error.InvalidDnsAAAARecord; - const new_addr = try ctx.addrs.addOne(); - new_addr.* = LookupAddr{ - .addr = Address.initIp6(data[0..16].*, ctx.port, 0, 0), - }; - }, - posix.RR.CNAME => { - var tmp: [256]u8 = undefined; - // Returns len of compressed name. strlen to get canon name. - _ = try posix.dn_expand(packet, data, &tmp); - const canon_name = mem.sliceTo(&tmp, 0); - if (isValidHostName(canon_name)) { - ctx.canon.items.len = 0; - try ctx.canon.appendSlice(canon_name); - } - }, - else => return, - } -} - -pub const Stream = struct { - /// Underlying platform-defined type which may or may not be - /// interchangeable with a file system file descriptor. - handle: posix.socket_t, - - pub fn close(s: Stream) void { - switch (native_os) { - .windows => windows.closesocket(s.handle) catch unreachable, - else => posix.close(s.handle), - } - } - - pub const ReadError = posix.ReadError; - pub const WriteError = posix.WriteError; - - pub const Reader = io.Reader(Stream, ReadError, read); - pub const Writer = io.Writer(Stream, WriteError, write); - - pub fn reader(self: Stream) Reader { - return .{ .context = self }; - } - - pub fn writer(self: Stream) Writer { - return .{ .context = self }; - } - - pub fn read(self: Stream, buffer: []u8) ReadError!usize { - if (native_os == .windows) { - return windows.ReadFile(self.handle, buffer, null); - } - - return posix.read(self.handle, buffer); - } - - pub fn readv(s: Stream, iovecs: []const posix.iovec) ReadError!usize { - if (native_os == .windows) { - // TODO improve this to use ReadFileScatter - if (iovecs.len == 0) return @as(usize, 0); - const first = iovecs[0]; - return windows.ReadFile(s.handle, first.base[0..first.len], null); - } - - return posix.readv(s.handle, iovecs); - } - - /// Returns the number of bytes read. If the number read is smaller than - /// `buffer.len`, it means the stream reached the end. Reaching the end of - /// a stream is not an error condition. - pub fn readAll(s: Stream, buffer: []u8) ReadError!usize { - return readAtLeast(s, buffer, buffer.len); - } - - /// Returns the number of bytes read, calling the underlying read function - /// the minimal number of times until the buffer has at least `len` bytes - /// filled. If the number read is less than `len` it means the stream - /// reached the end. Reaching the end of the stream is not an error - /// condition. - pub fn readAtLeast(s: Stream, buffer: []u8, len: usize) ReadError!usize { - assert(len <= buffer.len); - var index: usize = 0; - while (index < len) { - const amt = try s.read(buffer[index..]); - if (amt == 0) break; - index += amt; - } - return index; - } - - /// TODO in evented I/O mode, this implementation incorrectly uses the event loop's - /// file system thread instead of non-blocking. It needs to be reworked to properly - /// use non-blocking I/O. - pub fn write(self: Stream, buffer: []const u8) WriteError!usize { - if (native_os == .windows) { - return windows.WriteFile(self.handle, buffer, null); - } - - return posix.write(self.handle, buffer); - } - - pub fn writeAll(self: Stream, bytes: []const u8) WriteError!void { - var index: usize = 0; - while (index < bytes.len) { - index += try self.write(bytes[index..]); - } - } - - /// See https://github.com/ziglang/zig/issues/7699 - /// See equivalent function: `std.fs.File.writev`. - pub fn writev(self: Stream, iovecs: []const posix.iovec_const) WriteError!usize { - return posix.writev(self.handle, iovecs); - } - - /// The `iovecs` parameter is mutable because this function needs to mutate the fields in - /// order to handle partial writes from the underlying OS layer. - /// See https://github.com/ziglang/zig/issues/7699 - /// See equivalent function: `std.fs.File.writevAll`. - pub fn writevAll(self: Stream, iovecs: []posix.iovec_const) WriteError!void { - if (iovecs.len == 0) return; - - var i: usize = 0; - while (true) { - var amt = try self.writev(iovecs[i..]); - while (amt >= iovecs[i].len) { - amt -= iovecs[i].len; - i += 1; - if (i >= iovecs.len) return; - } - iovecs[i].base += amt; - iovecs[i].len -= amt; - } - } - - pub fn async_read( - self: Stream, - buffer: []u8, - ctx: *Ctx, - comptime cbk: Cbk, - ) !void { - return ctx.loop.recv(Ctx, ctx, cbk, self.handle, buffer); - } - - pub fn async_readv( - s: Stream, - iovecs: []const posix.iovec, - ctx: *Ctx, - comptime cbk: Cbk, - ) ReadError!void { - if (iovecs.len == 0) return; - const first_buffer = iovecs[0].base[0..iovecs[0].len]; - return s.async_read(first_buffer, ctx, cbk); - } - - // TODO: why not take a buffer here? - pub fn async_write(self: Stream, buffer: []const u8, ctx: *Ctx, comptime cbk: Cbk) void { - return ctx.loop.send(Ctx, ctx, cbk, self.handle, buffer); - } - - fn onWriteAll(ctx: *Ctx, res: anyerror!void) anyerror!void { - res catch |err| return ctx.pop(err); - if (ctx.len() < ctx.buf().len) { - const new_buf = ctx.buf()[ctx.len()..]; - ctx.setBuf(new_buf); - return ctx.stream().async_write(new_buf, ctx, onWriteAll); - } - ctx.setBuf(null); - return ctx.pop({}); - } - - pub fn async_writeAll(self: Stream, bytes: []const u8, ctx: *Ctx, comptime cbk: Cbk) !void { - ctx.setBuf(bytes); - try ctx.push(cbk); - self.async_write(bytes, ctx, onWriteAll); - } -}; - -pub const Server = struct { - listen_address: Address, - stream: std.net.Stream, - - pub const Connection = struct { - stream: std.net.Stream, - address: Address, - }; - - pub fn deinit(s: *Server) void { - s.stream.close(); - s.* = undefined; - } - - pub const AcceptError = posix.AcceptError; - - /// Blocks until a client connects to the server. The returned `Connection` has - /// an open stream. - pub fn accept(s: *Server) AcceptError!Connection { - var accepted_addr: Address = undefined; - var addr_len: posix.socklen_t = @sizeOf(Address); - const fd = try posix.accept(s.stream.handle, &accepted_addr.any, &addr_len, posix.SOCK.CLOEXEC); - return .{ - .stream = .{ .handle = fd }, - .address = accepted_addr, - }; - } -}; - -test { - _ = @import("net/test.zig"); - _ = Server; - _ = Stream; - _ = Address; -} - -fn onTcpConnectToHost(ctx: *Ctx, res: anyerror!void) anyerror!void { - res catch |e| switch (e) { - error.ConnectionRefused => { - if (ctx.data.addr_current < ctx.data.list.addrs.len) { - // next iteration of addr - ctx.push(onTcpConnectToHost) catch |er| return ctx.pop(er); - ctx.data.addr_current += 1; - return async_tcpConnectToAddress( - ctx.data.list.addrs[ctx.data.addr_current], - ctx, - onTcpConnectToHost, - ); - } - // end of iteration of addr - ctx.data.list.deinit(); - return ctx.pop(e); - }, - else => { - ctx.data.list.deinit(); - return ctx.pop(std.posix.ConnectError.ConnectionRefused); - }, - }; - // success - ctx.data.list.deinit(); - return ctx.pop({}); -} - -pub fn async_tcpConnectToHost( - allocator: mem.Allocator, - name: []const u8, - port: u16, - ctx: *Ctx, - comptime cbk: Cbk, -) !void { - const list = std.net.getAddressList(allocator, name, port) catch |e| return ctx.pop(e); - if (list.addrs.len == 0) return ctx.pop(error.UnknownHostName); - - ctx.push(cbk) catch |e| return ctx.pop(e); - ctx.data.list = list; - ctx.data.addr_current = 0; - return async_tcpConnectToAddress(list.addrs[0], ctx, onTcpConnectToHost); -} - -pub fn async_tcpConnectToAddress(address: std.net.Address, ctx: *Ctx, comptime cbk: Cbk) !void { - const nonblock = 0; - const sock_flags = posix.SOCK.STREAM | nonblock | - (if (native_os == .windows) 0 else posix.SOCK.CLOEXEC); - const sockfd = try posix.socket(address.any.family, sock_flags, posix.IPPROTO.TCP); - - ctx.data.socket = sockfd; - ctx.push(cbk) catch |e| return ctx.pop(e); - - ctx.loop.connect( - Ctx, - ctx, - setStream, - sockfd, - address, - ); -} - -// requires client.data.socket to be set -fn setStream(ctx: *Ctx, res: anyerror!void) anyerror!void { - res catch |e| return ctx.pop(e); - ctx.data.conn.stream = .{ .handle = ctx.data.socket }; - return ctx.pop({}); -} diff --git a/src/http/async/std/net/test.zig b/src/http/async/std/net/test.zig deleted file mode 100644 index 3e316c54..00000000 --- a/src/http/async/std/net/test.zig +++ /dev/null @@ -1,335 +0,0 @@ -const std = @import("std"); -const builtin = @import("builtin"); -const net = std.net; -const mem = std.mem; -const testing = std.testing; - -test "parse and render IP addresses at comptime" { - if (builtin.os.tag == .wasi) return error.SkipZigTest; - comptime { - var ipAddrBuffer: [16]u8 = undefined; - // Parses IPv6 at comptime - const ipv6addr = net.Address.parseIp("::1", 0) catch unreachable; - var ipv6 = std.fmt.bufPrint(ipAddrBuffer[0..], "{}", .{ipv6addr}) catch unreachable; - try std.testing.expect(std.mem.eql(u8, "::1", ipv6[1 .. ipv6.len - 3])); - - // Parses IPv4 at comptime - const ipv4addr = net.Address.parseIp("127.0.0.1", 0) catch unreachable; - var ipv4 = std.fmt.bufPrint(ipAddrBuffer[0..], "{}", .{ipv4addr}) catch unreachable; - try std.testing.expect(std.mem.eql(u8, "127.0.0.1", ipv4[0 .. ipv4.len - 2])); - - // Returns error for invalid IP addresses at comptime - try testing.expectError(error.InvalidIPAddressFormat, net.Address.parseIp("::123.123.123.123", 0)); - try testing.expectError(error.InvalidIPAddressFormat, net.Address.parseIp("127.01.0.1", 0)); - try testing.expectError(error.InvalidIPAddressFormat, net.Address.resolveIp("::123.123.123.123", 0)); - try testing.expectError(error.InvalidIPAddressFormat, net.Address.resolveIp("127.01.0.1", 0)); - } -} - -test "parse and render IPv6 addresses" { - if (builtin.os.tag == .wasi) return error.SkipZigTest; - - var buffer: [100]u8 = undefined; - const ips = [_][]const u8{ - "FF01:0:0:0:0:0:0:FB", - "FF01::Fb", - "::1", - "::", - "1::", - "2001:db8::", - "::1234:5678", - "2001:db8::1234:5678", - "FF01::FB%1234", - "::ffff:123.5.123.5", - }; - const printed = [_][]const u8{ - "ff01::fb", - "ff01::fb", - "::1", - "::", - "1::", - "2001:db8::", - "::1234:5678", - "2001:db8::1234:5678", - "ff01::fb", - "::ffff:123.5.123.5", - }; - for (ips, 0..) |ip, i| { - const addr = net.Address.parseIp6(ip, 0) catch unreachable; - var newIp = std.fmt.bufPrint(buffer[0..], "{}", .{addr}) catch unreachable; - try std.testing.expect(std.mem.eql(u8, printed[i], newIp[1 .. newIp.len - 3])); - - if (builtin.os.tag == .linux) { - const addr_via_resolve = net.Address.resolveIp6(ip, 0) catch unreachable; - var newResolvedIp = std.fmt.bufPrint(buffer[0..], "{}", .{addr_via_resolve}) catch unreachable; - try std.testing.expect(std.mem.eql(u8, printed[i], newResolvedIp[1 .. newResolvedIp.len - 3])); - } - } - - try testing.expectError(error.InvalidCharacter, net.Address.parseIp6(":::", 0)); - try testing.expectError(error.Overflow, net.Address.parseIp6("FF001::FB", 0)); - try testing.expectError(error.InvalidCharacter, net.Address.parseIp6("FF01::Fb:zig", 0)); - try testing.expectError(error.InvalidEnd, net.Address.parseIp6("FF01:0:0:0:0:0:0:FB:", 0)); - try testing.expectError(error.Incomplete, net.Address.parseIp6("FF01:", 0)); - try testing.expectError(error.InvalidIpv4Mapping, net.Address.parseIp6("::123.123.123.123", 0)); - try testing.expectError(error.Incomplete, net.Address.parseIp6("1", 0)); - // TODO Make this test pass on other operating systems. - if (builtin.os.tag == .linux or comptime builtin.os.tag.isDarwin()) { - try testing.expectError(error.Incomplete, net.Address.resolveIp6("ff01::fb%", 0)); - try testing.expectError(error.Overflow, net.Address.resolveIp6("ff01::fb%wlp3s0s0s0s0s0s0s0s0", 0)); - try testing.expectError(error.Overflow, net.Address.resolveIp6("ff01::fb%12345678901234", 0)); - } -} - -test "invalid but parseable IPv6 scope ids" { - if (builtin.os.tag != .linux and comptime !builtin.os.tag.isDarwin()) { - // Currently, resolveIp6 with alphanumerical scope IDs only works on Linux. - // TODO Make this test pass on other operating systems. - return error.SkipZigTest; - } - - try testing.expectError(error.InterfaceNotFound, net.Address.resolveIp6("ff01::fb%123s45678901234", 0)); -} - -test "parse and render IPv4 addresses" { - if (builtin.os.tag == .wasi) return error.SkipZigTest; - - var buffer: [18]u8 = undefined; - for ([_][]const u8{ - "0.0.0.0", - "255.255.255.255", - "1.2.3.4", - "123.255.0.91", - "127.0.0.1", - }) |ip| { - const addr = net.Address.parseIp4(ip, 0) catch unreachable; - var newIp = std.fmt.bufPrint(buffer[0..], "{}", .{addr}) catch unreachable; - try std.testing.expect(std.mem.eql(u8, ip, newIp[0 .. newIp.len - 2])); - } - - try testing.expectError(error.Overflow, net.Address.parseIp4("256.0.0.1", 0)); - try testing.expectError(error.InvalidCharacter, net.Address.parseIp4("x.0.0.1", 0)); - try testing.expectError(error.InvalidEnd, net.Address.parseIp4("127.0.0.1.1", 0)); - try testing.expectError(error.Incomplete, net.Address.parseIp4("127.0.0.", 0)); - try testing.expectError(error.InvalidCharacter, net.Address.parseIp4("100..0.1", 0)); - try testing.expectError(error.NonCanonical, net.Address.parseIp4("127.01.0.1", 0)); -} - -test "parse and render UNIX addresses" { - if (builtin.os.tag == .wasi) return error.SkipZigTest; - if (!net.has_unix_sockets) return error.SkipZigTest; - - var buffer: [14]u8 = undefined; - const addr = net.Address.initUnix("/tmp/testpath") catch unreachable; - const fmt_addr = std.fmt.bufPrint(buffer[0..], "{}", .{addr}) catch unreachable; - try std.testing.expectEqualSlices(u8, "/tmp/testpath", fmt_addr); - - const too_long = [_]u8{'a'} ** 200; - try testing.expectError(error.NameTooLong, net.Address.initUnix(too_long[0..])); -} - -test "resolve DNS" { - if (builtin.os.tag == .wasi) return error.SkipZigTest; - - if (builtin.os.tag == .windows) { - _ = try std.os.windows.WSAStartup(2, 2); - } - defer { - if (builtin.os.tag == .windows) { - std.os.windows.WSACleanup() catch unreachable; - } - } - - // Resolve localhost, this should not fail. - { - const localhost_v4 = try net.Address.parseIp("127.0.0.1", 80); - const localhost_v6 = try net.Address.parseIp("::2", 80); - - const result = try net.getAddressList(testing.allocator, "localhost", 80); - defer result.deinit(); - for (result.addrs) |addr| { - if (addr.eql(localhost_v4) or addr.eql(localhost_v6)) break; - } else @panic("unexpected address for localhost"); - } - - { - // The tests are required to work even when there is no Internet connection, - // so some of these errors we must accept and skip the test. - const result = net.getAddressList(testing.allocator, "example.com", 80) catch |err| switch (err) { - error.UnknownHostName => return error.SkipZigTest, - error.TemporaryNameServerFailure => return error.SkipZigTest, - else => return err, - }; - result.deinit(); - } -} - -test "listen on a port, send bytes, receive bytes" { - if (builtin.single_threaded) return error.SkipZigTest; - if (builtin.os.tag == .wasi) return error.SkipZigTest; - - if (builtin.os.tag == .windows) { - _ = try std.os.windows.WSAStartup(2, 2); - } - defer { - if (builtin.os.tag == .windows) { - std.os.windows.WSACleanup() catch unreachable; - } - } - - // Try only the IPv4 variant as some CI builders have no IPv6 localhost - // configured. - const localhost = try net.Address.parseIp("127.0.0.1", 0); - - var server = try localhost.listen(.{}); - defer server.deinit(); - - const S = struct { - fn clientFn(server_address: net.Address) !void { - const socket = try net.tcpConnectToAddress(server_address); - defer socket.close(); - - _ = try socket.writer().writeAll("Hello world!"); - } - }; - - const t = try std.Thread.spawn(.{}, S.clientFn, .{server.listen_address}); - defer t.join(); - - var client = try server.accept(); - defer client.stream.close(); - var buf: [16]u8 = undefined; - const n = try client.stream.reader().read(&buf); - - try testing.expectEqual(@as(usize, 12), n); - try testing.expectEqualSlices(u8, "Hello world!", buf[0..n]); -} - -test "listen on an in use port" { - if (builtin.os.tag != .linux and comptime !builtin.os.tag.isDarwin()) { - // TODO build abstractions for other operating systems - return error.SkipZigTest; - } - - const localhost = try net.Address.parseIp("127.0.0.1", 0); - - var server1 = try localhost.listen(.{ .reuse_port = true }); - defer server1.deinit(); - - var server2 = try server1.listen_address.listen(.{ .reuse_port = true }); - defer server2.deinit(); -} - -fn testClientToHost(allocator: mem.Allocator, name: []const u8, port: u16) anyerror!void { - if (builtin.os.tag == .wasi) return error.SkipZigTest; - - const connection = try net.tcpConnectToHost(allocator, name, port); - defer connection.close(); - - var buf: [100]u8 = undefined; - const len = try connection.read(&buf); - const msg = buf[0..len]; - try testing.expect(mem.eql(u8, msg, "hello from server\n")); -} - -fn testClient(addr: net.Address) anyerror!void { - if (builtin.os.tag == .wasi) return error.SkipZigTest; - - const socket_file = try net.tcpConnectToAddress(addr); - defer socket_file.close(); - - var buf: [100]u8 = undefined; - const len = try socket_file.read(&buf); - const msg = buf[0..len]; - try testing.expect(mem.eql(u8, msg, "hello from server\n")); -} - -fn testServer(server: *net.Server) anyerror!void { - if (builtin.os.tag == .wasi) return error.SkipZigTest; - - var client = try server.accept(); - - const stream = client.stream.writer(); - try stream.print("hello from server\n", .{}); -} - -test "listen on a unix socket, send bytes, receive bytes" { - if (builtin.single_threaded) return error.SkipZigTest; - if (!net.has_unix_sockets) return error.SkipZigTest; - - if (builtin.os.tag == .windows) { - _ = try std.os.windows.WSAStartup(2, 2); - } - defer { - if (builtin.os.tag == .windows) { - std.os.windows.WSACleanup() catch unreachable; - } - } - - const socket_path = try generateFileName("socket.unix"); - defer testing.allocator.free(socket_path); - - const socket_addr = try net.Address.initUnix(socket_path); - defer std.fs.cwd().deleteFile(socket_path) catch {}; - - var server = try socket_addr.listen(.{}); - defer server.deinit(); - - const S = struct { - fn clientFn(path: []const u8) !void { - const socket = try net.connectUnixSocket(path); - defer socket.close(); - - _ = try socket.writer().writeAll("Hello world!"); - } - }; - - const t = try std.Thread.spawn(.{}, S.clientFn, .{socket_path}); - defer t.join(); - - var client = try server.accept(); - defer client.stream.close(); - var buf: [16]u8 = undefined; - const n = try client.stream.reader().read(&buf); - - try testing.expectEqual(@as(usize, 12), n); - try testing.expectEqualSlices(u8, "Hello world!", buf[0..n]); -} - -fn generateFileName(base_name: []const u8) ![]const u8 { - const random_bytes_count = 12; - const sub_path_len = comptime std.fs.base64_encoder.calcSize(random_bytes_count); - var random_bytes: [12]u8 = undefined; - std.crypto.random.bytes(&random_bytes); - var sub_path: [sub_path_len]u8 = undefined; - _ = std.fs.base64_encoder.encode(&sub_path, &random_bytes); - return std.fmt.allocPrint(testing.allocator, "{s}-{s}", .{ sub_path[0..], base_name }); -} - -test "non-blocking tcp server" { - if (builtin.os.tag == .wasi) return error.SkipZigTest; - if (true) { - // https://github.com/ziglang/zig/issues/18315 - return error.SkipZigTest; - } - - const localhost = try net.Address.parseIp("127.0.0.1", 0); - var server = localhost.listen(.{ .force_nonblocking = true }); - defer server.deinit(); - - const accept_err = server.accept(); - try testing.expectError(error.WouldBlock, accept_err); - - const socket_file = try net.tcpConnectToAddress(server.listen_address); - defer socket_file.close(); - - var client = try server.accept(); - defer client.stream.close(); - const stream = client.stream.writer(); - try stream.print("hello from server\n", .{}); - - var buf: [100]u8 = undefined; - const len = try socket_file.read(&buf); - const msg = buf[0..len]; - try testing.expect(mem.eql(u8, msg, "hello from server\n")); -} diff --git a/src/http/async/tls.zig/PrivateKey.zig b/src/http/async/tls.zig/PrivateKey.zig deleted file mode 100644 index 0e2b944d..00000000 --- a/src/http/async/tls.zig/PrivateKey.zig +++ /dev/null @@ -1,260 +0,0 @@ -const std = @import("std"); -const Allocator = std.mem.Allocator; -const Certificate = std.crypto.Certificate; -const der = Certificate.der; -const rsa = @import("rsa/rsa.zig"); -const base64 = std.base64.standard.decoderWithIgnore(" \t\r\n"); -const proto = @import("protocol.zig"); - -const max_ecdsa_key_len = 66; - -signature_scheme: proto.SignatureScheme, - -key: union { - rsa: rsa.KeyPair, - ecdsa: [max_ecdsa_key_len]u8, -}, - -const PrivateKey = @This(); - -pub fn fromFile(gpa: Allocator, file: std.fs.File) !PrivateKey { - const buf = try file.readToEndAlloc(gpa, 1024 * 1024); - defer gpa.free(buf); - return try parsePem(buf); -} - -pub fn parsePem(buf: []const u8) !PrivateKey { - const key_start, const key_end, const marker_version = try findKey(buf); - const encoded = std.mem.trim(u8, buf[key_start..key_end], " \t\r\n"); - - // required bytes: - // 2412, 1821, 1236 for rsa 4096, 3072, 2048 bits size keys - var decoded: [4096]u8 = undefined; - const n = try base64.decode(&decoded, encoded); - - if (marker_version == 2) { - return try parseEcDer(decoded[0..n]); - } - return try parseDer(decoded[0..n]); -} - -fn findKey(buf: []const u8) !struct { usize, usize, usize } { - const markers = [_]struct { - begin: []const u8, - end: []const u8, - }{ - .{ .begin = "-----BEGIN PRIVATE KEY-----", .end = "-----END PRIVATE KEY-----" }, - .{ .begin = "-----BEGIN EC PRIVATE KEY-----", .end = "-----END EC PRIVATE KEY-----" }, - }; - - for (markers, 1..) |marker, ver| { - const begin_marker_start = std.mem.indexOfPos(u8, buf, 0, marker.begin) orelse continue; - const key_start = begin_marker_start + marker.begin.len; - const key_end = std.mem.indexOfPos(u8, buf, key_start, marker.end) orelse continue; - - return .{ key_start, key_end, ver }; - } - return error.MissingEndMarker; -} - -// ref: https://asn1js.eu/#MIG2AgEAMBAGByqGSM49AgEGBSuBBAAiBIGeMIGbAgEBBDBKFkVJCtU9FR6egz3yNxKBwXd86cFzMYqyGb8hRc1zVvLdw-So_2FBtITp6jzYmFShZANiAAQ-CH3a1R0V6dFlTK8Rs4M4egrpPtdta0osysO0Zl8mkBiDsTlvJNqeAp7L2ItHgFW8k_CfhgQT6iLDacNMhKC4XOV07r_ePD-mmkvqvRmzfOowHUoVRhCKrOTmF_J9Syc -pub fn parseDer(buf: []const u8) !PrivateKey { - const info = try der.Element.parse(buf, 0); - const version = try der.Element.parse(buf, info.slice.start); - - const algo_seq = try der.Element.parse(buf, version.slice.end); - const algo_cat = try der.Element.parse(buf, algo_seq.slice.start); - - const key_str = try der.Element.parse(buf, algo_seq.slice.end); - const key_seq = try der.Element.parse(buf, key_str.slice.start); - const key_int = try der.Element.parse(buf, key_seq.slice.start); - - const category = try Certificate.parseAlgorithmCategory(buf, algo_cat); - switch (category) { - .rsaEncryption => { - const modulus = try der.Element.parse(buf, key_int.slice.end); - const public_exponent = try der.Element.parse(buf, modulus.slice.end); - const private_exponent = try der.Element.parse(buf, public_exponent.slice.end); - - const public_key = try rsa.PublicKey.fromBytes(content(buf, modulus), content(buf, public_exponent)); - const secret_key = try rsa.SecretKey.fromBytes(public_key.modulus, content(buf, private_exponent)); - const key_pair = rsa.KeyPair{ .public = public_key, .secret = secret_key }; - - return .{ - .signature_scheme = switch (key_pair.public.modulus.bits()) { - 4096 => .rsa_pss_rsae_sha512, - 3072 => .rsa_pss_rsae_sha384, - else => .rsa_pss_rsae_sha256, - }, - .key = .{ .rsa = key_pair }, - }; - }, - .X9_62_id_ecPublicKey => { - const key = try der.Element.parse(buf, key_int.slice.end); - const algo_param = try der.Element.parse(buf, algo_cat.slice.end); - const named_curve = try Certificate.parseNamedCurve(buf, algo_param); - return .{ - .signature_scheme = signatureScheme(named_curve), - .key = .{ .ecdsa = ecdsaKey(buf, key) }, - }; - }, - else => unreachable, - } -} - -// References: -// https://asn1js.eu/#MHcCAQEEINJSRKv8kSKEzLHptfAlg-LGh4_pHHlq0XLf30Q9pcztoAoGCCqGSM49AwEHoUQDQgAEJpmLyp8aGCgyMcFIJaIq_-4V1K6nPpeoih3bT2npeplF9eyXj7rm8eW9Ua6VLhq71mqtMC-YLm-IkORBVq1cuA -// https://www.rfc-editor.org/rfc/rfc5915 -pub fn parseEcDer(bytes: []const u8) !PrivateKey { - const pki_msg = try der.Element.parse(bytes, 0); - const version = try der.Element.parse(bytes, pki_msg.slice.start); - const key = try der.Element.parse(bytes, version.slice.end); - const parameters = try der.Element.parse(bytes, key.slice.end); - const curve = try der.Element.parse(bytes, parameters.slice.start); - const named_curve = try Certificate.parseNamedCurve(bytes, curve); - return .{ - .signature_scheme = signatureScheme(named_curve), - .key = .{ .ecdsa = ecdsaKey(bytes, key) }, - }; -} - -fn signatureScheme(named_curve: Certificate.NamedCurve) proto.SignatureScheme { - return switch (named_curve) { - .X9_62_prime256v1 => .ecdsa_secp256r1_sha256, - .secp384r1 => .ecdsa_secp384r1_sha384, - .secp521r1 => .ecdsa_secp521r1_sha512, - }; -} - -fn ecdsaKey(bytes: []const u8, e: der.Element) [max_ecdsa_key_len]u8 { - const data = content(bytes, e); - var ecdsa_key: [max_ecdsa_key_len]u8 = undefined; - @memcpy(ecdsa_key[0..data.len], data); - return ecdsa_key; -} - -fn content(bytes: []const u8, e: der.Element) []const u8 { - return bytes[e.slice.start..e.slice.end]; -} - -const testing = std.testing; -const testu = @import("testu.zig"); - -test "parse ec pem" { - const data = @embedFile("testdata/ec_private_key.pem"); - var pk = try parsePem(data); - const priv_key = &testu.hexToBytes( - \\ 10 35 3d ca 1b 15 1d 06 aa 71 b8 ef f3 19 22 - \\ 43 78 f3 20 98 1e b1 2f 2b 64 7e 71 d0 30 2a - \\ 90 aa e5 eb 99 c3 90 65 3d c1 26 19 be 3f 08 - \\ 20 9b 01 - ); - try testing.expectEqualSlices(u8, priv_key, pk.key.ecdsa[0..priv_key.len]); - try testing.expectEqual(.ecdsa_secp384r1_sha384, pk.signature_scheme); -} - -test "parse ec prime256v1" { - const data = @embedFile("testdata/ec_prime256v1_private_key.pem"); - var pk = try parsePem(data); - const priv_key = &testu.hexToBytes( - \\ d2 52 44 ab fc 91 22 84 cc b1 e9 b5 f0 25 83 - \\ e2 c6 87 8f e9 1c 79 6a d1 72 df df 44 3d a5 - \\ cc ed - ); - try testing.expectEqualSlices(u8, priv_key, pk.key.ecdsa[0..priv_key.len]); - try testing.expectEqual(.ecdsa_secp256r1_sha256, pk.signature_scheme); -} - -test "parse ec secp384r1" { - const data = @embedFile("testdata/ec_secp384r1_private_key.pem"); - var pk = try parsePem(data); - const priv_key = &testu.hexToBytes( - \\ ee 6d 8a 5e 0d d3 b0 c6 4b 32 40 80 e2 3a de - \\ 8b 1e dd e2 92 db 36 1c db 91 ea ba a1 06 0d - \\ 42 2d d9 a9 dc 05 43 29 f1 78 7c f9 08 af c5 - \\ 03 1f 6d - ); - try testing.expectEqualSlices(u8, priv_key, pk.key.ecdsa[0..priv_key.len]); - try testing.expectEqual(.ecdsa_secp384r1_sha384, pk.signature_scheme); -} - -test "parse ec secp521r1" { - const data = @embedFile("testdata/ec_secp521r1_private_key.pem"); - var pk = try parsePem(data); - const priv_key = &testu.hexToBytes( - \\ 01 f0 2f 5a c7 24 18 ea 68 23 8c 2e a1 b4 b8 - \\ dc f2 11 b2 96 b0 ec 87 80 42 bf de ba f4 96 - \\ 83 8f 9b db c6 60 a7 4c d9 60 3a e4 ba 0b df - \\ ae 24 d3 1b c2 6e 82 a0 88 c1 ed 17 20 0d 3a - \\ f1 c5 7e e8 0b 27 - ); - try testing.expectEqualSlices(u8, priv_key, pk.key.ecdsa[0..priv_key.len]); - try testing.expectEqual(.ecdsa_secp521r1_sha512, pk.signature_scheme); -} - -test "parse rsa pem" { - const data = @embedFile("testdata/rsa_private_key.pem"); - const pk = try parsePem(data); - - // expected results from: - // $ openssl pkey -in testdata/rsa_private_key.pem -text -noout - const modulus = &testu.hexToBytes( - \\ 00 de f7 23 e6 75 cc 6f dd d5 6e 0f 8c 09 f8 - \\ 62 e3 60 1b c0 7d 8c d5 04 50 2c 36 e2 3b f7 - \\ 33 9f a1 14 af be cf 1a 0f 4c f5 cb 39 70 0e - \\ 3b 97 d6 21 f7 48 91 79 ca 7c 68 fc ea 62 a1 - \\ 5a 72 4f 78 57 0e cc f2 a3 50 05 f1 4c ca 51 - \\ 73 10 9a 18 8e 71 f5 b4 c7 3e be 4c ef 37 d4 - \\ 84 4b 82 1c ec 08 a3 cc 07 3d 5c 0b e5 85 3f - \\ fe b6 44 77 8f 3c 6a 2f 33 c3 5d f6 f2 29 46 - \\ 04 25 7e 05 d9 f8 3b 2d a4 40 66 9f 0d 6d 1a - \\ fa bc 0a c5 8b 86 43 30 ef 14 20 41 9d b5 cc - \\ 3e 63 b5 48 04 27 c9 5c d3 62 28 5f f5 b6 e4 - \\ 77 49 99 ac 84 4a a6 67 a5 9a 1a 37 c7 60 4c - \\ ba c1 70 cf 57 64 4a 21 ea 05 53 10 ec 94 71 - \\ 4a 43 04 83 00 aa 5a 28 bc f2 8c 58 14 92 d2 - \\ 83 17 f4 7b 29 0f e7 87 a2 47 b2 53 19 12 23 - \\ fb 4b ce 5a f8 a1 84 f9 b1 f3 bf e3 fa 10 f8 - \\ ad af 87 ce 03 0e a0 2c 13 71 57 c4 55 44 48 - \\ 44 cb - ); - const public_exponent = &testu.hexToBytes("01 00 01"); - const private_exponent = &testu.hexToBytes( - \\ 50 3b 80 98 aa a5 11 50 33 40 32 aa 02 e0 75 - \\ bd 3a 55 62 34 0b 9c 8f bb c5 dd 4e 15 a4 03 - \\ d8 9a 5f 56 4a 84 3d ed 69 95 3d 37 03 02 ac - \\ 21 1c 36 06 c4 ff 4c 63 37 d7 93 c3 48 10 a5 - \\ fa 62 6c 7c 6f 60 02 a4 0f e4 c3 8b 0d 76 b7 - \\ c0 2e a3 4d 86 e6 92 d1 eb db 10 d6 38 31 ea - \\ 15 3d d1 e8 81 c7 67 60 e7 8c 9a df 51 ce d0 - \\ 7a 88 32 b9 c1 54 b8 7d 98 fc d4 23 1a 05 0e - \\ f2 ea e1 72 29 28 2a 68 b7 90 18 80 1c 21 d6 - \\ 36 a8 6b 4a 9c dd 14 b8 9f 85 ee 95 0b f4 c6 - \\ 17 02 aa 4d ea 4d f9 39 d7 dd 9d b4 1d d2 f8 - \\ 92 46 0f 18 41 80 f4 ea 27 55 29 f8 37 59 bf - \\ 43 ec a3 eb 19 ba bc 13 06 95 3d 25 4b c9 72 - \\ cf 41 0a 6f aa cb 79 d4 7b fa b1 09 7c e2 2f - \\ 85 51 44 8b c6 97 8e 46 f9 6b ac 08 87 92 ce - \\ af 0b bf 8c bd 27 51 8f 09 e4 d3 f9 04 ac fa - \\ f2 04 70 3e d9 a6 28 17 c2 2d 74 e9 25 40 02 - \\ 49 - ); - - try testing.expectEqual(.rsa_pss_rsae_sha256, pk.signature_scheme); - const kp = pk.key.rsa; - { - var bytes: [modulus.len]u8 = undefined; - try kp.public.modulus.toBytes(&bytes, .big); - try testing.expectEqualSlices(u8, modulus, &bytes); - } - { - var bytes: [private_exponent.len]u8 = undefined; - try kp.public.public_exponent.toBytes(&bytes, .big); - try testing.expectEqualSlices(u8, public_exponent, bytes[bytes.len - public_exponent.len .. bytes.len]); - } - { - var btytes: [private_exponent.len]u8 = undefined; - try kp.secret.private_exponent.toBytes(&btytes, .big); - try testing.expectEqualSlices(u8, private_exponent, &btytes); - } -} diff --git a/src/http/async/tls.zig/cbc/main.zig b/src/http/async/tls.zig/cbc/main.zig deleted file mode 100644 index 25038445..00000000 --- a/src/http/async/tls.zig/cbc/main.zig +++ /dev/null @@ -1,148 +0,0 @@ -// This file is originally copied from: https://github.com/jedisct1/zig-cbc. -// -// It is modified then to have TLS padding insead of PKCS#7 padding. -// Reference: -// https://datatracker.ietf.org/doc/html/rfc5246/#section-6.2.3.2 -// https://crypto.stackexchange.com/questions/98917/on-the-correctness-of-the-padding-example-of-rfc-5246 -// -// If required padding i n bytes -// PKCS#7 padding is (n...n) -// TLS padding is (n-1...n-1) - n times of n-1 value -// -const std = @import("std"); -const aes = std.crypto.core.aes; -const mem = std.mem; -const debug = std.debug; - -/// CBC mode with TLS 1.2 padding -/// -/// Important: the counter mode doesn't provide authenticated encryption: the ciphertext can be trivially modified without this being detected. -/// If you need authenticated encryption, use anything from `std.crypto.aead` instead. -/// If you really need to use CBC mode, make sure to use a MAC to authenticate the ciphertext. -pub fn CBC(comptime BlockCipher: anytype) type { - const EncryptCtx = aes.AesEncryptCtx(BlockCipher); - const DecryptCtx = aes.AesDecryptCtx(BlockCipher); - - return struct { - const Self = @This(); - - enc_ctx: EncryptCtx, - dec_ctx: DecryptCtx, - - /// Initialize the CBC context with the given key. - pub fn init(key: [BlockCipher.key_bits / 8]u8) Self { - const enc_ctx = BlockCipher.initEnc(key); - const dec_ctx = DecryptCtx.initFromEnc(enc_ctx); - - return Self{ .enc_ctx = enc_ctx, .dec_ctx = dec_ctx }; - } - - /// Return the length of the ciphertext given the length of the plaintext. - pub fn paddedLength(length: usize) usize { - return (std.math.divCeil(usize, length + 1, EncryptCtx.block_length) catch unreachable) * EncryptCtx.block_length; - } - - /// Encrypt the given plaintext for the given IV. - /// The destination buffer must be large enough to hold the padded plaintext. - /// Use the `paddedLength()` function to compute the ciphertext size. - /// IV must be secret and unpredictable. - pub fn encrypt(self: Self, dst: []u8, src: []const u8, iv: [EncryptCtx.block_length]u8) void { - // Note: encryption *could* be parallelized, see https://research.kudelskisecurity.com/2022/11/17/some-aes-cbc-encryption-myth-busting/ - const block_length = EncryptCtx.block_length; - const padded_length = paddedLength(src.len); - debug.assert(dst.len == padded_length); // destination buffer must hold the padded plaintext - var cv = iv; - var i: usize = 0; - while (i + block_length <= src.len) : (i += block_length) { - const in = src[i..][0..block_length]; - for (cv[0..], in) |*x, y| x.* ^= y; - self.enc_ctx.encrypt(&cv, &cv); - @memcpy(dst[i..][0..block_length], &cv); - } - // Last block - var in = [_]u8{0} ** block_length; - const padding_length: u8 = @intCast(padded_length - src.len - 1); - @memset(&in, padding_length); - @memcpy(in[0 .. src.len - i], src[i..]); - for (cv[0..], in) |*x, y| x.* ^= y; - self.enc_ctx.encrypt(&cv, &cv); - @memcpy(dst[i..], cv[0 .. dst.len - i]); - } - - /// Decrypt the given ciphertext for the given IV. - /// The destination buffer must be large enough to hold the plaintext. - /// IV must be secret, unpredictable and match the one used for encryption. - pub fn decrypt(self: Self, dst: []u8, src: []const u8, iv: [DecryptCtx.block_length]u8) !void { - const block_length = DecryptCtx.block_length; - if (src.len != dst.len) { - return error.EncodingError; - } - debug.assert(src.len % block_length == 0); - var i: usize = 0; - var cv = iv; - var out: [block_length]u8 = undefined; - // Decryption could be parallelized - while (i + block_length <= dst.len) : (i += block_length) { - const in = src[i..][0..block_length]; - self.dec_ctx.decrypt(&out, in); - for (&out, cv) |*x, y| x.* ^= y; - cv = in.*; - @memcpy(dst[i..][0..block_length], &out); - } - // Last block - We intentionally don't check the padding to mitigate timing attacks - if (i < dst.len) { - const in = src[i..][0..block_length]; - @memset(&out, 0); - self.dec_ctx.decrypt(&out, in); - for (&out, cv) |*x, y| x.* ^= y; - @memcpy(dst[i..], out[0 .. dst.len - i]); - } - } - }; -} - -test "CBC mode" { - const M = CBC(aes.Aes128); - const key = [_]u8{ 0x2b, 0x7e, 0x15, 0x16, 0x28, 0xae, 0xd2, 0xa6, 0xab, 0xf7, 0x15, 0x88, 0x09, 0xcf, 0x4f, 0x3c }; - const iv = [_]u8{ 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f }; - const src_ = "This is a test of AES-CBC that goes on longer than a couple blocks. It is a somewhat long test case to type out!"; - const expected = "\xA0\x8C\x09\x7D\xFF\x42\xB6\x65\x4D\x4B\xC6\x90\x90\x39\xDE\x3D\xC7\xCA\xEB\xF6\x9A\x4F\x09\x97\xC9\x32\xAB\x75\x88\xB7\x57\x17"; - var res: [32]u8 = undefined; - - try comptime std.testing.expect(src_.len / M.paddedLength(1) >= 3); // Ensure that we have at least 3 blocks - - const z = M.init(key); - - // Test encryption and decryption with distinct buffers - var h = std.crypto.hash.sha2.Sha256.init(.{}); - inline for (0..src_.len) |len| { - const src = src_[0..len]; - var dst = [_]u8{0} ** M.paddedLength(src.len); - - z.encrypt(&dst, src, iv); - h.update(&dst); - - var decrypted = [_]u8{0} ** dst.len; - try z.decrypt(&decrypted, &dst, iv); - - const padding = decrypted[decrypted.len - 1] + 1; - try std.testing.expectEqualSlices(u8, src, decrypted[0 .. decrypted.len - padding]); - } - h.final(&res); - try std.testing.expectEqualSlices(u8, expected, &res); - - // Test encryption and decryption with the same buffer - h = std.crypto.hash.sha2.Sha256.init(.{}); - inline for (0..src_.len) |len| { - var buf = [_]u8{0} ** M.paddedLength(len); - @memcpy(buf[0..len], src_[0..len]); - z.encrypt(&buf, buf[0..len], iv); - h.update(&buf); - - try z.decrypt(&buf, &buf, iv); - - try std.testing.expectEqualSlices(u8, src_[0..len], buf[0..len]); - } - h.final(&res); - try std.testing.expectEqualSlices(u8, expected, &res); -} diff --git a/src/http/async/tls.zig/cipher.zig b/src/http/async/tls.zig/cipher.zig deleted file mode 100644 index dbf4a07a..00000000 --- a/src/http/async/tls.zig/cipher.zig +++ /dev/null @@ -1,1004 +0,0 @@ -const std = @import("std"); -const crypto = std.crypto; -const hkdfExpandLabel = crypto.tls.hkdfExpandLabel; - -const Sha1 = crypto.hash.Sha1; -const Sha256 = crypto.hash.sha2.Sha256; -const Sha384 = crypto.hash.sha2.Sha384; - -const record = @import("record.zig"); -const Record = record.Record; -const Transcript = @import("transcript.zig").Transcript; -const proto = @import("protocol.zig"); - -// tls 1.2 cbc cipher types -const CbcAes128Sha1 = CbcType(crypto.core.aes.Aes128, Sha1); -const CbcAes128Sha256 = CbcType(crypto.core.aes.Aes128, Sha256); -const CbcAes256Sha256 = CbcType(crypto.core.aes.Aes256, Sha256); -const CbcAes256Sha384 = CbcType(crypto.core.aes.Aes256, Sha384); -// tls 1.2 gcm cipher types -const Aead12Aes128Gcm = Aead12Type(crypto.aead.aes_gcm.Aes128Gcm); -const Aead12Aes256Gcm = Aead12Type(crypto.aead.aes_gcm.Aes256Gcm); -// tls 1.2 chacha cipher type -const Aead12ChaCha = Aead12ChaChaType(crypto.aead.chacha_poly.ChaCha20Poly1305); -// tls 1.3 cipher types -const Aes128GcmSha256 = Aead13Type(crypto.aead.aes_gcm.Aes128Gcm, Sha256); -const Aes256GcmSha384 = Aead13Type(crypto.aead.aes_gcm.Aes256Gcm, Sha384); -const ChaChaSha256 = Aead13Type(crypto.aead.chacha_poly.ChaCha20Poly1305, Sha256); -const Aegis128Sha256 = Aead13Type(crypto.aead.aegis.Aegis128L, Sha256); - -pub const encrypt_overhead_tls_12: comptime_int = @max( - CbcAes128Sha1.encrypt_overhead, - CbcAes128Sha256.encrypt_overhead, - CbcAes256Sha256.encrypt_overhead, - CbcAes256Sha384.encrypt_overhead, - Aead12Aes128Gcm.encrypt_overhead, - Aead12Aes256Gcm.encrypt_overhead, - Aead12ChaCha.encrypt_overhead, -); -pub const encrypt_overhead_tls_13: comptime_int = @max( - Aes128GcmSha256.encrypt_overhead, - Aes256GcmSha384.encrypt_overhead, - ChaChaSha256.encrypt_overhead, - Aegis128Sha256.encrypt_overhead, -); - -// ref (length): https://www.rfc-editor.org/rfc/rfc8446#section-5.1 -pub const max_cleartext_len = 1 << 14; -// ref (length): https://www.rfc-editor.org/rfc/rfc8446#section-5.2 -// The sum of the lengths of the content and the padding, plus one for the inner -// content type, plus any expansion added by the AEAD algorithm. -pub const max_ciphertext_len = max_cleartext_len + 256; -pub const max_ciphertext_record_len = record.header_len + max_ciphertext_len; - -/// Returns type for cipher suite tag. -fn CipherType(comptime tag: CipherSuite) type { - return switch (tag) { - // tls 1.2 cbc - .ECDHE_ECDSA_WITH_AES_128_CBC_SHA, - .ECDHE_RSA_WITH_AES_128_CBC_SHA, - .RSA_WITH_AES_128_CBC_SHA, - => CbcAes128Sha1, - .ECDHE_ECDSA_WITH_AES_128_CBC_SHA256, - .ECDHE_RSA_WITH_AES_128_CBC_SHA256, - .RSA_WITH_AES_128_CBC_SHA256, - => CbcAes128Sha256, - .ECDHE_ECDSA_WITH_AES_256_CBC_SHA384, - .ECDHE_RSA_WITH_AES_256_CBC_SHA384, - => CbcAes256Sha384, - - // tls 1.2 gcm - .ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, - .ECDHE_RSA_WITH_AES_128_GCM_SHA256, - => Aead12Aes128Gcm, - .ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, - .ECDHE_RSA_WITH_AES_256_GCM_SHA384, - => Aead12Aes256Gcm, - - // tls 1.2 chacha - .ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256, - .ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256, - => Aead12ChaCha, - - // tls 1.3 - .AES_128_GCM_SHA256 => Aes128GcmSha256, - .AES_256_GCM_SHA384 => Aes256GcmSha384, - .CHACHA20_POLY1305_SHA256 => ChaChaSha256, - .AEGIS_128L_SHA256 => Aegis128Sha256, - - else => unreachable, - }; -} - -/// Provides initialization and common encrypt/decrypt methods for all supported -/// ciphers. Tls 1.2 has only application cipher, tls 1.3 has separate cipher -/// for handshake and application. -pub const Cipher = union(CipherSuite) { - // tls 1.2 cbc - ECDHE_ECDSA_WITH_AES_128_CBC_SHA: CipherType(.ECDHE_ECDSA_WITH_AES_128_CBC_SHA), - ECDHE_RSA_WITH_AES_128_CBC_SHA: CipherType(.ECDHE_RSA_WITH_AES_128_CBC_SHA), - RSA_WITH_AES_128_CBC_SHA: CipherType(.RSA_WITH_AES_128_CBC_SHA), - - ECDHE_ECDSA_WITH_AES_128_CBC_SHA256: CipherType(.ECDHE_ECDSA_WITH_AES_128_CBC_SHA256), - ECDHE_RSA_WITH_AES_128_CBC_SHA256: CipherType(.ECDHE_RSA_WITH_AES_128_CBC_SHA256), - RSA_WITH_AES_128_CBC_SHA256: CipherType(.RSA_WITH_AES_128_CBC_SHA256), - - ECDHE_ECDSA_WITH_AES_256_CBC_SHA384: CipherType(.ECDHE_ECDSA_WITH_AES_256_CBC_SHA384), - ECDHE_RSA_WITH_AES_256_CBC_SHA384: CipherType(.ECDHE_RSA_WITH_AES_256_CBC_SHA384), - // tls 1.2 gcm - ECDHE_ECDSA_WITH_AES_128_GCM_SHA256: CipherType(.ECDHE_ECDSA_WITH_AES_128_GCM_SHA256), - ECDHE_ECDSA_WITH_AES_256_GCM_SHA384: CipherType(.ECDHE_RSA_WITH_AES_256_GCM_SHA384), - ECDHE_RSA_WITH_AES_128_GCM_SHA256: CipherType(.ECDHE_RSA_WITH_AES_128_GCM_SHA256), - ECDHE_RSA_WITH_AES_256_GCM_SHA384: CipherType(.ECDHE_RSA_WITH_AES_256_GCM_SHA384), - // tls 1.2 chacha - ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256: CipherType(.ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256), - ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256: CipherType(.ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256), - // tls 1.3 - AES_128_GCM_SHA256: CipherType(.AES_128_GCM_SHA256), - AES_256_GCM_SHA384: CipherType(.AES_256_GCM_SHA384), - CHACHA20_POLY1305_SHA256: CipherType(.CHACHA20_POLY1305_SHA256), - AEGIS_128L_SHA256: CipherType(.AEGIS_128L_SHA256), - - // tls 1.2 application cipher - pub fn initTls12(tag: CipherSuite, key_material: []const u8, side: proto.Side) !Cipher { - switch (tag) { - inline .ECDHE_ECDSA_WITH_AES_128_CBC_SHA, - .ECDHE_RSA_WITH_AES_128_CBC_SHA, - .RSA_WITH_AES_128_CBC_SHA, - .ECDHE_ECDSA_WITH_AES_128_CBC_SHA256, - .ECDHE_RSA_WITH_AES_128_CBC_SHA256, - .RSA_WITH_AES_128_CBC_SHA256, - .ECDHE_ECDSA_WITH_AES_256_CBC_SHA384, - .ECDHE_RSA_WITH_AES_256_CBC_SHA384, - .ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, - .ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, - .ECDHE_RSA_WITH_AES_128_GCM_SHA256, - .ECDHE_RSA_WITH_AES_256_GCM_SHA384, - .ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256, - .ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256, - => |comptime_tag| { - return @unionInit(Cipher, @tagName(comptime_tag), CipherType(comptime_tag).init(key_material, side)); - }, - else => return error.TlsIllegalParameter, - } - } - - // tls 1.3 handshake or application cipher - pub fn initTls13(tag: CipherSuite, secret: Transcript.Secret, side: proto.Side) !Cipher { - return switch (tag) { - inline .AES_128_GCM_SHA256, - .AES_256_GCM_SHA384, - .CHACHA20_POLY1305_SHA256, - .AEGIS_128L_SHA256, - => |comptime_tag| { - return @unionInit(Cipher, @tagName(comptime_tag), CipherType(comptime_tag).init(secret, side)); - }, - else => return error.TlsIllegalParameter, - }; - } - - pub fn encrypt( - c: *Cipher, - buf: []u8, - content_type: proto.ContentType, - cleartext: []const u8, - ) ![]const u8 { - return switch (c.*) { - inline else => |*f| try f.encrypt(buf, content_type, cleartext), - }; - } - - pub fn decrypt( - c: *Cipher, - buf: []u8, - rec: Record, - ) !struct { proto.ContentType, []u8 } { - return switch (c.*) { - inline else => |*f| { - const content_type, const cleartext = try f.decrypt(buf, rec); - if (cleartext.len > max_cleartext_len) return error.TlsRecordOverflow; - return .{ content_type, cleartext }; - }, - }; - } - - pub fn encryptSeq(c: Cipher) u64 { - return switch (c) { - inline else => |f| f.encrypt_seq, - }; - } - - pub fn keyUpdateEncrypt(c: *Cipher) !void { - return switch (c.*) { - inline .AES_128_GCM_SHA256, - .AES_256_GCM_SHA384, - .CHACHA20_POLY1305_SHA256, - .AEGIS_128L_SHA256, - => |*f| f.keyUpdateEncrypt(), - // can't happen on tls 1.2 - else => return error.TlsUnexpectedMessage, - }; - } - pub fn keyUpdateDecrypt(c: *Cipher) !void { - return switch (c.*) { - inline .AES_128_GCM_SHA256, - .AES_256_GCM_SHA384, - .CHACHA20_POLY1305_SHA256, - .AEGIS_128L_SHA256, - => |*f| f.keyUpdateDecrypt(), - // can't happen on tls 1.2 - else => return error.TlsUnexpectedMessage, - }; - } -}; - -fn Aead12Type(comptime AeadType: type) type { - return struct { - const explicit_iv_len = 8; - const key_len = AeadType.key_length; - const auth_tag_len = AeadType.tag_length; - const nonce_len = AeadType.nonce_length; - const iv_len = AeadType.nonce_length - explicit_iv_len; - const encrypt_overhead = record.header_len + explicit_iv_len + auth_tag_len; - - encrypt_key: [key_len]u8, - decrypt_key: [key_len]u8, - encrypt_iv: [iv_len]u8, - decrypt_iv: [iv_len]u8, - encrypt_seq: u64 = 0, - decrypt_seq: u64 = 0, - rnd: std.Random = crypto.random, - - const Self = @This(); - - fn init(key_material: []const u8, side: proto.Side) Self { - const client_key = key_material[0..key_len].*; - const server_key = key_material[key_len..][0..key_len].*; - const client_iv = key_material[2 * key_len ..][0..iv_len].*; - const server_iv = key_material[2 * key_len + iv_len ..][0..iv_len].*; - return .{ - .encrypt_key = if (side == .client) client_key else server_key, - .decrypt_key = if (side == .client) server_key else client_key, - .encrypt_iv = if (side == .client) client_iv else server_iv, - .decrypt_iv = if (side == .client) server_iv else client_iv, - }; - } - - /// Returns encrypted tls record in format: - /// ----------------- buf ---------------------- - /// header | explicit_iv | ciphertext | auth_tag - /// - /// tls record header: 5 bytes - /// explicit_iv: 8 bytes - /// ciphertext: same length as cleartext - /// auth_tag: 16 bytes - pub fn encrypt( - self: *Self, - buf: []u8, - content_type: proto.ContentType, - cleartext: []const u8, - ) ![]const u8 { - const record_len = record.header_len + explicit_iv_len + cleartext.len + auth_tag_len; - if (buf.len < record_len) return error.BufferOverflow; - - const header = buf[0..record.header_len]; - const explicit_iv = buf[record.header_len..][0..explicit_iv_len]; - const ciphertext = buf[record.header_len + explicit_iv_len ..][0..cleartext.len]; - const auth_tag = buf[record.header_len + explicit_iv_len + cleartext.len ..][0..auth_tag_len]; - - header.* = record.header(content_type, explicit_iv_len + cleartext.len + auth_tag_len); - self.rnd.bytes(explicit_iv); - const iv = self.encrypt_iv ++ explicit_iv.*; - const ad = additionalData(self.encrypt_seq, content_type, cleartext.len); - AeadType.encrypt(ciphertext, auth_tag, cleartext, &ad, iv, self.encrypt_key); - self.encrypt_seq +%= 1; - - return buf[0..record_len]; - } - - /// Decrypts payload into cleartext. Returns tls record content type and - /// cleartext. - /// Accepts tls record header and payload: - /// header | ----------- payload --------------- - /// header | explicit_iv | ciphertext | auth_tag - pub fn decrypt( - self: *Self, - buf: []u8, - rec: Record, - ) !struct { proto.ContentType, []u8 } { - const overhead = explicit_iv_len + auth_tag_len; - if (rec.payload.len < overhead) return error.TlsDecryptError; - const cleartext_len = rec.payload.len - overhead; - if (buf.len < cleartext_len) return error.BufferOverflow; - - const explicit_iv = rec.payload[0..explicit_iv_len]; - const ciphertext = rec.payload[explicit_iv_len..][0..cleartext_len]; - const auth_tag = rec.payload[explicit_iv_len + cleartext_len ..][0..auth_tag_len]; - - const iv = self.decrypt_iv ++ explicit_iv.*; - const ad = additionalData(self.decrypt_seq, rec.content_type, cleartext_len); - const cleartext = buf[0..cleartext_len]; - AeadType.decrypt(cleartext, ciphertext, auth_tag.*, &ad, iv, self.decrypt_key) catch return error.TlsDecryptError; - self.decrypt_seq +%= 1; - return .{ rec.content_type, cleartext }; - } - }; -} - -fn Aead12ChaChaType(comptime AeadType: type) type { - return struct { - const key_len = AeadType.key_length; - const auth_tag_len = AeadType.tag_length; - const nonce_len = AeadType.nonce_length; - const encrypt_overhead = record.header_len + auth_tag_len; - - encrypt_key: [key_len]u8, - decrypt_key: [key_len]u8, - encrypt_iv: [nonce_len]u8, - decrypt_iv: [nonce_len]u8, - encrypt_seq: u64 = 0, - decrypt_seq: u64 = 0, - - const Self = @This(); - - fn init(key_material: []const u8, side: proto.Side) Self { - const client_key = key_material[0..key_len].*; - const server_key = key_material[key_len..][0..key_len].*; - const client_iv = key_material[2 * key_len ..][0..nonce_len].*; - const server_iv = key_material[2 * key_len + nonce_len ..][0..nonce_len].*; - return .{ - .encrypt_key = if (side == .client) client_key else server_key, - .decrypt_key = if (side == .client) server_key else client_key, - .encrypt_iv = if (side == .client) client_iv else server_iv, - .decrypt_iv = if (side == .client) server_iv else client_iv, - }; - } - - /// Returns encrypted tls record in format: - /// ------------ buf ------------- - /// header | ciphertext | auth_tag - /// - /// tls record header: 5 bytes - /// ciphertext: same length as cleartext - /// auth_tag: 16 bytes - pub fn encrypt( - self: *Self, - buf: []u8, - content_type: proto.ContentType, - cleartext: []const u8, - ) ![]const u8 { - const record_len = record.header_len + cleartext.len + auth_tag_len; - if (buf.len < record_len) return error.BufferOverflow; - - const ciphertext = buf[record.header_len..][0..cleartext.len]; - const auth_tag = buf[record.header_len + ciphertext.len ..][0..auth_tag_len]; - - const ad = additionalData(self.encrypt_seq, content_type, cleartext.len); - const iv = ivWithSeq(nonce_len, self.encrypt_iv, self.encrypt_seq); - AeadType.encrypt(ciphertext, auth_tag, cleartext, &ad, iv, self.encrypt_key); - self.encrypt_seq +%= 1; - - buf[0..record.header_len].* = record.header(content_type, ciphertext.len + auth_tag.len); - return buf[0..record_len]; - } - - /// Decrypts payload into cleartext. Returns tls record content type and - /// cleartext. - /// Accepts tls record header and payload: - /// header | ----- payload ------- - /// header | ciphertext | auth_tag - pub fn decrypt( - self: *Self, - buf: []u8, - rec: Record, - ) !struct { proto.ContentType, []u8 } { - const overhead = auth_tag_len; - if (rec.payload.len < overhead) return error.TlsDecryptError; - const cleartext_len = rec.payload.len - overhead; - if (buf.len < cleartext_len) return error.BufferOverflow; - - const ciphertext = rec.payload[0..cleartext_len]; - const auth_tag = rec.payload[cleartext_len..][0..auth_tag_len]; - const cleartext = buf[0..cleartext_len]; - - const ad = additionalData(self.decrypt_seq, rec.content_type, cleartext_len); - const iv = ivWithSeq(nonce_len, self.decrypt_iv, self.decrypt_seq); - AeadType.decrypt(cleartext, ciphertext, auth_tag.*, &ad, iv, self.decrypt_key) catch return error.TlsDecryptError; - self.decrypt_seq +%= 1; - return .{ rec.content_type, cleartext }; - } - }; -} - -fn Aead13Type(comptime AeadType: type, comptime Hash: type) type { - return struct { - const Hmac = crypto.auth.hmac.Hmac(Hash); - const Hkdf = crypto.kdf.hkdf.Hkdf(Hmac); - - const key_len = AeadType.key_length; - const auth_tag_len = AeadType.tag_length; - const nonce_len = AeadType.nonce_length; - const digest_len = Hash.digest_length; - const encrypt_overhead = record.header_len + 1 + auth_tag_len; - - encrypt_secret: [digest_len]u8, - decrypt_secret: [digest_len]u8, - encrypt_key: [key_len]u8, - decrypt_key: [key_len]u8, - encrypt_iv: [nonce_len]u8, - decrypt_iv: [nonce_len]u8, - encrypt_seq: u64 = 0, - decrypt_seq: u64 = 0, - - const Self = @This(); - - pub fn init(secret: Transcript.Secret, side: proto.Side) Self { - var self = Self{ - .encrypt_secret = if (side == .client) secret.client[0..digest_len].* else secret.server[0..digest_len].*, - .decrypt_secret = if (side == .server) secret.client[0..digest_len].* else secret.server[0..digest_len].*, - .encrypt_key = undefined, - .decrypt_key = undefined, - .encrypt_iv = undefined, - .decrypt_iv = undefined, - }; - self.keyGenerate(); - return self; - } - - fn keyGenerate(self: *Self) void { - self.encrypt_key = hkdfExpandLabel(Hkdf, self.encrypt_secret, "key", "", key_len); - self.decrypt_key = hkdfExpandLabel(Hkdf, self.decrypt_secret, "key", "", key_len); - self.encrypt_iv = hkdfExpandLabel(Hkdf, self.encrypt_secret, "iv", "", nonce_len); - self.decrypt_iv = hkdfExpandLabel(Hkdf, self.decrypt_secret, "iv", "", nonce_len); - } - - pub fn keyUpdateEncrypt(self: *Self) void { - self.encrypt_secret = hkdfExpandLabel(Hkdf, self.encrypt_secret, "traffic upd", "", digest_len); - self.encrypt_seq = 0; - self.keyGenerate(); - } - - pub fn keyUpdateDecrypt(self: *Self) void { - self.decrypt_secret = hkdfExpandLabel(Hkdf, self.decrypt_secret, "traffic upd", "", digest_len); - self.decrypt_seq = 0; - self.keyGenerate(); - } - - /// Returns encrypted tls record in format: - /// ------------ buf ------------- - /// header | ciphertext | auth_tag - /// - /// tls record header: 5 bytes - /// ciphertext: cleartext len + 1 byte content type - /// auth_tag: 16 bytes - pub fn encrypt( - self: *Self, - buf: []u8, - content_type: proto.ContentType, - cleartext: []const u8, - ) ![]const u8 { - const payload_len = cleartext.len + 1 + auth_tag_len; - const record_len = record.header_len + payload_len; - if (buf.len < record_len) return error.BufferOverflow; - - const header = buf[0..record.header_len]; - header.* = record.header(.application_data, payload_len); - - // Skip @memcpy if cleartext is already part of the buf at right position - if (&cleartext[0] != &buf[record.header_len]) { - @memcpy(buf[record.header_len..][0..cleartext.len], cleartext); - } - buf[record.header_len + cleartext.len] = @intFromEnum(content_type); - const ciphertext = buf[record.header_len..][0 .. cleartext.len + 1]; - const auth_tag = buf[record.header_len + ciphertext.len ..][0..auth_tag_len]; - - const iv = ivWithSeq(nonce_len, self.encrypt_iv, self.encrypt_seq); - AeadType.encrypt(ciphertext, auth_tag, ciphertext, header, iv, self.encrypt_key); - self.encrypt_seq +%= 1; - return buf[0..record_len]; - } - - /// Decrypts payload into cleartext. Returns tls record content type and - /// cleartext. - /// Accepts tls record header and payload: - /// header | ------- payload --------- - /// header | ciphertext | auth_tag - /// header | cleartext + ct | auth_tag - /// Ciphertext after decryption contains cleartext and content type (1 byte). - pub fn decrypt( - self: *Self, - buf: []u8, - rec: Record, - ) !struct { proto.ContentType, []u8 } { - const overhead = auth_tag_len + 1; - if (rec.payload.len < overhead) return error.TlsDecryptError; - const ciphertext_len = rec.payload.len - auth_tag_len; - if (buf.len < ciphertext_len) return error.BufferOverflow; - - const ciphertext = rec.payload[0..ciphertext_len]; - const auth_tag = rec.payload[ciphertext_len..][0..auth_tag_len]; - - const iv = ivWithSeq(nonce_len, self.decrypt_iv, self.decrypt_seq); - AeadType.decrypt(buf[0..ciphertext_len], ciphertext, auth_tag.*, rec.header, iv, self.decrypt_key) catch return error.TlsBadRecordMac; - - // Remove zero bytes padding - var content_type_idx: usize = ciphertext_len - 1; - while (buf[content_type_idx] == 0 and content_type_idx > 0) : (content_type_idx -= 1) {} - - const cleartext = buf[0..content_type_idx]; - const content_type: proto.ContentType = @enumFromInt(buf[content_type_idx]); - self.decrypt_seq +%= 1; - return .{ content_type, cleartext }; - } - }; -} - -fn CbcType(comptime BlockCipher: type, comptime HashType: type) type { - const CBC = @import("cbc/main.zig").CBC(BlockCipher); - return struct { - const mac_len = HashType.digest_length; // 20, 32, 48 bytes for sha1, sha256, sha384 - const key_len = BlockCipher.key_bits / 8; // 16, 32 for Aes128, Aes256 - const iv_len = 16; - const encrypt_overhead = record.header_len + iv_len + mac_len + max_padding; - - pub const Hmac = crypto.auth.hmac.Hmac(HashType); - const paddedLength = CBC.paddedLength; - const max_padding = 16; - - encrypt_secret: [mac_len]u8, - decrypt_secret: [mac_len]u8, - encrypt_key: [key_len]u8, - decrypt_key: [key_len]u8, - encrypt_seq: u64 = 0, - decrypt_seq: u64 = 0, - rnd: std.Random = crypto.random, - - const Self = @This(); - - fn init(key_material: []const u8, side: proto.Side) Self { - const client_secret = key_material[0..mac_len].*; - const server_secret = key_material[mac_len..][0..mac_len].*; - const client_key = key_material[2 * mac_len ..][0..key_len].*; - const server_key = key_material[2 * mac_len + key_len ..][0..key_len].*; - return .{ - .encrypt_secret = if (side == .client) client_secret else server_secret, - .decrypt_secret = if (side == .client) server_secret else client_secret, - .encrypt_key = if (side == .client) client_key else server_key, - .decrypt_key = if (side == .client) server_key else client_key, - }; - } - - /// Returns encrypted tls record in format: - /// ----------------- buf ----------------- - /// header | iv | ------ ciphertext ------- - /// header | iv | cleartext | mac | padding - /// - /// tls record header: 5 bytes - /// iv: 16 bytes - /// ciphertext: cleartext length + mac + padding - /// mac: 20, 32 or 48 (sha1, sha256, sha384) - /// padding: 1-16 bytes - /// - /// Max encrypt buf overhead = iv + mac + padding (1-16) - /// aes_128_cbc_sha => 16 + 20 + 16 = 52 - /// aes_128_cbc_sha256 => 16 + 32 + 16 = 64 - /// aes_256_cbc_sha384 => 16 + 48 + 16 = 80 - pub fn encrypt( - self: *Self, - buf: []u8, - content_type: proto.ContentType, - cleartext: []const u8, - ) ![]const u8 { - const max_record_len = record.header_len + iv_len + cleartext.len + mac_len + max_padding; - if (buf.len < max_record_len) return error.BufferOverflow; - const cleartext_idx = record.header_len + iv_len; // position of cleartext in buf - @memcpy(buf[cleartext_idx..][0..cleartext.len], cleartext); - - { // calculate mac from (ad + cleartext) - // ... | ad | cleartext | mac | ... - // | -- mac msg -- | mac | - const ad = additionalData(self.encrypt_seq, content_type, cleartext.len); - const mac_msg = buf[cleartext_idx - ad.len ..][0 .. ad.len + cleartext.len]; - @memcpy(mac_msg[0..ad.len], &ad); - const mac = buf[cleartext_idx + cleartext.len ..][0..mac_len]; - Hmac.create(mac, mac_msg, &self.encrypt_secret); - self.encrypt_seq +%= 1; - } - - // ... | cleartext | mac | ... - // ... | -- plaintext --- ... - // ... | cleartext | mac | padding - // ... | ------ ciphertext ------- - const unpadded_len = cleartext.len + mac_len; - const padded_len = paddedLength(unpadded_len); - const plaintext = buf[cleartext_idx..][0..unpadded_len]; - const ciphertext = buf[cleartext_idx..][0..padded_len]; - - // Add header and iv at the buf start - // header | iv | ... - buf[0..record.header_len].* = record.header(content_type, iv_len + ciphertext.len); - const iv = buf[record.header_len..][0..iv_len]; - self.rnd.bytes(iv); - - // encrypt plaintext into ciphertext - CBC.init(self.encrypt_key).encrypt(ciphertext, plaintext, iv[0..iv_len].*); - - // header | iv | ------ ciphertext ------- - return buf[0 .. record.header_len + iv_len + ciphertext.len]; - } - - /// Decrypts payload into cleartext. Returns tls record content type and - /// cleartext. - pub fn decrypt( - self: *Self, - buf: []u8, - rec: Record, - ) !struct { proto.ContentType, []u8 } { - if (rec.payload.len < iv_len + mac_len + 1) return error.TlsDecryptError; - - // --------- payload ------------ - // iv | ------ ciphertext ------- - // iv | cleartext | mac | padding - const iv = rec.payload[0..iv_len]; - const ciphertext = rec.payload[iv_len..]; - - if (buf.len < ciphertext.len + additional_data_len) return error.BufferOverflow; - // ---------- buf --------------- - // ad | ------ plaintext -------- - // ad | cleartext | mac | padding - const plaintext = buf[additional_data_len..][0..ciphertext.len]; - // decrypt ciphertext -> plaintext - CBC.init(self.decrypt_key).decrypt(plaintext, ciphertext, iv[0..iv_len].*) catch return error.TlsDecryptError; - - // get padding len from last padding byte - const padding_len = plaintext[plaintext.len - 1] + 1; - if (plaintext.len < mac_len + padding_len) return error.TlsDecryptError; - // split plaintext into cleartext and mac - const cleartext_len = plaintext.len - mac_len - padding_len; - const cleartext = plaintext[0..cleartext_len]; - const mac = plaintext[cleartext_len..][0..mac_len]; - - // write ad to the buf - var ad = additionalData(self.decrypt_seq, rec.content_type, cleartext_len); - @memcpy(buf[0..ad.len], &ad); - const mac_msg = buf[0 .. ad.len + cleartext_len]; - self.decrypt_seq +%= 1; - - // calculate expected mac and compare with received mac - var expected_mac: [mac_len]u8 = undefined; - Hmac.create(&expected_mac, mac_msg, &self.decrypt_secret); - if (!std.mem.eql(u8, &expected_mac, mac)) - return error.TlsBadRecordMac; - - return .{ rec.content_type, cleartext }; - } - }; -} - -// xor lower 8 iv bytes with seq -fn ivWithSeq(comptime nonce_len: usize, iv: [nonce_len]u8, seq: u64) [nonce_len]u8 { - var res = iv; - const buf = res[nonce_len - 8 ..]; - const operand = std.mem.readInt(u64, buf, .big); - std.mem.writeInt(u64, buf, operand ^ seq, .big); - return res; -} - -pub const additional_data_len = record.header_len + @sizeOf(u64); - -fn additionalData(seq: u64, content_type: proto.ContentType, payload_len: usize) [additional_data_len]u8 { - const header = record.header(content_type, payload_len); - var seq_buf: [8]u8 = undefined; - std.mem.writeInt(u64, &seq_buf, seq, .big); - return seq_buf ++ header; -} - -// Cipher suites lists. In the order of preference. -// For the preference using grades priority and rules from Go project. -// https://ciphersuite.info/page/faq/ -// https://github.com/golang/go/blob/73186ba00251b3ed8baaab36e4f5278c7681155b/src/crypto/tls/cipher_suites.go#L226 -pub const cipher_suites = struct { - const tls12_secure = if (crypto.core.aes.has_hardware_support) [_]CipherSuite{ - // recommended - .ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, - .ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, - .ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256, - // secure - .ECDHE_RSA_WITH_AES_128_GCM_SHA256, - .ECDHE_RSA_WITH_AES_256_GCM_SHA384, - .ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256, - } else [_]CipherSuite{ - // recommended - .ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256, - .ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, - .ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, - - // secure - .ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256, - .ECDHE_RSA_WITH_AES_128_GCM_SHA256, - .ECDHE_RSA_WITH_AES_256_GCM_SHA384, - }; - const tls12_week = [_]CipherSuite{ - // week - .ECDHE_ECDSA_WITH_AES_128_CBC_SHA256, - .ECDHE_ECDSA_WITH_AES_256_CBC_SHA384, - .ECDHE_ECDSA_WITH_AES_128_CBC_SHA, - .ECDHE_RSA_WITH_AES_128_CBC_SHA256, - .ECDHE_RSA_WITH_AES_256_CBC_SHA384, - .ECDHE_RSA_WITH_AES_128_CBC_SHA, - - .RSA_WITH_AES_128_CBC_SHA256, - .RSA_WITH_AES_128_CBC_SHA, - }; - pub const tls13_ = if (crypto.core.aes.has_hardware_support) [_]CipherSuite{ - .AES_128_GCM_SHA256, - .AES_256_GCM_SHA384, - .CHACHA20_POLY1305_SHA256, - // Excluded because didn't find server which supports it to test - // .AEGIS_128L_SHA256 - } else [_]CipherSuite{ - .CHACHA20_POLY1305_SHA256, - .AES_128_GCM_SHA256, - .AES_256_GCM_SHA384, - }; - - pub const tls13 = &tls13_; - pub const tls12 = &(tls12_secure ++ tls12_week); - pub const secure = &(tls13_ ++ tls12_secure); - pub const all = &(tls13_ ++ tls12_secure ++ tls12_week); - - pub fn includes(list: []const CipherSuite, cs: CipherSuite) bool { - for (list) |s| { - if (cs == s) return true; - } - return false; - } -}; - -// Week, secure, recommended grades are from https://ciphersuite.info/page/faq/ -pub const CipherSuite = enum(u16) { - // tls 1.2 cbc sha1 - ECDHE_ECDSA_WITH_AES_128_CBC_SHA = 0xc009, // week - ECDHE_RSA_WITH_AES_128_CBC_SHA = 0xc013, // week - RSA_WITH_AES_128_CBC_SHA = 0x002F, // week - // tls 1.2 cbc sha256 - ECDHE_ECDSA_WITH_AES_128_CBC_SHA256 = 0xc023, // week - ECDHE_RSA_WITH_AES_128_CBC_SHA256 = 0xc027, // week - RSA_WITH_AES_128_CBC_SHA256 = 0x003c, // week - // tls 1.2 cbc sha384 - ECDHE_ECDSA_WITH_AES_256_CBC_SHA384 = 0xc024, // week - ECDHE_RSA_WITH_AES_256_CBC_SHA384 = 0xc028, // week - // tls 1.2 gcm - ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 = 0xc02b, // recommended - ECDHE_ECDSA_WITH_AES_256_GCM_SHA384 = 0xc02c, // recommended - ECDHE_RSA_WITH_AES_128_GCM_SHA256 = 0xc02f, // secure - ECDHE_RSA_WITH_AES_256_GCM_SHA384 = 0xc030, // secure - // tls 1.2 chacha - ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256 = 0xcca9, // recommended - ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256 = 0xcca8, // secure - // tls 1.3 (all are recommended) - AES_128_GCM_SHA256 = 0x1301, - AES_256_GCM_SHA384 = 0x1302, - CHACHA20_POLY1305_SHA256 = 0x1303, - AEGIS_128L_SHA256 = 0x1307, - // AEGIS_256_SHA512 = 0x1306, - _, - - pub fn validate(cs: CipherSuite) !void { - if (cipher_suites.includes(cipher_suites.tls12, cs)) return; - if (cipher_suites.includes(cipher_suites.tls13, cs)) return; - return error.TlsIllegalParameter; - } - - pub const Versions = enum { - both, - tls_1_3, - tls_1_2, - }; - - // get tls versions from list of cipher suites - pub fn versions(list: []const CipherSuite) !Versions { - var has_12 = false; - var has_13 = false; - for (list) |cs| { - if (cipher_suites.includes(cipher_suites.tls12, cs)) { - has_12 = true; - } else { - if (cipher_suites.includes(cipher_suites.tls13, cs)) has_13 = true; - } - } - if (has_12 and has_13) return .both; - if (has_12) return .tls_1_2; - if (has_13) return .tls_1_3; - return error.TlsIllegalParameter; - } - - pub const KeyExchangeAlgorithm = enum { - ecdhe, - rsa, - }; - - pub fn keyExchange(s: CipherSuite) KeyExchangeAlgorithm { - return switch (s) { - // Random premaster secret, encrypted with publich key from certificate. - // No server key exchange message. - .RSA_WITH_AES_128_CBC_SHA, - .RSA_WITH_AES_128_CBC_SHA256, - => .rsa, - else => .ecdhe, - }; - } - - pub const HashTag = enum { - sha256, - sha384, - sha512, - }; - - pub inline fn hash(cs: CipherSuite) HashTag { - return switch (cs) { - .ECDHE_RSA_WITH_AES_256_CBC_SHA384, - .ECDHE_RSA_WITH_AES_256_GCM_SHA384, - .ECDHE_ECDSA_WITH_AES_256_CBC_SHA384, - .ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, - .AES_256_GCM_SHA384, - => .sha384, - else => .sha256, - }; - } -}; - -const testing = std.testing; -const testu = @import("testu.zig"); - -test "CipherSuite validate" { - { - const cs: CipherSuite = .AES_256_GCM_SHA384; - try cs.validate(); - try testing.expectEqual(cs.hash(), .sha384); - try testing.expectEqual(cs.keyExchange(), .ecdhe); - } - { - const cs: CipherSuite = .AES_128_GCM_SHA256; - try cs.validate(); - try testing.expectEqual(.sha256, cs.hash()); - try testing.expectEqual(.ecdhe, cs.keyExchange()); - } - for (cipher_suites.tls12) |cs| { - try cs.validate(); - _ = cs.hash(); - _ = cs.keyExchange(); - } -} - -test "CipherSuite versions" { - try testing.expectEqual(.tls_1_3, CipherSuite.versions(&[_]CipherSuite{.AES_128_GCM_SHA256})); - try testing.expectEqual(.both, CipherSuite.versions(&[_]CipherSuite{ .AES_128_GCM_SHA256, .ECDHE_ECDSA_WITH_AES_128_CBC_SHA })); - try testing.expectEqual(.tls_1_2, CipherSuite.versions(&[_]CipherSuite{.RSA_WITH_AES_128_CBC_SHA})); -} - -test "gcm 1.2 encrypt overhead" { - inline for ([_]type{ - Aead12Aes128Gcm, - Aead12Aes256Gcm, - }) |T| { - { - const expected_key_len = switch (T) { - Aead12Aes128Gcm => 16, - Aead12Aes256Gcm => 32, - else => unreachable, - }; - try testing.expectEqual(expected_key_len, T.key_len); - try testing.expectEqual(16, T.auth_tag_len); - try testing.expectEqual(12, T.nonce_len); - try testing.expectEqual(4, T.iv_len); - try testing.expectEqual(29, T.encrypt_overhead); - } - } -} - -test "cbc 1.2 encrypt overhead" { - try testing.expectEqual(85, encrypt_overhead_tls_12); - - inline for ([_]type{ - CbcAes128Sha1, - CbcAes128Sha256, - CbcAes256Sha384, - }) |T| { - switch (T) { - CbcAes128Sha1 => { - try testing.expectEqual(20, T.mac_len); - try testing.expectEqual(16, T.key_len); - try testing.expectEqual(57, T.encrypt_overhead); - }, - CbcAes128Sha256 => { - try testing.expectEqual(32, T.mac_len); - try testing.expectEqual(16, T.key_len); - try testing.expectEqual(69, T.encrypt_overhead); - }, - CbcAes256Sha384 => { - try testing.expectEqual(48, T.mac_len); - try testing.expectEqual(32, T.key_len); - try testing.expectEqual(85, T.encrypt_overhead); - }, - else => unreachable, - } - try testing.expectEqual(16, T.paddedLength(1)); // cbc block padding - try testing.expectEqual(16, T.iv_len); - } -} - -test "overhead tls 1.3" { - try testing.expectEqual(22, encrypt_overhead_tls_13); - try expectOverhead(Aes128GcmSha256, 16, 16, 12, 22); - try expectOverhead(Aes256GcmSha384, 32, 16, 12, 22); - try expectOverhead(ChaChaSha256, 32, 16, 12, 22); - try expectOverhead(Aegis128Sha256, 16, 16, 16, 22); - // and tls 1.2 chacha - try expectOverhead(Aead12ChaCha, 32, 16, 12, 21); -} - -fn expectOverhead(T: type, key_len: usize, auth_tag_len: usize, nonce_len: usize, overhead: usize) !void { - try testing.expectEqual(key_len, T.key_len); - try testing.expectEqual(auth_tag_len, T.auth_tag_len); - try testing.expectEqual(nonce_len, T.nonce_len); - try testing.expectEqual(overhead, T.encrypt_overhead); -} - -test "client/server encryption tls 1.3" { - inline for (cipher_suites.tls13) |cs| { - var buf: [256]u8 = undefined; - testu.fill(&buf); - const secret = Transcript.Secret{ - .client = buf[0..128], - .server = buf[128..], - }; - var client_cipher = try Cipher.initTls13(cs, secret, .client); - var server_cipher = try Cipher.initTls13(cs, secret, .server); - try encryptDecrypt(&client_cipher, &server_cipher); - - try client_cipher.keyUpdateEncrypt(); - try server_cipher.keyUpdateDecrypt(); - try encryptDecrypt(&client_cipher, &server_cipher); - - try client_cipher.keyUpdateDecrypt(); - try server_cipher.keyUpdateEncrypt(); - try encryptDecrypt(&client_cipher, &server_cipher); - } -} - -test "client/server encryption tls 1.2" { - inline for (cipher_suites.tls12) |cs| { - var key_material: [256]u8 = undefined; - testu.fill(&key_material); - var client_cipher = try Cipher.initTls12(cs, &key_material, .client); - var server_cipher = try Cipher.initTls12(cs, &key_material, .server); - try encryptDecrypt(&client_cipher, &server_cipher); - } -} - -fn encryptDecrypt(client_cipher: *Cipher, server_cipher: *Cipher) !void { - const cleartext = - \\ Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do - \\ eiusmod tempor incididunt ut labore et dolore magna aliqua. - ; - var buf: [256]u8 = undefined; - - { // client to server - // encrypt - const encrypted = try client_cipher.encrypt(&buf, .application_data, cleartext); - const expected_encrypted_len = switch (client_cipher.*) { - inline else => |f| brk: { - const T = @TypeOf(f); - break :brk switch (T) { - CbcAes128Sha1, - CbcAes128Sha256, - CbcAes256Sha256, - CbcAes256Sha384, - => record.header_len + T.paddedLength(T.iv_len + cleartext.len + T.mac_len), - Aead12Aes128Gcm, - Aead12Aes256Gcm, - Aead12ChaCha, - Aes128GcmSha256, - Aes256GcmSha384, - ChaChaSha256, - Aegis128Sha256, - => cleartext.len + T.encrypt_overhead, - else => unreachable, - }; - }, - }; - try testing.expectEqual(expected_encrypted_len, encrypted.len); - // decrypt - const content_type, const decrypted = try server_cipher.decrypt(&buf, Record.init(encrypted)); - try testing.expectEqualSlices(u8, cleartext, decrypted); - try testing.expectEqual(.application_data, content_type); - } - // server to client - { - const encrypted = try server_cipher.encrypt(&buf, .application_data, cleartext); - const content_type, const decrypted = try client_cipher.decrypt(&buf, Record.init(encrypted)); - try testing.expectEqualSlices(u8, cleartext, decrypted); - try testing.expectEqual(.application_data, content_type); - } -} diff --git a/src/http/async/tls.zig/connection.zig b/src/http/async/tls.zig/connection.zig deleted file mode 100644 index 7a6afcbe..00000000 --- a/src/http/async/tls.zig/connection.zig +++ /dev/null @@ -1,665 +0,0 @@ -const std = @import("std"); -const assert = std.debug.assert; - -const proto = @import("protocol.zig"); -const record = @import("record.zig"); -const cipher = @import("cipher.zig"); -const Cipher = cipher.Cipher; - -const async_io = @import("../std/http/Client.zig"); -const Cbk = async_io.Cbk; -const Ctx = async_io.Ctx; - -pub fn connection(stream: anytype) Connection(@TypeOf(stream)) { - return .{ - .stream = stream, - .rec_rdr = record.reader(stream), - }; -} - -pub fn Connection(comptime Stream: type) type { - return struct { - stream: Stream, // underlying stream - rec_rdr: record.Reader(Stream), - cipher: Cipher = undefined, - - max_encrypt_seq: u64 = std.math.maxInt(u64) - 1, - key_update_requested: bool = false, - - read_buf: []const u8 = "", - received_close_notify: bool = false, - - const Self = @This(); - - /// Encrypts and writes single tls record to the stream. - fn writeRecord(c: *Self, content_type: proto.ContentType, bytes: []const u8) !void { - assert(bytes.len <= cipher.max_cleartext_len); - var write_buf: [cipher.max_ciphertext_record_len]u8 = undefined; - // If key update is requested send key update message and update - // my encryption keys. - if (c.cipher.encryptSeq() >= c.max_encrypt_seq or @atomicLoad(bool, &c.key_update_requested, .monotonic)) { - @atomicStore(bool, &c.key_update_requested, false, .monotonic); - - // If the request_update field is set to "update_requested", - // then the receiver MUST send a KeyUpdate of its own with - // request_update set to "update_not_requested" prior to sending - // its next Application Data record. This mechanism allows - // either side to force an update to the entire connection, but - // causes an implementation which receives multiple KeyUpdates - // while it is silent to respond with a single update. - // - // rfc: https://datatracker.ietf.org/doc/html/rfc8446#autoid-57 - const key_update = &record.handshakeHeader(.key_update, 1) ++ [_]u8{0}; - const rec = try c.cipher.encrypt(&write_buf, .handshake, key_update); - try c.stream.writeAll(rec); - try c.cipher.keyUpdateEncrypt(); - } - const rec = try c.cipher.encrypt(&write_buf, content_type, bytes); - try c.stream.writeAll(rec); - } - - fn writeAlert(c: *Self, err: anyerror) !void { - const cleartext = proto.alertFromError(err); - var buf: [128]u8 = undefined; - const ciphertext = try c.cipher.encrypt(&buf, .alert, &cleartext); - c.stream.writeAll(ciphertext) catch {}; - } - - /// Returns next record of cleartext data. - /// Can be used in iterator like loop without memcpy to another buffer: - /// while (try client.next()) |buf| { ... } - pub fn next(c: *Self) ReadError!?[]const u8 { - const content_type, const data = c.nextRecord() catch |err| { - try c.writeAlert(err); - return err; - } orelse return null; - if (content_type != .application_data) return error.TlsUnexpectedMessage; - return data; - } - - fn nextRecord(c: *Self) ReadError!?struct { proto.ContentType, []const u8 } { - if (c.eof()) return null; - while (true) { - const content_type, const cleartext = try c.rec_rdr.nextDecrypt(&c.cipher) orelse return null; - - switch (content_type) { - .application_data => {}, - .handshake => { - const handshake_type: proto.Handshake = @enumFromInt(cleartext[0]); - switch (handshake_type) { - // skip new session ticket and read next record - .new_session_ticket => continue, - .key_update => { - if (cleartext.len != 5) return error.TlsDecodeError; - // rfc: Upon receiving a KeyUpdate, the receiver MUST - // update its receiving keys. - try c.cipher.keyUpdateDecrypt(); - const key: proto.KeyUpdateRequest = @enumFromInt(cleartext[4]); - switch (key) { - .update_requested => { - @atomicStore(bool, &c.key_update_requested, true, .monotonic); - }, - .update_not_requested => {}, - else => return error.TlsIllegalParameter, - } - // this record is handled read next - continue; - }, - else => {}, - } - }, - .alert => { - if (cleartext.len < 2) return error.TlsUnexpectedMessage; - try proto.Alert.parse(cleartext[0..2].*).toError(); - // server side clean shutdown - c.received_close_notify = true; - return null; - }, - else => return error.TlsUnexpectedMessage, - } - return .{ content_type, cleartext }; - } - } - - pub fn eof(c: *Self) bool { - return c.received_close_notify and c.read_buf.len == 0; - } - - pub fn close(c: *Self) !void { - if (c.received_close_notify) return; - try c.writeRecord(.alert, &proto.Alert.closeNotify()); - } - - // read, write interface - - pub const ReadError = Stream.ReadError || proto.Alert.Error || - error{ - TlsBadVersion, - TlsUnexpectedMessage, - TlsRecordOverflow, - TlsDecryptError, - TlsDecodeError, - TlsBadRecordMac, - TlsIllegalParameter, - BufferOverflow, - }; - pub const WriteError = Stream.WriteError || - error{ - BufferOverflow, - TlsUnexpectedMessage, - }; - - pub const Reader = std.io.Reader(*Self, ReadError, read); - pub const Writer = std.io.Writer(*Self, WriteError, write); - - pub fn reader(c: *Self) Reader { - return .{ .context = c }; - } - - pub fn writer(c: *Self) Writer { - return .{ .context = c }; - } - - /// Encrypts cleartext and writes it to the underlying stream as single - /// tls record. Max single tls record payload length is 1<<14 (16K) - /// bytes. - pub fn write(c: *Self, bytes: []const u8) WriteError!usize { - const n = @min(bytes.len, cipher.max_cleartext_len); - try c.writeRecord(.application_data, bytes[0..n]); - return n; - } - - /// Encrypts cleartext and writes it to the underlying stream. If needed - /// splits cleartext into multiple tls record. - pub fn writeAll(c: *Self, bytes: []const u8) WriteError!void { - var index: usize = 0; - while (index < bytes.len) { - index += try c.write(bytes[index..]); - } - } - - pub fn read(c: *Self, buffer: []u8) ReadError!usize { - if (c.read_buf.len == 0) { - c.read_buf = try c.next() orelse return 0; - } - const n = @min(c.read_buf.len, buffer.len); - @memcpy(buffer[0..n], c.read_buf[0..n]); - c.read_buf = c.read_buf[n..]; - return n; - } - - /// Returns the number of bytes read. If the number read is smaller than - /// `buffer.len`, it means the stream reached the end. - pub fn readAll(c: *Self, buffer: []u8) ReadError!usize { - return c.readAtLeast(buffer, buffer.len); - } - - /// Returns the number of bytes read, calling the underlying read function - /// the minimal number of times until the buffer has at least `len` bytes - /// filled. If the number read is less than `len` it means the stream - /// reached the end. - pub fn readAtLeast(c: *Self, buffer: []u8, len: usize) ReadError!usize { - assert(len <= buffer.len); - var index: usize = 0; - while (index < len) { - const amt = try c.read(buffer[index..]); - if (amt == 0) break; - index += amt; - } - return index; - } - - /// Returns the number of bytes read. If the number read is less than - /// the space provided it means the stream reached the end. - pub fn readv(c: *Self, iovecs: []std.posix.iovec) !usize { - var vp: VecPut = .{ .iovecs = iovecs }; - while (true) { - if (c.read_buf.len == 0) { - c.read_buf = try c.next() orelse break; - } - const n = vp.put(c.read_buf); - const read_buf_len = c.read_buf.len; - c.read_buf = c.read_buf[n..]; - if ((n < read_buf_len) or - (n == read_buf_len and !c.rec_rdr.hasMore())) - break; - } - return vp.total; - } - - fn onWriteAll(ctx: *Ctx, res: anyerror!void) anyerror!void { - res catch |err| return ctx.pop(err); - - if (ctx._tls_write_bytes.len - ctx._tls_write_index > 0) { - const rec = ctx.conn().tls_client.prepareRecord(ctx.stream(), ctx) catch |err| return ctx.pop(err); - return ctx.stream().async_writeAll(rec, ctx, onWriteAll) catch |err| return ctx.pop(err); - } - - return ctx.pop({}); - } - - pub fn async_writeAll(c: *Self, stream: anytype, bytes: []const u8, ctx: *Ctx, comptime cbk: Cbk) !void { - assert(bytes.len <= cipher.max_cleartext_len); - - ctx._tls_write_bytes = bytes; - ctx._tls_write_index = 0; - const rec = try c.prepareRecord(stream, ctx); - - try ctx.push(cbk); - return stream.async_writeAll(rec, ctx, onWriteAll); - } - - fn prepareRecord(c: *Self, stream: anytype, ctx: *Ctx) ![]const u8 { - const len = @min(ctx._tls_write_bytes.len - ctx._tls_write_index, cipher.max_cleartext_len); - - // If key update is requested send key update message and update - // my encryption keys. - if (c.cipher.encryptSeq() >= c.max_encrypt_seq or @atomicLoad(bool, &c.key_update_requested, .monotonic)) { - @atomicStore(bool, &c.key_update_requested, false, .monotonic); - - // If the request_update field is set to "update_requested", - // then the receiver MUST send a KeyUpdate of its own with - // request_update set to "update_not_requested" prior to sending - // its next Application Data record. This mechanism allows - // either side to force an update to the entire connection, but - // causes an implementation which receives multiple KeyUpdates - // while it is silent to respond with a single update. - // - // rfc: https://datatracker.ietf.org/doc/html/rfc8446#autoid-57 - const key_update = &record.handshakeHeader(.key_update, 1) ++ [_]u8{0}; - const rec = try c.cipher.encrypt(&ctx._tls_write_buf, .handshake, key_update); - try stream.writeAll(rec); // TODO async - try c.cipher.keyUpdateEncrypt(); - } - - defer ctx._tls_write_index += len; - return c.cipher.encrypt(&ctx._tls_write_buf, .application_data, ctx._tls_write_bytes[ctx._tls_write_index..len]); - } - - fn onReadv(ctx: *Ctx, res: anyerror!void) anyerror!void { - res catch |err| return ctx.pop(err); - - if (ctx._tls_read_buf == null) { - // end of read - ctx.setLen(ctx._vp.total); - return ctx.pop({}); - } - - while (true) { - const n = ctx._vp.put(ctx._tls_read_buf.?); - const read_buf_len = ctx._tls_read_buf.?.len; - const c = ctx.conn().tls_client; - - if (read_buf_len == 0) { - // read another buffer - return c.async_next(ctx.stream(), ctx, onReadv) catch |err| return ctx.pop(err); - } - - ctx._tls_read_buf = ctx._tls_read_buf.?[n..]; - - if ((n < read_buf_len) or (n == read_buf_len and !c.rec_rdr.hasMore())) { - // end of read - ctx.setLen(ctx._vp.total); - return ctx.pop({}); - } - } - } - - pub fn async_readv(c: *Self, stream: anytype, iovecs: []std.posix.iovec, ctx: *Ctx, comptime cbk: Cbk) !void { - try ctx.push(cbk); - ctx._vp = .{ .iovecs = iovecs }; - - return c.async_next(stream, ctx, onReadv); - } - - fn onNext(ctx: *Ctx, res: anyerror!void) anyerror!void { - res catch |err| { - ctx.conn().tls_client.writeAlert(err) catch |e| std.log.err("onNext: write alert: {any}", .{e}); // TODO async - return ctx.pop(err); - }; - - if (ctx._tls_read_content_type != .application_data) { - return ctx.pop(error.TlsUnexpectedMessage); - } - - return ctx.pop({}); - } - - pub fn async_next(c: *Self, stream: anytype, ctx: *Ctx, comptime cbk: Cbk) !void { - try ctx.push(cbk); - - return c.async_next_decrypt(stream, ctx, onNext); - } - - pub fn onNextDecrypt(ctx: *Ctx, res: anyerror!void) anyerror!void { - res catch |err| return ctx.pop(err); - - const c = ctx.conn().tls_client; - // TOOD not sure if this works in my async case... - if (c.eof()) { - ctx._tls_read_buf = null; - return ctx.pop({}); - } - - const content_type = ctx._tls_read_content_type; - - switch (content_type) { - .application_data => {}, - .handshake => { - const handshake_type: proto.Handshake = @enumFromInt(ctx._tls_read_buf.?[0]); - switch (handshake_type) { - // skip new session ticket and read next record - .new_session_ticket => return c.async_next_record(ctx.stream(), ctx, onNextDecrypt) catch |err| return ctx.pop(err), - .key_update => { - if (ctx._tls_read_buf.?.len != 5) return ctx.pop(error.TlsDecodeError); - // rfc: Upon receiving a KeyUpdate, the receiver MUST - // update its receiving keys. - try c.cipher.keyUpdateDecrypt(); - const key: proto.KeyUpdateRequest = @enumFromInt(ctx._tls_read_buf.?[4]); - switch (key) { - .update_requested => { - @atomicStore(bool, &c.key_update_requested, true, .monotonic); - }, - .update_not_requested => {}, - else => return ctx.pop(error.TlsIllegalParameter), - } - // this record is handled read next - c.async_next_record(ctx.stream(), ctx, onNextDecrypt) catch |err| return ctx.pop(err); - }, - else => {}, - } - }, - .alert => { - if (ctx._tls_read_buf.?.len < 2) return ctx.pop(error.TlsUnexpectedMessage); - try proto.Alert.parse(ctx._tls_read_buf.?[0..2].*).toError(); - // server side clean shutdown - c.received_close_notify = true; - ctx._tls_read_buf = null; - return ctx.pop({}); - }, - else => return ctx.pop(error.TlsUnexpectedMessage), - } - - return ctx.pop({}); - } - - pub fn async_next_decrypt(c: *Self, stream: anytype, ctx: *Ctx, comptime cbk: Cbk) !void { - try ctx.push(cbk); - - return c.async_next_record(stream, ctx, onNextDecrypt) catch |err| return ctx.pop(err); - } - - pub fn onNextRecord(ctx: *Ctx, res: anyerror!void) anyerror!void { - res catch |err| return ctx.pop(err); - - const rec = ctx._tls_read_record orelse { - ctx._tls_read_buf = null; - return ctx.pop({}); - }; - - if (rec.protocol_version != .tls_1_2) return error.TlsBadVersion; - - const c = ctx.conn().tls_client; - const cph = &c.cipher; - - ctx._tls_read_content_type, ctx._tls_read_buf = cph.decrypt( - // Reuse reader buffer for cleartext. `rec.header` and - // `rec.payload`(ciphertext) are also pointing somewhere in - // this buffer. Decrypter is first reading then writing a - // block, cleartext has less length then ciphertext, - // cleartext starts from the beginning of the buffer, so - // ciphertext is always ahead of cleartext. - c.rec_rdr.buffer[0..c.rec_rdr.start], - rec, - ) catch |err| return ctx.pop(err); - - return ctx.pop({}); - } - - pub fn async_next_record(c: *Self, stream: anytype, ctx: *Ctx, comptime cbk: Cbk) !void { - try ctx.push(cbk); - - return c.async_reader_next(stream, ctx, onNextRecord); - } - - pub fn onReaderNext(ctx: *Ctx, res: anyerror!void) anyerror!void { - res catch |err| return ctx.pop(err); - - const c = ctx.conn().tls_client; - - const n = ctx.len(); - if (n == 0) { - ctx._tls_read_record = null; - return ctx.pop({}); - } - c.rec_rdr.end += n; - - return c.readNext(ctx); - } - - pub fn readNext(c: *Self, ctx: *Ctx) anyerror!void { - const buffer = c.rec_rdr.buffer[c.rec_rdr.start..c.rec_rdr.end]; - // If we have 5 bytes header. - if (buffer.len >= record.header_len) { - const record_header = buffer[0..record.header_len]; - const payload_len = std.mem.readInt(u16, record_header[3..5], .big); - if (payload_len > cipher.max_ciphertext_len) - return error.TlsRecordOverflow; - const record_len = record.header_len + payload_len; - // If we have whole record - if (buffer.len >= record_len) { - c.rec_rdr.start += record_len; - ctx._tls_read_record = record.Record.init(buffer[0..record_len]); - return ctx.pop({}); - } - } - { // Move dirty part to the start of the buffer. - const n = c.rec_rdr.end - c.rec_rdr.start; - if (n > 0 and c.rec_rdr.start > 0) { - if (c.rec_rdr.start > n) { - @memcpy(c.rec_rdr.buffer[0..n], c.rec_rdr.buffer[c.rec_rdr.start..][0..n]); - } else { - std.mem.copyForwards(u8, c.rec_rdr.buffer[0..n], c.rec_rdr.buffer[c.rec_rdr.start..][0..n]); - } - } - c.rec_rdr.start = 0; - c.rec_rdr.end = n; - } - // Read more from inner_reader. - return ctx.stream() - .async_read(c.rec_rdr.buffer[c.rec_rdr.end..], ctx, onReaderNext) catch |err| return ctx.pop(err); - } - - pub fn async_reader_next(c: *Self, _: anytype, ctx: *Ctx, comptime cbk: Cbk) !void { - try ctx.push(cbk); - return c.readNext(ctx); - } - }; -} - -const testing = std.testing; -const data12 = @import("testdata/tls12.zig"); -const testu = @import("testu.zig"); - -test "encrypt decrypt" { - var output_buf: [1024]u8 = undefined; - const stream = testu.Stream.init(&(data12.server_pong ** 3), &output_buf); - var conn: Connection(@TypeOf(stream)) = .{ .stream = stream, .rec_rdr = record.reader(stream) }; - conn.cipher = try Cipher.initTls12(.ECDHE_RSA_WITH_AES_128_CBC_SHA, &data12.key_material, .client); - conn.cipher.ECDHE_RSA_WITH_AES_128_CBC_SHA.rnd = testu.random(0); // use fixed rng - - conn.stream.output.reset(); - { // encrypt verify data from example - _ = testu.random(0x40); // sets iv to 40, 41, ... 4f - try conn.writeRecord(.handshake, &data12.client_finished); - try testing.expectEqualSlices(u8, &data12.verify_data_encrypted_msg, conn.stream.output.getWritten()); - } - - conn.stream.output.reset(); - { // encrypt ping - const cleartext = "ping"; - _ = testu.random(0); // sets iv to 00, 01, ... 0f - //conn.encrypt_seq = 1; - - try conn.writeAll(cleartext); - try testing.expectEqualSlices(u8, &data12.encrypted_ping_msg, conn.stream.output.getWritten()); - } - { // decrypt server pong message - conn.cipher.ECDHE_RSA_WITH_AES_128_CBC_SHA.decrypt_seq = 1; - try testing.expectEqualStrings("pong", (try conn.next()).?); - } - { // test reader interface - conn.cipher.ECDHE_RSA_WITH_AES_128_CBC_SHA.decrypt_seq = 1; - var rdr = conn.reader(); - var buffer: [4]u8 = undefined; - const n = try rdr.readAll(&buffer); - try testing.expectEqualStrings("pong", buffer[0..n]); - } - { // test readv interface - conn.cipher.ECDHE_RSA_WITH_AES_128_CBC_SHA.decrypt_seq = 1; - var buffer: [9]u8 = undefined; - var iovecs = [_]std.posix.iovec{ - .{ .base = &buffer, .len = 3 }, - .{ .base = buffer[3..], .len = 3 }, - .{ .base = buffer[6..], .len = 3 }, - }; - const n = try conn.readv(iovecs[0..]); - try testing.expectEqual(4, n); - try testing.expectEqualStrings("pong", buffer[0..n]); - } -} - -// Copied from: https://github.com/ziglang/zig/blob/455899668b620dfda40252501c748c0a983555bd/lib/std/crypto/tls/Client.zig#L1354 -/// Abstraction for sending multiple byte buffers to a slice of iovecs. -pub const VecPut = struct { - iovecs: []const std.posix.iovec, - idx: usize = 0, - off: usize = 0, - total: usize = 0, - - /// Returns the amount actually put which is always equal to bytes.len - /// unless the vectors ran out of space. - pub fn put(vp: *VecPut, bytes: []const u8) usize { - if (vp.idx >= vp.iovecs.len) return 0; - var bytes_i: usize = 0; - while (true) { - const v = vp.iovecs[vp.idx]; - const dest = v.base[vp.off..v.len]; - const src = bytes[bytes_i..][0..@min(dest.len, bytes.len - bytes_i)]; - @memcpy(dest[0..src.len], src); - bytes_i += src.len; - vp.off += src.len; - if (vp.off >= v.len) { - vp.off = 0; - vp.idx += 1; - if (vp.idx >= vp.iovecs.len) { - vp.total += bytes_i; - return bytes_i; - } - } - if (bytes_i >= bytes.len) { - vp.total += bytes_i; - return bytes_i; - } - } - } -}; - -test "client/server connection" { - const BufReaderWriter = struct { - buf: []u8, - wp: usize = 0, - rp: usize = 0, - - const Self = @This(); - - pub fn write(self: *Self, bytes: []const u8) !usize { - if (self.wp == self.buf.len) return error.NoSpaceLeft; - - const n = @min(bytes.len, self.buf.len - self.wp); - @memcpy(self.buf[self.wp..][0..n], bytes[0..n]); - self.wp += n; - return n; - } - - pub fn writeAll(self: *Self, bytes: []const u8) !void { - var n: usize = 0; - while (n < bytes.len) { - n += try self.write(bytes[n..]); - } - } - - pub fn read(self: *Self, bytes: []u8) !usize { - const n = @min(bytes.len, self.wp - self.rp); - if (n == 0) return 0; - @memcpy(bytes[0..n], self.buf[self.rp..][0..n]); - self.rp += n; - if (self.rp == self.wp) { - self.wp = 0; - self.rp = 0; - } - return n; - } - }; - - const TestStream = struct { - inner_stream: *BufReaderWriter, - const Self = @This(); - pub const ReadError = error{}; - pub const WriteError = error{NoSpaceLeft}; - pub fn read(self: *Self, bytes: []u8) !usize { - return try self.inner_stream.read(bytes); - } - pub fn writeAll(self: *Self, bytes: []const u8) !void { - return try self.inner_stream.writeAll(bytes); - } - }; - - const buf_len = 32 * 1024; - const tls_records_in_buf = (std.math.divCeil(comptime_int, buf_len, cipher.max_cleartext_len) catch unreachable); - const overhead: usize = tls_records_in_buf * @import("cipher.zig").encrypt_overhead_tls_13; - var buf: [buf_len + overhead]u8 = undefined; - var inner_stream = BufReaderWriter{ .buf = &buf }; - - const cipher_client, const cipher_server = brk: { - const Transcript = @import("transcript.zig").Transcript; - const CipherSuite = @import("cipher.zig").CipherSuite; - const cipher_suite: CipherSuite = .AES_256_GCM_SHA384; - - var rnd: [128]u8 = undefined; - std.crypto.random.bytes(&rnd); - const secret = Transcript.Secret{ - .client = rnd[0..64], - .server = rnd[64..], - }; - - break :brk .{ - try Cipher.initTls13(cipher_suite, secret, .client), - try Cipher.initTls13(cipher_suite, secret, .server), - }; - }; - - var conn1 = connection(TestStream{ .inner_stream = &inner_stream }); - conn1.cipher = cipher_client; - - var conn2 = connection(TestStream{ .inner_stream = &inner_stream }); - conn2.cipher = cipher_server; - - var prng = std.Random.DefaultPrng.init(0); - const random = prng.random(); - var send_buf: [buf_len]u8 = undefined; - var recv_buf: [buf_len]u8 = undefined; - random.bytes(&send_buf); // fill send buffer with random bytes - - for (0..16) |_| { - const n = buf_len; //random.uintLessThan(usize, buf_len); - - const sent = send_buf[0..n]; - try conn1.writeAll(sent); - const r = try conn2.readAll(&recv_buf); - const received = recv_buf[0..r]; - - try testing.expectEqual(n, r); - try testing.expectEqualSlices(u8, sent, received); - } -} diff --git a/src/http/async/tls.zig/handshake_client.zig b/src/http/async/tls.zig/handshake_client.zig deleted file mode 100644 index e7b48cf6..00000000 --- a/src/http/async/tls.zig/handshake_client.zig +++ /dev/null @@ -1,955 +0,0 @@ -const std = @import("std"); -const assert = std.debug.assert; -const crypto = std.crypto; -const mem = std.mem; -const Certificate = crypto.Certificate; - -const cipher = @import("cipher.zig"); -const Cipher = cipher.Cipher; -const CipherSuite = cipher.CipherSuite; -const cipher_suites = cipher.cipher_suites; -const Transcript = @import("transcript.zig").Transcript; -const record = @import("record.zig"); -const rsa = @import("rsa/rsa.zig"); -const key_log = @import("key_log.zig"); -const PrivateKey = @import("PrivateKey.zig"); -const proto = @import("protocol.zig"); - -const common = @import("handshake_common.zig"); -const dupe = common.dupe; -const CertificateBuilder = common.CertificateBuilder; -const CertificateParser = common.CertificateParser; -const DhKeyPair = common.DhKeyPair; -const CertBundle = common.CertBundle; -const CertKeyPair = common.CertKeyPair; - -pub const Options = struct { - host: []const u8, - /// Set of root certificate authorities that clients use when verifying - /// server certificates. - root_ca: CertBundle, - - /// Controls whether a client verifies the server's certificate chain and - /// host name. - insecure_skip_verify: bool = false, - - /// List of cipher suites to use. - /// To use just tls 1.3 cipher suites: - /// .cipher_suites = &tls.CipherSuite.tls13, - /// To select particular cipher suite: - /// .cipher_suites = &[_]tls.CipherSuite{tls.CipherSuite.CHACHA20_POLY1305_SHA256}, - cipher_suites: []const CipherSuite = cipher_suites.all, - - /// List of named groups to use. - /// To use specific named group: - /// .named_groups = &[_]tls.NamedGroup{.secp384r1}, - named_groups: []const proto.NamedGroup = supported_named_groups, - - /// Client authentication certificates and private key. - auth: ?CertKeyPair = null, - - /// If this structure is provided it will be filled with handshake attributes - /// at the end of the handshake process. - diagnostic: ?*Diagnostic = null, - - /// For logging current connection tls keys, so we can share them with - /// Wireshark and analyze decrypted traffic there. - key_log_callback: ?key_log.Callback = null, - - pub const Diagnostic = struct { - tls_version: proto.Version = @enumFromInt(0), - cipher_suite_tag: CipherSuite = @enumFromInt(0), - named_group: proto.NamedGroup = @enumFromInt(0), - signature_scheme: proto.SignatureScheme = @enumFromInt(0), - client_signature_scheme: proto.SignatureScheme = @enumFromInt(0), - }; -}; - -const supported_named_groups = &[_]proto.NamedGroup{ - .x25519, - .secp256r1, - .secp384r1, - .x25519_kyber768d00, -}; - -/// Handshake parses tls server message and creates client messages. Collects -/// tls attributes: server random, cipher suite and so on. Client messages are -/// created using provided buffer. Provided record reader is used to get tls -/// record when needed. -pub fn Handshake(comptime Stream: type) type { - const RecordReaderT = record.Reader(Stream); - return struct { - client_random: [32]u8, - server_random: [32]u8 = undefined, - master_secret: [48]u8 = undefined, - key_material: [48 * 4]u8 = undefined, // for sha256 32 * 4 is filled, for sha384 48 * 4 - - transcript: Transcript = .{}, - cipher_suite: CipherSuite = @enumFromInt(0), - named_group: ?proto.NamedGroup = null, - dh_kp: DhKeyPair, - rsa_secret: RsaSecret, - tls_version: proto.Version = .tls_1_2, - cipher: Cipher = undefined, - cert: CertificateParser = undefined, - client_certificate_requested: bool = false, - // public key len: x25519 = 32, secp256r1 = 65, secp384r1 = 97, x25519_kyber768d00 = 1120 - server_pub_key_buf: [2048]u8 = undefined, - server_pub_key: []const u8 = undefined, - - rec_rdr: *RecordReaderT, // tls record reader - buffer: []u8, // scratch buffer used in all messages creation - - const HandshakeT = @This(); - - pub fn init(buf: []u8, rec_rdr: *RecordReaderT) HandshakeT { - return .{ - .client_random = undefined, - .dh_kp = undefined, - .rsa_secret = undefined, - //.now_sec = std.time.timestamp(), - .buffer = buf, - .rec_rdr = rec_rdr, - }; - } - - fn initKeys( - h: *HandshakeT, - named_groups: []const proto.NamedGroup, - ) !void { - const init_keys_buf_len = 32 + 46 + DhKeyPair.seed_len; - var buf: [init_keys_buf_len]u8 = undefined; - crypto.random.bytes(&buf); - - h.client_random = buf[0..32].*; - h.rsa_secret = RsaSecret.init(buf[32..][0..46].*); - h.dh_kp = try DhKeyPair.init(buf[32 + 46 ..][0..DhKeyPair.seed_len].*, named_groups); - } - - /// Handshake exchanges messages with server to get agreement about - /// cryptographic parameters. That upgrades existing client-server - /// connection to TLS connection. Returns cipher used in application for - /// encrypted message exchange. - /// - /// Handles TLS 1.2 and TLS 1.3 connections. After initial client hello - /// server chooses in its server hello which TLS version will be used. - /// - /// TLS 1.2 handshake messages exchange: - /// Client Server - /// -------------------------------------------------------------- - /// ClientHello client flight 1 ---> - /// ServerHello - /// Certificate - /// ServerKeyExchange - /// CertificateRequest* - /// <--- server flight 1 ServerHelloDone - /// Certificate* - /// ClientKeyExchange - /// CertificateVerify* - /// ChangeCipherSpec - /// Finished client flight 2 ---> - /// ChangeCipherSpec - /// <--- server flight 2 Finished - /// - /// TLS 1.3 handshake messages exchange: - /// Client Server - /// -------------------------------------------------------------- - /// ClientHello client flight 1 ---> - /// ServerHello - /// {EncryptedExtensions} - /// {CertificateRequest*} - /// {Certificate} - /// {CertificateVerify} - /// <--- server flight 1 {Finished} - /// ChangeCipherSpec - /// {Certificate*} - /// {CertificateVerify*} - /// Finished client flight 2 ---> - /// - /// * - optional - /// {} - encrypted - /// - /// References: - /// https://datatracker.ietf.org/doc/html/rfc5246#section-7.3 - /// https://datatracker.ietf.org/doc/html/rfc8446#section-2 - /// - pub fn handshake(h: *HandshakeT, w: Stream, opt: Options) !Cipher { - defer h.updateDiagnostic(opt); - try h.initKeys(opt.named_groups); - h.cert = .{ - .host = opt.host, - .root_ca = opt.root_ca.bundle, - .skip_verify = opt.insecure_skip_verify, - }; - - try w.writeAll(try h.makeClientHello(opt)); // client flight 1 - try h.readServerFlight1(); // server flight 1 - h.transcript.use(h.cipher_suite.hash()); - - // tls 1.3 specific handshake part - if (h.tls_version == .tls_1_3) { - try h.generateHandshakeCipher(opt.key_log_callback); - try h.readEncryptedServerFlight1(); // server flight 1 - const app_cipher = try h.generateApplicationCipher(opt.key_log_callback); - try w.writeAll(try h.makeClientFlight2Tls13(opt.auth)); // client flight 2 - return app_cipher; - } - - // tls 1.2 specific handshake part - try h.generateCipher(opt.key_log_callback); - try w.writeAll(try h.makeClientFlight2Tls12(opt.auth)); // client flight 2 - try h.readServerFlight2(); // server flight 2 - return h.cipher; - } - - /// Prepare key material and generate cipher for TLS 1.2 - fn generateCipher(h: *HandshakeT, key_log_callback: ?key_log.Callback) !void { - try h.verifyCertificateSignatureTls12(); - try h.generateKeyMaterial(key_log_callback); - h.cipher = try Cipher.initTls12(h.cipher_suite, &h.key_material, .client); - } - - /// Generate TLS 1.2 pre master secret, master secret and key material. - fn generateKeyMaterial(h: *HandshakeT, key_log_callback: ?key_log.Callback) !void { - const pre_master_secret = if (h.named_group) |named_group| - try h.dh_kp.sharedKey(named_group, h.server_pub_key) - else - &h.rsa_secret.secret; - - _ = dupe( - &h.master_secret, - h.transcript.masterSecret(pre_master_secret, h.client_random, h.server_random), - ); - _ = dupe( - &h.key_material, - h.transcript.keyMaterial(&h.master_secret, h.client_random, h.server_random), - ); - if (key_log_callback) |cb| { - cb(key_log.label.client_random, &h.client_random, &h.master_secret); - } - } - - /// TLS 1.3 cipher used during handshake - fn generateHandshakeCipher(h: *HandshakeT, key_log_callback: ?key_log.Callback) !void { - const shared_key = try h.dh_kp.sharedKey(h.named_group.?, h.server_pub_key); - const handshake_secret = h.transcript.handshakeSecret(shared_key); - if (key_log_callback) |cb| { - cb(key_log.label.server_handshake_traffic_secret, &h.client_random, handshake_secret.server); - cb(key_log.label.client_handshake_traffic_secret, &h.client_random, handshake_secret.client); - } - h.cipher = try Cipher.initTls13(h.cipher_suite, handshake_secret, .client); - } - - /// TLS 1.3 application (client) cipher - fn generateApplicationCipher(h: *HandshakeT, key_log_callback: ?key_log.Callback) !Cipher { - const application_secret = h.transcript.applicationSecret(); - if (key_log_callback) |cb| { - cb(key_log.label.server_traffic_secret_0, &h.client_random, application_secret.server); - cb(key_log.label.client_traffic_secret_0, &h.client_random, application_secret.client); - } - return try Cipher.initTls13(h.cipher_suite, application_secret, .client); - } - - fn makeClientHello(h: *HandshakeT, opt: Options) ![]const u8 { - // Buffer will have this parts: - // | header | payload | extensions | - // - // Header will be written last because we need to know length of - // payload and extensions when creating it. Payload has - // extensions length (u16) as last element. - // - var buffer = h.buffer; - const header_len = 9; // tls record header (5 bytes) and handshake header (4 bytes) - const tls_versions = try CipherSuite.versions(opt.cipher_suites); - // Payload writer, preserve header_len bytes for handshake header. - var payload = record.Writer{ .buf = buffer[header_len..] }; - try payload.writeEnum(proto.Version.tls_1_2); - try payload.write(&h.client_random); - try payload.writeByte(0); // no session id - try payload.writeEnumArray(CipherSuite, opt.cipher_suites); - try payload.write(&[_]u8{ 0x01, 0x00 }); // no compression - - // Extensions writer starts after payload and preserves 2 more - // bytes for extension len in payload. - var ext = record.Writer{ .buf = buffer[header_len + payload.pos + 2 ..] }; - try ext.writeExtension(.supported_versions, switch (tls_versions) { - .both => &[_]proto.Version{ .tls_1_3, .tls_1_2 }, - .tls_1_3 => &[_]proto.Version{.tls_1_3}, - .tls_1_2 => &[_]proto.Version{.tls_1_2}, - }); - try ext.writeExtension(.signature_algorithms, common.supported_signature_algorithms); - - try ext.writeExtension(.supported_groups, opt.named_groups); - if (tls_versions != .tls_1_2) { - var keys: [supported_named_groups.len][]const u8 = undefined; - for (opt.named_groups, 0..) |ng, i| { - keys[i] = try h.dh_kp.publicKey(ng); - } - try ext.writeKeyShare(opt.named_groups, keys[0..opt.named_groups.len]); - } - try ext.writeServerName(opt.host); - - // Extensions length at the end of the payload. - try payload.writeInt(@as(u16, @intCast(ext.pos))); - - // Header at the start of the buffer. - const body_len = payload.pos + ext.pos; - buffer[0..header_len].* = record.header(.handshake, 4 + body_len) ++ - record.handshakeHeader(.client_hello, body_len); - - const msg = buffer[0 .. header_len + body_len]; - h.transcript.update(msg[record.header_len..]); - return msg; - } - - /// Process first flight of the messages from the server. - /// Read server hello message. If TLS 1.3 is chosen in server hello - /// return. For TLS 1.2 continue and read certificate, key_exchange - /// eventual certificate request and hello done messages. - fn readServerFlight1(h: *HandshakeT) !void { - var handshake_states: []const proto.Handshake = &.{.server_hello}; - - while (true) { - var d = try h.rec_rdr.nextDecoder(); - try d.expectContentType(.handshake); - - h.transcript.update(d.payload); - - // Multiple handshake messages can be packed in single tls record. - while (!d.eof()) { - const handshake_type = try d.decode(proto.Handshake); - - const length = try d.decode(u24); - if (length > cipher.max_cleartext_len) - return error.TlsUnsupportedFragmentedHandshakeMessage; - - brk: { - for (handshake_states) |state| - if (state == handshake_type) break :brk; - return error.TlsUnexpectedMessage; - } - switch (handshake_type) { - .server_hello => { // server hello, ref: https://datatracker.ietf.org/doc/html/rfc5246#section-7.4.1.3 - try h.parseServerHello(&d, length); - if (h.tls_version == .tls_1_3) { - if (!d.eof()) return error.TlsIllegalParameter; - return; // end of tls 1.3 server flight 1 - } - handshake_states = if (h.cert.skip_verify) - &.{ .certificate, .server_key_exchange, .server_hello_done } - else - &.{.certificate}; - }, - .certificate => { - try h.cert.parseCertificate(&d, h.tls_version); - handshake_states = if (h.cipher_suite.keyExchange() == .rsa) - &.{.server_hello_done} - else - &.{.server_key_exchange}; - }, - .server_key_exchange => { - try h.parseServerKeyExchange(&d); - handshake_states = &.{ .certificate_request, .server_hello_done }; - }, - .certificate_request => { - h.client_certificate_requested = true; - try d.skip(length); - handshake_states = &.{.server_hello_done}; - }, - .server_hello_done => { - if (length != 0) return error.TlsIllegalParameter; - return; - }, - else => return error.TlsUnexpectedMessage, - } - } - } - } - - /// Parse server hello message. - fn parseServerHello(h: *HandshakeT, d: *record.Decoder, length: u24) !void { - if (try d.decode(proto.Version) != proto.Version.tls_1_2) - return error.TlsBadVersion; - h.server_random = try d.array(32); - if (isServerHelloRetryRequest(&h.server_random)) - return error.TlsServerHelloRetryRequest; - - const session_id_len = try d.decode(u8); - if (session_id_len > 32) return error.TlsIllegalParameter; - try d.skip(session_id_len); - - h.cipher_suite = try d.decode(CipherSuite); - try h.cipher_suite.validate(); - try d.skip(1); // skip compression method - - const extensions_present = length > 2 + 32 + 1 + session_id_len + 2 + 1; - if (extensions_present) { - const exs_len = try d.decode(u16); - var l: usize = 0; - while (l < exs_len) { - const typ = try d.decode(proto.Extension); - const len = try d.decode(u16); - defer l += len + 4; - - switch (typ) { - .supported_versions => { - switch (try d.decode(proto.Version)) { - .tls_1_2, .tls_1_3 => |v| h.tls_version = v, - else => return error.TlsIllegalParameter, - } - if (len != 2) return error.TlsIllegalParameter; - }, - .key_share => { - h.named_group = try d.decode(proto.NamedGroup); - h.server_pub_key = dupe(&h.server_pub_key_buf, try d.slice(try d.decode(u16))); - if (len != h.server_pub_key.len + 4) return error.TlsIllegalParameter; - }, - else => { - try d.skip(len); - }, - } - } - } - } - - fn isServerHelloRetryRequest(server_random: []const u8) bool { - // Ref: https://datatracker.ietf.org/doc/html/rfc8446#section-4.1.3 - const hello_retry_request_magic = [32]u8{ - 0xCF, 0x21, 0xAD, 0x74, 0xE5, 0x9A, 0x61, 0x11, 0xBE, 0x1D, 0x8C, 0x02, 0x1E, 0x65, 0xB8, 0x91, - 0xC2, 0xA2, 0x11, 0x16, 0x7A, 0xBB, 0x8C, 0x5E, 0x07, 0x9E, 0x09, 0xE2, 0xC8, 0xA8, 0x33, 0x9C, - }; - return std.mem.eql(u8, server_random, &hello_retry_request_magic); - } - - fn parseServerKeyExchange(h: *HandshakeT, d: *record.Decoder) !void { - const curve_type = try d.decode(proto.Curve); - h.named_group = try d.decode(proto.NamedGroup); - h.server_pub_key = dupe(&h.server_pub_key_buf, try d.slice(try d.decode(u8))); - h.cert.signature_scheme = try d.decode(proto.SignatureScheme); - h.cert.signature = dupe(&h.cert.signature_buf, try d.slice(try d.decode(u16))); - if (curve_type != .named_curve) return error.TlsIllegalParameter; - } - - /// Read encrypted part (after server hello) of the server first flight - /// for TLS 1.3: change cipher spec, eventual certificate request, - /// certificate, certificate verify and handshake finished messages. - fn readEncryptedServerFlight1(h: *HandshakeT) !void { - var cleartext_buf = h.buffer; - var cleartext_buf_head: usize = 0; - var cleartext_buf_tail: usize = 0; - var handshake_states: []const proto.Handshake = &.{.encrypted_extensions}; - - outer: while (true) { - // wrapped record decoder - const rec = (try h.rec_rdr.next() orelse return error.EndOfStream); - if (rec.protocol_version != .tls_1_2) return error.TlsBadVersion; - switch (rec.content_type) { - .change_cipher_spec => {}, - .application_data => { - const content_type, const cleartext = try h.cipher.decrypt( - cleartext_buf[cleartext_buf_tail..], - rec, - ); - cleartext_buf_tail += cleartext.len; - if (cleartext_buf_tail > cleartext_buf.len) return error.TlsRecordOverflow; - - var d = record.Decoder.init(content_type, cleartext_buf[cleartext_buf_head..cleartext_buf_tail]); - try d.expectContentType(.handshake); - while (!d.eof()) { - const start_idx = d.idx; - const handshake_type = try d.decode(proto.Handshake); - const length = try d.decode(u24); - - // std.debug.print("handshake loop: {} {} {} {}\n", .{ handshake_type, length, d.payload.len, d.idx }); - if (length > cipher.max_cleartext_len) - return error.TlsUnsupportedFragmentedHandshakeMessage; - if (length > d.rest().len) - continue :outer; // fragmented handshake into multiple records - - defer { - const handshake_payload = d.payload[start_idx..d.idx]; - h.transcript.update(handshake_payload); - cleartext_buf_head += handshake_payload.len; - } - - brk: { - for (handshake_states) |state| - if (state == handshake_type) break :brk; - return error.TlsUnexpectedMessage; - } - switch (handshake_type) { - .encrypted_extensions => { - try d.skip(length); - handshake_states = if (h.cert.skip_verify) - &.{ .certificate_request, .certificate, .finished } - else - &.{ .certificate_request, .certificate }; - }, - .certificate_request => { - h.client_certificate_requested = true; - try d.skip(length); - handshake_states = if (h.cert.skip_verify) - &.{ .certificate, .finished } - else - &.{.certificate}; - }, - .certificate => { - try h.cert.parseCertificate(&d, h.tls_version); - handshake_states = &.{.certificate_verify}; - }, - .certificate_verify => { - try h.cert.parseCertificateVerify(&d); - try h.cert.verifySignature(h.transcript.serverCertificateVerify()); - handshake_states = &.{.finished}; - }, - .finished => { - const actual = try d.slice(length); - var buf: [Transcript.max_mac_length]u8 = undefined; - const expected = h.transcript.serverFinishedTls13(&buf); - if (!mem.eql(u8, expected, actual)) - return error.TlsDecryptError; - return; - }, - else => return error.TlsUnexpectedMessage, - } - } - cleartext_buf_head = 0; - cleartext_buf_tail = 0; - }, - else => return error.TlsUnexpectedMessage, - } - } - } - - fn verifyCertificateSignatureTls12(h: *HandshakeT) !void { - if (h.cipher_suite.keyExchange() != .ecdhe) return; - const verify_bytes = brk: { - var w = record.Writer{ .buf = h.buffer }; - try w.write(&h.client_random); - try w.write(&h.server_random); - try w.writeEnum(proto.Curve.named_curve); - try w.writeEnum(h.named_group.?); - try w.writeInt(@as(u8, @intCast(h.server_pub_key.len))); - try w.write(h.server_pub_key); - break :brk w.getWritten(); - }; - try h.cert.verifySignature(verify_bytes); - } - - /// Create client key exchange, change cipher spec and handshake - /// finished messages for tls 1.2. - /// If client certificate is requested also adds client certificate and - /// certificate verify messages. - fn makeClientFlight2Tls12(h: *HandshakeT, auth: ?CertKeyPair) ![]const u8 { - var w = record.Writer{ .buf = h.buffer }; - var cert_builder: ?CertificateBuilder = null; - - // Client certificate message - if (h.client_certificate_requested) { - if (auth) |a| { - const cb = h.certificateBuilder(a); - cert_builder = cb; - const client_certificate = try cb.makeCertificate(w.getPayload()); - h.transcript.update(client_certificate); - try w.advanceRecord(.handshake, client_certificate.len); - } else { - const empty_certificate = &record.handshakeHeader(.certificate, 3) ++ [_]u8{ 0, 0, 0 }; - h.transcript.update(empty_certificate); - try w.writeRecord(.handshake, empty_certificate); - } - } - - // Client key exchange message - { - const key_exchange = try h.makeClientKeyExchange(w.getPayload()); - h.transcript.update(key_exchange); - try w.advanceRecord(.handshake, key_exchange.len); - } - - // Client certificate verify message - if (cert_builder) |cb| { - const certificate_verify = try cb.makeCertificateVerify(w.getPayload()); - h.transcript.update(certificate_verify); - try w.advanceRecord(.handshake, certificate_verify.len); - } - - // Client change cipher spec message - try w.writeRecord(.change_cipher_spec, &[_]u8{1}); - - // Client handshake finished message - { - const client_finished = &record.handshakeHeader(.finished, 12) ++ - h.transcript.clientFinishedTls12(&h.master_secret); - h.transcript.update(client_finished); - try h.writeEncrypted(&w, client_finished); - } - - return w.getWritten(); - } - - /// Create client change cipher spec and handshake finished messages for - /// tls 1.3. - /// If the client certificate is requested by the server and client is - /// configured with certificates and private key then client certificate - /// and client certificate verify messages are also created. If the - /// server has requested certificate but the client is not configured - /// empty certificate message is sent, as is required by rfc. - fn makeClientFlight2Tls13(h: *HandshakeT, auth: ?CertKeyPair) ![]const u8 { - var w = record.Writer{ .buf = h.buffer }; - - // Client change cipher spec message - try w.writeRecord(.change_cipher_spec, &[_]u8{1}); - - if (h.client_certificate_requested) { - if (auth) |a| { - const cb = h.certificateBuilder(a); - { - const certificate = try cb.makeCertificate(w.getPayload()); - h.transcript.update(certificate); - try h.writeEncrypted(&w, certificate); - } - { - const certificate_verify = try cb.makeCertificateVerify(w.getPayload()); - h.transcript.update(certificate_verify); - try h.writeEncrypted(&w, certificate_verify); - } - } else { - // Empty certificate message and no certificate verify message - const empty_certificate = &record.handshakeHeader(.certificate, 4) ++ [_]u8{ 0, 0, 0, 0 }; - h.transcript.update(empty_certificate); - try h.writeEncrypted(&w, empty_certificate); - } - } - - // Client handshake finished message - { - const client_finished = try h.makeClientFinishedTls13(w.getPayload()); - h.transcript.update(client_finished); - try h.writeEncrypted(&w, client_finished); - } - - return w.getWritten(); - } - - fn certificateBuilder(h: *HandshakeT, auth: CertKeyPair) CertificateBuilder { - return .{ - .bundle = auth.bundle, - .key = auth.key, - .transcript = &h.transcript, - .tls_version = h.tls_version, - .side = .client, - }; - } - - fn makeClientFinishedTls13(h: *HandshakeT, buf: []u8) ![]const u8 { - var w = record.Writer{ .buf = buf }; - const verify_data = h.transcript.clientFinishedTls13(w.getHandshakePayload()); - try w.advanceHandshake(.finished, verify_data.len); - return w.getWritten(); - } - - fn makeClientKeyExchange(h: *HandshakeT, buf: []u8) ![]const u8 { - var w = record.Writer{ .buf = buf }; - if (h.named_group) |named_group| { - const key = try h.dh_kp.publicKey(named_group); - try w.writeHandshakeHeader(.client_key_exchange, 1 + key.len); - try w.writeInt(@as(u8, @intCast(key.len))); - try w.write(key); - } else { - const key = try h.rsa_secret.encrypted(h.cert.pub_key_algo, h.cert.pub_key); - try w.writeHandshakeHeader(.client_key_exchange, 2 + key.len); - try w.writeInt(@as(u16, @intCast(key.len))); - try w.write(key); - } - return w.getWritten(); - } - - fn readServerFlight2(h: *HandshakeT) !void { - // Read server change cipher spec message. - { - var d = try h.rec_rdr.nextDecoder(); - try d.expectContentType(.change_cipher_spec); - } - // Read encrypted server handshake finished message. Verify that - // content of the server finished message is based on transcript - // hash and master secret. - { - const content_type, const server_finished = - try h.rec_rdr.nextDecrypt(&h.cipher) orelse return error.EndOfStream; - if (content_type != .handshake) - return error.TlsUnexpectedMessage; - const expected = record.handshakeHeader(.finished, 12) ++ h.transcript.serverFinishedTls12(&h.master_secret); - if (!mem.eql(u8, server_finished, &expected)) - return error.TlsBadRecordMac; - } - } - - /// Write encrypted handshake message into `w` - fn writeEncrypted(h: *HandshakeT, w: *record.Writer, cleartext: []const u8) !void { - const ciphertext = try h.cipher.encrypt(w.getFree(), .handshake, cleartext); - w.pos += ciphertext.len; - } - - // Copy handshake parameters to opt.diagnostic - fn updateDiagnostic(h: *HandshakeT, opt: Options) void { - if (opt.diagnostic) |d| { - d.tls_version = h.tls_version; - d.cipher_suite_tag = h.cipher_suite; - d.named_group = h.named_group orelse @as(proto.NamedGroup, @enumFromInt(0x0000)); - d.signature_scheme = h.cert.signature_scheme; - if (opt.auth) |a| - d.client_signature_scheme = a.key.signature_scheme; - } - } - }; -} - -const RsaSecret = struct { - secret: [48]u8, - - fn init(rand: [46]u8) RsaSecret { - return .{ .secret = [_]u8{ 0x03, 0x03 } ++ rand }; - } - - // Pre master secret encrypted with certificate public key. - inline fn encrypted( - self: RsaSecret, - cert_pub_key_algo: Certificate.Parsed.PubKeyAlgo, - cert_pub_key: []const u8, - ) ![]const u8 { - if (cert_pub_key_algo != .rsaEncryption) return error.TlsBadSignatureScheme; - const pk = try rsa.PublicKey.fromDer(cert_pub_key); - var out: [512]u8 = undefined; - return try pk.encryptPkcsv1_5(&self.secret, &out); - } -}; - -const testing = std.testing; -const data12 = @import("testdata/tls12.zig"); -const data13 = @import("testdata/tls13.zig"); -const testu = @import("testu.zig"); - -fn testReader(data: []const u8) record.Reader(std.io.FixedBufferStream([]const u8)) { - return record.reader(std.io.fixedBufferStream(data)); -} -const TestHandshake = Handshake(std.io.FixedBufferStream([]const u8)); - -test "parse tls 1.2 server hello" { - var h = brk: { - var buffer: [1024]u8 = undefined; - var rec_rdr = testReader(&data12.server_hello_responses); - break :brk TestHandshake.init(&buffer, &rec_rdr); - }; - - // Set to known instead of random - h.client_random = data12.client_random; - h.dh_kp.x25519_kp.secret_key = data12.client_secret; - - // Parse server hello, certificate and key exchange messages. - // Read cipher suite, named group, signature scheme, server random certificate public key - // Verify host name, signature - // Calculate key material - h.cert = .{ .host = "example.ulfheim.net", .skip_verify = true, .root_ca = .{} }; - try h.readServerFlight1(); - try testing.expectEqual(.ECDHE_RSA_WITH_AES_128_CBC_SHA, h.cipher_suite); - try testing.expectEqual(.x25519, h.named_group.?); - try testing.expectEqual(.rsa_pkcs1_sha256, h.cert.signature_scheme); - try testing.expectEqualSlices(u8, &data12.server_random, &h.server_random); - try testing.expectEqualSlices(u8, &data12.server_pub_key, h.server_pub_key); - try testing.expectEqualSlices(u8, &data12.signature, h.cert.signature); - try testing.expectEqualSlices(u8, &data12.cert_pub_key, h.cert.pub_key); - - try h.verifyCertificateSignatureTls12(); - try h.generateKeyMaterial(null); - - try testing.expectEqualSlices(u8, &data12.key_material, h.key_material[0..data12.key_material.len]); -} - -test "verify google.com certificate" { - var h = brk: { - var buffer: [1024]u8 = undefined; - var rec_rdr = testReader(@embedFile("testdata/google.com/server_hello")); - break :brk TestHandshake.init(&buffer, &rec_rdr); - }; - h.client_random = @embedFile("testdata/google.com/client_random").*; - - var ca_bundle: Certificate.Bundle = .{}; - try ca_bundle.rescan(testing.allocator); - defer ca_bundle.deinit(testing.allocator); - - h.cert = .{ .host = "google.com", .skip_verify = true, .root_ca = .{}, .now_sec = 1714846451 }; - try h.readServerFlight1(); - try h.verifyCertificateSignatureTls12(); -} - -test "parse tls 1.3 server hello" { - var rec_rdr = testReader(&data13.server_hello); - var d = (try rec_rdr.nextDecoder()); - - const handshake_type = try d.decode(proto.Handshake); - const length = try d.decode(u24); - try testing.expectEqual(0x000076, length); - try testing.expectEqual(.server_hello, handshake_type); - - var h = TestHandshake.init(undefined, undefined); - try h.parseServerHello(&d, length); - - try testing.expectEqual(.AES_256_GCM_SHA384, h.cipher_suite); - try testing.expectEqualSlices(u8, &data13.server_random, &h.server_random); - try testing.expectEqual(.tls_1_3, h.tls_version); - try testing.expectEqual(.x25519, h.named_group); - try testing.expectEqualSlices(u8, &data13.server_pub_key, h.server_pub_key); -} - -test "init tls 1.3 handshake cipher" { - const cipher_suite_tag: CipherSuite = .AES_256_GCM_SHA384; - - var transcript = Transcript{}; - transcript.use(cipher_suite_tag.hash()); - transcript.update(data13.client_hello[record.header_len..]); - transcript.update(data13.server_hello[record.header_len..]); - - var dh_kp = DhKeyPair{ - .x25519_kp = .{ - .public_key = data13.client_public_key, - .secret_key = data13.client_private_key, - }, - }; - const shared_key = try dh_kp.sharedKey(.x25519, &data13.server_pub_key); - try testing.expectEqualSlices(u8, &data13.shared_key, shared_key); - - const cph = try Cipher.initTls13(cipher_suite_tag, transcript.handshakeSecret(shared_key), .client); - - const c = &cph.AES_256_GCM_SHA384; - try testing.expectEqualSlices(u8, &data13.server_handshake_key, &c.decrypt_key); - try testing.expectEqualSlices(u8, &data13.client_handshake_key, &c.encrypt_key); - try testing.expectEqualSlices(u8, &data13.server_handshake_iv, &c.decrypt_iv); - try testing.expectEqualSlices(u8, &data13.client_handshake_iv, &c.encrypt_iv); -} - -fn initExampleHandshake(h: *TestHandshake) !void { - h.cipher_suite = .AES_256_GCM_SHA384; - h.transcript.use(h.cipher_suite.hash()); - h.transcript.update(data13.client_hello[record.header_len..]); - h.transcript.update(data13.server_hello[record.header_len..]); - h.cipher = try Cipher.initTls13(h.cipher_suite, h.transcript.handshakeSecret(&data13.shared_key), .client); - h.tls_version = .tls_1_3; - h.cert.now_sec = 1714846451; - h.server_pub_key = &data13.server_pub_key; -} - -test "tls 1.3 decrypt wrapped record" { - var cph = brk: { - var h = TestHandshake.init(undefined, undefined); - try initExampleHandshake(&h); - break :brk h.cipher; - }; - - var cleartext_buf: [1024]u8 = undefined; - { - const rec = record.Record.init(&data13.server_encrypted_extensions_wrapped); - - const content_type, const cleartext = try cph.decrypt(&cleartext_buf, rec); - try testing.expectEqual(.handshake, content_type); - try testing.expectEqualSlices(u8, &data13.server_encrypted_extensions, cleartext); - } - { - const rec = record.Record.init(&data13.server_certificate_wrapped); - - const content_type, const cleartext = try cph.decrypt(&cleartext_buf, rec); - try testing.expectEqual(.handshake, content_type); - try testing.expectEqualSlices(u8, &data13.server_certificate, cleartext); - } -} - -test "tls 1.3 process server flight" { - var buffer: [1024]u8 = undefined; - var h = brk: { - var rec_rdr = testReader(&data13.server_flight); - break :brk TestHandshake.init(&buffer, &rec_rdr); - }; - - try initExampleHandshake(&h); - h.cert = .{ .host = "example.ulfheim.net", .skip_verify = true, .root_ca = .{} }; - try h.readEncryptedServerFlight1(); - - { // application cipher keys calculation - try testing.expectEqualSlices(u8, &data13.handshake_hash, &h.transcript.sha384.hash.peek()); - - var cph = try Cipher.initTls13(h.cipher_suite, h.transcript.applicationSecret(), .client); - const c = &cph.AES_256_GCM_SHA384; - try testing.expectEqualSlices(u8, &data13.server_application_key, &c.decrypt_key); - try testing.expectEqualSlices(u8, &data13.client_application_key, &c.encrypt_key); - try testing.expectEqualSlices(u8, &data13.server_application_iv, &c.decrypt_iv); - try testing.expectEqualSlices(u8, &data13.client_application_iv, &c.encrypt_iv); - - const encrypted = try cph.encrypt(&buffer, .application_data, "ping"); - try testing.expectEqualSlices(u8, &data13.client_ping_wrapped, encrypted); - } - { // client finished message - var buf: [4 + Transcript.max_mac_length]u8 = undefined; - const client_finished = try h.makeClientFinishedTls13(&buf); - try testing.expectEqualSlices(u8, &data13.client_finished_verify_data, client_finished[4..]); - const encrypted = try h.cipher.encrypt(&buffer, .handshake, client_finished); - try testing.expectEqualSlices(u8, &data13.client_finished_wrapped, encrypted); - } -} - -test "create client hello" { - var h = brk: { - var buffer: [1024]u8 = undefined; - var h = TestHandshake.init(&buffer, undefined); - h.client_random = testu.hexToBytes( - \\ 00 01 02 03 04 05 06 07 08 09 0a 0b 0c 0d 0e 0f 10 11 12 13 14 15 16 17 18 19 1a 1b 1c 1d 1e 1f - ); - break :brk h; - }; - - const actual = try h.makeClientHello(.{ - .host = "google.com", - .root_ca = .{}, - .cipher_suites = &[_]CipherSuite{CipherSuite.ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, - .named_groups = &[_]proto.NamedGroup{ .x25519, .secp256r1, .secp384r1 }, - }); - - const expected = testu.hexToBytes( - "16 03 03 00 6d " ++ // record header - "01 00 00 69 " ++ // handshake header - "03 03 " ++ // protocol version - "00 01 02 03 04 05 06 07 08 09 0a 0b 0c 0d 0e 0f 10 11 12 13 14 15 16 17 18 19 1a 1b 1c 1d 1e 1f " ++ // client random - "00 " ++ // no session id - "00 02 c0 2b " ++ // cipher suites - "01 00 " ++ // compression methods - "00 3e " ++ // extensions length - "00 2b 00 03 02 03 03 " ++ // supported versions extension - "00 0d 00 14 00 12 04 03 05 03 08 04 08 05 08 06 08 07 02 01 04 01 05 01 " ++ // signature algorithms extension - "00 0a 00 08 00 06 00 1d 00 17 00 18 " ++ // named groups extension - "00 00 00 0f 00 0d 00 00 0a 67 6f 6f 67 6c 65 2e 63 6f 6d ", // server name extension - ); - try testing.expectEqualSlices(u8, &expected, actual); -} - -test "handshake verify server finished message" { - var buffer: [1024]u8 = undefined; - var rec_rdr = testReader(&data12.server_handshake_finished_msgs); - var h = TestHandshake.init(&buffer, &rec_rdr); - - h.cipher_suite = .ECDHE_ECDSA_WITH_AES_128_CBC_SHA; - h.master_secret = data12.master_secret; - - // add handshake messages to the transcript - for (data12.handshake_messages) |msg| { - h.transcript.update(msg[record.header_len..]); - } - - // expect verify data - const client_finished = h.transcript.clientFinishedTls12(&h.master_secret); - try testing.expectEqualSlices(u8, &data12.client_finished, &record.handshakeHeader(.finished, 12) ++ client_finished); - - // init client with prepared key_material - h.cipher = try Cipher.initTls12(.ECDHE_RSA_WITH_AES_128_CBC_SHA, &data12.key_material, .client); - - // check that server verify data matches calculates from hashes of all handshake messages - h.transcript.update(&data12.client_finished); - try h.readServerFlight2(); -} diff --git a/src/http/async/tls.zig/handshake_common.zig b/src/http/async/tls.zig/handshake_common.zig deleted file mode 100644 index 178a3cea..00000000 --- a/src/http/async/tls.zig/handshake_common.zig +++ /dev/null @@ -1,448 +0,0 @@ -const std = @import("std"); -const assert = std.debug.assert; -const mem = std.mem; -const crypto = std.crypto; -const Certificate = crypto.Certificate; - -const Transcript = @import("transcript.zig").Transcript; -const PrivateKey = @import("PrivateKey.zig"); -const record = @import("record.zig"); -const rsa = @import("rsa/rsa.zig"); -const proto = @import("protocol.zig"); - -const X25519 = crypto.dh.X25519; -const EcdsaP256Sha256 = crypto.sign.ecdsa.EcdsaP256Sha256; -const EcdsaP384Sha384 = crypto.sign.ecdsa.EcdsaP384Sha384; -const Kyber768 = crypto.kem.kyber_d00.Kyber768; - -pub const supported_signature_algorithms = &[_]proto.SignatureScheme{ - .ecdsa_secp256r1_sha256, - .ecdsa_secp384r1_sha384, - .rsa_pss_rsae_sha256, - .rsa_pss_rsae_sha384, - .rsa_pss_rsae_sha512, - .ed25519, - .rsa_pkcs1_sha1, - .rsa_pkcs1_sha256, - .rsa_pkcs1_sha384, -}; - -pub const CertKeyPair = struct { - /// A chain of one or more certificates, leaf first. - /// - /// Each X.509 certificate contains the public key of a key pair, extra - /// information (the name of the holder, the name of an issuer of the - /// certificate, validity time spans) and a signature generated using the - /// private key of the issuer of the certificate. - /// - /// All certificates from the bundle are sent to the other side when creating - /// Certificate tls message. - /// - /// Leaf certificate and private key are used to create signature for - /// CertifyVerify tls message. - bundle: Certificate.Bundle, - - /// Private key corresponding to the public key in leaf certificate from the - /// bundle. - key: PrivateKey, - - pub fn load( - allocator: std.mem.Allocator, - dir: std.fs.Dir, - cert_path: []const u8, - key_path: []const u8, - ) !CertKeyPair { - var bundle: Certificate.Bundle = .{}; - try bundle.addCertsFromFilePath(allocator, dir, cert_path); - - const key_file = try dir.openFile(key_path, .{}); - defer key_file.close(); - const key = try PrivateKey.fromFile(allocator, key_file); - - return .{ .bundle = bundle, .key = key }; - } - - pub fn deinit(c: *CertKeyPair, allocator: std.mem.Allocator) void { - c.bundle.deinit(allocator); - } -}; - -pub const CertBundle = struct { - // A chain of one or more certificates. - // - // They are used to verify that certificate chain sent by the other side - // forms valid trust chain. - bundle: Certificate.Bundle = .{}, - - pub fn fromFile(allocator: std.mem.Allocator, dir: std.fs.Dir, path: []const u8) !CertBundle { - var bundle: Certificate.Bundle = .{}; - try bundle.addCertsFromFilePath(allocator, dir, path); - return .{ .bundle = bundle }; - } - - pub fn fromSystem(allocator: std.mem.Allocator) !CertBundle { - var bundle: Certificate.Bundle = .{}; - try bundle.rescan(allocator); - return .{ .bundle = bundle }; - } - - pub fn deinit(cb: *CertBundle, allocator: std.mem.Allocator) void { - cb.bundle.deinit(allocator); - } -}; - -pub const CertificateBuilder = struct { - bundle: Certificate.Bundle, - key: PrivateKey, - transcript: *Transcript, - tls_version: proto.Version = .tls_1_3, - side: proto.Side = .client, - - pub fn makeCertificate(h: CertificateBuilder, buf: []u8) ![]const u8 { - var w = record.Writer{ .buf = buf }; - const certs = h.bundle.bytes.items; - const certs_count = h.bundle.map.size; - - // Differences between tls 1.3 and 1.2 - // TLS 1.3 has request context in header and extensions for each certificate. - // Here we use empty length for each field. - // TLS 1.2 don't have these two fields. - const request_context, const extensions = if (h.tls_version == .tls_1_3) - .{ &[_]u8{0}, &[_]u8{ 0, 0 } } - else - .{ &[_]u8{}, &[_]u8{} }; - const certs_len = certs.len + (3 + extensions.len) * certs_count; - - // Write handshake header - try w.writeHandshakeHeader(.certificate, certs_len + request_context.len + 3); - try w.write(request_context); - try w.writeInt(@as(u24, @intCast(certs_len))); - - // Write each certificate - var index: u32 = 0; - while (index < certs.len) { - const e = try Certificate.der.Element.parse(certs, index); - const cert = certs[index..e.slice.end]; - try w.writeInt(@as(u24, @intCast(cert.len))); // certificate length - try w.write(cert); // certificate - try w.write(extensions); // certificate extensions - index = e.slice.end; - } - return w.getWritten(); - } - - pub fn makeCertificateVerify(h: CertificateBuilder, buf: []u8) ![]const u8 { - var w = record.Writer{ .buf = buf }; - const signature, const signature_scheme = try h.createSignature(); - try w.writeHandshakeHeader(.certificate_verify, signature.len + 4); - try w.writeEnum(signature_scheme); - try w.writeInt(@as(u16, @intCast(signature.len))); - try w.write(signature); - return w.getWritten(); - } - - /// Creates signature for client certificate signature message. - /// Returns signature bytes and signature scheme. - inline fn createSignature(h: CertificateBuilder) !struct { []const u8, proto.SignatureScheme } { - switch (h.key.signature_scheme) { - inline .ecdsa_secp256r1_sha256, - .ecdsa_secp384r1_sha384, - => |comptime_scheme| { - const Ecdsa = SchemeEcdsa(comptime_scheme); - const key = h.key.key.ecdsa; - const key_len = Ecdsa.SecretKey.encoded_length; - if (key.len < key_len) return error.InvalidEncoding; - const secret_key = try Ecdsa.SecretKey.fromBytes(key[0..key_len].*); - const key_pair = try Ecdsa.KeyPair.fromSecretKey(secret_key); - var signer = try key_pair.signer(null); - h.setSignatureVerifyBytes(&signer); - const signature = try signer.finalize(); - var buf: [Ecdsa.Signature.der_encoded_length_max]u8 = undefined; - return .{ signature.toDer(&buf), comptime_scheme }; - }, - inline .rsa_pss_rsae_sha256, - .rsa_pss_rsae_sha384, - .rsa_pss_rsae_sha512, - => |comptime_scheme| { - const Hash = SchemeHash(comptime_scheme); - var signer = try h.key.key.rsa.signerOaep(Hash, null); - h.setSignatureVerifyBytes(&signer); - var buf: [512]u8 = undefined; - const signature = try signer.finalize(&buf); - return .{ signature.bytes, comptime_scheme }; - }, - else => return error.TlsUnknownSignatureScheme, - } - } - - fn setSignatureVerifyBytes(h: CertificateBuilder, signer: anytype) void { - if (h.tls_version == .tls_1_2) { - // tls 1.2 signature uses current transcript hash value. - // ref: https://datatracker.ietf.org/doc/html/rfc5246.html#section-7.4.8 - const Hash = @TypeOf(signer.h); - signer.h = h.transcript.hash(Hash); - } else { - // tls 1.3 signature is computed over concatenation of 64 spaces, - // context, separator and content. - // ref: https://datatracker.ietf.org/doc/html/rfc8446#section-4.4.3 - if (h.side == .server) { - signer.update(h.transcript.serverCertificateVerify()); - } else { - signer.update(h.transcript.clientCertificateVerify()); - } - } - } - - fn SchemeEcdsa(comptime scheme: proto.SignatureScheme) type { - return switch (scheme) { - .ecdsa_secp256r1_sha256 => EcdsaP256Sha256, - .ecdsa_secp384r1_sha384 => EcdsaP384Sha384, - else => unreachable, - }; - } -}; - -pub const CertificateParser = struct { - pub_key_algo: Certificate.Parsed.PubKeyAlgo = undefined, - pub_key_buf: [600]u8 = undefined, - pub_key: []const u8 = undefined, - - signature_scheme: proto.SignatureScheme = @enumFromInt(0), - signature_buf: [1024]u8 = undefined, - signature: []const u8 = undefined, - - root_ca: Certificate.Bundle, - host: []const u8, - skip_verify: bool = false, - now_sec: i64 = 0, - - pub fn parseCertificate(h: *CertificateParser, d: *record.Decoder, tls_version: proto.Version) !void { - if (h.now_sec == 0) { - h.now_sec = std.time.timestamp(); - } - if (tls_version == .tls_1_3) { - const request_context = try d.decode(u8); - if (request_context != 0) return error.TlsIllegalParameter; - } - - var trust_chain_established = false; - var last_cert: ?Certificate.Parsed = null; - const certs_len = try d.decode(u24); - const start_idx = d.idx; - while (d.idx - start_idx < certs_len) { - const cert_len = try d.decode(u24); - // std.debug.print("=> {} {} {} {}\n", .{ certs_len, d.idx, cert_len, d.payload.len }); - const cert = try d.slice(cert_len); - if (tls_version == .tls_1_3) { - // certificate extensions present in tls 1.3 - try d.skip(try d.decode(u16)); - } - if (trust_chain_established) - continue; - - const subject = try (Certificate{ .buffer = cert, .index = 0 }).parse(); - if (last_cert) |pc| { - if (pc.verify(subject, h.now_sec)) { - last_cert = subject; - } else |err| switch (err) { - error.CertificateIssuerMismatch => { - // skip certificate which is not part of the chain - continue; - }, - else => return err, - } - } else { // first certificate - if (!h.skip_verify and h.host.len > 0) { - try subject.verifyHostName(h.host); - } - h.pub_key = dupe(&h.pub_key_buf, subject.pubKey()); - h.pub_key_algo = subject.pub_key_algo; - last_cert = subject; - } - if (!h.skip_verify) { - if (h.root_ca.verify(last_cert.?, h.now_sec)) |_| { - trust_chain_established = true; - } else |err| switch (err) { - error.CertificateIssuerNotFound => {}, - else => return err, - } - } - } - if (!h.skip_verify and !trust_chain_established) { - return error.CertificateIssuerNotFound; - } - } - - pub fn parseCertificateVerify(h: *CertificateParser, d: *record.Decoder) !void { - h.signature_scheme = try d.decode(proto.SignatureScheme); - h.signature = dupe(&h.signature_buf, try d.slice(try d.decode(u16))); - } - - pub fn verifySignature(h: *CertificateParser, verify_bytes: []const u8) !void { - switch (h.signature_scheme) { - inline .ecdsa_secp256r1_sha256, - .ecdsa_secp384r1_sha384, - => |comptime_scheme| { - if (h.pub_key_algo != .X9_62_id_ecPublicKey) return error.TlsBadSignatureScheme; - const cert_named_curve = h.pub_key_algo.X9_62_id_ecPublicKey; - switch (cert_named_curve) { - inline .secp384r1, .X9_62_prime256v1 => |comptime_cert_named_curve| { - const Ecdsa = SchemeEcdsaCert(comptime_scheme, comptime_cert_named_curve); - const key = try Ecdsa.PublicKey.fromSec1(h.pub_key); - const sig = try Ecdsa.Signature.fromDer(h.signature); - try sig.verify(verify_bytes, key); - }, - else => return error.TlsUnknownSignatureScheme, - } - }, - .ed25519 => { - if (h.pub_key_algo != .curveEd25519) return error.TlsBadSignatureScheme; - const Eddsa = crypto.sign.Ed25519; - if (h.signature.len != Eddsa.Signature.encoded_length) return error.InvalidEncoding; - const sig = Eddsa.Signature.fromBytes(h.signature[0..Eddsa.Signature.encoded_length].*); - if (h.pub_key.len != Eddsa.PublicKey.encoded_length) return error.InvalidEncoding; - const key = try Eddsa.PublicKey.fromBytes(h.pub_key[0..Eddsa.PublicKey.encoded_length].*); - try sig.verify(verify_bytes, key); - }, - inline .rsa_pss_rsae_sha256, - .rsa_pss_rsae_sha384, - .rsa_pss_rsae_sha512, - => |comptime_scheme| { - if (h.pub_key_algo != .rsaEncryption) return error.TlsBadSignatureScheme; - const Hash = SchemeHash(comptime_scheme); - const pk = try rsa.PublicKey.fromDer(h.pub_key); - const sig = rsa.Pss(Hash).Signature{ .bytes = h.signature }; - try sig.verify(verify_bytes, pk, null); - }, - inline .rsa_pkcs1_sha1, - .rsa_pkcs1_sha256, - .rsa_pkcs1_sha384, - .rsa_pkcs1_sha512, - => |comptime_scheme| { - if (h.pub_key_algo != .rsaEncryption) return error.TlsBadSignatureScheme; - const Hash = SchemeHash(comptime_scheme); - const pk = try rsa.PublicKey.fromDer(h.pub_key); - const sig = rsa.PKCS1v1_5(Hash).Signature{ .bytes = h.signature }; - try sig.verify(verify_bytes, pk); - }, - else => return error.TlsUnknownSignatureScheme, - } - } - - fn SchemeEcdsaCert(comptime scheme: proto.SignatureScheme, comptime cert_named_curve: Certificate.NamedCurve) type { - const Sha256 = crypto.hash.sha2.Sha256; - const Sha384 = crypto.hash.sha2.Sha384; - const Ecdsa = crypto.sign.ecdsa.Ecdsa; - - return switch (scheme) { - .ecdsa_secp256r1_sha256 => Ecdsa(cert_named_curve.Curve(), Sha256), - .ecdsa_secp384r1_sha384 => Ecdsa(cert_named_curve.Curve(), Sha384), - else => @compileError("bad scheme"), - }; - } -}; - -fn SchemeHash(comptime scheme: proto.SignatureScheme) type { - const Sha256 = crypto.hash.sha2.Sha256; - const Sha384 = crypto.hash.sha2.Sha384; - const Sha512 = crypto.hash.sha2.Sha512; - - return switch (scheme) { - .rsa_pkcs1_sha1 => crypto.hash.Sha1, - .rsa_pss_rsae_sha256, .rsa_pkcs1_sha256 => Sha256, - .rsa_pss_rsae_sha384, .rsa_pkcs1_sha384 => Sha384, - .rsa_pss_rsae_sha512, .rsa_pkcs1_sha512 => Sha512, - else => @compileError("bad scheme"), - }; -} - -pub fn dupe(buf: []u8, data: []const u8) []u8 { - const n = @min(data.len, buf.len); - @memcpy(buf[0..n], data[0..n]); - return buf[0..n]; -} - -pub const DhKeyPair = struct { - x25519_kp: X25519.KeyPair = undefined, - secp256r1_kp: EcdsaP256Sha256.KeyPair = undefined, - secp384r1_kp: EcdsaP384Sha384.KeyPair = undefined, - kyber768_kp: Kyber768.KeyPair = undefined, - - pub const seed_len = 32 + 32 + 48 + 64; - - pub fn init(seed: [seed_len]u8, named_groups: []const proto.NamedGroup) !DhKeyPair { - var kp: DhKeyPair = .{}; - for (named_groups) |ng| - switch (ng) { - .x25519 => kp.x25519_kp = try X25519.KeyPair.create(seed[0..][0..X25519.seed_length].*), - .secp256r1 => kp.secp256r1_kp = try EcdsaP256Sha256.KeyPair.create(seed[32..][0..EcdsaP256Sha256.KeyPair.seed_length].*), - .secp384r1 => kp.secp384r1_kp = try EcdsaP384Sha384.KeyPair.create(seed[32 + 32 ..][0..EcdsaP384Sha384.KeyPair.seed_length].*), - .x25519_kyber768d00 => kp.kyber768_kp = try Kyber768.KeyPair.create(seed[32 + 32 + 48 ..][0..Kyber768.seed_length].*), - else => return error.TlsIllegalParameter, - }; - return kp; - } - - pub inline fn sharedKey(self: DhKeyPair, named_group: proto.NamedGroup, server_pub_key: []const u8) ![]const u8 { - return switch (named_group) { - .x25519 => brk: { - if (server_pub_key.len != X25519.public_length) - return error.TlsIllegalParameter; - break :brk &(try X25519.scalarmult( - self.x25519_kp.secret_key, - server_pub_key[0..X25519.public_length].*, - )); - }, - .secp256r1 => brk: { - const pk = try EcdsaP256Sha256.PublicKey.fromSec1(server_pub_key); - const mul = try pk.p.mulPublic(self.secp256r1_kp.secret_key.bytes, .big); - break :brk &mul.affineCoordinates().x.toBytes(.big); - }, - .secp384r1 => brk: { - const pk = try EcdsaP384Sha384.PublicKey.fromSec1(server_pub_key); - const mul = try pk.p.mulPublic(self.secp384r1_kp.secret_key.bytes, .big); - break :brk &mul.affineCoordinates().x.toBytes(.big); - }, - .x25519_kyber768d00 => brk: { - const xksl = crypto.dh.X25519.public_length; - const hksl = xksl + Kyber768.ciphertext_length; - if (server_pub_key.len != hksl) - return error.TlsIllegalParameter; - - break :brk &((crypto.dh.X25519.scalarmult( - self.x25519_kp.secret_key, - server_pub_key[0..xksl].*, - ) catch return error.TlsDecryptFailure) ++ (self.kyber768_kp.secret_key.decaps( - server_pub_key[xksl..hksl], - ) catch return error.TlsDecryptFailure)); - }, - else => return error.TlsIllegalParameter, - }; - } - - // Returns 32, 65, 97 or 1216 bytes - pub inline fn publicKey(self: DhKeyPair, named_group: proto.NamedGroup) ![]const u8 { - return switch (named_group) { - .x25519 => &self.x25519_kp.public_key, - .secp256r1 => &self.secp256r1_kp.public_key.toUncompressedSec1(), - .secp384r1 => &self.secp384r1_kp.public_key.toUncompressedSec1(), - .x25519_kyber768d00 => &self.x25519_kp.public_key ++ self.kyber768_kp.public_key.toBytes(), - else => return error.TlsIllegalParameter, - }; - } -}; - -const testing = std.testing; -const testu = @import("testu.zig"); - -test "DhKeyPair.x25519" { - var seed: [DhKeyPair.seed_len]u8 = undefined; - testu.fill(&seed); - const server_pub_key = &testu.hexToBytes("3303486548531f08d91e675caf666c2dc924ac16f47a861a7f4d05919d143637"); - const expected = &testu.hexToBytes( - \\ F1 67 FB 4A 49 B2 91 77 08 29 45 A1 F7 08 5A 21 - \\ AF FE 9E 78 C2 03 9B 81 92 40 72 73 74 7A 46 1E - ); - const kp = try DhKeyPair.init(seed, &.{.x25519}); - try testing.expectEqualSlices(u8, expected, try kp.sharedKey(.x25519, server_pub_key)); -} diff --git a/src/http/async/tls.zig/handshake_server.zig b/src/http/async/tls.zig/handshake_server.zig deleted file mode 100644 index c26e8c69..00000000 --- a/src/http/async/tls.zig/handshake_server.zig +++ /dev/null @@ -1,520 +0,0 @@ -const std = @import("std"); -const assert = std.debug.assert; -const crypto = std.crypto; -const mem = std.mem; -const Certificate = crypto.Certificate; - -const cipher = @import("cipher.zig"); -const Cipher = cipher.Cipher; -const CipherSuite = @import("cipher.zig").CipherSuite; -const cipher_suites = @import("cipher.zig").cipher_suites; -const Transcript = @import("transcript.zig").Transcript; -const record = @import("record.zig"); -const PrivateKey = @import("PrivateKey.zig"); -const proto = @import("protocol.zig"); - -const common = @import("handshake_common.zig"); -const dupe = common.dupe; -const CertificateBuilder = common.CertificateBuilder; -const CertificateParser = common.CertificateParser; -const DhKeyPair = common.DhKeyPair; -const CertBundle = common.CertBundle; -const CertKeyPair = common.CertKeyPair; - -pub const Options = struct { - /// Server authentication. If null server will not send Certificate and - /// CertificateVerify message. - auth: ?CertKeyPair, - - /// If not null server will request client certificate. If auth_type is - /// .request empty client certificate message will be accepted. - /// Client certificate will be verified with root_ca certificates. - client_auth: ?ClientAuth = null, -}; - -pub const ClientAuth = struct { - /// Set of root certificate authorities that server use when verifying - /// client certificates. - root_ca: CertBundle, - - auth_type: Type = .require, - - pub const Type = enum { - /// Client certificate will be requested during the handshake, but does - /// not require that the client send any certificates. - request, - /// Client certificate will be requested during the handshake, and client - /// has to send valid certificate. - require, - }; -}; - -pub fn Handshake(comptime Stream: type) type { - const RecordReaderT = record.Reader(Stream); - return struct { - // public key len: x25519 = 32, secp256r1 = 65, secp384r1 = 97 - const max_pub_key_len = 98; - const supported_named_groups = &[_]proto.NamedGroup{ .x25519, .secp256r1, .secp384r1 }; - - server_random: [32]u8 = undefined, - client_random: [32]u8 = undefined, - legacy_session_id_buf: [32]u8 = undefined, - legacy_session_id: []u8 = "", - cipher_suite: CipherSuite = @enumFromInt(0), - signature_scheme: proto.SignatureScheme = @enumFromInt(0), - named_group: proto.NamedGroup = @enumFromInt(0), - client_pub_key_buf: [max_pub_key_len]u8 = undefined, - client_pub_key: []u8 = "", - server_pub_key_buf: [max_pub_key_len]u8 = undefined, - server_pub_key: []u8 = "", - - cipher: Cipher = undefined, - transcript: Transcript = .{}, - rec_rdr: *RecordReaderT, - buffer: []u8, - - const HandshakeT = @This(); - - pub fn init(buf: []u8, rec_rdr: *RecordReaderT) HandshakeT { - return .{ - .rec_rdr = rec_rdr, - .buffer = buf, - }; - } - - fn writeAlert(h: *HandshakeT, stream: Stream, cph: ?*Cipher, err: anyerror) !void { - if (cph) |c| { - const cleartext = proto.alertFromError(err); - const ciphertext = try c.encrypt(h.buffer, .alert, &cleartext); - stream.writeAll(ciphertext) catch {}; - } else { - const alert = record.header(.alert, 2) ++ proto.alertFromError(err); - stream.writeAll(&alert) catch {}; - } - } - - pub fn handshake(h: *HandshakeT, stream: Stream, opt: Options) !Cipher { - crypto.random.bytes(&h.server_random); - if (opt.auth) |a| { - // required signature scheme in client hello - h.signature_scheme = a.key.signature_scheme; - } - - h.readClientHello() catch |err| { - try h.writeAlert(stream, null, err); - return err; - }; - h.transcript.use(h.cipher_suite.hash()); - - const server_flight = brk: { - var w = record.Writer{ .buf = h.buffer }; - - const shared_key = h.sharedKey() catch |err| { - try h.writeAlert(stream, null, err); - return err; - }; - { - const hello = try h.makeServerHello(w.getFree()); - h.transcript.update(hello[record.header_len..]); - w.pos += hello.len; - } - { - const handshake_secret = h.transcript.handshakeSecret(shared_key); - h.cipher = try Cipher.initTls13(h.cipher_suite, handshake_secret, .server); - } - try w.writeRecord(.change_cipher_spec, &[_]u8{1}); - { - const encrypted_extensions = &record.handshakeHeader(.encrypted_extensions, 2) ++ [_]u8{ 0, 0 }; - h.transcript.update(encrypted_extensions); - try h.writeEncrypted(&w, encrypted_extensions); - } - if (opt.client_auth) |_| { - const certificate_request = try makeCertificateRequest(w.getPayload()); - h.transcript.update(certificate_request); - try h.writeEncrypted(&w, certificate_request); - } - if (opt.auth) |a| { - const cm = CertificateBuilder{ - .bundle = a.bundle, - .key = a.key, - .transcript = &h.transcript, - .side = .server, - }; - { - const certificate = try cm.makeCertificate(w.getPayload()); - h.transcript.update(certificate); - try h.writeEncrypted(&w, certificate); - } - { - const certificate_verify = try cm.makeCertificateVerify(w.getPayload()); - h.transcript.update(certificate_verify); - try h.writeEncrypted(&w, certificate_verify); - } - } - { - const finished = try h.makeFinished(w.getPayload()); - h.transcript.update(finished); - try h.writeEncrypted(&w, finished); - } - break :brk w.getWritten(); - }; - try stream.writeAll(server_flight); - - var app_cipher = brk: { - const application_secret = h.transcript.applicationSecret(); - break :brk try Cipher.initTls13(h.cipher_suite, application_secret, .server); - }; - - h.readClientFlight2(opt) catch |err| { - // Alert received from client - if (!mem.startsWith(u8, @errorName(err), "TlsAlert")) { - try h.writeAlert(stream, &app_cipher, err); - } - return err; - }; - return app_cipher; - } - - inline fn sharedKey(h: *HandshakeT) ![]const u8 { - var seed: [DhKeyPair.seed_len]u8 = undefined; - crypto.random.bytes(&seed); - var kp = try DhKeyPair.init(seed, supported_named_groups); - h.server_pub_key = dupe(&h.server_pub_key_buf, try kp.publicKey(h.named_group)); - return try kp.sharedKey(h.named_group, h.client_pub_key); - } - - fn readClientFlight2(h: *HandshakeT, opt: Options) !void { - var cleartext_buf = h.buffer; - var cleartext_buf_head: usize = 0; - var cleartext_buf_tail: usize = 0; - var handshake_state: proto.Handshake = .finished; - var cert: CertificateParser = undefined; - if (opt.client_auth) |client_auth| { - cert = .{ .root_ca = client_auth.root_ca.bundle, .host = "" }; - handshake_state = .certificate; - } - - outer: while (true) { - const rec = (try h.rec_rdr.next() orelse return error.EndOfStream); - if (rec.protocol_version != .tls_1_2 and rec.content_type != .alert) - return error.TlsProtocolVersion; - - switch (rec.content_type) { - .change_cipher_spec => { - if (rec.payload.len != 1) return error.TlsUnexpectedMessage; - }, - .application_data => { - const content_type, const cleartext = try h.cipher.decrypt( - cleartext_buf[cleartext_buf_tail..], - rec, - ); - cleartext_buf_tail += cleartext.len; - if (cleartext_buf_tail > cleartext_buf.len) return error.TlsRecordOverflow; - - var d = record.Decoder.init(content_type, cleartext_buf[cleartext_buf_head..cleartext_buf_tail]); - try d.expectContentType(.handshake); - while (!d.eof()) { - const start_idx = d.idx; - const handshake_type = try d.decode(proto.Handshake); - const length = try d.decode(u24); - - if (length > cipher.max_cleartext_len) - return error.TlsRecordOverflow; - if (length > d.rest().len) - continue :outer; // fragmented handshake into multiple records - - defer { - const handshake_payload = d.payload[start_idx..d.idx]; - h.transcript.update(handshake_payload); - cleartext_buf_head += handshake_payload.len; - } - - if (handshake_state != handshake_type) - return error.TlsUnexpectedMessage; - - switch (handshake_type) { - .certificate => { - if (length == 4) { - // got empty certificate message - if (opt.client_auth.?.auth_type == .require) - return error.TlsCertificateRequired; - try d.skip(length); - handshake_state = .finished; - } else { - try cert.parseCertificate(&d, .tls_1_3); - handshake_state = .certificate_verify; - } - }, - .certificate_verify => { - try cert.parseCertificateVerify(&d); - cert.verifySignature(h.transcript.clientCertificateVerify()) catch |err| return switch (err) { - error.TlsUnknownSignatureScheme => error.TlsIllegalParameter, - else => error.TlsDecryptError, - }; - handshake_state = .finished; - }, - .finished => { - const actual = try d.slice(length); - var buf: [Transcript.max_mac_length]u8 = undefined; - const expected = h.transcript.clientFinishedTls13(&buf); - if (!mem.eql(u8, expected, actual)) - return if (expected.len == actual.len) - error.TlsDecryptError - else - error.TlsDecodeError; - return; - }, - else => return error.TlsUnexpectedMessage, - } - } - cleartext_buf_head = 0; - cleartext_buf_tail = 0; - }, - .alert => { - var d = rec.decoder(); - return d.raiseAlert(); - }, - else => return error.TlsUnexpectedMessage, - } - } - } - - fn makeFinished(h: *HandshakeT, buf: []u8) ![]const u8 { - var w = record.Writer{ .buf = buf }; - const verify_data = h.transcript.serverFinishedTls13(w.getHandshakePayload()); - try w.advanceHandshake(.finished, verify_data.len); - return w.getWritten(); - } - - /// Write encrypted handshake message into `w` - fn writeEncrypted(h: *HandshakeT, w: *record.Writer, cleartext: []const u8) !void { - const ciphertext = try h.cipher.encrypt(w.getFree(), .handshake, cleartext); - w.pos += ciphertext.len; - } - - fn makeServerHello(h: *HandshakeT, buf: []u8) ![]const u8 { - const header_len = 9; // tls record header (5 bytes) and handshake header (4 bytes) - var w = record.Writer{ .buf = buf[header_len..] }; - - try w.writeEnum(proto.Version.tls_1_2); - try w.write(&h.server_random); - { - try w.writeInt(@as(u8, @intCast(h.legacy_session_id.len))); - if (h.legacy_session_id.len > 0) try w.write(h.legacy_session_id); - } - try w.writeEnum(h.cipher_suite); - try w.write(&[_]u8{0}); // compression method - - var e = record.Writer{ .buf = buf[header_len + w.pos + 2 ..] }; - { // supported versions extension - try e.writeEnum(proto.Extension.supported_versions); - try e.writeInt(@as(u16, 2)); - try e.writeEnum(proto.Version.tls_1_3); - } - { // key share extension - const key_len: u16 = @intCast(h.server_pub_key.len); - try e.writeEnum(proto.Extension.key_share); - try e.writeInt(key_len + 4); - try e.writeEnum(h.named_group); - try e.writeInt(key_len); - try e.write(h.server_pub_key); - } - try w.writeInt(@as(u16, @intCast(e.pos))); // extensions length - - const payload_len = w.pos + e.pos; - buf[0..header_len].* = record.header(.handshake, 4 + payload_len) ++ - record.handshakeHeader(.server_hello, payload_len); - - return buf[0 .. header_len + payload_len]; - } - - fn makeCertificateRequest(buf: []u8) ![]const u8 { - // handshake header + context length + extensions length - const header_len = 4 + 1 + 2; - - // First write extensions, leave space for header. - var ext = record.Writer{ .buf = buf[header_len..] }; - try ext.writeExtension(.signature_algorithms, common.supported_signature_algorithms); - - var w = record.Writer{ .buf = buf }; - try w.writeHandshakeHeader(.certificate_request, ext.pos + 3); - try w.writeInt(@as(u8, 0)); // certificate request context length = 0 - try w.writeInt(@as(u16, @intCast(ext.pos))); // extensions length - assert(w.pos == header_len); - w.pos += ext.pos; - - return w.getWritten(); - } - - fn readClientHello(h: *HandshakeT) !void { - var d = try h.rec_rdr.nextDecoder(); - try d.expectContentType(.handshake); - h.transcript.update(d.payload); - - const handshake_type = try d.decode(proto.Handshake); - if (handshake_type != .client_hello) return error.TlsUnexpectedMessage; - _ = try d.decode(u24); // handshake length - if (try d.decode(proto.Version) != .tls_1_2) return error.TlsProtocolVersion; - - h.client_random = try d.array(32); - { // legacy session id - const len = try d.decode(u8); - h.legacy_session_id = dupe(&h.legacy_session_id_buf, try d.slice(len)); - } - { // cipher suites - const end_idx = try d.decode(u16) + d.idx; - - while (d.idx < end_idx) { - const cipher_suite = try d.decode(CipherSuite); - if (cipher_suites.includes(cipher_suites.tls13, cipher_suite) and - @intFromEnum(h.cipher_suite) == 0) - { - h.cipher_suite = cipher_suite; - } - } - if (@intFromEnum(h.cipher_suite) == 0) - return error.TlsHandshakeFailure; - } - try d.skip(2); // compression methods - - var key_share_received = false; - // extensions - const extensions_end_idx = try d.decode(u16) + d.idx; - while (d.idx < extensions_end_idx) { - const extension_type = try d.decode(proto.Extension); - const extension_len = try d.decode(u16); - - switch (extension_type) { - .supported_versions => { - var tls_1_3_supported = false; - const end_idx = try d.decode(u8) + d.idx; - while (d.idx < end_idx) { - if (try d.decode(proto.Version) == proto.Version.tls_1_3) { - tls_1_3_supported = true; - } - } - if (!tls_1_3_supported) return error.TlsProtocolVersion; - }, - .key_share => { - if (extension_len == 0) return error.TlsDecodeError; - key_share_received = true; - var selected_named_group_idx = supported_named_groups.len; - const end_idx = try d.decode(u16) + d.idx; - while (d.idx < end_idx) { - const named_group = try d.decode(proto.NamedGroup); - switch (@intFromEnum(named_group)) { - 0x0001...0x0016, - 0x001a...0x001c, - 0xff01...0xff02, - => return error.TlsIllegalParameter, - else => {}, - } - const client_pub_key = try d.slice(try d.decode(u16)); - for (supported_named_groups, 0..) |supported, idx| { - if (named_group == supported and idx < selected_named_group_idx) { - h.named_group = named_group; - h.client_pub_key = dupe(&h.client_pub_key_buf, client_pub_key); - selected_named_group_idx = idx; - } - } - } - if (@intFromEnum(h.named_group) == 0) - return error.TlsIllegalParameter; - }, - .supported_groups => { - const end_idx = try d.decode(u16) + d.idx; - while (d.idx < end_idx) { - const named_group = try d.decode(proto.NamedGroup); - switch (@intFromEnum(named_group)) { - 0x0001...0x0016, - 0x001a...0x001c, - 0xff01...0xff02, - => return error.TlsIllegalParameter, - else => {}, - } - } - }, - .signature_algorithms => { - if (@intFromEnum(h.signature_scheme) == 0) { - try d.skip(extension_len); - } else { - var found = false; - const list_len = try d.decode(u16); - if (list_len == 0) return error.TlsDecodeError; - const end_idx = list_len + d.idx; - while (d.idx < end_idx) { - const signature_scheme = try d.decode(proto.SignatureScheme); - if (signature_scheme == h.signature_scheme) found = true; - } - if (!found) return error.TlsHandshakeFailure; - } - }, - else => { - try d.skip(extension_len); - }, - } - } - if (!key_share_received) return error.TlsMissingExtension; - if (@intFromEnum(h.named_group) == 0) return error.TlsIllegalParameter; - } - }; -} - -const testing = std.testing; -const data13 = @import("testdata/tls13.zig"); -const testu = @import("testu.zig"); - -fn testReader(data: []const u8) record.Reader(std.io.FixedBufferStream([]const u8)) { - return record.reader(std.io.fixedBufferStream(data)); -} -const TestHandshake = Handshake(std.io.FixedBufferStream([]const u8)); - -test "read client hello" { - var buffer: [1024]u8 = undefined; - var rec_rdr = testReader(&data13.client_hello); - var h = TestHandshake.init(&buffer, &rec_rdr); - h.signature_scheme = .ecdsa_secp521r1_sha512; // this must be supported in signature_algorithms extension - try h.readClientHello(); - - try testing.expectEqual(CipherSuite.AES_256_GCM_SHA384, h.cipher_suite); - try testing.expectEqual(.x25519, h.named_group); - try testing.expectEqualSlices(u8, &data13.client_random, &h.client_random); - try testing.expectEqualSlices(u8, &data13.client_public_key, h.client_pub_key); -} - -test "make server hello" { - var buffer: [128]u8 = undefined; - var h = TestHandshake.init(&buffer, undefined); - h.cipher_suite = .AES_256_GCM_SHA384; - testu.fillFrom(&h.server_random, 0); - testu.fillFrom(&h.server_pub_key_buf, 0x20); - h.named_group = .x25519; - h.server_pub_key = h.server_pub_key_buf[0..32]; - - const actual = try h.makeServerHello(&buffer); - const expected = &testu.hexToBytes( - \\ 16 03 03 00 5a 02 00 00 56 - \\ 03 03 - \\ 00 01 02 03 04 05 06 07 08 09 0a 0b 0c 0d 0e 0f 10 11 12 13 14 15 16 17 18 19 1a 1b 1c 1d 1e 1f - \\ 00 - \\ 13 02 00 - \\ 00 2e 00 2b 00 02 03 04 - \\ 00 33 00 24 00 1d 00 20 - \\ 20 21 22 23 24 25 26 27 28 29 2a 2b 2c 2d 2e 2f 30 31 32 33 34 35 36 37 38 39 3a 3b 3c 3d 3e 3f - ); - try testing.expectEqualSlices(u8, expected, actual); -} - -test "make certificate request" { - var buffer: [32]u8 = undefined; - - const expected = testu.hexToBytes("0d 00 00 1b" ++ // handshake header - "00 00 18" ++ // extension length - "00 0d" ++ // signature algorithms extension - "00 14" ++ // extension length - "00 12" ++ // list length 6 * 2 bytes - "04 03 05 03 08 04 08 05 08 06 08 07 02 01 04 01 05 01" // signature schemes - ); - const actual = try TestHandshake.makeCertificateRequest(&buffer); - try testing.expectEqualSlices(u8, &expected, actual); -} diff --git a/src/http/async/tls.zig/key_log.zig b/src/http/async/tls.zig/key_log.zig deleted file mode 100644 index 2da83f42..00000000 --- a/src/http/async/tls.zig/key_log.zig +++ /dev/null @@ -1,60 +0,0 @@ -//! Exporting tls key so we can share them with Wireshark and analyze decrypted -//! traffic in Wireshark. -//! To configure Wireshark to use exprted keys see curl reference. -//! -//! References: -//! curl: https://everything.curl.dev/usingcurl/tls/sslkeylogfile.html -//! openssl: https://www.openssl.org/docs/manmaster/man3/SSL_CTX_set_keylog_callback.html -//! https://udn.realityripple.com/docs/Mozilla/Projects/NSS/Key_Log_Format - -const std = @import("std"); - -const key_log_file_env = "SSLKEYLOGFILE"; - -pub const label = struct { - // tls 1.3 - pub const client_handshake_traffic_secret: []const u8 = "CLIENT_HANDSHAKE_TRAFFIC_SECRET"; - pub const server_handshake_traffic_secret: []const u8 = "SERVER_HANDSHAKE_TRAFFIC_SECRET"; - pub const client_traffic_secret_0: []const u8 = "CLIENT_TRAFFIC_SECRET_0"; - pub const server_traffic_secret_0: []const u8 = "SERVER_TRAFFIC_SECRET_0"; - // tls 1.2 - pub const client_random: []const u8 = "CLIENT_RANDOM"; -}; - -pub const Callback = *const fn (label: []const u8, client_random: []const u8, secret: []const u8) void; - -/// Writes tls keys to the file pointed by SSLKEYLOGFILE environment variable. -pub fn callback(label_: []const u8, client_random: []const u8, secret: []const u8) void { - if (std.posix.getenv(key_log_file_env)) |file_name| { - fileAppend(file_name, label_, client_random, secret) catch return; - } -} - -pub fn fileAppend(file_name: []const u8, label_: []const u8, client_random: []const u8, secret: []const u8) !void { - var buf: [1024]u8 = undefined; - const line = try formatLine(&buf, label_, client_random, secret); - try fileWrite(file_name, line); -} - -fn fileWrite(file_name: []const u8, line: []const u8) !void { - var file = try std.fs.createFileAbsolute(file_name, .{ .truncate = false }); - defer file.close(); - const stat = try file.stat(); - try file.seekTo(stat.size); - try file.writeAll(line); -} - -pub fn formatLine(buf: []u8, label_: []const u8, client_random: []const u8, secret: []const u8) ![]const u8 { - var fbs = std.io.fixedBufferStream(buf); - const w = fbs.writer(); - try w.print("{s} ", .{label_}); - for (client_random) |b| { - try std.fmt.formatInt(b, 16, .lower, .{ .width = 2, .fill = '0' }, w); - } - try w.writeByte(' '); - for (secret) |b| { - try std.fmt.formatInt(b, 16, .lower, .{ .width = 2, .fill = '0' }, w); - } - try w.writeByte('\n'); - return fbs.getWritten(); -} diff --git a/src/http/async/tls.zig/main.zig b/src/http/async/tls.zig/main.zig deleted file mode 100644 index b974377b..00000000 --- a/src/http/async/tls.zig/main.zig +++ /dev/null @@ -1,51 +0,0 @@ -const std = @import("std"); - -pub const CipherSuite = @import("cipher.zig").CipherSuite; -pub const cipher_suites = @import("cipher.zig").cipher_suites; -pub const PrivateKey = @import("PrivateKey.zig"); -pub const Connection = @import("connection.zig").Connection; -pub const ClientOptions = @import("handshake_client.zig").Options; -pub const ServerOptions = @import("handshake_server.zig").Options; -pub const key_log = @import("key_log.zig"); -pub const proto = @import("protocol.zig"); -pub const NamedGroup = proto.NamedGroup; -pub const Version = proto.Version; -const common = @import("handshake_common.zig"); -pub const CertBundle = common.CertBundle; -pub const CertKeyPair = common.CertKeyPair; - -pub const record = @import("record.zig"); -const connection = @import("connection.zig").connection; -const max_ciphertext_record_len = @import("cipher.zig").max_ciphertext_record_len; -const HandshakeServer = @import("handshake_server.zig").Handshake; -const HandshakeClient = @import("handshake_client.zig").Handshake; - -pub fn client(stream: anytype, opt: ClientOptions) !Connection(@TypeOf(stream)) { - const Stream = @TypeOf(stream); - var conn = connection(stream); - var write_buf: [max_ciphertext_record_len]u8 = undefined; - var h = HandshakeClient(Stream).init(&write_buf, &conn.rec_rdr); - conn.cipher = try h.handshake(conn.stream, opt); - return conn; -} - -pub fn server(stream: anytype, opt: ServerOptions) !Connection(@TypeOf(stream)) { - const Stream = @TypeOf(stream); - var conn = connection(stream); - var write_buf: [max_ciphertext_record_len]u8 = undefined; - var h = HandshakeServer(Stream).init(&write_buf, &conn.rec_rdr); - conn.cipher = try h.handshake(conn.stream, opt); - return conn; -} - -test { - _ = @import("handshake_common.zig"); - _ = @import("handshake_server.zig"); - _ = @import("handshake_client.zig"); - - _ = @import("connection.zig"); - _ = @import("cipher.zig"); - _ = @import("record.zig"); - _ = @import("transcript.zig"); - _ = @import("PrivateKey.zig"); -} diff --git a/src/http/async/tls.zig/protocol.zig b/src/http/async/tls.zig/protocol.zig deleted file mode 100644 index e3bb07ac..00000000 --- a/src/http/async/tls.zig/protocol.zig +++ /dev/null @@ -1,302 +0,0 @@ -pub const Version = enum(u16) { - tls_1_2 = 0x0303, - tls_1_3 = 0x0304, - _, -}; - -pub const ContentType = enum(u8) { - invalid = 0, - change_cipher_spec = 20, - alert = 21, - handshake = 22, - application_data = 23, - _, -}; - -pub const Handshake = enum(u8) { - client_hello = 1, - server_hello = 2, - new_session_ticket = 4, - end_of_early_data = 5, - encrypted_extensions = 8, - certificate = 11, - server_key_exchange = 12, - certificate_request = 13, - server_hello_done = 14, - certificate_verify = 15, - client_key_exchange = 16, - finished = 20, - key_update = 24, - message_hash = 254, - _, -}; - -pub const Curve = enum(u8) { - named_curve = 0x03, - _, -}; - -pub const Extension = enum(u16) { - /// RFC 6066 - server_name = 0, - /// RFC 6066 - max_fragment_length = 1, - /// RFC 6066 - status_request = 5, - /// RFC 8422, 7919 - supported_groups = 10, - /// RFC 8446 - signature_algorithms = 13, - /// RFC 5764 - use_srtp = 14, - /// RFC 6520 - heartbeat = 15, - /// RFC 7301 - application_layer_protocol_negotiation = 16, - /// RFC 6962 - signed_certificate_timestamp = 18, - /// RFC 7250 - client_certificate_type = 19, - /// RFC 7250 - server_certificate_type = 20, - /// RFC 7685 - padding = 21, - /// RFC 8446 - pre_shared_key = 41, - /// RFC 8446 - early_data = 42, - /// RFC 8446 - supported_versions = 43, - /// RFC 8446 - cookie = 44, - /// RFC 8446 - psk_key_exchange_modes = 45, - /// RFC 8446 - certificate_authorities = 47, - /// RFC 8446 - oid_filters = 48, - /// RFC 8446 - post_handshake_auth = 49, - /// RFC 8446 - signature_algorithms_cert = 50, - /// RFC 8446 - key_share = 51, - - _, -}; - -pub fn alertFromError(err: anyerror) [2]u8 { - return [2]u8{ @intFromEnum(Alert.Level.fatal), @intFromEnum(Alert.fromError(err)) }; -} - -pub const Alert = enum(u8) { - pub const Level = enum(u8) { - warning = 1, - fatal = 2, - _, - }; - - pub const Error = error{ - TlsAlertUnexpectedMessage, - TlsAlertBadRecordMac, - TlsAlertRecordOverflow, - TlsAlertHandshakeFailure, - TlsAlertBadCertificate, - TlsAlertUnsupportedCertificate, - TlsAlertCertificateRevoked, - TlsAlertCertificateExpired, - TlsAlertCertificateUnknown, - TlsAlertIllegalParameter, - TlsAlertUnknownCa, - TlsAlertAccessDenied, - TlsAlertDecodeError, - TlsAlertDecryptError, - TlsAlertProtocolVersion, - TlsAlertInsufficientSecurity, - TlsAlertInternalError, - TlsAlertInappropriateFallback, - TlsAlertMissingExtension, - TlsAlertUnsupportedExtension, - TlsAlertUnrecognizedName, - TlsAlertBadCertificateStatusResponse, - TlsAlertUnknownPskIdentity, - TlsAlertCertificateRequired, - TlsAlertNoApplicationProtocol, - TlsAlertUnknown, - }; - - close_notify = 0, - unexpected_message = 10, - bad_record_mac = 20, - record_overflow = 22, - handshake_failure = 40, - bad_certificate = 42, - unsupported_certificate = 43, - certificate_revoked = 44, - certificate_expired = 45, - certificate_unknown = 46, - illegal_parameter = 47, - unknown_ca = 48, - access_denied = 49, - decode_error = 50, - decrypt_error = 51, - protocol_version = 70, - insufficient_security = 71, - internal_error = 80, - inappropriate_fallback = 86, - user_canceled = 90, - missing_extension = 109, - unsupported_extension = 110, - unrecognized_name = 112, - bad_certificate_status_response = 113, - unknown_psk_identity = 115, - certificate_required = 116, - no_application_protocol = 120, - _, - - pub fn toError(alert: Alert) Error!void { - return switch (alert) { - .close_notify => {}, // not an error - .unexpected_message => error.TlsAlertUnexpectedMessage, - .bad_record_mac => error.TlsAlertBadRecordMac, - .record_overflow => error.TlsAlertRecordOverflow, - .handshake_failure => error.TlsAlertHandshakeFailure, - .bad_certificate => error.TlsAlertBadCertificate, - .unsupported_certificate => error.TlsAlertUnsupportedCertificate, - .certificate_revoked => error.TlsAlertCertificateRevoked, - .certificate_expired => error.TlsAlertCertificateExpired, - .certificate_unknown => error.TlsAlertCertificateUnknown, - .illegal_parameter => error.TlsAlertIllegalParameter, - .unknown_ca => error.TlsAlertUnknownCa, - .access_denied => error.TlsAlertAccessDenied, - .decode_error => error.TlsAlertDecodeError, - .decrypt_error => error.TlsAlertDecryptError, - .protocol_version => error.TlsAlertProtocolVersion, - .insufficient_security => error.TlsAlertInsufficientSecurity, - .internal_error => error.TlsAlertInternalError, - .inappropriate_fallback => error.TlsAlertInappropriateFallback, - .user_canceled => {}, // not an error - .missing_extension => error.TlsAlertMissingExtension, - .unsupported_extension => error.TlsAlertUnsupportedExtension, - .unrecognized_name => error.TlsAlertUnrecognizedName, - .bad_certificate_status_response => error.TlsAlertBadCertificateStatusResponse, - .unknown_psk_identity => error.TlsAlertUnknownPskIdentity, - .certificate_required => error.TlsAlertCertificateRequired, - .no_application_protocol => error.TlsAlertNoApplicationProtocol, - _ => error.TlsAlertUnknown, - }; - } - - pub fn fromError(err: anyerror) Alert { - return switch (err) { - error.TlsUnexpectedMessage => .unexpected_message, - error.TlsBadRecordMac => .bad_record_mac, - error.TlsRecordOverflow => .record_overflow, - error.TlsHandshakeFailure => .handshake_failure, - error.TlsBadCertificate => .bad_certificate, - error.TlsUnsupportedCertificate => .unsupported_certificate, - error.TlsCertificateRevoked => .certificate_revoked, - error.TlsCertificateExpired => .certificate_expired, - error.TlsCertificateUnknown => .certificate_unknown, - error.TlsIllegalParameter, - error.IdentityElement, - error.InvalidEncoding, - => .illegal_parameter, - error.TlsUnknownCa => .unknown_ca, - error.TlsAccessDenied => .access_denied, - error.TlsDecodeError => .decode_error, - error.TlsDecryptError => .decrypt_error, - error.TlsProtocolVersion => .protocol_version, - error.TlsInsufficientSecurity => .insufficient_security, - error.TlsInternalError => .internal_error, - error.TlsInappropriateFallback => .inappropriate_fallback, - error.TlsMissingExtension => .missing_extension, - error.TlsUnsupportedExtension => .unsupported_extension, - error.TlsUnrecognizedName => .unrecognized_name, - error.TlsBadCertificateStatusResponse => .bad_certificate_status_response, - error.TlsUnknownPskIdentity => .unknown_psk_identity, - error.TlsCertificateRequired => .certificate_required, - error.TlsNoApplicationProtocol => .no_application_protocol, - else => .internal_error, - }; - } - - pub fn parse(buf: [2]u8) Alert { - const level: Alert.Level = @enumFromInt(buf[0]); - const alert: Alert = @enumFromInt(buf[1]); - _ = level; - return alert; - } - - pub fn closeNotify() [2]u8 { - return [2]u8{ - @intFromEnum(Alert.Level.warning), - @intFromEnum(Alert.close_notify), - }; - } -}; - -pub const SignatureScheme = enum(u16) { - // RSASSA-PKCS1-v1_5 algorithms - rsa_pkcs1_sha256 = 0x0401, - rsa_pkcs1_sha384 = 0x0501, - rsa_pkcs1_sha512 = 0x0601, - - // ECDSA algorithms - ecdsa_secp256r1_sha256 = 0x0403, - ecdsa_secp384r1_sha384 = 0x0503, - ecdsa_secp521r1_sha512 = 0x0603, - - // RSASSA-PSS algorithms with public key OID rsaEncryption - rsa_pss_rsae_sha256 = 0x0804, - rsa_pss_rsae_sha384 = 0x0805, - rsa_pss_rsae_sha512 = 0x0806, - - // EdDSA algorithms - ed25519 = 0x0807, - ed448 = 0x0808, - - // RSASSA-PSS algorithms with public key OID RSASSA-PSS - rsa_pss_pss_sha256 = 0x0809, - rsa_pss_pss_sha384 = 0x080a, - rsa_pss_pss_sha512 = 0x080b, - - // Legacy algorithms - rsa_pkcs1_sha1 = 0x0201, - ecdsa_sha1 = 0x0203, - - _, -}; - -pub const NamedGroup = enum(u16) { - // Elliptic Curve Groups (ECDHE) - secp256r1 = 0x0017, - secp384r1 = 0x0018, - secp521r1 = 0x0019, - x25519 = 0x001D, - x448 = 0x001E, - - // Finite Field Groups (DHE) - ffdhe2048 = 0x0100, - ffdhe3072 = 0x0101, - ffdhe4096 = 0x0102, - ffdhe6144 = 0x0103, - ffdhe8192 = 0x0104, - - // Hybrid post-quantum key agreements - x25519_kyber512d00 = 0xFE30, - x25519_kyber768d00 = 0x6399, - - _, -}; - -pub const KeyUpdateRequest = enum(u8) { - update_not_requested = 0, - update_requested = 1, - _, -}; - -pub const Side = enum { - client, - server, -}; diff --git a/src/http/async/tls.zig/record.zig b/src/http/async/tls.zig/record.zig deleted file mode 100644 index 6c4df328..00000000 --- a/src/http/async/tls.zig/record.zig +++ /dev/null @@ -1,405 +0,0 @@ -const std = @import("std"); -const assert = std.debug.assert; -const mem = std.mem; - -const proto = @import("protocol.zig"); -const cipher = @import("cipher.zig"); -const Cipher = cipher.Cipher; -const record = @import("record.zig"); - -pub const header_len = 5; - -pub fn header(content_type: proto.ContentType, payload_len: usize) [header_len]u8 { - const int2 = std.crypto.tls.int2; - return [1]u8{@intFromEnum(content_type)} ++ - int2(@intFromEnum(proto.Version.tls_1_2)) ++ - int2(@intCast(payload_len)); -} - -pub fn handshakeHeader(handshake_type: proto.Handshake, payload_len: usize) [4]u8 { - const int3 = std.crypto.tls.int3; - return [1]u8{@intFromEnum(handshake_type)} ++ int3(@intCast(payload_len)); -} - -pub fn reader(inner_reader: anytype) Reader(@TypeOf(inner_reader)) { - return .{ .inner_reader = inner_reader }; -} - -pub fn Reader(comptime InnerReader: type) type { - return struct { - inner_reader: InnerReader, - - buffer: [cipher.max_ciphertext_record_len]u8 = undefined, - start: usize = 0, - end: usize = 0, - - const ReaderT = @This(); - - pub fn nextDecoder(r: *ReaderT) !Decoder { - const rec = (try r.next()) orelse return error.EndOfStream; - if (@intFromEnum(rec.protocol_version) != 0x0300 and - @intFromEnum(rec.protocol_version) != 0x0301 and - rec.protocol_version != .tls_1_2) - return error.TlsBadVersion; - return .{ - .content_type = rec.content_type, - .payload = rec.payload, - }; - } - - pub fn contentType(buf: []const u8) proto.ContentType { - return @enumFromInt(buf[0]); - } - - pub fn protocolVersion(buf: []const u8) proto.Version { - return @enumFromInt(mem.readInt(u16, buf[1..3], .big)); - } - - pub fn next(r: *ReaderT) !?Record { - while (true) { - const buffer = r.buffer[r.start..r.end]; - // If we have 5 bytes header. - if (buffer.len >= record.header_len) { - const record_header = buffer[0..record.header_len]; - const payload_len = mem.readInt(u16, record_header[3..5], .big); - if (payload_len > cipher.max_ciphertext_len) - return error.TlsRecordOverflow; - const record_len = record.header_len + payload_len; - // If we have whole record - if (buffer.len >= record_len) { - r.start += record_len; - return Record.init(buffer[0..record_len]); - } - } - { // Move dirty part to the start of the buffer. - const n = r.end - r.start; - if (n > 0 and r.start > 0) { - if (r.start > n) { - @memcpy(r.buffer[0..n], r.buffer[r.start..][0..n]); - } else { - mem.copyForwards(u8, r.buffer[0..n], r.buffer[r.start..][0..n]); - } - } - r.start = 0; - r.end = n; - } - { // Read more from inner_reader. - const n = try r.inner_reader.read(r.buffer[r.end..]); - if (n == 0) return null; - r.end += n; - } - } - } - - pub fn nextDecrypt(r: *ReaderT, cph: *Cipher) !?struct { proto.ContentType, []const u8 } { - const rec = (try r.next()) orelse return null; - if (rec.protocol_version != .tls_1_2) return error.TlsBadVersion; - - return try cph.decrypt( - // Reuse reader buffer for cleartext. `rec.header` and - // `rec.payload`(ciphertext) are also pointing somewhere in - // this buffer. Decrypter is first reading then writing a - // block, cleartext has less length then ciphertext, - // cleartext starts from the beginning of the buffer, so - // ciphertext is always ahead of cleartext. - r.buffer[0..r.start], - rec, - ); - } - - pub fn hasMore(r: *ReaderT) bool { - return r.end > r.start; - } - }; -} - -pub const Record = struct { - content_type: proto.ContentType, - protocol_version: proto.Version = .tls_1_2, - header: []const u8, - payload: []const u8, - - pub fn init(buffer: []const u8) Record { - return .{ - .content_type = @enumFromInt(buffer[0]), - .protocol_version = @enumFromInt(mem.readInt(u16, buffer[1..3], .big)), - .header = buffer[0..record.header_len], - .payload = buffer[record.header_len..], - }; - } - - pub fn decoder(r: @This()) Decoder { - return Decoder.init(r.content_type, @constCast(r.payload)); - } -}; - -pub const Decoder = struct { - content_type: proto.ContentType, - payload: []const u8, - idx: usize = 0, - - pub fn init(content_type: proto.ContentType, payload: []u8) Decoder { - return .{ - .content_type = content_type, - .payload = payload, - }; - } - - pub fn decode(d: *Decoder, comptime T: type) !T { - switch (@typeInfo(T)) { - .Int => |info| switch (info.bits) { - 8 => { - try skip(d, 1); - return d.payload[d.idx - 1]; - }, - 16 => { - try skip(d, 2); - const b0: u16 = d.payload[d.idx - 2]; - const b1: u16 = d.payload[d.idx - 1]; - return (b0 << 8) | b1; - }, - 24 => { - try skip(d, 3); - const b0: u24 = d.payload[d.idx - 3]; - const b1: u24 = d.payload[d.idx - 2]; - const b2: u24 = d.payload[d.idx - 1]; - return (b0 << 16) | (b1 << 8) | b2; - }, - else => @compileError("unsupported int type: " ++ @typeName(T)), - }, - .Enum => |info| { - const int = try d.decode(info.tag_type); - if (info.is_exhaustive) @compileError("exhaustive enum cannot be used"); - return @as(T, @enumFromInt(int)); - }, - else => @compileError("unsupported type: " ++ @typeName(T)), - } - } - - pub fn array(d: *Decoder, comptime len: usize) ![len]u8 { - try d.skip(len); - return d.payload[d.idx - len ..][0..len].*; - } - - pub fn slice(d: *Decoder, len: usize) ![]const u8 { - try d.skip(len); - return d.payload[d.idx - len ..][0..len]; - } - - pub fn skip(d: *Decoder, amt: usize) !void { - if (d.idx + amt > d.payload.len) return error.TlsDecodeError; - d.idx += amt; - } - - pub fn rest(d: Decoder) []const u8 { - return d.payload[d.idx..]; - } - - pub fn eof(d: Decoder) bool { - return d.idx == d.payload.len; - } - - pub fn expectContentType(d: *Decoder, content_type: proto.ContentType) !void { - if (d.content_type == content_type) return; - - switch (d.content_type) { - .alert => try d.raiseAlert(), - else => return error.TlsUnexpectedMessage, - } - } - - pub fn raiseAlert(d: *Decoder) !void { - if (d.payload.len < 2) return error.TlsUnexpectedMessage; - try proto.Alert.parse(try d.array(2)).toError(); - return error.TlsAlertCloseNotify; - } -}; - -const testing = std.testing; -const data12 = @import("testdata/tls12.zig"); -const testu = @import("testu.zig"); -const CipherSuite = @import("cipher.zig").CipherSuite; - -test Reader { - var fbs = std.io.fixedBufferStream(&data12.server_responses); - var rdr = reader(fbs.reader()); - - const expected = [_]struct { - content_type: proto.ContentType, - payload_len: usize, - }{ - .{ .content_type = .handshake, .payload_len = 49 }, - .{ .content_type = .handshake, .payload_len = 815 }, - .{ .content_type = .handshake, .payload_len = 300 }, - .{ .content_type = .handshake, .payload_len = 4 }, - .{ .content_type = .change_cipher_spec, .payload_len = 1 }, - .{ .content_type = .handshake, .payload_len = 64 }, - }; - for (expected) |e| { - const rec = (try rdr.next()).?; - try testing.expectEqual(e.content_type, rec.content_type); - try testing.expectEqual(e.payload_len, rec.payload.len); - try testing.expectEqual(.tls_1_2, rec.protocol_version); - } -} - -test Decoder { - var fbs = std.io.fixedBufferStream(&data12.server_responses); - var rdr = reader(fbs.reader()); - - var d = (try rdr.nextDecoder()); - try testing.expectEqual(.handshake, d.content_type); - - try testing.expectEqual(.server_hello, try d.decode(proto.Handshake)); - try testing.expectEqual(45, try d.decode(u24)); // length - try testing.expectEqual(.tls_1_2, try d.decode(proto.Version)); - try testing.expectEqualStrings( - &testu.hexToBytes("707172737475767778797a7b7c7d7e7f808182838485868788898a8b8c8d8e8f"), - try d.slice(32), - ); // server random - try testing.expectEqual(0, try d.decode(u8)); // session id len - try testing.expectEqual(.ECDHE_RSA_WITH_AES_128_CBC_SHA, try d.decode(CipherSuite)); - try testing.expectEqual(0, try d.decode(u8)); // compression method - try testing.expectEqual(5, try d.decode(u16)); // extension length - try testing.expectEqual(5, d.rest().len); - try d.skip(5); - try testing.expect(d.eof()); -} - -pub const Writer = struct { - buf: []u8, - pos: usize = 0, - - pub fn write(self: *Writer, data: []const u8) !void { - defer self.pos += data.len; - if (self.pos + data.len > self.buf.len) return error.BufferOverflow; - @memcpy(self.buf[self.pos..][0..data.len], data); - } - - pub fn writeByte(self: *Writer, b: u8) !void { - defer self.pos += 1; - if (self.pos == self.buf.len) return error.BufferOverflow; - self.buf[self.pos] = b; - } - - pub fn writeEnum(self: *Writer, value: anytype) !void { - try self.writeInt(@intFromEnum(value)); - } - - pub fn writeInt(self: *Writer, value: anytype) !void { - const IntT = @TypeOf(value); - const bytes = @divExact(@typeInfo(IntT).Int.bits, 8); - const free = self.buf[self.pos..]; - if (free.len < bytes) return error.BufferOverflow; - mem.writeInt(IntT, free[0..bytes], value, .big); - self.pos += bytes; - } - - pub fn writeHandshakeHeader(self: *Writer, handshake_type: proto.Handshake, payload_len: usize) !void { - try self.write(&record.handshakeHeader(handshake_type, payload_len)); - } - - /// Should be used after writing handshake payload in buffer provided by `getHandshakePayload`. - pub fn advanceHandshake(self: *Writer, handshake_type: proto.Handshake, payload_len: usize) !void { - try self.write(&record.handshakeHeader(handshake_type, payload_len)); - self.pos += payload_len; - } - - /// Record payload is already written by using buffer space from `getPayload`. - /// Now when we know payload len we can write record header and advance over payload. - pub fn advanceRecord(self: *Writer, content_type: proto.ContentType, payload_len: usize) !void { - try self.write(&record.header(content_type, payload_len)); - self.pos += payload_len; - } - - pub fn writeRecord(self: *Writer, content_type: proto.ContentType, payload: []const u8) !void { - try self.write(&record.header(content_type, payload.len)); - try self.write(payload); - } - - /// Preserves space for record header and returns buffer free space. - pub fn getPayload(self: *Writer) []u8 { - return self.buf[self.pos + record.header_len ..]; - } - - /// Preserves space for handshake header and returns buffer free space. - pub fn getHandshakePayload(self: *Writer) []u8 { - return self.buf[self.pos + 4 ..]; - } - - pub fn getWritten(self: *Writer) []const u8 { - return self.buf[0..self.pos]; - } - - pub fn getFree(self: *Writer) []u8 { - return self.buf[self.pos..]; - } - - pub fn writeEnumArray(self: *Writer, comptime E: type, tags: []const E) !void { - assert(@sizeOf(E) == 2); - try self.writeInt(@as(u16, @intCast(tags.len * 2))); - for (tags) |t| { - try self.writeEnum(t); - } - } - - pub fn writeExtension( - self: *Writer, - comptime et: proto.Extension, - tags: anytype, - ) !void { - try self.writeEnum(et); - if (et == .supported_versions) { - try self.writeInt(@as(u16, @intCast(tags.len * 2 + 1))); - try self.writeInt(@as(u8, @intCast(tags.len * 2))); - } else { - try self.writeInt(@as(u16, @intCast(tags.len * 2 + 2))); - try self.writeInt(@as(u16, @intCast(tags.len * 2))); - } - for (tags) |t| { - try self.writeEnum(t); - } - } - - pub fn writeKeyShare( - self: *Writer, - named_groups: []const proto.NamedGroup, - keys: []const []const u8, - ) !void { - assert(named_groups.len == keys.len); - try self.writeEnum(proto.Extension.key_share); - var l: usize = 0; - for (keys) |key| { - l += key.len + 4; - } - try self.writeInt(@as(u16, @intCast(l + 2))); - try self.writeInt(@as(u16, @intCast(l))); - for (named_groups, 0..) |ng, i| { - const key = keys[i]; - try self.writeEnum(ng); - try self.writeInt(@as(u16, @intCast(key.len))); - try self.write(key); - } - } - - pub fn writeServerName(self: *Writer, host: []const u8) !void { - const host_len: u16 = @intCast(host.len); - try self.writeEnum(proto.Extension.server_name); - try self.writeInt(host_len + 5); // byte length of extension payload - try self.writeInt(host_len + 3); // server_name_list byte count - try self.writeByte(0); // name type - try self.writeInt(host_len); - try self.write(host); - } -}; - -test "Writer" { - var buf: [16]u8 = undefined; - var w = Writer{ .buf = &buf }; - - try w.write("ab"); - try w.writeEnum(proto.Curve.named_curve); - try w.writeEnum(proto.NamedGroup.x25519); - try w.writeInt(@as(u16, 0x1234)); - try testing.expectEqualSlices(u8, &[_]u8{ 'a', 'b', 0x03, 0x00, 0x1d, 0x12, 0x34 }, w.getWritten()); -} diff --git a/src/http/async/tls.zig/rsa/der.zig b/src/http/async/tls.zig/rsa/der.zig deleted file mode 100644 index 743a65ad..00000000 --- a/src/http/async/tls.zig/rsa/der.zig +++ /dev/null @@ -1,467 +0,0 @@ -//! An encoding of ASN.1. -//! -//! Distinguised Encoding Rules as defined in X.690 and X.691. -//! -//! A version of Basic Encoding Rules (BER) where there is exactly ONE way to -//! represent non-constructed elements. This is useful for cryptographic signatures. -//! -//! Currently an implementation detail of the standard library not fit for public -//! use since it's missing an encoder. - -const std = @import("std"); -const builtin = @import("builtin"); - -pub const Index = usize; -const log = std.log.scoped(.der); - -/// A secure DER parser that: -/// - Does NOT read memory outside `bytes`. -/// - Does NOT return elements with slices outside `bytes`. -/// - Errors on values that do NOT follow DER rules. -/// - Lengths that could be represented in a shorter form. -/// - Booleans that are not 0xff or 0x00. -pub const Parser = struct { - bytes: []const u8, - index: Index = 0, - - pub const Error = Element.Error || error{ - UnexpectedElement, - InvalidIntegerEncoding, - Overflow, - NonCanonical, - }; - - pub fn expectBool(self: *Parser) Error!bool { - const ele = try self.expect(.universal, false, .boolean); - if (ele.slice.len() != 1) return error.InvalidBool; - - return switch (self.view(ele)[0]) { - 0x00 => false, - 0xff => true, - else => error.InvalidBool, - }; - } - - pub fn expectBitstring(self: *Parser) Error!BitString { - const ele = try self.expect(.universal, false, .bitstring); - const bytes = self.view(ele); - const right_padding = bytes[0]; - if (right_padding >= 8) return error.InvalidBitString; - return .{ - .bytes = bytes[1..], - .right_padding = @intCast(right_padding), - }; - } - - // TODO: return high resolution date time type instead of epoch seconds - pub fn expectDateTime(self: *Parser) Error!i64 { - const ele = try self.expect(.universal, false, null); - const bytes = self.view(ele); - switch (ele.identifier.tag) { - .utc_time => { - // Example: "YYMMDD000000Z" - if (bytes.len != 13) - return error.InvalidDateTime; - if (bytes[12] != 'Z') - return error.InvalidDateTime; - - var date: Date = undefined; - date.year = try parseTimeDigits(bytes[0..2], 0, 99); - date.year += if (date.year >= 50) 1900 else 2000; - date.month = try parseTimeDigits(bytes[2..4], 1, 12); - date.day = try parseTimeDigits(bytes[4..6], 1, 31); - const time = try parseTime(bytes[6..12]); - - return date.toEpochSeconds() + time.toSec(); - }, - .generalized_time => { - // Examples: - // "19920622123421Z" - // "19920722132100.3Z" - if (bytes.len < 15) - return error.InvalidDateTime; - - var date: Date = undefined; - date.year = try parseYear4(bytes[0..4]); - date.month = try parseTimeDigits(bytes[4..6], 1, 12); - date.day = try parseTimeDigits(bytes[6..8], 1, 31); - const time = try parseTime(bytes[8..14]); - - return date.toEpochSeconds() + time.toSec(); - }, - else => return error.InvalidDateTime, - } - } - - pub fn expectOid(self: *Parser) Error![]const u8 { - const oid = try self.expect(.universal, false, .object_identifier); - return self.view(oid); - } - - pub fn expectEnum(self: *Parser, comptime Enum: type) Error!Enum { - const oid = try self.expectOid(); - return Enum.oids.get(oid) orelse { - if (builtin.mode == .Debug) { - var buf: [256]u8 = undefined; - var stream = std.io.fixedBufferStream(&buf); - try @import("./oid.zig").decode(oid, stream.writer()); - log.warn("unknown oid {s} for enum {s}\n", .{ stream.getWritten(), @typeName(Enum) }); - } - return error.UnknownObjectId; - }; - } - - pub fn expectInt(self: *Parser, comptime T: type) Error!T { - const ele = try self.expectPrimitive(.integer); - const bytes = self.view(ele); - - const info = @typeInfo(T); - if (info != .Int) @compileError(@typeName(T) ++ " is not an int type"); - const Shift = std.math.Log2Int(u8); - - var result: std.meta.Int(.unsigned, info.Int.bits) = 0; - for (bytes, 0..) |b, index| { - const shifted = @shlWithOverflow(b, @as(Shift, @intCast(index * 8))); - if (shifted[1] == 1) return error.Overflow; - - result |= shifted[0]; - } - - return @bitCast(result); - } - - pub fn expectString(self: *Parser, allowed: std.EnumSet(String.Tag)) Error!String { - const ele = try self.expect(.universal, false, null); - switch (ele.identifier.tag) { - inline .string_utf8, - .string_numeric, - .string_printable, - .string_teletex, - .string_videotex, - .string_ia5, - .string_visible, - .string_universal, - .string_bmp, - => |t| { - const tagname = @tagName(t)["string_".len..]; - const tag = std.meta.stringToEnum(String.Tag, tagname) orelse unreachable; - if (allowed.contains(tag)) { - return String{ .tag = tag, .data = self.view(ele) }; - } - }, - else => {}, - } - return error.UnexpectedElement; - } - - pub fn expectPrimitive(self: *Parser, tag: ?Identifier.Tag) Error!Element { - var elem = try self.expect(.universal, false, tag); - if (tag == .integer and elem.slice.len() > 0) { - if (self.view(elem)[0] == 0) elem.slice.start += 1; - if (elem.slice.len() > 0 and self.view(elem)[0] == 0) return error.InvalidIntegerEncoding; - } - return elem; - } - - /// Remember to call `expectEnd` - pub fn expectSequence(self: *Parser) Error!Element { - return try self.expect(.universal, true, .sequence); - } - - /// Remember to call `expectEnd` - pub fn expectSequenceOf(self: *Parser) Error!Element { - return try self.expect(.universal, true, .sequence_of); - } - - pub fn expectEnd(self: *Parser, val: usize) Error!void { - if (self.index != val) return error.NonCanonical; // either forgot to parse end OR an attacker - } - - pub fn expect( - self: *Parser, - class: ?Identifier.Class, - constructed: ?bool, - tag: ?Identifier.Tag, - ) Error!Element { - if (self.index >= self.bytes.len) return error.EndOfStream; - - const res = try Element.init(self.bytes, self.index); - if (tag) |e| { - if (res.identifier.tag != e) return error.UnexpectedElement; - } - if (constructed) |e| { - if (res.identifier.constructed != e) return error.UnexpectedElement; - } - if (class) |e| { - if (res.identifier.class != e) return error.UnexpectedElement; - } - self.index = if (res.identifier.constructed) res.slice.start else res.slice.end; - return res; - } - - pub fn view(self: Parser, elem: Element) []const u8 { - return elem.slice.view(self.bytes); - } - - pub fn seek(self: *Parser, index: usize) void { - self.index = index; - } - - pub fn eof(self: *Parser) bool { - return self.index == self.bytes.len; - } -}; - -pub const Element = struct { - identifier: Identifier, - slice: Slice, - - pub const Slice = struct { - start: Index, - end: Index, - - pub fn len(self: Slice) Index { - return self.end - self.start; - } - - pub fn view(self: Slice, bytes: []const u8) []const u8 { - return bytes[self.start..self.end]; - } - }; - - pub const Error = error{ InvalidLength, EndOfStream }; - - pub fn init(bytes: []const u8, index: Index) Error!Element { - var stream = std.io.fixedBufferStream(bytes[index..]); - var reader = stream.reader(); - - const identifier = @as(Identifier, @bitCast(try reader.readByte())); - const size_or_len_size = try reader.readByte(); - - var start = index + 2; - // short form between 0-127 - if (size_or_len_size < 128) { - const end = start + size_or_len_size; - if (end > bytes.len) return error.InvalidLength; - - return .{ .identifier = identifier, .slice = .{ .start = start, .end = end } }; - } - - // long form between 0 and std.math.maxInt(u1024) - const len_size: u7 = @truncate(size_or_len_size); - start += len_size; - if (len_size > @sizeOf(Index)) return error.InvalidLength; - const len = try reader.readVarInt(Index, .big, len_size); - if (len < 128) return error.InvalidLength; // should have used short form - - const end = std.math.add(Index, start, len) catch return error.InvalidLength; - if (end > bytes.len) return error.InvalidLength; - - return .{ .identifier = identifier, .slice = .{ .start = start, .end = end } }; - } -}; - -test Element { - const short_form = [_]u8{ 0x30, 0x03, 0x02, 0x01, 0x09 }; - try std.testing.expectEqual(Element{ - .identifier = Identifier{ .tag = .sequence, .constructed = true, .class = .universal }, - .slice = .{ .start = 2, .end = short_form.len }, - }, Element.init(&short_form, 0)); - - const long_form = [_]u8{ 0x30, 129, 129 } ++ [_]u8{0} ** 129; - try std.testing.expectEqual(Element{ - .identifier = Identifier{ .tag = .sequence, .constructed = true, .class = .universal }, - .slice = .{ .start = 3, .end = long_form.len }, - }, Element.init(&long_form, 0)); -} - -test "parser.expectInt" { - const one = [_]u8{ 2, 1, 1 }; - var parser = Parser{ .bytes = &one }; - try std.testing.expectEqual(@as(u8, 1), try parser.expectInt(u8)); -} - -pub const Identifier = packed struct(u8) { - tag: Tag, - constructed: bool, - class: Class, - - pub const Class = enum(u2) { - universal, - application, - context_specific, - private, - }; - - // https://www.oss.com/asn1/resources/asn1-made-simple/asn1-quick-reference/asn1-tags.html - pub const Tag = enum(u5) { - boolean = 1, - integer = 2, - bitstring = 3, - octetstring = 4, - null = 5, - object_identifier = 6, - real = 9, - enumerated = 10, - string_utf8 = 12, - sequence = 16, - sequence_of = 17, - string_numeric = 18, - string_printable = 19, - string_teletex = 20, - string_videotex = 21, - string_ia5 = 22, - utc_time = 23, - generalized_time = 24, - string_visible = 26, - string_universal = 28, - string_bmp = 30, - _, - }; -}; - -pub const BitString = struct { - bytes: []const u8, - right_padding: u3, - - pub fn bitLen(self: BitString) usize { - return self.bytes.len * 8 + self.right_padding; - } -}; - -pub const String = struct { - tag: Tag, - data: []const u8, - - pub const Tag = enum { - /// Blessed. - utf8, - /// us-ascii ([-][0-9][eE][.])* - numeric, - /// us-ascii ([A-Z][a-z][0-9][.?!,][ \t])* - printable, - /// iso-8859-1 with escaping into different character sets. - /// Cursed. - teletex, - /// iso-8859-1 - videotex, - /// us-ascii first 128 characters. - ia5, - /// us-ascii without control characters. - visible, - /// utf-32-be - universal, - /// utf-16-be - bmp, - }; - - pub const all = [_]Tag{ - .utf8, - .numeric, - .printable, - .teletex, - .videotex, - .ia5, - .visible, - .universal, - .bmp, - }; -}; - -const Date = struct { - year: Year, - month: u8, - day: u8, - - const Year = std.time.epoch.Year; - - fn toEpochSeconds(date: Date) i64 { - // Euclidean Affine Transform by Cassio and Neri. - // Shift and correction constants for 1970-01-01. - const s = 82; - const K = 719468 + 146097 * s; - const L = 400 * s; - - const Y_G: u32 = date.year; - const M_G: u32 = date.month; - const D_G: u32 = date.day; - // Map to computational calendar. - const J: u32 = if (M_G <= 2) 1 else 0; - const Y: u32 = Y_G + L - J; - const M: u32 = if (J != 0) M_G + 12 else M_G; - const D: u32 = D_G - 1; - const C: u32 = Y / 100; - - // Rata die. - const y_star: u32 = 1461 * Y / 4 - C + C / 4; - const m_star: u32 = (979 * M - 2919) / 32; - const N: u32 = y_star + m_star + D; - const days: i32 = @intCast(N - K); - - return @as(i64, days) * std.time.epoch.secs_per_day; - } -}; - -const Time = struct { - hour: std.math.IntFittingRange(0, 24), - minute: std.math.IntFittingRange(0, 60), - second: std.math.IntFittingRange(0, 60), - - fn toSec(t: Time) i64 { - var sec: i64 = 0; - sec += @as(i64, t.hour) * 60 * 60; - sec += @as(i64, t.minute) * 60; - sec += t.second; - return sec; - } -}; - -fn parseTimeDigits( - text: *const [2]u8, - min: comptime_int, - max: comptime_int, -) !std.math.IntFittingRange(min, max) { - const result = std.fmt.parseInt(std.math.IntFittingRange(min, max), text, 10) catch - return error.InvalidTime; - if (result < min) return error.InvalidTime; - if (result > max) return error.InvalidTime; - return result; -} - -test parseTimeDigits { - const expectEqual = std.testing.expectEqual; - try expectEqual(@as(u8, 0), try parseTimeDigits("00", 0, 99)); - try expectEqual(@as(u8, 99), try parseTimeDigits("99", 0, 99)); - try expectEqual(@as(u8, 42), try parseTimeDigits("42", 0, 99)); - - const expectError = std.testing.expectError; - try expectError(error.InvalidTime, parseTimeDigits("13", 1, 12)); - try expectError(error.InvalidTime, parseTimeDigits("00", 1, 12)); - try expectError(error.InvalidTime, parseTimeDigits("Di", 0, 99)); -} - -fn parseYear4(text: *const [4]u8) !Date.Year { - const result = std.fmt.parseInt(Date.Year, text, 10) catch return error.InvalidYear; - if (result > 9999) return error.InvalidYear; - return result; -} - -test parseYear4 { - const expectEqual = std.testing.expectEqual; - try expectEqual(@as(Date.Year, 0), try parseYear4("0000")); - try expectEqual(@as(Date.Year, 9999), try parseYear4("9999")); - try expectEqual(@as(Date.Year, 1988), try parseYear4("1988")); - - const expectError = std.testing.expectError; - try expectError(error.InvalidYear, parseYear4("999b")); - try expectError(error.InvalidYear, parseYear4("crap")); - try expectError(error.InvalidYear, parseYear4("r:bQ")); -} - -fn parseTime(bytes: *const [6]u8) !Time { - return .{ - .hour = try parseTimeDigits(bytes[0..2], 0, 23), - .minute = try parseTimeDigits(bytes[2..4], 0, 59), - .second = try parseTimeDigits(bytes[4..6], 0, 59), - }; -} diff --git a/src/http/async/tls.zig/rsa/oid.zig b/src/http/async/tls.zig/rsa/oid.zig deleted file mode 100644 index fd360c3f..00000000 --- a/src/http/async/tls.zig/rsa/oid.zig +++ /dev/null @@ -1,132 +0,0 @@ -//! Developed by ITU-U and ISO/IEC for naming objects. Used in DER. -//! -//! This implementation supports any number of `u32` arcs. - -const Arc = u32; -const encoding_base = 128; - -/// Returns encoded length. -pub fn encodeLen(dot_notation: []const u8) !usize { - var split = std.mem.splitScalar(u8, dot_notation, '.'); - if (split.next() == null) return 0; - if (split.next() == null) return 1; - - var res: usize = 1; - while (split.next()) |s| { - const parsed = try std.fmt.parseUnsigned(Arc, s, 10); - const n_bytes = if (parsed == 0) 0 else std.math.log(Arc, encoding_base, parsed); - - res += n_bytes; - res += 1; - } - - return res; -} - -pub const EncodeError = std.fmt.ParseIntError || error{ - MissingPrefix, - BufferTooSmall, -}; - -pub fn encode(dot_notation: []const u8, buf: []u8) EncodeError![]const u8 { - if (buf.len < try encodeLen(dot_notation)) return error.BufferTooSmall; - - var split = std.mem.splitScalar(u8, dot_notation, '.'); - const first_str = split.next() orelse return error.MissingPrefix; - const second_str = split.next() orelse return error.MissingPrefix; - - const first = try std.fmt.parseInt(u8, first_str, 10); - const second = try std.fmt.parseInt(u8, second_str, 10); - - buf[0] = first * 40 + second; - - var i: usize = 1; - while (split.next()) |s| { - var parsed = try std.fmt.parseUnsigned(Arc, s, 10); - const n_bytes = if (parsed == 0) 0 else std.math.log(Arc, encoding_base, parsed); - - for (0..n_bytes) |j| { - const place = std.math.pow(Arc, encoding_base, n_bytes - @as(Arc, @intCast(j))); - const digit: u8 = @intCast(@divFloor(parsed, place)); - - buf[i] = digit | 0x80; - parsed -= digit * place; - - i += 1; - } - buf[i] = @intCast(parsed); - i += 1; - } - - return buf[0..i]; -} - -pub fn decode(encoded: []const u8, writer: anytype) @TypeOf(writer).Error!void { - const first = @divTrunc(encoded[0], 40); - const second = encoded[0] - first * 40; - try writer.print("{d}.{d}", .{ first, second }); - - var i: usize = 1; - while (i != encoded.len) { - const n_bytes: usize = brk: { - var res: usize = 1; - var j: usize = i; - while (encoded[j] & 0x80 != 0) { - res += 1; - j += 1; - } - break :brk res; - }; - - var n: usize = 0; - for (0..n_bytes) |j| { - const place = std.math.pow(usize, encoding_base, n_bytes - j - 1); - n += place * (encoded[i] & 0b01111111); - i += 1; - } - try writer.print(".{d}", .{n}); - } -} - -pub fn encodeComptime(comptime dot_notation: []const u8) [encodeLen(dot_notation) catch unreachable]u8 { - @setEvalBranchQuota(10_000); - var buf: [encodeLen(dot_notation) catch unreachable]u8 = undefined; - _ = encode(dot_notation, &buf) catch unreachable; - return buf; -} - -const std = @import("std"); - -fn testOid(expected_encoded: []const u8, expected_dot_notation: []const u8) !void { - var buf: [256]u8 = undefined; - const encoded = try encode(expected_dot_notation, &buf); - try std.testing.expectEqualSlices(u8, expected_encoded, encoded); - - var stream = std.io.fixedBufferStream(&buf); - try decode(expected_encoded, stream.writer()); - try std.testing.expectEqualStrings(expected_dot_notation, stream.getWritten()); -} - -test "encode and decode" { - // https://learn.microsoft.com/en-us/windows/win32/seccertenroll/about-object-identifier - try testOid( - &[_]u8{ 0x2b, 0x06, 0x01, 0x04, 0x01, 0x82, 0x37, 0x15, 0x14 }, - "1.3.6.1.4.1.311.21.20", - ); - // https://luca.ntop.org/Teaching/Appunti/asn1.html - try testOid(&[_]u8{ 0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d }, "1.2.840.113549"); - // https://www.sysadmins.lv/blog-en/how-to-encode-object-identifier-to-an-asn1-der-encoded-string.aspx - try testOid(&[_]u8{ 0x2a, 0x86, 0x8d, 0x20 }, "1.2.100000"); - try testOid( - &[_]u8{ 0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d, 0x01, 0x01, 0x0b }, - "1.2.840.113549.1.1.11", - ); - try testOid(&[_]u8{ 0x2b, 0x65, 0x70 }, "1.3.101.112"); -} - -test encodeComptime { - try std.testing.expectEqual( - [_]u8{ 0x2b, 0x06, 0x01, 0x04, 0x01, 0x82, 0x37, 0x15, 0x14 }, - encodeComptime("1.3.6.1.4.1.311.21.20"), - ); -} diff --git a/src/http/async/tls.zig/rsa/rsa.zig b/src/http/async/tls.zig/rsa/rsa.zig deleted file mode 100644 index 5e5f42fe..00000000 --- a/src/http/async/tls.zig/rsa/rsa.zig +++ /dev/null @@ -1,880 +0,0 @@ -//! RFC8017: Public Key Cryptography Standards #1 v2.2 (PKCS1) -const std = @import("std"); -const der = @import("der.zig"); -const ff = std.crypto.ff; - -pub const max_modulus_bits = 4096; -const max_modulus_len = max_modulus_bits / 8; - -const Modulus = std.crypto.ff.Modulus(max_modulus_bits); -const Fe = Modulus.Fe; - -pub const ValueError = error{ - Modulus, - Exponent, -}; - -pub const PublicKey = struct { - /// `n` - modulus: Modulus, - /// `e` - public_exponent: Fe, - - pub const FromBytesError = ValueError || ff.OverflowError || ff.FieldElementError || ff.InvalidModulusError || error{InsecureBitCount}; - - pub fn fromBytes(mod: []const u8, exp: []const u8) FromBytesError!PublicKey { - const modulus = try Modulus.fromBytes(mod, .big); - if (modulus.bits() <= 512) return error.InsecureBitCount; - const public_exponent = try Fe.fromBytes(modulus, exp, .big); - - if (std.debug.runtime_safety) { - // > the RSA public exponent e is an integer between 3 and n - 1 satisfying - // > GCD(e,\lambda(n)) = 1, where \lambda(n) = LCM(r_1 - 1, ..., r_u - 1) - const e_v = public_exponent.toPrimitive(u32) catch return error.Exponent; - if (!public_exponent.isOdd()) return error.Exponent; - if (e_v < 3) return error.Exponent; - if (modulus.v.compare(public_exponent.v) == .lt) return error.Exponent; - } - - return .{ .modulus = modulus, .public_exponent = public_exponent }; - } - - pub fn fromDer(bytes: []const u8) (der.Parser.Error || FromBytesError)!PublicKey { - var parser = der.Parser{ .bytes = bytes }; - - const seq = try parser.expectSequence(); - defer parser.seek(seq.slice.end); - - const modulus = try parser.expectPrimitive(.integer); - const pub_exp = try parser.expectPrimitive(.integer); - - try parser.expectEnd(seq.slice.end); - try parser.expectEnd(bytes.len); - - return try fromBytes(parser.view(modulus), parser.view(pub_exp)); - } - - /// Deprecated. - /// - /// Encrypt a short message using RSAES-PKCS1-v1_5. - /// The use of this scheme for encrypting an arbitrary message, as opposed to a - /// randomly generated key, is NOT RECOMMENDED. - pub fn encryptPkcsv1_5(pk: PublicKey, msg: []const u8, out: []u8) ![]const u8 { - // align variable names with spec - const k = byteLen(pk.modulus.bits()); - if (out.len < k) return error.BufferTooSmall; - if (msg.len > k - 11) return error.MessageTooLong; - - // EM = 0x00 || 0x02 || PS || 0x00 || M. - var em = out[0..k]; - em[0] = 0; - em[1] = 2; - - const ps = em[2..][0 .. k - msg.len - 3]; - // Section: 7.2.1 - // PS consists of pseudo-randomly generated nonzero octets. - for (ps) |*v| { - v.* = std.crypto.random.uintLessThan(u8, 0xff) + 1; - } - - em[em.len - msg.len - 1] = 0; - @memcpy(em[em.len - msg.len ..][0..msg.len], msg); - - const m = try Fe.fromBytes(pk.modulus, em, .big); - const e = try pk.modulus.powPublic(m, pk.public_exponent); - try e.toBytes(em, .big); - return em; - } - - /// Encrypt a short message using Optimal Asymmetric Encryption Padding (RSAES-OAEP). - pub fn encryptOaep( - pk: PublicKey, - comptime Hash: type, - msg: []const u8, - label: []const u8, - out: []u8, - ) ![]const u8 { - // align variable names with spec - const k = byteLen(pk.modulus.bits()); - if (out.len < k) return error.BufferTooSmall; - - if (msg.len > k - 2 * Hash.digest_length - 2) return error.MessageTooLong; - - // EM = 0x00 || maskedSeed || maskedDB. - var em = out[0..k]; - em[0] = 0; - const seed = em[1..][0..Hash.digest_length]; - std.crypto.random.bytes(seed); - - // DB = lHash || PS || 0x01 || M. - var db = em[1 + seed.len ..]; - const lHash = labelHash(Hash, label); - @memcpy(db[0..lHash.len], &lHash); - @memset(db[lHash.len .. db.len - msg.len - 2], 0); - db[db.len - msg.len - 1] = 1; - @memcpy(db[db.len - msg.len ..], msg); - - var mgf_buf: [max_modulus_len]u8 = undefined; - - const db_mask = mgf1(Hash, seed, mgf_buf[0..db.len]); - for (db, db_mask) |*v, m| v.* ^= m; - - const seed_mask = mgf1(Hash, db, mgf_buf[0..seed.len]); - for (seed, seed_mask) |*v, m| v.* ^= m; - - const m = try Fe.fromBytes(pk.modulus, em, .big); - const e = try pk.modulus.powPublic(m, pk.public_exponent); - try e.toBytes(em, .big); - return em; - } -}; - -pub fn byteLen(bits: usize) usize { - return std.math.divCeil(usize, bits, 8) catch unreachable; -} - -pub const SecretKey = struct { - /// `d` - private_exponent: Fe, - - pub const FromBytesError = ValueError || ff.OverflowError || ff.FieldElementError; - - pub fn fromBytes(n: Modulus, exp: []const u8) FromBytesError!SecretKey { - const d = try Fe.fromBytes(n, exp, .big); - if (std.debug.runtime_safety) { - // > The RSA private exponent d is a positive integer less than n - // > satisfying e * d == 1 (mod \lambda(n)), - if (!d.isOdd()) return error.Exponent; - if (d.v.compare(n.v) != .lt) return error.Exponent; - } - - return .{ .private_exponent = d }; - } -}; - -pub const KeyPair = struct { - public: PublicKey, - secret: SecretKey, - - pub const FromDerError = PublicKey.FromBytesError || SecretKey.FromBytesError || der.Parser.Error || error{ KeyMismatch, InvalidVersion }; - - pub fn fromDer(bytes: []const u8) FromDerError!KeyPair { - var parser = der.Parser{ .bytes = bytes }; - const seq = try parser.expectSequence(); - const version = try parser.expectInt(u8); - - const mod = try parser.expectPrimitive(.integer); - const pub_exp = try parser.expectPrimitive(.integer); - const sec_exp = try parser.expectPrimitive(.integer); - - const public = try PublicKey.fromBytes(parser.view(mod), parser.view(pub_exp)); - const secret = try SecretKey.fromBytes(public.modulus, parser.view(sec_exp)); - - const prime1 = try parser.expectPrimitive(.integer); - const prime2 = try parser.expectPrimitive(.integer); - const exp1 = try parser.expectPrimitive(.integer); - const exp2 = try parser.expectPrimitive(.integer); - const coeff = try parser.expectPrimitive(.integer); - _ = .{ exp1, exp2, coeff }; - - switch (version) { - 0 => {}, - 1 => { - _ = try parser.expectSequenceOf(); - while (!parser.eof()) { - _ = try parser.expectSequence(); - const ri = try parser.expectPrimitive(.integer); - const di = try parser.expectPrimitive(.integer); - const ti = try parser.expectPrimitive(.integer); - _ = .{ ri, di, ti }; - } - }, - else => return error.InvalidVersion, - } - - try parser.expectEnd(seq.slice.end); - try parser.expectEnd(bytes.len); - - if (std.debug.runtime_safety) { - const p = try Fe.fromBytes(public.modulus, parser.view(prime1), .big); - const q = try Fe.fromBytes(public.modulus, parser.view(prime2), .big); - - // check that n = p * q - const expected_zero = public.modulus.mul(p, q); - if (!expected_zero.isZero()) return error.KeyMismatch; - - // TODO: check that d * e is one mod p-1 and mod q-1. Note d and e were bound - // const de = secret.private_exponent.mul(public.public_exponent); - // const one = public.modulus.one(); - - // if (public.modulus.mul(de, p).compare(one) != .eq) return error.KeyMismatch; - // if (public.modulus.mul(de, q).compare(one) != .eq) return error.KeyMismatch; - } - - return .{ .public = public, .secret = secret }; - } - - /// Deprecated. - pub fn signPkcsv1_5(kp: KeyPair, comptime Hash: type, msg: []const u8, out: []u8) !PKCS1v1_5(Hash).Signature { - var st = try signerPkcsv1_5(kp, Hash); - st.update(msg); - return try st.finalize(out); - } - - /// Deprecated. - pub fn signerPkcsv1_5(kp: KeyPair, comptime Hash: type) !PKCS1v1_5(Hash).Signer { - return PKCS1v1_5(Hash).Signer.init(kp); - } - - /// Deprecated. - pub fn decryptPkcsv1_5(kp: KeyPair, ciphertext: []const u8, out: []u8) ![]const u8 { - const k = byteLen(kp.public.modulus.bits()); - if (out.len < k) return error.BufferTooSmall; - - const em = out[0..k]; - - const m = try Fe.fromBytes(kp.public.modulus, ciphertext, .big); - const e = try kp.public.modulus.pow(m, kp.secret.private_exponent); - try e.toBytes(em, .big); - - // Care shall be taken to ensure that an opponent cannot - // distinguish these error conditions, whether by error - // message or timing. - const msg_start = ct.lastIndexOfScalar(em, 0) orelse em.len; - const ps_len = em.len - msg_start; - if (ct.@"or"(em[0] != 0, ct.@"or"(em[1] != 2, ps_len < 8))) { - return error.Inconsistent; - } - - return em[msg_start + 1 ..]; - } - - pub fn signOaep( - kp: KeyPair, - comptime Hash: type, - msg: []const u8, - salt: ?[]const u8, - out: []u8, - ) !Pss(Hash).Signature { - var st = try signerOaep(kp, Hash, salt); - st.update(msg); - return try st.finalize(out); - } - - /// Salt must outlive returned `PSS.Signer`. - pub fn signerOaep(kp: KeyPair, comptime Hash: type, salt: ?[]const u8) !Pss(Hash).Signer { - return Pss(Hash).Signer.init(kp, salt); - } - - pub fn decryptOaep( - kp: KeyPair, - comptime Hash: type, - ciphertext: []const u8, - label: []const u8, - out: []u8, - ) ![]u8 { - // align variable names with spec - const k = byteLen(kp.public.modulus.bits()); - if (out.len < k) return error.BufferTooSmall; - - const mod = try Fe.fromBytes(kp.public.modulus, ciphertext, .big); - const exp = kp.public.modulus.pow(mod, kp.secret.private_exponent) catch unreachable; - const em = out[0..k]; - try exp.toBytes(em, .big); - - const y = em[0]; - const seed = em[1..][0..Hash.digest_length]; - const db = em[1 + Hash.digest_length ..]; - - var mgf_buf: [max_modulus_len]u8 = undefined; - - const seed_mask = mgf1(Hash, db, mgf_buf[0..seed.len]); - for (seed, seed_mask) |*v, m| v.* ^= m; - - const db_mask = mgf1(Hash, seed, mgf_buf[0..db.len]); - for (db, db_mask) |*v, m| v.* ^= m; - - const expected_hash = labelHash(Hash, label); - const actual_hash = db[0..expected_hash.len]; - - // Care shall be taken to ensure that an opponent cannot - // distinguish these error conditions, whether by error - // message or timing. - const msg_start = ct.indexOfScalarPos(em, expected_hash.len + 1, 1) orelse 0; - if (ct.@"or"(y != 0, ct.@"or"(msg_start == 0, !ct.memEql(&expected_hash, actual_hash)))) { - return error.Inconsistent; - } - - return em[msg_start + 1 ..]; - } - - /// Encrypt short plaintext with secret key. - pub fn encrypt(kp: KeyPair, plaintext: []const u8, out: []u8) !void { - const n = kp.public.modulus; - const k = byteLen(n.bits()); - if (plaintext.len > k) return error.MessageTooLong; - - const msg_as_int = try Fe.fromBytes(n, plaintext, .big); - const enc_as_int = try n.pow(msg_as_int, kp.secret.private_exponent); - try enc_as_int.toBytes(out, .big); - } -}; - -/// Deprecated. -/// -/// Signature Scheme with Appendix v1.5 (RSASSA-PKCS1-v1_5) -/// -/// This standard has been superceded by PSS which is formally proven secure -/// and has fewer footguns. -pub fn PKCS1v1_5(comptime Hash: type) type { - return struct { - const PkcsT = @This(); - pub const Signature = struct { - bytes: []const u8, - - const Self = @This(); - - pub fn verifier(self: Self, public_key: PublicKey) !Verifier { - return Verifier.init(self, public_key); - } - - pub fn verify(self: Self, msg: []const u8, public_key: PublicKey) !void { - var st = Verifier.init(self, public_key); - st.update(msg); - return st.verify(); - } - }; - - pub const Signer = struct { - h: Hash, - key_pair: KeyPair, - - fn init(key_pair: KeyPair) Signer { - return .{ - .h = Hash.init(.{}), - .key_pair = key_pair, - }; - } - - pub fn update(self: *Signer, data: []const u8) void { - self.h.update(data); - } - - pub fn finalize(self: *Signer, out: []u8) !PkcsT.Signature { - const k = byteLen(self.key_pair.public.modulus.bits()); - if (out.len < k) return error.BufferTooSmall; - - var hash: [Hash.digest_length]u8 = undefined; - self.h.final(&hash); - - const em = try emsaEncode(hash, out[0..k]); - try self.key_pair.encrypt(em, em); - return .{ .bytes = em }; - } - }; - - pub const Verifier = struct { - h: Hash, - sig: PkcsT.Signature, - public_key: PublicKey, - - fn init(sig: PkcsT.Signature, public_key: PublicKey) Verifier { - return Verifier{ - .h = Hash.init(.{}), - .sig = sig, - .public_key = public_key, - }; - } - - pub fn update(self: *Verifier, data: []const u8) void { - self.h.update(data); - } - - pub fn verify(self: *Verifier) !void { - const pk = self.public_key; - const s = try Fe.fromBytes(pk.modulus, self.sig.bytes, .big); - const emm = try pk.modulus.powPublic(s, pk.public_exponent); - - var em_buf: [max_modulus_len]u8 = undefined; - const em = em_buf[0..byteLen(pk.modulus.bits())]; - try emm.toBytes(em, .big); - - var hash: [Hash.digest_length]u8 = undefined; - self.h.final(&hash); - - // TODO: compare hash values instead of emsa values - const expected = try emsaEncode(hash, em); - - if (!std.mem.eql(u8, expected, em)) return error.Inconsistent; - } - }; - - /// PKCS Encrypted Message Signature Appendix - fn emsaEncode(hash: [Hash.digest_length]u8, out: []u8) ![]u8 { - const digest_header = comptime digestHeader(); - const tLen = digest_header.len + Hash.digest_length; - const emLen = out.len; - if (emLen < tLen + 11) return error.ModulusTooShort; - if (out.len < emLen) return error.BufferTooSmall; - - var res = out[0..emLen]; - res[0] = 0; - res[1] = 1; - const padding_len = emLen - tLen - 3; - @memset(res[2..][0..padding_len], 0xff); - res[2 + padding_len] = 0; - @memcpy(res[2 + padding_len + 1 ..][0..digest_header.len], digest_header); - @memcpy(res[res.len - hash.len ..], &hash); - - return res; - } - - /// DER encoded header. Sequence of digest algo + digest. - /// TODO: use a DER encoder instead - fn digestHeader() []const u8 { - const sha2 = std.crypto.hash.sha2; - // Section 9.2 Notes 1. - return switch (Hash) { - std.crypto.hash.Sha1 => &hexToBytes( - \\30 21 30 09 06 05 2b 0e 03 02 1a 05 00 04 14 - ), - sha2.Sha224 => &hexToBytes( - \\30 2d 30 0d 06 09 60 86 48 01 65 03 04 02 04 - \\05 00 04 1c - ), - sha2.Sha256 => &hexToBytes( - \\30 31 30 0d 06 09 60 86 48 01 65 03 04 02 01 05 00 - \\04 20 - ), - sha2.Sha384 => &hexToBytes( - \\30 41 30 0d 06 09 60 86 48 01 65 03 04 02 02 05 00 - \\04 30 - ), - sha2.Sha512 => &hexToBytes( - \\30 51 30 0d 06 09 60 86 48 01 65 03 04 02 03 05 00 - \\04 40 - ), - // sha2.Sha512224 => &hexToBytes( - // \\30 2d 30 0d 06 09 60 86 48 01 65 03 04 02 05 - // \\05 00 04 1c - // ), - // sha2.Sha512256 => &hexToBytes( - // \\30 31 30 0d 06 09 60 86 48 01 65 03 04 02 06 - // \\05 00 04 20 - // ), - else => @compileError("unknown Hash " ++ @typeName(Hash)), - }; - } - }; -} - -/// Probabilistic Signature Scheme (RSASSA-PSS) -pub fn Pss(comptime Hash: type) type { - // RFC 4055 S3.1 - const default_salt_len = Hash.digest_length; - return struct { - pub const Signature = struct { - bytes: []const u8, - - const Self = @This(); - - pub fn verifier(self: Self, public_key: PublicKey) !Verifier { - return Verifier.init(self, public_key); - } - - pub fn verify(self: Self, msg: []const u8, public_key: PublicKey, salt_len: ?usize) !void { - var st = Verifier.init(self, public_key, salt_len orelse default_salt_len); - st.update(msg); - return st.verify(); - } - }; - - const PssT = @This(); - - pub const Signer = struct { - h: Hash, - key_pair: KeyPair, - salt: ?[]const u8, - - fn init(key_pair: KeyPair, salt: ?[]const u8) Signer { - return .{ - .h = Hash.init(.{}), - .key_pair = key_pair, - .salt = salt, - }; - } - - pub fn update(self: *Signer, data: []const u8) void { - self.h.update(data); - } - - pub fn finalize(self: *Signer, out: []u8) !PssT.Signature { - var hashed: [Hash.digest_length]u8 = undefined; - self.h.final(&hashed); - - const salt = if (self.salt) |s| s else brk: { - var res: [default_salt_len]u8 = undefined; - std.crypto.random.bytes(&res); - break :brk &res; - }; - - const em_bits = self.key_pair.public.modulus.bits() - 1; - const em = try emsaEncode(hashed, salt, em_bits, out); - try self.key_pair.encrypt(em, em); - return .{ .bytes = em }; - } - }; - - pub const Verifier = struct { - h: Hash, - sig: PssT.Signature, - public_key: PublicKey, - salt_len: usize, - - fn init(sig: PssT.Signature, public_key: PublicKey, salt_len: usize) Verifier { - return Verifier{ - .h = Hash.init(.{}), - .sig = sig, - .public_key = public_key, - .salt_len = salt_len, - }; - } - - pub fn update(self: *Verifier, data: []const u8) void { - self.h.update(data); - } - - pub fn verify(self: *Verifier) !void { - const pk = self.public_key; - const s = try Fe.fromBytes(pk.modulus, self.sig.bytes, .big); - const emm = try pk.modulus.powPublic(s, pk.public_exponent); - - var em_buf: [max_modulus_len]u8 = undefined; - const em_bits = pk.modulus.bits() - 1; - const em_len = std.math.divCeil(usize, em_bits, 8) catch unreachable; - var em = em_buf[0..em_len]; - try emm.toBytes(em, .big); - - if (em.len < Hash.digest_length + self.salt_len + 2) return error.Inconsistent; - if (em[em.len - 1] != 0xbc) return error.Inconsistent; - - const db = em[0 .. em.len - Hash.digest_length - 1]; - if (@clz(db[0]) < em.len * 8 - em_bits) return error.Inconsistent; - - const expected_hash = em[db.len..][0..Hash.digest_length]; - var mgf_buf: [max_modulus_len]u8 = undefined; - const db_mask = mgf1(Hash, expected_hash, mgf_buf[0..db.len]); - for (db, db_mask) |*v, m| v.* ^= m; - - for (1..db.len - self.salt_len - 1) |i| { - if (db[i] != 0) return error.Inconsistent; - } - if (db[db.len - self.salt_len - 1] != 1) return error.Inconsistent; - const salt = db[db.len - self.salt_len ..]; - var mp_buf: [max_modulus_len]u8 = undefined; - var mp = mp_buf[0 .. 8 + Hash.digest_length + self.salt_len]; - @memset(mp[0..8], 0); - self.h.final(mp[8..][0..Hash.digest_length]); - @memcpy(mp[8 + Hash.digest_length ..][0..salt.len], salt); - - var actual_hash: [Hash.digest_length]u8 = undefined; - Hash.hash(mp, &actual_hash, .{}); - - if (!std.mem.eql(u8, expected_hash, &actual_hash)) return error.Inconsistent; - } - }; - - /// PSS Encrypted Message Signature Appendix - fn emsaEncode(msg_hash: [Hash.digest_length]u8, salt: []const u8, em_bits: usize, out: []u8) ![]u8 { - const em_len = std.math.divCeil(usize, em_bits, 8) catch unreachable; - - if (em_len < Hash.digest_length + salt.len + 2) return error.Encoding; - - // EM = maskedDB || H || 0xbc - var em = out[0..em_len]; - em[em.len - 1] = 0xbc; - - var mp_buf: [max_modulus_len]u8 = undefined; - // M' = (0x)00 00 00 00 00 00 00 00 || mHash || salt; - const mp = mp_buf[0 .. 8 + Hash.digest_length + salt.len]; - @memset(mp[0..8], 0); - @memcpy(mp[8..][0..Hash.digest_length], &msg_hash); - @memcpy(mp[8 + Hash.digest_length ..][0..salt.len], salt); - - // H = Hash(M') - const hash = em[em.len - 1 - Hash.digest_length ..][0..Hash.digest_length]; - Hash.hash(mp, hash, .{}); - - // DB = PS || 0x01 || salt - var db = em[0 .. em_len - Hash.digest_length - 1]; - @memset(db[0 .. db.len - salt.len - 1], 0); - db[db.len - salt.len - 1] = 1; - @memcpy(db[db.len - salt.len ..], salt); - - var mgf_buf: [max_modulus_len]u8 = undefined; - const db_mask = mgf1(Hash, hash, mgf_buf[0..db.len]); - for (db, db_mask) |*v, m| v.* ^= m; - - // Set the leftmost 8emLen - emBits bits of the leftmost octet - // in maskedDB to zero. - const shift = std.math.comptimeMod(8 * em_len - em_bits, 8); - const mask = @as(u8, 0xff) >> shift; - db[0] &= mask; - - return em; - } - }; -} - -/// Mask generation function. Currently the only one defined. -fn mgf1(comptime Hash: type, seed: []const u8, out: []u8) []u8 { - var c: [@sizeOf(u32)]u8 = undefined; - var tmp: [Hash.digest_length]u8 = undefined; - - var i: usize = 0; - var counter: u32 = 0; - while (i < out.len) : (counter += 1) { - var hasher = Hash.init(.{}); - hasher.update(seed); - std.mem.writeInt(u32, &c, counter, .big); - hasher.update(&c); - - const left = out.len - i; - if (left >= Hash.digest_length) { - // optimization: write straight to `out` - hasher.final(out[i..][0..Hash.digest_length]); - i += Hash.digest_length; - } else { - hasher.final(&tmp); - @memcpy(out[i..][0..left], tmp[0..left]); - i += left; - } - } - - return out; -} - -test mgf1 { - const Hash = std.crypto.hash.sha2.Sha256; - var out: [Hash.digest_length * 2 + 1]u8 = undefined; - try std.testing.expectEqualSlices( - u8, - &hexToBytes( - \\ed 1b 84 6b b9 26 39 00 c8 17 82 ad 08 eb 17 01 - \\fa 8c 72 21 c6 57 63 77 31 7f 5c e8 09 89 9f - ), - mgf1(Hash, "asdf", out[0 .. Hash.digest_length - 1]), - ); - try std.testing.expectEqualSlices( - u8, - &hexToBytes( - \\ed 1b 84 6b b9 26 39 00 c8 17 82 ad 08 eb 17 01 - \\fa 8c 72 21 c6 57 63 77 31 7f 5c e8 09 89 9f 5a - \\22 F2 80 D5 28 08 F4 93 83 76 00 DE 09 E4 EC 92 - \\4A 2C 7C EF 0D F7 7B BE 8F 7F 12 CB 8F 33 A6 65 - \\AB - ), - mgf1(Hash, "asdf", &out), - ); -} - -/// For OAEP. -inline fn labelHash(comptime Hash: type, label: []const u8) [Hash.digest_length]u8 { - if (label.len == 0) { - // magic constants from NIST - const sha2 = std.crypto.hash.sha2; - switch (Hash) { - std.crypto.hash.Sha1 => return hexToBytes( - \\da39a3ee 5e6b4b0d 3255bfef 95601890 - \\afd80709 - ), - sha2.Sha256 => return hexToBytes( - \\e3b0c442 98fc1c14 9afbf4c8 996fb924 - \\27ae41e4 649b934c a495991b 7852b855 - ), - sha2.Sha384 => return hexToBytes( - \\38b060a7 51ac9638 4cd9327e b1b1e36a - \\21fdb711 14be0743 4c0cc7bf 63f6e1da - \\274edebf e76f65fb d51ad2f1 4898b95b - ), - sha2.Sha512 => return hexToBytes( - \\cf83e135 7eefb8bd f1542850 d66d8007 - \\d620e405 0b5715dc 83f4a921 d36ce9ce - \\47d0d13c 5d85f2b0 ff8318d2 877eec2f - \\63b931bd 47417a81 a538327a f927da3e - ), - // just use the empty hash... - else => {}, - } - } - var res: [Hash.digest_length]u8 = undefined; - Hash.hash(label, &res, .{}); - return res; -} - -const ct = if (std.options.side_channels_mitigations == .none) ct_unprotected else ct_protected; - -const ct_unprotected = struct { - fn lastIndexOfScalar(slice: []const u8, value: u8) ?usize { - return std.mem.lastIndexOfScalar(u8, slice, value); - } - - fn indexOfScalarPos(slice: []const u8, start_index: usize, value: u8) ?usize { - return std.mem.indexOfScalarPos(u8, slice, start_index, value); - } - - fn memEql(a: []const u8, b: []const u8) bool { - return std.mem.eql(u8, a, b); - } - - fn @"and"(a: bool, b: bool) bool { - return a and b; - } - - fn @"or"(a: bool, b: bool) bool { - return a or b; - } -}; - -const ct_protected = struct { - fn lastIndexOfScalar(slice: []const u8, value: u8) ?usize { - var res: ?usize = null; - var i: usize = slice.len; - while (i != 0) { - i -= 1; - if (@intFromBool(res == null) & @intFromBool(slice[i] == value) == 1) res = i; - } - return res; - } - - fn indexOfScalarPos(slice: []const u8, start_index: usize, value: u8) ?usize { - var res: ?usize = null; - for (slice[start_index..], start_index..) |c, j| { - if (c == value) res = j; - } - return res; - } - - fn memEql(a: []const u8, b: []const u8) bool { - var res: u1 = 1; - for (a, b) |a_elem, b_elem| { - res &= @intFromBool(a_elem == b_elem); - } - return res == 1; - } - - fn @"and"(a: bool, b: bool) bool { - return (@intFromBool(a) & @intFromBool(b)) == 1; - } - - fn @"or"(a: bool, b: bool) bool { - return (@intFromBool(a) | @intFromBool(b)) == 1; - } -}; - -test ct { - const c = ct_unprotected; - try std.testing.expectEqual(true, c.@"or"(true, false)); - try std.testing.expectEqual(true, c.@"and"(true, true)); - try std.testing.expectEqual(true, c.memEql("Asdf", "Asdf")); - try std.testing.expectEqual(false, c.memEql("asdf", "Asdf")); - try std.testing.expectEqual(3, c.indexOfScalarPos("asdff", 1, 'f')); - try std.testing.expectEqual(4, c.lastIndexOfScalar("asdff", 'f')); -} - -fn removeNonHex(comptime hex: []const u8) []const u8 { - var res: [hex.len]u8 = undefined; - var i: usize = 0; - for (hex) |c| { - if (std.ascii.isHex(c)) { - res[i] = c; - i += 1; - } - } - return res[0..i]; -} - -/// For readable copy/pasting from hex viewers. -fn hexToBytes(comptime hex: []const u8) [removeNonHex(hex).len / 2]u8 { - const hex2 = comptime removeNonHex(hex); - comptime var res: [hex2.len / 2]u8 = undefined; - _ = comptime std.fmt.hexToBytes(&res, hex2) catch unreachable; - return res; -} - -test hexToBytes { - const hex = - \\e3b0c442 98fc1c14 9afbf4c8 996fb924 - \\27ae41e4 649b934c a495991b 7852b855 - ; - try std.testing.expectEqual( - [_]u8{ - 0xe3, 0xb0, 0xc4, 0x42, 0x98, 0xfc, 0x1c, 0x14, - 0x9a, 0xfb, 0xf4, 0xc8, 0x99, 0x6f, 0xb9, 0x24, - 0x27, 0xae, 0x41, 0xe4, 0x64, 0x9b, 0x93, 0x4c, - 0xa4, 0x95, 0x99, 0x1b, 0x78, 0x52, 0xb8, 0x55, - }, - hexToBytes(hex), - ); -} - -const TestHash = std.crypto.hash.sha2.Sha256; -fn testKeypair() !KeyPair { - const keypair_bytes = @embedFile("testdata/id_rsa.der"); - const kp = try KeyPair.fromDer(keypair_bytes); - try std.testing.expectEqual(2048, kp.public.modulus.bits()); - return kp; -} - -test "rsa PKCS1-v1_5 encrypt and decrypt" { - const kp = try testKeypair(); - - const msg = "rsa PKCS1-v1_5 encrypt and decrypt"; - var out: [max_modulus_len]u8 = undefined; - const enc = try kp.public.encryptPkcsv1_5(msg, &out); - - var out2: [max_modulus_len]u8 = undefined; - const dec = try kp.decryptPkcsv1_5(enc, &out2); - - try std.testing.expectEqualSlices(u8, msg, dec); -} - -test "rsa OAEP encrypt and decrypt" { - const kp = try testKeypair(); - - const msg = "rsa OAEP encrypt and decrypt"; - const label = ""; - var out: [max_modulus_len]u8 = undefined; - const enc = try kp.public.encryptOaep(TestHash, msg, label, &out); - - var out2: [max_modulus_len]u8 = undefined; - const dec = try kp.decryptOaep(TestHash, enc, label, &out2); - - try std.testing.expectEqualSlices(u8, msg, dec); -} - -test "rsa PKCS1-v1_5 signature" { - const kp = try testKeypair(); - - const msg = "rsa PKCS1-v1_5 signature"; - var out: [max_modulus_len]u8 = undefined; - - const signature = try kp.signPkcsv1_5(TestHash, msg, &out); - try signature.verify(msg, kp.public); -} - -test "rsa PSS signature" { - const kp = try testKeypair(); - - const msg = "rsa PSS signature"; - var out: [max_modulus_len]u8 = undefined; - - const salts = [_][]const u8{ "asdf", "" }; - for (salts) |salt| { - const signature = try kp.signOaep(TestHash, msg, salt, &out); - try signature.verify(msg, kp.public, salt.len); - } - - const signature = try kp.signOaep(TestHash, msg, null, &out); // random salt - try signature.verify(msg, kp.public, null); -} diff --git a/src/http/async/tls.zig/rsa/testdata/id_rsa.der b/src/http/async/tls.zig/rsa/testdata/id_rsa.der deleted file mode 100644 index 9e4f1334..00000000 Binary files a/src/http/async/tls.zig/rsa/testdata/id_rsa.der and /dev/null differ diff --git a/src/http/async/tls.zig/testdata/ec_prime256v1_private_key.pem b/src/http/async/tls.zig/testdata/ec_prime256v1_private_key.pem deleted file mode 100644 index 67ebf388..00000000 --- a/src/http/async/tls.zig/testdata/ec_prime256v1_private_key.pem +++ /dev/null @@ -1,5 +0,0 @@ ------BEGIN EC PRIVATE KEY----- -MHcCAQEEINJSRKv8kSKEzLHptfAlg+LGh4/pHHlq0XLf30Q9pcztoAoGCCqGSM49 -AwEHoUQDQgAEJpmLyp8aGCgyMcFIJaIq/+4V1K6nPpeoih3bT2npeplF9eyXj7rm -8eW9Ua6VLhq71mqtMC+YLm+IkORBVq1cuA== ------END EC PRIVATE KEY----- diff --git a/src/http/async/tls.zig/testdata/ec_private_key.pem b/src/http/async/tls.zig/testdata/ec_private_key.pem deleted file mode 100644 index 95048aaa..00000000 --- a/src/http/async/tls.zig/testdata/ec_private_key.pem +++ /dev/null @@ -1,6 +0,0 @@ ------BEGIN PRIVATE KEY----- -MIG2AgEAMBAGByqGSM49AgEGBSuBBAAiBIGeMIGbAgEBBDAQNT3KGxUdBqpxuO/z -GSJDePMgmB6xLytkfnHQMCqQquXrmcOQZT3BJhm+PwggmwGhZANiAATKxBc6kfqA -piA+Z0rIjVwaZaBNGnP4UZ5TqVewQ/dP9/BQCca2SJpsXauGLcUPmK4sKFQxGe6d -fzq9O50lo7qHEOIpwDBdhRp+oqB6sN2hMtCPbp6eyzsUlm3FUyhN9D0= ------END PRIVATE KEY----- diff --git a/src/http/async/tls.zig/testdata/ec_secp384r1_private_key.pem b/src/http/async/tls.zig/testdata/ec_secp384r1_private_key.pem deleted file mode 100644 index 62eac9ee..00000000 --- a/src/http/async/tls.zig/testdata/ec_secp384r1_private_key.pem +++ /dev/null @@ -1,6 +0,0 @@ ------BEGIN EC PRIVATE KEY----- -MIGkAgEBBDDubYpeDdOwxksyQIDiOt6LHt3ikts2HNuR6rqhBg1CLdmp3AVDKfF4 -fPkIr8UDH22gBwYFK4EEACKhZANiAARcVFUVv3bIHS6BEfLt98rtps7XP1y26m2n -v5x/5ecbDH2p7AXBYerJERKFi7ZFE1DSrSAj+KK8otjdEG44ZA2Mtl5AHwDVrKde -RgtavVoreHhLN80jJOun8JnFXQjdNsA= ------END EC PRIVATE KEY----- diff --git a/src/http/async/tls.zig/testdata/ec_secp521r1_private_key.pem b/src/http/async/tls.zig/testdata/ec_secp521r1_private_key.pem deleted file mode 100644 index 5b7f9321..00000000 --- a/src/http/async/tls.zig/testdata/ec_secp521r1_private_key.pem +++ /dev/null @@ -1,7 +0,0 @@ ------BEGIN EC PRIVATE KEY----- -MIHcAgEBBEIB8C9axyQY6mgjjC6htLjc8hGylrDsh4BCv9669JaDj5vbxmCnTNlg -OuS6C9+uJNMbwm6CoIjB7RcgDTrxxX7oCyegBwYFK4EEACOhgYkDgYYABABAT5Q8 -aOj9U0iuJE5tXfKnYTgPuvD6keHZAGJ5veM9uR6jr3BhfGubD6bnlD+cIBQzYWo0 -y/BNMzCRJ55PDCNU5gGLw+vkwhJ1lGF5OS6l2oG5WN3fe6cYo+uJD7+PB3WYNIuX -Ls0oidsEM0Q4WLblQOEP6VLGf4qTcZyhoFWYfkjWiw== ------END EC PRIVATE KEY----- diff --git a/src/http/async/tls.zig/testdata/google.com/client_random b/src/http/async/tls.zig/testdata/google.com/client_random deleted file mode 100644 index e817c906..00000000 --- a/src/http/async/tls.zig/testdata/google.com/client_random +++ /dev/null @@ -1 +0,0 @@ -'”’ßqp0x­0)ì©–Ã~Ì+Œ`‡¬tY4•©D_ \ No newline at end of file diff --git a/src/http/async/tls.zig/testdata/google.com/server_hello b/src/http/async/tls.zig/testdata/google.com/server_hello deleted file mode 100644 index 57a80765..00000000 Binary files a/src/http/async/tls.zig/testdata/google.com/server_hello and /dev/null differ diff --git a/src/http/async/tls.zig/testdata/rsa_private_key.pem b/src/http/async/tls.zig/testdata/rsa_private_key.pem deleted file mode 100644 index b8cc7a69..00000000 --- a/src/http/async/tls.zig/testdata/rsa_private_key.pem +++ /dev/null @@ -1,28 +0,0 @@ ------BEGIN PRIVATE KEY----- -MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQDe9yPmdcxv3dVu -D4wJ+GLjYBvAfYzVBFAsNuI79zOfoRSvvs8aD0z1yzlwDjuX1iH3SJF5ynxo/Opi -oVpyT3hXDszyo1AF8UzKUXMQmhiOcfW0xz6+TO831IRLghzsCKPMBz1cC+WFP/62 -RHePPGovM8Nd9vIpRgQlfgXZ+DstpEBmnw1tGvq8CsWLhkMw7xQgQZ21zD5jtUgE -J8lc02IoX/W25HdJmayESqZnpZoaN8dgTLrBcM9XZEoh6gVTEOyUcUpDBIMAqloo -vPKMWBSS0oMX9HspD+eHokeyUxkSI/tLzlr4oYT5sfO/4/oQ+K2vh84DDqAsE3FX -xFVESETLAgMBAAECggEAUDuAmKqlEVAzQDKqAuB1vTpVYjQLnI+7xd1OFaQD2Jpf -VkqEPe1plT03AwKsIRw2BsT/TGM315PDSBCl+mJsfG9gAqQP5MOLDXa3wC6jTYbm -ktHr2xDWODHqFT3R6IHHZ2DnjJrfUc7QeogyucFUuH2Y/NQjGgUO8urhcikoKmi3 -kBiAHCHWNqhrSpzdFLifhe6VC/TGFwKqTepN+TnX3Z20HdL4kkYPGEGA9OonVSn4 -N1m/Q+yj6xm6vBMGlT0lS8lyz0EKb6rLedR7+rEJfOIvhVFEi8aXjkb5a6wIh5LO -rwu/jL0nUY8J5NP5BKz68gRwPtmmKBfCLXTpJUACSQKBgQDmPHEZkBC5wl8plx16 -hrwwSdJuQy0b6BYZO06gpYBOIIENULijKwZzoMYaL8zivGT3KIluEelA7+NXnCuk -NUx7LieeZ+ChIUuRLvT02H9lH11d1Va2PmBgRmUKgul26YyaxeIy3UzjPbbgUFJv -t970IRfgS8qGD9KuhdlovZlCzwKBgQD36mo4BxgO1xmr4Qq0WSgQi2QBMAP9lpE4 -Lc59UP5qvNrGXLGPsirdzz6VSeMGrrxGDyof+fGG9d0Wt0+8OMRysSuVua+SRiJ4 -ugoaCzLbsq6pzDWPXf/wzVevjKTIGh4ZXk6Qa7IHqyEmvOnvxdDsL3iZXgPcQoIF -HybqHU9NRQKBgC6tnGSJX8q5jJ+bAp//xxGnNeGi/vdEc46EBqntQ/kS//caIYT7 -SSCSPPe8Lzbc6T9u2YYWXYsL17TAddyh7bKfpeqottMUNAToV0N4zUNMO5q1kRH7 -zYBXZU7fQcQZD6elbPnRAjCkJ3qM7lm2Fp66QuP3mcTaWmWFv5FLt1HjAoGANVaF -y9Aa6PZ2W3hraSnVaNnUhjziXujKDaAtUODgG+7N0ueWfCgE+PvhpxTid0mY0Cnr -Ej4gLL0w9/YwfXppKZPcoLX2hC36tKayDbBjHMlwsq9wxoueyRwkxWwo97RGzYZw -uLmy79ttonv6iM+yh14fQD/t7LGSb6+oG656pVECgYEA0oya1vG0WL3K8ip8io4c -ovB2K1Uf7EyFzxJHJt6QpmXlPDKkwc6JzpKGJdCi09Pz49U63HodxahtB831rbAY -EduOUQ5scTKf66qA9/kEyClnwl14ZCds7/mu9ioZ7D0VNmWPFsYHaGKAUxsq97nb -xw9Y4zAdgbDcl1bzN9XCDKs= ------END PRIVATE KEY----- diff --git a/src/http/async/tls.zig/testdata/tls12.zig b/src/http/async/tls.zig/testdata/tls12.zig deleted file mode 100644 index e5bd1c49..00000000 --- a/src/http/async/tls.zig/testdata/tls12.zig +++ /dev/null @@ -1,244 +0,0 @@ -/// Messages from The Illustrated TLS 1.2 Connection -/// https://tls12.xargs.org/ -const hexToBytes = @import("../testu.zig").hexToBytes; - -pub const client_hello = hexToBytes( - \\ 16 03 01 00 a5 01 00 00 a1 03 03 00 01 02 03 04 - \\ 05 06 07 08 09 0a 0b 0c 0d 0e 0f 10 11 12 13 14 - \\ 15 16 17 18 19 1a 1b 1c 1d 1e 1f 00 00 20 cc a8 - \\ cc a9 c0 2f c0 30 c0 2b c0 2c c0 13 c0 09 c0 14 - \\ c0 0a 00 9c 00 9d 00 2f 00 35 c0 12 00 0a 01 00 - \\ 00 58 00 00 00 18 00 16 00 00 13 65 78 61 6d 70 - \\ 6c 65 2e 75 6c 66 68 65 69 6d 2e 6e 65 74 00 05 - \\ 00 05 01 00 00 00 00 00 0a 00 0a 00 08 00 1d 00 - \\ 17 00 18 00 19 00 0b 00 02 01 00 00 0d 00 12 00 - \\ 10 04 01 04 03 05 01 05 03 06 01 06 03 02 01 02 - \\ 03 ff 01 00 01 00 00 12 00 00 -); -pub const server_hello = hexToBytes( - \\ 16 03 03 00 31 02 00 00 2d 03 03 70 71 72 73 74 - \\ 75 76 77 78 79 7a 7b 7c 7d 7e 7f 80 81 82 83 84 - \\ 85 86 87 88 89 8a 8b 8c 8d 8e 8f 00 c0 13 00 00 - \\ 05 ff 01 00 01 00 -); -pub const server_certificate = hexToBytes( - \\ 16 03 03 03 2f 0b 00 03 2b 00 03 28 00 03 25 30 - \\ 82 03 21 30 82 02 09 a0 03 02 01 02 02 08 15 5a - \\ 92 ad c2 04 8f 90 30 0d 06 09 2a 86 48 86 f7 0d - \\ 01 01 0b 05 00 30 22 31 0b 30 09 06 03 55 04 06 - \\ 13 02 55 53 31 13 30 11 06 03 55 04 0a 13 0a 45 - \\ 78 61 6d 70 6c 65 20 43 41 30 1e 17 0d 31 38 31 - \\ 30 30 35 30 31 33 38 31 37 5a 17 0d 31 39 31 30 - \\ 30 35 30 31 33 38 31 37 5a 30 2b 31 0b 30 09 06 - \\ 03 55 04 06 13 02 55 53 31 1c 30 1a 06 03 55 04 - \\ 03 13 13 65 78 61 6d 70 6c 65 2e 75 6c 66 68 65 - \\ 69 6d 2e 6e 65 74 30 82 01 22 30 0d 06 09 2a 86 - \\ 48 86 f7 0d 01 01 01 05 00 03 82 01 0f 00 30 82 - \\ 01 0a 02 82 01 01 00 c4 80 36 06 ba e7 47 6b 08 - \\ 94 04 ec a7 b6 91 04 3f f7 92 bc 19 ee fb 7d 74 - \\ d7 a8 0d 00 1e 7b 4b 3a 4a e6 0f e8 c0 71 fc 73 - \\ e7 02 4c 0d bc f4 bd d1 1d 39 6b ba 70 46 4a 13 - \\ e9 4a f8 3d f3 e1 09 59 54 7b c9 55 fb 41 2d a3 - \\ 76 52 11 e1 f3 dc 77 6c aa 53 37 6e ca 3a ec be - \\ c3 aa b7 3b 31 d5 6c b6 52 9c 80 98 bc c9 e0 28 - \\ 18 e2 0b f7 f8 a0 3a fd 17 04 50 9e ce 79 bd 9f - \\ 39 f1 ea 69 ec 47 97 2e 83 0f b5 ca 95 de 95 a1 - \\ e6 04 22 d5 ee be 52 79 54 a1 e7 bf 8a 86 f6 46 - \\ 6d 0d 9f 16 95 1a 4c f7 a0 46 92 59 5c 13 52 f2 - \\ 54 9e 5a fb 4e bf d7 7a 37 95 01 44 e4 c0 26 87 - \\ 4c 65 3e 40 7d 7d 23 07 44 01 f4 84 ff d0 8f 7a - \\ 1f a0 52 10 d1 f4 f0 d5 ce 79 70 29 32 e2 ca be - \\ 70 1f df ad 6b 4b b7 11 01 f4 4b ad 66 6a 11 13 - \\ 0f e2 ee 82 9e 4d 02 9d c9 1c dd 67 16 db b9 06 - \\ 18 86 ed c1 ba 94 21 02 03 01 00 01 a3 52 30 50 - \\ 30 0e 06 03 55 1d 0f 01 01 ff 04 04 03 02 05 a0 - \\ 30 1d 06 03 55 1d 25 04 16 30 14 06 08 2b 06 01 - \\ 05 05 07 03 02 06 08 2b 06 01 05 05 07 03 01 30 - \\ 1f 06 03 55 1d 23 04 18 30 16 80 14 89 4f de 5b - \\ cc 69 e2 52 cf 3e a3 00 df b1 97 b8 1d e1 c1 46 - \\ 30 0d 06 09 2a 86 48 86 f7 0d 01 01 0b 05 00 03 - \\ 82 01 01 00 59 16 45 a6 9a 2e 37 79 e4 f6 dd 27 - \\ 1a ba 1c 0b fd 6c d7 55 99 b5 e7 c3 6e 53 3e ff - \\ 36 59 08 43 24 c9 e7 a5 04 07 9d 39 e0 d4 29 87 - \\ ff e3 eb dd 09 c1 cf 1d 91 44 55 87 0b 57 1d d1 - \\ 9b df 1d 24 f8 bb 9a 11 fe 80 fd 59 2b a0 39 8c - \\ de 11 e2 65 1e 61 8c e5 98 fa 96 e5 37 2e ef 3d - \\ 24 8a fd e1 74 63 eb bf ab b8 e4 d1 ab 50 2a 54 - \\ ec 00 64 e9 2f 78 19 66 0d 3f 27 cf 20 9e 66 7f - \\ ce 5a e2 e4 ac 99 c7 c9 38 18 f8 b2 51 07 22 df - \\ ed 97 f3 2e 3e 93 49 d4 c6 6c 9e a6 39 6d 74 44 - \\ 62 a0 6b 42 c6 d5 ba 68 8e ac 3a 01 7b dd fc 8e - \\ 2c fc ad 27 cb 69 d3 cc dc a2 80 41 44 65 d3 ae - \\ 34 8c e0 f3 4a b2 fb 9c 61 83 71 31 2b 19 10 41 - \\ 64 1c 23 7f 11 a5 d6 5c 84 4f 04 04 84 99 38 71 - \\ 2b 95 9e d6 85 bc 5c 5d d6 45 ed 19 90 94 73 40 - \\ 29 26 dc b4 0e 34 69 a1 59 41 e8 e2 cc a8 4b b6 - \\ 08 46 36 a0 -); -pub const server_key_exchange = hexToBytes( - \\ 16 03 03 01 2c 0c 00 01 28 03 00 1d 20 9f d7 ad - \\ 6d cf f4 29 8d d3 f9 6d 5b 1b 2a f9 10 a0 53 5b - \\ 14 88 d7 f8 fa bb 34 9a 98 28 80 b6 15 04 01 01 - \\ 00 04 02 b6 61 f7 c1 91 ee 59 be 45 37 66 39 bd - \\ c3 d4 bb 81 e1 15 ca 73 c8 34 8b 52 5b 0d 23 38 - \\ aa 14 46 67 ed 94 31 02 14 12 cd 9b 84 4c ba 29 - \\ 93 4a aa cc e8 73 41 4e c1 1c b0 2e 27 2d 0a d8 - \\ 1f 76 7d 33 07 67 21 f1 3b f3 60 20 cf 0b 1f d0 - \\ ec b0 78 de 11 28 be ba 09 49 eb ec e1 a1 f9 6e - \\ 20 9d c3 6e 4f ff d3 6b 67 3a 7d dc 15 97 ad 44 - \\ 08 e4 85 c4 ad b2 c8 73 84 12 49 37 25 23 80 9e - \\ 43 12 d0 c7 b3 52 2e f9 83 ca c1 e0 39 35 ff 13 - \\ a8 e9 6b a6 81 a6 2e 40 d3 e7 0a 7f f3 58 66 d3 - \\ d9 99 3f 9e 26 a6 34 c8 1b 4e 71 38 0f cd d6 f4 - \\ e8 35 f7 5a 64 09 c7 dc 2c 07 41 0e 6f 87 85 8c - \\ 7b 94 c0 1c 2e 32 f2 91 76 9e ac ca 71 64 3b 8b - \\ 98 a9 63 df 0a 32 9b ea 4e d6 39 7e 8c d0 1a 11 - \\ 0a b3 61 ac 5b ad 1c cd 84 0a 6c 8a 6e aa 00 1a - \\ 9d 7d 87 dc 33 18 64 35 71 22 6c 4d d2 c2 ac 41 - \\ fb -); -pub const server_hello_done = hexToBytes("16 03 03 00 04 0e 00 00 00 "); -pub const server_change_cipher_spec = hexToBytes("14 03 03 00 01 01 "); - -pub const server_handshake_finished = hexToBytes( - \\ 16 03 03 00 40 51 52 53 54 55 56 57 58 59 5a 5b - \\ 5c 5d 5e 5f 60 18 e0 75 31 7b 10 03 15 f6 08 1f - \\ cb f3 13 78 1a ac 73 ef e1 9f e2 5b a1 af 59 c2 - \\ 0b e9 4f c0 1b da 2d 68 00 29 8b 73 a7 e8 49 d7 - \\ 4b d4 94 cf 7d -); -pub const client_key_exchange_for_transcript = hexToBytes( - \\ 16 03 03 00 25 10 00 00 21 20 35 80 72 d6 36 58 - \\ 80 d1 ae ea 32 9a df 91 21 38 38 51 ed 21 a2 8e - \\ 3b 75 e9 65 d0 d2 cd 16 62 54 -); - -pub const server_hello_responses = server_hello ++ server_certificate ++ server_key_exchange ++ server_hello_done; - -pub const server_responses = server_hello_responses ++ server_change_cipher_spec ++ server_handshake_finished; - -pub const server_handshake_finished_msgs = server_change_cipher_spec ++ server_handshake_finished; - -pub const master_secret = hexToBytes( - \\ 91 6a bf 9d a5 59 73 e1 36 14 ae 0a 3f 5d 3f 37 - \\ b0 23 ba 12 9a ee 02 cc 91 34 33 81 27 cd 70 49 - \\ 78 1c 8e 19 fc 1e b2 a7 38 7a c0 6a e2 37 34 4c -); - -pub const client_key_exchange = hexToBytes( - \\ 16 03 03 00 25 10 00 00 21 20 00 01 02 03 04 05 - \\ 06 07 08 09 0a 0b 0c 0d 0e 0f 10 11 12 13 14 15 - \\ 16 17 18 19 1a 1b 1c 1d 1e 1f -); -pub const client_change_cyper_spec = hexToBytes("14 03 03 00 01 01 "); - -pub const client_handshake_finished = hexToBytes( - \\ 16 03 03 00 40 20 21 22 23 24 25 26 27 28 29 2a - \\ 2b 2c 2d 2e 2f a9 ac f5 5a f3 7a 90 17 63 ff 91 - \\ 68 9a b7 ee a0 d4 0c 1c ca 62 44 ef f3 0b a3 6d - \\ d0 df 86 3f 7d e3 98 d3 1a cc 37 6a e6 7a 00 6d - \\ 8c 08 bc 8a 5a -); - -pub const handshake_messages = [_][]const u8{ - &client_hello, - &server_hello, - &server_certificate, - &server_key_exchange, - &server_hello_done, - &client_key_exchange_for_transcript, -}; - -pub const client_finished = hexToBytes("14 00 00 0c cf 91 96 26 f1 36 0c 53 6a aa d7 3a "); - -// with iv 40 " ++ 41 ... 4f -// client_sequence = 0 -pub const verify_data_encrypted_msg = hexToBytes( - \\ 16 03 03 00 40 40 41 42 43 44 45 46 47 48 49 4a - \\ 4b 4c 4d 4e 4f 22 7b c9 ba 81 ef 30 f2 a8 a7 8f - \\ f1 df 50 84 4d 58 04 b7 ee b2 e2 14 c3 2b 68 92 - \\ ac a3 db 7b 78 07 7f dd 90 06 7c 51 6b ac b3 ba - \\ 90 de df 72 0f -); - -// with iv 00 " ++ 01 ... 1f -// client_sequence = 1 -pub const encrypted_ping_msg = hexToBytes( - \\ 17 03 03 00 30 00 01 02 03 04 05 06 07 08 09 0a - \\ 0b 0c 0d 0e 0f 6c 42 1c 71 c4 2b 18 3b fa 06 19 - \\ 5d 13 3d 0a 09 d0 0f c7 cb 4e 0f 5d 1c da 59 d1 - \\ 47 ec 79 0c 99 -); - -pub const key_material = hexToBytes( - \\ 1b 7d 11 7c 7d 5f 69 0b c2 63 ca e8 ef 60 af 0f - \\ 18 78 ac c2 2a d8 bd d8 c6 01 a6 17 12 6f 63 54 - \\ 0e b2 09 06 f7 81 fa d2 f6 56 d0 37 b1 73 ef 3e - \\ 11 16 9f 27 23 1a 84 b6 75 2a 18 e7 a9 fc b7 cb - \\ cd d8 f9 8d d8 f7 69 eb a0 d2 55 0c 92 38 ee bf - \\ ef 5c 32 25 1a bb 67 d6 43 45 28 db 49 37 d5 40 - \\ d3 93 13 5e 06 a1 1b b8 0e 45 ea eb e3 2c ac 72 - \\ 75 74 38 fb b3 df 64 5c bd a4 06 7c df a0 f8 48 -); - -pub const server_pong = hexToBytes( - \\ 17 03 03 00 30 61 62 63 64 65 66 67 68 69 6a 6b - \\ 6c 6d 6e 6f 70 97 83 48 8a f5 fa 20 bf 7a 2e f6 - \\ 9d eb b5 34 db 9f b0 7a 8c 27 21 de e5 40 9f 77 - \\ af 0c 3d de 56 -); - -pub const client_random = hexToBytes( - \\ 00 01 02 03 04 05 06 07 08 09 0a 0b 0c 0d 0e 0f 10 11 12 13 14 15 16 17 18 19 1a 1b 1c 1d 1e 1f -); - -pub const server_random = hexToBytes( - \\ 70 71 72 73 74 75 76 77 78 79 7a 7b 7c 7d 7e 7f 80 81 82 83 84 85 86 87 88 89 8a 8b 8c 8d 8e 8f -); - -pub const client_secret = hexToBytes( - \\ 20 21 22 23 24 25 26 27 28 29 2a 2b 2c 2d 2e 2f 30 31 32 33 34 35 36 37 38 39 3a 3b 3c 3d 3e 3f -); - -pub const server_pub_key = hexToBytes( - \\ 9f d7 ad 6d cf f4 29 8d d3 f9 6d 5b 1b 2a f9 10 a0 53 5b 14 88 d7 f8 fa bb 34 9a 98 28 80 b6 15 -); - -pub const signature = hexToBytes( - \\ 04 02 b6 61 f7 c1 91 ee 59 be 45 37 66 39 bd c3 - \\ d4 bb 81 e1 15 ca 73 c8 34 8b 52 5b 0d 23 38 aa - \\ 14 46 67 ed 94 31 02 14 12 cd 9b 84 4c ba 29 93 - \\ 4a aa cc e8 73 41 4e c1 1c b0 2e 27 2d 0a d8 1f - \\ 76 7d 33 07 67 21 f1 3b f3 60 20 cf 0b 1f d0 ec - \\ b0 78 de 11 28 be ba 09 49 eb ec e1 a1 f9 6e 20 - \\ 9d c3 6e 4f ff d3 6b 67 3a 7d dc 15 97 ad 44 08 - \\ e4 85 c4 ad b2 c8 73 84 12 49 37 25 23 80 9e 43 - \\ 12 d0 c7 b3 52 2e f9 83 ca c1 e0 39 35 ff 13 a8 - \\ e9 6b a6 81 a6 2e 40 d3 e7 0a 7f f3 58 66 d3 d9 - \\ 99 3f 9e 26 a6 34 c8 1b 4e 71 38 0f cd d6 f4 e8 - \\ 35 f7 5a 64 09 c7 dc 2c 07 41 0e 6f 87 85 8c 7b - \\ 94 c0 1c 2e 32 f2 91 76 9e ac ca 71 64 3b 8b 98 - \\ a9 63 df 0a 32 9b ea 4e d6 39 7e 8c d0 1a 11 0a - \\ b3 61 ac 5b ad 1c cd 84 0a 6c 8a 6e aa 00 1a 9d - \\ 7d 87 dc 33 18 64 35 71 22 6c 4d d2 c2 ac 41 fb -); - -pub const cert_pub_key = hexToBytes( - \\ 30 82 01 0a 02 82 01 01 00 c4 80 36 06 ba e7 47 - \\ 6b 08 94 04 ec a7 b6 91 04 3f f7 92 bc 19 ee fb - \\ 7d 74 d7 a8 0d 00 1e 7b 4b 3a 4a e6 0f e8 c0 71 - \\ fc 73 e7 02 4c 0d bc f4 bd d1 1d 39 6b ba 70 46 - \\ 4a 13 e9 4a f8 3d f3 e1 09 59 54 7b c9 55 fb 41 - \\ 2d a3 76 52 11 e1 f3 dc 77 6c aa 53 37 6e ca 3a - \\ ec be c3 aa b7 3b 31 d5 6c b6 52 9c 80 98 bc c9 - \\ e0 28 18 e2 0b f7 f8 a0 3a fd 17 04 50 9e ce 79 - \\ bd 9f 39 f1 ea 69 ec 47 97 2e 83 0f b5 ca 95 de - \\ 95 a1 e6 04 22 d5 ee be 52 79 54 a1 e7 bf 8a 86 - \\ f6 46 6d 0d 9f 16 95 1a 4c f7 a0 46 92 59 5c 13 - \\ 52 f2 54 9e 5a fb 4e bf d7 7a 37 95 01 44 e4 c0 - \\ 26 87 4c 65 3e 40 7d 7d 23 07 44 01 f4 84 ff d0 - \\ 8f 7a 1f a0 52 10 d1 f4 f0 d5 ce 79 70 29 32 e2 - \\ ca be 70 1f df ad 6b 4b b7 11 01 f4 4b ad 66 6a - \\ 11 13 0f e2 ee 82 9e 4d 02 9d c9 1c dd 67 16 db - \\ b9 06 18 86 ed c1 ba 94 21 02 03 01 00 01 -); diff --git a/src/http/async/tls.zig/testdata/tls13.zig b/src/http/async/tls.zig/testdata/tls13.zig deleted file mode 100644 index f98f9ff5..00000000 --- a/src/http/async/tls.zig/testdata/tls13.zig +++ /dev/null @@ -1,64 +0,0 @@ -const hexToBytes = @import("../testu.zig").hexToBytes; - -pub const client_hello = - hexToBytes("16030100f8010000f40303000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f20e0e1e2e3e4e5e6e7e8e9eaebecedeeeff0f1f2f3f4f5f6f7f8f9fafbfcfdfeff000813021303130100ff010000a30000001800160000136578616d706c652e756c666865696d2e6e6574000b000403000102000a00160014001d0017001e0019001801000101010201030104002300000016000000170000000d001e001c040305030603080708080809080a080b080408050806040105010601002b0003020304002d00020101003300260024001d0020358072d6365880d1aeea329adf9121383851ed21a28e3b75e965d0d2cd166254"); - -pub const server_hello = - hexToBytes("160303007a") ++ // record header - hexToBytes("020000760303") ++ // handshake header, server version - hexToBytes("707172737475767778797a7b7c7d7e7f808182838485868788898a8b8c8d8e8f") ++ // server_random - hexToBytes("20e0e1e2e3e4e5e6e7e8e9eaebecedeeeff0f1f2f3f4f5f6f7f8f9fafbfcfdfeff") ++ // session id - hexToBytes("130200") ++ // cipher suite, compression method - hexToBytes("002e002b00020304") ++ // extensions, supported version - hexToBytes("00330024001d00209fd7ad6dcff4298dd3f96d5b1b2af910a0535b1488d7f8fabb349a982880b615"); // extension key share - -pub const client_random = hexToBytes( - \\ 00 01 02 03 04 05 06 07 08 09 0a 0b 0c 0d 0e 0f 10 11 12 13 14 15 16 17 18 19 1a 1b 1c 1d 1e 1f -); - -pub const server_random = - hexToBytes("707172737475767778797a7b7c7d7e7f808182838485868788898a8b8c8d8e8f"); -pub const server_pub_key = - hexToBytes("9fd7ad6dcff4298dd3f96d5b1b2af910a0535b1488d7f8fabb349a982880b615"); -pub const client_private_key = - hexToBytes("202122232425262728292a2b2c2d2e2f303132333435363738393a3b3c3d3e3f"); -pub const client_public_key = - hexToBytes("358072d6365880d1aeea329adf9121383851ed21a28e3b75e965d0d2cd166254"); - -pub const shared_key = hexToBytes("df4a291baa1eb7cfa6934b29b474baad2697e29f1f920dcc77c8a0a088447624"); -pub const server_handshake_key = hexToBytes("9f13575ce3f8cfc1df64a77ceaffe89700b492ad31b4fab01c4792be1b266b7f"); -pub const server_handshake_iv = hexToBytes("9563bc8b590f671f488d2da3"); -pub const client_handshake_key = hexToBytes("1135b4826a9a70257e5a391ad93093dfd7c4214812f493b3e3daae1eb2b1ac69"); -pub const client_handshake_iv = hexToBytes("4256d2e0e88babdd05eb2f27"); - -pub const server_application_key = hexToBytes("01f78623f17e3edcc09e944027ba3218d57c8e0db93cd3ac419309274700ac27"); -pub const server_application_iv = hexToBytes("196a750b0c5049c0cc51a541"); -pub const client_application_key = hexToBytes("de2f4c7672723a692319873e5c227606691a32d1c59d8b9f51dbb9352e9ca9cc"); -pub const client_application_iv = hexToBytes("bb007956f474b25de902432f"); - -pub const server_encrypted_extensions_wrapped = - hexToBytes("17030300176be02f9da7c2dc9ddef56f2468b90adfa25101ab0344ae"); -pub const server_encrypted_extensions = - hexToBytes("080000020000"); - -pub const server_certificate_wrapped = - hexToBytes("1703030343baf00a9be50f3f2307e726edcbdacbe4b18616449d46c6207af6e9953ee5d2411ba65d31feaf4f78764f2d693987186cc01329c187a5e4608e8d27b318e98dd94769f7739ce6768392caca8dcc597d77ec0d1272233785f6e69d6f43effa8e7905edfdc4037eee5933e990a7972f206913a31e8d04931366d3d8bcd6a4a4d647dd4bd80b0ff863ce3554833d744cf0e0b9c07cae726dd23f9953df1f1ce3aceb3b7230871e92310cfb2b098486f43538f8e82d8404e5c6c25f66a62ebe3c5f26232640e20a769175ef83483cd81e6cb16e78dfad4c1b714b04b45f6ac8d1065ad18c13451c9055c47da300f93536ea56f531986d6492775393c4ccb095467092a0ec0b43ed7a0687cb470ce350917b0ac30c6e5c24725a78c45f9f5f29b6626867f6f79ce054273547b36df030bd24af10d632dba54fc4e890bd0586928c0206ca2e28e44e227a2d5063195935df38da8936092eef01e84cad2e49d62e470a6c7745f625ec39e4fc23329c79d1172876807c36d736ba42bb69b004ff55f93850dc33c1f98abb92858324c76ff1eb085db3c1fc50f74ec04442e622973ea70743418794c388140bb492d6294a0540e5a59cfae60ba0f14899fca71333315ea083a68e1d7c1e4cdc2f56bcd6119681a4adbc1bbf42afd806c3cbd42a076f545dee4e118d0b396754be2b042a685dd4727e89c0386a94d3cd6ecb9820e9d49afeed66c47e6fc243eabebbcb0b02453877f5ac5dbfbdf8db1052a3c994b224cd9aaaf56b026bb9efa2e01302b36401ab6494e7018d6e5b573bd38bcef023b1fc92946bbca0209ca5fa926b4970b1009103645cb1fcfe552311ff730558984370038fd2cce2a91fc74d6f3e3ea9f843eed356f6f82d35d03bc24b81b58ceb1a43ec9437e6f1e50eb6f555e321fd67c8332eb1b832aa8d795a27d479c6e27d5a61034683891903f66421d094e1b00a9a138d861e6f78a20ad3e1580054d2e305253c713a02fe1e28deee7336246f6ae34331806b46b47b833c39b9d31cd300c2a6ed831399776d07f570eaf0059a2c68a5f3ae16b617404af7b7231a4d942758fc020b3f23ee8c15e36044cfd67cd640993b16207597fbf385ea7a4d99e8d456ff83d41f7b8b4f069b028a2a63a919a70e3a10e3084158faa5bafa30186c6b2f238eb530c73e"); -pub const server_certificate = - hexToBytes("0b00032e0000032a0003253082032130820209a0030201020208155a92adc2048f90300d06092a864886f70d01010b05003022310b300906035504061302555331133011060355040a130a4578616d706c65204341301e170d3138313030353031333831375a170d3139313030353031333831375a302b310b3009060355040613025553311c301a060355040313136578616d706c652e756c666865696d2e6e657430820122300d06092a864886f70d01010105000382010f003082010a0282010100c4803606bae7476b089404eca7b691043ff792bc19eefb7d74d7a80d001e7b4b3a4ae60fe8c071fc73e7024c0dbcf4bdd11d396bba70464a13e94af83df3e10959547bc955fb412da3765211e1f3dc776caa53376eca3aecbec3aab73b31d56cb6529c8098bcc9e02818e20bf7f8a03afd1704509ece79bd9f39f1ea69ec47972e830fb5ca95de95a1e60422d5eebe527954a1e7bf8a86f6466d0d9f16951a4cf7a04692595c1352f2549e5afb4ebfd77a37950144e4c026874c653e407d7d23074401f484ffd08f7a1fa05210d1f4f0d5ce79702932e2cabe701fdfad6b4bb71101f44bad666a11130fe2ee829e4d029dc91cdd6716dbb9061886edc1ba94210203010001a3523050300e0603551d0f0101ff0404030205a0301d0603551d250416301406082b0601050507030206082b06010505070301301f0603551d23041830168014894fde5bcc69e252cf3ea300dfb197b81de1c146300d06092a864886f70d01010b05000382010100591645a69a2e3779e4f6dd271aba1c0bfd6cd75599b5e7c36e533eff3659084324c9e7a504079d39e0d42987ffe3ebdd09c1cf1d914455870b571dd19bdf1d24f8bb9a11fe80fd592ba0398cde11e2651e618ce598fa96e5372eef3d248afde17463ebbfabb8e4d1ab502a54ec0064e92f7819660d3f27cf209e667fce5ae2e4ac99c7c93818f8b2510722dfed97f32e3e9349d4c66c9ea6396d744462a06b42c6d5ba688eac3a017bddfc8e2cfcad27cb69d3ccdca280414465d3ae348ce0f34ab2fb9c618371312b191041641c237f11a5d65c844f0404849938712b959ed685bc5c5dd645ed19909473402926dcb40e3469a15941e8e2cca84bb6084636a00000"); - -pub const server_certificate_verify_wrapped = hexToBytes("170303011973719fce07ec2f6d3bba0292a0d40b2770c06a271799a53314f6f77fc95c5fe7b9a4329fd9548c670ebeea2f2d5c351dd9356ef2dcd52eb137bd3a676522f8cd0fb7560789ad7b0e3caba2e37e6b4199c6793b3346ed46cf740a9fa1fec414dc715c415c60e575703ce6a34b70b5191aa6a61a18faff216c687ad8d17e12a7e99915a611bfc1a2befc15e6e94d784642e682fd17382a348c301056b940c9847200408bec56c81ea3d7217ab8e85a88715395899c90587f72e8ddd74b26d8edc1c7c837d9f2ebbc260962219038b05654a63a0b12999b4a8306a3ddcc0e17c53ba8f9c80363f7841354d291b4ace0c0f330c0fcd5aa9deef969ae8ab2d98da88ebb6ea80a3a11f00ea296a3232367ff075e1c66dd9cbedc4713"); -pub const server_finished_wrapped = hexToBytes("17030300451061de27e51c2c9f342911806f282b710c10632ca5006755880dbf7006002d0e84fed9adf27a43b5192303e4df5c285d58e3c76224078440c0742374744aecf28cf3182fd0"); - -pub const handshake_hash = hexToBytes("fa6800169a6baac19159524fa7b9721b41be3c9db6f3f93fa5ff7e3db3ece204d2b456c51046e40ec5312c55a86126f5"); - -pub const client_finished_verify_data = hexToBytes("bff56a671b6c659d0a7c5dd18428f58bdd38b184a3ce342d9fde95cbd5056f7da7918ee320eab7a93abd8f1c02454d27"); - -pub const client_finished_wrapped = hexToBytes("17030300459ff9b063175177322a46dd9896f3c3bb820ab51743ebc25fdadd53454b73deb54cc7248d411a18bccf657a960824e9a19364837c350a69a88d4bf635c85eb874aebc9dfde8"); - -pub const client_ping_wrapped = hexToBytes("1703030015828139cb7b73aaabf5b82fbf9a2961bcde10038a32"); -pub const server_flight = - hexToBytes("140303000101") ++ - server_encrypted_extensions_wrapped ++ - server_certificate_wrapped ++ - server_certificate_verify_wrapped ++ - server_finished_wrapped; diff --git a/src/http/async/tls.zig/testu.zig b/src/http/async/tls.zig/testu.zig deleted file mode 100644 index 255fe6d3..00000000 --- a/src/http/async/tls.zig/testu.zig +++ /dev/null @@ -1,117 +0,0 @@ -const std = @import("std"); - -pub fn bufPrint(var_name: []const u8, buf: []const u8) void { - // std.debug.print("\nconst {s} = [_]u8{{\n", .{var_name}); - // for (buf, 1..) |b, i| { - // std.debug.print("0x{x:0>2}, ", .{b}); - // if (i % 16 == 0) - // std.debug.print("\n", .{}); - // } - // std.debug.print("}};\n", .{}); - - std.debug.print("const {s} = \"", .{var_name}); - const charset = "0123456789abcdef"; - for (buf) |b| { - const x = charset[b >> 4]; - const y = charset[b & 15]; - std.debug.print("{c}{c} ", .{ x, y }); - } - std.debug.print("\"\n", .{}); -} - -const random_instance = std.Random{ .ptr = undefined, .fillFn = randomFillFn }; -var random_seed: u8 = 0; - -pub fn randomFillFn(_: *anyopaque, buf: []u8) void { - for (buf) |*v| { - v.* = random_seed; - random_seed +%= 1; - } -} - -pub fn random(seed: u8) std.Random { - random_seed = seed; - return random_instance; -} - -// Fill buf with 0,1,..ff,0,... -pub fn fill(buf: []u8) void { - fillFrom(buf, 0); -} - -pub fn fillFrom(buf: []u8, start: u8) void { - var i: u8 = start; - for (buf) |*v| { - v.* = i; - i +%= 1; - } -} - -pub const Stream = struct { - output: std.io.FixedBufferStream([]u8) = undefined, - input: std.io.FixedBufferStream([]const u8) = undefined, - - pub fn init(input: []const u8, output: []u8) Stream { - return .{ - .input = std.io.fixedBufferStream(input), - .output = std.io.fixedBufferStream(output), - }; - } - - pub const ReadError = error{}; - pub const WriteError = error{NoSpaceLeft}; - - pub fn write(self: *Stream, buf: []const u8) !usize { - return try self.output.writer().write(buf); - } - - pub fn writeAll(self: *Stream, buffer: []const u8) !void { - var n: usize = 0; - while (n < buffer.len) { - n += try self.write(buffer[n..]); - } - } - - pub fn read(self: *Stream, buffer: []u8) !usize { - return self.input.read(buffer); - } -}; - -// Copied from: https://github.com/clickingbuttons/zig/blob/f1cea91624fd2deae28bfb2414a4fd9c7e246883/lib/std/crypto/rsa.zig#L791 -/// For readable copy/pasting from hex viewers. -pub fn hexToBytes(comptime hex: []const u8) [removeNonHex(hex).len / 2]u8 { - @setEvalBranchQuota(1000 * 100); - const hex2 = comptime removeNonHex(hex); - comptime var res: [hex2.len / 2]u8 = undefined; - _ = comptime std.fmt.hexToBytes(&res, hex2) catch unreachable; - return res; -} - -fn removeNonHex(comptime hex: []const u8) []const u8 { - @setEvalBranchQuota(1000 * 100); - var res: [hex.len]u8 = undefined; - var i: usize = 0; - for (hex) |c| { - if (std.ascii.isHex(c)) { - res[i] = c; - i += 1; - } - } - return res[0..i]; -} - -test hexToBytes { - const hex = - \\e3b0c442 98fc1c14 9afbf4c8 996fb924 - \\27ae41e4 649b934c a495991b 7852b855 - ; - try std.testing.expectEqual( - [_]u8{ - 0xe3, 0xb0, 0xc4, 0x42, 0x98, 0xfc, 0x1c, 0x14, - 0x9a, 0xfb, 0xf4, 0xc8, 0x99, 0x6f, 0xb9, 0x24, - 0x27, 0xae, 0x41, 0xe4, 0x64, 0x9b, 0x93, 0x4c, - 0xa4, 0x95, 0x99, 0x1b, 0x78, 0x52, 0xb8, 0x55, - }, - hexToBytes(hex), - ); -} diff --git a/src/http/async/tls.zig/transcript.zig b/src/http/async/tls.zig/transcript.zig deleted file mode 100644 index 59c94986..00000000 --- a/src/http/async/tls.zig/transcript.zig +++ /dev/null @@ -1,297 +0,0 @@ -const std = @import("std"); -const crypto = std.crypto; -const tls = crypto.tls; -const hkdfExpandLabel = tls.hkdfExpandLabel; - -const Sha256 = crypto.hash.sha2.Sha256; -const Sha384 = crypto.hash.sha2.Sha384; -const Sha512 = crypto.hash.sha2.Sha512; - -const HashTag = @import("cipher.zig").CipherSuite.HashTag; - -// Transcript holds hash of all handshake message. -// -// Until the server hello is parsed we don't know which hash (sha256, sha384, -// sha512) will be used so we update all of them. Handshake process will set -// `selected` field once cipher suite is known. Other function will use that -// selected hash. We continue to calculate all hashes because client certificate -// message could use different hash than the other part of the handshake. -// Handshake hash is dictated by the server selected cipher. Client certificate -// hash is dictated by the private key used. -// -// Most of the functions are inlined because they are returning pointers. -// -pub const Transcript = struct { - sha256: Type(.sha256) = .{ .hash = Sha256.init(.{}) }, - sha384: Type(.sha384) = .{ .hash = Sha384.init(.{}) }, - sha512: Type(.sha512) = .{ .hash = Sha512.init(.{}) }, - - tag: HashTag = .sha256, - - pub const max_mac_length = Type(.sha512).mac_length; - - // Transcript Type from hash tag - fn Type(h: HashTag) type { - return switch (h) { - .sha256 => TranscriptT(Sha256), - .sha384 => TranscriptT(Sha384), - .sha512 => TranscriptT(Sha512), - }; - } - - /// Set hash to use in all following function calls. - pub fn use(t: *Transcript, tag: HashTag) void { - t.tag = tag; - } - - pub fn update(t: *Transcript, buf: []const u8) void { - t.sha256.hash.update(buf); - t.sha384.hash.update(buf); - t.sha512.hash.update(buf); - } - - // tls 1.2 handshake specific - - pub inline fn masterSecret( - t: *Transcript, - pre_master_secret: []const u8, - client_random: [32]u8, - server_random: [32]u8, - ) []const u8 { - return switch (t.tag) { - inline else => |h| &@field(t, @tagName(h)).masterSecret( - pre_master_secret, - client_random, - server_random, - ), - }; - } - - pub inline fn keyMaterial( - t: *Transcript, - master_secret: []const u8, - client_random: [32]u8, - server_random: [32]u8, - ) []const u8 { - return switch (t.tag) { - inline else => |h| &@field(t, @tagName(h)).keyExpansion( - master_secret, - client_random, - server_random, - ), - }; - } - - pub fn clientFinishedTls12(t: *Transcript, master_secret: []const u8) [12]u8 { - return switch (t.tag) { - inline else => |h| @field(t, @tagName(h)).clientFinishedTls12(master_secret), - }; - } - - pub fn serverFinishedTls12(t: *Transcript, master_secret: []const u8) [12]u8 { - return switch (t.tag) { - inline else => |h| @field(t, @tagName(h)).serverFinishedTls12(master_secret), - }; - } - - // tls 1.3 handshake specific - - pub inline fn serverCertificateVerify(t: *Transcript) []const u8 { - return switch (t.tag) { - inline else => |h| &@field(t, @tagName(h)).serverCertificateVerify(), - }; - } - - pub inline fn clientCertificateVerify(t: *Transcript) []const u8 { - return switch (t.tag) { - inline else => |h| &@field(t, @tagName(h)).clientCertificateVerify(), - }; - } - - pub fn serverFinishedTls13(t: *Transcript, buf: []u8) []const u8 { - return switch (t.tag) { - inline else => |h| @field(t, @tagName(h)).serverFinishedTls13(buf), - }; - } - - pub fn clientFinishedTls13(t: *Transcript, buf: []u8) []const u8 { - return switch (t.tag) { - inline else => |h| @field(t, @tagName(h)).clientFinishedTls13(buf), - }; - } - - pub const Secret = struct { - client: []const u8, - server: []const u8, - }; - - pub inline fn handshakeSecret(t: *Transcript, shared_key: []const u8) Secret { - return switch (t.tag) { - inline else => |h| @field(t, @tagName(h)).handshakeSecret(shared_key), - }; - } - - pub inline fn applicationSecret(t: *Transcript) Secret { - return switch (t.tag) { - inline else => |h| @field(t, @tagName(h)).applicationSecret(), - }; - } - - // other - - pub fn Hkdf(h: HashTag) type { - return Type(h).Hkdf; - } - - /// Copy of the current hash value - pub inline fn hash(t: *Transcript, comptime Hash: type) Hash { - return switch (Hash) { - Sha256 => t.sha256.hash, - Sha384 => t.sha384.hash, - Sha512 => t.sha512.hash, - else => @compileError("unimplemented"), - }; - } -}; - -fn TranscriptT(comptime Hash: type) type { - return struct { - const Hmac = crypto.auth.hmac.Hmac(Hash); - const Hkdf = crypto.kdf.hkdf.Hkdf(Hmac); - const mac_length = Hmac.mac_length; - - hash: Hash, - handshake_secret: [Hmac.mac_length]u8 = undefined, - server_finished_key: [Hmac.key_length]u8 = undefined, - client_finished_key: [Hmac.key_length]u8 = undefined, - - const Self = @This(); - - fn init(transcript: Hash) Self { - return .{ .transcript = transcript }; - } - - fn serverCertificateVerify(c: *Self) [64 + 34 + Hash.digest_length]u8 { - return ([1]u8{0x20} ** 64) ++ - "TLS 1.3, server CertificateVerify\x00".* ++ - c.hash.peek(); - } - - // ref: https://www.rfc-editor.org/rfc/rfc8446#section-4.4.3 - fn clientCertificateVerify(c: *Self) [64 + 34 + Hash.digest_length]u8 { - return ([1]u8{0x20} ** 64) ++ - "TLS 1.3, client CertificateVerify\x00".* ++ - c.hash.peek(); - } - - fn masterSecret( - _: *Self, - pre_master_secret: []const u8, - client_random: [32]u8, - server_random: [32]u8, - ) [mac_length * 2]u8 { - const seed = "master secret" ++ client_random ++ server_random; - - var a1: [mac_length]u8 = undefined; - var a2: [mac_length]u8 = undefined; - Hmac.create(&a1, seed, pre_master_secret); - Hmac.create(&a2, &a1, pre_master_secret); - - var p1: [mac_length]u8 = undefined; - var p2: [mac_length]u8 = undefined; - Hmac.create(&p1, a1 ++ seed, pre_master_secret); - Hmac.create(&p2, a2 ++ seed, pre_master_secret); - - return p1 ++ p2; - } - - fn keyExpansion( - _: *Self, - master_secret: []const u8, - client_random: [32]u8, - server_random: [32]u8, - ) [mac_length * 4]u8 { - const seed = "key expansion" ++ server_random ++ client_random; - - const a0 = seed; - var a1: [mac_length]u8 = undefined; - var a2: [mac_length]u8 = undefined; - var a3: [mac_length]u8 = undefined; - var a4: [mac_length]u8 = undefined; - Hmac.create(&a1, a0, master_secret); - Hmac.create(&a2, &a1, master_secret); - Hmac.create(&a3, &a2, master_secret); - Hmac.create(&a4, &a3, master_secret); - - var key_material: [mac_length * 4]u8 = undefined; - Hmac.create(key_material[0..mac_length], a1 ++ seed, master_secret); - Hmac.create(key_material[mac_length .. mac_length * 2], a2 ++ seed, master_secret); - Hmac.create(key_material[mac_length * 2 .. mac_length * 3], a3 ++ seed, master_secret); - Hmac.create(key_material[mac_length * 3 ..], a4 ++ seed, master_secret); - return key_material; - } - - fn clientFinishedTls12(self: *Self, master_secret: []const u8) [12]u8 { - const seed = "client finished" ++ self.hash.peek(); - var a1: [mac_length]u8 = undefined; - var p1: [mac_length]u8 = undefined; - Hmac.create(&a1, seed, master_secret); - Hmac.create(&p1, a1 ++ seed, master_secret); - return p1[0..12].*; - } - - fn serverFinishedTls12(self: *Self, master_secret: []const u8) [12]u8 { - const seed = "server finished" ++ self.hash.peek(); - var a1: [mac_length]u8 = undefined; - var p1: [mac_length]u8 = undefined; - Hmac.create(&a1, seed, master_secret); - Hmac.create(&p1, a1 ++ seed, master_secret); - return p1[0..12].*; - } - - // tls 1.3 - - inline fn handshakeSecret(self: *Self, shared_key: []const u8) Transcript.Secret { - const hello_hash = self.hash.peek(); - - const zeroes = [1]u8{0} ** Hash.digest_length; - const early_secret = Hkdf.extract(&[1]u8{0}, &zeroes); - const empty_hash = tls.emptyHash(Hash); - const hs_derived_secret = hkdfExpandLabel(Hkdf, early_secret, "derived", &empty_hash, Hash.digest_length); - - self.handshake_secret = Hkdf.extract(&hs_derived_secret, shared_key); - const client_secret = hkdfExpandLabel(Hkdf, self.handshake_secret, "c hs traffic", &hello_hash, Hash.digest_length); - const server_secret = hkdfExpandLabel(Hkdf, self.handshake_secret, "s hs traffic", &hello_hash, Hash.digest_length); - - self.server_finished_key = hkdfExpandLabel(Hkdf, server_secret, "finished", "", Hmac.key_length); - self.client_finished_key = hkdfExpandLabel(Hkdf, client_secret, "finished", "", Hmac.key_length); - - return .{ .client = &client_secret, .server = &server_secret }; - } - - inline fn applicationSecret(self: *Self) Transcript.Secret { - const handshake_hash = self.hash.peek(); - - const empty_hash = tls.emptyHash(Hash); - const zeroes = [1]u8{0} ** Hash.digest_length; - const ap_derived_secret = hkdfExpandLabel(Hkdf, self.handshake_secret, "derived", &empty_hash, Hash.digest_length); - const master_secret = Hkdf.extract(&ap_derived_secret, &zeroes); - - const client_secret = hkdfExpandLabel(Hkdf, master_secret, "c ap traffic", &handshake_hash, Hash.digest_length); - const server_secret = hkdfExpandLabel(Hkdf, master_secret, "s ap traffic", &handshake_hash, Hash.digest_length); - - return .{ .client = &client_secret, .server = &server_secret }; - } - - fn serverFinishedTls13(self: *Self, buf: []u8) []const u8 { - Hmac.create(buf[0..mac_length], &self.hash.peek(), &self.server_finished_key); - return buf[0..mac_length]; - } - - // client finished message with header - fn clientFinishedTls13(self: *Self, buf: []u8) []const u8 { - Hmac.create(buf[0..mac_length], &self.hash.peek(), &self.client_finished_key); - return buf[0..mac_length]; - } - }; -} diff --git a/src/main.zig b/src/main.zig index d7c3143d..b0e5e4f4 100644 --- a/src/main.zig +++ b/src/main.zig @@ -30,6 +30,7 @@ const apiweb = @import("apiweb.zig"); pub const Types = jsruntime.reflect(apiweb.Interfaces); pub const UserContext = apiweb.UserContext; +pub const IO = @import("asyncio").Wrapper(jsruntime.Loop); // Default options const Host = "127.0.0.1"; diff --git a/src/main_shell.zig b/src/main_shell.zig index ac803ae5..eb88ab50 100644 --- a/src/main_shell.zig +++ b/src/main_shell.zig @@ -24,12 +24,13 @@ const parser = @import("netsurf"); const apiweb = @import("apiweb.zig"); const Window = @import("html/window.zig").Window; const storage = @import("storage/storage.zig"); +const Client = @import("asyncio").Client; const html_test = @import("html_test.zig").html; pub const Types = jsruntime.reflect(apiweb.Interfaces); pub const UserContext = apiweb.UserContext; -const Client = @import("http/async/main.zig").Client; +pub const IO = @import("asyncio").Wrapper(jsruntime.Loop); var doc: *parser.DocumentHTML = undefined; diff --git a/src/main_wpt.zig b/src/main_wpt.zig index 49b7ba23..7cf2f077 100644 --- a/src/main_wpt.zig +++ b/src/main_wpt.zig @@ -50,6 +50,7 @@ const Out = enum { pub const Types = jsruntime.reflect(apiweb.Interfaces); pub const GlobalType = apiweb.GlobalType; pub const UserContext = apiweb.UserContext; +pub const IO = @import("asyncio").Wrapper(jsruntime.Loop); // TODO For now the WPT tests run is specific to WPT. // It manually load js framwork libs, and run the first script w/ js content in diff --git a/src/run_tests.zig b/src/run_tests.zig index 8e285840..a9073d4e 100644 --- a/src/run_tests.zig +++ b/src/run_tests.zig @@ -30,7 +30,7 @@ const xhr = @import("xhr/xhr.zig"); const storage = @import("storage/storage.zig"); const url = @import("url/url.zig"); const urlquery = @import("url/query.zig"); -const Client = @import("http/async/main.zig").Client; +const Client = @import("asyncio").Client; const documentTestExecFn = @import("dom/document.zig").testExecFn; const HTMLDocumentTestExecFn = @import("html/document.zig").testExecFn; @@ -59,6 +59,7 @@ const MutationObserverTestExecFn = @import("dom/mutation_observer.zig").testExec pub const Types = jsruntime.reflect(apiweb.Interfaces); pub const UserContext = @import("user_context.zig").UserContext; +pub const IO = @import("asyncio").Wrapper(jsruntime.Loop); var doc: *parser.DocumentHTML = undefined; @@ -298,9 +299,6 @@ test { const msgTest = @import("msg.zig"); std.testing.refAllDecls(msgTest); - std.testing.refAllDecls(@import("http/async/std/http.zig")); - std.testing.refAllDecls(@import("http/async/stack.zig")); - const dumpTest = @import("browser/dump.zig"); std.testing.refAllDecls(dumpTest); diff --git a/src/test_runner.zig b/src/test_runner.zig index 8b138d0b..8358b66c 100644 --- a/src/test_runner.zig +++ b/src/test_runner.zig @@ -22,6 +22,7 @@ const tests = @import("run_tests.zig"); pub const Types = tests.Types; pub const UserContext = tests.UserContext; +pub const IO = tests.IO; pub fn main() !void { try tests.main(); diff --git a/src/user_context.zig b/src/user_context.zig index 644893c8..3bed0108 100644 --- a/src/user_context.zig +++ b/src/user_context.zig @@ -1,6 +1,6 @@ const std = @import("std"); const parser = @import("netsurf"); -const Client = @import("http/async/main.zig").Client; +const Client = @import("asyncio").Client; pub const UserContext = struct { document: *parser.DocumentHTML, diff --git a/src/wpt/run.zig b/src/wpt/run.zig index ec1d3397..a44b12d7 100644 --- a/src/wpt/run.zig +++ b/src/wpt/run.zig @@ -28,10 +28,10 @@ const Loop = jsruntime.Loop; const Env = jsruntime.Env; const Window = @import("../html/window.zig").Window; const storage = @import("../storage/storage.zig"); +const Client = @import("asyncio").Client; const Types = @import("../main_wpt.zig").Types; const UserContext = @import("../main_wpt.zig").UserContext; -const Client = @import("../http/async/main.zig").Client; // runWPT parses the given HTML file, starts a js env and run the first script // tags containing javascript sources. diff --git a/src/xhr/xhr.zig b/src/xhr/xhr.zig index 660e449a..ab936b43 100644 --- a/src/xhr/xhr.zig +++ b/src/xhr/xhr.zig @@ -32,7 +32,7 @@ const XMLHttpRequestEventTarget = @import("event_target.zig").XMLHttpRequestEven const Mime = @import("../browser/mime.zig"); const Loop = jsruntime.Loop; -const Client = @import("../http/async/main.zig").Client; +const Client = @import("asyncio").Client; const parser = @import("netsurf"); @@ -97,7 +97,7 @@ pub const XMLHttpRequest = struct { proto: XMLHttpRequestEventTarget = XMLHttpRequestEventTarget{}, alloc: std.mem.Allocator, cli: *Client, - loop: Client.Loop, + io: Client.IO, priv_state: PrivState = .new, req: ?Client.Request = null, @@ -294,7 +294,7 @@ pub const XMLHttpRequest = struct { .alloc = alloc, .headers = Headers.init(alloc), .response_headers = Headers.init(alloc), - .loop = Client.Loop.init(loop), + .io = Client.IO.init(loop), .method = undefined, .url = null, .uri = undefined, @@ -513,7 +513,7 @@ pub const XMLHttpRequest = struct { self.req = null; } - self.ctx = try Client.Ctx.init(&self.loop, &self.req.?); + self.ctx = try Client.Ctx.init(&self.io, &self.req.?); errdefer { self.ctx.?.deinit(); self.ctx = null; diff --git a/vendor/zig-async-io b/vendor/zig-async-io new file mode 160000 index 00000000..d996742c --- /dev/null +++ b/vendor/zig-async-io @@ -0,0 +1 @@ +Subproject commit d996742c00f518be4f088af69d81912b8df94d58 diff --git a/vendor/zig-js-runtime b/vendor/zig-js-runtime index f434b3cf..d0e006be 160000 --- a/vendor/zig-js-runtime +++ b/vendor/zig-js-runtime @@ -1 +1 @@ -Subproject commit f434b3cfa1938277a6cd2e225974bb8d33d578c2 +Subproject commit d0e006becd9ddac8a4e0ac9890c7b4087e237bd7