From fadf3f609a1ed5d221f696895e347f6c998179d4 Mon Sep 17 00:00:00 2001 From: Pierre Tachoire Date: Thu, 14 Nov 2024 16:07:27 +0100 Subject: [PATCH 01/11] http: add full async client --- src/http/async/io.zig | 148 + src/http/async/loop.zig | 75 + src/http/async/main.zig | 4 + src/http/async/stack.zig | 95 + src/http/async/std/http.zig | 318 ++ src/http/async/std/http/Client.zig | 2545 +++++++++++++++++ src/http/async/std/http/Server.zig | 1148 ++++++++ src/http/async/std/http/protocol.zig | 447 +++ src/http/async/std/net.zig | 2050 +++++++++++++ src/http/async/std/net/test.zig | 335 +++ src/http/async/tls.zig/PrivateKey.zig | 260 ++ src/http/async/tls.zig/cbc/main.zig | 148 + src/http/async/tls.zig/cipher.zig | 1004 +++++++ src/http/async/tls.zig/connection.zig | 665 +++++ src/http/async/tls.zig/handshake_client.zig | 955 +++++++ src/http/async/tls.zig/handshake_common.zig | 448 +++ src/http/async/tls.zig/handshake_server.zig | 520 ++++ src/http/async/tls.zig/key_log.zig | 60 + src/http/async/tls.zig/main.zig | 51 + src/http/async/tls.zig/protocol.zig | 302 ++ src/http/async/tls.zig/record.zig | 405 +++ src/http/async/tls.zig/rsa/der.zig | 467 +++ src/http/async/tls.zig/rsa/oid.zig | 132 + src/http/async/tls.zig/rsa/rsa.zig | 880 ++++++ .../async/tls.zig/rsa/testdata/id_rsa.der | Bin 0 -> 1191 bytes .../testdata/ec_prime256v1_private_key.pem | 5 + .../async/tls.zig/testdata/ec_private_key.pem | 6 + .../testdata/ec_secp384r1_private_key.pem | 6 + .../testdata/ec_secp521r1_private_key.pem | 7 + .../tls.zig/testdata/google.com/client_random | 1 + .../tls.zig/testdata/google.com/server_hello | Bin 0 -> 7158 bytes .../tls.zig/testdata/rsa_private_key.pem | 28 + src/http/async/tls.zig/testdata/tls12.zig | 244 ++ src/http/async/tls.zig/testdata/tls13.zig | 64 + src/http/async/tls.zig/testu.zig | 117 + src/http/async/tls.zig/transcript.zig | 297 ++ 36 files changed, 14237 insertions(+) create mode 100644 src/http/async/io.zig create mode 100644 src/http/async/loop.zig create mode 100644 src/http/async/main.zig create mode 100644 src/http/async/stack.zig create mode 100644 src/http/async/std/http.zig create mode 100644 src/http/async/std/http/Client.zig create mode 100644 src/http/async/std/http/Server.zig create mode 100644 src/http/async/std/http/protocol.zig create mode 100644 src/http/async/std/net.zig create mode 100644 src/http/async/std/net/test.zig create mode 100644 src/http/async/tls.zig/PrivateKey.zig create mode 100644 src/http/async/tls.zig/cbc/main.zig create mode 100644 src/http/async/tls.zig/cipher.zig create mode 100644 src/http/async/tls.zig/connection.zig create mode 100644 src/http/async/tls.zig/handshake_client.zig create mode 100644 src/http/async/tls.zig/handshake_common.zig create mode 100644 src/http/async/tls.zig/handshake_server.zig create mode 100644 src/http/async/tls.zig/key_log.zig create mode 100644 src/http/async/tls.zig/main.zig create mode 100644 src/http/async/tls.zig/protocol.zig create mode 100644 src/http/async/tls.zig/record.zig create mode 100644 src/http/async/tls.zig/rsa/der.zig create mode 100644 src/http/async/tls.zig/rsa/oid.zig create mode 100644 src/http/async/tls.zig/rsa/rsa.zig create mode 100644 src/http/async/tls.zig/rsa/testdata/id_rsa.der create mode 100644 src/http/async/tls.zig/testdata/ec_prime256v1_private_key.pem create mode 100644 src/http/async/tls.zig/testdata/ec_private_key.pem create mode 100644 src/http/async/tls.zig/testdata/ec_secp384r1_private_key.pem create mode 100644 src/http/async/tls.zig/testdata/ec_secp521r1_private_key.pem create mode 100644 src/http/async/tls.zig/testdata/google.com/client_random create mode 100644 src/http/async/tls.zig/testdata/google.com/server_hello create mode 100644 src/http/async/tls.zig/testdata/rsa_private_key.pem create mode 100644 src/http/async/tls.zig/testdata/tls12.zig create mode 100644 src/http/async/tls.zig/testdata/tls13.zig create mode 100644 src/http/async/tls.zig/testu.zig create mode 100644 src/http/async/tls.zig/transcript.zig diff --git a/src/http/async/io.zig b/src/http/async/io.zig new file mode 100644 index 00000000..a416c5fc --- /dev/null +++ b/src/http/async/io.zig @@ -0,0 +1,148 @@ +const std = @import("std"); + +pub const IO = @import("jsruntime").IO; + +pub const Blocking = struct { + pub fn connect( + _: *Blocking, + comptime CtxT: type, + ctx: *CtxT, + comptime cbk: fn (ctx: *CtxT, res: anyerror!void) anyerror!void, + socket: std.posix.socket_t, + address: std.net.Address, + ) void { + std.posix.connect(socket, &address.any, address.getOsSockLen()) catch |err| { + std.posix.close(socket); + cbk(ctx, err) catch |e| { + ctx.setErr(e); + }; + }; + cbk(ctx, {}) catch |e| ctx.setErr(e); + } + + pub fn send( + _: *Blocking, + comptime CtxT: type, + ctx: *CtxT, + comptime cbk: fn (ctx: *CtxT, res: anyerror!void) anyerror!void, + socket: std.posix.socket_t, + buf: []const u8, + ) void { + const len = std.posix.write(socket, buf) catch |err| { + cbk(ctx, err) catch |e| { + return ctx.setErr(e); + }; + return ctx.setErr(err); + }; + ctx.setLen(len); + cbk(ctx, {}) catch |e| ctx.setErr(e); + } + + pub fn recv( + _: *Blocking, + comptime CtxT: type, + ctx: *CtxT, + comptime cbk: fn (ctx: *CtxT, res: anyerror!void) anyerror!void, + socket: std.posix.socket_t, + buf: []u8, + ) void { + const len = std.posix.read(socket, buf) catch |err| { + cbk(ctx, err) catch |e| { + return ctx.setErr(e); + }; + return ctx.setErr(err); + }; + ctx.setLen(len); + cbk(ctx, {}) catch |e| ctx.setErr(e); + } +}; + +pub fn SingleThreaded(comptime CtxT: type) type { + return struct { + io: *IO, + completion: IO.Completion, + ctx: *CtxT, + cbk: CbkT, + + count: u32 = 0, + + const CbkT = *const fn (ctx: *CtxT, res: anyerror!void) anyerror!void; + + const Self = @This(); + + pub fn init(io: *IO) Self { + return .{ + .io = io, + .completion = undefined, + .ctx = undefined, + .cbk = undefined, + }; + } + + pub fn connect( + self: *Self, + comptime _: type, + ctx: *CtxT, + comptime cbk: CbkT, + socket: std.posix.socket_t, + address: std.net.Address, + ) void { + self.ctx = ctx; + self.cbk = cbk; + self.count += 1; + self.io.connect(*Self, self, Self.connectCbk, &self.completion, socket, address); + } + + fn connectCbk(self: *Self, _: *IO.Completion, result: IO.ConnectError!void) void { + defer self.count -= 1; + _ = result catch |e| return self.ctx.setErr(e); + self.cbk(self.ctx, {}) catch |e| self.ctx.setErr(e); + } + + pub fn send( + self: *Self, + comptime _: type, + ctx: *CtxT, + comptime cbk: CbkT, + socket: std.posix.socket_t, + buf: []const u8, + ) void { + self.ctx = ctx; + self.cbk = cbk; + self.count += 1; + self.io.send(*Self, self, Self.sendCbk, &self.completion, socket, buf); + } + + fn sendCbk(self: *Self, _: *IO.Completion, result: IO.SendError!usize) void { + defer self.count -= 1; + const ln = result catch |e| return self.ctx.setErr(e); + self.ctx.setLen(ln); + self.cbk(self.ctx, {}) catch |e| self.ctx.setErr(e); + } + + pub fn recv( + self: *Self, + comptime _: type, + ctx: *CtxT, + comptime cbk: CbkT, + socket: std.posix.socket_t, + buf: []u8, + ) void { + self.ctx = ctx; + self.cbk = cbk; + self.count += 1; + self.io.recv(*Self, self, Self.receiveCbk, &self.completion, socket, buf); + } + + fn receiveCbk(self: *Self, _: *IO.Completion, result: IO.RecvError!usize) void { + defer self.count -= 1; + const ln = result catch |e| return self.ctx.setErr(e); + self.ctx.setLen(ln); + self.cbk(self.ctx, {}) catch |e| self.ctx.setErr(e); + } + + pub fn isDone(self: *Self) bool { + return self.count == 0; + } + }; +} diff --git a/src/http/async/loop.zig b/src/http/async/loop.zig new file mode 100644 index 00000000..1b18d0f7 --- /dev/null +++ b/src/http/async/loop.zig @@ -0,0 +1,75 @@ +const std = @import("std"); +const Client = @import("std/http/Client.zig"); + +const Stack = @import("stack.zig"); + +const Res = fn (ctx: *Ctx, res: ?anyerror) anyerror!void; + +pub const Blocking = struct { + pub fn connect( + _: *Blocking, + comptime ctxT: type, + ctx: *ctxT, + comptime cbk: Res, + socket: std.os.socket_t, + address: std.net.Address, + ) void { + std.os.connect(socket, &address.any, address.getOsSockLen()) catch |err| { + std.os.closeSocket(socket); + _ = cbk(ctx, err); + return; + }; + ctx.socket = socket; + _ = cbk(ctx, null); + } +}; + +const CtxStack = Stack(Res); + +pub const Ctx = struct { + alloc: std.mem.Allocator, + stack: ?*CtxStack = null, + + // TCP ctx + client: *Client = undefined, + addr_current: usize = undefined, + list: *std.net.AddressList = undefined, + socket: std.os.socket_t = undefined, + Stream: std.net.Stream = undefined, + host: []const u8 = undefined, + port: u16 = undefined, + protocol: Client.Connection.Protocol = undefined, + conn: *Client.Connection = undefined, + uri: std.Uri = undefined, + headers: std.http.Headers = undefined, + method: std.http.Method = undefined, + options: Client.RequestOptions = undefined, + request: Client.Request = undefined, + + err: ?anyerror, + + pub fn init(alloc: std.mem.Allocator) Ctx { + return .{ .alloc = alloc }; + } + + pub fn push(self: *Ctx, function: CtxStack.Fn) !void { + if (self.stack) |stack| { + return try stack.push(self.alloc, function); + } + self.stack = try CtxStack.init(self.alloc, function); + } + + pub fn next(self: *Ctx, err: ?anyerror) !void { + if (self.stack) |stack| { + const last = stack.next == null; + const function = stack.pop(self.alloc, stack); + const res = @call(.auto, function, .{ self, err }); + if (last) { + self.stack = null; + self.alloc.destroy(stack); + } + return res; + } + self.err = err; + } +}; diff --git a/src/http/async/main.zig b/src/http/async/main.zig new file mode 100644 index 00000000..ea756e8b --- /dev/null +++ b/src/http/async/main.zig @@ -0,0 +1,4 @@ +const std = @import("std"); + +const stack = @import("stack.zig"); +pub const Client = @import("std/http/Client.zig"); diff --git a/src/http/async/stack.zig b/src/http/async/stack.zig new file mode 100644 index 00000000..d19a0c8f --- /dev/null +++ b/src/http/async/stack.zig @@ -0,0 +1,95 @@ +const std = @import("std"); + +pub fn Stack(comptime T: type) type { + return struct { + const Self = @This(); + pub const Fn = *const T; + + next: ?*Self = null, + func: Fn, + + pub fn init(alloc: std.mem.Allocator, comptime func: Fn) !*Self { + const next = try alloc.create(Self); + next.* = .{ .func = func }; + return next; + } + + pub fn push(self: *Self, alloc: std.mem.Allocator, comptime func: Fn) !void { + if (self.next) |next| { + return next.push(alloc, func); + } + self.next = try Self.init(alloc, func); + } + + pub fn pop(self: *Self, alloc: std.mem.Allocator, prev: ?*Self) Fn { + if (self.next) |next| { + return next.pop(alloc, self); + } + defer { + if (prev) |p| { + self.deinit(alloc, p); + } + } + return self.func; + } + + pub fn deinit(self: *Self, alloc: std.mem.Allocator, prev: ?*Self) void { + if (self.next) |next| { + // recursivly deinit + next.deinit(alloc, self); + } + if (prev) |p| { + p.next = null; + } + alloc.destroy(self); + } + }; +} + +fn first() u8 { + return 1; +} + +fn second() u8 { + return 2; +} + +test "stack" { + const alloc = std.testing.allocator; + const TestStack = Stack(fn () u8); + + var stack = TestStack{ .func = first }; + try stack.push(alloc, second); + + const a = stack.pop(alloc, null); + try std.testing.expect(a() == 2); + + const b = stack.pop(alloc, null); + try std.testing.expect(b() == 1); +} + +fn first_op(arg: ?*anyopaque) u8 { + const val = @as(*u8, @ptrCast(arg)); + return val.* + @as(u8, 1); +} + +fn second_op(arg: ?*anyopaque) u8 { + const val = @as(*u8, @ptrCast(arg)); + return val.* + @as(u8, 2); +} + +test "opaque stack" { + const alloc = std.testing.allocator; + const TestStack = Stack(fn (?*anyopaque) u8); + + var stack = TestStack{ .func = first_op }; + try stack.push(alloc, second_op); + + const a = stack.pop(alloc, null); + var x: u8 = 5; + try std.testing.expect(a(@as(*anyopaque, @ptrCast(&x))) == 2 + x); + + const b = stack.pop(alloc, null); + var y: u8 = 3; + try std.testing.expect(b(@as(*anyopaque, @ptrCast(&y))) == 1 + y); +} diff --git a/src/http/async/std/http.zig b/src/http/async/std/http.zig new file mode 100644 index 00000000..f027d440 --- /dev/null +++ b/src/http/async/std/http.zig @@ -0,0 +1,318 @@ +pub const Client = @import("http/Client.zig"); +pub const Server = @import("http/Server.zig"); +pub const protocol = @import("http/protocol.zig"); +pub const HeadParser = std.http.HeadParser; +pub const ChunkParser = std.http.ChunkParser; +pub const HeaderIterator = std.http.HeaderIterator; + +pub const Version = enum { + @"HTTP/1.0", + @"HTTP/1.1", +}; + +/// https://developer.mozilla.org/en-US/docs/Web/HTTP/Methods +/// +/// https://datatracker.ietf.org/doc/html/rfc7231#section-4 Initial definition +/// +/// https://datatracker.ietf.org/doc/html/rfc5789#section-2 PATCH +pub const Method = enum(u64) { + GET = parse("GET"), + HEAD = parse("HEAD"), + POST = parse("POST"), + PUT = parse("PUT"), + DELETE = parse("DELETE"), + CONNECT = parse("CONNECT"), + OPTIONS = parse("OPTIONS"), + TRACE = parse("TRACE"), + PATCH = parse("PATCH"), + + _, + + /// Converts `s` into a type that may be used as a `Method` field. + /// Asserts that `s` is 24 or fewer bytes. + pub fn parse(s: []const u8) u64 { + var x: u64 = 0; + const len = @min(s.len, @sizeOf(@TypeOf(x))); + @memcpy(std.mem.asBytes(&x)[0..len], s[0..len]); + return x; + } + + pub fn write(self: Method, w: anytype) !void { + const bytes = std.mem.asBytes(&@intFromEnum(self)); + const str = std.mem.sliceTo(bytes, 0); + try w.writeAll(str); + } + + /// Returns true if a request of this method is allowed to have a body + /// Actual behavior from servers may vary and should still be checked + pub fn requestHasBody(self: Method) bool { + return switch (self) { + .POST, .PUT, .PATCH => true, + .GET, .HEAD, .DELETE, .CONNECT, .OPTIONS, .TRACE => false, + else => true, + }; + } + + /// Returns true if a response to this method is allowed to have a body + /// Actual behavior from clients may vary and should still be checked + pub fn responseHasBody(self: Method) bool { + return switch (self) { + .GET, .POST, .DELETE, .CONNECT, .OPTIONS, .PATCH => true, + .HEAD, .PUT, .TRACE => false, + else => true, + }; + } + + /// An HTTP method is safe if it doesn't alter the state of the server. + /// + /// https://developer.mozilla.org/en-US/docs/Glossary/Safe/HTTP + /// + /// https://datatracker.ietf.org/doc/html/rfc7231#section-4.2.1 + pub fn safe(self: Method) bool { + return switch (self) { + .GET, .HEAD, .OPTIONS, .TRACE => true, + .POST, .PUT, .DELETE, .CONNECT, .PATCH => false, + else => false, + }; + } + + /// An HTTP method is idempotent if an identical request can be made once or several times in a row with the same effect while leaving the server in the same state. + /// + /// https://developer.mozilla.org/en-US/docs/Glossary/Idempotent + /// + /// https://datatracker.ietf.org/doc/html/rfc7231#section-4.2.2 + pub fn idempotent(self: Method) bool { + return switch (self) { + .GET, .HEAD, .PUT, .DELETE, .OPTIONS, .TRACE => true, + .CONNECT, .POST, .PATCH => false, + else => false, + }; + } + + /// A cacheable response is an HTTP response that can be cached, that is stored to be retrieved and used later, saving a new request to the server. + /// + /// https://developer.mozilla.org/en-US/docs/Glossary/cacheable + /// + /// https://datatracker.ietf.org/doc/html/rfc7231#section-4.2.3 + pub fn cacheable(self: Method) bool { + return switch (self) { + .GET, .HEAD => true, + .POST, .PUT, .DELETE, .CONNECT, .OPTIONS, .TRACE, .PATCH => false, + else => false, + }; + } +}; + +/// https://developer.mozilla.org/en-US/docs/Web/HTTP/Status +pub const Status = enum(u10) { + @"continue" = 100, // RFC7231, Section 6.2.1 + switching_protocols = 101, // RFC7231, Section 6.2.2 + processing = 102, // RFC2518 + early_hints = 103, // RFC8297 + + ok = 200, // RFC7231, Section 6.3.1 + created = 201, // RFC7231, Section 6.3.2 + accepted = 202, // RFC7231, Section 6.3.3 + non_authoritative_info = 203, // RFC7231, Section 6.3.4 + no_content = 204, // RFC7231, Section 6.3.5 + reset_content = 205, // RFC7231, Section 6.3.6 + partial_content = 206, // RFC7233, Section 4.1 + multi_status = 207, // RFC4918 + already_reported = 208, // RFC5842 + im_used = 226, // RFC3229 + + multiple_choice = 300, // RFC7231, Section 6.4.1 + moved_permanently = 301, // RFC7231, Section 6.4.2 + found = 302, // RFC7231, Section 6.4.3 + see_other = 303, // RFC7231, Section 6.4.4 + not_modified = 304, // RFC7232, Section 4.1 + use_proxy = 305, // RFC7231, Section 6.4.5 + temporary_redirect = 307, // RFC7231, Section 6.4.7 + permanent_redirect = 308, // RFC7538 + + bad_request = 400, // RFC7231, Section 6.5.1 + unauthorized = 401, // RFC7235, Section 3.1 + payment_required = 402, // RFC7231, Section 6.5.2 + forbidden = 403, // RFC7231, Section 6.5.3 + not_found = 404, // RFC7231, Section 6.5.4 + method_not_allowed = 405, // RFC7231, Section 6.5.5 + not_acceptable = 406, // RFC7231, Section 6.5.6 + proxy_auth_required = 407, // RFC7235, Section 3.2 + request_timeout = 408, // RFC7231, Section 6.5.7 + conflict = 409, // RFC7231, Section 6.5.8 + gone = 410, // RFC7231, Section 6.5.9 + length_required = 411, // RFC7231, Section 6.5.10 + precondition_failed = 412, // RFC7232, Section 4.2][RFC8144, Section 3.2 + payload_too_large = 413, // RFC7231, Section 6.5.11 + uri_too_long = 414, // RFC7231, Section 6.5.12 + unsupported_media_type = 415, // RFC7231, Section 6.5.13][RFC7694, Section 3 + range_not_satisfiable = 416, // RFC7233, Section 4.4 + expectation_failed = 417, // RFC7231, Section 6.5.14 + teapot = 418, // RFC 7168, 2.3.3 + misdirected_request = 421, // RFC7540, Section 9.1.2 + unprocessable_entity = 422, // RFC4918 + locked = 423, // RFC4918 + failed_dependency = 424, // RFC4918 + too_early = 425, // RFC8470 + upgrade_required = 426, // RFC7231, Section 6.5.15 + precondition_required = 428, // RFC6585 + too_many_requests = 429, // RFC6585 + request_header_fields_too_large = 431, // RFC6585 + unavailable_for_legal_reasons = 451, // RFC7725 + + internal_server_error = 500, // RFC7231, Section 6.6.1 + not_implemented = 501, // RFC7231, Section 6.6.2 + bad_gateway = 502, // RFC7231, Section 6.6.3 + service_unavailable = 503, // RFC7231, Section 6.6.4 + gateway_timeout = 504, // RFC7231, Section 6.6.5 + http_version_not_supported = 505, // RFC7231, Section 6.6.6 + variant_also_negotiates = 506, // RFC2295 + insufficient_storage = 507, // RFC4918 + loop_detected = 508, // RFC5842 + not_extended = 510, // RFC2774 + network_authentication_required = 511, // RFC6585 + + _, + + pub fn phrase(self: Status) ?[]const u8 { + return switch (self) { + // 1xx statuses + .@"continue" => "Continue", + .switching_protocols => "Switching Protocols", + .processing => "Processing", + .early_hints => "Early Hints", + + // 2xx statuses + .ok => "OK", + .created => "Created", + .accepted => "Accepted", + .non_authoritative_info => "Non-Authoritative Information", + .no_content => "No Content", + .reset_content => "Reset Content", + .partial_content => "Partial Content", + .multi_status => "Multi-Status", + .already_reported => "Already Reported", + .im_used => "IM Used", + + // 3xx statuses + .multiple_choice => "Multiple Choice", + .moved_permanently => "Moved Permanently", + .found => "Found", + .see_other => "See Other", + .not_modified => "Not Modified", + .use_proxy => "Use Proxy", + .temporary_redirect => "Temporary Redirect", + .permanent_redirect => "Permanent Redirect", + + // 4xx statuses + .bad_request => "Bad Request", + .unauthorized => "Unauthorized", + .payment_required => "Payment Required", + .forbidden => "Forbidden", + .not_found => "Not Found", + .method_not_allowed => "Method Not Allowed", + .not_acceptable => "Not Acceptable", + .proxy_auth_required => "Proxy Authentication Required", + .request_timeout => "Request Timeout", + .conflict => "Conflict", + .gone => "Gone", + .length_required => "Length Required", + .precondition_failed => "Precondition Failed", + .payload_too_large => "Payload Too Large", + .uri_too_long => "URI Too Long", + .unsupported_media_type => "Unsupported Media Type", + .range_not_satisfiable => "Range Not Satisfiable", + .expectation_failed => "Expectation Failed", + .teapot => "I'm a teapot", + .misdirected_request => "Misdirected Request", + .unprocessable_entity => "Unprocessable Entity", + .locked => "Locked", + .failed_dependency => "Failed Dependency", + .too_early => "Too Early", + .upgrade_required => "Upgrade Required", + .precondition_required => "Precondition Required", + .too_many_requests => "Too Many Requests", + .request_header_fields_too_large => "Request Header Fields Too Large", + .unavailable_for_legal_reasons => "Unavailable For Legal Reasons", + + // 5xx statuses + .internal_server_error => "Internal Server Error", + .not_implemented => "Not Implemented", + .bad_gateway => "Bad Gateway", + .service_unavailable => "Service Unavailable", + .gateway_timeout => "Gateway Timeout", + .http_version_not_supported => "HTTP Version Not Supported", + .variant_also_negotiates => "Variant Also Negotiates", + .insufficient_storage => "Insufficient Storage", + .loop_detected => "Loop Detected", + .not_extended => "Not Extended", + .network_authentication_required => "Network Authentication Required", + + else => return null, + }; + } + + pub const Class = enum { + informational, + success, + redirect, + client_error, + server_error, + }; + + pub fn class(self: Status) Class { + return switch (@intFromEnum(self)) { + 100...199 => .informational, + 200...299 => .success, + 300...399 => .redirect, + 400...499 => .client_error, + else => .server_error, + }; + } + + test { + try std.testing.expectEqualStrings("OK", Status.ok.phrase().?); + try std.testing.expectEqualStrings("Not Found", Status.not_found.phrase().?); + } + + test { + try std.testing.expectEqual(Status.Class.success, Status.ok.class()); + try std.testing.expectEqual(Status.Class.client_error, Status.not_found.class()); + } +}; + +pub const TransferEncoding = enum { + chunked, + none, + // compression is intentionally omitted here, as std.http.Client stores it as content-encoding +}; + +pub const ContentEncoding = enum { + identity, + compress, + @"x-compress", + deflate, + gzip, + @"x-gzip", + zstd, +}; + +pub const Connection = enum { + keep_alive, + close, +}; + +pub const Header = struct { + name: []const u8, + value: []const u8, +}; + +const builtin = @import("builtin"); +const std = @import("std"); + +test { + _ = Client; + _ = Method; + _ = Server; + _ = Status; +} diff --git a/src/http/async/std/http/Client.zig b/src/http/async/std/http/Client.zig new file mode 100644 index 00000000..f0c37b20 --- /dev/null +++ b/src/http/async/std/http/Client.zig @@ -0,0 +1,2545 @@ +//! HTTP(S) Client implementation. +//! +//! Connections are opened in a thread-safe manner, but individual Requests are not. +//! +//! TLS support may be disabled via `std.options.http_disable_tls`. + +const std = @import("std"); +const builtin = @import("builtin"); +const testing = std.testing; +const http = std.http; +const mem = std.mem; +const net = @import("../net.zig"); +const Uri = std.Uri; +const Allocator = mem.Allocator; +const assert = std.debug.assert; +const use_vectors = builtin.zig_backend != .stage2_x86_64; + +const Client = @This(); +const proto = @import("protocol.zig"); + +const tls23 = @import("../../tls.zig/main.zig"); +const VecPut = @import("../../tls.zig/connection.zig").VecPut; +const GenericStack = @import("../../stack.zig").Stack; +const async_io = @import("../../io.zig"); +pub const Loop = async_io.SingleThreaded(Ctx); + +const cipher = @import("../../tls.zig/cipher.zig"); + +pub const disable_tls = std.options.http_disable_tls; + +/// Used for all client allocations. Must be thread-safe. +allocator: Allocator, + +ca_bundle: if (disable_tls) void else std.crypto.Certificate.Bundle = if (disable_tls) {} else .{}, +ca_bundle_mutex: std.Thread.Mutex = .{}, + +/// When this is `true`, the next time this client performs an HTTPS request, +/// it will first rescan the system for root certificates. +next_https_rescan_certs: bool = true, + +/// The pool of connections that can be reused (and currently in use). +connection_pool: ConnectionPool = .{}, + +/// If populated, all http traffic travels through this third party. +/// This field cannot be modified while the client has active connections. +/// Pointer to externally-owned memory. +http_proxy: ?*Proxy = null, +/// If populated, all https traffic travels through this third party. +/// This field cannot be modified while the client has active connections. +/// Pointer to externally-owned memory. +https_proxy: ?*Proxy = null, + +/// A set of linked lists of connections that can be reused. +pub const ConnectionPool = struct { + mutex: std.Thread.Mutex = .{}, + /// Open connections that are currently in use. + used: Queue = .{}, + /// Open connections that are not currently in use. + free: Queue = .{}, + free_len: usize = 0, + free_size: usize = 32, + + /// The criteria for a connection to be considered a match. + pub const Criteria = struct { + host: []const u8, + port: u16, + protocol: Connection.Protocol, + }; + + const Queue = std.DoublyLinkedList(Connection); + pub const Node = Queue.Node; + + /// Finds and acquires a connection from the connection pool matching the criteria. This function is threadsafe. + /// If no connection is found, null is returned. + pub fn findConnection(pool: *ConnectionPool, criteria: Criteria) ?*Connection { + pool.mutex.lock(); + defer pool.mutex.unlock(); + + var next = pool.free.last; + while (next) |node| : (next = node.prev) { + if (node.data.protocol != criteria.protocol) continue; + if (node.data.port != criteria.port) continue; + + // Domain names are case-insensitive (RFC 5890, Section 2.3.2.4) + if (!std.ascii.eqlIgnoreCase(node.data.host, criteria.host)) continue; + + pool.acquireUnsafe(node); + return &node.data; + } + + return null; + } + + /// Acquires an existing connection from the connection pool. This function is not threadsafe. + pub fn acquireUnsafe(pool: *ConnectionPool, node: *Node) void { + pool.free.remove(node); + pool.free_len -= 1; + + pool.used.append(node); + } + + /// Acquires an existing connection from the connection pool. This function is threadsafe. + pub fn acquire(pool: *ConnectionPool, node: *Node) void { + pool.mutex.lock(); + defer pool.mutex.unlock(); + + return pool.acquireUnsafe(node); + } + + /// Tries to release a connection back to the connection pool. This function is threadsafe. + /// If the connection is marked as closing, it will be closed instead. + /// + /// The allocator must be the owner of all nodes in this pool. + /// The allocator must be the owner of all resources associated with the connection. + pub fn release(pool: *ConnectionPool, allocator: Allocator, connection: *Connection) void { + pool.mutex.lock(); + defer pool.mutex.unlock(); + + const node: *Node = @fieldParentPtr("data", connection); + + pool.used.remove(node); + + if (node.data.closing or pool.free_size == 0) { + node.data.close(allocator); + return allocator.destroy(node); + } + + if (pool.free_len >= pool.free_size) { + const popped = pool.free.popFirst() orelse unreachable; + pool.free_len -= 1; + + popped.data.close(allocator); + allocator.destroy(popped); + } + + if (node.data.proxied) { + pool.free.prepend(node); // proxied connections go to the end of the queue, always try direct connections first + } else { + pool.free.append(node); + } + + pool.free_len += 1; + } + + /// Adds a newly created node to the pool of used connections. This function is threadsafe. + pub fn addUsed(pool: *ConnectionPool, node: *Node) void { + pool.mutex.lock(); + defer pool.mutex.unlock(); + + pool.used.append(node); + } + + /// Resizes the connection pool. This function is threadsafe. + /// + /// If the new size is smaller than the current size, then idle connections will be closed until the pool is the new size. + pub fn resize(pool: *ConnectionPool, allocator: Allocator, new_size: usize) void { + pool.mutex.lock(); + defer pool.mutex.unlock(); + + const next = pool.free.first; + _ = next; + while (pool.free_len > new_size) { + const popped = pool.free.popFirst() orelse unreachable; + pool.free_len -= 1; + + popped.data.close(allocator); + allocator.destroy(popped); + } + + pool.free_size = new_size; + } + + /// Frees the connection pool and closes all connections within. This function is threadsafe. + /// + /// All future operations on the connection pool will deadlock. + pub fn deinit(pool: *ConnectionPool, allocator: Allocator) void { + pool.mutex.lock(); + + var next = pool.free.first; + while (next) |node| { + defer allocator.destroy(node); + next = node.next; + + node.data.close(allocator); + } + + next = pool.used.first; + while (next) |node| { + defer allocator.destroy(node); + next = node.next; + + node.data.close(allocator); + } + + pool.* = undefined; + } +}; + +/// An interface to either a plain or TLS connection. +pub const Connection = struct { + stream: net.Stream, + /// undefined unless protocol is tls. + tls_client: if (!disable_tls) *tls23.Connection(net.Stream) else void, + + /// The protocol that this connection is using. + protocol: Protocol, + + /// The host that this connection is connected to. + host: []u8, + + /// The port that this connection is connected to. + port: u16, + + /// Whether this connection is proxied and is not directly connected. + proxied: bool = false, + + /// Whether this connection is closing when we're done with it. + closing: bool = false, + + read_start: BufferSize = 0, + read_end: BufferSize = 0, + write_end: BufferSize = 0, + read_buf: [buffer_size]u8 = undefined, + write_buf: [buffer_size]u8 = undefined, + + pub const buffer_size = std.crypto.tls.max_ciphertext_record_len; + const BufferSize = std.math.IntFittingRange(0, buffer_size); + + pub const Protocol = enum { plain, tls }; + + pub fn async_readvDirect( + conn: *Connection, + buffers: []std.posix.iovec, + ctx: *Ctx, + comptime cbk: Cbk, + ) !void { + _ = conn; + + if (ctx.conn().protocol == .tls) { + if (disable_tls) unreachable; + + return ctx.conn().tls_client.async_readv(ctx.conn().stream, buffers, ctx, cbk); + } + + return ctx.stream().async_readv(buffers, ctx, cbk); + } + + pub fn readvDirectTls(conn: *Connection, buffers: []std.posix.iovec) ReadError!usize { + return conn.tls_client.readv(buffers) catch |err| { + // https://github.com/ziglang/zig/issues/2473 + if (mem.startsWith(u8, @errorName(err), "TlsAlert")) return error.TlsAlert; + + switch (err) { + error.TlsRecordOverflow, error.TlsBadRecordMac, error.TlsUnexpectedMessage => return error.TlsFailure, + error.ConnectionTimedOut => return error.ConnectionTimedOut, + error.ConnectionResetByPeer, error.BrokenPipe => return error.ConnectionResetByPeer, + else => return error.UnexpectedReadFailure, + } + }; + } + + pub fn readvDirect(conn: *Connection, buffers: []std.posix.iovec) ReadError!usize { + if (conn.protocol == .tls) { + if (disable_tls) unreachable; + + return conn.readvDirectTls(buffers); + } + + return conn.stream.readv(buffers) catch |err| switch (err) { + error.ConnectionTimedOut => return error.ConnectionTimedOut, + error.ConnectionResetByPeer, error.BrokenPipe => return error.ConnectionResetByPeer, + else => return error.UnexpectedReadFailure, + }; + } + + fn onFill(ctx: *Ctx, res: anyerror!void) anyerror!void { + ctx.alloc().free(ctx._iovecs); + res catch |err| return ctx.pop(err); + + // EOF + const nread = ctx.len(); + if (nread == 0) return ctx.pop(error.EndOfStream); + + // finished + ctx.conn().read_start = 0; + ctx.conn().read_end = @intCast(nread); + return ctx.pop({}); + } + + pub fn async_fill(conn: *Connection, ctx: *Ctx, comptime cbk: Cbk) !void { + if (conn.read_end != conn.read_start) return; + + ctx._iovecs = try ctx.alloc().alloc(std.posix.iovec, 1); + errdefer ctx.alloc().free(ctx._iovecs); + const iovecs = [1]std.posix.iovec{ + .{ .base = &conn.read_buf, .len = conn.read_buf.len }, + }; + @memcpy(ctx._iovecs, &iovecs); + + try ctx.push(cbk); + return conn.async_readvDirect(ctx._iovecs, ctx, onFill); + } + + /// Refills the read buffer with data from the connection. + pub fn fill(conn: *Connection) ReadError!void { + if (conn.read_end != conn.read_start) return; + + var iovecs = [1]std.posix.iovec{ + .{ .base = &conn.read_buf, .len = conn.read_buf.len }, + }; + const nread = try conn.readvDirect(&iovecs); + if (nread == 0) return error.EndOfStream; + conn.read_start = 0; + conn.read_end = @intCast(nread); + } + + /// Returns the current slice of buffered data. + pub fn peek(conn: *Connection) []const u8 { + return conn.read_buf[conn.read_start..conn.read_end]; + } + + /// Discards the given number of bytes from the read buffer. + pub fn drop(conn: *Connection, num: BufferSize) void { + conn.read_start += num; + } + + /// Reads data from the connection into the given buffer. + pub fn read(conn: *Connection, buffer: []u8) ReadError!usize { + const available_read = conn.read_end - conn.read_start; + const available_buffer = buffer.len; + + if (available_read > available_buffer) { // partially read buffered data + @memcpy(buffer[0..available_buffer], conn.read_buf[conn.read_start..conn.read_end][0..available_buffer]); + conn.read_start += @intCast(available_buffer); + + return available_buffer; + } else if (available_read > 0) { // fully read buffered data + @memcpy(buffer[0..available_read], conn.read_buf[conn.read_start..conn.read_end]); + conn.read_start += available_read; + + return available_read; + } + + var iovecs = [2]std.posix.iovec{ + .{ .base = buffer.ptr, .len = buffer.len }, + .{ .base = &conn.read_buf, .len = conn.read_buf.len }, + }; + const nread = try conn.readvDirect(&iovecs); + + if (nread > buffer.len) { + conn.read_start = 0; + conn.read_end = @intCast(nread - buffer.len); + return buffer.len; + } + + return nread; + } + + pub const ReadError = error{ + TlsFailure, + TlsAlert, + ConnectionTimedOut, + ConnectionResetByPeer, + UnexpectedReadFailure, + EndOfStream, + }; + + pub const Reader = std.io.Reader(*Connection, ReadError, read); + + pub fn reader(conn: *Connection) Reader { + return Reader{ .context = conn }; + } + + pub fn writeAllDirectTls(conn: *Connection, buffer: []const u8) WriteError!void { + return conn.tls_client.writeAll(buffer) catch |err| switch (err) { + error.BrokenPipe, error.ConnectionResetByPeer => return error.ConnectionResetByPeer, + else => return error.UnexpectedWriteFailure, + }; + } + + fn onWriteAllDirect(ctx: *Ctx, res: anyerror!void) !void { + res catch |err| switch (err) { + error.BrokenPipe, + error.ConnectionResetByPeer, + => return ctx.pop(error.ConnectionResetByPeer), + else => return ctx.pop(error.UnexpectedWriteFailure), + }; + return ctx.pop({}); + } + + pub fn async_writeAllDirect( + conn: *Connection, + buffer: []const u8, + ctx: *Ctx, + comptime cbk: Cbk, + ) !void { + try ctx.push(cbk); + if (conn.protocol == .tls) { + if (disable_tls) unreachable; + + return conn.tls_client.async_writeAll(conn.stream, buffer, ctx, onWriteAllDirect); + } + + return conn.stream.async_writeAll(buffer, ctx, onWriteAllDirect); + } + + pub fn writeAllDirect(conn: *Connection, buffer: []const u8) WriteError!void { + if (conn.protocol == .tls) { + if (disable_tls) unreachable; + + return conn.writeAllDirectTls(buffer); + } + + return conn.stream.writeAll(buffer) catch |err| switch (err) { + error.BrokenPipe, error.ConnectionResetByPeer => return error.ConnectionResetByPeer, + else => return error.UnexpectedWriteFailure, + }; + } + + /// Writes the given buffer to the connection. + pub fn write(conn: *Connection, buffer: []const u8) WriteError!usize { + if (conn.write_buf.len - conn.write_end < buffer.len) { + try conn.flush(); + + if (buffer.len > conn.write_buf.len) { + try conn.writeAllDirect(buffer); + return buffer.len; + } + } + + @memcpy(conn.write_buf[conn.write_end..][0..buffer.len], buffer); + conn.write_end += @intCast(buffer.len); + + return buffer.len; + } + + /// Returns a buffer to be filled with exactly len bytes to write to the connection. + pub fn allocWriteBuffer(conn: *Connection, len: BufferSize) WriteError![]u8 { + if (conn.write_buf.len - conn.write_end < len) try conn.flush(); + defer conn.write_end += len; + return conn.write_buf[conn.write_end..][0..len]; + } + + fn onFlush(ctx: *Ctx, res: anyerror!void) !void { + res catch |err| return ctx.pop(err); + ctx.conn().write_end = 0; + return ctx.pop({}); + } + + pub fn async_flush(conn: *Connection, ctx: *Ctx, comptime cbk: Cbk) !void { + if (conn.write_end == 0) return error.WriteEmpty; + + try ctx.push(cbk); + try conn.async_writeAllDirect(conn.write_buf[0..conn.write_end], ctx, onFlush); + } + + /// Flushes the write buffer to the connection. + pub fn flush(conn: *Connection) WriteError!void { + if (conn.write_end == 0) return; + + try conn.writeAllDirect(conn.write_buf[0..conn.write_end]); + conn.write_end = 0; + } + + pub const WriteError = error{ + ConnectionResetByPeer, + UnexpectedWriteFailure, + }; + + pub const Writer = std.io.Writer(*Connection, WriteError, write); + + pub fn writer(conn: *Connection) Writer { + return Writer{ .context = conn }; + } + + /// Closes the connection. + pub fn close(conn: *Connection, allocator: Allocator) void { + if (conn.protocol == .tls) { + if (disable_tls) unreachable; + + // try to cleanly close the TLS connection, for any server that cares. + conn.tls_client.close() catch {}; + allocator.destroy(conn.tls_client); + } + + conn.stream.close(); + allocator.free(conn.host); + } +}; + +/// The mode of transport for requests. +pub const RequestTransfer = union(enum) { + content_length: u64, + chunked: void, + none: void, +}; + +/// The decompressor for response messages. +pub const Compression = union(enum) { + pub const DeflateDecompressor = std.compress.zlib.Decompressor(Request.TransferReader); + pub const GzipDecompressor = std.compress.gzip.Decompressor(Request.TransferReader); + // https://github.com/ziglang/zig/issues/18937 + //pub const ZstdDecompressor = std.compress.zstd.DecompressStream(Request.TransferReader, .{}); + + deflate: DeflateDecompressor, + gzip: GzipDecompressor, + // https://github.com/ziglang/zig/issues/18937 + //zstd: ZstdDecompressor, + none: void, +}; + +/// A HTTP response originating from a server. +pub const Response = struct { + version: http.Version, + status: http.Status, + reason: []const u8, + + /// Points into the user-provided `server_header_buffer`. + location: ?[]const u8 = null, + /// Points into the user-provided `server_header_buffer`. + content_type: ?[]const u8 = null, + /// Points into the user-provided `server_header_buffer`. + content_disposition: ?[]const u8 = null, + + keep_alive: bool, + + /// If present, the number of bytes in the response body. + content_length: ?u64 = null, + + /// If present, the transfer encoding of the response body, otherwise none. + transfer_encoding: http.TransferEncoding = .none, + + /// If present, the compression of the response body, otherwise identity (no compression). + transfer_compression: http.ContentEncoding = .identity, + + parser: proto.HeadersParser, + compression: Compression = .none, + + /// Whether the response body should be skipped. Any data read from the + /// response body will be discarded. + skip: bool = false, + + pub const ParseError = error{ + HttpHeadersInvalid, + HttpHeaderContinuationsUnsupported, + HttpTransferEncodingUnsupported, + HttpConnectionHeaderUnsupported, + InvalidContentLength, + CompressionUnsupported, + }; + + pub fn parse(res: *Response, bytes: []const u8) ParseError!void { + var it = mem.splitSequence(u8, bytes, "\r\n"); + + const first_line = it.next().?; + if (first_line.len < 12) { + return error.HttpHeadersInvalid; + } + + const version: http.Version = switch (int64(first_line[0..8])) { + int64("HTTP/1.0") => .@"HTTP/1.0", + int64("HTTP/1.1") => .@"HTTP/1.1", + else => return error.HttpHeadersInvalid, + }; + if (first_line[8] != ' ') return error.HttpHeadersInvalid; + const status: http.Status = @enumFromInt(parseInt3(first_line[9..12])); + const reason = mem.trimLeft(u8, first_line[12..], " "); + + res.version = version; + res.status = status; + res.reason = reason; + res.keep_alive = switch (version) { + .@"HTTP/1.0" => false, + .@"HTTP/1.1" => true, + }; + + while (it.next()) |line| { + if (line.len == 0) return; + switch (line[0]) { + ' ', '\t' => return error.HttpHeaderContinuationsUnsupported, + else => {}, + } + + var line_it = mem.splitScalar(u8, line, ':'); + const header_name = line_it.next().?; + const header_value = mem.trim(u8, line_it.rest(), " \t"); + if (header_name.len == 0) return error.HttpHeadersInvalid; + + if (std.ascii.eqlIgnoreCase(header_name, "connection")) { + res.keep_alive = !std.ascii.eqlIgnoreCase(header_value, "close"); + } else if (std.ascii.eqlIgnoreCase(header_name, "content-type")) { + res.content_type = header_value; + } else if (std.ascii.eqlIgnoreCase(header_name, "location")) { + res.location = header_value; + } else if (std.ascii.eqlIgnoreCase(header_name, "content-disposition")) { + res.content_disposition = header_value; + } else if (std.ascii.eqlIgnoreCase(header_name, "transfer-encoding")) { + // Transfer-Encoding: second, first + // Transfer-Encoding: deflate, chunked + var iter = mem.splitBackwardsScalar(u8, header_value, ','); + + const first = iter.first(); + const trimmed_first = mem.trim(u8, first, " "); + + var next: ?[]const u8 = first; + if (std.meta.stringToEnum(http.TransferEncoding, trimmed_first)) |transfer| { + if (res.transfer_encoding != .none) return error.HttpHeadersInvalid; // we already have a transfer encoding + res.transfer_encoding = transfer; + + next = iter.next(); + } + + if (next) |second| { + const trimmed_second = mem.trim(u8, second, " "); + + if (std.meta.stringToEnum(http.ContentEncoding, trimmed_second)) |transfer| { + if (res.transfer_compression != .identity) return error.HttpHeadersInvalid; // double compression is not supported + res.transfer_compression = transfer; + } else { + return error.HttpTransferEncodingUnsupported; + } + } + + if (iter.next()) |_| return error.HttpTransferEncodingUnsupported; + } else if (std.ascii.eqlIgnoreCase(header_name, "content-length")) { + const content_length = std.fmt.parseInt(u64, header_value, 10) catch return error.InvalidContentLength; + + if (res.content_length != null and res.content_length != content_length) return error.HttpHeadersInvalid; + + res.content_length = content_length; + } else if (std.ascii.eqlIgnoreCase(header_name, "content-encoding")) { + if (res.transfer_compression != .identity) return error.HttpHeadersInvalid; + + const trimmed = mem.trim(u8, header_value, " "); + + if (std.meta.stringToEnum(http.ContentEncoding, trimmed)) |ce| { + res.transfer_compression = ce; + } else { + return error.HttpTransferEncodingUnsupported; + } + } + } + return error.HttpHeadersInvalid; // missing empty line + } + + test parse { + const response_bytes = "HTTP/1.1 200 OK\r\n" ++ + "LOcation:url\r\n" ++ + "content-tYpe: text/plain\r\n" ++ + "content-disposition:attachment; filename=example.txt \r\n" ++ + "content-Length:10\r\n" ++ + "TRansfer-encoding:\tdeflate, chunked \r\n" ++ + "connectioN:\t keep-alive \r\n\r\n"; + + var header_buffer: [1024]u8 = undefined; + var res = Response{ + .status = undefined, + .reason = undefined, + .version = undefined, + .keep_alive = false, + .parser = proto.HeadersParser.init(&header_buffer), + }; + + @memcpy(header_buffer[0..response_bytes.len], response_bytes); + res.parser.header_bytes_len = response_bytes.len; + + try res.parse(response_bytes); + + try testing.expectEqual(.@"HTTP/1.1", res.version); + try testing.expectEqualStrings("OK", res.reason); + try testing.expectEqual(.ok, res.status); + + try testing.expectEqualStrings("url", res.location.?); + try testing.expectEqualStrings("text/plain", res.content_type.?); + try testing.expectEqualStrings("attachment; filename=example.txt", res.content_disposition.?); + + try testing.expectEqual(true, res.keep_alive); + try testing.expectEqual(10, res.content_length.?); + try testing.expectEqual(.chunked, res.transfer_encoding); + try testing.expectEqual(.deflate, res.transfer_compression); + } + + inline fn int64(array: *const [8]u8) u64 { + return @bitCast(array.*); + } + + fn parseInt3(text: *const [3]u8) u10 { + if (use_vectors) { + const nnn: @Vector(3, u8) = text.*; + const zero: @Vector(3, u8) = .{ '0', '0', '0' }; + const mmm: @Vector(3, u10) = .{ 100, 10, 1 }; + return @reduce(.Add, @as(@Vector(3, u10), nnn -% zero) *% mmm); + } + return std.fmt.parseInt(u10, text, 10) catch unreachable; + } + + test parseInt3 { + const expectEqual = testing.expectEqual; + try expectEqual(@as(u10, 0), parseInt3("000")); + try expectEqual(@as(u10, 418), parseInt3("418")); + try expectEqual(@as(u10, 999), parseInt3("999")); + } + + pub fn iterateHeaders(r: Response) http.HeaderIterator { + return http.HeaderIterator.init(r.parser.get()); + } + + test iterateHeaders { + const response_bytes = "HTTP/1.1 200 OK\r\n" ++ + "LOcation:url\r\n" ++ + "content-tYpe: text/plain\r\n" ++ + "content-disposition:attachment; filename=example.txt \r\n" ++ + "content-Length:10\r\n" ++ + "TRansfer-encoding:\tdeflate, chunked \r\n" ++ + "connectioN:\t keep-alive \r\n\r\n"; + + var header_buffer: [1024]u8 = undefined; + var res = Response{ + .status = undefined, + .reason = undefined, + .version = undefined, + .keep_alive = false, + .parser = proto.HeadersParser.init(&header_buffer), + }; + + @memcpy(header_buffer[0..response_bytes.len], response_bytes); + res.parser.header_bytes_len = response_bytes.len; + + var it = res.iterateHeaders(); + { + const header = it.next().?; + try testing.expectEqualStrings("LOcation", header.name); + try testing.expectEqualStrings("url", header.value); + try testing.expect(!it.is_trailer); + } + { + const header = it.next().?; + try testing.expectEqualStrings("content-tYpe", header.name); + try testing.expectEqualStrings("text/plain", header.value); + try testing.expect(!it.is_trailer); + } + { + const header = it.next().?; + try testing.expectEqualStrings("content-disposition", header.name); + try testing.expectEqualStrings("attachment; filename=example.txt", header.value); + try testing.expect(!it.is_trailer); + } + { + const header = it.next().?; + try testing.expectEqualStrings("content-Length", header.name); + try testing.expectEqualStrings("10", header.value); + try testing.expect(!it.is_trailer); + } + { + const header = it.next().?; + try testing.expectEqualStrings("TRansfer-encoding", header.name); + try testing.expectEqualStrings("deflate, chunked", header.value); + try testing.expect(!it.is_trailer); + } + { + const header = it.next().?; + try testing.expectEqualStrings("connectioN", header.name); + try testing.expectEqualStrings("keep-alive", header.value); + try testing.expect(!it.is_trailer); + } + try testing.expectEqual(null, it.next()); + } +}; + +/// A HTTP request that has been sent. +/// +/// Order of operations: open -> send[ -> write -> finish] -> wait -> read +pub const Request = struct { + uri: Uri = undefined, + client: *Client, + /// This is null when the connection is released. + connection: ?*Connection = null, + keep_alive: bool = undefined, + + method: http.Method = undefined, + version: http.Version = .@"HTTP/1.1", + transfer_encoding: RequestTransfer = undefined, + redirect_behavior: RedirectBehavior = undefined, + + /// Whether the request should handle a 100-continue response before sending the request body. + handle_continue: bool = undefined, + + /// The response associated with this request. + /// + /// This field is undefined until `wait` is called. + response: Response = undefined, + + /// Standard headers that have default, but overridable, behavior. + headers: Headers = undefined, + + /// These headers are kept including when following a redirect to a + /// different domain. + /// Externally-owned; must outlive the Request. + extra_headers: []const http.Header = undefined, + + /// These headers are stripped when following a redirect to a different + /// domain. + /// Externally-owned; must outlive the Request. + privileged_headers: []const http.Header = undefined, + + pub fn init(client: *Client) Request { + return .{ + .client = client, + }; + } + + pub const Headers = struct { + host: Value = .default, + authorization: Value = .default, + user_agent: Value = .default, + connection: Value = .default, + accept_encoding: Value = .default, + content_type: Value = .default, + + pub const Value = union(enum) { + default, + omit, + override: []const u8, + }; + }; + + /// Any value other than `not_allowed` or `unhandled` means that integer represents + /// how many remaining redirects are allowed. + pub const RedirectBehavior = enum(u16) { + /// The next redirect will cause an error. + not_allowed = 0, + /// Redirects are passed to the client to analyze the redirect response + /// directly. + unhandled = std.math.maxInt(u16), + _, + + pub fn subtractOne(rb: *RedirectBehavior) void { + switch (rb.*) { + .not_allowed => unreachable, + .unhandled => unreachable, + _ => rb.* = @enumFromInt(@intFromEnum(rb.*) - 1), + } + } + + pub fn remaining(rb: RedirectBehavior) u16 { + assert(rb != .unhandled); + return @intFromEnum(rb); + } + }; + + /// Frees all resources associated with the request. + pub fn deinit(req: *Request) void { + if (req.connection) |connection| { + if (!req.response.parser.done) { + // If the response wasn't fully read, then we need to close the connection. + connection.closing = true; + } + req.client.connection_pool.release(req.client.allocator, connection); + } + req.* = undefined; + } + + fn onRedirectSend(ctx: *Ctx, res: anyerror!void) !void { + res catch |err| return ctx.pop(err); + // go back on check headers + ctx.conn().async_fill(ctx, onResponseHeaders) catch |err| return ctx.pop(err); + } + + fn onRedirectConnect(ctx: *Ctx, res: anyerror!void) !void { + res catch |err| return ctx.pop(err); + // re-send request + ctx.req.prepareSend(.{}) catch |err| return ctx.pop(err); + ctx.req.connection.?.async_flush(ctx, onRedirectSend) catch |err| return ctx.pop(err); + } + + // async_redirect flow: + // connect -> setRequestConnection + // -> onRedirectConnect -> async_flush + // -> onRedirectSend -> async_fill + // -> go back on the wait workflow of the response + fn async_redirect(req: *Request, uri: Uri, ctx: *Ctx) !void { + try req.prepareRedirect(); + + var server_header = std.heap.FixedBufferAllocator.init(req.response.parser.header_bytes_buffer); + defer req.response.parser.header_bytes_buffer = server_header.buffer[server_header.end_index..]; + + const protocol, const valid_uri = try validateUri(uri, server_header.allocator()); + + const new_host = valid_uri.host.?.raw; + const prev_host = req.uri.host.?.raw; + const keep_privileged_headers = + std.ascii.eqlIgnoreCase(valid_uri.scheme, req.uri.scheme) and + std.ascii.endsWithIgnoreCase(new_host, prev_host) and + (new_host.len == prev_host.len or new_host[new_host.len - prev_host.len - 1] == '.'); + if (!keep_privileged_headers) { + // When redirecting to a different domain, strip privileged headers. + req.privileged_headers = &.{}; + } + + // create a new connection for the redirected URI + ctx.data.conn = try req.client.allocator.create(Connection); + ctx.data.conn.* = .{ + .stream = undefined, + .tls_client = undefined, + .protocol = undefined, + .host = undefined, + .port = undefined, + }; + req.uri = valid_uri; + return req.client.async_connect(new_host, uriPort(valid_uri, protocol), protocol, ctx, setRequestConnection); + } + + // This function must deallocate all resources associated with the request, + // or keep those which will be used. + // This needs to be kept in sync with deinit and request. + fn redirect(req: *Request, uri: Uri) !void { + try req.prepareRedirect(); + + var server_header = std.heap.FixedBufferAllocator.init(req.response.parser.header_bytes_buffer); + defer req.response.parser.header_bytes_buffer = server_header.buffer[server_header.end_index..]; + + const protocol, const valid_uri = try validateUri(uri, server_header.allocator()); + + const new_host = valid_uri.host.?.raw; + const prev_host = req.uri.host.?.raw; + const keep_privileged_headers = + std.ascii.eqlIgnoreCase(valid_uri.scheme, req.uri.scheme) and + std.ascii.endsWithIgnoreCase(new_host, prev_host) and + (new_host.len == prev_host.len or new_host[new_host.len - prev_host.len - 1] == '.'); + if (!keep_privileged_headers) { + // When redirecting to a different domain, strip privileged headers. + req.privileged_headers = &.{}; + } + + req.connection = try req.client.connect(new_host, uriPort(valid_uri, protocol), protocol); + req.uri = valid_uri; + } + fn prepareRedirect(req: *Request) !void { + assert(req.response.parser.done); + + req.client.connection_pool.release(req.client.allocator, req.connection.?); + req.connection = null; + + if (switch (req.response.status) { + .see_other => true, + .moved_permanently, .found => req.method == .POST, + else => false, + }) { + // A redirect to a GET must change the method and remove the body. + req.method = .GET; + req.transfer_encoding = .none; + req.headers.content_type = .omit; + } + + if (req.transfer_encoding != .none) { + // The request body has already been sent. The request is + // still in a valid state, but the redirect must be handled + // manually. + return error.RedirectRequiresResend; + } + + req.redirect_behavior.subtractOne(); + req.response.parser.reset(); + + req.response = .{ + .version = undefined, + .status = undefined, + .reason = undefined, + .keep_alive = undefined, + .parser = req.response.parser, + }; + } + + pub const SendError = Connection.WriteError || error{ InvalidContentLength, UnsupportedTransferEncoding }; + + pub fn async_send(req: *Request, ctx: *Ctx, comptime cbk: Cbk) !void { + try req.prepareSend(); + try req.connection.?.async_flush(ctx, cbk); + } + + /// Send the HTTP request headers to the server. + pub fn send(req: *Request) SendError!void { + try req.prepareSend(); + try req.connection.?.flush(); + } + + fn prepareSend(req: *Request) SendError!void { + if (!req.method.requestHasBody() and req.transfer_encoding != .none) + if (!req.method.requestHasBody() and req.transfer_encoding != .none) + return error.UnsupportedTransferEncoding; + + const connection = req.connection.?; + const w = connection.writer(); + + try req.method.write(w); + try w.writeByte(' '); + + if (req.method == .CONNECT) { + try req.uri.writeToStream(.{ .authority = true }, w); + } else { + try req.uri.writeToStream(.{ + .scheme = connection.proxied, + .authentication = connection.proxied, + .authority = connection.proxied, + .path = true, + .query = true, + }, w); + } + try w.writeByte(' '); + try w.writeAll(@tagName(req.version)); + try w.writeAll("\r\n"); + + if (try emitOverridableHeader("host: ", req.headers.host, w)) { + try w.writeAll("host: "); + try req.uri.writeToStream(.{ .authority = true }, w); + try w.writeAll("\r\n"); + } + + if (try emitOverridableHeader("authorization: ", req.headers.authorization, w)) { + if (req.uri.user != null or req.uri.password != null) { + try w.writeAll("authorization: "); + const authorization = try connection.allocWriteBuffer( + @intCast(basic_authorization.valueLengthFromUri(req.uri)), + ); + assert(basic_authorization.value(req.uri, authorization).len == authorization.len); + try w.writeAll("\r\n"); + } + } + + if (try emitOverridableHeader("user-agent: ", req.headers.user_agent, w)) { + try w.writeAll("user-agent: zig/"); + try w.writeAll(builtin.zig_version_string); + try w.writeAll(" (std.http)\r\n"); + } + + if (try emitOverridableHeader("connection: ", req.headers.connection, w)) { + if (req.keep_alive) { + try w.writeAll("connection: keep-alive\r\n"); + } else { + try w.writeAll("connection: close\r\n"); + } + } + + if (try emitOverridableHeader("accept-encoding: ", req.headers.accept_encoding, w)) { + // https://github.com/ziglang/zig/issues/18937 + //try w.writeAll("accept-encoding: gzip, deflate, zstd\r\n"); + try w.writeAll("accept-encoding: gzip, deflate\r\n"); + } + + switch (req.transfer_encoding) { + .chunked => try w.writeAll("transfer-encoding: chunked\r\n"), + .content_length => |len| try w.print("content-length: {d}\r\n", .{len}), + .none => {}, + } + + if (try emitOverridableHeader("content-type: ", req.headers.content_type, w)) { + // The default is to omit content-type if not provided because + // "application/octet-stream" is redundant. + } + + for (req.extra_headers) |header| { + assert(header.name.len != 0); + + try w.writeAll(header.name); + try w.writeAll(": "); + try w.writeAll(header.value); + try w.writeAll("\r\n"); + } + + if (connection.proxied) proxy: { + const proxy = switch (connection.protocol) { + .plain => req.client.http_proxy, + .tls => req.client.https_proxy, + } orelse break :proxy; + + const authorization = proxy.authorization orelse break :proxy; + try w.writeAll("proxy-authorization: "); + try w.writeAll(authorization); + try w.writeAll("\r\n"); + } + + try w.writeAll("\r\n"); + } + + /// Returns true if the default behavior is required, otherwise handles + /// writing (or not writing) the header. + fn emitOverridableHeader(prefix: []const u8, v: Headers.Value, w: anytype) !bool { + switch (v) { + .default => return true, + .omit => return false, + .override => |x| { + try w.writeAll(prefix); + try w.writeAll(x); + try w.writeAll("\r\n"); + return false; + }, + } + } + + const TransferReadError = Connection.ReadError || proto.HeadersParser.ReadError; + + const TransferReader = std.io.Reader(*Request, TransferReadError, transferRead); + + fn transferReader(req: *Request) TransferReader { + return .{ .context = req }; + } + + fn transferRead(req: *Request, buf: []u8) TransferReadError!usize { + if (req.response.parser.done) return 0; + + var index: usize = 0; + while (index == 0) { + const amt = try req.response.parser.read(req.connection.?, buf[index..], req.response.skip); + if (amt == 0 and req.response.parser.done) break; + index += amt; + } + + return index; + } + + pub const WaitError = RequestError || SendError || TransferReadError || + proto.HeadersParser.CheckCompleteHeadError || Response.ParseError || + error{ // TODO: file zig fmt issue for this bad indentation + TooManyHttpRedirects, + RedirectRequiresResend, + HttpRedirectLocationMissing, + HttpRedirectLocationInvalid, + CompressionInitializationFailed, + CompressionUnsupported, + }; + + pub fn async_wait(_: *Request, ctx: *Ctx, comptime cbk: Cbk) !void { + try ctx.push(cbk); + return ctx.conn().async_fill(ctx, onResponseHeaders); + } + + /// Waits for a response from the server and parses any headers that are sent. + /// This function will block until the final response is received. + /// + /// If handling redirects and the request has no payload, then this + /// function will automatically follow redirects. If a request payload is + /// present, then this function will error with + /// error.RedirectRequiresResend. + /// + /// Must be called after `send` and, if any data was written to the request + /// body, then also after `finish`. + pub fn wait(req: *Request) WaitError!void { + while (true) { + // This while loop is for handling redirects, which means the request's + // connection may be different than the previous iteration. However, it + // is still guaranteed to be non-null with each iteration of this loop. + const connection = req.connection.?; + + while (true) { // read headers + try connection.fill(); + + const nchecked = try req.response.parser.checkCompleteHead(connection.peek()); + connection.drop(@intCast(nchecked)); + + if (req.response.parser.state.isContent()) break; + } + + try req.response.parse(req.response.parser.get()); + + if (req.response.status == .@"continue") { + // We're done parsing the continue response; reset to prepare + // for the real response. + req.response.parser.done = true; + req.response.parser.reset(); + + if (req.handle_continue) + continue; + + return; // we're not handling the 100-continue + } + + // we're switching protocols, so this connection is no longer doing http + if (req.method == .CONNECT and req.response.status.class() == .success) { + connection.closing = false; + req.response.parser.done = true; + return; // the connection is not HTTP past this point + } + + connection.closing = !req.response.keep_alive or !req.keep_alive; + + // Any response to a HEAD request and any response with a 1xx + // (Informational), 204 (No Content), or 304 (Not Modified) status + // code is always terminated by the first empty line after the + // header fields, regardless of the header fields present in the + // message. + if (req.method == .HEAD or req.response.status.class() == .informational or + req.response.status == .no_content or req.response.status == .not_modified) + { + req.response.parser.done = true; + return; // The response is empty; no further setup or redirection is necessary. + } + + switch (req.response.transfer_encoding) { + .none => { + if (req.response.content_length) |cl| { + req.response.parser.next_chunk_length = cl; + + if (cl == 0) req.response.parser.done = true; + } else { + // read until the connection is closed + req.response.parser.next_chunk_length = std.math.maxInt(u64); + } + }, + .chunked => { + req.response.parser.next_chunk_length = 0; + req.response.parser.state = .chunk_head_size; + }, + } + + if (req.response.status.class() == .redirect and req.redirect_behavior != .unhandled) { + // skip the body of the redirect response, this will at least + // leave the connection in a known good state. + req.response.skip = true; + assert(try req.transferRead(&.{}) == 0); // we're skipping, no buffer is necessary + + if (req.redirect_behavior == .not_allowed) return error.TooManyHttpRedirects; + + const location = req.response.location orelse + return error.HttpRedirectLocationMissing; + + // This mutates the beginning of header_bytes_buffer and uses that + // for the backing memory of the returned Uri. + try req.redirect(req.uri.resolve_inplace( + location, + &req.response.parser.header_bytes_buffer, + ) catch |err| switch (err) { + error.UnexpectedCharacter, + error.InvalidFormat, + error.InvalidPort, + => return error.HttpRedirectLocationInvalid, + error.NoSpaceLeft => return error.HttpHeadersOversize, + }); + try req.send(); + } else { + req.response.skip = false; + if (!req.response.parser.done) { + switch (req.response.transfer_compression) { + .identity => req.response.compression = .none, + .compress, .@"x-compress" => return error.CompressionUnsupported, + .deflate => req.response.compression = .{ + .deflate = std.compress.zlib.decompressor(req.transferReader()), + }, + .gzip, .@"x-gzip" => req.response.compression = .{ + .gzip = std.compress.gzip.decompressor(req.transferReader()), + }, + // https://github.com/ziglang/zig/issues/18937 + //.zstd => req.response.compression = .{ + // .zstd = std.compress.zstd.decompressStream(req.client.allocator, req.transferReader()), + //}, + .zstd => return error.CompressionUnsupported, + } + } + + break; + } + } + } + + pub const ReadError = TransferReadError || proto.HeadersParser.CheckCompleteHeadError || + error{ DecompressionFailure, InvalidTrailers }; + + pub const Reader = std.io.Reader(*Request, ReadError, read); + + pub fn reader(req: *Request) Reader { + return .{ .context = req }; + } + + /// Reads data from the response body. Must be called after `wait`. + pub fn read(req: *Request, buffer: []u8) ReadError!usize { + const out_index = switch (req.response.compression) { + .deflate => |*deflate| deflate.read(buffer) catch return error.DecompressionFailure, + .gzip => |*gzip| gzip.read(buffer) catch return error.DecompressionFailure, + // https://github.com/ziglang/zig/issues/18937 + //.zstd => |*zstd| zstd.read(buffer) catch return error.DecompressionFailure, + else => try req.transferRead(buffer), + }; + if (out_index > 0) return out_index; + + while (!req.response.parser.state.isContent()) { // read trailing headers + try req.connection.?.fill(); + + const nchecked = try req.response.parser.checkCompleteHead(req.connection.?.peek()); + req.connection.?.drop(@intCast(nchecked)); + } + + return 0; + } + + /// Reads data from the response body. Must be called after `wait`. + pub fn readAll(req: *Request, buffer: []u8) !usize { + var index: usize = 0; + while (index < buffer.len) { + const amt = try read(req, buffer[index..]); + if (amt == 0) break; + index += amt; + } + return index; + } + + pub const WriteError = Connection.WriteError || error{ NotWriteable, MessageTooLong }; + + pub const Writer = std.io.Writer(*Request, WriteError, write); + + pub fn writer(req: *Request) Writer { + return .{ .context = req }; + } + + /// Write `bytes` to the server. The `transfer_encoding` field determines how data will be sent. + /// Must be called after `send` and before `finish`. + pub fn write(req: *Request, bytes: []const u8) WriteError!usize { + switch (req.transfer_encoding) { + .chunked => { + if (bytes.len > 0) { + try req.connection.?.writer().print("{x}\r\n", .{bytes.len}); + try req.connection.?.writer().writeAll(bytes); + try req.connection.?.writer().writeAll("\r\n"); + } + + return bytes.len; + }, + .content_length => |*len| { + if (len.* < bytes.len) return error.MessageTooLong; + + const amt = try req.connection.?.write(bytes); + len.* -= amt; + return amt; + }, + .none => return error.NotWriteable, + } + } + + /// Write `bytes` to the server. The `transfer_encoding` field determines how data will be sent. + /// Must be called after `send` and before `finish`. + pub fn writeAll(req: *Request, bytes: []const u8) WriteError!void { + var index: usize = 0; + while (index < bytes.len) { + index += try write(req, bytes[index..]); + } + } + + pub fn async_writeAll(req: *Request, buf: []const u8, ctx: *Ctx, comptime cbk: Cbk) !void { + try req.connection.?.async_writeAllDirect(buf, ctx, cbk); + } + + pub const FinishError = WriteError || error{MessageNotCompleted}; + + pub fn async_finish(req: *Request, ctx: *Ctx, comptime cbk: Cbk) !void { + try req.common_finish(); + req.connection.?.async_flush(ctx, cbk) catch |err| switch (err) { + error.WriteEmpty => return cbk(ctx, {}), + else => return cbk(ctx, err), + }; + } + + /// Finish the body of a request. This notifies the server that you have no more data to send. + /// Must be called after `send`. + pub fn finish(req: *Request) FinishError!void { + try req.common_finish(); + try req.connection.?.flush(); + } + + fn common_finish(req: *Request) FinishError!void { + switch (req.transfer_encoding) { + .chunked => try req.connection.?.writer().writeAll("0\r\n\r\n"), + .content_length => |len| if (len != 0) return error.MessageNotCompleted, + .none => {}, + } + } + + fn onResponseHeaders(ctx: *Ctx, res: anyerror!void) !void { + res catch |err| return ctx.pop(err); + const done = ctx.req.parseResponseHeaders() catch |err| return ctx.pop(err); + // if read of the headers is not done, continue + if (!done) return ctx.conn().async_fill(ctx, onResponseHeaders); + // if read of the headers is done, go read the reponse + return onResponse(ctx, {}); + } + + fn parseResponseHeaders(req: *Request) !bool { + const nchecked = try req.response.parser.checkCompleteHead(req.connection.?.peek()); + req.connection.?.drop(@intCast(nchecked)); + + if (req.response.parser.state.isContent()) return true; + return false; + } + + fn onResponse(ctx: *Ctx, res: anyerror!void) !void { + res catch |err| return ctx.pop(err); + const ret = ctx.req.parseResponse() catch |err| return ctx.pop(err); + if (ret.redirect_uri) |uri| { + ctx.req.async_redirect(uri, ctx) catch |err| return ctx.pop(err); + return; + } + // if read of the response is not done, continue + if (!ret.done) return ctx.conn().async_fill(ctx, onResponse); + // if read of the response is done, go execute the provided callback + return ctx.pop({}); + } + + const WaitRedirectsReturn = struct { + redirect_uri: ?Uri = null, + done: bool = true, + }; + + fn parseResponse(req: *Request) WaitError!WaitRedirectsReturn { + try req.response.parse(req.response.parser.get()); + + if (req.response.status == .@"continue") { + // We're done parsing the continue response; reset to prepare + // for the real response. + req.response.parser.done = true; + req.response.parser.reset(); + + if (req.handle_continue) return .{ .done = false }; + + return .{ .done = true }; + } + + // we're switching protocols, so this connection is no longer doing http + if (req.method == .CONNECT and req.response.status.class() == .success) { + req.connection.?.closing = false; + req.response.parser.done = true; + return .{ .done = true }; // the connection is not HTTP past this point + } + + req.connection.?.closing = !req.response.keep_alive or !req.keep_alive; + + // Any response to a HEAD request and any response with a 1xx + // (Informational), 204 (No Content), or 304 (Not Modified) status + // code is always terminated by the first empty line after the + // header fields, regardless of the header fields present in the + // message. + if (req.method == .HEAD or req.response.status.class() == .informational or + req.response.status == .no_content or req.response.status == .not_modified) + { + req.response.parser.done = true; + return .{ .done = true }; // The response is empty; no further setup or redirection is necessary. + } + + switch (req.response.transfer_encoding) { + .none => { + if (req.response.content_length) |cl| { + req.response.parser.next_chunk_length = cl; + + if (cl == 0) req.response.parser.done = true; + } else { + // read until the connection is closed + req.response.parser.next_chunk_length = std.math.maxInt(u64); + } + }, + .chunked => { + req.response.parser.next_chunk_length = 0; + req.response.parser.state = .chunk_head_size; + }, + } + + if (req.response.status.class() == .redirect and req.redirect_behavior != .unhandled) { + // skip the body of the redirect response, this will at least + // leave the connection in a known good state. + req.response.skip = true; + assert(try req.transferRead(&.{}) == 0); // we're skipping, no buffer is necessary + + if (req.redirect_behavior == .not_allowed) return error.TooManyHttpRedirects; + + const location = req.response.location orelse + return error.HttpRedirectLocationMissing; + + // This mutates the beginning of header_bytes_buffer and uses that + // for the backing memory of the returned Uri. + try req.redirect(req.uri.resolve_inplace( + location, + &req.response.parser.header_bytes_buffer, + ) catch |err| switch (err) { + error.UnexpectedCharacter, + error.InvalidFormat, + error.InvalidPort, + => return error.HttpRedirectLocationInvalid, + error.NoSpaceLeft => return error.HttpHeadersOversize, + }); + + return .{ .redirect_uri = req.uri }; + } else { + req.response.skip = false; + if (!req.response.parser.done) { + switch (req.response.transfer_compression) { + .identity => req.response.compression = .none, + .compress, .@"x-compress" => return error.CompressionUnsupported, + .deflate => req.response.compression = .{ + .deflate = std.compress.zlib.decompressor(req.transferReader()), + }, + .gzip, .@"x-gzip" => req.response.compression = .{ + .gzip = std.compress.gzip.decompressor(req.transferReader()), + }, + // https://github.com/ziglang/zig/issues/18937 + //.zstd => req.response.compression = .{ + // .zstd = std.compress.zstd.decompressStream(req.client.allocator, req.transferReader()), + //}, + .zstd => return error.CompressionUnsupported, + } + } + + return .{ .done = true }; + } + return .{ .done = false }; + } +}; + +pub const Proxy = struct { + protocol: Connection.Protocol, + host: []const u8, + authorization: ?[]const u8, + port: u16, + supports_connect: bool, +}; + +/// Release all associated resources with the client. +/// +/// All pending requests must be de-initialized and all active connections released +/// before calling this function. +pub fn deinit(client: *Client) void { + assert(client.connection_pool.used.first == null); // There are still active requests. + + client.connection_pool.deinit(client.allocator); + + if (!disable_tls) + client.ca_bundle.deinit(client.allocator); + + client.* = undefined; +} + +/// Populates `http_proxy` and `https_proxy` via standard proxy environment variables. +/// Asserts the client has no active connections. +/// Uses `arena` for a few small allocations that must outlive the client, or +/// at least until those fields are set to different values. +pub fn initDefaultProxies(client: *Client, arena: Allocator) !void { + // Prevent any new connections from being created. + client.connection_pool.mutex.lock(); + defer client.connection_pool.mutex.unlock(); + + assert(client.connection_pool.used.first == null); // There are active requests. + + if (client.http_proxy == null) { + client.http_proxy = try createProxyFromEnvVar(arena, &.{ + "http_proxy", "HTTP_PROXY", "all_proxy", "ALL_PROXY", + }); + } + + if (client.https_proxy == null) { + client.https_proxy = try createProxyFromEnvVar(arena, &.{ + "https_proxy", "HTTPS_PROXY", "all_proxy", "ALL_PROXY", + }); + } +} + +fn createProxyFromEnvVar(arena: Allocator, env_var_names: []const []const u8) !?*Proxy { + const content = for (env_var_names) |name| { + break std.process.getEnvVarOwned(arena, name) catch |err| switch (err) { + error.EnvironmentVariableNotFound => continue, + else => |e| return e, + }; + } else return null; + + const uri = Uri.parse(content) catch try Uri.parseAfterScheme("http", content); + const protocol, const valid_uri = validateUri(uri, arena) catch |err| switch (err) { + error.UnsupportedUriScheme => return null, + error.UriMissingHost => return error.HttpProxyMissingHost, + error.OutOfMemory => |e| return e, + }; + + const authorization: ?[]const u8 = if (valid_uri.user != null or valid_uri.password != null) a: { + const authorization = try arena.alloc(u8, basic_authorization.valueLengthFromUri(valid_uri)); + assert(basic_authorization.value(valid_uri, authorization).len == authorization.len); + break :a authorization; + } else null; + + const proxy = try arena.create(Proxy); + proxy.* = .{ + .protocol = protocol, + .host = valid_uri.host.?.raw, + .authorization = authorization, + .port = uriPort(valid_uri, protocol), + .supports_connect = true, + }; + return proxy; +} + +pub const basic_authorization = struct { + pub const max_user_len = 255; + pub const max_password_len = 255; + pub const max_value_len = valueLength(max_user_len, max_password_len); + + const prefix = "Basic "; + + pub fn valueLength(user_len: usize, password_len: usize) usize { + return prefix.len + std.base64.standard.Encoder.calcSize(user_len + 1 + password_len); + } + + pub fn valueLengthFromUri(uri: Uri) usize { + var stream = std.io.countingWriter(std.io.null_writer); + try stream.writer().print("{user}", .{uri.user orelse Uri.Component.empty}); + const user_len = stream.bytes_written; + stream.bytes_written = 0; + try stream.writer().print("{password}", .{uri.password orelse Uri.Component.empty}); + const password_len = stream.bytes_written; + return valueLength(@intCast(user_len), @intCast(password_len)); + } + + pub fn value(uri: Uri, out: []u8) []u8 { + var buf: [max_user_len + ":".len + max_password_len]u8 = undefined; + var stream = std.io.fixedBufferStream(&buf); + stream.writer().print("{user}", .{uri.user orelse Uri.Component.empty}) catch + unreachable; + assert(stream.pos <= max_user_len); + stream.writer().print(":{password}", .{uri.password orelse Uri.Component.empty}) catch + unreachable; + + @memcpy(out[0..prefix.len], prefix); + const base64 = std.base64.standard.Encoder.encode(out[prefix.len..], stream.getWritten()); + return out[0 .. prefix.len + base64.len]; + } +}; + +pub const ConnectTcpError = Allocator.Error || error{ ConnectionRefused, NetworkUnreachable, ConnectionTimedOut, ConnectionResetByPeer, TemporaryNameServerFailure, NameServerFailure, UnknownHostName, HostLacksNetworkAddresses, UnexpectedConnectFailure, TlsInitializationFailed }; + +// requires ctx.data.stream to be set +fn setConnection(ctx: *Ctx, res: anyerror!void) !void { + + // check stream + errdefer ctx.data.conn.stream.close(); + res catch |e| { + // ctx.data.conn.stream.close(); is it needed with errdefer? + switch (e) { + error.ConnectionRefused, + error.NetworkUnreachable, + error.ConnectionTimedOut, + error.ConnectionResetByPeer, + error.TemporaryNameServerFailure, + error.NameServerFailure, + error.UnknownHostName, + error.HostLacksNetworkAddresses, + => return ctx.pop(e), + else => return ctx.pop(error.UnexpectedConnectFailure), + } + }; + + if (ctx.data.conn.protocol == .tls) { + if (disable_tls) unreachable; + + ctx.data.conn.tls_client = try ctx.alloc().create(tls23.Connection(net.Stream)); + errdefer ctx.alloc().destroy(ctx.data.conn.tls_client); + + // TODO tls23.client does an handshake to pick a cipher. + ctx.data.conn.tls_client.* = tls23.client(ctx.data.conn.stream, .{ + .host = ctx.data.conn.host, + .root_ca = .{ .bundle = ctx.req.client.ca_bundle }, + }) catch return error.TlsInitializationFailed; + } + + // add connection node in pool + const node = ctx.req.client.allocator.create(ConnectionPool.Node) catch |e| return ctx.pop(e); + errdefer ctx.req.client.allocator.destroy(node); + // NOTE we can not use the ctx.data.conn pointer as a node connection data, + // we need to copy it's value and use this reference for the connection + node.* = .{ + .data = .{ + .stream = ctx.data.conn.stream, + .tls_client = ctx.data.conn.tls_client, + .protocol = ctx.data.conn.protocol, + .host = ctx.data.conn.host, + .port = ctx.data.conn.port, + }, + }; + // remove old pointer, now useless + const old_conn = ctx.data.conn; + defer ctx.req.client.allocator.destroy(old_conn); + + ctx.req.client.connection_pool.addUsed(node); + ctx.data.conn = &node.data; + + return ctx.pop({}); +} + +/// Connect to `host:port` using the specified protocol. This will reuse a connection if one is already open. +/// +/// This function is threadsafe. +pub fn connectTcp(client: *Client, host: []const u8, port: u16, protocol: Connection.Protocol) ConnectTcpError!*Connection { + if (client.connection_pool.findConnection(.{ + .host = host, + .port = port, + .protocol = protocol, + })) |node| return node; + + if (disable_tls and protocol == .tls) + return error.TlsInitializationFailed; + + const conn = try client.allocator.create(ConnectionPool.Node); + errdefer client.allocator.destroy(conn); + conn.* = .{ .data = undefined }; + + const stream = net.tcpConnectToHost(client.allocator, host, port) catch |err| switch (err) { + error.ConnectionRefused => return error.ConnectionRefused, + error.NetworkUnreachable => return error.NetworkUnreachable, + error.ConnectionTimedOut => return error.ConnectionTimedOut, + error.ConnectionResetByPeer => return error.ConnectionResetByPeer, + error.TemporaryNameServerFailure => return error.TemporaryNameServerFailure, + error.NameServerFailure => return error.NameServerFailure, + error.UnknownHostName => return error.UnknownHostName, + error.HostLacksNetworkAddresses => return error.HostLacksNetworkAddresses, + else => return error.UnexpectedConnectFailure, + }; + errdefer stream.close(); + + conn.data = .{ + .stream = stream, + .tls_client = undefined, + + .protocol = protocol, + .host = try client.allocator.dupe(u8, host), + .port = port, + }; + errdefer client.allocator.free(conn.data.host); + + if (protocol == .tls) { + if (disable_tls) unreachable; + + conn.data.tls_client = try client.allocator.create(tls23.Connection(net.Stream)); + errdefer client.allocator.destroy(conn.data.tls_client); + + // TODO tls23.client does an handshake to pick a cipher. + conn.data.tls_client.* = tls23.client(stream, .{ + .host = host, + .root_ca = .{ .bundle = client.ca_bundle }, + }) catch return error.TlsInitializationFailed; + } + + client.connection_pool.addUsed(conn); + + return &conn.data; +} + +pub fn async_connectTcp( + client: *Client, + host: []const u8, + port: u16, + protocol: Connection.Protocol, + ctx: *Ctx, + comptime cbk: Cbk, +) !void { + try ctx.push(cbk); + if (ctx.req.client.connection_pool.findConnection(.{ + .host = host, + .port = port, + .protocol = protocol, + })) |conn| { + ctx.req.connection = conn; + return ctx.pop({}); + } + + if (disable_tls and protocol == .tls) + return error.TlsInitializationFailed; + + return net.async_tcpConnectToHost( + client.allocator, + host, + port, + ctx, + setConnection, + ); +} + +pub const ConnectUnixError = Allocator.Error || std.posix.SocketError || error{NameTooLong} || std.posix.ConnectError; + +// Connect to `path` as a unix domain socket. This will reuse a connection if one is already open. +// +// This function is threadsafe. +// pub fn connectUnix(client: *Client, path: []const u8) ConnectUnixError!*Connection { +// if (client.connection_pool.findConnection(.{ +// .host = path, +// .port = 0, +// .protocol = .plain, +// })) |node| +// return node; + +// const conn = try client.allocator.create(ConnectionPool.Node); +// errdefer client.allocator.destroy(conn); +// conn.* = .{ .data = undefined }; + +// const stream = try std.net.connectUnixSocket(path); +// errdefer stream.close(); + +// conn.data = .{ +// .stream = stream, +// .tls_client = undefined, +// .protocol = .plain, + +// .host = try client.allocator.dupe(u8, path), +// .port = 0, +// }; +// errdefer client.allocator.free(conn.data.host); + +// client.connection_pool.addUsed(conn); + +// return &conn.data; +//} + +/// Connect to `tunnel_host:tunnel_port` using the specified proxy with HTTP +/// CONNECT. This will reuse a connection if one is already open. +/// +/// This function is threadsafe. +pub fn connectTunnel( + client: *Client, + proxy: *Proxy, + tunnel_host: []const u8, + tunnel_port: u16, +) !*Connection { + if (!proxy.supports_connect) return error.TunnelNotSupported; + + if (client.connection_pool.findConnection(.{ + .host = tunnel_host, + .port = tunnel_port, + .protocol = proxy.protocol, + })) |node| + return node; + + var maybe_valid = false; + (tunnel: { + const conn = try client.connectTcp(proxy.host, proxy.port, proxy.protocol); + errdefer { + conn.closing = true; + client.connection_pool.release(client.allocator, conn); + } + + var buffer: [8096]u8 = undefined; + var req = client.open(.CONNECT, .{ + .scheme = "http", + .host = .{ .raw = tunnel_host }, + .port = tunnel_port, + }, .{ + .redirect_behavior = .unhandled, + .connection = conn, + .server_header_buffer = &buffer, + }) catch |err| { + std.log.debug("err {}", .{err}); + break :tunnel err; + }; + defer req.deinit(); + + req.send() catch |err| break :tunnel err; + req.wait() catch |err| break :tunnel err; + + if (req.response.status.class() == .server_error) { + maybe_valid = true; + break :tunnel error.ServerError; + } + + if (req.response.status != .ok) break :tunnel error.ConnectionRefused; + + // this connection is now a tunnel, so we can't use it for anything else, it will only be released when the client is de-initialized. + req.connection = null; + + client.allocator.free(conn.host); + conn.host = try client.allocator.dupe(u8, tunnel_host); + errdefer client.allocator.free(conn.host); + + conn.port = tunnel_port; + conn.closing = false; + + return conn; + }) catch { + // something went wrong with the tunnel + proxy.supports_connect = maybe_valid; + return error.TunnelNotSupported; + }; +} + +// Prevents a dependency loop in open() +const ConnectErrorPartial = ConnectTcpError || error{ UnsupportedUriScheme, ConnectionRefused }; +pub const ConnectError = ConnectErrorPartial || RequestError; + +fn onConnectProxy(ctx: *Ctx, res: anyerror!void) anyerror!void { + res catch |e| { + ctx.data.conn.closing = true; + ctx.req.client.connection_pool.release(ctx.req.client.allocator, ctx.data.conn); + return ctx.pop(e); + }; + ctx.data.conn.proxied = true; + return ctx.pop({}); +} + +/// Connect to `host:port` using the specified protocol. This will reuse a +/// connection if one is already open. +/// If a proxy is configured for the client, then the proxy will be used to +/// connect to the host. +/// +/// This function is threadsafe. +pub fn connect( + client: *Client, + host: []const u8, + port: u16, + protocol: Connection.Protocol, +) ConnectError!*Connection { + const proxy = switch (protocol) { + .plain => client.http_proxy, + .tls => client.https_proxy, + } orelse return client.connectTcp(host, port, protocol); + + // Prevent proxying through itself. + if (std.ascii.eqlIgnoreCase(proxy.host, host) and + proxy.port == port and proxy.protocol == protocol) + { + return client.connectTcp(host, port, protocol); + } + + if (proxy.supports_connect) tunnel: { + return connectTunnel(client, proxy, host, port) catch |err| switch (err) { + error.TunnelNotSupported => break :tunnel, + else => |e| return e, + }; + } + + // fall back to using the proxy as a normal http proxy + const conn = try client.connectTcp(proxy.host, proxy.port, proxy.protocol); + errdefer { + conn.closing = true; + client.connection_pool.release(conn); + } + + conn.proxied = true; + return conn; +} + +pub fn async_connect( + client: *Client, + host: []const u8, + port: u16, + protocol: Connection.Protocol, + ctx: *Ctx, + comptime cbk: Cbk, +) !void { + const proxy = switch (protocol) { + .plain => client.http_proxy, + .tls => client.https_proxy, + } orelse return client.async_connectTcp(host, port, protocol, ctx, cbk); + + // Prevent proxying through itself. + if (std.ascii.eqlIgnoreCase(proxy.host, host) and + proxy.port == port and proxy.protocol == protocol) + { + return client.async_connectTcp(host, port, protocol, ctx, cbk); + } + + // TODO: enable async_connectTunnel + // if (proxy.supports_connect) tunnel: { + // return connectTunnel(client, proxy, host, port) catch |err| switch (err) { + // error.TunnelNotSupported => break :tunnel, + // else => |e| return e, + // }; + // } + + // fall back to using the proxy as a normal http proxy + try ctx.push(cbk); + return client.async_connectTcp(proxy.host, proxy.port, proxy.protocol, ctx, onConnectProxy); +} + +pub const RequestError = ConnectTcpError || ConnectErrorPartial || Request.SendError || + std.fmt.ParseIntError || Connection.WriteError || + error{ // TODO: file a zig fmt issue for this bad indentation + UnsupportedUriScheme, + UriMissingHost, + + CertificateBundleLoadFailure, + UnsupportedTransferEncoding, +}; + +pub const RequestOptions = struct { + version: http.Version = .@"HTTP/1.1", + + /// Automatically ignore 100 Continue responses. This assumes you don't + /// care, and will have sent the body before you wait for the response. + /// + /// If this is not the case AND you know the server will send a 100 + /// Continue, set this to false and wait for a response before sending the + /// body. If you wait AND the server does not send a 100 Continue before + /// you finish the request, then the request *will* deadlock. + handle_continue: bool = true, + + /// If false, close the connection after the one request. If true, + /// participate in the client connection pool. + keep_alive: bool = true, + + /// This field specifies whether to automatically follow redirects, and if + /// so, how many redirects to follow before returning an error. + /// + /// This will only follow redirects for repeatable requests (ie. with no + /// payload or the server has acknowledged the payload). + redirect_behavior: Request.RedirectBehavior = @enumFromInt(3), + + /// Externally-owned memory used to store the server's entire HTTP header. + /// `error.HttpHeadersOversize` is returned from read() when a + /// client sends too many bytes of HTTP headers. + server_header_buffer: []u8, + + /// Must be an already acquired connection. + connection: ?*Connection = null, + + /// Standard headers that have default, but overridable, behavior. + headers: Request.Headers = .{}, + /// These headers are kept including when following a redirect to a + /// different domain. + /// Externally-owned; must outlive the Request. + extra_headers: []const http.Header = &.{}, + /// These headers are stripped when following a redirect to a different + /// domain. + /// Externally-owned; must outlive the Request. + privileged_headers: []const http.Header = &.{}, +}; + +const protocol_map = std.StaticStringMap(Connection.Protocol).initComptime(.{ + .{ "http", .plain }, + .{ "ws", .plain }, + .{ "https", .tls }, + .{ "wss", .tls }, +}); + +fn validateUri(uri: Uri, arena: Allocator) !struct { Connection.Protocol, Uri } { + const protocol = protocol_map.get(uri.scheme) orelse return error.UnsupportedUriScheme; + var valid_uri = uri; + // The host is always going to be needed as a raw string for hostname resolution anyway. + valid_uri.host = .{ + .raw = try (uri.host orelse return error.UriMissingHost).toRawMaybeAlloc(arena), + }; + return .{ protocol, valid_uri }; +} + +fn uriPort(uri: Uri, protocol: Connection.Protocol) u16 { + return uri.port orelse switch (protocol) { + .plain => 80, + .tls => 443, + }; +} + +pub fn create( + client: *Client, + method: http.Method, + uri: Uri, + options: RequestOptions, +) RequestError!Request { + if (std.debug.runtime_safety) { + for (options.extra_headers) |header| { + assert(header.name.len != 0); + assert(std.mem.indexOfScalar(u8, header.name, ':') == null); + assert(std.mem.indexOfPosLinear(u8, header.name, 0, "\r\n") == null); + assert(std.mem.indexOfPosLinear(u8, header.value, 0, "\r\n") == null); + } + for (options.privileged_headers) |header| { + assert(header.name.len != 0); + assert(std.mem.indexOfPosLinear(u8, header.name, 0, "\r\n") == null); + assert(std.mem.indexOfPosLinear(u8, header.value, 0, "\r\n") == null); + } + } + + var server_header = std.heap.FixedBufferAllocator.init(options.server_header_buffer); + _, const valid_uri = try validateUri(uri, server_header.allocator()); + + var req: Request = .{ + .uri = valid_uri, + .client = client, + .keep_alive = options.keep_alive, + .method = method, + .version = options.version, + .transfer_encoding = .none, + .redirect_behavior = options.redirect_behavior, + .handle_continue = options.handle_continue, + .response = .{ + .version = undefined, + .status = undefined, + .reason = undefined, + .keep_alive = undefined, + .parser = proto.HeadersParser.init(server_header.buffer[server_header.end_index..]), + }, + .headers = options.headers, + .extra_headers = options.extra_headers, + .privileged_headers = options.privileged_headers, + }; + errdefer req.deinit(); + + return req; +} + +/// Open a connection to the host specified by `uri` and prepare to send a HTTP request. +/// +/// `uri` must remain alive during the entire request. +/// +/// The caller is responsible for calling `deinit()` on the `Request`. +/// This function is threadsafe. +/// +/// Asserts that "\r\n" does not occur in any header name or value. +pub fn open( + client: *Client, + method: http.Method, + uri: Uri, + options: RequestOptions, +) RequestError!Request { + if (std.debug.runtime_safety) { + for (options.extra_headers) |header| { + assert(header.name.len != 0); + assert(std.mem.indexOfScalar(u8, header.name, ':') == null); + assert(std.mem.indexOfPosLinear(u8, header.name, 0, "\r\n") == null); + assert(std.mem.indexOfPosLinear(u8, header.value, 0, "\r\n") == null); + } + for (options.privileged_headers) |header| { + assert(header.name.len != 0); + assert(std.mem.indexOfPosLinear(u8, header.name, 0, "\r\n") == null); + assert(std.mem.indexOfPosLinear(u8, header.value, 0, "\r\n") == null); + } + } + + var server_header = std.heap.FixedBufferAllocator.init(options.server_header_buffer); + const protocol, const valid_uri = try validateUri(uri, server_header.allocator()); + + if (protocol == .tls and @atomicLoad(bool, &client.next_https_rescan_certs, .acquire)) { + if (disable_tls) unreachable; + + client.ca_bundle_mutex.lock(); + defer client.ca_bundle_mutex.unlock(); + + if (client.next_https_rescan_certs) { + client.ca_bundle.rescan(client.allocator) catch + return error.CertificateBundleLoadFailure; + @atomicStore(bool, &client.next_https_rescan_certs, false, .release); + } + } + + const conn = options.connection orelse + try client.connect(valid_uri.host.?.raw, uriPort(valid_uri, protocol), protocol); + + var req: Request = .{ + .uri = valid_uri, + .client = client, + .connection = conn, + .keep_alive = options.keep_alive, + .method = method, + .version = options.version, + .transfer_encoding = .none, + .redirect_behavior = options.redirect_behavior, + .handle_continue = options.handle_continue, + .response = .{ + .version = undefined, + .status = undefined, + .reason = undefined, + .keep_alive = undefined, + .parser = proto.HeadersParser.init(server_header.buffer[server_header.end_index..]), + }, + .headers = options.headers, + .extra_headers = options.extra_headers, + .privileged_headers = options.privileged_headers, + }; + errdefer req.deinit(); + + return req; +} + +pub fn async_open( + client: *Client, + method: http.Method, + uri: Uri, + options: RequestOptions, + ctx: *Ctx, + comptime cbk: Cbk, +) !void { + if (std.debug.runtime_safety) { + for (options.extra_headers) |header| { + assert(header.name.len != 0); + assert(std.mem.indexOfScalar(u8, header.name, ':') == null); + assert(std.mem.indexOfPosLinear(u8, header.name, 0, "\r\n") == null); + assert(std.mem.indexOfPosLinear(u8, header.value, 0, "\r\n") == null); + } + for (options.privileged_headers) |header| { + assert(header.name.len != 0); + assert(std.mem.indexOfPosLinear(u8, header.name, 0, "\r\n") == null); + assert(std.mem.indexOfPosLinear(u8, header.value, 0, "\r\n") == null); + } + } + + var server_header = std.heap.FixedBufferAllocator.init(options.server_header_buffer); + const protocol, const valid_uri = try validateUri(uri, server_header.allocator()); + + if (protocol == .tls and @atomicLoad(bool, &client.next_https_rescan_certs, .acquire)) { + if (disable_tls) unreachable; + + client.ca_bundle_mutex.lock(); + defer client.ca_bundle_mutex.unlock(); + + if (client.next_https_rescan_certs) { + client.ca_bundle.rescan(client.allocator) catch return error.CertificateBundleLoadFailure; + @atomicStore(bool, &client.next_https_rescan_certs, false, .release); + } + } + + // add fields to request + ctx.req.uri = valid_uri; + ctx.req.keep_alive = options.keep_alive; + ctx.req.method = method; + ctx.req.transfer_encoding = .none; + ctx.req.redirect_behavior = options.redirect_behavior; + ctx.req.handle_continue = options.handle_continue; + ctx.req.headers = options.headers; + ctx.req.extra_headers = options.extra_headers; + ctx.req.privileged_headers = options.privileged_headers; + ctx.req.response = .{ + .version = undefined, + .status = undefined, + .reason = undefined, + .keep_alive = undefined, + .parser = proto.HeadersParser.init(server_header.buffer[server_header.end_index..]), + }; + + // we already have the connection, + // set it and call directly the callback + if (options.connection) |conn| { + ctx.req.connection = conn; + return cbk(ctx, {}); + } + + // push callback function + try ctx.push(cbk); + + const host = valid_uri.host orelse return error.UriMissingHost; + const port = uriPort(valid_uri, protocol); + + // add fields to connection + ctx.data.conn.protocol = protocol; + ctx.data.conn.host = try client.allocator.dupe(u8, host.raw); + ctx.data.conn.port = port; + + return client.async_connect(host.raw, port, protocol, ctx, setRequestConnection); +} + +pub const FetchOptions = struct { + server_header_buffer: ?[]u8 = null, + redirect_behavior: ?Request.RedirectBehavior = null, + + /// If the server sends a body, it will be appended to this ArrayList. + /// `max_append_size` provides an upper limit for how much they can grow. + response_storage: ResponseStorage = .ignore, + max_append_size: ?usize = null, + + location: Location, + method: ?http.Method = null, + payload: ?[]const u8 = null, + raw_uri: bool = false, + keep_alive: bool = true, + + /// Standard headers that have default, but overridable, behavior. + headers: Request.Headers = .{}, + /// These headers are kept including when following a redirect to a + /// different domain. + /// Externally-owned; must outlive the Request. + extra_headers: []const http.Header = &.{}, + /// These headers are stripped when following a redirect to a different + /// domain. + /// Externally-owned; must outlive the Request. + privileged_headers: []const http.Header = &.{}, + + pub const Location = union(enum) { + url: []const u8, + uri: Uri, + }; + + pub const ResponseStorage = union(enum) { + ignore, + /// Only the existing capacity will be used. + static: *std.ArrayListUnmanaged(u8), + dynamic: *std.ArrayList(u8), + }; +}; + +pub const FetchResult = struct { + status: http.Status, +}; + +// TODO: enable async_fetch +/// Perform a one-shot HTTP request with the provided options. +/// +/// This function is threadsafe. +pub fn fetch(client: *Client, options: FetchOptions) !FetchResult { + const uri = switch (options.location) { + .url => |u| try Uri.parse(u), + .uri => |u| u, + }; + var server_header_buffer: [16 * 1024]u8 = undefined; + + const method: http.Method = options.method orelse + if (options.payload != null) .POST else .GET; + + var req = try open(client, method, uri, .{ + .server_header_buffer = options.server_header_buffer orelse &server_header_buffer, + .redirect_behavior = options.redirect_behavior orelse + if (options.payload == null) @enumFromInt(3) else .unhandled, + .headers = options.headers, + .extra_headers = options.extra_headers, + .privileged_headers = options.privileged_headers, + .keep_alive = options.keep_alive, + }); + defer req.deinit(); + + if (options.payload) |payload| req.transfer_encoding = .{ .content_length = payload.len }; + + try req.send(); + + if (options.payload) |payload| try req.writeAll(payload); + + try req.finish(); + try req.wait(); + + switch (options.response_storage) { + .ignore => { + // Take advantage of request internals to discard the response body + // and make the connection available for another request. + req.response.skip = true; + assert(try req.transferRead(&.{}) == 0); // No buffer is necessary when skipping. + }, + .dynamic => |list| { + const max_append_size = options.max_append_size orelse 2 * 1024 * 1024; + try req.reader().readAllArrayList(list, max_append_size); + }, + .static => |list| { + const buf = b: { + const buf = list.unusedCapacitySlice(); + if (options.max_append_size) |len| { + if (len < buf.len) break :b buf[0..len]; + } + break :b buf; + }; + list.items.len += try req.reader().readAll(buf); + }, + } + + return .{ + .status = req.response.status, + }; +} + +pub const Cbk = fn (ctx: *Ctx, res: anyerror!void) anyerror!void; + +pub const Ctx = struct { + const Stack = GenericStack(Cbk); + + // temporary Data we need to store on the heap + // because of the callback execution model + const Data = struct { + list: *std.net.AddressList = undefined, + addr_current: usize = undefined, + socket: std.posix.socket_t = undefined, + + // TODO: we could remove this field as it is already set in ctx.req + // but we do not know for now what will be the impact to set those directly + // on the request, especially in case of error/cancellation + conn: *Connection, + }; + + req: *Request = undefined, + + userData: *anyopaque = undefined, + + loop: *Loop, + data: Data, + stack: ?*Stack = null, + err: ?anyerror = null, + + _buffer: ?[]const u8 = null, + _len: ?usize = null, + + _iovecs: []std.posix.iovec = undefined, + + // TLS readvAtLeast + // _off_i: usize = 0, + // _vec_i: usize = 0, + // _tls_len: usize = 0, + + // TLS readv + _vp: VecPut = undefined, + // _tls_read_buf contains the next decrypted buffer + _tls_read_buf: ?[]u8 = undefined, + _tls_read_content_type: tls23.proto.ContentType = undefined, + + // _tls_read_record contains the crypted record + _tls_read_record: ?tls23.record.Record = null, + + // TLS writeAll + _tls_write_bytes: []const u8 = undefined, + _tls_write_index: usize = 0, + _tls_write_buf: [cipher.max_ciphertext_record_len]u8 = undefined, + + pub fn init(loop: *Loop, req: *Request) !Ctx { + const connection = try req.client.allocator.create(Connection); + connection.* = .{ + .stream = undefined, + .tls_client = undefined, + .protocol = undefined, + .host = undefined, + .port = undefined, + }; + return .{ + .req = req, + .loop = loop, + .data = .{ .conn = connection }, + }; + } + + pub fn setErr(self: *Ctx, err: anyerror) void { + self.err = err; + } + + pub fn push(self: *Ctx, comptime func: Stack.Fn) !void { + if (self.stack) |stack| { + return try stack.push(self.alloc(), func); + } + self.stack = try Stack.init(self.alloc(), func); + } + + pub fn pop(self: *Ctx, res: anyerror!void) !void { + if (self.stack) |stack| { + const func = stack.pop(self.alloc(), null); + const ret = @call(.auto, func, .{ self, res }); + if (stack.next == null) { + self.stack = null; + self.alloc().destroy(stack); + } + return ret; + } + } + + pub fn deinit(self: Ctx) void { + if (self.stack) |stack| { + stack.deinit(self.alloc(), null); + } + } + + // not sure about those + + pub fn len(self: Ctx) usize { + if (self._len == null) unreachable; + return self._len.?; + } + + pub fn setLen(self: *Ctx, nb: ?usize) void { + self._len = nb; + } + + pub fn buf(self: Ctx) []const u8 { + if (self._buffer == null) unreachable; + return self._buffer.?; + } + + pub fn setBuf(self: *Ctx, bytes: ?[]const u8) void { + self._buffer = bytes; + } + + // ctx Request aliases + + pub fn alloc(self: Ctx) std.mem.Allocator { + return self.req.client.allocator; + } + + pub fn conn(self: Ctx) *Connection { + return self.req.connection.?; + } + + pub fn stream(self: Ctx) net.Stream { + return self.conn().stream; + } +}; + +// requires ctx.data.conn to be set +fn setRequestConnection(ctx: *Ctx, res: anyerror!void) anyerror!void { + res catch |e| return ctx.pop(e); + + ctx.req.connection = ctx.data.conn; + return ctx.pop({}); +} + +fn onRequestWait(ctx: *Ctx, res: anyerror!void) !void { + res catch |e| { + std.debug.print("error: {any}\n", .{e}); + return e; + }; + std.log.debug("REQUEST WAITED", .{}); + std.log.debug("Status code: {any}", .{ctx.req.response.status}); + const body = try ctx.req.reader().readAllAlloc(ctx.alloc(), 1024 * 1024); + defer ctx.alloc().free(body); + std.log.debug("Body: \n{s}", .{body}); +} + +fn onRequestFinish(ctx: *Ctx, res: anyerror!void) !void { + res catch |err| return err; + std.log.debug("REQUEST FINISHED", .{}); + return ctx.req.async_wait(ctx, onRequestWait); +} + +fn onRequestSend(ctx: *Ctx, res: anyerror!void) !void { + res catch |err| return err; + std.log.debug("REQUEST SENT", .{}); + return ctx.req.async_finish(ctx, onRequestFinish); +} + +pub fn onRequestConnect(ctx: *Ctx, res: anyerror!void) anyerror!void { + res catch |err| return err; + std.log.debug("REQUEST CONNECTED", .{}); + return ctx.req.async_send(ctx, onRequestSend); +} + +test { + const alloc = std.testing.allocator; + + var loop = Loop{}; + + var client = Client{ .allocator = alloc }; + defer client.deinit(); + + var req = Request{ + .client = &client, + }; + defer req.deinit(); + + var ctx = try Ctx.init(&loop, &req); + defer ctx.deinit(); + + var server_header_buffer: [2048]u8 = undefined; + + const url = "http://www.example.com"; + // const url = "http://127.0.0.1:8000/zig"; + try client.async_open( + .GET, + try std.Uri.parse(url), + .{ .server_header_buffer = &server_header_buffer }, + &ctx, + onRequestConnect, + ); +} diff --git a/src/http/async/std/http/Server.zig b/src/http/async/std/http/Server.zig new file mode 100644 index 00000000..38d3f133 --- /dev/null +++ b/src/http/async/std/http/Server.zig @@ -0,0 +1,1148 @@ +//! Blocking HTTP server implementation. +//! Handles a single connection's lifecycle. + +connection: net.Server.Connection, +/// Keeps track of whether the Server is ready to accept a new request on the +/// same connection, and makes invalid API usage cause assertion failures +/// rather than HTTP protocol violations. +state: State, +/// User-provided buffer that must outlive this Server. +/// Used to store the client's entire HTTP header. +read_buffer: []u8, +/// Amount of available data inside read_buffer. +read_buffer_len: usize, +/// Index into `read_buffer` of the first byte of the next HTTP request. +next_request_start: usize, + +pub const State = enum { + /// The connection is available to be used for the first time, or reused. + ready, + /// An error occurred in `receiveHead`. + receiving_head, + /// A Request object has been obtained and from there a Response can be + /// opened. + received_head, + /// The client is uploading something to this Server. + receiving_body, + /// The connection is eligible for another HTTP request, however the client + /// and server did not negotiate a persistent connection. + closing, +}; + +/// Initialize an HTTP server that can respond to multiple requests on the same +/// connection. +/// The returned `Server` is ready for `receiveHead` to be called. +pub fn init(connection: net.Server.Connection, read_buffer: []u8) Server { + return .{ + .connection = connection, + .state = .ready, + .read_buffer = read_buffer, + .read_buffer_len = 0, + .next_request_start = 0, + }; +} + +pub const ReceiveHeadError = error{ + /// Client sent too many bytes of HTTP headers. + /// The HTTP specification suggests to respond with a 431 status code + /// before closing the connection. + HttpHeadersOversize, + /// Client sent headers that did not conform to the HTTP protocol. + HttpHeadersInvalid, + /// A low level I/O error occurred trying to read the headers. + HttpHeadersUnreadable, + /// Partial HTTP request was received but the connection was closed before + /// fully receiving the headers. + HttpRequestTruncated, + /// The client sent 0 bytes of headers before closing the stream. + /// In other words, a keep-alive connection was finally closed. + HttpConnectionClosing, +}; + +/// The header bytes reference the read buffer that Server was initialized with +/// and remain alive until the next call to receiveHead. +pub fn receiveHead(s: *Server) ReceiveHeadError!Request { + assert(s.state == .ready); + s.state = .received_head; + errdefer s.state = .receiving_head; + + // In case of a reused connection, move the next request's bytes to the + // beginning of the buffer. + if (s.next_request_start > 0) { + if (s.read_buffer_len > s.next_request_start) { + rebase(s, 0); + } else { + s.read_buffer_len = 0; + } + } + + var hp: http.HeadParser = .{}; + + if (s.read_buffer_len > 0) { + const bytes = s.read_buffer[0..s.read_buffer_len]; + const end = hp.feed(bytes); + if (hp.state == .finished) + return finishReceivingHead(s, end); + } + + while (true) { + const buf = s.read_buffer[s.read_buffer_len..]; + if (buf.len == 0) + return error.HttpHeadersOversize; + const read_n = s.connection.stream.read(buf) catch + return error.HttpHeadersUnreadable; + if (read_n == 0) { + if (s.read_buffer_len > 0) { + return error.HttpRequestTruncated; + } else { + return error.HttpConnectionClosing; + } + } + s.read_buffer_len += read_n; + const bytes = buf[0..read_n]; + const end = hp.feed(bytes); + if (hp.state == .finished) + return finishReceivingHead(s, s.read_buffer_len - bytes.len + end); + } +} + +fn finishReceivingHead(s: *Server, head_end: usize) ReceiveHeadError!Request { + return .{ + .server = s, + .head_end = head_end, + .head = Request.Head.parse(s.read_buffer[0..head_end]) catch + return error.HttpHeadersInvalid, + .reader_state = undefined, + }; +} + +pub const Request = struct { + server: *Server, + /// Index into Server's read_buffer. + head_end: usize, + head: Head, + reader_state: union { + remaining_content_length: u64, + chunk_parser: http.ChunkParser, + }, + + pub const Compression = union(enum) { + pub const DeflateDecompressor = std.compress.zlib.Decompressor(std.io.AnyReader); + pub const GzipDecompressor = std.compress.gzip.Decompressor(std.io.AnyReader); + pub const ZstdDecompressor = std.compress.zstd.Decompressor(std.io.AnyReader); + + deflate: DeflateDecompressor, + gzip: GzipDecompressor, + zstd: ZstdDecompressor, + none: void, + }; + + pub const Head = struct { + method: http.Method, + target: []const u8, + version: http.Version, + expect: ?[]const u8, + content_type: ?[]const u8, + content_length: ?u64, + transfer_encoding: http.TransferEncoding, + transfer_compression: http.ContentEncoding, + keep_alive: bool, + compression: Compression, + + pub const ParseError = error{ + UnknownHttpMethod, + HttpHeadersInvalid, + HttpHeaderContinuationsUnsupported, + HttpTransferEncodingUnsupported, + HttpConnectionHeaderUnsupported, + InvalidContentLength, + CompressionUnsupported, + MissingFinalNewline, + }; + + pub fn parse(bytes: []const u8) ParseError!Head { + var it = mem.splitSequence(u8, bytes, "\r\n"); + + const first_line = it.next().?; + if (first_line.len < 10) + return error.HttpHeadersInvalid; + + const method_end = mem.indexOfScalar(u8, first_line, ' ') orelse + return error.HttpHeadersInvalid; + if (method_end > 24) return error.HttpHeadersInvalid; + + const method_str = first_line[0..method_end]; + const method: http.Method = @enumFromInt(http.Method.parse(method_str)); + + const version_start = mem.lastIndexOfScalar(u8, first_line, ' ') orelse + return error.HttpHeadersInvalid; + if (version_start == method_end) return error.HttpHeadersInvalid; + + const version_str = first_line[version_start + 1 ..]; + if (version_str.len != 8) return error.HttpHeadersInvalid; + const version: http.Version = switch (int64(version_str[0..8])) { + int64("HTTP/1.0") => .@"HTTP/1.0", + int64("HTTP/1.1") => .@"HTTP/1.1", + else => return error.HttpHeadersInvalid, + }; + + const target = first_line[method_end + 1 .. version_start]; + + var head: Head = .{ + .method = method, + .target = target, + .version = version, + .expect = null, + .content_type = null, + .content_length = null, + .transfer_encoding = .none, + .transfer_compression = .identity, + .keep_alive = switch (version) { + .@"HTTP/1.0" => false, + .@"HTTP/1.1" => true, + }, + .compression = .none, + }; + + while (it.next()) |line| { + if (line.len == 0) return head; + switch (line[0]) { + ' ', '\t' => return error.HttpHeaderContinuationsUnsupported, + else => {}, + } + + var line_it = mem.splitScalar(u8, line, ':'); + const header_name = line_it.next().?; + const header_value = mem.trim(u8, line_it.rest(), " \t"); + if (header_name.len == 0) return error.HttpHeadersInvalid; + + if (std.ascii.eqlIgnoreCase(header_name, "connection")) { + head.keep_alive = !std.ascii.eqlIgnoreCase(header_value, "close"); + } else if (std.ascii.eqlIgnoreCase(header_name, "expect")) { + head.expect = header_value; + } else if (std.ascii.eqlIgnoreCase(header_name, "content-type")) { + head.content_type = header_value; + } else if (std.ascii.eqlIgnoreCase(header_name, "content-length")) { + if (head.content_length != null) return error.HttpHeadersInvalid; + head.content_length = std.fmt.parseInt(u64, header_value, 10) catch + return error.InvalidContentLength; + } else if (std.ascii.eqlIgnoreCase(header_name, "content-encoding")) { + if (head.transfer_compression != .identity) return error.HttpHeadersInvalid; + + const trimmed = mem.trim(u8, header_value, " "); + + if (std.meta.stringToEnum(http.ContentEncoding, trimmed)) |ce| { + head.transfer_compression = ce; + } else { + return error.HttpTransferEncodingUnsupported; + } + } else if (std.ascii.eqlIgnoreCase(header_name, "transfer-encoding")) { + // Transfer-Encoding: second, first + // Transfer-Encoding: deflate, chunked + var iter = mem.splitBackwardsScalar(u8, header_value, ','); + + const first = iter.first(); + const trimmed_first = mem.trim(u8, first, " "); + + var next: ?[]const u8 = first; + if (std.meta.stringToEnum(http.TransferEncoding, trimmed_first)) |transfer| { + if (head.transfer_encoding != .none) + return error.HttpHeadersInvalid; // we already have a transfer encoding + head.transfer_encoding = transfer; + + next = iter.next(); + } + + if (next) |second| { + const trimmed_second = mem.trim(u8, second, " "); + + if (std.meta.stringToEnum(http.ContentEncoding, trimmed_second)) |transfer| { + if (head.transfer_compression != .identity) + return error.HttpHeadersInvalid; // double compression is not supported + head.transfer_compression = transfer; + } else { + return error.HttpTransferEncodingUnsupported; + } + } + + if (iter.next()) |_| return error.HttpTransferEncodingUnsupported; + } + } + return error.MissingFinalNewline; + } + + test parse { + const request_bytes = "GET /hi HTTP/1.0\r\n" ++ + "content-tYpe: text/plain\r\n" ++ + "content-Length:10\r\n" ++ + "expeCt: 100-continue \r\n" ++ + "TRansfer-encoding:\tdeflate, chunked \r\n" ++ + "connectioN:\t keep-alive \r\n\r\n"; + + const req = try parse(request_bytes); + + try testing.expectEqual(.GET, req.method); + try testing.expectEqual(.@"HTTP/1.0", req.version); + try testing.expectEqualStrings("/hi", req.target); + + try testing.expectEqualStrings("text/plain", req.content_type.?); + try testing.expectEqualStrings("100-continue", req.expect.?); + + try testing.expectEqual(true, req.keep_alive); + try testing.expectEqual(10, req.content_length.?); + try testing.expectEqual(.chunked, req.transfer_encoding); + try testing.expectEqual(.deflate, req.transfer_compression); + } + + inline fn int64(array: *const [8]u8) u64 { + return @bitCast(array.*); + } + }; + + pub fn iterateHeaders(r: *Request) http.HeaderIterator { + return http.HeaderIterator.init(r.server.read_buffer[0..r.head_end]); + } + + test iterateHeaders { + const request_bytes = "GET /hi HTTP/1.0\r\n" ++ + "content-tYpe: text/plain\r\n" ++ + "content-Length:10\r\n" ++ + "expeCt: 100-continue \r\n" ++ + "TRansfer-encoding:\tdeflate, chunked \r\n" ++ + "connectioN:\t keep-alive \r\n\r\n"; + + var read_buffer: [500]u8 = undefined; + @memcpy(read_buffer[0..request_bytes.len], request_bytes); + + var server: Server = .{ + .connection = undefined, + .state = .ready, + .read_buffer = &read_buffer, + .read_buffer_len = request_bytes.len, + .next_request_start = 0, + }; + + var request: Request = .{ + .server = &server, + .head_end = request_bytes.len, + .head = undefined, + .reader_state = undefined, + }; + + var it = request.iterateHeaders(); + { + const header = it.next().?; + try testing.expectEqualStrings("content-tYpe", header.name); + try testing.expectEqualStrings("text/plain", header.value); + try testing.expect(!it.is_trailer); + } + { + const header = it.next().?; + try testing.expectEqualStrings("content-Length", header.name); + try testing.expectEqualStrings("10", header.value); + try testing.expect(!it.is_trailer); + } + { + const header = it.next().?; + try testing.expectEqualStrings("expeCt", header.name); + try testing.expectEqualStrings("100-continue", header.value); + try testing.expect(!it.is_trailer); + } + { + const header = it.next().?; + try testing.expectEqualStrings("TRansfer-encoding", header.name); + try testing.expectEqualStrings("deflate, chunked", header.value); + try testing.expect(!it.is_trailer); + } + { + const header = it.next().?; + try testing.expectEqualStrings("connectioN", header.name); + try testing.expectEqualStrings("keep-alive", header.value); + try testing.expect(!it.is_trailer); + } + try testing.expectEqual(null, it.next()); + } + + pub const RespondOptions = struct { + version: http.Version = .@"HTTP/1.1", + status: http.Status = .ok, + reason: ?[]const u8 = null, + keep_alive: bool = true, + extra_headers: []const http.Header = &.{}, + transfer_encoding: ?http.TransferEncoding = null, + }; + + /// Send an entire HTTP response to the client, including headers and body. + /// + /// Automatically handles HEAD requests by omitting the body. + /// + /// Unless `transfer_encoding` is specified, uses the "content-length" + /// header. + /// + /// If the request contains a body and the connection is to be reused, + /// discards the request body, leaving the Server in the `ready` state. If + /// this discarding fails, the connection is marked as not to be reused and + /// no error is surfaced. + /// + /// Asserts status is not `continue`. + /// Asserts there are at most 25 extra_headers. + /// Asserts that "\r\n" does not occur in any header name or value. + pub fn respond( + request: *Request, + content: []const u8, + options: RespondOptions, + ) Response.WriteError!void { + const max_extra_headers = 25; + assert(options.status != .@"continue"); + assert(options.extra_headers.len <= max_extra_headers); + if (std.debug.runtime_safety) { + for (options.extra_headers) |header| { + assert(header.name.len != 0); + assert(std.mem.indexOfScalar(u8, header.name, ':') == null); + assert(std.mem.indexOfPosLinear(u8, header.name, 0, "\r\n") == null); + assert(std.mem.indexOfPosLinear(u8, header.value, 0, "\r\n") == null); + } + } + + const transfer_encoding_none = (options.transfer_encoding orelse .chunked) == .none; + const server_keep_alive = !transfer_encoding_none and options.keep_alive; + const keep_alive = request.discardBody(server_keep_alive); + + const phrase = options.reason orelse options.status.phrase() orelse ""; + + var first_buffer: [500]u8 = undefined; + var h = std.ArrayListUnmanaged(u8).initBuffer(&first_buffer); + if (request.head.expect != null) { + // reader() and hence discardBody() above sets expect to null if it + // is handled. So the fact that it is not null here means unhandled. + h.appendSliceAssumeCapacity("HTTP/1.1 417 Expectation Failed\r\n"); + if (!keep_alive) h.appendSliceAssumeCapacity("connection: close\r\n"); + h.appendSliceAssumeCapacity("content-length: 0\r\n\r\n"); + try request.server.connection.stream.writeAll(h.items); + return; + } + h.fixedWriter().print("{s} {d} {s}\r\n", .{ + @tagName(options.version), @intFromEnum(options.status), phrase, + }) catch unreachable; + + switch (options.version) { + .@"HTTP/1.0" => if (keep_alive) h.appendSliceAssumeCapacity("connection: keep-alive\r\n"), + .@"HTTP/1.1" => if (!keep_alive) h.appendSliceAssumeCapacity("connection: close\r\n"), + } + + if (options.transfer_encoding) |transfer_encoding| switch (transfer_encoding) { + .none => {}, + .chunked => h.appendSliceAssumeCapacity("transfer-encoding: chunked\r\n"), + } else { + h.fixedWriter().print("content-length: {d}\r\n", .{content.len}) catch unreachable; + } + + var chunk_header_buffer: [18]u8 = undefined; + var iovecs: [max_extra_headers * 4 + 3]std.posix.iovec_const = undefined; + var iovecs_len: usize = 0; + + iovecs[iovecs_len] = .{ + .base = h.items.ptr, + .len = h.items.len, + }; + iovecs_len += 1; + + for (options.extra_headers) |header| { + iovecs[iovecs_len] = .{ + .base = header.name.ptr, + .len = header.name.len, + }; + iovecs_len += 1; + + iovecs[iovecs_len] = .{ + .base = ": ", + .len = 2, + }; + iovecs_len += 1; + + if (header.value.len != 0) { + iovecs[iovecs_len] = .{ + .base = header.value.ptr, + .len = header.value.len, + }; + iovecs_len += 1; + } + + iovecs[iovecs_len] = .{ + .base = "\r\n", + .len = 2, + }; + iovecs_len += 1; + } + + iovecs[iovecs_len] = .{ + .base = "\r\n", + .len = 2, + }; + iovecs_len += 1; + + if (request.head.method != .HEAD) { + const is_chunked = (options.transfer_encoding orelse .none) == .chunked; + if (is_chunked) { + if (content.len > 0) { + const chunk_header = std.fmt.bufPrint( + &chunk_header_buffer, + "{x}\r\n", + .{content.len}, + ) catch unreachable; + + iovecs[iovecs_len] = .{ + .base = chunk_header.ptr, + .len = chunk_header.len, + }; + iovecs_len += 1; + + iovecs[iovecs_len] = .{ + .base = content.ptr, + .len = content.len, + }; + iovecs_len += 1; + + iovecs[iovecs_len] = .{ + .base = "\r\n", + .len = 2, + }; + iovecs_len += 1; + } + + iovecs[iovecs_len] = .{ + .base = "0\r\n\r\n", + .len = 5, + }; + iovecs_len += 1; + } else if (content.len > 0) { + iovecs[iovecs_len] = .{ + .base = content.ptr, + .len = content.len, + }; + iovecs_len += 1; + } + } + + try request.server.connection.stream.writevAll(iovecs[0..iovecs_len]); + } + + pub const RespondStreamingOptions = struct { + /// An externally managed slice of memory used to batch bytes before + /// sending. `respondStreaming` asserts this is large enough to store + /// the full HTTP response head. + /// + /// Must outlive the returned Response. + send_buffer: []u8, + /// If provided, the response will use the content-length header; + /// otherwise it will use transfer-encoding: chunked. + content_length: ?u64 = null, + /// Options that are shared with the `respond` method. + respond_options: RespondOptions = .{}, + }; + + /// The header is buffered but not sent until Response.flush is called. + /// + /// If the request contains a body and the connection is to be reused, + /// discards the request body, leaving the Server in the `ready` state. If + /// this discarding fails, the connection is marked as not to be reused and + /// no error is surfaced. + /// + /// HEAD requests are handled transparently by setting a flag on the + /// returned Response to omit the body. However it may be worth noticing + /// that flag and skipping any expensive work that would otherwise need to + /// be done to satisfy the request. + /// + /// Asserts `send_buffer` is large enough to store the entire response header. + /// Asserts status is not `continue`. + pub fn respondStreaming(request: *Request, options: RespondStreamingOptions) Response { + const o = options.respond_options; + assert(o.status != .@"continue"); + const transfer_encoding_none = (o.transfer_encoding orelse .chunked) == .none; + const server_keep_alive = !transfer_encoding_none and o.keep_alive; + const keep_alive = request.discardBody(server_keep_alive); + const phrase = o.reason orelse o.status.phrase() orelse ""; + + var h = std.ArrayListUnmanaged(u8).initBuffer(options.send_buffer); + + const elide_body = if (request.head.expect != null) eb: { + // reader() and hence discardBody() above sets expect to null if it + // is handled. So the fact that it is not null here means unhandled. + h.appendSliceAssumeCapacity("HTTP/1.1 417 Expectation Failed\r\n"); + if (!keep_alive) h.appendSliceAssumeCapacity("connection: close\r\n"); + h.appendSliceAssumeCapacity("content-length: 0\r\n\r\n"); + break :eb true; + } else eb: { + h.fixedWriter().print("{s} {d} {s}\r\n", .{ + @tagName(o.version), @intFromEnum(o.status), phrase, + }) catch unreachable; + + switch (o.version) { + .@"HTTP/1.0" => if (keep_alive) h.appendSliceAssumeCapacity("connection: keep-alive\r\n"), + .@"HTTP/1.1" => if (!keep_alive) h.appendSliceAssumeCapacity("connection: close\r\n"), + } + + if (o.transfer_encoding) |transfer_encoding| switch (transfer_encoding) { + .chunked => h.appendSliceAssumeCapacity("transfer-encoding: chunked\r\n"), + .none => {}, + } else if (options.content_length) |len| { + h.fixedWriter().print("content-length: {d}\r\n", .{len}) catch unreachable; + } else { + h.appendSliceAssumeCapacity("transfer-encoding: chunked\r\n"); + } + + for (o.extra_headers) |header| { + assert(header.name.len != 0); + h.appendSliceAssumeCapacity(header.name); + h.appendSliceAssumeCapacity(": "); + h.appendSliceAssumeCapacity(header.value); + h.appendSliceAssumeCapacity("\r\n"); + } + + h.appendSliceAssumeCapacity("\r\n"); + break :eb request.head.method == .HEAD; + }; + + return .{ + .stream = request.server.connection.stream, + .send_buffer = options.send_buffer, + .send_buffer_start = 0, + .send_buffer_end = h.items.len, + .transfer_encoding = if (o.transfer_encoding) |te| switch (te) { + .chunked => .chunked, + .none => .none, + } else if (options.content_length) |len| .{ + .content_length = len, + } else .chunked, + .elide_body = elide_body, + .chunk_len = 0, + }; + } + + pub const ReadError = net.Stream.ReadError || error{ + HttpChunkInvalid, + HttpHeadersOversize, + }; + + fn read_cl(context: *const anyopaque, buffer: []u8) ReadError!usize { + const request: *Request = @constCast(@alignCast(@ptrCast(context))); + const s = request.server; + + const remaining_content_length = &request.reader_state.remaining_content_length; + if (remaining_content_length.* == 0) { + s.state = .ready; + return 0; + } + assert(s.state == .receiving_body); + const available = try fill(s, request.head_end); + const len = @min(remaining_content_length.*, available.len, buffer.len); + @memcpy(buffer[0..len], available[0..len]); + remaining_content_length.* -= len; + s.next_request_start += len; + if (remaining_content_length.* == 0) + s.state = .ready; + return len; + } + + fn fill(s: *Server, head_end: usize) ReadError![]u8 { + const available = s.read_buffer[s.next_request_start..s.read_buffer_len]; + if (available.len > 0) return available; + s.next_request_start = head_end; + s.read_buffer_len = head_end + try s.connection.stream.read(s.read_buffer[head_end..]); + return s.read_buffer[head_end..s.read_buffer_len]; + } + + fn read_chunked(context: *const anyopaque, buffer: []u8) ReadError!usize { + const request: *Request = @constCast(@alignCast(@ptrCast(context))); + const s = request.server; + + const cp = &request.reader_state.chunk_parser; + const head_end = request.head_end; + + // Protect against returning 0 before the end of stream. + var out_end: usize = 0; + while (out_end == 0) { + switch (cp.state) { + .invalid => return 0, + .data => { + assert(s.state == .receiving_body); + const available = try fill(s, head_end); + const len = @min(cp.chunk_len, available.len, buffer.len); + @memcpy(buffer[0..len], available[0..len]); + cp.chunk_len -= len; + if (cp.chunk_len == 0) + cp.state = .data_suffix; + out_end += len; + s.next_request_start += len; + continue; + }, + else => { + assert(s.state == .receiving_body); + const available = try fill(s, head_end); + const n = cp.feed(available); + switch (cp.state) { + .invalid => return error.HttpChunkInvalid, + .data => { + if (cp.chunk_len == 0) { + // The next bytes in the stream are trailers, + // or \r\n to indicate end of chunked body. + // + // This function must append the trailers at + // head_end so that headers and trailers are + // together. + // + // Since returning 0 would indicate end of + // stream, this function must read all the + // trailers before returning. + if (s.next_request_start > head_end) rebase(s, head_end); + var hp: http.HeadParser = .{}; + { + const bytes = s.read_buffer[head_end..s.read_buffer_len]; + const end = hp.feed(bytes); + if (hp.state == .finished) { + cp.state = .invalid; + s.state = .ready; + s.next_request_start = s.read_buffer_len - bytes.len + end; + return out_end; + } + } + while (true) { + const buf = s.read_buffer[s.read_buffer_len..]; + if (buf.len == 0) + return error.HttpHeadersOversize; + const read_n = try s.connection.stream.read(buf); + s.read_buffer_len += read_n; + const bytes = buf[0..read_n]; + const end = hp.feed(bytes); + if (hp.state == .finished) { + cp.state = .invalid; + s.state = .ready; + s.next_request_start = s.read_buffer_len - bytes.len + end; + return out_end; + } + } + } + const data = available[n..]; + const len = @min(cp.chunk_len, data.len, buffer.len); + @memcpy(buffer[0..len], data[0..len]); + cp.chunk_len -= len; + if (cp.chunk_len == 0) + cp.state = .data_suffix; + out_end += len; + s.next_request_start += n + len; + continue; + }, + else => continue, + } + }, + } + } + return out_end; + } + + pub const ReaderError = Response.WriteError || error{ + /// The client sent an expect HTTP header value other than + /// "100-continue". + HttpExpectationFailed, + }; + + /// In the case that the request contains "expect: 100-continue", this + /// function writes the continuation header, which means it can fail with a + /// write error. After sending the continuation header, it sets the + /// request's expect field to `null`. + /// + /// Asserts that this function is only called once. + pub fn reader(request: *Request) ReaderError!std.io.AnyReader { + const s = request.server; + assert(s.state == .received_head); + s.state = .receiving_body; + s.next_request_start = request.head_end; + + if (request.head.expect) |expect| { + if (mem.eql(u8, expect, "100-continue")) { + try request.server.connection.stream.writeAll("HTTP/1.1 100 Continue\r\n\r\n"); + request.head.expect = null; + } else { + return error.HttpExpectationFailed; + } + } + + switch (request.head.transfer_encoding) { + .chunked => { + request.reader_state = .{ .chunk_parser = http.ChunkParser.init }; + return .{ + .readFn = read_chunked, + .context = request, + }; + }, + .none => { + request.reader_state = .{ + .remaining_content_length = request.head.content_length orelse 0, + }; + return .{ + .readFn = read_cl, + .context = request, + }; + }, + } + } + + /// Returns whether the connection should remain persistent. + /// If it would fail, it instead sets the Server state to `receiving_body` + /// and returns false. + fn discardBody(request: *Request, keep_alive: bool) bool { + // Prepare to receive another request on the same connection. + // There are two factors to consider: + // * Any body the client sent must be discarded. + // * The Server's read_buffer may already have some bytes in it from + // whatever came after the head, which may be the next HTTP request + // or the request body. + // If the connection won't be kept alive, then none of this matters + // because the connection will be severed after the response is sent. + const s = request.server; + if (keep_alive and request.head.keep_alive) switch (s.state) { + .received_head => { + const r = request.reader() catch return false; + _ = r.discard() catch return false; + assert(s.state == .ready); + return true; + }, + .receiving_body, .ready => return true, + else => unreachable, + }; + + // Avoid clobbering the state in case a reading stream already exists. + switch (s.state) { + .received_head => s.state = .closing, + else => {}, + } + return false; + } +}; + +pub const Response = struct { + stream: net.Stream, + send_buffer: []u8, + /// Index of the first byte in `send_buffer`. + /// This is 0 unless a short write happens in `write`. + send_buffer_start: usize, + /// Index of the last byte + 1 in `send_buffer`. + send_buffer_end: usize, + /// `null` means transfer-encoding: chunked. + /// As a debugging utility, counts down to zero as bytes are written. + transfer_encoding: TransferEncoding, + elide_body: bool, + /// Indicates how much of the end of the `send_buffer` corresponds to a + /// chunk. This amount of data will be wrapped by an HTTP chunk header. + chunk_len: usize, + + pub const TransferEncoding = union(enum) { + /// End of connection signals the end of the stream. + none, + /// As a debugging utility, counts down to zero as bytes are written. + content_length: u64, + /// Each chunk is wrapped in a header and trailer. + chunked, + }; + + pub const WriteError = net.Stream.WriteError; + + /// When using content-length, asserts that the amount of data sent matches + /// the value sent in the header, then calls `flush`. + /// Otherwise, transfer-encoding: chunked is being used, and it writes the + /// end-of-stream message, then flushes the stream to the system. + /// Respects the value of `elide_body` to omit all data after the headers. + pub fn end(r: *Response) WriteError!void { + switch (r.transfer_encoding) { + .content_length => |len| { + assert(len == 0); // Trips when end() called before all bytes written. + try flush_cl(r); + }, + .none => { + try flush_cl(r); + }, + .chunked => { + try flush_chunked(r, &.{}); + }, + } + r.* = undefined; + } + + pub const EndChunkedOptions = struct { + trailers: []const http.Header = &.{}, + }; + + /// Asserts that the Response is using transfer-encoding: chunked. + /// Writes the end-of-stream message and any optional trailers, then + /// flushes the stream to the system. + /// Respects the value of `elide_body` to omit all data after the headers. + /// Asserts there are at most 25 trailers. + pub fn endChunked(r: *Response, options: EndChunkedOptions) WriteError!void { + assert(r.transfer_encoding == .chunked); + try flush_chunked(r, options.trailers); + r.* = undefined; + } + + /// If using content-length, asserts that writing these bytes to the client + /// would not exceed the content-length value sent in the HTTP header. + /// May return 0, which does not indicate end of stream. The caller decides + /// when the end of stream occurs by calling `end`. + pub fn write(r: *Response, bytes: []const u8) WriteError!usize { + switch (r.transfer_encoding) { + .content_length, .none => return write_cl(r, bytes), + .chunked => return write_chunked(r, bytes), + } + } + + fn write_cl(context: *const anyopaque, bytes: []const u8) WriteError!usize { + const r: *Response = @constCast(@alignCast(@ptrCast(context))); + + var trash: u64 = std.math.maxInt(u64); + const len = switch (r.transfer_encoding) { + .content_length => |*len| len, + else => &trash, + }; + + if (r.elide_body) { + len.* -= bytes.len; + return bytes.len; + } + + if (bytes.len + r.send_buffer_end > r.send_buffer.len) { + const send_buffer_len = r.send_buffer_end - r.send_buffer_start; + var iovecs: [2]std.posix.iovec_const = .{ + .{ + .base = r.send_buffer.ptr + r.send_buffer_start, + .len = send_buffer_len, + }, + .{ + .base = bytes.ptr, + .len = bytes.len, + }, + }; + const n = try r.stream.writev(&iovecs); + + if (n >= send_buffer_len) { + // It was enough to reset the buffer. + r.send_buffer_start = 0; + r.send_buffer_end = 0; + const bytes_n = n - send_buffer_len; + len.* -= bytes_n; + return bytes_n; + } + + // It didn't even make it through the existing buffer, let + // alone the new bytes provided. + r.send_buffer_start += n; + return 0; + } + + // All bytes can be stored in the remaining space of the buffer. + @memcpy(r.send_buffer[r.send_buffer_end..][0..bytes.len], bytes); + r.send_buffer_end += bytes.len; + len.* -= bytes.len; + return bytes.len; + } + + fn write_chunked(context: *const anyopaque, bytes: []const u8) WriteError!usize { + const r: *Response = @constCast(@alignCast(@ptrCast(context))); + assert(r.transfer_encoding == .chunked); + + if (r.elide_body) + return bytes.len; + + if (bytes.len + r.send_buffer_end > r.send_buffer.len) { + const send_buffer_len = r.send_buffer_end - r.send_buffer_start; + const chunk_len = r.chunk_len + bytes.len; + var header_buf: [18]u8 = undefined; + const chunk_header = std.fmt.bufPrint(&header_buf, "{x}\r\n", .{chunk_len}) catch unreachable; + + var iovecs: [5]std.posix.iovec_const = .{ + .{ + .base = r.send_buffer.ptr + r.send_buffer_start, + .len = send_buffer_len - r.chunk_len, + }, + .{ + .base = chunk_header.ptr, + .len = chunk_header.len, + }, + .{ + .base = r.send_buffer.ptr + r.send_buffer_end - r.chunk_len, + .len = r.chunk_len, + }, + .{ + .base = bytes.ptr, + .len = bytes.len, + }, + .{ + .base = "\r\n", + .len = 2, + }, + }; + // TODO make this writev instead of writevAll, which involves + // complicating the logic of this function. + try r.stream.writevAll(&iovecs); + r.send_buffer_start = 0; + r.send_buffer_end = 0; + r.chunk_len = 0; + return bytes.len; + } + + // All bytes can be stored in the remaining space of the buffer. + @memcpy(r.send_buffer[r.send_buffer_end..][0..bytes.len], bytes); + r.send_buffer_end += bytes.len; + r.chunk_len += bytes.len; + return bytes.len; + } + + /// If using content-length, asserts that writing these bytes to the client + /// would not exceed the content-length value sent in the HTTP header. + pub fn writeAll(r: *Response, bytes: []const u8) WriteError!void { + var index: usize = 0; + while (index < bytes.len) { + index += try write(r, bytes[index..]); + } + } + + /// Sends all buffered data to the client. + /// This is redundant after calling `end`. + /// Respects the value of `elide_body` to omit all data after the headers. + pub fn flush(r: *Response) WriteError!void { + switch (r.transfer_encoding) { + .none, .content_length => return flush_cl(r), + .chunked => return flush_chunked(r, null), + } + } + + fn flush_cl(r: *Response) WriteError!void { + try r.stream.writeAll(r.send_buffer[r.send_buffer_start..r.send_buffer_end]); + r.send_buffer_start = 0; + r.send_buffer_end = 0; + } + + fn flush_chunked(r: *Response, end_trailers: ?[]const http.Header) WriteError!void { + const max_trailers = 25; + if (end_trailers) |trailers| assert(trailers.len <= max_trailers); + assert(r.transfer_encoding == .chunked); + + const http_headers = r.send_buffer[r.send_buffer_start .. r.send_buffer_end - r.chunk_len]; + + if (r.elide_body) { + try r.stream.writeAll(http_headers); + r.send_buffer_start = 0; + r.send_buffer_end = 0; + r.chunk_len = 0; + return; + } + + var header_buf: [18]u8 = undefined; + const chunk_header = std.fmt.bufPrint(&header_buf, "{x}\r\n", .{r.chunk_len}) catch unreachable; + + var iovecs: [max_trailers * 4 + 5]std.posix.iovec_const = undefined; + var iovecs_len: usize = 0; + + iovecs[iovecs_len] = .{ + .base = http_headers.ptr, + .len = http_headers.len, + }; + iovecs_len += 1; + + if (r.chunk_len > 0) { + iovecs[iovecs_len] = .{ + .base = chunk_header.ptr, + .len = chunk_header.len, + }; + iovecs_len += 1; + + iovecs[iovecs_len] = .{ + .base = r.send_buffer.ptr + r.send_buffer_end - r.chunk_len, + .len = r.chunk_len, + }; + iovecs_len += 1; + + iovecs[iovecs_len] = .{ + .base = "\r\n", + .len = 2, + }; + iovecs_len += 1; + } + + if (end_trailers) |trailers| { + iovecs[iovecs_len] = .{ + .base = "0\r\n", + .len = 3, + }; + iovecs_len += 1; + + for (trailers) |trailer| { + iovecs[iovecs_len] = .{ + .base = trailer.name.ptr, + .len = trailer.name.len, + }; + iovecs_len += 1; + + iovecs[iovecs_len] = .{ + .base = ": ", + .len = 2, + }; + iovecs_len += 1; + + if (trailer.value.len != 0) { + iovecs[iovecs_len] = .{ + .base = trailer.value.ptr, + .len = trailer.value.len, + }; + iovecs_len += 1; + } + + iovecs[iovecs_len] = .{ + .base = "\r\n", + .len = 2, + }; + iovecs_len += 1; + } + + iovecs[iovecs_len] = .{ + .base = "\r\n", + .len = 2, + }; + iovecs_len += 1; + } + + try r.stream.writevAll(iovecs[0..iovecs_len]); + r.send_buffer_start = 0; + r.send_buffer_end = 0; + r.chunk_len = 0; + } + + pub fn writer(r: *Response) std.io.AnyWriter { + return .{ + .writeFn = switch (r.transfer_encoding) { + .none, .content_length => write_cl, + .chunked => write_chunked, + }, + .context = r, + }; + } +}; + +fn rebase(s: *Server, index: usize) void { + const leftover = s.read_buffer[s.next_request_start..s.read_buffer_len]; + const dest = s.read_buffer[index..][0..leftover.len]; + if (leftover.len <= s.next_request_start - index) { + @memcpy(dest, leftover); + } else { + mem.copyBackwards(u8, dest, leftover); + } + s.read_buffer_len = index + leftover.len; +} + +const std = @import("std"); +const http = std.http; +const mem = std.mem; +const net = std.net; +const Uri = std.Uri; +const assert = std.debug.assert; +const testing = std.testing; + +const Server = @This(); diff --git a/src/http/async/std/http/protocol.zig b/src/http/async/std/http/protocol.zig new file mode 100644 index 00000000..389e1e4f --- /dev/null +++ b/src/http/async/std/http/protocol.zig @@ -0,0 +1,447 @@ +const std = @import("std"); +const builtin = @import("builtin"); +const testing = std.testing; +const mem = std.mem; + +const assert = std.debug.assert; +const use_vectors = builtin.zig_backend != .stage2_x86_64; + +pub const State = enum { + invalid, + + // Begin header and trailer parsing states. + + start, + seen_n, + seen_r, + seen_rn, + seen_rnr, + finished, + + // Begin transfer-encoding: chunked parsing states. + + chunk_head_size, + chunk_head_ext, + chunk_head_r, + chunk_data, + chunk_data_suffix, + chunk_data_suffix_r, + + /// Returns true if the parser is in a content state (ie. not waiting for more headers). + pub fn isContent(self: State) bool { + return switch (self) { + .invalid, .start, .seen_n, .seen_r, .seen_rn, .seen_rnr => false, + .finished, .chunk_head_size, .chunk_head_ext, .chunk_head_r, .chunk_data, .chunk_data_suffix, .chunk_data_suffix_r => true, + }; + } +}; + +pub const HeadersParser = struct { + state: State = .start, + /// A fixed buffer of len `max_header_bytes`. + /// Pointers into this buffer are not stable until after a message is complete. + header_bytes_buffer: []u8, + header_bytes_len: u32, + next_chunk_length: u64, + /// `false`: headers. `true`: trailers. + done: bool, + + /// Initializes the parser with a provided buffer `buf`. + pub fn init(buf: []u8) HeadersParser { + return .{ + .header_bytes_buffer = buf, + .header_bytes_len = 0, + .done = false, + .next_chunk_length = 0, + }; + } + + /// Reinitialize the parser. + /// Asserts the parser is in the "done" state. + pub fn reset(hp: *HeadersParser) void { + assert(hp.done); + hp.* = .{ + .state = .start, + .header_bytes_buffer = hp.header_bytes_buffer, + .header_bytes_len = 0, + .done = false, + .next_chunk_length = 0, + }; + } + + pub fn get(hp: HeadersParser) []u8 { + return hp.header_bytes_buffer[0..hp.header_bytes_len]; + } + + pub fn findHeadersEnd(r: *HeadersParser, bytes: []const u8) u32 { + var hp: std.http.HeadParser = .{ + .state = switch (r.state) { + .start => .start, + .seen_n => .seen_n, + .seen_r => .seen_r, + .seen_rn => .seen_rn, + .seen_rnr => .seen_rnr, + .finished => .finished, + else => unreachable, + }, + }; + const result = hp.feed(bytes); + r.state = switch (hp.state) { + .start => .start, + .seen_n => .seen_n, + .seen_r => .seen_r, + .seen_rn => .seen_rn, + .seen_rnr => .seen_rnr, + .finished => .finished, + }; + return @intCast(result); + } + + pub fn findChunkedLen(r: *HeadersParser, bytes: []const u8) u32 { + var cp: std.http.ChunkParser = .{ + .state = switch (r.state) { + .chunk_head_size => .head_size, + .chunk_head_ext => .head_ext, + .chunk_head_r => .head_r, + .chunk_data => .data, + .chunk_data_suffix => .data_suffix, + .chunk_data_suffix_r => .data_suffix_r, + .invalid => .invalid, + else => unreachable, + }, + .chunk_len = r.next_chunk_length, + }; + const result = cp.feed(bytes); + r.state = switch (cp.state) { + .head_size => .chunk_head_size, + .head_ext => .chunk_head_ext, + .head_r => .chunk_head_r, + .data => .chunk_data, + .data_suffix => .chunk_data_suffix, + .data_suffix_r => .chunk_data_suffix_r, + .invalid => .invalid, + }; + r.next_chunk_length = cp.chunk_len; + return @intCast(result); + } + + /// Returns whether or not the parser has finished parsing a complete + /// message. A message is only complete after the entire body has been read + /// and any trailing headers have been parsed. + pub fn isComplete(r: *HeadersParser) bool { + return r.done and r.state == .finished; + } + + pub const CheckCompleteHeadError = error{HttpHeadersOversize}; + + /// Pushes `in` into the parser. Returns the number of bytes consumed by + /// the header. Any header bytes are appended to `header_bytes_buffer`. + pub fn checkCompleteHead(hp: *HeadersParser, in: []const u8) CheckCompleteHeadError!u32 { + if (hp.state.isContent()) return 0; + + const i = hp.findHeadersEnd(in); + const data = in[0..i]; + if (hp.header_bytes_len + data.len > hp.header_bytes_buffer.len) + return error.HttpHeadersOversize; + + @memcpy(hp.header_bytes_buffer[hp.header_bytes_len..][0..data.len], data); + hp.header_bytes_len += @intCast(data.len); + + return i; + } + + pub const ReadError = error{ + HttpChunkInvalid, + }; + + /// Reads the body of the message into `buffer`. Returns the number of + /// bytes placed in the buffer. + /// + /// If `skip` is true, the buffer will be unused and the body will be skipped. + /// + /// See `std.http.Client.Connection for an example of `conn`. + pub fn read(r: *HeadersParser, conn: anytype, buffer: []u8, skip: bool) !usize { + assert(r.state.isContent()); + if (r.done) return 0; + + var out_index: usize = 0; + while (true) { + switch (r.state) { + .invalid, .start, .seen_n, .seen_r, .seen_rn, .seen_rnr => unreachable, + .finished => { + const data_avail = r.next_chunk_length; + + if (skip) { + try conn.fill(); + + const nread = @min(conn.peek().len, data_avail); + conn.drop(@intCast(nread)); + r.next_chunk_length -= nread; + + if (r.next_chunk_length == 0 or nread == 0) r.done = true; + + return out_index; + } else if (out_index < buffer.len) { + const out_avail = buffer.len - out_index; + + const can_read = @as(usize, @intCast(@min(data_avail, out_avail))); + const nread = try conn.read(buffer[0..can_read]); + r.next_chunk_length -= nread; + + if (r.next_chunk_length == 0 or nread == 0) r.done = true; + + return nread; + } else { + return out_index; + } + }, + .chunk_data_suffix, .chunk_data_suffix_r, .chunk_head_size, .chunk_head_ext, .chunk_head_r => { + try conn.fill(); + + const i = r.findChunkedLen(conn.peek()); + conn.drop(@intCast(i)); + + switch (r.state) { + .invalid => return error.HttpChunkInvalid, + .chunk_data => if (r.next_chunk_length == 0) { + if (std.mem.eql(u8, conn.peek(), "\r\n")) { + r.state = .finished; + conn.drop(2); + } else { + // The trailer section is formatted identically + // to the header section. + r.state = .seen_rn; + } + r.done = true; + + return out_index; + }, + else => return out_index, + } + + continue; + }, + .chunk_data => { + const data_avail = r.next_chunk_length; + const out_avail = buffer.len - out_index; + + if (skip) { + try conn.fill(); + + const nread = @min(conn.peek().len, data_avail); + conn.drop(@intCast(nread)); + r.next_chunk_length -= nread; + } else if (out_avail > 0) { + const can_read: usize = @intCast(@min(data_avail, out_avail)); + const nread = try conn.read(buffer[out_index..][0..can_read]); + r.next_chunk_length -= nread; + out_index += nread; + } + + if (r.next_chunk_length == 0) { + r.state = .chunk_data_suffix; + continue; + } + + return out_index; + }, + } + } + } +}; + +inline fn int16(array: *const [2]u8) u16 { + return @as(u16, @bitCast(array.*)); +} + +inline fn int24(array: *const [3]u8) u24 { + return @as(u24, @bitCast(array.*)); +} + +inline fn int32(array: *const [4]u8) u32 { + return @as(u32, @bitCast(array.*)); +} + +inline fn intShift(comptime T: type, x: anytype) T { + switch (@import("builtin").cpu.arch.endian()) { + .little => return @as(T, @truncate(x >> (@bitSizeOf(@TypeOf(x)) - @bitSizeOf(T)))), + .big => return @as(T, @truncate(x)), + } +} + +/// A buffered (and peekable) Connection. +const MockBufferedConnection = struct { + pub const buffer_size = 0x2000; + + conn: std.io.FixedBufferStream([]const u8), + buf: [buffer_size]u8 = undefined, + start: u16 = 0, + end: u16 = 0, + + pub fn fill(conn: *MockBufferedConnection) ReadError!void { + if (conn.end != conn.start) return; + + const nread = try conn.conn.read(conn.buf[0..]); + if (nread == 0) return error.EndOfStream; + conn.start = 0; + conn.end = @as(u16, @truncate(nread)); + } + + pub fn peek(conn: *MockBufferedConnection) []const u8 { + return conn.buf[conn.start..conn.end]; + } + + pub fn drop(conn: *MockBufferedConnection, num: u16) void { + conn.start += num; + } + + pub fn readAtLeast(conn: *MockBufferedConnection, buffer: []u8, len: usize) ReadError!usize { + var out_index: u16 = 0; + while (out_index < len) { + const available = conn.end - conn.start; + const left = buffer.len - out_index; + + if (available > 0) { + const can_read = @as(u16, @truncate(@min(available, left))); + + @memcpy(buffer[out_index..][0..can_read], conn.buf[conn.start..][0..can_read]); + out_index += can_read; + conn.start += can_read; + + continue; + } + + if (left > conn.buf.len) { + // skip the buffer if the output is large enough + return conn.conn.read(buffer[out_index..]); + } + + try conn.fill(); + } + + return out_index; + } + + pub fn read(conn: *MockBufferedConnection, buffer: []u8) ReadError!usize { + return conn.readAtLeast(buffer, 1); + } + + pub const ReadError = std.io.FixedBufferStream([]const u8).ReadError || error{EndOfStream}; + pub const Reader = std.io.Reader(*MockBufferedConnection, ReadError, read); + + pub fn reader(conn: *MockBufferedConnection) Reader { + return Reader{ .context = conn }; + } + + pub fn writeAll(conn: *MockBufferedConnection, buffer: []const u8) WriteError!void { + return conn.conn.writeAll(buffer); + } + + pub fn write(conn: *MockBufferedConnection, buffer: []const u8) WriteError!usize { + return conn.conn.write(buffer); + } + + pub const WriteError = std.io.FixedBufferStream([]const u8).WriteError; + pub const Writer = std.io.Writer(*MockBufferedConnection, WriteError, write); + + pub fn writer(conn: *MockBufferedConnection) Writer { + return Writer{ .context = conn }; + } +}; + +test "HeadersParser.read length" { + // mock BufferedConnection for read + var headers_buf: [256]u8 = undefined; + + var r = HeadersParser.init(&headers_buf); + const data = "GET / HTTP/1.1\r\nHost: localhost\r\nContent-Length: 5\r\n\r\nHello"; + + var conn: MockBufferedConnection = .{ + .conn = std.io.fixedBufferStream(data), + }; + + while (true) { // read headers + try conn.fill(); + + const nchecked = try r.checkCompleteHead(conn.peek()); + conn.drop(@intCast(nchecked)); + + if (r.state.isContent()) break; + } + + var buf: [8]u8 = undefined; + + r.next_chunk_length = 5; + const len = try r.read(&conn, &buf, false); + try std.testing.expectEqual(@as(usize, 5), len); + try std.testing.expectEqualStrings("Hello", buf[0..len]); + + try std.testing.expectEqualStrings("GET / HTTP/1.1\r\nHost: localhost\r\nContent-Length: 5\r\n\r\n", r.get()); +} + +test "HeadersParser.read chunked" { + // mock BufferedConnection for read + + var headers_buf: [256]u8 = undefined; + var r = HeadersParser.init(&headers_buf); + const data = "GET / HTTP/1.1\r\nHost: localhost\r\n\r\n2\r\nHe\r\n2\r\nll\r\n1\r\no\r\n0\r\n\r\n"; + + var conn: MockBufferedConnection = .{ + .conn = std.io.fixedBufferStream(data), + }; + + while (true) { // read headers + try conn.fill(); + + const nchecked = try r.checkCompleteHead(conn.peek()); + conn.drop(@intCast(nchecked)); + + if (r.state.isContent()) break; + } + var buf: [8]u8 = undefined; + + r.state = .chunk_head_size; + const len = try r.read(&conn, &buf, false); + try std.testing.expectEqual(@as(usize, 5), len); + try std.testing.expectEqualStrings("Hello", buf[0..len]); + + try std.testing.expectEqualStrings("GET / HTTP/1.1\r\nHost: localhost\r\n\r\n", r.get()); +} + +test "HeadersParser.read chunked trailer" { + // mock BufferedConnection for read + + var headers_buf: [256]u8 = undefined; + var r = HeadersParser.init(&headers_buf); + const data = "GET / HTTP/1.1\r\nHost: localhost\r\n\r\n2\r\nHe\r\n2\r\nll\r\n1\r\no\r\n0\r\nContent-Type: text/plain\r\n\r\n"; + + var conn: MockBufferedConnection = .{ + .conn = std.io.fixedBufferStream(data), + }; + + while (true) { // read headers + try conn.fill(); + + const nchecked = try r.checkCompleteHead(conn.peek()); + conn.drop(@intCast(nchecked)); + + if (r.state.isContent()) break; + } + var buf: [8]u8 = undefined; + + r.state = .chunk_head_size; + const len = try r.read(&conn, &buf, false); + try std.testing.expectEqual(@as(usize, 5), len); + try std.testing.expectEqualStrings("Hello", buf[0..len]); + + while (true) { // read headers + try conn.fill(); + + const nchecked = try r.checkCompleteHead(conn.peek()); + conn.drop(@intCast(nchecked)); + + if (r.state.isContent()) break; + } + + try std.testing.expectEqualStrings("GET / HTTP/1.1\r\nHost: localhost\r\n\r\nContent-Type: text/plain\r\n\r\n", r.get()); +} diff --git a/src/http/async/std/net.zig b/src/http/async/std/net.zig new file mode 100644 index 00000000..86863deb --- /dev/null +++ b/src/http/async/std/net.zig @@ -0,0 +1,2050 @@ +//! Cross-platform networking abstractions. + +const std = @import("std"); +const builtin = @import("builtin"); +const assert = std.debug.assert; +const net = @This(); +const mem = std.mem; +const posix = std.posix; +const fs = std.fs; +const io = std.io; +const native_endian = builtin.target.cpu.arch.endian(); +const native_os = builtin.os.tag; +const windows = std.os.windows; + +const Ctx = @import("http/Client.zig").Ctx; +const Cbk = @import("http/Client.zig").Cbk; + +// Windows 10 added support for unix sockets in build 17063, redstone 4 is the +// first release to support them. +pub const has_unix_sockets = switch (native_os) { + .windows => builtin.os.version_range.windows.isAtLeast(.win10_rs4) orelse false, + else => true, +}; + +pub const IPParseError = error{ + Overflow, + InvalidEnd, + InvalidCharacter, + Incomplete, +}; + +pub const IPv4ParseError = IPParseError || error{NonCanonical}; + +pub const IPv6ParseError = IPParseError || error{InvalidIpv4Mapping}; +pub const IPv6InterfaceError = posix.SocketError || posix.IoCtl_SIOCGIFINDEX_Error || error{NameTooLong}; +pub const IPv6ResolveError = IPv6ParseError || IPv6InterfaceError; + +pub const Address = extern union { + any: posix.sockaddr, + in: Ip4Address, + in6: Ip6Address, + un: if (has_unix_sockets) posix.sockaddr.un else void, + + /// Parse the given IP address string into an Address value. + /// It is recommended to use `resolveIp` instead, to handle + /// IPv6 link-local unix addresses. + pub fn parseIp(name: []const u8, port: u16) !Address { + if (parseIp4(name, port)) |ip4| return ip4 else |err| switch (err) { + error.Overflow, + error.InvalidEnd, + error.InvalidCharacter, + error.Incomplete, + error.NonCanonical, + => {}, + } + + if (parseIp6(name, port)) |ip6| return ip6 else |err| switch (err) { + error.Overflow, + error.InvalidEnd, + error.InvalidCharacter, + error.Incomplete, + error.InvalidIpv4Mapping, + => {}, + } + + return error.InvalidIPAddressFormat; + } + + pub fn resolveIp(name: []const u8, port: u16) !Address { + if (parseIp4(name, port)) |ip4| return ip4 else |err| switch (err) { + error.Overflow, + error.InvalidEnd, + error.InvalidCharacter, + error.Incomplete, + error.NonCanonical, + => {}, + } + + if (resolveIp6(name, port)) |ip6| return ip6 else |err| switch (err) { + error.Overflow, + error.InvalidEnd, + error.InvalidCharacter, + error.Incomplete, + error.InvalidIpv4Mapping, + => {}, + else => return err, + } + + return error.InvalidIPAddressFormat; + } + + pub fn parseExpectingFamily(name: []const u8, family: posix.sa_family_t, port: u16) !Address { + switch (family) { + posix.AF.INET => return parseIp4(name, port), + posix.AF.INET6 => return parseIp6(name, port), + posix.AF.UNSPEC => return parseIp(name, port), + else => unreachable, + } + } + + pub fn parseIp6(buf: []const u8, port: u16) IPv6ParseError!Address { + return .{ .in6 = try Ip6Address.parse(buf, port) }; + } + + pub fn resolveIp6(buf: []const u8, port: u16) IPv6ResolveError!Address { + return .{ .in6 = try Ip6Address.resolve(buf, port) }; + } + + pub fn parseIp4(buf: []const u8, port: u16) IPv4ParseError!Address { + return .{ .in = try Ip4Address.parse(buf, port) }; + } + + pub fn initIp4(addr: [4]u8, port: u16) Address { + return .{ .in = Ip4Address.init(addr, port) }; + } + + pub fn initIp6(addr: [16]u8, port: u16, flowinfo: u32, scope_id: u32) Address { + return .{ .in6 = Ip6Address.init(addr, port, flowinfo, scope_id) }; + } + + pub fn initUnix(path: []const u8) !Address { + var sock_addr = posix.sockaddr.un{ + .family = posix.AF.UNIX, + .path = undefined, + }; + + // Add 1 to ensure a terminating 0 is present in the path array for maximum portability. + if (path.len + 1 > sock_addr.path.len) return error.NameTooLong; + + @memset(&sock_addr.path, 0); + @memcpy(sock_addr.path[0..path.len], path); + + return .{ .un = sock_addr }; + } + + /// Returns the port in native endian. + /// Asserts that the address is ip4 or ip6. + pub fn getPort(self: Address) u16 { + return switch (self.any.family) { + posix.AF.INET => self.in.getPort(), + posix.AF.INET6 => self.in6.getPort(), + else => unreachable, + }; + } + + /// `port` is native-endian. + /// Asserts that the address is ip4 or ip6. + pub fn setPort(self: *Address, port: u16) void { + switch (self.any.family) { + posix.AF.INET => self.in.setPort(port), + posix.AF.INET6 => self.in6.setPort(port), + else => unreachable, + } + } + + /// Asserts that `addr` is an IP address. + /// This function will read past the end of the pointer, with a size depending + /// on the address family. + pub fn initPosix(addr: *align(4) const posix.sockaddr) Address { + switch (addr.family) { + posix.AF.INET => return Address{ .in = Ip4Address{ .sa = @as(*const posix.sockaddr.in, @ptrCast(addr)).* } }, + posix.AF.INET6 => return Address{ .in6 = Ip6Address{ .sa = @as(*const posix.sockaddr.in6, @ptrCast(addr)).* } }, + else => unreachable, + } + } + + pub fn format( + self: Address, + comptime fmt: []const u8, + options: std.fmt.FormatOptions, + out_stream: anytype, + ) !void { + if (fmt.len != 0) std.fmt.invalidFmtError(fmt, self); + switch (self.any.family) { + posix.AF.INET => try self.in.format(fmt, options, out_stream), + posix.AF.INET6 => try self.in6.format(fmt, options, out_stream), + posix.AF.UNIX => { + if (!has_unix_sockets) { + unreachable; + } + + try std.fmt.format(out_stream, "{s}", .{std.mem.sliceTo(&self.un.path, 0)}); + }, + else => unreachable, + } + } + + pub fn eql(a: Address, b: Address) bool { + const a_bytes = @as([*]const u8, @ptrCast(&a.any))[0..a.getOsSockLen()]; + const b_bytes = @as([*]const u8, @ptrCast(&b.any))[0..b.getOsSockLen()]; + return mem.eql(u8, a_bytes, b_bytes); + } + + pub fn getOsSockLen(self: Address) posix.socklen_t { + switch (self.any.family) { + posix.AF.INET => return self.in.getOsSockLen(), + posix.AF.INET6 => return self.in6.getOsSockLen(), + posix.AF.UNIX => { + if (!has_unix_sockets) { + unreachable; + } + + // Using the full length of the structure here is more portable than returning + // the number of bytes actually used by the currently stored path. + // This also is correct regardless if we are passing a socket address to the kernel + // (e.g. in bind, connect, sendto) since we ensure the path is 0 terminated in + // initUnix() or if we are receiving a socket address from the kernel and must + // provide the full buffer size (e.g. getsockname, getpeername, recvfrom, accept). + // + // To access the path, std.mem.sliceTo(&address.un.path, 0) should be used. + return @as(posix.socklen_t, @intCast(@sizeOf(posix.sockaddr.un))); + }, + + else => unreachable, + } + } + + pub const ListenError = posix.SocketError || posix.BindError || posix.ListenError || + posix.SetSockOptError || posix.GetSockNameError; + + pub const ListenOptions = struct { + /// How many connections the kernel will accept on the application's behalf. + /// If more than this many connections pool in the kernel, clients will start + /// seeing "Connection refused". + kernel_backlog: u31 = 128, + /// Sets SO_REUSEADDR and SO_REUSEPORT on POSIX. + /// Sets SO_REUSEADDR on Windows, which is roughly equivalent. + reuse_address: bool = false, + /// Deprecated. Does the same thing as reuse_address. + reuse_port: bool = false, + force_nonblocking: bool = false, + }; + + /// The returned `Server` has an open `stream`. + pub fn listen(address: Address, options: ListenOptions) ListenError!Server { + const nonblock: u32 = if (options.force_nonblocking) posix.SOCK.NONBLOCK else 0; + const sock_flags = posix.SOCK.STREAM | posix.SOCK.CLOEXEC | nonblock; + const proto: u32 = if (address.any.family == posix.AF.UNIX) 0 else posix.IPPROTO.TCP; + + const sockfd = try posix.socket(address.any.family, sock_flags, proto); + var s: Server = .{ + .listen_address = undefined, + .stream = .{ .handle = sockfd }, + }; + errdefer s.stream.close(); + + if (options.reuse_address or options.reuse_port) { + try posix.setsockopt( + sockfd, + posix.SOL.SOCKET, + posix.SO.REUSEADDR, + &mem.toBytes(@as(c_int, 1)), + ); + switch (native_os) { + .windows => {}, + else => try posix.setsockopt( + sockfd, + posix.SOL.SOCKET, + posix.SO.REUSEPORT, + &mem.toBytes(@as(c_int, 1)), + ), + } + } + + var socklen = address.getOsSockLen(); + try posix.bind(sockfd, &address.any, socklen); + try posix.listen(sockfd, options.kernel_backlog); + try posix.getsockname(sockfd, &s.listen_address.any, &socklen); + return s; + } +}; + +pub const Ip4Address = extern struct { + sa: posix.sockaddr.in, + + pub fn parse(buf: []const u8, port: u16) IPv4ParseError!Ip4Address { + var result: Ip4Address = .{ + .sa = .{ + .port = mem.nativeToBig(u16, port), + .addr = undefined, + }, + }; + const out_ptr = mem.asBytes(&result.sa.addr); + + var x: u8 = 0; + var index: u8 = 0; + var saw_any_digits = false; + var has_zero_prefix = false; + for (buf) |c| { + if (c == '.') { + if (!saw_any_digits) { + return error.InvalidCharacter; + } + if (index == 3) { + return error.InvalidEnd; + } + out_ptr[index] = x; + index += 1; + x = 0; + saw_any_digits = false; + has_zero_prefix = false; + } else if (c >= '0' and c <= '9') { + if (c == '0' and !saw_any_digits) { + has_zero_prefix = true; + } else if (has_zero_prefix) { + return error.NonCanonical; + } + saw_any_digits = true; + x = try std.math.mul(u8, x, 10); + x = try std.math.add(u8, x, c - '0'); + } else { + return error.InvalidCharacter; + } + } + if (index == 3 and saw_any_digits) { + out_ptr[index] = x; + return result; + } + + return error.Incomplete; + } + + pub fn resolveIp(name: []const u8, port: u16) !Ip4Address { + if (parse(name, port)) |ip4| return ip4 else |err| switch (err) { + error.Overflow, + error.InvalidEnd, + error.InvalidCharacter, + error.Incomplete, + error.NonCanonical, + => {}, + } + return error.InvalidIPAddressFormat; + } + + pub fn init(addr: [4]u8, port: u16) Ip4Address { + return Ip4Address{ + .sa = posix.sockaddr.in{ + .port = mem.nativeToBig(u16, port), + .addr = @as(*align(1) const u32, @ptrCast(&addr)).*, + }, + }; + } + + /// Returns the port in native endian. + /// Asserts that the address is ip4 or ip6. + pub fn getPort(self: Ip4Address) u16 { + return mem.bigToNative(u16, self.sa.port); + } + + /// `port` is native-endian. + /// Asserts that the address is ip4 or ip6. + pub fn setPort(self: *Ip4Address, port: u16) void { + self.sa.port = mem.nativeToBig(u16, port); + } + + pub fn format( + self: Ip4Address, + comptime fmt: []const u8, + options: std.fmt.FormatOptions, + out_stream: anytype, + ) !void { + if (fmt.len != 0) std.fmt.invalidFmtError(fmt, self); + _ = options; + const bytes = @as(*const [4]u8, @ptrCast(&self.sa.addr)); + try std.fmt.format(out_stream, "{}.{}.{}.{}:{}", .{ + bytes[0], + bytes[1], + bytes[2], + bytes[3], + self.getPort(), + }); + } + + pub fn getOsSockLen(self: Ip4Address) posix.socklen_t { + _ = self; + return @sizeOf(posix.sockaddr.in); + } +}; + +pub const Ip6Address = extern struct { + sa: posix.sockaddr.in6, + + /// Parse a given IPv6 address string into an Address. + /// Assumes the Scope ID of the address is fully numeric. + /// For non-numeric addresses, see `resolveIp6`. + pub fn parse(buf: []const u8, port: u16) IPv6ParseError!Ip6Address { + var result = Ip6Address{ + .sa = posix.sockaddr.in6{ + .scope_id = 0, + .port = mem.nativeToBig(u16, port), + .flowinfo = 0, + .addr = undefined, + }, + }; + var ip_slice: *[16]u8 = result.sa.addr[0..]; + + var tail: [16]u8 = undefined; + + var x: u16 = 0; + var saw_any_digits = false; + var index: u8 = 0; + var scope_id = false; + var abbrv = false; + for (buf, 0..) |c, i| { + if (scope_id) { + if (c >= '0' and c <= '9') { + const digit = c - '0'; + { + const ov = @mulWithOverflow(result.sa.scope_id, 10); + if (ov[1] != 0) return error.Overflow; + result.sa.scope_id = ov[0]; + } + { + const ov = @addWithOverflow(result.sa.scope_id, digit); + if (ov[1] != 0) return error.Overflow; + result.sa.scope_id = ov[0]; + } + } else { + return error.InvalidCharacter; + } + } else if (c == ':') { + if (!saw_any_digits) { + if (abbrv) return error.InvalidCharacter; // ':::' + if (i != 0) abbrv = true; + @memset(ip_slice[index..], 0); + ip_slice = tail[0..]; + index = 0; + continue; + } + if (index == 14) { + return error.InvalidEnd; + } + ip_slice[index] = @as(u8, @truncate(x >> 8)); + index += 1; + ip_slice[index] = @as(u8, @truncate(x)); + index += 1; + + x = 0; + saw_any_digits = false; + } else if (c == '%') { + if (!saw_any_digits) { + return error.InvalidCharacter; + } + scope_id = true; + saw_any_digits = false; + } else if (c == '.') { + if (!abbrv or ip_slice[0] != 0xff or ip_slice[1] != 0xff) { + // must start with '::ffff:' + return error.InvalidIpv4Mapping; + } + const start_index = mem.lastIndexOfScalar(u8, buf[0..i], ':').? + 1; + const addr = (Ip4Address.parse(buf[start_index..], 0) catch { + return error.InvalidIpv4Mapping; + }).sa.addr; + ip_slice = result.sa.addr[0..]; + ip_slice[10] = 0xff; + ip_slice[11] = 0xff; + + const ptr = mem.sliceAsBytes(@as(*const [1]u32, &addr)[0..]); + + ip_slice[12] = ptr[0]; + ip_slice[13] = ptr[1]; + ip_slice[14] = ptr[2]; + ip_slice[15] = ptr[3]; + return result; + } else { + const digit = try std.fmt.charToDigit(c, 16); + { + const ov = @mulWithOverflow(x, 16); + if (ov[1] != 0) return error.Overflow; + x = ov[0]; + } + { + const ov = @addWithOverflow(x, digit); + if (ov[1] != 0) return error.Overflow; + x = ov[0]; + } + saw_any_digits = true; + } + } + + if (!saw_any_digits and !abbrv) { + return error.Incomplete; + } + if (!abbrv and index < 14) { + return error.Incomplete; + } + + if (index == 14) { + ip_slice[14] = @as(u8, @truncate(x >> 8)); + ip_slice[15] = @as(u8, @truncate(x)); + return result; + } else { + ip_slice[index] = @as(u8, @truncate(x >> 8)); + index += 1; + ip_slice[index] = @as(u8, @truncate(x)); + index += 1; + @memcpy(result.sa.addr[16 - index ..][0..index], ip_slice[0..index]); + return result; + } + } + + pub fn resolve(buf: []const u8, port: u16) IPv6ResolveError!Ip6Address { + // TODO: Unify the implementations of resolveIp6 and parseIp6. + var result = Ip6Address{ + .sa = posix.sockaddr.in6{ + .scope_id = 0, + .port = mem.nativeToBig(u16, port), + .flowinfo = 0, + .addr = undefined, + }, + }; + var ip_slice: *[16]u8 = result.sa.addr[0..]; + + var tail: [16]u8 = undefined; + + var x: u16 = 0; + var saw_any_digits = false; + var index: u8 = 0; + var abbrv = false; + + var scope_id = false; + var scope_id_value: [posix.IFNAMESIZE - 1]u8 = undefined; + var scope_id_index: usize = 0; + + for (buf, 0..) |c, i| { + if (scope_id) { + // Handling of percent-encoding should be for an URI library. + if ((c >= '0' and c <= '9') or + (c >= 'A' and c <= 'Z') or + (c >= 'a' and c <= 'z') or + (c == '-') or (c == '.') or (c == '_') or (c == '~')) + { + if (scope_id_index >= scope_id_value.len) { + return error.Overflow; + } + + scope_id_value[scope_id_index] = c; + scope_id_index += 1; + } else { + return error.InvalidCharacter; + } + } else if (c == ':') { + if (!saw_any_digits) { + if (abbrv) return error.InvalidCharacter; // ':::' + if (i != 0) abbrv = true; + @memset(ip_slice[index..], 0); + ip_slice = tail[0..]; + index = 0; + continue; + } + if (index == 14) { + return error.InvalidEnd; + } + ip_slice[index] = @as(u8, @truncate(x >> 8)); + index += 1; + ip_slice[index] = @as(u8, @truncate(x)); + index += 1; + + x = 0; + saw_any_digits = false; + } else if (c == '%') { + if (!saw_any_digits) { + return error.InvalidCharacter; + } + scope_id = true; + saw_any_digits = false; + } else if (c == '.') { + if (!abbrv or ip_slice[0] != 0xff or ip_slice[1] != 0xff) { + // must start with '::ffff:' + return error.InvalidIpv4Mapping; + } + const start_index = mem.lastIndexOfScalar(u8, buf[0..i], ':').? + 1; + const addr = (Ip4Address.parse(buf[start_index..], 0) catch { + return error.InvalidIpv4Mapping; + }).sa.addr; + ip_slice = result.sa.addr[0..]; + ip_slice[10] = 0xff; + ip_slice[11] = 0xff; + + const ptr = mem.sliceAsBytes(@as(*const [1]u32, &addr)[0..]); + + ip_slice[12] = ptr[0]; + ip_slice[13] = ptr[1]; + ip_slice[14] = ptr[2]; + ip_slice[15] = ptr[3]; + return result; + } else { + const digit = try std.fmt.charToDigit(c, 16); + { + const ov = @mulWithOverflow(x, 16); + if (ov[1] != 0) return error.Overflow; + x = ov[0]; + } + { + const ov = @addWithOverflow(x, digit); + if (ov[1] != 0) return error.Overflow; + x = ov[0]; + } + saw_any_digits = true; + } + } + + if (!saw_any_digits and !abbrv) { + return error.Incomplete; + } + + if (scope_id and scope_id_index == 0) { + return error.Incomplete; + } + + var resolved_scope_id: u32 = 0; + if (scope_id_index > 0) { + const scope_id_str = scope_id_value[0..scope_id_index]; + resolved_scope_id = std.fmt.parseInt(u32, scope_id_str, 10) catch |err| blk: { + if (err != error.InvalidCharacter) return err; + break :blk try if_nametoindex(scope_id_str); + }; + } + + result.sa.scope_id = resolved_scope_id; + + if (index == 14) { + ip_slice[14] = @as(u8, @truncate(x >> 8)); + ip_slice[15] = @as(u8, @truncate(x)); + return result; + } else { + ip_slice[index] = @as(u8, @truncate(x >> 8)); + index += 1; + ip_slice[index] = @as(u8, @truncate(x)); + index += 1; + @memcpy(result.sa.addr[16 - index ..][0..index], ip_slice[0..index]); + return result; + } + } + + pub fn init(addr: [16]u8, port: u16, flowinfo: u32, scope_id: u32) Ip6Address { + return Ip6Address{ + .sa = posix.sockaddr.in6{ + .addr = addr, + .port = mem.nativeToBig(u16, port), + .flowinfo = flowinfo, + .scope_id = scope_id, + }, + }; + } + + /// Returns the port in native endian. + /// Asserts that the address is ip4 or ip6. + pub fn getPort(self: Ip6Address) u16 { + return mem.bigToNative(u16, self.sa.port); + } + + /// `port` is native-endian. + /// Asserts that the address is ip4 or ip6. + pub fn setPort(self: *Ip6Address, port: u16) void { + self.sa.port = mem.nativeToBig(u16, port); + } + + pub fn format( + self: Ip6Address, + comptime fmt: []const u8, + options: std.fmt.FormatOptions, + out_stream: anytype, + ) !void { + if (fmt.len != 0) std.fmt.invalidFmtError(fmt, self); + _ = options; + const port = mem.bigToNative(u16, self.sa.port); + if (mem.eql(u8, self.sa.addr[0..12], &[_]u8{ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xff, 0xff })) { + try std.fmt.format(out_stream, "[::ffff:{}.{}.{}.{}]:{}", .{ + self.sa.addr[12], + self.sa.addr[13], + self.sa.addr[14], + self.sa.addr[15], + port, + }); + return; + } + const big_endian_parts = @as(*align(1) const [8]u16, @ptrCast(&self.sa.addr)); + const native_endian_parts = switch (native_endian) { + .big => big_endian_parts.*, + .little => blk: { + var buf: [8]u16 = undefined; + for (big_endian_parts, 0..) |part, i| { + buf[i] = mem.bigToNative(u16, part); + } + break :blk buf; + }, + }; + try out_stream.writeAll("["); + var i: usize = 0; + var abbrv = false; + while (i < native_endian_parts.len) : (i += 1) { + if (native_endian_parts[i] == 0) { + if (!abbrv) { + try out_stream.writeAll(if (i == 0) "::" else ":"); + abbrv = true; + } + continue; + } + try std.fmt.format(out_stream, "{x}", .{native_endian_parts[i]}); + if (i != native_endian_parts.len - 1) { + try out_stream.writeAll(":"); + } + } + try std.fmt.format(out_stream, "]:{}", .{port}); + } + + pub fn getOsSockLen(self: Ip6Address) posix.socklen_t { + _ = self; + return @sizeOf(posix.sockaddr.in6); + } +}; + +pub fn connectUnixSocket(path: []const u8) !Stream { + const opt_non_block = 0; + const sockfd = try posix.socket( + posix.AF.UNIX, + posix.SOCK.STREAM | posix.SOCK.CLOEXEC | opt_non_block, + 0, + ); + errdefer Stream.close(.{ .handle = sockfd }); + + var addr = try std.net.Address.initUnix(path); + try posix.connect(sockfd, &addr.any, addr.getOsSockLen()); + + return .{ .handle = sockfd }; +} + +fn if_nametoindex(name: []const u8) IPv6InterfaceError!u32 { + if (native_os == .linux) { + var ifr: posix.ifreq = undefined; + const sockfd = try posix.socket(posix.AF.UNIX, posix.SOCK.DGRAM | posix.SOCK.CLOEXEC, 0); + defer Stream.close(.{ .handle = sockfd }); + + @memcpy(ifr.ifrn.name[0..name.len], name); + ifr.ifrn.name[name.len] = 0; + + // TODO investigate if this needs to be integrated with evented I/O. + try posix.ioctl_SIOCGIFINDEX(sockfd, &ifr); + + return @bitCast(ifr.ifru.ivalue); + } + + if (native_os.isDarwin()) { + if (name.len >= posix.IFNAMESIZE) + return error.NameTooLong; + + var if_name: [posix.IFNAMESIZE:0]u8 = undefined; + @memcpy(if_name[0..name.len], name); + if_name[name.len] = 0; + const if_slice = if_name[0..name.len :0]; + const index = std.c.if_nametoindex(if_slice); + if (index == 0) + return error.InterfaceNotFound; + return @as(u32, @bitCast(index)); + } + + @compileError("std.net.if_nametoindex unimplemented for this OS"); +} + +pub const AddressList = struct { + arena: std.heap.ArenaAllocator, + addrs: []Address, + canon_name: ?[]u8, + + pub fn deinit(self: *AddressList) void { + // Here we copy the arena allocator into stack memory, because + // otherwise it would destroy itself while it was still working. + var arena = self.arena; + arena.deinit(); + // self is destroyed + } +}; + +pub const TcpConnectToHostError = GetAddressListError || TcpConnectToAddressError; + +/// All memory allocated with `allocator` will be freed before this function returns. +pub fn tcpConnectToHost(allocator: mem.Allocator, name: []const u8, port: u16) TcpConnectToHostError!Stream { + const list = try getAddressList(allocator, name, port); + defer list.deinit(); + + if (list.addrs.len == 0) return error.UnknownHostName; + + for (list.addrs) |addr| { + return tcpConnectToAddress(addr) catch |err| switch (err) { + error.ConnectionRefused => { + continue; + }, + else => return err, + }; + } + return posix.ConnectError.ConnectionRefused; +} + +pub const TcpConnectToAddressError = posix.SocketError || posix.ConnectError; + +pub fn tcpConnectToAddress(address: Address) TcpConnectToAddressError!Stream { + const nonblock = 0; + const sock_flags = posix.SOCK.STREAM | nonblock | + (if (native_os == .windows) 0 else posix.SOCK.CLOEXEC); + const sockfd = try posix.socket(address.any.family, sock_flags, posix.IPPROTO.TCP); + errdefer Stream.close(.{ .handle = sockfd }); + + try posix.connect(sockfd, &address.any, address.getOsSockLen()); + + return Stream{ .handle = sockfd }; +} + +const GetAddressListError = std.mem.Allocator.Error || std.fs.File.OpenError || std.fs.File.ReadError || posix.SocketError || posix.BindError || posix.SetSockOptError || error{ + // TODO: break this up into error sets from the various underlying functions + + TemporaryNameServerFailure, + NameServerFailure, + AddressFamilyNotSupported, + UnknownHostName, + ServiceUnavailable, + Unexpected, + + HostLacksNetworkAddresses, + + InvalidCharacter, + InvalidEnd, + NonCanonical, + Overflow, + Incomplete, + InvalidIpv4Mapping, + InvalidIPAddressFormat, + + InterfaceNotFound, + FileSystem, +}; + +/// Call `AddressList.deinit` on the result. +pub fn getAddressList(allocator: mem.Allocator, name: []const u8, port: u16) GetAddressListError!*AddressList { + const result = blk: { + var arena = std.heap.ArenaAllocator.init(allocator); + errdefer arena.deinit(); + + const result = try arena.allocator().create(AddressList); + result.* = AddressList{ + .arena = arena, + .addrs = undefined, + .canon_name = null, + }; + break :blk result; + }; + const arena = result.arena.allocator(); + errdefer result.deinit(); + + if (native_os == .windows) { + const name_c = try allocator.dupeZ(u8, name); + defer allocator.free(name_c); + + const port_c = try std.fmt.allocPrintZ(allocator, "{}", .{port}); + defer allocator.free(port_c); + + const ws2_32 = windows.ws2_32; + const hints = posix.addrinfo{ + .flags = ws2_32.AI.NUMERICSERV, + .family = posix.AF.UNSPEC, + .socktype = posix.SOCK.STREAM, + .protocol = posix.IPPROTO.TCP, + .canonname = null, + .addr = null, + .addrlen = 0, + .next = null, + }; + var res: ?*posix.addrinfo = null; + var first = true; + while (true) { + const rc = ws2_32.getaddrinfo(name_c.ptr, port_c.ptr, &hints, &res); + switch (@as(windows.ws2_32.WinsockError, @enumFromInt(@as(u16, @intCast(rc))))) { + @as(windows.ws2_32.WinsockError, @enumFromInt(0)) => break, + .WSATRY_AGAIN => return error.TemporaryNameServerFailure, + .WSANO_RECOVERY => return error.NameServerFailure, + .WSAEAFNOSUPPORT => return error.AddressFamilyNotSupported, + .WSA_NOT_ENOUGH_MEMORY => return error.OutOfMemory, + .WSAHOST_NOT_FOUND => return error.UnknownHostName, + .WSATYPE_NOT_FOUND => return error.ServiceUnavailable, + .WSAEINVAL => unreachable, + .WSAESOCKTNOSUPPORT => unreachable, + .WSANOTINITIALISED => { + if (!first) return error.Unexpected; + first = false; + try windows.callWSAStartup(); + continue; + }, + else => |err| return windows.unexpectedWSAError(err), + } + } + defer ws2_32.freeaddrinfo(res); + + const addr_count = blk: { + var count: usize = 0; + var it = res; + while (it) |info| : (it = info.next) { + if (info.addr != null) { + count += 1; + } + } + break :blk count; + }; + result.addrs = try arena.alloc(Address, addr_count); + + var it = res; + var i: usize = 0; + while (it) |info| : (it = info.next) { + const addr = info.addr orelse continue; + result.addrs[i] = Address.initPosix(@alignCast(addr)); + + if (info.canonname) |n| { + if (result.canon_name == null) { + result.canon_name = try arena.dupe(u8, mem.sliceTo(n, 0)); + } + } + i += 1; + } + + return result; + } + + if (builtin.link_libc) { + const name_c = try allocator.dupeZ(u8, name); + defer allocator.free(name_c); + + const port_c = try std.fmt.allocPrintZ(allocator, "{}", .{port}); + defer allocator.free(port_c); + + const sys = if (native_os == .windows) windows.ws2_32 else posix.system; + const hints = posix.addrinfo{ + .flags = sys.AI.NUMERICSERV, + .family = posix.AF.UNSPEC, + .socktype = posix.SOCK.STREAM, + .protocol = posix.IPPROTO.TCP, + .canonname = null, + .addr = null, + .addrlen = 0, + .next = null, + }; + var res: ?*posix.addrinfo = null; + switch (sys.getaddrinfo(name_c.ptr, port_c.ptr, &hints, &res)) { + @as(sys.EAI, @enumFromInt(0)) => {}, + .ADDRFAMILY => return error.HostLacksNetworkAddresses, + .AGAIN => return error.TemporaryNameServerFailure, + .BADFLAGS => unreachable, // Invalid hints + .FAIL => return error.NameServerFailure, + .FAMILY => return error.AddressFamilyNotSupported, + .MEMORY => return error.OutOfMemory, + .NODATA => return error.HostLacksNetworkAddresses, + .NONAME => return error.UnknownHostName, + .SERVICE => return error.ServiceUnavailable, + .SOCKTYPE => unreachable, // Invalid socket type requested in hints + .SYSTEM => switch (posix.errno(-1)) { + else => |e| return posix.unexpectedErrno(e), + }, + else => unreachable, + } + defer if (res) |some| sys.freeaddrinfo(some); + + const addr_count = blk: { + var count: usize = 0; + var it = res; + while (it) |info| : (it = info.next) { + if (info.addr != null) { + count += 1; + } + } + break :blk count; + }; + result.addrs = try arena.alloc(Address, addr_count); + + var it = res; + var i: usize = 0; + while (it) |info| : (it = info.next) { + const addr = info.addr orelse continue; + result.addrs[i] = Address.initPosix(@alignCast(addr)); + + if (info.canonname) |n| { + if (result.canon_name == null) { + result.canon_name = try arena.dupe(u8, mem.sliceTo(n, 0)); + } + } + i += 1; + } + + return result; + } + + if (native_os == .linux) { + const flags = std.c.AI.NUMERICSERV; + const family = posix.AF.UNSPEC; + var lookup_addrs = std.ArrayList(LookupAddr).init(allocator); + defer lookup_addrs.deinit(); + + var canon = std.ArrayList(u8).init(arena); + defer canon.deinit(); + + try linuxLookupName(&lookup_addrs, &canon, name, family, flags, port); + + result.addrs = try arena.alloc(Address, lookup_addrs.items.len); + if (canon.items.len != 0) { + result.canon_name = try canon.toOwnedSlice(); + } + + for (lookup_addrs.items, 0..) |lookup_addr, i| { + result.addrs[i] = lookup_addr.addr; + assert(result.addrs[i].getPort() == port); + } + + return result; + } + @compileError("std.net.getAddressList unimplemented for this OS"); +} + +const LookupAddr = struct { + addr: Address, + sortkey: i32 = 0, +}; + +const DAS_USABLE = 0x40000000; +const DAS_MATCHINGSCOPE = 0x20000000; +const DAS_MATCHINGLABEL = 0x10000000; +const DAS_PREC_SHIFT = 20; +const DAS_SCOPE_SHIFT = 16; +const DAS_PREFIX_SHIFT = 8; +const DAS_ORDER_SHIFT = 0; + +fn linuxLookupName( + addrs: *std.ArrayList(LookupAddr), + canon: *std.ArrayList(u8), + opt_name: ?[]const u8, + family: posix.sa_family_t, + flags: u32, + port: u16, +) !void { + if (opt_name) |name| { + // reject empty name and check len so it fits into temp bufs + canon.items.len = 0; + try canon.appendSlice(name); + if (Address.parseExpectingFamily(name, family, port)) |addr| { + try addrs.append(LookupAddr{ .addr = addr }); + } else |name_err| if ((flags & std.c.AI.NUMERICHOST) != 0) { + return name_err; + } else { + try linuxLookupNameFromHosts(addrs, canon, name, family, port); + if (addrs.items.len == 0) { + // RFC 6761 Section 6.3.3 + // Name resolution APIs and libraries SHOULD recognize localhost + // names as special and SHOULD always return the IP loopback address + // for address queries and negative responses for all other query + // types. + + // Check for equal to "localhost(.)" or ends in ".localhost(.)" + const localhost = if (name[name.len - 1] == '.') "localhost." else "localhost"; + if (mem.endsWith(u8, name, localhost) and (name.len == localhost.len or name[name.len - localhost.len] == '.')) { + try addrs.append(LookupAddr{ .addr = .{ .in = Ip4Address.parse("127.0.0.1", port) catch unreachable } }); + try addrs.append(LookupAddr{ .addr = .{ .in6 = Ip6Address.parse("::1", port) catch unreachable } }); + return; + } + + try linuxLookupNameFromDnsSearch(addrs, canon, name, family, port); + } + } + } else { + try canon.resize(0); + try linuxLookupNameFromNull(addrs, family, flags, port); + } + if (addrs.items.len == 0) return error.UnknownHostName; + + // No further processing is needed if there are fewer than 2 + // results or if there are only IPv4 results. + if (addrs.items.len == 1 or family == posix.AF.INET) return; + const all_ip4 = for (addrs.items) |addr| { + if (addr.addr.any.family != posix.AF.INET) break false; + } else true; + if (all_ip4) return; + + // The following implements a subset of RFC 3484/6724 destination + // address selection by generating a single 31-bit sort key for + // each address. Rules 3, 4, and 7 are omitted for having + // excessive runtime and code size cost and dubious benefit. + // So far the label/precedence table cannot be customized. + // This implementation is ported from musl libc. + // A more idiomatic "ziggy" implementation would be welcome. + for (addrs.items, 0..) |*addr, i| { + var key: i32 = 0; + var sa6: posix.sockaddr.in6 = undefined; + @memset(@as([*]u8, @ptrCast(&sa6))[0..@sizeOf(posix.sockaddr.in6)], 0); + var da6 = posix.sockaddr.in6{ + .family = posix.AF.INET6, + .scope_id = addr.addr.in6.sa.scope_id, + .port = 65535, + .flowinfo = 0, + .addr = [1]u8{0} ** 16, + }; + var sa4: posix.sockaddr.in = undefined; + @memset(@as([*]u8, @ptrCast(&sa4))[0..@sizeOf(posix.sockaddr.in)], 0); + var da4 = posix.sockaddr.in{ + .family = posix.AF.INET, + .port = 65535, + .addr = 0, + .zero = [1]u8{0} ** 8, + }; + var sa: *align(4) posix.sockaddr = undefined; + var da: *align(4) posix.sockaddr = undefined; + var salen: posix.socklen_t = undefined; + var dalen: posix.socklen_t = undefined; + if (addr.addr.any.family == posix.AF.INET6) { + da6.addr = addr.addr.in6.sa.addr; + da = @ptrCast(&da6); + dalen = @sizeOf(posix.sockaddr.in6); + sa = @ptrCast(&sa6); + salen = @sizeOf(posix.sockaddr.in6); + } else { + sa6.addr[0..12].* = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff".*; + da6.addr[0..12].* = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff".*; + mem.writeInt(u32, da6.addr[12..], addr.addr.in.sa.addr, native_endian); + da4.addr = addr.addr.in.sa.addr; + da = @ptrCast(&da4); + dalen = @sizeOf(posix.sockaddr.in); + sa = @ptrCast(&sa4); + salen = @sizeOf(posix.sockaddr.in); + } + const dpolicy = policyOf(da6.addr); + const dscope: i32 = scopeOf(da6.addr); + const dlabel = dpolicy.label; + const dprec: i32 = dpolicy.prec; + const MAXADDRS = 3; + var prefixlen: i32 = 0; + const sock_flags = posix.SOCK.DGRAM | posix.SOCK.CLOEXEC; + if (posix.socket(addr.addr.any.family, sock_flags, posix.IPPROTO.UDP)) |fd| syscalls: { + defer Stream.close(.{ .handle = fd }); + posix.connect(fd, da, dalen) catch break :syscalls; + key |= DAS_USABLE; + posix.getsockname(fd, sa, &salen) catch break :syscalls; + if (addr.addr.any.family == posix.AF.INET) { + mem.writeInt(u32, sa6.addr[12..16], sa4.addr, native_endian); + } + if (dscope == @as(i32, scopeOf(sa6.addr))) key |= DAS_MATCHINGSCOPE; + if (dlabel == labelOf(sa6.addr)) key |= DAS_MATCHINGLABEL; + prefixlen = prefixMatch(sa6.addr, da6.addr); + } else |_| {} + key |= dprec << DAS_PREC_SHIFT; + key |= (15 - dscope) << DAS_SCOPE_SHIFT; + key |= prefixlen << DAS_PREFIX_SHIFT; + key |= (MAXADDRS - @as(i32, @intCast(i))) << DAS_ORDER_SHIFT; + addr.sortkey = key; + } + mem.sort(LookupAddr, addrs.items, {}, addrCmpLessThan); +} + +const Policy = struct { + addr: [16]u8, + len: u8, + mask: u8, + prec: u8, + label: u8, +}; + +const defined_policies = [_]Policy{ + Policy{ + .addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01".*, + .len = 15, + .mask = 0xff, + .prec = 50, + .label = 0, + }, + Policy{ + .addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\x00\x00\x00\x00".*, + .len = 11, + .mask = 0xff, + .prec = 35, + .label = 4, + }, + Policy{ + .addr = "\x20\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00".*, + .len = 1, + .mask = 0xff, + .prec = 30, + .label = 2, + }, + Policy{ + .addr = "\x20\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00".*, + .len = 3, + .mask = 0xff, + .prec = 5, + .label = 5, + }, + Policy{ + .addr = "\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00".*, + .len = 0, + .mask = 0xfe, + .prec = 3, + .label = 13, + }, + // These are deprecated and/or returned to the address + // pool, so despite the RFC, treating them as special + // is probably wrong. + // { "", 11, 0xff, 1, 3 }, + // { "\xfe\xc0", 1, 0xc0, 1, 11 }, + // { "\x3f\xfe", 1, 0xff, 1, 12 }, + // Last rule must match all addresses to stop loop. + Policy{ + .addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00".*, + .len = 0, + .mask = 0, + .prec = 40, + .label = 1, + }, +}; + +fn policyOf(a: [16]u8) *const Policy { + for (&defined_policies) |*policy| { + if (!mem.eql(u8, a[0..policy.len], policy.addr[0..policy.len])) continue; + if ((a[policy.len] & policy.mask) != policy.addr[policy.len]) continue; + return policy; + } + unreachable; +} + +fn scopeOf(a: [16]u8) u8 { + if (IN6_IS_ADDR_MULTICAST(a)) return a[1] & 15; + if (IN6_IS_ADDR_LINKLOCAL(a)) return 2; + if (IN6_IS_ADDR_LOOPBACK(a)) return 2; + if (IN6_IS_ADDR_SITELOCAL(a)) return 5; + return 14; +} + +fn prefixMatch(s: [16]u8, d: [16]u8) u8 { + // TODO: This FIXME inherited from porting from musl libc. + // I don't want this to go into zig std lib 1.0.0. + + // FIXME: The common prefix length should be limited to no greater + // than the nominal length of the prefix portion of the source + // address. However the definition of the source prefix length is + // not clear and thus this limiting is not yet implemented. + var i: u8 = 0; + while (i < 128 and ((s[i / 8] ^ d[i / 8]) & (@as(u8, 128) >> @as(u3, @intCast(i % 8)))) == 0) : (i += 1) {} + return i; +} + +fn labelOf(a: [16]u8) u8 { + return policyOf(a).label; +} + +fn IN6_IS_ADDR_MULTICAST(a: [16]u8) bool { + return a[0] == 0xff; +} + +fn IN6_IS_ADDR_LINKLOCAL(a: [16]u8) bool { + return a[0] == 0xfe and (a[1] & 0xc0) == 0x80; +} + +fn IN6_IS_ADDR_LOOPBACK(a: [16]u8) bool { + return a[0] == 0 and a[1] == 0 and + a[2] == 0 and + a[12] == 0 and a[13] == 0 and + a[14] == 0 and a[15] == 1; +} + +fn IN6_IS_ADDR_SITELOCAL(a: [16]u8) bool { + return a[0] == 0xfe and (a[1] & 0xc0) == 0xc0; +} + +// Parameters `b` and `a` swapped to make this descending. +fn addrCmpLessThan(context: void, b: LookupAddr, a: LookupAddr) bool { + _ = context; + return a.sortkey < b.sortkey; +} + +fn linuxLookupNameFromNull( + addrs: *std.ArrayList(LookupAddr), + family: posix.sa_family_t, + flags: u32, + port: u16, +) !void { + if ((flags & std.c.AI.PASSIVE) != 0) { + if (family != posix.AF.INET6) { + (try addrs.addOne()).* = LookupAddr{ + .addr = Address.initIp4([1]u8{0} ** 4, port), + }; + } + if (family != posix.AF.INET) { + (try addrs.addOne()).* = LookupAddr{ + .addr = Address.initIp6([1]u8{0} ** 16, port, 0, 0), + }; + } + } else { + if (family != posix.AF.INET6) { + (try addrs.addOne()).* = LookupAddr{ + .addr = Address.initIp4([4]u8{ 127, 0, 0, 1 }, port), + }; + } + if (family != posix.AF.INET) { + (try addrs.addOne()).* = LookupAddr{ + .addr = Address.initIp6(([1]u8{0} ** 15) ++ [1]u8{1}, port, 0, 0), + }; + } + } +} + +fn linuxLookupNameFromHosts( + addrs: *std.ArrayList(LookupAddr), + canon: *std.ArrayList(u8), + name: []const u8, + family: posix.sa_family_t, + port: u16, +) !void { + const file = fs.openFileAbsoluteZ("/etc/hosts", .{}) catch |err| switch (err) { + error.FileNotFound, + error.NotDir, + error.AccessDenied, + => return, + else => |e| return e, + }; + defer file.close(); + + var buffered_reader = std.io.bufferedReader(file.reader()); + const reader = buffered_reader.reader(); + var line_buf: [512]u8 = undefined; + while (reader.readUntilDelimiterOrEof(&line_buf, '\n') catch |err| switch (err) { + error.StreamTooLong => blk: { + // Skip to the delimiter in the reader, to fix parsing + try reader.skipUntilDelimiterOrEof('\n'); + // Use the truncated line. A truncated comment or hostname will be handled correctly. + break :blk &line_buf; + }, + else => |e| return e, + }) |line| { + var split_it = mem.splitScalar(u8, line, '#'); + const no_comment_line = split_it.first(); + + var line_it = mem.tokenizeAny(u8, no_comment_line, " \t"); + const ip_text = line_it.next() orelse continue; + var first_name_text: ?[]const u8 = null; + while (line_it.next()) |name_text| { + if (first_name_text == null) first_name_text = name_text; + if (mem.eql(u8, name_text, name)) { + break; + } + } else continue; + + const addr = Address.parseExpectingFamily(ip_text, family, port) catch |err| switch (err) { + error.Overflow, + error.InvalidEnd, + error.InvalidCharacter, + error.Incomplete, + error.InvalidIPAddressFormat, + error.InvalidIpv4Mapping, + error.NonCanonical, + => continue, + }; + try addrs.append(LookupAddr{ .addr = addr }); + + // first name is canonical name + const name_text = first_name_text.?; + if (isValidHostName(name_text)) { + canon.items.len = 0; + try canon.appendSlice(name_text); + } + } +} + +pub fn isValidHostName(hostname: []const u8) bool { + if (hostname.len >= 254) return false; + if (!std.unicode.utf8ValidateSlice(hostname)) return false; + for (hostname) |byte| { + if (!std.ascii.isASCII(byte) or byte == '.' or byte == '-' or std.ascii.isAlphanumeric(byte)) { + continue; + } + return false; + } + return true; +} + +fn linuxLookupNameFromDnsSearch( + addrs: *std.ArrayList(LookupAddr), + canon: *std.ArrayList(u8), + name: []const u8, + family: posix.sa_family_t, + port: u16, +) !void { + var rc: ResolvConf = undefined; + try getResolvConf(addrs.allocator, &rc); + defer rc.deinit(); + + // Count dots, suppress search when >=ndots or name ends in + // a dot, which is an explicit request for global scope. + var dots: usize = 0; + for (name) |byte| { + if (byte == '.') dots += 1; + } + + const search = if (dots >= rc.ndots or mem.endsWith(u8, name, ".")) + "" + else + rc.search.items; + + var canon_name = name; + + // Strip final dot for canon, fail if multiple trailing dots. + if (mem.endsWith(u8, canon_name, ".")) canon_name.len -= 1; + if (mem.endsWith(u8, canon_name, ".")) return error.UnknownHostName; + + // Name with search domain appended is setup in canon[]. This both + // provides the desired default canonical name (if the requested + // name is not a CNAME record) and serves as a buffer for passing + // the full requested name to name_from_dns. + try canon.resize(canon_name.len); + @memcpy(canon.items, canon_name); + try canon.append('.'); + + var tok_it = mem.tokenizeAny(u8, search, " \t"); + while (tok_it.next()) |tok| { + canon.shrinkRetainingCapacity(canon_name.len + 1); + try canon.appendSlice(tok); + try linuxLookupNameFromDns(addrs, canon, canon.items, family, rc, port); + if (addrs.items.len != 0) return; + } + + canon.shrinkRetainingCapacity(canon_name.len); + return linuxLookupNameFromDns(addrs, canon, name, family, rc, port); +} + +const dpc_ctx = struct { + addrs: *std.ArrayList(LookupAddr), + canon: *std.ArrayList(u8), + port: u16, +}; + +fn linuxLookupNameFromDns( + addrs: *std.ArrayList(LookupAddr), + canon: *std.ArrayList(u8), + name: []const u8, + family: posix.sa_family_t, + rc: ResolvConf, + port: u16, +) !void { + const ctx = dpc_ctx{ + .addrs = addrs, + .canon = canon, + .port = port, + }; + const AfRr = struct { + af: posix.sa_family_t, + rr: u8, + }; + const afrrs = [_]AfRr{ + AfRr{ .af = posix.AF.INET6, .rr = posix.RR.A }, + AfRr{ .af = posix.AF.INET, .rr = posix.RR.AAAA }, + }; + var qbuf: [2][280]u8 = undefined; + var abuf: [2][512]u8 = undefined; + var qp: [2][]const u8 = undefined; + const apbuf = [2][]u8{ &abuf[0], &abuf[1] }; + var nq: usize = 0; + + for (afrrs) |afrr| { + if (family != afrr.af) { + const len = posix.res_mkquery(0, name, 1, afrr.rr, &[_]u8{}, null, &qbuf[nq]); + qp[nq] = qbuf[nq][0..len]; + nq += 1; + } + } + + var ap = [2][]u8{ apbuf[0], apbuf[1] }; + ap[0].len = 0; + ap[1].len = 0; + + try resMSendRc(qp[0..nq], ap[0..nq], apbuf[0..nq], rc); + + var i: usize = 0; + while (i < nq) : (i += 1) { + dnsParse(ap[i], ctx, dnsParseCallback) catch {}; + } + + if (addrs.items.len != 0) return; + if (ap[0].len < 4 or (ap[0][3] & 15) == 2) return error.TemporaryNameServerFailure; + if ((ap[0][3] & 15) == 0) return error.UnknownHostName; + if ((ap[0][3] & 15) == 3) return; + return error.NameServerFailure; +} + +const ResolvConf = struct { + attempts: u32, + ndots: u32, + timeout: u32, + search: std.ArrayList(u8), + ns: std.ArrayList(LookupAddr), + + fn deinit(rc: *ResolvConf) void { + rc.ns.deinit(); + rc.search.deinit(); + rc.* = undefined; + } +}; + +/// Ignores lines longer than 512 bytes. +/// TODO: https://github.com/ziglang/zig/issues/2765 and https://github.com/ziglang/zig/issues/2761 +fn getResolvConf(allocator: mem.Allocator, rc: *ResolvConf) !void { + rc.* = ResolvConf{ + .ns = std.ArrayList(LookupAddr).init(allocator), + .search = std.ArrayList(u8).init(allocator), + .ndots = 1, + .timeout = 5, + .attempts = 2, + }; + errdefer rc.deinit(); + + const file = fs.openFileAbsoluteZ("/etc/resolv.conf", .{}) catch |err| switch (err) { + error.FileNotFound, + error.NotDir, + error.AccessDenied, + => return linuxLookupNameFromNumericUnspec(&rc.ns, "127.0.0.1", 53), + else => |e| return e, + }; + defer file.close(); + + var buf_reader = std.io.bufferedReader(file.reader()); + const stream = buf_reader.reader(); + var line_buf: [512]u8 = undefined; + while (stream.readUntilDelimiterOrEof(&line_buf, '\n') catch |err| switch (err) { + error.StreamTooLong => blk: { + // Skip to the delimiter in the stream, to fix parsing + try stream.skipUntilDelimiterOrEof('\n'); + // Give an empty line to the while loop, which will be skipped. + break :blk line_buf[0..0]; + }, + else => |e| return e, + }) |line| { + const no_comment_line = no_comment_line: { + var split = mem.splitScalar(u8, line, '#'); + break :no_comment_line split.first(); + }; + var line_it = mem.tokenizeAny(u8, no_comment_line, " \t"); + + const token = line_it.next() orelse continue; + if (mem.eql(u8, token, "options")) { + while (line_it.next()) |sub_tok| { + var colon_it = mem.splitScalar(u8, sub_tok, ':'); + const name = colon_it.first(); + const value_txt = colon_it.next() orelse continue; + const value = std.fmt.parseInt(u8, value_txt, 10) catch |err| switch (err) { + // TODO https://github.com/ziglang/zig/issues/11812 + error.Overflow => @as(u8, 255), + error.InvalidCharacter => continue, + }; + if (mem.eql(u8, name, "ndots")) { + rc.ndots = @min(value, 15); + } else if (mem.eql(u8, name, "attempts")) { + rc.attempts = @min(value, 10); + } else if (mem.eql(u8, name, "timeout")) { + rc.timeout = @min(value, 60); + } + } + } else if (mem.eql(u8, token, "nameserver")) { + const ip_txt = line_it.next() orelse continue; + try linuxLookupNameFromNumericUnspec(&rc.ns, ip_txt, 53); + } else if (mem.eql(u8, token, "domain") or mem.eql(u8, token, "search")) { + rc.search.items.len = 0; + try rc.search.appendSlice(line_it.rest()); + } + } + + if (rc.ns.items.len == 0) { + return linuxLookupNameFromNumericUnspec(&rc.ns, "127.0.0.1", 53); + } +} + +fn linuxLookupNameFromNumericUnspec( + addrs: *std.ArrayList(LookupAddr), + name: []const u8, + port: u16, +) !void { + const addr = try Address.resolveIp(name, port); + (try addrs.addOne()).* = LookupAddr{ .addr = addr }; +} + +fn resMSendRc( + queries: []const []const u8, + answers: [][]u8, + answer_bufs: []const []u8, + rc: ResolvConf, +) !void { + const timeout = 1000 * rc.timeout; + const attempts = rc.attempts; + + var sl: posix.socklen_t = @sizeOf(posix.sockaddr.in); + var family: posix.sa_family_t = posix.AF.INET; + + var ns_list = std.ArrayList(Address).init(rc.ns.allocator); + defer ns_list.deinit(); + + try ns_list.resize(rc.ns.items.len); + const ns = ns_list.items; + + for (rc.ns.items, 0..) |iplit, i| { + ns[i] = iplit.addr; + assert(ns[i].getPort() == 53); + if (iplit.addr.any.family != posix.AF.INET) { + family = posix.AF.INET6; + } + } + + const flags = posix.SOCK.DGRAM | posix.SOCK.CLOEXEC | posix.SOCK.NONBLOCK; + const fd = posix.socket(family, flags, 0) catch |err| switch (err) { + error.AddressFamilyNotSupported => blk: { + // Handle case where system lacks IPv6 support + if (family == posix.AF.INET6) { + family = posix.AF.INET; + break :blk try posix.socket(posix.AF.INET, flags, 0); + } + return err; + }, + else => |e| return e, + }; + defer Stream.close(.{ .handle = fd }); + + // Past this point, there are no errors. Each individual query will + // yield either no reply (indicated by zero length) or an answer + // packet which is up to the caller to interpret. + + // Convert any IPv4 addresses in a mixed environment to v4-mapped + if (family == posix.AF.INET6) { + try posix.setsockopt( + fd, + posix.SOL.IPV6, + std.os.linux.IPV6.V6ONLY, + &mem.toBytes(@as(c_int, 0)), + ); + for (0..ns.len) |i| { + if (ns[i].any.family != posix.AF.INET) continue; + mem.writeInt(u32, ns[i].in6.sa.addr[12..], ns[i].in.sa.addr, native_endian); + ns[i].in6.sa.addr[0..12].* = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff".*; + ns[i].any.family = posix.AF.INET6; + ns[i].in6.sa.flowinfo = 0; + ns[i].in6.sa.scope_id = 0; + } + sl = @sizeOf(posix.sockaddr.in6); + } + + // Get local address and open/bind a socket + var sa: Address = undefined; + @memset(@as([*]u8, @ptrCast(&sa))[0..@sizeOf(Address)], 0); + sa.any.family = family; + try posix.bind(fd, &sa.any, sl); + + var pfd = [1]posix.pollfd{posix.pollfd{ + .fd = fd, + .events = posix.POLL.IN, + .revents = undefined, + }}; + const retry_interval = timeout / attempts; + var next: u32 = 0; + var t2: u64 = @bitCast(std.time.milliTimestamp()); + const t0 = t2; + var t1 = t2 - retry_interval; + + var servfail_retry: usize = undefined; + + outer: while (t2 - t0 < timeout) : (t2 = @as(u64, @bitCast(std.time.milliTimestamp()))) { + if (t2 - t1 >= retry_interval) { + // Query all configured nameservers in parallel + var i: usize = 0; + while (i < queries.len) : (i += 1) { + if (answers[i].len == 0) { + var j: usize = 0; + while (j < ns.len) : (j += 1) { + _ = posix.sendto(fd, queries[i], posix.MSG.NOSIGNAL, &ns[j].any, sl) catch undefined; + } + } + } + t1 = t2; + servfail_retry = 2 * queries.len; + } + + // Wait for a response, or until time to retry + const clamped_timeout = @min(@as(u31, std.math.maxInt(u31)), t1 + retry_interval - t2); + const nevents = posix.poll(&pfd, clamped_timeout) catch 0; + if (nevents == 0) continue; + + while (true) { + var sl_copy = sl; + const rlen = posix.recvfrom(fd, answer_bufs[next], 0, &sa.any, &sl_copy) catch break; + + // Ignore non-identifiable packets + if (rlen < 4) continue; + + // Ignore replies from addresses we didn't send to + var j: usize = 0; + while (j < ns.len and !ns[j].eql(sa)) : (j += 1) {} + if (j == ns.len) continue; + + // Find which query this answer goes with, if any + var i: usize = next; + while (i < queries.len and (answer_bufs[next][0] != queries[i][0] or + answer_bufs[next][1] != queries[i][1])) : (i += 1) + {} + + if (i == queries.len) continue; + if (answers[i].len != 0) continue; + + // Only accept positive or negative responses; + // retry immediately on server failure, and ignore + // all other codes such as refusal. + switch (answer_bufs[next][3] & 15) { + 0, 3 => {}, + 2 => if (servfail_retry != 0) { + servfail_retry -= 1; + _ = posix.sendto(fd, queries[i], posix.MSG.NOSIGNAL, &ns[j].any, sl) catch undefined; + }, + else => continue, + } + + // Store answer in the right slot, or update next + // available temp slot if it's already in place. + answers[i].len = rlen; + if (i == next) { + while (next < queries.len and answers[next].len != 0) : (next += 1) {} + } else { + @memcpy(answer_bufs[i][0..rlen], answer_bufs[next][0..rlen]); + } + + if (next == queries.len) break :outer; + } + } +} + +fn dnsParse( + r: []const u8, + ctx: anytype, + comptime callback: anytype, +) !void { + // This implementation is ported from musl libc. + // A more idiomatic "ziggy" implementation would be welcome. + if (r.len < 12) return error.InvalidDnsPacket; + if ((r[3] & 15) != 0) return; + var p = r.ptr + 12; + var qdcount = r[4] * @as(usize, 256) + r[5]; + var ancount = r[6] * @as(usize, 256) + r[7]; + if (qdcount + ancount > 64) return error.InvalidDnsPacket; + while (qdcount != 0) { + qdcount -= 1; + while (@intFromPtr(p) - @intFromPtr(r.ptr) < r.len and p[0] -% 1 < 127) p += 1; + if (p[0] > 193 or (p[0] == 193 and p[1] > 254) or @intFromPtr(p) > @intFromPtr(r.ptr) + r.len - 6) + return error.InvalidDnsPacket; + p += @as(usize, 5) + @intFromBool(p[0] != 0); + } + while (ancount != 0) { + ancount -= 1; + while (@intFromPtr(p) - @intFromPtr(r.ptr) < r.len and p[0] -% 1 < 127) p += 1; + if (p[0] > 193 or (p[0] == 193 and p[1] > 254) or @intFromPtr(p) > @intFromPtr(r.ptr) + r.len - 6) + return error.InvalidDnsPacket; + p += @as(usize, 1) + @intFromBool(p[0] != 0); + const len = p[8] * @as(usize, 256) + p[9]; + if (@intFromPtr(p) + len > @intFromPtr(r.ptr) + r.len) return error.InvalidDnsPacket; + try callback(ctx, p[1], p[10..][0..len], r); + p += 10 + len; + } +} + +fn dnsParseCallback(ctx: dpc_ctx, rr: u8, data: []const u8, packet: []const u8) !void { + switch (rr) { + posix.RR.A => { + if (data.len != 4) return error.InvalidDnsARecord; + const new_addr = try ctx.addrs.addOne(); + new_addr.* = LookupAddr{ + .addr = Address.initIp4(data[0..4].*, ctx.port), + }; + }, + posix.RR.AAAA => { + if (data.len != 16) return error.InvalidDnsAAAARecord; + const new_addr = try ctx.addrs.addOne(); + new_addr.* = LookupAddr{ + .addr = Address.initIp6(data[0..16].*, ctx.port, 0, 0), + }; + }, + posix.RR.CNAME => { + var tmp: [256]u8 = undefined; + // Returns len of compressed name. strlen to get canon name. + _ = try posix.dn_expand(packet, data, &tmp); + const canon_name = mem.sliceTo(&tmp, 0); + if (isValidHostName(canon_name)) { + ctx.canon.items.len = 0; + try ctx.canon.appendSlice(canon_name); + } + }, + else => return, + } +} + +pub const Stream = struct { + /// Underlying platform-defined type which may or may not be + /// interchangeable with a file system file descriptor. + handle: posix.socket_t, + + pub fn close(s: Stream) void { + switch (native_os) { + .windows => windows.closesocket(s.handle) catch unreachable, + else => posix.close(s.handle), + } + } + + pub const ReadError = posix.ReadError; + pub const WriteError = posix.WriteError; + + pub const Reader = io.Reader(Stream, ReadError, read); + pub const Writer = io.Writer(Stream, WriteError, write); + + pub fn reader(self: Stream) Reader { + return .{ .context = self }; + } + + pub fn writer(self: Stream) Writer { + return .{ .context = self }; + } + + pub fn read(self: Stream, buffer: []u8) ReadError!usize { + if (native_os == .windows) { + return windows.ReadFile(self.handle, buffer, null); + } + + return posix.read(self.handle, buffer); + } + + pub fn readv(s: Stream, iovecs: []const posix.iovec) ReadError!usize { + if (native_os == .windows) { + // TODO improve this to use ReadFileScatter + if (iovecs.len == 0) return @as(usize, 0); + const first = iovecs[0]; + return windows.ReadFile(s.handle, first.base[0..first.len], null); + } + + return posix.readv(s.handle, iovecs); + } + + /// Returns the number of bytes read. If the number read is smaller than + /// `buffer.len`, it means the stream reached the end. Reaching the end of + /// a stream is not an error condition. + pub fn readAll(s: Stream, buffer: []u8) ReadError!usize { + return readAtLeast(s, buffer, buffer.len); + } + + /// Returns the number of bytes read, calling the underlying read function + /// the minimal number of times until the buffer has at least `len` bytes + /// filled. If the number read is less than `len` it means the stream + /// reached the end. Reaching the end of the stream is not an error + /// condition. + pub fn readAtLeast(s: Stream, buffer: []u8, len: usize) ReadError!usize { + assert(len <= buffer.len); + var index: usize = 0; + while (index < len) { + const amt = try s.read(buffer[index..]); + if (amt == 0) break; + index += amt; + } + return index; + } + + /// TODO in evented I/O mode, this implementation incorrectly uses the event loop's + /// file system thread instead of non-blocking. It needs to be reworked to properly + /// use non-blocking I/O. + pub fn write(self: Stream, buffer: []const u8) WriteError!usize { + if (native_os == .windows) { + return windows.WriteFile(self.handle, buffer, null); + } + + return posix.write(self.handle, buffer); + } + + pub fn writeAll(self: Stream, bytes: []const u8) WriteError!void { + var index: usize = 0; + while (index < bytes.len) { + index += try self.write(bytes[index..]); + } + } + + /// See https://github.com/ziglang/zig/issues/7699 + /// See equivalent function: `std.fs.File.writev`. + pub fn writev(self: Stream, iovecs: []const posix.iovec_const) WriteError!usize { + return posix.writev(self.handle, iovecs); + } + + /// The `iovecs` parameter is mutable because this function needs to mutate the fields in + /// order to handle partial writes from the underlying OS layer. + /// See https://github.com/ziglang/zig/issues/7699 + /// See equivalent function: `std.fs.File.writevAll`. + pub fn writevAll(self: Stream, iovecs: []posix.iovec_const) WriteError!void { + if (iovecs.len == 0) return; + + var i: usize = 0; + while (true) { + var amt = try self.writev(iovecs[i..]); + while (amt >= iovecs[i].len) { + amt -= iovecs[i].len; + i += 1; + if (i >= iovecs.len) return; + } + iovecs[i].base += amt; + iovecs[i].len -= amt; + } + } + + pub fn async_read( + self: Stream, + buffer: []u8, + ctx: *Ctx, + comptime cbk: Cbk, + ) !void { + return ctx.loop.recv(Ctx, ctx, cbk, self.handle, buffer); + } + + pub fn async_readv( + s: Stream, + iovecs: []const posix.iovec, + ctx: *Ctx, + comptime cbk: Cbk, + ) ReadError!void { + if (iovecs.len == 0) return; + const first_buffer = iovecs[0].base[0..iovecs[0].len]; + return s.async_read(first_buffer, ctx, cbk); + } + + // TODO: why not take a buffer here? + pub fn async_write(self: Stream, buffer: []const u8, ctx: *Ctx, comptime cbk: Cbk) void { + return ctx.loop.send(Ctx, ctx, cbk, self.handle, buffer); + } + + fn onWriteAll(ctx: *Ctx, res: anyerror!void) anyerror!void { + res catch |err| return ctx.pop(err); + if (ctx.len() < ctx.buf().len) { + const new_buf = ctx.buf()[ctx.len()..]; + ctx.setBuf(new_buf); + return ctx.stream().async_write(new_buf, ctx, onWriteAll); + } + ctx.setBuf(null); + return ctx.pop({}); + } + + pub fn async_writeAll(self: Stream, bytes: []const u8, ctx: *Ctx, comptime cbk: Cbk) !void { + ctx.setBuf(bytes); + try ctx.push(cbk); + self.async_write(bytes, ctx, onWriteAll); + } +}; + +pub const Server = struct { + listen_address: Address, + stream: std.net.Stream, + + pub const Connection = struct { + stream: std.net.Stream, + address: Address, + }; + + pub fn deinit(s: *Server) void { + s.stream.close(); + s.* = undefined; + } + + pub const AcceptError = posix.AcceptError; + + /// Blocks until a client connects to the server. The returned `Connection` has + /// an open stream. + pub fn accept(s: *Server) AcceptError!Connection { + var accepted_addr: Address = undefined; + var addr_len: posix.socklen_t = @sizeOf(Address); + const fd = try posix.accept(s.stream.handle, &accepted_addr.any, &addr_len, posix.SOCK.CLOEXEC); + return .{ + .stream = .{ .handle = fd }, + .address = accepted_addr, + }; + } +}; + +test { + _ = @import("net/test.zig"); + _ = Server; + _ = Stream; + _ = Address; +} + +fn onTcpConnectToHost(ctx: *Ctx, res: anyerror!void) anyerror!void { + res catch |e| switch (e) { + error.ConnectionRefused => { + if (ctx.data.addr_current < ctx.data.list.addrs.len) { + // next iteration of addr + ctx.push(onTcpConnectToHost) catch |er| return ctx.pop(er); + ctx.data.addr_current += 1; + return async_tcpConnectToAddress( + ctx.data.list.addrs[ctx.data.addr_current], + ctx, + onTcpConnectToHost, + ); + } + // end of iteration of addr + ctx.data.list.deinit(); + return ctx.pop(e); + }, + else => { + ctx.data.list.deinit(); + return ctx.pop(std.posix.ConnectError.ConnectionRefused); + }, + }; + // success + ctx.data.list.deinit(); + return ctx.pop({}); +} + +pub fn async_tcpConnectToHost( + allocator: mem.Allocator, + name: []const u8, + port: u16, + ctx: *Ctx, + comptime cbk: Cbk, +) !void { + const list = std.net.getAddressList(allocator, name, port) catch |e| return ctx.pop(e); + if (list.addrs.len == 0) return ctx.pop(error.UnknownHostName); + + ctx.push(cbk) catch |e| return ctx.pop(e); + ctx.data.list = list; + ctx.data.addr_current = 0; + return async_tcpConnectToAddress(list.addrs[0], ctx, onTcpConnectToHost); +} + +pub fn async_tcpConnectToAddress(address: std.net.Address, ctx: *Ctx, comptime cbk: Cbk) !void { + const nonblock = 0; + const sock_flags = posix.SOCK.STREAM | nonblock | + (if (native_os == .windows) 0 else posix.SOCK.CLOEXEC); + const sockfd = try posix.socket(address.any.family, sock_flags, posix.IPPROTO.TCP); + + ctx.data.socket = sockfd; + ctx.push(cbk) catch |e| return ctx.pop(e); + + ctx.loop.connect( + Ctx, + ctx, + setStream, + sockfd, + address, + ); +} + +// requires client.data.socket to be set +fn setStream(ctx: *Ctx, res: anyerror!void) anyerror!void { + res catch |e| return ctx.pop(e); + ctx.data.conn.stream = .{ .handle = ctx.data.socket }; + return ctx.pop({}); +} diff --git a/src/http/async/std/net/test.zig b/src/http/async/std/net/test.zig new file mode 100644 index 00000000..3e316c54 --- /dev/null +++ b/src/http/async/std/net/test.zig @@ -0,0 +1,335 @@ +const std = @import("std"); +const builtin = @import("builtin"); +const net = std.net; +const mem = std.mem; +const testing = std.testing; + +test "parse and render IP addresses at comptime" { + if (builtin.os.tag == .wasi) return error.SkipZigTest; + comptime { + var ipAddrBuffer: [16]u8 = undefined; + // Parses IPv6 at comptime + const ipv6addr = net.Address.parseIp("::1", 0) catch unreachable; + var ipv6 = std.fmt.bufPrint(ipAddrBuffer[0..], "{}", .{ipv6addr}) catch unreachable; + try std.testing.expect(std.mem.eql(u8, "::1", ipv6[1 .. ipv6.len - 3])); + + // Parses IPv4 at comptime + const ipv4addr = net.Address.parseIp("127.0.0.1", 0) catch unreachable; + var ipv4 = std.fmt.bufPrint(ipAddrBuffer[0..], "{}", .{ipv4addr}) catch unreachable; + try std.testing.expect(std.mem.eql(u8, "127.0.0.1", ipv4[0 .. ipv4.len - 2])); + + // Returns error for invalid IP addresses at comptime + try testing.expectError(error.InvalidIPAddressFormat, net.Address.parseIp("::123.123.123.123", 0)); + try testing.expectError(error.InvalidIPAddressFormat, net.Address.parseIp("127.01.0.1", 0)); + try testing.expectError(error.InvalidIPAddressFormat, net.Address.resolveIp("::123.123.123.123", 0)); + try testing.expectError(error.InvalidIPAddressFormat, net.Address.resolveIp("127.01.0.1", 0)); + } +} + +test "parse and render IPv6 addresses" { + if (builtin.os.tag == .wasi) return error.SkipZigTest; + + var buffer: [100]u8 = undefined; + const ips = [_][]const u8{ + "FF01:0:0:0:0:0:0:FB", + "FF01::Fb", + "::1", + "::", + "1::", + "2001:db8::", + "::1234:5678", + "2001:db8::1234:5678", + "FF01::FB%1234", + "::ffff:123.5.123.5", + }; + const printed = [_][]const u8{ + "ff01::fb", + "ff01::fb", + "::1", + "::", + "1::", + "2001:db8::", + "::1234:5678", + "2001:db8::1234:5678", + "ff01::fb", + "::ffff:123.5.123.5", + }; + for (ips, 0..) |ip, i| { + const addr = net.Address.parseIp6(ip, 0) catch unreachable; + var newIp = std.fmt.bufPrint(buffer[0..], "{}", .{addr}) catch unreachable; + try std.testing.expect(std.mem.eql(u8, printed[i], newIp[1 .. newIp.len - 3])); + + if (builtin.os.tag == .linux) { + const addr_via_resolve = net.Address.resolveIp6(ip, 0) catch unreachable; + var newResolvedIp = std.fmt.bufPrint(buffer[0..], "{}", .{addr_via_resolve}) catch unreachable; + try std.testing.expect(std.mem.eql(u8, printed[i], newResolvedIp[1 .. newResolvedIp.len - 3])); + } + } + + try testing.expectError(error.InvalidCharacter, net.Address.parseIp6(":::", 0)); + try testing.expectError(error.Overflow, net.Address.parseIp6("FF001::FB", 0)); + try testing.expectError(error.InvalidCharacter, net.Address.parseIp6("FF01::Fb:zig", 0)); + try testing.expectError(error.InvalidEnd, net.Address.parseIp6("FF01:0:0:0:0:0:0:FB:", 0)); + try testing.expectError(error.Incomplete, net.Address.parseIp6("FF01:", 0)); + try testing.expectError(error.InvalidIpv4Mapping, net.Address.parseIp6("::123.123.123.123", 0)); + try testing.expectError(error.Incomplete, net.Address.parseIp6("1", 0)); + // TODO Make this test pass on other operating systems. + if (builtin.os.tag == .linux or comptime builtin.os.tag.isDarwin()) { + try testing.expectError(error.Incomplete, net.Address.resolveIp6("ff01::fb%", 0)); + try testing.expectError(error.Overflow, net.Address.resolveIp6("ff01::fb%wlp3s0s0s0s0s0s0s0s0", 0)); + try testing.expectError(error.Overflow, net.Address.resolveIp6("ff01::fb%12345678901234", 0)); + } +} + +test "invalid but parseable IPv6 scope ids" { + if (builtin.os.tag != .linux and comptime !builtin.os.tag.isDarwin()) { + // Currently, resolveIp6 with alphanumerical scope IDs only works on Linux. + // TODO Make this test pass on other operating systems. + return error.SkipZigTest; + } + + try testing.expectError(error.InterfaceNotFound, net.Address.resolveIp6("ff01::fb%123s45678901234", 0)); +} + +test "parse and render IPv4 addresses" { + if (builtin.os.tag == .wasi) return error.SkipZigTest; + + var buffer: [18]u8 = undefined; + for ([_][]const u8{ + "0.0.0.0", + "255.255.255.255", + "1.2.3.4", + "123.255.0.91", + "127.0.0.1", + }) |ip| { + const addr = net.Address.parseIp4(ip, 0) catch unreachable; + var newIp = std.fmt.bufPrint(buffer[0..], "{}", .{addr}) catch unreachable; + try std.testing.expect(std.mem.eql(u8, ip, newIp[0 .. newIp.len - 2])); + } + + try testing.expectError(error.Overflow, net.Address.parseIp4("256.0.0.1", 0)); + try testing.expectError(error.InvalidCharacter, net.Address.parseIp4("x.0.0.1", 0)); + try testing.expectError(error.InvalidEnd, net.Address.parseIp4("127.0.0.1.1", 0)); + try testing.expectError(error.Incomplete, net.Address.parseIp4("127.0.0.", 0)); + try testing.expectError(error.InvalidCharacter, net.Address.parseIp4("100..0.1", 0)); + try testing.expectError(error.NonCanonical, net.Address.parseIp4("127.01.0.1", 0)); +} + +test "parse and render UNIX addresses" { + if (builtin.os.tag == .wasi) return error.SkipZigTest; + if (!net.has_unix_sockets) return error.SkipZigTest; + + var buffer: [14]u8 = undefined; + const addr = net.Address.initUnix("/tmp/testpath") catch unreachable; + const fmt_addr = std.fmt.bufPrint(buffer[0..], "{}", .{addr}) catch unreachable; + try std.testing.expectEqualSlices(u8, "/tmp/testpath", fmt_addr); + + const too_long = [_]u8{'a'} ** 200; + try testing.expectError(error.NameTooLong, net.Address.initUnix(too_long[0..])); +} + +test "resolve DNS" { + if (builtin.os.tag == .wasi) return error.SkipZigTest; + + if (builtin.os.tag == .windows) { + _ = try std.os.windows.WSAStartup(2, 2); + } + defer { + if (builtin.os.tag == .windows) { + std.os.windows.WSACleanup() catch unreachable; + } + } + + // Resolve localhost, this should not fail. + { + const localhost_v4 = try net.Address.parseIp("127.0.0.1", 80); + const localhost_v6 = try net.Address.parseIp("::2", 80); + + const result = try net.getAddressList(testing.allocator, "localhost", 80); + defer result.deinit(); + for (result.addrs) |addr| { + if (addr.eql(localhost_v4) or addr.eql(localhost_v6)) break; + } else @panic("unexpected address for localhost"); + } + + { + // The tests are required to work even when there is no Internet connection, + // so some of these errors we must accept and skip the test. + const result = net.getAddressList(testing.allocator, "example.com", 80) catch |err| switch (err) { + error.UnknownHostName => return error.SkipZigTest, + error.TemporaryNameServerFailure => return error.SkipZigTest, + else => return err, + }; + result.deinit(); + } +} + +test "listen on a port, send bytes, receive bytes" { + if (builtin.single_threaded) return error.SkipZigTest; + if (builtin.os.tag == .wasi) return error.SkipZigTest; + + if (builtin.os.tag == .windows) { + _ = try std.os.windows.WSAStartup(2, 2); + } + defer { + if (builtin.os.tag == .windows) { + std.os.windows.WSACleanup() catch unreachable; + } + } + + // Try only the IPv4 variant as some CI builders have no IPv6 localhost + // configured. + const localhost = try net.Address.parseIp("127.0.0.1", 0); + + var server = try localhost.listen(.{}); + defer server.deinit(); + + const S = struct { + fn clientFn(server_address: net.Address) !void { + const socket = try net.tcpConnectToAddress(server_address); + defer socket.close(); + + _ = try socket.writer().writeAll("Hello world!"); + } + }; + + const t = try std.Thread.spawn(.{}, S.clientFn, .{server.listen_address}); + defer t.join(); + + var client = try server.accept(); + defer client.stream.close(); + var buf: [16]u8 = undefined; + const n = try client.stream.reader().read(&buf); + + try testing.expectEqual(@as(usize, 12), n); + try testing.expectEqualSlices(u8, "Hello world!", buf[0..n]); +} + +test "listen on an in use port" { + if (builtin.os.tag != .linux and comptime !builtin.os.tag.isDarwin()) { + // TODO build abstractions for other operating systems + return error.SkipZigTest; + } + + const localhost = try net.Address.parseIp("127.0.0.1", 0); + + var server1 = try localhost.listen(.{ .reuse_port = true }); + defer server1.deinit(); + + var server2 = try server1.listen_address.listen(.{ .reuse_port = true }); + defer server2.deinit(); +} + +fn testClientToHost(allocator: mem.Allocator, name: []const u8, port: u16) anyerror!void { + if (builtin.os.tag == .wasi) return error.SkipZigTest; + + const connection = try net.tcpConnectToHost(allocator, name, port); + defer connection.close(); + + var buf: [100]u8 = undefined; + const len = try connection.read(&buf); + const msg = buf[0..len]; + try testing.expect(mem.eql(u8, msg, "hello from server\n")); +} + +fn testClient(addr: net.Address) anyerror!void { + if (builtin.os.tag == .wasi) return error.SkipZigTest; + + const socket_file = try net.tcpConnectToAddress(addr); + defer socket_file.close(); + + var buf: [100]u8 = undefined; + const len = try socket_file.read(&buf); + const msg = buf[0..len]; + try testing.expect(mem.eql(u8, msg, "hello from server\n")); +} + +fn testServer(server: *net.Server) anyerror!void { + if (builtin.os.tag == .wasi) return error.SkipZigTest; + + var client = try server.accept(); + + const stream = client.stream.writer(); + try stream.print("hello from server\n", .{}); +} + +test "listen on a unix socket, send bytes, receive bytes" { + if (builtin.single_threaded) return error.SkipZigTest; + if (!net.has_unix_sockets) return error.SkipZigTest; + + if (builtin.os.tag == .windows) { + _ = try std.os.windows.WSAStartup(2, 2); + } + defer { + if (builtin.os.tag == .windows) { + std.os.windows.WSACleanup() catch unreachable; + } + } + + const socket_path = try generateFileName("socket.unix"); + defer testing.allocator.free(socket_path); + + const socket_addr = try net.Address.initUnix(socket_path); + defer std.fs.cwd().deleteFile(socket_path) catch {}; + + var server = try socket_addr.listen(.{}); + defer server.deinit(); + + const S = struct { + fn clientFn(path: []const u8) !void { + const socket = try net.connectUnixSocket(path); + defer socket.close(); + + _ = try socket.writer().writeAll("Hello world!"); + } + }; + + const t = try std.Thread.spawn(.{}, S.clientFn, .{socket_path}); + defer t.join(); + + var client = try server.accept(); + defer client.stream.close(); + var buf: [16]u8 = undefined; + const n = try client.stream.reader().read(&buf); + + try testing.expectEqual(@as(usize, 12), n); + try testing.expectEqualSlices(u8, "Hello world!", buf[0..n]); +} + +fn generateFileName(base_name: []const u8) ![]const u8 { + const random_bytes_count = 12; + const sub_path_len = comptime std.fs.base64_encoder.calcSize(random_bytes_count); + var random_bytes: [12]u8 = undefined; + std.crypto.random.bytes(&random_bytes); + var sub_path: [sub_path_len]u8 = undefined; + _ = std.fs.base64_encoder.encode(&sub_path, &random_bytes); + return std.fmt.allocPrint(testing.allocator, "{s}-{s}", .{ sub_path[0..], base_name }); +} + +test "non-blocking tcp server" { + if (builtin.os.tag == .wasi) return error.SkipZigTest; + if (true) { + // https://github.com/ziglang/zig/issues/18315 + return error.SkipZigTest; + } + + const localhost = try net.Address.parseIp("127.0.0.1", 0); + var server = localhost.listen(.{ .force_nonblocking = true }); + defer server.deinit(); + + const accept_err = server.accept(); + try testing.expectError(error.WouldBlock, accept_err); + + const socket_file = try net.tcpConnectToAddress(server.listen_address); + defer socket_file.close(); + + var client = try server.accept(); + defer client.stream.close(); + const stream = client.stream.writer(); + try stream.print("hello from server\n", .{}); + + var buf: [100]u8 = undefined; + const len = try socket_file.read(&buf); + const msg = buf[0..len]; + try testing.expect(mem.eql(u8, msg, "hello from server\n")); +} diff --git a/src/http/async/tls.zig/PrivateKey.zig b/src/http/async/tls.zig/PrivateKey.zig new file mode 100644 index 00000000..0e2b944d --- /dev/null +++ b/src/http/async/tls.zig/PrivateKey.zig @@ -0,0 +1,260 @@ +const std = @import("std"); +const Allocator = std.mem.Allocator; +const Certificate = std.crypto.Certificate; +const der = Certificate.der; +const rsa = @import("rsa/rsa.zig"); +const base64 = std.base64.standard.decoderWithIgnore(" \t\r\n"); +const proto = @import("protocol.zig"); + +const max_ecdsa_key_len = 66; + +signature_scheme: proto.SignatureScheme, + +key: union { + rsa: rsa.KeyPair, + ecdsa: [max_ecdsa_key_len]u8, +}, + +const PrivateKey = @This(); + +pub fn fromFile(gpa: Allocator, file: std.fs.File) !PrivateKey { + const buf = try file.readToEndAlloc(gpa, 1024 * 1024); + defer gpa.free(buf); + return try parsePem(buf); +} + +pub fn parsePem(buf: []const u8) !PrivateKey { + const key_start, const key_end, const marker_version = try findKey(buf); + const encoded = std.mem.trim(u8, buf[key_start..key_end], " \t\r\n"); + + // required bytes: + // 2412, 1821, 1236 for rsa 4096, 3072, 2048 bits size keys + var decoded: [4096]u8 = undefined; + const n = try base64.decode(&decoded, encoded); + + if (marker_version == 2) { + return try parseEcDer(decoded[0..n]); + } + return try parseDer(decoded[0..n]); +} + +fn findKey(buf: []const u8) !struct { usize, usize, usize } { + const markers = [_]struct { + begin: []const u8, + end: []const u8, + }{ + .{ .begin = "-----BEGIN PRIVATE KEY-----", .end = "-----END PRIVATE KEY-----" }, + .{ .begin = "-----BEGIN EC PRIVATE KEY-----", .end = "-----END EC PRIVATE KEY-----" }, + }; + + for (markers, 1..) |marker, ver| { + const begin_marker_start = std.mem.indexOfPos(u8, buf, 0, marker.begin) orelse continue; + const key_start = begin_marker_start + marker.begin.len; + const key_end = std.mem.indexOfPos(u8, buf, key_start, marker.end) orelse continue; + + return .{ key_start, key_end, ver }; + } + return error.MissingEndMarker; +} + +// ref: https://asn1js.eu/#MIG2AgEAMBAGByqGSM49AgEGBSuBBAAiBIGeMIGbAgEBBDBKFkVJCtU9FR6egz3yNxKBwXd86cFzMYqyGb8hRc1zVvLdw-So_2FBtITp6jzYmFShZANiAAQ-CH3a1R0V6dFlTK8Rs4M4egrpPtdta0osysO0Zl8mkBiDsTlvJNqeAp7L2ItHgFW8k_CfhgQT6iLDacNMhKC4XOV07r_ePD-mmkvqvRmzfOowHUoVRhCKrOTmF_J9Syc +pub fn parseDer(buf: []const u8) !PrivateKey { + const info = try der.Element.parse(buf, 0); + const version = try der.Element.parse(buf, info.slice.start); + + const algo_seq = try der.Element.parse(buf, version.slice.end); + const algo_cat = try der.Element.parse(buf, algo_seq.slice.start); + + const key_str = try der.Element.parse(buf, algo_seq.slice.end); + const key_seq = try der.Element.parse(buf, key_str.slice.start); + const key_int = try der.Element.parse(buf, key_seq.slice.start); + + const category = try Certificate.parseAlgorithmCategory(buf, algo_cat); + switch (category) { + .rsaEncryption => { + const modulus = try der.Element.parse(buf, key_int.slice.end); + const public_exponent = try der.Element.parse(buf, modulus.slice.end); + const private_exponent = try der.Element.parse(buf, public_exponent.slice.end); + + const public_key = try rsa.PublicKey.fromBytes(content(buf, modulus), content(buf, public_exponent)); + const secret_key = try rsa.SecretKey.fromBytes(public_key.modulus, content(buf, private_exponent)); + const key_pair = rsa.KeyPair{ .public = public_key, .secret = secret_key }; + + return .{ + .signature_scheme = switch (key_pair.public.modulus.bits()) { + 4096 => .rsa_pss_rsae_sha512, + 3072 => .rsa_pss_rsae_sha384, + else => .rsa_pss_rsae_sha256, + }, + .key = .{ .rsa = key_pair }, + }; + }, + .X9_62_id_ecPublicKey => { + const key = try der.Element.parse(buf, key_int.slice.end); + const algo_param = try der.Element.parse(buf, algo_cat.slice.end); + const named_curve = try Certificate.parseNamedCurve(buf, algo_param); + return .{ + .signature_scheme = signatureScheme(named_curve), + .key = .{ .ecdsa = ecdsaKey(buf, key) }, + }; + }, + else => unreachable, + } +} + +// References: +// https://asn1js.eu/#MHcCAQEEINJSRKv8kSKEzLHptfAlg-LGh4_pHHlq0XLf30Q9pcztoAoGCCqGSM49AwEHoUQDQgAEJpmLyp8aGCgyMcFIJaIq_-4V1K6nPpeoih3bT2npeplF9eyXj7rm8eW9Ua6VLhq71mqtMC-YLm-IkORBVq1cuA +// https://www.rfc-editor.org/rfc/rfc5915 +pub fn parseEcDer(bytes: []const u8) !PrivateKey { + const pki_msg = try der.Element.parse(bytes, 0); + const version = try der.Element.parse(bytes, pki_msg.slice.start); + const key = try der.Element.parse(bytes, version.slice.end); + const parameters = try der.Element.parse(bytes, key.slice.end); + const curve = try der.Element.parse(bytes, parameters.slice.start); + const named_curve = try Certificate.parseNamedCurve(bytes, curve); + return .{ + .signature_scheme = signatureScheme(named_curve), + .key = .{ .ecdsa = ecdsaKey(bytes, key) }, + }; +} + +fn signatureScheme(named_curve: Certificate.NamedCurve) proto.SignatureScheme { + return switch (named_curve) { + .X9_62_prime256v1 => .ecdsa_secp256r1_sha256, + .secp384r1 => .ecdsa_secp384r1_sha384, + .secp521r1 => .ecdsa_secp521r1_sha512, + }; +} + +fn ecdsaKey(bytes: []const u8, e: der.Element) [max_ecdsa_key_len]u8 { + const data = content(bytes, e); + var ecdsa_key: [max_ecdsa_key_len]u8 = undefined; + @memcpy(ecdsa_key[0..data.len], data); + return ecdsa_key; +} + +fn content(bytes: []const u8, e: der.Element) []const u8 { + return bytes[e.slice.start..e.slice.end]; +} + +const testing = std.testing; +const testu = @import("testu.zig"); + +test "parse ec pem" { + const data = @embedFile("testdata/ec_private_key.pem"); + var pk = try parsePem(data); + const priv_key = &testu.hexToBytes( + \\ 10 35 3d ca 1b 15 1d 06 aa 71 b8 ef f3 19 22 + \\ 43 78 f3 20 98 1e b1 2f 2b 64 7e 71 d0 30 2a + \\ 90 aa e5 eb 99 c3 90 65 3d c1 26 19 be 3f 08 + \\ 20 9b 01 + ); + try testing.expectEqualSlices(u8, priv_key, pk.key.ecdsa[0..priv_key.len]); + try testing.expectEqual(.ecdsa_secp384r1_sha384, pk.signature_scheme); +} + +test "parse ec prime256v1" { + const data = @embedFile("testdata/ec_prime256v1_private_key.pem"); + var pk = try parsePem(data); + const priv_key = &testu.hexToBytes( + \\ d2 52 44 ab fc 91 22 84 cc b1 e9 b5 f0 25 83 + \\ e2 c6 87 8f e9 1c 79 6a d1 72 df df 44 3d a5 + \\ cc ed + ); + try testing.expectEqualSlices(u8, priv_key, pk.key.ecdsa[0..priv_key.len]); + try testing.expectEqual(.ecdsa_secp256r1_sha256, pk.signature_scheme); +} + +test "parse ec secp384r1" { + const data = @embedFile("testdata/ec_secp384r1_private_key.pem"); + var pk = try parsePem(data); + const priv_key = &testu.hexToBytes( + \\ ee 6d 8a 5e 0d d3 b0 c6 4b 32 40 80 e2 3a de + \\ 8b 1e dd e2 92 db 36 1c db 91 ea ba a1 06 0d + \\ 42 2d d9 a9 dc 05 43 29 f1 78 7c f9 08 af c5 + \\ 03 1f 6d + ); + try testing.expectEqualSlices(u8, priv_key, pk.key.ecdsa[0..priv_key.len]); + try testing.expectEqual(.ecdsa_secp384r1_sha384, pk.signature_scheme); +} + +test "parse ec secp521r1" { + const data = @embedFile("testdata/ec_secp521r1_private_key.pem"); + var pk = try parsePem(data); + const priv_key = &testu.hexToBytes( + \\ 01 f0 2f 5a c7 24 18 ea 68 23 8c 2e a1 b4 b8 + \\ dc f2 11 b2 96 b0 ec 87 80 42 bf de ba f4 96 + \\ 83 8f 9b db c6 60 a7 4c d9 60 3a e4 ba 0b df + \\ ae 24 d3 1b c2 6e 82 a0 88 c1 ed 17 20 0d 3a + \\ f1 c5 7e e8 0b 27 + ); + try testing.expectEqualSlices(u8, priv_key, pk.key.ecdsa[0..priv_key.len]); + try testing.expectEqual(.ecdsa_secp521r1_sha512, pk.signature_scheme); +} + +test "parse rsa pem" { + const data = @embedFile("testdata/rsa_private_key.pem"); + const pk = try parsePem(data); + + // expected results from: + // $ openssl pkey -in testdata/rsa_private_key.pem -text -noout + const modulus = &testu.hexToBytes( + \\ 00 de f7 23 e6 75 cc 6f dd d5 6e 0f 8c 09 f8 + \\ 62 e3 60 1b c0 7d 8c d5 04 50 2c 36 e2 3b f7 + \\ 33 9f a1 14 af be cf 1a 0f 4c f5 cb 39 70 0e + \\ 3b 97 d6 21 f7 48 91 79 ca 7c 68 fc ea 62 a1 + \\ 5a 72 4f 78 57 0e cc f2 a3 50 05 f1 4c ca 51 + \\ 73 10 9a 18 8e 71 f5 b4 c7 3e be 4c ef 37 d4 + \\ 84 4b 82 1c ec 08 a3 cc 07 3d 5c 0b e5 85 3f + \\ fe b6 44 77 8f 3c 6a 2f 33 c3 5d f6 f2 29 46 + \\ 04 25 7e 05 d9 f8 3b 2d a4 40 66 9f 0d 6d 1a + \\ fa bc 0a c5 8b 86 43 30 ef 14 20 41 9d b5 cc + \\ 3e 63 b5 48 04 27 c9 5c d3 62 28 5f f5 b6 e4 + \\ 77 49 99 ac 84 4a a6 67 a5 9a 1a 37 c7 60 4c + \\ ba c1 70 cf 57 64 4a 21 ea 05 53 10 ec 94 71 + \\ 4a 43 04 83 00 aa 5a 28 bc f2 8c 58 14 92 d2 + \\ 83 17 f4 7b 29 0f e7 87 a2 47 b2 53 19 12 23 + \\ fb 4b ce 5a f8 a1 84 f9 b1 f3 bf e3 fa 10 f8 + \\ ad af 87 ce 03 0e a0 2c 13 71 57 c4 55 44 48 + \\ 44 cb + ); + const public_exponent = &testu.hexToBytes("01 00 01"); + const private_exponent = &testu.hexToBytes( + \\ 50 3b 80 98 aa a5 11 50 33 40 32 aa 02 e0 75 + \\ bd 3a 55 62 34 0b 9c 8f bb c5 dd 4e 15 a4 03 + \\ d8 9a 5f 56 4a 84 3d ed 69 95 3d 37 03 02 ac + \\ 21 1c 36 06 c4 ff 4c 63 37 d7 93 c3 48 10 a5 + \\ fa 62 6c 7c 6f 60 02 a4 0f e4 c3 8b 0d 76 b7 + \\ c0 2e a3 4d 86 e6 92 d1 eb db 10 d6 38 31 ea + \\ 15 3d d1 e8 81 c7 67 60 e7 8c 9a df 51 ce d0 + \\ 7a 88 32 b9 c1 54 b8 7d 98 fc d4 23 1a 05 0e + \\ f2 ea e1 72 29 28 2a 68 b7 90 18 80 1c 21 d6 + \\ 36 a8 6b 4a 9c dd 14 b8 9f 85 ee 95 0b f4 c6 + \\ 17 02 aa 4d ea 4d f9 39 d7 dd 9d b4 1d d2 f8 + \\ 92 46 0f 18 41 80 f4 ea 27 55 29 f8 37 59 bf + \\ 43 ec a3 eb 19 ba bc 13 06 95 3d 25 4b c9 72 + \\ cf 41 0a 6f aa cb 79 d4 7b fa b1 09 7c e2 2f + \\ 85 51 44 8b c6 97 8e 46 f9 6b ac 08 87 92 ce + \\ af 0b bf 8c bd 27 51 8f 09 e4 d3 f9 04 ac fa + \\ f2 04 70 3e d9 a6 28 17 c2 2d 74 e9 25 40 02 + \\ 49 + ); + + try testing.expectEqual(.rsa_pss_rsae_sha256, pk.signature_scheme); + const kp = pk.key.rsa; + { + var bytes: [modulus.len]u8 = undefined; + try kp.public.modulus.toBytes(&bytes, .big); + try testing.expectEqualSlices(u8, modulus, &bytes); + } + { + var bytes: [private_exponent.len]u8 = undefined; + try kp.public.public_exponent.toBytes(&bytes, .big); + try testing.expectEqualSlices(u8, public_exponent, bytes[bytes.len - public_exponent.len .. bytes.len]); + } + { + var btytes: [private_exponent.len]u8 = undefined; + try kp.secret.private_exponent.toBytes(&btytes, .big); + try testing.expectEqualSlices(u8, private_exponent, &btytes); + } +} diff --git a/src/http/async/tls.zig/cbc/main.zig b/src/http/async/tls.zig/cbc/main.zig new file mode 100644 index 00000000..25038445 --- /dev/null +++ b/src/http/async/tls.zig/cbc/main.zig @@ -0,0 +1,148 @@ +// This file is originally copied from: https://github.com/jedisct1/zig-cbc. +// +// It is modified then to have TLS padding insead of PKCS#7 padding. +// Reference: +// https://datatracker.ietf.org/doc/html/rfc5246/#section-6.2.3.2 +// https://crypto.stackexchange.com/questions/98917/on-the-correctness-of-the-padding-example-of-rfc-5246 +// +// If required padding i n bytes +// PKCS#7 padding is (n...n) +// TLS padding is (n-1...n-1) - n times of n-1 value +// +const std = @import("std"); +const aes = std.crypto.core.aes; +const mem = std.mem; +const debug = std.debug; + +/// CBC mode with TLS 1.2 padding +/// +/// Important: the counter mode doesn't provide authenticated encryption: the ciphertext can be trivially modified without this being detected. +/// If you need authenticated encryption, use anything from `std.crypto.aead` instead. +/// If you really need to use CBC mode, make sure to use a MAC to authenticate the ciphertext. +pub fn CBC(comptime BlockCipher: anytype) type { + const EncryptCtx = aes.AesEncryptCtx(BlockCipher); + const DecryptCtx = aes.AesDecryptCtx(BlockCipher); + + return struct { + const Self = @This(); + + enc_ctx: EncryptCtx, + dec_ctx: DecryptCtx, + + /// Initialize the CBC context with the given key. + pub fn init(key: [BlockCipher.key_bits / 8]u8) Self { + const enc_ctx = BlockCipher.initEnc(key); + const dec_ctx = DecryptCtx.initFromEnc(enc_ctx); + + return Self{ .enc_ctx = enc_ctx, .dec_ctx = dec_ctx }; + } + + /// Return the length of the ciphertext given the length of the plaintext. + pub fn paddedLength(length: usize) usize { + return (std.math.divCeil(usize, length + 1, EncryptCtx.block_length) catch unreachable) * EncryptCtx.block_length; + } + + /// Encrypt the given plaintext for the given IV. + /// The destination buffer must be large enough to hold the padded plaintext. + /// Use the `paddedLength()` function to compute the ciphertext size. + /// IV must be secret and unpredictable. + pub fn encrypt(self: Self, dst: []u8, src: []const u8, iv: [EncryptCtx.block_length]u8) void { + // Note: encryption *could* be parallelized, see https://research.kudelskisecurity.com/2022/11/17/some-aes-cbc-encryption-myth-busting/ + const block_length = EncryptCtx.block_length; + const padded_length = paddedLength(src.len); + debug.assert(dst.len == padded_length); // destination buffer must hold the padded plaintext + var cv = iv; + var i: usize = 0; + while (i + block_length <= src.len) : (i += block_length) { + const in = src[i..][0..block_length]; + for (cv[0..], in) |*x, y| x.* ^= y; + self.enc_ctx.encrypt(&cv, &cv); + @memcpy(dst[i..][0..block_length], &cv); + } + // Last block + var in = [_]u8{0} ** block_length; + const padding_length: u8 = @intCast(padded_length - src.len - 1); + @memset(&in, padding_length); + @memcpy(in[0 .. src.len - i], src[i..]); + for (cv[0..], in) |*x, y| x.* ^= y; + self.enc_ctx.encrypt(&cv, &cv); + @memcpy(dst[i..], cv[0 .. dst.len - i]); + } + + /// Decrypt the given ciphertext for the given IV. + /// The destination buffer must be large enough to hold the plaintext. + /// IV must be secret, unpredictable and match the one used for encryption. + pub fn decrypt(self: Self, dst: []u8, src: []const u8, iv: [DecryptCtx.block_length]u8) !void { + const block_length = DecryptCtx.block_length; + if (src.len != dst.len) { + return error.EncodingError; + } + debug.assert(src.len % block_length == 0); + var i: usize = 0; + var cv = iv; + var out: [block_length]u8 = undefined; + // Decryption could be parallelized + while (i + block_length <= dst.len) : (i += block_length) { + const in = src[i..][0..block_length]; + self.dec_ctx.decrypt(&out, in); + for (&out, cv) |*x, y| x.* ^= y; + cv = in.*; + @memcpy(dst[i..][0..block_length], &out); + } + // Last block - We intentionally don't check the padding to mitigate timing attacks + if (i < dst.len) { + const in = src[i..][0..block_length]; + @memset(&out, 0); + self.dec_ctx.decrypt(&out, in); + for (&out, cv) |*x, y| x.* ^= y; + @memcpy(dst[i..], out[0 .. dst.len - i]); + } + } + }; +} + +test "CBC mode" { + const M = CBC(aes.Aes128); + const key = [_]u8{ 0x2b, 0x7e, 0x15, 0x16, 0x28, 0xae, 0xd2, 0xa6, 0xab, 0xf7, 0x15, 0x88, 0x09, 0xcf, 0x4f, 0x3c }; + const iv = [_]u8{ 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f }; + const src_ = "This is a test of AES-CBC that goes on longer than a couple blocks. It is a somewhat long test case to type out!"; + const expected = "\xA0\x8C\x09\x7D\xFF\x42\xB6\x65\x4D\x4B\xC6\x90\x90\x39\xDE\x3D\xC7\xCA\xEB\xF6\x9A\x4F\x09\x97\xC9\x32\xAB\x75\x88\xB7\x57\x17"; + var res: [32]u8 = undefined; + + try comptime std.testing.expect(src_.len / M.paddedLength(1) >= 3); // Ensure that we have at least 3 blocks + + const z = M.init(key); + + // Test encryption and decryption with distinct buffers + var h = std.crypto.hash.sha2.Sha256.init(.{}); + inline for (0..src_.len) |len| { + const src = src_[0..len]; + var dst = [_]u8{0} ** M.paddedLength(src.len); + + z.encrypt(&dst, src, iv); + h.update(&dst); + + var decrypted = [_]u8{0} ** dst.len; + try z.decrypt(&decrypted, &dst, iv); + + const padding = decrypted[decrypted.len - 1] + 1; + try std.testing.expectEqualSlices(u8, src, decrypted[0 .. decrypted.len - padding]); + } + h.final(&res); + try std.testing.expectEqualSlices(u8, expected, &res); + + // Test encryption and decryption with the same buffer + h = std.crypto.hash.sha2.Sha256.init(.{}); + inline for (0..src_.len) |len| { + var buf = [_]u8{0} ** M.paddedLength(len); + @memcpy(buf[0..len], src_[0..len]); + z.encrypt(&buf, buf[0..len], iv); + h.update(&buf); + + try z.decrypt(&buf, &buf, iv); + + try std.testing.expectEqualSlices(u8, src_[0..len], buf[0..len]); + } + h.final(&res); + try std.testing.expectEqualSlices(u8, expected, &res); +} diff --git a/src/http/async/tls.zig/cipher.zig b/src/http/async/tls.zig/cipher.zig new file mode 100644 index 00000000..dbf4a07a --- /dev/null +++ b/src/http/async/tls.zig/cipher.zig @@ -0,0 +1,1004 @@ +const std = @import("std"); +const crypto = std.crypto; +const hkdfExpandLabel = crypto.tls.hkdfExpandLabel; + +const Sha1 = crypto.hash.Sha1; +const Sha256 = crypto.hash.sha2.Sha256; +const Sha384 = crypto.hash.sha2.Sha384; + +const record = @import("record.zig"); +const Record = record.Record; +const Transcript = @import("transcript.zig").Transcript; +const proto = @import("protocol.zig"); + +// tls 1.2 cbc cipher types +const CbcAes128Sha1 = CbcType(crypto.core.aes.Aes128, Sha1); +const CbcAes128Sha256 = CbcType(crypto.core.aes.Aes128, Sha256); +const CbcAes256Sha256 = CbcType(crypto.core.aes.Aes256, Sha256); +const CbcAes256Sha384 = CbcType(crypto.core.aes.Aes256, Sha384); +// tls 1.2 gcm cipher types +const Aead12Aes128Gcm = Aead12Type(crypto.aead.aes_gcm.Aes128Gcm); +const Aead12Aes256Gcm = Aead12Type(crypto.aead.aes_gcm.Aes256Gcm); +// tls 1.2 chacha cipher type +const Aead12ChaCha = Aead12ChaChaType(crypto.aead.chacha_poly.ChaCha20Poly1305); +// tls 1.3 cipher types +const Aes128GcmSha256 = Aead13Type(crypto.aead.aes_gcm.Aes128Gcm, Sha256); +const Aes256GcmSha384 = Aead13Type(crypto.aead.aes_gcm.Aes256Gcm, Sha384); +const ChaChaSha256 = Aead13Type(crypto.aead.chacha_poly.ChaCha20Poly1305, Sha256); +const Aegis128Sha256 = Aead13Type(crypto.aead.aegis.Aegis128L, Sha256); + +pub const encrypt_overhead_tls_12: comptime_int = @max( + CbcAes128Sha1.encrypt_overhead, + CbcAes128Sha256.encrypt_overhead, + CbcAes256Sha256.encrypt_overhead, + CbcAes256Sha384.encrypt_overhead, + Aead12Aes128Gcm.encrypt_overhead, + Aead12Aes256Gcm.encrypt_overhead, + Aead12ChaCha.encrypt_overhead, +); +pub const encrypt_overhead_tls_13: comptime_int = @max( + Aes128GcmSha256.encrypt_overhead, + Aes256GcmSha384.encrypt_overhead, + ChaChaSha256.encrypt_overhead, + Aegis128Sha256.encrypt_overhead, +); + +// ref (length): https://www.rfc-editor.org/rfc/rfc8446#section-5.1 +pub const max_cleartext_len = 1 << 14; +// ref (length): https://www.rfc-editor.org/rfc/rfc8446#section-5.2 +// The sum of the lengths of the content and the padding, plus one for the inner +// content type, plus any expansion added by the AEAD algorithm. +pub const max_ciphertext_len = max_cleartext_len + 256; +pub const max_ciphertext_record_len = record.header_len + max_ciphertext_len; + +/// Returns type for cipher suite tag. +fn CipherType(comptime tag: CipherSuite) type { + return switch (tag) { + // tls 1.2 cbc + .ECDHE_ECDSA_WITH_AES_128_CBC_SHA, + .ECDHE_RSA_WITH_AES_128_CBC_SHA, + .RSA_WITH_AES_128_CBC_SHA, + => CbcAes128Sha1, + .ECDHE_ECDSA_WITH_AES_128_CBC_SHA256, + .ECDHE_RSA_WITH_AES_128_CBC_SHA256, + .RSA_WITH_AES_128_CBC_SHA256, + => CbcAes128Sha256, + .ECDHE_ECDSA_WITH_AES_256_CBC_SHA384, + .ECDHE_RSA_WITH_AES_256_CBC_SHA384, + => CbcAes256Sha384, + + // tls 1.2 gcm + .ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, + .ECDHE_RSA_WITH_AES_128_GCM_SHA256, + => Aead12Aes128Gcm, + .ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, + .ECDHE_RSA_WITH_AES_256_GCM_SHA384, + => Aead12Aes256Gcm, + + // tls 1.2 chacha + .ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256, + .ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256, + => Aead12ChaCha, + + // tls 1.3 + .AES_128_GCM_SHA256 => Aes128GcmSha256, + .AES_256_GCM_SHA384 => Aes256GcmSha384, + .CHACHA20_POLY1305_SHA256 => ChaChaSha256, + .AEGIS_128L_SHA256 => Aegis128Sha256, + + else => unreachable, + }; +} + +/// Provides initialization and common encrypt/decrypt methods for all supported +/// ciphers. Tls 1.2 has only application cipher, tls 1.3 has separate cipher +/// for handshake and application. +pub const Cipher = union(CipherSuite) { + // tls 1.2 cbc + ECDHE_ECDSA_WITH_AES_128_CBC_SHA: CipherType(.ECDHE_ECDSA_WITH_AES_128_CBC_SHA), + ECDHE_RSA_WITH_AES_128_CBC_SHA: CipherType(.ECDHE_RSA_WITH_AES_128_CBC_SHA), + RSA_WITH_AES_128_CBC_SHA: CipherType(.RSA_WITH_AES_128_CBC_SHA), + + ECDHE_ECDSA_WITH_AES_128_CBC_SHA256: CipherType(.ECDHE_ECDSA_WITH_AES_128_CBC_SHA256), + ECDHE_RSA_WITH_AES_128_CBC_SHA256: CipherType(.ECDHE_RSA_WITH_AES_128_CBC_SHA256), + RSA_WITH_AES_128_CBC_SHA256: CipherType(.RSA_WITH_AES_128_CBC_SHA256), + + ECDHE_ECDSA_WITH_AES_256_CBC_SHA384: CipherType(.ECDHE_ECDSA_WITH_AES_256_CBC_SHA384), + ECDHE_RSA_WITH_AES_256_CBC_SHA384: CipherType(.ECDHE_RSA_WITH_AES_256_CBC_SHA384), + // tls 1.2 gcm + ECDHE_ECDSA_WITH_AES_128_GCM_SHA256: CipherType(.ECDHE_ECDSA_WITH_AES_128_GCM_SHA256), + ECDHE_ECDSA_WITH_AES_256_GCM_SHA384: CipherType(.ECDHE_RSA_WITH_AES_256_GCM_SHA384), + ECDHE_RSA_WITH_AES_128_GCM_SHA256: CipherType(.ECDHE_RSA_WITH_AES_128_GCM_SHA256), + ECDHE_RSA_WITH_AES_256_GCM_SHA384: CipherType(.ECDHE_RSA_WITH_AES_256_GCM_SHA384), + // tls 1.2 chacha + ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256: CipherType(.ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256), + ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256: CipherType(.ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256), + // tls 1.3 + AES_128_GCM_SHA256: CipherType(.AES_128_GCM_SHA256), + AES_256_GCM_SHA384: CipherType(.AES_256_GCM_SHA384), + CHACHA20_POLY1305_SHA256: CipherType(.CHACHA20_POLY1305_SHA256), + AEGIS_128L_SHA256: CipherType(.AEGIS_128L_SHA256), + + // tls 1.2 application cipher + pub fn initTls12(tag: CipherSuite, key_material: []const u8, side: proto.Side) !Cipher { + switch (tag) { + inline .ECDHE_ECDSA_WITH_AES_128_CBC_SHA, + .ECDHE_RSA_WITH_AES_128_CBC_SHA, + .RSA_WITH_AES_128_CBC_SHA, + .ECDHE_ECDSA_WITH_AES_128_CBC_SHA256, + .ECDHE_RSA_WITH_AES_128_CBC_SHA256, + .RSA_WITH_AES_128_CBC_SHA256, + .ECDHE_ECDSA_WITH_AES_256_CBC_SHA384, + .ECDHE_RSA_WITH_AES_256_CBC_SHA384, + .ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, + .ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, + .ECDHE_RSA_WITH_AES_128_GCM_SHA256, + .ECDHE_RSA_WITH_AES_256_GCM_SHA384, + .ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256, + .ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256, + => |comptime_tag| { + return @unionInit(Cipher, @tagName(comptime_tag), CipherType(comptime_tag).init(key_material, side)); + }, + else => return error.TlsIllegalParameter, + } + } + + // tls 1.3 handshake or application cipher + pub fn initTls13(tag: CipherSuite, secret: Transcript.Secret, side: proto.Side) !Cipher { + return switch (tag) { + inline .AES_128_GCM_SHA256, + .AES_256_GCM_SHA384, + .CHACHA20_POLY1305_SHA256, + .AEGIS_128L_SHA256, + => |comptime_tag| { + return @unionInit(Cipher, @tagName(comptime_tag), CipherType(comptime_tag).init(secret, side)); + }, + else => return error.TlsIllegalParameter, + }; + } + + pub fn encrypt( + c: *Cipher, + buf: []u8, + content_type: proto.ContentType, + cleartext: []const u8, + ) ![]const u8 { + return switch (c.*) { + inline else => |*f| try f.encrypt(buf, content_type, cleartext), + }; + } + + pub fn decrypt( + c: *Cipher, + buf: []u8, + rec: Record, + ) !struct { proto.ContentType, []u8 } { + return switch (c.*) { + inline else => |*f| { + const content_type, const cleartext = try f.decrypt(buf, rec); + if (cleartext.len > max_cleartext_len) return error.TlsRecordOverflow; + return .{ content_type, cleartext }; + }, + }; + } + + pub fn encryptSeq(c: Cipher) u64 { + return switch (c) { + inline else => |f| f.encrypt_seq, + }; + } + + pub fn keyUpdateEncrypt(c: *Cipher) !void { + return switch (c.*) { + inline .AES_128_GCM_SHA256, + .AES_256_GCM_SHA384, + .CHACHA20_POLY1305_SHA256, + .AEGIS_128L_SHA256, + => |*f| f.keyUpdateEncrypt(), + // can't happen on tls 1.2 + else => return error.TlsUnexpectedMessage, + }; + } + pub fn keyUpdateDecrypt(c: *Cipher) !void { + return switch (c.*) { + inline .AES_128_GCM_SHA256, + .AES_256_GCM_SHA384, + .CHACHA20_POLY1305_SHA256, + .AEGIS_128L_SHA256, + => |*f| f.keyUpdateDecrypt(), + // can't happen on tls 1.2 + else => return error.TlsUnexpectedMessage, + }; + } +}; + +fn Aead12Type(comptime AeadType: type) type { + return struct { + const explicit_iv_len = 8; + const key_len = AeadType.key_length; + const auth_tag_len = AeadType.tag_length; + const nonce_len = AeadType.nonce_length; + const iv_len = AeadType.nonce_length - explicit_iv_len; + const encrypt_overhead = record.header_len + explicit_iv_len + auth_tag_len; + + encrypt_key: [key_len]u8, + decrypt_key: [key_len]u8, + encrypt_iv: [iv_len]u8, + decrypt_iv: [iv_len]u8, + encrypt_seq: u64 = 0, + decrypt_seq: u64 = 0, + rnd: std.Random = crypto.random, + + const Self = @This(); + + fn init(key_material: []const u8, side: proto.Side) Self { + const client_key = key_material[0..key_len].*; + const server_key = key_material[key_len..][0..key_len].*; + const client_iv = key_material[2 * key_len ..][0..iv_len].*; + const server_iv = key_material[2 * key_len + iv_len ..][0..iv_len].*; + return .{ + .encrypt_key = if (side == .client) client_key else server_key, + .decrypt_key = if (side == .client) server_key else client_key, + .encrypt_iv = if (side == .client) client_iv else server_iv, + .decrypt_iv = if (side == .client) server_iv else client_iv, + }; + } + + /// Returns encrypted tls record in format: + /// ----------------- buf ---------------------- + /// header | explicit_iv | ciphertext | auth_tag + /// + /// tls record header: 5 bytes + /// explicit_iv: 8 bytes + /// ciphertext: same length as cleartext + /// auth_tag: 16 bytes + pub fn encrypt( + self: *Self, + buf: []u8, + content_type: proto.ContentType, + cleartext: []const u8, + ) ![]const u8 { + const record_len = record.header_len + explicit_iv_len + cleartext.len + auth_tag_len; + if (buf.len < record_len) return error.BufferOverflow; + + const header = buf[0..record.header_len]; + const explicit_iv = buf[record.header_len..][0..explicit_iv_len]; + const ciphertext = buf[record.header_len + explicit_iv_len ..][0..cleartext.len]; + const auth_tag = buf[record.header_len + explicit_iv_len + cleartext.len ..][0..auth_tag_len]; + + header.* = record.header(content_type, explicit_iv_len + cleartext.len + auth_tag_len); + self.rnd.bytes(explicit_iv); + const iv = self.encrypt_iv ++ explicit_iv.*; + const ad = additionalData(self.encrypt_seq, content_type, cleartext.len); + AeadType.encrypt(ciphertext, auth_tag, cleartext, &ad, iv, self.encrypt_key); + self.encrypt_seq +%= 1; + + return buf[0..record_len]; + } + + /// Decrypts payload into cleartext. Returns tls record content type and + /// cleartext. + /// Accepts tls record header and payload: + /// header | ----------- payload --------------- + /// header | explicit_iv | ciphertext | auth_tag + pub fn decrypt( + self: *Self, + buf: []u8, + rec: Record, + ) !struct { proto.ContentType, []u8 } { + const overhead = explicit_iv_len + auth_tag_len; + if (rec.payload.len < overhead) return error.TlsDecryptError; + const cleartext_len = rec.payload.len - overhead; + if (buf.len < cleartext_len) return error.BufferOverflow; + + const explicit_iv = rec.payload[0..explicit_iv_len]; + const ciphertext = rec.payload[explicit_iv_len..][0..cleartext_len]; + const auth_tag = rec.payload[explicit_iv_len + cleartext_len ..][0..auth_tag_len]; + + const iv = self.decrypt_iv ++ explicit_iv.*; + const ad = additionalData(self.decrypt_seq, rec.content_type, cleartext_len); + const cleartext = buf[0..cleartext_len]; + AeadType.decrypt(cleartext, ciphertext, auth_tag.*, &ad, iv, self.decrypt_key) catch return error.TlsDecryptError; + self.decrypt_seq +%= 1; + return .{ rec.content_type, cleartext }; + } + }; +} + +fn Aead12ChaChaType(comptime AeadType: type) type { + return struct { + const key_len = AeadType.key_length; + const auth_tag_len = AeadType.tag_length; + const nonce_len = AeadType.nonce_length; + const encrypt_overhead = record.header_len + auth_tag_len; + + encrypt_key: [key_len]u8, + decrypt_key: [key_len]u8, + encrypt_iv: [nonce_len]u8, + decrypt_iv: [nonce_len]u8, + encrypt_seq: u64 = 0, + decrypt_seq: u64 = 0, + + const Self = @This(); + + fn init(key_material: []const u8, side: proto.Side) Self { + const client_key = key_material[0..key_len].*; + const server_key = key_material[key_len..][0..key_len].*; + const client_iv = key_material[2 * key_len ..][0..nonce_len].*; + const server_iv = key_material[2 * key_len + nonce_len ..][0..nonce_len].*; + return .{ + .encrypt_key = if (side == .client) client_key else server_key, + .decrypt_key = if (side == .client) server_key else client_key, + .encrypt_iv = if (side == .client) client_iv else server_iv, + .decrypt_iv = if (side == .client) server_iv else client_iv, + }; + } + + /// Returns encrypted tls record in format: + /// ------------ buf ------------- + /// header | ciphertext | auth_tag + /// + /// tls record header: 5 bytes + /// ciphertext: same length as cleartext + /// auth_tag: 16 bytes + pub fn encrypt( + self: *Self, + buf: []u8, + content_type: proto.ContentType, + cleartext: []const u8, + ) ![]const u8 { + const record_len = record.header_len + cleartext.len + auth_tag_len; + if (buf.len < record_len) return error.BufferOverflow; + + const ciphertext = buf[record.header_len..][0..cleartext.len]; + const auth_tag = buf[record.header_len + ciphertext.len ..][0..auth_tag_len]; + + const ad = additionalData(self.encrypt_seq, content_type, cleartext.len); + const iv = ivWithSeq(nonce_len, self.encrypt_iv, self.encrypt_seq); + AeadType.encrypt(ciphertext, auth_tag, cleartext, &ad, iv, self.encrypt_key); + self.encrypt_seq +%= 1; + + buf[0..record.header_len].* = record.header(content_type, ciphertext.len + auth_tag.len); + return buf[0..record_len]; + } + + /// Decrypts payload into cleartext. Returns tls record content type and + /// cleartext. + /// Accepts tls record header and payload: + /// header | ----- payload ------- + /// header | ciphertext | auth_tag + pub fn decrypt( + self: *Self, + buf: []u8, + rec: Record, + ) !struct { proto.ContentType, []u8 } { + const overhead = auth_tag_len; + if (rec.payload.len < overhead) return error.TlsDecryptError; + const cleartext_len = rec.payload.len - overhead; + if (buf.len < cleartext_len) return error.BufferOverflow; + + const ciphertext = rec.payload[0..cleartext_len]; + const auth_tag = rec.payload[cleartext_len..][0..auth_tag_len]; + const cleartext = buf[0..cleartext_len]; + + const ad = additionalData(self.decrypt_seq, rec.content_type, cleartext_len); + const iv = ivWithSeq(nonce_len, self.decrypt_iv, self.decrypt_seq); + AeadType.decrypt(cleartext, ciphertext, auth_tag.*, &ad, iv, self.decrypt_key) catch return error.TlsDecryptError; + self.decrypt_seq +%= 1; + return .{ rec.content_type, cleartext }; + } + }; +} + +fn Aead13Type(comptime AeadType: type, comptime Hash: type) type { + return struct { + const Hmac = crypto.auth.hmac.Hmac(Hash); + const Hkdf = crypto.kdf.hkdf.Hkdf(Hmac); + + const key_len = AeadType.key_length; + const auth_tag_len = AeadType.tag_length; + const nonce_len = AeadType.nonce_length; + const digest_len = Hash.digest_length; + const encrypt_overhead = record.header_len + 1 + auth_tag_len; + + encrypt_secret: [digest_len]u8, + decrypt_secret: [digest_len]u8, + encrypt_key: [key_len]u8, + decrypt_key: [key_len]u8, + encrypt_iv: [nonce_len]u8, + decrypt_iv: [nonce_len]u8, + encrypt_seq: u64 = 0, + decrypt_seq: u64 = 0, + + const Self = @This(); + + pub fn init(secret: Transcript.Secret, side: proto.Side) Self { + var self = Self{ + .encrypt_secret = if (side == .client) secret.client[0..digest_len].* else secret.server[0..digest_len].*, + .decrypt_secret = if (side == .server) secret.client[0..digest_len].* else secret.server[0..digest_len].*, + .encrypt_key = undefined, + .decrypt_key = undefined, + .encrypt_iv = undefined, + .decrypt_iv = undefined, + }; + self.keyGenerate(); + return self; + } + + fn keyGenerate(self: *Self) void { + self.encrypt_key = hkdfExpandLabel(Hkdf, self.encrypt_secret, "key", "", key_len); + self.decrypt_key = hkdfExpandLabel(Hkdf, self.decrypt_secret, "key", "", key_len); + self.encrypt_iv = hkdfExpandLabel(Hkdf, self.encrypt_secret, "iv", "", nonce_len); + self.decrypt_iv = hkdfExpandLabel(Hkdf, self.decrypt_secret, "iv", "", nonce_len); + } + + pub fn keyUpdateEncrypt(self: *Self) void { + self.encrypt_secret = hkdfExpandLabel(Hkdf, self.encrypt_secret, "traffic upd", "", digest_len); + self.encrypt_seq = 0; + self.keyGenerate(); + } + + pub fn keyUpdateDecrypt(self: *Self) void { + self.decrypt_secret = hkdfExpandLabel(Hkdf, self.decrypt_secret, "traffic upd", "", digest_len); + self.decrypt_seq = 0; + self.keyGenerate(); + } + + /// Returns encrypted tls record in format: + /// ------------ buf ------------- + /// header | ciphertext | auth_tag + /// + /// tls record header: 5 bytes + /// ciphertext: cleartext len + 1 byte content type + /// auth_tag: 16 bytes + pub fn encrypt( + self: *Self, + buf: []u8, + content_type: proto.ContentType, + cleartext: []const u8, + ) ![]const u8 { + const payload_len = cleartext.len + 1 + auth_tag_len; + const record_len = record.header_len + payload_len; + if (buf.len < record_len) return error.BufferOverflow; + + const header = buf[0..record.header_len]; + header.* = record.header(.application_data, payload_len); + + // Skip @memcpy if cleartext is already part of the buf at right position + if (&cleartext[0] != &buf[record.header_len]) { + @memcpy(buf[record.header_len..][0..cleartext.len], cleartext); + } + buf[record.header_len + cleartext.len] = @intFromEnum(content_type); + const ciphertext = buf[record.header_len..][0 .. cleartext.len + 1]; + const auth_tag = buf[record.header_len + ciphertext.len ..][0..auth_tag_len]; + + const iv = ivWithSeq(nonce_len, self.encrypt_iv, self.encrypt_seq); + AeadType.encrypt(ciphertext, auth_tag, ciphertext, header, iv, self.encrypt_key); + self.encrypt_seq +%= 1; + return buf[0..record_len]; + } + + /// Decrypts payload into cleartext. Returns tls record content type and + /// cleartext. + /// Accepts tls record header and payload: + /// header | ------- payload --------- + /// header | ciphertext | auth_tag + /// header | cleartext + ct | auth_tag + /// Ciphertext after decryption contains cleartext and content type (1 byte). + pub fn decrypt( + self: *Self, + buf: []u8, + rec: Record, + ) !struct { proto.ContentType, []u8 } { + const overhead = auth_tag_len + 1; + if (rec.payload.len < overhead) return error.TlsDecryptError; + const ciphertext_len = rec.payload.len - auth_tag_len; + if (buf.len < ciphertext_len) return error.BufferOverflow; + + const ciphertext = rec.payload[0..ciphertext_len]; + const auth_tag = rec.payload[ciphertext_len..][0..auth_tag_len]; + + const iv = ivWithSeq(nonce_len, self.decrypt_iv, self.decrypt_seq); + AeadType.decrypt(buf[0..ciphertext_len], ciphertext, auth_tag.*, rec.header, iv, self.decrypt_key) catch return error.TlsBadRecordMac; + + // Remove zero bytes padding + var content_type_idx: usize = ciphertext_len - 1; + while (buf[content_type_idx] == 0 and content_type_idx > 0) : (content_type_idx -= 1) {} + + const cleartext = buf[0..content_type_idx]; + const content_type: proto.ContentType = @enumFromInt(buf[content_type_idx]); + self.decrypt_seq +%= 1; + return .{ content_type, cleartext }; + } + }; +} + +fn CbcType(comptime BlockCipher: type, comptime HashType: type) type { + const CBC = @import("cbc/main.zig").CBC(BlockCipher); + return struct { + const mac_len = HashType.digest_length; // 20, 32, 48 bytes for sha1, sha256, sha384 + const key_len = BlockCipher.key_bits / 8; // 16, 32 for Aes128, Aes256 + const iv_len = 16; + const encrypt_overhead = record.header_len + iv_len + mac_len + max_padding; + + pub const Hmac = crypto.auth.hmac.Hmac(HashType); + const paddedLength = CBC.paddedLength; + const max_padding = 16; + + encrypt_secret: [mac_len]u8, + decrypt_secret: [mac_len]u8, + encrypt_key: [key_len]u8, + decrypt_key: [key_len]u8, + encrypt_seq: u64 = 0, + decrypt_seq: u64 = 0, + rnd: std.Random = crypto.random, + + const Self = @This(); + + fn init(key_material: []const u8, side: proto.Side) Self { + const client_secret = key_material[0..mac_len].*; + const server_secret = key_material[mac_len..][0..mac_len].*; + const client_key = key_material[2 * mac_len ..][0..key_len].*; + const server_key = key_material[2 * mac_len + key_len ..][0..key_len].*; + return .{ + .encrypt_secret = if (side == .client) client_secret else server_secret, + .decrypt_secret = if (side == .client) server_secret else client_secret, + .encrypt_key = if (side == .client) client_key else server_key, + .decrypt_key = if (side == .client) server_key else client_key, + }; + } + + /// Returns encrypted tls record in format: + /// ----------------- buf ----------------- + /// header | iv | ------ ciphertext ------- + /// header | iv | cleartext | mac | padding + /// + /// tls record header: 5 bytes + /// iv: 16 bytes + /// ciphertext: cleartext length + mac + padding + /// mac: 20, 32 or 48 (sha1, sha256, sha384) + /// padding: 1-16 bytes + /// + /// Max encrypt buf overhead = iv + mac + padding (1-16) + /// aes_128_cbc_sha => 16 + 20 + 16 = 52 + /// aes_128_cbc_sha256 => 16 + 32 + 16 = 64 + /// aes_256_cbc_sha384 => 16 + 48 + 16 = 80 + pub fn encrypt( + self: *Self, + buf: []u8, + content_type: proto.ContentType, + cleartext: []const u8, + ) ![]const u8 { + const max_record_len = record.header_len + iv_len + cleartext.len + mac_len + max_padding; + if (buf.len < max_record_len) return error.BufferOverflow; + const cleartext_idx = record.header_len + iv_len; // position of cleartext in buf + @memcpy(buf[cleartext_idx..][0..cleartext.len], cleartext); + + { // calculate mac from (ad + cleartext) + // ... | ad | cleartext | mac | ... + // | -- mac msg -- | mac | + const ad = additionalData(self.encrypt_seq, content_type, cleartext.len); + const mac_msg = buf[cleartext_idx - ad.len ..][0 .. ad.len + cleartext.len]; + @memcpy(mac_msg[0..ad.len], &ad); + const mac = buf[cleartext_idx + cleartext.len ..][0..mac_len]; + Hmac.create(mac, mac_msg, &self.encrypt_secret); + self.encrypt_seq +%= 1; + } + + // ... | cleartext | mac | ... + // ... | -- plaintext --- ... + // ... | cleartext | mac | padding + // ... | ------ ciphertext ------- + const unpadded_len = cleartext.len + mac_len; + const padded_len = paddedLength(unpadded_len); + const plaintext = buf[cleartext_idx..][0..unpadded_len]; + const ciphertext = buf[cleartext_idx..][0..padded_len]; + + // Add header and iv at the buf start + // header | iv | ... + buf[0..record.header_len].* = record.header(content_type, iv_len + ciphertext.len); + const iv = buf[record.header_len..][0..iv_len]; + self.rnd.bytes(iv); + + // encrypt plaintext into ciphertext + CBC.init(self.encrypt_key).encrypt(ciphertext, plaintext, iv[0..iv_len].*); + + // header | iv | ------ ciphertext ------- + return buf[0 .. record.header_len + iv_len + ciphertext.len]; + } + + /// Decrypts payload into cleartext. Returns tls record content type and + /// cleartext. + pub fn decrypt( + self: *Self, + buf: []u8, + rec: Record, + ) !struct { proto.ContentType, []u8 } { + if (rec.payload.len < iv_len + mac_len + 1) return error.TlsDecryptError; + + // --------- payload ------------ + // iv | ------ ciphertext ------- + // iv | cleartext | mac | padding + const iv = rec.payload[0..iv_len]; + const ciphertext = rec.payload[iv_len..]; + + if (buf.len < ciphertext.len + additional_data_len) return error.BufferOverflow; + // ---------- buf --------------- + // ad | ------ plaintext -------- + // ad | cleartext | mac | padding + const plaintext = buf[additional_data_len..][0..ciphertext.len]; + // decrypt ciphertext -> plaintext + CBC.init(self.decrypt_key).decrypt(plaintext, ciphertext, iv[0..iv_len].*) catch return error.TlsDecryptError; + + // get padding len from last padding byte + const padding_len = plaintext[plaintext.len - 1] + 1; + if (plaintext.len < mac_len + padding_len) return error.TlsDecryptError; + // split plaintext into cleartext and mac + const cleartext_len = plaintext.len - mac_len - padding_len; + const cleartext = plaintext[0..cleartext_len]; + const mac = plaintext[cleartext_len..][0..mac_len]; + + // write ad to the buf + var ad = additionalData(self.decrypt_seq, rec.content_type, cleartext_len); + @memcpy(buf[0..ad.len], &ad); + const mac_msg = buf[0 .. ad.len + cleartext_len]; + self.decrypt_seq +%= 1; + + // calculate expected mac and compare with received mac + var expected_mac: [mac_len]u8 = undefined; + Hmac.create(&expected_mac, mac_msg, &self.decrypt_secret); + if (!std.mem.eql(u8, &expected_mac, mac)) + return error.TlsBadRecordMac; + + return .{ rec.content_type, cleartext }; + } + }; +} + +// xor lower 8 iv bytes with seq +fn ivWithSeq(comptime nonce_len: usize, iv: [nonce_len]u8, seq: u64) [nonce_len]u8 { + var res = iv; + const buf = res[nonce_len - 8 ..]; + const operand = std.mem.readInt(u64, buf, .big); + std.mem.writeInt(u64, buf, operand ^ seq, .big); + return res; +} + +pub const additional_data_len = record.header_len + @sizeOf(u64); + +fn additionalData(seq: u64, content_type: proto.ContentType, payload_len: usize) [additional_data_len]u8 { + const header = record.header(content_type, payload_len); + var seq_buf: [8]u8 = undefined; + std.mem.writeInt(u64, &seq_buf, seq, .big); + return seq_buf ++ header; +} + +// Cipher suites lists. In the order of preference. +// For the preference using grades priority and rules from Go project. +// https://ciphersuite.info/page/faq/ +// https://github.com/golang/go/blob/73186ba00251b3ed8baaab36e4f5278c7681155b/src/crypto/tls/cipher_suites.go#L226 +pub const cipher_suites = struct { + const tls12_secure = if (crypto.core.aes.has_hardware_support) [_]CipherSuite{ + // recommended + .ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, + .ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, + .ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256, + // secure + .ECDHE_RSA_WITH_AES_128_GCM_SHA256, + .ECDHE_RSA_WITH_AES_256_GCM_SHA384, + .ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256, + } else [_]CipherSuite{ + // recommended + .ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256, + .ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, + .ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, + + // secure + .ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256, + .ECDHE_RSA_WITH_AES_128_GCM_SHA256, + .ECDHE_RSA_WITH_AES_256_GCM_SHA384, + }; + const tls12_week = [_]CipherSuite{ + // week + .ECDHE_ECDSA_WITH_AES_128_CBC_SHA256, + .ECDHE_ECDSA_WITH_AES_256_CBC_SHA384, + .ECDHE_ECDSA_WITH_AES_128_CBC_SHA, + .ECDHE_RSA_WITH_AES_128_CBC_SHA256, + .ECDHE_RSA_WITH_AES_256_CBC_SHA384, + .ECDHE_RSA_WITH_AES_128_CBC_SHA, + + .RSA_WITH_AES_128_CBC_SHA256, + .RSA_WITH_AES_128_CBC_SHA, + }; + pub const tls13_ = if (crypto.core.aes.has_hardware_support) [_]CipherSuite{ + .AES_128_GCM_SHA256, + .AES_256_GCM_SHA384, + .CHACHA20_POLY1305_SHA256, + // Excluded because didn't find server which supports it to test + // .AEGIS_128L_SHA256 + } else [_]CipherSuite{ + .CHACHA20_POLY1305_SHA256, + .AES_128_GCM_SHA256, + .AES_256_GCM_SHA384, + }; + + pub const tls13 = &tls13_; + pub const tls12 = &(tls12_secure ++ tls12_week); + pub const secure = &(tls13_ ++ tls12_secure); + pub const all = &(tls13_ ++ tls12_secure ++ tls12_week); + + pub fn includes(list: []const CipherSuite, cs: CipherSuite) bool { + for (list) |s| { + if (cs == s) return true; + } + return false; + } +}; + +// Week, secure, recommended grades are from https://ciphersuite.info/page/faq/ +pub const CipherSuite = enum(u16) { + // tls 1.2 cbc sha1 + ECDHE_ECDSA_WITH_AES_128_CBC_SHA = 0xc009, // week + ECDHE_RSA_WITH_AES_128_CBC_SHA = 0xc013, // week + RSA_WITH_AES_128_CBC_SHA = 0x002F, // week + // tls 1.2 cbc sha256 + ECDHE_ECDSA_WITH_AES_128_CBC_SHA256 = 0xc023, // week + ECDHE_RSA_WITH_AES_128_CBC_SHA256 = 0xc027, // week + RSA_WITH_AES_128_CBC_SHA256 = 0x003c, // week + // tls 1.2 cbc sha384 + ECDHE_ECDSA_WITH_AES_256_CBC_SHA384 = 0xc024, // week + ECDHE_RSA_WITH_AES_256_CBC_SHA384 = 0xc028, // week + // tls 1.2 gcm + ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 = 0xc02b, // recommended + ECDHE_ECDSA_WITH_AES_256_GCM_SHA384 = 0xc02c, // recommended + ECDHE_RSA_WITH_AES_128_GCM_SHA256 = 0xc02f, // secure + ECDHE_RSA_WITH_AES_256_GCM_SHA384 = 0xc030, // secure + // tls 1.2 chacha + ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256 = 0xcca9, // recommended + ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256 = 0xcca8, // secure + // tls 1.3 (all are recommended) + AES_128_GCM_SHA256 = 0x1301, + AES_256_GCM_SHA384 = 0x1302, + CHACHA20_POLY1305_SHA256 = 0x1303, + AEGIS_128L_SHA256 = 0x1307, + // AEGIS_256_SHA512 = 0x1306, + _, + + pub fn validate(cs: CipherSuite) !void { + if (cipher_suites.includes(cipher_suites.tls12, cs)) return; + if (cipher_suites.includes(cipher_suites.tls13, cs)) return; + return error.TlsIllegalParameter; + } + + pub const Versions = enum { + both, + tls_1_3, + tls_1_2, + }; + + // get tls versions from list of cipher suites + pub fn versions(list: []const CipherSuite) !Versions { + var has_12 = false; + var has_13 = false; + for (list) |cs| { + if (cipher_suites.includes(cipher_suites.tls12, cs)) { + has_12 = true; + } else { + if (cipher_suites.includes(cipher_suites.tls13, cs)) has_13 = true; + } + } + if (has_12 and has_13) return .both; + if (has_12) return .tls_1_2; + if (has_13) return .tls_1_3; + return error.TlsIllegalParameter; + } + + pub const KeyExchangeAlgorithm = enum { + ecdhe, + rsa, + }; + + pub fn keyExchange(s: CipherSuite) KeyExchangeAlgorithm { + return switch (s) { + // Random premaster secret, encrypted with publich key from certificate. + // No server key exchange message. + .RSA_WITH_AES_128_CBC_SHA, + .RSA_WITH_AES_128_CBC_SHA256, + => .rsa, + else => .ecdhe, + }; + } + + pub const HashTag = enum { + sha256, + sha384, + sha512, + }; + + pub inline fn hash(cs: CipherSuite) HashTag { + return switch (cs) { + .ECDHE_RSA_WITH_AES_256_CBC_SHA384, + .ECDHE_RSA_WITH_AES_256_GCM_SHA384, + .ECDHE_ECDSA_WITH_AES_256_CBC_SHA384, + .ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, + .AES_256_GCM_SHA384, + => .sha384, + else => .sha256, + }; + } +}; + +const testing = std.testing; +const testu = @import("testu.zig"); + +test "CipherSuite validate" { + { + const cs: CipherSuite = .AES_256_GCM_SHA384; + try cs.validate(); + try testing.expectEqual(cs.hash(), .sha384); + try testing.expectEqual(cs.keyExchange(), .ecdhe); + } + { + const cs: CipherSuite = .AES_128_GCM_SHA256; + try cs.validate(); + try testing.expectEqual(.sha256, cs.hash()); + try testing.expectEqual(.ecdhe, cs.keyExchange()); + } + for (cipher_suites.tls12) |cs| { + try cs.validate(); + _ = cs.hash(); + _ = cs.keyExchange(); + } +} + +test "CipherSuite versions" { + try testing.expectEqual(.tls_1_3, CipherSuite.versions(&[_]CipherSuite{.AES_128_GCM_SHA256})); + try testing.expectEqual(.both, CipherSuite.versions(&[_]CipherSuite{ .AES_128_GCM_SHA256, .ECDHE_ECDSA_WITH_AES_128_CBC_SHA })); + try testing.expectEqual(.tls_1_2, CipherSuite.versions(&[_]CipherSuite{.RSA_WITH_AES_128_CBC_SHA})); +} + +test "gcm 1.2 encrypt overhead" { + inline for ([_]type{ + Aead12Aes128Gcm, + Aead12Aes256Gcm, + }) |T| { + { + const expected_key_len = switch (T) { + Aead12Aes128Gcm => 16, + Aead12Aes256Gcm => 32, + else => unreachable, + }; + try testing.expectEqual(expected_key_len, T.key_len); + try testing.expectEqual(16, T.auth_tag_len); + try testing.expectEqual(12, T.nonce_len); + try testing.expectEqual(4, T.iv_len); + try testing.expectEqual(29, T.encrypt_overhead); + } + } +} + +test "cbc 1.2 encrypt overhead" { + try testing.expectEqual(85, encrypt_overhead_tls_12); + + inline for ([_]type{ + CbcAes128Sha1, + CbcAes128Sha256, + CbcAes256Sha384, + }) |T| { + switch (T) { + CbcAes128Sha1 => { + try testing.expectEqual(20, T.mac_len); + try testing.expectEqual(16, T.key_len); + try testing.expectEqual(57, T.encrypt_overhead); + }, + CbcAes128Sha256 => { + try testing.expectEqual(32, T.mac_len); + try testing.expectEqual(16, T.key_len); + try testing.expectEqual(69, T.encrypt_overhead); + }, + CbcAes256Sha384 => { + try testing.expectEqual(48, T.mac_len); + try testing.expectEqual(32, T.key_len); + try testing.expectEqual(85, T.encrypt_overhead); + }, + else => unreachable, + } + try testing.expectEqual(16, T.paddedLength(1)); // cbc block padding + try testing.expectEqual(16, T.iv_len); + } +} + +test "overhead tls 1.3" { + try testing.expectEqual(22, encrypt_overhead_tls_13); + try expectOverhead(Aes128GcmSha256, 16, 16, 12, 22); + try expectOverhead(Aes256GcmSha384, 32, 16, 12, 22); + try expectOverhead(ChaChaSha256, 32, 16, 12, 22); + try expectOverhead(Aegis128Sha256, 16, 16, 16, 22); + // and tls 1.2 chacha + try expectOverhead(Aead12ChaCha, 32, 16, 12, 21); +} + +fn expectOverhead(T: type, key_len: usize, auth_tag_len: usize, nonce_len: usize, overhead: usize) !void { + try testing.expectEqual(key_len, T.key_len); + try testing.expectEqual(auth_tag_len, T.auth_tag_len); + try testing.expectEqual(nonce_len, T.nonce_len); + try testing.expectEqual(overhead, T.encrypt_overhead); +} + +test "client/server encryption tls 1.3" { + inline for (cipher_suites.tls13) |cs| { + var buf: [256]u8 = undefined; + testu.fill(&buf); + const secret = Transcript.Secret{ + .client = buf[0..128], + .server = buf[128..], + }; + var client_cipher = try Cipher.initTls13(cs, secret, .client); + var server_cipher = try Cipher.initTls13(cs, secret, .server); + try encryptDecrypt(&client_cipher, &server_cipher); + + try client_cipher.keyUpdateEncrypt(); + try server_cipher.keyUpdateDecrypt(); + try encryptDecrypt(&client_cipher, &server_cipher); + + try client_cipher.keyUpdateDecrypt(); + try server_cipher.keyUpdateEncrypt(); + try encryptDecrypt(&client_cipher, &server_cipher); + } +} + +test "client/server encryption tls 1.2" { + inline for (cipher_suites.tls12) |cs| { + var key_material: [256]u8 = undefined; + testu.fill(&key_material); + var client_cipher = try Cipher.initTls12(cs, &key_material, .client); + var server_cipher = try Cipher.initTls12(cs, &key_material, .server); + try encryptDecrypt(&client_cipher, &server_cipher); + } +} + +fn encryptDecrypt(client_cipher: *Cipher, server_cipher: *Cipher) !void { + const cleartext = + \\ Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do + \\ eiusmod tempor incididunt ut labore et dolore magna aliqua. + ; + var buf: [256]u8 = undefined; + + { // client to server + // encrypt + const encrypted = try client_cipher.encrypt(&buf, .application_data, cleartext); + const expected_encrypted_len = switch (client_cipher.*) { + inline else => |f| brk: { + const T = @TypeOf(f); + break :brk switch (T) { + CbcAes128Sha1, + CbcAes128Sha256, + CbcAes256Sha256, + CbcAes256Sha384, + => record.header_len + T.paddedLength(T.iv_len + cleartext.len + T.mac_len), + Aead12Aes128Gcm, + Aead12Aes256Gcm, + Aead12ChaCha, + Aes128GcmSha256, + Aes256GcmSha384, + ChaChaSha256, + Aegis128Sha256, + => cleartext.len + T.encrypt_overhead, + else => unreachable, + }; + }, + }; + try testing.expectEqual(expected_encrypted_len, encrypted.len); + // decrypt + const content_type, const decrypted = try server_cipher.decrypt(&buf, Record.init(encrypted)); + try testing.expectEqualSlices(u8, cleartext, decrypted); + try testing.expectEqual(.application_data, content_type); + } + // server to client + { + const encrypted = try server_cipher.encrypt(&buf, .application_data, cleartext); + const content_type, const decrypted = try client_cipher.decrypt(&buf, Record.init(encrypted)); + try testing.expectEqualSlices(u8, cleartext, decrypted); + try testing.expectEqual(.application_data, content_type); + } +} diff --git a/src/http/async/tls.zig/connection.zig b/src/http/async/tls.zig/connection.zig new file mode 100644 index 00000000..9ccd9c53 --- /dev/null +++ b/src/http/async/tls.zig/connection.zig @@ -0,0 +1,665 @@ +const std = @import("std"); +const assert = std.debug.assert; + +const proto = @import("protocol.zig"); +const record = @import("record.zig"); +const cipher = @import("cipher.zig"); +const Cipher = cipher.Cipher; + +const async_io = @import("../std/http/Client.zig"); +const Cbk = async_io.Cbk; +const Ctx = async_io.Ctx; + +pub fn connection(stream: anytype) Connection(@TypeOf(stream)) { + return .{ + .stream = stream, + .rec_rdr = record.reader(stream), + }; +} + +pub fn Connection(comptime Stream: type) type { + return struct { + stream: Stream, // underlying stream + rec_rdr: record.Reader(Stream), + cipher: Cipher = undefined, + + max_encrypt_seq: u64 = std.math.maxInt(u64) - 1, + key_update_requested: bool = false, + + read_buf: []const u8 = "", + received_close_notify: bool = false, + + const Self = @This(); + + /// Encrypts and writes single tls record to the stream. + fn writeRecord(c: *Self, content_type: proto.ContentType, bytes: []const u8) !void { + assert(bytes.len <= cipher.max_cleartext_len); + var write_buf: [cipher.max_ciphertext_record_len]u8 = undefined; + // If key update is requested send key update message and update + // my encryption keys. + if (c.cipher.encryptSeq() >= c.max_encrypt_seq or @atomicLoad(bool, &c.key_update_requested, .monotonic)) { + @atomicStore(bool, &c.key_update_requested, false, .monotonic); + + // If the request_update field is set to "update_requested", + // then the receiver MUST send a KeyUpdate of its own with + // request_update set to "update_not_requested" prior to sending + // its next Application Data record. This mechanism allows + // either side to force an update to the entire connection, but + // causes an implementation which receives multiple KeyUpdates + // while it is silent to respond with a single update. + // + // rfc: https://datatracker.ietf.org/doc/html/rfc8446#autoid-57 + const key_update = &record.handshakeHeader(.key_update, 1) ++ [_]u8{0}; + const rec = try c.cipher.encrypt(&write_buf, .handshake, key_update); + try c.stream.writeAll(rec); + try c.cipher.keyUpdateEncrypt(); + } + const rec = try c.cipher.encrypt(&write_buf, content_type, bytes); + try c.stream.writeAll(rec); + } + + fn writeAlert(c: *Self, err: anyerror) !void { + const cleartext = proto.alertFromError(err); + var buf: [128]u8 = undefined; + const ciphertext = try c.cipher.encrypt(&buf, .alert, &cleartext); + c.stream.writeAll(ciphertext) catch {}; + } + + /// Returns next record of cleartext data. + /// Can be used in iterator like loop without memcpy to another buffer: + /// while (try client.next()) |buf| { ... } + pub fn next(c: *Self) ReadError!?[]const u8 { + const content_type, const data = c.nextRecord() catch |err| { + try c.writeAlert(err); + return err; + } orelse return null; + if (content_type != .application_data) return error.TlsUnexpectedMessage; + return data; + } + + fn nextRecord(c: *Self) ReadError!?struct { proto.ContentType, []const u8 } { + if (c.eof()) return null; + while (true) { + const content_type, const cleartext = try c.rec_rdr.nextDecrypt(&c.cipher) orelse return null; + + switch (content_type) { + .application_data => {}, + .handshake => { + const handshake_type: proto.Handshake = @enumFromInt(cleartext[0]); + switch (handshake_type) { + // skip new session ticket and read next record + .new_session_ticket => continue, + .key_update => { + if (cleartext.len != 5) return error.TlsDecodeError; + // rfc: Upon receiving a KeyUpdate, the receiver MUST + // update its receiving keys. + try c.cipher.keyUpdateDecrypt(); + const key: proto.KeyUpdateRequest = @enumFromInt(cleartext[4]); + switch (key) { + .update_requested => { + @atomicStore(bool, &c.key_update_requested, true, .monotonic); + }, + .update_not_requested => {}, + else => return error.TlsIllegalParameter, + } + // this record is handled read next + continue; + }, + else => {}, + } + }, + .alert => { + if (cleartext.len < 2) return error.TlsUnexpectedMessage; + try proto.Alert.parse(cleartext[0..2].*).toError(); + // server side clean shutdown + c.received_close_notify = true; + return null; + }, + else => return error.TlsUnexpectedMessage, + } + return .{ content_type, cleartext }; + } + } + + pub fn eof(c: *Self) bool { + return c.received_close_notify and c.read_buf.len == 0; + } + + pub fn close(c: *Self) !void { + if (c.received_close_notify) return; + try c.writeRecord(.alert, &proto.Alert.closeNotify()); + } + + // read, write interface + + pub const ReadError = Stream.ReadError || proto.Alert.Error || + error{ + TlsBadVersion, + TlsUnexpectedMessage, + TlsRecordOverflow, + TlsDecryptError, + TlsDecodeError, + TlsBadRecordMac, + TlsIllegalParameter, + BufferOverflow, + }; + pub const WriteError = Stream.WriteError || + error{ + BufferOverflow, + TlsUnexpectedMessage, + }; + + pub const Reader = std.io.Reader(*Self, ReadError, read); + pub const Writer = std.io.Writer(*Self, WriteError, write); + + pub fn reader(c: *Self) Reader { + return .{ .context = c }; + } + + pub fn writer(c: *Self) Writer { + return .{ .context = c }; + } + + /// Encrypts cleartext and writes it to the underlying stream as single + /// tls record. Max single tls record payload length is 1<<14 (16K) + /// bytes. + pub fn write(c: *Self, bytes: []const u8) WriteError!usize { + const n = @min(bytes.len, cipher.max_cleartext_len); + try c.writeRecord(.application_data, bytes[0..n]); + return n; + } + + /// Encrypts cleartext and writes it to the underlying stream. If needed + /// splits cleartext into multiple tls record. + pub fn writeAll(c: *Self, bytes: []const u8) WriteError!void { + var index: usize = 0; + while (index < bytes.len) { + index += try c.write(bytes[index..]); + } + } + + pub fn read(c: *Self, buffer: []u8) ReadError!usize { + if (c.read_buf.len == 0) { + c.read_buf = try c.next() orelse return 0; + } + const n = @min(c.read_buf.len, buffer.len); + @memcpy(buffer[0..n], c.read_buf[0..n]); + c.read_buf = c.read_buf[n..]; + return n; + } + + /// Returns the number of bytes read. If the number read is smaller than + /// `buffer.len`, it means the stream reached the end. + pub fn readAll(c: *Self, buffer: []u8) ReadError!usize { + return c.readAtLeast(buffer, buffer.len); + } + + /// Returns the number of bytes read, calling the underlying read function + /// the minimal number of times until the buffer has at least `len` bytes + /// filled. If the number read is less than `len` it means the stream + /// reached the end. + pub fn readAtLeast(c: *Self, buffer: []u8, len: usize) ReadError!usize { + assert(len <= buffer.len); + var index: usize = 0; + while (index < len) { + const amt = try c.read(buffer[index..]); + if (amt == 0) break; + index += amt; + } + return index; + } + + /// Returns the number of bytes read. If the number read is less than + /// the space provided it means the stream reached the end. + pub fn readv(c: *Self, iovecs: []std.posix.iovec) !usize { + var vp: VecPut = .{ .iovecs = iovecs }; + while (true) { + if (c.read_buf.len == 0) { + c.read_buf = try c.next() orelse break; + } + const n = vp.put(c.read_buf); + const read_buf_len = c.read_buf.len; + c.read_buf = c.read_buf[n..]; + if ((n < read_buf_len) or + (n == read_buf_len and !c.rec_rdr.hasMore())) + break; + } + return vp.total; + } + + fn onWriteAll(ctx: *Ctx, res: anyerror!void) anyerror!void { + res catch |err| return ctx.pop(err); + + if (ctx._tls_write_bytes.len - ctx._tls_write_index > 0) { + const rec = ctx.conn().tls_client.prepareRecord(ctx.stream(), ctx) catch |err| return ctx.pop(err); + return ctx.stream().async_writeAll(rec, ctx, onWriteAll) catch |err| return ctx.pop(err); + } + + return ctx.pop({}); + } + + pub fn async_writeAll(c: *Self, stream: anytype, bytes: []const u8, ctx: *Ctx, comptime cbk: Cbk) !void { + assert(bytes.len <= cipher.max_cleartext_len); + + ctx._tls_write_bytes = bytes; + ctx._tls_write_index = 0; + const rec = try c.prepareRecord(stream, ctx); + + try ctx.push(cbk); + return stream.async_writeAll(rec, ctx, onWriteAll); + } + + fn prepareRecord(c: *Self, stream: anytype, ctx: *Ctx) ![]const u8 { + const len = @min(ctx._tls_write_bytes.len - ctx._tls_write_index, cipher.max_cleartext_len); + + // If key update is requested send key update message and update + // my encryption keys. + if (c.cipher.encryptSeq() >= c.max_encrypt_seq or @atomicLoad(bool, &c.key_update_requested, .monotonic)) { + @atomicStore(bool, &c.key_update_requested, false, .monotonic); + + // If the request_update field is set to "update_requested", + // then the receiver MUST send a KeyUpdate of its own with + // request_update set to "update_not_requested" prior to sending + // its next Application Data record. This mechanism allows + // either side to force an update to the entire connection, but + // causes an implementation which receives multiple KeyUpdates + // while it is silent to respond with a single update. + // + // rfc: https://datatracker.ietf.org/doc/html/rfc8446#autoid-57 + const key_update = &record.handshakeHeader(.key_update, 1) ++ [_]u8{0}; + const rec = try c.cipher.encrypt(&ctx._tls_write_buf, .handshake, key_update); + try stream.writeAll(rec); // TODO async + try c.cipher.keyUpdateEncrypt(); + } + + defer ctx._tls_write_index += len; + return c.cipher.encrypt(&ctx._tls_write_buf, .application_data, ctx._tls_write_bytes[ctx._tls_write_index..len]); + } + + fn onReadv(ctx: *Ctx, res: anyerror!void) anyerror!void { + res catch |err| return ctx.pop(err); + + if (ctx._tls_read_buf == null) { + // end of read + ctx.setLen(ctx._vp.total); + return ctx.pop({}); + } + + while (true) { + const n = ctx._vp.put(ctx._tls_read_buf.?); + const read_buf_len = ctx._tls_read_buf.?.len; + const c = ctx.conn().tls_client; + + if (read_buf_len == 0) { + // read another buffer + c.async_next(ctx.stream(), ctx, onReadv) catch |err| return ctx.pop(err); + } + + ctx._tls_read_buf = ctx._tls_read_buf.?[n..]; + + if ((n < read_buf_len) or (n == read_buf_len and !c.rec_rdr.hasMore())) { + // end of read + ctx.setLen(ctx._vp.total); + return ctx.pop({}); + } + } + } + + pub fn async_readv(c: *Self, stream: anytype, iovecs: []std.posix.iovec, ctx: *Ctx, comptime cbk: Cbk) !void { + try ctx.push(cbk); + ctx._vp = .{ .iovecs = iovecs }; + + return c.async_next(stream, ctx, onReadv); + } + + fn onNext(ctx: *Ctx, res: anyerror!void) anyerror!void { + res catch |err| { + ctx.conn().tls_client.writeAlert(err) catch |e| std.log.err("onNext: write alert: {any}", .{e}); // TODO async + return ctx.pop(err); + }; + + if (ctx._tls_read_content_type != .application_data) { + return ctx.pop(error.TlsUnexpectedMessage); + } + + return ctx.pop({}); + } + + pub fn async_next(c: *Self, stream: anytype, ctx: *Ctx, comptime cbk: Cbk) !void { + try ctx.push(cbk); + + return c.async_next_decrypt(stream, ctx, onNext); + } + + pub fn onNextDecrypt(ctx: *Ctx, res: anyerror!void) anyerror!void { + res catch |err| return ctx.pop(err); + + const c = ctx.conn().tls_client; + // TOOD not sure if this works in my async case... + if (c.eof()) { + ctx._tls_read_buf = null; + return ctx.pop({}); + } + + const content_type = ctx._tls_read_content_type; + + switch (content_type) { + .application_data => {}, + .handshake => { + const handshake_type: proto.Handshake = @enumFromInt(ctx._tls_read_buf.?[0]); + switch (handshake_type) { + // skip new session ticket and read next record + .new_session_ticket => return c.async_next_record(ctx.stream(), ctx, onNextDecrypt) catch |err| return ctx.pop(err), + .key_update => { + if (ctx._tls_read_buf.?.len != 5) return ctx.pop(error.TlsDecodeError); + // rfc: Upon receiving a KeyUpdate, the receiver MUST + // update its receiving keys. + try c.cipher.keyUpdateDecrypt(); + const key: proto.KeyUpdateRequest = @enumFromInt(ctx._tls_read_buf.?[4]); + switch (key) { + .update_requested => { + @atomicStore(bool, &c.key_update_requested, true, .monotonic); + }, + .update_not_requested => {}, + else => return ctx.pop(error.TlsIllegalParameter), + } + // this record is handled read next + c.async_next_record(ctx.stream(), ctx, onNextDecrypt) catch |err| return ctx.pop(err); + }, + else => {}, + } + }, + .alert => { + if (ctx._tls_read_buf.?.len < 2) return ctx.pop(error.TlsUnexpectedMessage); + try proto.Alert.parse(ctx._tls_read_buf.?[0..2].*).toError(); + // server side clean shutdown + c.received_close_notify = true; + ctx._tls_read_buf = null; + return ctx.pop({}); + }, + else => return ctx.pop(error.TlsUnexpectedMessage), + } + + return ctx.pop({}); + } + + pub fn async_next_decrypt(c: *Self, stream: anytype, ctx: *Ctx, comptime cbk: Cbk) !void { + try ctx.push(cbk); + + return c.async_next_record(stream, ctx, onNextDecrypt) catch |err| return ctx.pop(err); + } + + pub fn onNextRecord(ctx: *Ctx, res: anyerror!void) anyerror!void { + res catch |err| return ctx.pop(err); + + const rec = ctx._tls_read_record orelse { + ctx._tls_read_buf = null; + return ctx.pop({}); + }; + + if (rec.protocol_version != .tls_1_2) return error.TlsBadVersion; + + const c = ctx.conn().tls_client; + const cph = &c.cipher; + + ctx._tls_read_content_type, ctx._tls_read_buf = cph.decrypt( + // Reuse reader buffer for cleartext. `rec.header` and + // `rec.payload`(ciphertext) are also pointing somewhere in + // this buffer. Decrypter is first reading then writing a + // block, cleartext has less length then ciphertext, + // cleartext starts from the beginning of the buffer, so + // ciphertext is always ahead of cleartext. + c.rec_rdr.buffer[0..c.rec_rdr.start], + rec, + ) catch |err| return ctx.pop(err); + + return ctx.pop({}); + } + + pub fn async_next_record(c: *Self, stream: anytype, ctx: *Ctx, comptime cbk: Cbk) !void { + try ctx.push(cbk); + + return c.async_reader_next(stream, ctx, onNextRecord); + } + + pub fn onReaderNext(ctx: *Ctx, res: anyerror!void) anyerror!void { + res catch |err| return ctx.pop(err); + + const c = ctx.conn().tls_client; + + const n = ctx.len(); + if (n == 0) { + ctx._tls_read_record = null; + return ctx.pop({}); + } + c.rec_rdr.end += n; + + return c.readNext(ctx); + } + + pub fn readNext(c: *Self, ctx: *Ctx) anyerror!void { + const buffer = c.rec_rdr.buffer[c.rec_rdr.start..c.rec_rdr.end]; + // If we have 5 bytes header. + if (buffer.len >= record.header_len) { + const record_header = buffer[0..record.header_len]; + const payload_len = std.mem.readInt(u16, record_header[3..5], .big); + if (payload_len > cipher.max_ciphertext_len) + return error.TlsRecordOverflow; + const record_len = record.header_len + payload_len; + // If we have whole record + if (buffer.len >= record_len) { + c.rec_rdr.start += record_len; + ctx._tls_read_record = record.Record.init(buffer[0..record_len]); + return ctx.pop({}); + } + } + { // Move dirty part to the start of the buffer. + const n = c.rec_rdr.end - c.rec_rdr.start; + if (n > 0 and c.rec_rdr.start > 0) { + if (c.rec_rdr.start > n) { + @memcpy(c.rec_rdr.buffer[0..n], c.rec_rdr.buffer[c.rec_rdr.start..][0..n]); + } else { + std.mem.copyForwards(u8, c.rec_rdr.buffer[0..n], c.rec_rdr.buffer[c.rec_rdr.start..][0..n]); + } + } + c.rec_rdr.start = 0; + c.rec_rdr.end = n; + } + // Read more from inner_reader. + return ctx.stream() + .async_read(c.rec_rdr.buffer[c.rec_rdr.end..], ctx, onReaderNext) catch |err| return ctx.pop(err); + } + + pub fn async_reader_next(c: *Self, _: anytype, ctx: *Ctx, comptime cbk: Cbk) !void { + try ctx.push(cbk); + return c.readNext(ctx); + } + }; +} + +const testing = std.testing; +const data12 = @import("testdata/tls12.zig"); +const testu = @import("testu.zig"); + +test "encrypt decrypt" { + var output_buf: [1024]u8 = undefined; + const stream = testu.Stream.init(&(data12.server_pong ** 3), &output_buf); + var conn: Connection(@TypeOf(stream)) = .{ .stream = stream, .rec_rdr = record.reader(stream) }; + conn.cipher = try Cipher.initTls12(.ECDHE_RSA_WITH_AES_128_CBC_SHA, &data12.key_material, .client); + conn.cipher.ECDHE_RSA_WITH_AES_128_CBC_SHA.rnd = testu.random(0); // use fixed rng + + conn.stream.output.reset(); + { // encrypt verify data from example + _ = testu.random(0x40); // sets iv to 40, 41, ... 4f + try conn.writeRecord(.handshake, &data12.client_finished); + try testing.expectEqualSlices(u8, &data12.verify_data_encrypted_msg, conn.stream.output.getWritten()); + } + + conn.stream.output.reset(); + { // encrypt ping + const cleartext = "ping"; + _ = testu.random(0); // sets iv to 00, 01, ... 0f + //conn.encrypt_seq = 1; + + try conn.writeAll(cleartext); + try testing.expectEqualSlices(u8, &data12.encrypted_ping_msg, conn.stream.output.getWritten()); + } + { // decrypt server pong message + conn.cipher.ECDHE_RSA_WITH_AES_128_CBC_SHA.decrypt_seq = 1; + try testing.expectEqualStrings("pong", (try conn.next()).?); + } + { // test reader interface + conn.cipher.ECDHE_RSA_WITH_AES_128_CBC_SHA.decrypt_seq = 1; + var rdr = conn.reader(); + var buffer: [4]u8 = undefined; + const n = try rdr.readAll(&buffer); + try testing.expectEqualStrings("pong", buffer[0..n]); + } + { // test readv interface + conn.cipher.ECDHE_RSA_WITH_AES_128_CBC_SHA.decrypt_seq = 1; + var buffer: [9]u8 = undefined; + var iovecs = [_]std.posix.iovec{ + .{ .base = &buffer, .len = 3 }, + .{ .base = buffer[3..], .len = 3 }, + .{ .base = buffer[6..], .len = 3 }, + }; + const n = try conn.readv(iovecs[0..]); + try testing.expectEqual(4, n); + try testing.expectEqualStrings("pong", buffer[0..n]); + } +} + +// Copied from: https://github.com/ziglang/zig/blob/455899668b620dfda40252501c748c0a983555bd/lib/std/crypto/tls/Client.zig#L1354 +/// Abstraction for sending multiple byte buffers to a slice of iovecs. +pub const VecPut = struct { + iovecs: []const std.posix.iovec, + idx: usize = 0, + off: usize = 0, + total: usize = 0, + + /// Returns the amount actually put which is always equal to bytes.len + /// unless the vectors ran out of space. + pub fn put(vp: *VecPut, bytes: []const u8) usize { + if (vp.idx >= vp.iovecs.len) return 0; + var bytes_i: usize = 0; + while (true) { + const v = vp.iovecs[vp.idx]; + const dest = v.base[vp.off..v.len]; + const src = bytes[bytes_i..][0..@min(dest.len, bytes.len - bytes_i)]; + @memcpy(dest[0..src.len], src); + bytes_i += src.len; + vp.off += src.len; + if (vp.off >= v.len) { + vp.off = 0; + vp.idx += 1; + if (vp.idx >= vp.iovecs.len) { + vp.total += bytes_i; + return bytes_i; + } + } + if (bytes_i >= bytes.len) { + vp.total += bytes_i; + return bytes_i; + } + } + } +}; + +test "client/server connection" { + const BufReaderWriter = struct { + buf: []u8, + wp: usize = 0, + rp: usize = 0, + + const Self = @This(); + + pub fn write(self: *Self, bytes: []const u8) !usize { + if (self.wp == self.buf.len) return error.NoSpaceLeft; + + const n = @min(bytes.len, self.buf.len - self.wp); + @memcpy(self.buf[self.wp..][0..n], bytes[0..n]); + self.wp += n; + return n; + } + + pub fn writeAll(self: *Self, bytes: []const u8) !void { + var n: usize = 0; + while (n < bytes.len) { + n += try self.write(bytes[n..]); + } + } + + pub fn read(self: *Self, bytes: []u8) !usize { + const n = @min(bytes.len, self.wp - self.rp); + if (n == 0) return 0; + @memcpy(bytes[0..n], self.buf[self.rp..][0..n]); + self.rp += n; + if (self.rp == self.wp) { + self.wp = 0; + self.rp = 0; + } + return n; + } + }; + + const TestStream = struct { + inner_stream: *BufReaderWriter, + const Self = @This(); + pub const ReadError = error{}; + pub const WriteError = error{NoSpaceLeft}; + pub fn read(self: *Self, bytes: []u8) !usize { + return try self.inner_stream.read(bytes); + } + pub fn writeAll(self: *Self, bytes: []const u8) !void { + return try self.inner_stream.writeAll(bytes); + } + }; + + const buf_len = 32 * 1024; + const tls_records_in_buf = (std.math.divCeil(comptime_int, buf_len, cipher.max_cleartext_len) catch unreachable); + const overhead: usize = tls_records_in_buf * @import("cipher.zig").encrypt_overhead_tls_13; + var buf: [buf_len + overhead]u8 = undefined; + var inner_stream = BufReaderWriter{ .buf = &buf }; + + const cipher_client, const cipher_server = brk: { + const Transcript = @import("transcript.zig").Transcript; + const CipherSuite = @import("cipher.zig").CipherSuite; + const cipher_suite: CipherSuite = .AES_256_GCM_SHA384; + + var rnd: [128]u8 = undefined; + std.crypto.random.bytes(&rnd); + const secret = Transcript.Secret{ + .client = rnd[0..64], + .server = rnd[64..], + }; + + break :brk .{ + try Cipher.initTls13(cipher_suite, secret, .client), + try Cipher.initTls13(cipher_suite, secret, .server), + }; + }; + + var conn1 = connection(TestStream{ .inner_stream = &inner_stream }); + conn1.cipher = cipher_client; + + var conn2 = connection(TestStream{ .inner_stream = &inner_stream }); + conn2.cipher = cipher_server; + + var prng = std.Random.DefaultPrng.init(0); + const random = prng.random(); + var send_buf: [buf_len]u8 = undefined; + var recv_buf: [buf_len]u8 = undefined; + random.bytes(&send_buf); // fill send buffer with random bytes + + for (0..16) |_| { + const n = buf_len; //random.uintLessThan(usize, buf_len); + + const sent = send_buf[0..n]; + try conn1.writeAll(sent); + const r = try conn2.readAll(&recv_buf); + const received = recv_buf[0..r]; + + try testing.expectEqual(n, r); + try testing.expectEqualSlices(u8, sent, received); + } +} diff --git a/src/http/async/tls.zig/handshake_client.zig b/src/http/async/tls.zig/handshake_client.zig new file mode 100644 index 00000000..e7b48cf6 --- /dev/null +++ b/src/http/async/tls.zig/handshake_client.zig @@ -0,0 +1,955 @@ +const std = @import("std"); +const assert = std.debug.assert; +const crypto = std.crypto; +const mem = std.mem; +const Certificate = crypto.Certificate; + +const cipher = @import("cipher.zig"); +const Cipher = cipher.Cipher; +const CipherSuite = cipher.CipherSuite; +const cipher_suites = cipher.cipher_suites; +const Transcript = @import("transcript.zig").Transcript; +const record = @import("record.zig"); +const rsa = @import("rsa/rsa.zig"); +const key_log = @import("key_log.zig"); +const PrivateKey = @import("PrivateKey.zig"); +const proto = @import("protocol.zig"); + +const common = @import("handshake_common.zig"); +const dupe = common.dupe; +const CertificateBuilder = common.CertificateBuilder; +const CertificateParser = common.CertificateParser; +const DhKeyPair = common.DhKeyPair; +const CertBundle = common.CertBundle; +const CertKeyPair = common.CertKeyPair; + +pub const Options = struct { + host: []const u8, + /// Set of root certificate authorities that clients use when verifying + /// server certificates. + root_ca: CertBundle, + + /// Controls whether a client verifies the server's certificate chain and + /// host name. + insecure_skip_verify: bool = false, + + /// List of cipher suites to use. + /// To use just tls 1.3 cipher suites: + /// .cipher_suites = &tls.CipherSuite.tls13, + /// To select particular cipher suite: + /// .cipher_suites = &[_]tls.CipherSuite{tls.CipherSuite.CHACHA20_POLY1305_SHA256}, + cipher_suites: []const CipherSuite = cipher_suites.all, + + /// List of named groups to use. + /// To use specific named group: + /// .named_groups = &[_]tls.NamedGroup{.secp384r1}, + named_groups: []const proto.NamedGroup = supported_named_groups, + + /// Client authentication certificates and private key. + auth: ?CertKeyPair = null, + + /// If this structure is provided it will be filled with handshake attributes + /// at the end of the handshake process. + diagnostic: ?*Diagnostic = null, + + /// For logging current connection tls keys, so we can share them with + /// Wireshark and analyze decrypted traffic there. + key_log_callback: ?key_log.Callback = null, + + pub const Diagnostic = struct { + tls_version: proto.Version = @enumFromInt(0), + cipher_suite_tag: CipherSuite = @enumFromInt(0), + named_group: proto.NamedGroup = @enumFromInt(0), + signature_scheme: proto.SignatureScheme = @enumFromInt(0), + client_signature_scheme: proto.SignatureScheme = @enumFromInt(0), + }; +}; + +const supported_named_groups = &[_]proto.NamedGroup{ + .x25519, + .secp256r1, + .secp384r1, + .x25519_kyber768d00, +}; + +/// Handshake parses tls server message and creates client messages. Collects +/// tls attributes: server random, cipher suite and so on. Client messages are +/// created using provided buffer. Provided record reader is used to get tls +/// record when needed. +pub fn Handshake(comptime Stream: type) type { + const RecordReaderT = record.Reader(Stream); + return struct { + client_random: [32]u8, + server_random: [32]u8 = undefined, + master_secret: [48]u8 = undefined, + key_material: [48 * 4]u8 = undefined, // for sha256 32 * 4 is filled, for sha384 48 * 4 + + transcript: Transcript = .{}, + cipher_suite: CipherSuite = @enumFromInt(0), + named_group: ?proto.NamedGroup = null, + dh_kp: DhKeyPair, + rsa_secret: RsaSecret, + tls_version: proto.Version = .tls_1_2, + cipher: Cipher = undefined, + cert: CertificateParser = undefined, + client_certificate_requested: bool = false, + // public key len: x25519 = 32, secp256r1 = 65, secp384r1 = 97, x25519_kyber768d00 = 1120 + server_pub_key_buf: [2048]u8 = undefined, + server_pub_key: []const u8 = undefined, + + rec_rdr: *RecordReaderT, // tls record reader + buffer: []u8, // scratch buffer used in all messages creation + + const HandshakeT = @This(); + + pub fn init(buf: []u8, rec_rdr: *RecordReaderT) HandshakeT { + return .{ + .client_random = undefined, + .dh_kp = undefined, + .rsa_secret = undefined, + //.now_sec = std.time.timestamp(), + .buffer = buf, + .rec_rdr = rec_rdr, + }; + } + + fn initKeys( + h: *HandshakeT, + named_groups: []const proto.NamedGroup, + ) !void { + const init_keys_buf_len = 32 + 46 + DhKeyPair.seed_len; + var buf: [init_keys_buf_len]u8 = undefined; + crypto.random.bytes(&buf); + + h.client_random = buf[0..32].*; + h.rsa_secret = RsaSecret.init(buf[32..][0..46].*); + h.dh_kp = try DhKeyPair.init(buf[32 + 46 ..][0..DhKeyPair.seed_len].*, named_groups); + } + + /// Handshake exchanges messages with server to get agreement about + /// cryptographic parameters. That upgrades existing client-server + /// connection to TLS connection. Returns cipher used in application for + /// encrypted message exchange. + /// + /// Handles TLS 1.2 and TLS 1.3 connections. After initial client hello + /// server chooses in its server hello which TLS version will be used. + /// + /// TLS 1.2 handshake messages exchange: + /// Client Server + /// -------------------------------------------------------------- + /// ClientHello client flight 1 ---> + /// ServerHello + /// Certificate + /// ServerKeyExchange + /// CertificateRequest* + /// <--- server flight 1 ServerHelloDone + /// Certificate* + /// ClientKeyExchange + /// CertificateVerify* + /// ChangeCipherSpec + /// Finished client flight 2 ---> + /// ChangeCipherSpec + /// <--- server flight 2 Finished + /// + /// TLS 1.3 handshake messages exchange: + /// Client Server + /// -------------------------------------------------------------- + /// ClientHello client flight 1 ---> + /// ServerHello + /// {EncryptedExtensions} + /// {CertificateRequest*} + /// {Certificate} + /// {CertificateVerify} + /// <--- server flight 1 {Finished} + /// ChangeCipherSpec + /// {Certificate*} + /// {CertificateVerify*} + /// Finished client flight 2 ---> + /// + /// * - optional + /// {} - encrypted + /// + /// References: + /// https://datatracker.ietf.org/doc/html/rfc5246#section-7.3 + /// https://datatracker.ietf.org/doc/html/rfc8446#section-2 + /// + pub fn handshake(h: *HandshakeT, w: Stream, opt: Options) !Cipher { + defer h.updateDiagnostic(opt); + try h.initKeys(opt.named_groups); + h.cert = .{ + .host = opt.host, + .root_ca = opt.root_ca.bundle, + .skip_verify = opt.insecure_skip_verify, + }; + + try w.writeAll(try h.makeClientHello(opt)); // client flight 1 + try h.readServerFlight1(); // server flight 1 + h.transcript.use(h.cipher_suite.hash()); + + // tls 1.3 specific handshake part + if (h.tls_version == .tls_1_3) { + try h.generateHandshakeCipher(opt.key_log_callback); + try h.readEncryptedServerFlight1(); // server flight 1 + const app_cipher = try h.generateApplicationCipher(opt.key_log_callback); + try w.writeAll(try h.makeClientFlight2Tls13(opt.auth)); // client flight 2 + return app_cipher; + } + + // tls 1.2 specific handshake part + try h.generateCipher(opt.key_log_callback); + try w.writeAll(try h.makeClientFlight2Tls12(opt.auth)); // client flight 2 + try h.readServerFlight2(); // server flight 2 + return h.cipher; + } + + /// Prepare key material and generate cipher for TLS 1.2 + fn generateCipher(h: *HandshakeT, key_log_callback: ?key_log.Callback) !void { + try h.verifyCertificateSignatureTls12(); + try h.generateKeyMaterial(key_log_callback); + h.cipher = try Cipher.initTls12(h.cipher_suite, &h.key_material, .client); + } + + /// Generate TLS 1.2 pre master secret, master secret and key material. + fn generateKeyMaterial(h: *HandshakeT, key_log_callback: ?key_log.Callback) !void { + const pre_master_secret = if (h.named_group) |named_group| + try h.dh_kp.sharedKey(named_group, h.server_pub_key) + else + &h.rsa_secret.secret; + + _ = dupe( + &h.master_secret, + h.transcript.masterSecret(pre_master_secret, h.client_random, h.server_random), + ); + _ = dupe( + &h.key_material, + h.transcript.keyMaterial(&h.master_secret, h.client_random, h.server_random), + ); + if (key_log_callback) |cb| { + cb(key_log.label.client_random, &h.client_random, &h.master_secret); + } + } + + /// TLS 1.3 cipher used during handshake + fn generateHandshakeCipher(h: *HandshakeT, key_log_callback: ?key_log.Callback) !void { + const shared_key = try h.dh_kp.sharedKey(h.named_group.?, h.server_pub_key); + const handshake_secret = h.transcript.handshakeSecret(shared_key); + if (key_log_callback) |cb| { + cb(key_log.label.server_handshake_traffic_secret, &h.client_random, handshake_secret.server); + cb(key_log.label.client_handshake_traffic_secret, &h.client_random, handshake_secret.client); + } + h.cipher = try Cipher.initTls13(h.cipher_suite, handshake_secret, .client); + } + + /// TLS 1.3 application (client) cipher + fn generateApplicationCipher(h: *HandshakeT, key_log_callback: ?key_log.Callback) !Cipher { + const application_secret = h.transcript.applicationSecret(); + if (key_log_callback) |cb| { + cb(key_log.label.server_traffic_secret_0, &h.client_random, application_secret.server); + cb(key_log.label.client_traffic_secret_0, &h.client_random, application_secret.client); + } + return try Cipher.initTls13(h.cipher_suite, application_secret, .client); + } + + fn makeClientHello(h: *HandshakeT, opt: Options) ![]const u8 { + // Buffer will have this parts: + // | header | payload | extensions | + // + // Header will be written last because we need to know length of + // payload and extensions when creating it. Payload has + // extensions length (u16) as last element. + // + var buffer = h.buffer; + const header_len = 9; // tls record header (5 bytes) and handshake header (4 bytes) + const tls_versions = try CipherSuite.versions(opt.cipher_suites); + // Payload writer, preserve header_len bytes for handshake header. + var payload = record.Writer{ .buf = buffer[header_len..] }; + try payload.writeEnum(proto.Version.tls_1_2); + try payload.write(&h.client_random); + try payload.writeByte(0); // no session id + try payload.writeEnumArray(CipherSuite, opt.cipher_suites); + try payload.write(&[_]u8{ 0x01, 0x00 }); // no compression + + // Extensions writer starts after payload and preserves 2 more + // bytes for extension len in payload. + var ext = record.Writer{ .buf = buffer[header_len + payload.pos + 2 ..] }; + try ext.writeExtension(.supported_versions, switch (tls_versions) { + .both => &[_]proto.Version{ .tls_1_3, .tls_1_2 }, + .tls_1_3 => &[_]proto.Version{.tls_1_3}, + .tls_1_2 => &[_]proto.Version{.tls_1_2}, + }); + try ext.writeExtension(.signature_algorithms, common.supported_signature_algorithms); + + try ext.writeExtension(.supported_groups, opt.named_groups); + if (tls_versions != .tls_1_2) { + var keys: [supported_named_groups.len][]const u8 = undefined; + for (opt.named_groups, 0..) |ng, i| { + keys[i] = try h.dh_kp.publicKey(ng); + } + try ext.writeKeyShare(opt.named_groups, keys[0..opt.named_groups.len]); + } + try ext.writeServerName(opt.host); + + // Extensions length at the end of the payload. + try payload.writeInt(@as(u16, @intCast(ext.pos))); + + // Header at the start of the buffer. + const body_len = payload.pos + ext.pos; + buffer[0..header_len].* = record.header(.handshake, 4 + body_len) ++ + record.handshakeHeader(.client_hello, body_len); + + const msg = buffer[0 .. header_len + body_len]; + h.transcript.update(msg[record.header_len..]); + return msg; + } + + /// Process first flight of the messages from the server. + /// Read server hello message. If TLS 1.3 is chosen in server hello + /// return. For TLS 1.2 continue and read certificate, key_exchange + /// eventual certificate request and hello done messages. + fn readServerFlight1(h: *HandshakeT) !void { + var handshake_states: []const proto.Handshake = &.{.server_hello}; + + while (true) { + var d = try h.rec_rdr.nextDecoder(); + try d.expectContentType(.handshake); + + h.transcript.update(d.payload); + + // Multiple handshake messages can be packed in single tls record. + while (!d.eof()) { + const handshake_type = try d.decode(proto.Handshake); + + const length = try d.decode(u24); + if (length > cipher.max_cleartext_len) + return error.TlsUnsupportedFragmentedHandshakeMessage; + + brk: { + for (handshake_states) |state| + if (state == handshake_type) break :brk; + return error.TlsUnexpectedMessage; + } + switch (handshake_type) { + .server_hello => { // server hello, ref: https://datatracker.ietf.org/doc/html/rfc5246#section-7.4.1.3 + try h.parseServerHello(&d, length); + if (h.tls_version == .tls_1_3) { + if (!d.eof()) return error.TlsIllegalParameter; + return; // end of tls 1.3 server flight 1 + } + handshake_states = if (h.cert.skip_verify) + &.{ .certificate, .server_key_exchange, .server_hello_done } + else + &.{.certificate}; + }, + .certificate => { + try h.cert.parseCertificate(&d, h.tls_version); + handshake_states = if (h.cipher_suite.keyExchange() == .rsa) + &.{.server_hello_done} + else + &.{.server_key_exchange}; + }, + .server_key_exchange => { + try h.parseServerKeyExchange(&d); + handshake_states = &.{ .certificate_request, .server_hello_done }; + }, + .certificate_request => { + h.client_certificate_requested = true; + try d.skip(length); + handshake_states = &.{.server_hello_done}; + }, + .server_hello_done => { + if (length != 0) return error.TlsIllegalParameter; + return; + }, + else => return error.TlsUnexpectedMessage, + } + } + } + } + + /// Parse server hello message. + fn parseServerHello(h: *HandshakeT, d: *record.Decoder, length: u24) !void { + if (try d.decode(proto.Version) != proto.Version.tls_1_2) + return error.TlsBadVersion; + h.server_random = try d.array(32); + if (isServerHelloRetryRequest(&h.server_random)) + return error.TlsServerHelloRetryRequest; + + const session_id_len = try d.decode(u8); + if (session_id_len > 32) return error.TlsIllegalParameter; + try d.skip(session_id_len); + + h.cipher_suite = try d.decode(CipherSuite); + try h.cipher_suite.validate(); + try d.skip(1); // skip compression method + + const extensions_present = length > 2 + 32 + 1 + session_id_len + 2 + 1; + if (extensions_present) { + const exs_len = try d.decode(u16); + var l: usize = 0; + while (l < exs_len) { + const typ = try d.decode(proto.Extension); + const len = try d.decode(u16); + defer l += len + 4; + + switch (typ) { + .supported_versions => { + switch (try d.decode(proto.Version)) { + .tls_1_2, .tls_1_3 => |v| h.tls_version = v, + else => return error.TlsIllegalParameter, + } + if (len != 2) return error.TlsIllegalParameter; + }, + .key_share => { + h.named_group = try d.decode(proto.NamedGroup); + h.server_pub_key = dupe(&h.server_pub_key_buf, try d.slice(try d.decode(u16))); + if (len != h.server_pub_key.len + 4) return error.TlsIllegalParameter; + }, + else => { + try d.skip(len); + }, + } + } + } + } + + fn isServerHelloRetryRequest(server_random: []const u8) bool { + // Ref: https://datatracker.ietf.org/doc/html/rfc8446#section-4.1.3 + const hello_retry_request_magic = [32]u8{ + 0xCF, 0x21, 0xAD, 0x74, 0xE5, 0x9A, 0x61, 0x11, 0xBE, 0x1D, 0x8C, 0x02, 0x1E, 0x65, 0xB8, 0x91, + 0xC2, 0xA2, 0x11, 0x16, 0x7A, 0xBB, 0x8C, 0x5E, 0x07, 0x9E, 0x09, 0xE2, 0xC8, 0xA8, 0x33, 0x9C, + }; + return std.mem.eql(u8, server_random, &hello_retry_request_magic); + } + + fn parseServerKeyExchange(h: *HandshakeT, d: *record.Decoder) !void { + const curve_type = try d.decode(proto.Curve); + h.named_group = try d.decode(proto.NamedGroup); + h.server_pub_key = dupe(&h.server_pub_key_buf, try d.slice(try d.decode(u8))); + h.cert.signature_scheme = try d.decode(proto.SignatureScheme); + h.cert.signature = dupe(&h.cert.signature_buf, try d.slice(try d.decode(u16))); + if (curve_type != .named_curve) return error.TlsIllegalParameter; + } + + /// Read encrypted part (after server hello) of the server first flight + /// for TLS 1.3: change cipher spec, eventual certificate request, + /// certificate, certificate verify and handshake finished messages. + fn readEncryptedServerFlight1(h: *HandshakeT) !void { + var cleartext_buf = h.buffer; + var cleartext_buf_head: usize = 0; + var cleartext_buf_tail: usize = 0; + var handshake_states: []const proto.Handshake = &.{.encrypted_extensions}; + + outer: while (true) { + // wrapped record decoder + const rec = (try h.rec_rdr.next() orelse return error.EndOfStream); + if (rec.protocol_version != .tls_1_2) return error.TlsBadVersion; + switch (rec.content_type) { + .change_cipher_spec => {}, + .application_data => { + const content_type, const cleartext = try h.cipher.decrypt( + cleartext_buf[cleartext_buf_tail..], + rec, + ); + cleartext_buf_tail += cleartext.len; + if (cleartext_buf_tail > cleartext_buf.len) return error.TlsRecordOverflow; + + var d = record.Decoder.init(content_type, cleartext_buf[cleartext_buf_head..cleartext_buf_tail]); + try d.expectContentType(.handshake); + while (!d.eof()) { + const start_idx = d.idx; + const handshake_type = try d.decode(proto.Handshake); + const length = try d.decode(u24); + + // std.debug.print("handshake loop: {} {} {} {}\n", .{ handshake_type, length, d.payload.len, d.idx }); + if (length > cipher.max_cleartext_len) + return error.TlsUnsupportedFragmentedHandshakeMessage; + if (length > d.rest().len) + continue :outer; // fragmented handshake into multiple records + + defer { + const handshake_payload = d.payload[start_idx..d.idx]; + h.transcript.update(handshake_payload); + cleartext_buf_head += handshake_payload.len; + } + + brk: { + for (handshake_states) |state| + if (state == handshake_type) break :brk; + return error.TlsUnexpectedMessage; + } + switch (handshake_type) { + .encrypted_extensions => { + try d.skip(length); + handshake_states = if (h.cert.skip_verify) + &.{ .certificate_request, .certificate, .finished } + else + &.{ .certificate_request, .certificate }; + }, + .certificate_request => { + h.client_certificate_requested = true; + try d.skip(length); + handshake_states = if (h.cert.skip_verify) + &.{ .certificate, .finished } + else + &.{.certificate}; + }, + .certificate => { + try h.cert.parseCertificate(&d, h.tls_version); + handshake_states = &.{.certificate_verify}; + }, + .certificate_verify => { + try h.cert.parseCertificateVerify(&d); + try h.cert.verifySignature(h.transcript.serverCertificateVerify()); + handshake_states = &.{.finished}; + }, + .finished => { + const actual = try d.slice(length); + var buf: [Transcript.max_mac_length]u8 = undefined; + const expected = h.transcript.serverFinishedTls13(&buf); + if (!mem.eql(u8, expected, actual)) + return error.TlsDecryptError; + return; + }, + else => return error.TlsUnexpectedMessage, + } + } + cleartext_buf_head = 0; + cleartext_buf_tail = 0; + }, + else => return error.TlsUnexpectedMessage, + } + } + } + + fn verifyCertificateSignatureTls12(h: *HandshakeT) !void { + if (h.cipher_suite.keyExchange() != .ecdhe) return; + const verify_bytes = brk: { + var w = record.Writer{ .buf = h.buffer }; + try w.write(&h.client_random); + try w.write(&h.server_random); + try w.writeEnum(proto.Curve.named_curve); + try w.writeEnum(h.named_group.?); + try w.writeInt(@as(u8, @intCast(h.server_pub_key.len))); + try w.write(h.server_pub_key); + break :brk w.getWritten(); + }; + try h.cert.verifySignature(verify_bytes); + } + + /// Create client key exchange, change cipher spec and handshake + /// finished messages for tls 1.2. + /// If client certificate is requested also adds client certificate and + /// certificate verify messages. + fn makeClientFlight2Tls12(h: *HandshakeT, auth: ?CertKeyPair) ![]const u8 { + var w = record.Writer{ .buf = h.buffer }; + var cert_builder: ?CertificateBuilder = null; + + // Client certificate message + if (h.client_certificate_requested) { + if (auth) |a| { + const cb = h.certificateBuilder(a); + cert_builder = cb; + const client_certificate = try cb.makeCertificate(w.getPayload()); + h.transcript.update(client_certificate); + try w.advanceRecord(.handshake, client_certificate.len); + } else { + const empty_certificate = &record.handshakeHeader(.certificate, 3) ++ [_]u8{ 0, 0, 0 }; + h.transcript.update(empty_certificate); + try w.writeRecord(.handshake, empty_certificate); + } + } + + // Client key exchange message + { + const key_exchange = try h.makeClientKeyExchange(w.getPayload()); + h.transcript.update(key_exchange); + try w.advanceRecord(.handshake, key_exchange.len); + } + + // Client certificate verify message + if (cert_builder) |cb| { + const certificate_verify = try cb.makeCertificateVerify(w.getPayload()); + h.transcript.update(certificate_verify); + try w.advanceRecord(.handshake, certificate_verify.len); + } + + // Client change cipher spec message + try w.writeRecord(.change_cipher_spec, &[_]u8{1}); + + // Client handshake finished message + { + const client_finished = &record.handshakeHeader(.finished, 12) ++ + h.transcript.clientFinishedTls12(&h.master_secret); + h.transcript.update(client_finished); + try h.writeEncrypted(&w, client_finished); + } + + return w.getWritten(); + } + + /// Create client change cipher spec and handshake finished messages for + /// tls 1.3. + /// If the client certificate is requested by the server and client is + /// configured with certificates and private key then client certificate + /// and client certificate verify messages are also created. If the + /// server has requested certificate but the client is not configured + /// empty certificate message is sent, as is required by rfc. + fn makeClientFlight2Tls13(h: *HandshakeT, auth: ?CertKeyPair) ![]const u8 { + var w = record.Writer{ .buf = h.buffer }; + + // Client change cipher spec message + try w.writeRecord(.change_cipher_spec, &[_]u8{1}); + + if (h.client_certificate_requested) { + if (auth) |a| { + const cb = h.certificateBuilder(a); + { + const certificate = try cb.makeCertificate(w.getPayload()); + h.transcript.update(certificate); + try h.writeEncrypted(&w, certificate); + } + { + const certificate_verify = try cb.makeCertificateVerify(w.getPayload()); + h.transcript.update(certificate_verify); + try h.writeEncrypted(&w, certificate_verify); + } + } else { + // Empty certificate message and no certificate verify message + const empty_certificate = &record.handshakeHeader(.certificate, 4) ++ [_]u8{ 0, 0, 0, 0 }; + h.transcript.update(empty_certificate); + try h.writeEncrypted(&w, empty_certificate); + } + } + + // Client handshake finished message + { + const client_finished = try h.makeClientFinishedTls13(w.getPayload()); + h.transcript.update(client_finished); + try h.writeEncrypted(&w, client_finished); + } + + return w.getWritten(); + } + + fn certificateBuilder(h: *HandshakeT, auth: CertKeyPair) CertificateBuilder { + return .{ + .bundle = auth.bundle, + .key = auth.key, + .transcript = &h.transcript, + .tls_version = h.tls_version, + .side = .client, + }; + } + + fn makeClientFinishedTls13(h: *HandshakeT, buf: []u8) ![]const u8 { + var w = record.Writer{ .buf = buf }; + const verify_data = h.transcript.clientFinishedTls13(w.getHandshakePayload()); + try w.advanceHandshake(.finished, verify_data.len); + return w.getWritten(); + } + + fn makeClientKeyExchange(h: *HandshakeT, buf: []u8) ![]const u8 { + var w = record.Writer{ .buf = buf }; + if (h.named_group) |named_group| { + const key = try h.dh_kp.publicKey(named_group); + try w.writeHandshakeHeader(.client_key_exchange, 1 + key.len); + try w.writeInt(@as(u8, @intCast(key.len))); + try w.write(key); + } else { + const key = try h.rsa_secret.encrypted(h.cert.pub_key_algo, h.cert.pub_key); + try w.writeHandshakeHeader(.client_key_exchange, 2 + key.len); + try w.writeInt(@as(u16, @intCast(key.len))); + try w.write(key); + } + return w.getWritten(); + } + + fn readServerFlight2(h: *HandshakeT) !void { + // Read server change cipher spec message. + { + var d = try h.rec_rdr.nextDecoder(); + try d.expectContentType(.change_cipher_spec); + } + // Read encrypted server handshake finished message. Verify that + // content of the server finished message is based on transcript + // hash and master secret. + { + const content_type, const server_finished = + try h.rec_rdr.nextDecrypt(&h.cipher) orelse return error.EndOfStream; + if (content_type != .handshake) + return error.TlsUnexpectedMessage; + const expected = record.handshakeHeader(.finished, 12) ++ h.transcript.serverFinishedTls12(&h.master_secret); + if (!mem.eql(u8, server_finished, &expected)) + return error.TlsBadRecordMac; + } + } + + /// Write encrypted handshake message into `w` + fn writeEncrypted(h: *HandshakeT, w: *record.Writer, cleartext: []const u8) !void { + const ciphertext = try h.cipher.encrypt(w.getFree(), .handshake, cleartext); + w.pos += ciphertext.len; + } + + // Copy handshake parameters to opt.diagnostic + fn updateDiagnostic(h: *HandshakeT, opt: Options) void { + if (opt.diagnostic) |d| { + d.tls_version = h.tls_version; + d.cipher_suite_tag = h.cipher_suite; + d.named_group = h.named_group orelse @as(proto.NamedGroup, @enumFromInt(0x0000)); + d.signature_scheme = h.cert.signature_scheme; + if (opt.auth) |a| + d.client_signature_scheme = a.key.signature_scheme; + } + } + }; +} + +const RsaSecret = struct { + secret: [48]u8, + + fn init(rand: [46]u8) RsaSecret { + return .{ .secret = [_]u8{ 0x03, 0x03 } ++ rand }; + } + + // Pre master secret encrypted with certificate public key. + inline fn encrypted( + self: RsaSecret, + cert_pub_key_algo: Certificate.Parsed.PubKeyAlgo, + cert_pub_key: []const u8, + ) ![]const u8 { + if (cert_pub_key_algo != .rsaEncryption) return error.TlsBadSignatureScheme; + const pk = try rsa.PublicKey.fromDer(cert_pub_key); + var out: [512]u8 = undefined; + return try pk.encryptPkcsv1_5(&self.secret, &out); + } +}; + +const testing = std.testing; +const data12 = @import("testdata/tls12.zig"); +const data13 = @import("testdata/tls13.zig"); +const testu = @import("testu.zig"); + +fn testReader(data: []const u8) record.Reader(std.io.FixedBufferStream([]const u8)) { + return record.reader(std.io.fixedBufferStream(data)); +} +const TestHandshake = Handshake(std.io.FixedBufferStream([]const u8)); + +test "parse tls 1.2 server hello" { + var h = brk: { + var buffer: [1024]u8 = undefined; + var rec_rdr = testReader(&data12.server_hello_responses); + break :brk TestHandshake.init(&buffer, &rec_rdr); + }; + + // Set to known instead of random + h.client_random = data12.client_random; + h.dh_kp.x25519_kp.secret_key = data12.client_secret; + + // Parse server hello, certificate and key exchange messages. + // Read cipher suite, named group, signature scheme, server random certificate public key + // Verify host name, signature + // Calculate key material + h.cert = .{ .host = "example.ulfheim.net", .skip_verify = true, .root_ca = .{} }; + try h.readServerFlight1(); + try testing.expectEqual(.ECDHE_RSA_WITH_AES_128_CBC_SHA, h.cipher_suite); + try testing.expectEqual(.x25519, h.named_group.?); + try testing.expectEqual(.rsa_pkcs1_sha256, h.cert.signature_scheme); + try testing.expectEqualSlices(u8, &data12.server_random, &h.server_random); + try testing.expectEqualSlices(u8, &data12.server_pub_key, h.server_pub_key); + try testing.expectEqualSlices(u8, &data12.signature, h.cert.signature); + try testing.expectEqualSlices(u8, &data12.cert_pub_key, h.cert.pub_key); + + try h.verifyCertificateSignatureTls12(); + try h.generateKeyMaterial(null); + + try testing.expectEqualSlices(u8, &data12.key_material, h.key_material[0..data12.key_material.len]); +} + +test "verify google.com certificate" { + var h = brk: { + var buffer: [1024]u8 = undefined; + var rec_rdr = testReader(@embedFile("testdata/google.com/server_hello")); + break :brk TestHandshake.init(&buffer, &rec_rdr); + }; + h.client_random = @embedFile("testdata/google.com/client_random").*; + + var ca_bundle: Certificate.Bundle = .{}; + try ca_bundle.rescan(testing.allocator); + defer ca_bundle.deinit(testing.allocator); + + h.cert = .{ .host = "google.com", .skip_verify = true, .root_ca = .{}, .now_sec = 1714846451 }; + try h.readServerFlight1(); + try h.verifyCertificateSignatureTls12(); +} + +test "parse tls 1.3 server hello" { + var rec_rdr = testReader(&data13.server_hello); + var d = (try rec_rdr.nextDecoder()); + + const handshake_type = try d.decode(proto.Handshake); + const length = try d.decode(u24); + try testing.expectEqual(0x000076, length); + try testing.expectEqual(.server_hello, handshake_type); + + var h = TestHandshake.init(undefined, undefined); + try h.parseServerHello(&d, length); + + try testing.expectEqual(.AES_256_GCM_SHA384, h.cipher_suite); + try testing.expectEqualSlices(u8, &data13.server_random, &h.server_random); + try testing.expectEqual(.tls_1_3, h.tls_version); + try testing.expectEqual(.x25519, h.named_group); + try testing.expectEqualSlices(u8, &data13.server_pub_key, h.server_pub_key); +} + +test "init tls 1.3 handshake cipher" { + const cipher_suite_tag: CipherSuite = .AES_256_GCM_SHA384; + + var transcript = Transcript{}; + transcript.use(cipher_suite_tag.hash()); + transcript.update(data13.client_hello[record.header_len..]); + transcript.update(data13.server_hello[record.header_len..]); + + var dh_kp = DhKeyPair{ + .x25519_kp = .{ + .public_key = data13.client_public_key, + .secret_key = data13.client_private_key, + }, + }; + const shared_key = try dh_kp.sharedKey(.x25519, &data13.server_pub_key); + try testing.expectEqualSlices(u8, &data13.shared_key, shared_key); + + const cph = try Cipher.initTls13(cipher_suite_tag, transcript.handshakeSecret(shared_key), .client); + + const c = &cph.AES_256_GCM_SHA384; + try testing.expectEqualSlices(u8, &data13.server_handshake_key, &c.decrypt_key); + try testing.expectEqualSlices(u8, &data13.client_handshake_key, &c.encrypt_key); + try testing.expectEqualSlices(u8, &data13.server_handshake_iv, &c.decrypt_iv); + try testing.expectEqualSlices(u8, &data13.client_handshake_iv, &c.encrypt_iv); +} + +fn initExampleHandshake(h: *TestHandshake) !void { + h.cipher_suite = .AES_256_GCM_SHA384; + h.transcript.use(h.cipher_suite.hash()); + h.transcript.update(data13.client_hello[record.header_len..]); + h.transcript.update(data13.server_hello[record.header_len..]); + h.cipher = try Cipher.initTls13(h.cipher_suite, h.transcript.handshakeSecret(&data13.shared_key), .client); + h.tls_version = .tls_1_3; + h.cert.now_sec = 1714846451; + h.server_pub_key = &data13.server_pub_key; +} + +test "tls 1.3 decrypt wrapped record" { + var cph = brk: { + var h = TestHandshake.init(undefined, undefined); + try initExampleHandshake(&h); + break :brk h.cipher; + }; + + var cleartext_buf: [1024]u8 = undefined; + { + const rec = record.Record.init(&data13.server_encrypted_extensions_wrapped); + + const content_type, const cleartext = try cph.decrypt(&cleartext_buf, rec); + try testing.expectEqual(.handshake, content_type); + try testing.expectEqualSlices(u8, &data13.server_encrypted_extensions, cleartext); + } + { + const rec = record.Record.init(&data13.server_certificate_wrapped); + + const content_type, const cleartext = try cph.decrypt(&cleartext_buf, rec); + try testing.expectEqual(.handshake, content_type); + try testing.expectEqualSlices(u8, &data13.server_certificate, cleartext); + } +} + +test "tls 1.3 process server flight" { + var buffer: [1024]u8 = undefined; + var h = brk: { + var rec_rdr = testReader(&data13.server_flight); + break :brk TestHandshake.init(&buffer, &rec_rdr); + }; + + try initExampleHandshake(&h); + h.cert = .{ .host = "example.ulfheim.net", .skip_verify = true, .root_ca = .{} }; + try h.readEncryptedServerFlight1(); + + { // application cipher keys calculation + try testing.expectEqualSlices(u8, &data13.handshake_hash, &h.transcript.sha384.hash.peek()); + + var cph = try Cipher.initTls13(h.cipher_suite, h.transcript.applicationSecret(), .client); + const c = &cph.AES_256_GCM_SHA384; + try testing.expectEqualSlices(u8, &data13.server_application_key, &c.decrypt_key); + try testing.expectEqualSlices(u8, &data13.client_application_key, &c.encrypt_key); + try testing.expectEqualSlices(u8, &data13.server_application_iv, &c.decrypt_iv); + try testing.expectEqualSlices(u8, &data13.client_application_iv, &c.encrypt_iv); + + const encrypted = try cph.encrypt(&buffer, .application_data, "ping"); + try testing.expectEqualSlices(u8, &data13.client_ping_wrapped, encrypted); + } + { // client finished message + var buf: [4 + Transcript.max_mac_length]u8 = undefined; + const client_finished = try h.makeClientFinishedTls13(&buf); + try testing.expectEqualSlices(u8, &data13.client_finished_verify_data, client_finished[4..]); + const encrypted = try h.cipher.encrypt(&buffer, .handshake, client_finished); + try testing.expectEqualSlices(u8, &data13.client_finished_wrapped, encrypted); + } +} + +test "create client hello" { + var h = brk: { + var buffer: [1024]u8 = undefined; + var h = TestHandshake.init(&buffer, undefined); + h.client_random = testu.hexToBytes( + \\ 00 01 02 03 04 05 06 07 08 09 0a 0b 0c 0d 0e 0f 10 11 12 13 14 15 16 17 18 19 1a 1b 1c 1d 1e 1f + ); + break :brk h; + }; + + const actual = try h.makeClientHello(.{ + .host = "google.com", + .root_ca = .{}, + .cipher_suites = &[_]CipherSuite{CipherSuite.ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, + .named_groups = &[_]proto.NamedGroup{ .x25519, .secp256r1, .secp384r1 }, + }); + + const expected = testu.hexToBytes( + "16 03 03 00 6d " ++ // record header + "01 00 00 69 " ++ // handshake header + "03 03 " ++ // protocol version + "00 01 02 03 04 05 06 07 08 09 0a 0b 0c 0d 0e 0f 10 11 12 13 14 15 16 17 18 19 1a 1b 1c 1d 1e 1f " ++ // client random + "00 " ++ // no session id + "00 02 c0 2b " ++ // cipher suites + "01 00 " ++ // compression methods + "00 3e " ++ // extensions length + "00 2b 00 03 02 03 03 " ++ // supported versions extension + "00 0d 00 14 00 12 04 03 05 03 08 04 08 05 08 06 08 07 02 01 04 01 05 01 " ++ // signature algorithms extension + "00 0a 00 08 00 06 00 1d 00 17 00 18 " ++ // named groups extension + "00 00 00 0f 00 0d 00 00 0a 67 6f 6f 67 6c 65 2e 63 6f 6d ", // server name extension + ); + try testing.expectEqualSlices(u8, &expected, actual); +} + +test "handshake verify server finished message" { + var buffer: [1024]u8 = undefined; + var rec_rdr = testReader(&data12.server_handshake_finished_msgs); + var h = TestHandshake.init(&buffer, &rec_rdr); + + h.cipher_suite = .ECDHE_ECDSA_WITH_AES_128_CBC_SHA; + h.master_secret = data12.master_secret; + + // add handshake messages to the transcript + for (data12.handshake_messages) |msg| { + h.transcript.update(msg[record.header_len..]); + } + + // expect verify data + const client_finished = h.transcript.clientFinishedTls12(&h.master_secret); + try testing.expectEqualSlices(u8, &data12.client_finished, &record.handshakeHeader(.finished, 12) ++ client_finished); + + // init client with prepared key_material + h.cipher = try Cipher.initTls12(.ECDHE_RSA_WITH_AES_128_CBC_SHA, &data12.key_material, .client); + + // check that server verify data matches calculates from hashes of all handshake messages + h.transcript.update(&data12.client_finished); + try h.readServerFlight2(); +} diff --git a/src/http/async/tls.zig/handshake_common.zig b/src/http/async/tls.zig/handshake_common.zig new file mode 100644 index 00000000..178a3cea --- /dev/null +++ b/src/http/async/tls.zig/handshake_common.zig @@ -0,0 +1,448 @@ +const std = @import("std"); +const assert = std.debug.assert; +const mem = std.mem; +const crypto = std.crypto; +const Certificate = crypto.Certificate; + +const Transcript = @import("transcript.zig").Transcript; +const PrivateKey = @import("PrivateKey.zig"); +const record = @import("record.zig"); +const rsa = @import("rsa/rsa.zig"); +const proto = @import("protocol.zig"); + +const X25519 = crypto.dh.X25519; +const EcdsaP256Sha256 = crypto.sign.ecdsa.EcdsaP256Sha256; +const EcdsaP384Sha384 = crypto.sign.ecdsa.EcdsaP384Sha384; +const Kyber768 = crypto.kem.kyber_d00.Kyber768; + +pub const supported_signature_algorithms = &[_]proto.SignatureScheme{ + .ecdsa_secp256r1_sha256, + .ecdsa_secp384r1_sha384, + .rsa_pss_rsae_sha256, + .rsa_pss_rsae_sha384, + .rsa_pss_rsae_sha512, + .ed25519, + .rsa_pkcs1_sha1, + .rsa_pkcs1_sha256, + .rsa_pkcs1_sha384, +}; + +pub const CertKeyPair = struct { + /// A chain of one or more certificates, leaf first. + /// + /// Each X.509 certificate contains the public key of a key pair, extra + /// information (the name of the holder, the name of an issuer of the + /// certificate, validity time spans) and a signature generated using the + /// private key of the issuer of the certificate. + /// + /// All certificates from the bundle are sent to the other side when creating + /// Certificate tls message. + /// + /// Leaf certificate and private key are used to create signature for + /// CertifyVerify tls message. + bundle: Certificate.Bundle, + + /// Private key corresponding to the public key in leaf certificate from the + /// bundle. + key: PrivateKey, + + pub fn load( + allocator: std.mem.Allocator, + dir: std.fs.Dir, + cert_path: []const u8, + key_path: []const u8, + ) !CertKeyPair { + var bundle: Certificate.Bundle = .{}; + try bundle.addCertsFromFilePath(allocator, dir, cert_path); + + const key_file = try dir.openFile(key_path, .{}); + defer key_file.close(); + const key = try PrivateKey.fromFile(allocator, key_file); + + return .{ .bundle = bundle, .key = key }; + } + + pub fn deinit(c: *CertKeyPair, allocator: std.mem.Allocator) void { + c.bundle.deinit(allocator); + } +}; + +pub const CertBundle = struct { + // A chain of one or more certificates. + // + // They are used to verify that certificate chain sent by the other side + // forms valid trust chain. + bundle: Certificate.Bundle = .{}, + + pub fn fromFile(allocator: std.mem.Allocator, dir: std.fs.Dir, path: []const u8) !CertBundle { + var bundle: Certificate.Bundle = .{}; + try bundle.addCertsFromFilePath(allocator, dir, path); + return .{ .bundle = bundle }; + } + + pub fn fromSystem(allocator: std.mem.Allocator) !CertBundle { + var bundle: Certificate.Bundle = .{}; + try bundle.rescan(allocator); + return .{ .bundle = bundle }; + } + + pub fn deinit(cb: *CertBundle, allocator: std.mem.Allocator) void { + cb.bundle.deinit(allocator); + } +}; + +pub const CertificateBuilder = struct { + bundle: Certificate.Bundle, + key: PrivateKey, + transcript: *Transcript, + tls_version: proto.Version = .tls_1_3, + side: proto.Side = .client, + + pub fn makeCertificate(h: CertificateBuilder, buf: []u8) ![]const u8 { + var w = record.Writer{ .buf = buf }; + const certs = h.bundle.bytes.items; + const certs_count = h.bundle.map.size; + + // Differences between tls 1.3 and 1.2 + // TLS 1.3 has request context in header and extensions for each certificate. + // Here we use empty length for each field. + // TLS 1.2 don't have these two fields. + const request_context, const extensions = if (h.tls_version == .tls_1_3) + .{ &[_]u8{0}, &[_]u8{ 0, 0 } } + else + .{ &[_]u8{}, &[_]u8{} }; + const certs_len = certs.len + (3 + extensions.len) * certs_count; + + // Write handshake header + try w.writeHandshakeHeader(.certificate, certs_len + request_context.len + 3); + try w.write(request_context); + try w.writeInt(@as(u24, @intCast(certs_len))); + + // Write each certificate + var index: u32 = 0; + while (index < certs.len) { + const e = try Certificate.der.Element.parse(certs, index); + const cert = certs[index..e.slice.end]; + try w.writeInt(@as(u24, @intCast(cert.len))); // certificate length + try w.write(cert); // certificate + try w.write(extensions); // certificate extensions + index = e.slice.end; + } + return w.getWritten(); + } + + pub fn makeCertificateVerify(h: CertificateBuilder, buf: []u8) ![]const u8 { + var w = record.Writer{ .buf = buf }; + const signature, const signature_scheme = try h.createSignature(); + try w.writeHandshakeHeader(.certificate_verify, signature.len + 4); + try w.writeEnum(signature_scheme); + try w.writeInt(@as(u16, @intCast(signature.len))); + try w.write(signature); + return w.getWritten(); + } + + /// Creates signature for client certificate signature message. + /// Returns signature bytes and signature scheme. + inline fn createSignature(h: CertificateBuilder) !struct { []const u8, proto.SignatureScheme } { + switch (h.key.signature_scheme) { + inline .ecdsa_secp256r1_sha256, + .ecdsa_secp384r1_sha384, + => |comptime_scheme| { + const Ecdsa = SchemeEcdsa(comptime_scheme); + const key = h.key.key.ecdsa; + const key_len = Ecdsa.SecretKey.encoded_length; + if (key.len < key_len) return error.InvalidEncoding; + const secret_key = try Ecdsa.SecretKey.fromBytes(key[0..key_len].*); + const key_pair = try Ecdsa.KeyPair.fromSecretKey(secret_key); + var signer = try key_pair.signer(null); + h.setSignatureVerifyBytes(&signer); + const signature = try signer.finalize(); + var buf: [Ecdsa.Signature.der_encoded_length_max]u8 = undefined; + return .{ signature.toDer(&buf), comptime_scheme }; + }, + inline .rsa_pss_rsae_sha256, + .rsa_pss_rsae_sha384, + .rsa_pss_rsae_sha512, + => |comptime_scheme| { + const Hash = SchemeHash(comptime_scheme); + var signer = try h.key.key.rsa.signerOaep(Hash, null); + h.setSignatureVerifyBytes(&signer); + var buf: [512]u8 = undefined; + const signature = try signer.finalize(&buf); + return .{ signature.bytes, comptime_scheme }; + }, + else => return error.TlsUnknownSignatureScheme, + } + } + + fn setSignatureVerifyBytes(h: CertificateBuilder, signer: anytype) void { + if (h.tls_version == .tls_1_2) { + // tls 1.2 signature uses current transcript hash value. + // ref: https://datatracker.ietf.org/doc/html/rfc5246.html#section-7.4.8 + const Hash = @TypeOf(signer.h); + signer.h = h.transcript.hash(Hash); + } else { + // tls 1.3 signature is computed over concatenation of 64 spaces, + // context, separator and content. + // ref: https://datatracker.ietf.org/doc/html/rfc8446#section-4.4.3 + if (h.side == .server) { + signer.update(h.transcript.serverCertificateVerify()); + } else { + signer.update(h.transcript.clientCertificateVerify()); + } + } + } + + fn SchemeEcdsa(comptime scheme: proto.SignatureScheme) type { + return switch (scheme) { + .ecdsa_secp256r1_sha256 => EcdsaP256Sha256, + .ecdsa_secp384r1_sha384 => EcdsaP384Sha384, + else => unreachable, + }; + } +}; + +pub const CertificateParser = struct { + pub_key_algo: Certificate.Parsed.PubKeyAlgo = undefined, + pub_key_buf: [600]u8 = undefined, + pub_key: []const u8 = undefined, + + signature_scheme: proto.SignatureScheme = @enumFromInt(0), + signature_buf: [1024]u8 = undefined, + signature: []const u8 = undefined, + + root_ca: Certificate.Bundle, + host: []const u8, + skip_verify: bool = false, + now_sec: i64 = 0, + + pub fn parseCertificate(h: *CertificateParser, d: *record.Decoder, tls_version: proto.Version) !void { + if (h.now_sec == 0) { + h.now_sec = std.time.timestamp(); + } + if (tls_version == .tls_1_3) { + const request_context = try d.decode(u8); + if (request_context != 0) return error.TlsIllegalParameter; + } + + var trust_chain_established = false; + var last_cert: ?Certificate.Parsed = null; + const certs_len = try d.decode(u24); + const start_idx = d.idx; + while (d.idx - start_idx < certs_len) { + const cert_len = try d.decode(u24); + // std.debug.print("=> {} {} {} {}\n", .{ certs_len, d.idx, cert_len, d.payload.len }); + const cert = try d.slice(cert_len); + if (tls_version == .tls_1_3) { + // certificate extensions present in tls 1.3 + try d.skip(try d.decode(u16)); + } + if (trust_chain_established) + continue; + + const subject = try (Certificate{ .buffer = cert, .index = 0 }).parse(); + if (last_cert) |pc| { + if (pc.verify(subject, h.now_sec)) { + last_cert = subject; + } else |err| switch (err) { + error.CertificateIssuerMismatch => { + // skip certificate which is not part of the chain + continue; + }, + else => return err, + } + } else { // first certificate + if (!h.skip_verify and h.host.len > 0) { + try subject.verifyHostName(h.host); + } + h.pub_key = dupe(&h.pub_key_buf, subject.pubKey()); + h.pub_key_algo = subject.pub_key_algo; + last_cert = subject; + } + if (!h.skip_verify) { + if (h.root_ca.verify(last_cert.?, h.now_sec)) |_| { + trust_chain_established = true; + } else |err| switch (err) { + error.CertificateIssuerNotFound => {}, + else => return err, + } + } + } + if (!h.skip_verify and !trust_chain_established) { + return error.CertificateIssuerNotFound; + } + } + + pub fn parseCertificateVerify(h: *CertificateParser, d: *record.Decoder) !void { + h.signature_scheme = try d.decode(proto.SignatureScheme); + h.signature = dupe(&h.signature_buf, try d.slice(try d.decode(u16))); + } + + pub fn verifySignature(h: *CertificateParser, verify_bytes: []const u8) !void { + switch (h.signature_scheme) { + inline .ecdsa_secp256r1_sha256, + .ecdsa_secp384r1_sha384, + => |comptime_scheme| { + if (h.pub_key_algo != .X9_62_id_ecPublicKey) return error.TlsBadSignatureScheme; + const cert_named_curve = h.pub_key_algo.X9_62_id_ecPublicKey; + switch (cert_named_curve) { + inline .secp384r1, .X9_62_prime256v1 => |comptime_cert_named_curve| { + const Ecdsa = SchemeEcdsaCert(comptime_scheme, comptime_cert_named_curve); + const key = try Ecdsa.PublicKey.fromSec1(h.pub_key); + const sig = try Ecdsa.Signature.fromDer(h.signature); + try sig.verify(verify_bytes, key); + }, + else => return error.TlsUnknownSignatureScheme, + } + }, + .ed25519 => { + if (h.pub_key_algo != .curveEd25519) return error.TlsBadSignatureScheme; + const Eddsa = crypto.sign.Ed25519; + if (h.signature.len != Eddsa.Signature.encoded_length) return error.InvalidEncoding; + const sig = Eddsa.Signature.fromBytes(h.signature[0..Eddsa.Signature.encoded_length].*); + if (h.pub_key.len != Eddsa.PublicKey.encoded_length) return error.InvalidEncoding; + const key = try Eddsa.PublicKey.fromBytes(h.pub_key[0..Eddsa.PublicKey.encoded_length].*); + try sig.verify(verify_bytes, key); + }, + inline .rsa_pss_rsae_sha256, + .rsa_pss_rsae_sha384, + .rsa_pss_rsae_sha512, + => |comptime_scheme| { + if (h.pub_key_algo != .rsaEncryption) return error.TlsBadSignatureScheme; + const Hash = SchemeHash(comptime_scheme); + const pk = try rsa.PublicKey.fromDer(h.pub_key); + const sig = rsa.Pss(Hash).Signature{ .bytes = h.signature }; + try sig.verify(verify_bytes, pk, null); + }, + inline .rsa_pkcs1_sha1, + .rsa_pkcs1_sha256, + .rsa_pkcs1_sha384, + .rsa_pkcs1_sha512, + => |comptime_scheme| { + if (h.pub_key_algo != .rsaEncryption) return error.TlsBadSignatureScheme; + const Hash = SchemeHash(comptime_scheme); + const pk = try rsa.PublicKey.fromDer(h.pub_key); + const sig = rsa.PKCS1v1_5(Hash).Signature{ .bytes = h.signature }; + try sig.verify(verify_bytes, pk); + }, + else => return error.TlsUnknownSignatureScheme, + } + } + + fn SchemeEcdsaCert(comptime scheme: proto.SignatureScheme, comptime cert_named_curve: Certificate.NamedCurve) type { + const Sha256 = crypto.hash.sha2.Sha256; + const Sha384 = crypto.hash.sha2.Sha384; + const Ecdsa = crypto.sign.ecdsa.Ecdsa; + + return switch (scheme) { + .ecdsa_secp256r1_sha256 => Ecdsa(cert_named_curve.Curve(), Sha256), + .ecdsa_secp384r1_sha384 => Ecdsa(cert_named_curve.Curve(), Sha384), + else => @compileError("bad scheme"), + }; + } +}; + +fn SchemeHash(comptime scheme: proto.SignatureScheme) type { + const Sha256 = crypto.hash.sha2.Sha256; + const Sha384 = crypto.hash.sha2.Sha384; + const Sha512 = crypto.hash.sha2.Sha512; + + return switch (scheme) { + .rsa_pkcs1_sha1 => crypto.hash.Sha1, + .rsa_pss_rsae_sha256, .rsa_pkcs1_sha256 => Sha256, + .rsa_pss_rsae_sha384, .rsa_pkcs1_sha384 => Sha384, + .rsa_pss_rsae_sha512, .rsa_pkcs1_sha512 => Sha512, + else => @compileError("bad scheme"), + }; +} + +pub fn dupe(buf: []u8, data: []const u8) []u8 { + const n = @min(data.len, buf.len); + @memcpy(buf[0..n], data[0..n]); + return buf[0..n]; +} + +pub const DhKeyPair = struct { + x25519_kp: X25519.KeyPair = undefined, + secp256r1_kp: EcdsaP256Sha256.KeyPair = undefined, + secp384r1_kp: EcdsaP384Sha384.KeyPair = undefined, + kyber768_kp: Kyber768.KeyPair = undefined, + + pub const seed_len = 32 + 32 + 48 + 64; + + pub fn init(seed: [seed_len]u8, named_groups: []const proto.NamedGroup) !DhKeyPair { + var kp: DhKeyPair = .{}; + for (named_groups) |ng| + switch (ng) { + .x25519 => kp.x25519_kp = try X25519.KeyPair.create(seed[0..][0..X25519.seed_length].*), + .secp256r1 => kp.secp256r1_kp = try EcdsaP256Sha256.KeyPair.create(seed[32..][0..EcdsaP256Sha256.KeyPair.seed_length].*), + .secp384r1 => kp.secp384r1_kp = try EcdsaP384Sha384.KeyPair.create(seed[32 + 32 ..][0..EcdsaP384Sha384.KeyPair.seed_length].*), + .x25519_kyber768d00 => kp.kyber768_kp = try Kyber768.KeyPair.create(seed[32 + 32 + 48 ..][0..Kyber768.seed_length].*), + else => return error.TlsIllegalParameter, + }; + return kp; + } + + pub inline fn sharedKey(self: DhKeyPair, named_group: proto.NamedGroup, server_pub_key: []const u8) ![]const u8 { + return switch (named_group) { + .x25519 => brk: { + if (server_pub_key.len != X25519.public_length) + return error.TlsIllegalParameter; + break :brk &(try X25519.scalarmult( + self.x25519_kp.secret_key, + server_pub_key[0..X25519.public_length].*, + )); + }, + .secp256r1 => brk: { + const pk = try EcdsaP256Sha256.PublicKey.fromSec1(server_pub_key); + const mul = try pk.p.mulPublic(self.secp256r1_kp.secret_key.bytes, .big); + break :brk &mul.affineCoordinates().x.toBytes(.big); + }, + .secp384r1 => brk: { + const pk = try EcdsaP384Sha384.PublicKey.fromSec1(server_pub_key); + const mul = try pk.p.mulPublic(self.secp384r1_kp.secret_key.bytes, .big); + break :brk &mul.affineCoordinates().x.toBytes(.big); + }, + .x25519_kyber768d00 => brk: { + const xksl = crypto.dh.X25519.public_length; + const hksl = xksl + Kyber768.ciphertext_length; + if (server_pub_key.len != hksl) + return error.TlsIllegalParameter; + + break :brk &((crypto.dh.X25519.scalarmult( + self.x25519_kp.secret_key, + server_pub_key[0..xksl].*, + ) catch return error.TlsDecryptFailure) ++ (self.kyber768_kp.secret_key.decaps( + server_pub_key[xksl..hksl], + ) catch return error.TlsDecryptFailure)); + }, + else => return error.TlsIllegalParameter, + }; + } + + // Returns 32, 65, 97 or 1216 bytes + pub inline fn publicKey(self: DhKeyPair, named_group: proto.NamedGroup) ![]const u8 { + return switch (named_group) { + .x25519 => &self.x25519_kp.public_key, + .secp256r1 => &self.secp256r1_kp.public_key.toUncompressedSec1(), + .secp384r1 => &self.secp384r1_kp.public_key.toUncompressedSec1(), + .x25519_kyber768d00 => &self.x25519_kp.public_key ++ self.kyber768_kp.public_key.toBytes(), + else => return error.TlsIllegalParameter, + }; + } +}; + +const testing = std.testing; +const testu = @import("testu.zig"); + +test "DhKeyPair.x25519" { + var seed: [DhKeyPair.seed_len]u8 = undefined; + testu.fill(&seed); + const server_pub_key = &testu.hexToBytes("3303486548531f08d91e675caf666c2dc924ac16f47a861a7f4d05919d143637"); + const expected = &testu.hexToBytes( + \\ F1 67 FB 4A 49 B2 91 77 08 29 45 A1 F7 08 5A 21 + \\ AF FE 9E 78 C2 03 9B 81 92 40 72 73 74 7A 46 1E + ); + const kp = try DhKeyPair.init(seed, &.{.x25519}); + try testing.expectEqualSlices(u8, expected, try kp.sharedKey(.x25519, server_pub_key)); +} diff --git a/src/http/async/tls.zig/handshake_server.zig b/src/http/async/tls.zig/handshake_server.zig new file mode 100644 index 00000000..c26e8c69 --- /dev/null +++ b/src/http/async/tls.zig/handshake_server.zig @@ -0,0 +1,520 @@ +const std = @import("std"); +const assert = std.debug.assert; +const crypto = std.crypto; +const mem = std.mem; +const Certificate = crypto.Certificate; + +const cipher = @import("cipher.zig"); +const Cipher = cipher.Cipher; +const CipherSuite = @import("cipher.zig").CipherSuite; +const cipher_suites = @import("cipher.zig").cipher_suites; +const Transcript = @import("transcript.zig").Transcript; +const record = @import("record.zig"); +const PrivateKey = @import("PrivateKey.zig"); +const proto = @import("protocol.zig"); + +const common = @import("handshake_common.zig"); +const dupe = common.dupe; +const CertificateBuilder = common.CertificateBuilder; +const CertificateParser = common.CertificateParser; +const DhKeyPair = common.DhKeyPair; +const CertBundle = common.CertBundle; +const CertKeyPair = common.CertKeyPair; + +pub const Options = struct { + /// Server authentication. If null server will not send Certificate and + /// CertificateVerify message. + auth: ?CertKeyPair, + + /// If not null server will request client certificate. If auth_type is + /// .request empty client certificate message will be accepted. + /// Client certificate will be verified with root_ca certificates. + client_auth: ?ClientAuth = null, +}; + +pub const ClientAuth = struct { + /// Set of root certificate authorities that server use when verifying + /// client certificates. + root_ca: CertBundle, + + auth_type: Type = .require, + + pub const Type = enum { + /// Client certificate will be requested during the handshake, but does + /// not require that the client send any certificates. + request, + /// Client certificate will be requested during the handshake, and client + /// has to send valid certificate. + require, + }; +}; + +pub fn Handshake(comptime Stream: type) type { + const RecordReaderT = record.Reader(Stream); + return struct { + // public key len: x25519 = 32, secp256r1 = 65, secp384r1 = 97 + const max_pub_key_len = 98; + const supported_named_groups = &[_]proto.NamedGroup{ .x25519, .secp256r1, .secp384r1 }; + + server_random: [32]u8 = undefined, + client_random: [32]u8 = undefined, + legacy_session_id_buf: [32]u8 = undefined, + legacy_session_id: []u8 = "", + cipher_suite: CipherSuite = @enumFromInt(0), + signature_scheme: proto.SignatureScheme = @enumFromInt(0), + named_group: proto.NamedGroup = @enumFromInt(0), + client_pub_key_buf: [max_pub_key_len]u8 = undefined, + client_pub_key: []u8 = "", + server_pub_key_buf: [max_pub_key_len]u8 = undefined, + server_pub_key: []u8 = "", + + cipher: Cipher = undefined, + transcript: Transcript = .{}, + rec_rdr: *RecordReaderT, + buffer: []u8, + + const HandshakeT = @This(); + + pub fn init(buf: []u8, rec_rdr: *RecordReaderT) HandshakeT { + return .{ + .rec_rdr = rec_rdr, + .buffer = buf, + }; + } + + fn writeAlert(h: *HandshakeT, stream: Stream, cph: ?*Cipher, err: anyerror) !void { + if (cph) |c| { + const cleartext = proto.alertFromError(err); + const ciphertext = try c.encrypt(h.buffer, .alert, &cleartext); + stream.writeAll(ciphertext) catch {}; + } else { + const alert = record.header(.alert, 2) ++ proto.alertFromError(err); + stream.writeAll(&alert) catch {}; + } + } + + pub fn handshake(h: *HandshakeT, stream: Stream, opt: Options) !Cipher { + crypto.random.bytes(&h.server_random); + if (opt.auth) |a| { + // required signature scheme in client hello + h.signature_scheme = a.key.signature_scheme; + } + + h.readClientHello() catch |err| { + try h.writeAlert(stream, null, err); + return err; + }; + h.transcript.use(h.cipher_suite.hash()); + + const server_flight = brk: { + var w = record.Writer{ .buf = h.buffer }; + + const shared_key = h.sharedKey() catch |err| { + try h.writeAlert(stream, null, err); + return err; + }; + { + const hello = try h.makeServerHello(w.getFree()); + h.transcript.update(hello[record.header_len..]); + w.pos += hello.len; + } + { + const handshake_secret = h.transcript.handshakeSecret(shared_key); + h.cipher = try Cipher.initTls13(h.cipher_suite, handshake_secret, .server); + } + try w.writeRecord(.change_cipher_spec, &[_]u8{1}); + { + const encrypted_extensions = &record.handshakeHeader(.encrypted_extensions, 2) ++ [_]u8{ 0, 0 }; + h.transcript.update(encrypted_extensions); + try h.writeEncrypted(&w, encrypted_extensions); + } + if (opt.client_auth) |_| { + const certificate_request = try makeCertificateRequest(w.getPayload()); + h.transcript.update(certificate_request); + try h.writeEncrypted(&w, certificate_request); + } + if (opt.auth) |a| { + const cm = CertificateBuilder{ + .bundle = a.bundle, + .key = a.key, + .transcript = &h.transcript, + .side = .server, + }; + { + const certificate = try cm.makeCertificate(w.getPayload()); + h.transcript.update(certificate); + try h.writeEncrypted(&w, certificate); + } + { + const certificate_verify = try cm.makeCertificateVerify(w.getPayload()); + h.transcript.update(certificate_verify); + try h.writeEncrypted(&w, certificate_verify); + } + } + { + const finished = try h.makeFinished(w.getPayload()); + h.transcript.update(finished); + try h.writeEncrypted(&w, finished); + } + break :brk w.getWritten(); + }; + try stream.writeAll(server_flight); + + var app_cipher = brk: { + const application_secret = h.transcript.applicationSecret(); + break :brk try Cipher.initTls13(h.cipher_suite, application_secret, .server); + }; + + h.readClientFlight2(opt) catch |err| { + // Alert received from client + if (!mem.startsWith(u8, @errorName(err), "TlsAlert")) { + try h.writeAlert(stream, &app_cipher, err); + } + return err; + }; + return app_cipher; + } + + inline fn sharedKey(h: *HandshakeT) ![]const u8 { + var seed: [DhKeyPair.seed_len]u8 = undefined; + crypto.random.bytes(&seed); + var kp = try DhKeyPair.init(seed, supported_named_groups); + h.server_pub_key = dupe(&h.server_pub_key_buf, try kp.publicKey(h.named_group)); + return try kp.sharedKey(h.named_group, h.client_pub_key); + } + + fn readClientFlight2(h: *HandshakeT, opt: Options) !void { + var cleartext_buf = h.buffer; + var cleartext_buf_head: usize = 0; + var cleartext_buf_tail: usize = 0; + var handshake_state: proto.Handshake = .finished; + var cert: CertificateParser = undefined; + if (opt.client_auth) |client_auth| { + cert = .{ .root_ca = client_auth.root_ca.bundle, .host = "" }; + handshake_state = .certificate; + } + + outer: while (true) { + const rec = (try h.rec_rdr.next() orelse return error.EndOfStream); + if (rec.protocol_version != .tls_1_2 and rec.content_type != .alert) + return error.TlsProtocolVersion; + + switch (rec.content_type) { + .change_cipher_spec => { + if (rec.payload.len != 1) return error.TlsUnexpectedMessage; + }, + .application_data => { + const content_type, const cleartext = try h.cipher.decrypt( + cleartext_buf[cleartext_buf_tail..], + rec, + ); + cleartext_buf_tail += cleartext.len; + if (cleartext_buf_tail > cleartext_buf.len) return error.TlsRecordOverflow; + + var d = record.Decoder.init(content_type, cleartext_buf[cleartext_buf_head..cleartext_buf_tail]); + try d.expectContentType(.handshake); + while (!d.eof()) { + const start_idx = d.idx; + const handshake_type = try d.decode(proto.Handshake); + const length = try d.decode(u24); + + if (length > cipher.max_cleartext_len) + return error.TlsRecordOverflow; + if (length > d.rest().len) + continue :outer; // fragmented handshake into multiple records + + defer { + const handshake_payload = d.payload[start_idx..d.idx]; + h.transcript.update(handshake_payload); + cleartext_buf_head += handshake_payload.len; + } + + if (handshake_state != handshake_type) + return error.TlsUnexpectedMessage; + + switch (handshake_type) { + .certificate => { + if (length == 4) { + // got empty certificate message + if (opt.client_auth.?.auth_type == .require) + return error.TlsCertificateRequired; + try d.skip(length); + handshake_state = .finished; + } else { + try cert.parseCertificate(&d, .tls_1_3); + handshake_state = .certificate_verify; + } + }, + .certificate_verify => { + try cert.parseCertificateVerify(&d); + cert.verifySignature(h.transcript.clientCertificateVerify()) catch |err| return switch (err) { + error.TlsUnknownSignatureScheme => error.TlsIllegalParameter, + else => error.TlsDecryptError, + }; + handshake_state = .finished; + }, + .finished => { + const actual = try d.slice(length); + var buf: [Transcript.max_mac_length]u8 = undefined; + const expected = h.transcript.clientFinishedTls13(&buf); + if (!mem.eql(u8, expected, actual)) + return if (expected.len == actual.len) + error.TlsDecryptError + else + error.TlsDecodeError; + return; + }, + else => return error.TlsUnexpectedMessage, + } + } + cleartext_buf_head = 0; + cleartext_buf_tail = 0; + }, + .alert => { + var d = rec.decoder(); + return d.raiseAlert(); + }, + else => return error.TlsUnexpectedMessage, + } + } + } + + fn makeFinished(h: *HandshakeT, buf: []u8) ![]const u8 { + var w = record.Writer{ .buf = buf }; + const verify_data = h.transcript.serverFinishedTls13(w.getHandshakePayload()); + try w.advanceHandshake(.finished, verify_data.len); + return w.getWritten(); + } + + /// Write encrypted handshake message into `w` + fn writeEncrypted(h: *HandshakeT, w: *record.Writer, cleartext: []const u8) !void { + const ciphertext = try h.cipher.encrypt(w.getFree(), .handshake, cleartext); + w.pos += ciphertext.len; + } + + fn makeServerHello(h: *HandshakeT, buf: []u8) ![]const u8 { + const header_len = 9; // tls record header (5 bytes) and handshake header (4 bytes) + var w = record.Writer{ .buf = buf[header_len..] }; + + try w.writeEnum(proto.Version.tls_1_2); + try w.write(&h.server_random); + { + try w.writeInt(@as(u8, @intCast(h.legacy_session_id.len))); + if (h.legacy_session_id.len > 0) try w.write(h.legacy_session_id); + } + try w.writeEnum(h.cipher_suite); + try w.write(&[_]u8{0}); // compression method + + var e = record.Writer{ .buf = buf[header_len + w.pos + 2 ..] }; + { // supported versions extension + try e.writeEnum(proto.Extension.supported_versions); + try e.writeInt(@as(u16, 2)); + try e.writeEnum(proto.Version.tls_1_3); + } + { // key share extension + const key_len: u16 = @intCast(h.server_pub_key.len); + try e.writeEnum(proto.Extension.key_share); + try e.writeInt(key_len + 4); + try e.writeEnum(h.named_group); + try e.writeInt(key_len); + try e.write(h.server_pub_key); + } + try w.writeInt(@as(u16, @intCast(e.pos))); // extensions length + + const payload_len = w.pos + e.pos; + buf[0..header_len].* = record.header(.handshake, 4 + payload_len) ++ + record.handshakeHeader(.server_hello, payload_len); + + return buf[0 .. header_len + payload_len]; + } + + fn makeCertificateRequest(buf: []u8) ![]const u8 { + // handshake header + context length + extensions length + const header_len = 4 + 1 + 2; + + // First write extensions, leave space for header. + var ext = record.Writer{ .buf = buf[header_len..] }; + try ext.writeExtension(.signature_algorithms, common.supported_signature_algorithms); + + var w = record.Writer{ .buf = buf }; + try w.writeHandshakeHeader(.certificate_request, ext.pos + 3); + try w.writeInt(@as(u8, 0)); // certificate request context length = 0 + try w.writeInt(@as(u16, @intCast(ext.pos))); // extensions length + assert(w.pos == header_len); + w.pos += ext.pos; + + return w.getWritten(); + } + + fn readClientHello(h: *HandshakeT) !void { + var d = try h.rec_rdr.nextDecoder(); + try d.expectContentType(.handshake); + h.transcript.update(d.payload); + + const handshake_type = try d.decode(proto.Handshake); + if (handshake_type != .client_hello) return error.TlsUnexpectedMessage; + _ = try d.decode(u24); // handshake length + if (try d.decode(proto.Version) != .tls_1_2) return error.TlsProtocolVersion; + + h.client_random = try d.array(32); + { // legacy session id + const len = try d.decode(u8); + h.legacy_session_id = dupe(&h.legacy_session_id_buf, try d.slice(len)); + } + { // cipher suites + const end_idx = try d.decode(u16) + d.idx; + + while (d.idx < end_idx) { + const cipher_suite = try d.decode(CipherSuite); + if (cipher_suites.includes(cipher_suites.tls13, cipher_suite) and + @intFromEnum(h.cipher_suite) == 0) + { + h.cipher_suite = cipher_suite; + } + } + if (@intFromEnum(h.cipher_suite) == 0) + return error.TlsHandshakeFailure; + } + try d.skip(2); // compression methods + + var key_share_received = false; + // extensions + const extensions_end_idx = try d.decode(u16) + d.idx; + while (d.idx < extensions_end_idx) { + const extension_type = try d.decode(proto.Extension); + const extension_len = try d.decode(u16); + + switch (extension_type) { + .supported_versions => { + var tls_1_3_supported = false; + const end_idx = try d.decode(u8) + d.idx; + while (d.idx < end_idx) { + if (try d.decode(proto.Version) == proto.Version.tls_1_3) { + tls_1_3_supported = true; + } + } + if (!tls_1_3_supported) return error.TlsProtocolVersion; + }, + .key_share => { + if (extension_len == 0) return error.TlsDecodeError; + key_share_received = true; + var selected_named_group_idx = supported_named_groups.len; + const end_idx = try d.decode(u16) + d.idx; + while (d.idx < end_idx) { + const named_group = try d.decode(proto.NamedGroup); + switch (@intFromEnum(named_group)) { + 0x0001...0x0016, + 0x001a...0x001c, + 0xff01...0xff02, + => return error.TlsIllegalParameter, + else => {}, + } + const client_pub_key = try d.slice(try d.decode(u16)); + for (supported_named_groups, 0..) |supported, idx| { + if (named_group == supported and idx < selected_named_group_idx) { + h.named_group = named_group; + h.client_pub_key = dupe(&h.client_pub_key_buf, client_pub_key); + selected_named_group_idx = idx; + } + } + } + if (@intFromEnum(h.named_group) == 0) + return error.TlsIllegalParameter; + }, + .supported_groups => { + const end_idx = try d.decode(u16) + d.idx; + while (d.idx < end_idx) { + const named_group = try d.decode(proto.NamedGroup); + switch (@intFromEnum(named_group)) { + 0x0001...0x0016, + 0x001a...0x001c, + 0xff01...0xff02, + => return error.TlsIllegalParameter, + else => {}, + } + } + }, + .signature_algorithms => { + if (@intFromEnum(h.signature_scheme) == 0) { + try d.skip(extension_len); + } else { + var found = false; + const list_len = try d.decode(u16); + if (list_len == 0) return error.TlsDecodeError; + const end_idx = list_len + d.idx; + while (d.idx < end_idx) { + const signature_scheme = try d.decode(proto.SignatureScheme); + if (signature_scheme == h.signature_scheme) found = true; + } + if (!found) return error.TlsHandshakeFailure; + } + }, + else => { + try d.skip(extension_len); + }, + } + } + if (!key_share_received) return error.TlsMissingExtension; + if (@intFromEnum(h.named_group) == 0) return error.TlsIllegalParameter; + } + }; +} + +const testing = std.testing; +const data13 = @import("testdata/tls13.zig"); +const testu = @import("testu.zig"); + +fn testReader(data: []const u8) record.Reader(std.io.FixedBufferStream([]const u8)) { + return record.reader(std.io.fixedBufferStream(data)); +} +const TestHandshake = Handshake(std.io.FixedBufferStream([]const u8)); + +test "read client hello" { + var buffer: [1024]u8 = undefined; + var rec_rdr = testReader(&data13.client_hello); + var h = TestHandshake.init(&buffer, &rec_rdr); + h.signature_scheme = .ecdsa_secp521r1_sha512; // this must be supported in signature_algorithms extension + try h.readClientHello(); + + try testing.expectEqual(CipherSuite.AES_256_GCM_SHA384, h.cipher_suite); + try testing.expectEqual(.x25519, h.named_group); + try testing.expectEqualSlices(u8, &data13.client_random, &h.client_random); + try testing.expectEqualSlices(u8, &data13.client_public_key, h.client_pub_key); +} + +test "make server hello" { + var buffer: [128]u8 = undefined; + var h = TestHandshake.init(&buffer, undefined); + h.cipher_suite = .AES_256_GCM_SHA384; + testu.fillFrom(&h.server_random, 0); + testu.fillFrom(&h.server_pub_key_buf, 0x20); + h.named_group = .x25519; + h.server_pub_key = h.server_pub_key_buf[0..32]; + + const actual = try h.makeServerHello(&buffer); + const expected = &testu.hexToBytes( + \\ 16 03 03 00 5a 02 00 00 56 + \\ 03 03 + \\ 00 01 02 03 04 05 06 07 08 09 0a 0b 0c 0d 0e 0f 10 11 12 13 14 15 16 17 18 19 1a 1b 1c 1d 1e 1f + \\ 00 + \\ 13 02 00 + \\ 00 2e 00 2b 00 02 03 04 + \\ 00 33 00 24 00 1d 00 20 + \\ 20 21 22 23 24 25 26 27 28 29 2a 2b 2c 2d 2e 2f 30 31 32 33 34 35 36 37 38 39 3a 3b 3c 3d 3e 3f + ); + try testing.expectEqualSlices(u8, expected, actual); +} + +test "make certificate request" { + var buffer: [32]u8 = undefined; + + const expected = testu.hexToBytes("0d 00 00 1b" ++ // handshake header + "00 00 18" ++ // extension length + "00 0d" ++ // signature algorithms extension + "00 14" ++ // extension length + "00 12" ++ // list length 6 * 2 bytes + "04 03 05 03 08 04 08 05 08 06 08 07 02 01 04 01 05 01" // signature schemes + ); + const actual = try TestHandshake.makeCertificateRequest(&buffer); + try testing.expectEqualSlices(u8, &expected, actual); +} diff --git a/src/http/async/tls.zig/key_log.zig b/src/http/async/tls.zig/key_log.zig new file mode 100644 index 00000000..2da83f42 --- /dev/null +++ b/src/http/async/tls.zig/key_log.zig @@ -0,0 +1,60 @@ +//! Exporting tls key so we can share them with Wireshark and analyze decrypted +//! traffic in Wireshark. +//! To configure Wireshark to use exprted keys see curl reference. +//! +//! References: +//! curl: https://everything.curl.dev/usingcurl/tls/sslkeylogfile.html +//! openssl: https://www.openssl.org/docs/manmaster/man3/SSL_CTX_set_keylog_callback.html +//! https://udn.realityripple.com/docs/Mozilla/Projects/NSS/Key_Log_Format + +const std = @import("std"); + +const key_log_file_env = "SSLKEYLOGFILE"; + +pub const label = struct { + // tls 1.3 + pub const client_handshake_traffic_secret: []const u8 = "CLIENT_HANDSHAKE_TRAFFIC_SECRET"; + pub const server_handshake_traffic_secret: []const u8 = "SERVER_HANDSHAKE_TRAFFIC_SECRET"; + pub const client_traffic_secret_0: []const u8 = "CLIENT_TRAFFIC_SECRET_0"; + pub const server_traffic_secret_0: []const u8 = "SERVER_TRAFFIC_SECRET_0"; + // tls 1.2 + pub const client_random: []const u8 = "CLIENT_RANDOM"; +}; + +pub const Callback = *const fn (label: []const u8, client_random: []const u8, secret: []const u8) void; + +/// Writes tls keys to the file pointed by SSLKEYLOGFILE environment variable. +pub fn callback(label_: []const u8, client_random: []const u8, secret: []const u8) void { + if (std.posix.getenv(key_log_file_env)) |file_name| { + fileAppend(file_name, label_, client_random, secret) catch return; + } +} + +pub fn fileAppend(file_name: []const u8, label_: []const u8, client_random: []const u8, secret: []const u8) !void { + var buf: [1024]u8 = undefined; + const line = try formatLine(&buf, label_, client_random, secret); + try fileWrite(file_name, line); +} + +fn fileWrite(file_name: []const u8, line: []const u8) !void { + var file = try std.fs.createFileAbsolute(file_name, .{ .truncate = false }); + defer file.close(); + const stat = try file.stat(); + try file.seekTo(stat.size); + try file.writeAll(line); +} + +pub fn formatLine(buf: []u8, label_: []const u8, client_random: []const u8, secret: []const u8) ![]const u8 { + var fbs = std.io.fixedBufferStream(buf); + const w = fbs.writer(); + try w.print("{s} ", .{label_}); + for (client_random) |b| { + try std.fmt.formatInt(b, 16, .lower, .{ .width = 2, .fill = '0' }, w); + } + try w.writeByte(' '); + for (secret) |b| { + try std.fmt.formatInt(b, 16, .lower, .{ .width = 2, .fill = '0' }, w); + } + try w.writeByte('\n'); + return fbs.getWritten(); +} diff --git a/src/http/async/tls.zig/main.zig b/src/http/async/tls.zig/main.zig new file mode 100644 index 00000000..b974377b --- /dev/null +++ b/src/http/async/tls.zig/main.zig @@ -0,0 +1,51 @@ +const std = @import("std"); + +pub const CipherSuite = @import("cipher.zig").CipherSuite; +pub const cipher_suites = @import("cipher.zig").cipher_suites; +pub const PrivateKey = @import("PrivateKey.zig"); +pub const Connection = @import("connection.zig").Connection; +pub const ClientOptions = @import("handshake_client.zig").Options; +pub const ServerOptions = @import("handshake_server.zig").Options; +pub const key_log = @import("key_log.zig"); +pub const proto = @import("protocol.zig"); +pub const NamedGroup = proto.NamedGroup; +pub const Version = proto.Version; +const common = @import("handshake_common.zig"); +pub const CertBundle = common.CertBundle; +pub const CertKeyPair = common.CertKeyPair; + +pub const record = @import("record.zig"); +const connection = @import("connection.zig").connection; +const max_ciphertext_record_len = @import("cipher.zig").max_ciphertext_record_len; +const HandshakeServer = @import("handshake_server.zig").Handshake; +const HandshakeClient = @import("handshake_client.zig").Handshake; + +pub fn client(stream: anytype, opt: ClientOptions) !Connection(@TypeOf(stream)) { + const Stream = @TypeOf(stream); + var conn = connection(stream); + var write_buf: [max_ciphertext_record_len]u8 = undefined; + var h = HandshakeClient(Stream).init(&write_buf, &conn.rec_rdr); + conn.cipher = try h.handshake(conn.stream, opt); + return conn; +} + +pub fn server(stream: anytype, opt: ServerOptions) !Connection(@TypeOf(stream)) { + const Stream = @TypeOf(stream); + var conn = connection(stream); + var write_buf: [max_ciphertext_record_len]u8 = undefined; + var h = HandshakeServer(Stream).init(&write_buf, &conn.rec_rdr); + conn.cipher = try h.handshake(conn.stream, opt); + return conn; +} + +test { + _ = @import("handshake_common.zig"); + _ = @import("handshake_server.zig"); + _ = @import("handshake_client.zig"); + + _ = @import("connection.zig"); + _ = @import("cipher.zig"); + _ = @import("record.zig"); + _ = @import("transcript.zig"); + _ = @import("PrivateKey.zig"); +} diff --git a/src/http/async/tls.zig/protocol.zig b/src/http/async/tls.zig/protocol.zig new file mode 100644 index 00000000..e3bb07ac --- /dev/null +++ b/src/http/async/tls.zig/protocol.zig @@ -0,0 +1,302 @@ +pub const Version = enum(u16) { + tls_1_2 = 0x0303, + tls_1_3 = 0x0304, + _, +}; + +pub const ContentType = enum(u8) { + invalid = 0, + change_cipher_spec = 20, + alert = 21, + handshake = 22, + application_data = 23, + _, +}; + +pub const Handshake = enum(u8) { + client_hello = 1, + server_hello = 2, + new_session_ticket = 4, + end_of_early_data = 5, + encrypted_extensions = 8, + certificate = 11, + server_key_exchange = 12, + certificate_request = 13, + server_hello_done = 14, + certificate_verify = 15, + client_key_exchange = 16, + finished = 20, + key_update = 24, + message_hash = 254, + _, +}; + +pub const Curve = enum(u8) { + named_curve = 0x03, + _, +}; + +pub const Extension = enum(u16) { + /// RFC 6066 + server_name = 0, + /// RFC 6066 + max_fragment_length = 1, + /// RFC 6066 + status_request = 5, + /// RFC 8422, 7919 + supported_groups = 10, + /// RFC 8446 + signature_algorithms = 13, + /// RFC 5764 + use_srtp = 14, + /// RFC 6520 + heartbeat = 15, + /// RFC 7301 + application_layer_protocol_negotiation = 16, + /// RFC 6962 + signed_certificate_timestamp = 18, + /// RFC 7250 + client_certificate_type = 19, + /// RFC 7250 + server_certificate_type = 20, + /// RFC 7685 + padding = 21, + /// RFC 8446 + pre_shared_key = 41, + /// RFC 8446 + early_data = 42, + /// RFC 8446 + supported_versions = 43, + /// RFC 8446 + cookie = 44, + /// RFC 8446 + psk_key_exchange_modes = 45, + /// RFC 8446 + certificate_authorities = 47, + /// RFC 8446 + oid_filters = 48, + /// RFC 8446 + post_handshake_auth = 49, + /// RFC 8446 + signature_algorithms_cert = 50, + /// RFC 8446 + key_share = 51, + + _, +}; + +pub fn alertFromError(err: anyerror) [2]u8 { + return [2]u8{ @intFromEnum(Alert.Level.fatal), @intFromEnum(Alert.fromError(err)) }; +} + +pub const Alert = enum(u8) { + pub const Level = enum(u8) { + warning = 1, + fatal = 2, + _, + }; + + pub const Error = error{ + TlsAlertUnexpectedMessage, + TlsAlertBadRecordMac, + TlsAlertRecordOverflow, + TlsAlertHandshakeFailure, + TlsAlertBadCertificate, + TlsAlertUnsupportedCertificate, + TlsAlertCertificateRevoked, + TlsAlertCertificateExpired, + TlsAlertCertificateUnknown, + TlsAlertIllegalParameter, + TlsAlertUnknownCa, + TlsAlertAccessDenied, + TlsAlertDecodeError, + TlsAlertDecryptError, + TlsAlertProtocolVersion, + TlsAlertInsufficientSecurity, + TlsAlertInternalError, + TlsAlertInappropriateFallback, + TlsAlertMissingExtension, + TlsAlertUnsupportedExtension, + TlsAlertUnrecognizedName, + TlsAlertBadCertificateStatusResponse, + TlsAlertUnknownPskIdentity, + TlsAlertCertificateRequired, + TlsAlertNoApplicationProtocol, + TlsAlertUnknown, + }; + + close_notify = 0, + unexpected_message = 10, + bad_record_mac = 20, + record_overflow = 22, + handshake_failure = 40, + bad_certificate = 42, + unsupported_certificate = 43, + certificate_revoked = 44, + certificate_expired = 45, + certificate_unknown = 46, + illegal_parameter = 47, + unknown_ca = 48, + access_denied = 49, + decode_error = 50, + decrypt_error = 51, + protocol_version = 70, + insufficient_security = 71, + internal_error = 80, + inappropriate_fallback = 86, + user_canceled = 90, + missing_extension = 109, + unsupported_extension = 110, + unrecognized_name = 112, + bad_certificate_status_response = 113, + unknown_psk_identity = 115, + certificate_required = 116, + no_application_protocol = 120, + _, + + pub fn toError(alert: Alert) Error!void { + return switch (alert) { + .close_notify => {}, // not an error + .unexpected_message => error.TlsAlertUnexpectedMessage, + .bad_record_mac => error.TlsAlertBadRecordMac, + .record_overflow => error.TlsAlertRecordOverflow, + .handshake_failure => error.TlsAlertHandshakeFailure, + .bad_certificate => error.TlsAlertBadCertificate, + .unsupported_certificate => error.TlsAlertUnsupportedCertificate, + .certificate_revoked => error.TlsAlertCertificateRevoked, + .certificate_expired => error.TlsAlertCertificateExpired, + .certificate_unknown => error.TlsAlertCertificateUnknown, + .illegal_parameter => error.TlsAlertIllegalParameter, + .unknown_ca => error.TlsAlertUnknownCa, + .access_denied => error.TlsAlertAccessDenied, + .decode_error => error.TlsAlertDecodeError, + .decrypt_error => error.TlsAlertDecryptError, + .protocol_version => error.TlsAlertProtocolVersion, + .insufficient_security => error.TlsAlertInsufficientSecurity, + .internal_error => error.TlsAlertInternalError, + .inappropriate_fallback => error.TlsAlertInappropriateFallback, + .user_canceled => {}, // not an error + .missing_extension => error.TlsAlertMissingExtension, + .unsupported_extension => error.TlsAlertUnsupportedExtension, + .unrecognized_name => error.TlsAlertUnrecognizedName, + .bad_certificate_status_response => error.TlsAlertBadCertificateStatusResponse, + .unknown_psk_identity => error.TlsAlertUnknownPskIdentity, + .certificate_required => error.TlsAlertCertificateRequired, + .no_application_protocol => error.TlsAlertNoApplicationProtocol, + _ => error.TlsAlertUnknown, + }; + } + + pub fn fromError(err: anyerror) Alert { + return switch (err) { + error.TlsUnexpectedMessage => .unexpected_message, + error.TlsBadRecordMac => .bad_record_mac, + error.TlsRecordOverflow => .record_overflow, + error.TlsHandshakeFailure => .handshake_failure, + error.TlsBadCertificate => .bad_certificate, + error.TlsUnsupportedCertificate => .unsupported_certificate, + error.TlsCertificateRevoked => .certificate_revoked, + error.TlsCertificateExpired => .certificate_expired, + error.TlsCertificateUnknown => .certificate_unknown, + error.TlsIllegalParameter, + error.IdentityElement, + error.InvalidEncoding, + => .illegal_parameter, + error.TlsUnknownCa => .unknown_ca, + error.TlsAccessDenied => .access_denied, + error.TlsDecodeError => .decode_error, + error.TlsDecryptError => .decrypt_error, + error.TlsProtocolVersion => .protocol_version, + error.TlsInsufficientSecurity => .insufficient_security, + error.TlsInternalError => .internal_error, + error.TlsInappropriateFallback => .inappropriate_fallback, + error.TlsMissingExtension => .missing_extension, + error.TlsUnsupportedExtension => .unsupported_extension, + error.TlsUnrecognizedName => .unrecognized_name, + error.TlsBadCertificateStatusResponse => .bad_certificate_status_response, + error.TlsUnknownPskIdentity => .unknown_psk_identity, + error.TlsCertificateRequired => .certificate_required, + error.TlsNoApplicationProtocol => .no_application_protocol, + else => .internal_error, + }; + } + + pub fn parse(buf: [2]u8) Alert { + const level: Alert.Level = @enumFromInt(buf[0]); + const alert: Alert = @enumFromInt(buf[1]); + _ = level; + return alert; + } + + pub fn closeNotify() [2]u8 { + return [2]u8{ + @intFromEnum(Alert.Level.warning), + @intFromEnum(Alert.close_notify), + }; + } +}; + +pub const SignatureScheme = enum(u16) { + // RSASSA-PKCS1-v1_5 algorithms + rsa_pkcs1_sha256 = 0x0401, + rsa_pkcs1_sha384 = 0x0501, + rsa_pkcs1_sha512 = 0x0601, + + // ECDSA algorithms + ecdsa_secp256r1_sha256 = 0x0403, + ecdsa_secp384r1_sha384 = 0x0503, + ecdsa_secp521r1_sha512 = 0x0603, + + // RSASSA-PSS algorithms with public key OID rsaEncryption + rsa_pss_rsae_sha256 = 0x0804, + rsa_pss_rsae_sha384 = 0x0805, + rsa_pss_rsae_sha512 = 0x0806, + + // EdDSA algorithms + ed25519 = 0x0807, + ed448 = 0x0808, + + // RSASSA-PSS algorithms with public key OID RSASSA-PSS + rsa_pss_pss_sha256 = 0x0809, + rsa_pss_pss_sha384 = 0x080a, + rsa_pss_pss_sha512 = 0x080b, + + // Legacy algorithms + rsa_pkcs1_sha1 = 0x0201, + ecdsa_sha1 = 0x0203, + + _, +}; + +pub const NamedGroup = enum(u16) { + // Elliptic Curve Groups (ECDHE) + secp256r1 = 0x0017, + secp384r1 = 0x0018, + secp521r1 = 0x0019, + x25519 = 0x001D, + x448 = 0x001E, + + // Finite Field Groups (DHE) + ffdhe2048 = 0x0100, + ffdhe3072 = 0x0101, + ffdhe4096 = 0x0102, + ffdhe6144 = 0x0103, + ffdhe8192 = 0x0104, + + // Hybrid post-quantum key agreements + x25519_kyber512d00 = 0xFE30, + x25519_kyber768d00 = 0x6399, + + _, +}; + +pub const KeyUpdateRequest = enum(u8) { + update_not_requested = 0, + update_requested = 1, + _, +}; + +pub const Side = enum { + client, + server, +}; diff --git a/src/http/async/tls.zig/record.zig b/src/http/async/tls.zig/record.zig new file mode 100644 index 00000000..6c4df328 --- /dev/null +++ b/src/http/async/tls.zig/record.zig @@ -0,0 +1,405 @@ +const std = @import("std"); +const assert = std.debug.assert; +const mem = std.mem; + +const proto = @import("protocol.zig"); +const cipher = @import("cipher.zig"); +const Cipher = cipher.Cipher; +const record = @import("record.zig"); + +pub const header_len = 5; + +pub fn header(content_type: proto.ContentType, payload_len: usize) [header_len]u8 { + const int2 = std.crypto.tls.int2; + return [1]u8{@intFromEnum(content_type)} ++ + int2(@intFromEnum(proto.Version.tls_1_2)) ++ + int2(@intCast(payload_len)); +} + +pub fn handshakeHeader(handshake_type: proto.Handshake, payload_len: usize) [4]u8 { + const int3 = std.crypto.tls.int3; + return [1]u8{@intFromEnum(handshake_type)} ++ int3(@intCast(payload_len)); +} + +pub fn reader(inner_reader: anytype) Reader(@TypeOf(inner_reader)) { + return .{ .inner_reader = inner_reader }; +} + +pub fn Reader(comptime InnerReader: type) type { + return struct { + inner_reader: InnerReader, + + buffer: [cipher.max_ciphertext_record_len]u8 = undefined, + start: usize = 0, + end: usize = 0, + + const ReaderT = @This(); + + pub fn nextDecoder(r: *ReaderT) !Decoder { + const rec = (try r.next()) orelse return error.EndOfStream; + if (@intFromEnum(rec.protocol_version) != 0x0300 and + @intFromEnum(rec.protocol_version) != 0x0301 and + rec.protocol_version != .tls_1_2) + return error.TlsBadVersion; + return .{ + .content_type = rec.content_type, + .payload = rec.payload, + }; + } + + pub fn contentType(buf: []const u8) proto.ContentType { + return @enumFromInt(buf[0]); + } + + pub fn protocolVersion(buf: []const u8) proto.Version { + return @enumFromInt(mem.readInt(u16, buf[1..3], .big)); + } + + pub fn next(r: *ReaderT) !?Record { + while (true) { + const buffer = r.buffer[r.start..r.end]; + // If we have 5 bytes header. + if (buffer.len >= record.header_len) { + const record_header = buffer[0..record.header_len]; + const payload_len = mem.readInt(u16, record_header[3..5], .big); + if (payload_len > cipher.max_ciphertext_len) + return error.TlsRecordOverflow; + const record_len = record.header_len + payload_len; + // If we have whole record + if (buffer.len >= record_len) { + r.start += record_len; + return Record.init(buffer[0..record_len]); + } + } + { // Move dirty part to the start of the buffer. + const n = r.end - r.start; + if (n > 0 and r.start > 0) { + if (r.start > n) { + @memcpy(r.buffer[0..n], r.buffer[r.start..][0..n]); + } else { + mem.copyForwards(u8, r.buffer[0..n], r.buffer[r.start..][0..n]); + } + } + r.start = 0; + r.end = n; + } + { // Read more from inner_reader. + const n = try r.inner_reader.read(r.buffer[r.end..]); + if (n == 0) return null; + r.end += n; + } + } + } + + pub fn nextDecrypt(r: *ReaderT, cph: *Cipher) !?struct { proto.ContentType, []const u8 } { + const rec = (try r.next()) orelse return null; + if (rec.protocol_version != .tls_1_2) return error.TlsBadVersion; + + return try cph.decrypt( + // Reuse reader buffer for cleartext. `rec.header` and + // `rec.payload`(ciphertext) are also pointing somewhere in + // this buffer. Decrypter is first reading then writing a + // block, cleartext has less length then ciphertext, + // cleartext starts from the beginning of the buffer, so + // ciphertext is always ahead of cleartext. + r.buffer[0..r.start], + rec, + ); + } + + pub fn hasMore(r: *ReaderT) bool { + return r.end > r.start; + } + }; +} + +pub const Record = struct { + content_type: proto.ContentType, + protocol_version: proto.Version = .tls_1_2, + header: []const u8, + payload: []const u8, + + pub fn init(buffer: []const u8) Record { + return .{ + .content_type = @enumFromInt(buffer[0]), + .protocol_version = @enumFromInt(mem.readInt(u16, buffer[1..3], .big)), + .header = buffer[0..record.header_len], + .payload = buffer[record.header_len..], + }; + } + + pub fn decoder(r: @This()) Decoder { + return Decoder.init(r.content_type, @constCast(r.payload)); + } +}; + +pub const Decoder = struct { + content_type: proto.ContentType, + payload: []const u8, + idx: usize = 0, + + pub fn init(content_type: proto.ContentType, payload: []u8) Decoder { + return .{ + .content_type = content_type, + .payload = payload, + }; + } + + pub fn decode(d: *Decoder, comptime T: type) !T { + switch (@typeInfo(T)) { + .Int => |info| switch (info.bits) { + 8 => { + try skip(d, 1); + return d.payload[d.idx - 1]; + }, + 16 => { + try skip(d, 2); + const b0: u16 = d.payload[d.idx - 2]; + const b1: u16 = d.payload[d.idx - 1]; + return (b0 << 8) | b1; + }, + 24 => { + try skip(d, 3); + const b0: u24 = d.payload[d.idx - 3]; + const b1: u24 = d.payload[d.idx - 2]; + const b2: u24 = d.payload[d.idx - 1]; + return (b0 << 16) | (b1 << 8) | b2; + }, + else => @compileError("unsupported int type: " ++ @typeName(T)), + }, + .Enum => |info| { + const int = try d.decode(info.tag_type); + if (info.is_exhaustive) @compileError("exhaustive enum cannot be used"); + return @as(T, @enumFromInt(int)); + }, + else => @compileError("unsupported type: " ++ @typeName(T)), + } + } + + pub fn array(d: *Decoder, comptime len: usize) ![len]u8 { + try d.skip(len); + return d.payload[d.idx - len ..][0..len].*; + } + + pub fn slice(d: *Decoder, len: usize) ![]const u8 { + try d.skip(len); + return d.payload[d.idx - len ..][0..len]; + } + + pub fn skip(d: *Decoder, amt: usize) !void { + if (d.idx + amt > d.payload.len) return error.TlsDecodeError; + d.idx += amt; + } + + pub fn rest(d: Decoder) []const u8 { + return d.payload[d.idx..]; + } + + pub fn eof(d: Decoder) bool { + return d.idx == d.payload.len; + } + + pub fn expectContentType(d: *Decoder, content_type: proto.ContentType) !void { + if (d.content_type == content_type) return; + + switch (d.content_type) { + .alert => try d.raiseAlert(), + else => return error.TlsUnexpectedMessage, + } + } + + pub fn raiseAlert(d: *Decoder) !void { + if (d.payload.len < 2) return error.TlsUnexpectedMessage; + try proto.Alert.parse(try d.array(2)).toError(); + return error.TlsAlertCloseNotify; + } +}; + +const testing = std.testing; +const data12 = @import("testdata/tls12.zig"); +const testu = @import("testu.zig"); +const CipherSuite = @import("cipher.zig").CipherSuite; + +test Reader { + var fbs = std.io.fixedBufferStream(&data12.server_responses); + var rdr = reader(fbs.reader()); + + const expected = [_]struct { + content_type: proto.ContentType, + payload_len: usize, + }{ + .{ .content_type = .handshake, .payload_len = 49 }, + .{ .content_type = .handshake, .payload_len = 815 }, + .{ .content_type = .handshake, .payload_len = 300 }, + .{ .content_type = .handshake, .payload_len = 4 }, + .{ .content_type = .change_cipher_spec, .payload_len = 1 }, + .{ .content_type = .handshake, .payload_len = 64 }, + }; + for (expected) |e| { + const rec = (try rdr.next()).?; + try testing.expectEqual(e.content_type, rec.content_type); + try testing.expectEqual(e.payload_len, rec.payload.len); + try testing.expectEqual(.tls_1_2, rec.protocol_version); + } +} + +test Decoder { + var fbs = std.io.fixedBufferStream(&data12.server_responses); + var rdr = reader(fbs.reader()); + + var d = (try rdr.nextDecoder()); + try testing.expectEqual(.handshake, d.content_type); + + try testing.expectEqual(.server_hello, try d.decode(proto.Handshake)); + try testing.expectEqual(45, try d.decode(u24)); // length + try testing.expectEqual(.tls_1_2, try d.decode(proto.Version)); + try testing.expectEqualStrings( + &testu.hexToBytes("707172737475767778797a7b7c7d7e7f808182838485868788898a8b8c8d8e8f"), + try d.slice(32), + ); // server random + try testing.expectEqual(0, try d.decode(u8)); // session id len + try testing.expectEqual(.ECDHE_RSA_WITH_AES_128_CBC_SHA, try d.decode(CipherSuite)); + try testing.expectEqual(0, try d.decode(u8)); // compression method + try testing.expectEqual(5, try d.decode(u16)); // extension length + try testing.expectEqual(5, d.rest().len); + try d.skip(5); + try testing.expect(d.eof()); +} + +pub const Writer = struct { + buf: []u8, + pos: usize = 0, + + pub fn write(self: *Writer, data: []const u8) !void { + defer self.pos += data.len; + if (self.pos + data.len > self.buf.len) return error.BufferOverflow; + @memcpy(self.buf[self.pos..][0..data.len], data); + } + + pub fn writeByte(self: *Writer, b: u8) !void { + defer self.pos += 1; + if (self.pos == self.buf.len) return error.BufferOverflow; + self.buf[self.pos] = b; + } + + pub fn writeEnum(self: *Writer, value: anytype) !void { + try self.writeInt(@intFromEnum(value)); + } + + pub fn writeInt(self: *Writer, value: anytype) !void { + const IntT = @TypeOf(value); + const bytes = @divExact(@typeInfo(IntT).Int.bits, 8); + const free = self.buf[self.pos..]; + if (free.len < bytes) return error.BufferOverflow; + mem.writeInt(IntT, free[0..bytes], value, .big); + self.pos += bytes; + } + + pub fn writeHandshakeHeader(self: *Writer, handshake_type: proto.Handshake, payload_len: usize) !void { + try self.write(&record.handshakeHeader(handshake_type, payload_len)); + } + + /// Should be used after writing handshake payload in buffer provided by `getHandshakePayload`. + pub fn advanceHandshake(self: *Writer, handshake_type: proto.Handshake, payload_len: usize) !void { + try self.write(&record.handshakeHeader(handshake_type, payload_len)); + self.pos += payload_len; + } + + /// Record payload is already written by using buffer space from `getPayload`. + /// Now when we know payload len we can write record header and advance over payload. + pub fn advanceRecord(self: *Writer, content_type: proto.ContentType, payload_len: usize) !void { + try self.write(&record.header(content_type, payload_len)); + self.pos += payload_len; + } + + pub fn writeRecord(self: *Writer, content_type: proto.ContentType, payload: []const u8) !void { + try self.write(&record.header(content_type, payload.len)); + try self.write(payload); + } + + /// Preserves space for record header and returns buffer free space. + pub fn getPayload(self: *Writer) []u8 { + return self.buf[self.pos + record.header_len ..]; + } + + /// Preserves space for handshake header and returns buffer free space. + pub fn getHandshakePayload(self: *Writer) []u8 { + return self.buf[self.pos + 4 ..]; + } + + pub fn getWritten(self: *Writer) []const u8 { + return self.buf[0..self.pos]; + } + + pub fn getFree(self: *Writer) []u8 { + return self.buf[self.pos..]; + } + + pub fn writeEnumArray(self: *Writer, comptime E: type, tags: []const E) !void { + assert(@sizeOf(E) == 2); + try self.writeInt(@as(u16, @intCast(tags.len * 2))); + for (tags) |t| { + try self.writeEnum(t); + } + } + + pub fn writeExtension( + self: *Writer, + comptime et: proto.Extension, + tags: anytype, + ) !void { + try self.writeEnum(et); + if (et == .supported_versions) { + try self.writeInt(@as(u16, @intCast(tags.len * 2 + 1))); + try self.writeInt(@as(u8, @intCast(tags.len * 2))); + } else { + try self.writeInt(@as(u16, @intCast(tags.len * 2 + 2))); + try self.writeInt(@as(u16, @intCast(tags.len * 2))); + } + for (tags) |t| { + try self.writeEnum(t); + } + } + + pub fn writeKeyShare( + self: *Writer, + named_groups: []const proto.NamedGroup, + keys: []const []const u8, + ) !void { + assert(named_groups.len == keys.len); + try self.writeEnum(proto.Extension.key_share); + var l: usize = 0; + for (keys) |key| { + l += key.len + 4; + } + try self.writeInt(@as(u16, @intCast(l + 2))); + try self.writeInt(@as(u16, @intCast(l))); + for (named_groups, 0..) |ng, i| { + const key = keys[i]; + try self.writeEnum(ng); + try self.writeInt(@as(u16, @intCast(key.len))); + try self.write(key); + } + } + + pub fn writeServerName(self: *Writer, host: []const u8) !void { + const host_len: u16 = @intCast(host.len); + try self.writeEnum(proto.Extension.server_name); + try self.writeInt(host_len + 5); // byte length of extension payload + try self.writeInt(host_len + 3); // server_name_list byte count + try self.writeByte(0); // name type + try self.writeInt(host_len); + try self.write(host); + } +}; + +test "Writer" { + var buf: [16]u8 = undefined; + var w = Writer{ .buf = &buf }; + + try w.write("ab"); + try w.writeEnum(proto.Curve.named_curve); + try w.writeEnum(proto.NamedGroup.x25519); + try w.writeInt(@as(u16, 0x1234)); + try testing.expectEqualSlices(u8, &[_]u8{ 'a', 'b', 0x03, 0x00, 0x1d, 0x12, 0x34 }, w.getWritten()); +} diff --git a/src/http/async/tls.zig/rsa/der.zig b/src/http/async/tls.zig/rsa/der.zig new file mode 100644 index 00000000..743a65ad --- /dev/null +++ b/src/http/async/tls.zig/rsa/der.zig @@ -0,0 +1,467 @@ +//! An encoding of ASN.1. +//! +//! Distinguised Encoding Rules as defined in X.690 and X.691. +//! +//! A version of Basic Encoding Rules (BER) where there is exactly ONE way to +//! represent non-constructed elements. This is useful for cryptographic signatures. +//! +//! Currently an implementation detail of the standard library not fit for public +//! use since it's missing an encoder. + +const std = @import("std"); +const builtin = @import("builtin"); + +pub const Index = usize; +const log = std.log.scoped(.der); + +/// A secure DER parser that: +/// - Does NOT read memory outside `bytes`. +/// - Does NOT return elements with slices outside `bytes`. +/// - Errors on values that do NOT follow DER rules. +/// - Lengths that could be represented in a shorter form. +/// - Booleans that are not 0xff or 0x00. +pub const Parser = struct { + bytes: []const u8, + index: Index = 0, + + pub const Error = Element.Error || error{ + UnexpectedElement, + InvalidIntegerEncoding, + Overflow, + NonCanonical, + }; + + pub fn expectBool(self: *Parser) Error!bool { + const ele = try self.expect(.universal, false, .boolean); + if (ele.slice.len() != 1) return error.InvalidBool; + + return switch (self.view(ele)[0]) { + 0x00 => false, + 0xff => true, + else => error.InvalidBool, + }; + } + + pub fn expectBitstring(self: *Parser) Error!BitString { + const ele = try self.expect(.universal, false, .bitstring); + const bytes = self.view(ele); + const right_padding = bytes[0]; + if (right_padding >= 8) return error.InvalidBitString; + return .{ + .bytes = bytes[1..], + .right_padding = @intCast(right_padding), + }; + } + + // TODO: return high resolution date time type instead of epoch seconds + pub fn expectDateTime(self: *Parser) Error!i64 { + const ele = try self.expect(.universal, false, null); + const bytes = self.view(ele); + switch (ele.identifier.tag) { + .utc_time => { + // Example: "YYMMDD000000Z" + if (bytes.len != 13) + return error.InvalidDateTime; + if (bytes[12] != 'Z') + return error.InvalidDateTime; + + var date: Date = undefined; + date.year = try parseTimeDigits(bytes[0..2], 0, 99); + date.year += if (date.year >= 50) 1900 else 2000; + date.month = try parseTimeDigits(bytes[2..4], 1, 12); + date.day = try parseTimeDigits(bytes[4..6], 1, 31); + const time = try parseTime(bytes[6..12]); + + return date.toEpochSeconds() + time.toSec(); + }, + .generalized_time => { + // Examples: + // "19920622123421Z" + // "19920722132100.3Z" + if (bytes.len < 15) + return error.InvalidDateTime; + + var date: Date = undefined; + date.year = try parseYear4(bytes[0..4]); + date.month = try parseTimeDigits(bytes[4..6], 1, 12); + date.day = try parseTimeDigits(bytes[6..8], 1, 31); + const time = try parseTime(bytes[8..14]); + + return date.toEpochSeconds() + time.toSec(); + }, + else => return error.InvalidDateTime, + } + } + + pub fn expectOid(self: *Parser) Error![]const u8 { + const oid = try self.expect(.universal, false, .object_identifier); + return self.view(oid); + } + + pub fn expectEnum(self: *Parser, comptime Enum: type) Error!Enum { + const oid = try self.expectOid(); + return Enum.oids.get(oid) orelse { + if (builtin.mode == .Debug) { + var buf: [256]u8 = undefined; + var stream = std.io.fixedBufferStream(&buf); + try @import("./oid.zig").decode(oid, stream.writer()); + log.warn("unknown oid {s} for enum {s}\n", .{ stream.getWritten(), @typeName(Enum) }); + } + return error.UnknownObjectId; + }; + } + + pub fn expectInt(self: *Parser, comptime T: type) Error!T { + const ele = try self.expectPrimitive(.integer); + const bytes = self.view(ele); + + const info = @typeInfo(T); + if (info != .Int) @compileError(@typeName(T) ++ " is not an int type"); + const Shift = std.math.Log2Int(u8); + + var result: std.meta.Int(.unsigned, info.Int.bits) = 0; + for (bytes, 0..) |b, index| { + const shifted = @shlWithOverflow(b, @as(Shift, @intCast(index * 8))); + if (shifted[1] == 1) return error.Overflow; + + result |= shifted[0]; + } + + return @bitCast(result); + } + + pub fn expectString(self: *Parser, allowed: std.EnumSet(String.Tag)) Error!String { + const ele = try self.expect(.universal, false, null); + switch (ele.identifier.tag) { + inline .string_utf8, + .string_numeric, + .string_printable, + .string_teletex, + .string_videotex, + .string_ia5, + .string_visible, + .string_universal, + .string_bmp, + => |t| { + const tagname = @tagName(t)["string_".len..]; + const tag = std.meta.stringToEnum(String.Tag, tagname) orelse unreachable; + if (allowed.contains(tag)) { + return String{ .tag = tag, .data = self.view(ele) }; + } + }, + else => {}, + } + return error.UnexpectedElement; + } + + pub fn expectPrimitive(self: *Parser, tag: ?Identifier.Tag) Error!Element { + var elem = try self.expect(.universal, false, tag); + if (tag == .integer and elem.slice.len() > 0) { + if (self.view(elem)[0] == 0) elem.slice.start += 1; + if (elem.slice.len() > 0 and self.view(elem)[0] == 0) return error.InvalidIntegerEncoding; + } + return elem; + } + + /// Remember to call `expectEnd` + pub fn expectSequence(self: *Parser) Error!Element { + return try self.expect(.universal, true, .sequence); + } + + /// Remember to call `expectEnd` + pub fn expectSequenceOf(self: *Parser) Error!Element { + return try self.expect(.universal, true, .sequence_of); + } + + pub fn expectEnd(self: *Parser, val: usize) Error!void { + if (self.index != val) return error.NonCanonical; // either forgot to parse end OR an attacker + } + + pub fn expect( + self: *Parser, + class: ?Identifier.Class, + constructed: ?bool, + tag: ?Identifier.Tag, + ) Error!Element { + if (self.index >= self.bytes.len) return error.EndOfStream; + + const res = try Element.init(self.bytes, self.index); + if (tag) |e| { + if (res.identifier.tag != e) return error.UnexpectedElement; + } + if (constructed) |e| { + if (res.identifier.constructed != e) return error.UnexpectedElement; + } + if (class) |e| { + if (res.identifier.class != e) return error.UnexpectedElement; + } + self.index = if (res.identifier.constructed) res.slice.start else res.slice.end; + return res; + } + + pub fn view(self: Parser, elem: Element) []const u8 { + return elem.slice.view(self.bytes); + } + + pub fn seek(self: *Parser, index: usize) void { + self.index = index; + } + + pub fn eof(self: *Parser) bool { + return self.index == self.bytes.len; + } +}; + +pub const Element = struct { + identifier: Identifier, + slice: Slice, + + pub const Slice = struct { + start: Index, + end: Index, + + pub fn len(self: Slice) Index { + return self.end - self.start; + } + + pub fn view(self: Slice, bytes: []const u8) []const u8 { + return bytes[self.start..self.end]; + } + }; + + pub const Error = error{ InvalidLength, EndOfStream }; + + pub fn init(bytes: []const u8, index: Index) Error!Element { + var stream = std.io.fixedBufferStream(bytes[index..]); + var reader = stream.reader(); + + const identifier = @as(Identifier, @bitCast(try reader.readByte())); + const size_or_len_size = try reader.readByte(); + + var start = index + 2; + // short form between 0-127 + if (size_or_len_size < 128) { + const end = start + size_or_len_size; + if (end > bytes.len) return error.InvalidLength; + + return .{ .identifier = identifier, .slice = .{ .start = start, .end = end } }; + } + + // long form between 0 and std.math.maxInt(u1024) + const len_size: u7 = @truncate(size_or_len_size); + start += len_size; + if (len_size > @sizeOf(Index)) return error.InvalidLength; + const len = try reader.readVarInt(Index, .big, len_size); + if (len < 128) return error.InvalidLength; // should have used short form + + const end = std.math.add(Index, start, len) catch return error.InvalidLength; + if (end > bytes.len) return error.InvalidLength; + + return .{ .identifier = identifier, .slice = .{ .start = start, .end = end } }; + } +}; + +test Element { + const short_form = [_]u8{ 0x30, 0x03, 0x02, 0x01, 0x09 }; + try std.testing.expectEqual(Element{ + .identifier = Identifier{ .tag = .sequence, .constructed = true, .class = .universal }, + .slice = .{ .start = 2, .end = short_form.len }, + }, Element.init(&short_form, 0)); + + const long_form = [_]u8{ 0x30, 129, 129 } ++ [_]u8{0} ** 129; + try std.testing.expectEqual(Element{ + .identifier = Identifier{ .tag = .sequence, .constructed = true, .class = .universal }, + .slice = .{ .start = 3, .end = long_form.len }, + }, Element.init(&long_form, 0)); +} + +test "parser.expectInt" { + const one = [_]u8{ 2, 1, 1 }; + var parser = Parser{ .bytes = &one }; + try std.testing.expectEqual(@as(u8, 1), try parser.expectInt(u8)); +} + +pub const Identifier = packed struct(u8) { + tag: Tag, + constructed: bool, + class: Class, + + pub const Class = enum(u2) { + universal, + application, + context_specific, + private, + }; + + // https://www.oss.com/asn1/resources/asn1-made-simple/asn1-quick-reference/asn1-tags.html + pub const Tag = enum(u5) { + boolean = 1, + integer = 2, + bitstring = 3, + octetstring = 4, + null = 5, + object_identifier = 6, + real = 9, + enumerated = 10, + string_utf8 = 12, + sequence = 16, + sequence_of = 17, + string_numeric = 18, + string_printable = 19, + string_teletex = 20, + string_videotex = 21, + string_ia5 = 22, + utc_time = 23, + generalized_time = 24, + string_visible = 26, + string_universal = 28, + string_bmp = 30, + _, + }; +}; + +pub const BitString = struct { + bytes: []const u8, + right_padding: u3, + + pub fn bitLen(self: BitString) usize { + return self.bytes.len * 8 + self.right_padding; + } +}; + +pub const String = struct { + tag: Tag, + data: []const u8, + + pub const Tag = enum { + /// Blessed. + utf8, + /// us-ascii ([-][0-9][eE][.])* + numeric, + /// us-ascii ([A-Z][a-z][0-9][.?!,][ \t])* + printable, + /// iso-8859-1 with escaping into different character sets. + /// Cursed. + teletex, + /// iso-8859-1 + videotex, + /// us-ascii first 128 characters. + ia5, + /// us-ascii without control characters. + visible, + /// utf-32-be + universal, + /// utf-16-be + bmp, + }; + + pub const all = [_]Tag{ + .utf8, + .numeric, + .printable, + .teletex, + .videotex, + .ia5, + .visible, + .universal, + .bmp, + }; +}; + +const Date = struct { + year: Year, + month: u8, + day: u8, + + const Year = std.time.epoch.Year; + + fn toEpochSeconds(date: Date) i64 { + // Euclidean Affine Transform by Cassio and Neri. + // Shift and correction constants for 1970-01-01. + const s = 82; + const K = 719468 + 146097 * s; + const L = 400 * s; + + const Y_G: u32 = date.year; + const M_G: u32 = date.month; + const D_G: u32 = date.day; + // Map to computational calendar. + const J: u32 = if (M_G <= 2) 1 else 0; + const Y: u32 = Y_G + L - J; + const M: u32 = if (J != 0) M_G + 12 else M_G; + const D: u32 = D_G - 1; + const C: u32 = Y / 100; + + // Rata die. + const y_star: u32 = 1461 * Y / 4 - C + C / 4; + const m_star: u32 = (979 * M - 2919) / 32; + const N: u32 = y_star + m_star + D; + const days: i32 = @intCast(N - K); + + return @as(i64, days) * std.time.epoch.secs_per_day; + } +}; + +const Time = struct { + hour: std.math.IntFittingRange(0, 24), + minute: std.math.IntFittingRange(0, 60), + second: std.math.IntFittingRange(0, 60), + + fn toSec(t: Time) i64 { + var sec: i64 = 0; + sec += @as(i64, t.hour) * 60 * 60; + sec += @as(i64, t.minute) * 60; + sec += t.second; + return sec; + } +}; + +fn parseTimeDigits( + text: *const [2]u8, + min: comptime_int, + max: comptime_int, +) !std.math.IntFittingRange(min, max) { + const result = std.fmt.parseInt(std.math.IntFittingRange(min, max), text, 10) catch + return error.InvalidTime; + if (result < min) return error.InvalidTime; + if (result > max) return error.InvalidTime; + return result; +} + +test parseTimeDigits { + const expectEqual = std.testing.expectEqual; + try expectEqual(@as(u8, 0), try parseTimeDigits("00", 0, 99)); + try expectEqual(@as(u8, 99), try parseTimeDigits("99", 0, 99)); + try expectEqual(@as(u8, 42), try parseTimeDigits("42", 0, 99)); + + const expectError = std.testing.expectError; + try expectError(error.InvalidTime, parseTimeDigits("13", 1, 12)); + try expectError(error.InvalidTime, parseTimeDigits("00", 1, 12)); + try expectError(error.InvalidTime, parseTimeDigits("Di", 0, 99)); +} + +fn parseYear4(text: *const [4]u8) !Date.Year { + const result = std.fmt.parseInt(Date.Year, text, 10) catch return error.InvalidYear; + if (result > 9999) return error.InvalidYear; + return result; +} + +test parseYear4 { + const expectEqual = std.testing.expectEqual; + try expectEqual(@as(Date.Year, 0), try parseYear4("0000")); + try expectEqual(@as(Date.Year, 9999), try parseYear4("9999")); + try expectEqual(@as(Date.Year, 1988), try parseYear4("1988")); + + const expectError = std.testing.expectError; + try expectError(error.InvalidYear, parseYear4("999b")); + try expectError(error.InvalidYear, parseYear4("crap")); + try expectError(error.InvalidYear, parseYear4("r:bQ")); +} + +fn parseTime(bytes: *const [6]u8) !Time { + return .{ + .hour = try parseTimeDigits(bytes[0..2], 0, 23), + .minute = try parseTimeDigits(bytes[2..4], 0, 59), + .second = try parseTimeDigits(bytes[4..6], 0, 59), + }; +} diff --git a/src/http/async/tls.zig/rsa/oid.zig b/src/http/async/tls.zig/rsa/oid.zig new file mode 100644 index 00000000..fd360c3f --- /dev/null +++ b/src/http/async/tls.zig/rsa/oid.zig @@ -0,0 +1,132 @@ +//! Developed by ITU-U and ISO/IEC for naming objects. Used in DER. +//! +//! This implementation supports any number of `u32` arcs. + +const Arc = u32; +const encoding_base = 128; + +/// Returns encoded length. +pub fn encodeLen(dot_notation: []const u8) !usize { + var split = std.mem.splitScalar(u8, dot_notation, '.'); + if (split.next() == null) return 0; + if (split.next() == null) return 1; + + var res: usize = 1; + while (split.next()) |s| { + const parsed = try std.fmt.parseUnsigned(Arc, s, 10); + const n_bytes = if (parsed == 0) 0 else std.math.log(Arc, encoding_base, parsed); + + res += n_bytes; + res += 1; + } + + return res; +} + +pub const EncodeError = std.fmt.ParseIntError || error{ + MissingPrefix, + BufferTooSmall, +}; + +pub fn encode(dot_notation: []const u8, buf: []u8) EncodeError![]const u8 { + if (buf.len < try encodeLen(dot_notation)) return error.BufferTooSmall; + + var split = std.mem.splitScalar(u8, dot_notation, '.'); + const first_str = split.next() orelse return error.MissingPrefix; + const second_str = split.next() orelse return error.MissingPrefix; + + const first = try std.fmt.parseInt(u8, first_str, 10); + const second = try std.fmt.parseInt(u8, second_str, 10); + + buf[0] = first * 40 + second; + + var i: usize = 1; + while (split.next()) |s| { + var parsed = try std.fmt.parseUnsigned(Arc, s, 10); + const n_bytes = if (parsed == 0) 0 else std.math.log(Arc, encoding_base, parsed); + + for (0..n_bytes) |j| { + const place = std.math.pow(Arc, encoding_base, n_bytes - @as(Arc, @intCast(j))); + const digit: u8 = @intCast(@divFloor(parsed, place)); + + buf[i] = digit | 0x80; + parsed -= digit * place; + + i += 1; + } + buf[i] = @intCast(parsed); + i += 1; + } + + return buf[0..i]; +} + +pub fn decode(encoded: []const u8, writer: anytype) @TypeOf(writer).Error!void { + const first = @divTrunc(encoded[0], 40); + const second = encoded[0] - first * 40; + try writer.print("{d}.{d}", .{ first, second }); + + var i: usize = 1; + while (i != encoded.len) { + const n_bytes: usize = brk: { + var res: usize = 1; + var j: usize = i; + while (encoded[j] & 0x80 != 0) { + res += 1; + j += 1; + } + break :brk res; + }; + + var n: usize = 0; + for (0..n_bytes) |j| { + const place = std.math.pow(usize, encoding_base, n_bytes - j - 1); + n += place * (encoded[i] & 0b01111111); + i += 1; + } + try writer.print(".{d}", .{n}); + } +} + +pub fn encodeComptime(comptime dot_notation: []const u8) [encodeLen(dot_notation) catch unreachable]u8 { + @setEvalBranchQuota(10_000); + var buf: [encodeLen(dot_notation) catch unreachable]u8 = undefined; + _ = encode(dot_notation, &buf) catch unreachable; + return buf; +} + +const std = @import("std"); + +fn testOid(expected_encoded: []const u8, expected_dot_notation: []const u8) !void { + var buf: [256]u8 = undefined; + const encoded = try encode(expected_dot_notation, &buf); + try std.testing.expectEqualSlices(u8, expected_encoded, encoded); + + var stream = std.io.fixedBufferStream(&buf); + try decode(expected_encoded, stream.writer()); + try std.testing.expectEqualStrings(expected_dot_notation, stream.getWritten()); +} + +test "encode and decode" { + // https://learn.microsoft.com/en-us/windows/win32/seccertenroll/about-object-identifier + try testOid( + &[_]u8{ 0x2b, 0x06, 0x01, 0x04, 0x01, 0x82, 0x37, 0x15, 0x14 }, + "1.3.6.1.4.1.311.21.20", + ); + // https://luca.ntop.org/Teaching/Appunti/asn1.html + try testOid(&[_]u8{ 0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d }, "1.2.840.113549"); + // https://www.sysadmins.lv/blog-en/how-to-encode-object-identifier-to-an-asn1-der-encoded-string.aspx + try testOid(&[_]u8{ 0x2a, 0x86, 0x8d, 0x20 }, "1.2.100000"); + try testOid( + &[_]u8{ 0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d, 0x01, 0x01, 0x0b }, + "1.2.840.113549.1.1.11", + ); + try testOid(&[_]u8{ 0x2b, 0x65, 0x70 }, "1.3.101.112"); +} + +test encodeComptime { + try std.testing.expectEqual( + [_]u8{ 0x2b, 0x06, 0x01, 0x04, 0x01, 0x82, 0x37, 0x15, 0x14 }, + encodeComptime("1.3.6.1.4.1.311.21.20"), + ); +} diff --git a/src/http/async/tls.zig/rsa/rsa.zig b/src/http/async/tls.zig/rsa/rsa.zig new file mode 100644 index 00000000..5e5f42fe --- /dev/null +++ b/src/http/async/tls.zig/rsa/rsa.zig @@ -0,0 +1,880 @@ +//! RFC8017: Public Key Cryptography Standards #1 v2.2 (PKCS1) +const std = @import("std"); +const der = @import("der.zig"); +const ff = std.crypto.ff; + +pub const max_modulus_bits = 4096; +const max_modulus_len = max_modulus_bits / 8; + +const Modulus = std.crypto.ff.Modulus(max_modulus_bits); +const Fe = Modulus.Fe; + +pub const ValueError = error{ + Modulus, + Exponent, +}; + +pub const PublicKey = struct { + /// `n` + modulus: Modulus, + /// `e` + public_exponent: Fe, + + pub const FromBytesError = ValueError || ff.OverflowError || ff.FieldElementError || ff.InvalidModulusError || error{InsecureBitCount}; + + pub fn fromBytes(mod: []const u8, exp: []const u8) FromBytesError!PublicKey { + const modulus = try Modulus.fromBytes(mod, .big); + if (modulus.bits() <= 512) return error.InsecureBitCount; + const public_exponent = try Fe.fromBytes(modulus, exp, .big); + + if (std.debug.runtime_safety) { + // > the RSA public exponent e is an integer between 3 and n - 1 satisfying + // > GCD(e,\lambda(n)) = 1, where \lambda(n) = LCM(r_1 - 1, ..., r_u - 1) + const e_v = public_exponent.toPrimitive(u32) catch return error.Exponent; + if (!public_exponent.isOdd()) return error.Exponent; + if (e_v < 3) return error.Exponent; + if (modulus.v.compare(public_exponent.v) == .lt) return error.Exponent; + } + + return .{ .modulus = modulus, .public_exponent = public_exponent }; + } + + pub fn fromDer(bytes: []const u8) (der.Parser.Error || FromBytesError)!PublicKey { + var parser = der.Parser{ .bytes = bytes }; + + const seq = try parser.expectSequence(); + defer parser.seek(seq.slice.end); + + const modulus = try parser.expectPrimitive(.integer); + const pub_exp = try parser.expectPrimitive(.integer); + + try parser.expectEnd(seq.slice.end); + try parser.expectEnd(bytes.len); + + return try fromBytes(parser.view(modulus), parser.view(pub_exp)); + } + + /// Deprecated. + /// + /// Encrypt a short message using RSAES-PKCS1-v1_5. + /// The use of this scheme for encrypting an arbitrary message, as opposed to a + /// randomly generated key, is NOT RECOMMENDED. + pub fn encryptPkcsv1_5(pk: PublicKey, msg: []const u8, out: []u8) ![]const u8 { + // align variable names with spec + const k = byteLen(pk.modulus.bits()); + if (out.len < k) return error.BufferTooSmall; + if (msg.len > k - 11) return error.MessageTooLong; + + // EM = 0x00 || 0x02 || PS || 0x00 || M. + var em = out[0..k]; + em[0] = 0; + em[1] = 2; + + const ps = em[2..][0 .. k - msg.len - 3]; + // Section: 7.2.1 + // PS consists of pseudo-randomly generated nonzero octets. + for (ps) |*v| { + v.* = std.crypto.random.uintLessThan(u8, 0xff) + 1; + } + + em[em.len - msg.len - 1] = 0; + @memcpy(em[em.len - msg.len ..][0..msg.len], msg); + + const m = try Fe.fromBytes(pk.modulus, em, .big); + const e = try pk.modulus.powPublic(m, pk.public_exponent); + try e.toBytes(em, .big); + return em; + } + + /// Encrypt a short message using Optimal Asymmetric Encryption Padding (RSAES-OAEP). + pub fn encryptOaep( + pk: PublicKey, + comptime Hash: type, + msg: []const u8, + label: []const u8, + out: []u8, + ) ![]const u8 { + // align variable names with spec + const k = byteLen(pk.modulus.bits()); + if (out.len < k) return error.BufferTooSmall; + + if (msg.len > k - 2 * Hash.digest_length - 2) return error.MessageTooLong; + + // EM = 0x00 || maskedSeed || maskedDB. + var em = out[0..k]; + em[0] = 0; + const seed = em[1..][0..Hash.digest_length]; + std.crypto.random.bytes(seed); + + // DB = lHash || PS || 0x01 || M. + var db = em[1 + seed.len ..]; + const lHash = labelHash(Hash, label); + @memcpy(db[0..lHash.len], &lHash); + @memset(db[lHash.len .. db.len - msg.len - 2], 0); + db[db.len - msg.len - 1] = 1; + @memcpy(db[db.len - msg.len ..], msg); + + var mgf_buf: [max_modulus_len]u8 = undefined; + + const db_mask = mgf1(Hash, seed, mgf_buf[0..db.len]); + for (db, db_mask) |*v, m| v.* ^= m; + + const seed_mask = mgf1(Hash, db, mgf_buf[0..seed.len]); + for (seed, seed_mask) |*v, m| v.* ^= m; + + const m = try Fe.fromBytes(pk.modulus, em, .big); + const e = try pk.modulus.powPublic(m, pk.public_exponent); + try e.toBytes(em, .big); + return em; + } +}; + +pub fn byteLen(bits: usize) usize { + return std.math.divCeil(usize, bits, 8) catch unreachable; +} + +pub const SecretKey = struct { + /// `d` + private_exponent: Fe, + + pub const FromBytesError = ValueError || ff.OverflowError || ff.FieldElementError; + + pub fn fromBytes(n: Modulus, exp: []const u8) FromBytesError!SecretKey { + const d = try Fe.fromBytes(n, exp, .big); + if (std.debug.runtime_safety) { + // > The RSA private exponent d is a positive integer less than n + // > satisfying e * d == 1 (mod \lambda(n)), + if (!d.isOdd()) return error.Exponent; + if (d.v.compare(n.v) != .lt) return error.Exponent; + } + + return .{ .private_exponent = d }; + } +}; + +pub const KeyPair = struct { + public: PublicKey, + secret: SecretKey, + + pub const FromDerError = PublicKey.FromBytesError || SecretKey.FromBytesError || der.Parser.Error || error{ KeyMismatch, InvalidVersion }; + + pub fn fromDer(bytes: []const u8) FromDerError!KeyPair { + var parser = der.Parser{ .bytes = bytes }; + const seq = try parser.expectSequence(); + const version = try parser.expectInt(u8); + + const mod = try parser.expectPrimitive(.integer); + const pub_exp = try parser.expectPrimitive(.integer); + const sec_exp = try parser.expectPrimitive(.integer); + + const public = try PublicKey.fromBytes(parser.view(mod), parser.view(pub_exp)); + const secret = try SecretKey.fromBytes(public.modulus, parser.view(sec_exp)); + + const prime1 = try parser.expectPrimitive(.integer); + const prime2 = try parser.expectPrimitive(.integer); + const exp1 = try parser.expectPrimitive(.integer); + const exp2 = try parser.expectPrimitive(.integer); + const coeff = try parser.expectPrimitive(.integer); + _ = .{ exp1, exp2, coeff }; + + switch (version) { + 0 => {}, + 1 => { + _ = try parser.expectSequenceOf(); + while (!parser.eof()) { + _ = try parser.expectSequence(); + const ri = try parser.expectPrimitive(.integer); + const di = try parser.expectPrimitive(.integer); + const ti = try parser.expectPrimitive(.integer); + _ = .{ ri, di, ti }; + } + }, + else => return error.InvalidVersion, + } + + try parser.expectEnd(seq.slice.end); + try parser.expectEnd(bytes.len); + + if (std.debug.runtime_safety) { + const p = try Fe.fromBytes(public.modulus, parser.view(prime1), .big); + const q = try Fe.fromBytes(public.modulus, parser.view(prime2), .big); + + // check that n = p * q + const expected_zero = public.modulus.mul(p, q); + if (!expected_zero.isZero()) return error.KeyMismatch; + + // TODO: check that d * e is one mod p-1 and mod q-1. Note d and e were bound + // const de = secret.private_exponent.mul(public.public_exponent); + // const one = public.modulus.one(); + + // if (public.modulus.mul(de, p).compare(one) != .eq) return error.KeyMismatch; + // if (public.modulus.mul(de, q).compare(one) != .eq) return error.KeyMismatch; + } + + return .{ .public = public, .secret = secret }; + } + + /// Deprecated. + pub fn signPkcsv1_5(kp: KeyPair, comptime Hash: type, msg: []const u8, out: []u8) !PKCS1v1_5(Hash).Signature { + var st = try signerPkcsv1_5(kp, Hash); + st.update(msg); + return try st.finalize(out); + } + + /// Deprecated. + pub fn signerPkcsv1_5(kp: KeyPair, comptime Hash: type) !PKCS1v1_5(Hash).Signer { + return PKCS1v1_5(Hash).Signer.init(kp); + } + + /// Deprecated. + pub fn decryptPkcsv1_5(kp: KeyPair, ciphertext: []const u8, out: []u8) ![]const u8 { + const k = byteLen(kp.public.modulus.bits()); + if (out.len < k) return error.BufferTooSmall; + + const em = out[0..k]; + + const m = try Fe.fromBytes(kp.public.modulus, ciphertext, .big); + const e = try kp.public.modulus.pow(m, kp.secret.private_exponent); + try e.toBytes(em, .big); + + // Care shall be taken to ensure that an opponent cannot + // distinguish these error conditions, whether by error + // message or timing. + const msg_start = ct.lastIndexOfScalar(em, 0) orelse em.len; + const ps_len = em.len - msg_start; + if (ct.@"or"(em[0] != 0, ct.@"or"(em[1] != 2, ps_len < 8))) { + return error.Inconsistent; + } + + return em[msg_start + 1 ..]; + } + + pub fn signOaep( + kp: KeyPair, + comptime Hash: type, + msg: []const u8, + salt: ?[]const u8, + out: []u8, + ) !Pss(Hash).Signature { + var st = try signerOaep(kp, Hash, salt); + st.update(msg); + return try st.finalize(out); + } + + /// Salt must outlive returned `PSS.Signer`. + pub fn signerOaep(kp: KeyPair, comptime Hash: type, salt: ?[]const u8) !Pss(Hash).Signer { + return Pss(Hash).Signer.init(kp, salt); + } + + pub fn decryptOaep( + kp: KeyPair, + comptime Hash: type, + ciphertext: []const u8, + label: []const u8, + out: []u8, + ) ![]u8 { + // align variable names with spec + const k = byteLen(kp.public.modulus.bits()); + if (out.len < k) return error.BufferTooSmall; + + const mod = try Fe.fromBytes(kp.public.modulus, ciphertext, .big); + const exp = kp.public.modulus.pow(mod, kp.secret.private_exponent) catch unreachable; + const em = out[0..k]; + try exp.toBytes(em, .big); + + const y = em[0]; + const seed = em[1..][0..Hash.digest_length]; + const db = em[1 + Hash.digest_length ..]; + + var mgf_buf: [max_modulus_len]u8 = undefined; + + const seed_mask = mgf1(Hash, db, mgf_buf[0..seed.len]); + for (seed, seed_mask) |*v, m| v.* ^= m; + + const db_mask = mgf1(Hash, seed, mgf_buf[0..db.len]); + for (db, db_mask) |*v, m| v.* ^= m; + + const expected_hash = labelHash(Hash, label); + const actual_hash = db[0..expected_hash.len]; + + // Care shall be taken to ensure that an opponent cannot + // distinguish these error conditions, whether by error + // message or timing. + const msg_start = ct.indexOfScalarPos(em, expected_hash.len + 1, 1) orelse 0; + if (ct.@"or"(y != 0, ct.@"or"(msg_start == 0, !ct.memEql(&expected_hash, actual_hash)))) { + return error.Inconsistent; + } + + return em[msg_start + 1 ..]; + } + + /// Encrypt short plaintext with secret key. + pub fn encrypt(kp: KeyPair, plaintext: []const u8, out: []u8) !void { + const n = kp.public.modulus; + const k = byteLen(n.bits()); + if (plaintext.len > k) return error.MessageTooLong; + + const msg_as_int = try Fe.fromBytes(n, plaintext, .big); + const enc_as_int = try n.pow(msg_as_int, kp.secret.private_exponent); + try enc_as_int.toBytes(out, .big); + } +}; + +/// Deprecated. +/// +/// Signature Scheme with Appendix v1.5 (RSASSA-PKCS1-v1_5) +/// +/// This standard has been superceded by PSS which is formally proven secure +/// and has fewer footguns. +pub fn PKCS1v1_5(comptime Hash: type) type { + return struct { + const PkcsT = @This(); + pub const Signature = struct { + bytes: []const u8, + + const Self = @This(); + + pub fn verifier(self: Self, public_key: PublicKey) !Verifier { + return Verifier.init(self, public_key); + } + + pub fn verify(self: Self, msg: []const u8, public_key: PublicKey) !void { + var st = Verifier.init(self, public_key); + st.update(msg); + return st.verify(); + } + }; + + pub const Signer = struct { + h: Hash, + key_pair: KeyPair, + + fn init(key_pair: KeyPair) Signer { + return .{ + .h = Hash.init(.{}), + .key_pair = key_pair, + }; + } + + pub fn update(self: *Signer, data: []const u8) void { + self.h.update(data); + } + + pub fn finalize(self: *Signer, out: []u8) !PkcsT.Signature { + const k = byteLen(self.key_pair.public.modulus.bits()); + if (out.len < k) return error.BufferTooSmall; + + var hash: [Hash.digest_length]u8 = undefined; + self.h.final(&hash); + + const em = try emsaEncode(hash, out[0..k]); + try self.key_pair.encrypt(em, em); + return .{ .bytes = em }; + } + }; + + pub const Verifier = struct { + h: Hash, + sig: PkcsT.Signature, + public_key: PublicKey, + + fn init(sig: PkcsT.Signature, public_key: PublicKey) Verifier { + return Verifier{ + .h = Hash.init(.{}), + .sig = sig, + .public_key = public_key, + }; + } + + pub fn update(self: *Verifier, data: []const u8) void { + self.h.update(data); + } + + pub fn verify(self: *Verifier) !void { + const pk = self.public_key; + const s = try Fe.fromBytes(pk.modulus, self.sig.bytes, .big); + const emm = try pk.modulus.powPublic(s, pk.public_exponent); + + var em_buf: [max_modulus_len]u8 = undefined; + const em = em_buf[0..byteLen(pk.modulus.bits())]; + try emm.toBytes(em, .big); + + var hash: [Hash.digest_length]u8 = undefined; + self.h.final(&hash); + + // TODO: compare hash values instead of emsa values + const expected = try emsaEncode(hash, em); + + if (!std.mem.eql(u8, expected, em)) return error.Inconsistent; + } + }; + + /// PKCS Encrypted Message Signature Appendix + fn emsaEncode(hash: [Hash.digest_length]u8, out: []u8) ![]u8 { + const digest_header = comptime digestHeader(); + const tLen = digest_header.len + Hash.digest_length; + const emLen = out.len; + if (emLen < tLen + 11) return error.ModulusTooShort; + if (out.len < emLen) return error.BufferTooSmall; + + var res = out[0..emLen]; + res[0] = 0; + res[1] = 1; + const padding_len = emLen - tLen - 3; + @memset(res[2..][0..padding_len], 0xff); + res[2 + padding_len] = 0; + @memcpy(res[2 + padding_len + 1 ..][0..digest_header.len], digest_header); + @memcpy(res[res.len - hash.len ..], &hash); + + return res; + } + + /// DER encoded header. Sequence of digest algo + digest. + /// TODO: use a DER encoder instead + fn digestHeader() []const u8 { + const sha2 = std.crypto.hash.sha2; + // Section 9.2 Notes 1. + return switch (Hash) { + std.crypto.hash.Sha1 => &hexToBytes( + \\30 21 30 09 06 05 2b 0e 03 02 1a 05 00 04 14 + ), + sha2.Sha224 => &hexToBytes( + \\30 2d 30 0d 06 09 60 86 48 01 65 03 04 02 04 + \\05 00 04 1c + ), + sha2.Sha256 => &hexToBytes( + \\30 31 30 0d 06 09 60 86 48 01 65 03 04 02 01 05 00 + \\04 20 + ), + sha2.Sha384 => &hexToBytes( + \\30 41 30 0d 06 09 60 86 48 01 65 03 04 02 02 05 00 + \\04 30 + ), + sha2.Sha512 => &hexToBytes( + \\30 51 30 0d 06 09 60 86 48 01 65 03 04 02 03 05 00 + \\04 40 + ), + // sha2.Sha512224 => &hexToBytes( + // \\30 2d 30 0d 06 09 60 86 48 01 65 03 04 02 05 + // \\05 00 04 1c + // ), + // sha2.Sha512256 => &hexToBytes( + // \\30 31 30 0d 06 09 60 86 48 01 65 03 04 02 06 + // \\05 00 04 20 + // ), + else => @compileError("unknown Hash " ++ @typeName(Hash)), + }; + } + }; +} + +/// Probabilistic Signature Scheme (RSASSA-PSS) +pub fn Pss(comptime Hash: type) type { + // RFC 4055 S3.1 + const default_salt_len = Hash.digest_length; + return struct { + pub const Signature = struct { + bytes: []const u8, + + const Self = @This(); + + pub fn verifier(self: Self, public_key: PublicKey) !Verifier { + return Verifier.init(self, public_key); + } + + pub fn verify(self: Self, msg: []const u8, public_key: PublicKey, salt_len: ?usize) !void { + var st = Verifier.init(self, public_key, salt_len orelse default_salt_len); + st.update(msg); + return st.verify(); + } + }; + + const PssT = @This(); + + pub const Signer = struct { + h: Hash, + key_pair: KeyPair, + salt: ?[]const u8, + + fn init(key_pair: KeyPair, salt: ?[]const u8) Signer { + return .{ + .h = Hash.init(.{}), + .key_pair = key_pair, + .salt = salt, + }; + } + + pub fn update(self: *Signer, data: []const u8) void { + self.h.update(data); + } + + pub fn finalize(self: *Signer, out: []u8) !PssT.Signature { + var hashed: [Hash.digest_length]u8 = undefined; + self.h.final(&hashed); + + const salt = if (self.salt) |s| s else brk: { + var res: [default_salt_len]u8 = undefined; + std.crypto.random.bytes(&res); + break :brk &res; + }; + + const em_bits = self.key_pair.public.modulus.bits() - 1; + const em = try emsaEncode(hashed, salt, em_bits, out); + try self.key_pair.encrypt(em, em); + return .{ .bytes = em }; + } + }; + + pub const Verifier = struct { + h: Hash, + sig: PssT.Signature, + public_key: PublicKey, + salt_len: usize, + + fn init(sig: PssT.Signature, public_key: PublicKey, salt_len: usize) Verifier { + return Verifier{ + .h = Hash.init(.{}), + .sig = sig, + .public_key = public_key, + .salt_len = salt_len, + }; + } + + pub fn update(self: *Verifier, data: []const u8) void { + self.h.update(data); + } + + pub fn verify(self: *Verifier) !void { + const pk = self.public_key; + const s = try Fe.fromBytes(pk.modulus, self.sig.bytes, .big); + const emm = try pk.modulus.powPublic(s, pk.public_exponent); + + var em_buf: [max_modulus_len]u8 = undefined; + const em_bits = pk.modulus.bits() - 1; + const em_len = std.math.divCeil(usize, em_bits, 8) catch unreachable; + var em = em_buf[0..em_len]; + try emm.toBytes(em, .big); + + if (em.len < Hash.digest_length + self.salt_len + 2) return error.Inconsistent; + if (em[em.len - 1] != 0xbc) return error.Inconsistent; + + const db = em[0 .. em.len - Hash.digest_length - 1]; + if (@clz(db[0]) < em.len * 8 - em_bits) return error.Inconsistent; + + const expected_hash = em[db.len..][0..Hash.digest_length]; + var mgf_buf: [max_modulus_len]u8 = undefined; + const db_mask = mgf1(Hash, expected_hash, mgf_buf[0..db.len]); + for (db, db_mask) |*v, m| v.* ^= m; + + for (1..db.len - self.salt_len - 1) |i| { + if (db[i] != 0) return error.Inconsistent; + } + if (db[db.len - self.salt_len - 1] != 1) return error.Inconsistent; + const salt = db[db.len - self.salt_len ..]; + var mp_buf: [max_modulus_len]u8 = undefined; + var mp = mp_buf[0 .. 8 + Hash.digest_length + self.salt_len]; + @memset(mp[0..8], 0); + self.h.final(mp[8..][0..Hash.digest_length]); + @memcpy(mp[8 + Hash.digest_length ..][0..salt.len], salt); + + var actual_hash: [Hash.digest_length]u8 = undefined; + Hash.hash(mp, &actual_hash, .{}); + + if (!std.mem.eql(u8, expected_hash, &actual_hash)) return error.Inconsistent; + } + }; + + /// PSS Encrypted Message Signature Appendix + fn emsaEncode(msg_hash: [Hash.digest_length]u8, salt: []const u8, em_bits: usize, out: []u8) ![]u8 { + const em_len = std.math.divCeil(usize, em_bits, 8) catch unreachable; + + if (em_len < Hash.digest_length + salt.len + 2) return error.Encoding; + + // EM = maskedDB || H || 0xbc + var em = out[0..em_len]; + em[em.len - 1] = 0xbc; + + var mp_buf: [max_modulus_len]u8 = undefined; + // M' = (0x)00 00 00 00 00 00 00 00 || mHash || salt; + const mp = mp_buf[0 .. 8 + Hash.digest_length + salt.len]; + @memset(mp[0..8], 0); + @memcpy(mp[8..][0..Hash.digest_length], &msg_hash); + @memcpy(mp[8 + Hash.digest_length ..][0..salt.len], salt); + + // H = Hash(M') + const hash = em[em.len - 1 - Hash.digest_length ..][0..Hash.digest_length]; + Hash.hash(mp, hash, .{}); + + // DB = PS || 0x01 || salt + var db = em[0 .. em_len - Hash.digest_length - 1]; + @memset(db[0 .. db.len - salt.len - 1], 0); + db[db.len - salt.len - 1] = 1; + @memcpy(db[db.len - salt.len ..], salt); + + var mgf_buf: [max_modulus_len]u8 = undefined; + const db_mask = mgf1(Hash, hash, mgf_buf[0..db.len]); + for (db, db_mask) |*v, m| v.* ^= m; + + // Set the leftmost 8emLen - emBits bits of the leftmost octet + // in maskedDB to zero. + const shift = std.math.comptimeMod(8 * em_len - em_bits, 8); + const mask = @as(u8, 0xff) >> shift; + db[0] &= mask; + + return em; + } + }; +} + +/// Mask generation function. Currently the only one defined. +fn mgf1(comptime Hash: type, seed: []const u8, out: []u8) []u8 { + var c: [@sizeOf(u32)]u8 = undefined; + var tmp: [Hash.digest_length]u8 = undefined; + + var i: usize = 0; + var counter: u32 = 0; + while (i < out.len) : (counter += 1) { + var hasher = Hash.init(.{}); + hasher.update(seed); + std.mem.writeInt(u32, &c, counter, .big); + hasher.update(&c); + + const left = out.len - i; + if (left >= Hash.digest_length) { + // optimization: write straight to `out` + hasher.final(out[i..][0..Hash.digest_length]); + i += Hash.digest_length; + } else { + hasher.final(&tmp); + @memcpy(out[i..][0..left], tmp[0..left]); + i += left; + } + } + + return out; +} + +test mgf1 { + const Hash = std.crypto.hash.sha2.Sha256; + var out: [Hash.digest_length * 2 + 1]u8 = undefined; + try std.testing.expectEqualSlices( + u8, + &hexToBytes( + \\ed 1b 84 6b b9 26 39 00 c8 17 82 ad 08 eb 17 01 + \\fa 8c 72 21 c6 57 63 77 31 7f 5c e8 09 89 9f + ), + mgf1(Hash, "asdf", out[0 .. Hash.digest_length - 1]), + ); + try std.testing.expectEqualSlices( + u8, + &hexToBytes( + \\ed 1b 84 6b b9 26 39 00 c8 17 82 ad 08 eb 17 01 + \\fa 8c 72 21 c6 57 63 77 31 7f 5c e8 09 89 9f 5a + \\22 F2 80 D5 28 08 F4 93 83 76 00 DE 09 E4 EC 92 + \\4A 2C 7C EF 0D F7 7B BE 8F 7F 12 CB 8F 33 A6 65 + \\AB + ), + mgf1(Hash, "asdf", &out), + ); +} + +/// For OAEP. +inline fn labelHash(comptime Hash: type, label: []const u8) [Hash.digest_length]u8 { + if (label.len == 0) { + // magic constants from NIST + const sha2 = std.crypto.hash.sha2; + switch (Hash) { + std.crypto.hash.Sha1 => return hexToBytes( + \\da39a3ee 5e6b4b0d 3255bfef 95601890 + \\afd80709 + ), + sha2.Sha256 => return hexToBytes( + \\e3b0c442 98fc1c14 9afbf4c8 996fb924 + \\27ae41e4 649b934c a495991b 7852b855 + ), + sha2.Sha384 => return hexToBytes( + \\38b060a7 51ac9638 4cd9327e b1b1e36a + \\21fdb711 14be0743 4c0cc7bf 63f6e1da + \\274edebf e76f65fb d51ad2f1 4898b95b + ), + sha2.Sha512 => return hexToBytes( + \\cf83e135 7eefb8bd f1542850 d66d8007 + \\d620e405 0b5715dc 83f4a921 d36ce9ce + \\47d0d13c 5d85f2b0 ff8318d2 877eec2f + \\63b931bd 47417a81 a538327a f927da3e + ), + // just use the empty hash... + else => {}, + } + } + var res: [Hash.digest_length]u8 = undefined; + Hash.hash(label, &res, .{}); + return res; +} + +const ct = if (std.options.side_channels_mitigations == .none) ct_unprotected else ct_protected; + +const ct_unprotected = struct { + fn lastIndexOfScalar(slice: []const u8, value: u8) ?usize { + return std.mem.lastIndexOfScalar(u8, slice, value); + } + + fn indexOfScalarPos(slice: []const u8, start_index: usize, value: u8) ?usize { + return std.mem.indexOfScalarPos(u8, slice, start_index, value); + } + + fn memEql(a: []const u8, b: []const u8) bool { + return std.mem.eql(u8, a, b); + } + + fn @"and"(a: bool, b: bool) bool { + return a and b; + } + + fn @"or"(a: bool, b: bool) bool { + return a or b; + } +}; + +const ct_protected = struct { + fn lastIndexOfScalar(slice: []const u8, value: u8) ?usize { + var res: ?usize = null; + var i: usize = slice.len; + while (i != 0) { + i -= 1; + if (@intFromBool(res == null) & @intFromBool(slice[i] == value) == 1) res = i; + } + return res; + } + + fn indexOfScalarPos(slice: []const u8, start_index: usize, value: u8) ?usize { + var res: ?usize = null; + for (slice[start_index..], start_index..) |c, j| { + if (c == value) res = j; + } + return res; + } + + fn memEql(a: []const u8, b: []const u8) bool { + var res: u1 = 1; + for (a, b) |a_elem, b_elem| { + res &= @intFromBool(a_elem == b_elem); + } + return res == 1; + } + + fn @"and"(a: bool, b: bool) bool { + return (@intFromBool(a) & @intFromBool(b)) == 1; + } + + fn @"or"(a: bool, b: bool) bool { + return (@intFromBool(a) | @intFromBool(b)) == 1; + } +}; + +test ct { + const c = ct_unprotected; + try std.testing.expectEqual(true, c.@"or"(true, false)); + try std.testing.expectEqual(true, c.@"and"(true, true)); + try std.testing.expectEqual(true, c.memEql("Asdf", "Asdf")); + try std.testing.expectEqual(false, c.memEql("asdf", "Asdf")); + try std.testing.expectEqual(3, c.indexOfScalarPos("asdff", 1, 'f')); + try std.testing.expectEqual(4, c.lastIndexOfScalar("asdff", 'f')); +} + +fn removeNonHex(comptime hex: []const u8) []const u8 { + var res: [hex.len]u8 = undefined; + var i: usize = 0; + for (hex) |c| { + if (std.ascii.isHex(c)) { + res[i] = c; + i += 1; + } + } + return res[0..i]; +} + +/// For readable copy/pasting from hex viewers. +fn hexToBytes(comptime hex: []const u8) [removeNonHex(hex).len / 2]u8 { + const hex2 = comptime removeNonHex(hex); + comptime var res: [hex2.len / 2]u8 = undefined; + _ = comptime std.fmt.hexToBytes(&res, hex2) catch unreachable; + return res; +} + +test hexToBytes { + const hex = + \\e3b0c442 98fc1c14 9afbf4c8 996fb924 + \\27ae41e4 649b934c a495991b 7852b855 + ; + try std.testing.expectEqual( + [_]u8{ + 0xe3, 0xb0, 0xc4, 0x42, 0x98, 0xfc, 0x1c, 0x14, + 0x9a, 0xfb, 0xf4, 0xc8, 0x99, 0x6f, 0xb9, 0x24, + 0x27, 0xae, 0x41, 0xe4, 0x64, 0x9b, 0x93, 0x4c, + 0xa4, 0x95, 0x99, 0x1b, 0x78, 0x52, 0xb8, 0x55, + }, + hexToBytes(hex), + ); +} + +const TestHash = std.crypto.hash.sha2.Sha256; +fn testKeypair() !KeyPair { + const keypair_bytes = @embedFile("testdata/id_rsa.der"); + const kp = try KeyPair.fromDer(keypair_bytes); + try std.testing.expectEqual(2048, kp.public.modulus.bits()); + return kp; +} + +test "rsa PKCS1-v1_5 encrypt and decrypt" { + const kp = try testKeypair(); + + const msg = "rsa PKCS1-v1_5 encrypt and decrypt"; + var out: [max_modulus_len]u8 = undefined; + const enc = try kp.public.encryptPkcsv1_5(msg, &out); + + var out2: [max_modulus_len]u8 = undefined; + const dec = try kp.decryptPkcsv1_5(enc, &out2); + + try std.testing.expectEqualSlices(u8, msg, dec); +} + +test "rsa OAEP encrypt and decrypt" { + const kp = try testKeypair(); + + const msg = "rsa OAEP encrypt and decrypt"; + const label = ""; + var out: [max_modulus_len]u8 = undefined; + const enc = try kp.public.encryptOaep(TestHash, msg, label, &out); + + var out2: [max_modulus_len]u8 = undefined; + const dec = try kp.decryptOaep(TestHash, enc, label, &out2); + + try std.testing.expectEqualSlices(u8, msg, dec); +} + +test "rsa PKCS1-v1_5 signature" { + const kp = try testKeypair(); + + const msg = "rsa PKCS1-v1_5 signature"; + var out: [max_modulus_len]u8 = undefined; + + const signature = try kp.signPkcsv1_5(TestHash, msg, &out); + try signature.verify(msg, kp.public); +} + +test "rsa PSS signature" { + const kp = try testKeypair(); + + const msg = "rsa PSS signature"; + var out: [max_modulus_len]u8 = undefined; + + const salts = [_][]const u8{ "asdf", "" }; + for (salts) |salt| { + const signature = try kp.signOaep(TestHash, msg, salt, &out); + try signature.verify(msg, kp.public, salt.len); + } + + const signature = try kp.signOaep(TestHash, msg, null, &out); // random salt + try signature.verify(msg, kp.public, null); +} diff --git a/src/http/async/tls.zig/rsa/testdata/id_rsa.der b/src/http/async/tls.zig/rsa/testdata/id_rsa.der new file mode 100644 index 0000000000000000000000000000000000000000..9e4f1334d16264ca8accea1d9f7212da6a14554a GIT binary patch literal 1191 zcmV;Y1X%kpf&`-i0RRGm0RaHSfHnrY=SOP@lmzUjwvhxs_maFB?)!ao*QgBu9(zkV zO6Cvfz;XO@=K@R(y!5@%9XV^da7IcK=}P!L^Wh0uRC~!)`#~+Ec2W`H^W1lAs#7;^ z$~x@6!>YGCG1Y9gQk;O8yvg7w7~%`}_@Fxd7X(nA&Uw9`Iq~Xg>_?X_gAcXJmEM)1 z<^&?u?!HoaRH5g;iiY+^Z4I9ml^RU`K|Im+mAD~hY(e&`u&UtVtGUCd4b^x|AUT0|5X50)hbmVS71J9cG^gdO7qra{M0j{KnM;d)X4|Dh$$5 zpuK)<(|%N=kJIS5RL4bJb#zDd;>Vn{)X3fy;mX=(OjWQ=^8)r|>p{8`>I2BL7G&f4 zTJ9uTTbuWvdjpoOPq1k)g+g{=Rb*1Z;Uaedc>>;gWfTSv__%efoKYgOC)y9G5#U>) zX4nJos0`bTLHEODsw7TAd zZaNEJYbv|;_HfEL^^x7ZFckyx4;#r{Io*=6Z9}3uriSet?;rJ}1Dtq0BZU5Hi zQ2$bCHjw`G*=@EQ{sO+TQM3ZngpvmKZIi)B<^5nTWPodhAR8k0Hx zRF$oWcryh&arHHJi6s7g-`0Dx>>O0zV}G!2mkbOQ-yu{=4O|BV%p8Sr{rQ_^#~t`> zf?9e9bV^_}9}jMUVFxw}5)x&pGh02vD*60E<8!x0uv>GxJM5R&>2$QQ4Tt~%+};}5 z0)c=)%%UKzS~2V}a|gnxjvGVeIuV_4^!R_w=;DyFwofV-Q6uIo%d2X5fjbKeVWjqAX@$;2c6-1ooGbC%q?qgq z&ubOB)zcw4OL0tehSLK07lV!!1Rl;&ya=HJfq*-h@&u~`pTO^q@QyjGbrp)>GZQdH zcz7kX1Vj5y4#LIWXWtXaiHK{S62>>_jGa;(zih-YEkAhHYt}#X2GJeI?Z^wED){4A z(W%+sitXpW`9b2KDu2r)_+LLHh&*l-n*lh1Z!PcJ-#DyiQth7%Jy~5wD5a_T5y=1u F(f{FrO=AE6 literal 0 HcmV?d00001 diff --git a/src/http/async/tls.zig/testdata/ec_prime256v1_private_key.pem b/src/http/async/tls.zig/testdata/ec_prime256v1_private_key.pem new file mode 100644 index 00000000..67ebf388 --- /dev/null +++ b/src/http/async/tls.zig/testdata/ec_prime256v1_private_key.pem @@ -0,0 +1,5 @@ +-----BEGIN EC PRIVATE KEY----- +MHcCAQEEINJSRKv8kSKEzLHptfAlg+LGh4/pHHlq0XLf30Q9pcztoAoGCCqGSM49 +AwEHoUQDQgAEJpmLyp8aGCgyMcFIJaIq/+4V1K6nPpeoih3bT2npeplF9eyXj7rm +8eW9Ua6VLhq71mqtMC+YLm+IkORBVq1cuA== +-----END EC PRIVATE KEY----- diff --git a/src/http/async/tls.zig/testdata/ec_private_key.pem b/src/http/async/tls.zig/testdata/ec_private_key.pem new file mode 100644 index 00000000..95048aaa --- /dev/null +++ b/src/http/async/tls.zig/testdata/ec_private_key.pem @@ -0,0 +1,6 @@ +-----BEGIN PRIVATE KEY----- +MIG2AgEAMBAGByqGSM49AgEGBSuBBAAiBIGeMIGbAgEBBDAQNT3KGxUdBqpxuO/z +GSJDePMgmB6xLytkfnHQMCqQquXrmcOQZT3BJhm+PwggmwGhZANiAATKxBc6kfqA +piA+Z0rIjVwaZaBNGnP4UZ5TqVewQ/dP9/BQCca2SJpsXauGLcUPmK4sKFQxGe6d +fzq9O50lo7qHEOIpwDBdhRp+oqB6sN2hMtCPbp6eyzsUlm3FUyhN9D0= +-----END PRIVATE KEY----- diff --git a/src/http/async/tls.zig/testdata/ec_secp384r1_private_key.pem b/src/http/async/tls.zig/testdata/ec_secp384r1_private_key.pem new file mode 100644 index 00000000..62eac9ee --- /dev/null +++ b/src/http/async/tls.zig/testdata/ec_secp384r1_private_key.pem @@ -0,0 +1,6 @@ +-----BEGIN EC PRIVATE KEY----- +MIGkAgEBBDDubYpeDdOwxksyQIDiOt6LHt3ikts2HNuR6rqhBg1CLdmp3AVDKfF4 +fPkIr8UDH22gBwYFK4EEACKhZANiAARcVFUVv3bIHS6BEfLt98rtps7XP1y26m2n +v5x/5ecbDH2p7AXBYerJERKFi7ZFE1DSrSAj+KK8otjdEG44ZA2Mtl5AHwDVrKde +RgtavVoreHhLN80jJOun8JnFXQjdNsA= +-----END EC PRIVATE KEY----- diff --git a/src/http/async/tls.zig/testdata/ec_secp521r1_private_key.pem b/src/http/async/tls.zig/testdata/ec_secp521r1_private_key.pem new file mode 100644 index 00000000..5b7f9321 --- /dev/null +++ b/src/http/async/tls.zig/testdata/ec_secp521r1_private_key.pem @@ -0,0 +1,7 @@ +-----BEGIN EC PRIVATE KEY----- +MIHcAgEBBEIB8C9axyQY6mgjjC6htLjc8hGylrDsh4BCv9669JaDj5vbxmCnTNlg +OuS6C9+uJNMbwm6CoIjB7RcgDTrxxX7oCyegBwYFK4EEACOhgYkDgYYABABAT5Q8 +aOj9U0iuJE5tXfKnYTgPuvD6keHZAGJ5veM9uR6jr3BhfGubD6bnlD+cIBQzYWo0 +y/BNMzCRJ55PDCNU5gGLw+vkwhJ1lGF5OS6l2oG5WN3fe6cYo+uJD7+PB3WYNIuX +Ls0oidsEM0Q4WLblQOEP6VLGf4qTcZyhoFWYfkjWiw== +-----END EC PRIVATE KEY----- diff --git a/src/http/async/tls.zig/testdata/google.com/client_random b/src/http/async/tls.zig/testdata/google.com/client_random new file mode 100644 index 00000000..e817c906 --- /dev/null +++ b/src/http/async/tls.zig/testdata/google.com/client_random @@ -0,0 +1 @@ +'”’ßqp0x­0)ì©–Ã~Ì+Œ`‡¬tY4•©D_ \ No newline at end of file diff --git a/src/http/async/tls.zig/testdata/google.com/server_hello b/src/http/async/tls.zig/testdata/google.com/server_hello new file mode 100644 index 0000000000000000000000000000000000000000..57a807650723010f00d7b371701d3004a55bc685 GIT binary patch literal 7158 zcmb_hc|4Tu*S}{W8M0;#V~vsN9%~CJm8FoSlBH$LW-?~R%rGg1A+ogED)ppN+EH4y zsvc5NPy0`0i54npQHi|QJ!Zy`=Y2o#=kxOi_xGH0o$K7^x~}iJ&P@k{fkMy_6pX>p zorU$GtTg;h4}3p%T1cZp<_Yf6$8lfgEC})U3Yvq$XWpA1O~lMwl}*ZPGQWH*Bvy%1 zcYsxO+aM+TMt8uL#aAao5N-ekp}-#qje>MA7=v&eWDo)wEQHkN!y+{=STt^OF$Rr7 zqt!!v6l0IH*N@k1b7kp%(>{%wRkM zl(qD|I2;CxhF{2w;|uV?G+sQDLgV9oeP@%jU=73uqowS%Fc?337M?WQ0XEiEwReOa zNzSkXX^N9Wm>9aiQ^n9e4Av!$hqc7RR8$B=hS)ig!ij-JC^Pw(Pzn%6gi?cmTp2Aw zp`h)NwV|O4*X0FvH-8mgC|76qJvi8zw|Zyd0gKtY$^PnvCEEwKtPuWM;<5nUY;HZT zGQ+>+O?7>X+S=s)EIQuX)%D|^oGlNdPVCdYn30@h@VeReqxO%WgPjWfQ4y~ea9u2n z{V!J@iyl?^<^i^da(93B%G*a>KSflRDPAW~SIII8?8Js{(8gj7B!jz<~h4``@l z0U2yhZk%A41~;^K$s#$x5{(er0P5xh0GODqR2!xO@PD=w$>lAAuA*(U^vW5D#HmuD&mp7 z4BV&^i1JW_9K+@EI4fxs0oc$C!~(q6pF(*%NkGHxBtqT6H?$+O$*csBaap$XrLXy? zN-_rILqh-349mv+snQsnCmv{|*0S}G7fS4q-SLN=ym0)WKMq&GjA5>(No{qdQIAw; zQaN#vEEokxz56748fI^ZbV#n8z7I36vA z#ui9DgEjmY4UQ2nAW6-@*^-Bd&VIN~266T=RB_=#dwcGASTdunl8q z0v_O1d8zVX7R1zwurY;80Ywu@j94_NwiiDm{735I)g;nrCQ_R~ zs9^m@iWCySb%!PzgDe2@U?9sjiwTYlUw-urLEWLTI9wV}mVRk&*rEwCC4UsHp5(eD z^*%JtKsJs@iQ;i&=NVZ_Js~X;xt|SAIur|AWJKXm0kcH{CLTo%nTv>E0Ry`tCMH-X7+g7vCl}kL+e$8*6dvGr*^P7XRyuh#b=E=w zJa^S{&uDP{V}U*cJLCe1fe#yn;N)75EcV@!E@23&|$B}aO!tz+!;b0Q;+cDw|#T;$?sny$yu2% z20Ax>=||(WC`IjC7cHCT-x1s4v8m|fK)t1DavR0Df%2Ed(TeejU|#R?SJD4D=;f%^ zz7pNL8}6JCgu1=>#iW#HRacHqeVWFX%O0+rMmy5);_Sn-92(YTIp+t={KBT)phwE4D^Md z2SY*d6|#q6tWsWMjf?-s>-!znIvV0Dvsd}LNgM|0}I?$GN4q}WRO z`bm+tDu>(h=TG&2WtjM^y^PrP%6Z#6*VCG%+0}cRtS{bn;{UyCP0{_UGdz z>C|uK7~d+i$X#_XulVQ_ZJQ-gH3PP6;(`aS_g!z(2~*x%)&6Glx&_m4p{EX86_(<& z)`hQmk-&I&QD?TNDJH7=$(?IsTLnvCqwn`pA1pswGH*hyvXjv1?*+MgHX1v$*O}u< zGiUPp6Pg1!cJJEPymZC>jt}NRW#;eaA>q+jovp@0DVNtV6XpwE3FM{g3SwBX&ll@ z2XQz8d=N?8OT+LKk^>BW9PN=_8g$ZPguwq=&;Eb4-rqJfy@|{*cKPAru_cL>aqkyR zWZD1rc67-=%)2o2)*U?)-fhF|&`KPuoYm{`c-dl$BdZjM8-!`-$1uFg(*jOjT;cgWy=;Hd+M8x3xcT`rk))9sM zcF%EVqP0|V^lfzuwx9o$-BUcV>ddQ+ftxR#s41$*l(yceWIz&$!&`4u`e0veiPeF% zrF||2C3`}P7k01-SArEmFAl!c$hG35cG~i-cw3pRRjh%n_r&#+gXojGK4ogf&VZw& z{y{4q0f)g>QVd#ggMVwqc_bK?*`d=d{`XFtM-n&Uj)R1pGDy||us=Lyh*SYZ6JTp;2xyef1jH8~d^V7AGx9Zn&lGTY36te$UP}1r z5ho2J0?Eu=dki0;;(B@Kp^$vT$`idu5Ab3uo%8#fW@=!3t39==v@MQ)3WzmpJbr_n z)X=?>Q=4k|v_5lr1}o#-VU_(wdn5C5^P9p?(DiZXg1`Dszfot-Sbf{mXSrv~tcoYw z{b-I*vVr@8J**@jC?$AS$u?7w$e?@woNz*brk3BeVCOF$N@WXsHwl9csZ%EEj=i+x zmSN&lyYMgd4?h$(YGba94Yn|;R>3B0xB5)(4i4(~$T6hbx{u#mHPBP^DrF~1F*ei5 z+TumarA~dj7Vq0`p{aqDg1>hs{#}tCUUtXv#)CT#Q#A6SyBR)9isZYOTbE36oTI6&V27?+`w-;Unur1`f5AQL3n+>M|HPQ ztB;?Y*J)Da;$Dv?(98Ko_6O`)_T`(JZca@I{!52Fjc$pX~_fC5xWKsk>h8Nxg>p9 z7vWxIEp;yzCz8wxW-{1#cDx;*Oi~vgb}UY7jF^&V&mtMZ`h!HRDQ7HEO8DBDG%>QZIQ}MGt&8Xt%&BB*m=Rfa%cTqEQS?Q6Wqd&|i-Y%)w zMQMDpW8Ej^Wt~w6zy2NOrRt~d@XEhJYk#dCYySJbv0IvUQ*A8G$Ly{A$Bnq6x#6F& z8C7?^KlxQxuilVVb)g|t|8d#o@-gLqk-A0Kt~L=5#<29qoqoE`%3v@{>cyE?aS@k# zGEOGeG@@MzQZ`uWYLpPt$@w;}%Ohncqh z!D{8DkEW|uWbMPeBptsa?BJ}K@wjQLnda-uy-)4?)G6A5yq8NK>*QXtiE`+gMt?sh z%c@E1NeFG##Fxu_w4r@a(#}JsXAW27mmfHiX>N#My)pTcNYE8JIZUl}oOiS3RXUfK7SvK3cyE@f~1s5AfV z^uwoI-dvwkvhFGCNsZ^XZAEV$nO>%E)6ZMPdT4V}^%_0ARp{ew{bBbL0|#F4DUnk{ zWcDqK>;q?~s#L8b%(@cP(6_JPOtr`z{C^2GNd-b#V<02^cYnS{waq_!33G9uvYjru zEO-IYI+>@gPvbtx_xmC6Rl=emW5ZwqYagSg8a;cPzv6O@(7QZa^(DuWKhV?T>t8R4 z$lZ3Nc*glt4qxA6fTv#IX1w(%BQd1k&iGK;x7PEe zTDux_pzXi>4mm7$KIED>p*XAVqtC79)*Dy1nKcwVc`~)HCuzRrn=BZ?s zAKfdyOvQRkym~#=tsM6#!lCNX(jEnO?1lob&!rn2M7s>`B2}, ", .{b}); + // if (i % 16 == 0) + // std.debug.print("\n", .{}); + // } + // std.debug.print("}};\n", .{}); + + std.debug.print("const {s} = \"", .{var_name}); + const charset = "0123456789abcdef"; + for (buf) |b| { + const x = charset[b >> 4]; + const y = charset[b & 15]; + std.debug.print("{c}{c} ", .{ x, y }); + } + std.debug.print("\"\n", .{}); +} + +const random_instance = std.Random{ .ptr = undefined, .fillFn = randomFillFn }; +var random_seed: u8 = 0; + +pub fn randomFillFn(_: *anyopaque, buf: []u8) void { + for (buf) |*v| { + v.* = random_seed; + random_seed +%= 1; + } +} + +pub fn random(seed: u8) std.Random { + random_seed = seed; + return random_instance; +} + +// Fill buf with 0,1,..ff,0,... +pub fn fill(buf: []u8) void { + fillFrom(buf, 0); +} + +pub fn fillFrom(buf: []u8, start: u8) void { + var i: u8 = start; + for (buf) |*v| { + v.* = i; + i +%= 1; + } +} + +pub const Stream = struct { + output: std.io.FixedBufferStream([]u8) = undefined, + input: std.io.FixedBufferStream([]const u8) = undefined, + + pub fn init(input: []const u8, output: []u8) Stream { + return .{ + .input = std.io.fixedBufferStream(input), + .output = std.io.fixedBufferStream(output), + }; + } + + pub const ReadError = error{}; + pub const WriteError = error{NoSpaceLeft}; + + pub fn write(self: *Stream, buf: []const u8) !usize { + return try self.output.writer().write(buf); + } + + pub fn writeAll(self: *Stream, buffer: []const u8) !void { + var n: usize = 0; + while (n < buffer.len) { + n += try self.write(buffer[n..]); + } + } + + pub fn read(self: *Stream, buffer: []u8) !usize { + return self.input.read(buffer); + } +}; + +// Copied from: https://github.com/clickingbuttons/zig/blob/f1cea91624fd2deae28bfb2414a4fd9c7e246883/lib/std/crypto/rsa.zig#L791 +/// For readable copy/pasting from hex viewers. +pub fn hexToBytes(comptime hex: []const u8) [removeNonHex(hex).len / 2]u8 { + @setEvalBranchQuota(1000 * 100); + const hex2 = comptime removeNonHex(hex); + comptime var res: [hex2.len / 2]u8 = undefined; + _ = comptime std.fmt.hexToBytes(&res, hex2) catch unreachable; + return res; +} + +fn removeNonHex(comptime hex: []const u8) []const u8 { + @setEvalBranchQuota(1000 * 100); + var res: [hex.len]u8 = undefined; + var i: usize = 0; + for (hex) |c| { + if (std.ascii.isHex(c)) { + res[i] = c; + i += 1; + } + } + return res[0..i]; +} + +test hexToBytes { + const hex = + \\e3b0c442 98fc1c14 9afbf4c8 996fb924 + \\27ae41e4 649b934c a495991b 7852b855 + ; + try std.testing.expectEqual( + [_]u8{ + 0xe3, 0xb0, 0xc4, 0x42, 0x98, 0xfc, 0x1c, 0x14, + 0x9a, 0xfb, 0xf4, 0xc8, 0x99, 0x6f, 0xb9, 0x24, + 0x27, 0xae, 0x41, 0xe4, 0x64, 0x9b, 0x93, 0x4c, + 0xa4, 0x95, 0x99, 0x1b, 0x78, 0x52, 0xb8, 0x55, + }, + hexToBytes(hex), + ); +} diff --git a/src/http/async/tls.zig/transcript.zig b/src/http/async/tls.zig/transcript.zig new file mode 100644 index 00000000..59c94986 --- /dev/null +++ b/src/http/async/tls.zig/transcript.zig @@ -0,0 +1,297 @@ +const std = @import("std"); +const crypto = std.crypto; +const tls = crypto.tls; +const hkdfExpandLabel = tls.hkdfExpandLabel; + +const Sha256 = crypto.hash.sha2.Sha256; +const Sha384 = crypto.hash.sha2.Sha384; +const Sha512 = crypto.hash.sha2.Sha512; + +const HashTag = @import("cipher.zig").CipherSuite.HashTag; + +// Transcript holds hash of all handshake message. +// +// Until the server hello is parsed we don't know which hash (sha256, sha384, +// sha512) will be used so we update all of them. Handshake process will set +// `selected` field once cipher suite is known. Other function will use that +// selected hash. We continue to calculate all hashes because client certificate +// message could use different hash than the other part of the handshake. +// Handshake hash is dictated by the server selected cipher. Client certificate +// hash is dictated by the private key used. +// +// Most of the functions are inlined because they are returning pointers. +// +pub const Transcript = struct { + sha256: Type(.sha256) = .{ .hash = Sha256.init(.{}) }, + sha384: Type(.sha384) = .{ .hash = Sha384.init(.{}) }, + sha512: Type(.sha512) = .{ .hash = Sha512.init(.{}) }, + + tag: HashTag = .sha256, + + pub const max_mac_length = Type(.sha512).mac_length; + + // Transcript Type from hash tag + fn Type(h: HashTag) type { + return switch (h) { + .sha256 => TranscriptT(Sha256), + .sha384 => TranscriptT(Sha384), + .sha512 => TranscriptT(Sha512), + }; + } + + /// Set hash to use in all following function calls. + pub fn use(t: *Transcript, tag: HashTag) void { + t.tag = tag; + } + + pub fn update(t: *Transcript, buf: []const u8) void { + t.sha256.hash.update(buf); + t.sha384.hash.update(buf); + t.sha512.hash.update(buf); + } + + // tls 1.2 handshake specific + + pub inline fn masterSecret( + t: *Transcript, + pre_master_secret: []const u8, + client_random: [32]u8, + server_random: [32]u8, + ) []const u8 { + return switch (t.tag) { + inline else => |h| &@field(t, @tagName(h)).masterSecret( + pre_master_secret, + client_random, + server_random, + ), + }; + } + + pub inline fn keyMaterial( + t: *Transcript, + master_secret: []const u8, + client_random: [32]u8, + server_random: [32]u8, + ) []const u8 { + return switch (t.tag) { + inline else => |h| &@field(t, @tagName(h)).keyExpansion( + master_secret, + client_random, + server_random, + ), + }; + } + + pub fn clientFinishedTls12(t: *Transcript, master_secret: []const u8) [12]u8 { + return switch (t.tag) { + inline else => |h| @field(t, @tagName(h)).clientFinishedTls12(master_secret), + }; + } + + pub fn serverFinishedTls12(t: *Transcript, master_secret: []const u8) [12]u8 { + return switch (t.tag) { + inline else => |h| @field(t, @tagName(h)).serverFinishedTls12(master_secret), + }; + } + + // tls 1.3 handshake specific + + pub inline fn serverCertificateVerify(t: *Transcript) []const u8 { + return switch (t.tag) { + inline else => |h| &@field(t, @tagName(h)).serverCertificateVerify(), + }; + } + + pub inline fn clientCertificateVerify(t: *Transcript) []const u8 { + return switch (t.tag) { + inline else => |h| &@field(t, @tagName(h)).clientCertificateVerify(), + }; + } + + pub fn serverFinishedTls13(t: *Transcript, buf: []u8) []const u8 { + return switch (t.tag) { + inline else => |h| @field(t, @tagName(h)).serverFinishedTls13(buf), + }; + } + + pub fn clientFinishedTls13(t: *Transcript, buf: []u8) []const u8 { + return switch (t.tag) { + inline else => |h| @field(t, @tagName(h)).clientFinishedTls13(buf), + }; + } + + pub const Secret = struct { + client: []const u8, + server: []const u8, + }; + + pub inline fn handshakeSecret(t: *Transcript, shared_key: []const u8) Secret { + return switch (t.tag) { + inline else => |h| @field(t, @tagName(h)).handshakeSecret(shared_key), + }; + } + + pub inline fn applicationSecret(t: *Transcript) Secret { + return switch (t.tag) { + inline else => |h| @field(t, @tagName(h)).applicationSecret(), + }; + } + + // other + + pub fn Hkdf(h: HashTag) type { + return Type(h).Hkdf; + } + + /// Copy of the current hash value + pub inline fn hash(t: *Transcript, comptime Hash: type) Hash { + return switch (Hash) { + Sha256 => t.sha256.hash, + Sha384 => t.sha384.hash, + Sha512 => t.sha512.hash, + else => @compileError("unimplemented"), + }; + } +}; + +fn TranscriptT(comptime Hash: type) type { + return struct { + const Hmac = crypto.auth.hmac.Hmac(Hash); + const Hkdf = crypto.kdf.hkdf.Hkdf(Hmac); + const mac_length = Hmac.mac_length; + + hash: Hash, + handshake_secret: [Hmac.mac_length]u8 = undefined, + server_finished_key: [Hmac.key_length]u8 = undefined, + client_finished_key: [Hmac.key_length]u8 = undefined, + + const Self = @This(); + + fn init(transcript: Hash) Self { + return .{ .transcript = transcript }; + } + + fn serverCertificateVerify(c: *Self) [64 + 34 + Hash.digest_length]u8 { + return ([1]u8{0x20} ** 64) ++ + "TLS 1.3, server CertificateVerify\x00".* ++ + c.hash.peek(); + } + + // ref: https://www.rfc-editor.org/rfc/rfc8446#section-4.4.3 + fn clientCertificateVerify(c: *Self) [64 + 34 + Hash.digest_length]u8 { + return ([1]u8{0x20} ** 64) ++ + "TLS 1.3, client CertificateVerify\x00".* ++ + c.hash.peek(); + } + + fn masterSecret( + _: *Self, + pre_master_secret: []const u8, + client_random: [32]u8, + server_random: [32]u8, + ) [mac_length * 2]u8 { + const seed = "master secret" ++ client_random ++ server_random; + + var a1: [mac_length]u8 = undefined; + var a2: [mac_length]u8 = undefined; + Hmac.create(&a1, seed, pre_master_secret); + Hmac.create(&a2, &a1, pre_master_secret); + + var p1: [mac_length]u8 = undefined; + var p2: [mac_length]u8 = undefined; + Hmac.create(&p1, a1 ++ seed, pre_master_secret); + Hmac.create(&p2, a2 ++ seed, pre_master_secret); + + return p1 ++ p2; + } + + fn keyExpansion( + _: *Self, + master_secret: []const u8, + client_random: [32]u8, + server_random: [32]u8, + ) [mac_length * 4]u8 { + const seed = "key expansion" ++ server_random ++ client_random; + + const a0 = seed; + var a1: [mac_length]u8 = undefined; + var a2: [mac_length]u8 = undefined; + var a3: [mac_length]u8 = undefined; + var a4: [mac_length]u8 = undefined; + Hmac.create(&a1, a0, master_secret); + Hmac.create(&a2, &a1, master_secret); + Hmac.create(&a3, &a2, master_secret); + Hmac.create(&a4, &a3, master_secret); + + var key_material: [mac_length * 4]u8 = undefined; + Hmac.create(key_material[0..mac_length], a1 ++ seed, master_secret); + Hmac.create(key_material[mac_length .. mac_length * 2], a2 ++ seed, master_secret); + Hmac.create(key_material[mac_length * 2 .. mac_length * 3], a3 ++ seed, master_secret); + Hmac.create(key_material[mac_length * 3 ..], a4 ++ seed, master_secret); + return key_material; + } + + fn clientFinishedTls12(self: *Self, master_secret: []const u8) [12]u8 { + const seed = "client finished" ++ self.hash.peek(); + var a1: [mac_length]u8 = undefined; + var p1: [mac_length]u8 = undefined; + Hmac.create(&a1, seed, master_secret); + Hmac.create(&p1, a1 ++ seed, master_secret); + return p1[0..12].*; + } + + fn serverFinishedTls12(self: *Self, master_secret: []const u8) [12]u8 { + const seed = "server finished" ++ self.hash.peek(); + var a1: [mac_length]u8 = undefined; + var p1: [mac_length]u8 = undefined; + Hmac.create(&a1, seed, master_secret); + Hmac.create(&p1, a1 ++ seed, master_secret); + return p1[0..12].*; + } + + // tls 1.3 + + inline fn handshakeSecret(self: *Self, shared_key: []const u8) Transcript.Secret { + const hello_hash = self.hash.peek(); + + const zeroes = [1]u8{0} ** Hash.digest_length; + const early_secret = Hkdf.extract(&[1]u8{0}, &zeroes); + const empty_hash = tls.emptyHash(Hash); + const hs_derived_secret = hkdfExpandLabel(Hkdf, early_secret, "derived", &empty_hash, Hash.digest_length); + + self.handshake_secret = Hkdf.extract(&hs_derived_secret, shared_key); + const client_secret = hkdfExpandLabel(Hkdf, self.handshake_secret, "c hs traffic", &hello_hash, Hash.digest_length); + const server_secret = hkdfExpandLabel(Hkdf, self.handshake_secret, "s hs traffic", &hello_hash, Hash.digest_length); + + self.server_finished_key = hkdfExpandLabel(Hkdf, server_secret, "finished", "", Hmac.key_length); + self.client_finished_key = hkdfExpandLabel(Hkdf, client_secret, "finished", "", Hmac.key_length); + + return .{ .client = &client_secret, .server = &server_secret }; + } + + inline fn applicationSecret(self: *Self) Transcript.Secret { + const handshake_hash = self.hash.peek(); + + const empty_hash = tls.emptyHash(Hash); + const zeroes = [1]u8{0} ** Hash.digest_length; + const ap_derived_secret = hkdfExpandLabel(Hkdf, self.handshake_secret, "derived", &empty_hash, Hash.digest_length); + const master_secret = Hkdf.extract(&ap_derived_secret, &zeroes); + + const client_secret = hkdfExpandLabel(Hkdf, master_secret, "c ap traffic", &handshake_hash, Hash.digest_length); + const server_secret = hkdfExpandLabel(Hkdf, master_secret, "s ap traffic", &handshake_hash, Hash.digest_length); + + return .{ .client = &client_secret, .server = &server_secret }; + } + + fn serverFinishedTls13(self: *Self, buf: []u8) []const u8 { + Hmac.create(buf[0..mac_length], &self.hash.peek(), &self.server_finished_key); + return buf[0..mac_length]; + } + + // client finished message with header + fn clientFinishedTls13(self: *Self, buf: []u8) []const u8 { + Hmac.create(buf[0..mac_length], &self.hash.peek(), &self.client_finished_key); + return buf[0..mac_length]; + } + }; +} From 6809bb53931ba6c1d34d9b6f88613c08d235c65e Mon Sep 17 00:00:00 2001 From: Pierre Tachoire Date: Fri, 15 Nov 2024 14:37:19 +0100 Subject: [PATCH 02/11] async: adapt async cli --- src/browser/browser.zig | 4 +- src/http/async/io.zig | 149 +++++++------- src/http/async/std/http/Client.zig | 103 ++++------ src/run_tests.zig | 4 +- src/user_context.zig | 2 +- src/wpt/run.zig | 4 +- src/xhr/xhr.zig | 311 ++++++++++++++++------------- 7 files changed, 282 insertions(+), 295 deletions(-) diff --git a/src/browser/browser.zig b/src/browser/browser.zig index 9308ed82..e6f646ef 100644 --- a/src/browser/browser.zig +++ b/src/browser/browser.zig @@ -40,7 +40,7 @@ const storage = @import("../storage/storage.zig"); const FetchResult = @import("../http/Client.zig").Client.FetchResult; const UserContext = @import("../user_context.zig").UserContext; -const HttpClient = @import("../async/Client.zig"); +const HttpClient = @import("../http/async/main.zig").Client; const log = std.log.scoped(.browser); @@ -116,7 +116,7 @@ pub const Session = struct { }; Env.init(&self.env, self.arena.allocator(), loop, null); - self.httpClient = .{ .allocator = alloc, .loop = loop }; + self.httpClient = .{ .allocator = alloc }; try self.env.load(&self.jstypes); } diff --git a/src/http/async/io.zig b/src/http/async/io.zig index a416c5fc..7c2aad5b 100644 --- a/src/http/async/io.zig +++ b/src/http/async/io.zig @@ -1,6 +1,8 @@ const std = @import("std"); -pub const IO = @import("jsruntime").IO; +const Ctx = @import("std/http/Client.zig").Ctx; +const Loop = @import("jsruntime").Loop; +const NetworkImpl = Loop.Network(SingleThreaded); pub const Blocking = struct { pub fn connect( @@ -57,92 +59,75 @@ pub const Blocking = struct { } }; -pub fn SingleThreaded(comptime CtxT: type) type { - return struct { - io: *IO, - completion: IO.Completion, - ctx: *CtxT, - cbk: CbkT, +pub const SingleThreaded = struct { + impl: NetworkImpl, + cbk: Cbk, + ctx: *Ctx, - count: u32 = 0, + const Self = @This(); + const Cbk = *const fn (ctx: *Ctx, res: anyerror!void) anyerror!void; - const CbkT = *const fn (ctx: *CtxT, res: anyerror!void) anyerror!void; + pub fn init(loop: *Loop) Self { + return .{ + .impl = NetworkImpl.init(loop), + .cbk = undefined, + .ctx = undefined, + }; + } - const Self = @This(); + pub fn connect( + self: *Self, + comptime _: type, + ctx: *Ctx, + comptime cbk: Cbk, + socket: std.posix.socket_t, + address: std.net.Address, + ) void { + self.cbk = cbk; + self.ctx = ctx; + self.impl.connect(self, socket, address); + } - pub fn init(io: *IO) Self { - return .{ - .io = io, - .completion = undefined, - .ctx = undefined, - .cbk = undefined, - }; - } + pub fn onConnect(self: *Self, err: ?anyerror) void { + if (err) |e| return self.ctx.setErr(e); + self.cbk(self.ctx, {}) catch |e| self.ctx.setErr(e); + } - pub fn connect( - self: *Self, - comptime _: type, - ctx: *CtxT, - comptime cbk: CbkT, - socket: std.posix.socket_t, - address: std.net.Address, - ) void { - self.ctx = ctx; - self.cbk = cbk; - self.count += 1; - self.io.connect(*Self, self, Self.connectCbk, &self.completion, socket, address); - } + pub fn send( + self: *Self, + comptime _: type, + ctx: *Ctx, + comptime cbk: Cbk, + socket: std.posix.socket_t, + buf: []const u8, + ) void { + self.ctx = ctx; + self.cbk = cbk; + self.impl.send(self, socket, buf); + } - fn connectCbk(self: *Self, _: *IO.Completion, result: IO.ConnectError!void) void { - defer self.count -= 1; - _ = result catch |e| return self.ctx.setErr(e); - self.cbk(self.ctx, {}) catch |e| self.ctx.setErr(e); - } + pub fn onSend(self: *Self, ln: usize, err: ?anyerror) void { + if (err) |e| return self.ctx.setErr(e); + self.ctx.setLen(ln); + self.cbk(self.ctx, {}) catch |e| self.ctx.setErr(e); + } - pub fn send( - self: *Self, - comptime _: type, - ctx: *CtxT, - comptime cbk: CbkT, - socket: std.posix.socket_t, - buf: []const u8, - ) void { - self.ctx = ctx; - self.cbk = cbk; - self.count += 1; - self.io.send(*Self, self, Self.sendCbk, &self.completion, socket, buf); - } + pub fn recv( + self: *Self, + comptime _: type, + ctx: *Ctx, + comptime cbk: Cbk, + socket: std.posix.socket_t, + buf: []u8, + ) void { + self.ctx = ctx; + self.cbk = cbk; + self.impl.receive(self, socket, buf); + } - fn sendCbk(self: *Self, _: *IO.Completion, result: IO.SendError!usize) void { - defer self.count -= 1; - const ln = result catch |e| return self.ctx.setErr(e); - self.ctx.setLen(ln); - self.cbk(self.ctx, {}) catch |e| self.ctx.setErr(e); - } - - pub fn recv( - self: *Self, - comptime _: type, - ctx: *CtxT, - comptime cbk: CbkT, - socket: std.posix.socket_t, - buf: []u8, - ) void { - self.ctx = ctx; - self.cbk = cbk; - self.count += 1; - self.io.recv(*Self, self, Self.receiveCbk, &self.completion, socket, buf); - } - - fn receiveCbk(self: *Self, _: *IO.Completion, result: IO.RecvError!usize) void { - defer self.count -= 1; - const ln = result catch |e| return self.ctx.setErr(e); - self.ctx.setLen(ln); - self.cbk(self.ctx, {}) catch |e| self.ctx.setErr(e); - } - - pub fn isDone(self: *Self) bool { - return self.count == 0; - } - }; -} + pub fn onReceive(self: *Self, ln: usize, err: ?anyerror) void { + if (err) |e| return self.ctx.setErr(e); + self.ctx.setLen(ln); + self.cbk(self.ctx, {}) catch |e| self.ctx.setErr(e); + } +}; diff --git a/src/http/async/std/http/Client.zig b/src/http/async/std/http/Client.zig index f0c37b20..2c866e6f 100644 --- a/src/http/async/std/http/Client.zig +++ b/src/http/async/std/http/Client.zig @@ -22,7 +22,7 @@ const tls23 = @import("../../tls.zig/main.zig"); const VecPut = @import("../../tls.zig/connection.zig").VecPut; const GenericStack = @import("../../stack.zig").Stack; const async_io = @import("../../io.zig"); -pub const Loop = async_io.SingleThreaded(Ctx); +pub const Loop = async_io.SingleThreaded; const cipher = @import("../../tls.zig/cipher.zig"); @@ -1343,8 +1343,29 @@ pub const Request = struct { } } + fn onWriteAll(ctx: *Ctx, res: anyerror!void) !void { + res catch |err| return ctx.pop(err); + switch (ctx.req.transfer_encoding) { + .chunked => unreachable, + .none => unreachable, + .content_length => |*len| { + len.* = 0; + }, + } + try ctx.pop({}); + } + pub fn async_writeAll(req: *Request, buf: []const u8, ctx: *Ctx, comptime cbk: Cbk) !void { - try req.connection.?.async_writeAllDirect(buf, ctx, cbk); + switch (req.transfer_encoding) { + .chunked => return error.ChunkedNotImplemented, + .none => return error.NotWriteable, + .content_length => |len| { + try ctx.push(cbk); + if (len < buf.len) return error.MessageTooLong; + + try req.connection.?.async_writeAllDirect(buf, ctx, onWriteAll); + }, + } } pub const FinishError = WriteError || error{MessageNotCompleted}; @@ -1757,6 +1778,7 @@ pub fn async_connectTcp( .port = port, .protocol = protocol, })) |conn| { + ctx.data.conn = conn; ctx.req.connection = conn; return ctx.pop({}); } @@ -1845,7 +1867,6 @@ pub fn connectTunnel( .connection = conn, .server_header_buffer = &buffer, }) catch |err| { - std.log.debug("err {}", .{err}); break :tunnel err; }; defer req.deinit(); @@ -2426,14 +2447,19 @@ pub const Ctx = struct { pub fn pop(self: *Ctx, res: anyerror!void) !void { if (self.stack) |stack| { - const func = stack.pop(self.alloc(), null); - const ret = @call(.auto, func, .{ self, res }); - if (stack.next == null) { - self.stack = null; - self.alloc().destroy(stack); + const allocator = self.alloc(); + const func = stack.pop(allocator, null); + + defer { + if (stack.next == null) { + allocator.destroy(stack); + self.stack = null; + } } - return ret; + + return @call(.auto, func, .{ self, res }); } + unreachable; } pub fn deinit(self: Ctx) void { @@ -2484,62 +2510,3 @@ fn setRequestConnection(ctx: *Ctx, res: anyerror!void) anyerror!void { ctx.req.connection = ctx.data.conn; return ctx.pop({}); } - -fn onRequestWait(ctx: *Ctx, res: anyerror!void) !void { - res catch |e| { - std.debug.print("error: {any}\n", .{e}); - return e; - }; - std.log.debug("REQUEST WAITED", .{}); - std.log.debug("Status code: {any}", .{ctx.req.response.status}); - const body = try ctx.req.reader().readAllAlloc(ctx.alloc(), 1024 * 1024); - defer ctx.alloc().free(body); - std.log.debug("Body: \n{s}", .{body}); -} - -fn onRequestFinish(ctx: *Ctx, res: anyerror!void) !void { - res catch |err| return err; - std.log.debug("REQUEST FINISHED", .{}); - return ctx.req.async_wait(ctx, onRequestWait); -} - -fn onRequestSend(ctx: *Ctx, res: anyerror!void) !void { - res catch |err| return err; - std.log.debug("REQUEST SENT", .{}); - return ctx.req.async_finish(ctx, onRequestFinish); -} - -pub fn onRequestConnect(ctx: *Ctx, res: anyerror!void) anyerror!void { - res catch |err| return err; - std.log.debug("REQUEST CONNECTED", .{}); - return ctx.req.async_send(ctx, onRequestSend); -} - -test { - const alloc = std.testing.allocator; - - var loop = Loop{}; - - var client = Client{ .allocator = alloc }; - defer client.deinit(); - - var req = Request{ - .client = &client, - }; - defer req.deinit(); - - var ctx = try Ctx.init(&loop, &req); - defer ctx.deinit(); - - var server_header_buffer: [2048]u8 = undefined; - - const url = "http://www.example.com"; - // const url = "http://127.0.0.1:8000/zig"; - try client.async_open( - .GET, - try std.Uri.parse(url), - .{ .server_header_buffer = &server_header_buffer }, - &ctx, - onRequestConnect, - ); -} diff --git a/src/run_tests.zig b/src/run_tests.zig index 8a5be4a5..b50dbce8 100644 --- a/src/run_tests.zig +++ b/src/run_tests.zig @@ -30,7 +30,7 @@ const xhr = @import("xhr/xhr.zig"); const storage = @import("storage/storage.zig"); const url = @import("url/url.zig"); const urlquery = @import("url/query.zig"); -const Client = @import("async/Client.zig"); +const Client = @import("http/async/main.zig").Client; const documentTestExecFn = @import("dom/document.zig").testExecFn; const HTMLDocumentTestExecFn = @import("html/document.zig").testExecFn; @@ -86,7 +86,7 @@ fn testExecFn( std.debug.print("documentHTMLClose error: {s}\n", .{@errorName(err)}); }; - var cli = Client{ .allocator = alloc, .loop = js_env.nat_ctx.loop }; + var cli = Client{ .allocator = alloc }; defer cli.deinit(); try js_env.setUserContext(.{ diff --git a/src/user_context.zig b/src/user_context.zig index 23d85955..644893c8 100644 --- a/src/user_context.zig +++ b/src/user_context.zig @@ -1,6 +1,6 @@ const std = @import("std"); const parser = @import("netsurf"); -const Client = @import("async/Client.zig"); +const Client = @import("http/async/main.zig").Client; pub const UserContext = struct { document: *parser.DocumentHTML, diff --git a/src/wpt/run.zig b/src/wpt/run.zig index e2b58470..ec1d3397 100644 --- a/src/wpt/run.zig +++ b/src/wpt/run.zig @@ -31,7 +31,7 @@ const storage = @import("../storage/storage.zig"); const Types = @import("../main_wpt.zig").Types; const UserContext = @import("../main_wpt.zig").UserContext; -const Client = @import("../async/Client.zig"); +const Client = @import("../http/async/main.zig").Client; // runWPT parses the given HTML file, starts a js env and run the first script // tags containing javascript sources. @@ -53,7 +53,7 @@ pub fn run(arena: *std.heap.ArenaAllocator, comptime dir: []const u8, f: []const var loop = try Loop.init(alloc); defer loop.deinit(); - var cli = Client{ .allocator = alloc, .loop = &loop }; + var cli = Client{ .allocator = alloc }; defer cli.deinit(); var js_env: Env = undefined; diff --git a/src/xhr/xhr.zig b/src/xhr/xhr.zig index 6d6269e9..660e449a 100644 --- a/src/xhr/xhr.zig +++ b/src/xhr/xhr.zig @@ -32,8 +32,7 @@ const XMLHttpRequestEventTarget = @import("event_target.zig").XMLHttpRequestEven const Mime = @import("../browser/mime.zig"); const Loop = jsruntime.Loop; -const YieldImpl = Loop.Yield(XMLHttpRequest); -const Client = @import("../async/Client.zig"); +const Client = @import("../http/async/main.zig").Client; const parser = @import("netsurf"); @@ -98,10 +97,11 @@ pub const XMLHttpRequest = struct { proto: XMLHttpRequestEventTarget = XMLHttpRequestEventTarget{}, alloc: std.mem.Allocator, cli: *Client, - impl: YieldImpl, + loop: Client.Loop, priv_state: PrivState = .new, req: ?Client.Request = null, + ctx: ?Client.Ctx = null, method: std.http.Method, state: u16, @@ -135,7 +135,13 @@ pub const XMLHttpRequest = struct { response_header_buffer: [1024 * 16]u8 = undefined, response_status: u10 = 0, - response_override_mime_type: ?[]const u8 = null, + + // TODO uncomment this field causes casting issue with + // XMLHttpRequestEventTarget. I think it's dueto an alignement issue, but + // not sure. see + // https://lightpanda.slack.com/archives/C05TRU6RBM1/p1707819010681019 + // response_override_mime_type: ?[]const u8 = null, + response_mime: Mime = undefined, response_obj: ?ResponseObj = null, send_flag: bool = false, @@ -288,7 +294,7 @@ pub const XMLHttpRequest = struct { .alloc = alloc, .headers = Headers.init(alloc), .response_headers = Headers.init(alloc), - .impl = YieldImpl.init(loop), + .loop = Client.Loop.init(loop), .method = undefined, .url = null, .uri = undefined, @@ -320,10 +326,11 @@ pub const XMLHttpRequest = struct { self.priv_state = .new; - if (self.req) |*r| { - r.deinit(); - self.req = null; - } + if (self.ctx) |*c| c.deinit(); + self.ctx = null; + + if (self.req) |*r| r.deinit(); + self.req = null; } pub fn deinit(self: *XMLHttpRequest, alloc: std.mem.Allocator) void { @@ -494,138 +501,160 @@ pub const XMLHttpRequest = struct { log.debug("{any} {any}", .{ self.method, self.uri }); self.send_flag = true; - self.impl.yield(self); - } - // onYield is a callback called between each request's steps. - // Between each step, the code is blocking. - // Yielding allows pseudo-async and gives a chance to other async process - // to be called. - pub fn onYield(self: *XMLHttpRequest, err: ?anyerror) void { - if (err) |e| return self.onErr(e); + self.priv_state = .open; - switch (self.priv_state) { - .new => { - self.priv_state = .open; - self.req = self.cli.open(self.method, self.uri, .{ - .server_header_buffer = &self.response_header_buffer, - .extra_headers = self.headers.all(), - }) catch |e| return self.onErr(e); - }, - .open => { - // prepare payload transfert. - if (self.payload) |v| self.req.?.transfer_encoding = .{ .content_length = v.len }; - - self.priv_state = .send; - self.req.?.send() catch |e| return self.onErr(e); - }, - .send => { - if (self.payload) |payload| { - self.priv_state = .write; - self.req.?.writeAll(payload) catch |e| return self.onErr(e); - } else { - self.priv_state = .finish; - self.req.?.finish() catch |e| return self.onErr(e); - } - }, - .write => { - self.priv_state = .finish; - self.req.?.finish() catch |e| return self.onErr(e); - }, - .finish => { - self.priv_state = .wait; - self.req.?.wait() catch |e| return self.onErr(e); - }, - .wait => { - log.info("{any} {any} {d}", .{ self.method, self.uri, self.req.?.response.status }); - - self.priv_state = .done; - var it = self.req.?.response.iterateHeaders(); - self.response_headers.load(&it) catch |e| return self.onErr(e); - - // extract a mime type from headers. - const ct = self.response_headers.getFirstValue("Content-Type") orelse "text/xml"; - self.response_mime = Mime.parse(ct) catch |e| return self.onErr(e); - - // TODO handle override mime type - - self.state = HEADERS_RECEIVED; - self.dispatchEvt("readystatechange"); - - self.response_status = @intFromEnum(self.req.?.response.status); - - var buf: std.ArrayListUnmanaged(u8) = .{}; - - // TODO set correct length - const total = 0; - var loaded: u64 = 0; - - // dispatch a progress event loadstart. - self.dispatchProgressEvent("loadstart", .{ .loaded = loaded, .total = total }); - - const reader = self.req.?.reader(); - var buffer: [1024]u8 = undefined; - var ln = buffer.len; - var prev_dispatch: ?std.time.Instant = null; - while (ln > 0) { - ln = reader.read(&buffer) catch |e| { - buf.deinit(self.alloc); - return self.onErr(e); - }; - buf.appendSlice(self.alloc, buffer[0..ln]) catch |e| { - buf.deinit(self.alloc); - return self.onErr(e); - }; - loaded = loaded + ln; - - // Dispatch only if 50ms have passed. - const now = std.time.Instant.now() catch |e| { - buf.deinit(self.alloc); - return self.onErr(e); - }; - if (prev_dispatch != null and now.since(prev_dispatch.?) < min_delay) continue; - defer prev_dispatch = now; - - self.state = LOADING; - self.dispatchEvt("readystatechange"); - - // dispatch a progress event progress. - self.dispatchProgressEvent("progress", .{ - .loaded = loaded, - .total = total, - }); - } - self.response_bytes = buf.items; - self.send_flag = false; - - self.state = DONE; - self.dispatchEvt("readystatechange"); - - // dispatch a progress event load. - self.dispatchProgressEvent("load", .{ .loaded = loaded, .total = total }); - // dispatch a progress event loadend. - self.dispatchProgressEvent("loadend", .{ .loaded = loaded, .total = total }); - }, - .done => { - if (self.req) |*r| { - r.deinit(); - self.req = null; - } - - // finalize fetch process. - return; - }, + self.req = try self.cli.create(self.method, self.uri, .{ + .server_header_buffer = &self.response_header_buffer, + .extra_headers = self.headers.all(), + }); + errdefer { + self.req.?.deinit(); + self.req = null; } - self.impl.yield(self); + self.ctx = try Client.Ctx.init(&self.loop, &self.req.?); + errdefer { + self.ctx.?.deinit(); + self.ctx = null; + } + self.ctx.?.userData = self; + + try self.cli.async_open( + self.method, + self.uri, + .{ .server_header_buffer = &self.response_header_buffer }, + &self.ctx.?, + onRequestConnect, + ); + } + + fn onRequestWait(ctx: *Client.Ctx, res: anyerror!void) !void { + var self = selfCtx(ctx); + res catch |err| return self.onErr(err); + + log.info("{any} {any} {d}", .{ self.method, self.uri, self.req.?.response.status }); + + self.priv_state = .done; + var it = self.req.?.response.iterateHeaders(); + self.response_headers.load(&it) catch |e| return self.onErr(e); + + // extract a mime type from headers. + const ct = self.response_headers.getFirstValue("Content-Type") orelse "text/xml"; + self.response_mime = Mime.parse(ct) catch |e| return self.onErr(e); + + // TODO handle override mime type + + self.state = HEADERS_RECEIVED; + self.dispatchEvt("readystatechange"); + + self.response_status = @intFromEnum(self.req.?.response.status); + + var buf: std.ArrayListUnmanaged(u8) = .{}; + + // TODO set correct length + const total = 0; + var loaded: u64 = 0; + + // dispatch a progress event loadstart. + self.dispatchProgressEvent("loadstart", .{ .loaded = loaded, .total = total }); + + // TODO read async + const reader = self.req.?.reader(); + var buffer: [1024]u8 = undefined; + var ln = buffer.len; + var prev_dispatch: ?std.time.Instant = null; + while (ln > 0) { + ln = reader.read(&buffer) catch |e| { + buf.deinit(self.alloc); + return self.onErr(e); + }; + buf.appendSlice(self.alloc, buffer[0..ln]) catch |e| { + buf.deinit(self.alloc); + return self.onErr(e); + }; + loaded = loaded + ln; + + // Dispatch only if 50ms have passed. + const now = std.time.Instant.now() catch |e| { + buf.deinit(self.alloc); + return self.onErr(e); + }; + if (prev_dispatch != null and now.since(prev_dispatch.?) < min_delay) continue; + defer prev_dispatch = now; + + self.state = LOADING; + self.dispatchEvt("readystatechange"); + + // dispatch a progress event progress. + self.dispatchProgressEvent("progress", .{ + .loaded = loaded, + .total = total, + }); + } + self.response_bytes = buf.items; + self.send_flag = false; + + self.state = DONE; + self.dispatchEvt("readystatechange"); + + // dispatch a progress event load. + self.dispatchProgressEvent("load", .{ .loaded = loaded, .total = total }); + // dispatch a progress event loadend. + self.dispatchProgressEvent("loadend", .{ .loaded = loaded, .total = total }); + + if (self.ctx) |*c| c.deinit(); + self.ctx = null; + + if (self.req) |*r| r.deinit(); + self.req = null; + } + + fn onRequestFinish(ctx: *Client.Ctx, res: anyerror!void) !void { + var self = selfCtx(ctx); + res catch |err| return self.onErr(err); + + self.priv_state = .wait; + return ctx.req.async_wait(ctx, onRequestWait) catch |e| return self.onErr(e); + } + + fn onRequestSend(ctx: *Client.Ctx, res: anyerror!void) !void { + var self = selfCtx(ctx); + res catch |err| return self.onErr(err); + + if (self.payload) |payload| { + self.priv_state = .write; + return ctx.req.async_writeAll(payload, ctx, onRequestWrite) catch |e| return self.onErr(e); + } + + self.priv_state = .finish; + return ctx.req.async_finish(ctx, onRequestFinish) catch |e| return self.onErr(e); + } + + fn onRequestWrite(ctx: *Client.Ctx, res: anyerror!void) !void { + var self = selfCtx(ctx); + res catch |err| return self.onErr(err); + self.priv_state = .finish; + return ctx.req.async_finish(ctx, onRequestFinish) catch |e| return self.onErr(e); + } + + fn onRequestConnect(ctx: *Client.Ctx, res: anyerror!void) anyerror!void { + var self = selfCtx(ctx); + res catch |err| return self.onErr(err); + + // prepare payload transfert. + if (self.payload) |v| self.req.?.transfer_encoding = .{ .content_length = v.len }; + + self.priv_state = .send; + return ctx.req.async_send(ctx, onRequestSend) catch |err| return self.onErr(err); + } + + fn selfCtx(ctx: *Client.Ctx) *XMLHttpRequest { + return @ptrCast(@alignCast(ctx.userData)); } fn onErr(self: *XMLHttpRequest, err: anyerror) void { self.priv_state = .done; - if (self.req) |*r| { - r.deinit(); - self.req = null; - } self.err = err; self.state = DONE; @@ -635,6 +664,12 @@ pub const XMLHttpRequest = struct { self.dispatchProgressEvent("loadend", .{}); log.debug("{any} {any} {any}", .{ self.method, self.uri, self.err }); + + if (self.ctx) |*c| c.deinit(); + self.ctx = null; + + if (self.req) |*r| r.deinit(); + self.req = null; } pub fn _abort(self: *XMLHttpRequest) void { @@ -882,7 +917,7 @@ pub fn testExecFn( // .{ .src = "req.onload", .ex = "function cbk(event) { nb ++; evt = event; }" }, //.{ .src = "req.onload = cbk", .ex = "function cbk(event) { nb ++; evt = event; }" }, - .{ .src = "req.open('GET', 'http://httpbin.io/html')", .ex = "undefined" }, + .{ .src = "req.open('GET', 'https://httpbin.io/html')", .ex = "undefined" }, .{ .src = "req.setRequestHeader('User-Agent', 'lightpanda/1.0')", .ex = "undefined" }, // ensure open resets values @@ -912,7 +947,7 @@ pub fn testExecFn( var document = [_]Case{ .{ .src = "const req2 = new XMLHttpRequest()", .ex = "undefined" }, - .{ .src = "req2.open('GET', 'http://httpbin.io/html')", .ex = "undefined" }, + .{ .src = "req2.open('GET', 'https://httpbin.io/html')", .ex = "undefined" }, .{ .src = "req2.responseType = 'document'", .ex = "document" }, .{ .src = "req2.send()", .ex = "undefined" }, @@ -928,7 +963,7 @@ pub fn testExecFn( var json = [_]Case{ .{ .src = "const req3 = new XMLHttpRequest()", .ex = "undefined" }, - .{ .src = "req3.open('GET', 'http://httpbin.io/json')", .ex = "undefined" }, + .{ .src = "req3.open('GET', 'https://httpbin.io/json')", .ex = "undefined" }, .{ .src = "req3.responseType = 'json'", .ex = "json" }, .{ .src = "req3.send()", .ex = "undefined" }, @@ -943,7 +978,7 @@ pub fn testExecFn( var post = [_]Case{ .{ .src = "const req4 = new XMLHttpRequest()", .ex = "undefined" }, - .{ .src = "req4.open('POST', 'http://httpbin.io/post')", .ex = "undefined" }, + .{ .src = "req4.open('POST', 'https://httpbin.io/post')", .ex = "undefined" }, .{ .src = "req4.send('foo')", .ex = "undefined" }, // Each case executed waits for all loop callaback calls. @@ -956,7 +991,7 @@ pub fn testExecFn( var cbk = [_]Case{ .{ .src = "const req5 = new XMLHttpRequest()", .ex = "undefined" }, - .{ .src = "req5.open('GET', 'http://httpbin.io/json')", .ex = "undefined" }, + .{ .src = "req5.open('GET', 'https://httpbin.io/json')", .ex = "undefined" }, .{ .src = "var status = 0; req5.onload = function () { status = this.status };", .ex = "function () { status = this.status }" }, .{ .src = "req5.send()", .ex = "undefined" }, From 7fed1f3015b2e1f36c754ac84004f6c2a2337996 Mon Sep 17 00:00:00 2001 From: Pierre Tachoire Date: Fri, 15 Nov 2024 14:57:44 +0100 Subject: [PATCH 03/11] async: remove pseudo-async http client --- src/async/Client.zig | 1766 ------------------------------------------ src/async/stream.zig | 133 ---- src/async/tcp.zig | 112 --- src/async/test.zig | 189 ----- src/main_shell.zig | 4 +- src/run_tests.zig | 2 +- 6 files changed, 3 insertions(+), 2203 deletions(-) delete mode 100644 src/async/Client.zig delete mode 100644 src/async/stream.zig delete mode 100644 src/async/tcp.zig delete mode 100644 src/async/test.zig diff --git a/src/async/Client.zig b/src/async/Client.zig deleted file mode 100644 index 91748b9a..00000000 --- a/src/async/Client.zig +++ /dev/null @@ -1,1766 +0,0 @@ -// Copyright (C) 2023-2024 Lightpanda (Selecy SAS) -// -// Francis Bouvier -// Pierre Tachoire -// -// This program is free software: you can redistribute it and/or modify -// it under the terms of the GNU Affero General Public License as -// published by the Free Software Foundation, either version 3 of the -// License, or (at your option) any later version. -// -// This program is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU Affero General Public License for more details. -// -// You should have received a copy of the GNU Affero General Public License -// along with this program. If not, see . - -//! HTTP(S) Client implementation. -//! -//! Connections are opened in a thread-safe manner, but individual Requests are not. -//! -//! TLS support may be disabled via `std.options.http_disable_tls`. - -const std = @import("std"); -const builtin = @import("builtin"); -const Stream = @import("stream.zig").Stream; -const testing = std.testing; -const http = std.http; -const mem = std.mem; -const net = std.net; -const Uri = std.Uri; -const Allocator = mem.Allocator; -const assert = std.debug.assert; -const use_vectors = builtin.zig_backend != .stage2_x86_64; - -const Client = @This(); -const proto = std.http.protocol; - -const tls23 = @import("tls"); - -const Loop = @import("jsruntime").Loop; -const tcp = @import("tcp.zig"); - -pub const disable_tls = std.options.http_disable_tls; - -/// Used for all client allocations. Must be thread-safe. -allocator: Allocator, - -// std.net.Stream implementation using jsruntime Loop -loop: *Loop, - -ca_bundle: if (disable_tls) void else std.crypto.Certificate.Bundle = if (disable_tls) {} else .{}, -ca_bundle_mutex: std.Thread.Mutex = .{}, - -/// When this is `true`, the next time this client performs an HTTPS request, -/// it will first rescan the system for root certificates. -next_https_rescan_certs: bool = true, - -/// The pool of connections that can be reused (and currently in use). -connection_pool: ConnectionPool = .{}, - -/// If populated, all http traffic travels through this third party. -/// This field cannot be modified while the client has active connections. -/// Pointer to externally-owned memory. -http_proxy: ?*Proxy = null, -/// If populated, all https traffic travels through this third party. -/// This field cannot be modified while the client has active connections. -/// Pointer to externally-owned memory. -https_proxy: ?*Proxy = null, - -/// A set of linked lists of connections that can be reused. -pub const ConnectionPool = struct { - mutex: std.Thread.Mutex = .{}, - /// Open connections that are currently in use. - used: Queue = .{}, - /// Open connections that are not currently in use. - free: Queue = .{}, - free_len: usize = 0, - free_size: usize = 32, - - /// The criteria for a connection to be considered a match. - pub const Criteria = struct { - host: []const u8, - port: u16, - protocol: Connection.Protocol, - }; - - const Queue = std.DoublyLinkedList(Connection); - pub const Node = Queue.Node; - - /// Finds and acquires a connection from the connection pool matching the criteria. This function is threadsafe. - /// If no connection is found, null is returned. - pub fn findConnection(pool: *ConnectionPool, criteria: Criteria) ?*Connection { - pool.mutex.lock(); - defer pool.mutex.unlock(); - - var next = pool.free.last; - while (next) |node| : (next = node.prev) { - if (node.data.protocol != criteria.protocol) continue; - if (node.data.port != criteria.port) continue; - - // Domain names are case-insensitive (RFC 5890, Section 2.3.2.4) - if (!std.ascii.eqlIgnoreCase(node.data.host, criteria.host)) continue; - - pool.acquireUnsafe(node); - return &node.data; - } - - return null; - } - - /// Acquires an existing connection from the connection pool. This function is not threadsafe. - pub fn acquireUnsafe(pool: *ConnectionPool, node: *Node) void { - pool.free.remove(node); - pool.free_len -= 1; - - pool.used.append(node); - } - - /// Acquires an existing connection from the connection pool. This function is threadsafe. - pub fn acquire(pool: *ConnectionPool, node: *Node) void { - pool.mutex.lock(); - defer pool.mutex.unlock(); - - return pool.acquireUnsafe(node); - } - - /// Tries to release a connection back to the connection pool. This function is threadsafe. - /// If the connection is marked as closing, it will be closed instead. - /// - /// The allocator must be the owner of all nodes in this pool. - /// The allocator must be the owner of all resources associated with the connection. - pub fn release(pool: *ConnectionPool, allocator: Allocator, connection: *Connection) void { - pool.mutex.lock(); - defer pool.mutex.unlock(); - - const node: *Node = @fieldParentPtr("data", connection); - - pool.used.remove(node); - - if (node.data.closing or pool.free_size == 0) { - node.data.close(allocator); - return allocator.destroy(node); - } - - if (pool.free_len >= pool.free_size) { - const popped = pool.free.popFirst() orelse unreachable; - pool.free_len -= 1; - - popped.data.close(allocator); - allocator.destroy(popped); - } - - if (node.data.proxied) { - pool.free.prepend(node); // proxied connections go to the end of the queue, always try direct connections first - } else { - pool.free.append(node); - } - - pool.free_len += 1; - } - - /// Adds a newly created node to the pool of used connections. This function is threadsafe. - pub fn addUsed(pool: *ConnectionPool, node: *Node) void { - pool.mutex.lock(); - defer pool.mutex.unlock(); - - pool.used.append(node); - } - - /// Resizes the connection pool. This function is threadsafe. - /// - /// If the new size is smaller than the current size, then idle connections will be closed until the pool is the new size. - pub fn resize(pool: *ConnectionPool, allocator: Allocator, new_size: usize) void { - pool.mutex.lock(); - defer pool.mutex.unlock(); - - const next = pool.free.first; - _ = next; - while (pool.free_len > new_size) { - const popped = pool.free.popFirst() orelse unreachable; - pool.free_len -= 1; - - popped.data.close(allocator); - allocator.destroy(popped); - } - - pool.free_size = new_size; - } - - /// Frees the connection pool and closes all connections within. This function is threadsafe. - /// - /// All future operations on the connection pool will deadlock. - pub fn deinit(pool: *ConnectionPool, allocator: Allocator) void { - pool.mutex.lock(); - - var next = pool.free.first; - while (next) |node| { - defer allocator.destroy(node); - next = node.next; - - node.data.close(allocator); - } - - next = pool.used.first; - while (next) |node| { - defer allocator.destroy(node); - next = node.next; - - node.data.close(allocator); - } - - pool.* = undefined; - } -}; - -/// An interface to either a plain or TLS connection. -pub const Connection = struct { - stream: Stream, - /// undefined unless protocol is tls. - tls_client: if (!disable_tls) *tls23.Connection(Stream) else void, - - /// The protocol that this connection is using. - protocol: Protocol, - - /// The host that this connection is connected to. - host: []u8, - - /// The port that this connection is connected to. - port: u16, - - /// Whether this connection is proxied and is not directly connected. - proxied: bool = false, - - /// Whether this connection is closing when we're done with it. - closing: bool = false, - - read_start: BufferSize = 0, - read_end: BufferSize = 0, - write_end: BufferSize = 0, - read_buf: [buffer_size]u8 = undefined, - write_buf: [buffer_size]u8 = undefined, - - pub const buffer_size = std.crypto.tls.max_ciphertext_record_len; - const BufferSize = std.math.IntFittingRange(0, buffer_size); - - pub const Protocol = enum { plain, tls }; - - pub fn readvDirectTls(conn: *Connection, buffers: []std.posix.iovec) ReadError!usize { - return conn.tls_client.readv(buffers) catch |err| { - // https://github.com/ziglang/zig/issues/2473 - if (mem.startsWith(u8, @errorName(err), "TlsAlert")) return error.TlsAlert; - - switch (err) { - error.TlsRecordOverflow, error.TlsBadRecordMac, error.TlsUnexpectedMessage => return error.TlsFailure, - error.ConnectionTimedOut => return error.ConnectionTimedOut, - error.ConnectionResetByPeer, error.BrokenPipe => return error.ConnectionResetByPeer, - else => return error.UnexpectedReadFailure, - } - }; - } - - pub fn readvDirect(conn: *Connection, buffers: []std.posix.iovec) ReadError!usize { - if (conn.protocol == .tls) { - if (disable_tls) unreachable; - - return conn.readvDirectTls(buffers); - } - - return conn.stream.readv(buffers) catch |err| switch (err) { - error.ConnectionTimedOut => return error.ConnectionTimedOut, - error.ConnectionResetByPeer, error.BrokenPipe => return error.ConnectionResetByPeer, - else => return error.UnexpectedReadFailure, - }; - } - - /// Refills the read buffer with data from the connection. - pub fn fill(conn: *Connection) ReadError!void { - if (conn.read_end != conn.read_start) return; - - var iovecs = [1]std.posix.iovec{ - .{ .base = &conn.read_buf, .len = conn.read_buf.len }, - }; - const nread = try conn.readvDirect(&iovecs); - if (nread == 0) return error.EndOfStream; - conn.read_start = 0; - conn.read_end = @intCast(nread); - } - - /// Returns the current slice of buffered data. - pub fn peek(conn: *Connection) []const u8 { - return conn.read_buf[conn.read_start..conn.read_end]; - } - - /// Discards the given number of bytes from the read buffer. - pub fn drop(conn: *Connection, num: BufferSize) void { - conn.read_start += num; - } - - /// Reads data from the connection into the given buffer. - pub fn read(conn: *Connection, buffer: []u8) ReadError!usize { - const available_read = conn.read_end - conn.read_start; - const available_buffer = buffer.len; - - if (available_read > available_buffer) { // partially read buffered data - @memcpy(buffer[0..available_buffer], conn.read_buf[conn.read_start..conn.read_end][0..available_buffer]); - conn.read_start += @intCast(available_buffer); - - return available_buffer; - } else if (available_read > 0) { // fully read buffered data - @memcpy(buffer[0..available_read], conn.read_buf[conn.read_start..conn.read_end]); - conn.read_start += available_read; - - return available_read; - } - - var iovecs = [2]std.posix.iovec{ - .{ .base = buffer.ptr, .len = buffer.len }, - .{ .base = &conn.read_buf, .len = conn.read_buf.len }, - }; - const nread = try conn.readvDirect(&iovecs); - - if (nread > buffer.len) { - conn.read_start = 0; - conn.read_end = @intCast(nread - buffer.len); - return buffer.len; - } - - return nread; - } - - pub const ReadError = error{ - TlsFailure, - TlsAlert, - ConnectionTimedOut, - ConnectionResetByPeer, - UnexpectedReadFailure, - EndOfStream, - }; - - pub const Reader = std.io.Reader(*Connection, ReadError, read); - - pub fn reader(conn: *Connection) Reader { - return Reader{ .context = conn }; - } - - pub fn writeAllDirectTls(conn: *Connection, buffer: []const u8) WriteError!void { - return conn.tls_client.writeAll(buffer) catch |err| switch (err) { - error.BrokenPipe, error.ConnectionResetByPeer => return error.ConnectionResetByPeer, - else => return error.UnexpectedWriteFailure, - }; - } - - pub fn writeAllDirect(conn: *Connection, buffer: []const u8) WriteError!void { - if (conn.protocol == .tls) { - if (disable_tls) unreachable; - - return conn.writeAllDirectTls(buffer); - } - - return conn.stream.writeAll(buffer) catch |err| switch (err) { - error.BrokenPipe, error.ConnectionResetByPeer => return error.ConnectionResetByPeer, - else => return error.UnexpectedWriteFailure, - }; - } - - /// Writes the given buffer to the connection. - pub fn write(conn: *Connection, buffer: []const u8) WriteError!usize { - if (conn.write_buf.len - conn.write_end < buffer.len) { - try conn.flush(); - - if (buffer.len > conn.write_buf.len) { - try conn.writeAllDirect(buffer); - return buffer.len; - } - } - - @memcpy(conn.write_buf[conn.write_end..][0..buffer.len], buffer); - conn.write_end += @intCast(buffer.len); - - return buffer.len; - } - - /// Returns a buffer to be filled with exactly len bytes to write to the connection. - pub fn allocWriteBuffer(conn: *Connection, len: BufferSize) WriteError![]u8 { - if (conn.write_buf.len - conn.write_end < len) try conn.flush(); - defer conn.write_end += len; - return conn.write_buf[conn.write_end..][0..len]; - } - - /// Flushes the write buffer to the connection. - pub fn flush(conn: *Connection) WriteError!void { - if (conn.write_end == 0) return; - - try conn.writeAllDirect(conn.write_buf[0..conn.write_end]); - conn.write_end = 0; - } - - pub const WriteError = error{ - ConnectionResetByPeer, - UnexpectedWriteFailure, - }; - - pub const Writer = std.io.Writer(*Connection, WriteError, write); - - pub fn writer(conn: *Connection) Writer { - return Writer{ .context = conn }; - } - - /// Closes the connection. - pub fn close(conn: *Connection, allocator: Allocator) void { - if (conn.protocol == .tls) { - if (disable_tls) unreachable; - - // try to cleanly close the TLS connection, for any server that cares. - conn.tls_client.close() catch {}; - allocator.destroy(conn.tls_client); - } - - conn.stream.close(); - allocator.free(conn.host); - } -}; - -/// The mode of transport for requests. -pub const RequestTransfer = union(enum) { - content_length: u64, - chunked: void, - none: void, -}; - -/// The decompressor for response messages. -pub const Compression = union(enum) { - pub const DeflateDecompressor = std.compress.zlib.Decompressor(Request.TransferReader); - pub const GzipDecompressor = std.compress.gzip.Decompressor(Request.TransferReader); - // https://github.com/ziglang/zig/issues/18937 - //pub const ZstdDecompressor = std.compress.zstd.DecompressStream(Request.TransferReader, .{}); - - deflate: DeflateDecompressor, - gzip: GzipDecompressor, - // https://github.com/ziglang/zig/issues/18937 - //zstd: ZstdDecompressor, - none: void, -}; - -/// A HTTP response originating from a server. -pub const Response = struct { - version: http.Version, - status: http.Status, - reason: []const u8, - - /// Points into the user-provided `server_header_buffer`. - location: ?[]const u8 = null, - /// Points into the user-provided `server_header_buffer`. - content_type: ?[]const u8 = null, - /// Points into the user-provided `server_header_buffer`. - content_disposition: ?[]const u8 = null, - - keep_alive: bool, - - /// If present, the number of bytes in the response body. - content_length: ?u64 = null, - - /// If present, the transfer encoding of the response body, otherwise none. - transfer_encoding: http.TransferEncoding = .none, - - /// If present, the compression of the response body, otherwise identity (no compression). - transfer_compression: http.ContentEncoding = .identity, - - parser: proto.HeadersParser, - compression: Compression = .none, - - /// Whether the response body should be skipped. Any data read from the - /// response body will be discarded. - skip: bool = false, - - pub const ParseError = error{ - HttpHeadersInvalid, - HttpHeaderContinuationsUnsupported, - HttpTransferEncodingUnsupported, - HttpConnectionHeaderUnsupported, - InvalidContentLength, - CompressionUnsupported, - }; - - pub fn parse(res: *Response, bytes: []const u8) ParseError!void { - var it = mem.splitSequence(u8, bytes, "\r\n"); - - const first_line = it.next().?; - if (first_line.len < 12) { - return error.HttpHeadersInvalid; - } - - const version: http.Version = switch (int64(first_line[0..8])) { - int64("HTTP/1.0") => .@"HTTP/1.0", - int64("HTTP/1.1") => .@"HTTP/1.1", - else => return error.HttpHeadersInvalid, - }; - if (first_line[8] != ' ') return error.HttpHeadersInvalid; - const status: http.Status = @enumFromInt(parseInt3(first_line[9..12])); - const reason = mem.trimLeft(u8, first_line[12..], " "); - - res.version = version; - res.status = status; - res.reason = reason; - res.keep_alive = switch (version) { - .@"HTTP/1.0" => false, - .@"HTTP/1.1" => true, - }; - - while (it.next()) |line| { - if (line.len == 0) return; - switch (line[0]) { - ' ', '\t' => return error.HttpHeaderContinuationsUnsupported, - else => {}, - } - - var line_it = mem.splitScalar(u8, line, ':'); - const header_name = line_it.next().?; - const header_value = mem.trim(u8, line_it.rest(), " \t"); - if (header_name.len == 0) return error.HttpHeadersInvalid; - - if (std.ascii.eqlIgnoreCase(header_name, "connection")) { - res.keep_alive = !std.ascii.eqlIgnoreCase(header_value, "close"); - } else if (std.ascii.eqlIgnoreCase(header_name, "content-type")) { - res.content_type = header_value; - } else if (std.ascii.eqlIgnoreCase(header_name, "location")) { - res.location = header_value; - } else if (std.ascii.eqlIgnoreCase(header_name, "content-disposition")) { - res.content_disposition = header_value; - } else if (std.ascii.eqlIgnoreCase(header_name, "transfer-encoding")) { - // Transfer-Encoding: second, first - // Transfer-Encoding: deflate, chunked - var iter = mem.splitBackwardsScalar(u8, header_value, ','); - - const first = iter.first(); - const trimmed_first = mem.trim(u8, first, " "); - - var next: ?[]const u8 = first; - if (std.meta.stringToEnum(http.TransferEncoding, trimmed_first)) |transfer| { - if (res.transfer_encoding != .none) return error.HttpHeadersInvalid; // we already have a transfer encoding - res.transfer_encoding = transfer; - - next = iter.next(); - } - - if (next) |second| { - const trimmed_second = mem.trim(u8, second, " "); - - if (std.meta.stringToEnum(http.ContentEncoding, trimmed_second)) |transfer| { - if (res.transfer_compression != .identity) return error.HttpHeadersInvalid; // double compression is not supported - res.transfer_compression = transfer; - } else { - return error.HttpTransferEncodingUnsupported; - } - } - - if (iter.next()) |_| return error.HttpTransferEncodingUnsupported; - } else if (std.ascii.eqlIgnoreCase(header_name, "content-length")) { - const content_length = std.fmt.parseInt(u64, header_value, 10) catch return error.InvalidContentLength; - - if (res.content_length != null and res.content_length != content_length) return error.HttpHeadersInvalid; - - res.content_length = content_length; - } else if (std.ascii.eqlIgnoreCase(header_name, "content-encoding")) { - if (res.transfer_compression != .identity) return error.HttpHeadersInvalid; - - const trimmed = mem.trim(u8, header_value, " "); - - if (std.meta.stringToEnum(http.ContentEncoding, trimmed)) |ce| { - res.transfer_compression = ce; - } else { - return error.HttpTransferEncodingUnsupported; - } - } - } - return error.HttpHeadersInvalid; // missing empty line - } - - test parse { - const response_bytes = "HTTP/1.1 200 OK\r\n" ++ - "LOcation:url\r\n" ++ - "content-tYpe: text/plain\r\n" ++ - "content-disposition:attachment; filename=example.txt \r\n" ++ - "content-Length:10\r\n" ++ - "TRansfer-encoding:\tdeflate, chunked \r\n" ++ - "connectioN:\t keep-alive \r\n\r\n"; - - var header_buffer: [1024]u8 = undefined; - var res = Response{ - .status = undefined, - .reason = undefined, - .version = undefined, - .keep_alive = false, - .parser = proto.HeadersParser.init(&header_buffer), - }; - - @memcpy(header_buffer[0..response_bytes.len], response_bytes); - res.parser.header_bytes_len = response_bytes.len; - - try res.parse(response_bytes); - - try testing.expectEqual(.@"HTTP/1.1", res.version); - try testing.expectEqualStrings("OK", res.reason); - try testing.expectEqual(.ok, res.status); - - try testing.expectEqualStrings("url", res.location.?); - try testing.expectEqualStrings("text/plain", res.content_type.?); - try testing.expectEqualStrings("attachment; filename=example.txt", res.content_disposition.?); - - try testing.expectEqual(true, res.keep_alive); - try testing.expectEqual(10, res.content_length.?); - try testing.expectEqual(.chunked, res.transfer_encoding); - try testing.expectEqual(.deflate, res.transfer_compression); - } - - inline fn int64(array: *const [8]u8) u64 { - return @bitCast(array.*); - } - - fn parseInt3(text: *const [3]u8) u10 { - if (use_vectors) { - const nnn: @Vector(3, u8) = text.*; - const zero: @Vector(3, u8) = .{ '0', '0', '0' }; - const mmm: @Vector(3, u10) = .{ 100, 10, 1 }; - return @reduce(.Add, @as(@Vector(3, u10), nnn -% zero) *% mmm); - } - return std.fmt.parseInt(u10, text, 10) catch unreachable; - } - - test parseInt3 { - const expectEqual = testing.expectEqual; - try expectEqual(@as(u10, 0), parseInt3("000")); - try expectEqual(@as(u10, 418), parseInt3("418")); - try expectEqual(@as(u10, 999), parseInt3("999")); - } - - pub fn iterateHeaders(r: Response) http.HeaderIterator { - return http.HeaderIterator.init(r.parser.get()); - } - - test iterateHeaders { - const response_bytes = "HTTP/1.1 200 OK\r\n" ++ - "LOcation:url\r\n" ++ - "content-tYpe: text/plain\r\n" ++ - "content-disposition:attachment; filename=example.txt \r\n" ++ - "content-Length:10\r\n" ++ - "TRansfer-encoding:\tdeflate, chunked \r\n" ++ - "connectioN:\t keep-alive \r\n\r\n"; - - var header_buffer: [1024]u8 = undefined; - var res = Response{ - .status = undefined, - .reason = undefined, - .version = undefined, - .keep_alive = false, - .parser = proto.HeadersParser.init(&header_buffer), - }; - - @memcpy(header_buffer[0..response_bytes.len], response_bytes); - res.parser.header_bytes_len = response_bytes.len; - - var it = res.iterateHeaders(); - { - const header = it.next().?; - try testing.expectEqualStrings("LOcation", header.name); - try testing.expectEqualStrings("url", header.value); - try testing.expect(!it.is_trailer); - } - { - const header = it.next().?; - try testing.expectEqualStrings("content-tYpe", header.name); - try testing.expectEqualStrings("text/plain", header.value); - try testing.expect(!it.is_trailer); - } - { - const header = it.next().?; - try testing.expectEqualStrings("content-disposition", header.name); - try testing.expectEqualStrings("attachment; filename=example.txt", header.value); - try testing.expect(!it.is_trailer); - } - { - const header = it.next().?; - try testing.expectEqualStrings("content-Length", header.name); - try testing.expectEqualStrings("10", header.value); - try testing.expect(!it.is_trailer); - } - { - const header = it.next().?; - try testing.expectEqualStrings("TRansfer-encoding", header.name); - try testing.expectEqualStrings("deflate, chunked", header.value); - try testing.expect(!it.is_trailer); - } - { - const header = it.next().?; - try testing.expectEqualStrings("connectioN", header.name); - try testing.expectEqualStrings("keep-alive", header.value); - try testing.expect(!it.is_trailer); - } - try testing.expectEqual(null, it.next()); - } -}; - -/// A HTTP request that has been sent. -/// -/// Order of operations: open -> send[ -> write -> finish] -> wait -> read -pub const Request = struct { - uri: Uri, - client: *Client, - /// This is null when the connection is released. - connection: ?*Connection, - keep_alive: bool, - - method: http.Method, - version: http.Version = .@"HTTP/1.1", - transfer_encoding: RequestTransfer, - redirect_behavior: RedirectBehavior, - - /// Whether the request should handle a 100-continue response before sending the request body. - handle_continue: bool, - - /// The response associated with this request. - /// - /// This field is undefined until `wait` is called. - response: Response, - - /// Standard headers that have default, but overridable, behavior. - headers: Headers, - - /// These headers are kept including when following a redirect to a - /// different domain. - /// Externally-owned; must outlive the Request. - extra_headers: []const http.Header, - - /// These headers are stripped when following a redirect to a different - /// domain. - /// Externally-owned; must outlive the Request. - privileged_headers: []const http.Header, - - pub const Headers = struct { - host: Value = .default, - authorization: Value = .default, - user_agent: Value = .default, - connection: Value = .default, - accept_encoding: Value = .default, - content_type: Value = .default, - - pub const Value = union(enum) { - default, - omit, - override: []const u8, - }; - }; - - /// Any value other than `not_allowed` or `unhandled` means that integer represents - /// how many remaining redirects are allowed. - pub const RedirectBehavior = enum(u16) { - /// The next redirect will cause an error. - not_allowed = 0, - /// Redirects are passed to the client to analyze the redirect response - /// directly. - unhandled = std.math.maxInt(u16), - _, - - pub fn subtractOne(rb: *RedirectBehavior) void { - switch (rb.*) { - .not_allowed => unreachable, - .unhandled => unreachable, - _ => rb.* = @enumFromInt(@intFromEnum(rb.*) - 1), - } - } - - pub fn remaining(rb: RedirectBehavior) u16 { - assert(rb != .unhandled); - return @intFromEnum(rb); - } - }; - - /// Frees all resources associated with the request. - pub fn deinit(req: *Request) void { - if (req.connection) |connection| { - if (!req.response.parser.done) { - // If the response wasn't fully read, then we need to close the connection. - connection.closing = true; - } - req.client.connection_pool.release(req.client.allocator, connection); - } - req.* = undefined; - } - - // This function must deallocate all resources associated with the request, - // or keep those which will be used. - // This needs to be kept in sync with deinit and request. - fn redirect(req: *Request, uri: Uri) !void { - assert(req.response.parser.done); - - req.client.connection_pool.release(req.client.allocator, req.connection.?); - req.connection = null; - - var server_header = std.heap.FixedBufferAllocator.init(req.response.parser.header_bytes_buffer); - defer req.response.parser.header_bytes_buffer = server_header.buffer[server_header.end_index..]; - const protocol, const valid_uri = try validateUri(uri, server_header.allocator()); - - const new_host = valid_uri.host.?.raw; - const prev_host = req.uri.host.?.raw; - const keep_privileged_headers = - std.ascii.eqlIgnoreCase(valid_uri.scheme, req.uri.scheme) and - std.ascii.endsWithIgnoreCase(new_host, prev_host) and - (new_host.len == prev_host.len or new_host[new_host.len - prev_host.len - 1] == '.'); - if (!keep_privileged_headers) { - // When redirecting to a different domain, strip privileged headers. - req.privileged_headers = &.{}; - } - - if (switch (req.response.status) { - .see_other => true, - .moved_permanently, .found => req.method == .POST, - else => false, - }) { - // A redirect to a GET must change the method and remove the body. - req.method = .GET; - req.transfer_encoding = .none; - req.headers.content_type = .omit; - } - - if (req.transfer_encoding != .none) { - // The request body has already been sent. The request is - // still in a valid state, but the redirect must be handled - // manually. - return error.RedirectRequiresResend; - } - - req.uri = valid_uri; - req.connection = try req.client.connect(new_host, uriPort(valid_uri, protocol), protocol); - req.redirect_behavior.subtractOne(); - req.response.parser.reset(); - - req.response = .{ - .version = undefined, - .status = undefined, - .reason = undefined, - .keep_alive = undefined, - .parser = req.response.parser, - }; - } - - pub const SendError = Connection.WriteError || error{ InvalidContentLength, UnsupportedTransferEncoding }; - - /// Send the HTTP request headers to the server. - pub fn send(req: *Request) SendError!void { - if (!req.method.requestHasBody() and req.transfer_encoding != .none) - return error.UnsupportedTransferEncoding; - - const connection = req.connection.?; - const w = connection.writer(); - - try req.method.write(w); - try w.writeByte(' '); - - if (req.method == .CONNECT) { - try req.uri.writeToStream(.{ .authority = true }, w); - } else { - try req.uri.writeToStream(.{ - .scheme = connection.proxied, - .authentication = connection.proxied, - .authority = connection.proxied, - .path = true, - .query = true, - }, w); - } - try w.writeByte(' '); - try w.writeAll(@tagName(req.version)); - try w.writeAll("\r\n"); - - if (try emitOverridableHeader("host: ", req.headers.host, w)) { - try w.writeAll("host: "); - try req.uri.writeToStream(.{ .authority = true }, w); - try w.writeAll("\r\n"); - } - - if (try emitOverridableHeader("authorization: ", req.headers.authorization, w)) { - if (req.uri.user != null or req.uri.password != null) { - try w.writeAll("authorization: "); - const authorization = try connection.allocWriteBuffer( - @intCast(basic_authorization.valueLengthFromUri(req.uri)), - ); - assert(basic_authorization.value(req.uri, authorization).len == authorization.len); - try w.writeAll("\r\n"); - } - } - - if (try emitOverridableHeader("user-agent: ", req.headers.user_agent, w)) { - try w.writeAll("user-agent: zig/"); - try w.writeAll(builtin.zig_version_string); - try w.writeAll(" (std.http)\r\n"); - } - - if (try emitOverridableHeader("connection: ", req.headers.connection, w)) { - if (req.keep_alive) { - try w.writeAll("connection: keep-alive\r\n"); - } else { - try w.writeAll("connection: close\r\n"); - } - } - - if (try emitOverridableHeader("accept-encoding: ", req.headers.accept_encoding, w)) { - // https://github.com/ziglang/zig/issues/18937 - //try w.writeAll("accept-encoding: gzip, deflate, zstd\r\n"); - try w.writeAll("accept-encoding: gzip, deflate\r\n"); - } - - switch (req.transfer_encoding) { - .chunked => try w.writeAll("transfer-encoding: chunked\r\n"), - .content_length => |len| try w.print("content-length: {d}\r\n", .{len}), - .none => {}, - } - - if (try emitOverridableHeader("content-type: ", req.headers.content_type, w)) { - // The default is to omit content-type if not provided because - // "application/octet-stream" is redundant. - } - - for (req.extra_headers) |header| { - assert(header.name.len != 0); - - try w.writeAll(header.name); - try w.writeAll(": "); - try w.writeAll(header.value); - try w.writeAll("\r\n"); - } - - if (connection.proxied) proxy: { - const proxy = switch (connection.protocol) { - .plain => req.client.http_proxy, - .tls => req.client.https_proxy, - } orelse break :proxy; - - const authorization = proxy.authorization orelse break :proxy; - try w.writeAll("proxy-authorization: "); - try w.writeAll(authorization); - try w.writeAll("\r\n"); - } - - try w.writeAll("\r\n"); - - try connection.flush(); - } - - /// Returns true if the default behavior is required, otherwise handles - /// writing (or not writing) the header. - fn emitOverridableHeader(prefix: []const u8, v: Headers.Value, w: anytype) !bool { - switch (v) { - .default => return true, - .omit => return false, - .override => |x| { - try w.writeAll(prefix); - try w.writeAll(x); - try w.writeAll("\r\n"); - return false; - }, - } - } - - const TransferReadError = Connection.ReadError || proto.HeadersParser.ReadError; - - const TransferReader = std.io.Reader(*Request, TransferReadError, transferRead); - - fn transferReader(req: *Request) TransferReader { - return .{ .context = req }; - } - - fn transferRead(req: *Request, buf: []u8) TransferReadError!usize { - if (req.response.parser.done) return 0; - - var index: usize = 0; - while (index == 0) { - const amt = try req.response.parser.read(req.connection.?, buf[index..], req.response.skip); - if (amt == 0 and req.response.parser.done) break; - index += amt; - } - - return index; - } - - pub const WaitError = RequestError || SendError || TransferReadError || - proto.HeadersParser.CheckCompleteHeadError || Response.ParseError || - error{ // TODO: file zig fmt issue for this bad indentation - TooManyHttpRedirects, - RedirectRequiresResend, - HttpRedirectLocationMissing, - HttpRedirectLocationInvalid, - CompressionInitializationFailed, - CompressionUnsupported, - }; - - /// Waits for a response from the server and parses any headers that are sent. - /// This function will block until the final response is received. - /// - /// If handling redirects and the request has no payload, then this - /// function will automatically follow redirects. If a request payload is - /// present, then this function will error with - /// error.RedirectRequiresResend. - /// - /// Must be called after `send` and, if any data was written to the request - /// body, then also after `finish`. - pub fn wait(req: *Request) WaitError!void { - while (true) { - // This while loop is for handling redirects, which means the request's - // connection may be different than the previous iteration. However, it - // is still guaranteed to be non-null with each iteration of this loop. - const connection = req.connection.?; - - while (true) { // read headers - try connection.fill(); - - const nchecked = try req.response.parser.checkCompleteHead(connection.peek()); - connection.drop(@intCast(nchecked)); - - if (req.response.parser.state.isContent()) break; - } - - try req.response.parse(req.response.parser.get()); - - if (req.response.status == .@"continue") { - // We're done parsing the continue response; reset to prepare - // for the real response. - req.response.parser.done = true; - req.response.parser.reset(); - - if (req.handle_continue) - continue; - - return; // we're not handling the 100-continue - } - - // we're switching protocols, so this connection is no longer doing http - if (req.method == .CONNECT and req.response.status.class() == .success) { - connection.closing = false; - req.response.parser.done = true; - return; // the connection is not HTTP past this point - } - - connection.closing = !req.response.keep_alive or !req.keep_alive; - - // Any response to a HEAD request and any response with a 1xx - // (Informational), 204 (No Content), or 304 (Not Modified) status - // code is always terminated by the first empty line after the - // header fields, regardless of the header fields present in the - // message. - if (req.method == .HEAD or req.response.status.class() == .informational or - req.response.status == .no_content or req.response.status == .not_modified) - { - req.response.parser.done = true; - return; // The response is empty; no further setup or redirection is necessary. - } - - switch (req.response.transfer_encoding) { - .none => { - if (req.response.content_length) |cl| { - req.response.parser.next_chunk_length = cl; - - if (cl == 0) req.response.parser.done = true; - } else { - // read until the connection is closed - req.response.parser.next_chunk_length = std.math.maxInt(u64); - } - }, - .chunked => { - req.response.parser.next_chunk_length = 0; - req.response.parser.state = .chunk_head_size; - }, - } - - if (req.response.status.class() == .redirect and req.redirect_behavior != .unhandled) { - // skip the body of the redirect response, this will at least - // leave the connection in a known good state. - req.response.skip = true; - assert(try req.transferRead(&.{}) == 0); // we're skipping, no buffer is necessary - - if (req.redirect_behavior == .not_allowed) return error.TooManyHttpRedirects; - - const location = req.response.location orelse - return error.HttpRedirectLocationMissing; - - // This mutates the beginning of header_bytes_buffer and uses that - // for the backing memory of the returned Uri. - try req.redirect(req.uri.resolve_inplace( - location, - &req.response.parser.header_bytes_buffer, - ) catch |err| switch (err) { - error.UnexpectedCharacter, - error.InvalidFormat, - error.InvalidPort, - => return error.HttpRedirectLocationInvalid, - error.NoSpaceLeft => return error.HttpHeadersOversize, - }); - try req.send(); - } else { - req.response.skip = false; - if (!req.response.parser.done) { - switch (req.response.transfer_compression) { - .identity => req.response.compression = .none, - .compress, .@"x-compress" => return error.CompressionUnsupported, - .deflate => req.response.compression = .{ - .deflate = std.compress.zlib.decompressor(req.transferReader()), - }, - .gzip, .@"x-gzip" => req.response.compression = .{ - .gzip = std.compress.gzip.decompressor(req.transferReader()), - }, - // https://github.com/ziglang/zig/issues/18937 - //.zstd => req.response.compression = .{ - // .zstd = std.compress.zstd.decompressStream(req.client.allocator, req.transferReader()), - //}, - .zstd => return error.CompressionUnsupported, - } - } - - break; - } - } - } - - pub const ReadError = TransferReadError || proto.HeadersParser.CheckCompleteHeadError || - error{ DecompressionFailure, InvalidTrailers }; - - pub const Reader = std.io.Reader(*Request, ReadError, read); - - pub fn reader(req: *Request) Reader { - return .{ .context = req }; - } - - /// Reads data from the response body. Must be called after `wait`. - pub fn read(req: *Request, buffer: []u8) ReadError!usize { - const out_index = switch (req.response.compression) { - .deflate => |*deflate| deflate.read(buffer) catch return error.DecompressionFailure, - .gzip => |*gzip| gzip.read(buffer) catch return error.DecompressionFailure, - // https://github.com/ziglang/zig/issues/18937 - //.zstd => |*zstd| zstd.read(buffer) catch return error.DecompressionFailure, - else => try req.transferRead(buffer), - }; - if (out_index > 0) return out_index; - - while (!req.response.parser.state.isContent()) { // read trailing headers - try req.connection.?.fill(); - - const nchecked = try req.response.parser.checkCompleteHead(req.connection.?.peek()); - req.connection.?.drop(@intCast(nchecked)); - } - - return 0; - } - - /// Reads data from the response body. Must be called after `wait`. - pub fn readAll(req: *Request, buffer: []u8) !usize { - var index: usize = 0; - while (index < buffer.len) { - const amt = try read(req, buffer[index..]); - if (amt == 0) break; - index += amt; - } - return index; - } - - pub const WriteError = Connection.WriteError || error{ NotWriteable, MessageTooLong }; - - pub const Writer = std.io.Writer(*Request, WriteError, write); - - pub fn writer(req: *Request) Writer { - return .{ .context = req }; - } - - /// Write `bytes` to the server. The `transfer_encoding` field determines how data will be sent. - /// Must be called after `send` and before `finish`. - pub fn write(req: *Request, bytes: []const u8) WriteError!usize { - switch (req.transfer_encoding) { - .chunked => { - if (bytes.len > 0) { - try req.connection.?.writer().print("{x}\r\n", .{bytes.len}); - try req.connection.?.writer().writeAll(bytes); - try req.connection.?.writer().writeAll("\r\n"); - } - - return bytes.len; - }, - .content_length => |*len| { - if (len.* < bytes.len) return error.MessageTooLong; - - const amt = try req.connection.?.write(bytes); - len.* -= amt; - return amt; - }, - .none => return error.NotWriteable, - } - } - - /// Write `bytes` to the server. The `transfer_encoding` field determines how data will be sent. - /// Must be called after `send` and before `finish`. - pub fn writeAll(req: *Request, bytes: []const u8) WriteError!void { - var index: usize = 0; - while (index < bytes.len) { - index += try write(req, bytes[index..]); - } - } - - pub const FinishError = WriteError || error{MessageNotCompleted}; - - /// Finish the body of a request. This notifies the server that you have no more data to send. - /// Must be called after `send`. - pub fn finish(req: *Request) FinishError!void { - switch (req.transfer_encoding) { - .chunked => try req.connection.?.writer().writeAll("0\r\n\r\n"), - .content_length => |len| if (len != 0) return error.MessageNotCompleted, - .none => {}, - } - - try req.connection.?.flush(); - } -}; - -pub const Proxy = struct { - protocol: Connection.Protocol, - host: []const u8, - authorization: ?[]const u8, - port: u16, - supports_connect: bool, -}; - -/// Release all associated resources with the client. -/// -/// All pending requests must be de-initialized and all active connections released -/// before calling this function. -pub fn deinit(client: *Client) void { - assert(client.connection_pool.used.first == null); // There are still active requests. - - client.connection_pool.deinit(client.allocator); - - if (!disable_tls) - client.ca_bundle.deinit(client.allocator); - - client.* = undefined; -} - -/// Populates `http_proxy` and `https_proxy` via standard proxy environment variables. -/// Asserts the client has no active connections. -/// Uses `arena` for a few small allocations that must outlive the client, or -/// at least until those fields are set to different values. -pub fn initDefaultProxies(client: *Client, arena: Allocator) !void { - // Prevent any new connections from being created. - client.connection_pool.mutex.lock(); - defer client.connection_pool.mutex.unlock(); - - assert(client.connection_pool.used.first == null); // There are active requests. - - if (client.http_proxy == null) { - client.http_proxy = try createProxyFromEnvVar(arena, &.{ - "http_proxy", "HTTP_PROXY", "all_proxy", "ALL_PROXY", - }); - } - - if (client.https_proxy == null) { - client.https_proxy = try createProxyFromEnvVar(arena, &.{ - "https_proxy", "HTTPS_PROXY", "all_proxy", "ALL_PROXY", - }); - } -} - -fn createProxyFromEnvVar(arena: Allocator, env_var_names: []const []const u8) !?*Proxy { - const content = for (env_var_names) |name| { - break std.process.getEnvVarOwned(arena, name) catch |err| switch (err) { - error.EnvironmentVariableNotFound => continue, - else => |e| return e, - }; - } else return null; - - const uri = Uri.parse(content) catch try Uri.parseAfterScheme("http", content); - const protocol, const valid_uri = validateUri(uri, arena) catch |err| switch (err) { - error.UnsupportedUriScheme => return null, - error.UriMissingHost => return error.HttpProxyMissingHost, - error.OutOfMemory => |e| return e, - }; - - const authorization: ?[]const u8 = if (valid_uri.user != null or valid_uri.password != null) a: { - const authorization = try arena.alloc(u8, basic_authorization.valueLengthFromUri(valid_uri)); - assert(basic_authorization.value(valid_uri, authorization).len == authorization.len); - break :a authorization; - } else null; - - const proxy = try arena.create(Proxy); - proxy.* = .{ - .protocol = protocol, - .host = valid_uri.host.?.raw, - .authorization = authorization, - .port = uriPort(valid_uri, protocol), - .supports_connect = true, - }; - return proxy; -} - -pub const basic_authorization = struct { - pub const max_user_len = 255; - pub const max_password_len = 255; - pub const max_value_len = valueLength(max_user_len, max_password_len); - - const prefix = "Basic "; - - pub fn valueLength(user_len: usize, password_len: usize) usize { - return prefix.len + std.base64.standard.Encoder.calcSize(user_len + 1 + password_len); - } - - pub fn valueLengthFromUri(uri: Uri) usize { - var stream = std.io.countingWriter(std.io.null_writer); - try stream.writer().print("{user}", .{uri.user orelse Uri.Component.empty}); - const user_len = stream.bytes_written; - stream.bytes_written = 0; - try stream.writer().print("{password}", .{uri.password orelse Uri.Component.empty}); - const password_len = stream.bytes_written; - return valueLength(@intCast(user_len), @intCast(password_len)); - } - - pub fn value(uri: Uri, out: []u8) []u8 { - var buf: [max_user_len + ":".len + max_password_len]u8 = undefined; - var stream = std.io.fixedBufferStream(&buf); - stream.writer().print("{user}", .{uri.user orelse Uri.Component.empty}) catch - unreachable; - assert(stream.pos <= max_user_len); - stream.writer().print(":{password}", .{uri.password orelse Uri.Component.empty}) catch - unreachable; - - @memcpy(out[0..prefix.len], prefix); - const base64 = std.base64.standard.Encoder.encode(out[prefix.len..], stream.getWritten()); - return out[0 .. prefix.len + base64.len]; - } -}; - -pub const ConnectTcpError = Allocator.Error || error{ ConnectionRefused, NetworkUnreachable, ConnectionTimedOut, ConnectionResetByPeer, TemporaryNameServerFailure, NameServerFailure, UnknownHostName, HostLacksNetworkAddresses, UnexpectedConnectFailure, TlsInitializationFailed }; - -/// Connect to `host:port` using the specified protocol. This will reuse a connection if one is already open. -/// -/// This function is threadsafe. -pub fn connectTcp(client: *Client, host: []const u8, port: u16, protocol: Connection.Protocol) ConnectTcpError!*Connection { - if (client.connection_pool.findConnection(.{ - .host = host, - .port = port, - .protocol = protocol, - })) |node| return node; - - if (disable_tls and protocol == .tls) - return error.TlsInitializationFailed; - - const conn = try client.allocator.create(ConnectionPool.Node); - errdefer client.allocator.destroy(conn); - conn.* = .{ .data = undefined }; - - const stream = tcp.tcpConnectToHost(client.allocator, client.loop, host, port) catch |err| switch (err) { - error.ConnectionRefused => return error.ConnectionRefused, - error.NetworkUnreachable => return error.NetworkUnreachable, - error.ConnectionTimedOut => return error.ConnectionTimedOut, - error.ConnectionResetByPeer => return error.ConnectionResetByPeer, - error.TemporaryNameServerFailure => return error.TemporaryNameServerFailure, - error.NameServerFailure => return error.NameServerFailure, - error.UnknownHostName => return error.UnknownHostName, - error.HostLacksNetworkAddresses => return error.HostLacksNetworkAddresses, - else => return error.UnexpectedConnectFailure, - }; - errdefer stream.close(); - - conn.data = .{ - .stream = stream, - .tls_client = undefined, - - .protocol = protocol, - .host = try client.allocator.dupe(u8, host), - .port = port, - }; - errdefer client.allocator.free(conn.data.host); - - if (protocol == .tls) { - if (disable_tls) unreachable; - - conn.data.tls_client = try client.allocator.create(tls23.Connection(Stream)); - errdefer client.allocator.destroy(conn.data.tls_client); - - conn.data.tls_client.* = tls23.client(stream, .{ - .host = host, - .root_ca = client.ca_bundle, - }) catch return error.TlsInitializationFailed; - } - - client.connection_pool.addUsed(conn); - - return &conn.data; -} - -/// Connect to `tunnel_host:tunnel_port` using the specified proxy with HTTP -/// CONNECT. This will reuse a connection if one is already open. -/// -/// This function is threadsafe. -pub fn connectTunnel( - client: *Client, - proxy: *Proxy, - tunnel_host: []const u8, - tunnel_port: u16, -) !*Connection { - if (!proxy.supports_connect) return error.TunnelNotSupported; - - if (client.connection_pool.findConnection(.{ - .host = tunnel_host, - .port = tunnel_port, - .protocol = proxy.protocol, - })) |node| - return node; - - var maybe_valid = false; - (tunnel: { - const conn = try client.connectTcp(proxy.host, proxy.port, proxy.protocol); - errdefer { - conn.closing = true; - client.connection_pool.release(client.allocator, conn); - } - - var buffer: [8096]u8 = undefined; - var req = client.open(.CONNECT, .{ - .scheme = "http", - .host = .{ .raw = tunnel_host }, - .port = tunnel_port, - }, .{ - .redirect_behavior = .unhandled, - .connection = conn, - .server_header_buffer = &buffer, - }) catch |err| { - std.log.debug("err {}", .{err}); - break :tunnel err; - }; - defer req.deinit(); - - req.send() catch |err| break :tunnel err; - req.wait() catch |err| break :tunnel err; - - if (req.response.status.class() == .server_error) { - maybe_valid = true; - break :tunnel error.ServerError; - } - - if (req.response.status != .ok) break :tunnel error.ConnectionRefused; - - // this connection is now a tunnel, so we can't use it for anything else, it will only be released when the client is de-initialized. - req.connection = null; - - client.allocator.free(conn.host); - conn.host = try client.allocator.dupe(u8, tunnel_host); - errdefer client.allocator.free(conn.host); - - conn.port = tunnel_port; - conn.closing = false; - - return conn; - }) catch { - // something went wrong with the tunnel - proxy.supports_connect = maybe_valid; - return error.TunnelNotSupported; - }; -} - -// Prevents a dependency loop in open() -const ConnectErrorPartial = ConnectTcpError || error{ UnsupportedUriScheme, ConnectionRefused }; -pub const ConnectError = ConnectErrorPartial || RequestError; - -/// Connect to `host:port` using the specified protocol. This will reuse a -/// connection if one is already open. -/// If a proxy is configured for the client, then the proxy will be used to -/// connect to the host. -/// -/// This function is threadsafe. -pub fn connect( - client: *Client, - host: []const u8, - port: u16, - protocol: Connection.Protocol, -) ConnectError!*Connection { - const proxy = switch (protocol) { - .plain => client.http_proxy, - .tls => client.https_proxy, - } orelse return client.connectTcp(host, port, protocol); - - // Prevent proxying through itself. - if (std.ascii.eqlIgnoreCase(proxy.host, host) and - proxy.port == port and proxy.protocol == protocol) - { - return client.connectTcp(host, port, protocol); - } - - if (proxy.supports_connect) tunnel: { - return connectTunnel(client, proxy, host, port) catch |err| switch (err) { - error.TunnelNotSupported => break :tunnel, - else => |e| return e, - }; - } - - // fall back to using the proxy as a normal http proxy - const conn = try client.connectTcp(proxy.host, proxy.port, proxy.protocol); - errdefer { - conn.closing = true; - client.connection_pool.release(conn); - } - - conn.proxied = true; - return conn; -} - -pub const RequestError = ConnectTcpError || ConnectErrorPartial || Request.SendError || - std.fmt.ParseIntError || Connection.WriteError || - error{ // TODO: file a zig fmt issue for this bad indentation - UnsupportedUriScheme, - UriMissingHost, - - CertificateBundleLoadFailure, - UnsupportedTransferEncoding, -}; - -pub const RequestOptions = struct { - version: http.Version = .@"HTTP/1.1", - - /// Automatically ignore 100 Continue responses. This assumes you don't - /// care, and will have sent the body before you wait for the response. - /// - /// If this is not the case AND you know the server will send a 100 - /// Continue, set this to false and wait for a response before sending the - /// body. If you wait AND the server does not send a 100 Continue before - /// you finish the request, then the request *will* deadlock. - handle_continue: bool = true, - - /// If false, close the connection after the one request. If true, - /// participate in the client connection pool. - keep_alive: bool = true, - - /// This field specifies whether to automatically follow redirects, and if - /// so, how many redirects to follow before returning an error. - /// - /// This will only follow redirects for repeatable requests (ie. with no - /// payload or the server has acknowledged the payload). - redirect_behavior: Request.RedirectBehavior = @enumFromInt(3), - - /// Externally-owned memory used to store the server's entire HTTP header. - /// `error.HttpHeadersOversize` is returned from read() when a - /// client sends too many bytes of HTTP headers. - server_header_buffer: []u8, - - /// Must be an already acquired connection. - connection: ?*Connection = null, - - /// Standard headers that have default, but overridable, behavior. - headers: Request.Headers = .{}, - /// These headers are kept including when following a redirect to a - /// different domain. - /// Externally-owned; must outlive the Request. - extra_headers: []const http.Header = &.{}, - /// These headers are stripped when following a redirect to a different - /// domain. - /// Externally-owned; must outlive the Request. - privileged_headers: []const http.Header = &.{}, -}; - -fn validateUri(uri: Uri, arena: Allocator) !struct { Connection.Protocol, Uri } { - const protocol_map = std.StaticStringMap(Connection.Protocol).initComptime(.{ - .{ "http", .plain }, - .{ "ws", .plain }, - .{ "https", .tls }, - .{ "wss", .tls }, - }); - const protocol = protocol_map.get(uri.scheme) orelse return error.UnsupportedUriScheme; - var valid_uri = uri; - // The host is always going to be needed as a raw string for hostname resolution anyway. - valid_uri.host = .{ - .raw = try (uri.host orelse return error.UriMissingHost).toRawMaybeAlloc(arena), - }; - return .{ protocol, valid_uri }; -} - -fn uriPort(uri: Uri, protocol: Connection.Protocol) u16 { - return uri.port orelse switch (protocol) { - .plain => 80, - .tls => 443, - }; -} - -/// Open a connection to the host specified by `uri` and prepare to send a HTTP request. -/// -/// `uri` must remain alive during the entire request. -/// -/// The caller is responsible for calling `deinit()` on the `Request`. -/// This function is threadsafe. -/// -/// Asserts that "\r\n" does not occur in any header name or value. -pub fn open( - client: *Client, - method: http.Method, - uri: Uri, - options: RequestOptions, -) RequestError!Request { - if (std.debug.runtime_safety) { - for (options.extra_headers) |header| { - assert(header.name.len != 0); - assert(std.mem.indexOfScalar(u8, header.name, ':') == null); - assert(std.mem.indexOfPosLinear(u8, header.name, 0, "\r\n") == null); - assert(std.mem.indexOfPosLinear(u8, header.value, 0, "\r\n") == null); - } - for (options.privileged_headers) |header| { - assert(header.name.len != 0); - assert(std.mem.indexOfPosLinear(u8, header.name, 0, "\r\n") == null); - assert(std.mem.indexOfPosLinear(u8, header.value, 0, "\r\n") == null); - } - } - - var server_header = std.heap.FixedBufferAllocator.init(options.server_header_buffer); - const protocol, const valid_uri = try validateUri(uri, server_header.allocator()); - - if (protocol == .tls and @atomicLoad(bool, &client.next_https_rescan_certs, .acquire)) { - if (disable_tls) unreachable; - - client.ca_bundle_mutex.lock(); - defer client.ca_bundle_mutex.unlock(); - - if (client.next_https_rescan_certs) { - client.ca_bundle.rescan(client.allocator) catch - return error.CertificateBundleLoadFailure; - @atomicStore(bool, &client.next_https_rescan_certs, false, .release); - } - } - - const conn = options.connection orelse - try client.connect(valid_uri.host.?.raw, uriPort(valid_uri, protocol), protocol); - - var req: Request = .{ - .uri = valid_uri, - .client = client, - .connection = conn, - .keep_alive = options.keep_alive, - .method = method, - .version = options.version, - .transfer_encoding = .none, - .redirect_behavior = options.redirect_behavior, - .handle_continue = options.handle_continue, - .response = .{ - .version = undefined, - .status = undefined, - .reason = undefined, - .keep_alive = undefined, - .parser = proto.HeadersParser.init(server_header.buffer[server_header.end_index..]), - }, - .headers = options.headers, - .extra_headers = options.extra_headers, - .privileged_headers = options.privileged_headers, - }; - errdefer req.deinit(); - - return req; -} - -pub const FetchOptions = struct { - server_header_buffer: ?[]u8 = null, - redirect_behavior: ?Request.RedirectBehavior = null, - - /// If the server sends a body, it will be appended to this ArrayList. - /// `max_append_size` provides an upper limit for how much they can grow. - response_storage: ResponseStorage = .ignore, - max_append_size: ?usize = null, - - location: Location, - method: ?http.Method = null, - payload: ?[]const u8 = null, - raw_uri: bool = false, - keep_alive: bool = true, - - /// Standard headers that have default, but overridable, behavior. - headers: Request.Headers = .{}, - /// These headers are kept including when following a redirect to a - /// different domain. - /// Externally-owned; must outlive the Request. - extra_headers: []const http.Header = &.{}, - /// These headers are stripped when following a redirect to a different - /// domain. - /// Externally-owned; must outlive the Request. - privileged_headers: []const http.Header = &.{}, - - pub const Location = union(enum) { - url: []const u8, - uri: Uri, - }; - - pub const ResponseStorage = union(enum) { - ignore, - /// Only the existing capacity will be used. - static: *std.ArrayListUnmanaged(u8), - dynamic: *std.ArrayList(u8), - }; -}; - -pub const FetchResult = struct { - status: http.Status, -}; - -/// Perform a one-shot HTTP request with the provided options. -/// -/// This function is threadsafe. -pub fn fetch(client: *Client, options: FetchOptions) !FetchResult { - const uri = switch (options.location) { - .url => |u| try Uri.parse(u), - .uri => |u| u, - }; - var server_header_buffer: [16 * 1024]u8 = undefined; - - const method: http.Method = options.method orelse - if (options.payload != null) .POST else .GET; - - var req = try open(client, method, uri, .{ - .server_header_buffer = options.server_header_buffer orelse &server_header_buffer, - .redirect_behavior = options.redirect_behavior orelse - if (options.payload == null) @enumFromInt(3) else .unhandled, - .headers = options.headers, - .extra_headers = options.extra_headers, - .privileged_headers = options.privileged_headers, - .keep_alive = options.keep_alive, - }); - defer req.deinit(); - - if (options.payload) |payload| req.transfer_encoding = .{ .content_length = payload.len }; - - try req.send(); - - if (options.payload) |payload| try req.writeAll(payload); - - try req.finish(); - try req.wait(); - - switch (options.response_storage) { - .ignore => { - // Take advantage of request internals to discard the response body - // and make the connection available for another request. - req.response.skip = true; - assert(try req.transferRead(&.{}) == 0); // No buffer is necessary when skipping. - }, - .dynamic => |list| { - const max_append_size = options.max_append_size orelse 2 * 1024 * 1024; - try req.reader().readAllArrayList(list, max_append_size); - }, - .static => |list| { - const buf = b: { - const buf = list.unusedCapacitySlice(); - if (options.max_append_size) |len| { - if (len < buf.len) break :b buf[0..len]; - } - break :b buf; - }; - list.items.len += try req.reader().readAll(buf); - }, - } - - return .{ - .status = req.response.status, - }; -} - -test { - _ = &initDefaultProxies; -} diff --git a/src/async/stream.zig b/src/async/stream.zig deleted file mode 100644 index 85b6cbb2..00000000 --- a/src/async/stream.zig +++ /dev/null @@ -1,133 +0,0 @@ -// Copyright (C) 2023-2024 Lightpanda (Selecy SAS) -// -// Francis Bouvier -// Pierre Tachoire -// -// This program is free software: you can redistribute it and/or modify -// it under the terms of the GNU Affero General Public License as -// published by the Free Software Foundation, either version 3 of the -// License, or (at your option) any later version. -// -// This program is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU Affero General Public License for more details. -// -// You should have received a copy of the GNU Affero General Public License -// along with this program. If not, see . - -const std = @import("std"); -const builtin = @import("builtin"); -const posix = std.posix; -const io = std.io; -const assert = std.debug.assert; - -const tcp = @import("tcp.zig"); - -pub const Stream = struct { - alloc: std.mem.Allocator, - conn: *tcp.Conn, - - handle: posix.socket_t, - - pub fn close(self: Stream) void { - posix.close(self.handle); - self.alloc.destroy(self.conn); - } - - pub const ReadError = posix.ReadError; - pub const WriteError = posix.WriteError; - - pub const Reader = io.Reader(Stream, ReadError, read); - pub const Writer = io.Writer(Stream, WriteError, write); - - pub fn reader(self: Stream) Reader { - return .{ .context = self }; - } - - pub fn writer(self: Stream) Writer { - return .{ .context = self }; - } - - pub fn read(self: Stream, buffer: []u8) ReadError!usize { - return self.conn.receive(self.handle, buffer) catch |err| switch (err) { - else => return error.Unexpected, - }; - } - - pub fn readv(s: Stream, iovecs: []const posix.iovec) ReadError!usize { - return posix.readv(s.handle, iovecs); - } - - /// Returns the number of bytes read. If the number read is smaller than - /// `buffer.len`, it means the stream reached the end. Reaching the end of - /// a stream is not an error condition. - pub fn readAll(s: Stream, buffer: []u8) ReadError!usize { - return readAtLeast(s, buffer, buffer.len); - } - - /// Returns the number of bytes read, calling the underlying read function - /// the minimal number of times until the buffer has at least `len` bytes - /// filled. If the number read is less than `len` it means the stream - /// reached the end. Reaching the end of the stream is not an error - /// condition. - pub fn readAtLeast(s: Stream, buffer: []u8, len: usize) ReadError!usize { - assert(len <= buffer.len); - var index: usize = 0; - while (index < len) { - const amt = try s.read(buffer[index..]); - if (amt == 0) break; - index += amt; - } - return index; - } - - /// TODO in evented I/O mode, this implementation incorrectly uses the event loop's - /// file system thread instead of non-blocking. It needs to be reworked to properly - /// use non-blocking I/O. - pub fn write(self: Stream, buffer: []const u8) WriteError!usize { - return self.conn.send(self.handle, buffer) catch |err| switch (err) { - error.AccessDenied => error.AccessDenied, - error.WouldBlock => error.WouldBlock, - error.ConnectionResetByPeer => error.ConnectionResetByPeer, - error.MessageTooBig => error.FileTooBig, - error.BrokenPipe => error.BrokenPipe, - else => return error.Unexpected, - }; - } - - pub fn writeAll(self: Stream, bytes: []const u8) WriteError!void { - var index: usize = 0; - while (index < bytes.len) { - index += try self.write(bytes[index..]); - } - } - - /// See https://github.com/ziglang/zig/issues/7699 - /// See equivalent function: `std.fs.File.writev`. - pub fn writev(self: Stream, iovecs: []const posix.iovec_const) WriteError!usize { - if (iovecs.len == 0) return 0; - const first_buffer = iovecs[0].base[0..iovecs[0].len]; - return try self.write(first_buffer); - } - - /// The `iovecs` parameter is mutable because this function needs to mutate the fields in - /// order to handle partial writes from the underlying OS layer. - /// See https://github.com/ziglang/zig/issues/7699 - /// See equivalent function: `std.fs.File.writevAll`. - pub fn writevAll(self: Stream, iovecs: []posix.iovec_const) WriteError!void { - if (iovecs.len == 0) return; - - var i: usize = 0; - while (true) { - var amt = try self.writev(iovecs[i..]); - while (amt >= iovecs[i].len) { - amt -= iovecs[i].len; - i += 1; - if (i >= iovecs.len) return; - } - iovecs[i].base += amt; - iovecs[i].len -= amt; - } - } -}; diff --git a/src/async/tcp.zig b/src/async/tcp.zig deleted file mode 100644 index 61a49548..00000000 --- a/src/async/tcp.zig +++ /dev/null @@ -1,112 +0,0 @@ -// Copyright (C) 2023-2024 Lightpanda (Selecy SAS) -// -// Francis Bouvier -// Pierre Tachoire -// -// This program is free software: you can redistribute it and/or modify -// it under the terms of the GNU Affero General Public License as -// published by the Free Software Foundation, either version 3 of the -// License, or (at your option) any later version. -// -// This program is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU Affero General Public License for more details. -// -// You should have received a copy of the GNU Affero General Public License -// along with this program. If not, see . - -const std = @import("std"); -const net = std.net; -const Stream = @import("stream.zig").Stream; -const Loop = @import("jsruntime").Loop; -const NetworkImpl = Loop.Network(Conn.Command); - -// Conn is a TCP connection using jsruntime Loop async I/O. -// connect, send and receive are blocking, but use async I/O in the background. -// Client doesn't own the socket used for the connection, the caller is -// responsible for closing it. -pub const Conn = struct { - const Command = struct { - impl: NetworkImpl, - - done: bool = false, - err: ?anyerror = null, - ln: usize = 0, - - fn ok(self: *Command, err: ?anyerror, ln: usize) void { - self.err = err; - self.ln = ln; - self.done = true; - } - - fn wait(self: *Command) !usize { - while (!self.done) try self.impl.tick(); - - if (self.err) |err| return err; - return self.ln; - } - pub fn onConnect(self: *Command, err: ?anyerror) void { - self.ok(err, 0); - } - pub fn onSend(self: *Command, ln: usize, err: ?anyerror) void { - self.ok(err, ln); - } - pub fn onReceive(self: *Command, ln: usize, err: ?anyerror) void { - self.ok(err, ln); - } - }; - - loop: *Loop, - - pub fn connect(self: *Conn, socket: std.posix.socket_t, address: std.net.Address) !void { - var cmd = Command{ .impl = NetworkImpl.init(self.loop) }; - cmd.impl.connect(&cmd, socket, address); - _ = try cmd.wait(); - } - - pub fn send(self: *Conn, socket: std.posix.socket_t, buffer: []const u8) !usize { - var cmd = Command{ .impl = NetworkImpl.init(self.loop) }; - cmd.impl.send(&cmd, socket, buffer); - return try cmd.wait(); - } - - pub fn receive(self: *Conn, socket: std.posix.socket_t, buffer: []u8) !usize { - var cmd = Command{ .impl = NetworkImpl.init(self.loop) }; - cmd.impl.receive(&cmd, socket, buffer); - return try cmd.wait(); - } -}; - -pub fn tcpConnectToHost(alloc: std.mem.Allocator, loop: *Loop, name: []const u8, port: u16) !Stream { - // TODO async resolve - const list = try net.getAddressList(alloc, name, port); - defer list.deinit(); - - if (list.addrs.len == 0) return error.UnknownHostName; - - for (list.addrs) |addr| { - return tcpConnectToAddress(alloc, loop, addr) catch |err| switch (err) { - error.ConnectionRefused => { - continue; - }, - else => return err, - }; - } - return std.posix.ConnectError.ConnectionRefused; -} - -pub fn tcpConnectToAddress(alloc: std.mem.Allocator, loop: *Loop, addr: net.Address) !Stream { - const sockfd = try std.posix.socket(addr.any.family, std.posix.SOCK.STREAM, std.posix.IPPROTO.TCP); - errdefer std.posix.close(sockfd); - - var conn = try alloc.create(Conn); - conn.* = Conn{ .loop = loop }; - try conn.connect(sockfd, addr); - - return Stream{ - .alloc = alloc, - .conn = conn, - .handle = sockfd, - }; -} diff --git a/src/async/test.zig b/src/async/test.zig deleted file mode 100644 index 27f86c6a..00000000 --- a/src/async/test.zig +++ /dev/null @@ -1,189 +0,0 @@ -// Copyright (C) 2023-2024 Lightpanda (Selecy SAS) -// -// Francis Bouvier -// Pierre Tachoire -// -// This program is free software: you can redistribute it and/or modify -// it under the terms of the GNU Affero General Public License as -// published by the Free Software Foundation, either version 3 of the -// License, or (at your option) any later version. -// -// This program is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU Affero General Public License for more details. -// -// You should have received a copy of the GNU Affero General Public License -// along with this program. If not, see . - -const std = @import("std"); -const http = std.http; -const Client = @import("Client.zig"); -const Request = @import("Client.zig").Request; - -pub const Loop = @import("jsruntime").Loop; - -const url = "https://w3.org"; - -test "blocking mode fetch API" { - const alloc = std.testing.allocator; - - var loop = try Loop.init(alloc); - defer loop.deinit(); - - var client: Client = .{ - .allocator = alloc, - .loop = &loop, - }; - defer client.deinit(); - - // force client's CA cert scan from system. - try client.ca_bundle.rescan(client.allocator); - - const res = try client.fetch(.{ - .location = .{ .uri = try std.Uri.parse(url) }, - }); - - try std.testing.expect(res.status == .ok); -} - -test "blocking mode open/send/wait API" { - const alloc = std.testing.allocator; - - var loop = try Loop.init(alloc); - defer loop.deinit(); - - var client: Client = .{ - .allocator = alloc, - .loop = &loop, - }; - defer client.deinit(); - - // force client's CA cert scan from system. - try client.ca_bundle.rescan(client.allocator); - - var buf: [2014]u8 = undefined; - var req = try client.open(.GET, try std.Uri.parse(url), .{ - .server_header_buffer = &buf, - }); - defer req.deinit(); - - try req.send(); - try req.finish(); - try req.wait(); - - try std.testing.expect(req.response.status == .ok); -} - -// Example how to write an async http client using the modified standard client. -const AsyncClient = struct { - cli: Client, - - const YieldImpl = Loop.Yield(AsyncRequest); - const AsyncRequest = struct { - const State = enum { new, open, send, finish, wait, done }; - - cli: *Client, - uri: std.Uri, - - req: ?Request = undefined, - state: State = .new, - - impl: YieldImpl, - err: ?anyerror = null, - - buf: [2014]u8 = undefined, - - pub fn deinit(self: *AsyncRequest) void { - if (self.req) |*r| r.deinit(); - } - - pub fn fetch(self: *AsyncRequest) void { - self.state = .new; - return self.impl.yield(self); - } - - fn onerr(self: *AsyncRequest, err: anyerror) void { - self.state = .done; - self.err = err; - } - - pub fn onYield(self: *AsyncRequest, err: ?anyerror) void { - if (err) |e| return self.onerr(e); - - switch (self.state) { - .new => { - self.state = .open; - self.req = self.cli.open(.GET, self.uri, .{ - .server_header_buffer = &self.buf, - }) catch |e| return self.onerr(e); - }, - .open => { - self.state = .send; - self.req.?.send() catch |e| return self.onerr(e); - }, - .send => { - self.state = .finish; - self.req.?.finish() catch |e| return self.onerr(e); - }, - .finish => { - self.state = .wait; - self.req.?.wait() catch |e| return self.onerr(e); - }, - .wait => { - self.state = .done; - return; - }, - .done => return, - } - - return self.impl.yield(self); - } - - pub fn wait(self: *AsyncRequest) !void { - while (self.state != .done) try self.impl.tick(); - if (self.err) |err| return err; - } - }; - - pub fn init(alloc: std.mem.Allocator, loop: *Loop) AsyncClient { - return .{ - .cli = .{ - .allocator = alloc, - .loop = loop, - }, - }; - } - - pub fn deinit(self: *AsyncClient) void { - self.cli.deinit(); - } - - pub fn createRequest(self: *AsyncClient, uri: std.Uri) !AsyncRequest { - return .{ - .impl = YieldImpl.init(self.cli.loop), - .cli = &self.cli, - .uri = uri, - }; - } -}; - -test "non blocking client" { - const alloc = std.testing.allocator; - - var loop = try Loop.init(alloc); - defer loop.deinit(); - - var client = AsyncClient.init(alloc, &loop); - defer client.deinit(); - - var reqs: [3]AsyncClient.AsyncRequest = undefined; - for (0..reqs.len) |i| { - reqs[i] = try client.createRequest(try std.Uri.parse(url)); - reqs[i].fetch(); - } - for (0..reqs.len) |i| { - try reqs[i].wait(); - reqs[i].deinit(); - } -} diff --git a/src/main_shell.zig b/src/main_shell.zig index fbb23660..ac803ae5 100644 --- a/src/main_shell.zig +++ b/src/main_shell.zig @@ -29,7 +29,7 @@ const html_test = @import("html_test.zig").html; pub const Types = jsruntime.reflect(apiweb.Interfaces); pub const UserContext = apiweb.UserContext; -const Client = @import("async/Client.zig"); +const Client = @import("http/async/main.zig").Client; var doc: *parser.DocumentHTML = undefined; @@ -41,7 +41,7 @@ fn execJS( try js_env.start(); defer js_env.stop(); - var cli = Client{ .allocator = alloc, .loop = js_env.nat_ctx.loop }; + var cli = Client{ .allocator = alloc }; defer cli.deinit(); try js_env.setUserContext(UserContext{ diff --git a/src/run_tests.zig b/src/run_tests.zig index b50dbce8..bc0ca2db 100644 --- a/src/run_tests.zig +++ b/src/run_tests.zig @@ -298,7 +298,7 @@ test { const msgTest = @import("msg.zig"); std.testing.refAllDecls(msgTest); - const asyncTest = @import("async/test.zig"); + const asyncTest = @import("http/async/std/http.zig"); std.testing.refAllDecls(asyncTest); const dumpTest = @import("browser/dump.zig"); From 18ab0c8199b2f90a7743cef59f138822c272a899 Mon Sep 17 00:00:00 2001 From: Pierre Tachoire Date: Fri, 15 Nov 2024 15:05:35 +0100 Subject: [PATCH 04/11] cdp: replace tick by run_for_ns --- src/server.zig | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/server.zig b/src/server.zig index 505de7b3..9b34c602 100644 --- a/src/server.zig +++ b/src/server.zig @@ -482,7 +482,7 @@ pub fn listen( // - cmd from incoming connection on server socket // - JS callbacks events from scripts while (true) { - try loop.io.tick(); + try loop.io.run_for_ns(10 * std.time.ns_per_ms); if (loop.cbk_error) { log.err("JS error", .{}); // if (try try_catch.exception(alloc, js_env.*)) |msg| { From 9149b601365a512a5c31e681f941f1f08bcb5dd4 Mon Sep 17 00:00:00 2001 From: Pierre Tachoire Date: Mon, 18 Nov 2024 11:10:32 +0100 Subject: [PATCH 05/11] async: remove dead code --- src/http/async/loop.zig | 75 ----------------------------------------- 1 file changed, 75 deletions(-) delete mode 100644 src/http/async/loop.zig diff --git a/src/http/async/loop.zig b/src/http/async/loop.zig deleted file mode 100644 index 1b18d0f7..00000000 --- a/src/http/async/loop.zig +++ /dev/null @@ -1,75 +0,0 @@ -const std = @import("std"); -const Client = @import("std/http/Client.zig"); - -const Stack = @import("stack.zig"); - -const Res = fn (ctx: *Ctx, res: ?anyerror) anyerror!void; - -pub const Blocking = struct { - pub fn connect( - _: *Blocking, - comptime ctxT: type, - ctx: *ctxT, - comptime cbk: Res, - socket: std.os.socket_t, - address: std.net.Address, - ) void { - std.os.connect(socket, &address.any, address.getOsSockLen()) catch |err| { - std.os.closeSocket(socket); - _ = cbk(ctx, err); - return; - }; - ctx.socket = socket; - _ = cbk(ctx, null); - } -}; - -const CtxStack = Stack(Res); - -pub const Ctx = struct { - alloc: std.mem.Allocator, - stack: ?*CtxStack = null, - - // TCP ctx - client: *Client = undefined, - addr_current: usize = undefined, - list: *std.net.AddressList = undefined, - socket: std.os.socket_t = undefined, - Stream: std.net.Stream = undefined, - host: []const u8 = undefined, - port: u16 = undefined, - protocol: Client.Connection.Protocol = undefined, - conn: *Client.Connection = undefined, - uri: std.Uri = undefined, - headers: std.http.Headers = undefined, - method: std.http.Method = undefined, - options: Client.RequestOptions = undefined, - request: Client.Request = undefined, - - err: ?anyerror, - - pub fn init(alloc: std.mem.Allocator) Ctx { - return .{ .alloc = alloc }; - } - - pub fn push(self: *Ctx, function: CtxStack.Fn) !void { - if (self.stack) |stack| { - return try stack.push(self.alloc, function); - } - self.stack = try CtxStack.init(self.alloc, function); - } - - pub fn next(self: *Ctx, err: ?anyerror) !void { - if (self.stack) |stack| { - const last = stack.next == null; - const function = stack.pop(self.alloc, stack); - const res = @call(.auto, function, .{ self, err }); - if (last) { - self.stack = null; - self.alloc.destroy(stack); - } - return res; - } - self.err = err; - } -}; From d2d2e851b0778b6753e86257a4065bb49f58795b Mon Sep 17 00:00:00 2001 From: Pierre Tachoire Date: Mon, 18 Nov 2024 11:32:13 +0100 Subject: [PATCH 06/11] async: fix assync call pop error --- src/http/async/tls.zig/connection.zig | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/http/async/tls.zig/connection.zig b/src/http/async/tls.zig/connection.zig index 9ccd9c53..7a6afcbe 100644 --- a/src/http/async/tls.zig/connection.zig +++ b/src/http/async/tls.zig/connection.zig @@ -292,7 +292,7 @@ pub fn Connection(comptime Stream: type) type { if (read_buf_len == 0) { // read another buffer - c.async_next(ctx.stream(), ctx, onReadv) catch |err| return ctx.pop(err); + return c.async_next(ctx.stream(), ctx, onReadv) catch |err| return ctx.pop(err); } ctx._tls_read_buf = ctx._tls_read_buf.?[n..]; From e1137274fb6ca4a229fabfa33c404c06c23a0871 Mon Sep 17 00:00:00 2001 From: Pierre Tachoire Date: Tue, 19 Nov 2024 15:51:36 +0100 Subject: [PATCH 07/11] async: remove dead code --- src/http/async/main.zig | 1 - 1 file changed, 1 deletion(-) diff --git a/src/http/async/main.zig b/src/http/async/main.zig index ea756e8b..c3ef9934 100644 --- a/src/http/async/main.zig +++ b/src/http/async/main.zig @@ -1,4 +1,3 @@ const std = @import("std"); -const stack = @import("stack.zig"); pub const Client = @import("std/http/Client.zig"); From 395eb3e8add7e488fbacbb68eb66b56ecc167721 Mon Sep 17 00:00:00 2001 From: Pierre Tachoire Date: Tue, 19 Nov 2024 15:53:12 +0100 Subject: [PATCH 08/11] async: add missing tests execution --- src/run_tests.zig | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/run_tests.zig b/src/run_tests.zig index bc0ca2db..8e285840 100644 --- a/src/run_tests.zig +++ b/src/run_tests.zig @@ -298,8 +298,8 @@ test { const msgTest = @import("msg.zig"); std.testing.refAllDecls(msgTest); - const asyncTest = @import("http/async/std/http.zig"); - std.testing.refAllDecls(asyncTest); + std.testing.refAllDecls(@import("http/async/std/http.zig")); + std.testing.refAllDecls(@import("http/async/stack.zig")); const dumpTest = @import("browser/dump.zig"); std.testing.refAllDecls(dumpTest); From 70752027f1517f7d6088988d6832d409fcbcdae5 Mon Sep 17 00:00:00 2001 From: Pierre Tachoire Date: Tue, 19 Nov 2024 15:55:26 +0100 Subject: [PATCH 09/11] async: remove @This from SigleThreaded --- src/http/async/io.zig | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/src/http/async/io.zig b/src/http/async/io.zig index 7c2aad5b..4a11b5b6 100644 --- a/src/http/async/io.zig +++ b/src/http/async/io.zig @@ -64,10 +64,9 @@ pub const SingleThreaded = struct { cbk: Cbk, ctx: *Ctx, - const Self = @This(); const Cbk = *const fn (ctx: *Ctx, res: anyerror!void) anyerror!void; - pub fn init(loop: *Loop) Self { + pub fn init(loop: *Loop) SingleThreaded { return .{ .impl = NetworkImpl.init(loop), .cbk = undefined, @@ -76,7 +75,7 @@ pub const SingleThreaded = struct { } pub fn connect( - self: *Self, + self: *SingleThreaded, comptime _: type, ctx: *Ctx, comptime cbk: Cbk, @@ -88,13 +87,13 @@ pub const SingleThreaded = struct { self.impl.connect(self, socket, address); } - pub fn onConnect(self: *Self, err: ?anyerror) void { + pub fn onConnect(self: *SingleThreaded, err: ?anyerror) void { if (err) |e| return self.ctx.setErr(e); self.cbk(self.ctx, {}) catch |e| self.ctx.setErr(e); } pub fn send( - self: *Self, + self: *SingleThreaded, comptime _: type, ctx: *Ctx, comptime cbk: Cbk, @@ -106,14 +105,14 @@ pub const SingleThreaded = struct { self.impl.send(self, socket, buf); } - pub fn onSend(self: *Self, ln: usize, err: ?anyerror) void { + pub fn onSend(self: *SingleThreaded, ln: usize, err: ?anyerror) void { if (err) |e| return self.ctx.setErr(e); self.ctx.setLen(ln); self.cbk(self.ctx, {}) catch |e| self.ctx.setErr(e); } pub fn recv( - self: *Self, + self: *SingleThreaded, comptime _: type, ctx: *Ctx, comptime cbk: Cbk, @@ -125,7 +124,7 @@ pub const SingleThreaded = struct { self.impl.receive(self, socket, buf); } - pub fn onReceive(self: *Self, ln: usize, err: ?anyerror) void { + pub fn onReceive(self: *SingleThreaded, ln: usize, err: ?anyerror) void { if (err) |e| return self.ctx.setErr(e); self.ctx.setLen(ln); self.cbk(self.ctx, {}) catch |e| self.ctx.setErr(e); From de286dd78efafddcb3b39251d302cd7c649cc0fe Mon Sep 17 00:00:00 2001 From: Francis Bouvier Date: Thu, 21 Nov 2024 12:27:00 +0100 Subject: [PATCH 10/11] async: use zig-async-io Signed-off-by: Francis Bouvier --- .gitmodules | 3 + build.zig | 5 + src/browser/browser.zig | 2 +- src/http/async/io.zig | 132 - src/http/async/main.zig | 3 - src/http/async/stack.zig | 95 - src/http/async/std/http.zig | 318 --- src/http/async/std/http/Client.zig | 2512 ----------------- src/http/async/std/http/Server.zig | 1148 -------- src/http/async/std/http/protocol.zig | 447 --- src/http/async/std/net.zig | 2050 -------------- src/http/async/std/net/test.zig | 335 --- src/http/async/tls.zig/PrivateKey.zig | 260 -- src/http/async/tls.zig/cbc/main.zig | 148 - src/http/async/tls.zig/cipher.zig | 1004 ------- src/http/async/tls.zig/connection.zig | 665 ----- src/http/async/tls.zig/handshake_client.zig | 955 ------- src/http/async/tls.zig/handshake_common.zig | 448 --- src/http/async/tls.zig/handshake_server.zig | 520 ---- src/http/async/tls.zig/key_log.zig | 60 - src/http/async/tls.zig/main.zig | 51 - src/http/async/tls.zig/protocol.zig | 302 -- src/http/async/tls.zig/record.zig | 405 --- src/http/async/tls.zig/rsa/der.zig | 467 --- src/http/async/tls.zig/rsa/oid.zig | 132 - src/http/async/tls.zig/rsa/rsa.zig | 880 ------ .../async/tls.zig/rsa/testdata/id_rsa.der | Bin 1191 -> 0 bytes .../testdata/ec_prime256v1_private_key.pem | 5 - .../async/tls.zig/testdata/ec_private_key.pem | 6 - .../testdata/ec_secp384r1_private_key.pem | 6 - .../testdata/ec_secp521r1_private_key.pem | 7 - .../tls.zig/testdata/google.com/client_random | 1 - .../tls.zig/testdata/google.com/server_hello | Bin 7158 -> 0 bytes .../tls.zig/testdata/rsa_private_key.pem | 28 - src/http/async/tls.zig/testdata/tls12.zig | 244 -- src/http/async/tls.zig/testdata/tls13.zig | 64 - src/http/async/tls.zig/testu.zig | 117 - src/http/async/tls.zig/transcript.zig | 297 -- src/main.zig | 1 + src/main_shell.zig | 3 +- src/main_wpt.zig | 1 + src/run_tests.zig | 6 +- src/test_runner.zig | 1 + src/user_context.zig | 2 +- src/wpt/run.zig | 2 +- src/xhr/xhr.zig | 8 +- vendor/zig-async-io | 1 + vendor/zig-js-runtime | 2 +- 48 files changed, 24 insertions(+), 14125 deletions(-) delete mode 100644 src/http/async/io.zig delete mode 100644 src/http/async/main.zig delete mode 100644 src/http/async/stack.zig delete mode 100644 src/http/async/std/http.zig delete mode 100644 src/http/async/std/http/Client.zig delete mode 100644 src/http/async/std/http/Server.zig delete mode 100644 src/http/async/std/http/protocol.zig delete mode 100644 src/http/async/std/net.zig delete mode 100644 src/http/async/std/net/test.zig delete mode 100644 src/http/async/tls.zig/PrivateKey.zig delete mode 100644 src/http/async/tls.zig/cbc/main.zig delete mode 100644 src/http/async/tls.zig/cipher.zig delete mode 100644 src/http/async/tls.zig/connection.zig delete mode 100644 src/http/async/tls.zig/handshake_client.zig delete mode 100644 src/http/async/tls.zig/handshake_common.zig delete mode 100644 src/http/async/tls.zig/handshake_server.zig delete mode 100644 src/http/async/tls.zig/key_log.zig delete mode 100644 src/http/async/tls.zig/main.zig delete mode 100644 src/http/async/tls.zig/protocol.zig delete mode 100644 src/http/async/tls.zig/record.zig delete mode 100644 src/http/async/tls.zig/rsa/der.zig delete mode 100644 src/http/async/tls.zig/rsa/oid.zig delete mode 100644 src/http/async/tls.zig/rsa/rsa.zig delete mode 100644 src/http/async/tls.zig/rsa/testdata/id_rsa.der delete mode 100644 src/http/async/tls.zig/testdata/ec_prime256v1_private_key.pem delete mode 100644 src/http/async/tls.zig/testdata/ec_private_key.pem delete mode 100644 src/http/async/tls.zig/testdata/ec_secp384r1_private_key.pem delete mode 100644 src/http/async/tls.zig/testdata/ec_secp521r1_private_key.pem delete mode 100644 src/http/async/tls.zig/testdata/google.com/client_random delete mode 100644 src/http/async/tls.zig/testdata/google.com/server_hello delete mode 100644 src/http/async/tls.zig/testdata/rsa_private_key.pem delete mode 100644 src/http/async/tls.zig/testdata/tls12.zig delete mode 100644 src/http/async/tls.zig/testdata/tls13.zig delete mode 100644 src/http/async/tls.zig/testu.zig delete mode 100644 src/http/async/tls.zig/transcript.zig create mode 160000 vendor/zig-async-io diff --git a/.gitmodules b/.gitmodules index 2ceea970..229d1a16 100644 --- a/.gitmodules +++ b/.gitmodules @@ -25,3 +25,6 @@ [submodule "vendor/tls.zig"] path = vendor/tls.zig url = git@github.com:ianic/tls.zig.git +[submodule "vendor/zig-async-io"] + path = vendor/zig-async-io + url = git@github.com:lightpanda-io/zig-async-io.git diff --git a/build.zig b/build.zig index 86ad4ef9..8c83d648 100644 --- a/build.zig +++ b/build.zig @@ -159,6 +159,11 @@ fn common( netsurf.addImport("jsruntime", jsruntimemod); step.root_module.addImport("netsurf", netsurf); + const asyncio = b.addModule("asyncio", .{ + .root_source_file = b.path("vendor/zig-async-io/src/lib.zig"), + }); + step.root_module.addImport("asyncio", asyncio); + const tlsmod = b.addModule("tls", .{ .root_source_file = b.path("vendor/tls.zig/src/main.zig"), }); diff --git a/src/browser/browser.zig b/src/browser/browser.zig index e6f646ef..0b58cbaa 100644 --- a/src/browser/browser.zig +++ b/src/browser/browser.zig @@ -40,7 +40,7 @@ const storage = @import("../storage/storage.zig"); const FetchResult = @import("../http/Client.zig").Client.FetchResult; const UserContext = @import("../user_context.zig").UserContext; -const HttpClient = @import("../http/async/main.zig").Client; +const HttpClient = @import("asyncio").Client; const log = std.log.scoped(.browser); diff --git a/src/http/async/io.zig b/src/http/async/io.zig deleted file mode 100644 index 4a11b5b6..00000000 --- a/src/http/async/io.zig +++ /dev/null @@ -1,132 +0,0 @@ -const std = @import("std"); - -const Ctx = @import("std/http/Client.zig").Ctx; -const Loop = @import("jsruntime").Loop; -const NetworkImpl = Loop.Network(SingleThreaded); - -pub const Blocking = struct { - pub fn connect( - _: *Blocking, - comptime CtxT: type, - ctx: *CtxT, - comptime cbk: fn (ctx: *CtxT, res: anyerror!void) anyerror!void, - socket: std.posix.socket_t, - address: std.net.Address, - ) void { - std.posix.connect(socket, &address.any, address.getOsSockLen()) catch |err| { - std.posix.close(socket); - cbk(ctx, err) catch |e| { - ctx.setErr(e); - }; - }; - cbk(ctx, {}) catch |e| ctx.setErr(e); - } - - pub fn send( - _: *Blocking, - comptime CtxT: type, - ctx: *CtxT, - comptime cbk: fn (ctx: *CtxT, res: anyerror!void) anyerror!void, - socket: std.posix.socket_t, - buf: []const u8, - ) void { - const len = std.posix.write(socket, buf) catch |err| { - cbk(ctx, err) catch |e| { - return ctx.setErr(e); - }; - return ctx.setErr(err); - }; - ctx.setLen(len); - cbk(ctx, {}) catch |e| ctx.setErr(e); - } - - pub fn recv( - _: *Blocking, - comptime CtxT: type, - ctx: *CtxT, - comptime cbk: fn (ctx: *CtxT, res: anyerror!void) anyerror!void, - socket: std.posix.socket_t, - buf: []u8, - ) void { - const len = std.posix.read(socket, buf) catch |err| { - cbk(ctx, err) catch |e| { - return ctx.setErr(e); - }; - return ctx.setErr(err); - }; - ctx.setLen(len); - cbk(ctx, {}) catch |e| ctx.setErr(e); - } -}; - -pub const SingleThreaded = struct { - impl: NetworkImpl, - cbk: Cbk, - ctx: *Ctx, - - const Cbk = *const fn (ctx: *Ctx, res: anyerror!void) anyerror!void; - - pub fn init(loop: *Loop) SingleThreaded { - return .{ - .impl = NetworkImpl.init(loop), - .cbk = undefined, - .ctx = undefined, - }; - } - - pub fn connect( - self: *SingleThreaded, - comptime _: type, - ctx: *Ctx, - comptime cbk: Cbk, - socket: std.posix.socket_t, - address: std.net.Address, - ) void { - self.cbk = cbk; - self.ctx = ctx; - self.impl.connect(self, socket, address); - } - - pub fn onConnect(self: *SingleThreaded, err: ?anyerror) void { - if (err) |e| return self.ctx.setErr(e); - self.cbk(self.ctx, {}) catch |e| self.ctx.setErr(e); - } - - pub fn send( - self: *SingleThreaded, - comptime _: type, - ctx: *Ctx, - comptime cbk: Cbk, - socket: std.posix.socket_t, - buf: []const u8, - ) void { - self.ctx = ctx; - self.cbk = cbk; - self.impl.send(self, socket, buf); - } - - pub fn onSend(self: *SingleThreaded, ln: usize, err: ?anyerror) void { - if (err) |e| return self.ctx.setErr(e); - self.ctx.setLen(ln); - self.cbk(self.ctx, {}) catch |e| self.ctx.setErr(e); - } - - pub fn recv( - self: *SingleThreaded, - comptime _: type, - ctx: *Ctx, - comptime cbk: Cbk, - socket: std.posix.socket_t, - buf: []u8, - ) void { - self.ctx = ctx; - self.cbk = cbk; - self.impl.receive(self, socket, buf); - } - - pub fn onReceive(self: *SingleThreaded, ln: usize, err: ?anyerror) void { - if (err) |e| return self.ctx.setErr(e); - self.ctx.setLen(ln); - self.cbk(self.ctx, {}) catch |e| self.ctx.setErr(e); - } -}; diff --git a/src/http/async/main.zig b/src/http/async/main.zig deleted file mode 100644 index c3ef9934..00000000 --- a/src/http/async/main.zig +++ /dev/null @@ -1,3 +0,0 @@ -const std = @import("std"); - -pub const Client = @import("std/http/Client.zig"); diff --git a/src/http/async/stack.zig b/src/http/async/stack.zig deleted file mode 100644 index d19a0c8f..00000000 --- a/src/http/async/stack.zig +++ /dev/null @@ -1,95 +0,0 @@ -const std = @import("std"); - -pub fn Stack(comptime T: type) type { - return struct { - const Self = @This(); - pub const Fn = *const T; - - next: ?*Self = null, - func: Fn, - - pub fn init(alloc: std.mem.Allocator, comptime func: Fn) !*Self { - const next = try alloc.create(Self); - next.* = .{ .func = func }; - return next; - } - - pub fn push(self: *Self, alloc: std.mem.Allocator, comptime func: Fn) !void { - if (self.next) |next| { - return next.push(alloc, func); - } - self.next = try Self.init(alloc, func); - } - - pub fn pop(self: *Self, alloc: std.mem.Allocator, prev: ?*Self) Fn { - if (self.next) |next| { - return next.pop(alloc, self); - } - defer { - if (prev) |p| { - self.deinit(alloc, p); - } - } - return self.func; - } - - pub fn deinit(self: *Self, alloc: std.mem.Allocator, prev: ?*Self) void { - if (self.next) |next| { - // recursivly deinit - next.deinit(alloc, self); - } - if (prev) |p| { - p.next = null; - } - alloc.destroy(self); - } - }; -} - -fn first() u8 { - return 1; -} - -fn second() u8 { - return 2; -} - -test "stack" { - const alloc = std.testing.allocator; - const TestStack = Stack(fn () u8); - - var stack = TestStack{ .func = first }; - try stack.push(alloc, second); - - const a = stack.pop(alloc, null); - try std.testing.expect(a() == 2); - - const b = stack.pop(alloc, null); - try std.testing.expect(b() == 1); -} - -fn first_op(arg: ?*anyopaque) u8 { - const val = @as(*u8, @ptrCast(arg)); - return val.* + @as(u8, 1); -} - -fn second_op(arg: ?*anyopaque) u8 { - const val = @as(*u8, @ptrCast(arg)); - return val.* + @as(u8, 2); -} - -test "opaque stack" { - const alloc = std.testing.allocator; - const TestStack = Stack(fn (?*anyopaque) u8); - - var stack = TestStack{ .func = first_op }; - try stack.push(alloc, second_op); - - const a = stack.pop(alloc, null); - var x: u8 = 5; - try std.testing.expect(a(@as(*anyopaque, @ptrCast(&x))) == 2 + x); - - const b = stack.pop(alloc, null); - var y: u8 = 3; - try std.testing.expect(b(@as(*anyopaque, @ptrCast(&y))) == 1 + y); -} diff --git a/src/http/async/std/http.zig b/src/http/async/std/http.zig deleted file mode 100644 index f027d440..00000000 --- a/src/http/async/std/http.zig +++ /dev/null @@ -1,318 +0,0 @@ -pub const Client = @import("http/Client.zig"); -pub const Server = @import("http/Server.zig"); -pub const protocol = @import("http/protocol.zig"); -pub const HeadParser = std.http.HeadParser; -pub const ChunkParser = std.http.ChunkParser; -pub const HeaderIterator = std.http.HeaderIterator; - -pub const Version = enum { - @"HTTP/1.0", - @"HTTP/1.1", -}; - -/// https://developer.mozilla.org/en-US/docs/Web/HTTP/Methods -/// -/// https://datatracker.ietf.org/doc/html/rfc7231#section-4 Initial definition -/// -/// https://datatracker.ietf.org/doc/html/rfc5789#section-2 PATCH -pub const Method = enum(u64) { - GET = parse("GET"), - HEAD = parse("HEAD"), - POST = parse("POST"), - PUT = parse("PUT"), - DELETE = parse("DELETE"), - CONNECT = parse("CONNECT"), - OPTIONS = parse("OPTIONS"), - TRACE = parse("TRACE"), - PATCH = parse("PATCH"), - - _, - - /// Converts `s` into a type that may be used as a `Method` field. - /// Asserts that `s` is 24 or fewer bytes. - pub fn parse(s: []const u8) u64 { - var x: u64 = 0; - const len = @min(s.len, @sizeOf(@TypeOf(x))); - @memcpy(std.mem.asBytes(&x)[0..len], s[0..len]); - return x; - } - - pub fn write(self: Method, w: anytype) !void { - const bytes = std.mem.asBytes(&@intFromEnum(self)); - const str = std.mem.sliceTo(bytes, 0); - try w.writeAll(str); - } - - /// Returns true if a request of this method is allowed to have a body - /// Actual behavior from servers may vary and should still be checked - pub fn requestHasBody(self: Method) bool { - return switch (self) { - .POST, .PUT, .PATCH => true, - .GET, .HEAD, .DELETE, .CONNECT, .OPTIONS, .TRACE => false, - else => true, - }; - } - - /// Returns true if a response to this method is allowed to have a body - /// Actual behavior from clients may vary and should still be checked - pub fn responseHasBody(self: Method) bool { - return switch (self) { - .GET, .POST, .DELETE, .CONNECT, .OPTIONS, .PATCH => true, - .HEAD, .PUT, .TRACE => false, - else => true, - }; - } - - /// An HTTP method is safe if it doesn't alter the state of the server. - /// - /// https://developer.mozilla.org/en-US/docs/Glossary/Safe/HTTP - /// - /// https://datatracker.ietf.org/doc/html/rfc7231#section-4.2.1 - pub fn safe(self: Method) bool { - return switch (self) { - .GET, .HEAD, .OPTIONS, .TRACE => true, - .POST, .PUT, .DELETE, .CONNECT, .PATCH => false, - else => false, - }; - } - - /// An HTTP method is idempotent if an identical request can be made once or several times in a row with the same effect while leaving the server in the same state. - /// - /// https://developer.mozilla.org/en-US/docs/Glossary/Idempotent - /// - /// https://datatracker.ietf.org/doc/html/rfc7231#section-4.2.2 - pub fn idempotent(self: Method) bool { - return switch (self) { - .GET, .HEAD, .PUT, .DELETE, .OPTIONS, .TRACE => true, - .CONNECT, .POST, .PATCH => false, - else => false, - }; - } - - /// A cacheable response is an HTTP response that can be cached, that is stored to be retrieved and used later, saving a new request to the server. - /// - /// https://developer.mozilla.org/en-US/docs/Glossary/cacheable - /// - /// https://datatracker.ietf.org/doc/html/rfc7231#section-4.2.3 - pub fn cacheable(self: Method) bool { - return switch (self) { - .GET, .HEAD => true, - .POST, .PUT, .DELETE, .CONNECT, .OPTIONS, .TRACE, .PATCH => false, - else => false, - }; - } -}; - -/// https://developer.mozilla.org/en-US/docs/Web/HTTP/Status -pub const Status = enum(u10) { - @"continue" = 100, // RFC7231, Section 6.2.1 - switching_protocols = 101, // RFC7231, Section 6.2.2 - processing = 102, // RFC2518 - early_hints = 103, // RFC8297 - - ok = 200, // RFC7231, Section 6.3.1 - created = 201, // RFC7231, Section 6.3.2 - accepted = 202, // RFC7231, Section 6.3.3 - non_authoritative_info = 203, // RFC7231, Section 6.3.4 - no_content = 204, // RFC7231, Section 6.3.5 - reset_content = 205, // RFC7231, Section 6.3.6 - partial_content = 206, // RFC7233, Section 4.1 - multi_status = 207, // RFC4918 - already_reported = 208, // RFC5842 - im_used = 226, // RFC3229 - - multiple_choice = 300, // RFC7231, Section 6.4.1 - moved_permanently = 301, // RFC7231, Section 6.4.2 - found = 302, // RFC7231, Section 6.4.3 - see_other = 303, // RFC7231, Section 6.4.4 - not_modified = 304, // RFC7232, Section 4.1 - use_proxy = 305, // RFC7231, Section 6.4.5 - temporary_redirect = 307, // RFC7231, Section 6.4.7 - permanent_redirect = 308, // RFC7538 - - bad_request = 400, // RFC7231, Section 6.5.1 - unauthorized = 401, // RFC7235, Section 3.1 - payment_required = 402, // RFC7231, Section 6.5.2 - forbidden = 403, // RFC7231, Section 6.5.3 - not_found = 404, // RFC7231, Section 6.5.4 - method_not_allowed = 405, // RFC7231, Section 6.5.5 - not_acceptable = 406, // RFC7231, Section 6.5.6 - proxy_auth_required = 407, // RFC7235, Section 3.2 - request_timeout = 408, // RFC7231, Section 6.5.7 - conflict = 409, // RFC7231, Section 6.5.8 - gone = 410, // RFC7231, Section 6.5.9 - length_required = 411, // RFC7231, Section 6.5.10 - precondition_failed = 412, // RFC7232, Section 4.2][RFC8144, Section 3.2 - payload_too_large = 413, // RFC7231, Section 6.5.11 - uri_too_long = 414, // RFC7231, Section 6.5.12 - unsupported_media_type = 415, // RFC7231, Section 6.5.13][RFC7694, Section 3 - range_not_satisfiable = 416, // RFC7233, Section 4.4 - expectation_failed = 417, // RFC7231, Section 6.5.14 - teapot = 418, // RFC 7168, 2.3.3 - misdirected_request = 421, // RFC7540, Section 9.1.2 - unprocessable_entity = 422, // RFC4918 - locked = 423, // RFC4918 - failed_dependency = 424, // RFC4918 - too_early = 425, // RFC8470 - upgrade_required = 426, // RFC7231, Section 6.5.15 - precondition_required = 428, // RFC6585 - too_many_requests = 429, // RFC6585 - request_header_fields_too_large = 431, // RFC6585 - unavailable_for_legal_reasons = 451, // RFC7725 - - internal_server_error = 500, // RFC7231, Section 6.6.1 - not_implemented = 501, // RFC7231, Section 6.6.2 - bad_gateway = 502, // RFC7231, Section 6.6.3 - service_unavailable = 503, // RFC7231, Section 6.6.4 - gateway_timeout = 504, // RFC7231, Section 6.6.5 - http_version_not_supported = 505, // RFC7231, Section 6.6.6 - variant_also_negotiates = 506, // RFC2295 - insufficient_storage = 507, // RFC4918 - loop_detected = 508, // RFC5842 - not_extended = 510, // RFC2774 - network_authentication_required = 511, // RFC6585 - - _, - - pub fn phrase(self: Status) ?[]const u8 { - return switch (self) { - // 1xx statuses - .@"continue" => "Continue", - .switching_protocols => "Switching Protocols", - .processing => "Processing", - .early_hints => "Early Hints", - - // 2xx statuses - .ok => "OK", - .created => "Created", - .accepted => "Accepted", - .non_authoritative_info => "Non-Authoritative Information", - .no_content => "No Content", - .reset_content => "Reset Content", - .partial_content => "Partial Content", - .multi_status => "Multi-Status", - .already_reported => "Already Reported", - .im_used => "IM Used", - - // 3xx statuses - .multiple_choice => "Multiple Choice", - .moved_permanently => "Moved Permanently", - .found => "Found", - .see_other => "See Other", - .not_modified => "Not Modified", - .use_proxy => "Use Proxy", - .temporary_redirect => "Temporary Redirect", - .permanent_redirect => "Permanent Redirect", - - // 4xx statuses - .bad_request => "Bad Request", - .unauthorized => "Unauthorized", - .payment_required => "Payment Required", - .forbidden => "Forbidden", - .not_found => "Not Found", - .method_not_allowed => "Method Not Allowed", - .not_acceptable => "Not Acceptable", - .proxy_auth_required => "Proxy Authentication Required", - .request_timeout => "Request Timeout", - .conflict => "Conflict", - .gone => "Gone", - .length_required => "Length Required", - .precondition_failed => "Precondition Failed", - .payload_too_large => "Payload Too Large", - .uri_too_long => "URI Too Long", - .unsupported_media_type => "Unsupported Media Type", - .range_not_satisfiable => "Range Not Satisfiable", - .expectation_failed => "Expectation Failed", - .teapot => "I'm a teapot", - .misdirected_request => "Misdirected Request", - .unprocessable_entity => "Unprocessable Entity", - .locked => "Locked", - .failed_dependency => "Failed Dependency", - .too_early => "Too Early", - .upgrade_required => "Upgrade Required", - .precondition_required => "Precondition Required", - .too_many_requests => "Too Many Requests", - .request_header_fields_too_large => "Request Header Fields Too Large", - .unavailable_for_legal_reasons => "Unavailable For Legal Reasons", - - // 5xx statuses - .internal_server_error => "Internal Server Error", - .not_implemented => "Not Implemented", - .bad_gateway => "Bad Gateway", - .service_unavailable => "Service Unavailable", - .gateway_timeout => "Gateway Timeout", - .http_version_not_supported => "HTTP Version Not Supported", - .variant_also_negotiates => "Variant Also Negotiates", - .insufficient_storage => "Insufficient Storage", - .loop_detected => "Loop Detected", - .not_extended => "Not Extended", - .network_authentication_required => "Network Authentication Required", - - else => return null, - }; - } - - pub const Class = enum { - informational, - success, - redirect, - client_error, - server_error, - }; - - pub fn class(self: Status) Class { - return switch (@intFromEnum(self)) { - 100...199 => .informational, - 200...299 => .success, - 300...399 => .redirect, - 400...499 => .client_error, - else => .server_error, - }; - } - - test { - try std.testing.expectEqualStrings("OK", Status.ok.phrase().?); - try std.testing.expectEqualStrings("Not Found", Status.not_found.phrase().?); - } - - test { - try std.testing.expectEqual(Status.Class.success, Status.ok.class()); - try std.testing.expectEqual(Status.Class.client_error, Status.not_found.class()); - } -}; - -pub const TransferEncoding = enum { - chunked, - none, - // compression is intentionally omitted here, as std.http.Client stores it as content-encoding -}; - -pub const ContentEncoding = enum { - identity, - compress, - @"x-compress", - deflate, - gzip, - @"x-gzip", - zstd, -}; - -pub const Connection = enum { - keep_alive, - close, -}; - -pub const Header = struct { - name: []const u8, - value: []const u8, -}; - -const builtin = @import("builtin"); -const std = @import("std"); - -test { - _ = Client; - _ = Method; - _ = Server; - _ = Status; -} diff --git a/src/http/async/std/http/Client.zig b/src/http/async/std/http/Client.zig deleted file mode 100644 index 2c866e6f..00000000 --- a/src/http/async/std/http/Client.zig +++ /dev/null @@ -1,2512 +0,0 @@ -//! HTTP(S) Client implementation. -//! -//! Connections are opened in a thread-safe manner, but individual Requests are not. -//! -//! TLS support may be disabled via `std.options.http_disable_tls`. - -const std = @import("std"); -const builtin = @import("builtin"); -const testing = std.testing; -const http = std.http; -const mem = std.mem; -const net = @import("../net.zig"); -const Uri = std.Uri; -const Allocator = mem.Allocator; -const assert = std.debug.assert; -const use_vectors = builtin.zig_backend != .stage2_x86_64; - -const Client = @This(); -const proto = @import("protocol.zig"); - -const tls23 = @import("../../tls.zig/main.zig"); -const VecPut = @import("../../tls.zig/connection.zig").VecPut; -const GenericStack = @import("../../stack.zig").Stack; -const async_io = @import("../../io.zig"); -pub const Loop = async_io.SingleThreaded; - -const cipher = @import("../../tls.zig/cipher.zig"); - -pub const disable_tls = std.options.http_disable_tls; - -/// Used for all client allocations. Must be thread-safe. -allocator: Allocator, - -ca_bundle: if (disable_tls) void else std.crypto.Certificate.Bundle = if (disable_tls) {} else .{}, -ca_bundle_mutex: std.Thread.Mutex = .{}, - -/// When this is `true`, the next time this client performs an HTTPS request, -/// it will first rescan the system for root certificates. -next_https_rescan_certs: bool = true, - -/// The pool of connections that can be reused (and currently in use). -connection_pool: ConnectionPool = .{}, - -/// If populated, all http traffic travels through this third party. -/// This field cannot be modified while the client has active connections. -/// Pointer to externally-owned memory. -http_proxy: ?*Proxy = null, -/// If populated, all https traffic travels through this third party. -/// This field cannot be modified while the client has active connections. -/// Pointer to externally-owned memory. -https_proxy: ?*Proxy = null, - -/// A set of linked lists of connections that can be reused. -pub const ConnectionPool = struct { - mutex: std.Thread.Mutex = .{}, - /// Open connections that are currently in use. - used: Queue = .{}, - /// Open connections that are not currently in use. - free: Queue = .{}, - free_len: usize = 0, - free_size: usize = 32, - - /// The criteria for a connection to be considered a match. - pub const Criteria = struct { - host: []const u8, - port: u16, - protocol: Connection.Protocol, - }; - - const Queue = std.DoublyLinkedList(Connection); - pub const Node = Queue.Node; - - /// Finds and acquires a connection from the connection pool matching the criteria. This function is threadsafe. - /// If no connection is found, null is returned. - pub fn findConnection(pool: *ConnectionPool, criteria: Criteria) ?*Connection { - pool.mutex.lock(); - defer pool.mutex.unlock(); - - var next = pool.free.last; - while (next) |node| : (next = node.prev) { - if (node.data.protocol != criteria.protocol) continue; - if (node.data.port != criteria.port) continue; - - // Domain names are case-insensitive (RFC 5890, Section 2.3.2.4) - if (!std.ascii.eqlIgnoreCase(node.data.host, criteria.host)) continue; - - pool.acquireUnsafe(node); - return &node.data; - } - - return null; - } - - /// Acquires an existing connection from the connection pool. This function is not threadsafe. - pub fn acquireUnsafe(pool: *ConnectionPool, node: *Node) void { - pool.free.remove(node); - pool.free_len -= 1; - - pool.used.append(node); - } - - /// Acquires an existing connection from the connection pool. This function is threadsafe. - pub fn acquire(pool: *ConnectionPool, node: *Node) void { - pool.mutex.lock(); - defer pool.mutex.unlock(); - - return pool.acquireUnsafe(node); - } - - /// Tries to release a connection back to the connection pool. This function is threadsafe. - /// If the connection is marked as closing, it will be closed instead. - /// - /// The allocator must be the owner of all nodes in this pool. - /// The allocator must be the owner of all resources associated with the connection. - pub fn release(pool: *ConnectionPool, allocator: Allocator, connection: *Connection) void { - pool.mutex.lock(); - defer pool.mutex.unlock(); - - const node: *Node = @fieldParentPtr("data", connection); - - pool.used.remove(node); - - if (node.data.closing or pool.free_size == 0) { - node.data.close(allocator); - return allocator.destroy(node); - } - - if (pool.free_len >= pool.free_size) { - const popped = pool.free.popFirst() orelse unreachable; - pool.free_len -= 1; - - popped.data.close(allocator); - allocator.destroy(popped); - } - - if (node.data.proxied) { - pool.free.prepend(node); // proxied connections go to the end of the queue, always try direct connections first - } else { - pool.free.append(node); - } - - pool.free_len += 1; - } - - /// Adds a newly created node to the pool of used connections. This function is threadsafe. - pub fn addUsed(pool: *ConnectionPool, node: *Node) void { - pool.mutex.lock(); - defer pool.mutex.unlock(); - - pool.used.append(node); - } - - /// Resizes the connection pool. This function is threadsafe. - /// - /// If the new size is smaller than the current size, then idle connections will be closed until the pool is the new size. - pub fn resize(pool: *ConnectionPool, allocator: Allocator, new_size: usize) void { - pool.mutex.lock(); - defer pool.mutex.unlock(); - - const next = pool.free.first; - _ = next; - while (pool.free_len > new_size) { - const popped = pool.free.popFirst() orelse unreachable; - pool.free_len -= 1; - - popped.data.close(allocator); - allocator.destroy(popped); - } - - pool.free_size = new_size; - } - - /// Frees the connection pool and closes all connections within. This function is threadsafe. - /// - /// All future operations on the connection pool will deadlock. - pub fn deinit(pool: *ConnectionPool, allocator: Allocator) void { - pool.mutex.lock(); - - var next = pool.free.first; - while (next) |node| { - defer allocator.destroy(node); - next = node.next; - - node.data.close(allocator); - } - - next = pool.used.first; - while (next) |node| { - defer allocator.destroy(node); - next = node.next; - - node.data.close(allocator); - } - - pool.* = undefined; - } -}; - -/// An interface to either a plain or TLS connection. -pub const Connection = struct { - stream: net.Stream, - /// undefined unless protocol is tls. - tls_client: if (!disable_tls) *tls23.Connection(net.Stream) else void, - - /// The protocol that this connection is using. - protocol: Protocol, - - /// The host that this connection is connected to. - host: []u8, - - /// The port that this connection is connected to. - port: u16, - - /// Whether this connection is proxied and is not directly connected. - proxied: bool = false, - - /// Whether this connection is closing when we're done with it. - closing: bool = false, - - read_start: BufferSize = 0, - read_end: BufferSize = 0, - write_end: BufferSize = 0, - read_buf: [buffer_size]u8 = undefined, - write_buf: [buffer_size]u8 = undefined, - - pub const buffer_size = std.crypto.tls.max_ciphertext_record_len; - const BufferSize = std.math.IntFittingRange(0, buffer_size); - - pub const Protocol = enum { plain, tls }; - - pub fn async_readvDirect( - conn: *Connection, - buffers: []std.posix.iovec, - ctx: *Ctx, - comptime cbk: Cbk, - ) !void { - _ = conn; - - if (ctx.conn().protocol == .tls) { - if (disable_tls) unreachable; - - return ctx.conn().tls_client.async_readv(ctx.conn().stream, buffers, ctx, cbk); - } - - return ctx.stream().async_readv(buffers, ctx, cbk); - } - - pub fn readvDirectTls(conn: *Connection, buffers: []std.posix.iovec) ReadError!usize { - return conn.tls_client.readv(buffers) catch |err| { - // https://github.com/ziglang/zig/issues/2473 - if (mem.startsWith(u8, @errorName(err), "TlsAlert")) return error.TlsAlert; - - switch (err) { - error.TlsRecordOverflow, error.TlsBadRecordMac, error.TlsUnexpectedMessage => return error.TlsFailure, - error.ConnectionTimedOut => return error.ConnectionTimedOut, - error.ConnectionResetByPeer, error.BrokenPipe => return error.ConnectionResetByPeer, - else => return error.UnexpectedReadFailure, - } - }; - } - - pub fn readvDirect(conn: *Connection, buffers: []std.posix.iovec) ReadError!usize { - if (conn.protocol == .tls) { - if (disable_tls) unreachable; - - return conn.readvDirectTls(buffers); - } - - return conn.stream.readv(buffers) catch |err| switch (err) { - error.ConnectionTimedOut => return error.ConnectionTimedOut, - error.ConnectionResetByPeer, error.BrokenPipe => return error.ConnectionResetByPeer, - else => return error.UnexpectedReadFailure, - }; - } - - fn onFill(ctx: *Ctx, res: anyerror!void) anyerror!void { - ctx.alloc().free(ctx._iovecs); - res catch |err| return ctx.pop(err); - - // EOF - const nread = ctx.len(); - if (nread == 0) return ctx.pop(error.EndOfStream); - - // finished - ctx.conn().read_start = 0; - ctx.conn().read_end = @intCast(nread); - return ctx.pop({}); - } - - pub fn async_fill(conn: *Connection, ctx: *Ctx, comptime cbk: Cbk) !void { - if (conn.read_end != conn.read_start) return; - - ctx._iovecs = try ctx.alloc().alloc(std.posix.iovec, 1); - errdefer ctx.alloc().free(ctx._iovecs); - const iovecs = [1]std.posix.iovec{ - .{ .base = &conn.read_buf, .len = conn.read_buf.len }, - }; - @memcpy(ctx._iovecs, &iovecs); - - try ctx.push(cbk); - return conn.async_readvDirect(ctx._iovecs, ctx, onFill); - } - - /// Refills the read buffer with data from the connection. - pub fn fill(conn: *Connection) ReadError!void { - if (conn.read_end != conn.read_start) return; - - var iovecs = [1]std.posix.iovec{ - .{ .base = &conn.read_buf, .len = conn.read_buf.len }, - }; - const nread = try conn.readvDirect(&iovecs); - if (nread == 0) return error.EndOfStream; - conn.read_start = 0; - conn.read_end = @intCast(nread); - } - - /// Returns the current slice of buffered data. - pub fn peek(conn: *Connection) []const u8 { - return conn.read_buf[conn.read_start..conn.read_end]; - } - - /// Discards the given number of bytes from the read buffer. - pub fn drop(conn: *Connection, num: BufferSize) void { - conn.read_start += num; - } - - /// Reads data from the connection into the given buffer. - pub fn read(conn: *Connection, buffer: []u8) ReadError!usize { - const available_read = conn.read_end - conn.read_start; - const available_buffer = buffer.len; - - if (available_read > available_buffer) { // partially read buffered data - @memcpy(buffer[0..available_buffer], conn.read_buf[conn.read_start..conn.read_end][0..available_buffer]); - conn.read_start += @intCast(available_buffer); - - return available_buffer; - } else if (available_read > 0) { // fully read buffered data - @memcpy(buffer[0..available_read], conn.read_buf[conn.read_start..conn.read_end]); - conn.read_start += available_read; - - return available_read; - } - - var iovecs = [2]std.posix.iovec{ - .{ .base = buffer.ptr, .len = buffer.len }, - .{ .base = &conn.read_buf, .len = conn.read_buf.len }, - }; - const nread = try conn.readvDirect(&iovecs); - - if (nread > buffer.len) { - conn.read_start = 0; - conn.read_end = @intCast(nread - buffer.len); - return buffer.len; - } - - return nread; - } - - pub const ReadError = error{ - TlsFailure, - TlsAlert, - ConnectionTimedOut, - ConnectionResetByPeer, - UnexpectedReadFailure, - EndOfStream, - }; - - pub const Reader = std.io.Reader(*Connection, ReadError, read); - - pub fn reader(conn: *Connection) Reader { - return Reader{ .context = conn }; - } - - pub fn writeAllDirectTls(conn: *Connection, buffer: []const u8) WriteError!void { - return conn.tls_client.writeAll(buffer) catch |err| switch (err) { - error.BrokenPipe, error.ConnectionResetByPeer => return error.ConnectionResetByPeer, - else => return error.UnexpectedWriteFailure, - }; - } - - fn onWriteAllDirect(ctx: *Ctx, res: anyerror!void) !void { - res catch |err| switch (err) { - error.BrokenPipe, - error.ConnectionResetByPeer, - => return ctx.pop(error.ConnectionResetByPeer), - else => return ctx.pop(error.UnexpectedWriteFailure), - }; - return ctx.pop({}); - } - - pub fn async_writeAllDirect( - conn: *Connection, - buffer: []const u8, - ctx: *Ctx, - comptime cbk: Cbk, - ) !void { - try ctx.push(cbk); - if (conn.protocol == .tls) { - if (disable_tls) unreachable; - - return conn.tls_client.async_writeAll(conn.stream, buffer, ctx, onWriteAllDirect); - } - - return conn.stream.async_writeAll(buffer, ctx, onWriteAllDirect); - } - - pub fn writeAllDirect(conn: *Connection, buffer: []const u8) WriteError!void { - if (conn.protocol == .tls) { - if (disable_tls) unreachable; - - return conn.writeAllDirectTls(buffer); - } - - return conn.stream.writeAll(buffer) catch |err| switch (err) { - error.BrokenPipe, error.ConnectionResetByPeer => return error.ConnectionResetByPeer, - else => return error.UnexpectedWriteFailure, - }; - } - - /// Writes the given buffer to the connection. - pub fn write(conn: *Connection, buffer: []const u8) WriteError!usize { - if (conn.write_buf.len - conn.write_end < buffer.len) { - try conn.flush(); - - if (buffer.len > conn.write_buf.len) { - try conn.writeAllDirect(buffer); - return buffer.len; - } - } - - @memcpy(conn.write_buf[conn.write_end..][0..buffer.len], buffer); - conn.write_end += @intCast(buffer.len); - - return buffer.len; - } - - /// Returns a buffer to be filled with exactly len bytes to write to the connection. - pub fn allocWriteBuffer(conn: *Connection, len: BufferSize) WriteError![]u8 { - if (conn.write_buf.len - conn.write_end < len) try conn.flush(); - defer conn.write_end += len; - return conn.write_buf[conn.write_end..][0..len]; - } - - fn onFlush(ctx: *Ctx, res: anyerror!void) !void { - res catch |err| return ctx.pop(err); - ctx.conn().write_end = 0; - return ctx.pop({}); - } - - pub fn async_flush(conn: *Connection, ctx: *Ctx, comptime cbk: Cbk) !void { - if (conn.write_end == 0) return error.WriteEmpty; - - try ctx.push(cbk); - try conn.async_writeAllDirect(conn.write_buf[0..conn.write_end], ctx, onFlush); - } - - /// Flushes the write buffer to the connection. - pub fn flush(conn: *Connection) WriteError!void { - if (conn.write_end == 0) return; - - try conn.writeAllDirect(conn.write_buf[0..conn.write_end]); - conn.write_end = 0; - } - - pub const WriteError = error{ - ConnectionResetByPeer, - UnexpectedWriteFailure, - }; - - pub const Writer = std.io.Writer(*Connection, WriteError, write); - - pub fn writer(conn: *Connection) Writer { - return Writer{ .context = conn }; - } - - /// Closes the connection. - pub fn close(conn: *Connection, allocator: Allocator) void { - if (conn.protocol == .tls) { - if (disable_tls) unreachable; - - // try to cleanly close the TLS connection, for any server that cares. - conn.tls_client.close() catch {}; - allocator.destroy(conn.tls_client); - } - - conn.stream.close(); - allocator.free(conn.host); - } -}; - -/// The mode of transport for requests. -pub const RequestTransfer = union(enum) { - content_length: u64, - chunked: void, - none: void, -}; - -/// The decompressor for response messages. -pub const Compression = union(enum) { - pub const DeflateDecompressor = std.compress.zlib.Decompressor(Request.TransferReader); - pub const GzipDecompressor = std.compress.gzip.Decompressor(Request.TransferReader); - // https://github.com/ziglang/zig/issues/18937 - //pub const ZstdDecompressor = std.compress.zstd.DecompressStream(Request.TransferReader, .{}); - - deflate: DeflateDecompressor, - gzip: GzipDecompressor, - // https://github.com/ziglang/zig/issues/18937 - //zstd: ZstdDecompressor, - none: void, -}; - -/// A HTTP response originating from a server. -pub const Response = struct { - version: http.Version, - status: http.Status, - reason: []const u8, - - /// Points into the user-provided `server_header_buffer`. - location: ?[]const u8 = null, - /// Points into the user-provided `server_header_buffer`. - content_type: ?[]const u8 = null, - /// Points into the user-provided `server_header_buffer`. - content_disposition: ?[]const u8 = null, - - keep_alive: bool, - - /// If present, the number of bytes in the response body. - content_length: ?u64 = null, - - /// If present, the transfer encoding of the response body, otherwise none. - transfer_encoding: http.TransferEncoding = .none, - - /// If present, the compression of the response body, otherwise identity (no compression). - transfer_compression: http.ContentEncoding = .identity, - - parser: proto.HeadersParser, - compression: Compression = .none, - - /// Whether the response body should be skipped. Any data read from the - /// response body will be discarded. - skip: bool = false, - - pub const ParseError = error{ - HttpHeadersInvalid, - HttpHeaderContinuationsUnsupported, - HttpTransferEncodingUnsupported, - HttpConnectionHeaderUnsupported, - InvalidContentLength, - CompressionUnsupported, - }; - - pub fn parse(res: *Response, bytes: []const u8) ParseError!void { - var it = mem.splitSequence(u8, bytes, "\r\n"); - - const first_line = it.next().?; - if (first_line.len < 12) { - return error.HttpHeadersInvalid; - } - - const version: http.Version = switch (int64(first_line[0..8])) { - int64("HTTP/1.0") => .@"HTTP/1.0", - int64("HTTP/1.1") => .@"HTTP/1.1", - else => return error.HttpHeadersInvalid, - }; - if (first_line[8] != ' ') return error.HttpHeadersInvalid; - const status: http.Status = @enumFromInt(parseInt3(first_line[9..12])); - const reason = mem.trimLeft(u8, first_line[12..], " "); - - res.version = version; - res.status = status; - res.reason = reason; - res.keep_alive = switch (version) { - .@"HTTP/1.0" => false, - .@"HTTP/1.1" => true, - }; - - while (it.next()) |line| { - if (line.len == 0) return; - switch (line[0]) { - ' ', '\t' => return error.HttpHeaderContinuationsUnsupported, - else => {}, - } - - var line_it = mem.splitScalar(u8, line, ':'); - const header_name = line_it.next().?; - const header_value = mem.trim(u8, line_it.rest(), " \t"); - if (header_name.len == 0) return error.HttpHeadersInvalid; - - if (std.ascii.eqlIgnoreCase(header_name, "connection")) { - res.keep_alive = !std.ascii.eqlIgnoreCase(header_value, "close"); - } else if (std.ascii.eqlIgnoreCase(header_name, "content-type")) { - res.content_type = header_value; - } else if (std.ascii.eqlIgnoreCase(header_name, "location")) { - res.location = header_value; - } else if (std.ascii.eqlIgnoreCase(header_name, "content-disposition")) { - res.content_disposition = header_value; - } else if (std.ascii.eqlIgnoreCase(header_name, "transfer-encoding")) { - // Transfer-Encoding: second, first - // Transfer-Encoding: deflate, chunked - var iter = mem.splitBackwardsScalar(u8, header_value, ','); - - const first = iter.first(); - const trimmed_first = mem.trim(u8, first, " "); - - var next: ?[]const u8 = first; - if (std.meta.stringToEnum(http.TransferEncoding, trimmed_first)) |transfer| { - if (res.transfer_encoding != .none) return error.HttpHeadersInvalid; // we already have a transfer encoding - res.transfer_encoding = transfer; - - next = iter.next(); - } - - if (next) |second| { - const trimmed_second = mem.trim(u8, second, " "); - - if (std.meta.stringToEnum(http.ContentEncoding, trimmed_second)) |transfer| { - if (res.transfer_compression != .identity) return error.HttpHeadersInvalid; // double compression is not supported - res.transfer_compression = transfer; - } else { - return error.HttpTransferEncodingUnsupported; - } - } - - if (iter.next()) |_| return error.HttpTransferEncodingUnsupported; - } else if (std.ascii.eqlIgnoreCase(header_name, "content-length")) { - const content_length = std.fmt.parseInt(u64, header_value, 10) catch return error.InvalidContentLength; - - if (res.content_length != null and res.content_length != content_length) return error.HttpHeadersInvalid; - - res.content_length = content_length; - } else if (std.ascii.eqlIgnoreCase(header_name, "content-encoding")) { - if (res.transfer_compression != .identity) return error.HttpHeadersInvalid; - - const trimmed = mem.trim(u8, header_value, " "); - - if (std.meta.stringToEnum(http.ContentEncoding, trimmed)) |ce| { - res.transfer_compression = ce; - } else { - return error.HttpTransferEncodingUnsupported; - } - } - } - return error.HttpHeadersInvalid; // missing empty line - } - - test parse { - const response_bytes = "HTTP/1.1 200 OK\r\n" ++ - "LOcation:url\r\n" ++ - "content-tYpe: text/plain\r\n" ++ - "content-disposition:attachment; filename=example.txt \r\n" ++ - "content-Length:10\r\n" ++ - "TRansfer-encoding:\tdeflate, chunked \r\n" ++ - "connectioN:\t keep-alive \r\n\r\n"; - - var header_buffer: [1024]u8 = undefined; - var res = Response{ - .status = undefined, - .reason = undefined, - .version = undefined, - .keep_alive = false, - .parser = proto.HeadersParser.init(&header_buffer), - }; - - @memcpy(header_buffer[0..response_bytes.len], response_bytes); - res.parser.header_bytes_len = response_bytes.len; - - try res.parse(response_bytes); - - try testing.expectEqual(.@"HTTP/1.1", res.version); - try testing.expectEqualStrings("OK", res.reason); - try testing.expectEqual(.ok, res.status); - - try testing.expectEqualStrings("url", res.location.?); - try testing.expectEqualStrings("text/plain", res.content_type.?); - try testing.expectEqualStrings("attachment; filename=example.txt", res.content_disposition.?); - - try testing.expectEqual(true, res.keep_alive); - try testing.expectEqual(10, res.content_length.?); - try testing.expectEqual(.chunked, res.transfer_encoding); - try testing.expectEqual(.deflate, res.transfer_compression); - } - - inline fn int64(array: *const [8]u8) u64 { - return @bitCast(array.*); - } - - fn parseInt3(text: *const [3]u8) u10 { - if (use_vectors) { - const nnn: @Vector(3, u8) = text.*; - const zero: @Vector(3, u8) = .{ '0', '0', '0' }; - const mmm: @Vector(3, u10) = .{ 100, 10, 1 }; - return @reduce(.Add, @as(@Vector(3, u10), nnn -% zero) *% mmm); - } - return std.fmt.parseInt(u10, text, 10) catch unreachable; - } - - test parseInt3 { - const expectEqual = testing.expectEqual; - try expectEqual(@as(u10, 0), parseInt3("000")); - try expectEqual(@as(u10, 418), parseInt3("418")); - try expectEqual(@as(u10, 999), parseInt3("999")); - } - - pub fn iterateHeaders(r: Response) http.HeaderIterator { - return http.HeaderIterator.init(r.parser.get()); - } - - test iterateHeaders { - const response_bytes = "HTTP/1.1 200 OK\r\n" ++ - "LOcation:url\r\n" ++ - "content-tYpe: text/plain\r\n" ++ - "content-disposition:attachment; filename=example.txt \r\n" ++ - "content-Length:10\r\n" ++ - "TRansfer-encoding:\tdeflate, chunked \r\n" ++ - "connectioN:\t keep-alive \r\n\r\n"; - - var header_buffer: [1024]u8 = undefined; - var res = Response{ - .status = undefined, - .reason = undefined, - .version = undefined, - .keep_alive = false, - .parser = proto.HeadersParser.init(&header_buffer), - }; - - @memcpy(header_buffer[0..response_bytes.len], response_bytes); - res.parser.header_bytes_len = response_bytes.len; - - var it = res.iterateHeaders(); - { - const header = it.next().?; - try testing.expectEqualStrings("LOcation", header.name); - try testing.expectEqualStrings("url", header.value); - try testing.expect(!it.is_trailer); - } - { - const header = it.next().?; - try testing.expectEqualStrings("content-tYpe", header.name); - try testing.expectEqualStrings("text/plain", header.value); - try testing.expect(!it.is_trailer); - } - { - const header = it.next().?; - try testing.expectEqualStrings("content-disposition", header.name); - try testing.expectEqualStrings("attachment; filename=example.txt", header.value); - try testing.expect(!it.is_trailer); - } - { - const header = it.next().?; - try testing.expectEqualStrings("content-Length", header.name); - try testing.expectEqualStrings("10", header.value); - try testing.expect(!it.is_trailer); - } - { - const header = it.next().?; - try testing.expectEqualStrings("TRansfer-encoding", header.name); - try testing.expectEqualStrings("deflate, chunked", header.value); - try testing.expect(!it.is_trailer); - } - { - const header = it.next().?; - try testing.expectEqualStrings("connectioN", header.name); - try testing.expectEqualStrings("keep-alive", header.value); - try testing.expect(!it.is_trailer); - } - try testing.expectEqual(null, it.next()); - } -}; - -/// A HTTP request that has been sent. -/// -/// Order of operations: open -> send[ -> write -> finish] -> wait -> read -pub const Request = struct { - uri: Uri = undefined, - client: *Client, - /// This is null when the connection is released. - connection: ?*Connection = null, - keep_alive: bool = undefined, - - method: http.Method = undefined, - version: http.Version = .@"HTTP/1.1", - transfer_encoding: RequestTransfer = undefined, - redirect_behavior: RedirectBehavior = undefined, - - /// Whether the request should handle a 100-continue response before sending the request body. - handle_continue: bool = undefined, - - /// The response associated with this request. - /// - /// This field is undefined until `wait` is called. - response: Response = undefined, - - /// Standard headers that have default, but overridable, behavior. - headers: Headers = undefined, - - /// These headers are kept including when following a redirect to a - /// different domain. - /// Externally-owned; must outlive the Request. - extra_headers: []const http.Header = undefined, - - /// These headers are stripped when following a redirect to a different - /// domain. - /// Externally-owned; must outlive the Request. - privileged_headers: []const http.Header = undefined, - - pub fn init(client: *Client) Request { - return .{ - .client = client, - }; - } - - pub const Headers = struct { - host: Value = .default, - authorization: Value = .default, - user_agent: Value = .default, - connection: Value = .default, - accept_encoding: Value = .default, - content_type: Value = .default, - - pub const Value = union(enum) { - default, - omit, - override: []const u8, - }; - }; - - /// Any value other than `not_allowed` or `unhandled` means that integer represents - /// how many remaining redirects are allowed. - pub const RedirectBehavior = enum(u16) { - /// The next redirect will cause an error. - not_allowed = 0, - /// Redirects are passed to the client to analyze the redirect response - /// directly. - unhandled = std.math.maxInt(u16), - _, - - pub fn subtractOne(rb: *RedirectBehavior) void { - switch (rb.*) { - .not_allowed => unreachable, - .unhandled => unreachable, - _ => rb.* = @enumFromInt(@intFromEnum(rb.*) - 1), - } - } - - pub fn remaining(rb: RedirectBehavior) u16 { - assert(rb != .unhandled); - return @intFromEnum(rb); - } - }; - - /// Frees all resources associated with the request. - pub fn deinit(req: *Request) void { - if (req.connection) |connection| { - if (!req.response.parser.done) { - // If the response wasn't fully read, then we need to close the connection. - connection.closing = true; - } - req.client.connection_pool.release(req.client.allocator, connection); - } - req.* = undefined; - } - - fn onRedirectSend(ctx: *Ctx, res: anyerror!void) !void { - res catch |err| return ctx.pop(err); - // go back on check headers - ctx.conn().async_fill(ctx, onResponseHeaders) catch |err| return ctx.pop(err); - } - - fn onRedirectConnect(ctx: *Ctx, res: anyerror!void) !void { - res catch |err| return ctx.pop(err); - // re-send request - ctx.req.prepareSend(.{}) catch |err| return ctx.pop(err); - ctx.req.connection.?.async_flush(ctx, onRedirectSend) catch |err| return ctx.pop(err); - } - - // async_redirect flow: - // connect -> setRequestConnection - // -> onRedirectConnect -> async_flush - // -> onRedirectSend -> async_fill - // -> go back on the wait workflow of the response - fn async_redirect(req: *Request, uri: Uri, ctx: *Ctx) !void { - try req.prepareRedirect(); - - var server_header = std.heap.FixedBufferAllocator.init(req.response.parser.header_bytes_buffer); - defer req.response.parser.header_bytes_buffer = server_header.buffer[server_header.end_index..]; - - const protocol, const valid_uri = try validateUri(uri, server_header.allocator()); - - const new_host = valid_uri.host.?.raw; - const prev_host = req.uri.host.?.raw; - const keep_privileged_headers = - std.ascii.eqlIgnoreCase(valid_uri.scheme, req.uri.scheme) and - std.ascii.endsWithIgnoreCase(new_host, prev_host) and - (new_host.len == prev_host.len or new_host[new_host.len - prev_host.len - 1] == '.'); - if (!keep_privileged_headers) { - // When redirecting to a different domain, strip privileged headers. - req.privileged_headers = &.{}; - } - - // create a new connection for the redirected URI - ctx.data.conn = try req.client.allocator.create(Connection); - ctx.data.conn.* = .{ - .stream = undefined, - .tls_client = undefined, - .protocol = undefined, - .host = undefined, - .port = undefined, - }; - req.uri = valid_uri; - return req.client.async_connect(new_host, uriPort(valid_uri, protocol), protocol, ctx, setRequestConnection); - } - - // This function must deallocate all resources associated with the request, - // or keep those which will be used. - // This needs to be kept in sync with deinit and request. - fn redirect(req: *Request, uri: Uri) !void { - try req.prepareRedirect(); - - var server_header = std.heap.FixedBufferAllocator.init(req.response.parser.header_bytes_buffer); - defer req.response.parser.header_bytes_buffer = server_header.buffer[server_header.end_index..]; - - const protocol, const valid_uri = try validateUri(uri, server_header.allocator()); - - const new_host = valid_uri.host.?.raw; - const prev_host = req.uri.host.?.raw; - const keep_privileged_headers = - std.ascii.eqlIgnoreCase(valid_uri.scheme, req.uri.scheme) and - std.ascii.endsWithIgnoreCase(new_host, prev_host) and - (new_host.len == prev_host.len or new_host[new_host.len - prev_host.len - 1] == '.'); - if (!keep_privileged_headers) { - // When redirecting to a different domain, strip privileged headers. - req.privileged_headers = &.{}; - } - - req.connection = try req.client.connect(new_host, uriPort(valid_uri, protocol), protocol); - req.uri = valid_uri; - } - fn prepareRedirect(req: *Request) !void { - assert(req.response.parser.done); - - req.client.connection_pool.release(req.client.allocator, req.connection.?); - req.connection = null; - - if (switch (req.response.status) { - .see_other => true, - .moved_permanently, .found => req.method == .POST, - else => false, - }) { - // A redirect to a GET must change the method and remove the body. - req.method = .GET; - req.transfer_encoding = .none; - req.headers.content_type = .omit; - } - - if (req.transfer_encoding != .none) { - // The request body has already been sent. The request is - // still in a valid state, but the redirect must be handled - // manually. - return error.RedirectRequiresResend; - } - - req.redirect_behavior.subtractOne(); - req.response.parser.reset(); - - req.response = .{ - .version = undefined, - .status = undefined, - .reason = undefined, - .keep_alive = undefined, - .parser = req.response.parser, - }; - } - - pub const SendError = Connection.WriteError || error{ InvalidContentLength, UnsupportedTransferEncoding }; - - pub fn async_send(req: *Request, ctx: *Ctx, comptime cbk: Cbk) !void { - try req.prepareSend(); - try req.connection.?.async_flush(ctx, cbk); - } - - /// Send the HTTP request headers to the server. - pub fn send(req: *Request) SendError!void { - try req.prepareSend(); - try req.connection.?.flush(); - } - - fn prepareSend(req: *Request) SendError!void { - if (!req.method.requestHasBody() and req.transfer_encoding != .none) - if (!req.method.requestHasBody() and req.transfer_encoding != .none) - return error.UnsupportedTransferEncoding; - - const connection = req.connection.?; - const w = connection.writer(); - - try req.method.write(w); - try w.writeByte(' '); - - if (req.method == .CONNECT) { - try req.uri.writeToStream(.{ .authority = true }, w); - } else { - try req.uri.writeToStream(.{ - .scheme = connection.proxied, - .authentication = connection.proxied, - .authority = connection.proxied, - .path = true, - .query = true, - }, w); - } - try w.writeByte(' '); - try w.writeAll(@tagName(req.version)); - try w.writeAll("\r\n"); - - if (try emitOverridableHeader("host: ", req.headers.host, w)) { - try w.writeAll("host: "); - try req.uri.writeToStream(.{ .authority = true }, w); - try w.writeAll("\r\n"); - } - - if (try emitOverridableHeader("authorization: ", req.headers.authorization, w)) { - if (req.uri.user != null or req.uri.password != null) { - try w.writeAll("authorization: "); - const authorization = try connection.allocWriteBuffer( - @intCast(basic_authorization.valueLengthFromUri(req.uri)), - ); - assert(basic_authorization.value(req.uri, authorization).len == authorization.len); - try w.writeAll("\r\n"); - } - } - - if (try emitOverridableHeader("user-agent: ", req.headers.user_agent, w)) { - try w.writeAll("user-agent: zig/"); - try w.writeAll(builtin.zig_version_string); - try w.writeAll(" (std.http)\r\n"); - } - - if (try emitOverridableHeader("connection: ", req.headers.connection, w)) { - if (req.keep_alive) { - try w.writeAll("connection: keep-alive\r\n"); - } else { - try w.writeAll("connection: close\r\n"); - } - } - - if (try emitOverridableHeader("accept-encoding: ", req.headers.accept_encoding, w)) { - // https://github.com/ziglang/zig/issues/18937 - //try w.writeAll("accept-encoding: gzip, deflate, zstd\r\n"); - try w.writeAll("accept-encoding: gzip, deflate\r\n"); - } - - switch (req.transfer_encoding) { - .chunked => try w.writeAll("transfer-encoding: chunked\r\n"), - .content_length => |len| try w.print("content-length: {d}\r\n", .{len}), - .none => {}, - } - - if (try emitOverridableHeader("content-type: ", req.headers.content_type, w)) { - // The default is to omit content-type if not provided because - // "application/octet-stream" is redundant. - } - - for (req.extra_headers) |header| { - assert(header.name.len != 0); - - try w.writeAll(header.name); - try w.writeAll(": "); - try w.writeAll(header.value); - try w.writeAll("\r\n"); - } - - if (connection.proxied) proxy: { - const proxy = switch (connection.protocol) { - .plain => req.client.http_proxy, - .tls => req.client.https_proxy, - } orelse break :proxy; - - const authorization = proxy.authorization orelse break :proxy; - try w.writeAll("proxy-authorization: "); - try w.writeAll(authorization); - try w.writeAll("\r\n"); - } - - try w.writeAll("\r\n"); - } - - /// Returns true if the default behavior is required, otherwise handles - /// writing (or not writing) the header. - fn emitOverridableHeader(prefix: []const u8, v: Headers.Value, w: anytype) !bool { - switch (v) { - .default => return true, - .omit => return false, - .override => |x| { - try w.writeAll(prefix); - try w.writeAll(x); - try w.writeAll("\r\n"); - return false; - }, - } - } - - const TransferReadError = Connection.ReadError || proto.HeadersParser.ReadError; - - const TransferReader = std.io.Reader(*Request, TransferReadError, transferRead); - - fn transferReader(req: *Request) TransferReader { - return .{ .context = req }; - } - - fn transferRead(req: *Request, buf: []u8) TransferReadError!usize { - if (req.response.parser.done) return 0; - - var index: usize = 0; - while (index == 0) { - const amt = try req.response.parser.read(req.connection.?, buf[index..], req.response.skip); - if (amt == 0 and req.response.parser.done) break; - index += amt; - } - - return index; - } - - pub const WaitError = RequestError || SendError || TransferReadError || - proto.HeadersParser.CheckCompleteHeadError || Response.ParseError || - error{ // TODO: file zig fmt issue for this bad indentation - TooManyHttpRedirects, - RedirectRequiresResend, - HttpRedirectLocationMissing, - HttpRedirectLocationInvalid, - CompressionInitializationFailed, - CompressionUnsupported, - }; - - pub fn async_wait(_: *Request, ctx: *Ctx, comptime cbk: Cbk) !void { - try ctx.push(cbk); - return ctx.conn().async_fill(ctx, onResponseHeaders); - } - - /// Waits for a response from the server and parses any headers that are sent. - /// This function will block until the final response is received. - /// - /// If handling redirects and the request has no payload, then this - /// function will automatically follow redirects. If a request payload is - /// present, then this function will error with - /// error.RedirectRequiresResend. - /// - /// Must be called after `send` and, if any data was written to the request - /// body, then also after `finish`. - pub fn wait(req: *Request) WaitError!void { - while (true) { - // This while loop is for handling redirects, which means the request's - // connection may be different than the previous iteration. However, it - // is still guaranteed to be non-null with each iteration of this loop. - const connection = req.connection.?; - - while (true) { // read headers - try connection.fill(); - - const nchecked = try req.response.parser.checkCompleteHead(connection.peek()); - connection.drop(@intCast(nchecked)); - - if (req.response.parser.state.isContent()) break; - } - - try req.response.parse(req.response.parser.get()); - - if (req.response.status == .@"continue") { - // We're done parsing the continue response; reset to prepare - // for the real response. - req.response.parser.done = true; - req.response.parser.reset(); - - if (req.handle_continue) - continue; - - return; // we're not handling the 100-continue - } - - // we're switching protocols, so this connection is no longer doing http - if (req.method == .CONNECT and req.response.status.class() == .success) { - connection.closing = false; - req.response.parser.done = true; - return; // the connection is not HTTP past this point - } - - connection.closing = !req.response.keep_alive or !req.keep_alive; - - // Any response to a HEAD request and any response with a 1xx - // (Informational), 204 (No Content), or 304 (Not Modified) status - // code is always terminated by the first empty line after the - // header fields, regardless of the header fields present in the - // message. - if (req.method == .HEAD or req.response.status.class() == .informational or - req.response.status == .no_content or req.response.status == .not_modified) - { - req.response.parser.done = true; - return; // The response is empty; no further setup or redirection is necessary. - } - - switch (req.response.transfer_encoding) { - .none => { - if (req.response.content_length) |cl| { - req.response.parser.next_chunk_length = cl; - - if (cl == 0) req.response.parser.done = true; - } else { - // read until the connection is closed - req.response.parser.next_chunk_length = std.math.maxInt(u64); - } - }, - .chunked => { - req.response.parser.next_chunk_length = 0; - req.response.parser.state = .chunk_head_size; - }, - } - - if (req.response.status.class() == .redirect and req.redirect_behavior != .unhandled) { - // skip the body of the redirect response, this will at least - // leave the connection in a known good state. - req.response.skip = true; - assert(try req.transferRead(&.{}) == 0); // we're skipping, no buffer is necessary - - if (req.redirect_behavior == .not_allowed) return error.TooManyHttpRedirects; - - const location = req.response.location orelse - return error.HttpRedirectLocationMissing; - - // This mutates the beginning of header_bytes_buffer and uses that - // for the backing memory of the returned Uri. - try req.redirect(req.uri.resolve_inplace( - location, - &req.response.parser.header_bytes_buffer, - ) catch |err| switch (err) { - error.UnexpectedCharacter, - error.InvalidFormat, - error.InvalidPort, - => return error.HttpRedirectLocationInvalid, - error.NoSpaceLeft => return error.HttpHeadersOversize, - }); - try req.send(); - } else { - req.response.skip = false; - if (!req.response.parser.done) { - switch (req.response.transfer_compression) { - .identity => req.response.compression = .none, - .compress, .@"x-compress" => return error.CompressionUnsupported, - .deflate => req.response.compression = .{ - .deflate = std.compress.zlib.decompressor(req.transferReader()), - }, - .gzip, .@"x-gzip" => req.response.compression = .{ - .gzip = std.compress.gzip.decompressor(req.transferReader()), - }, - // https://github.com/ziglang/zig/issues/18937 - //.zstd => req.response.compression = .{ - // .zstd = std.compress.zstd.decompressStream(req.client.allocator, req.transferReader()), - //}, - .zstd => return error.CompressionUnsupported, - } - } - - break; - } - } - } - - pub const ReadError = TransferReadError || proto.HeadersParser.CheckCompleteHeadError || - error{ DecompressionFailure, InvalidTrailers }; - - pub const Reader = std.io.Reader(*Request, ReadError, read); - - pub fn reader(req: *Request) Reader { - return .{ .context = req }; - } - - /// Reads data from the response body. Must be called after `wait`. - pub fn read(req: *Request, buffer: []u8) ReadError!usize { - const out_index = switch (req.response.compression) { - .deflate => |*deflate| deflate.read(buffer) catch return error.DecompressionFailure, - .gzip => |*gzip| gzip.read(buffer) catch return error.DecompressionFailure, - // https://github.com/ziglang/zig/issues/18937 - //.zstd => |*zstd| zstd.read(buffer) catch return error.DecompressionFailure, - else => try req.transferRead(buffer), - }; - if (out_index > 0) return out_index; - - while (!req.response.parser.state.isContent()) { // read trailing headers - try req.connection.?.fill(); - - const nchecked = try req.response.parser.checkCompleteHead(req.connection.?.peek()); - req.connection.?.drop(@intCast(nchecked)); - } - - return 0; - } - - /// Reads data from the response body. Must be called after `wait`. - pub fn readAll(req: *Request, buffer: []u8) !usize { - var index: usize = 0; - while (index < buffer.len) { - const amt = try read(req, buffer[index..]); - if (amt == 0) break; - index += amt; - } - return index; - } - - pub const WriteError = Connection.WriteError || error{ NotWriteable, MessageTooLong }; - - pub const Writer = std.io.Writer(*Request, WriteError, write); - - pub fn writer(req: *Request) Writer { - return .{ .context = req }; - } - - /// Write `bytes` to the server. The `transfer_encoding` field determines how data will be sent. - /// Must be called after `send` and before `finish`. - pub fn write(req: *Request, bytes: []const u8) WriteError!usize { - switch (req.transfer_encoding) { - .chunked => { - if (bytes.len > 0) { - try req.connection.?.writer().print("{x}\r\n", .{bytes.len}); - try req.connection.?.writer().writeAll(bytes); - try req.connection.?.writer().writeAll("\r\n"); - } - - return bytes.len; - }, - .content_length => |*len| { - if (len.* < bytes.len) return error.MessageTooLong; - - const amt = try req.connection.?.write(bytes); - len.* -= amt; - return amt; - }, - .none => return error.NotWriteable, - } - } - - /// Write `bytes` to the server. The `transfer_encoding` field determines how data will be sent. - /// Must be called after `send` and before `finish`. - pub fn writeAll(req: *Request, bytes: []const u8) WriteError!void { - var index: usize = 0; - while (index < bytes.len) { - index += try write(req, bytes[index..]); - } - } - - fn onWriteAll(ctx: *Ctx, res: anyerror!void) !void { - res catch |err| return ctx.pop(err); - switch (ctx.req.transfer_encoding) { - .chunked => unreachable, - .none => unreachable, - .content_length => |*len| { - len.* = 0; - }, - } - try ctx.pop({}); - } - - pub fn async_writeAll(req: *Request, buf: []const u8, ctx: *Ctx, comptime cbk: Cbk) !void { - switch (req.transfer_encoding) { - .chunked => return error.ChunkedNotImplemented, - .none => return error.NotWriteable, - .content_length => |len| { - try ctx.push(cbk); - if (len < buf.len) return error.MessageTooLong; - - try req.connection.?.async_writeAllDirect(buf, ctx, onWriteAll); - }, - } - } - - pub const FinishError = WriteError || error{MessageNotCompleted}; - - pub fn async_finish(req: *Request, ctx: *Ctx, comptime cbk: Cbk) !void { - try req.common_finish(); - req.connection.?.async_flush(ctx, cbk) catch |err| switch (err) { - error.WriteEmpty => return cbk(ctx, {}), - else => return cbk(ctx, err), - }; - } - - /// Finish the body of a request. This notifies the server that you have no more data to send. - /// Must be called after `send`. - pub fn finish(req: *Request) FinishError!void { - try req.common_finish(); - try req.connection.?.flush(); - } - - fn common_finish(req: *Request) FinishError!void { - switch (req.transfer_encoding) { - .chunked => try req.connection.?.writer().writeAll("0\r\n\r\n"), - .content_length => |len| if (len != 0) return error.MessageNotCompleted, - .none => {}, - } - } - - fn onResponseHeaders(ctx: *Ctx, res: anyerror!void) !void { - res catch |err| return ctx.pop(err); - const done = ctx.req.parseResponseHeaders() catch |err| return ctx.pop(err); - // if read of the headers is not done, continue - if (!done) return ctx.conn().async_fill(ctx, onResponseHeaders); - // if read of the headers is done, go read the reponse - return onResponse(ctx, {}); - } - - fn parseResponseHeaders(req: *Request) !bool { - const nchecked = try req.response.parser.checkCompleteHead(req.connection.?.peek()); - req.connection.?.drop(@intCast(nchecked)); - - if (req.response.parser.state.isContent()) return true; - return false; - } - - fn onResponse(ctx: *Ctx, res: anyerror!void) !void { - res catch |err| return ctx.pop(err); - const ret = ctx.req.parseResponse() catch |err| return ctx.pop(err); - if (ret.redirect_uri) |uri| { - ctx.req.async_redirect(uri, ctx) catch |err| return ctx.pop(err); - return; - } - // if read of the response is not done, continue - if (!ret.done) return ctx.conn().async_fill(ctx, onResponse); - // if read of the response is done, go execute the provided callback - return ctx.pop({}); - } - - const WaitRedirectsReturn = struct { - redirect_uri: ?Uri = null, - done: bool = true, - }; - - fn parseResponse(req: *Request) WaitError!WaitRedirectsReturn { - try req.response.parse(req.response.parser.get()); - - if (req.response.status == .@"continue") { - // We're done parsing the continue response; reset to prepare - // for the real response. - req.response.parser.done = true; - req.response.parser.reset(); - - if (req.handle_continue) return .{ .done = false }; - - return .{ .done = true }; - } - - // we're switching protocols, so this connection is no longer doing http - if (req.method == .CONNECT and req.response.status.class() == .success) { - req.connection.?.closing = false; - req.response.parser.done = true; - return .{ .done = true }; // the connection is not HTTP past this point - } - - req.connection.?.closing = !req.response.keep_alive or !req.keep_alive; - - // Any response to a HEAD request and any response with a 1xx - // (Informational), 204 (No Content), or 304 (Not Modified) status - // code is always terminated by the first empty line after the - // header fields, regardless of the header fields present in the - // message. - if (req.method == .HEAD or req.response.status.class() == .informational or - req.response.status == .no_content or req.response.status == .not_modified) - { - req.response.parser.done = true; - return .{ .done = true }; // The response is empty; no further setup or redirection is necessary. - } - - switch (req.response.transfer_encoding) { - .none => { - if (req.response.content_length) |cl| { - req.response.parser.next_chunk_length = cl; - - if (cl == 0) req.response.parser.done = true; - } else { - // read until the connection is closed - req.response.parser.next_chunk_length = std.math.maxInt(u64); - } - }, - .chunked => { - req.response.parser.next_chunk_length = 0; - req.response.parser.state = .chunk_head_size; - }, - } - - if (req.response.status.class() == .redirect and req.redirect_behavior != .unhandled) { - // skip the body of the redirect response, this will at least - // leave the connection in a known good state. - req.response.skip = true; - assert(try req.transferRead(&.{}) == 0); // we're skipping, no buffer is necessary - - if (req.redirect_behavior == .not_allowed) return error.TooManyHttpRedirects; - - const location = req.response.location orelse - return error.HttpRedirectLocationMissing; - - // This mutates the beginning of header_bytes_buffer and uses that - // for the backing memory of the returned Uri. - try req.redirect(req.uri.resolve_inplace( - location, - &req.response.parser.header_bytes_buffer, - ) catch |err| switch (err) { - error.UnexpectedCharacter, - error.InvalidFormat, - error.InvalidPort, - => return error.HttpRedirectLocationInvalid, - error.NoSpaceLeft => return error.HttpHeadersOversize, - }); - - return .{ .redirect_uri = req.uri }; - } else { - req.response.skip = false; - if (!req.response.parser.done) { - switch (req.response.transfer_compression) { - .identity => req.response.compression = .none, - .compress, .@"x-compress" => return error.CompressionUnsupported, - .deflate => req.response.compression = .{ - .deflate = std.compress.zlib.decompressor(req.transferReader()), - }, - .gzip, .@"x-gzip" => req.response.compression = .{ - .gzip = std.compress.gzip.decompressor(req.transferReader()), - }, - // https://github.com/ziglang/zig/issues/18937 - //.zstd => req.response.compression = .{ - // .zstd = std.compress.zstd.decompressStream(req.client.allocator, req.transferReader()), - //}, - .zstd => return error.CompressionUnsupported, - } - } - - return .{ .done = true }; - } - return .{ .done = false }; - } -}; - -pub const Proxy = struct { - protocol: Connection.Protocol, - host: []const u8, - authorization: ?[]const u8, - port: u16, - supports_connect: bool, -}; - -/// Release all associated resources with the client. -/// -/// All pending requests must be de-initialized and all active connections released -/// before calling this function. -pub fn deinit(client: *Client) void { - assert(client.connection_pool.used.first == null); // There are still active requests. - - client.connection_pool.deinit(client.allocator); - - if (!disable_tls) - client.ca_bundle.deinit(client.allocator); - - client.* = undefined; -} - -/// Populates `http_proxy` and `https_proxy` via standard proxy environment variables. -/// Asserts the client has no active connections. -/// Uses `arena` for a few small allocations that must outlive the client, or -/// at least until those fields are set to different values. -pub fn initDefaultProxies(client: *Client, arena: Allocator) !void { - // Prevent any new connections from being created. - client.connection_pool.mutex.lock(); - defer client.connection_pool.mutex.unlock(); - - assert(client.connection_pool.used.first == null); // There are active requests. - - if (client.http_proxy == null) { - client.http_proxy = try createProxyFromEnvVar(arena, &.{ - "http_proxy", "HTTP_PROXY", "all_proxy", "ALL_PROXY", - }); - } - - if (client.https_proxy == null) { - client.https_proxy = try createProxyFromEnvVar(arena, &.{ - "https_proxy", "HTTPS_PROXY", "all_proxy", "ALL_PROXY", - }); - } -} - -fn createProxyFromEnvVar(arena: Allocator, env_var_names: []const []const u8) !?*Proxy { - const content = for (env_var_names) |name| { - break std.process.getEnvVarOwned(arena, name) catch |err| switch (err) { - error.EnvironmentVariableNotFound => continue, - else => |e| return e, - }; - } else return null; - - const uri = Uri.parse(content) catch try Uri.parseAfterScheme("http", content); - const protocol, const valid_uri = validateUri(uri, arena) catch |err| switch (err) { - error.UnsupportedUriScheme => return null, - error.UriMissingHost => return error.HttpProxyMissingHost, - error.OutOfMemory => |e| return e, - }; - - const authorization: ?[]const u8 = if (valid_uri.user != null or valid_uri.password != null) a: { - const authorization = try arena.alloc(u8, basic_authorization.valueLengthFromUri(valid_uri)); - assert(basic_authorization.value(valid_uri, authorization).len == authorization.len); - break :a authorization; - } else null; - - const proxy = try arena.create(Proxy); - proxy.* = .{ - .protocol = protocol, - .host = valid_uri.host.?.raw, - .authorization = authorization, - .port = uriPort(valid_uri, protocol), - .supports_connect = true, - }; - return proxy; -} - -pub const basic_authorization = struct { - pub const max_user_len = 255; - pub const max_password_len = 255; - pub const max_value_len = valueLength(max_user_len, max_password_len); - - const prefix = "Basic "; - - pub fn valueLength(user_len: usize, password_len: usize) usize { - return prefix.len + std.base64.standard.Encoder.calcSize(user_len + 1 + password_len); - } - - pub fn valueLengthFromUri(uri: Uri) usize { - var stream = std.io.countingWriter(std.io.null_writer); - try stream.writer().print("{user}", .{uri.user orelse Uri.Component.empty}); - const user_len = stream.bytes_written; - stream.bytes_written = 0; - try stream.writer().print("{password}", .{uri.password orelse Uri.Component.empty}); - const password_len = stream.bytes_written; - return valueLength(@intCast(user_len), @intCast(password_len)); - } - - pub fn value(uri: Uri, out: []u8) []u8 { - var buf: [max_user_len + ":".len + max_password_len]u8 = undefined; - var stream = std.io.fixedBufferStream(&buf); - stream.writer().print("{user}", .{uri.user orelse Uri.Component.empty}) catch - unreachable; - assert(stream.pos <= max_user_len); - stream.writer().print(":{password}", .{uri.password orelse Uri.Component.empty}) catch - unreachable; - - @memcpy(out[0..prefix.len], prefix); - const base64 = std.base64.standard.Encoder.encode(out[prefix.len..], stream.getWritten()); - return out[0 .. prefix.len + base64.len]; - } -}; - -pub const ConnectTcpError = Allocator.Error || error{ ConnectionRefused, NetworkUnreachable, ConnectionTimedOut, ConnectionResetByPeer, TemporaryNameServerFailure, NameServerFailure, UnknownHostName, HostLacksNetworkAddresses, UnexpectedConnectFailure, TlsInitializationFailed }; - -// requires ctx.data.stream to be set -fn setConnection(ctx: *Ctx, res: anyerror!void) !void { - - // check stream - errdefer ctx.data.conn.stream.close(); - res catch |e| { - // ctx.data.conn.stream.close(); is it needed with errdefer? - switch (e) { - error.ConnectionRefused, - error.NetworkUnreachable, - error.ConnectionTimedOut, - error.ConnectionResetByPeer, - error.TemporaryNameServerFailure, - error.NameServerFailure, - error.UnknownHostName, - error.HostLacksNetworkAddresses, - => return ctx.pop(e), - else => return ctx.pop(error.UnexpectedConnectFailure), - } - }; - - if (ctx.data.conn.protocol == .tls) { - if (disable_tls) unreachable; - - ctx.data.conn.tls_client = try ctx.alloc().create(tls23.Connection(net.Stream)); - errdefer ctx.alloc().destroy(ctx.data.conn.tls_client); - - // TODO tls23.client does an handshake to pick a cipher. - ctx.data.conn.tls_client.* = tls23.client(ctx.data.conn.stream, .{ - .host = ctx.data.conn.host, - .root_ca = .{ .bundle = ctx.req.client.ca_bundle }, - }) catch return error.TlsInitializationFailed; - } - - // add connection node in pool - const node = ctx.req.client.allocator.create(ConnectionPool.Node) catch |e| return ctx.pop(e); - errdefer ctx.req.client.allocator.destroy(node); - // NOTE we can not use the ctx.data.conn pointer as a node connection data, - // we need to copy it's value and use this reference for the connection - node.* = .{ - .data = .{ - .stream = ctx.data.conn.stream, - .tls_client = ctx.data.conn.tls_client, - .protocol = ctx.data.conn.protocol, - .host = ctx.data.conn.host, - .port = ctx.data.conn.port, - }, - }; - // remove old pointer, now useless - const old_conn = ctx.data.conn; - defer ctx.req.client.allocator.destroy(old_conn); - - ctx.req.client.connection_pool.addUsed(node); - ctx.data.conn = &node.data; - - return ctx.pop({}); -} - -/// Connect to `host:port` using the specified protocol. This will reuse a connection if one is already open. -/// -/// This function is threadsafe. -pub fn connectTcp(client: *Client, host: []const u8, port: u16, protocol: Connection.Protocol) ConnectTcpError!*Connection { - if (client.connection_pool.findConnection(.{ - .host = host, - .port = port, - .protocol = protocol, - })) |node| return node; - - if (disable_tls and protocol == .tls) - return error.TlsInitializationFailed; - - const conn = try client.allocator.create(ConnectionPool.Node); - errdefer client.allocator.destroy(conn); - conn.* = .{ .data = undefined }; - - const stream = net.tcpConnectToHost(client.allocator, host, port) catch |err| switch (err) { - error.ConnectionRefused => return error.ConnectionRefused, - error.NetworkUnreachable => return error.NetworkUnreachable, - error.ConnectionTimedOut => return error.ConnectionTimedOut, - error.ConnectionResetByPeer => return error.ConnectionResetByPeer, - error.TemporaryNameServerFailure => return error.TemporaryNameServerFailure, - error.NameServerFailure => return error.NameServerFailure, - error.UnknownHostName => return error.UnknownHostName, - error.HostLacksNetworkAddresses => return error.HostLacksNetworkAddresses, - else => return error.UnexpectedConnectFailure, - }; - errdefer stream.close(); - - conn.data = .{ - .stream = stream, - .tls_client = undefined, - - .protocol = protocol, - .host = try client.allocator.dupe(u8, host), - .port = port, - }; - errdefer client.allocator.free(conn.data.host); - - if (protocol == .tls) { - if (disable_tls) unreachable; - - conn.data.tls_client = try client.allocator.create(tls23.Connection(net.Stream)); - errdefer client.allocator.destroy(conn.data.tls_client); - - // TODO tls23.client does an handshake to pick a cipher. - conn.data.tls_client.* = tls23.client(stream, .{ - .host = host, - .root_ca = .{ .bundle = client.ca_bundle }, - }) catch return error.TlsInitializationFailed; - } - - client.connection_pool.addUsed(conn); - - return &conn.data; -} - -pub fn async_connectTcp( - client: *Client, - host: []const u8, - port: u16, - protocol: Connection.Protocol, - ctx: *Ctx, - comptime cbk: Cbk, -) !void { - try ctx.push(cbk); - if (ctx.req.client.connection_pool.findConnection(.{ - .host = host, - .port = port, - .protocol = protocol, - })) |conn| { - ctx.data.conn = conn; - ctx.req.connection = conn; - return ctx.pop({}); - } - - if (disable_tls and protocol == .tls) - return error.TlsInitializationFailed; - - return net.async_tcpConnectToHost( - client.allocator, - host, - port, - ctx, - setConnection, - ); -} - -pub const ConnectUnixError = Allocator.Error || std.posix.SocketError || error{NameTooLong} || std.posix.ConnectError; - -// Connect to `path` as a unix domain socket. This will reuse a connection if one is already open. -// -// This function is threadsafe. -// pub fn connectUnix(client: *Client, path: []const u8) ConnectUnixError!*Connection { -// if (client.connection_pool.findConnection(.{ -// .host = path, -// .port = 0, -// .protocol = .plain, -// })) |node| -// return node; - -// const conn = try client.allocator.create(ConnectionPool.Node); -// errdefer client.allocator.destroy(conn); -// conn.* = .{ .data = undefined }; - -// const stream = try std.net.connectUnixSocket(path); -// errdefer stream.close(); - -// conn.data = .{ -// .stream = stream, -// .tls_client = undefined, -// .protocol = .plain, - -// .host = try client.allocator.dupe(u8, path), -// .port = 0, -// }; -// errdefer client.allocator.free(conn.data.host); - -// client.connection_pool.addUsed(conn); - -// return &conn.data; -//} - -/// Connect to `tunnel_host:tunnel_port` using the specified proxy with HTTP -/// CONNECT. This will reuse a connection if one is already open. -/// -/// This function is threadsafe. -pub fn connectTunnel( - client: *Client, - proxy: *Proxy, - tunnel_host: []const u8, - tunnel_port: u16, -) !*Connection { - if (!proxy.supports_connect) return error.TunnelNotSupported; - - if (client.connection_pool.findConnection(.{ - .host = tunnel_host, - .port = tunnel_port, - .protocol = proxy.protocol, - })) |node| - return node; - - var maybe_valid = false; - (tunnel: { - const conn = try client.connectTcp(proxy.host, proxy.port, proxy.protocol); - errdefer { - conn.closing = true; - client.connection_pool.release(client.allocator, conn); - } - - var buffer: [8096]u8 = undefined; - var req = client.open(.CONNECT, .{ - .scheme = "http", - .host = .{ .raw = tunnel_host }, - .port = tunnel_port, - }, .{ - .redirect_behavior = .unhandled, - .connection = conn, - .server_header_buffer = &buffer, - }) catch |err| { - break :tunnel err; - }; - defer req.deinit(); - - req.send() catch |err| break :tunnel err; - req.wait() catch |err| break :tunnel err; - - if (req.response.status.class() == .server_error) { - maybe_valid = true; - break :tunnel error.ServerError; - } - - if (req.response.status != .ok) break :tunnel error.ConnectionRefused; - - // this connection is now a tunnel, so we can't use it for anything else, it will only be released when the client is de-initialized. - req.connection = null; - - client.allocator.free(conn.host); - conn.host = try client.allocator.dupe(u8, tunnel_host); - errdefer client.allocator.free(conn.host); - - conn.port = tunnel_port; - conn.closing = false; - - return conn; - }) catch { - // something went wrong with the tunnel - proxy.supports_connect = maybe_valid; - return error.TunnelNotSupported; - }; -} - -// Prevents a dependency loop in open() -const ConnectErrorPartial = ConnectTcpError || error{ UnsupportedUriScheme, ConnectionRefused }; -pub const ConnectError = ConnectErrorPartial || RequestError; - -fn onConnectProxy(ctx: *Ctx, res: anyerror!void) anyerror!void { - res catch |e| { - ctx.data.conn.closing = true; - ctx.req.client.connection_pool.release(ctx.req.client.allocator, ctx.data.conn); - return ctx.pop(e); - }; - ctx.data.conn.proxied = true; - return ctx.pop({}); -} - -/// Connect to `host:port` using the specified protocol. This will reuse a -/// connection if one is already open. -/// If a proxy is configured for the client, then the proxy will be used to -/// connect to the host. -/// -/// This function is threadsafe. -pub fn connect( - client: *Client, - host: []const u8, - port: u16, - protocol: Connection.Protocol, -) ConnectError!*Connection { - const proxy = switch (protocol) { - .plain => client.http_proxy, - .tls => client.https_proxy, - } orelse return client.connectTcp(host, port, protocol); - - // Prevent proxying through itself. - if (std.ascii.eqlIgnoreCase(proxy.host, host) and - proxy.port == port and proxy.protocol == protocol) - { - return client.connectTcp(host, port, protocol); - } - - if (proxy.supports_connect) tunnel: { - return connectTunnel(client, proxy, host, port) catch |err| switch (err) { - error.TunnelNotSupported => break :tunnel, - else => |e| return e, - }; - } - - // fall back to using the proxy as a normal http proxy - const conn = try client.connectTcp(proxy.host, proxy.port, proxy.protocol); - errdefer { - conn.closing = true; - client.connection_pool.release(conn); - } - - conn.proxied = true; - return conn; -} - -pub fn async_connect( - client: *Client, - host: []const u8, - port: u16, - protocol: Connection.Protocol, - ctx: *Ctx, - comptime cbk: Cbk, -) !void { - const proxy = switch (protocol) { - .plain => client.http_proxy, - .tls => client.https_proxy, - } orelse return client.async_connectTcp(host, port, protocol, ctx, cbk); - - // Prevent proxying through itself. - if (std.ascii.eqlIgnoreCase(proxy.host, host) and - proxy.port == port and proxy.protocol == protocol) - { - return client.async_connectTcp(host, port, protocol, ctx, cbk); - } - - // TODO: enable async_connectTunnel - // if (proxy.supports_connect) tunnel: { - // return connectTunnel(client, proxy, host, port) catch |err| switch (err) { - // error.TunnelNotSupported => break :tunnel, - // else => |e| return e, - // }; - // } - - // fall back to using the proxy as a normal http proxy - try ctx.push(cbk); - return client.async_connectTcp(proxy.host, proxy.port, proxy.protocol, ctx, onConnectProxy); -} - -pub const RequestError = ConnectTcpError || ConnectErrorPartial || Request.SendError || - std.fmt.ParseIntError || Connection.WriteError || - error{ // TODO: file a zig fmt issue for this bad indentation - UnsupportedUriScheme, - UriMissingHost, - - CertificateBundleLoadFailure, - UnsupportedTransferEncoding, -}; - -pub const RequestOptions = struct { - version: http.Version = .@"HTTP/1.1", - - /// Automatically ignore 100 Continue responses. This assumes you don't - /// care, and will have sent the body before you wait for the response. - /// - /// If this is not the case AND you know the server will send a 100 - /// Continue, set this to false and wait for a response before sending the - /// body. If you wait AND the server does not send a 100 Continue before - /// you finish the request, then the request *will* deadlock. - handle_continue: bool = true, - - /// If false, close the connection after the one request. If true, - /// participate in the client connection pool. - keep_alive: bool = true, - - /// This field specifies whether to automatically follow redirects, and if - /// so, how many redirects to follow before returning an error. - /// - /// This will only follow redirects for repeatable requests (ie. with no - /// payload or the server has acknowledged the payload). - redirect_behavior: Request.RedirectBehavior = @enumFromInt(3), - - /// Externally-owned memory used to store the server's entire HTTP header. - /// `error.HttpHeadersOversize` is returned from read() when a - /// client sends too many bytes of HTTP headers. - server_header_buffer: []u8, - - /// Must be an already acquired connection. - connection: ?*Connection = null, - - /// Standard headers that have default, but overridable, behavior. - headers: Request.Headers = .{}, - /// These headers are kept including when following a redirect to a - /// different domain. - /// Externally-owned; must outlive the Request. - extra_headers: []const http.Header = &.{}, - /// These headers are stripped when following a redirect to a different - /// domain. - /// Externally-owned; must outlive the Request. - privileged_headers: []const http.Header = &.{}, -}; - -const protocol_map = std.StaticStringMap(Connection.Protocol).initComptime(.{ - .{ "http", .plain }, - .{ "ws", .plain }, - .{ "https", .tls }, - .{ "wss", .tls }, -}); - -fn validateUri(uri: Uri, arena: Allocator) !struct { Connection.Protocol, Uri } { - const protocol = protocol_map.get(uri.scheme) orelse return error.UnsupportedUriScheme; - var valid_uri = uri; - // The host is always going to be needed as a raw string for hostname resolution anyway. - valid_uri.host = .{ - .raw = try (uri.host orelse return error.UriMissingHost).toRawMaybeAlloc(arena), - }; - return .{ protocol, valid_uri }; -} - -fn uriPort(uri: Uri, protocol: Connection.Protocol) u16 { - return uri.port orelse switch (protocol) { - .plain => 80, - .tls => 443, - }; -} - -pub fn create( - client: *Client, - method: http.Method, - uri: Uri, - options: RequestOptions, -) RequestError!Request { - if (std.debug.runtime_safety) { - for (options.extra_headers) |header| { - assert(header.name.len != 0); - assert(std.mem.indexOfScalar(u8, header.name, ':') == null); - assert(std.mem.indexOfPosLinear(u8, header.name, 0, "\r\n") == null); - assert(std.mem.indexOfPosLinear(u8, header.value, 0, "\r\n") == null); - } - for (options.privileged_headers) |header| { - assert(header.name.len != 0); - assert(std.mem.indexOfPosLinear(u8, header.name, 0, "\r\n") == null); - assert(std.mem.indexOfPosLinear(u8, header.value, 0, "\r\n") == null); - } - } - - var server_header = std.heap.FixedBufferAllocator.init(options.server_header_buffer); - _, const valid_uri = try validateUri(uri, server_header.allocator()); - - var req: Request = .{ - .uri = valid_uri, - .client = client, - .keep_alive = options.keep_alive, - .method = method, - .version = options.version, - .transfer_encoding = .none, - .redirect_behavior = options.redirect_behavior, - .handle_continue = options.handle_continue, - .response = .{ - .version = undefined, - .status = undefined, - .reason = undefined, - .keep_alive = undefined, - .parser = proto.HeadersParser.init(server_header.buffer[server_header.end_index..]), - }, - .headers = options.headers, - .extra_headers = options.extra_headers, - .privileged_headers = options.privileged_headers, - }; - errdefer req.deinit(); - - return req; -} - -/// Open a connection to the host specified by `uri` and prepare to send a HTTP request. -/// -/// `uri` must remain alive during the entire request. -/// -/// The caller is responsible for calling `deinit()` on the `Request`. -/// This function is threadsafe. -/// -/// Asserts that "\r\n" does not occur in any header name or value. -pub fn open( - client: *Client, - method: http.Method, - uri: Uri, - options: RequestOptions, -) RequestError!Request { - if (std.debug.runtime_safety) { - for (options.extra_headers) |header| { - assert(header.name.len != 0); - assert(std.mem.indexOfScalar(u8, header.name, ':') == null); - assert(std.mem.indexOfPosLinear(u8, header.name, 0, "\r\n") == null); - assert(std.mem.indexOfPosLinear(u8, header.value, 0, "\r\n") == null); - } - for (options.privileged_headers) |header| { - assert(header.name.len != 0); - assert(std.mem.indexOfPosLinear(u8, header.name, 0, "\r\n") == null); - assert(std.mem.indexOfPosLinear(u8, header.value, 0, "\r\n") == null); - } - } - - var server_header = std.heap.FixedBufferAllocator.init(options.server_header_buffer); - const protocol, const valid_uri = try validateUri(uri, server_header.allocator()); - - if (protocol == .tls and @atomicLoad(bool, &client.next_https_rescan_certs, .acquire)) { - if (disable_tls) unreachable; - - client.ca_bundle_mutex.lock(); - defer client.ca_bundle_mutex.unlock(); - - if (client.next_https_rescan_certs) { - client.ca_bundle.rescan(client.allocator) catch - return error.CertificateBundleLoadFailure; - @atomicStore(bool, &client.next_https_rescan_certs, false, .release); - } - } - - const conn = options.connection orelse - try client.connect(valid_uri.host.?.raw, uriPort(valid_uri, protocol), protocol); - - var req: Request = .{ - .uri = valid_uri, - .client = client, - .connection = conn, - .keep_alive = options.keep_alive, - .method = method, - .version = options.version, - .transfer_encoding = .none, - .redirect_behavior = options.redirect_behavior, - .handle_continue = options.handle_continue, - .response = .{ - .version = undefined, - .status = undefined, - .reason = undefined, - .keep_alive = undefined, - .parser = proto.HeadersParser.init(server_header.buffer[server_header.end_index..]), - }, - .headers = options.headers, - .extra_headers = options.extra_headers, - .privileged_headers = options.privileged_headers, - }; - errdefer req.deinit(); - - return req; -} - -pub fn async_open( - client: *Client, - method: http.Method, - uri: Uri, - options: RequestOptions, - ctx: *Ctx, - comptime cbk: Cbk, -) !void { - if (std.debug.runtime_safety) { - for (options.extra_headers) |header| { - assert(header.name.len != 0); - assert(std.mem.indexOfScalar(u8, header.name, ':') == null); - assert(std.mem.indexOfPosLinear(u8, header.name, 0, "\r\n") == null); - assert(std.mem.indexOfPosLinear(u8, header.value, 0, "\r\n") == null); - } - for (options.privileged_headers) |header| { - assert(header.name.len != 0); - assert(std.mem.indexOfPosLinear(u8, header.name, 0, "\r\n") == null); - assert(std.mem.indexOfPosLinear(u8, header.value, 0, "\r\n") == null); - } - } - - var server_header = std.heap.FixedBufferAllocator.init(options.server_header_buffer); - const protocol, const valid_uri = try validateUri(uri, server_header.allocator()); - - if (protocol == .tls and @atomicLoad(bool, &client.next_https_rescan_certs, .acquire)) { - if (disable_tls) unreachable; - - client.ca_bundle_mutex.lock(); - defer client.ca_bundle_mutex.unlock(); - - if (client.next_https_rescan_certs) { - client.ca_bundle.rescan(client.allocator) catch return error.CertificateBundleLoadFailure; - @atomicStore(bool, &client.next_https_rescan_certs, false, .release); - } - } - - // add fields to request - ctx.req.uri = valid_uri; - ctx.req.keep_alive = options.keep_alive; - ctx.req.method = method; - ctx.req.transfer_encoding = .none; - ctx.req.redirect_behavior = options.redirect_behavior; - ctx.req.handle_continue = options.handle_continue; - ctx.req.headers = options.headers; - ctx.req.extra_headers = options.extra_headers; - ctx.req.privileged_headers = options.privileged_headers; - ctx.req.response = .{ - .version = undefined, - .status = undefined, - .reason = undefined, - .keep_alive = undefined, - .parser = proto.HeadersParser.init(server_header.buffer[server_header.end_index..]), - }; - - // we already have the connection, - // set it and call directly the callback - if (options.connection) |conn| { - ctx.req.connection = conn; - return cbk(ctx, {}); - } - - // push callback function - try ctx.push(cbk); - - const host = valid_uri.host orelse return error.UriMissingHost; - const port = uriPort(valid_uri, protocol); - - // add fields to connection - ctx.data.conn.protocol = protocol; - ctx.data.conn.host = try client.allocator.dupe(u8, host.raw); - ctx.data.conn.port = port; - - return client.async_connect(host.raw, port, protocol, ctx, setRequestConnection); -} - -pub const FetchOptions = struct { - server_header_buffer: ?[]u8 = null, - redirect_behavior: ?Request.RedirectBehavior = null, - - /// If the server sends a body, it will be appended to this ArrayList. - /// `max_append_size` provides an upper limit for how much they can grow. - response_storage: ResponseStorage = .ignore, - max_append_size: ?usize = null, - - location: Location, - method: ?http.Method = null, - payload: ?[]const u8 = null, - raw_uri: bool = false, - keep_alive: bool = true, - - /// Standard headers that have default, but overridable, behavior. - headers: Request.Headers = .{}, - /// These headers are kept including when following a redirect to a - /// different domain. - /// Externally-owned; must outlive the Request. - extra_headers: []const http.Header = &.{}, - /// These headers are stripped when following a redirect to a different - /// domain. - /// Externally-owned; must outlive the Request. - privileged_headers: []const http.Header = &.{}, - - pub const Location = union(enum) { - url: []const u8, - uri: Uri, - }; - - pub const ResponseStorage = union(enum) { - ignore, - /// Only the existing capacity will be used. - static: *std.ArrayListUnmanaged(u8), - dynamic: *std.ArrayList(u8), - }; -}; - -pub const FetchResult = struct { - status: http.Status, -}; - -// TODO: enable async_fetch -/// Perform a one-shot HTTP request with the provided options. -/// -/// This function is threadsafe. -pub fn fetch(client: *Client, options: FetchOptions) !FetchResult { - const uri = switch (options.location) { - .url => |u| try Uri.parse(u), - .uri => |u| u, - }; - var server_header_buffer: [16 * 1024]u8 = undefined; - - const method: http.Method = options.method orelse - if (options.payload != null) .POST else .GET; - - var req = try open(client, method, uri, .{ - .server_header_buffer = options.server_header_buffer orelse &server_header_buffer, - .redirect_behavior = options.redirect_behavior orelse - if (options.payload == null) @enumFromInt(3) else .unhandled, - .headers = options.headers, - .extra_headers = options.extra_headers, - .privileged_headers = options.privileged_headers, - .keep_alive = options.keep_alive, - }); - defer req.deinit(); - - if (options.payload) |payload| req.transfer_encoding = .{ .content_length = payload.len }; - - try req.send(); - - if (options.payload) |payload| try req.writeAll(payload); - - try req.finish(); - try req.wait(); - - switch (options.response_storage) { - .ignore => { - // Take advantage of request internals to discard the response body - // and make the connection available for another request. - req.response.skip = true; - assert(try req.transferRead(&.{}) == 0); // No buffer is necessary when skipping. - }, - .dynamic => |list| { - const max_append_size = options.max_append_size orelse 2 * 1024 * 1024; - try req.reader().readAllArrayList(list, max_append_size); - }, - .static => |list| { - const buf = b: { - const buf = list.unusedCapacitySlice(); - if (options.max_append_size) |len| { - if (len < buf.len) break :b buf[0..len]; - } - break :b buf; - }; - list.items.len += try req.reader().readAll(buf); - }, - } - - return .{ - .status = req.response.status, - }; -} - -pub const Cbk = fn (ctx: *Ctx, res: anyerror!void) anyerror!void; - -pub const Ctx = struct { - const Stack = GenericStack(Cbk); - - // temporary Data we need to store on the heap - // because of the callback execution model - const Data = struct { - list: *std.net.AddressList = undefined, - addr_current: usize = undefined, - socket: std.posix.socket_t = undefined, - - // TODO: we could remove this field as it is already set in ctx.req - // but we do not know for now what will be the impact to set those directly - // on the request, especially in case of error/cancellation - conn: *Connection, - }; - - req: *Request = undefined, - - userData: *anyopaque = undefined, - - loop: *Loop, - data: Data, - stack: ?*Stack = null, - err: ?anyerror = null, - - _buffer: ?[]const u8 = null, - _len: ?usize = null, - - _iovecs: []std.posix.iovec = undefined, - - // TLS readvAtLeast - // _off_i: usize = 0, - // _vec_i: usize = 0, - // _tls_len: usize = 0, - - // TLS readv - _vp: VecPut = undefined, - // _tls_read_buf contains the next decrypted buffer - _tls_read_buf: ?[]u8 = undefined, - _tls_read_content_type: tls23.proto.ContentType = undefined, - - // _tls_read_record contains the crypted record - _tls_read_record: ?tls23.record.Record = null, - - // TLS writeAll - _tls_write_bytes: []const u8 = undefined, - _tls_write_index: usize = 0, - _tls_write_buf: [cipher.max_ciphertext_record_len]u8 = undefined, - - pub fn init(loop: *Loop, req: *Request) !Ctx { - const connection = try req.client.allocator.create(Connection); - connection.* = .{ - .stream = undefined, - .tls_client = undefined, - .protocol = undefined, - .host = undefined, - .port = undefined, - }; - return .{ - .req = req, - .loop = loop, - .data = .{ .conn = connection }, - }; - } - - pub fn setErr(self: *Ctx, err: anyerror) void { - self.err = err; - } - - pub fn push(self: *Ctx, comptime func: Stack.Fn) !void { - if (self.stack) |stack| { - return try stack.push(self.alloc(), func); - } - self.stack = try Stack.init(self.alloc(), func); - } - - pub fn pop(self: *Ctx, res: anyerror!void) !void { - if (self.stack) |stack| { - const allocator = self.alloc(); - const func = stack.pop(allocator, null); - - defer { - if (stack.next == null) { - allocator.destroy(stack); - self.stack = null; - } - } - - return @call(.auto, func, .{ self, res }); - } - unreachable; - } - - pub fn deinit(self: Ctx) void { - if (self.stack) |stack| { - stack.deinit(self.alloc(), null); - } - } - - // not sure about those - - pub fn len(self: Ctx) usize { - if (self._len == null) unreachable; - return self._len.?; - } - - pub fn setLen(self: *Ctx, nb: ?usize) void { - self._len = nb; - } - - pub fn buf(self: Ctx) []const u8 { - if (self._buffer == null) unreachable; - return self._buffer.?; - } - - pub fn setBuf(self: *Ctx, bytes: ?[]const u8) void { - self._buffer = bytes; - } - - // ctx Request aliases - - pub fn alloc(self: Ctx) std.mem.Allocator { - return self.req.client.allocator; - } - - pub fn conn(self: Ctx) *Connection { - return self.req.connection.?; - } - - pub fn stream(self: Ctx) net.Stream { - return self.conn().stream; - } -}; - -// requires ctx.data.conn to be set -fn setRequestConnection(ctx: *Ctx, res: anyerror!void) anyerror!void { - res catch |e| return ctx.pop(e); - - ctx.req.connection = ctx.data.conn; - return ctx.pop({}); -} diff --git a/src/http/async/std/http/Server.zig b/src/http/async/std/http/Server.zig deleted file mode 100644 index 38d3f133..00000000 --- a/src/http/async/std/http/Server.zig +++ /dev/null @@ -1,1148 +0,0 @@ -//! Blocking HTTP server implementation. -//! Handles a single connection's lifecycle. - -connection: net.Server.Connection, -/// Keeps track of whether the Server is ready to accept a new request on the -/// same connection, and makes invalid API usage cause assertion failures -/// rather than HTTP protocol violations. -state: State, -/// User-provided buffer that must outlive this Server. -/// Used to store the client's entire HTTP header. -read_buffer: []u8, -/// Amount of available data inside read_buffer. -read_buffer_len: usize, -/// Index into `read_buffer` of the first byte of the next HTTP request. -next_request_start: usize, - -pub const State = enum { - /// The connection is available to be used for the first time, or reused. - ready, - /// An error occurred in `receiveHead`. - receiving_head, - /// A Request object has been obtained and from there a Response can be - /// opened. - received_head, - /// The client is uploading something to this Server. - receiving_body, - /// The connection is eligible for another HTTP request, however the client - /// and server did not negotiate a persistent connection. - closing, -}; - -/// Initialize an HTTP server that can respond to multiple requests on the same -/// connection. -/// The returned `Server` is ready for `receiveHead` to be called. -pub fn init(connection: net.Server.Connection, read_buffer: []u8) Server { - return .{ - .connection = connection, - .state = .ready, - .read_buffer = read_buffer, - .read_buffer_len = 0, - .next_request_start = 0, - }; -} - -pub const ReceiveHeadError = error{ - /// Client sent too many bytes of HTTP headers. - /// The HTTP specification suggests to respond with a 431 status code - /// before closing the connection. - HttpHeadersOversize, - /// Client sent headers that did not conform to the HTTP protocol. - HttpHeadersInvalid, - /// A low level I/O error occurred trying to read the headers. - HttpHeadersUnreadable, - /// Partial HTTP request was received but the connection was closed before - /// fully receiving the headers. - HttpRequestTruncated, - /// The client sent 0 bytes of headers before closing the stream. - /// In other words, a keep-alive connection was finally closed. - HttpConnectionClosing, -}; - -/// The header bytes reference the read buffer that Server was initialized with -/// and remain alive until the next call to receiveHead. -pub fn receiveHead(s: *Server) ReceiveHeadError!Request { - assert(s.state == .ready); - s.state = .received_head; - errdefer s.state = .receiving_head; - - // In case of a reused connection, move the next request's bytes to the - // beginning of the buffer. - if (s.next_request_start > 0) { - if (s.read_buffer_len > s.next_request_start) { - rebase(s, 0); - } else { - s.read_buffer_len = 0; - } - } - - var hp: http.HeadParser = .{}; - - if (s.read_buffer_len > 0) { - const bytes = s.read_buffer[0..s.read_buffer_len]; - const end = hp.feed(bytes); - if (hp.state == .finished) - return finishReceivingHead(s, end); - } - - while (true) { - const buf = s.read_buffer[s.read_buffer_len..]; - if (buf.len == 0) - return error.HttpHeadersOversize; - const read_n = s.connection.stream.read(buf) catch - return error.HttpHeadersUnreadable; - if (read_n == 0) { - if (s.read_buffer_len > 0) { - return error.HttpRequestTruncated; - } else { - return error.HttpConnectionClosing; - } - } - s.read_buffer_len += read_n; - const bytes = buf[0..read_n]; - const end = hp.feed(bytes); - if (hp.state == .finished) - return finishReceivingHead(s, s.read_buffer_len - bytes.len + end); - } -} - -fn finishReceivingHead(s: *Server, head_end: usize) ReceiveHeadError!Request { - return .{ - .server = s, - .head_end = head_end, - .head = Request.Head.parse(s.read_buffer[0..head_end]) catch - return error.HttpHeadersInvalid, - .reader_state = undefined, - }; -} - -pub const Request = struct { - server: *Server, - /// Index into Server's read_buffer. - head_end: usize, - head: Head, - reader_state: union { - remaining_content_length: u64, - chunk_parser: http.ChunkParser, - }, - - pub const Compression = union(enum) { - pub const DeflateDecompressor = std.compress.zlib.Decompressor(std.io.AnyReader); - pub const GzipDecompressor = std.compress.gzip.Decompressor(std.io.AnyReader); - pub const ZstdDecompressor = std.compress.zstd.Decompressor(std.io.AnyReader); - - deflate: DeflateDecompressor, - gzip: GzipDecompressor, - zstd: ZstdDecompressor, - none: void, - }; - - pub const Head = struct { - method: http.Method, - target: []const u8, - version: http.Version, - expect: ?[]const u8, - content_type: ?[]const u8, - content_length: ?u64, - transfer_encoding: http.TransferEncoding, - transfer_compression: http.ContentEncoding, - keep_alive: bool, - compression: Compression, - - pub const ParseError = error{ - UnknownHttpMethod, - HttpHeadersInvalid, - HttpHeaderContinuationsUnsupported, - HttpTransferEncodingUnsupported, - HttpConnectionHeaderUnsupported, - InvalidContentLength, - CompressionUnsupported, - MissingFinalNewline, - }; - - pub fn parse(bytes: []const u8) ParseError!Head { - var it = mem.splitSequence(u8, bytes, "\r\n"); - - const first_line = it.next().?; - if (first_line.len < 10) - return error.HttpHeadersInvalid; - - const method_end = mem.indexOfScalar(u8, first_line, ' ') orelse - return error.HttpHeadersInvalid; - if (method_end > 24) return error.HttpHeadersInvalid; - - const method_str = first_line[0..method_end]; - const method: http.Method = @enumFromInt(http.Method.parse(method_str)); - - const version_start = mem.lastIndexOfScalar(u8, first_line, ' ') orelse - return error.HttpHeadersInvalid; - if (version_start == method_end) return error.HttpHeadersInvalid; - - const version_str = first_line[version_start + 1 ..]; - if (version_str.len != 8) return error.HttpHeadersInvalid; - const version: http.Version = switch (int64(version_str[0..8])) { - int64("HTTP/1.0") => .@"HTTP/1.0", - int64("HTTP/1.1") => .@"HTTP/1.1", - else => return error.HttpHeadersInvalid, - }; - - const target = first_line[method_end + 1 .. version_start]; - - var head: Head = .{ - .method = method, - .target = target, - .version = version, - .expect = null, - .content_type = null, - .content_length = null, - .transfer_encoding = .none, - .transfer_compression = .identity, - .keep_alive = switch (version) { - .@"HTTP/1.0" => false, - .@"HTTP/1.1" => true, - }, - .compression = .none, - }; - - while (it.next()) |line| { - if (line.len == 0) return head; - switch (line[0]) { - ' ', '\t' => return error.HttpHeaderContinuationsUnsupported, - else => {}, - } - - var line_it = mem.splitScalar(u8, line, ':'); - const header_name = line_it.next().?; - const header_value = mem.trim(u8, line_it.rest(), " \t"); - if (header_name.len == 0) return error.HttpHeadersInvalid; - - if (std.ascii.eqlIgnoreCase(header_name, "connection")) { - head.keep_alive = !std.ascii.eqlIgnoreCase(header_value, "close"); - } else if (std.ascii.eqlIgnoreCase(header_name, "expect")) { - head.expect = header_value; - } else if (std.ascii.eqlIgnoreCase(header_name, "content-type")) { - head.content_type = header_value; - } else if (std.ascii.eqlIgnoreCase(header_name, "content-length")) { - if (head.content_length != null) return error.HttpHeadersInvalid; - head.content_length = std.fmt.parseInt(u64, header_value, 10) catch - return error.InvalidContentLength; - } else if (std.ascii.eqlIgnoreCase(header_name, "content-encoding")) { - if (head.transfer_compression != .identity) return error.HttpHeadersInvalid; - - const trimmed = mem.trim(u8, header_value, " "); - - if (std.meta.stringToEnum(http.ContentEncoding, trimmed)) |ce| { - head.transfer_compression = ce; - } else { - return error.HttpTransferEncodingUnsupported; - } - } else if (std.ascii.eqlIgnoreCase(header_name, "transfer-encoding")) { - // Transfer-Encoding: second, first - // Transfer-Encoding: deflate, chunked - var iter = mem.splitBackwardsScalar(u8, header_value, ','); - - const first = iter.first(); - const trimmed_first = mem.trim(u8, first, " "); - - var next: ?[]const u8 = first; - if (std.meta.stringToEnum(http.TransferEncoding, trimmed_first)) |transfer| { - if (head.transfer_encoding != .none) - return error.HttpHeadersInvalid; // we already have a transfer encoding - head.transfer_encoding = transfer; - - next = iter.next(); - } - - if (next) |second| { - const trimmed_second = mem.trim(u8, second, " "); - - if (std.meta.stringToEnum(http.ContentEncoding, trimmed_second)) |transfer| { - if (head.transfer_compression != .identity) - return error.HttpHeadersInvalid; // double compression is not supported - head.transfer_compression = transfer; - } else { - return error.HttpTransferEncodingUnsupported; - } - } - - if (iter.next()) |_| return error.HttpTransferEncodingUnsupported; - } - } - return error.MissingFinalNewline; - } - - test parse { - const request_bytes = "GET /hi HTTP/1.0\r\n" ++ - "content-tYpe: text/plain\r\n" ++ - "content-Length:10\r\n" ++ - "expeCt: 100-continue \r\n" ++ - "TRansfer-encoding:\tdeflate, chunked \r\n" ++ - "connectioN:\t keep-alive \r\n\r\n"; - - const req = try parse(request_bytes); - - try testing.expectEqual(.GET, req.method); - try testing.expectEqual(.@"HTTP/1.0", req.version); - try testing.expectEqualStrings("/hi", req.target); - - try testing.expectEqualStrings("text/plain", req.content_type.?); - try testing.expectEqualStrings("100-continue", req.expect.?); - - try testing.expectEqual(true, req.keep_alive); - try testing.expectEqual(10, req.content_length.?); - try testing.expectEqual(.chunked, req.transfer_encoding); - try testing.expectEqual(.deflate, req.transfer_compression); - } - - inline fn int64(array: *const [8]u8) u64 { - return @bitCast(array.*); - } - }; - - pub fn iterateHeaders(r: *Request) http.HeaderIterator { - return http.HeaderIterator.init(r.server.read_buffer[0..r.head_end]); - } - - test iterateHeaders { - const request_bytes = "GET /hi HTTP/1.0\r\n" ++ - "content-tYpe: text/plain\r\n" ++ - "content-Length:10\r\n" ++ - "expeCt: 100-continue \r\n" ++ - "TRansfer-encoding:\tdeflate, chunked \r\n" ++ - "connectioN:\t keep-alive \r\n\r\n"; - - var read_buffer: [500]u8 = undefined; - @memcpy(read_buffer[0..request_bytes.len], request_bytes); - - var server: Server = .{ - .connection = undefined, - .state = .ready, - .read_buffer = &read_buffer, - .read_buffer_len = request_bytes.len, - .next_request_start = 0, - }; - - var request: Request = .{ - .server = &server, - .head_end = request_bytes.len, - .head = undefined, - .reader_state = undefined, - }; - - var it = request.iterateHeaders(); - { - const header = it.next().?; - try testing.expectEqualStrings("content-tYpe", header.name); - try testing.expectEqualStrings("text/plain", header.value); - try testing.expect(!it.is_trailer); - } - { - const header = it.next().?; - try testing.expectEqualStrings("content-Length", header.name); - try testing.expectEqualStrings("10", header.value); - try testing.expect(!it.is_trailer); - } - { - const header = it.next().?; - try testing.expectEqualStrings("expeCt", header.name); - try testing.expectEqualStrings("100-continue", header.value); - try testing.expect(!it.is_trailer); - } - { - const header = it.next().?; - try testing.expectEqualStrings("TRansfer-encoding", header.name); - try testing.expectEqualStrings("deflate, chunked", header.value); - try testing.expect(!it.is_trailer); - } - { - const header = it.next().?; - try testing.expectEqualStrings("connectioN", header.name); - try testing.expectEqualStrings("keep-alive", header.value); - try testing.expect(!it.is_trailer); - } - try testing.expectEqual(null, it.next()); - } - - pub const RespondOptions = struct { - version: http.Version = .@"HTTP/1.1", - status: http.Status = .ok, - reason: ?[]const u8 = null, - keep_alive: bool = true, - extra_headers: []const http.Header = &.{}, - transfer_encoding: ?http.TransferEncoding = null, - }; - - /// Send an entire HTTP response to the client, including headers and body. - /// - /// Automatically handles HEAD requests by omitting the body. - /// - /// Unless `transfer_encoding` is specified, uses the "content-length" - /// header. - /// - /// If the request contains a body and the connection is to be reused, - /// discards the request body, leaving the Server in the `ready` state. If - /// this discarding fails, the connection is marked as not to be reused and - /// no error is surfaced. - /// - /// Asserts status is not `continue`. - /// Asserts there are at most 25 extra_headers. - /// Asserts that "\r\n" does not occur in any header name or value. - pub fn respond( - request: *Request, - content: []const u8, - options: RespondOptions, - ) Response.WriteError!void { - const max_extra_headers = 25; - assert(options.status != .@"continue"); - assert(options.extra_headers.len <= max_extra_headers); - if (std.debug.runtime_safety) { - for (options.extra_headers) |header| { - assert(header.name.len != 0); - assert(std.mem.indexOfScalar(u8, header.name, ':') == null); - assert(std.mem.indexOfPosLinear(u8, header.name, 0, "\r\n") == null); - assert(std.mem.indexOfPosLinear(u8, header.value, 0, "\r\n") == null); - } - } - - const transfer_encoding_none = (options.transfer_encoding orelse .chunked) == .none; - const server_keep_alive = !transfer_encoding_none and options.keep_alive; - const keep_alive = request.discardBody(server_keep_alive); - - const phrase = options.reason orelse options.status.phrase() orelse ""; - - var first_buffer: [500]u8 = undefined; - var h = std.ArrayListUnmanaged(u8).initBuffer(&first_buffer); - if (request.head.expect != null) { - // reader() and hence discardBody() above sets expect to null if it - // is handled. So the fact that it is not null here means unhandled. - h.appendSliceAssumeCapacity("HTTP/1.1 417 Expectation Failed\r\n"); - if (!keep_alive) h.appendSliceAssumeCapacity("connection: close\r\n"); - h.appendSliceAssumeCapacity("content-length: 0\r\n\r\n"); - try request.server.connection.stream.writeAll(h.items); - return; - } - h.fixedWriter().print("{s} {d} {s}\r\n", .{ - @tagName(options.version), @intFromEnum(options.status), phrase, - }) catch unreachable; - - switch (options.version) { - .@"HTTP/1.0" => if (keep_alive) h.appendSliceAssumeCapacity("connection: keep-alive\r\n"), - .@"HTTP/1.1" => if (!keep_alive) h.appendSliceAssumeCapacity("connection: close\r\n"), - } - - if (options.transfer_encoding) |transfer_encoding| switch (transfer_encoding) { - .none => {}, - .chunked => h.appendSliceAssumeCapacity("transfer-encoding: chunked\r\n"), - } else { - h.fixedWriter().print("content-length: {d}\r\n", .{content.len}) catch unreachable; - } - - var chunk_header_buffer: [18]u8 = undefined; - var iovecs: [max_extra_headers * 4 + 3]std.posix.iovec_const = undefined; - var iovecs_len: usize = 0; - - iovecs[iovecs_len] = .{ - .base = h.items.ptr, - .len = h.items.len, - }; - iovecs_len += 1; - - for (options.extra_headers) |header| { - iovecs[iovecs_len] = .{ - .base = header.name.ptr, - .len = header.name.len, - }; - iovecs_len += 1; - - iovecs[iovecs_len] = .{ - .base = ": ", - .len = 2, - }; - iovecs_len += 1; - - if (header.value.len != 0) { - iovecs[iovecs_len] = .{ - .base = header.value.ptr, - .len = header.value.len, - }; - iovecs_len += 1; - } - - iovecs[iovecs_len] = .{ - .base = "\r\n", - .len = 2, - }; - iovecs_len += 1; - } - - iovecs[iovecs_len] = .{ - .base = "\r\n", - .len = 2, - }; - iovecs_len += 1; - - if (request.head.method != .HEAD) { - const is_chunked = (options.transfer_encoding orelse .none) == .chunked; - if (is_chunked) { - if (content.len > 0) { - const chunk_header = std.fmt.bufPrint( - &chunk_header_buffer, - "{x}\r\n", - .{content.len}, - ) catch unreachable; - - iovecs[iovecs_len] = .{ - .base = chunk_header.ptr, - .len = chunk_header.len, - }; - iovecs_len += 1; - - iovecs[iovecs_len] = .{ - .base = content.ptr, - .len = content.len, - }; - iovecs_len += 1; - - iovecs[iovecs_len] = .{ - .base = "\r\n", - .len = 2, - }; - iovecs_len += 1; - } - - iovecs[iovecs_len] = .{ - .base = "0\r\n\r\n", - .len = 5, - }; - iovecs_len += 1; - } else if (content.len > 0) { - iovecs[iovecs_len] = .{ - .base = content.ptr, - .len = content.len, - }; - iovecs_len += 1; - } - } - - try request.server.connection.stream.writevAll(iovecs[0..iovecs_len]); - } - - pub const RespondStreamingOptions = struct { - /// An externally managed slice of memory used to batch bytes before - /// sending. `respondStreaming` asserts this is large enough to store - /// the full HTTP response head. - /// - /// Must outlive the returned Response. - send_buffer: []u8, - /// If provided, the response will use the content-length header; - /// otherwise it will use transfer-encoding: chunked. - content_length: ?u64 = null, - /// Options that are shared with the `respond` method. - respond_options: RespondOptions = .{}, - }; - - /// The header is buffered but not sent until Response.flush is called. - /// - /// If the request contains a body and the connection is to be reused, - /// discards the request body, leaving the Server in the `ready` state. If - /// this discarding fails, the connection is marked as not to be reused and - /// no error is surfaced. - /// - /// HEAD requests are handled transparently by setting a flag on the - /// returned Response to omit the body. However it may be worth noticing - /// that flag and skipping any expensive work that would otherwise need to - /// be done to satisfy the request. - /// - /// Asserts `send_buffer` is large enough to store the entire response header. - /// Asserts status is not `continue`. - pub fn respondStreaming(request: *Request, options: RespondStreamingOptions) Response { - const o = options.respond_options; - assert(o.status != .@"continue"); - const transfer_encoding_none = (o.transfer_encoding orelse .chunked) == .none; - const server_keep_alive = !transfer_encoding_none and o.keep_alive; - const keep_alive = request.discardBody(server_keep_alive); - const phrase = o.reason orelse o.status.phrase() orelse ""; - - var h = std.ArrayListUnmanaged(u8).initBuffer(options.send_buffer); - - const elide_body = if (request.head.expect != null) eb: { - // reader() and hence discardBody() above sets expect to null if it - // is handled. So the fact that it is not null here means unhandled. - h.appendSliceAssumeCapacity("HTTP/1.1 417 Expectation Failed\r\n"); - if (!keep_alive) h.appendSliceAssumeCapacity("connection: close\r\n"); - h.appendSliceAssumeCapacity("content-length: 0\r\n\r\n"); - break :eb true; - } else eb: { - h.fixedWriter().print("{s} {d} {s}\r\n", .{ - @tagName(o.version), @intFromEnum(o.status), phrase, - }) catch unreachable; - - switch (o.version) { - .@"HTTP/1.0" => if (keep_alive) h.appendSliceAssumeCapacity("connection: keep-alive\r\n"), - .@"HTTP/1.1" => if (!keep_alive) h.appendSliceAssumeCapacity("connection: close\r\n"), - } - - if (o.transfer_encoding) |transfer_encoding| switch (transfer_encoding) { - .chunked => h.appendSliceAssumeCapacity("transfer-encoding: chunked\r\n"), - .none => {}, - } else if (options.content_length) |len| { - h.fixedWriter().print("content-length: {d}\r\n", .{len}) catch unreachable; - } else { - h.appendSliceAssumeCapacity("transfer-encoding: chunked\r\n"); - } - - for (o.extra_headers) |header| { - assert(header.name.len != 0); - h.appendSliceAssumeCapacity(header.name); - h.appendSliceAssumeCapacity(": "); - h.appendSliceAssumeCapacity(header.value); - h.appendSliceAssumeCapacity("\r\n"); - } - - h.appendSliceAssumeCapacity("\r\n"); - break :eb request.head.method == .HEAD; - }; - - return .{ - .stream = request.server.connection.stream, - .send_buffer = options.send_buffer, - .send_buffer_start = 0, - .send_buffer_end = h.items.len, - .transfer_encoding = if (o.transfer_encoding) |te| switch (te) { - .chunked => .chunked, - .none => .none, - } else if (options.content_length) |len| .{ - .content_length = len, - } else .chunked, - .elide_body = elide_body, - .chunk_len = 0, - }; - } - - pub const ReadError = net.Stream.ReadError || error{ - HttpChunkInvalid, - HttpHeadersOversize, - }; - - fn read_cl(context: *const anyopaque, buffer: []u8) ReadError!usize { - const request: *Request = @constCast(@alignCast(@ptrCast(context))); - const s = request.server; - - const remaining_content_length = &request.reader_state.remaining_content_length; - if (remaining_content_length.* == 0) { - s.state = .ready; - return 0; - } - assert(s.state == .receiving_body); - const available = try fill(s, request.head_end); - const len = @min(remaining_content_length.*, available.len, buffer.len); - @memcpy(buffer[0..len], available[0..len]); - remaining_content_length.* -= len; - s.next_request_start += len; - if (remaining_content_length.* == 0) - s.state = .ready; - return len; - } - - fn fill(s: *Server, head_end: usize) ReadError![]u8 { - const available = s.read_buffer[s.next_request_start..s.read_buffer_len]; - if (available.len > 0) return available; - s.next_request_start = head_end; - s.read_buffer_len = head_end + try s.connection.stream.read(s.read_buffer[head_end..]); - return s.read_buffer[head_end..s.read_buffer_len]; - } - - fn read_chunked(context: *const anyopaque, buffer: []u8) ReadError!usize { - const request: *Request = @constCast(@alignCast(@ptrCast(context))); - const s = request.server; - - const cp = &request.reader_state.chunk_parser; - const head_end = request.head_end; - - // Protect against returning 0 before the end of stream. - var out_end: usize = 0; - while (out_end == 0) { - switch (cp.state) { - .invalid => return 0, - .data => { - assert(s.state == .receiving_body); - const available = try fill(s, head_end); - const len = @min(cp.chunk_len, available.len, buffer.len); - @memcpy(buffer[0..len], available[0..len]); - cp.chunk_len -= len; - if (cp.chunk_len == 0) - cp.state = .data_suffix; - out_end += len; - s.next_request_start += len; - continue; - }, - else => { - assert(s.state == .receiving_body); - const available = try fill(s, head_end); - const n = cp.feed(available); - switch (cp.state) { - .invalid => return error.HttpChunkInvalid, - .data => { - if (cp.chunk_len == 0) { - // The next bytes in the stream are trailers, - // or \r\n to indicate end of chunked body. - // - // This function must append the trailers at - // head_end so that headers and trailers are - // together. - // - // Since returning 0 would indicate end of - // stream, this function must read all the - // trailers before returning. - if (s.next_request_start > head_end) rebase(s, head_end); - var hp: http.HeadParser = .{}; - { - const bytes = s.read_buffer[head_end..s.read_buffer_len]; - const end = hp.feed(bytes); - if (hp.state == .finished) { - cp.state = .invalid; - s.state = .ready; - s.next_request_start = s.read_buffer_len - bytes.len + end; - return out_end; - } - } - while (true) { - const buf = s.read_buffer[s.read_buffer_len..]; - if (buf.len == 0) - return error.HttpHeadersOversize; - const read_n = try s.connection.stream.read(buf); - s.read_buffer_len += read_n; - const bytes = buf[0..read_n]; - const end = hp.feed(bytes); - if (hp.state == .finished) { - cp.state = .invalid; - s.state = .ready; - s.next_request_start = s.read_buffer_len - bytes.len + end; - return out_end; - } - } - } - const data = available[n..]; - const len = @min(cp.chunk_len, data.len, buffer.len); - @memcpy(buffer[0..len], data[0..len]); - cp.chunk_len -= len; - if (cp.chunk_len == 0) - cp.state = .data_suffix; - out_end += len; - s.next_request_start += n + len; - continue; - }, - else => continue, - } - }, - } - } - return out_end; - } - - pub const ReaderError = Response.WriteError || error{ - /// The client sent an expect HTTP header value other than - /// "100-continue". - HttpExpectationFailed, - }; - - /// In the case that the request contains "expect: 100-continue", this - /// function writes the continuation header, which means it can fail with a - /// write error. After sending the continuation header, it sets the - /// request's expect field to `null`. - /// - /// Asserts that this function is only called once. - pub fn reader(request: *Request) ReaderError!std.io.AnyReader { - const s = request.server; - assert(s.state == .received_head); - s.state = .receiving_body; - s.next_request_start = request.head_end; - - if (request.head.expect) |expect| { - if (mem.eql(u8, expect, "100-continue")) { - try request.server.connection.stream.writeAll("HTTP/1.1 100 Continue\r\n\r\n"); - request.head.expect = null; - } else { - return error.HttpExpectationFailed; - } - } - - switch (request.head.transfer_encoding) { - .chunked => { - request.reader_state = .{ .chunk_parser = http.ChunkParser.init }; - return .{ - .readFn = read_chunked, - .context = request, - }; - }, - .none => { - request.reader_state = .{ - .remaining_content_length = request.head.content_length orelse 0, - }; - return .{ - .readFn = read_cl, - .context = request, - }; - }, - } - } - - /// Returns whether the connection should remain persistent. - /// If it would fail, it instead sets the Server state to `receiving_body` - /// and returns false. - fn discardBody(request: *Request, keep_alive: bool) bool { - // Prepare to receive another request on the same connection. - // There are two factors to consider: - // * Any body the client sent must be discarded. - // * The Server's read_buffer may already have some bytes in it from - // whatever came after the head, which may be the next HTTP request - // or the request body. - // If the connection won't be kept alive, then none of this matters - // because the connection will be severed after the response is sent. - const s = request.server; - if (keep_alive and request.head.keep_alive) switch (s.state) { - .received_head => { - const r = request.reader() catch return false; - _ = r.discard() catch return false; - assert(s.state == .ready); - return true; - }, - .receiving_body, .ready => return true, - else => unreachable, - }; - - // Avoid clobbering the state in case a reading stream already exists. - switch (s.state) { - .received_head => s.state = .closing, - else => {}, - } - return false; - } -}; - -pub const Response = struct { - stream: net.Stream, - send_buffer: []u8, - /// Index of the first byte in `send_buffer`. - /// This is 0 unless a short write happens in `write`. - send_buffer_start: usize, - /// Index of the last byte + 1 in `send_buffer`. - send_buffer_end: usize, - /// `null` means transfer-encoding: chunked. - /// As a debugging utility, counts down to zero as bytes are written. - transfer_encoding: TransferEncoding, - elide_body: bool, - /// Indicates how much of the end of the `send_buffer` corresponds to a - /// chunk. This amount of data will be wrapped by an HTTP chunk header. - chunk_len: usize, - - pub const TransferEncoding = union(enum) { - /// End of connection signals the end of the stream. - none, - /// As a debugging utility, counts down to zero as bytes are written. - content_length: u64, - /// Each chunk is wrapped in a header and trailer. - chunked, - }; - - pub const WriteError = net.Stream.WriteError; - - /// When using content-length, asserts that the amount of data sent matches - /// the value sent in the header, then calls `flush`. - /// Otherwise, transfer-encoding: chunked is being used, and it writes the - /// end-of-stream message, then flushes the stream to the system. - /// Respects the value of `elide_body` to omit all data after the headers. - pub fn end(r: *Response) WriteError!void { - switch (r.transfer_encoding) { - .content_length => |len| { - assert(len == 0); // Trips when end() called before all bytes written. - try flush_cl(r); - }, - .none => { - try flush_cl(r); - }, - .chunked => { - try flush_chunked(r, &.{}); - }, - } - r.* = undefined; - } - - pub const EndChunkedOptions = struct { - trailers: []const http.Header = &.{}, - }; - - /// Asserts that the Response is using transfer-encoding: chunked. - /// Writes the end-of-stream message and any optional trailers, then - /// flushes the stream to the system. - /// Respects the value of `elide_body` to omit all data after the headers. - /// Asserts there are at most 25 trailers. - pub fn endChunked(r: *Response, options: EndChunkedOptions) WriteError!void { - assert(r.transfer_encoding == .chunked); - try flush_chunked(r, options.trailers); - r.* = undefined; - } - - /// If using content-length, asserts that writing these bytes to the client - /// would not exceed the content-length value sent in the HTTP header. - /// May return 0, which does not indicate end of stream. The caller decides - /// when the end of stream occurs by calling `end`. - pub fn write(r: *Response, bytes: []const u8) WriteError!usize { - switch (r.transfer_encoding) { - .content_length, .none => return write_cl(r, bytes), - .chunked => return write_chunked(r, bytes), - } - } - - fn write_cl(context: *const anyopaque, bytes: []const u8) WriteError!usize { - const r: *Response = @constCast(@alignCast(@ptrCast(context))); - - var trash: u64 = std.math.maxInt(u64); - const len = switch (r.transfer_encoding) { - .content_length => |*len| len, - else => &trash, - }; - - if (r.elide_body) { - len.* -= bytes.len; - return bytes.len; - } - - if (bytes.len + r.send_buffer_end > r.send_buffer.len) { - const send_buffer_len = r.send_buffer_end - r.send_buffer_start; - var iovecs: [2]std.posix.iovec_const = .{ - .{ - .base = r.send_buffer.ptr + r.send_buffer_start, - .len = send_buffer_len, - }, - .{ - .base = bytes.ptr, - .len = bytes.len, - }, - }; - const n = try r.stream.writev(&iovecs); - - if (n >= send_buffer_len) { - // It was enough to reset the buffer. - r.send_buffer_start = 0; - r.send_buffer_end = 0; - const bytes_n = n - send_buffer_len; - len.* -= bytes_n; - return bytes_n; - } - - // It didn't even make it through the existing buffer, let - // alone the new bytes provided. - r.send_buffer_start += n; - return 0; - } - - // All bytes can be stored in the remaining space of the buffer. - @memcpy(r.send_buffer[r.send_buffer_end..][0..bytes.len], bytes); - r.send_buffer_end += bytes.len; - len.* -= bytes.len; - return bytes.len; - } - - fn write_chunked(context: *const anyopaque, bytes: []const u8) WriteError!usize { - const r: *Response = @constCast(@alignCast(@ptrCast(context))); - assert(r.transfer_encoding == .chunked); - - if (r.elide_body) - return bytes.len; - - if (bytes.len + r.send_buffer_end > r.send_buffer.len) { - const send_buffer_len = r.send_buffer_end - r.send_buffer_start; - const chunk_len = r.chunk_len + bytes.len; - var header_buf: [18]u8 = undefined; - const chunk_header = std.fmt.bufPrint(&header_buf, "{x}\r\n", .{chunk_len}) catch unreachable; - - var iovecs: [5]std.posix.iovec_const = .{ - .{ - .base = r.send_buffer.ptr + r.send_buffer_start, - .len = send_buffer_len - r.chunk_len, - }, - .{ - .base = chunk_header.ptr, - .len = chunk_header.len, - }, - .{ - .base = r.send_buffer.ptr + r.send_buffer_end - r.chunk_len, - .len = r.chunk_len, - }, - .{ - .base = bytes.ptr, - .len = bytes.len, - }, - .{ - .base = "\r\n", - .len = 2, - }, - }; - // TODO make this writev instead of writevAll, which involves - // complicating the logic of this function. - try r.stream.writevAll(&iovecs); - r.send_buffer_start = 0; - r.send_buffer_end = 0; - r.chunk_len = 0; - return bytes.len; - } - - // All bytes can be stored in the remaining space of the buffer. - @memcpy(r.send_buffer[r.send_buffer_end..][0..bytes.len], bytes); - r.send_buffer_end += bytes.len; - r.chunk_len += bytes.len; - return bytes.len; - } - - /// If using content-length, asserts that writing these bytes to the client - /// would not exceed the content-length value sent in the HTTP header. - pub fn writeAll(r: *Response, bytes: []const u8) WriteError!void { - var index: usize = 0; - while (index < bytes.len) { - index += try write(r, bytes[index..]); - } - } - - /// Sends all buffered data to the client. - /// This is redundant after calling `end`. - /// Respects the value of `elide_body` to omit all data after the headers. - pub fn flush(r: *Response) WriteError!void { - switch (r.transfer_encoding) { - .none, .content_length => return flush_cl(r), - .chunked => return flush_chunked(r, null), - } - } - - fn flush_cl(r: *Response) WriteError!void { - try r.stream.writeAll(r.send_buffer[r.send_buffer_start..r.send_buffer_end]); - r.send_buffer_start = 0; - r.send_buffer_end = 0; - } - - fn flush_chunked(r: *Response, end_trailers: ?[]const http.Header) WriteError!void { - const max_trailers = 25; - if (end_trailers) |trailers| assert(trailers.len <= max_trailers); - assert(r.transfer_encoding == .chunked); - - const http_headers = r.send_buffer[r.send_buffer_start .. r.send_buffer_end - r.chunk_len]; - - if (r.elide_body) { - try r.stream.writeAll(http_headers); - r.send_buffer_start = 0; - r.send_buffer_end = 0; - r.chunk_len = 0; - return; - } - - var header_buf: [18]u8 = undefined; - const chunk_header = std.fmt.bufPrint(&header_buf, "{x}\r\n", .{r.chunk_len}) catch unreachable; - - var iovecs: [max_trailers * 4 + 5]std.posix.iovec_const = undefined; - var iovecs_len: usize = 0; - - iovecs[iovecs_len] = .{ - .base = http_headers.ptr, - .len = http_headers.len, - }; - iovecs_len += 1; - - if (r.chunk_len > 0) { - iovecs[iovecs_len] = .{ - .base = chunk_header.ptr, - .len = chunk_header.len, - }; - iovecs_len += 1; - - iovecs[iovecs_len] = .{ - .base = r.send_buffer.ptr + r.send_buffer_end - r.chunk_len, - .len = r.chunk_len, - }; - iovecs_len += 1; - - iovecs[iovecs_len] = .{ - .base = "\r\n", - .len = 2, - }; - iovecs_len += 1; - } - - if (end_trailers) |trailers| { - iovecs[iovecs_len] = .{ - .base = "0\r\n", - .len = 3, - }; - iovecs_len += 1; - - for (trailers) |trailer| { - iovecs[iovecs_len] = .{ - .base = trailer.name.ptr, - .len = trailer.name.len, - }; - iovecs_len += 1; - - iovecs[iovecs_len] = .{ - .base = ": ", - .len = 2, - }; - iovecs_len += 1; - - if (trailer.value.len != 0) { - iovecs[iovecs_len] = .{ - .base = trailer.value.ptr, - .len = trailer.value.len, - }; - iovecs_len += 1; - } - - iovecs[iovecs_len] = .{ - .base = "\r\n", - .len = 2, - }; - iovecs_len += 1; - } - - iovecs[iovecs_len] = .{ - .base = "\r\n", - .len = 2, - }; - iovecs_len += 1; - } - - try r.stream.writevAll(iovecs[0..iovecs_len]); - r.send_buffer_start = 0; - r.send_buffer_end = 0; - r.chunk_len = 0; - } - - pub fn writer(r: *Response) std.io.AnyWriter { - return .{ - .writeFn = switch (r.transfer_encoding) { - .none, .content_length => write_cl, - .chunked => write_chunked, - }, - .context = r, - }; - } -}; - -fn rebase(s: *Server, index: usize) void { - const leftover = s.read_buffer[s.next_request_start..s.read_buffer_len]; - const dest = s.read_buffer[index..][0..leftover.len]; - if (leftover.len <= s.next_request_start - index) { - @memcpy(dest, leftover); - } else { - mem.copyBackwards(u8, dest, leftover); - } - s.read_buffer_len = index + leftover.len; -} - -const std = @import("std"); -const http = std.http; -const mem = std.mem; -const net = std.net; -const Uri = std.Uri; -const assert = std.debug.assert; -const testing = std.testing; - -const Server = @This(); diff --git a/src/http/async/std/http/protocol.zig b/src/http/async/std/http/protocol.zig deleted file mode 100644 index 389e1e4f..00000000 --- a/src/http/async/std/http/protocol.zig +++ /dev/null @@ -1,447 +0,0 @@ -const std = @import("std"); -const builtin = @import("builtin"); -const testing = std.testing; -const mem = std.mem; - -const assert = std.debug.assert; -const use_vectors = builtin.zig_backend != .stage2_x86_64; - -pub const State = enum { - invalid, - - // Begin header and trailer parsing states. - - start, - seen_n, - seen_r, - seen_rn, - seen_rnr, - finished, - - // Begin transfer-encoding: chunked parsing states. - - chunk_head_size, - chunk_head_ext, - chunk_head_r, - chunk_data, - chunk_data_suffix, - chunk_data_suffix_r, - - /// Returns true if the parser is in a content state (ie. not waiting for more headers). - pub fn isContent(self: State) bool { - return switch (self) { - .invalid, .start, .seen_n, .seen_r, .seen_rn, .seen_rnr => false, - .finished, .chunk_head_size, .chunk_head_ext, .chunk_head_r, .chunk_data, .chunk_data_suffix, .chunk_data_suffix_r => true, - }; - } -}; - -pub const HeadersParser = struct { - state: State = .start, - /// A fixed buffer of len `max_header_bytes`. - /// Pointers into this buffer are not stable until after a message is complete. - header_bytes_buffer: []u8, - header_bytes_len: u32, - next_chunk_length: u64, - /// `false`: headers. `true`: trailers. - done: bool, - - /// Initializes the parser with a provided buffer `buf`. - pub fn init(buf: []u8) HeadersParser { - return .{ - .header_bytes_buffer = buf, - .header_bytes_len = 0, - .done = false, - .next_chunk_length = 0, - }; - } - - /// Reinitialize the parser. - /// Asserts the parser is in the "done" state. - pub fn reset(hp: *HeadersParser) void { - assert(hp.done); - hp.* = .{ - .state = .start, - .header_bytes_buffer = hp.header_bytes_buffer, - .header_bytes_len = 0, - .done = false, - .next_chunk_length = 0, - }; - } - - pub fn get(hp: HeadersParser) []u8 { - return hp.header_bytes_buffer[0..hp.header_bytes_len]; - } - - pub fn findHeadersEnd(r: *HeadersParser, bytes: []const u8) u32 { - var hp: std.http.HeadParser = .{ - .state = switch (r.state) { - .start => .start, - .seen_n => .seen_n, - .seen_r => .seen_r, - .seen_rn => .seen_rn, - .seen_rnr => .seen_rnr, - .finished => .finished, - else => unreachable, - }, - }; - const result = hp.feed(bytes); - r.state = switch (hp.state) { - .start => .start, - .seen_n => .seen_n, - .seen_r => .seen_r, - .seen_rn => .seen_rn, - .seen_rnr => .seen_rnr, - .finished => .finished, - }; - return @intCast(result); - } - - pub fn findChunkedLen(r: *HeadersParser, bytes: []const u8) u32 { - var cp: std.http.ChunkParser = .{ - .state = switch (r.state) { - .chunk_head_size => .head_size, - .chunk_head_ext => .head_ext, - .chunk_head_r => .head_r, - .chunk_data => .data, - .chunk_data_suffix => .data_suffix, - .chunk_data_suffix_r => .data_suffix_r, - .invalid => .invalid, - else => unreachable, - }, - .chunk_len = r.next_chunk_length, - }; - const result = cp.feed(bytes); - r.state = switch (cp.state) { - .head_size => .chunk_head_size, - .head_ext => .chunk_head_ext, - .head_r => .chunk_head_r, - .data => .chunk_data, - .data_suffix => .chunk_data_suffix, - .data_suffix_r => .chunk_data_suffix_r, - .invalid => .invalid, - }; - r.next_chunk_length = cp.chunk_len; - return @intCast(result); - } - - /// Returns whether or not the parser has finished parsing a complete - /// message. A message is only complete after the entire body has been read - /// and any trailing headers have been parsed. - pub fn isComplete(r: *HeadersParser) bool { - return r.done and r.state == .finished; - } - - pub const CheckCompleteHeadError = error{HttpHeadersOversize}; - - /// Pushes `in` into the parser. Returns the number of bytes consumed by - /// the header. Any header bytes are appended to `header_bytes_buffer`. - pub fn checkCompleteHead(hp: *HeadersParser, in: []const u8) CheckCompleteHeadError!u32 { - if (hp.state.isContent()) return 0; - - const i = hp.findHeadersEnd(in); - const data = in[0..i]; - if (hp.header_bytes_len + data.len > hp.header_bytes_buffer.len) - return error.HttpHeadersOversize; - - @memcpy(hp.header_bytes_buffer[hp.header_bytes_len..][0..data.len], data); - hp.header_bytes_len += @intCast(data.len); - - return i; - } - - pub const ReadError = error{ - HttpChunkInvalid, - }; - - /// Reads the body of the message into `buffer`. Returns the number of - /// bytes placed in the buffer. - /// - /// If `skip` is true, the buffer will be unused and the body will be skipped. - /// - /// See `std.http.Client.Connection for an example of `conn`. - pub fn read(r: *HeadersParser, conn: anytype, buffer: []u8, skip: bool) !usize { - assert(r.state.isContent()); - if (r.done) return 0; - - var out_index: usize = 0; - while (true) { - switch (r.state) { - .invalid, .start, .seen_n, .seen_r, .seen_rn, .seen_rnr => unreachable, - .finished => { - const data_avail = r.next_chunk_length; - - if (skip) { - try conn.fill(); - - const nread = @min(conn.peek().len, data_avail); - conn.drop(@intCast(nread)); - r.next_chunk_length -= nread; - - if (r.next_chunk_length == 0 or nread == 0) r.done = true; - - return out_index; - } else if (out_index < buffer.len) { - const out_avail = buffer.len - out_index; - - const can_read = @as(usize, @intCast(@min(data_avail, out_avail))); - const nread = try conn.read(buffer[0..can_read]); - r.next_chunk_length -= nread; - - if (r.next_chunk_length == 0 or nread == 0) r.done = true; - - return nread; - } else { - return out_index; - } - }, - .chunk_data_suffix, .chunk_data_suffix_r, .chunk_head_size, .chunk_head_ext, .chunk_head_r => { - try conn.fill(); - - const i = r.findChunkedLen(conn.peek()); - conn.drop(@intCast(i)); - - switch (r.state) { - .invalid => return error.HttpChunkInvalid, - .chunk_data => if (r.next_chunk_length == 0) { - if (std.mem.eql(u8, conn.peek(), "\r\n")) { - r.state = .finished; - conn.drop(2); - } else { - // The trailer section is formatted identically - // to the header section. - r.state = .seen_rn; - } - r.done = true; - - return out_index; - }, - else => return out_index, - } - - continue; - }, - .chunk_data => { - const data_avail = r.next_chunk_length; - const out_avail = buffer.len - out_index; - - if (skip) { - try conn.fill(); - - const nread = @min(conn.peek().len, data_avail); - conn.drop(@intCast(nread)); - r.next_chunk_length -= nread; - } else if (out_avail > 0) { - const can_read: usize = @intCast(@min(data_avail, out_avail)); - const nread = try conn.read(buffer[out_index..][0..can_read]); - r.next_chunk_length -= nread; - out_index += nread; - } - - if (r.next_chunk_length == 0) { - r.state = .chunk_data_suffix; - continue; - } - - return out_index; - }, - } - } - } -}; - -inline fn int16(array: *const [2]u8) u16 { - return @as(u16, @bitCast(array.*)); -} - -inline fn int24(array: *const [3]u8) u24 { - return @as(u24, @bitCast(array.*)); -} - -inline fn int32(array: *const [4]u8) u32 { - return @as(u32, @bitCast(array.*)); -} - -inline fn intShift(comptime T: type, x: anytype) T { - switch (@import("builtin").cpu.arch.endian()) { - .little => return @as(T, @truncate(x >> (@bitSizeOf(@TypeOf(x)) - @bitSizeOf(T)))), - .big => return @as(T, @truncate(x)), - } -} - -/// A buffered (and peekable) Connection. -const MockBufferedConnection = struct { - pub const buffer_size = 0x2000; - - conn: std.io.FixedBufferStream([]const u8), - buf: [buffer_size]u8 = undefined, - start: u16 = 0, - end: u16 = 0, - - pub fn fill(conn: *MockBufferedConnection) ReadError!void { - if (conn.end != conn.start) return; - - const nread = try conn.conn.read(conn.buf[0..]); - if (nread == 0) return error.EndOfStream; - conn.start = 0; - conn.end = @as(u16, @truncate(nread)); - } - - pub fn peek(conn: *MockBufferedConnection) []const u8 { - return conn.buf[conn.start..conn.end]; - } - - pub fn drop(conn: *MockBufferedConnection, num: u16) void { - conn.start += num; - } - - pub fn readAtLeast(conn: *MockBufferedConnection, buffer: []u8, len: usize) ReadError!usize { - var out_index: u16 = 0; - while (out_index < len) { - const available = conn.end - conn.start; - const left = buffer.len - out_index; - - if (available > 0) { - const can_read = @as(u16, @truncate(@min(available, left))); - - @memcpy(buffer[out_index..][0..can_read], conn.buf[conn.start..][0..can_read]); - out_index += can_read; - conn.start += can_read; - - continue; - } - - if (left > conn.buf.len) { - // skip the buffer if the output is large enough - return conn.conn.read(buffer[out_index..]); - } - - try conn.fill(); - } - - return out_index; - } - - pub fn read(conn: *MockBufferedConnection, buffer: []u8) ReadError!usize { - return conn.readAtLeast(buffer, 1); - } - - pub const ReadError = std.io.FixedBufferStream([]const u8).ReadError || error{EndOfStream}; - pub const Reader = std.io.Reader(*MockBufferedConnection, ReadError, read); - - pub fn reader(conn: *MockBufferedConnection) Reader { - return Reader{ .context = conn }; - } - - pub fn writeAll(conn: *MockBufferedConnection, buffer: []const u8) WriteError!void { - return conn.conn.writeAll(buffer); - } - - pub fn write(conn: *MockBufferedConnection, buffer: []const u8) WriteError!usize { - return conn.conn.write(buffer); - } - - pub const WriteError = std.io.FixedBufferStream([]const u8).WriteError; - pub const Writer = std.io.Writer(*MockBufferedConnection, WriteError, write); - - pub fn writer(conn: *MockBufferedConnection) Writer { - return Writer{ .context = conn }; - } -}; - -test "HeadersParser.read length" { - // mock BufferedConnection for read - var headers_buf: [256]u8 = undefined; - - var r = HeadersParser.init(&headers_buf); - const data = "GET / HTTP/1.1\r\nHost: localhost\r\nContent-Length: 5\r\n\r\nHello"; - - var conn: MockBufferedConnection = .{ - .conn = std.io.fixedBufferStream(data), - }; - - while (true) { // read headers - try conn.fill(); - - const nchecked = try r.checkCompleteHead(conn.peek()); - conn.drop(@intCast(nchecked)); - - if (r.state.isContent()) break; - } - - var buf: [8]u8 = undefined; - - r.next_chunk_length = 5; - const len = try r.read(&conn, &buf, false); - try std.testing.expectEqual(@as(usize, 5), len); - try std.testing.expectEqualStrings("Hello", buf[0..len]); - - try std.testing.expectEqualStrings("GET / HTTP/1.1\r\nHost: localhost\r\nContent-Length: 5\r\n\r\n", r.get()); -} - -test "HeadersParser.read chunked" { - // mock BufferedConnection for read - - var headers_buf: [256]u8 = undefined; - var r = HeadersParser.init(&headers_buf); - const data = "GET / HTTP/1.1\r\nHost: localhost\r\n\r\n2\r\nHe\r\n2\r\nll\r\n1\r\no\r\n0\r\n\r\n"; - - var conn: MockBufferedConnection = .{ - .conn = std.io.fixedBufferStream(data), - }; - - while (true) { // read headers - try conn.fill(); - - const nchecked = try r.checkCompleteHead(conn.peek()); - conn.drop(@intCast(nchecked)); - - if (r.state.isContent()) break; - } - var buf: [8]u8 = undefined; - - r.state = .chunk_head_size; - const len = try r.read(&conn, &buf, false); - try std.testing.expectEqual(@as(usize, 5), len); - try std.testing.expectEqualStrings("Hello", buf[0..len]); - - try std.testing.expectEqualStrings("GET / HTTP/1.1\r\nHost: localhost\r\n\r\n", r.get()); -} - -test "HeadersParser.read chunked trailer" { - // mock BufferedConnection for read - - var headers_buf: [256]u8 = undefined; - var r = HeadersParser.init(&headers_buf); - const data = "GET / HTTP/1.1\r\nHost: localhost\r\n\r\n2\r\nHe\r\n2\r\nll\r\n1\r\no\r\n0\r\nContent-Type: text/plain\r\n\r\n"; - - var conn: MockBufferedConnection = .{ - .conn = std.io.fixedBufferStream(data), - }; - - while (true) { // read headers - try conn.fill(); - - const nchecked = try r.checkCompleteHead(conn.peek()); - conn.drop(@intCast(nchecked)); - - if (r.state.isContent()) break; - } - var buf: [8]u8 = undefined; - - r.state = .chunk_head_size; - const len = try r.read(&conn, &buf, false); - try std.testing.expectEqual(@as(usize, 5), len); - try std.testing.expectEqualStrings("Hello", buf[0..len]); - - while (true) { // read headers - try conn.fill(); - - const nchecked = try r.checkCompleteHead(conn.peek()); - conn.drop(@intCast(nchecked)); - - if (r.state.isContent()) break; - } - - try std.testing.expectEqualStrings("GET / HTTP/1.1\r\nHost: localhost\r\n\r\nContent-Type: text/plain\r\n\r\n", r.get()); -} diff --git a/src/http/async/std/net.zig b/src/http/async/std/net.zig deleted file mode 100644 index 86863deb..00000000 --- a/src/http/async/std/net.zig +++ /dev/null @@ -1,2050 +0,0 @@ -//! Cross-platform networking abstractions. - -const std = @import("std"); -const builtin = @import("builtin"); -const assert = std.debug.assert; -const net = @This(); -const mem = std.mem; -const posix = std.posix; -const fs = std.fs; -const io = std.io; -const native_endian = builtin.target.cpu.arch.endian(); -const native_os = builtin.os.tag; -const windows = std.os.windows; - -const Ctx = @import("http/Client.zig").Ctx; -const Cbk = @import("http/Client.zig").Cbk; - -// Windows 10 added support for unix sockets in build 17063, redstone 4 is the -// first release to support them. -pub const has_unix_sockets = switch (native_os) { - .windows => builtin.os.version_range.windows.isAtLeast(.win10_rs4) orelse false, - else => true, -}; - -pub const IPParseError = error{ - Overflow, - InvalidEnd, - InvalidCharacter, - Incomplete, -}; - -pub const IPv4ParseError = IPParseError || error{NonCanonical}; - -pub const IPv6ParseError = IPParseError || error{InvalidIpv4Mapping}; -pub const IPv6InterfaceError = posix.SocketError || posix.IoCtl_SIOCGIFINDEX_Error || error{NameTooLong}; -pub const IPv6ResolveError = IPv6ParseError || IPv6InterfaceError; - -pub const Address = extern union { - any: posix.sockaddr, - in: Ip4Address, - in6: Ip6Address, - un: if (has_unix_sockets) posix.sockaddr.un else void, - - /// Parse the given IP address string into an Address value. - /// It is recommended to use `resolveIp` instead, to handle - /// IPv6 link-local unix addresses. - pub fn parseIp(name: []const u8, port: u16) !Address { - if (parseIp4(name, port)) |ip4| return ip4 else |err| switch (err) { - error.Overflow, - error.InvalidEnd, - error.InvalidCharacter, - error.Incomplete, - error.NonCanonical, - => {}, - } - - if (parseIp6(name, port)) |ip6| return ip6 else |err| switch (err) { - error.Overflow, - error.InvalidEnd, - error.InvalidCharacter, - error.Incomplete, - error.InvalidIpv4Mapping, - => {}, - } - - return error.InvalidIPAddressFormat; - } - - pub fn resolveIp(name: []const u8, port: u16) !Address { - if (parseIp4(name, port)) |ip4| return ip4 else |err| switch (err) { - error.Overflow, - error.InvalidEnd, - error.InvalidCharacter, - error.Incomplete, - error.NonCanonical, - => {}, - } - - if (resolveIp6(name, port)) |ip6| return ip6 else |err| switch (err) { - error.Overflow, - error.InvalidEnd, - error.InvalidCharacter, - error.Incomplete, - error.InvalidIpv4Mapping, - => {}, - else => return err, - } - - return error.InvalidIPAddressFormat; - } - - pub fn parseExpectingFamily(name: []const u8, family: posix.sa_family_t, port: u16) !Address { - switch (family) { - posix.AF.INET => return parseIp4(name, port), - posix.AF.INET6 => return parseIp6(name, port), - posix.AF.UNSPEC => return parseIp(name, port), - else => unreachable, - } - } - - pub fn parseIp6(buf: []const u8, port: u16) IPv6ParseError!Address { - return .{ .in6 = try Ip6Address.parse(buf, port) }; - } - - pub fn resolveIp6(buf: []const u8, port: u16) IPv6ResolveError!Address { - return .{ .in6 = try Ip6Address.resolve(buf, port) }; - } - - pub fn parseIp4(buf: []const u8, port: u16) IPv4ParseError!Address { - return .{ .in = try Ip4Address.parse(buf, port) }; - } - - pub fn initIp4(addr: [4]u8, port: u16) Address { - return .{ .in = Ip4Address.init(addr, port) }; - } - - pub fn initIp6(addr: [16]u8, port: u16, flowinfo: u32, scope_id: u32) Address { - return .{ .in6 = Ip6Address.init(addr, port, flowinfo, scope_id) }; - } - - pub fn initUnix(path: []const u8) !Address { - var sock_addr = posix.sockaddr.un{ - .family = posix.AF.UNIX, - .path = undefined, - }; - - // Add 1 to ensure a terminating 0 is present in the path array for maximum portability. - if (path.len + 1 > sock_addr.path.len) return error.NameTooLong; - - @memset(&sock_addr.path, 0); - @memcpy(sock_addr.path[0..path.len], path); - - return .{ .un = sock_addr }; - } - - /// Returns the port in native endian. - /// Asserts that the address is ip4 or ip6. - pub fn getPort(self: Address) u16 { - return switch (self.any.family) { - posix.AF.INET => self.in.getPort(), - posix.AF.INET6 => self.in6.getPort(), - else => unreachable, - }; - } - - /// `port` is native-endian. - /// Asserts that the address is ip4 or ip6. - pub fn setPort(self: *Address, port: u16) void { - switch (self.any.family) { - posix.AF.INET => self.in.setPort(port), - posix.AF.INET6 => self.in6.setPort(port), - else => unreachable, - } - } - - /// Asserts that `addr` is an IP address. - /// This function will read past the end of the pointer, with a size depending - /// on the address family. - pub fn initPosix(addr: *align(4) const posix.sockaddr) Address { - switch (addr.family) { - posix.AF.INET => return Address{ .in = Ip4Address{ .sa = @as(*const posix.sockaddr.in, @ptrCast(addr)).* } }, - posix.AF.INET6 => return Address{ .in6 = Ip6Address{ .sa = @as(*const posix.sockaddr.in6, @ptrCast(addr)).* } }, - else => unreachable, - } - } - - pub fn format( - self: Address, - comptime fmt: []const u8, - options: std.fmt.FormatOptions, - out_stream: anytype, - ) !void { - if (fmt.len != 0) std.fmt.invalidFmtError(fmt, self); - switch (self.any.family) { - posix.AF.INET => try self.in.format(fmt, options, out_stream), - posix.AF.INET6 => try self.in6.format(fmt, options, out_stream), - posix.AF.UNIX => { - if (!has_unix_sockets) { - unreachable; - } - - try std.fmt.format(out_stream, "{s}", .{std.mem.sliceTo(&self.un.path, 0)}); - }, - else => unreachable, - } - } - - pub fn eql(a: Address, b: Address) bool { - const a_bytes = @as([*]const u8, @ptrCast(&a.any))[0..a.getOsSockLen()]; - const b_bytes = @as([*]const u8, @ptrCast(&b.any))[0..b.getOsSockLen()]; - return mem.eql(u8, a_bytes, b_bytes); - } - - pub fn getOsSockLen(self: Address) posix.socklen_t { - switch (self.any.family) { - posix.AF.INET => return self.in.getOsSockLen(), - posix.AF.INET6 => return self.in6.getOsSockLen(), - posix.AF.UNIX => { - if (!has_unix_sockets) { - unreachable; - } - - // Using the full length of the structure here is more portable than returning - // the number of bytes actually used by the currently stored path. - // This also is correct regardless if we are passing a socket address to the kernel - // (e.g. in bind, connect, sendto) since we ensure the path is 0 terminated in - // initUnix() or if we are receiving a socket address from the kernel and must - // provide the full buffer size (e.g. getsockname, getpeername, recvfrom, accept). - // - // To access the path, std.mem.sliceTo(&address.un.path, 0) should be used. - return @as(posix.socklen_t, @intCast(@sizeOf(posix.sockaddr.un))); - }, - - else => unreachable, - } - } - - pub const ListenError = posix.SocketError || posix.BindError || posix.ListenError || - posix.SetSockOptError || posix.GetSockNameError; - - pub const ListenOptions = struct { - /// How many connections the kernel will accept on the application's behalf. - /// If more than this many connections pool in the kernel, clients will start - /// seeing "Connection refused". - kernel_backlog: u31 = 128, - /// Sets SO_REUSEADDR and SO_REUSEPORT on POSIX. - /// Sets SO_REUSEADDR on Windows, which is roughly equivalent. - reuse_address: bool = false, - /// Deprecated. Does the same thing as reuse_address. - reuse_port: bool = false, - force_nonblocking: bool = false, - }; - - /// The returned `Server` has an open `stream`. - pub fn listen(address: Address, options: ListenOptions) ListenError!Server { - const nonblock: u32 = if (options.force_nonblocking) posix.SOCK.NONBLOCK else 0; - const sock_flags = posix.SOCK.STREAM | posix.SOCK.CLOEXEC | nonblock; - const proto: u32 = if (address.any.family == posix.AF.UNIX) 0 else posix.IPPROTO.TCP; - - const sockfd = try posix.socket(address.any.family, sock_flags, proto); - var s: Server = .{ - .listen_address = undefined, - .stream = .{ .handle = sockfd }, - }; - errdefer s.stream.close(); - - if (options.reuse_address or options.reuse_port) { - try posix.setsockopt( - sockfd, - posix.SOL.SOCKET, - posix.SO.REUSEADDR, - &mem.toBytes(@as(c_int, 1)), - ); - switch (native_os) { - .windows => {}, - else => try posix.setsockopt( - sockfd, - posix.SOL.SOCKET, - posix.SO.REUSEPORT, - &mem.toBytes(@as(c_int, 1)), - ), - } - } - - var socklen = address.getOsSockLen(); - try posix.bind(sockfd, &address.any, socklen); - try posix.listen(sockfd, options.kernel_backlog); - try posix.getsockname(sockfd, &s.listen_address.any, &socklen); - return s; - } -}; - -pub const Ip4Address = extern struct { - sa: posix.sockaddr.in, - - pub fn parse(buf: []const u8, port: u16) IPv4ParseError!Ip4Address { - var result: Ip4Address = .{ - .sa = .{ - .port = mem.nativeToBig(u16, port), - .addr = undefined, - }, - }; - const out_ptr = mem.asBytes(&result.sa.addr); - - var x: u8 = 0; - var index: u8 = 0; - var saw_any_digits = false; - var has_zero_prefix = false; - for (buf) |c| { - if (c == '.') { - if (!saw_any_digits) { - return error.InvalidCharacter; - } - if (index == 3) { - return error.InvalidEnd; - } - out_ptr[index] = x; - index += 1; - x = 0; - saw_any_digits = false; - has_zero_prefix = false; - } else if (c >= '0' and c <= '9') { - if (c == '0' and !saw_any_digits) { - has_zero_prefix = true; - } else if (has_zero_prefix) { - return error.NonCanonical; - } - saw_any_digits = true; - x = try std.math.mul(u8, x, 10); - x = try std.math.add(u8, x, c - '0'); - } else { - return error.InvalidCharacter; - } - } - if (index == 3 and saw_any_digits) { - out_ptr[index] = x; - return result; - } - - return error.Incomplete; - } - - pub fn resolveIp(name: []const u8, port: u16) !Ip4Address { - if (parse(name, port)) |ip4| return ip4 else |err| switch (err) { - error.Overflow, - error.InvalidEnd, - error.InvalidCharacter, - error.Incomplete, - error.NonCanonical, - => {}, - } - return error.InvalidIPAddressFormat; - } - - pub fn init(addr: [4]u8, port: u16) Ip4Address { - return Ip4Address{ - .sa = posix.sockaddr.in{ - .port = mem.nativeToBig(u16, port), - .addr = @as(*align(1) const u32, @ptrCast(&addr)).*, - }, - }; - } - - /// Returns the port in native endian. - /// Asserts that the address is ip4 or ip6. - pub fn getPort(self: Ip4Address) u16 { - return mem.bigToNative(u16, self.sa.port); - } - - /// `port` is native-endian. - /// Asserts that the address is ip4 or ip6. - pub fn setPort(self: *Ip4Address, port: u16) void { - self.sa.port = mem.nativeToBig(u16, port); - } - - pub fn format( - self: Ip4Address, - comptime fmt: []const u8, - options: std.fmt.FormatOptions, - out_stream: anytype, - ) !void { - if (fmt.len != 0) std.fmt.invalidFmtError(fmt, self); - _ = options; - const bytes = @as(*const [4]u8, @ptrCast(&self.sa.addr)); - try std.fmt.format(out_stream, "{}.{}.{}.{}:{}", .{ - bytes[0], - bytes[1], - bytes[2], - bytes[3], - self.getPort(), - }); - } - - pub fn getOsSockLen(self: Ip4Address) posix.socklen_t { - _ = self; - return @sizeOf(posix.sockaddr.in); - } -}; - -pub const Ip6Address = extern struct { - sa: posix.sockaddr.in6, - - /// Parse a given IPv6 address string into an Address. - /// Assumes the Scope ID of the address is fully numeric. - /// For non-numeric addresses, see `resolveIp6`. - pub fn parse(buf: []const u8, port: u16) IPv6ParseError!Ip6Address { - var result = Ip6Address{ - .sa = posix.sockaddr.in6{ - .scope_id = 0, - .port = mem.nativeToBig(u16, port), - .flowinfo = 0, - .addr = undefined, - }, - }; - var ip_slice: *[16]u8 = result.sa.addr[0..]; - - var tail: [16]u8 = undefined; - - var x: u16 = 0; - var saw_any_digits = false; - var index: u8 = 0; - var scope_id = false; - var abbrv = false; - for (buf, 0..) |c, i| { - if (scope_id) { - if (c >= '0' and c <= '9') { - const digit = c - '0'; - { - const ov = @mulWithOverflow(result.sa.scope_id, 10); - if (ov[1] != 0) return error.Overflow; - result.sa.scope_id = ov[0]; - } - { - const ov = @addWithOverflow(result.sa.scope_id, digit); - if (ov[1] != 0) return error.Overflow; - result.sa.scope_id = ov[0]; - } - } else { - return error.InvalidCharacter; - } - } else if (c == ':') { - if (!saw_any_digits) { - if (abbrv) return error.InvalidCharacter; // ':::' - if (i != 0) abbrv = true; - @memset(ip_slice[index..], 0); - ip_slice = tail[0..]; - index = 0; - continue; - } - if (index == 14) { - return error.InvalidEnd; - } - ip_slice[index] = @as(u8, @truncate(x >> 8)); - index += 1; - ip_slice[index] = @as(u8, @truncate(x)); - index += 1; - - x = 0; - saw_any_digits = false; - } else if (c == '%') { - if (!saw_any_digits) { - return error.InvalidCharacter; - } - scope_id = true; - saw_any_digits = false; - } else if (c == '.') { - if (!abbrv or ip_slice[0] != 0xff or ip_slice[1] != 0xff) { - // must start with '::ffff:' - return error.InvalidIpv4Mapping; - } - const start_index = mem.lastIndexOfScalar(u8, buf[0..i], ':').? + 1; - const addr = (Ip4Address.parse(buf[start_index..], 0) catch { - return error.InvalidIpv4Mapping; - }).sa.addr; - ip_slice = result.sa.addr[0..]; - ip_slice[10] = 0xff; - ip_slice[11] = 0xff; - - const ptr = mem.sliceAsBytes(@as(*const [1]u32, &addr)[0..]); - - ip_slice[12] = ptr[0]; - ip_slice[13] = ptr[1]; - ip_slice[14] = ptr[2]; - ip_slice[15] = ptr[3]; - return result; - } else { - const digit = try std.fmt.charToDigit(c, 16); - { - const ov = @mulWithOverflow(x, 16); - if (ov[1] != 0) return error.Overflow; - x = ov[0]; - } - { - const ov = @addWithOverflow(x, digit); - if (ov[1] != 0) return error.Overflow; - x = ov[0]; - } - saw_any_digits = true; - } - } - - if (!saw_any_digits and !abbrv) { - return error.Incomplete; - } - if (!abbrv and index < 14) { - return error.Incomplete; - } - - if (index == 14) { - ip_slice[14] = @as(u8, @truncate(x >> 8)); - ip_slice[15] = @as(u8, @truncate(x)); - return result; - } else { - ip_slice[index] = @as(u8, @truncate(x >> 8)); - index += 1; - ip_slice[index] = @as(u8, @truncate(x)); - index += 1; - @memcpy(result.sa.addr[16 - index ..][0..index], ip_slice[0..index]); - return result; - } - } - - pub fn resolve(buf: []const u8, port: u16) IPv6ResolveError!Ip6Address { - // TODO: Unify the implementations of resolveIp6 and parseIp6. - var result = Ip6Address{ - .sa = posix.sockaddr.in6{ - .scope_id = 0, - .port = mem.nativeToBig(u16, port), - .flowinfo = 0, - .addr = undefined, - }, - }; - var ip_slice: *[16]u8 = result.sa.addr[0..]; - - var tail: [16]u8 = undefined; - - var x: u16 = 0; - var saw_any_digits = false; - var index: u8 = 0; - var abbrv = false; - - var scope_id = false; - var scope_id_value: [posix.IFNAMESIZE - 1]u8 = undefined; - var scope_id_index: usize = 0; - - for (buf, 0..) |c, i| { - if (scope_id) { - // Handling of percent-encoding should be for an URI library. - if ((c >= '0' and c <= '9') or - (c >= 'A' and c <= 'Z') or - (c >= 'a' and c <= 'z') or - (c == '-') or (c == '.') or (c == '_') or (c == '~')) - { - if (scope_id_index >= scope_id_value.len) { - return error.Overflow; - } - - scope_id_value[scope_id_index] = c; - scope_id_index += 1; - } else { - return error.InvalidCharacter; - } - } else if (c == ':') { - if (!saw_any_digits) { - if (abbrv) return error.InvalidCharacter; // ':::' - if (i != 0) abbrv = true; - @memset(ip_slice[index..], 0); - ip_slice = tail[0..]; - index = 0; - continue; - } - if (index == 14) { - return error.InvalidEnd; - } - ip_slice[index] = @as(u8, @truncate(x >> 8)); - index += 1; - ip_slice[index] = @as(u8, @truncate(x)); - index += 1; - - x = 0; - saw_any_digits = false; - } else if (c == '%') { - if (!saw_any_digits) { - return error.InvalidCharacter; - } - scope_id = true; - saw_any_digits = false; - } else if (c == '.') { - if (!abbrv or ip_slice[0] != 0xff or ip_slice[1] != 0xff) { - // must start with '::ffff:' - return error.InvalidIpv4Mapping; - } - const start_index = mem.lastIndexOfScalar(u8, buf[0..i], ':').? + 1; - const addr = (Ip4Address.parse(buf[start_index..], 0) catch { - return error.InvalidIpv4Mapping; - }).sa.addr; - ip_slice = result.sa.addr[0..]; - ip_slice[10] = 0xff; - ip_slice[11] = 0xff; - - const ptr = mem.sliceAsBytes(@as(*const [1]u32, &addr)[0..]); - - ip_slice[12] = ptr[0]; - ip_slice[13] = ptr[1]; - ip_slice[14] = ptr[2]; - ip_slice[15] = ptr[3]; - return result; - } else { - const digit = try std.fmt.charToDigit(c, 16); - { - const ov = @mulWithOverflow(x, 16); - if (ov[1] != 0) return error.Overflow; - x = ov[0]; - } - { - const ov = @addWithOverflow(x, digit); - if (ov[1] != 0) return error.Overflow; - x = ov[0]; - } - saw_any_digits = true; - } - } - - if (!saw_any_digits and !abbrv) { - return error.Incomplete; - } - - if (scope_id and scope_id_index == 0) { - return error.Incomplete; - } - - var resolved_scope_id: u32 = 0; - if (scope_id_index > 0) { - const scope_id_str = scope_id_value[0..scope_id_index]; - resolved_scope_id = std.fmt.parseInt(u32, scope_id_str, 10) catch |err| blk: { - if (err != error.InvalidCharacter) return err; - break :blk try if_nametoindex(scope_id_str); - }; - } - - result.sa.scope_id = resolved_scope_id; - - if (index == 14) { - ip_slice[14] = @as(u8, @truncate(x >> 8)); - ip_slice[15] = @as(u8, @truncate(x)); - return result; - } else { - ip_slice[index] = @as(u8, @truncate(x >> 8)); - index += 1; - ip_slice[index] = @as(u8, @truncate(x)); - index += 1; - @memcpy(result.sa.addr[16 - index ..][0..index], ip_slice[0..index]); - return result; - } - } - - pub fn init(addr: [16]u8, port: u16, flowinfo: u32, scope_id: u32) Ip6Address { - return Ip6Address{ - .sa = posix.sockaddr.in6{ - .addr = addr, - .port = mem.nativeToBig(u16, port), - .flowinfo = flowinfo, - .scope_id = scope_id, - }, - }; - } - - /// Returns the port in native endian. - /// Asserts that the address is ip4 or ip6. - pub fn getPort(self: Ip6Address) u16 { - return mem.bigToNative(u16, self.sa.port); - } - - /// `port` is native-endian. - /// Asserts that the address is ip4 or ip6. - pub fn setPort(self: *Ip6Address, port: u16) void { - self.sa.port = mem.nativeToBig(u16, port); - } - - pub fn format( - self: Ip6Address, - comptime fmt: []const u8, - options: std.fmt.FormatOptions, - out_stream: anytype, - ) !void { - if (fmt.len != 0) std.fmt.invalidFmtError(fmt, self); - _ = options; - const port = mem.bigToNative(u16, self.sa.port); - if (mem.eql(u8, self.sa.addr[0..12], &[_]u8{ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xff, 0xff })) { - try std.fmt.format(out_stream, "[::ffff:{}.{}.{}.{}]:{}", .{ - self.sa.addr[12], - self.sa.addr[13], - self.sa.addr[14], - self.sa.addr[15], - port, - }); - return; - } - const big_endian_parts = @as(*align(1) const [8]u16, @ptrCast(&self.sa.addr)); - const native_endian_parts = switch (native_endian) { - .big => big_endian_parts.*, - .little => blk: { - var buf: [8]u16 = undefined; - for (big_endian_parts, 0..) |part, i| { - buf[i] = mem.bigToNative(u16, part); - } - break :blk buf; - }, - }; - try out_stream.writeAll("["); - var i: usize = 0; - var abbrv = false; - while (i < native_endian_parts.len) : (i += 1) { - if (native_endian_parts[i] == 0) { - if (!abbrv) { - try out_stream.writeAll(if (i == 0) "::" else ":"); - abbrv = true; - } - continue; - } - try std.fmt.format(out_stream, "{x}", .{native_endian_parts[i]}); - if (i != native_endian_parts.len - 1) { - try out_stream.writeAll(":"); - } - } - try std.fmt.format(out_stream, "]:{}", .{port}); - } - - pub fn getOsSockLen(self: Ip6Address) posix.socklen_t { - _ = self; - return @sizeOf(posix.sockaddr.in6); - } -}; - -pub fn connectUnixSocket(path: []const u8) !Stream { - const opt_non_block = 0; - const sockfd = try posix.socket( - posix.AF.UNIX, - posix.SOCK.STREAM | posix.SOCK.CLOEXEC | opt_non_block, - 0, - ); - errdefer Stream.close(.{ .handle = sockfd }); - - var addr = try std.net.Address.initUnix(path); - try posix.connect(sockfd, &addr.any, addr.getOsSockLen()); - - return .{ .handle = sockfd }; -} - -fn if_nametoindex(name: []const u8) IPv6InterfaceError!u32 { - if (native_os == .linux) { - var ifr: posix.ifreq = undefined; - const sockfd = try posix.socket(posix.AF.UNIX, posix.SOCK.DGRAM | posix.SOCK.CLOEXEC, 0); - defer Stream.close(.{ .handle = sockfd }); - - @memcpy(ifr.ifrn.name[0..name.len], name); - ifr.ifrn.name[name.len] = 0; - - // TODO investigate if this needs to be integrated with evented I/O. - try posix.ioctl_SIOCGIFINDEX(sockfd, &ifr); - - return @bitCast(ifr.ifru.ivalue); - } - - if (native_os.isDarwin()) { - if (name.len >= posix.IFNAMESIZE) - return error.NameTooLong; - - var if_name: [posix.IFNAMESIZE:0]u8 = undefined; - @memcpy(if_name[0..name.len], name); - if_name[name.len] = 0; - const if_slice = if_name[0..name.len :0]; - const index = std.c.if_nametoindex(if_slice); - if (index == 0) - return error.InterfaceNotFound; - return @as(u32, @bitCast(index)); - } - - @compileError("std.net.if_nametoindex unimplemented for this OS"); -} - -pub const AddressList = struct { - arena: std.heap.ArenaAllocator, - addrs: []Address, - canon_name: ?[]u8, - - pub fn deinit(self: *AddressList) void { - // Here we copy the arena allocator into stack memory, because - // otherwise it would destroy itself while it was still working. - var arena = self.arena; - arena.deinit(); - // self is destroyed - } -}; - -pub const TcpConnectToHostError = GetAddressListError || TcpConnectToAddressError; - -/// All memory allocated with `allocator` will be freed before this function returns. -pub fn tcpConnectToHost(allocator: mem.Allocator, name: []const u8, port: u16) TcpConnectToHostError!Stream { - const list = try getAddressList(allocator, name, port); - defer list.deinit(); - - if (list.addrs.len == 0) return error.UnknownHostName; - - for (list.addrs) |addr| { - return tcpConnectToAddress(addr) catch |err| switch (err) { - error.ConnectionRefused => { - continue; - }, - else => return err, - }; - } - return posix.ConnectError.ConnectionRefused; -} - -pub const TcpConnectToAddressError = posix.SocketError || posix.ConnectError; - -pub fn tcpConnectToAddress(address: Address) TcpConnectToAddressError!Stream { - const nonblock = 0; - const sock_flags = posix.SOCK.STREAM | nonblock | - (if (native_os == .windows) 0 else posix.SOCK.CLOEXEC); - const sockfd = try posix.socket(address.any.family, sock_flags, posix.IPPROTO.TCP); - errdefer Stream.close(.{ .handle = sockfd }); - - try posix.connect(sockfd, &address.any, address.getOsSockLen()); - - return Stream{ .handle = sockfd }; -} - -const GetAddressListError = std.mem.Allocator.Error || std.fs.File.OpenError || std.fs.File.ReadError || posix.SocketError || posix.BindError || posix.SetSockOptError || error{ - // TODO: break this up into error sets from the various underlying functions - - TemporaryNameServerFailure, - NameServerFailure, - AddressFamilyNotSupported, - UnknownHostName, - ServiceUnavailable, - Unexpected, - - HostLacksNetworkAddresses, - - InvalidCharacter, - InvalidEnd, - NonCanonical, - Overflow, - Incomplete, - InvalidIpv4Mapping, - InvalidIPAddressFormat, - - InterfaceNotFound, - FileSystem, -}; - -/// Call `AddressList.deinit` on the result. -pub fn getAddressList(allocator: mem.Allocator, name: []const u8, port: u16) GetAddressListError!*AddressList { - const result = blk: { - var arena = std.heap.ArenaAllocator.init(allocator); - errdefer arena.deinit(); - - const result = try arena.allocator().create(AddressList); - result.* = AddressList{ - .arena = arena, - .addrs = undefined, - .canon_name = null, - }; - break :blk result; - }; - const arena = result.arena.allocator(); - errdefer result.deinit(); - - if (native_os == .windows) { - const name_c = try allocator.dupeZ(u8, name); - defer allocator.free(name_c); - - const port_c = try std.fmt.allocPrintZ(allocator, "{}", .{port}); - defer allocator.free(port_c); - - const ws2_32 = windows.ws2_32; - const hints = posix.addrinfo{ - .flags = ws2_32.AI.NUMERICSERV, - .family = posix.AF.UNSPEC, - .socktype = posix.SOCK.STREAM, - .protocol = posix.IPPROTO.TCP, - .canonname = null, - .addr = null, - .addrlen = 0, - .next = null, - }; - var res: ?*posix.addrinfo = null; - var first = true; - while (true) { - const rc = ws2_32.getaddrinfo(name_c.ptr, port_c.ptr, &hints, &res); - switch (@as(windows.ws2_32.WinsockError, @enumFromInt(@as(u16, @intCast(rc))))) { - @as(windows.ws2_32.WinsockError, @enumFromInt(0)) => break, - .WSATRY_AGAIN => return error.TemporaryNameServerFailure, - .WSANO_RECOVERY => return error.NameServerFailure, - .WSAEAFNOSUPPORT => return error.AddressFamilyNotSupported, - .WSA_NOT_ENOUGH_MEMORY => return error.OutOfMemory, - .WSAHOST_NOT_FOUND => return error.UnknownHostName, - .WSATYPE_NOT_FOUND => return error.ServiceUnavailable, - .WSAEINVAL => unreachable, - .WSAESOCKTNOSUPPORT => unreachable, - .WSANOTINITIALISED => { - if (!first) return error.Unexpected; - first = false; - try windows.callWSAStartup(); - continue; - }, - else => |err| return windows.unexpectedWSAError(err), - } - } - defer ws2_32.freeaddrinfo(res); - - const addr_count = blk: { - var count: usize = 0; - var it = res; - while (it) |info| : (it = info.next) { - if (info.addr != null) { - count += 1; - } - } - break :blk count; - }; - result.addrs = try arena.alloc(Address, addr_count); - - var it = res; - var i: usize = 0; - while (it) |info| : (it = info.next) { - const addr = info.addr orelse continue; - result.addrs[i] = Address.initPosix(@alignCast(addr)); - - if (info.canonname) |n| { - if (result.canon_name == null) { - result.canon_name = try arena.dupe(u8, mem.sliceTo(n, 0)); - } - } - i += 1; - } - - return result; - } - - if (builtin.link_libc) { - const name_c = try allocator.dupeZ(u8, name); - defer allocator.free(name_c); - - const port_c = try std.fmt.allocPrintZ(allocator, "{}", .{port}); - defer allocator.free(port_c); - - const sys = if (native_os == .windows) windows.ws2_32 else posix.system; - const hints = posix.addrinfo{ - .flags = sys.AI.NUMERICSERV, - .family = posix.AF.UNSPEC, - .socktype = posix.SOCK.STREAM, - .protocol = posix.IPPROTO.TCP, - .canonname = null, - .addr = null, - .addrlen = 0, - .next = null, - }; - var res: ?*posix.addrinfo = null; - switch (sys.getaddrinfo(name_c.ptr, port_c.ptr, &hints, &res)) { - @as(sys.EAI, @enumFromInt(0)) => {}, - .ADDRFAMILY => return error.HostLacksNetworkAddresses, - .AGAIN => return error.TemporaryNameServerFailure, - .BADFLAGS => unreachable, // Invalid hints - .FAIL => return error.NameServerFailure, - .FAMILY => return error.AddressFamilyNotSupported, - .MEMORY => return error.OutOfMemory, - .NODATA => return error.HostLacksNetworkAddresses, - .NONAME => return error.UnknownHostName, - .SERVICE => return error.ServiceUnavailable, - .SOCKTYPE => unreachable, // Invalid socket type requested in hints - .SYSTEM => switch (posix.errno(-1)) { - else => |e| return posix.unexpectedErrno(e), - }, - else => unreachable, - } - defer if (res) |some| sys.freeaddrinfo(some); - - const addr_count = blk: { - var count: usize = 0; - var it = res; - while (it) |info| : (it = info.next) { - if (info.addr != null) { - count += 1; - } - } - break :blk count; - }; - result.addrs = try arena.alloc(Address, addr_count); - - var it = res; - var i: usize = 0; - while (it) |info| : (it = info.next) { - const addr = info.addr orelse continue; - result.addrs[i] = Address.initPosix(@alignCast(addr)); - - if (info.canonname) |n| { - if (result.canon_name == null) { - result.canon_name = try arena.dupe(u8, mem.sliceTo(n, 0)); - } - } - i += 1; - } - - return result; - } - - if (native_os == .linux) { - const flags = std.c.AI.NUMERICSERV; - const family = posix.AF.UNSPEC; - var lookup_addrs = std.ArrayList(LookupAddr).init(allocator); - defer lookup_addrs.deinit(); - - var canon = std.ArrayList(u8).init(arena); - defer canon.deinit(); - - try linuxLookupName(&lookup_addrs, &canon, name, family, flags, port); - - result.addrs = try arena.alloc(Address, lookup_addrs.items.len); - if (canon.items.len != 0) { - result.canon_name = try canon.toOwnedSlice(); - } - - for (lookup_addrs.items, 0..) |lookup_addr, i| { - result.addrs[i] = lookup_addr.addr; - assert(result.addrs[i].getPort() == port); - } - - return result; - } - @compileError("std.net.getAddressList unimplemented for this OS"); -} - -const LookupAddr = struct { - addr: Address, - sortkey: i32 = 0, -}; - -const DAS_USABLE = 0x40000000; -const DAS_MATCHINGSCOPE = 0x20000000; -const DAS_MATCHINGLABEL = 0x10000000; -const DAS_PREC_SHIFT = 20; -const DAS_SCOPE_SHIFT = 16; -const DAS_PREFIX_SHIFT = 8; -const DAS_ORDER_SHIFT = 0; - -fn linuxLookupName( - addrs: *std.ArrayList(LookupAddr), - canon: *std.ArrayList(u8), - opt_name: ?[]const u8, - family: posix.sa_family_t, - flags: u32, - port: u16, -) !void { - if (opt_name) |name| { - // reject empty name and check len so it fits into temp bufs - canon.items.len = 0; - try canon.appendSlice(name); - if (Address.parseExpectingFamily(name, family, port)) |addr| { - try addrs.append(LookupAddr{ .addr = addr }); - } else |name_err| if ((flags & std.c.AI.NUMERICHOST) != 0) { - return name_err; - } else { - try linuxLookupNameFromHosts(addrs, canon, name, family, port); - if (addrs.items.len == 0) { - // RFC 6761 Section 6.3.3 - // Name resolution APIs and libraries SHOULD recognize localhost - // names as special and SHOULD always return the IP loopback address - // for address queries and negative responses for all other query - // types. - - // Check for equal to "localhost(.)" or ends in ".localhost(.)" - const localhost = if (name[name.len - 1] == '.') "localhost." else "localhost"; - if (mem.endsWith(u8, name, localhost) and (name.len == localhost.len or name[name.len - localhost.len] == '.')) { - try addrs.append(LookupAddr{ .addr = .{ .in = Ip4Address.parse("127.0.0.1", port) catch unreachable } }); - try addrs.append(LookupAddr{ .addr = .{ .in6 = Ip6Address.parse("::1", port) catch unreachable } }); - return; - } - - try linuxLookupNameFromDnsSearch(addrs, canon, name, family, port); - } - } - } else { - try canon.resize(0); - try linuxLookupNameFromNull(addrs, family, flags, port); - } - if (addrs.items.len == 0) return error.UnknownHostName; - - // No further processing is needed if there are fewer than 2 - // results or if there are only IPv4 results. - if (addrs.items.len == 1 or family == posix.AF.INET) return; - const all_ip4 = for (addrs.items) |addr| { - if (addr.addr.any.family != posix.AF.INET) break false; - } else true; - if (all_ip4) return; - - // The following implements a subset of RFC 3484/6724 destination - // address selection by generating a single 31-bit sort key for - // each address. Rules 3, 4, and 7 are omitted for having - // excessive runtime and code size cost and dubious benefit. - // So far the label/precedence table cannot be customized. - // This implementation is ported from musl libc. - // A more idiomatic "ziggy" implementation would be welcome. - for (addrs.items, 0..) |*addr, i| { - var key: i32 = 0; - var sa6: posix.sockaddr.in6 = undefined; - @memset(@as([*]u8, @ptrCast(&sa6))[0..@sizeOf(posix.sockaddr.in6)], 0); - var da6 = posix.sockaddr.in6{ - .family = posix.AF.INET6, - .scope_id = addr.addr.in6.sa.scope_id, - .port = 65535, - .flowinfo = 0, - .addr = [1]u8{0} ** 16, - }; - var sa4: posix.sockaddr.in = undefined; - @memset(@as([*]u8, @ptrCast(&sa4))[0..@sizeOf(posix.sockaddr.in)], 0); - var da4 = posix.sockaddr.in{ - .family = posix.AF.INET, - .port = 65535, - .addr = 0, - .zero = [1]u8{0} ** 8, - }; - var sa: *align(4) posix.sockaddr = undefined; - var da: *align(4) posix.sockaddr = undefined; - var salen: posix.socklen_t = undefined; - var dalen: posix.socklen_t = undefined; - if (addr.addr.any.family == posix.AF.INET6) { - da6.addr = addr.addr.in6.sa.addr; - da = @ptrCast(&da6); - dalen = @sizeOf(posix.sockaddr.in6); - sa = @ptrCast(&sa6); - salen = @sizeOf(posix.sockaddr.in6); - } else { - sa6.addr[0..12].* = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff".*; - da6.addr[0..12].* = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff".*; - mem.writeInt(u32, da6.addr[12..], addr.addr.in.sa.addr, native_endian); - da4.addr = addr.addr.in.sa.addr; - da = @ptrCast(&da4); - dalen = @sizeOf(posix.sockaddr.in); - sa = @ptrCast(&sa4); - salen = @sizeOf(posix.sockaddr.in); - } - const dpolicy = policyOf(da6.addr); - const dscope: i32 = scopeOf(da6.addr); - const dlabel = dpolicy.label; - const dprec: i32 = dpolicy.prec; - const MAXADDRS = 3; - var prefixlen: i32 = 0; - const sock_flags = posix.SOCK.DGRAM | posix.SOCK.CLOEXEC; - if (posix.socket(addr.addr.any.family, sock_flags, posix.IPPROTO.UDP)) |fd| syscalls: { - defer Stream.close(.{ .handle = fd }); - posix.connect(fd, da, dalen) catch break :syscalls; - key |= DAS_USABLE; - posix.getsockname(fd, sa, &salen) catch break :syscalls; - if (addr.addr.any.family == posix.AF.INET) { - mem.writeInt(u32, sa6.addr[12..16], sa4.addr, native_endian); - } - if (dscope == @as(i32, scopeOf(sa6.addr))) key |= DAS_MATCHINGSCOPE; - if (dlabel == labelOf(sa6.addr)) key |= DAS_MATCHINGLABEL; - prefixlen = prefixMatch(sa6.addr, da6.addr); - } else |_| {} - key |= dprec << DAS_PREC_SHIFT; - key |= (15 - dscope) << DAS_SCOPE_SHIFT; - key |= prefixlen << DAS_PREFIX_SHIFT; - key |= (MAXADDRS - @as(i32, @intCast(i))) << DAS_ORDER_SHIFT; - addr.sortkey = key; - } - mem.sort(LookupAddr, addrs.items, {}, addrCmpLessThan); -} - -const Policy = struct { - addr: [16]u8, - len: u8, - mask: u8, - prec: u8, - label: u8, -}; - -const defined_policies = [_]Policy{ - Policy{ - .addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01".*, - .len = 15, - .mask = 0xff, - .prec = 50, - .label = 0, - }, - Policy{ - .addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\x00\x00\x00\x00".*, - .len = 11, - .mask = 0xff, - .prec = 35, - .label = 4, - }, - Policy{ - .addr = "\x20\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00".*, - .len = 1, - .mask = 0xff, - .prec = 30, - .label = 2, - }, - Policy{ - .addr = "\x20\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00".*, - .len = 3, - .mask = 0xff, - .prec = 5, - .label = 5, - }, - Policy{ - .addr = "\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00".*, - .len = 0, - .mask = 0xfe, - .prec = 3, - .label = 13, - }, - // These are deprecated and/or returned to the address - // pool, so despite the RFC, treating them as special - // is probably wrong. - // { "", 11, 0xff, 1, 3 }, - // { "\xfe\xc0", 1, 0xc0, 1, 11 }, - // { "\x3f\xfe", 1, 0xff, 1, 12 }, - // Last rule must match all addresses to stop loop. - Policy{ - .addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00".*, - .len = 0, - .mask = 0, - .prec = 40, - .label = 1, - }, -}; - -fn policyOf(a: [16]u8) *const Policy { - for (&defined_policies) |*policy| { - if (!mem.eql(u8, a[0..policy.len], policy.addr[0..policy.len])) continue; - if ((a[policy.len] & policy.mask) != policy.addr[policy.len]) continue; - return policy; - } - unreachable; -} - -fn scopeOf(a: [16]u8) u8 { - if (IN6_IS_ADDR_MULTICAST(a)) return a[1] & 15; - if (IN6_IS_ADDR_LINKLOCAL(a)) return 2; - if (IN6_IS_ADDR_LOOPBACK(a)) return 2; - if (IN6_IS_ADDR_SITELOCAL(a)) return 5; - return 14; -} - -fn prefixMatch(s: [16]u8, d: [16]u8) u8 { - // TODO: This FIXME inherited from porting from musl libc. - // I don't want this to go into zig std lib 1.0.0. - - // FIXME: The common prefix length should be limited to no greater - // than the nominal length of the prefix portion of the source - // address. However the definition of the source prefix length is - // not clear and thus this limiting is not yet implemented. - var i: u8 = 0; - while (i < 128 and ((s[i / 8] ^ d[i / 8]) & (@as(u8, 128) >> @as(u3, @intCast(i % 8)))) == 0) : (i += 1) {} - return i; -} - -fn labelOf(a: [16]u8) u8 { - return policyOf(a).label; -} - -fn IN6_IS_ADDR_MULTICAST(a: [16]u8) bool { - return a[0] == 0xff; -} - -fn IN6_IS_ADDR_LINKLOCAL(a: [16]u8) bool { - return a[0] == 0xfe and (a[1] & 0xc0) == 0x80; -} - -fn IN6_IS_ADDR_LOOPBACK(a: [16]u8) bool { - return a[0] == 0 and a[1] == 0 and - a[2] == 0 and - a[12] == 0 and a[13] == 0 and - a[14] == 0 and a[15] == 1; -} - -fn IN6_IS_ADDR_SITELOCAL(a: [16]u8) bool { - return a[0] == 0xfe and (a[1] & 0xc0) == 0xc0; -} - -// Parameters `b` and `a` swapped to make this descending. -fn addrCmpLessThan(context: void, b: LookupAddr, a: LookupAddr) bool { - _ = context; - return a.sortkey < b.sortkey; -} - -fn linuxLookupNameFromNull( - addrs: *std.ArrayList(LookupAddr), - family: posix.sa_family_t, - flags: u32, - port: u16, -) !void { - if ((flags & std.c.AI.PASSIVE) != 0) { - if (family != posix.AF.INET6) { - (try addrs.addOne()).* = LookupAddr{ - .addr = Address.initIp4([1]u8{0} ** 4, port), - }; - } - if (family != posix.AF.INET) { - (try addrs.addOne()).* = LookupAddr{ - .addr = Address.initIp6([1]u8{0} ** 16, port, 0, 0), - }; - } - } else { - if (family != posix.AF.INET6) { - (try addrs.addOne()).* = LookupAddr{ - .addr = Address.initIp4([4]u8{ 127, 0, 0, 1 }, port), - }; - } - if (family != posix.AF.INET) { - (try addrs.addOne()).* = LookupAddr{ - .addr = Address.initIp6(([1]u8{0} ** 15) ++ [1]u8{1}, port, 0, 0), - }; - } - } -} - -fn linuxLookupNameFromHosts( - addrs: *std.ArrayList(LookupAddr), - canon: *std.ArrayList(u8), - name: []const u8, - family: posix.sa_family_t, - port: u16, -) !void { - const file = fs.openFileAbsoluteZ("/etc/hosts", .{}) catch |err| switch (err) { - error.FileNotFound, - error.NotDir, - error.AccessDenied, - => return, - else => |e| return e, - }; - defer file.close(); - - var buffered_reader = std.io.bufferedReader(file.reader()); - const reader = buffered_reader.reader(); - var line_buf: [512]u8 = undefined; - while (reader.readUntilDelimiterOrEof(&line_buf, '\n') catch |err| switch (err) { - error.StreamTooLong => blk: { - // Skip to the delimiter in the reader, to fix parsing - try reader.skipUntilDelimiterOrEof('\n'); - // Use the truncated line. A truncated comment or hostname will be handled correctly. - break :blk &line_buf; - }, - else => |e| return e, - }) |line| { - var split_it = mem.splitScalar(u8, line, '#'); - const no_comment_line = split_it.first(); - - var line_it = mem.tokenizeAny(u8, no_comment_line, " \t"); - const ip_text = line_it.next() orelse continue; - var first_name_text: ?[]const u8 = null; - while (line_it.next()) |name_text| { - if (first_name_text == null) first_name_text = name_text; - if (mem.eql(u8, name_text, name)) { - break; - } - } else continue; - - const addr = Address.parseExpectingFamily(ip_text, family, port) catch |err| switch (err) { - error.Overflow, - error.InvalidEnd, - error.InvalidCharacter, - error.Incomplete, - error.InvalidIPAddressFormat, - error.InvalidIpv4Mapping, - error.NonCanonical, - => continue, - }; - try addrs.append(LookupAddr{ .addr = addr }); - - // first name is canonical name - const name_text = first_name_text.?; - if (isValidHostName(name_text)) { - canon.items.len = 0; - try canon.appendSlice(name_text); - } - } -} - -pub fn isValidHostName(hostname: []const u8) bool { - if (hostname.len >= 254) return false; - if (!std.unicode.utf8ValidateSlice(hostname)) return false; - for (hostname) |byte| { - if (!std.ascii.isASCII(byte) or byte == '.' or byte == '-' or std.ascii.isAlphanumeric(byte)) { - continue; - } - return false; - } - return true; -} - -fn linuxLookupNameFromDnsSearch( - addrs: *std.ArrayList(LookupAddr), - canon: *std.ArrayList(u8), - name: []const u8, - family: posix.sa_family_t, - port: u16, -) !void { - var rc: ResolvConf = undefined; - try getResolvConf(addrs.allocator, &rc); - defer rc.deinit(); - - // Count dots, suppress search when >=ndots or name ends in - // a dot, which is an explicit request for global scope. - var dots: usize = 0; - for (name) |byte| { - if (byte == '.') dots += 1; - } - - const search = if (dots >= rc.ndots or mem.endsWith(u8, name, ".")) - "" - else - rc.search.items; - - var canon_name = name; - - // Strip final dot for canon, fail if multiple trailing dots. - if (mem.endsWith(u8, canon_name, ".")) canon_name.len -= 1; - if (mem.endsWith(u8, canon_name, ".")) return error.UnknownHostName; - - // Name with search domain appended is setup in canon[]. This both - // provides the desired default canonical name (if the requested - // name is not a CNAME record) and serves as a buffer for passing - // the full requested name to name_from_dns. - try canon.resize(canon_name.len); - @memcpy(canon.items, canon_name); - try canon.append('.'); - - var tok_it = mem.tokenizeAny(u8, search, " \t"); - while (tok_it.next()) |tok| { - canon.shrinkRetainingCapacity(canon_name.len + 1); - try canon.appendSlice(tok); - try linuxLookupNameFromDns(addrs, canon, canon.items, family, rc, port); - if (addrs.items.len != 0) return; - } - - canon.shrinkRetainingCapacity(canon_name.len); - return linuxLookupNameFromDns(addrs, canon, name, family, rc, port); -} - -const dpc_ctx = struct { - addrs: *std.ArrayList(LookupAddr), - canon: *std.ArrayList(u8), - port: u16, -}; - -fn linuxLookupNameFromDns( - addrs: *std.ArrayList(LookupAddr), - canon: *std.ArrayList(u8), - name: []const u8, - family: posix.sa_family_t, - rc: ResolvConf, - port: u16, -) !void { - const ctx = dpc_ctx{ - .addrs = addrs, - .canon = canon, - .port = port, - }; - const AfRr = struct { - af: posix.sa_family_t, - rr: u8, - }; - const afrrs = [_]AfRr{ - AfRr{ .af = posix.AF.INET6, .rr = posix.RR.A }, - AfRr{ .af = posix.AF.INET, .rr = posix.RR.AAAA }, - }; - var qbuf: [2][280]u8 = undefined; - var abuf: [2][512]u8 = undefined; - var qp: [2][]const u8 = undefined; - const apbuf = [2][]u8{ &abuf[0], &abuf[1] }; - var nq: usize = 0; - - for (afrrs) |afrr| { - if (family != afrr.af) { - const len = posix.res_mkquery(0, name, 1, afrr.rr, &[_]u8{}, null, &qbuf[nq]); - qp[nq] = qbuf[nq][0..len]; - nq += 1; - } - } - - var ap = [2][]u8{ apbuf[0], apbuf[1] }; - ap[0].len = 0; - ap[1].len = 0; - - try resMSendRc(qp[0..nq], ap[0..nq], apbuf[0..nq], rc); - - var i: usize = 0; - while (i < nq) : (i += 1) { - dnsParse(ap[i], ctx, dnsParseCallback) catch {}; - } - - if (addrs.items.len != 0) return; - if (ap[0].len < 4 or (ap[0][3] & 15) == 2) return error.TemporaryNameServerFailure; - if ((ap[0][3] & 15) == 0) return error.UnknownHostName; - if ((ap[0][3] & 15) == 3) return; - return error.NameServerFailure; -} - -const ResolvConf = struct { - attempts: u32, - ndots: u32, - timeout: u32, - search: std.ArrayList(u8), - ns: std.ArrayList(LookupAddr), - - fn deinit(rc: *ResolvConf) void { - rc.ns.deinit(); - rc.search.deinit(); - rc.* = undefined; - } -}; - -/// Ignores lines longer than 512 bytes. -/// TODO: https://github.com/ziglang/zig/issues/2765 and https://github.com/ziglang/zig/issues/2761 -fn getResolvConf(allocator: mem.Allocator, rc: *ResolvConf) !void { - rc.* = ResolvConf{ - .ns = std.ArrayList(LookupAddr).init(allocator), - .search = std.ArrayList(u8).init(allocator), - .ndots = 1, - .timeout = 5, - .attempts = 2, - }; - errdefer rc.deinit(); - - const file = fs.openFileAbsoluteZ("/etc/resolv.conf", .{}) catch |err| switch (err) { - error.FileNotFound, - error.NotDir, - error.AccessDenied, - => return linuxLookupNameFromNumericUnspec(&rc.ns, "127.0.0.1", 53), - else => |e| return e, - }; - defer file.close(); - - var buf_reader = std.io.bufferedReader(file.reader()); - const stream = buf_reader.reader(); - var line_buf: [512]u8 = undefined; - while (stream.readUntilDelimiterOrEof(&line_buf, '\n') catch |err| switch (err) { - error.StreamTooLong => blk: { - // Skip to the delimiter in the stream, to fix parsing - try stream.skipUntilDelimiterOrEof('\n'); - // Give an empty line to the while loop, which will be skipped. - break :blk line_buf[0..0]; - }, - else => |e| return e, - }) |line| { - const no_comment_line = no_comment_line: { - var split = mem.splitScalar(u8, line, '#'); - break :no_comment_line split.first(); - }; - var line_it = mem.tokenizeAny(u8, no_comment_line, " \t"); - - const token = line_it.next() orelse continue; - if (mem.eql(u8, token, "options")) { - while (line_it.next()) |sub_tok| { - var colon_it = mem.splitScalar(u8, sub_tok, ':'); - const name = colon_it.first(); - const value_txt = colon_it.next() orelse continue; - const value = std.fmt.parseInt(u8, value_txt, 10) catch |err| switch (err) { - // TODO https://github.com/ziglang/zig/issues/11812 - error.Overflow => @as(u8, 255), - error.InvalidCharacter => continue, - }; - if (mem.eql(u8, name, "ndots")) { - rc.ndots = @min(value, 15); - } else if (mem.eql(u8, name, "attempts")) { - rc.attempts = @min(value, 10); - } else if (mem.eql(u8, name, "timeout")) { - rc.timeout = @min(value, 60); - } - } - } else if (mem.eql(u8, token, "nameserver")) { - const ip_txt = line_it.next() orelse continue; - try linuxLookupNameFromNumericUnspec(&rc.ns, ip_txt, 53); - } else if (mem.eql(u8, token, "domain") or mem.eql(u8, token, "search")) { - rc.search.items.len = 0; - try rc.search.appendSlice(line_it.rest()); - } - } - - if (rc.ns.items.len == 0) { - return linuxLookupNameFromNumericUnspec(&rc.ns, "127.0.0.1", 53); - } -} - -fn linuxLookupNameFromNumericUnspec( - addrs: *std.ArrayList(LookupAddr), - name: []const u8, - port: u16, -) !void { - const addr = try Address.resolveIp(name, port); - (try addrs.addOne()).* = LookupAddr{ .addr = addr }; -} - -fn resMSendRc( - queries: []const []const u8, - answers: [][]u8, - answer_bufs: []const []u8, - rc: ResolvConf, -) !void { - const timeout = 1000 * rc.timeout; - const attempts = rc.attempts; - - var sl: posix.socklen_t = @sizeOf(posix.sockaddr.in); - var family: posix.sa_family_t = posix.AF.INET; - - var ns_list = std.ArrayList(Address).init(rc.ns.allocator); - defer ns_list.deinit(); - - try ns_list.resize(rc.ns.items.len); - const ns = ns_list.items; - - for (rc.ns.items, 0..) |iplit, i| { - ns[i] = iplit.addr; - assert(ns[i].getPort() == 53); - if (iplit.addr.any.family != posix.AF.INET) { - family = posix.AF.INET6; - } - } - - const flags = posix.SOCK.DGRAM | posix.SOCK.CLOEXEC | posix.SOCK.NONBLOCK; - const fd = posix.socket(family, flags, 0) catch |err| switch (err) { - error.AddressFamilyNotSupported => blk: { - // Handle case where system lacks IPv6 support - if (family == posix.AF.INET6) { - family = posix.AF.INET; - break :blk try posix.socket(posix.AF.INET, flags, 0); - } - return err; - }, - else => |e| return e, - }; - defer Stream.close(.{ .handle = fd }); - - // Past this point, there are no errors. Each individual query will - // yield either no reply (indicated by zero length) or an answer - // packet which is up to the caller to interpret. - - // Convert any IPv4 addresses in a mixed environment to v4-mapped - if (family == posix.AF.INET6) { - try posix.setsockopt( - fd, - posix.SOL.IPV6, - std.os.linux.IPV6.V6ONLY, - &mem.toBytes(@as(c_int, 0)), - ); - for (0..ns.len) |i| { - if (ns[i].any.family != posix.AF.INET) continue; - mem.writeInt(u32, ns[i].in6.sa.addr[12..], ns[i].in.sa.addr, native_endian); - ns[i].in6.sa.addr[0..12].* = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff".*; - ns[i].any.family = posix.AF.INET6; - ns[i].in6.sa.flowinfo = 0; - ns[i].in6.sa.scope_id = 0; - } - sl = @sizeOf(posix.sockaddr.in6); - } - - // Get local address and open/bind a socket - var sa: Address = undefined; - @memset(@as([*]u8, @ptrCast(&sa))[0..@sizeOf(Address)], 0); - sa.any.family = family; - try posix.bind(fd, &sa.any, sl); - - var pfd = [1]posix.pollfd{posix.pollfd{ - .fd = fd, - .events = posix.POLL.IN, - .revents = undefined, - }}; - const retry_interval = timeout / attempts; - var next: u32 = 0; - var t2: u64 = @bitCast(std.time.milliTimestamp()); - const t0 = t2; - var t1 = t2 - retry_interval; - - var servfail_retry: usize = undefined; - - outer: while (t2 - t0 < timeout) : (t2 = @as(u64, @bitCast(std.time.milliTimestamp()))) { - if (t2 - t1 >= retry_interval) { - // Query all configured nameservers in parallel - var i: usize = 0; - while (i < queries.len) : (i += 1) { - if (answers[i].len == 0) { - var j: usize = 0; - while (j < ns.len) : (j += 1) { - _ = posix.sendto(fd, queries[i], posix.MSG.NOSIGNAL, &ns[j].any, sl) catch undefined; - } - } - } - t1 = t2; - servfail_retry = 2 * queries.len; - } - - // Wait for a response, or until time to retry - const clamped_timeout = @min(@as(u31, std.math.maxInt(u31)), t1 + retry_interval - t2); - const nevents = posix.poll(&pfd, clamped_timeout) catch 0; - if (nevents == 0) continue; - - while (true) { - var sl_copy = sl; - const rlen = posix.recvfrom(fd, answer_bufs[next], 0, &sa.any, &sl_copy) catch break; - - // Ignore non-identifiable packets - if (rlen < 4) continue; - - // Ignore replies from addresses we didn't send to - var j: usize = 0; - while (j < ns.len and !ns[j].eql(sa)) : (j += 1) {} - if (j == ns.len) continue; - - // Find which query this answer goes with, if any - var i: usize = next; - while (i < queries.len and (answer_bufs[next][0] != queries[i][0] or - answer_bufs[next][1] != queries[i][1])) : (i += 1) - {} - - if (i == queries.len) continue; - if (answers[i].len != 0) continue; - - // Only accept positive or negative responses; - // retry immediately on server failure, and ignore - // all other codes such as refusal. - switch (answer_bufs[next][3] & 15) { - 0, 3 => {}, - 2 => if (servfail_retry != 0) { - servfail_retry -= 1; - _ = posix.sendto(fd, queries[i], posix.MSG.NOSIGNAL, &ns[j].any, sl) catch undefined; - }, - else => continue, - } - - // Store answer in the right slot, or update next - // available temp slot if it's already in place. - answers[i].len = rlen; - if (i == next) { - while (next < queries.len and answers[next].len != 0) : (next += 1) {} - } else { - @memcpy(answer_bufs[i][0..rlen], answer_bufs[next][0..rlen]); - } - - if (next == queries.len) break :outer; - } - } -} - -fn dnsParse( - r: []const u8, - ctx: anytype, - comptime callback: anytype, -) !void { - // This implementation is ported from musl libc. - // A more idiomatic "ziggy" implementation would be welcome. - if (r.len < 12) return error.InvalidDnsPacket; - if ((r[3] & 15) != 0) return; - var p = r.ptr + 12; - var qdcount = r[4] * @as(usize, 256) + r[5]; - var ancount = r[6] * @as(usize, 256) + r[7]; - if (qdcount + ancount > 64) return error.InvalidDnsPacket; - while (qdcount != 0) { - qdcount -= 1; - while (@intFromPtr(p) - @intFromPtr(r.ptr) < r.len and p[0] -% 1 < 127) p += 1; - if (p[0] > 193 or (p[0] == 193 and p[1] > 254) or @intFromPtr(p) > @intFromPtr(r.ptr) + r.len - 6) - return error.InvalidDnsPacket; - p += @as(usize, 5) + @intFromBool(p[0] != 0); - } - while (ancount != 0) { - ancount -= 1; - while (@intFromPtr(p) - @intFromPtr(r.ptr) < r.len and p[0] -% 1 < 127) p += 1; - if (p[0] > 193 or (p[0] == 193 and p[1] > 254) or @intFromPtr(p) > @intFromPtr(r.ptr) + r.len - 6) - return error.InvalidDnsPacket; - p += @as(usize, 1) + @intFromBool(p[0] != 0); - const len = p[8] * @as(usize, 256) + p[9]; - if (@intFromPtr(p) + len > @intFromPtr(r.ptr) + r.len) return error.InvalidDnsPacket; - try callback(ctx, p[1], p[10..][0..len], r); - p += 10 + len; - } -} - -fn dnsParseCallback(ctx: dpc_ctx, rr: u8, data: []const u8, packet: []const u8) !void { - switch (rr) { - posix.RR.A => { - if (data.len != 4) return error.InvalidDnsARecord; - const new_addr = try ctx.addrs.addOne(); - new_addr.* = LookupAddr{ - .addr = Address.initIp4(data[0..4].*, ctx.port), - }; - }, - posix.RR.AAAA => { - if (data.len != 16) return error.InvalidDnsAAAARecord; - const new_addr = try ctx.addrs.addOne(); - new_addr.* = LookupAddr{ - .addr = Address.initIp6(data[0..16].*, ctx.port, 0, 0), - }; - }, - posix.RR.CNAME => { - var tmp: [256]u8 = undefined; - // Returns len of compressed name. strlen to get canon name. - _ = try posix.dn_expand(packet, data, &tmp); - const canon_name = mem.sliceTo(&tmp, 0); - if (isValidHostName(canon_name)) { - ctx.canon.items.len = 0; - try ctx.canon.appendSlice(canon_name); - } - }, - else => return, - } -} - -pub const Stream = struct { - /// Underlying platform-defined type which may or may not be - /// interchangeable with a file system file descriptor. - handle: posix.socket_t, - - pub fn close(s: Stream) void { - switch (native_os) { - .windows => windows.closesocket(s.handle) catch unreachable, - else => posix.close(s.handle), - } - } - - pub const ReadError = posix.ReadError; - pub const WriteError = posix.WriteError; - - pub const Reader = io.Reader(Stream, ReadError, read); - pub const Writer = io.Writer(Stream, WriteError, write); - - pub fn reader(self: Stream) Reader { - return .{ .context = self }; - } - - pub fn writer(self: Stream) Writer { - return .{ .context = self }; - } - - pub fn read(self: Stream, buffer: []u8) ReadError!usize { - if (native_os == .windows) { - return windows.ReadFile(self.handle, buffer, null); - } - - return posix.read(self.handle, buffer); - } - - pub fn readv(s: Stream, iovecs: []const posix.iovec) ReadError!usize { - if (native_os == .windows) { - // TODO improve this to use ReadFileScatter - if (iovecs.len == 0) return @as(usize, 0); - const first = iovecs[0]; - return windows.ReadFile(s.handle, first.base[0..first.len], null); - } - - return posix.readv(s.handle, iovecs); - } - - /// Returns the number of bytes read. If the number read is smaller than - /// `buffer.len`, it means the stream reached the end. Reaching the end of - /// a stream is not an error condition. - pub fn readAll(s: Stream, buffer: []u8) ReadError!usize { - return readAtLeast(s, buffer, buffer.len); - } - - /// Returns the number of bytes read, calling the underlying read function - /// the minimal number of times until the buffer has at least `len` bytes - /// filled. If the number read is less than `len` it means the stream - /// reached the end. Reaching the end of the stream is not an error - /// condition. - pub fn readAtLeast(s: Stream, buffer: []u8, len: usize) ReadError!usize { - assert(len <= buffer.len); - var index: usize = 0; - while (index < len) { - const amt = try s.read(buffer[index..]); - if (amt == 0) break; - index += amt; - } - return index; - } - - /// TODO in evented I/O mode, this implementation incorrectly uses the event loop's - /// file system thread instead of non-blocking. It needs to be reworked to properly - /// use non-blocking I/O. - pub fn write(self: Stream, buffer: []const u8) WriteError!usize { - if (native_os == .windows) { - return windows.WriteFile(self.handle, buffer, null); - } - - return posix.write(self.handle, buffer); - } - - pub fn writeAll(self: Stream, bytes: []const u8) WriteError!void { - var index: usize = 0; - while (index < bytes.len) { - index += try self.write(bytes[index..]); - } - } - - /// See https://github.com/ziglang/zig/issues/7699 - /// See equivalent function: `std.fs.File.writev`. - pub fn writev(self: Stream, iovecs: []const posix.iovec_const) WriteError!usize { - return posix.writev(self.handle, iovecs); - } - - /// The `iovecs` parameter is mutable because this function needs to mutate the fields in - /// order to handle partial writes from the underlying OS layer. - /// See https://github.com/ziglang/zig/issues/7699 - /// See equivalent function: `std.fs.File.writevAll`. - pub fn writevAll(self: Stream, iovecs: []posix.iovec_const) WriteError!void { - if (iovecs.len == 0) return; - - var i: usize = 0; - while (true) { - var amt = try self.writev(iovecs[i..]); - while (amt >= iovecs[i].len) { - amt -= iovecs[i].len; - i += 1; - if (i >= iovecs.len) return; - } - iovecs[i].base += amt; - iovecs[i].len -= amt; - } - } - - pub fn async_read( - self: Stream, - buffer: []u8, - ctx: *Ctx, - comptime cbk: Cbk, - ) !void { - return ctx.loop.recv(Ctx, ctx, cbk, self.handle, buffer); - } - - pub fn async_readv( - s: Stream, - iovecs: []const posix.iovec, - ctx: *Ctx, - comptime cbk: Cbk, - ) ReadError!void { - if (iovecs.len == 0) return; - const first_buffer = iovecs[0].base[0..iovecs[0].len]; - return s.async_read(first_buffer, ctx, cbk); - } - - // TODO: why not take a buffer here? - pub fn async_write(self: Stream, buffer: []const u8, ctx: *Ctx, comptime cbk: Cbk) void { - return ctx.loop.send(Ctx, ctx, cbk, self.handle, buffer); - } - - fn onWriteAll(ctx: *Ctx, res: anyerror!void) anyerror!void { - res catch |err| return ctx.pop(err); - if (ctx.len() < ctx.buf().len) { - const new_buf = ctx.buf()[ctx.len()..]; - ctx.setBuf(new_buf); - return ctx.stream().async_write(new_buf, ctx, onWriteAll); - } - ctx.setBuf(null); - return ctx.pop({}); - } - - pub fn async_writeAll(self: Stream, bytes: []const u8, ctx: *Ctx, comptime cbk: Cbk) !void { - ctx.setBuf(bytes); - try ctx.push(cbk); - self.async_write(bytes, ctx, onWriteAll); - } -}; - -pub const Server = struct { - listen_address: Address, - stream: std.net.Stream, - - pub const Connection = struct { - stream: std.net.Stream, - address: Address, - }; - - pub fn deinit(s: *Server) void { - s.stream.close(); - s.* = undefined; - } - - pub const AcceptError = posix.AcceptError; - - /// Blocks until a client connects to the server. The returned `Connection` has - /// an open stream. - pub fn accept(s: *Server) AcceptError!Connection { - var accepted_addr: Address = undefined; - var addr_len: posix.socklen_t = @sizeOf(Address); - const fd = try posix.accept(s.stream.handle, &accepted_addr.any, &addr_len, posix.SOCK.CLOEXEC); - return .{ - .stream = .{ .handle = fd }, - .address = accepted_addr, - }; - } -}; - -test { - _ = @import("net/test.zig"); - _ = Server; - _ = Stream; - _ = Address; -} - -fn onTcpConnectToHost(ctx: *Ctx, res: anyerror!void) anyerror!void { - res catch |e| switch (e) { - error.ConnectionRefused => { - if (ctx.data.addr_current < ctx.data.list.addrs.len) { - // next iteration of addr - ctx.push(onTcpConnectToHost) catch |er| return ctx.pop(er); - ctx.data.addr_current += 1; - return async_tcpConnectToAddress( - ctx.data.list.addrs[ctx.data.addr_current], - ctx, - onTcpConnectToHost, - ); - } - // end of iteration of addr - ctx.data.list.deinit(); - return ctx.pop(e); - }, - else => { - ctx.data.list.deinit(); - return ctx.pop(std.posix.ConnectError.ConnectionRefused); - }, - }; - // success - ctx.data.list.deinit(); - return ctx.pop({}); -} - -pub fn async_tcpConnectToHost( - allocator: mem.Allocator, - name: []const u8, - port: u16, - ctx: *Ctx, - comptime cbk: Cbk, -) !void { - const list = std.net.getAddressList(allocator, name, port) catch |e| return ctx.pop(e); - if (list.addrs.len == 0) return ctx.pop(error.UnknownHostName); - - ctx.push(cbk) catch |e| return ctx.pop(e); - ctx.data.list = list; - ctx.data.addr_current = 0; - return async_tcpConnectToAddress(list.addrs[0], ctx, onTcpConnectToHost); -} - -pub fn async_tcpConnectToAddress(address: std.net.Address, ctx: *Ctx, comptime cbk: Cbk) !void { - const nonblock = 0; - const sock_flags = posix.SOCK.STREAM | nonblock | - (if (native_os == .windows) 0 else posix.SOCK.CLOEXEC); - const sockfd = try posix.socket(address.any.family, sock_flags, posix.IPPROTO.TCP); - - ctx.data.socket = sockfd; - ctx.push(cbk) catch |e| return ctx.pop(e); - - ctx.loop.connect( - Ctx, - ctx, - setStream, - sockfd, - address, - ); -} - -// requires client.data.socket to be set -fn setStream(ctx: *Ctx, res: anyerror!void) anyerror!void { - res catch |e| return ctx.pop(e); - ctx.data.conn.stream = .{ .handle = ctx.data.socket }; - return ctx.pop({}); -} diff --git a/src/http/async/std/net/test.zig b/src/http/async/std/net/test.zig deleted file mode 100644 index 3e316c54..00000000 --- a/src/http/async/std/net/test.zig +++ /dev/null @@ -1,335 +0,0 @@ -const std = @import("std"); -const builtin = @import("builtin"); -const net = std.net; -const mem = std.mem; -const testing = std.testing; - -test "parse and render IP addresses at comptime" { - if (builtin.os.tag == .wasi) return error.SkipZigTest; - comptime { - var ipAddrBuffer: [16]u8 = undefined; - // Parses IPv6 at comptime - const ipv6addr = net.Address.parseIp("::1", 0) catch unreachable; - var ipv6 = std.fmt.bufPrint(ipAddrBuffer[0..], "{}", .{ipv6addr}) catch unreachable; - try std.testing.expect(std.mem.eql(u8, "::1", ipv6[1 .. ipv6.len - 3])); - - // Parses IPv4 at comptime - const ipv4addr = net.Address.parseIp("127.0.0.1", 0) catch unreachable; - var ipv4 = std.fmt.bufPrint(ipAddrBuffer[0..], "{}", .{ipv4addr}) catch unreachable; - try std.testing.expect(std.mem.eql(u8, "127.0.0.1", ipv4[0 .. ipv4.len - 2])); - - // Returns error for invalid IP addresses at comptime - try testing.expectError(error.InvalidIPAddressFormat, net.Address.parseIp("::123.123.123.123", 0)); - try testing.expectError(error.InvalidIPAddressFormat, net.Address.parseIp("127.01.0.1", 0)); - try testing.expectError(error.InvalidIPAddressFormat, net.Address.resolveIp("::123.123.123.123", 0)); - try testing.expectError(error.InvalidIPAddressFormat, net.Address.resolveIp("127.01.0.1", 0)); - } -} - -test "parse and render IPv6 addresses" { - if (builtin.os.tag == .wasi) return error.SkipZigTest; - - var buffer: [100]u8 = undefined; - const ips = [_][]const u8{ - "FF01:0:0:0:0:0:0:FB", - "FF01::Fb", - "::1", - "::", - "1::", - "2001:db8::", - "::1234:5678", - "2001:db8::1234:5678", - "FF01::FB%1234", - "::ffff:123.5.123.5", - }; - const printed = [_][]const u8{ - "ff01::fb", - "ff01::fb", - "::1", - "::", - "1::", - "2001:db8::", - "::1234:5678", - "2001:db8::1234:5678", - "ff01::fb", - "::ffff:123.5.123.5", - }; - for (ips, 0..) |ip, i| { - const addr = net.Address.parseIp6(ip, 0) catch unreachable; - var newIp = std.fmt.bufPrint(buffer[0..], "{}", .{addr}) catch unreachable; - try std.testing.expect(std.mem.eql(u8, printed[i], newIp[1 .. newIp.len - 3])); - - if (builtin.os.tag == .linux) { - const addr_via_resolve = net.Address.resolveIp6(ip, 0) catch unreachable; - var newResolvedIp = std.fmt.bufPrint(buffer[0..], "{}", .{addr_via_resolve}) catch unreachable; - try std.testing.expect(std.mem.eql(u8, printed[i], newResolvedIp[1 .. newResolvedIp.len - 3])); - } - } - - try testing.expectError(error.InvalidCharacter, net.Address.parseIp6(":::", 0)); - try testing.expectError(error.Overflow, net.Address.parseIp6("FF001::FB", 0)); - try testing.expectError(error.InvalidCharacter, net.Address.parseIp6("FF01::Fb:zig", 0)); - try testing.expectError(error.InvalidEnd, net.Address.parseIp6("FF01:0:0:0:0:0:0:FB:", 0)); - try testing.expectError(error.Incomplete, net.Address.parseIp6("FF01:", 0)); - try testing.expectError(error.InvalidIpv4Mapping, net.Address.parseIp6("::123.123.123.123", 0)); - try testing.expectError(error.Incomplete, net.Address.parseIp6("1", 0)); - // TODO Make this test pass on other operating systems. - if (builtin.os.tag == .linux or comptime builtin.os.tag.isDarwin()) { - try testing.expectError(error.Incomplete, net.Address.resolveIp6("ff01::fb%", 0)); - try testing.expectError(error.Overflow, net.Address.resolveIp6("ff01::fb%wlp3s0s0s0s0s0s0s0s0", 0)); - try testing.expectError(error.Overflow, net.Address.resolveIp6("ff01::fb%12345678901234", 0)); - } -} - -test "invalid but parseable IPv6 scope ids" { - if (builtin.os.tag != .linux and comptime !builtin.os.tag.isDarwin()) { - // Currently, resolveIp6 with alphanumerical scope IDs only works on Linux. - // TODO Make this test pass on other operating systems. - return error.SkipZigTest; - } - - try testing.expectError(error.InterfaceNotFound, net.Address.resolveIp6("ff01::fb%123s45678901234", 0)); -} - -test "parse and render IPv4 addresses" { - if (builtin.os.tag == .wasi) return error.SkipZigTest; - - var buffer: [18]u8 = undefined; - for ([_][]const u8{ - "0.0.0.0", - "255.255.255.255", - "1.2.3.4", - "123.255.0.91", - "127.0.0.1", - }) |ip| { - const addr = net.Address.parseIp4(ip, 0) catch unreachable; - var newIp = std.fmt.bufPrint(buffer[0..], "{}", .{addr}) catch unreachable; - try std.testing.expect(std.mem.eql(u8, ip, newIp[0 .. newIp.len - 2])); - } - - try testing.expectError(error.Overflow, net.Address.parseIp4("256.0.0.1", 0)); - try testing.expectError(error.InvalidCharacter, net.Address.parseIp4("x.0.0.1", 0)); - try testing.expectError(error.InvalidEnd, net.Address.parseIp4("127.0.0.1.1", 0)); - try testing.expectError(error.Incomplete, net.Address.parseIp4("127.0.0.", 0)); - try testing.expectError(error.InvalidCharacter, net.Address.parseIp4("100..0.1", 0)); - try testing.expectError(error.NonCanonical, net.Address.parseIp4("127.01.0.1", 0)); -} - -test "parse and render UNIX addresses" { - if (builtin.os.tag == .wasi) return error.SkipZigTest; - if (!net.has_unix_sockets) return error.SkipZigTest; - - var buffer: [14]u8 = undefined; - const addr = net.Address.initUnix("/tmp/testpath") catch unreachable; - const fmt_addr = std.fmt.bufPrint(buffer[0..], "{}", .{addr}) catch unreachable; - try std.testing.expectEqualSlices(u8, "/tmp/testpath", fmt_addr); - - const too_long = [_]u8{'a'} ** 200; - try testing.expectError(error.NameTooLong, net.Address.initUnix(too_long[0..])); -} - -test "resolve DNS" { - if (builtin.os.tag == .wasi) return error.SkipZigTest; - - if (builtin.os.tag == .windows) { - _ = try std.os.windows.WSAStartup(2, 2); - } - defer { - if (builtin.os.tag == .windows) { - std.os.windows.WSACleanup() catch unreachable; - } - } - - // Resolve localhost, this should not fail. - { - const localhost_v4 = try net.Address.parseIp("127.0.0.1", 80); - const localhost_v6 = try net.Address.parseIp("::2", 80); - - const result = try net.getAddressList(testing.allocator, "localhost", 80); - defer result.deinit(); - for (result.addrs) |addr| { - if (addr.eql(localhost_v4) or addr.eql(localhost_v6)) break; - } else @panic("unexpected address for localhost"); - } - - { - // The tests are required to work even when there is no Internet connection, - // so some of these errors we must accept and skip the test. - const result = net.getAddressList(testing.allocator, "example.com", 80) catch |err| switch (err) { - error.UnknownHostName => return error.SkipZigTest, - error.TemporaryNameServerFailure => return error.SkipZigTest, - else => return err, - }; - result.deinit(); - } -} - -test "listen on a port, send bytes, receive bytes" { - if (builtin.single_threaded) return error.SkipZigTest; - if (builtin.os.tag == .wasi) return error.SkipZigTest; - - if (builtin.os.tag == .windows) { - _ = try std.os.windows.WSAStartup(2, 2); - } - defer { - if (builtin.os.tag == .windows) { - std.os.windows.WSACleanup() catch unreachable; - } - } - - // Try only the IPv4 variant as some CI builders have no IPv6 localhost - // configured. - const localhost = try net.Address.parseIp("127.0.0.1", 0); - - var server = try localhost.listen(.{}); - defer server.deinit(); - - const S = struct { - fn clientFn(server_address: net.Address) !void { - const socket = try net.tcpConnectToAddress(server_address); - defer socket.close(); - - _ = try socket.writer().writeAll("Hello world!"); - } - }; - - const t = try std.Thread.spawn(.{}, S.clientFn, .{server.listen_address}); - defer t.join(); - - var client = try server.accept(); - defer client.stream.close(); - var buf: [16]u8 = undefined; - const n = try client.stream.reader().read(&buf); - - try testing.expectEqual(@as(usize, 12), n); - try testing.expectEqualSlices(u8, "Hello world!", buf[0..n]); -} - -test "listen on an in use port" { - if (builtin.os.tag != .linux and comptime !builtin.os.tag.isDarwin()) { - // TODO build abstractions for other operating systems - return error.SkipZigTest; - } - - const localhost = try net.Address.parseIp("127.0.0.1", 0); - - var server1 = try localhost.listen(.{ .reuse_port = true }); - defer server1.deinit(); - - var server2 = try server1.listen_address.listen(.{ .reuse_port = true }); - defer server2.deinit(); -} - -fn testClientToHost(allocator: mem.Allocator, name: []const u8, port: u16) anyerror!void { - if (builtin.os.tag == .wasi) return error.SkipZigTest; - - const connection = try net.tcpConnectToHost(allocator, name, port); - defer connection.close(); - - var buf: [100]u8 = undefined; - const len = try connection.read(&buf); - const msg = buf[0..len]; - try testing.expect(mem.eql(u8, msg, "hello from server\n")); -} - -fn testClient(addr: net.Address) anyerror!void { - if (builtin.os.tag == .wasi) return error.SkipZigTest; - - const socket_file = try net.tcpConnectToAddress(addr); - defer socket_file.close(); - - var buf: [100]u8 = undefined; - const len = try socket_file.read(&buf); - const msg = buf[0..len]; - try testing.expect(mem.eql(u8, msg, "hello from server\n")); -} - -fn testServer(server: *net.Server) anyerror!void { - if (builtin.os.tag == .wasi) return error.SkipZigTest; - - var client = try server.accept(); - - const stream = client.stream.writer(); - try stream.print("hello from server\n", .{}); -} - -test "listen on a unix socket, send bytes, receive bytes" { - if (builtin.single_threaded) return error.SkipZigTest; - if (!net.has_unix_sockets) return error.SkipZigTest; - - if (builtin.os.tag == .windows) { - _ = try std.os.windows.WSAStartup(2, 2); - } - defer { - if (builtin.os.tag == .windows) { - std.os.windows.WSACleanup() catch unreachable; - } - } - - const socket_path = try generateFileName("socket.unix"); - defer testing.allocator.free(socket_path); - - const socket_addr = try net.Address.initUnix(socket_path); - defer std.fs.cwd().deleteFile(socket_path) catch {}; - - var server = try socket_addr.listen(.{}); - defer server.deinit(); - - const S = struct { - fn clientFn(path: []const u8) !void { - const socket = try net.connectUnixSocket(path); - defer socket.close(); - - _ = try socket.writer().writeAll("Hello world!"); - } - }; - - const t = try std.Thread.spawn(.{}, S.clientFn, .{socket_path}); - defer t.join(); - - var client = try server.accept(); - defer client.stream.close(); - var buf: [16]u8 = undefined; - const n = try client.stream.reader().read(&buf); - - try testing.expectEqual(@as(usize, 12), n); - try testing.expectEqualSlices(u8, "Hello world!", buf[0..n]); -} - -fn generateFileName(base_name: []const u8) ![]const u8 { - const random_bytes_count = 12; - const sub_path_len = comptime std.fs.base64_encoder.calcSize(random_bytes_count); - var random_bytes: [12]u8 = undefined; - std.crypto.random.bytes(&random_bytes); - var sub_path: [sub_path_len]u8 = undefined; - _ = std.fs.base64_encoder.encode(&sub_path, &random_bytes); - return std.fmt.allocPrint(testing.allocator, "{s}-{s}", .{ sub_path[0..], base_name }); -} - -test "non-blocking tcp server" { - if (builtin.os.tag == .wasi) return error.SkipZigTest; - if (true) { - // https://github.com/ziglang/zig/issues/18315 - return error.SkipZigTest; - } - - const localhost = try net.Address.parseIp("127.0.0.1", 0); - var server = localhost.listen(.{ .force_nonblocking = true }); - defer server.deinit(); - - const accept_err = server.accept(); - try testing.expectError(error.WouldBlock, accept_err); - - const socket_file = try net.tcpConnectToAddress(server.listen_address); - defer socket_file.close(); - - var client = try server.accept(); - defer client.stream.close(); - const stream = client.stream.writer(); - try stream.print("hello from server\n", .{}); - - var buf: [100]u8 = undefined; - const len = try socket_file.read(&buf); - const msg = buf[0..len]; - try testing.expect(mem.eql(u8, msg, "hello from server\n")); -} diff --git a/src/http/async/tls.zig/PrivateKey.zig b/src/http/async/tls.zig/PrivateKey.zig deleted file mode 100644 index 0e2b944d..00000000 --- a/src/http/async/tls.zig/PrivateKey.zig +++ /dev/null @@ -1,260 +0,0 @@ -const std = @import("std"); -const Allocator = std.mem.Allocator; -const Certificate = std.crypto.Certificate; -const der = Certificate.der; -const rsa = @import("rsa/rsa.zig"); -const base64 = std.base64.standard.decoderWithIgnore(" \t\r\n"); -const proto = @import("protocol.zig"); - -const max_ecdsa_key_len = 66; - -signature_scheme: proto.SignatureScheme, - -key: union { - rsa: rsa.KeyPair, - ecdsa: [max_ecdsa_key_len]u8, -}, - -const PrivateKey = @This(); - -pub fn fromFile(gpa: Allocator, file: std.fs.File) !PrivateKey { - const buf = try file.readToEndAlloc(gpa, 1024 * 1024); - defer gpa.free(buf); - return try parsePem(buf); -} - -pub fn parsePem(buf: []const u8) !PrivateKey { - const key_start, const key_end, const marker_version = try findKey(buf); - const encoded = std.mem.trim(u8, buf[key_start..key_end], " \t\r\n"); - - // required bytes: - // 2412, 1821, 1236 for rsa 4096, 3072, 2048 bits size keys - var decoded: [4096]u8 = undefined; - const n = try base64.decode(&decoded, encoded); - - if (marker_version == 2) { - return try parseEcDer(decoded[0..n]); - } - return try parseDer(decoded[0..n]); -} - -fn findKey(buf: []const u8) !struct { usize, usize, usize } { - const markers = [_]struct { - begin: []const u8, - end: []const u8, - }{ - .{ .begin = "-----BEGIN PRIVATE KEY-----", .end = "-----END PRIVATE KEY-----" }, - .{ .begin = "-----BEGIN EC PRIVATE KEY-----", .end = "-----END EC PRIVATE KEY-----" }, - }; - - for (markers, 1..) |marker, ver| { - const begin_marker_start = std.mem.indexOfPos(u8, buf, 0, marker.begin) orelse continue; - const key_start = begin_marker_start + marker.begin.len; - const key_end = std.mem.indexOfPos(u8, buf, key_start, marker.end) orelse continue; - - return .{ key_start, key_end, ver }; - } - return error.MissingEndMarker; -} - -// ref: https://asn1js.eu/#MIG2AgEAMBAGByqGSM49AgEGBSuBBAAiBIGeMIGbAgEBBDBKFkVJCtU9FR6egz3yNxKBwXd86cFzMYqyGb8hRc1zVvLdw-So_2FBtITp6jzYmFShZANiAAQ-CH3a1R0V6dFlTK8Rs4M4egrpPtdta0osysO0Zl8mkBiDsTlvJNqeAp7L2ItHgFW8k_CfhgQT6iLDacNMhKC4XOV07r_ePD-mmkvqvRmzfOowHUoVRhCKrOTmF_J9Syc -pub fn parseDer(buf: []const u8) !PrivateKey { - const info = try der.Element.parse(buf, 0); - const version = try der.Element.parse(buf, info.slice.start); - - const algo_seq = try der.Element.parse(buf, version.slice.end); - const algo_cat = try der.Element.parse(buf, algo_seq.slice.start); - - const key_str = try der.Element.parse(buf, algo_seq.slice.end); - const key_seq = try der.Element.parse(buf, key_str.slice.start); - const key_int = try der.Element.parse(buf, key_seq.slice.start); - - const category = try Certificate.parseAlgorithmCategory(buf, algo_cat); - switch (category) { - .rsaEncryption => { - const modulus = try der.Element.parse(buf, key_int.slice.end); - const public_exponent = try der.Element.parse(buf, modulus.slice.end); - const private_exponent = try der.Element.parse(buf, public_exponent.slice.end); - - const public_key = try rsa.PublicKey.fromBytes(content(buf, modulus), content(buf, public_exponent)); - const secret_key = try rsa.SecretKey.fromBytes(public_key.modulus, content(buf, private_exponent)); - const key_pair = rsa.KeyPair{ .public = public_key, .secret = secret_key }; - - return .{ - .signature_scheme = switch (key_pair.public.modulus.bits()) { - 4096 => .rsa_pss_rsae_sha512, - 3072 => .rsa_pss_rsae_sha384, - else => .rsa_pss_rsae_sha256, - }, - .key = .{ .rsa = key_pair }, - }; - }, - .X9_62_id_ecPublicKey => { - const key = try der.Element.parse(buf, key_int.slice.end); - const algo_param = try der.Element.parse(buf, algo_cat.slice.end); - const named_curve = try Certificate.parseNamedCurve(buf, algo_param); - return .{ - .signature_scheme = signatureScheme(named_curve), - .key = .{ .ecdsa = ecdsaKey(buf, key) }, - }; - }, - else => unreachable, - } -} - -// References: -// https://asn1js.eu/#MHcCAQEEINJSRKv8kSKEzLHptfAlg-LGh4_pHHlq0XLf30Q9pcztoAoGCCqGSM49AwEHoUQDQgAEJpmLyp8aGCgyMcFIJaIq_-4V1K6nPpeoih3bT2npeplF9eyXj7rm8eW9Ua6VLhq71mqtMC-YLm-IkORBVq1cuA -// https://www.rfc-editor.org/rfc/rfc5915 -pub fn parseEcDer(bytes: []const u8) !PrivateKey { - const pki_msg = try der.Element.parse(bytes, 0); - const version = try der.Element.parse(bytes, pki_msg.slice.start); - const key = try der.Element.parse(bytes, version.slice.end); - const parameters = try der.Element.parse(bytes, key.slice.end); - const curve = try der.Element.parse(bytes, parameters.slice.start); - const named_curve = try Certificate.parseNamedCurve(bytes, curve); - return .{ - .signature_scheme = signatureScheme(named_curve), - .key = .{ .ecdsa = ecdsaKey(bytes, key) }, - }; -} - -fn signatureScheme(named_curve: Certificate.NamedCurve) proto.SignatureScheme { - return switch (named_curve) { - .X9_62_prime256v1 => .ecdsa_secp256r1_sha256, - .secp384r1 => .ecdsa_secp384r1_sha384, - .secp521r1 => .ecdsa_secp521r1_sha512, - }; -} - -fn ecdsaKey(bytes: []const u8, e: der.Element) [max_ecdsa_key_len]u8 { - const data = content(bytes, e); - var ecdsa_key: [max_ecdsa_key_len]u8 = undefined; - @memcpy(ecdsa_key[0..data.len], data); - return ecdsa_key; -} - -fn content(bytes: []const u8, e: der.Element) []const u8 { - return bytes[e.slice.start..e.slice.end]; -} - -const testing = std.testing; -const testu = @import("testu.zig"); - -test "parse ec pem" { - const data = @embedFile("testdata/ec_private_key.pem"); - var pk = try parsePem(data); - const priv_key = &testu.hexToBytes( - \\ 10 35 3d ca 1b 15 1d 06 aa 71 b8 ef f3 19 22 - \\ 43 78 f3 20 98 1e b1 2f 2b 64 7e 71 d0 30 2a - \\ 90 aa e5 eb 99 c3 90 65 3d c1 26 19 be 3f 08 - \\ 20 9b 01 - ); - try testing.expectEqualSlices(u8, priv_key, pk.key.ecdsa[0..priv_key.len]); - try testing.expectEqual(.ecdsa_secp384r1_sha384, pk.signature_scheme); -} - -test "parse ec prime256v1" { - const data = @embedFile("testdata/ec_prime256v1_private_key.pem"); - var pk = try parsePem(data); - const priv_key = &testu.hexToBytes( - \\ d2 52 44 ab fc 91 22 84 cc b1 e9 b5 f0 25 83 - \\ e2 c6 87 8f e9 1c 79 6a d1 72 df df 44 3d a5 - \\ cc ed - ); - try testing.expectEqualSlices(u8, priv_key, pk.key.ecdsa[0..priv_key.len]); - try testing.expectEqual(.ecdsa_secp256r1_sha256, pk.signature_scheme); -} - -test "parse ec secp384r1" { - const data = @embedFile("testdata/ec_secp384r1_private_key.pem"); - var pk = try parsePem(data); - const priv_key = &testu.hexToBytes( - \\ ee 6d 8a 5e 0d d3 b0 c6 4b 32 40 80 e2 3a de - \\ 8b 1e dd e2 92 db 36 1c db 91 ea ba a1 06 0d - \\ 42 2d d9 a9 dc 05 43 29 f1 78 7c f9 08 af c5 - \\ 03 1f 6d - ); - try testing.expectEqualSlices(u8, priv_key, pk.key.ecdsa[0..priv_key.len]); - try testing.expectEqual(.ecdsa_secp384r1_sha384, pk.signature_scheme); -} - -test "parse ec secp521r1" { - const data = @embedFile("testdata/ec_secp521r1_private_key.pem"); - var pk = try parsePem(data); - const priv_key = &testu.hexToBytes( - \\ 01 f0 2f 5a c7 24 18 ea 68 23 8c 2e a1 b4 b8 - \\ dc f2 11 b2 96 b0 ec 87 80 42 bf de ba f4 96 - \\ 83 8f 9b db c6 60 a7 4c d9 60 3a e4 ba 0b df - \\ ae 24 d3 1b c2 6e 82 a0 88 c1 ed 17 20 0d 3a - \\ f1 c5 7e e8 0b 27 - ); - try testing.expectEqualSlices(u8, priv_key, pk.key.ecdsa[0..priv_key.len]); - try testing.expectEqual(.ecdsa_secp521r1_sha512, pk.signature_scheme); -} - -test "parse rsa pem" { - const data = @embedFile("testdata/rsa_private_key.pem"); - const pk = try parsePem(data); - - // expected results from: - // $ openssl pkey -in testdata/rsa_private_key.pem -text -noout - const modulus = &testu.hexToBytes( - \\ 00 de f7 23 e6 75 cc 6f dd d5 6e 0f 8c 09 f8 - \\ 62 e3 60 1b c0 7d 8c d5 04 50 2c 36 e2 3b f7 - \\ 33 9f a1 14 af be cf 1a 0f 4c f5 cb 39 70 0e - \\ 3b 97 d6 21 f7 48 91 79 ca 7c 68 fc ea 62 a1 - \\ 5a 72 4f 78 57 0e cc f2 a3 50 05 f1 4c ca 51 - \\ 73 10 9a 18 8e 71 f5 b4 c7 3e be 4c ef 37 d4 - \\ 84 4b 82 1c ec 08 a3 cc 07 3d 5c 0b e5 85 3f - \\ fe b6 44 77 8f 3c 6a 2f 33 c3 5d f6 f2 29 46 - \\ 04 25 7e 05 d9 f8 3b 2d a4 40 66 9f 0d 6d 1a - \\ fa bc 0a c5 8b 86 43 30 ef 14 20 41 9d b5 cc - \\ 3e 63 b5 48 04 27 c9 5c d3 62 28 5f f5 b6 e4 - \\ 77 49 99 ac 84 4a a6 67 a5 9a 1a 37 c7 60 4c - \\ ba c1 70 cf 57 64 4a 21 ea 05 53 10 ec 94 71 - \\ 4a 43 04 83 00 aa 5a 28 bc f2 8c 58 14 92 d2 - \\ 83 17 f4 7b 29 0f e7 87 a2 47 b2 53 19 12 23 - \\ fb 4b ce 5a f8 a1 84 f9 b1 f3 bf e3 fa 10 f8 - \\ ad af 87 ce 03 0e a0 2c 13 71 57 c4 55 44 48 - \\ 44 cb - ); - const public_exponent = &testu.hexToBytes("01 00 01"); - const private_exponent = &testu.hexToBytes( - \\ 50 3b 80 98 aa a5 11 50 33 40 32 aa 02 e0 75 - \\ bd 3a 55 62 34 0b 9c 8f bb c5 dd 4e 15 a4 03 - \\ d8 9a 5f 56 4a 84 3d ed 69 95 3d 37 03 02 ac - \\ 21 1c 36 06 c4 ff 4c 63 37 d7 93 c3 48 10 a5 - \\ fa 62 6c 7c 6f 60 02 a4 0f e4 c3 8b 0d 76 b7 - \\ c0 2e a3 4d 86 e6 92 d1 eb db 10 d6 38 31 ea - \\ 15 3d d1 e8 81 c7 67 60 e7 8c 9a df 51 ce d0 - \\ 7a 88 32 b9 c1 54 b8 7d 98 fc d4 23 1a 05 0e - \\ f2 ea e1 72 29 28 2a 68 b7 90 18 80 1c 21 d6 - \\ 36 a8 6b 4a 9c dd 14 b8 9f 85 ee 95 0b f4 c6 - \\ 17 02 aa 4d ea 4d f9 39 d7 dd 9d b4 1d d2 f8 - \\ 92 46 0f 18 41 80 f4 ea 27 55 29 f8 37 59 bf - \\ 43 ec a3 eb 19 ba bc 13 06 95 3d 25 4b c9 72 - \\ cf 41 0a 6f aa cb 79 d4 7b fa b1 09 7c e2 2f - \\ 85 51 44 8b c6 97 8e 46 f9 6b ac 08 87 92 ce - \\ af 0b bf 8c bd 27 51 8f 09 e4 d3 f9 04 ac fa - \\ f2 04 70 3e d9 a6 28 17 c2 2d 74 e9 25 40 02 - \\ 49 - ); - - try testing.expectEqual(.rsa_pss_rsae_sha256, pk.signature_scheme); - const kp = pk.key.rsa; - { - var bytes: [modulus.len]u8 = undefined; - try kp.public.modulus.toBytes(&bytes, .big); - try testing.expectEqualSlices(u8, modulus, &bytes); - } - { - var bytes: [private_exponent.len]u8 = undefined; - try kp.public.public_exponent.toBytes(&bytes, .big); - try testing.expectEqualSlices(u8, public_exponent, bytes[bytes.len - public_exponent.len .. bytes.len]); - } - { - var btytes: [private_exponent.len]u8 = undefined; - try kp.secret.private_exponent.toBytes(&btytes, .big); - try testing.expectEqualSlices(u8, private_exponent, &btytes); - } -} diff --git a/src/http/async/tls.zig/cbc/main.zig b/src/http/async/tls.zig/cbc/main.zig deleted file mode 100644 index 25038445..00000000 --- a/src/http/async/tls.zig/cbc/main.zig +++ /dev/null @@ -1,148 +0,0 @@ -// This file is originally copied from: https://github.com/jedisct1/zig-cbc. -// -// It is modified then to have TLS padding insead of PKCS#7 padding. -// Reference: -// https://datatracker.ietf.org/doc/html/rfc5246/#section-6.2.3.2 -// https://crypto.stackexchange.com/questions/98917/on-the-correctness-of-the-padding-example-of-rfc-5246 -// -// If required padding i n bytes -// PKCS#7 padding is (n...n) -// TLS padding is (n-1...n-1) - n times of n-1 value -// -const std = @import("std"); -const aes = std.crypto.core.aes; -const mem = std.mem; -const debug = std.debug; - -/// CBC mode with TLS 1.2 padding -/// -/// Important: the counter mode doesn't provide authenticated encryption: the ciphertext can be trivially modified without this being detected. -/// If you need authenticated encryption, use anything from `std.crypto.aead` instead. -/// If you really need to use CBC mode, make sure to use a MAC to authenticate the ciphertext. -pub fn CBC(comptime BlockCipher: anytype) type { - const EncryptCtx = aes.AesEncryptCtx(BlockCipher); - const DecryptCtx = aes.AesDecryptCtx(BlockCipher); - - return struct { - const Self = @This(); - - enc_ctx: EncryptCtx, - dec_ctx: DecryptCtx, - - /// Initialize the CBC context with the given key. - pub fn init(key: [BlockCipher.key_bits / 8]u8) Self { - const enc_ctx = BlockCipher.initEnc(key); - const dec_ctx = DecryptCtx.initFromEnc(enc_ctx); - - return Self{ .enc_ctx = enc_ctx, .dec_ctx = dec_ctx }; - } - - /// Return the length of the ciphertext given the length of the plaintext. - pub fn paddedLength(length: usize) usize { - return (std.math.divCeil(usize, length + 1, EncryptCtx.block_length) catch unreachable) * EncryptCtx.block_length; - } - - /// Encrypt the given plaintext for the given IV. - /// The destination buffer must be large enough to hold the padded plaintext. - /// Use the `paddedLength()` function to compute the ciphertext size. - /// IV must be secret and unpredictable. - pub fn encrypt(self: Self, dst: []u8, src: []const u8, iv: [EncryptCtx.block_length]u8) void { - // Note: encryption *could* be parallelized, see https://research.kudelskisecurity.com/2022/11/17/some-aes-cbc-encryption-myth-busting/ - const block_length = EncryptCtx.block_length; - const padded_length = paddedLength(src.len); - debug.assert(dst.len == padded_length); // destination buffer must hold the padded plaintext - var cv = iv; - var i: usize = 0; - while (i + block_length <= src.len) : (i += block_length) { - const in = src[i..][0..block_length]; - for (cv[0..], in) |*x, y| x.* ^= y; - self.enc_ctx.encrypt(&cv, &cv); - @memcpy(dst[i..][0..block_length], &cv); - } - // Last block - var in = [_]u8{0} ** block_length; - const padding_length: u8 = @intCast(padded_length - src.len - 1); - @memset(&in, padding_length); - @memcpy(in[0 .. src.len - i], src[i..]); - for (cv[0..], in) |*x, y| x.* ^= y; - self.enc_ctx.encrypt(&cv, &cv); - @memcpy(dst[i..], cv[0 .. dst.len - i]); - } - - /// Decrypt the given ciphertext for the given IV. - /// The destination buffer must be large enough to hold the plaintext. - /// IV must be secret, unpredictable and match the one used for encryption. - pub fn decrypt(self: Self, dst: []u8, src: []const u8, iv: [DecryptCtx.block_length]u8) !void { - const block_length = DecryptCtx.block_length; - if (src.len != dst.len) { - return error.EncodingError; - } - debug.assert(src.len % block_length == 0); - var i: usize = 0; - var cv = iv; - var out: [block_length]u8 = undefined; - // Decryption could be parallelized - while (i + block_length <= dst.len) : (i += block_length) { - const in = src[i..][0..block_length]; - self.dec_ctx.decrypt(&out, in); - for (&out, cv) |*x, y| x.* ^= y; - cv = in.*; - @memcpy(dst[i..][0..block_length], &out); - } - // Last block - We intentionally don't check the padding to mitigate timing attacks - if (i < dst.len) { - const in = src[i..][0..block_length]; - @memset(&out, 0); - self.dec_ctx.decrypt(&out, in); - for (&out, cv) |*x, y| x.* ^= y; - @memcpy(dst[i..], out[0 .. dst.len - i]); - } - } - }; -} - -test "CBC mode" { - const M = CBC(aes.Aes128); - const key = [_]u8{ 0x2b, 0x7e, 0x15, 0x16, 0x28, 0xae, 0xd2, 0xa6, 0xab, 0xf7, 0x15, 0x88, 0x09, 0xcf, 0x4f, 0x3c }; - const iv = [_]u8{ 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f }; - const src_ = "This is a test of AES-CBC that goes on longer than a couple blocks. It is a somewhat long test case to type out!"; - const expected = "\xA0\x8C\x09\x7D\xFF\x42\xB6\x65\x4D\x4B\xC6\x90\x90\x39\xDE\x3D\xC7\xCA\xEB\xF6\x9A\x4F\x09\x97\xC9\x32\xAB\x75\x88\xB7\x57\x17"; - var res: [32]u8 = undefined; - - try comptime std.testing.expect(src_.len / M.paddedLength(1) >= 3); // Ensure that we have at least 3 blocks - - const z = M.init(key); - - // Test encryption and decryption with distinct buffers - var h = std.crypto.hash.sha2.Sha256.init(.{}); - inline for (0..src_.len) |len| { - const src = src_[0..len]; - var dst = [_]u8{0} ** M.paddedLength(src.len); - - z.encrypt(&dst, src, iv); - h.update(&dst); - - var decrypted = [_]u8{0} ** dst.len; - try z.decrypt(&decrypted, &dst, iv); - - const padding = decrypted[decrypted.len - 1] + 1; - try std.testing.expectEqualSlices(u8, src, decrypted[0 .. decrypted.len - padding]); - } - h.final(&res); - try std.testing.expectEqualSlices(u8, expected, &res); - - // Test encryption and decryption with the same buffer - h = std.crypto.hash.sha2.Sha256.init(.{}); - inline for (0..src_.len) |len| { - var buf = [_]u8{0} ** M.paddedLength(len); - @memcpy(buf[0..len], src_[0..len]); - z.encrypt(&buf, buf[0..len], iv); - h.update(&buf); - - try z.decrypt(&buf, &buf, iv); - - try std.testing.expectEqualSlices(u8, src_[0..len], buf[0..len]); - } - h.final(&res); - try std.testing.expectEqualSlices(u8, expected, &res); -} diff --git a/src/http/async/tls.zig/cipher.zig b/src/http/async/tls.zig/cipher.zig deleted file mode 100644 index dbf4a07a..00000000 --- a/src/http/async/tls.zig/cipher.zig +++ /dev/null @@ -1,1004 +0,0 @@ -const std = @import("std"); -const crypto = std.crypto; -const hkdfExpandLabel = crypto.tls.hkdfExpandLabel; - -const Sha1 = crypto.hash.Sha1; -const Sha256 = crypto.hash.sha2.Sha256; -const Sha384 = crypto.hash.sha2.Sha384; - -const record = @import("record.zig"); -const Record = record.Record; -const Transcript = @import("transcript.zig").Transcript; -const proto = @import("protocol.zig"); - -// tls 1.2 cbc cipher types -const CbcAes128Sha1 = CbcType(crypto.core.aes.Aes128, Sha1); -const CbcAes128Sha256 = CbcType(crypto.core.aes.Aes128, Sha256); -const CbcAes256Sha256 = CbcType(crypto.core.aes.Aes256, Sha256); -const CbcAes256Sha384 = CbcType(crypto.core.aes.Aes256, Sha384); -// tls 1.2 gcm cipher types -const Aead12Aes128Gcm = Aead12Type(crypto.aead.aes_gcm.Aes128Gcm); -const Aead12Aes256Gcm = Aead12Type(crypto.aead.aes_gcm.Aes256Gcm); -// tls 1.2 chacha cipher type -const Aead12ChaCha = Aead12ChaChaType(crypto.aead.chacha_poly.ChaCha20Poly1305); -// tls 1.3 cipher types -const Aes128GcmSha256 = Aead13Type(crypto.aead.aes_gcm.Aes128Gcm, Sha256); -const Aes256GcmSha384 = Aead13Type(crypto.aead.aes_gcm.Aes256Gcm, Sha384); -const ChaChaSha256 = Aead13Type(crypto.aead.chacha_poly.ChaCha20Poly1305, Sha256); -const Aegis128Sha256 = Aead13Type(crypto.aead.aegis.Aegis128L, Sha256); - -pub const encrypt_overhead_tls_12: comptime_int = @max( - CbcAes128Sha1.encrypt_overhead, - CbcAes128Sha256.encrypt_overhead, - CbcAes256Sha256.encrypt_overhead, - CbcAes256Sha384.encrypt_overhead, - Aead12Aes128Gcm.encrypt_overhead, - Aead12Aes256Gcm.encrypt_overhead, - Aead12ChaCha.encrypt_overhead, -); -pub const encrypt_overhead_tls_13: comptime_int = @max( - Aes128GcmSha256.encrypt_overhead, - Aes256GcmSha384.encrypt_overhead, - ChaChaSha256.encrypt_overhead, - Aegis128Sha256.encrypt_overhead, -); - -// ref (length): https://www.rfc-editor.org/rfc/rfc8446#section-5.1 -pub const max_cleartext_len = 1 << 14; -// ref (length): https://www.rfc-editor.org/rfc/rfc8446#section-5.2 -// The sum of the lengths of the content and the padding, plus one for the inner -// content type, plus any expansion added by the AEAD algorithm. -pub const max_ciphertext_len = max_cleartext_len + 256; -pub const max_ciphertext_record_len = record.header_len + max_ciphertext_len; - -/// Returns type for cipher suite tag. -fn CipherType(comptime tag: CipherSuite) type { - return switch (tag) { - // tls 1.2 cbc - .ECDHE_ECDSA_WITH_AES_128_CBC_SHA, - .ECDHE_RSA_WITH_AES_128_CBC_SHA, - .RSA_WITH_AES_128_CBC_SHA, - => CbcAes128Sha1, - .ECDHE_ECDSA_WITH_AES_128_CBC_SHA256, - .ECDHE_RSA_WITH_AES_128_CBC_SHA256, - .RSA_WITH_AES_128_CBC_SHA256, - => CbcAes128Sha256, - .ECDHE_ECDSA_WITH_AES_256_CBC_SHA384, - .ECDHE_RSA_WITH_AES_256_CBC_SHA384, - => CbcAes256Sha384, - - // tls 1.2 gcm - .ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, - .ECDHE_RSA_WITH_AES_128_GCM_SHA256, - => Aead12Aes128Gcm, - .ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, - .ECDHE_RSA_WITH_AES_256_GCM_SHA384, - => Aead12Aes256Gcm, - - // tls 1.2 chacha - .ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256, - .ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256, - => Aead12ChaCha, - - // tls 1.3 - .AES_128_GCM_SHA256 => Aes128GcmSha256, - .AES_256_GCM_SHA384 => Aes256GcmSha384, - .CHACHA20_POLY1305_SHA256 => ChaChaSha256, - .AEGIS_128L_SHA256 => Aegis128Sha256, - - else => unreachable, - }; -} - -/// Provides initialization and common encrypt/decrypt methods for all supported -/// ciphers. Tls 1.2 has only application cipher, tls 1.3 has separate cipher -/// for handshake and application. -pub const Cipher = union(CipherSuite) { - // tls 1.2 cbc - ECDHE_ECDSA_WITH_AES_128_CBC_SHA: CipherType(.ECDHE_ECDSA_WITH_AES_128_CBC_SHA), - ECDHE_RSA_WITH_AES_128_CBC_SHA: CipherType(.ECDHE_RSA_WITH_AES_128_CBC_SHA), - RSA_WITH_AES_128_CBC_SHA: CipherType(.RSA_WITH_AES_128_CBC_SHA), - - ECDHE_ECDSA_WITH_AES_128_CBC_SHA256: CipherType(.ECDHE_ECDSA_WITH_AES_128_CBC_SHA256), - ECDHE_RSA_WITH_AES_128_CBC_SHA256: CipherType(.ECDHE_RSA_WITH_AES_128_CBC_SHA256), - RSA_WITH_AES_128_CBC_SHA256: CipherType(.RSA_WITH_AES_128_CBC_SHA256), - - ECDHE_ECDSA_WITH_AES_256_CBC_SHA384: CipherType(.ECDHE_ECDSA_WITH_AES_256_CBC_SHA384), - ECDHE_RSA_WITH_AES_256_CBC_SHA384: CipherType(.ECDHE_RSA_WITH_AES_256_CBC_SHA384), - // tls 1.2 gcm - ECDHE_ECDSA_WITH_AES_128_GCM_SHA256: CipherType(.ECDHE_ECDSA_WITH_AES_128_GCM_SHA256), - ECDHE_ECDSA_WITH_AES_256_GCM_SHA384: CipherType(.ECDHE_RSA_WITH_AES_256_GCM_SHA384), - ECDHE_RSA_WITH_AES_128_GCM_SHA256: CipherType(.ECDHE_RSA_WITH_AES_128_GCM_SHA256), - ECDHE_RSA_WITH_AES_256_GCM_SHA384: CipherType(.ECDHE_RSA_WITH_AES_256_GCM_SHA384), - // tls 1.2 chacha - ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256: CipherType(.ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256), - ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256: CipherType(.ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256), - // tls 1.3 - AES_128_GCM_SHA256: CipherType(.AES_128_GCM_SHA256), - AES_256_GCM_SHA384: CipherType(.AES_256_GCM_SHA384), - CHACHA20_POLY1305_SHA256: CipherType(.CHACHA20_POLY1305_SHA256), - AEGIS_128L_SHA256: CipherType(.AEGIS_128L_SHA256), - - // tls 1.2 application cipher - pub fn initTls12(tag: CipherSuite, key_material: []const u8, side: proto.Side) !Cipher { - switch (tag) { - inline .ECDHE_ECDSA_WITH_AES_128_CBC_SHA, - .ECDHE_RSA_WITH_AES_128_CBC_SHA, - .RSA_WITH_AES_128_CBC_SHA, - .ECDHE_ECDSA_WITH_AES_128_CBC_SHA256, - .ECDHE_RSA_WITH_AES_128_CBC_SHA256, - .RSA_WITH_AES_128_CBC_SHA256, - .ECDHE_ECDSA_WITH_AES_256_CBC_SHA384, - .ECDHE_RSA_WITH_AES_256_CBC_SHA384, - .ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, - .ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, - .ECDHE_RSA_WITH_AES_128_GCM_SHA256, - .ECDHE_RSA_WITH_AES_256_GCM_SHA384, - .ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256, - .ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256, - => |comptime_tag| { - return @unionInit(Cipher, @tagName(comptime_tag), CipherType(comptime_tag).init(key_material, side)); - }, - else => return error.TlsIllegalParameter, - } - } - - // tls 1.3 handshake or application cipher - pub fn initTls13(tag: CipherSuite, secret: Transcript.Secret, side: proto.Side) !Cipher { - return switch (tag) { - inline .AES_128_GCM_SHA256, - .AES_256_GCM_SHA384, - .CHACHA20_POLY1305_SHA256, - .AEGIS_128L_SHA256, - => |comptime_tag| { - return @unionInit(Cipher, @tagName(comptime_tag), CipherType(comptime_tag).init(secret, side)); - }, - else => return error.TlsIllegalParameter, - }; - } - - pub fn encrypt( - c: *Cipher, - buf: []u8, - content_type: proto.ContentType, - cleartext: []const u8, - ) ![]const u8 { - return switch (c.*) { - inline else => |*f| try f.encrypt(buf, content_type, cleartext), - }; - } - - pub fn decrypt( - c: *Cipher, - buf: []u8, - rec: Record, - ) !struct { proto.ContentType, []u8 } { - return switch (c.*) { - inline else => |*f| { - const content_type, const cleartext = try f.decrypt(buf, rec); - if (cleartext.len > max_cleartext_len) return error.TlsRecordOverflow; - return .{ content_type, cleartext }; - }, - }; - } - - pub fn encryptSeq(c: Cipher) u64 { - return switch (c) { - inline else => |f| f.encrypt_seq, - }; - } - - pub fn keyUpdateEncrypt(c: *Cipher) !void { - return switch (c.*) { - inline .AES_128_GCM_SHA256, - .AES_256_GCM_SHA384, - .CHACHA20_POLY1305_SHA256, - .AEGIS_128L_SHA256, - => |*f| f.keyUpdateEncrypt(), - // can't happen on tls 1.2 - else => return error.TlsUnexpectedMessage, - }; - } - pub fn keyUpdateDecrypt(c: *Cipher) !void { - return switch (c.*) { - inline .AES_128_GCM_SHA256, - .AES_256_GCM_SHA384, - .CHACHA20_POLY1305_SHA256, - .AEGIS_128L_SHA256, - => |*f| f.keyUpdateDecrypt(), - // can't happen on tls 1.2 - else => return error.TlsUnexpectedMessage, - }; - } -}; - -fn Aead12Type(comptime AeadType: type) type { - return struct { - const explicit_iv_len = 8; - const key_len = AeadType.key_length; - const auth_tag_len = AeadType.tag_length; - const nonce_len = AeadType.nonce_length; - const iv_len = AeadType.nonce_length - explicit_iv_len; - const encrypt_overhead = record.header_len + explicit_iv_len + auth_tag_len; - - encrypt_key: [key_len]u8, - decrypt_key: [key_len]u8, - encrypt_iv: [iv_len]u8, - decrypt_iv: [iv_len]u8, - encrypt_seq: u64 = 0, - decrypt_seq: u64 = 0, - rnd: std.Random = crypto.random, - - const Self = @This(); - - fn init(key_material: []const u8, side: proto.Side) Self { - const client_key = key_material[0..key_len].*; - const server_key = key_material[key_len..][0..key_len].*; - const client_iv = key_material[2 * key_len ..][0..iv_len].*; - const server_iv = key_material[2 * key_len + iv_len ..][0..iv_len].*; - return .{ - .encrypt_key = if (side == .client) client_key else server_key, - .decrypt_key = if (side == .client) server_key else client_key, - .encrypt_iv = if (side == .client) client_iv else server_iv, - .decrypt_iv = if (side == .client) server_iv else client_iv, - }; - } - - /// Returns encrypted tls record in format: - /// ----------------- buf ---------------------- - /// header | explicit_iv | ciphertext | auth_tag - /// - /// tls record header: 5 bytes - /// explicit_iv: 8 bytes - /// ciphertext: same length as cleartext - /// auth_tag: 16 bytes - pub fn encrypt( - self: *Self, - buf: []u8, - content_type: proto.ContentType, - cleartext: []const u8, - ) ![]const u8 { - const record_len = record.header_len + explicit_iv_len + cleartext.len + auth_tag_len; - if (buf.len < record_len) return error.BufferOverflow; - - const header = buf[0..record.header_len]; - const explicit_iv = buf[record.header_len..][0..explicit_iv_len]; - const ciphertext = buf[record.header_len + explicit_iv_len ..][0..cleartext.len]; - const auth_tag = buf[record.header_len + explicit_iv_len + cleartext.len ..][0..auth_tag_len]; - - header.* = record.header(content_type, explicit_iv_len + cleartext.len + auth_tag_len); - self.rnd.bytes(explicit_iv); - const iv = self.encrypt_iv ++ explicit_iv.*; - const ad = additionalData(self.encrypt_seq, content_type, cleartext.len); - AeadType.encrypt(ciphertext, auth_tag, cleartext, &ad, iv, self.encrypt_key); - self.encrypt_seq +%= 1; - - return buf[0..record_len]; - } - - /// Decrypts payload into cleartext. Returns tls record content type and - /// cleartext. - /// Accepts tls record header and payload: - /// header | ----------- payload --------------- - /// header | explicit_iv | ciphertext | auth_tag - pub fn decrypt( - self: *Self, - buf: []u8, - rec: Record, - ) !struct { proto.ContentType, []u8 } { - const overhead = explicit_iv_len + auth_tag_len; - if (rec.payload.len < overhead) return error.TlsDecryptError; - const cleartext_len = rec.payload.len - overhead; - if (buf.len < cleartext_len) return error.BufferOverflow; - - const explicit_iv = rec.payload[0..explicit_iv_len]; - const ciphertext = rec.payload[explicit_iv_len..][0..cleartext_len]; - const auth_tag = rec.payload[explicit_iv_len + cleartext_len ..][0..auth_tag_len]; - - const iv = self.decrypt_iv ++ explicit_iv.*; - const ad = additionalData(self.decrypt_seq, rec.content_type, cleartext_len); - const cleartext = buf[0..cleartext_len]; - AeadType.decrypt(cleartext, ciphertext, auth_tag.*, &ad, iv, self.decrypt_key) catch return error.TlsDecryptError; - self.decrypt_seq +%= 1; - return .{ rec.content_type, cleartext }; - } - }; -} - -fn Aead12ChaChaType(comptime AeadType: type) type { - return struct { - const key_len = AeadType.key_length; - const auth_tag_len = AeadType.tag_length; - const nonce_len = AeadType.nonce_length; - const encrypt_overhead = record.header_len + auth_tag_len; - - encrypt_key: [key_len]u8, - decrypt_key: [key_len]u8, - encrypt_iv: [nonce_len]u8, - decrypt_iv: [nonce_len]u8, - encrypt_seq: u64 = 0, - decrypt_seq: u64 = 0, - - const Self = @This(); - - fn init(key_material: []const u8, side: proto.Side) Self { - const client_key = key_material[0..key_len].*; - const server_key = key_material[key_len..][0..key_len].*; - const client_iv = key_material[2 * key_len ..][0..nonce_len].*; - const server_iv = key_material[2 * key_len + nonce_len ..][0..nonce_len].*; - return .{ - .encrypt_key = if (side == .client) client_key else server_key, - .decrypt_key = if (side == .client) server_key else client_key, - .encrypt_iv = if (side == .client) client_iv else server_iv, - .decrypt_iv = if (side == .client) server_iv else client_iv, - }; - } - - /// Returns encrypted tls record in format: - /// ------------ buf ------------- - /// header | ciphertext | auth_tag - /// - /// tls record header: 5 bytes - /// ciphertext: same length as cleartext - /// auth_tag: 16 bytes - pub fn encrypt( - self: *Self, - buf: []u8, - content_type: proto.ContentType, - cleartext: []const u8, - ) ![]const u8 { - const record_len = record.header_len + cleartext.len + auth_tag_len; - if (buf.len < record_len) return error.BufferOverflow; - - const ciphertext = buf[record.header_len..][0..cleartext.len]; - const auth_tag = buf[record.header_len + ciphertext.len ..][0..auth_tag_len]; - - const ad = additionalData(self.encrypt_seq, content_type, cleartext.len); - const iv = ivWithSeq(nonce_len, self.encrypt_iv, self.encrypt_seq); - AeadType.encrypt(ciphertext, auth_tag, cleartext, &ad, iv, self.encrypt_key); - self.encrypt_seq +%= 1; - - buf[0..record.header_len].* = record.header(content_type, ciphertext.len + auth_tag.len); - return buf[0..record_len]; - } - - /// Decrypts payload into cleartext. Returns tls record content type and - /// cleartext. - /// Accepts tls record header and payload: - /// header | ----- payload ------- - /// header | ciphertext | auth_tag - pub fn decrypt( - self: *Self, - buf: []u8, - rec: Record, - ) !struct { proto.ContentType, []u8 } { - const overhead = auth_tag_len; - if (rec.payload.len < overhead) return error.TlsDecryptError; - const cleartext_len = rec.payload.len - overhead; - if (buf.len < cleartext_len) return error.BufferOverflow; - - const ciphertext = rec.payload[0..cleartext_len]; - const auth_tag = rec.payload[cleartext_len..][0..auth_tag_len]; - const cleartext = buf[0..cleartext_len]; - - const ad = additionalData(self.decrypt_seq, rec.content_type, cleartext_len); - const iv = ivWithSeq(nonce_len, self.decrypt_iv, self.decrypt_seq); - AeadType.decrypt(cleartext, ciphertext, auth_tag.*, &ad, iv, self.decrypt_key) catch return error.TlsDecryptError; - self.decrypt_seq +%= 1; - return .{ rec.content_type, cleartext }; - } - }; -} - -fn Aead13Type(comptime AeadType: type, comptime Hash: type) type { - return struct { - const Hmac = crypto.auth.hmac.Hmac(Hash); - const Hkdf = crypto.kdf.hkdf.Hkdf(Hmac); - - const key_len = AeadType.key_length; - const auth_tag_len = AeadType.tag_length; - const nonce_len = AeadType.nonce_length; - const digest_len = Hash.digest_length; - const encrypt_overhead = record.header_len + 1 + auth_tag_len; - - encrypt_secret: [digest_len]u8, - decrypt_secret: [digest_len]u8, - encrypt_key: [key_len]u8, - decrypt_key: [key_len]u8, - encrypt_iv: [nonce_len]u8, - decrypt_iv: [nonce_len]u8, - encrypt_seq: u64 = 0, - decrypt_seq: u64 = 0, - - const Self = @This(); - - pub fn init(secret: Transcript.Secret, side: proto.Side) Self { - var self = Self{ - .encrypt_secret = if (side == .client) secret.client[0..digest_len].* else secret.server[0..digest_len].*, - .decrypt_secret = if (side == .server) secret.client[0..digest_len].* else secret.server[0..digest_len].*, - .encrypt_key = undefined, - .decrypt_key = undefined, - .encrypt_iv = undefined, - .decrypt_iv = undefined, - }; - self.keyGenerate(); - return self; - } - - fn keyGenerate(self: *Self) void { - self.encrypt_key = hkdfExpandLabel(Hkdf, self.encrypt_secret, "key", "", key_len); - self.decrypt_key = hkdfExpandLabel(Hkdf, self.decrypt_secret, "key", "", key_len); - self.encrypt_iv = hkdfExpandLabel(Hkdf, self.encrypt_secret, "iv", "", nonce_len); - self.decrypt_iv = hkdfExpandLabel(Hkdf, self.decrypt_secret, "iv", "", nonce_len); - } - - pub fn keyUpdateEncrypt(self: *Self) void { - self.encrypt_secret = hkdfExpandLabel(Hkdf, self.encrypt_secret, "traffic upd", "", digest_len); - self.encrypt_seq = 0; - self.keyGenerate(); - } - - pub fn keyUpdateDecrypt(self: *Self) void { - self.decrypt_secret = hkdfExpandLabel(Hkdf, self.decrypt_secret, "traffic upd", "", digest_len); - self.decrypt_seq = 0; - self.keyGenerate(); - } - - /// Returns encrypted tls record in format: - /// ------------ buf ------------- - /// header | ciphertext | auth_tag - /// - /// tls record header: 5 bytes - /// ciphertext: cleartext len + 1 byte content type - /// auth_tag: 16 bytes - pub fn encrypt( - self: *Self, - buf: []u8, - content_type: proto.ContentType, - cleartext: []const u8, - ) ![]const u8 { - const payload_len = cleartext.len + 1 + auth_tag_len; - const record_len = record.header_len + payload_len; - if (buf.len < record_len) return error.BufferOverflow; - - const header = buf[0..record.header_len]; - header.* = record.header(.application_data, payload_len); - - // Skip @memcpy if cleartext is already part of the buf at right position - if (&cleartext[0] != &buf[record.header_len]) { - @memcpy(buf[record.header_len..][0..cleartext.len], cleartext); - } - buf[record.header_len + cleartext.len] = @intFromEnum(content_type); - const ciphertext = buf[record.header_len..][0 .. cleartext.len + 1]; - const auth_tag = buf[record.header_len + ciphertext.len ..][0..auth_tag_len]; - - const iv = ivWithSeq(nonce_len, self.encrypt_iv, self.encrypt_seq); - AeadType.encrypt(ciphertext, auth_tag, ciphertext, header, iv, self.encrypt_key); - self.encrypt_seq +%= 1; - return buf[0..record_len]; - } - - /// Decrypts payload into cleartext. Returns tls record content type and - /// cleartext. - /// Accepts tls record header and payload: - /// header | ------- payload --------- - /// header | ciphertext | auth_tag - /// header | cleartext + ct | auth_tag - /// Ciphertext after decryption contains cleartext and content type (1 byte). - pub fn decrypt( - self: *Self, - buf: []u8, - rec: Record, - ) !struct { proto.ContentType, []u8 } { - const overhead = auth_tag_len + 1; - if (rec.payload.len < overhead) return error.TlsDecryptError; - const ciphertext_len = rec.payload.len - auth_tag_len; - if (buf.len < ciphertext_len) return error.BufferOverflow; - - const ciphertext = rec.payload[0..ciphertext_len]; - const auth_tag = rec.payload[ciphertext_len..][0..auth_tag_len]; - - const iv = ivWithSeq(nonce_len, self.decrypt_iv, self.decrypt_seq); - AeadType.decrypt(buf[0..ciphertext_len], ciphertext, auth_tag.*, rec.header, iv, self.decrypt_key) catch return error.TlsBadRecordMac; - - // Remove zero bytes padding - var content_type_idx: usize = ciphertext_len - 1; - while (buf[content_type_idx] == 0 and content_type_idx > 0) : (content_type_idx -= 1) {} - - const cleartext = buf[0..content_type_idx]; - const content_type: proto.ContentType = @enumFromInt(buf[content_type_idx]); - self.decrypt_seq +%= 1; - return .{ content_type, cleartext }; - } - }; -} - -fn CbcType(comptime BlockCipher: type, comptime HashType: type) type { - const CBC = @import("cbc/main.zig").CBC(BlockCipher); - return struct { - const mac_len = HashType.digest_length; // 20, 32, 48 bytes for sha1, sha256, sha384 - const key_len = BlockCipher.key_bits / 8; // 16, 32 for Aes128, Aes256 - const iv_len = 16; - const encrypt_overhead = record.header_len + iv_len + mac_len + max_padding; - - pub const Hmac = crypto.auth.hmac.Hmac(HashType); - const paddedLength = CBC.paddedLength; - const max_padding = 16; - - encrypt_secret: [mac_len]u8, - decrypt_secret: [mac_len]u8, - encrypt_key: [key_len]u8, - decrypt_key: [key_len]u8, - encrypt_seq: u64 = 0, - decrypt_seq: u64 = 0, - rnd: std.Random = crypto.random, - - const Self = @This(); - - fn init(key_material: []const u8, side: proto.Side) Self { - const client_secret = key_material[0..mac_len].*; - const server_secret = key_material[mac_len..][0..mac_len].*; - const client_key = key_material[2 * mac_len ..][0..key_len].*; - const server_key = key_material[2 * mac_len + key_len ..][0..key_len].*; - return .{ - .encrypt_secret = if (side == .client) client_secret else server_secret, - .decrypt_secret = if (side == .client) server_secret else client_secret, - .encrypt_key = if (side == .client) client_key else server_key, - .decrypt_key = if (side == .client) server_key else client_key, - }; - } - - /// Returns encrypted tls record in format: - /// ----------------- buf ----------------- - /// header | iv | ------ ciphertext ------- - /// header | iv | cleartext | mac | padding - /// - /// tls record header: 5 bytes - /// iv: 16 bytes - /// ciphertext: cleartext length + mac + padding - /// mac: 20, 32 or 48 (sha1, sha256, sha384) - /// padding: 1-16 bytes - /// - /// Max encrypt buf overhead = iv + mac + padding (1-16) - /// aes_128_cbc_sha => 16 + 20 + 16 = 52 - /// aes_128_cbc_sha256 => 16 + 32 + 16 = 64 - /// aes_256_cbc_sha384 => 16 + 48 + 16 = 80 - pub fn encrypt( - self: *Self, - buf: []u8, - content_type: proto.ContentType, - cleartext: []const u8, - ) ![]const u8 { - const max_record_len = record.header_len + iv_len + cleartext.len + mac_len + max_padding; - if (buf.len < max_record_len) return error.BufferOverflow; - const cleartext_idx = record.header_len + iv_len; // position of cleartext in buf - @memcpy(buf[cleartext_idx..][0..cleartext.len], cleartext); - - { // calculate mac from (ad + cleartext) - // ... | ad | cleartext | mac | ... - // | -- mac msg -- | mac | - const ad = additionalData(self.encrypt_seq, content_type, cleartext.len); - const mac_msg = buf[cleartext_idx - ad.len ..][0 .. ad.len + cleartext.len]; - @memcpy(mac_msg[0..ad.len], &ad); - const mac = buf[cleartext_idx + cleartext.len ..][0..mac_len]; - Hmac.create(mac, mac_msg, &self.encrypt_secret); - self.encrypt_seq +%= 1; - } - - // ... | cleartext | mac | ... - // ... | -- plaintext --- ... - // ... | cleartext | mac | padding - // ... | ------ ciphertext ------- - const unpadded_len = cleartext.len + mac_len; - const padded_len = paddedLength(unpadded_len); - const plaintext = buf[cleartext_idx..][0..unpadded_len]; - const ciphertext = buf[cleartext_idx..][0..padded_len]; - - // Add header and iv at the buf start - // header | iv | ... - buf[0..record.header_len].* = record.header(content_type, iv_len + ciphertext.len); - const iv = buf[record.header_len..][0..iv_len]; - self.rnd.bytes(iv); - - // encrypt plaintext into ciphertext - CBC.init(self.encrypt_key).encrypt(ciphertext, plaintext, iv[0..iv_len].*); - - // header | iv | ------ ciphertext ------- - return buf[0 .. record.header_len + iv_len + ciphertext.len]; - } - - /// Decrypts payload into cleartext. Returns tls record content type and - /// cleartext. - pub fn decrypt( - self: *Self, - buf: []u8, - rec: Record, - ) !struct { proto.ContentType, []u8 } { - if (rec.payload.len < iv_len + mac_len + 1) return error.TlsDecryptError; - - // --------- payload ------------ - // iv | ------ ciphertext ------- - // iv | cleartext | mac | padding - const iv = rec.payload[0..iv_len]; - const ciphertext = rec.payload[iv_len..]; - - if (buf.len < ciphertext.len + additional_data_len) return error.BufferOverflow; - // ---------- buf --------------- - // ad | ------ plaintext -------- - // ad | cleartext | mac | padding - const plaintext = buf[additional_data_len..][0..ciphertext.len]; - // decrypt ciphertext -> plaintext - CBC.init(self.decrypt_key).decrypt(plaintext, ciphertext, iv[0..iv_len].*) catch return error.TlsDecryptError; - - // get padding len from last padding byte - const padding_len = plaintext[plaintext.len - 1] + 1; - if (plaintext.len < mac_len + padding_len) return error.TlsDecryptError; - // split plaintext into cleartext and mac - const cleartext_len = plaintext.len - mac_len - padding_len; - const cleartext = plaintext[0..cleartext_len]; - const mac = plaintext[cleartext_len..][0..mac_len]; - - // write ad to the buf - var ad = additionalData(self.decrypt_seq, rec.content_type, cleartext_len); - @memcpy(buf[0..ad.len], &ad); - const mac_msg = buf[0 .. ad.len + cleartext_len]; - self.decrypt_seq +%= 1; - - // calculate expected mac and compare with received mac - var expected_mac: [mac_len]u8 = undefined; - Hmac.create(&expected_mac, mac_msg, &self.decrypt_secret); - if (!std.mem.eql(u8, &expected_mac, mac)) - return error.TlsBadRecordMac; - - return .{ rec.content_type, cleartext }; - } - }; -} - -// xor lower 8 iv bytes with seq -fn ivWithSeq(comptime nonce_len: usize, iv: [nonce_len]u8, seq: u64) [nonce_len]u8 { - var res = iv; - const buf = res[nonce_len - 8 ..]; - const operand = std.mem.readInt(u64, buf, .big); - std.mem.writeInt(u64, buf, operand ^ seq, .big); - return res; -} - -pub const additional_data_len = record.header_len + @sizeOf(u64); - -fn additionalData(seq: u64, content_type: proto.ContentType, payload_len: usize) [additional_data_len]u8 { - const header = record.header(content_type, payload_len); - var seq_buf: [8]u8 = undefined; - std.mem.writeInt(u64, &seq_buf, seq, .big); - return seq_buf ++ header; -} - -// Cipher suites lists. In the order of preference. -// For the preference using grades priority and rules from Go project. -// https://ciphersuite.info/page/faq/ -// https://github.com/golang/go/blob/73186ba00251b3ed8baaab36e4f5278c7681155b/src/crypto/tls/cipher_suites.go#L226 -pub const cipher_suites = struct { - const tls12_secure = if (crypto.core.aes.has_hardware_support) [_]CipherSuite{ - // recommended - .ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, - .ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, - .ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256, - // secure - .ECDHE_RSA_WITH_AES_128_GCM_SHA256, - .ECDHE_RSA_WITH_AES_256_GCM_SHA384, - .ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256, - } else [_]CipherSuite{ - // recommended - .ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256, - .ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, - .ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, - - // secure - .ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256, - .ECDHE_RSA_WITH_AES_128_GCM_SHA256, - .ECDHE_RSA_WITH_AES_256_GCM_SHA384, - }; - const tls12_week = [_]CipherSuite{ - // week - .ECDHE_ECDSA_WITH_AES_128_CBC_SHA256, - .ECDHE_ECDSA_WITH_AES_256_CBC_SHA384, - .ECDHE_ECDSA_WITH_AES_128_CBC_SHA, - .ECDHE_RSA_WITH_AES_128_CBC_SHA256, - .ECDHE_RSA_WITH_AES_256_CBC_SHA384, - .ECDHE_RSA_WITH_AES_128_CBC_SHA, - - .RSA_WITH_AES_128_CBC_SHA256, - .RSA_WITH_AES_128_CBC_SHA, - }; - pub const tls13_ = if (crypto.core.aes.has_hardware_support) [_]CipherSuite{ - .AES_128_GCM_SHA256, - .AES_256_GCM_SHA384, - .CHACHA20_POLY1305_SHA256, - // Excluded because didn't find server which supports it to test - // .AEGIS_128L_SHA256 - } else [_]CipherSuite{ - .CHACHA20_POLY1305_SHA256, - .AES_128_GCM_SHA256, - .AES_256_GCM_SHA384, - }; - - pub const tls13 = &tls13_; - pub const tls12 = &(tls12_secure ++ tls12_week); - pub const secure = &(tls13_ ++ tls12_secure); - pub const all = &(tls13_ ++ tls12_secure ++ tls12_week); - - pub fn includes(list: []const CipherSuite, cs: CipherSuite) bool { - for (list) |s| { - if (cs == s) return true; - } - return false; - } -}; - -// Week, secure, recommended grades are from https://ciphersuite.info/page/faq/ -pub const CipherSuite = enum(u16) { - // tls 1.2 cbc sha1 - ECDHE_ECDSA_WITH_AES_128_CBC_SHA = 0xc009, // week - ECDHE_RSA_WITH_AES_128_CBC_SHA = 0xc013, // week - RSA_WITH_AES_128_CBC_SHA = 0x002F, // week - // tls 1.2 cbc sha256 - ECDHE_ECDSA_WITH_AES_128_CBC_SHA256 = 0xc023, // week - ECDHE_RSA_WITH_AES_128_CBC_SHA256 = 0xc027, // week - RSA_WITH_AES_128_CBC_SHA256 = 0x003c, // week - // tls 1.2 cbc sha384 - ECDHE_ECDSA_WITH_AES_256_CBC_SHA384 = 0xc024, // week - ECDHE_RSA_WITH_AES_256_CBC_SHA384 = 0xc028, // week - // tls 1.2 gcm - ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 = 0xc02b, // recommended - ECDHE_ECDSA_WITH_AES_256_GCM_SHA384 = 0xc02c, // recommended - ECDHE_RSA_WITH_AES_128_GCM_SHA256 = 0xc02f, // secure - ECDHE_RSA_WITH_AES_256_GCM_SHA384 = 0xc030, // secure - // tls 1.2 chacha - ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256 = 0xcca9, // recommended - ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256 = 0xcca8, // secure - // tls 1.3 (all are recommended) - AES_128_GCM_SHA256 = 0x1301, - AES_256_GCM_SHA384 = 0x1302, - CHACHA20_POLY1305_SHA256 = 0x1303, - AEGIS_128L_SHA256 = 0x1307, - // AEGIS_256_SHA512 = 0x1306, - _, - - pub fn validate(cs: CipherSuite) !void { - if (cipher_suites.includes(cipher_suites.tls12, cs)) return; - if (cipher_suites.includes(cipher_suites.tls13, cs)) return; - return error.TlsIllegalParameter; - } - - pub const Versions = enum { - both, - tls_1_3, - tls_1_2, - }; - - // get tls versions from list of cipher suites - pub fn versions(list: []const CipherSuite) !Versions { - var has_12 = false; - var has_13 = false; - for (list) |cs| { - if (cipher_suites.includes(cipher_suites.tls12, cs)) { - has_12 = true; - } else { - if (cipher_suites.includes(cipher_suites.tls13, cs)) has_13 = true; - } - } - if (has_12 and has_13) return .both; - if (has_12) return .tls_1_2; - if (has_13) return .tls_1_3; - return error.TlsIllegalParameter; - } - - pub const KeyExchangeAlgorithm = enum { - ecdhe, - rsa, - }; - - pub fn keyExchange(s: CipherSuite) KeyExchangeAlgorithm { - return switch (s) { - // Random premaster secret, encrypted with publich key from certificate. - // No server key exchange message. - .RSA_WITH_AES_128_CBC_SHA, - .RSA_WITH_AES_128_CBC_SHA256, - => .rsa, - else => .ecdhe, - }; - } - - pub const HashTag = enum { - sha256, - sha384, - sha512, - }; - - pub inline fn hash(cs: CipherSuite) HashTag { - return switch (cs) { - .ECDHE_RSA_WITH_AES_256_CBC_SHA384, - .ECDHE_RSA_WITH_AES_256_GCM_SHA384, - .ECDHE_ECDSA_WITH_AES_256_CBC_SHA384, - .ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, - .AES_256_GCM_SHA384, - => .sha384, - else => .sha256, - }; - } -}; - -const testing = std.testing; -const testu = @import("testu.zig"); - -test "CipherSuite validate" { - { - const cs: CipherSuite = .AES_256_GCM_SHA384; - try cs.validate(); - try testing.expectEqual(cs.hash(), .sha384); - try testing.expectEqual(cs.keyExchange(), .ecdhe); - } - { - const cs: CipherSuite = .AES_128_GCM_SHA256; - try cs.validate(); - try testing.expectEqual(.sha256, cs.hash()); - try testing.expectEqual(.ecdhe, cs.keyExchange()); - } - for (cipher_suites.tls12) |cs| { - try cs.validate(); - _ = cs.hash(); - _ = cs.keyExchange(); - } -} - -test "CipherSuite versions" { - try testing.expectEqual(.tls_1_3, CipherSuite.versions(&[_]CipherSuite{.AES_128_GCM_SHA256})); - try testing.expectEqual(.both, CipherSuite.versions(&[_]CipherSuite{ .AES_128_GCM_SHA256, .ECDHE_ECDSA_WITH_AES_128_CBC_SHA })); - try testing.expectEqual(.tls_1_2, CipherSuite.versions(&[_]CipherSuite{.RSA_WITH_AES_128_CBC_SHA})); -} - -test "gcm 1.2 encrypt overhead" { - inline for ([_]type{ - Aead12Aes128Gcm, - Aead12Aes256Gcm, - }) |T| { - { - const expected_key_len = switch (T) { - Aead12Aes128Gcm => 16, - Aead12Aes256Gcm => 32, - else => unreachable, - }; - try testing.expectEqual(expected_key_len, T.key_len); - try testing.expectEqual(16, T.auth_tag_len); - try testing.expectEqual(12, T.nonce_len); - try testing.expectEqual(4, T.iv_len); - try testing.expectEqual(29, T.encrypt_overhead); - } - } -} - -test "cbc 1.2 encrypt overhead" { - try testing.expectEqual(85, encrypt_overhead_tls_12); - - inline for ([_]type{ - CbcAes128Sha1, - CbcAes128Sha256, - CbcAes256Sha384, - }) |T| { - switch (T) { - CbcAes128Sha1 => { - try testing.expectEqual(20, T.mac_len); - try testing.expectEqual(16, T.key_len); - try testing.expectEqual(57, T.encrypt_overhead); - }, - CbcAes128Sha256 => { - try testing.expectEqual(32, T.mac_len); - try testing.expectEqual(16, T.key_len); - try testing.expectEqual(69, T.encrypt_overhead); - }, - CbcAes256Sha384 => { - try testing.expectEqual(48, T.mac_len); - try testing.expectEqual(32, T.key_len); - try testing.expectEqual(85, T.encrypt_overhead); - }, - else => unreachable, - } - try testing.expectEqual(16, T.paddedLength(1)); // cbc block padding - try testing.expectEqual(16, T.iv_len); - } -} - -test "overhead tls 1.3" { - try testing.expectEqual(22, encrypt_overhead_tls_13); - try expectOverhead(Aes128GcmSha256, 16, 16, 12, 22); - try expectOverhead(Aes256GcmSha384, 32, 16, 12, 22); - try expectOverhead(ChaChaSha256, 32, 16, 12, 22); - try expectOverhead(Aegis128Sha256, 16, 16, 16, 22); - // and tls 1.2 chacha - try expectOverhead(Aead12ChaCha, 32, 16, 12, 21); -} - -fn expectOverhead(T: type, key_len: usize, auth_tag_len: usize, nonce_len: usize, overhead: usize) !void { - try testing.expectEqual(key_len, T.key_len); - try testing.expectEqual(auth_tag_len, T.auth_tag_len); - try testing.expectEqual(nonce_len, T.nonce_len); - try testing.expectEqual(overhead, T.encrypt_overhead); -} - -test "client/server encryption tls 1.3" { - inline for (cipher_suites.tls13) |cs| { - var buf: [256]u8 = undefined; - testu.fill(&buf); - const secret = Transcript.Secret{ - .client = buf[0..128], - .server = buf[128..], - }; - var client_cipher = try Cipher.initTls13(cs, secret, .client); - var server_cipher = try Cipher.initTls13(cs, secret, .server); - try encryptDecrypt(&client_cipher, &server_cipher); - - try client_cipher.keyUpdateEncrypt(); - try server_cipher.keyUpdateDecrypt(); - try encryptDecrypt(&client_cipher, &server_cipher); - - try client_cipher.keyUpdateDecrypt(); - try server_cipher.keyUpdateEncrypt(); - try encryptDecrypt(&client_cipher, &server_cipher); - } -} - -test "client/server encryption tls 1.2" { - inline for (cipher_suites.tls12) |cs| { - var key_material: [256]u8 = undefined; - testu.fill(&key_material); - var client_cipher = try Cipher.initTls12(cs, &key_material, .client); - var server_cipher = try Cipher.initTls12(cs, &key_material, .server); - try encryptDecrypt(&client_cipher, &server_cipher); - } -} - -fn encryptDecrypt(client_cipher: *Cipher, server_cipher: *Cipher) !void { - const cleartext = - \\ Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do - \\ eiusmod tempor incididunt ut labore et dolore magna aliqua. - ; - var buf: [256]u8 = undefined; - - { // client to server - // encrypt - const encrypted = try client_cipher.encrypt(&buf, .application_data, cleartext); - const expected_encrypted_len = switch (client_cipher.*) { - inline else => |f| brk: { - const T = @TypeOf(f); - break :brk switch (T) { - CbcAes128Sha1, - CbcAes128Sha256, - CbcAes256Sha256, - CbcAes256Sha384, - => record.header_len + T.paddedLength(T.iv_len + cleartext.len + T.mac_len), - Aead12Aes128Gcm, - Aead12Aes256Gcm, - Aead12ChaCha, - Aes128GcmSha256, - Aes256GcmSha384, - ChaChaSha256, - Aegis128Sha256, - => cleartext.len + T.encrypt_overhead, - else => unreachable, - }; - }, - }; - try testing.expectEqual(expected_encrypted_len, encrypted.len); - // decrypt - const content_type, const decrypted = try server_cipher.decrypt(&buf, Record.init(encrypted)); - try testing.expectEqualSlices(u8, cleartext, decrypted); - try testing.expectEqual(.application_data, content_type); - } - // server to client - { - const encrypted = try server_cipher.encrypt(&buf, .application_data, cleartext); - const content_type, const decrypted = try client_cipher.decrypt(&buf, Record.init(encrypted)); - try testing.expectEqualSlices(u8, cleartext, decrypted); - try testing.expectEqual(.application_data, content_type); - } -} diff --git a/src/http/async/tls.zig/connection.zig b/src/http/async/tls.zig/connection.zig deleted file mode 100644 index 7a6afcbe..00000000 --- a/src/http/async/tls.zig/connection.zig +++ /dev/null @@ -1,665 +0,0 @@ -const std = @import("std"); -const assert = std.debug.assert; - -const proto = @import("protocol.zig"); -const record = @import("record.zig"); -const cipher = @import("cipher.zig"); -const Cipher = cipher.Cipher; - -const async_io = @import("../std/http/Client.zig"); -const Cbk = async_io.Cbk; -const Ctx = async_io.Ctx; - -pub fn connection(stream: anytype) Connection(@TypeOf(stream)) { - return .{ - .stream = stream, - .rec_rdr = record.reader(stream), - }; -} - -pub fn Connection(comptime Stream: type) type { - return struct { - stream: Stream, // underlying stream - rec_rdr: record.Reader(Stream), - cipher: Cipher = undefined, - - max_encrypt_seq: u64 = std.math.maxInt(u64) - 1, - key_update_requested: bool = false, - - read_buf: []const u8 = "", - received_close_notify: bool = false, - - const Self = @This(); - - /// Encrypts and writes single tls record to the stream. - fn writeRecord(c: *Self, content_type: proto.ContentType, bytes: []const u8) !void { - assert(bytes.len <= cipher.max_cleartext_len); - var write_buf: [cipher.max_ciphertext_record_len]u8 = undefined; - // If key update is requested send key update message and update - // my encryption keys. - if (c.cipher.encryptSeq() >= c.max_encrypt_seq or @atomicLoad(bool, &c.key_update_requested, .monotonic)) { - @atomicStore(bool, &c.key_update_requested, false, .monotonic); - - // If the request_update field is set to "update_requested", - // then the receiver MUST send a KeyUpdate of its own with - // request_update set to "update_not_requested" prior to sending - // its next Application Data record. This mechanism allows - // either side to force an update to the entire connection, but - // causes an implementation which receives multiple KeyUpdates - // while it is silent to respond with a single update. - // - // rfc: https://datatracker.ietf.org/doc/html/rfc8446#autoid-57 - const key_update = &record.handshakeHeader(.key_update, 1) ++ [_]u8{0}; - const rec = try c.cipher.encrypt(&write_buf, .handshake, key_update); - try c.stream.writeAll(rec); - try c.cipher.keyUpdateEncrypt(); - } - const rec = try c.cipher.encrypt(&write_buf, content_type, bytes); - try c.stream.writeAll(rec); - } - - fn writeAlert(c: *Self, err: anyerror) !void { - const cleartext = proto.alertFromError(err); - var buf: [128]u8 = undefined; - const ciphertext = try c.cipher.encrypt(&buf, .alert, &cleartext); - c.stream.writeAll(ciphertext) catch {}; - } - - /// Returns next record of cleartext data. - /// Can be used in iterator like loop without memcpy to another buffer: - /// while (try client.next()) |buf| { ... } - pub fn next(c: *Self) ReadError!?[]const u8 { - const content_type, const data = c.nextRecord() catch |err| { - try c.writeAlert(err); - return err; - } orelse return null; - if (content_type != .application_data) return error.TlsUnexpectedMessage; - return data; - } - - fn nextRecord(c: *Self) ReadError!?struct { proto.ContentType, []const u8 } { - if (c.eof()) return null; - while (true) { - const content_type, const cleartext = try c.rec_rdr.nextDecrypt(&c.cipher) orelse return null; - - switch (content_type) { - .application_data => {}, - .handshake => { - const handshake_type: proto.Handshake = @enumFromInt(cleartext[0]); - switch (handshake_type) { - // skip new session ticket and read next record - .new_session_ticket => continue, - .key_update => { - if (cleartext.len != 5) return error.TlsDecodeError; - // rfc: Upon receiving a KeyUpdate, the receiver MUST - // update its receiving keys. - try c.cipher.keyUpdateDecrypt(); - const key: proto.KeyUpdateRequest = @enumFromInt(cleartext[4]); - switch (key) { - .update_requested => { - @atomicStore(bool, &c.key_update_requested, true, .monotonic); - }, - .update_not_requested => {}, - else => return error.TlsIllegalParameter, - } - // this record is handled read next - continue; - }, - else => {}, - } - }, - .alert => { - if (cleartext.len < 2) return error.TlsUnexpectedMessage; - try proto.Alert.parse(cleartext[0..2].*).toError(); - // server side clean shutdown - c.received_close_notify = true; - return null; - }, - else => return error.TlsUnexpectedMessage, - } - return .{ content_type, cleartext }; - } - } - - pub fn eof(c: *Self) bool { - return c.received_close_notify and c.read_buf.len == 0; - } - - pub fn close(c: *Self) !void { - if (c.received_close_notify) return; - try c.writeRecord(.alert, &proto.Alert.closeNotify()); - } - - // read, write interface - - pub const ReadError = Stream.ReadError || proto.Alert.Error || - error{ - TlsBadVersion, - TlsUnexpectedMessage, - TlsRecordOverflow, - TlsDecryptError, - TlsDecodeError, - TlsBadRecordMac, - TlsIllegalParameter, - BufferOverflow, - }; - pub const WriteError = Stream.WriteError || - error{ - BufferOverflow, - TlsUnexpectedMessage, - }; - - pub const Reader = std.io.Reader(*Self, ReadError, read); - pub const Writer = std.io.Writer(*Self, WriteError, write); - - pub fn reader(c: *Self) Reader { - return .{ .context = c }; - } - - pub fn writer(c: *Self) Writer { - return .{ .context = c }; - } - - /// Encrypts cleartext and writes it to the underlying stream as single - /// tls record. Max single tls record payload length is 1<<14 (16K) - /// bytes. - pub fn write(c: *Self, bytes: []const u8) WriteError!usize { - const n = @min(bytes.len, cipher.max_cleartext_len); - try c.writeRecord(.application_data, bytes[0..n]); - return n; - } - - /// Encrypts cleartext and writes it to the underlying stream. If needed - /// splits cleartext into multiple tls record. - pub fn writeAll(c: *Self, bytes: []const u8) WriteError!void { - var index: usize = 0; - while (index < bytes.len) { - index += try c.write(bytes[index..]); - } - } - - pub fn read(c: *Self, buffer: []u8) ReadError!usize { - if (c.read_buf.len == 0) { - c.read_buf = try c.next() orelse return 0; - } - const n = @min(c.read_buf.len, buffer.len); - @memcpy(buffer[0..n], c.read_buf[0..n]); - c.read_buf = c.read_buf[n..]; - return n; - } - - /// Returns the number of bytes read. If the number read is smaller than - /// `buffer.len`, it means the stream reached the end. - pub fn readAll(c: *Self, buffer: []u8) ReadError!usize { - return c.readAtLeast(buffer, buffer.len); - } - - /// Returns the number of bytes read, calling the underlying read function - /// the minimal number of times until the buffer has at least `len` bytes - /// filled. If the number read is less than `len` it means the stream - /// reached the end. - pub fn readAtLeast(c: *Self, buffer: []u8, len: usize) ReadError!usize { - assert(len <= buffer.len); - var index: usize = 0; - while (index < len) { - const amt = try c.read(buffer[index..]); - if (amt == 0) break; - index += amt; - } - return index; - } - - /// Returns the number of bytes read. If the number read is less than - /// the space provided it means the stream reached the end. - pub fn readv(c: *Self, iovecs: []std.posix.iovec) !usize { - var vp: VecPut = .{ .iovecs = iovecs }; - while (true) { - if (c.read_buf.len == 0) { - c.read_buf = try c.next() orelse break; - } - const n = vp.put(c.read_buf); - const read_buf_len = c.read_buf.len; - c.read_buf = c.read_buf[n..]; - if ((n < read_buf_len) or - (n == read_buf_len and !c.rec_rdr.hasMore())) - break; - } - return vp.total; - } - - fn onWriteAll(ctx: *Ctx, res: anyerror!void) anyerror!void { - res catch |err| return ctx.pop(err); - - if (ctx._tls_write_bytes.len - ctx._tls_write_index > 0) { - const rec = ctx.conn().tls_client.prepareRecord(ctx.stream(), ctx) catch |err| return ctx.pop(err); - return ctx.stream().async_writeAll(rec, ctx, onWriteAll) catch |err| return ctx.pop(err); - } - - return ctx.pop({}); - } - - pub fn async_writeAll(c: *Self, stream: anytype, bytes: []const u8, ctx: *Ctx, comptime cbk: Cbk) !void { - assert(bytes.len <= cipher.max_cleartext_len); - - ctx._tls_write_bytes = bytes; - ctx._tls_write_index = 0; - const rec = try c.prepareRecord(stream, ctx); - - try ctx.push(cbk); - return stream.async_writeAll(rec, ctx, onWriteAll); - } - - fn prepareRecord(c: *Self, stream: anytype, ctx: *Ctx) ![]const u8 { - const len = @min(ctx._tls_write_bytes.len - ctx._tls_write_index, cipher.max_cleartext_len); - - // If key update is requested send key update message and update - // my encryption keys. - if (c.cipher.encryptSeq() >= c.max_encrypt_seq or @atomicLoad(bool, &c.key_update_requested, .monotonic)) { - @atomicStore(bool, &c.key_update_requested, false, .monotonic); - - // If the request_update field is set to "update_requested", - // then the receiver MUST send a KeyUpdate of its own with - // request_update set to "update_not_requested" prior to sending - // its next Application Data record. This mechanism allows - // either side to force an update to the entire connection, but - // causes an implementation which receives multiple KeyUpdates - // while it is silent to respond with a single update. - // - // rfc: https://datatracker.ietf.org/doc/html/rfc8446#autoid-57 - const key_update = &record.handshakeHeader(.key_update, 1) ++ [_]u8{0}; - const rec = try c.cipher.encrypt(&ctx._tls_write_buf, .handshake, key_update); - try stream.writeAll(rec); // TODO async - try c.cipher.keyUpdateEncrypt(); - } - - defer ctx._tls_write_index += len; - return c.cipher.encrypt(&ctx._tls_write_buf, .application_data, ctx._tls_write_bytes[ctx._tls_write_index..len]); - } - - fn onReadv(ctx: *Ctx, res: anyerror!void) anyerror!void { - res catch |err| return ctx.pop(err); - - if (ctx._tls_read_buf == null) { - // end of read - ctx.setLen(ctx._vp.total); - return ctx.pop({}); - } - - while (true) { - const n = ctx._vp.put(ctx._tls_read_buf.?); - const read_buf_len = ctx._tls_read_buf.?.len; - const c = ctx.conn().tls_client; - - if (read_buf_len == 0) { - // read another buffer - return c.async_next(ctx.stream(), ctx, onReadv) catch |err| return ctx.pop(err); - } - - ctx._tls_read_buf = ctx._tls_read_buf.?[n..]; - - if ((n < read_buf_len) or (n == read_buf_len and !c.rec_rdr.hasMore())) { - // end of read - ctx.setLen(ctx._vp.total); - return ctx.pop({}); - } - } - } - - pub fn async_readv(c: *Self, stream: anytype, iovecs: []std.posix.iovec, ctx: *Ctx, comptime cbk: Cbk) !void { - try ctx.push(cbk); - ctx._vp = .{ .iovecs = iovecs }; - - return c.async_next(stream, ctx, onReadv); - } - - fn onNext(ctx: *Ctx, res: anyerror!void) anyerror!void { - res catch |err| { - ctx.conn().tls_client.writeAlert(err) catch |e| std.log.err("onNext: write alert: {any}", .{e}); // TODO async - return ctx.pop(err); - }; - - if (ctx._tls_read_content_type != .application_data) { - return ctx.pop(error.TlsUnexpectedMessage); - } - - return ctx.pop({}); - } - - pub fn async_next(c: *Self, stream: anytype, ctx: *Ctx, comptime cbk: Cbk) !void { - try ctx.push(cbk); - - return c.async_next_decrypt(stream, ctx, onNext); - } - - pub fn onNextDecrypt(ctx: *Ctx, res: anyerror!void) anyerror!void { - res catch |err| return ctx.pop(err); - - const c = ctx.conn().tls_client; - // TOOD not sure if this works in my async case... - if (c.eof()) { - ctx._tls_read_buf = null; - return ctx.pop({}); - } - - const content_type = ctx._tls_read_content_type; - - switch (content_type) { - .application_data => {}, - .handshake => { - const handshake_type: proto.Handshake = @enumFromInt(ctx._tls_read_buf.?[0]); - switch (handshake_type) { - // skip new session ticket and read next record - .new_session_ticket => return c.async_next_record(ctx.stream(), ctx, onNextDecrypt) catch |err| return ctx.pop(err), - .key_update => { - if (ctx._tls_read_buf.?.len != 5) return ctx.pop(error.TlsDecodeError); - // rfc: Upon receiving a KeyUpdate, the receiver MUST - // update its receiving keys. - try c.cipher.keyUpdateDecrypt(); - const key: proto.KeyUpdateRequest = @enumFromInt(ctx._tls_read_buf.?[4]); - switch (key) { - .update_requested => { - @atomicStore(bool, &c.key_update_requested, true, .monotonic); - }, - .update_not_requested => {}, - else => return ctx.pop(error.TlsIllegalParameter), - } - // this record is handled read next - c.async_next_record(ctx.stream(), ctx, onNextDecrypt) catch |err| return ctx.pop(err); - }, - else => {}, - } - }, - .alert => { - if (ctx._tls_read_buf.?.len < 2) return ctx.pop(error.TlsUnexpectedMessage); - try proto.Alert.parse(ctx._tls_read_buf.?[0..2].*).toError(); - // server side clean shutdown - c.received_close_notify = true; - ctx._tls_read_buf = null; - return ctx.pop({}); - }, - else => return ctx.pop(error.TlsUnexpectedMessage), - } - - return ctx.pop({}); - } - - pub fn async_next_decrypt(c: *Self, stream: anytype, ctx: *Ctx, comptime cbk: Cbk) !void { - try ctx.push(cbk); - - return c.async_next_record(stream, ctx, onNextDecrypt) catch |err| return ctx.pop(err); - } - - pub fn onNextRecord(ctx: *Ctx, res: anyerror!void) anyerror!void { - res catch |err| return ctx.pop(err); - - const rec = ctx._tls_read_record orelse { - ctx._tls_read_buf = null; - return ctx.pop({}); - }; - - if (rec.protocol_version != .tls_1_2) return error.TlsBadVersion; - - const c = ctx.conn().tls_client; - const cph = &c.cipher; - - ctx._tls_read_content_type, ctx._tls_read_buf = cph.decrypt( - // Reuse reader buffer for cleartext. `rec.header` and - // `rec.payload`(ciphertext) are also pointing somewhere in - // this buffer. Decrypter is first reading then writing a - // block, cleartext has less length then ciphertext, - // cleartext starts from the beginning of the buffer, so - // ciphertext is always ahead of cleartext. - c.rec_rdr.buffer[0..c.rec_rdr.start], - rec, - ) catch |err| return ctx.pop(err); - - return ctx.pop({}); - } - - pub fn async_next_record(c: *Self, stream: anytype, ctx: *Ctx, comptime cbk: Cbk) !void { - try ctx.push(cbk); - - return c.async_reader_next(stream, ctx, onNextRecord); - } - - pub fn onReaderNext(ctx: *Ctx, res: anyerror!void) anyerror!void { - res catch |err| return ctx.pop(err); - - const c = ctx.conn().tls_client; - - const n = ctx.len(); - if (n == 0) { - ctx._tls_read_record = null; - return ctx.pop({}); - } - c.rec_rdr.end += n; - - return c.readNext(ctx); - } - - pub fn readNext(c: *Self, ctx: *Ctx) anyerror!void { - const buffer = c.rec_rdr.buffer[c.rec_rdr.start..c.rec_rdr.end]; - // If we have 5 bytes header. - if (buffer.len >= record.header_len) { - const record_header = buffer[0..record.header_len]; - const payload_len = std.mem.readInt(u16, record_header[3..5], .big); - if (payload_len > cipher.max_ciphertext_len) - return error.TlsRecordOverflow; - const record_len = record.header_len + payload_len; - // If we have whole record - if (buffer.len >= record_len) { - c.rec_rdr.start += record_len; - ctx._tls_read_record = record.Record.init(buffer[0..record_len]); - return ctx.pop({}); - } - } - { // Move dirty part to the start of the buffer. - const n = c.rec_rdr.end - c.rec_rdr.start; - if (n > 0 and c.rec_rdr.start > 0) { - if (c.rec_rdr.start > n) { - @memcpy(c.rec_rdr.buffer[0..n], c.rec_rdr.buffer[c.rec_rdr.start..][0..n]); - } else { - std.mem.copyForwards(u8, c.rec_rdr.buffer[0..n], c.rec_rdr.buffer[c.rec_rdr.start..][0..n]); - } - } - c.rec_rdr.start = 0; - c.rec_rdr.end = n; - } - // Read more from inner_reader. - return ctx.stream() - .async_read(c.rec_rdr.buffer[c.rec_rdr.end..], ctx, onReaderNext) catch |err| return ctx.pop(err); - } - - pub fn async_reader_next(c: *Self, _: anytype, ctx: *Ctx, comptime cbk: Cbk) !void { - try ctx.push(cbk); - return c.readNext(ctx); - } - }; -} - -const testing = std.testing; -const data12 = @import("testdata/tls12.zig"); -const testu = @import("testu.zig"); - -test "encrypt decrypt" { - var output_buf: [1024]u8 = undefined; - const stream = testu.Stream.init(&(data12.server_pong ** 3), &output_buf); - var conn: Connection(@TypeOf(stream)) = .{ .stream = stream, .rec_rdr = record.reader(stream) }; - conn.cipher = try Cipher.initTls12(.ECDHE_RSA_WITH_AES_128_CBC_SHA, &data12.key_material, .client); - conn.cipher.ECDHE_RSA_WITH_AES_128_CBC_SHA.rnd = testu.random(0); // use fixed rng - - conn.stream.output.reset(); - { // encrypt verify data from example - _ = testu.random(0x40); // sets iv to 40, 41, ... 4f - try conn.writeRecord(.handshake, &data12.client_finished); - try testing.expectEqualSlices(u8, &data12.verify_data_encrypted_msg, conn.stream.output.getWritten()); - } - - conn.stream.output.reset(); - { // encrypt ping - const cleartext = "ping"; - _ = testu.random(0); // sets iv to 00, 01, ... 0f - //conn.encrypt_seq = 1; - - try conn.writeAll(cleartext); - try testing.expectEqualSlices(u8, &data12.encrypted_ping_msg, conn.stream.output.getWritten()); - } - { // decrypt server pong message - conn.cipher.ECDHE_RSA_WITH_AES_128_CBC_SHA.decrypt_seq = 1; - try testing.expectEqualStrings("pong", (try conn.next()).?); - } - { // test reader interface - conn.cipher.ECDHE_RSA_WITH_AES_128_CBC_SHA.decrypt_seq = 1; - var rdr = conn.reader(); - var buffer: [4]u8 = undefined; - const n = try rdr.readAll(&buffer); - try testing.expectEqualStrings("pong", buffer[0..n]); - } - { // test readv interface - conn.cipher.ECDHE_RSA_WITH_AES_128_CBC_SHA.decrypt_seq = 1; - var buffer: [9]u8 = undefined; - var iovecs = [_]std.posix.iovec{ - .{ .base = &buffer, .len = 3 }, - .{ .base = buffer[3..], .len = 3 }, - .{ .base = buffer[6..], .len = 3 }, - }; - const n = try conn.readv(iovecs[0..]); - try testing.expectEqual(4, n); - try testing.expectEqualStrings("pong", buffer[0..n]); - } -} - -// Copied from: https://github.com/ziglang/zig/blob/455899668b620dfda40252501c748c0a983555bd/lib/std/crypto/tls/Client.zig#L1354 -/// Abstraction for sending multiple byte buffers to a slice of iovecs. -pub const VecPut = struct { - iovecs: []const std.posix.iovec, - idx: usize = 0, - off: usize = 0, - total: usize = 0, - - /// Returns the amount actually put which is always equal to bytes.len - /// unless the vectors ran out of space. - pub fn put(vp: *VecPut, bytes: []const u8) usize { - if (vp.idx >= vp.iovecs.len) return 0; - var bytes_i: usize = 0; - while (true) { - const v = vp.iovecs[vp.idx]; - const dest = v.base[vp.off..v.len]; - const src = bytes[bytes_i..][0..@min(dest.len, bytes.len - bytes_i)]; - @memcpy(dest[0..src.len], src); - bytes_i += src.len; - vp.off += src.len; - if (vp.off >= v.len) { - vp.off = 0; - vp.idx += 1; - if (vp.idx >= vp.iovecs.len) { - vp.total += bytes_i; - return bytes_i; - } - } - if (bytes_i >= bytes.len) { - vp.total += bytes_i; - return bytes_i; - } - } - } -}; - -test "client/server connection" { - const BufReaderWriter = struct { - buf: []u8, - wp: usize = 0, - rp: usize = 0, - - const Self = @This(); - - pub fn write(self: *Self, bytes: []const u8) !usize { - if (self.wp == self.buf.len) return error.NoSpaceLeft; - - const n = @min(bytes.len, self.buf.len - self.wp); - @memcpy(self.buf[self.wp..][0..n], bytes[0..n]); - self.wp += n; - return n; - } - - pub fn writeAll(self: *Self, bytes: []const u8) !void { - var n: usize = 0; - while (n < bytes.len) { - n += try self.write(bytes[n..]); - } - } - - pub fn read(self: *Self, bytes: []u8) !usize { - const n = @min(bytes.len, self.wp - self.rp); - if (n == 0) return 0; - @memcpy(bytes[0..n], self.buf[self.rp..][0..n]); - self.rp += n; - if (self.rp == self.wp) { - self.wp = 0; - self.rp = 0; - } - return n; - } - }; - - const TestStream = struct { - inner_stream: *BufReaderWriter, - const Self = @This(); - pub const ReadError = error{}; - pub const WriteError = error{NoSpaceLeft}; - pub fn read(self: *Self, bytes: []u8) !usize { - return try self.inner_stream.read(bytes); - } - pub fn writeAll(self: *Self, bytes: []const u8) !void { - return try self.inner_stream.writeAll(bytes); - } - }; - - const buf_len = 32 * 1024; - const tls_records_in_buf = (std.math.divCeil(comptime_int, buf_len, cipher.max_cleartext_len) catch unreachable); - const overhead: usize = tls_records_in_buf * @import("cipher.zig").encrypt_overhead_tls_13; - var buf: [buf_len + overhead]u8 = undefined; - var inner_stream = BufReaderWriter{ .buf = &buf }; - - const cipher_client, const cipher_server = brk: { - const Transcript = @import("transcript.zig").Transcript; - const CipherSuite = @import("cipher.zig").CipherSuite; - const cipher_suite: CipherSuite = .AES_256_GCM_SHA384; - - var rnd: [128]u8 = undefined; - std.crypto.random.bytes(&rnd); - const secret = Transcript.Secret{ - .client = rnd[0..64], - .server = rnd[64..], - }; - - break :brk .{ - try Cipher.initTls13(cipher_suite, secret, .client), - try Cipher.initTls13(cipher_suite, secret, .server), - }; - }; - - var conn1 = connection(TestStream{ .inner_stream = &inner_stream }); - conn1.cipher = cipher_client; - - var conn2 = connection(TestStream{ .inner_stream = &inner_stream }); - conn2.cipher = cipher_server; - - var prng = std.Random.DefaultPrng.init(0); - const random = prng.random(); - var send_buf: [buf_len]u8 = undefined; - var recv_buf: [buf_len]u8 = undefined; - random.bytes(&send_buf); // fill send buffer with random bytes - - for (0..16) |_| { - const n = buf_len; //random.uintLessThan(usize, buf_len); - - const sent = send_buf[0..n]; - try conn1.writeAll(sent); - const r = try conn2.readAll(&recv_buf); - const received = recv_buf[0..r]; - - try testing.expectEqual(n, r); - try testing.expectEqualSlices(u8, sent, received); - } -} diff --git a/src/http/async/tls.zig/handshake_client.zig b/src/http/async/tls.zig/handshake_client.zig deleted file mode 100644 index e7b48cf6..00000000 --- a/src/http/async/tls.zig/handshake_client.zig +++ /dev/null @@ -1,955 +0,0 @@ -const std = @import("std"); -const assert = std.debug.assert; -const crypto = std.crypto; -const mem = std.mem; -const Certificate = crypto.Certificate; - -const cipher = @import("cipher.zig"); -const Cipher = cipher.Cipher; -const CipherSuite = cipher.CipherSuite; -const cipher_suites = cipher.cipher_suites; -const Transcript = @import("transcript.zig").Transcript; -const record = @import("record.zig"); -const rsa = @import("rsa/rsa.zig"); -const key_log = @import("key_log.zig"); -const PrivateKey = @import("PrivateKey.zig"); -const proto = @import("protocol.zig"); - -const common = @import("handshake_common.zig"); -const dupe = common.dupe; -const CertificateBuilder = common.CertificateBuilder; -const CertificateParser = common.CertificateParser; -const DhKeyPair = common.DhKeyPair; -const CertBundle = common.CertBundle; -const CertKeyPair = common.CertKeyPair; - -pub const Options = struct { - host: []const u8, - /// Set of root certificate authorities that clients use when verifying - /// server certificates. - root_ca: CertBundle, - - /// Controls whether a client verifies the server's certificate chain and - /// host name. - insecure_skip_verify: bool = false, - - /// List of cipher suites to use. - /// To use just tls 1.3 cipher suites: - /// .cipher_suites = &tls.CipherSuite.tls13, - /// To select particular cipher suite: - /// .cipher_suites = &[_]tls.CipherSuite{tls.CipherSuite.CHACHA20_POLY1305_SHA256}, - cipher_suites: []const CipherSuite = cipher_suites.all, - - /// List of named groups to use. - /// To use specific named group: - /// .named_groups = &[_]tls.NamedGroup{.secp384r1}, - named_groups: []const proto.NamedGroup = supported_named_groups, - - /// Client authentication certificates and private key. - auth: ?CertKeyPair = null, - - /// If this structure is provided it will be filled with handshake attributes - /// at the end of the handshake process. - diagnostic: ?*Diagnostic = null, - - /// For logging current connection tls keys, so we can share them with - /// Wireshark and analyze decrypted traffic there. - key_log_callback: ?key_log.Callback = null, - - pub const Diagnostic = struct { - tls_version: proto.Version = @enumFromInt(0), - cipher_suite_tag: CipherSuite = @enumFromInt(0), - named_group: proto.NamedGroup = @enumFromInt(0), - signature_scheme: proto.SignatureScheme = @enumFromInt(0), - client_signature_scheme: proto.SignatureScheme = @enumFromInt(0), - }; -}; - -const supported_named_groups = &[_]proto.NamedGroup{ - .x25519, - .secp256r1, - .secp384r1, - .x25519_kyber768d00, -}; - -/// Handshake parses tls server message and creates client messages. Collects -/// tls attributes: server random, cipher suite and so on. Client messages are -/// created using provided buffer. Provided record reader is used to get tls -/// record when needed. -pub fn Handshake(comptime Stream: type) type { - const RecordReaderT = record.Reader(Stream); - return struct { - client_random: [32]u8, - server_random: [32]u8 = undefined, - master_secret: [48]u8 = undefined, - key_material: [48 * 4]u8 = undefined, // for sha256 32 * 4 is filled, for sha384 48 * 4 - - transcript: Transcript = .{}, - cipher_suite: CipherSuite = @enumFromInt(0), - named_group: ?proto.NamedGroup = null, - dh_kp: DhKeyPair, - rsa_secret: RsaSecret, - tls_version: proto.Version = .tls_1_2, - cipher: Cipher = undefined, - cert: CertificateParser = undefined, - client_certificate_requested: bool = false, - // public key len: x25519 = 32, secp256r1 = 65, secp384r1 = 97, x25519_kyber768d00 = 1120 - server_pub_key_buf: [2048]u8 = undefined, - server_pub_key: []const u8 = undefined, - - rec_rdr: *RecordReaderT, // tls record reader - buffer: []u8, // scratch buffer used in all messages creation - - const HandshakeT = @This(); - - pub fn init(buf: []u8, rec_rdr: *RecordReaderT) HandshakeT { - return .{ - .client_random = undefined, - .dh_kp = undefined, - .rsa_secret = undefined, - //.now_sec = std.time.timestamp(), - .buffer = buf, - .rec_rdr = rec_rdr, - }; - } - - fn initKeys( - h: *HandshakeT, - named_groups: []const proto.NamedGroup, - ) !void { - const init_keys_buf_len = 32 + 46 + DhKeyPair.seed_len; - var buf: [init_keys_buf_len]u8 = undefined; - crypto.random.bytes(&buf); - - h.client_random = buf[0..32].*; - h.rsa_secret = RsaSecret.init(buf[32..][0..46].*); - h.dh_kp = try DhKeyPair.init(buf[32 + 46 ..][0..DhKeyPair.seed_len].*, named_groups); - } - - /// Handshake exchanges messages with server to get agreement about - /// cryptographic parameters. That upgrades existing client-server - /// connection to TLS connection. Returns cipher used in application for - /// encrypted message exchange. - /// - /// Handles TLS 1.2 and TLS 1.3 connections. After initial client hello - /// server chooses in its server hello which TLS version will be used. - /// - /// TLS 1.2 handshake messages exchange: - /// Client Server - /// -------------------------------------------------------------- - /// ClientHello client flight 1 ---> - /// ServerHello - /// Certificate - /// ServerKeyExchange - /// CertificateRequest* - /// <--- server flight 1 ServerHelloDone - /// Certificate* - /// ClientKeyExchange - /// CertificateVerify* - /// ChangeCipherSpec - /// Finished client flight 2 ---> - /// ChangeCipherSpec - /// <--- server flight 2 Finished - /// - /// TLS 1.3 handshake messages exchange: - /// Client Server - /// -------------------------------------------------------------- - /// ClientHello client flight 1 ---> - /// ServerHello - /// {EncryptedExtensions} - /// {CertificateRequest*} - /// {Certificate} - /// {CertificateVerify} - /// <--- server flight 1 {Finished} - /// ChangeCipherSpec - /// {Certificate*} - /// {CertificateVerify*} - /// Finished client flight 2 ---> - /// - /// * - optional - /// {} - encrypted - /// - /// References: - /// https://datatracker.ietf.org/doc/html/rfc5246#section-7.3 - /// https://datatracker.ietf.org/doc/html/rfc8446#section-2 - /// - pub fn handshake(h: *HandshakeT, w: Stream, opt: Options) !Cipher { - defer h.updateDiagnostic(opt); - try h.initKeys(opt.named_groups); - h.cert = .{ - .host = opt.host, - .root_ca = opt.root_ca.bundle, - .skip_verify = opt.insecure_skip_verify, - }; - - try w.writeAll(try h.makeClientHello(opt)); // client flight 1 - try h.readServerFlight1(); // server flight 1 - h.transcript.use(h.cipher_suite.hash()); - - // tls 1.3 specific handshake part - if (h.tls_version == .tls_1_3) { - try h.generateHandshakeCipher(opt.key_log_callback); - try h.readEncryptedServerFlight1(); // server flight 1 - const app_cipher = try h.generateApplicationCipher(opt.key_log_callback); - try w.writeAll(try h.makeClientFlight2Tls13(opt.auth)); // client flight 2 - return app_cipher; - } - - // tls 1.2 specific handshake part - try h.generateCipher(opt.key_log_callback); - try w.writeAll(try h.makeClientFlight2Tls12(opt.auth)); // client flight 2 - try h.readServerFlight2(); // server flight 2 - return h.cipher; - } - - /// Prepare key material and generate cipher for TLS 1.2 - fn generateCipher(h: *HandshakeT, key_log_callback: ?key_log.Callback) !void { - try h.verifyCertificateSignatureTls12(); - try h.generateKeyMaterial(key_log_callback); - h.cipher = try Cipher.initTls12(h.cipher_suite, &h.key_material, .client); - } - - /// Generate TLS 1.2 pre master secret, master secret and key material. - fn generateKeyMaterial(h: *HandshakeT, key_log_callback: ?key_log.Callback) !void { - const pre_master_secret = if (h.named_group) |named_group| - try h.dh_kp.sharedKey(named_group, h.server_pub_key) - else - &h.rsa_secret.secret; - - _ = dupe( - &h.master_secret, - h.transcript.masterSecret(pre_master_secret, h.client_random, h.server_random), - ); - _ = dupe( - &h.key_material, - h.transcript.keyMaterial(&h.master_secret, h.client_random, h.server_random), - ); - if (key_log_callback) |cb| { - cb(key_log.label.client_random, &h.client_random, &h.master_secret); - } - } - - /// TLS 1.3 cipher used during handshake - fn generateHandshakeCipher(h: *HandshakeT, key_log_callback: ?key_log.Callback) !void { - const shared_key = try h.dh_kp.sharedKey(h.named_group.?, h.server_pub_key); - const handshake_secret = h.transcript.handshakeSecret(shared_key); - if (key_log_callback) |cb| { - cb(key_log.label.server_handshake_traffic_secret, &h.client_random, handshake_secret.server); - cb(key_log.label.client_handshake_traffic_secret, &h.client_random, handshake_secret.client); - } - h.cipher = try Cipher.initTls13(h.cipher_suite, handshake_secret, .client); - } - - /// TLS 1.3 application (client) cipher - fn generateApplicationCipher(h: *HandshakeT, key_log_callback: ?key_log.Callback) !Cipher { - const application_secret = h.transcript.applicationSecret(); - if (key_log_callback) |cb| { - cb(key_log.label.server_traffic_secret_0, &h.client_random, application_secret.server); - cb(key_log.label.client_traffic_secret_0, &h.client_random, application_secret.client); - } - return try Cipher.initTls13(h.cipher_suite, application_secret, .client); - } - - fn makeClientHello(h: *HandshakeT, opt: Options) ![]const u8 { - // Buffer will have this parts: - // | header | payload | extensions | - // - // Header will be written last because we need to know length of - // payload and extensions when creating it. Payload has - // extensions length (u16) as last element. - // - var buffer = h.buffer; - const header_len = 9; // tls record header (5 bytes) and handshake header (4 bytes) - const tls_versions = try CipherSuite.versions(opt.cipher_suites); - // Payload writer, preserve header_len bytes for handshake header. - var payload = record.Writer{ .buf = buffer[header_len..] }; - try payload.writeEnum(proto.Version.tls_1_2); - try payload.write(&h.client_random); - try payload.writeByte(0); // no session id - try payload.writeEnumArray(CipherSuite, opt.cipher_suites); - try payload.write(&[_]u8{ 0x01, 0x00 }); // no compression - - // Extensions writer starts after payload and preserves 2 more - // bytes for extension len in payload. - var ext = record.Writer{ .buf = buffer[header_len + payload.pos + 2 ..] }; - try ext.writeExtension(.supported_versions, switch (tls_versions) { - .both => &[_]proto.Version{ .tls_1_3, .tls_1_2 }, - .tls_1_3 => &[_]proto.Version{.tls_1_3}, - .tls_1_2 => &[_]proto.Version{.tls_1_2}, - }); - try ext.writeExtension(.signature_algorithms, common.supported_signature_algorithms); - - try ext.writeExtension(.supported_groups, opt.named_groups); - if (tls_versions != .tls_1_2) { - var keys: [supported_named_groups.len][]const u8 = undefined; - for (opt.named_groups, 0..) |ng, i| { - keys[i] = try h.dh_kp.publicKey(ng); - } - try ext.writeKeyShare(opt.named_groups, keys[0..opt.named_groups.len]); - } - try ext.writeServerName(opt.host); - - // Extensions length at the end of the payload. - try payload.writeInt(@as(u16, @intCast(ext.pos))); - - // Header at the start of the buffer. - const body_len = payload.pos + ext.pos; - buffer[0..header_len].* = record.header(.handshake, 4 + body_len) ++ - record.handshakeHeader(.client_hello, body_len); - - const msg = buffer[0 .. header_len + body_len]; - h.transcript.update(msg[record.header_len..]); - return msg; - } - - /// Process first flight of the messages from the server. - /// Read server hello message. If TLS 1.3 is chosen in server hello - /// return. For TLS 1.2 continue and read certificate, key_exchange - /// eventual certificate request and hello done messages. - fn readServerFlight1(h: *HandshakeT) !void { - var handshake_states: []const proto.Handshake = &.{.server_hello}; - - while (true) { - var d = try h.rec_rdr.nextDecoder(); - try d.expectContentType(.handshake); - - h.transcript.update(d.payload); - - // Multiple handshake messages can be packed in single tls record. - while (!d.eof()) { - const handshake_type = try d.decode(proto.Handshake); - - const length = try d.decode(u24); - if (length > cipher.max_cleartext_len) - return error.TlsUnsupportedFragmentedHandshakeMessage; - - brk: { - for (handshake_states) |state| - if (state == handshake_type) break :brk; - return error.TlsUnexpectedMessage; - } - switch (handshake_type) { - .server_hello => { // server hello, ref: https://datatracker.ietf.org/doc/html/rfc5246#section-7.4.1.3 - try h.parseServerHello(&d, length); - if (h.tls_version == .tls_1_3) { - if (!d.eof()) return error.TlsIllegalParameter; - return; // end of tls 1.3 server flight 1 - } - handshake_states = if (h.cert.skip_verify) - &.{ .certificate, .server_key_exchange, .server_hello_done } - else - &.{.certificate}; - }, - .certificate => { - try h.cert.parseCertificate(&d, h.tls_version); - handshake_states = if (h.cipher_suite.keyExchange() == .rsa) - &.{.server_hello_done} - else - &.{.server_key_exchange}; - }, - .server_key_exchange => { - try h.parseServerKeyExchange(&d); - handshake_states = &.{ .certificate_request, .server_hello_done }; - }, - .certificate_request => { - h.client_certificate_requested = true; - try d.skip(length); - handshake_states = &.{.server_hello_done}; - }, - .server_hello_done => { - if (length != 0) return error.TlsIllegalParameter; - return; - }, - else => return error.TlsUnexpectedMessage, - } - } - } - } - - /// Parse server hello message. - fn parseServerHello(h: *HandshakeT, d: *record.Decoder, length: u24) !void { - if (try d.decode(proto.Version) != proto.Version.tls_1_2) - return error.TlsBadVersion; - h.server_random = try d.array(32); - if (isServerHelloRetryRequest(&h.server_random)) - return error.TlsServerHelloRetryRequest; - - const session_id_len = try d.decode(u8); - if (session_id_len > 32) return error.TlsIllegalParameter; - try d.skip(session_id_len); - - h.cipher_suite = try d.decode(CipherSuite); - try h.cipher_suite.validate(); - try d.skip(1); // skip compression method - - const extensions_present = length > 2 + 32 + 1 + session_id_len + 2 + 1; - if (extensions_present) { - const exs_len = try d.decode(u16); - var l: usize = 0; - while (l < exs_len) { - const typ = try d.decode(proto.Extension); - const len = try d.decode(u16); - defer l += len + 4; - - switch (typ) { - .supported_versions => { - switch (try d.decode(proto.Version)) { - .tls_1_2, .tls_1_3 => |v| h.tls_version = v, - else => return error.TlsIllegalParameter, - } - if (len != 2) return error.TlsIllegalParameter; - }, - .key_share => { - h.named_group = try d.decode(proto.NamedGroup); - h.server_pub_key = dupe(&h.server_pub_key_buf, try d.slice(try d.decode(u16))); - if (len != h.server_pub_key.len + 4) return error.TlsIllegalParameter; - }, - else => { - try d.skip(len); - }, - } - } - } - } - - fn isServerHelloRetryRequest(server_random: []const u8) bool { - // Ref: https://datatracker.ietf.org/doc/html/rfc8446#section-4.1.3 - const hello_retry_request_magic = [32]u8{ - 0xCF, 0x21, 0xAD, 0x74, 0xE5, 0x9A, 0x61, 0x11, 0xBE, 0x1D, 0x8C, 0x02, 0x1E, 0x65, 0xB8, 0x91, - 0xC2, 0xA2, 0x11, 0x16, 0x7A, 0xBB, 0x8C, 0x5E, 0x07, 0x9E, 0x09, 0xE2, 0xC8, 0xA8, 0x33, 0x9C, - }; - return std.mem.eql(u8, server_random, &hello_retry_request_magic); - } - - fn parseServerKeyExchange(h: *HandshakeT, d: *record.Decoder) !void { - const curve_type = try d.decode(proto.Curve); - h.named_group = try d.decode(proto.NamedGroup); - h.server_pub_key = dupe(&h.server_pub_key_buf, try d.slice(try d.decode(u8))); - h.cert.signature_scheme = try d.decode(proto.SignatureScheme); - h.cert.signature = dupe(&h.cert.signature_buf, try d.slice(try d.decode(u16))); - if (curve_type != .named_curve) return error.TlsIllegalParameter; - } - - /// Read encrypted part (after server hello) of the server first flight - /// for TLS 1.3: change cipher spec, eventual certificate request, - /// certificate, certificate verify and handshake finished messages. - fn readEncryptedServerFlight1(h: *HandshakeT) !void { - var cleartext_buf = h.buffer; - var cleartext_buf_head: usize = 0; - var cleartext_buf_tail: usize = 0; - var handshake_states: []const proto.Handshake = &.{.encrypted_extensions}; - - outer: while (true) { - // wrapped record decoder - const rec = (try h.rec_rdr.next() orelse return error.EndOfStream); - if (rec.protocol_version != .tls_1_2) return error.TlsBadVersion; - switch (rec.content_type) { - .change_cipher_spec => {}, - .application_data => { - const content_type, const cleartext = try h.cipher.decrypt( - cleartext_buf[cleartext_buf_tail..], - rec, - ); - cleartext_buf_tail += cleartext.len; - if (cleartext_buf_tail > cleartext_buf.len) return error.TlsRecordOverflow; - - var d = record.Decoder.init(content_type, cleartext_buf[cleartext_buf_head..cleartext_buf_tail]); - try d.expectContentType(.handshake); - while (!d.eof()) { - const start_idx = d.idx; - const handshake_type = try d.decode(proto.Handshake); - const length = try d.decode(u24); - - // std.debug.print("handshake loop: {} {} {} {}\n", .{ handshake_type, length, d.payload.len, d.idx }); - if (length > cipher.max_cleartext_len) - return error.TlsUnsupportedFragmentedHandshakeMessage; - if (length > d.rest().len) - continue :outer; // fragmented handshake into multiple records - - defer { - const handshake_payload = d.payload[start_idx..d.idx]; - h.transcript.update(handshake_payload); - cleartext_buf_head += handshake_payload.len; - } - - brk: { - for (handshake_states) |state| - if (state == handshake_type) break :brk; - return error.TlsUnexpectedMessage; - } - switch (handshake_type) { - .encrypted_extensions => { - try d.skip(length); - handshake_states = if (h.cert.skip_verify) - &.{ .certificate_request, .certificate, .finished } - else - &.{ .certificate_request, .certificate }; - }, - .certificate_request => { - h.client_certificate_requested = true; - try d.skip(length); - handshake_states = if (h.cert.skip_verify) - &.{ .certificate, .finished } - else - &.{.certificate}; - }, - .certificate => { - try h.cert.parseCertificate(&d, h.tls_version); - handshake_states = &.{.certificate_verify}; - }, - .certificate_verify => { - try h.cert.parseCertificateVerify(&d); - try h.cert.verifySignature(h.transcript.serverCertificateVerify()); - handshake_states = &.{.finished}; - }, - .finished => { - const actual = try d.slice(length); - var buf: [Transcript.max_mac_length]u8 = undefined; - const expected = h.transcript.serverFinishedTls13(&buf); - if (!mem.eql(u8, expected, actual)) - return error.TlsDecryptError; - return; - }, - else => return error.TlsUnexpectedMessage, - } - } - cleartext_buf_head = 0; - cleartext_buf_tail = 0; - }, - else => return error.TlsUnexpectedMessage, - } - } - } - - fn verifyCertificateSignatureTls12(h: *HandshakeT) !void { - if (h.cipher_suite.keyExchange() != .ecdhe) return; - const verify_bytes = brk: { - var w = record.Writer{ .buf = h.buffer }; - try w.write(&h.client_random); - try w.write(&h.server_random); - try w.writeEnum(proto.Curve.named_curve); - try w.writeEnum(h.named_group.?); - try w.writeInt(@as(u8, @intCast(h.server_pub_key.len))); - try w.write(h.server_pub_key); - break :brk w.getWritten(); - }; - try h.cert.verifySignature(verify_bytes); - } - - /// Create client key exchange, change cipher spec and handshake - /// finished messages for tls 1.2. - /// If client certificate is requested also adds client certificate and - /// certificate verify messages. - fn makeClientFlight2Tls12(h: *HandshakeT, auth: ?CertKeyPair) ![]const u8 { - var w = record.Writer{ .buf = h.buffer }; - var cert_builder: ?CertificateBuilder = null; - - // Client certificate message - if (h.client_certificate_requested) { - if (auth) |a| { - const cb = h.certificateBuilder(a); - cert_builder = cb; - const client_certificate = try cb.makeCertificate(w.getPayload()); - h.transcript.update(client_certificate); - try w.advanceRecord(.handshake, client_certificate.len); - } else { - const empty_certificate = &record.handshakeHeader(.certificate, 3) ++ [_]u8{ 0, 0, 0 }; - h.transcript.update(empty_certificate); - try w.writeRecord(.handshake, empty_certificate); - } - } - - // Client key exchange message - { - const key_exchange = try h.makeClientKeyExchange(w.getPayload()); - h.transcript.update(key_exchange); - try w.advanceRecord(.handshake, key_exchange.len); - } - - // Client certificate verify message - if (cert_builder) |cb| { - const certificate_verify = try cb.makeCertificateVerify(w.getPayload()); - h.transcript.update(certificate_verify); - try w.advanceRecord(.handshake, certificate_verify.len); - } - - // Client change cipher spec message - try w.writeRecord(.change_cipher_spec, &[_]u8{1}); - - // Client handshake finished message - { - const client_finished = &record.handshakeHeader(.finished, 12) ++ - h.transcript.clientFinishedTls12(&h.master_secret); - h.transcript.update(client_finished); - try h.writeEncrypted(&w, client_finished); - } - - return w.getWritten(); - } - - /// Create client change cipher spec and handshake finished messages for - /// tls 1.3. - /// If the client certificate is requested by the server and client is - /// configured with certificates and private key then client certificate - /// and client certificate verify messages are also created. If the - /// server has requested certificate but the client is not configured - /// empty certificate message is sent, as is required by rfc. - fn makeClientFlight2Tls13(h: *HandshakeT, auth: ?CertKeyPair) ![]const u8 { - var w = record.Writer{ .buf = h.buffer }; - - // Client change cipher spec message - try w.writeRecord(.change_cipher_spec, &[_]u8{1}); - - if (h.client_certificate_requested) { - if (auth) |a| { - const cb = h.certificateBuilder(a); - { - const certificate = try cb.makeCertificate(w.getPayload()); - h.transcript.update(certificate); - try h.writeEncrypted(&w, certificate); - } - { - const certificate_verify = try cb.makeCertificateVerify(w.getPayload()); - h.transcript.update(certificate_verify); - try h.writeEncrypted(&w, certificate_verify); - } - } else { - // Empty certificate message and no certificate verify message - const empty_certificate = &record.handshakeHeader(.certificate, 4) ++ [_]u8{ 0, 0, 0, 0 }; - h.transcript.update(empty_certificate); - try h.writeEncrypted(&w, empty_certificate); - } - } - - // Client handshake finished message - { - const client_finished = try h.makeClientFinishedTls13(w.getPayload()); - h.transcript.update(client_finished); - try h.writeEncrypted(&w, client_finished); - } - - return w.getWritten(); - } - - fn certificateBuilder(h: *HandshakeT, auth: CertKeyPair) CertificateBuilder { - return .{ - .bundle = auth.bundle, - .key = auth.key, - .transcript = &h.transcript, - .tls_version = h.tls_version, - .side = .client, - }; - } - - fn makeClientFinishedTls13(h: *HandshakeT, buf: []u8) ![]const u8 { - var w = record.Writer{ .buf = buf }; - const verify_data = h.transcript.clientFinishedTls13(w.getHandshakePayload()); - try w.advanceHandshake(.finished, verify_data.len); - return w.getWritten(); - } - - fn makeClientKeyExchange(h: *HandshakeT, buf: []u8) ![]const u8 { - var w = record.Writer{ .buf = buf }; - if (h.named_group) |named_group| { - const key = try h.dh_kp.publicKey(named_group); - try w.writeHandshakeHeader(.client_key_exchange, 1 + key.len); - try w.writeInt(@as(u8, @intCast(key.len))); - try w.write(key); - } else { - const key = try h.rsa_secret.encrypted(h.cert.pub_key_algo, h.cert.pub_key); - try w.writeHandshakeHeader(.client_key_exchange, 2 + key.len); - try w.writeInt(@as(u16, @intCast(key.len))); - try w.write(key); - } - return w.getWritten(); - } - - fn readServerFlight2(h: *HandshakeT) !void { - // Read server change cipher spec message. - { - var d = try h.rec_rdr.nextDecoder(); - try d.expectContentType(.change_cipher_spec); - } - // Read encrypted server handshake finished message. Verify that - // content of the server finished message is based on transcript - // hash and master secret. - { - const content_type, const server_finished = - try h.rec_rdr.nextDecrypt(&h.cipher) orelse return error.EndOfStream; - if (content_type != .handshake) - return error.TlsUnexpectedMessage; - const expected = record.handshakeHeader(.finished, 12) ++ h.transcript.serverFinishedTls12(&h.master_secret); - if (!mem.eql(u8, server_finished, &expected)) - return error.TlsBadRecordMac; - } - } - - /// Write encrypted handshake message into `w` - fn writeEncrypted(h: *HandshakeT, w: *record.Writer, cleartext: []const u8) !void { - const ciphertext = try h.cipher.encrypt(w.getFree(), .handshake, cleartext); - w.pos += ciphertext.len; - } - - // Copy handshake parameters to opt.diagnostic - fn updateDiagnostic(h: *HandshakeT, opt: Options) void { - if (opt.diagnostic) |d| { - d.tls_version = h.tls_version; - d.cipher_suite_tag = h.cipher_suite; - d.named_group = h.named_group orelse @as(proto.NamedGroup, @enumFromInt(0x0000)); - d.signature_scheme = h.cert.signature_scheme; - if (opt.auth) |a| - d.client_signature_scheme = a.key.signature_scheme; - } - } - }; -} - -const RsaSecret = struct { - secret: [48]u8, - - fn init(rand: [46]u8) RsaSecret { - return .{ .secret = [_]u8{ 0x03, 0x03 } ++ rand }; - } - - // Pre master secret encrypted with certificate public key. - inline fn encrypted( - self: RsaSecret, - cert_pub_key_algo: Certificate.Parsed.PubKeyAlgo, - cert_pub_key: []const u8, - ) ![]const u8 { - if (cert_pub_key_algo != .rsaEncryption) return error.TlsBadSignatureScheme; - const pk = try rsa.PublicKey.fromDer(cert_pub_key); - var out: [512]u8 = undefined; - return try pk.encryptPkcsv1_5(&self.secret, &out); - } -}; - -const testing = std.testing; -const data12 = @import("testdata/tls12.zig"); -const data13 = @import("testdata/tls13.zig"); -const testu = @import("testu.zig"); - -fn testReader(data: []const u8) record.Reader(std.io.FixedBufferStream([]const u8)) { - return record.reader(std.io.fixedBufferStream(data)); -} -const TestHandshake = Handshake(std.io.FixedBufferStream([]const u8)); - -test "parse tls 1.2 server hello" { - var h = brk: { - var buffer: [1024]u8 = undefined; - var rec_rdr = testReader(&data12.server_hello_responses); - break :brk TestHandshake.init(&buffer, &rec_rdr); - }; - - // Set to known instead of random - h.client_random = data12.client_random; - h.dh_kp.x25519_kp.secret_key = data12.client_secret; - - // Parse server hello, certificate and key exchange messages. - // Read cipher suite, named group, signature scheme, server random certificate public key - // Verify host name, signature - // Calculate key material - h.cert = .{ .host = "example.ulfheim.net", .skip_verify = true, .root_ca = .{} }; - try h.readServerFlight1(); - try testing.expectEqual(.ECDHE_RSA_WITH_AES_128_CBC_SHA, h.cipher_suite); - try testing.expectEqual(.x25519, h.named_group.?); - try testing.expectEqual(.rsa_pkcs1_sha256, h.cert.signature_scheme); - try testing.expectEqualSlices(u8, &data12.server_random, &h.server_random); - try testing.expectEqualSlices(u8, &data12.server_pub_key, h.server_pub_key); - try testing.expectEqualSlices(u8, &data12.signature, h.cert.signature); - try testing.expectEqualSlices(u8, &data12.cert_pub_key, h.cert.pub_key); - - try h.verifyCertificateSignatureTls12(); - try h.generateKeyMaterial(null); - - try testing.expectEqualSlices(u8, &data12.key_material, h.key_material[0..data12.key_material.len]); -} - -test "verify google.com certificate" { - var h = brk: { - var buffer: [1024]u8 = undefined; - var rec_rdr = testReader(@embedFile("testdata/google.com/server_hello")); - break :brk TestHandshake.init(&buffer, &rec_rdr); - }; - h.client_random = @embedFile("testdata/google.com/client_random").*; - - var ca_bundle: Certificate.Bundle = .{}; - try ca_bundle.rescan(testing.allocator); - defer ca_bundle.deinit(testing.allocator); - - h.cert = .{ .host = "google.com", .skip_verify = true, .root_ca = .{}, .now_sec = 1714846451 }; - try h.readServerFlight1(); - try h.verifyCertificateSignatureTls12(); -} - -test "parse tls 1.3 server hello" { - var rec_rdr = testReader(&data13.server_hello); - var d = (try rec_rdr.nextDecoder()); - - const handshake_type = try d.decode(proto.Handshake); - const length = try d.decode(u24); - try testing.expectEqual(0x000076, length); - try testing.expectEqual(.server_hello, handshake_type); - - var h = TestHandshake.init(undefined, undefined); - try h.parseServerHello(&d, length); - - try testing.expectEqual(.AES_256_GCM_SHA384, h.cipher_suite); - try testing.expectEqualSlices(u8, &data13.server_random, &h.server_random); - try testing.expectEqual(.tls_1_3, h.tls_version); - try testing.expectEqual(.x25519, h.named_group); - try testing.expectEqualSlices(u8, &data13.server_pub_key, h.server_pub_key); -} - -test "init tls 1.3 handshake cipher" { - const cipher_suite_tag: CipherSuite = .AES_256_GCM_SHA384; - - var transcript = Transcript{}; - transcript.use(cipher_suite_tag.hash()); - transcript.update(data13.client_hello[record.header_len..]); - transcript.update(data13.server_hello[record.header_len..]); - - var dh_kp = DhKeyPair{ - .x25519_kp = .{ - .public_key = data13.client_public_key, - .secret_key = data13.client_private_key, - }, - }; - const shared_key = try dh_kp.sharedKey(.x25519, &data13.server_pub_key); - try testing.expectEqualSlices(u8, &data13.shared_key, shared_key); - - const cph = try Cipher.initTls13(cipher_suite_tag, transcript.handshakeSecret(shared_key), .client); - - const c = &cph.AES_256_GCM_SHA384; - try testing.expectEqualSlices(u8, &data13.server_handshake_key, &c.decrypt_key); - try testing.expectEqualSlices(u8, &data13.client_handshake_key, &c.encrypt_key); - try testing.expectEqualSlices(u8, &data13.server_handshake_iv, &c.decrypt_iv); - try testing.expectEqualSlices(u8, &data13.client_handshake_iv, &c.encrypt_iv); -} - -fn initExampleHandshake(h: *TestHandshake) !void { - h.cipher_suite = .AES_256_GCM_SHA384; - h.transcript.use(h.cipher_suite.hash()); - h.transcript.update(data13.client_hello[record.header_len..]); - h.transcript.update(data13.server_hello[record.header_len..]); - h.cipher = try Cipher.initTls13(h.cipher_suite, h.transcript.handshakeSecret(&data13.shared_key), .client); - h.tls_version = .tls_1_3; - h.cert.now_sec = 1714846451; - h.server_pub_key = &data13.server_pub_key; -} - -test "tls 1.3 decrypt wrapped record" { - var cph = brk: { - var h = TestHandshake.init(undefined, undefined); - try initExampleHandshake(&h); - break :brk h.cipher; - }; - - var cleartext_buf: [1024]u8 = undefined; - { - const rec = record.Record.init(&data13.server_encrypted_extensions_wrapped); - - const content_type, const cleartext = try cph.decrypt(&cleartext_buf, rec); - try testing.expectEqual(.handshake, content_type); - try testing.expectEqualSlices(u8, &data13.server_encrypted_extensions, cleartext); - } - { - const rec = record.Record.init(&data13.server_certificate_wrapped); - - const content_type, const cleartext = try cph.decrypt(&cleartext_buf, rec); - try testing.expectEqual(.handshake, content_type); - try testing.expectEqualSlices(u8, &data13.server_certificate, cleartext); - } -} - -test "tls 1.3 process server flight" { - var buffer: [1024]u8 = undefined; - var h = brk: { - var rec_rdr = testReader(&data13.server_flight); - break :brk TestHandshake.init(&buffer, &rec_rdr); - }; - - try initExampleHandshake(&h); - h.cert = .{ .host = "example.ulfheim.net", .skip_verify = true, .root_ca = .{} }; - try h.readEncryptedServerFlight1(); - - { // application cipher keys calculation - try testing.expectEqualSlices(u8, &data13.handshake_hash, &h.transcript.sha384.hash.peek()); - - var cph = try Cipher.initTls13(h.cipher_suite, h.transcript.applicationSecret(), .client); - const c = &cph.AES_256_GCM_SHA384; - try testing.expectEqualSlices(u8, &data13.server_application_key, &c.decrypt_key); - try testing.expectEqualSlices(u8, &data13.client_application_key, &c.encrypt_key); - try testing.expectEqualSlices(u8, &data13.server_application_iv, &c.decrypt_iv); - try testing.expectEqualSlices(u8, &data13.client_application_iv, &c.encrypt_iv); - - const encrypted = try cph.encrypt(&buffer, .application_data, "ping"); - try testing.expectEqualSlices(u8, &data13.client_ping_wrapped, encrypted); - } - { // client finished message - var buf: [4 + Transcript.max_mac_length]u8 = undefined; - const client_finished = try h.makeClientFinishedTls13(&buf); - try testing.expectEqualSlices(u8, &data13.client_finished_verify_data, client_finished[4..]); - const encrypted = try h.cipher.encrypt(&buffer, .handshake, client_finished); - try testing.expectEqualSlices(u8, &data13.client_finished_wrapped, encrypted); - } -} - -test "create client hello" { - var h = brk: { - var buffer: [1024]u8 = undefined; - var h = TestHandshake.init(&buffer, undefined); - h.client_random = testu.hexToBytes( - \\ 00 01 02 03 04 05 06 07 08 09 0a 0b 0c 0d 0e 0f 10 11 12 13 14 15 16 17 18 19 1a 1b 1c 1d 1e 1f - ); - break :brk h; - }; - - const actual = try h.makeClientHello(.{ - .host = "google.com", - .root_ca = .{}, - .cipher_suites = &[_]CipherSuite{CipherSuite.ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, - .named_groups = &[_]proto.NamedGroup{ .x25519, .secp256r1, .secp384r1 }, - }); - - const expected = testu.hexToBytes( - "16 03 03 00 6d " ++ // record header - "01 00 00 69 " ++ // handshake header - "03 03 " ++ // protocol version - "00 01 02 03 04 05 06 07 08 09 0a 0b 0c 0d 0e 0f 10 11 12 13 14 15 16 17 18 19 1a 1b 1c 1d 1e 1f " ++ // client random - "00 " ++ // no session id - "00 02 c0 2b " ++ // cipher suites - "01 00 " ++ // compression methods - "00 3e " ++ // extensions length - "00 2b 00 03 02 03 03 " ++ // supported versions extension - "00 0d 00 14 00 12 04 03 05 03 08 04 08 05 08 06 08 07 02 01 04 01 05 01 " ++ // signature algorithms extension - "00 0a 00 08 00 06 00 1d 00 17 00 18 " ++ // named groups extension - "00 00 00 0f 00 0d 00 00 0a 67 6f 6f 67 6c 65 2e 63 6f 6d ", // server name extension - ); - try testing.expectEqualSlices(u8, &expected, actual); -} - -test "handshake verify server finished message" { - var buffer: [1024]u8 = undefined; - var rec_rdr = testReader(&data12.server_handshake_finished_msgs); - var h = TestHandshake.init(&buffer, &rec_rdr); - - h.cipher_suite = .ECDHE_ECDSA_WITH_AES_128_CBC_SHA; - h.master_secret = data12.master_secret; - - // add handshake messages to the transcript - for (data12.handshake_messages) |msg| { - h.transcript.update(msg[record.header_len..]); - } - - // expect verify data - const client_finished = h.transcript.clientFinishedTls12(&h.master_secret); - try testing.expectEqualSlices(u8, &data12.client_finished, &record.handshakeHeader(.finished, 12) ++ client_finished); - - // init client with prepared key_material - h.cipher = try Cipher.initTls12(.ECDHE_RSA_WITH_AES_128_CBC_SHA, &data12.key_material, .client); - - // check that server verify data matches calculates from hashes of all handshake messages - h.transcript.update(&data12.client_finished); - try h.readServerFlight2(); -} diff --git a/src/http/async/tls.zig/handshake_common.zig b/src/http/async/tls.zig/handshake_common.zig deleted file mode 100644 index 178a3cea..00000000 --- a/src/http/async/tls.zig/handshake_common.zig +++ /dev/null @@ -1,448 +0,0 @@ -const std = @import("std"); -const assert = std.debug.assert; -const mem = std.mem; -const crypto = std.crypto; -const Certificate = crypto.Certificate; - -const Transcript = @import("transcript.zig").Transcript; -const PrivateKey = @import("PrivateKey.zig"); -const record = @import("record.zig"); -const rsa = @import("rsa/rsa.zig"); -const proto = @import("protocol.zig"); - -const X25519 = crypto.dh.X25519; -const EcdsaP256Sha256 = crypto.sign.ecdsa.EcdsaP256Sha256; -const EcdsaP384Sha384 = crypto.sign.ecdsa.EcdsaP384Sha384; -const Kyber768 = crypto.kem.kyber_d00.Kyber768; - -pub const supported_signature_algorithms = &[_]proto.SignatureScheme{ - .ecdsa_secp256r1_sha256, - .ecdsa_secp384r1_sha384, - .rsa_pss_rsae_sha256, - .rsa_pss_rsae_sha384, - .rsa_pss_rsae_sha512, - .ed25519, - .rsa_pkcs1_sha1, - .rsa_pkcs1_sha256, - .rsa_pkcs1_sha384, -}; - -pub const CertKeyPair = struct { - /// A chain of one or more certificates, leaf first. - /// - /// Each X.509 certificate contains the public key of a key pair, extra - /// information (the name of the holder, the name of an issuer of the - /// certificate, validity time spans) and a signature generated using the - /// private key of the issuer of the certificate. - /// - /// All certificates from the bundle are sent to the other side when creating - /// Certificate tls message. - /// - /// Leaf certificate and private key are used to create signature for - /// CertifyVerify tls message. - bundle: Certificate.Bundle, - - /// Private key corresponding to the public key in leaf certificate from the - /// bundle. - key: PrivateKey, - - pub fn load( - allocator: std.mem.Allocator, - dir: std.fs.Dir, - cert_path: []const u8, - key_path: []const u8, - ) !CertKeyPair { - var bundle: Certificate.Bundle = .{}; - try bundle.addCertsFromFilePath(allocator, dir, cert_path); - - const key_file = try dir.openFile(key_path, .{}); - defer key_file.close(); - const key = try PrivateKey.fromFile(allocator, key_file); - - return .{ .bundle = bundle, .key = key }; - } - - pub fn deinit(c: *CertKeyPair, allocator: std.mem.Allocator) void { - c.bundle.deinit(allocator); - } -}; - -pub const CertBundle = struct { - // A chain of one or more certificates. - // - // They are used to verify that certificate chain sent by the other side - // forms valid trust chain. - bundle: Certificate.Bundle = .{}, - - pub fn fromFile(allocator: std.mem.Allocator, dir: std.fs.Dir, path: []const u8) !CertBundle { - var bundle: Certificate.Bundle = .{}; - try bundle.addCertsFromFilePath(allocator, dir, path); - return .{ .bundle = bundle }; - } - - pub fn fromSystem(allocator: std.mem.Allocator) !CertBundle { - var bundle: Certificate.Bundle = .{}; - try bundle.rescan(allocator); - return .{ .bundle = bundle }; - } - - pub fn deinit(cb: *CertBundle, allocator: std.mem.Allocator) void { - cb.bundle.deinit(allocator); - } -}; - -pub const CertificateBuilder = struct { - bundle: Certificate.Bundle, - key: PrivateKey, - transcript: *Transcript, - tls_version: proto.Version = .tls_1_3, - side: proto.Side = .client, - - pub fn makeCertificate(h: CertificateBuilder, buf: []u8) ![]const u8 { - var w = record.Writer{ .buf = buf }; - const certs = h.bundle.bytes.items; - const certs_count = h.bundle.map.size; - - // Differences between tls 1.3 and 1.2 - // TLS 1.3 has request context in header and extensions for each certificate. - // Here we use empty length for each field. - // TLS 1.2 don't have these two fields. - const request_context, const extensions = if (h.tls_version == .tls_1_3) - .{ &[_]u8{0}, &[_]u8{ 0, 0 } } - else - .{ &[_]u8{}, &[_]u8{} }; - const certs_len = certs.len + (3 + extensions.len) * certs_count; - - // Write handshake header - try w.writeHandshakeHeader(.certificate, certs_len + request_context.len + 3); - try w.write(request_context); - try w.writeInt(@as(u24, @intCast(certs_len))); - - // Write each certificate - var index: u32 = 0; - while (index < certs.len) { - const e = try Certificate.der.Element.parse(certs, index); - const cert = certs[index..e.slice.end]; - try w.writeInt(@as(u24, @intCast(cert.len))); // certificate length - try w.write(cert); // certificate - try w.write(extensions); // certificate extensions - index = e.slice.end; - } - return w.getWritten(); - } - - pub fn makeCertificateVerify(h: CertificateBuilder, buf: []u8) ![]const u8 { - var w = record.Writer{ .buf = buf }; - const signature, const signature_scheme = try h.createSignature(); - try w.writeHandshakeHeader(.certificate_verify, signature.len + 4); - try w.writeEnum(signature_scheme); - try w.writeInt(@as(u16, @intCast(signature.len))); - try w.write(signature); - return w.getWritten(); - } - - /// Creates signature for client certificate signature message. - /// Returns signature bytes and signature scheme. - inline fn createSignature(h: CertificateBuilder) !struct { []const u8, proto.SignatureScheme } { - switch (h.key.signature_scheme) { - inline .ecdsa_secp256r1_sha256, - .ecdsa_secp384r1_sha384, - => |comptime_scheme| { - const Ecdsa = SchemeEcdsa(comptime_scheme); - const key = h.key.key.ecdsa; - const key_len = Ecdsa.SecretKey.encoded_length; - if (key.len < key_len) return error.InvalidEncoding; - const secret_key = try Ecdsa.SecretKey.fromBytes(key[0..key_len].*); - const key_pair = try Ecdsa.KeyPair.fromSecretKey(secret_key); - var signer = try key_pair.signer(null); - h.setSignatureVerifyBytes(&signer); - const signature = try signer.finalize(); - var buf: [Ecdsa.Signature.der_encoded_length_max]u8 = undefined; - return .{ signature.toDer(&buf), comptime_scheme }; - }, - inline .rsa_pss_rsae_sha256, - .rsa_pss_rsae_sha384, - .rsa_pss_rsae_sha512, - => |comptime_scheme| { - const Hash = SchemeHash(comptime_scheme); - var signer = try h.key.key.rsa.signerOaep(Hash, null); - h.setSignatureVerifyBytes(&signer); - var buf: [512]u8 = undefined; - const signature = try signer.finalize(&buf); - return .{ signature.bytes, comptime_scheme }; - }, - else => return error.TlsUnknownSignatureScheme, - } - } - - fn setSignatureVerifyBytes(h: CertificateBuilder, signer: anytype) void { - if (h.tls_version == .tls_1_2) { - // tls 1.2 signature uses current transcript hash value. - // ref: https://datatracker.ietf.org/doc/html/rfc5246.html#section-7.4.8 - const Hash = @TypeOf(signer.h); - signer.h = h.transcript.hash(Hash); - } else { - // tls 1.3 signature is computed over concatenation of 64 spaces, - // context, separator and content. - // ref: https://datatracker.ietf.org/doc/html/rfc8446#section-4.4.3 - if (h.side == .server) { - signer.update(h.transcript.serverCertificateVerify()); - } else { - signer.update(h.transcript.clientCertificateVerify()); - } - } - } - - fn SchemeEcdsa(comptime scheme: proto.SignatureScheme) type { - return switch (scheme) { - .ecdsa_secp256r1_sha256 => EcdsaP256Sha256, - .ecdsa_secp384r1_sha384 => EcdsaP384Sha384, - else => unreachable, - }; - } -}; - -pub const CertificateParser = struct { - pub_key_algo: Certificate.Parsed.PubKeyAlgo = undefined, - pub_key_buf: [600]u8 = undefined, - pub_key: []const u8 = undefined, - - signature_scheme: proto.SignatureScheme = @enumFromInt(0), - signature_buf: [1024]u8 = undefined, - signature: []const u8 = undefined, - - root_ca: Certificate.Bundle, - host: []const u8, - skip_verify: bool = false, - now_sec: i64 = 0, - - pub fn parseCertificate(h: *CertificateParser, d: *record.Decoder, tls_version: proto.Version) !void { - if (h.now_sec == 0) { - h.now_sec = std.time.timestamp(); - } - if (tls_version == .tls_1_3) { - const request_context = try d.decode(u8); - if (request_context != 0) return error.TlsIllegalParameter; - } - - var trust_chain_established = false; - var last_cert: ?Certificate.Parsed = null; - const certs_len = try d.decode(u24); - const start_idx = d.idx; - while (d.idx - start_idx < certs_len) { - const cert_len = try d.decode(u24); - // std.debug.print("=> {} {} {} {}\n", .{ certs_len, d.idx, cert_len, d.payload.len }); - const cert = try d.slice(cert_len); - if (tls_version == .tls_1_3) { - // certificate extensions present in tls 1.3 - try d.skip(try d.decode(u16)); - } - if (trust_chain_established) - continue; - - const subject = try (Certificate{ .buffer = cert, .index = 0 }).parse(); - if (last_cert) |pc| { - if (pc.verify(subject, h.now_sec)) { - last_cert = subject; - } else |err| switch (err) { - error.CertificateIssuerMismatch => { - // skip certificate which is not part of the chain - continue; - }, - else => return err, - } - } else { // first certificate - if (!h.skip_verify and h.host.len > 0) { - try subject.verifyHostName(h.host); - } - h.pub_key = dupe(&h.pub_key_buf, subject.pubKey()); - h.pub_key_algo = subject.pub_key_algo; - last_cert = subject; - } - if (!h.skip_verify) { - if (h.root_ca.verify(last_cert.?, h.now_sec)) |_| { - trust_chain_established = true; - } else |err| switch (err) { - error.CertificateIssuerNotFound => {}, - else => return err, - } - } - } - if (!h.skip_verify and !trust_chain_established) { - return error.CertificateIssuerNotFound; - } - } - - pub fn parseCertificateVerify(h: *CertificateParser, d: *record.Decoder) !void { - h.signature_scheme = try d.decode(proto.SignatureScheme); - h.signature = dupe(&h.signature_buf, try d.slice(try d.decode(u16))); - } - - pub fn verifySignature(h: *CertificateParser, verify_bytes: []const u8) !void { - switch (h.signature_scheme) { - inline .ecdsa_secp256r1_sha256, - .ecdsa_secp384r1_sha384, - => |comptime_scheme| { - if (h.pub_key_algo != .X9_62_id_ecPublicKey) return error.TlsBadSignatureScheme; - const cert_named_curve = h.pub_key_algo.X9_62_id_ecPublicKey; - switch (cert_named_curve) { - inline .secp384r1, .X9_62_prime256v1 => |comptime_cert_named_curve| { - const Ecdsa = SchemeEcdsaCert(comptime_scheme, comptime_cert_named_curve); - const key = try Ecdsa.PublicKey.fromSec1(h.pub_key); - const sig = try Ecdsa.Signature.fromDer(h.signature); - try sig.verify(verify_bytes, key); - }, - else => return error.TlsUnknownSignatureScheme, - } - }, - .ed25519 => { - if (h.pub_key_algo != .curveEd25519) return error.TlsBadSignatureScheme; - const Eddsa = crypto.sign.Ed25519; - if (h.signature.len != Eddsa.Signature.encoded_length) return error.InvalidEncoding; - const sig = Eddsa.Signature.fromBytes(h.signature[0..Eddsa.Signature.encoded_length].*); - if (h.pub_key.len != Eddsa.PublicKey.encoded_length) return error.InvalidEncoding; - const key = try Eddsa.PublicKey.fromBytes(h.pub_key[0..Eddsa.PublicKey.encoded_length].*); - try sig.verify(verify_bytes, key); - }, - inline .rsa_pss_rsae_sha256, - .rsa_pss_rsae_sha384, - .rsa_pss_rsae_sha512, - => |comptime_scheme| { - if (h.pub_key_algo != .rsaEncryption) return error.TlsBadSignatureScheme; - const Hash = SchemeHash(comptime_scheme); - const pk = try rsa.PublicKey.fromDer(h.pub_key); - const sig = rsa.Pss(Hash).Signature{ .bytes = h.signature }; - try sig.verify(verify_bytes, pk, null); - }, - inline .rsa_pkcs1_sha1, - .rsa_pkcs1_sha256, - .rsa_pkcs1_sha384, - .rsa_pkcs1_sha512, - => |comptime_scheme| { - if (h.pub_key_algo != .rsaEncryption) return error.TlsBadSignatureScheme; - const Hash = SchemeHash(comptime_scheme); - const pk = try rsa.PublicKey.fromDer(h.pub_key); - const sig = rsa.PKCS1v1_5(Hash).Signature{ .bytes = h.signature }; - try sig.verify(verify_bytes, pk); - }, - else => return error.TlsUnknownSignatureScheme, - } - } - - fn SchemeEcdsaCert(comptime scheme: proto.SignatureScheme, comptime cert_named_curve: Certificate.NamedCurve) type { - const Sha256 = crypto.hash.sha2.Sha256; - const Sha384 = crypto.hash.sha2.Sha384; - const Ecdsa = crypto.sign.ecdsa.Ecdsa; - - return switch (scheme) { - .ecdsa_secp256r1_sha256 => Ecdsa(cert_named_curve.Curve(), Sha256), - .ecdsa_secp384r1_sha384 => Ecdsa(cert_named_curve.Curve(), Sha384), - else => @compileError("bad scheme"), - }; - } -}; - -fn SchemeHash(comptime scheme: proto.SignatureScheme) type { - const Sha256 = crypto.hash.sha2.Sha256; - const Sha384 = crypto.hash.sha2.Sha384; - const Sha512 = crypto.hash.sha2.Sha512; - - return switch (scheme) { - .rsa_pkcs1_sha1 => crypto.hash.Sha1, - .rsa_pss_rsae_sha256, .rsa_pkcs1_sha256 => Sha256, - .rsa_pss_rsae_sha384, .rsa_pkcs1_sha384 => Sha384, - .rsa_pss_rsae_sha512, .rsa_pkcs1_sha512 => Sha512, - else => @compileError("bad scheme"), - }; -} - -pub fn dupe(buf: []u8, data: []const u8) []u8 { - const n = @min(data.len, buf.len); - @memcpy(buf[0..n], data[0..n]); - return buf[0..n]; -} - -pub const DhKeyPair = struct { - x25519_kp: X25519.KeyPair = undefined, - secp256r1_kp: EcdsaP256Sha256.KeyPair = undefined, - secp384r1_kp: EcdsaP384Sha384.KeyPair = undefined, - kyber768_kp: Kyber768.KeyPair = undefined, - - pub const seed_len = 32 + 32 + 48 + 64; - - pub fn init(seed: [seed_len]u8, named_groups: []const proto.NamedGroup) !DhKeyPair { - var kp: DhKeyPair = .{}; - for (named_groups) |ng| - switch (ng) { - .x25519 => kp.x25519_kp = try X25519.KeyPair.create(seed[0..][0..X25519.seed_length].*), - .secp256r1 => kp.secp256r1_kp = try EcdsaP256Sha256.KeyPair.create(seed[32..][0..EcdsaP256Sha256.KeyPair.seed_length].*), - .secp384r1 => kp.secp384r1_kp = try EcdsaP384Sha384.KeyPair.create(seed[32 + 32 ..][0..EcdsaP384Sha384.KeyPair.seed_length].*), - .x25519_kyber768d00 => kp.kyber768_kp = try Kyber768.KeyPair.create(seed[32 + 32 + 48 ..][0..Kyber768.seed_length].*), - else => return error.TlsIllegalParameter, - }; - return kp; - } - - pub inline fn sharedKey(self: DhKeyPair, named_group: proto.NamedGroup, server_pub_key: []const u8) ![]const u8 { - return switch (named_group) { - .x25519 => brk: { - if (server_pub_key.len != X25519.public_length) - return error.TlsIllegalParameter; - break :brk &(try X25519.scalarmult( - self.x25519_kp.secret_key, - server_pub_key[0..X25519.public_length].*, - )); - }, - .secp256r1 => brk: { - const pk = try EcdsaP256Sha256.PublicKey.fromSec1(server_pub_key); - const mul = try pk.p.mulPublic(self.secp256r1_kp.secret_key.bytes, .big); - break :brk &mul.affineCoordinates().x.toBytes(.big); - }, - .secp384r1 => brk: { - const pk = try EcdsaP384Sha384.PublicKey.fromSec1(server_pub_key); - const mul = try pk.p.mulPublic(self.secp384r1_kp.secret_key.bytes, .big); - break :brk &mul.affineCoordinates().x.toBytes(.big); - }, - .x25519_kyber768d00 => brk: { - const xksl = crypto.dh.X25519.public_length; - const hksl = xksl + Kyber768.ciphertext_length; - if (server_pub_key.len != hksl) - return error.TlsIllegalParameter; - - break :brk &((crypto.dh.X25519.scalarmult( - self.x25519_kp.secret_key, - server_pub_key[0..xksl].*, - ) catch return error.TlsDecryptFailure) ++ (self.kyber768_kp.secret_key.decaps( - server_pub_key[xksl..hksl], - ) catch return error.TlsDecryptFailure)); - }, - else => return error.TlsIllegalParameter, - }; - } - - // Returns 32, 65, 97 or 1216 bytes - pub inline fn publicKey(self: DhKeyPair, named_group: proto.NamedGroup) ![]const u8 { - return switch (named_group) { - .x25519 => &self.x25519_kp.public_key, - .secp256r1 => &self.secp256r1_kp.public_key.toUncompressedSec1(), - .secp384r1 => &self.secp384r1_kp.public_key.toUncompressedSec1(), - .x25519_kyber768d00 => &self.x25519_kp.public_key ++ self.kyber768_kp.public_key.toBytes(), - else => return error.TlsIllegalParameter, - }; - } -}; - -const testing = std.testing; -const testu = @import("testu.zig"); - -test "DhKeyPair.x25519" { - var seed: [DhKeyPair.seed_len]u8 = undefined; - testu.fill(&seed); - const server_pub_key = &testu.hexToBytes("3303486548531f08d91e675caf666c2dc924ac16f47a861a7f4d05919d143637"); - const expected = &testu.hexToBytes( - \\ F1 67 FB 4A 49 B2 91 77 08 29 45 A1 F7 08 5A 21 - \\ AF FE 9E 78 C2 03 9B 81 92 40 72 73 74 7A 46 1E - ); - const kp = try DhKeyPair.init(seed, &.{.x25519}); - try testing.expectEqualSlices(u8, expected, try kp.sharedKey(.x25519, server_pub_key)); -} diff --git a/src/http/async/tls.zig/handshake_server.zig b/src/http/async/tls.zig/handshake_server.zig deleted file mode 100644 index c26e8c69..00000000 --- a/src/http/async/tls.zig/handshake_server.zig +++ /dev/null @@ -1,520 +0,0 @@ -const std = @import("std"); -const assert = std.debug.assert; -const crypto = std.crypto; -const mem = std.mem; -const Certificate = crypto.Certificate; - -const cipher = @import("cipher.zig"); -const Cipher = cipher.Cipher; -const CipherSuite = @import("cipher.zig").CipherSuite; -const cipher_suites = @import("cipher.zig").cipher_suites; -const Transcript = @import("transcript.zig").Transcript; -const record = @import("record.zig"); -const PrivateKey = @import("PrivateKey.zig"); -const proto = @import("protocol.zig"); - -const common = @import("handshake_common.zig"); -const dupe = common.dupe; -const CertificateBuilder = common.CertificateBuilder; -const CertificateParser = common.CertificateParser; -const DhKeyPair = common.DhKeyPair; -const CertBundle = common.CertBundle; -const CertKeyPair = common.CertKeyPair; - -pub const Options = struct { - /// Server authentication. If null server will not send Certificate and - /// CertificateVerify message. - auth: ?CertKeyPair, - - /// If not null server will request client certificate. If auth_type is - /// .request empty client certificate message will be accepted. - /// Client certificate will be verified with root_ca certificates. - client_auth: ?ClientAuth = null, -}; - -pub const ClientAuth = struct { - /// Set of root certificate authorities that server use when verifying - /// client certificates. - root_ca: CertBundle, - - auth_type: Type = .require, - - pub const Type = enum { - /// Client certificate will be requested during the handshake, but does - /// not require that the client send any certificates. - request, - /// Client certificate will be requested during the handshake, and client - /// has to send valid certificate. - require, - }; -}; - -pub fn Handshake(comptime Stream: type) type { - const RecordReaderT = record.Reader(Stream); - return struct { - // public key len: x25519 = 32, secp256r1 = 65, secp384r1 = 97 - const max_pub_key_len = 98; - const supported_named_groups = &[_]proto.NamedGroup{ .x25519, .secp256r1, .secp384r1 }; - - server_random: [32]u8 = undefined, - client_random: [32]u8 = undefined, - legacy_session_id_buf: [32]u8 = undefined, - legacy_session_id: []u8 = "", - cipher_suite: CipherSuite = @enumFromInt(0), - signature_scheme: proto.SignatureScheme = @enumFromInt(0), - named_group: proto.NamedGroup = @enumFromInt(0), - client_pub_key_buf: [max_pub_key_len]u8 = undefined, - client_pub_key: []u8 = "", - server_pub_key_buf: [max_pub_key_len]u8 = undefined, - server_pub_key: []u8 = "", - - cipher: Cipher = undefined, - transcript: Transcript = .{}, - rec_rdr: *RecordReaderT, - buffer: []u8, - - const HandshakeT = @This(); - - pub fn init(buf: []u8, rec_rdr: *RecordReaderT) HandshakeT { - return .{ - .rec_rdr = rec_rdr, - .buffer = buf, - }; - } - - fn writeAlert(h: *HandshakeT, stream: Stream, cph: ?*Cipher, err: anyerror) !void { - if (cph) |c| { - const cleartext = proto.alertFromError(err); - const ciphertext = try c.encrypt(h.buffer, .alert, &cleartext); - stream.writeAll(ciphertext) catch {}; - } else { - const alert = record.header(.alert, 2) ++ proto.alertFromError(err); - stream.writeAll(&alert) catch {}; - } - } - - pub fn handshake(h: *HandshakeT, stream: Stream, opt: Options) !Cipher { - crypto.random.bytes(&h.server_random); - if (opt.auth) |a| { - // required signature scheme in client hello - h.signature_scheme = a.key.signature_scheme; - } - - h.readClientHello() catch |err| { - try h.writeAlert(stream, null, err); - return err; - }; - h.transcript.use(h.cipher_suite.hash()); - - const server_flight = brk: { - var w = record.Writer{ .buf = h.buffer }; - - const shared_key = h.sharedKey() catch |err| { - try h.writeAlert(stream, null, err); - return err; - }; - { - const hello = try h.makeServerHello(w.getFree()); - h.transcript.update(hello[record.header_len..]); - w.pos += hello.len; - } - { - const handshake_secret = h.transcript.handshakeSecret(shared_key); - h.cipher = try Cipher.initTls13(h.cipher_suite, handshake_secret, .server); - } - try w.writeRecord(.change_cipher_spec, &[_]u8{1}); - { - const encrypted_extensions = &record.handshakeHeader(.encrypted_extensions, 2) ++ [_]u8{ 0, 0 }; - h.transcript.update(encrypted_extensions); - try h.writeEncrypted(&w, encrypted_extensions); - } - if (opt.client_auth) |_| { - const certificate_request = try makeCertificateRequest(w.getPayload()); - h.transcript.update(certificate_request); - try h.writeEncrypted(&w, certificate_request); - } - if (opt.auth) |a| { - const cm = CertificateBuilder{ - .bundle = a.bundle, - .key = a.key, - .transcript = &h.transcript, - .side = .server, - }; - { - const certificate = try cm.makeCertificate(w.getPayload()); - h.transcript.update(certificate); - try h.writeEncrypted(&w, certificate); - } - { - const certificate_verify = try cm.makeCertificateVerify(w.getPayload()); - h.transcript.update(certificate_verify); - try h.writeEncrypted(&w, certificate_verify); - } - } - { - const finished = try h.makeFinished(w.getPayload()); - h.transcript.update(finished); - try h.writeEncrypted(&w, finished); - } - break :brk w.getWritten(); - }; - try stream.writeAll(server_flight); - - var app_cipher = brk: { - const application_secret = h.transcript.applicationSecret(); - break :brk try Cipher.initTls13(h.cipher_suite, application_secret, .server); - }; - - h.readClientFlight2(opt) catch |err| { - // Alert received from client - if (!mem.startsWith(u8, @errorName(err), "TlsAlert")) { - try h.writeAlert(stream, &app_cipher, err); - } - return err; - }; - return app_cipher; - } - - inline fn sharedKey(h: *HandshakeT) ![]const u8 { - var seed: [DhKeyPair.seed_len]u8 = undefined; - crypto.random.bytes(&seed); - var kp = try DhKeyPair.init(seed, supported_named_groups); - h.server_pub_key = dupe(&h.server_pub_key_buf, try kp.publicKey(h.named_group)); - return try kp.sharedKey(h.named_group, h.client_pub_key); - } - - fn readClientFlight2(h: *HandshakeT, opt: Options) !void { - var cleartext_buf = h.buffer; - var cleartext_buf_head: usize = 0; - var cleartext_buf_tail: usize = 0; - var handshake_state: proto.Handshake = .finished; - var cert: CertificateParser = undefined; - if (opt.client_auth) |client_auth| { - cert = .{ .root_ca = client_auth.root_ca.bundle, .host = "" }; - handshake_state = .certificate; - } - - outer: while (true) { - const rec = (try h.rec_rdr.next() orelse return error.EndOfStream); - if (rec.protocol_version != .tls_1_2 and rec.content_type != .alert) - return error.TlsProtocolVersion; - - switch (rec.content_type) { - .change_cipher_spec => { - if (rec.payload.len != 1) return error.TlsUnexpectedMessage; - }, - .application_data => { - const content_type, const cleartext = try h.cipher.decrypt( - cleartext_buf[cleartext_buf_tail..], - rec, - ); - cleartext_buf_tail += cleartext.len; - if (cleartext_buf_tail > cleartext_buf.len) return error.TlsRecordOverflow; - - var d = record.Decoder.init(content_type, cleartext_buf[cleartext_buf_head..cleartext_buf_tail]); - try d.expectContentType(.handshake); - while (!d.eof()) { - const start_idx = d.idx; - const handshake_type = try d.decode(proto.Handshake); - const length = try d.decode(u24); - - if (length > cipher.max_cleartext_len) - return error.TlsRecordOverflow; - if (length > d.rest().len) - continue :outer; // fragmented handshake into multiple records - - defer { - const handshake_payload = d.payload[start_idx..d.idx]; - h.transcript.update(handshake_payload); - cleartext_buf_head += handshake_payload.len; - } - - if (handshake_state != handshake_type) - return error.TlsUnexpectedMessage; - - switch (handshake_type) { - .certificate => { - if (length == 4) { - // got empty certificate message - if (opt.client_auth.?.auth_type == .require) - return error.TlsCertificateRequired; - try d.skip(length); - handshake_state = .finished; - } else { - try cert.parseCertificate(&d, .tls_1_3); - handshake_state = .certificate_verify; - } - }, - .certificate_verify => { - try cert.parseCertificateVerify(&d); - cert.verifySignature(h.transcript.clientCertificateVerify()) catch |err| return switch (err) { - error.TlsUnknownSignatureScheme => error.TlsIllegalParameter, - else => error.TlsDecryptError, - }; - handshake_state = .finished; - }, - .finished => { - const actual = try d.slice(length); - var buf: [Transcript.max_mac_length]u8 = undefined; - const expected = h.transcript.clientFinishedTls13(&buf); - if (!mem.eql(u8, expected, actual)) - return if (expected.len == actual.len) - error.TlsDecryptError - else - error.TlsDecodeError; - return; - }, - else => return error.TlsUnexpectedMessage, - } - } - cleartext_buf_head = 0; - cleartext_buf_tail = 0; - }, - .alert => { - var d = rec.decoder(); - return d.raiseAlert(); - }, - else => return error.TlsUnexpectedMessage, - } - } - } - - fn makeFinished(h: *HandshakeT, buf: []u8) ![]const u8 { - var w = record.Writer{ .buf = buf }; - const verify_data = h.transcript.serverFinishedTls13(w.getHandshakePayload()); - try w.advanceHandshake(.finished, verify_data.len); - return w.getWritten(); - } - - /// Write encrypted handshake message into `w` - fn writeEncrypted(h: *HandshakeT, w: *record.Writer, cleartext: []const u8) !void { - const ciphertext = try h.cipher.encrypt(w.getFree(), .handshake, cleartext); - w.pos += ciphertext.len; - } - - fn makeServerHello(h: *HandshakeT, buf: []u8) ![]const u8 { - const header_len = 9; // tls record header (5 bytes) and handshake header (4 bytes) - var w = record.Writer{ .buf = buf[header_len..] }; - - try w.writeEnum(proto.Version.tls_1_2); - try w.write(&h.server_random); - { - try w.writeInt(@as(u8, @intCast(h.legacy_session_id.len))); - if (h.legacy_session_id.len > 0) try w.write(h.legacy_session_id); - } - try w.writeEnum(h.cipher_suite); - try w.write(&[_]u8{0}); // compression method - - var e = record.Writer{ .buf = buf[header_len + w.pos + 2 ..] }; - { // supported versions extension - try e.writeEnum(proto.Extension.supported_versions); - try e.writeInt(@as(u16, 2)); - try e.writeEnum(proto.Version.tls_1_3); - } - { // key share extension - const key_len: u16 = @intCast(h.server_pub_key.len); - try e.writeEnum(proto.Extension.key_share); - try e.writeInt(key_len + 4); - try e.writeEnum(h.named_group); - try e.writeInt(key_len); - try e.write(h.server_pub_key); - } - try w.writeInt(@as(u16, @intCast(e.pos))); // extensions length - - const payload_len = w.pos + e.pos; - buf[0..header_len].* = record.header(.handshake, 4 + payload_len) ++ - record.handshakeHeader(.server_hello, payload_len); - - return buf[0 .. header_len + payload_len]; - } - - fn makeCertificateRequest(buf: []u8) ![]const u8 { - // handshake header + context length + extensions length - const header_len = 4 + 1 + 2; - - // First write extensions, leave space for header. - var ext = record.Writer{ .buf = buf[header_len..] }; - try ext.writeExtension(.signature_algorithms, common.supported_signature_algorithms); - - var w = record.Writer{ .buf = buf }; - try w.writeHandshakeHeader(.certificate_request, ext.pos + 3); - try w.writeInt(@as(u8, 0)); // certificate request context length = 0 - try w.writeInt(@as(u16, @intCast(ext.pos))); // extensions length - assert(w.pos == header_len); - w.pos += ext.pos; - - return w.getWritten(); - } - - fn readClientHello(h: *HandshakeT) !void { - var d = try h.rec_rdr.nextDecoder(); - try d.expectContentType(.handshake); - h.transcript.update(d.payload); - - const handshake_type = try d.decode(proto.Handshake); - if (handshake_type != .client_hello) return error.TlsUnexpectedMessage; - _ = try d.decode(u24); // handshake length - if (try d.decode(proto.Version) != .tls_1_2) return error.TlsProtocolVersion; - - h.client_random = try d.array(32); - { // legacy session id - const len = try d.decode(u8); - h.legacy_session_id = dupe(&h.legacy_session_id_buf, try d.slice(len)); - } - { // cipher suites - const end_idx = try d.decode(u16) + d.idx; - - while (d.idx < end_idx) { - const cipher_suite = try d.decode(CipherSuite); - if (cipher_suites.includes(cipher_suites.tls13, cipher_suite) and - @intFromEnum(h.cipher_suite) == 0) - { - h.cipher_suite = cipher_suite; - } - } - if (@intFromEnum(h.cipher_suite) == 0) - return error.TlsHandshakeFailure; - } - try d.skip(2); // compression methods - - var key_share_received = false; - // extensions - const extensions_end_idx = try d.decode(u16) + d.idx; - while (d.idx < extensions_end_idx) { - const extension_type = try d.decode(proto.Extension); - const extension_len = try d.decode(u16); - - switch (extension_type) { - .supported_versions => { - var tls_1_3_supported = false; - const end_idx = try d.decode(u8) + d.idx; - while (d.idx < end_idx) { - if (try d.decode(proto.Version) == proto.Version.tls_1_3) { - tls_1_3_supported = true; - } - } - if (!tls_1_3_supported) return error.TlsProtocolVersion; - }, - .key_share => { - if (extension_len == 0) return error.TlsDecodeError; - key_share_received = true; - var selected_named_group_idx = supported_named_groups.len; - const end_idx = try d.decode(u16) + d.idx; - while (d.idx < end_idx) { - const named_group = try d.decode(proto.NamedGroup); - switch (@intFromEnum(named_group)) { - 0x0001...0x0016, - 0x001a...0x001c, - 0xff01...0xff02, - => return error.TlsIllegalParameter, - else => {}, - } - const client_pub_key = try d.slice(try d.decode(u16)); - for (supported_named_groups, 0..) |supported, idx| { - if (named_group == supported and idx < selected_named_group_idx) { - h.named_group = named_group; - h.client_pub_key = dupe(&h.client_pub_key_buf, client_pub_key); - selected_named_group_idx = idx; - } - } - } - if (@intFromEnum(h.named_group) == 0) - return error.TlsIllegalParameter; - }, - .supported_groups => { - const end_idx = try d.decode(u16) + d.idx; - while (d.idx < end_idx) { - const named_group = try d.decode(proto.NamedGroup); - switch (@intFromEnum(named_group)) { - 0x0001...0x0016, - 0x001a...0x001c, - 0xff01...0xff02, - => return error.TlsIllegalParameter, - else => {}, - } - } - }, - .signature_algorithms => { - if (@intFromEnum(h.signature_scheme) == 0) { - try d.skip(extension_len); - } else { - var found = false; - const list_len = try d.decode(u16); - if (list_len == 0) return error.TlsDecodeError; - const end_idx = list_len + d.idx; - while (d.idx < end_idx) { - const signature_scheme = try d.decode(proto.SignatureScheme); - if (signature_scheme == h.signature_scheme) found = true; - } - if (!found) return error.TlsHandshakeFailure; - } - }, - else => { - try d.skip(extension_len); - }, - } - } - if (!key_share_received) return error.TlsMissingExtension; - if (@intFromEnum(h.named_group) == 0) return error.TlsIllegalParameter; - } - }; -} - -const testing = std.testing; -const data13 = @import("testdata/tls13.zig"); -const testu = @import("testu.zig"); - -fn testReader(data: []const u8) record.Reader(std.io.FixedBufferStream([]const u8)) { - return record.reader(std.io.fixedBufferStream(data)); -} -const TestHandshake = Handshake(std.io.FixedBufferStream([]const u8)); - -test "read client hello" { - var buffer: [1024]u8 = undefined; - var rec_rdr = testReader(&data13.client_hello); - var h = TestHandshake.init(&buffer, &rec_rdr); - h.signature_scheme = .ecdsa_secp521r1_sha512; // this must be supported in signature_algorithms extension - try h.readClientHello(); - - try testing.expectEqual(CipherSuite.AES_256_GCM_SHA384, h.cipher_suite); - try testing.expectEqual(.x25519, h.named_group); - try testing.expectEqualSlices(u8, &data13.client_random, &h.client_random); - try testing.expectEqualSlices(u8, &data13.client_public_key, h.client_pub_key); -} - -test "make server hello" { - var buffer: [128]u8 = undefined; - var h = TestHandshake.init(&buffer, undefined); - h.cipher_suite = .AES_256_GCM_SHA384; - testu.fillFrom(&h.server_random, 0); - testu.fillFrom(&h.server_pub_key_buf, 0x20); - h.named_group = .x25519; - h.server_pub_key = h.server_pub_key_buf[0..32]; - - const actual = try h.makeServerHello(&buffer); - const expected = &testu.hexToBytes( - \\ 16 03 03 00 5a 02 00 00 56 - \\ 03 03 - \\ 00 01 02 03 04 05 06 07 08 09 0a 0b 0c 0d 0e 0f 10 11 12 13 14 15 16 17 18 19 1a 1b 1c 1d 1e 1f - \\ 00 - \\ 13 02 00 - \\ 00 2e 00 2b 00 02 03 04 - \\ 00 33 00 24 00 1d 00 20 - \\ 20 21 22 23 24 25 26 27 28 29 2a 2b 2c 2d 2e 2f 30 31 32 33 34 35 36 37 38 39 3a 3b 3c 3d 3e 3f - ); - try testing.expectEqualSlices(u8, expected, actual); -} - -test "make certificate request" { - var buffer: [32]u8 = undefined; - - const expected = testu.hexToBytes("0d 00 00 1b" ++ // handshake header - "00 00 18" ++ // extension length - "00 0d" ++ // signature algorithms extension - "00 14" ++ // extension length - "00 12" ++ // list length 6 * 2 bytes - "04 03 05 03 08 04 08 05 08 06 08 07 02 01 04 01 05 01" // signature schemes - ); - const actual = try TestHandshake.makeCertificateRequest(&buffer); - try testing.expectEqualSlices(u8, &expected, actual); -} diff --git a/src/http/async/tls.zig/key_log.zig b/src/http/async/tls.zig/key_log.zig deleted file mode 100644 index 2da83f42..00000000 --- a/src/http/async/tls.zig/key_log.zig +++ /dev/null @@ -1,60 +0,0 @@ -//! Exporting tls key so we can share them with Wireshark and analyze decrypted -//! traffic in Wireshark. -//! To configure Wireshark to use exprted keys see curl reference. -//! -//! References: -//! curl: https://everything.curl.dev/usingcurl/tls/sslkeylogfile.html -//! openssl: https://www.openssl.org/docs/manmaster/man3/SSL_CTX_set_keylog_callback.html -//! https://udn.realityripple.com/docs/Mozilla/Projects/NSS/Key_Log_Format - -const std = @import("std"); - -const key_log_file_env = "SSLKEYLOGFILE"; - -pub const label = struct { - // tls 1.3 - pub const client_handshake_traffic_secret: []const u8 = "CLIENT_HANDSHAKE_TRAFFIC_SECRET"; - pub const server_handshake_traffic_secret: []const u8 = "SERVER_HANDSHAKE_TRAFFIC_SECRET"; - pub const client_traffic_secret_0: []const u8 = "CLIENT_TRAFFIC_SECRET_0"; - pub const server_traffic_secret_0: []const u8 = "SERVER_TRAFFIC_SECRET_0"; - // tls 1.2 - pub const client_random: []const u8 = "CLIENT_RANDOM"; -}; - -pub const Callback = *const fn (label: []const u8, client_random: []const u8, secret: []const u8) void; - -/// Writes tls keys to the file pointed by SSLKEYLOGFILE environment variable. -pub fn callback(label_: []const u8, client_random: []const u8, secret: []const u8) void { - if (std.posix.getenv(key_log_file_env)) |file_name| { - fileAppend(file_name, label_, client_random, secret) catch return; - } -} - -pub fn fileAppend(file_name: []const u8, label_: []const u8, client_random: []const u8, secret: []const u8) !void { - var buf: [1024]u8 = undefined; - const line = try formatLine(&buf, label_, client_random, secret); - try fileWrite(file_name, line); -} - -fn fileWrite(file_name: []const u8, line: []const u8) !void { - var file = try std.fs.createFileAbsolute(file_name, .{ .truncate = false }); - defer file.close(); - const stat = try file.stat(); - try file.seekTo(stat.size); - try file.writeAll(line); -} - -pub fn formatLine(buf: []u8, label_: []const u8, client_random: []const u8, secret: []const u8) ![]const u8 { - var fbs = std.io.fixedBufferStream(buf); - const w = fbs.writer(); - try w.print("{s} ", .{label_}); - for (client_random) |b| { - try std.fmt.formatInt(b, 16, .lower, .{ .width = 2, .fill = '0' }, w); - } - try w.writeByte(' '); - for (secret) |b| { - try std.fmt.formatInt(b, 16, .lower, .{ .width = 2, .fill = '0' }, w); - } - try w.writeByte('\n'); - return fbs.getWritten(); -} diff --git a/src/http/async/tls.zig/main.zig b/src/http/async/tls.zig/main.zig deleted file mode 100644 index b974377b..00000000 --- a/src/http/async/tls.zig/main.zig +++ /dev/null @@ -1,51 +0,0 @@ -const std = @import("std"); - -pub const CipherSuite = @import("cipher.zig").CipherSuite; -pub const cipher_suites = @import("cipher.zig").cipher_suites; -pub const PrivateKey = @import("PrivateKey.zig"); -pub const Connection = @import("connection.zig").Connection; -pub const ClientOptions = @import("handshake_client.zig").Options; -pub const ServerOptions = @import("handshake_server.zig").Options; -pub const key_log = @import("key_log.zig"); -pub const proto = @import("protocol.zig"); -pub const NamedGroup = proto.NamedGroup; -pub const Version = proto.Version; -const common = @import("handshake_common.zig"); -pub const CertBundle = common.CertBundle; -pub const CertKeyPair = common.CertKeyPair; - -pub const record = @import("record.zig"); -const connection = @import("connection.zig").connection; -const max_ciphertext_record_len = @import("cipher.zig").max_ciphertext_record_len; -const HandshakeServer = @import("handshake_server.zig").Handshake; -const HandshakeClient = @import("handshake_client.zig").Handshake; - -pub fn client(stream: anytype, opt: ClientOptions) !Connection(@TypeOf(stream)) { - const Stream = @TypeOf(stream); - var conn = connection(stream); - var write_buf: [max_ciphertext_record_len]u8 = undefined; - var h = HandshakeClient(Stream).init(&write_buf, &conn.rec_rdr); - conn.cipher = try h.handshake(conn.stream, opt); - return conn; -} - -pub fn server(stream: anytype, opt: ServerOptions) !Connection(@TypeOf(stream)) { - const Stream = @TypeOf(stream); - var conn = connection(stream); - var write_buf: [max_ciphertext_record_len]u8 = undefined; - var h = HandshakeServer(Stream).init(&write_buf, &conn.rec_rdr); - conn.cipher = try h.handshake(conn.stream, opt); - return conn; -} - -test { - _ = @import("handshake_common.zig"); - _ = @import("handshake_server.zig"); - _ = @import("handshake_client.zig"); - - _ = @import("connection.zig"); - _ = @import("cipher.zig"); - _ = @import("record.zig"); - _ = @import("transcript.zig"); - _ = @import("PrivateKey.zig"); -} diff --git a/src/http/async/tls.zig/protocol.zig b/src/http/async/tls.zig/protocol.zig deleted file mode 100644 index e3bb07ac..00000000 --- a/src/http/async/tls.zig/protocol.zig +++ /dev/null @@ -1,302 +0,0 @@ -pub const Version = enum(u16) { - tls_1_2 = 0x0303, - tls_1_3 = 0x0304, - _, -}; - -pub const ContentType = enum(u8) { - invalid = 0, - change_cipher_spec = 20, - alert = 21, - handshake = 22, - application_data = 23, - _, -}; - -pub const Handshake = enum(u8) { - client_hello = 1, - server_hello = 2, - new_session_ticket = 4, - end_of_early_data = 5, - encrypted_extensions = 8, - certificate = 11, - server_key_exchange = 12, - certificate_request = 13, - server_hello_done = 14, - certificate_verify = 15, - client_key_exchange = 16, - finished = 20, - key_update = 24, - message_hash = 254, - _, -}; - -pub const Curve = enum(u8) { - named_curve = 0x03, - _, -}; - -pub const Extension = enum(u16) { - /// RFC 6066 - server_name = 0, - /// RFC 6066 - max_fragment_length = 1, - /// RFC 6066 - status_request = 5, - /// RFC 8422, 7919 - supported_groups = 10, - /// RFC 8446 - signature_algorithms = 13, - /// RFC 5764 - use_srtp = 14, - /// RFC 6520 - heartbeat = 15, - /// RFC 7301 - application_layer_protocol_negotiation = 16, - /// RFC 6962 - signed_certificate_timestamp = 18, - /// RFC 7250 - client_certificate_type = 19, - /// RFC 7250 - server_certificate_type = 20, - /// RFC 7685 - padding = 21, - /// RFC 8446 - pre_shared_key = 41, - /// RFC 8446 - early_data = 42, - /// RFC 8446 - supported_versions = 43, - /// RFC 8446 - cookie = 44, - /// RFC 8446 - psk_key_exchange_modes = 45, - /// RFC 8446 - certificate_authorities = 47, - /// RFC 8446 - oid_filters = 48, - /// RFC 8446 - post_handshake_auth = 49, - /// RFC 8446 - signature_algorithms_cert = 50, - /// RFC 8446 - key_share = 51, - - _, -}; - -pub fn alertFromError(err: anyerror) [2]u8 { - return [2]u8{ @intFromEnum(Alert.Level.fatal), @intFromEnum(Alert.fromError(err)) }; -} - -pub const Alert = enum(u8) { - pub const Level = enum(u8) { - warning = 1, - fatal = 2, - _, - }; - - pub const Error = error{ - TlsAlertUnexpectedMessage, - TlsAlertBadRecordMac, - TlsAlertRecordOverflow, - TlsAlertHandshakeFailure, - TlsAlertBadCertificate, - TlsAlertUnsupportedCertificate, - TlsAlertCertificateRevoked, - TlsAlertCertificateExpired, - TlsAlertCertificateUnknown, - TlsAlertIllegalParameter, - TlsAlertUnknownCa, - TlsAlertAccessDenied, - TlsAlertDecodeError, - TlsAlertDecryptError, - TlsAlertProtocolVersion, - TlsAlertInsufficientSecurity, - TlsAlertInternalError, - TlsAlertInappropriateFallback, - TlsAlertMissingExtension, - TlsAlertUnsupportedExtension, - TlsAlertUnrecognizedName, - TlsAlertBadCertificateStatusResponse, - TlsAlertUnknownPskIdentity, - TlsAlertCertificateRequired, - TlsAlertNoApplicationProtocol, - TlsAlertUnknown, - }; - - close_notify = 0, - unexpected_message = 10, - bad_record_mac = 20, - record_overflow = 22, - handshake_failure = 40, - bad_certificate = 42, - unsupported_certificate = 43, - certificate_revoked = 44, - certificate_expired = 45, - certificate_unknown = 46, - illegal_parameter = 47, - unknown_ca = 48, - access_denied = 49, - decode_error = 50, - decrypt_error = 51, - protocol_version = 70, - insufficient_security = 71, - internal_error = 80, - inappropriate_fallback = 86, - user_canceled = 90, - missing_extension = 109, - unsupported_extension = 110, - unrecognized_name = 112, - bad_certificate_status_response = 113, - unknown_psk_identity = 115, - certificate_required = 116, - no_application_protocol = 120, - _, - - pub fn toError(alert: Alert) Error!void { - return switch (alert) { - .close_notify => {}, // not an error - .unexpected_message => error.TlsAlertUnexpectedMessage, - .bad_record_mac => error.TlsAlertBadRecordMac, - .record_overflow => error.TlsAlertRecordOverflow, - .handshake_failure => error.TlsAlertHandshakeFailure, - .bad_certificate => error.TlsAlertBadCertificate, - .unsupported_certificate => error.TlsAlertUnsupportedCertificate, - .certificate_revoked => error.TlsAlertCertificateRevoked, - .certificate_expired => error.TlsAlertCertificateExpired, - .certificate_unknown => error.TlsAlertCertificateUnknown, - .illegal_parameter => error.TlsAlertIllegalParameter, - .unknown_ca => error.TlsAlertUnknownCa, - .access_denied => error.TlsAlertAccessDenied, - .decode_error => error.TlsAlertDecodeError, - .decrypt_error => error.TlsAlertDecryptError, - .protocol_version => error.TlsAlertProtocolVersion, - .insufficient_security => error.TlsAlertInsufficientSecurity, - .internal_error => error.TlsAlertInternalError, - .inappropriate_fallback => error.TlsAlertInappropriateFallback, - .user_canceled => {}, // not an error - .missing_extension => error.TlsAlertMissingExtension, - .unsupported_extension => error.TlsAlertUnsupportedExtension, - .unrecognized_name => error.TlsAlertUnrecognizedName, - .bad_certificate_status_response => error.TlsAlertBadCertificateStatusResponse, - .unknown_psk_identity => error.TlsAlertUnknownPskIdentity, - .certificate_required => error.TlsAlertCertificateRequired, - .no_application_protocol => error.TlsAlertNoApplicationProtocol, - _ => error.TlsAlertUnknown, - }; - } - - pub fn fromError(err: anyerror) Alert { - return switch (err) { - error.TlsUnexpectedMessage => .unexpected_message, - error.TlsBadRecordMac => .bad_record_mac, - error.TlsRecordOverflow => .record_overflow, - error.TlsHandshakeFailure => .handshake_failure, - error.TlsBadCertificate => .bad_certificate, - error.TlsUnsupportedCertificate => .unsupported_certificate, - error.TlsCertificateRevoked => .certificate_revoked, - error.TlsCertificateExpired => .certificate_expired, - error.TlsCertificateUnknown => .certificate_unknown, - error.TlsIllegalParameter, - error.IdentityElement, - error.InvalidEncoding, - => .illegal_parameter, - error.TlsUnknownCa => .unknown_ca, - error.TlsAccessDenied => .access_denied, - error.TlsDecodeError => .decode_error, - error.TlsDecryptError => .decrypt_error, - error.TlsProtocolVersion => .protocol_version, - error.TlsInsufficientSecurity => .insufficient_security, - error.TlsInternalError => .internal_error, - error.TlsInappropriateFallback => .inappropriate_fallback, - error.TlsMissingExtension => .missing_extension, - error.TlsUnsupportedExtension => .unsupported_extension, - error.TlsUnrecognizedName => .unrecognized_name, - error.TlsBadCertificateStatusResponse => .bad_certificate_status_response, - error.TlsUnknownPskIdentity => .unknown_psk_identity, - error.TlsCertificateRequired => .certificate_required, - error.TlsNoApplicationProtocol => .no_application_protocol, - else => .internal_error, - }; - } - - pub fn parse(buf: [2]u8) Alert { - const level: Alert.Level = @enumFromInt(buf[0]); - const alert: Alert = @enumFromInt(buf[1]); - _ = level; - return alert; - } - - pub fn closeNotify() [2]u8 { - return [2]u8{ - @intFromEnum(Alert.Level.warning), - @intFromEnum(Alert.close_notify), - }; - } -}; - -pub const SignatureScheme = enum(u16) { - // RSASSA-PKCS1-v1_5 algorithms - rsa_pkcs1_sha256 = 0x0401, - rsa_pkcs1_sha384 = 0x0501, - rsa_pkcs1_sha512 = 0x0601, - - // ECDSA algorithms - ecdsa_secp256r1_sha256 = 0x0403, - ecdsa_secp384r1_sha384 = 0x0503, - ecdsa_secp521r1_sha512 = 0x0603, - - // RSASSA-PSS algorithms with public key OID rsaEncryption - rsa_pss_rsae_sha256 = 0x0804, - rsa_pss_rsae_sha384 = 0x0805, - rsa_pss_rsae_sha512 = 0x0806, - - // EdDSA algorithms - ed25519 = 0x0807, - ed448 = 0x0808, - - // RSASSA-PSS algorithms with public key OID RSASSA-PSS - rsa_pss_pss_sha256 = 0x0809, - rsa_pss_pss_sha384 = 0x080a, - rsa_pss_pss_sha512 = 0x080b, - - // Legacy algorithms - rsa_pkcs1_sha1 = 0x0201, - ecdsa_sha1 = 0x0203, - - _, -}; - -pub const NamedGroup = enum(u16) { - // Elliptic Curve Groups (ECDHE) - secp256r1 = 0x0017, - secp384r1 = 0x0018, - secp521r1 = 0x0019, - x25519 = 0x001D, - x448 = 0x001E, - - // Finite Field Groups (DHE) - ffdhe2048 = 0x0100, - ffdhe3072 = 0x0101, - ffdhe4096 = 0x0102, - ffdhe6144 = 0x0103, - ffdhe8192 = 0x0104, - - // Hybrid post-quantum key agreements - x25519_kyber512d00 = 0xFE30, - x25519_kyber768d00 = 0x6399, - - _, -}; - -pub const KeyUpdateRequest = enum(u8) { - update_not_requested = 0, - update_requested = 1, - _, -}; - -pub const Side = enum { - client, - server, -}; diff --git a/src/http/async/tls.zig/record.zig b/src/http/async/tls.zig/record.zig deleted file mode 100644 index 6c4df328..00000000 --- a/src/http/async/tls.zig/record.zig +++ /dev/null @@ -1,405 +0,0 @@ -const std = @import("std"); -const assert = std.debug.assert; -const mem = std.mem; - -const proto = @import("protocol.zig"); -const cipher = @import("cipher.zig"); -const Cipher = cipher.Cipher; -const record = @import("record.zig"); - -pub const header_len = 5; - -pub fn header(content_type: proto.ContentType, payload_len: usize) [header_len]u8 { - const int2 = std.crypto.tls.int2; - return [1]u8{@intFromEnum(content_type)} ++ - int2(@intFromEnum(proto.Version.tls_1_2)) ++ - int2(@intCast(payload_len)); -} - -pub fn handshakeHeader(handshake_type: proto.Handshake, payload_len: usize) [4]u8 { - const int3 = std.crypto.tls.int3; - return [1]u8{@intFromEnum(handshake_type)} ++ int3(@intCast(payload_len)); -} - -pub fn reader(inner_reader: anytype) Reader(@TypeOf(inner_reader)) { - return .{ .inner_reader = inner_reader }; -} - -pub fn Reader(comptime InnerReader: type) type { - return struct { - inner_reader: InnerReader, - - buffer: [cipher.max_ciphertext_record_len]u8 = undefined, - start: usize = 0, - end: usize = 0, - - const ReaderT = @This(); - - pub fn nextDecoder(r: *ReaderT) !Decoder { - const rec = (try r.next()) orelse return error.EndOfStream; - if (@intFromEnum(rec.protocol_version) != 0x0300 and - @intFromEnum(rec.protocol_version) != 0x0301 and - rec.protocol_version != .tls_1_2) - return error.TlsBadVersion; - return .{ - .content_type = rec.content_type, - .payload = rec.payload, - }; - } - - pub fn contentType(buf: []const u8) proto.ContentType { - return @enumFromInt(buf[0]); - } - - pub fn protocolVersion(buf: []const u8) proto.Version { - return @enumFromInt(mem.readInt(u16, buf[1..3], .big)); - } - - pub fn next(r: *ReaderT) !?Record { - while (true) { - const buffer = r.buffer[r.start..r.end]; - // If we have 5 bytes header. - if (buffer.len >= record.header_len) { - const record_header = buffer[0..record.header_len]; - const payload_len = mem.readInt(u16, record_header[3..5], .big); - if (payload_len > cipher.max_ciphertext_len) - return error.TlsRecordOverflow; - const record_len = record.header_len + payload_len; - // If we have whole record - if (buffer.len >= record_len) { - r.start += record_len; - return Record.init(buffer[0..record_len]); - } - } - { // Move dirty part to the start of the buffer. - const n = r.end - r.start; - if (n > 0 and r.start > 0) { - if (r.start > n) { - @memcpy(r.buffer[0..n], r.buffer[r.start..][0..n]); - } else { - mem.copyForwards(u8, r.buffer[0..n], r.buffer[r.start..][0..n]); - } - } - r.start = 0; - r.end = n; - } - { // Read more from inner_reader. - const n = try r.inner_reader.read(r.buffer[r.end..]); - if (n == 0) return null; - r.end += n; - } - } - } - - pub fn nextDecrypt(r: *ReaderT, cph: *Cipher) !?struct { proto.ContentType, []const u8 } { - const rec = (try r.next()) orelse return null; - if (rec.protocol_version != .tls_1_2) return error.TlsBadVersion; - - return try cph.decrypt( - // Reuse reader buffer for cleartext. `rec.header` and - // `rec.payload`(ciphertext) are also pointing somewhere in - // this buffer. Decrypter is first reading then writing a - // block, cleartext has less length then ciphertext, - // cleartext starts from the beginning of the buffer, so - // ciphertext is always ahead of cleartext. - r.buffer[0..r.start], - rec, - ); - } - - pub fn hasMore(r: *ReaderT) bool { - return r.end > r.start; - } - }; -} - -pub const Record = struct { - content_type: proto.ContentType, - protocol_version: proto.Version = .tls_1_2, - header: []const u8, - payload: []const u8, - - pub fn init(buffer: []const u8) Record { - return .{ - .content_type = @enumFromInt(buffer[0]), - .protocol_version = @enumFromInt(mem.readInt(u16, buffer[1..3], .big)), - .header = buffer[0..record.header_len], - .payload = buffer[record.header_len..], - }; - } - - pub fn decoder(r: @This()) Decoder { - return Decoder.init(r.content_type, @constCast(r.payload)); - } -}; - -pub const Decoder = struct { - content_type: proto.ContentType, - payload: []const u8, - idx: usize = 0, - - pub fn init(content_type: proto.ContentType, payload: []u8) Decoder { - return .{ - .content_type = content_type, - .payload = payload, - }; - } - - pub fn decode(d: *Decoder, comptime T: type) !T { - switch (@typeInfo(T)) { - .Int => |info| switch (info.bits) { - 8 => { - try skip(d, 1); - return d.payload[d.idx - 1]; - }, - 16 => { - try skip(d, 2); - const b0: u16 = d.payload[d.idx - 2]; - const b1: u16 = d.payload[d.idx - 1]; - return (b0 << 8) | b1; - }, - 24 => { - try skip(d, 3); - const b0: u24 = d.payload[d.idx - 3]; - const b1: u24 = d.payload[d.idx - 2]; - const b2: u24 = d.payload[d.idx - 1]; - return (b0 << 16) | (b1 << 8) | b2; - }, - else => @compileError("unsupported int type: " ++ @typeName(T)), - }, - .Enum => |info| { - const int = try d.decode(info.tag_type); - if (info.is_exhaustive) @compileError("exhaustive enum cannot be used"); - return @as(T, @enumFromInt(int)); - }, - else => @compileError("unsupported type: " ++ @typeName(T)), - } - } - - pub fn array(d: *Decoder, comptime len: usize) ![len]u8 { - try d.skip(len); - return d.payload[d.idx - len ..][0..len].*; - } - - pub fn slice(d: *Decoder, len: usize) ![]const u8 { - try d.skip(len); - return d.payload[d.idx - len ..][0..len]; - } - - pub fn skip(d: *Decoder, amt: usize) !void { - if (d.idx + amt > d.payload.len) return error.TlsDecodeError; - d.idx += amt; - } - - pub fn rest(d: Decoder) []const u8 { - return d.payload[d.idx..]; - } - - pub fn eof(d: Decoder) bool { - return d.idx == d.payload.len; - } - - pub fn expectContentType(d: *Decoder, content_type: proto.ContentType) !void { - if (d.content_type == content_type) return; - - switch (d.content_type) { - .alert => try d.raiseAlert(), - else => return error.TlsUnexpectedMessage, - } - } - - pub fn raiseAlert(d: *Decoder) !void { - if (d.payload.len < 2) return error.TlsUnexpectedMessage; - try proto.Alert.parse(try d.array(2)).toError(); - return error.TlsAlertCloseNotify; - } -}; - -const testing = std.testing; -const data12 = @import("testdata/tls12.zig"); -const testu = @import("testu.zig"); -const CipherSuite = @import("cipher.zig").CipherSuite; - -test Reader { - var fbs = std.io.fixedBufferStream(&data12.server_responses); - var rdr = reader(fbs.reader()); - - const expected = [_]struct { - content_type: proto.ContentType, - payload_len: usize, - }{ - .{ .content_type = .handshake, .payload_len = 49 }, - .{ .content_type = .handshake, .payload_len = 815 }, - .{ .content_type = .handshake, .payload_len = 300 }, - .{ .content_type = .handshake, .payload_len = 4 }, - .{ .content_type = .change_cipher_spec, .payload_len = 1 }, - .{ .content_type = .handshake, .payload_len = 64 }, - }; - for (expected) |e| { - const rec = (try rdr.next()).?; - try testing.expectEqual(e.content_type, rec.content_type); - try testing.expectEqual(e.payload_len, rec.payload.len); - try testing.expectEqual(.tls_1_2, rec.protocol_version); - } -} - -test Decoder { - var fbs = std.io.fixedBufferStream(&data12.server_responses); - var rdr = reader(fbs.reader()); - - var d = (try rdr.nextDecoder()); - try testing.expectEqual(.handshake, d.content_type); - - try testing.expectEqual(.server_hello, try d.decode(proto.Handshake)); - try testing.expectEqual(45, try d.decode(u24)); // length - try testing.expectEqual(.tls_1_2, try d.decode(proto.Version)); - try testing.expectEqualStrings( - &testu.hexToBytes("707172737475767778797a7b7c7d7e7f808182838485868788898a8b8c8d8e8f"), - try d.slice(32), - ); // server random - try testing.expectEqual(0, try d.decode(u8)); // session id len - try testing.expectEqual(.ECDHE_RSA_WITH_AES_128_CBC_SHA, try d.decode(CipherSuite)); - try testing.expectEqual(0, try d.decode(u8)); // compression method - try testing.expectEqual(5, try d.decode(u16)); // extension length - try testing.expectEqual(5, d.rest().len); - try d.skip(5); - try testing.expect(d.eof()); -} - -pub const Writer = struct { - buf: []u8, - pos: usize = 0, - - pub fn write(self: *Writer, data: []const u8) !void { - defer self.pos += data.len; - if (self.pos + data.len > self.buf.len) return error.BufferOverflow; - @memcpy(self.buf[self.pos..][0..data.len], data); - } - - pub fn writeByte(self: *Writer, b: u8) !void { - defer self.pos += 1; - if (self.pos == self.buf.len) return error.BufferOverflow; - self.buf[self.pos] = b; - } - - pub fn writeEnum(self: *Writer, value: anytype) !void { - try self.writeInt(@intFromEnum(value)); - } - - pub fn writeInt(self: *Writer, value: anytype) !void { - const IntT = @TypeOf(value); - const bytes = @divExact(@typeInfo(IntT).Int.bits, 8); - const free = self.buf[self.pos..]; - if (free.len < bytes) return error.BufferOverflow; - mem.writeInt(IntT, free[0..bytes], value, .big); - self.pos += bytes; - } - - pub fn writeHandshakeHeader(self: *Writer, handshake_type: proto.Handshake, payload_len: usize) !void { - try self.write(&record.handshakeHeader(handshake_type, payload_len)); - } - - /// Should be used after writing handshake payload in buffer provided by `getHandshakePayload`. - pub fn advanceHandshake(self: *Writer, handshake_type: proto.Handshake, payload_len: usize) !void { - try self.write(&record.handshakeHeader(handshake_type, payload_len)); - self.pos += payload_len; - } - - /// Record payload is already written by using buffer space from `getPayload`. - /// Now when we know payload len we can write record header and advance over payload. - pub fn advanceRecord(self: *Writer, content_type: proto.ContentType, payload_len: usize) !void { - try self.write(&record.header(content_type, payload_len)); - self.pos += payload_len; - } - - pub fn writeRecord(self: *Writer, content_type: proto.ContentType, payload: []const u8) !void { - try self.write(&record.header(content_type, payload.len)); - try self.write(payload); - } - - /// Preserves space for record header and returns buffer free space. - pub fn getPayload(self: *Writer) []u8 { - return self.buf[self.pos + record.header_len ..]; - } - - /// Preserves space for handshake header and returns buffer free space. - pub fn getHandshakePayload(self: *Writer) []u8 { - return self.buf[self.pos + 4 ..]; - } - - pub fn getWritten(self: *Writer) []const u8 { - return self.buf[0..self.pos]; - } - - pub fn getFree(self: *Writer) []u8 { - return self.buf[self.pos..]; - } - - pub fn writeEnumArray(self: *Writer, comptime E: type, tags: []const E) !void { - assert(@sizeOf(E) == 2); - try self.writeInt(@as(u16, @intCast(tags.len * 2))); - for (tags) |t| { - try self.writeEnum(t); - } - } - - pub fn writeExtension( - self: *Writer, - comptime et: proto.Extension, - tags: anytype, - ) !void { - try self.writeEnum(et); - if (et == .supported_versions) { - try self.writeInt(@as(u16, @intCast(tags.len * 2 + 1))); - try self.writeInt(@as(u8, @intCast(tags.len * 2))); - } else { - try self.writeInt(@as(u16, @intCast(tags.len * 2 + 2))); - try self.writeInt(@as(u16, @intCast(tags.len * 2))); - } - for (tags) |t| { - try self.writeEnum(t); - } - } - - pub fn writeKeyShare( - self: *Writer, - named_groups: []const proto.NamedGroup, - keys: []const []const u8, - ) !void { - assert(named_groups.len == keys.len); - try self.writeEnum(proto.Extension.key_share); - var l: usize = 0; - for (keys) |key| { - l += key.len + 4; - } - try self.writeInt(@as(u16, @intCast(l + 2))); - try self.writeInt(@as(u16, @intCast(l))); - for (named_groups, 0..) |ng, i| { - const key = keys[i]; - try self.writeEnum(ng); - try self.writeInt(@as(u16, @intCast(key.len))); - try self.write(key); - } - } - - pub fn writeServerName(self: *Writer, host: []const u8) !void { - const host_len: u16 = @intCast(host.len); - try self.writeEnum(proto.Extension.server_name); - try self.writeInt(host_len + 5); // byte length of extension payload - try self.writeInt(host_len + 3); // server_name_list byte count - try self.writeByte(0); // name type - try self.writeInt(host_len); - try self.write(host); - } -}; - -test "Writer" { - var buf: [16]u8 = undefined; - var w = Writer{ .buf = &buf }; - - try w.write("ab"); - try w.writeEnum(proto.Curve.named_curve); - try w.writeEnum(proto.NamedGroup.x25519); - try w.writeInt(@as(u16, 0x1234)); - try testing.expectEqualSlices(u8, &[_]u8{ 'a', 'b', 0x03, 0x00, 0x1d, 0x12, 0x34 }, w.getWritten()); -} diff --git a/src/http/async/tls.zig/rsa/der.zig b/src/http/async/tls.zig/rsa/der.zig deleted file mode 100644 index 743a65ad..00000000 --- a/src/http/async/tls.zig/rsa/der.zig +++ /dev/null @@ -1,467 +0,0 @@ -//! An encoding of ASN.1. -//! -//! Distinguised Encoding Rules as defined in X.690 and X.691. -//! -//! A version of Basic Encoding Rules (BER) where there is exactly ONE way to -//! represent non-constructed elements. This is useful for cryptographic signatures. -//! -//! Currently an implementation detail of the standard library not fit for public -//! use since it's missing an encoder. - -const std = @import("std"); -const builtin = @import("builtin"); - -pub const Index = usize; -const log = std.log.scoped(.der); - -/// A secure DER parser that: -/// - Does NOT read memory outside `bytes`. -/// - Does NOT return elements with slices outside `bytes`. -/// - Errors on values that do NOT follow DER rules. -/// - Lengths that could be represented in a shorter form. -/// - Booleans that are not 0xff or 0x00. -pub const Parser = struct { - bytes: []const u8, - index: Index = 0, - - pub const Error = Element.Error || error{ - UnexpectedElement, - InvalidIntegerEncoding, - Overflow, - NonCanonical, - }; - - pub fn expectBool(self: *Parser) Error!bool { - const ele = try self.expect(.universal, false, .boolean); - if (ele.slice.len() != 1) return error.InvalidBool; - - return switch (self.view(ele)[0]) { - 0x00 => false, - 0xff => true, - else => error.InvalidBool, - }; - } - - pub fn expectBitstring(self: *Parser) Error!BitString { - const ele = try self.expect(.universal, false, .bitstring); - const bytes = self.view(ele); - const right_padding = bytes[0]; - if (right_padding >= 8) return error.InvalidBitString; - return .{ - .bytes = bytes[1..], - .right_padding = @intCast(right_padding), - }; - } - - // TODO: return high resolution date time type instead of epoch seconds - pub fn expectDateTime(self: *Parser) Error!i64 { - const ele = try self.expect(.universal, false, null); - const bytes = self.view(ele); - switch (ele.identifier.tag) { - .utc_time => { - // Example: "YYMMDD000000Z" - if (bytes.len != 13) - return error.InvalidDateTime; - if (bytes[12] != 'Z') - return error.InvalidDateTime; - - var date: Date = undefined; - date.year = try parseTimeDigits(bytes[0..2], 0, 99); - date.year += if (date.year >= 50) 1900 else 2000; - date.month = try parseTimeDigits(bytes[2..4], 1, 12); - date.day = try parseTimeDigits(bytes[4..6], 1, 31); - const time = try parseTime(bytes[6..12]); - - return date.toEpochSeconds() + time.toSec(); - }, - .generalized_time => { - // Examples: - // "19920622123421Z" - // "19920722132100.3Z" - if (bytes.len < 15) - return error.InvalidDateTime; - - var date: Date = undefined; - date.year = try parseYear4(bytes[0..4]); - date.month = try parseTimeDigits(bytes[4..6], 1, 12); - date.day = try parseTimeDigits(bytes[6..8], 1, 31); - const time = try parseTime(bytes[8..14]); - - return date.toEpochSeconds() + time.toSec(); - }, - else => return error.InvalidDateTime, - } - } - - pub fn expectOid(self: *Parser) Error![]const u8 { - const oid = try self.expect(.universal, false, .object_identifier); - return self.view(oid); - } - - pub fn expectEnum(self: *Parser, comptime Enum: type) Error!Enum { - const oid = try self.expectOid(); - return Enum.oids.get(oid) orelse { - if (builtin.mode == .Debug) { - var buf: [256]u8 = undefined; - var stream = std.io.fixedBufferStream(&buf); - try @import("./oid.zig").decode(oid, stream.writer()); - log.warn("unknown oid {s} for enum {s}\n", .{ stream.getWritten(), @typeName(Enum) }); - } - return error.UnknownObjectId; - }; - } - - pub fn expectInt(self: *Parser, comptime T: type) Error!T { - const ele = try self.expectPrimitive(.integer); - const bytes = self.view(ele); - - const info = @typeInfo(T); - if (info != .Int) @compileError(@typeName(T) ++ " is not an int type"); - const Shift = std.math.Log2Int(u8); - - var result: std.meta.Int(.unsigned, info.Int.bits) = 0; - for (bytes, 0..) |b, index| { - const shifted = @shlWithOverflow(b, @as(Shift, @intCast(index * 8))); - if (shifted[1] == 1) return error.Overflow; - - result |= shifted[0]; - } - - return @bitCast(result); - } - - pub fn expectString(self: *Parser, allowed: std.EnumSet(String.Tag)) Error!String { - const ele = try self.expect(.universal, false, null); - switch (ele.identifier.tag) { - inline .string_utf8, - .string_numeric, - .string_printable, - .string_teletex, - .string_videotex, - .string_ia5, - .string_visible, - .string_universal, - .string_bmp, - => |t| { - const tagname = @tagName(t)["string_".len..]; - const tag = std.meta.stringToEnum(String.Tag, tagname) orelse unreachable; - if (allowed.contains(tag)) { - return String{ .tag = tag, .data = self.view(ele) }; - } - }, - else => {}, - } - return error.UnexpectedElement; - } - - pub fn expectPrimitive(self: *Parser, tag: ?Identifier.Tag) Error!Element { - var elem = try self.expect(.universal, false, tag); - if (tag == .integer and elem.slice.len() > 0) { - if (self.view(elem)[0] == 0) elem.slice.start += 1; - if (elem.slice.len() > 0 and self.view(elem)[0] == 0) return error.InvalidIntegerEncoding; - } - return elem; - } - - /// Remember to call `expectEnd` - pub fn expectSequence(self: *Parser) Error!Element { - return try self.expect(.universal, true, .sequence); - } - - /// Remember to call `expectEnd` - pub fn expectSequenceOf(self: *Parser) Error!Element { - return try self.expect(.universal, true, .sequence_of); - } - - pub fn expectEnd(self: *Parser, val: usize) Error!void { - if (self.index != val) return error.NonCanonical; // either forgot to parse end OR an attacker - } - - pub fn expect( - self: *Parser, - class: ?Identifier.Class, - constructed: ?bool, - tag: ?Identifier.Tag, - ) Error!Element { - if (self.index >= self.bytes.len) return error.EndOfStream; - - const res = try Element.init(self.bytes, self.index); - if (tag) |e| { - if (res.identifier.tag != e) return error.UnexpectedElement; - } - if (constructed) |e| { - if (res.identifier.constructed != e) return error.UnexpectedElement; - } - if (class) |e| { - if (res.identifier.class != e) return error.UnexpectedElement; - } - self.index = if (res.identifier.constructed) res.slice.start else res.slice.end; - return res; - } - - pub fn view(self: Parser, elem: Element) []const u8 { - return elem.slice.view(self.bytes); - } - - pub fn seek(self: *Parser, index: usize) void { - self.index = index; - } - - pub fn eof(self: *Parser) bool { - return self.index == self.bytes.len; - } -}; - -pub const Element = struct { - identifier: Identifier, - slice: Slice, - - pub const Slice = struct { - start: Index, - end: Index, - - pub fn len(self: Slice) Index { - return self.end - self.start; - } - - pub fn view(self: Slice, bytes: []const u8) []const u8 { - return bytes[self.start..self.end]; - } - }; - - pub const Error = error{ InvalidLength, EndOfStream }; - - pub fn init(bytes: []const u8, index: Index) Error!Element { - var stream = std.io.fixedBufferStream(bytes[index..]); - var reader = stream.reader(); - - const identifier = @as(Identifier, @bitCast(try reader.readByte())); - const size_or_len_size = try reader.readByte(); - - var start = index + 2; - // short form between 0-127 - if (size_or_len_size < 128) { - const end = start + size_or_len_size; - if (end > bytes.len) return error.InvalidLength; - - return .{ .identifier = identifier, .slice = .{ .start = start, .end = end } }; - } - - // long form between 0 and std.math.maxInt(u1024) - const len_size: u7 = @truncate(size_or_len_size); - start += len_size; - if (len_size > @sizeOf(Index)) return error.InvalidLength; - const len = try reader.readVarInt(Index, .big, len_size); - if (len < 128) return error.InvalidLength; // should have used short form - - const end = std.math.add(Index, start, len) catch return error.InvalidLength; - if (end > bytes.len) return error.InvalidLength; - - return .{ .identifier = identifier, .slice = .{ .start = start, .end = end } }; - } -}; - -test Element { - const short_form = [_]u8{ 0x30, 0x03, 0x02, 0x01, 0x09 }; - try std.testing.expectEqual(Element{ - .identifier = Identifier{ .tag = .sequence, .constructed = true, .class = .universal }, - .slice = .{ .start = 2, .end = short_form.len }, - }, Element.init(&short_form, 0)); - - const long_form = [_]u8{ 0x30, 129, 129 } ++ [_]u8{0} ** 129; - try std.testing.expectEqual(Element{ - .identifier = Identifier{ .tag = .sequence, .constructed = true, .class = .universal }, - .slice = .{ .start = 3, .end = long_form.len }, - }, Element.init(&long_form, 0)); -} - -test "parser.expectInt" { - const one = [_]u8{ 2, 1, 1 }; - var parser = Parser{ .bytes = &one }; - try std.testing.expectEqual(@as(u8, 1), try parser.expectInt(u8)); -} - -pub const Identifier = packed struct(u8) { - tag: Tag, - constructed: bool, - class: Class, - - pub const Class = enum(u2) { - universal, - application, - context_specific, - private, - }; - - // https://www.oss.com/asn1/resources/asn1-made-simple/asn1-quick-reference/asn1-tags.html - pub const Tag = enum(u5) { - boolean = 1, - integer = 2, - bitstring = 3, - octetstring = 4, - null = 5, - object_identifier = 6, - real = 9, - enumerated = 10, - string_utf8 = 12, - sequence = 16, - sequence_of = 17, - string_numeric = 18, - string_printable = 19, - string_teletex = 20, - string_videotex = 21, - string_ia5 = 22, - utc_time = 23, - generalized_time = 24, - string_visible = 26, - string_universal = 28, - string_bmp = 30, - _, - }; -}; - -pub const BitString = struct { - bytes: []const u8, - right_padding: u3, - - pub fn bitLen(self: BitString) usize { - return self.bytes.len * 8 + self.right_padding; - } -}; - -pub const String = struct { - tag: Tag, - data: []const u8, - - pub const Tag = enum { - /// Blessed. - utf8, - /// us-ascii ([-][0-9][eE][.])* - numeric, - /// us-ascii ([A-Z][a-z][0-9][.?!,][ \t])* - printable, - /// iso-8859-1 with escaping into different character sets. - /// Cursed. - teletex, - /// iso-8859-1 - videotex, - /// us-ascii first 128 characters. - ia5, - /// us-ascii without control characters. - visible, - /// utf-32-be - universal, - /// utf-16-be - bmp, - }; - - pub const all = [_]Tag{ - .utf8, - .numeric, - .printable, - .teletex, - .videotex, - .ia5, - .visible, - .universal, - .bmp, - }; -}; - -const Date = struct { - year: Year, - month: u8, - day: u8, - - const Year = std.time.epoch.Year; - - fn toEpochSeconds(date: Date) i64 { - // Euclidean Affine Transform by Cassio and Neri. - // Shift and correction constants for 1970-01-01. - const s = 82; - const K = 719468 + 146097 * s; - const L = 400 * s; - - const Y_G: u32 = date.year; - const M_G: u32 = date.month; - const D_G: u32 = date.day; - // Map to computational calendar. - const J: u32 = if (M_G <= 2) 1 else 0; - const Y: u32 = Y_G + L - J; - const M: u32 = if (J != 0) M_G + 12 else M_G; - const D: u32 = D_G - 1; - const C: u32 = Y / 100; - - // Rata die. - const y_star: u32 = 1461 * Y / 4 - C + C / 4; - const m_star: u32 = (979 * M - 2919) / 32; - const N: u32 = y_star + m_star + D; - const days: i32 = @intCast(N - K); - - return @as(i64, days) * std.time.epoch.secs_per_day; - } -}; - -const Time = struct { - hour: std.math.IntFittingRange(0, 24), - minute: std.math.IntFittingRange(0, 60), - second: std.math.IntFittingRange(0, 60), - - fn toSec(t: Time) i64 { - var sec: i64 = 0; - sec += @as(i64, t.hour) * 60 * 60; - sec += @as(i64, t.minute) * 60; - sec += t.second; - return sec; - } -}; - -fn parseTimeDigits( - text: *const [2]u8, - min: comptime_int, - max: comptime_int, -) !std.math.IntFittingRange(min, max) { - const result = std.fmt.parseInt(std.math.IntFittingRange(min, max), text, 10) catch - return error.InvalidTime; - if (result < min) return error.InvalidTime; - if (result > max) return error.InvalidTime; - return result; -} - -test parseTimeDigits { - const expectEqual = std.testing.expectEqual; - try expectEqual(@as(u8, 0), try parseTimeDigits("00", 0, 99)); - try expectEqual(@as(u8, 99), try parseTimeDigits("99", 0, 99)); - try expectEqual(@as(u8, 42), try parseTimeDigits("42", 0, 99)); - - const expectError = std.testing.expectError; - try expectError(error.InvalidTime, parseTimeDigits("13", 1, 12)); - try expectError(error.InvalidTime, parseTimeDigits("00", 1, 12)); - try expectError(error.InvalidTime, parseTimeDigits("Di", 0, 99)); -} - -fn parseYear4(text: *const [4]u8) !Date.Year { - const result = std.fmt.parseInt(Date.Year, text, 10) catch return error.InvalidYear; - if (result > 9999) return error.InvalidYear; - return result; -} - -test parseYear4 { - const expectEqual = std.testing.expectEqual; - try expectEqual(@as(Date.Year, 0), try parseYear4("0000")); - try expectEqual(@as(Date.Year, 9999), try parseYear4("9999")); - try expectEqual(@as(Date.Year, 1988), try parseYear4("1988")); - - const expectError = std.testing.expectError; - try expectError(error.InvalidYear, parseYear4("999b")); - try expectError(error.InvalidYear, parseYear4("crap")); - try expectError(error.InvalidYear, parseYear4("r:bQ")); -} - -fn parseTime(bytes: *const [6]u8) !Time { - return .{ - .hour = try parseTimeDigits(bytes[0..2], 0, 23), - .minute = try parseTimeDigits(bytes[2..4], 0, 59), - .second = try parseTimeDigits(bytes[4..6], 0, 59), - }; -} diff --git a/src/http/async/tls.zig/rsa/oid.zig b/src/http/async/tls.zig/rsa/oid.zig deleted file mode 100644 index fd360c3f..00000000 --- a/src/http/async/tls.zig/rsa/oid.zig +++ /dev/null @@ -1,132 +0,0 @@ -//! Developed by ITU-U and ISO/IEC for naming objects. Used in DER. -//! -//! This implementation supports any number of `u32` arcs. - -const Arc = u32; -const encoding_base = 128; - -/// Returns encoded length. -pub fn encodeLen(dot_notation: []const u8) !usize { - var split = std.mem.splitScalar(u8, dot_notation, '.'); - if (split.next() == null) return 0; - if (split.next() == null) return 1; - - var res: usize = 1; - while (split.next()) |s| { - const parsed = try std.fmt.parseUnsigned(Arc, s, 10); - const n_bytes = if (parsed == 0) 0 else std.math.log(Arc, encoding_base, parsed); - - res += n_bytes; - res += 1; - } - - return res; -} - -pub const EncodeError = std.fmt.ParseIntError || error{ - MissingPrefix, - BufferTooSmall, -}; - -pub fn encode(dot_notation: []const u8, buf: []u8) EncodeError![]const u8 { - if (buf.len < try encodeLen(dot_notation)) return error.BufferTooSmall; - - var split = std.mem.splitScalar(u8, dot_notation, '.'); - const first_str = split.next() orelse return error.MissingPrefix; - const second_str = split.next() orelse return error.MissingPrefix; - - const first = try std.fmt.parseInt(u8, first_str, 10); - const second = try std.fmt.parseInt(u8, second_str, 10); - - buf[0] = first * 40 + second; - - var i: usize = 1; - while (split.next()) |s| { - var parsed = try std.fmt.parseUnsigned(Arc, s, 10); - const n_bytes = if (parsed == 0) 0 else std.math.log(Arc, encoding_base, parsed); - - for (0..n_bytes) |j| { - const place = std.math.pow(Arc, encoding_base, n_bytes - @as(Arc, @intCast(j))); - const digit: u8 = @intCast(@divFloor(parsed, place)); - - buf[i] = digit | 0x80; - parsed -= digit * place; - - i += 1; - } - buf[i] = @intCast(parsed); - i += 1; - } - - return buf[0..i]; -} - -pub fn decode(encoded: []const u8, writer: anytype) @TypeOf(writer).Error!void { - const first = @divTrunc(encoded[0], 40); - const second = encoded[0] - first * 40; - try writer.print("{d}.{d}", .{ first, second }); - - var i: usize = 1; - while (i != encoded.len) { - const n_bytes: usize = brk: { - var res: usize = 1; - var j: usize = i; - while (encoded[j] & 0x80 != 0) { - res += 1; - j += 1; - } - break :brk res; - }; - - var n: usize = 0; - for (0..n_bytes) |j| { - const place = std.math.pow(usize, encoding_base, n_bytes - j - 1); - n += place * (encoded[i] & 0b01111111); - i += 1; - } - try writer.print(".{d}", .{n}); - } -} - -pub fn encodeComptime(comptime dot_notation: []const u8) [encodeLen(dot_notation) catch unreachable]u8 { - @setEvalBranchQuota(10_000); - var buf: [encodeLen(dot_notation) catch unreachable]u8 = undefined; - _ = encode(dot_notation, &buf) catch unreachable; - return buf; -} - -const std = @import("std"); - -fn testOid(expected_encoded: []const u8, expected_dot_notation: []const u8) !void { - var buf: [256]u8 = undefined; - const encoded = try encode(expected_dot_notation, &buf); - try std.testing.expectEqualSlices(u8, expected_encoded, encoded); - - var stream = std.io.fixedBufferStream(&buf); - try decode(expected_encoded, stream.writer()); - try std.testing.expectEqualStrings(expected_dot_notation, stream.getWritten()); -} - -test "encode and decode" { - // https://learn.microsoft.com/en-us/windows/win32/seccertenroll/about-object-identifier - try testOid( - &[_]u8{ 0x2b, 0x06, 0x01, 0x04, 0x01, 0x82, 0x37, 0x15, 0x14 }, - "1.3.6.1.4.1.311.21.20", - ); - // https://luca.ntop.org/Teaching/Appunti/asn1.html - try testOid(&[_]u8{ 0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d }, "1.2.840.113549"); - // https://www.sysadmins.lv/blog-en/how-to-encode-object-identifier-to-an-asn1-der-encoded-string.aspx - try testOid(&[_]u8{ 0x2a, 0x86, 0x8d, 0x20 }, "1.2.100000"); - try testOid( - &[_]u8{ 0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d, 0x01, 0x01, 0x0b }, - "1.2.840.113549.1.1.11", - ); - try testOid(&[_]u8{ 0x2b, 0x65, 0x70 }, "1.3.101.112"); -} - -test encodeComptime { - try std.testing.expectEqual( - [_]u8{ 0x2b, 0x06, 0x01, 0x04, 0x01, 0x82, 0x37, 0x15, 0x14 }, - encodeComptime("1.3.6.1.4.1.311.21.20"), - ); -} diff --git a/src/http/async/tls.zig/rsa/rsa.zig b/src/http/async/tls.zig/rsa/rsa.zig deleted file mode 100644 index 5e5f42fe..00000000 --- a/src/http/async/tls.zig/rsa/rsa.zig +++ /dev/null @@ -1,880 +0,0 @@ -//! RFC8017: Public Key Cryptography Standards #1 v2.2 (PKCS1) -const std = @import("std"); -const der = @import("der.zig"); -const ff = std.crypto.ff; - -pub const max_modulus_bits = 4096; -const max_modulus_len = max_modulus_bits / 8; - -const Modulus = std.crypto.ff.Modulus(max_modulus_bits); -const Fe = Modulus.Fe; - -pub const ValueError = error{ - Modulus, - Exponent, -}; - -pub const PublicKey = struct { - /// `n` - modulus: Modulus, - /// `e` - public_exponent: Fe, - - pub const FromBytesError = ValueError || ff.OverflowError || ff.FieldElementError || ff.InvalidModulusError || error{InsecureBitCount}; - - pub fn fromBytes(mod: []const u8, exp: []const u8) FromBytesError!PublicKey { - const modulus = try Modulus.fromBytes(mod, .big); - if (modulus.bits() <= 512) return error.InsecureBitCount; - const public_exponent = try Fe.fromBytes(modulus, exp, .big); - - if (std.debug.runtime_safety) { - // > the RSA public exponent e is an integer between 3 and n - 1 satisfying - // > GCD(e,\lambda(n)) = 1, where \lambda(n) = LCM(r_1 - 1, ..., r_u - 1) - const e_v = public_exponent.toPrimitive(u32) catch return error.Exponent; - if (!public_exponent.isOdd()) return error.Exponent; - if (e_v < 3) return error.Exponent; - if (modulus.v.compare(public_exponent.v) == .lt) return error.Exponent; - } - - return .{ .modulus = modulus, .public_exponent = public_exponent }; - } - - pub fn fromDer(bytes: []const u8) (der.Parser.Error || FromBytesError)!PublicKey { - var parser = der.Parser{ .bytes = bytes }; - - const seq = try parser.expectSequence(); - defer parser.seek(seq.slice.end); - - const modulus = try parser.expectPrimitive(.integer); - const pub_exp = try parser.expectPrimitive(.integer); - - try parser.expectEnd(seq.slice.end); - try parser.expectEnd(bytes.len); - - return try fromBytes(parser.view(modulus), parser.view(pub_exp)); - } - - /// Deprecated. - /// - /// Encrypt a short message using RSAES-PKCS1-v1_5. - /// The use of this scheme for encrypting an arbitrary message, as opposed to a - /// randomly generated key, is NOT RECOMMENDED. - pub fn encryptPkcsv1_5(pk: PublicKey, msg: []const u8, out: []u8) ![]const u8 { - // align variable names with spec - const k = byteLen(pk.modulus.bits()); - if (out.len < k) return error.BufferTooSmall; - if (msg.len > k - 11) return error.MessageTooLong; - - // EM = 0x00 || 0x02 || PS || 0x00 || M. - var em = out[0..k]; - em[0] = 0; - em[1] = 2; - - const ps = em[2..][0 .. k - msg.len - 3]; - // Section: 7.2.1 - // PS consists of pseudo-randomly generated nonzero octets. - for (ps) |*v| { - v.* = std.crypto.random.uintLessThan(u8, 0xff) + 1; - } - - em[em.len - msg.len - 1] = 0; - @memcpy(em[em.len - msg.len ..][0..msg.len], msg); - - const m = try Fe.fromBytes(pk.modulus, em, .big); - const e = try pk.modulus.powPublic(m, pk.public_exponent); - try e.toBytes(em, .big); - return em; - } - - /// Encrypt a short message using Optimal Asymmetric Encryption Padding (RSAES-OAEP). - pub fn encryptOaep( - pk: PublicKey, - comptime Hash: type, - msg: []const u8, - label: []const u8, - out: []u8, - ) ![]const u8 { - // align variable names with spec - const k = byteLen(pk.modulus.bits()); - if (out.len < k) return error.BufferTooSmall; - - if (msg.len > k - 2 * Hash.digest_length - 2) return error.MessageTooLong; - - // EM = 0x00 || maskedSeed || maskedDB. - var em = out[0..k]; - em[0] = 0; - const seed = em[1..][0..Hash.digest_length]; - std.crypto.random.bytes(seed); - - // DB = lHash || PS || 0x01 || M. - var db = em[1 + seed.len ..]; - const lHash = labelHash(Hash, label); - @memcpy(db[0..lHash.len], &lHash); - @memset(db[lHash.len .. db.len - msg.len - 2], 0); - db[db.len - msg.len - 1] = 1; - @memcpy(db[db.len - msg.len ..], msg); - - var mgf_buf: [max_modulus_len]u8 = undefined; - - const db_mask = mgf1(Hash, seed, mgf_buf[0..db.len]); - for (db, db_mask) |*v, m| v.* ^= m; - - const seed_mask = mgf1(Hash, db, mgf_buf[0..seed.len]); - for (seed, seed_mask) |*v, m| v.* ^= m; - - const m = try Fe.fromBytes(pk.modulus, em, .big); - const e = try pk.modulus.powPublic(m, pk.public_exponent); - try e.toBytes(em, .big); - return em; - } -}; - -pub fn byteLen(bits: usize) usize { - return std.math.divCeil(usize, bits, 8) catch unreachable; -} - -pub const SecretKey = struct { - /// `d` - private_exponent: Fe, - - pub const FromBytesError = ValueError || ff.OverflowError || ff.FieldElementError; - - pub fn fromBytes(n: Modulus, exp: []const u8) FromBytesError!SecretKey { - const d = try Fe.fromBytes(n, exp, .big); - if (std.debug.runtime_safety) { - // > The RSA private exponent d is a positive integer less than n - // > satisfying e * d == 1 (mod \lambda(n)), - if (!d.isOdd()) return error.Exponent; - if (d.v.compare(n.v) != .lt) return error.Exponent; - } - - return .{ .private_exponent = d }; - } -}; - -pub const KeyPair = struct { - public: PublicKey, - secret: SecretKey, - - pub const FromDerError = PublicKey.FromBytesError || SecretKey.FromBytesError || der.Parser.Error || error{ KeyMismatch, InvalidVersion }; - - pub fn fromDer(bytes: []const u8) FromDerError!KeyPair { - var parser = der.Parser{ .bytes = bytes }; - const seq = try parser.expectSequence(); - const version = try parser.expectInt(u8); - - const mod = try parser.expectPrimitive(.integer); - const pub_exp = try parser.expectPrimitive(.integer); - const sec_exp = try parser.expectPrimitive(.integer); - - const public = try PublicKey.fromBytes(parser.view(mod), parser.view(pub_exp)); - const secret = try SecretKey.fromBytes(public.modulus, parser.view(sec_exp)); - - const prime1 = try parser.expectPrimitive(.integer); - const prime2 = try parser.expectPrimitive(.integer); - const exp1 = try parser.expectPrimitive(.integer); - const exp2 = try parser.expectPrimitive(.integer); - const coeff = try parser.expectPrimitive(.integer); - _ = .{ exp1, exp2, coeff }; - - switch (version) { - 0 => {}, - 1 => { - _ = try parser.expectSequenceOf(); - while (!parser.eof()) { - _ = try parser.expectSequence(); - const ri = try parser.expectPrimitive(.integer); - const di = try parser.expectPrimitive(.integer); - const ti = try parser.expectPrimitive(.integer); - _ = .{ ri, di, ti }; - } - }, - else => return error.InvalidVersion, - } - - try parser.expectEnd(seq.slice.end); - try parser.expectEnd(bytes.len); - - if (std.debug.runtime_safety) { - const p = try Fe.fromBytes(public.modulus, parser.view(prime1), .big); - const q = try Fe.fromBytes(public.modulus, parser.view(prime2), .big); - - // check that n = p * q - const expected_zero = public.modulus.mul(p, q); - if (!expected_zero.isZero()) return error.KeyMismatch; - - // TODO: check that d * e is one mod p-1 and mod q-1. Note d and e were bound - // const de = secret.private_exponent.mul(public.public_exponent); - // const one = public.modulus.one(); - - // if (public.modulus.mul(de, p).compare(one) != .eq) return error.KeyMismatch; - // if (public.modulus.mul(de, q).compare(one) != .eq) return error.KeyMismatch; - } - - return .{ .public = public, .secret = secret }; - } - - /// Deprecated. - pub fn signPkcsv1_5(kp: KeyPair, comptime Hash: type, msg: []const u8, out: []u8) !PKCS1v1_5(Hash).Signature { - var st = try signerPkcsv1_5(kp, Hash); - st.update(msg); - return try st.finalize(out); - } - - /// Deprecated. - pub fn signerPkcsv1_5(kp: KeyPair, comptime Hash: type) !PKCS1v1_5(Hash).Signer { - return PKCS1v1_5(Hash).Signer.init(kp); - } - - /// Deprecated. - pub fn decryptPkcsv1_5(kp: KeyPair, ciphertext: []const u8, out: []u8) ![]const u8 { - const k = byteLen(kp.public.modulus.bits()); - if (out.len < k) return error.BufferTooSmall; - - const em = out[0..k]; - - const m = try Fe.fromBytes(kp.public.modulus, ciphertext, .big); - const e = try kp.public.modulus.pow(m, kp.secret.private_exponent); - try e.toBytes(em, .big); - - // Care shall be taken to ensure that an opponent cannot - // distinguish these error conditions, whether by error - // message or timing. - const msg_start = ct.lastIndexOfScalar(em, 0) orelse em.len; - const ps_len = em.len - msg_start; - if (ct.@"or"(em[0] != 0, ct.@"or"(em[1] != 2, ps_len < 8))) { - return error.Inconsistent; - } - - return em[msg_start + 1 ..]; - } - - pub fn signOaep( - kp: KeyPair, - comptime Hash: type, - msg: []const u8, - salt: ?[]const u8, - out: []u8, - ) !Pss(Hash).Signature { - var st = try signerOaep(kp, Hash, salt); - st.update(msg); - return try st.finalize(out); - } - - /// Salt must outlive returned `PSS.Signer`. - pub fn signerOaep(kp: KeyPair, comptime Hash: type, salt: ?[]const u8) !Pss(Hash).Signer { - return Pss(Hash).Signer.init(kp, salt); - } - - pub fn decryptOaep( - kp: KeyPair, - comptime Hash: type, - ciphertext: []const u8, - label: []const u8, - out: []u8, - ) ![]u8 { - // align variable names with spec - const k = byteLen(kp.public.modulus.bits()); - if (out.len < k) return error.BufferTooSmall; - - const mod = try Fe.fromBytes(kp.public.modulus, ciphertext, .big); - const exp = kp.public.modulus.pow(mod, kp.secret.private_exponent) catch unreachable; - const em = out[0..k]; - try exp.toBytes(em, .big); - - const y = em[0]; - const seed = em[1..][0..Hash.digest_length]; - const db = em[1 + Hash.digest_length ..]; - - var mgf_buf: [max_modulus_len]u8 = undefined; - - const seed_mask = mgf1(Hash, db, mgf_buf[0..seed.len]); - for (seed, seed_mask) |*v, m| v.* ^= m; - - const db_mask = mgf1(Hash, seed, mgf_buf[0..db.len]); - for (db, db_mask) |*v, m| v.* ^= m; - - const expected_hash = labelHash(Hash, label); - const actual_hash = db[0..expected_hash.len]; - - // Care shall be taken to ensure that an opponent cannot - // distinguish these error conditions, whether by error - // message or timing. - const msg_start = ct.indexOfScalarPos(em, expected_hash.len + 1, 1) orelse 0; - if (ct.@"or"(y != 0, ct.@"or"(msg_start == 0, !ct.memEql(&expected_hash, actual_hash)))) { - return error.Inconsistent; - } - - return em[msg_start + 1 ..]; - } - - /// Encrypt short plaintext with secret key. - pub fn encrypt(kp: KeyPair, plaintext: []const u8, out: []u8) !void { - const n = kp.public.modulus; - const k = byteLen(n.bits()); - if (plaintext.len > k) return error.MessageTooLong; - - const msg_as_int = try Fe.fromBytes(n, plaintext, .big); - const enc_as_int = try n.pow(msg_as_int, kp.secret.private_exponent); - try enc_as_int.toBytes(out, .big); - } -}; - -/// Deprecated. -/// -/// Signature Scheme with Appendix v1.5 (RSASSA-PKCS1-v1_5) -/// -/// This standard has been superceded by PSS which is formally proven secure -/// and has fewer footguns. -pub fn PKCS1v1_5(comptime Hash: type) type { - return struct { - const PkcsT = @This(); - pub const Signature = struct { - bytes: []const u8, - - const Self = @This(); - - pub fn verifier(self: Self, public_key: PublicKey) !Verifier { - return Verifier.init(self, public_key); - } - - pub fn verify(self: Self, msg: []const u8, public_key: PublicKey) !void { - var st = Verifier.init(self, public_key); - st.update(msg); - return st.verify(); - } - }; - - pub const Signer = struct { - h: Hash, - key_pair: KeyPair, - - fn init(key_pair: KeyPair) Signer { - return .{ - .h = Hash.init(.{}), - .key_pair = key_pair, - }; - } - - pub fn update(self: *Signer, data: []const u8) void { - self.h.update(data); - } - - pub fn finalize(self: *Signer, out: []u8) !PkcsT.Signature { - const k = byteLen(self.key_pair.public.modulus.bits()); - if (out.len < k) return error.BufferTooSmall; - - var hash: [Hash.digest_length]u8 = undefined; - self.h.final(&hash); - - const em = try emsaEncode(hash, out[0..k]); - try self.key_pair.encrypt(em, em); - return .{ .bytes = em }; - } - }; - - pub const Verifier = struct { - h: Hash, - sig: PkcsT.Signature, - public_key: PublicKey, - - fn init(sig: PkcsT.Signature, public_key: PublicKey) Verifier { - return Verifier{ - .h = Hash.init(.{}), - .sig = sig, - .public_key = public_key, - }; - } - - pub fn update(self: *Verifier, data: []const u8) void { - self.h.update(data); - } - - pub fn verify(self: *Verifier) !void { - const pk = self.public_key; - const s = try Fe.fromBytes(pk.modulus, self.sig.bytes, .big); - const emm = try pk.modulus.powPublic(s, pk.public_exponent); - - var em_buf: [max_modulus_len]u8 = undefined; - const em = em_buf[0..byteLen(pk.modulus.bits())]; - try emm.toBytes(em, .big); - - var hash: [Hash.digest_length]u8 = undefined; - self.h.final(&hash); - - // TODO: compare hash values instead of emsa values - const expected = try emsaEncode(hash, em); - - if (!std.mem.eql(u8, expected, em)) return error.Inconsistent; - } - }; - - /// PKCS Encrypted Message Signature Appendix - fn emsaEncode(hash: [Hash.digest_length]u8, out: []u8) ![]u8 { - const digest_header = comptime digestHeader(); - const tLen = digest_header.len + Hash.digest_length; - const emLen = out.len; - if (emLen < tLen + 11) return error.ModulusTooShort; - if (out.len < emLen) return error.BufferTooSmall; - - var res = out[0..emLen]; - res[0] = 0; - res[1] = 1; - const padding_len = emLen - tLen - 3; - @memset(res[2..][0..padding_len], 0xff); - res[2 + padding_len] = 0; - @memcpy(res[2 + padding_len + 1 ..][0..digest_header.len], digest_header); - @memcpy(res[res.len - hash.len ..], &hash); - - return res; - } - - /// DER encoded header. Sequence of digest algo + digest. - /// TODO: use a DER encoder instead - fn digestHeader() []const u8 { - const sha2 = std.crypto.hash.sha2; - // Section 9.2 Notes 1. - return switch (Hash) { - std.crypto.hash.Sha1 => &hexToBytes( - \\30 21 30 09 06 05 2b 0e 03 02 1a 05 00 04 14 - ), - sha2.Sha224 => &hexToBytes( - \\30 2d 30 0d 06 09 60 86 48 01 65 03 04 02 04 - \\05 00 04 1c - ), - sha2.Sha256 => &hexToBytes( - \\30 31 30 0d 06 09 60 86 48 01 65 03 04 02 01 05 00 - \\04 20 - ), - sha2.Sha384 => &hexToBytes( - \\30 41 30 0d 06 09 60 86 48 01 65 03 04 02 02 05 00 - \\04 30 - ), - sha2.Sha512 => &hexToBytes( - \\30 51 30 0d 06 09 60 86 48 01 65 03 04 02 03 05 00 - \\04 40 - ), - // sha2.Sha512224 => &hexToBytes( - // \\30 2d 30 0d 06 09 60 86 48 01 65 03 04 02 05 - // \\05 00 04 1c - // ), - // sha2.Sha512256 => &hexToBytes( - // \\30 31 30 0d 06 09 60 86 48 01 65 03 04 02 06 - // \\05 00 04 20 - // ), - else => @compileError("unknown Hash " ++ @typeName(Hash)), - }; - } - }; -} - -/// Probabilistic Signature Scheme (RSASSA-PSS) -pub fn Pss(comptime Hash: type) type { - // RFC 4055 S3.1 - const default_salt_len = Hash.digest_length; - return struct { - pub const Signature = struct { - bytes: []const u8, - - const Self = @This(); - - pub fn verifier(self: Self, public_key: PublicKey) !Verifier { - return Verifier.init(self, public_key); - } - - pub fn verify(self: Self, msg: []const u8, public_key: PublicKey, salt_len: ?usize) !void { - var st = Verifier.init(self, public_key, salt_len orelse default_salt_len); - st.update(msg); - return st.verify(); - } - }; - - const PssT = @This(); - - pub const Signer = struct { - h: Hash, - key_pair: KeyPair, - salt: ?[]const u8, - - fn init(key_pair: KeyPair, salt: ?[]const u8) Signer { - return .{ - .h = Hash.init(.{}), - .key_pair = key_pair, - .salt = salt, - }; - } - - pub fn update(self: *Signer, data: []const u8) void { - self.h.update(data); - } - - pub fn finalize(self: *Signer, out: []u8) !PssT.Signature { - var hashed: [Hash.digest_length]u8 = undefined; - self.h.final(&hashed); - - const salt = if (self.salt) |s| s else brk: { - var res: [default_salt_len]u8 = undefined; - std.crypto.random.bytes(&res); - break :brk &res; - }; - - const em_bits = self.key_pair.public.modulus.bits() - 1; - const em = try emsaEncode(hashed, salt, em_bits, out); - try self.key_pair.encrypt(em, em); - return .{ .bytes = em }; - } - }; - - pub const Verifier = struct { - h: Hash, - sig: PssT.Signature, - public_key: PublicKey, - salt_len: usize, - - fn init(sig: PssT.Signature, public_key: PublicKey, salt_len: usize) Verifier { - return Verifier{ - .h = Hash.init(.{}), - .sig = sig, - .public_key = public_key, - .salt_len = salt_len, - }; - } - - pub fn update(self: *Verifier, data: []const u8) void { - self.h.update(data); - } - - pub fn verify(self: *Verifier) !void { - const pk = self.public_key; - const s = try Fe.fromBytes(pk.modulus, self.sig.bytes, .big); - const emm = try pk.modulus.powPublic(s, pk.public_exponent); - - var em_buf: [max_modulus_len]u8 = undefined; - const em_bits = pk.modulus.bits() - 1; - const em_len = std.math.divCeil(usize, em_bits, 8) catch unreachable; - var em = em_buf[0..em_len]; - try emm.toBytes(em, .big); - - if (em.len < Hash.digest_length + self.salt_len + 2) return error.Inconsistent; - if (em[em.len - 1] != 0xbc) return error.Inconsistent; - - const db = em[0 .. em.len - Hash.digest_length - 1]; - if (@clz(db[0]) < em.len * 8 - em_bits) return error.Inconsistent; - - const expected_hash = em[db.len..][0..Hash.digest_length]; - var mgf_buf: [max_modulus_len]u8 = undefined; - const db_mask = mgf1(Hash, expected_hash, mgf_buf[0..db.len]); - for (db, db_mask) |*v, m| v.* ^= m; - - for (1..db.len - self.salt_len - 1) |i| { - if (db[i] != 0) return error.Inconsistent; - } - if (db[db.len - self.salt_len - 1] != 1) return error.Inconsistent; - const salt = db[db.len - self.salt_len ..]; - var mp_buf: [max_modulus_len]u8 = undefined; - var mp = mp_buf[0 .. 8 + Hash.digest_length + self.salt_len]; - @memset(mp[0..8], 0); - self.h.final(mp[8..][0..Hash.digest_length]); - @memcpy(mp[8 + Hash.digest_length ..][0..salt.len], salt); - - var actual_hash: [Hash.digest_length]u8 = undefined; - Hash.hash(mp, &actual_hash, .{}); - - if (!std.mem.eql(u8, expected_hash, &actual_hash)) return error.Inconsistent; - } - }; - - /// PSS Encrypted Message Signature Appendix - fn emsaEncode(msg_hash: [Hash.digest_length]u8, salt: []const u8, em_bits: usize, out: []u8) ![]u8 { - const em_len = std.math.divCeil(usize, em_bits, 8) catch unreachable; - - if (em_len < Hash.digest_length + salt.len + 2) return error.Encoding; - - // EM = maskedDB || H || 0xbc - var em = out[0..em_len]; - em[em.len - 1] = 0xbc; - - var mp_buf: [max_modulus_len]u8 = undefined; - // M' = (0x)00 00 00 00 00 00 00 00 || mHash || salt; - const mp = mp_buf[0 .. 8 + Hash.digest_length + salt.len]; - @memset(mp[0..8], 0); - @memcpy(mp[8..][0..Hash.digest_length], &msg_hash); - @memcpy(mp[8 + Hash.digest_length ..][0..salt.len], salt); - - // H = Hash(M') - const hash = em[em.len - 1 - Hash.digest_length ..][0..Hash.digest_length]; - Hash.hash(mp, hash, .{}); - - // DB = PS || 0x01 || salt - var db = em[0 .. em_len - Hash.digest_length - 1]; - @memset(db[0 .. db.len - salt.len - 1], 0); - db[db.len - salt.len - 1] = 1; - @memcpy(db[db.len - salt.len ..], salt); - - var mgf_buf: [max_modulus_len]u8 = undefined; - const db_mask = mgf1(Hash, hash, mgf_buf[0..db.len]); - for (db, db_mask) |*v, m| v.* ^= m; - - // Set the leftmost 8emLen - emBits bits of the leftmost octet - // in maskedDB to zero. - const shift = std.math.comptimeMod(8 * em_len - em_bits, 8); - const mask = @as(u8, 0xff) >> shift; - db[0] &= mask; - - return em; - } - }; -} - -/// Mask generation function. Currently the only one defined. -fn mgf1(comptime Hash: type, seed: []const u8, out: []u8) []u8 { - var c: [@sizeOf(u32)]u8 = undefined; - var tmp: [Hash.digest_length]u8 = undefined; - - var i: usize = 0; - var counter: u32 = 0; - while (i < out.len) : (counter += 1) { - var hasher = Hash.init(.{}); - hasher.update(seed); - std.mem.writeInt(u32, &c, counter, .big); - hasher.update(&c); - - const left = out.len - i; - if (left >= Hash.digest_length) { - // optimization: write straight to `out` - hasher.final(out[i..][0..Hash.digest_length]); - i += Hash.digest_length; - } else { - hasher.final(&tmp); - @memcpy(out[i..][0..left], tmp[0..left]); - i += left; - } - } - - return out; -} - -test mgf1 { - const Hash = std.crypto.hash.sha2.Sha256; - var out: [Hash.digest_length * 2 + 1]u8 = undefined; - try std.testing.expectEqualSlices( - u8, - &hexToBytes( - \\ed 1b 84 6b b9 26 39 00 c8 17 82 ad 08 eb 17 01 - \\fa 8c 72 21 c6 57 63 77 31 7f 5c e8 09 89 9f - ), - mgf1(Hash, "asdf", out[0 .. Hash.digest_length - 1]), - ); - try std.testing.expectEqualSlices( - u8, - &hexToBytes( - \\ed 1b 84 6b b9 26 39 00 c8 17 82 ad 08 eb 17 01 - \\fa 8c 72 21 c6 57 63 77 31 7f 5c e8 09 89 9f 5a - \\22 F2 80 D5 28 08 F4 93 83 76 00 DE 09 E4 EC 92 - \\4A 2C 7C EF 0D F7 7B BE 8F 7F 12 CB 8F 33 A6 65 - \\AB - ), - mgf1(Hash, "asdf", &out), - ); -} - -/// For OAEP. -inline fn labelHash(comptime Hash: type, label: []const u8) [Hash.digest_length]u8 { - if (label.len == 0) { - // magic constants from NIST - const sha2 = std.crypto.hash.sha2; - switch (Hash) { - std.crypto.hash.Sha1 => return hexToBytes( - \\da39a3ee 5e6b4b0d 3255bfef 95601890 - \\afd80709 - ), - sha2.Sha256 => return hexToBytes( - \\e3b0c442 98fc1c14 9afbf4c8 996fb924 - \\27ae41e4 649b934c a495991b 7852b855 - ), - sha2.Sha384 => return hexToBytes( - \\38b060a7 51ac9638 4cd9327e b1b1e36a - \\21fdb711 14be0743 4c0cc7bf 63f6e1da - \\274edebf e76f65fb d51ad2f1 4898b95b - ), - sha2.Sha512 => return hexToBytes( - \\cf83e135 7eefb8bd f1542850 d66d8007 - \\d620e405 0b5715dc 83f4a921 d36ce9ce - \\47d0d13c 5d85f2b0 ff8318d2 877eec2f - \\63b931bd 47417a81 a538327a f927da3e - ), - // just use the empty hash... - else => {}, - } - } - var res: [Hash.digest_length]u8 = undefined; - Hash.hash(label, &res, .{}); - return res; -} - -const ct = if (std.options.side_channels_mitigations == .none) ct_unprotected else ct_protected; - -const ct_unprotected = struct { - fn lastIndexOfScalar(slice: []const u8, value: u8) ?usize { - return std.mem.lastIndexOfScalar(u8, slice, value); - } - - fn indexOfScalarPos(slice: []const u8, start_index: usize, value: u8) ?usize { - return std.mem.indexOfScalarPos(u8, slice, start_index, value); - } - - fn memEql(a: []const u8, b: []const u8) bool { - return std.mem.eql(u8, a, b); - } - - fn @"and"(a: bool, b: bool) bool { - return a and b; - } - - fn @"or"(a: bool, b: bool) bool { - return a or b; - } -}; - -const ct_protected = struct { - fn lastIndexOfScalar(slice: []const u8, value: u8) ?usize { - var res: ?usize = null; - var i: usize = slice.len; - while (i != 0) { - i -= 1; - if (@intFromBool(res == null) & @intFromBool(slice[i] == value) == 1) res = i; - } - return res; - } - - fn indexOfScalarPos(slice: []const u8, start_index: usize, value: u8) ?usize { - var res: ?usize = null; - for (slice[start_index..], start_index..) |c, j| { - if (c == value) res = j; - } - return res; - } - - fn memEql(a: []const u8, b: []const u8) bool { - var res: u1 = 1; - for (a, b) |a_elem, b_elem| { - res &= @intFromBool(a_elem == b_elem); - } - return res == 1; - } - - fn @"and"(a: bool, b: bool) bool { - return (@intFromBool(a) & @intFromBool(b)) == 1; - } - - fn @"or"(a: bool, b: bool) bool { - return (@intFromBool(a) | @intFromBool(b)) == 1; - } -}; - -test ct { - const c = ct_unprotected; - try std.testing.expectEqual(true, c.@"or"(true, false)); - try std.testing.expectEqual(true, c.@"and"(true, true)); - try std.testing.expectEqual(true, c.memEql("Asdf", "Asdf")); - try std.testing.expectEqual(false, c.memEql("asdf", "Asdf")); - try std.testing.expectEqual(3, c.indexOfScalarPos("asdff", 1, 'f')); - try std.testing.expectEqual(4, c.lastIndexOfScalar("asdff", 'f')); -} - -fn removeNonHex(comptime hex: []const u8) []const u8 { - var res: [hex.len]u8 = undefined; - var i: usize = 0; - for (hex) |c| { - if (std.ascii.isHex(c)) { - res[i] = c; - i += 1; - } - } - return res[0..i]; -} - -/// For readable copy/pasting from hex viewers. -fn hexToBytes(comptime hex: []const u8) [removeNonHex(hex).len / 2]u8 { - const hex2 = comptime removeNonHex(hex); - comptime var res: [hex2.len / 2]u8 = undefined; - _ = comptime std.fmt.hexToBytes(&res, hex2) catch unreachable; - return res; -} - -test hexToBytes { - const hex = - \\e3b0c442 98fc1c14 9afbf4c8 996fb924 - \\27ae41e4 649b934c a495991b 7852b855 - ; - try std.testing.expectEqual( - [_]u8{ - 0xe3, 0xb0, 0xc4, 0x42, 0x98, 0xfc, 0x1c, 0x14, - 0x9a, 0xfb, 0xf4, 0xc8, 0x99, 0x6f, 0xb9, 0x24, - 0x27, 0xae, 0x41, 0xe4, 0x64, 0x9b, 0x93, 0x4c, - 0xa4, 0x95, 0x99, 0x1b, 0x78, 0x52, 0xb8, 0x55, - }, - hexToBytes(hex), - ); -} - -const TestHash = std.crypto.hash.sha2.Sha256; -fn testKeypair() !KeyPair { - const keypair_bytes = @embedFile("testdata/id_rsa.der"); - const kp = try KeyPair.fromDer(keypair_bytes); - try std.testing.expectEqual(2048, kp.public.modulus.bits()); - return kp; -} - -test "rsa PKCS1-v1_5 encrypt and decrypt" { - const kp = try testKeypair(); - - const msg = "rsa PKCS1-v1_5 encrypt and decrypt"; - var out: [max_modulus_len]u8 = undefined; - const enc = try kp.public.encryptPkcsv1_5(msg, &out); - - var out2: [max_modulus_len]u8 = undefined; - const dec = try kp.decryptPkcsv1_5(enc, &out2); - - try std.testing.expectEqualSlices(u8, msg, dec); -} - -test "rsa OAEP encrypt and decrypt" { - const kp = try testKeypair(); - - const msg = "rsa OAEP encrypt and decrypt"; - const label = ""; - var out: [max_modulus_len]u8 = undefined; - const enc = try kp.public.encryptOaep(TestHash, msg, label, &out); - - var out2: [max_modulus_len]u8 = undefined; - const dec = try kp.decryptOaep(TestHash, enc, label, &out2); - - try std.testing.expectEqualSlices(u8, msg, dec); -} - -test "rsa PKCS1-v1_5 signature" { - const kp = try testKeypair(); - - const msg = "rsa PKCS1-v1_5 signature"; - var out: [max_modulus_len]u8 = undefined; - - const signature = try kp.signPkcsv1_5(TestHash, msg, &out); - try signature.verify(msg, kp.public); -} - -test "rsa PSS signature" { - const kp = try testKeypair(); - - const msg = "rsa PSS signature"; - var out: [max_modulus_len]u8 = undefined; - - const salts = [_][]const u8{ "asdf", "" }; - for (salts) |salt| { - const signature = try kp.signOaep(TestHash, msg, salt, &out); - try signature.verify(msg, kp.public, salt.len); - } - - const signature = try kp.signOaep(TestHash, msg, null, &out); // random salt - try signature.verify(msg, kp.public, null); -} diff --git a/src/http/async/tls.zig/rsa/testdata/id_rsa.der b/src/http/async/tls.zig/rsa/testdata/id_rsa.der deleted file mode 100644 index 9e4f1334d16264ca8accea1d9f7212da6a14554a..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 1191 zcmV;Y1X%kpf&`-i0RRGm0RaHSfHnrY=SOP@lmzUjwvhxs_maFB?)!ao*QgBu9(zkV zO6Cvfz;XO@=K@R(y!5@%9XV^da7IcK=}P!L^Wh0uRC~!)`#~+Ec2W`H^W1lAs#7;^ z$~x@6!>YGCG1Y9gQk;O8yvg7w7~%`}_@Fxd7X(nA&Uw9`Iq~Xg>_?X_gAcXJmEM)1 z<^&?u?!HoaRH5g;iiY+^Z4I9ml^RU`K|Im+mAD~hY(e&`u&UtVtGUCd4b^x|AUT0|5X50)hbmVS71J9cG^gdO7qra{M0j{KnM;d)X4|Dh$$5 zpuK)<(|%N=kJIS5RL4bJb#zDd;>Vn{)X3fy;mX=(OjWQ=^8)r|>p{8`>I2BL7G&f4 zTJ9uTTbuWvdjpoOPq1k)g+g{=Rb*1Z;Uaedc>>;gWfTSv__%efoKYgOC)y9G5#U>) zX4nJos0`bTLHEODsw7TAd zZaNEJYbv|;_HfEL^^x7ZFckyx4;#r{Io*=6Z9}3uriSet?;rJ}1Dtq0BZU5Hi zQ2$bCHjw`G*=@EQ{sO+TQM3ZngpvmKZIi)B<^5nTWPodhAR8k0Hx zRF$oWcryh&arHHJi6s7g-`0Dx>>O0zV}G!2mkbOQ-yu{=4O|BV%p8Sr{rQ_^#~t`> zf?9e9bV^_}9}jMUVFxw}5)x&pGh02vD*60E<8!x0uv>GxJM5R&>2$QQ4Tt~%+};}5 z0)c=)%%UKzS~2V}a|gnxjvGVeIuV_4^!R_w=;DyFwofV-Q6uIo%d2X5fjbKeVWjqAX@$;2c6-1ooGbC%q?qgq z&ubOB)zcw4OL0tehSLK07lV!!1Rl;&ya=HJfq*-h@&u~`pTO^q@QyjGbrp)>GZQdH zcz7kX1Vj5y4#LIWXWtXaiHK{S62>>_jGa;(zih-YEkAhHYt}#X2GJeI?Z^wED){4A z(W%+sitXpW`9b2KDu2r)_+LLHh&*l-n*lh1Z!PcJ-#DyiQth7%Jy~5wD5a_T5y=1u F(f{FrO=AE6 diff --git a/src/http/async/tls.zig/testdata/ec_prime256v1_private_key.pem b/src/http/async/tls.zig/testdata/ec_prime256v1_private_key.pem deleted file mode 100644 index 67ebf388..00000000 --- a/src/http/async/tls.zig/testdata/ec_prime256v1_private_key.pem +++ /dev/null @@ -1,5 +0,0 @@ ------BEGIN EC PRIVATE KEY----- -MHcCAQEEINJSRKv8kSKEzLHptfAlg+LGh4/pHHlq0XLf30Q9pcztoAoGCCqGSM49 -AwEHoUQDQgAEJpmLyp8aGCgyMcFIJaIq/+4V1K6nPpeoih3bT2npeplF9eyXj7rm -8eW9Ua6VLhq71mqtMC+YLm+IkORBVq1cuA== ------END EC PRIVATE KEY----- diff --git a/src/http/async/tls.zig/testdata/ec_private_key.pem b/src/http/async/tls.zig/testdata/ec_private_key.pem deleted file mode 100644 index 95048aaa..00000000 --- a/src/http/async/tls.zig/testdata/ec_private_key.pem +++ /dev/null @@ -1,6 +0,0 @@ ------BEGIN PRIVATE KEY----- -MIG2AgEAMBAGByqGSM49AgEGBSuBBAAiBIGeMIGbAgEBBDAQNT3KGxUdBqpxuO/z -GSJDePMgmB6xLytkfnHQMCqQquXrmcOQZT3BJhm+PwggmwGhZANiAATKxBc6kfqA -piA+Z0rIjVwaZaBNGnP4UZ5TqVewQ/dP9/BQCca2SJpsXauGLcUPmK4sKFQxGe6d -fzq9O50lo7qHEOIpwDBdhRp+oqB6sN2hMtCPbp6eyzsUlm3FUyhN9D0= ------END PRIVATE KEY----- diff --git a/src/http/async/tls.zig/testdata/ec_secp384r1_private_key.pem b/src/http/async/tls.zig/testdata/ec_secp384r1_private_key.pem deleted file mode 100644 index 62eac9ee..00000000 --- a/src/http/async/tls.zig/testdata/ec_secp384r1_private_key.pem +++ /dev/null @@ -1,6 +0,0 @@ ------BEGIN EC PRIVATE KEY----- -MIGkAgEBBDDubYpeDdOwxksyQIDiOt6LHt3ikts2HNuR6rqhBg1CLdmp3AVDKfF4 -fPkIr8UDH22gBwYFK4EEACKhZANiAARcVFUVv3bIHS6BEfLt98rtps7XP1y26m2n -v5x/5ecbDH2p7AXBYerJERKFi7ZFE1DSrSAj+KK8otjdEG44ZA2Mtl5AHwDVrKde -RgtavVoreHhLN80jJOun8JnFXQjdNsA= ------END EC PRIVATE KEY----- diff --git a/src/http/async/tls.zig/testdata/ec_secp521r1_private_key.pem b/src/http/async/tls.zig/testdata/ec_secp521r1_private_key.pem deleted file mode 100644 index 5b7f9321..00000000 --- a/src/http/async/tls.zig/testdata/ec_secp521r1_private_key.pem +++ /dev/null @@ -1,7 +0,0 @@ ------BEGIN EC PRIVATE KEY----- -MIHcAgEBBEIB8C9axyQY6mgjjC6htLjc8hGylrDsh4BCv9669JaDj5vbxmCnTNlg -OuS6C9+uJNMbwm6CoIjB7RcgDTrxxX7oCyegBwYFK4EEACOhgYkDgYYABABAT5Q8 -aOj9U0iuJE5tXfKnYTgPuvD6keHZAGJ5veM9uR6jr3BhfGubD6bnlD+cIBQzYWo0 -y/BNMzCRJ55PDCNU5gGLw+vkwhJ1lGF5OS6l2oG5WN3fe6cYo+uJD7+PB3WYNIuX -Ls0oidsEM0Q4WLblQOEP6VLGf4qTcZyhoFWYfkjWiw== ------END EC PRIVATE KEY----- diff --git a/src/http/async/tls.zig/testdata/google.com/client_random b/src/http/async/tls.zig/testdata/google.com/client_random deleted file mode 100644 index e817c906..00000000 --- a/src/http/async/tls.zig/testdata/google.com/client_random +++ /dev/null @@ -1 +0,0 @@ -'”’ßqp0x­0)ì©–Ã~Ì+Œ`‡¬tY4•©D_ \ No newline at end of file diff --git a/src/http/async/tls.zig/testdata/google.com/server_hello b/src/http/async/tls.zig/testdata/google.com/server_hello deleted file mode 100644 index 57a807650723010f00d7b371701d3004a55bc685..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 7158 zcmb_hc|4Tu*S}{W8M0;#V~vsN9%~CJm8FoSlBH$LW-?~R%rGg1A+ogED)ppN+EH4y zsvc5NPy0`0i54npQHi|QJ!Zy`=Y2o#=kxOi_xGH0o$K7^x~}iJ&P@k{fkMy_6pX>p zorU$GtTg;h4}3p%T1cZp<_Yf6$8lfgEC})U3Yvq$XWpA1O~lMwl}*ZPGQWH*Bvy%1 zcYsxO+aM+TMt8uL#aAao5N-ekp}-#qje>MA7=v&eWDo)wEQHkN!y+{=STt^OF$Rr7 zqt!!v6l0IH*N@k1b7kp%(>{%wRkM zl(qD|I2;CxhF{2w;|uV?G+sQDLgV9oeP@%jU=73uqowS%Fc?337M?WQ0XEiEwReOa zNzSkXX^N9Wm>9aiQ^n9e4Av!$hqc7RR8$B=hS)ig!ij-JC^Pw(Pzn%6gi?cmTp2Aw zp`h)NwV|O4*X0FvH-8mgC|76qJvi8zw|Zyd0gKtY$^PnvCEEwKtPuWM;<5nUY;HZT zGQ+>+O?7>X+S=s)EIQuX)%D|^oGlNdPVCdYn30@h@VeReqxO%WgPjWfQ4y~ea9u2n z{V!J@iyl?^<^i^da(93B%G*a>KSflRDPAW~SIII8?8Js{(8gj7B!jz<~h4``@l z0U2yhZk%A41~;^K$s#$x5{(er0P5xh0GODqR2!xO@PD=w$>lAAuA*(U^vW5D#HmuD&mp7 z4BV&^i1JW_9K+@EI4fxs0oc$C!~(q6pF(*%NkGHxBtqT6H?$+O$*csBaap$XrLXy? zN-_rILqh-349mv+snQsnCmv{|*0S}G7fS4q-SLN=ym0)WKMq&GjA5>(No{qdQIAw; zQaN#vEEokxz56748fI^ZbV#n8z7I36vA z#ui9DgEjmY4UQ2nAW6-@*^-Bd&VIN~266T=RB_=#dwcGASTdunl8q z0v_O1d8zVX7R1zwurY;80Ywu@j94_NwiiDm{735I)g;nrCQ_R~ zs9^m@iWCySb%!PzgDe2@U?9sjiwTYlUw-urLEWLTI9wV}mVRk&*rEwCC4UsHp5(eD z^*%JtKsJs@iQ;i&=NVZ_Js~X;xt|SAIur|AWJKXm0kcH{CLTo%nTv>E0Ry`tCMH-X7+g7vCl}kL+e$8*6dvGr*^P7XRyuh#b=E=w zJa^S{&uDP{V}U*cJLCe1fe#yn;N)75EcV@!E@23&|$B}aO!tz+!;b0Q;+cDw|#T;$?sny$yu2% z20Ax>=||(WC`IjC7cHCT-x1s4v8m|fK)t1DavR0Df%2Ed(TeejU|#R?SJD4D=;f%^ zz7pNL8}6JCgu1=>#iW#HRacHqeVWFX%O0+rMmy5);_Sn-92(YTIp+t={KBT)phwE4D^Md z2SY*d6|#q6tWsWMjf?-s>-!znIvV0Dvsd}LNgM|0}I?$GN4q}WRO z`bm+tDu>(h=TG&2WtjM^y^PrP%6Z#6*VCG%+0}cRtS{bn;{UyCP0{_UGdz z>C|uK7~d+i$X#_XulVQ_ZJQ-gH3PP6;(`aS_g!z(2~*x%)&6Glx&_m4p{EX86_(<& z)`hQmk-&I&QD?TNDJH7=$(?IsTLnvCqwn`pA1pswGH*hyvXjv1?*+MgHX1v$*O}u< zGiUPp6Pg1!cJJEPymZC>jt}NRW#;eaA>q+jovp@0DVNtV6XpwE3FM{g3SwBX&ll@ z2XQz8d=N?8OT+LKk^>BW9PN=_8g$ZPguwq=&;Eb4-rqJfy@|{*cKPAru_cL>aqkyR zWZD1rc67-=%)2o2)*U?)-fhF|&`KPuoYm{`c-dl$BdZjM8-!`-$1uFg(*jOjT;cgWy=;Hd+M8x3xcT`rk))9sM zcF%EVqP0|V^lfzuwx9o$-BUcV>ddQ+ftxR#s41$*l(yceWIz&$!&`4u`e0veiPeF% zrF||2C3`}P7k01-SArEmFAl!c$hG35cG~i-cw3pRRjh%n_r&#+gXojGK4ogf&VZw& z{y{4q0f)g>QVd#ggMVwqc_bK?*`d=d{`XFtM-n&Uj)R1pGDy||us=Lyh*SYZ6JTp;2xyef1jH8~d^V7AGx9Zn&lGTY36te$UP}1r z5ho2J0?Eu=dki0;;(B@Kp^$vT$`idu5Ab3uo%8#fW@=!3t39==v@MQ)3WzmpJbr_n z)X=?>Q=4k|v_5lr1}o#-VU_(wdn5C5^P9p?(DiZXg1`Dszfot-Sbf{mXSrv~tcoYw z{b-I*vVr@8J**@jC?$AS$u?7w$e?@woNz*brk3BeVCOF$N@WXsHwl9csZ%EEj=i+x zmSN&lyYMgd4?h$(YGba94Yn|;R>3B0xB5)(4i4(~$T6hbx{u#mHPBP^DrF~1F*ei5 z+TumarA~dj7Vq0`p{aqDg1>hs{#}tCUUtXv#)CT#Q#A6SyBR)9isZYOTbE36oTI6&V27?+`w-;Unur1`f5AQL3n+>M|HPQ ztB;?Y*J)Da;$Dv?(98Ko_6O`)_T`(JZca@I{!52Fjc$pX~_fC5xWKsk>h8Nxg>p9 z7vWxIEp;yzCz8wxW-{1#cDx;*Oi~vgb}UY7jF^&V&mtMZ`h!HRDQ7HEO8DBDG%>QZIQ}MGt&8Xt%&BB*m=Rfa%cTqEQS?Q6Wqd&|i-Y%)w zMQMDpW8Ej^Wt~w6zy2NOrRt~d@XEhJYk#dCYySJbv0IvUQ*A8G$Ly{A$Bnq6x#6F& z8C7?^KlxQxuilVVb)g|t|8d#o@-gLqk-A0Kt~L=5#<29qoqoE`%3v@{>cyE?aS@k# zGEOGeG@@MzQZ`uWYLpPt$@w;}%Ohncqh z!D{8DkEW|uWbMPeBptsa?BJ}K@wjQLnda-uy-)4?)G6A5yq8NK>*QXtiE`+gMt?sh z%c@E1NeFG##Fxu_w4r@a(#}JsXAW27mmfHiX>N#My)pTcNYE8JIZUl}oOiS3RXUfK7SvK3cyE@f~1s5AfV z^uwoI-dvwkvhFGCNsZ^XZAEV$nO>%E)6ZMPdT4V}^%_0ARp{ew{bBbL0|#F4DUnk{ zWcDqK>;q?~s#L8b%(@cP(6_JPOtr`z{C^2GNd-b#V<02^cYnS{waq_!33G9uvYjru zEO-IYI+>@gPvbtx_xmC6Rl=emW5ZwqYagSg8a;cPzv6O@(7QZa^(DuWKhV?T>t8R4 z$lZ3Nc*glt4qxA6fTv#IX1w(%BQd1k&iGK;x7PEe zTDux_pzXi>4mm7$KIED>p*XAVqtC79)*Dy1nKcwVc`~)HCuzRrn=BZ?s zAKfdyOvQRkym~#=tsM6#!lCNX(jEnO?1lob&!rn2M7s>`B2}, ", .{b}); - // if (i % 16 == 0) - // std.debug.print("\n", .{}); - // } - // std.debug.print("}};\n", .{}); - - std.debug.print("const {s} = \"", .{var_name}); - const charset = "0123456789abcdef"; - for (buf) |b| { - const x = charset[b >> 4]; - const y = charset[b & 15]; - std.debug.print("{c}{c} ", .{ x, y }); - } - std.debug.print("\"\n", .{}); -} - -const random_instance = std.Random{ .ptr = undefined, .fillFn = randomFillFn }; -var random_seed: u8 = 0; - -pub fn randomFillFn(_: *anyopaque, buf: []u8) void { - for (buf) |*v| { - v.* = random_seed; - random_seed +%= 1; - } -} - -pub fn random(seed: u8) std.Random { - random_seed = seed; - return random_instance; -} - -// Fill buf with 0,1,..ff,0,... -pub fn fill(buf: []u8) void { - fillFrom(buf, 0); -} - -pub fn fillFrom(buf: []u8, start: u8) void { - var i: u8 = start; - for (buf) |*v| { - v.* = i; - i +%= 1; - } -} - -pub const Stream = struct { - output: std.io.FixedBufferStream([]u8) = undefined, - input: std.io.FixedBufferStream([]const u8) = undefined, - - pub fn init(input: []const u8, output: []u8) Stream { - return .{ - .input = std.io.fixedBufferStream(input), - .output = std.io.fixedBufferStream(output), - }; - } - - pub const ReadError = error{}; - pub const WriteError = error{NoSpaceLeft}; - - pub fn write(self: *Stream, buf: []const u8) !usize { - return try self.output.writer().write(buf); - } - - pub fn writeAll(self: *Stream, buffer: []const u8) !void { - var n: usize = 0; - while (n < buffer.len) { - n += try self.write(buffer[n..]); - } - } - - pub fn read(self: *Stream, buffer: []u8) !usize { - return self.input.read(buffer); - } -}; - -// Copied from: https://github.com/clickingbuttons/zig/blob/f1cea91624fd2deae28bfb2414a4fd9c7e246883/lib/std/crypto/rsa.zig#L791 -/// For readable copy/pasting from hex viewers. -pub fn hexToBytes(comptime hex: []const u8) [removeNonHex(hex).len / 2]u8 { - @setEvalBranchQuota(1000 * 100); - const hex2 = comptime removeNonHex(hex); - comptime var res: [hex2.len / 2]u8 = undefined; - _ = comptime std.fmt.hexToBytes(&res, hex2) catch unreachable; - return res; -} - -fn removeNonHex(comptime hex: []const u8) []const u8 { - @setEvalBranchQuota(1000 * 100); - var res: [hex.len]u8 = undefined; - var i: usize = 0; - for (hex) |c| { - if (std.ascii.isHex(c)) { - res[i] = c; - i += 1; - } - } - return res[0..i]; -} - -test hexToBytes { - const hex = - \\e3b0c442 98fc1c14 9afbf4c8 996fb924 - \\27ae41e4 649b934c a495991b 7852b855 - ; - try std.testing.expectEqual( - [_]u8{ - 0xe3, 0xb0, 0xc4, 0x42, 0x98, 0xfc, 0x1c, 0x14, - 0x9a, 0xfb, 0xf4, 0xc8, 0x99, 0x6f, 0xb9, 0x24, - 0x27, 0xae, 0x41, 0xe4, 0x64, 0x9b, 0x93, 0x4c, - 0xa4, 0x95, 0x99, 0x1b, 0x78, 0x52, 0xb8, 0x55, - }, - hexToBytes(hex), - ); -} diff --git a/src/http/async/tls.zig/transcript.zig b/src/http/async/tls.zig/transcript.zig deleted file mode 100644 index 59c94986..00000000 --- a/src/http/async/tls.zig/transcript.zig +++ /dev/null @@ -1,297 +0,0 @@ -const std = @import("std"); -const crypto = std.crypto; -const tls = crypto.tls; -const hkdfExpandLabel = tls.hkdfExpandLabel; - -const Sha256 = crypto.hash.sha2.Sha256; -const Sha384 = crypto.hash.sha2.Sha384; -const Sha512 = crypto.hash.sha2.Sha512; - -const HashTag = @import("cipher.zig").CipherSuite.HashTag; - -// Transcript holds hash of all handshake message. -// -// Until the server hello is parsed we don't know which hash (sha256, sha384, -// sha512) will be used so we update all of them. Handshake process will set -// `selected` field once cipher suite is known. Other function will use that -// selected hash. We continue to calculate all hashes because client certificate -// message could use different hash than the other part of the handshake. -// Handshake hash is dictated by the server selected cipher. Client certificate -// hash is dictated by the private key used. -// -// Most of the functions are inlined because they are returning pointers. -// -pub const Transcript = struct { - sha256: Type(.sha256) = .{ .hash = Sha256.init(.{}) }, - sha384: Type(.sha384) = .{ .hash = Sha384.init(.{}) }, - sha512: Type(.sha512) = .{ .hash = Sha512.init(.{}) }, - - tag: HashTag = .sha256, - - pub const max_mac_length = Type(.sha512).mac_length; - - // Transcript Type from hash tag - fn Type(h: HashTag) type { - return switch (h) { - .sha256 => TranscriptT(Sha256), - .sha384 => TranscriptT(Sha384), - .sha512 => TranscriptT(Sha512), - }; - } - - /// Set hash to use in all following function calls. - pub fn use(t: *Transcript, tag: HashTag) void { - t.tag = tag; - } - - pub fn update(t: *Transcript, buf: []const u8) void { - t.sha256.hash.update(buf); - t.sha384.hash.update(buf); - t.sha512.hash.update(buf); - } - - // tls 1.2 handshake specific - - pub inline fn masterSecret( - t: *Transcript, - pre_master_secret: []const u8, - client_random: [32]u8, - server_random: [32]u8, - ) []const u8 { - return switch (t.tag) { - inline else => |h| &@field(t, @tagName(h)).masterSecret( - pre_master_secret, - client_random, - server_random, - ), - }; - } - - pub inline fn keyMaterial( - t: *Transcript, - master_secret: []const u8, - client_random: [32]u8, - server_random: [32]u8, - ) []const u8 { - return switch (t.tag) { - inline else => |h| &@field(t, @tagName(h)).keyExpansion( - master_secret, - client_random, - server_random, - ), - }; - } - - pub fn clientFinishedTls12(t: *Transcript, master_secret: []const u8) [12]u8 { - return switch (t.tag) { - inline else => |h| @field(t, @tagName(h)).clientFinishedTls12(master_secret), - }; - } - - pub fn serverFinishedTls12(t: *Transcript, master_secret: []const u8) [12]u8 { - return switch (t.tag) { - inline else => |h| @field(t, @tagName(h)).serverFinishedTls12(master_secret), - }; - } - - // tls 1.3 handshake specific - - pub inline fn serverCertificateVerify(t: *Transcript) []const u8 { - return switch (t.tag) { - inline else => |h| &@field(t, @tagName(h)).serverCertificateVerify(), - }; - } - - pub inline fn clientCertificateVerify(t: *Transcript) []const u8 { - return switch (t.tag) { - inline else => |h| &@field(t, @tagName(h)).clientCertificateVerify(), - }; - } - - pub fn serverFinishedTls13(t: *Transcript, buf: []u8) []const u8 { - return switch (t.tag) { - inline else => |h| @field(t, @tagName(h)).serverFinishedTls13(buf), - }; - } - - pub fn clientFinishedTls13(t: *Transcript, buf: []u8) []const u8 { - return switch (t.tag) { - inline else => |h| @field(t, @tagName(h)).clientFinishedTls13(buf), - }; - } - - pub const Secret = struct { - client: []const u8, - server: []const u8, - }; - - pub inline fn handshakeSecret(t: *Transcript, shared_key: []const u8) Secret { - return switch (t.tag) { - inline else => |h| @field(t, @tagName(h)).handshakeSecret(shared_key), - }; - } - - pub inline fn applicationSecret(t: *Transcript) Secret { - return switch (t.tag) { - inline else => |h| @field(t, @tagName(h)).applicationSecret(), - }; - } - - // other - - pub fn Hkdf(h: HashTag) type { - return Type(h).Hkdf; - } - - /// Copy of the current hash value - pub inline fn hash(t: *Transcript, comptime Hash: type) Hash { - return switch (Hash) { - Sha256 => t.sha256.hash, - Sha384 => t.sha384.hash, - Sha512 => t.sha512.hash, - else => @compileError("unimplemented"), - }; - } -}; - -fn TranscriptT(comptime Hash: type) type { - return struct { - const Hmac = crypto.auth.hmac.Hmac(Hash); - const Hkdf = crypto.kdf.hkdf.Hkdf(Hmac); - const mac_length = Hmac.mac_length; - - hash: Hash, - handshake_secret: [Hmac.mac_length]u8 = undefined, - server_finished_key: [Hmac.key_length]u8 = undefined, - client_finished_key: [Hmac.key_length]u8 = undefined, - - const Self = @This(); - - fn init(transcript: Hash) Self { - return .{ .transcript = transcript }; - } - - fn serverCertificateVerify(c: *Self) [64 + 34 + Hash.digest_length]u8 { - return ([1]u8{0x20} ** 64) ++ - "TLS 1.3, server CertificateVerify\x00".* ++ - c.hash.peek(); - } - - // ref: https://www.rfc-editor.org/rfc/rfc8446#section-4.4.3 - fn clientCertificateVerify(c: *Self) [64 + 34 + Hash.digest_length]u8 { - return ([1]u8{0x20} ** 64) ++ - "TLS 1.3, client CertificateVerify\x00".* ++ - c.hash.peek(); - } - - fn masterSecret( - _: *Self, - pre_master_secret: []const u8, - client_random: [32]u8, - server_random: [32]u8, - ) [mac_length * 2]u8 { - const seed = "master secret" ++ client_random ++ server_random; - - var a1: [mac_length]u8 = undefined; - var a2: [mac_length]u8 = undefined; - Hmac.create(&a1, seed, pre_master_secret); - Hmac.create(&a2, &a1, pre_master_secret); - - var p1: [mac_length]u8 = undefined; - var p2: [mac_length]u8 = undefined; - Hmac.create(&p1, a1 ++ seed, pre_master_secret); - Hmac.create(&p2, a2 ++ seed, pre_master_secret); - - return p1 ++ p2; - } - - fn keyExpansion( - _: *Self, - master_secret: []const u8, - client_random: [32]u8, - server_random: [32]u8, - ) [mac_length * 4]u8 { - const seed = "key expansion" ++ server_random ++ client_random; - - const a0 = seed; - var a1: [mac_length]u8 = undefined; - var a2: [mac_length]u8 = undefined; - var a3: [mac_length]u8 = undefined; - var a4: [mac_length]u8 = undefined; - Hmac.create(&a1, a0, master_secret); - Hmac.create(&a2, &a1, master_secret); - Hmac.create(&a3, &a2, master_secret); - Hmac.create(&a4, &a3, master_secret); - - var key_material: [mac_length * 4]u8 = undefined; - Hmac.create(key_material[0..mac_length], a1 ++ seed, master_secret); - Hmac.create(key_material[mac_length .. mac_length * 2], a2 ++ seed, master_secret); - Hmac.create(key_material[mac_length * 2 .. mac_length * 3], a3 ++ seed, master_secret); - Hmac.create(key_material[mac_length * 3 ..], a4 ++ seed, master_secret); - return key_material; - } - - fn clientFinishedTls12(self: *Self, master_secret: []const u8) [12]u8 { - const seed = "client finished" ++ self.hash.peek(); - var a1: [mac_length]u8 = undefined; - var p1: [mac_length]u8 = undefined; - Hmac.create(&a1, seed, master_secret); - Hmac.create(&p1, a1 ++ seed, master_secret); - return p1[0..12].*; - } - - fn serverFinishedTls12(self: *Self, master_secret: []const u8) [12]u8 { - const seed = "server finished" ++ self.hash.peek(); - var a1: [mac_length]u8 = undefined; - var p1: [mac_length]u8 = undefined; - Hmac.create(&a1, seed, master_secret); - Hmac.create(&p1, a1 ++ seed, master_secret); - return p1[0..12].*; - } - - // tls 1.3 - - inline fn handshakeSecret(self: *Self, shared_key: []const u8) Transcript.Secret { - const hello_hash = self.hash.peek(); - - const zeroes = [1]u8{0} ** Hash.digest_length; - const early_secret = Hkdf.extract(&[1]u8{0}, &zeroes); - const empty_hash = tls.emptyHash(Hash); - const hs_derived_secret = hkdfExpandLabel(Hkdf, early_secret, "derived", &empty_hash, Hash.digest_length); - - self.handshake_secret = Hkdf.extract(&hs_derived_secret, shared_key); - const client_secret = hkdfExpandLabel(Hkdf, self.handshake_secret, "c hs traffic", &hello_hash, Hash.digest_length); - const server_secret = hkdfExpandLabel(Hkdf, self.handshake_secret, "s hs traffic", &hello_hash, Hash.digest_length); - - self.server_finished_key = hkdfExpandLabel(Hkdf, server_secret, "finished", "", Hmac.key_length); - self.client_finished_key = hkdfExpandLabel(Hkdf, client_secret, "finished", "", Hmac.key_length); - - return .{ .client = &client_secret, .server = &server_secret }; - } - - inline fn applicationSecret(self: *Self) Transcript.Secret { - const handshake_hash = self.hash.peek(); - - const empty_hash = tls.emptyHash(Hash); - const zeroes = [1]u8{0} ** Hash.digest_length; - const ap_derived_secret = hkdfExpandLabel(Hkdf, self.handshake_secret, "derived", &empty_hash, Hash.digest_length); - const master_secret = Hkdf.extract(&ap_derived_secret, &zeroes); - - const client_secret = hkdfExpandLabel(Hkdf, master_secret, "c ap traffic", &handshake_hash, Hash.digest_length); - const server_secret = hkdfExpandLabel(Hkdf, master_secret, "s ap traffic", &handshake_hash, Hash.digest_length); - - return .{ .client = &client_secret, .server = &server_secret }; - } - - fn serverFinishedTls13(self: *Self, buf: []u8) []const u8 { - Hmac.create(buf[0..mac_length], &self.hash.peek(), &self.server_finished_key); - return buf[0..mac_length]; - } - - // client finished message with header - fn clientFinishedTls13(self: *Self, buf: []u8) []const u8 { - Hmac.create(buf[0..mac_length], &self.hash.peek(), &self.client_finished_key); - return buf[0..mac_length]; - } - }; -} diff --git a/src/main.zig b/src/main.zig index d7c3143d..b0e5e4f4 100644 --- a/src/main.zig +++ b/src/main.zig @@ -30,6 +30,7 @@ const apiweb = @import("apiweb.zig"); pub const Types = jsruntime.reflect(apiweb.Interfaces); pub const UserContext = apiweb.UserContext; +pub const IO = @import("asyncio").Wrapper(jsruntime.Loop); // Default options const Host = "127.0.0.1"; diff --git a/src/main_shell.zig b/src/main_shell.zig index ac803ae5..eb88ab50 100644 --- a/src/main_shell.zig +++ b/src/main_shell.zig @@ -24,12 +24,13 @@ const parser = @import("netsurf"); const apiweb = @import("apiweb.zig"); const Window = @import("html/window.zig").Window; const storage = @import("storage/storage.zig"); +const Client = @import("asyncio").Client; const html_test = @import("html_test.zig").html; pub const Types = jsruntime.reflect(apiweb.Interfaces); pub const UserContext = apiweb.UserContext; -const Client = @import("http/async/main.zig").Client; +pub const IO = @import("asyncio").Wrapper(jsruntime.Loop); var doc: *parser.DocumentHTML = undefined; diff --git a/src/main_wpt.zig b/src/main_wpt.zig index 49b7ba23..7cf2f077 100644 --- a/src/main_wpt.zig +++ b/src/main_wpt.zig @@ -50,6 +50,7 @@ const Out = enum { pub const Types = jsruntime.reflect(apiweb.Interfaces); pub const GlobalType = apiweb.GlobalType; pub const UserContext = apiweb.UserContext; +pub const IO = @import("asyncio").Wrapper(jsruntime.Loop); // TODO For now the WPT tests run is specific to WPT. // It manually load js framwork libs, and run the first script w/ js content in diff --git a/src/run_tests.zig b/src/run_tests.zig index 8e285840..a9073d4e 100644 --- a/src/run_tests.zig +++ b/src/run_tests.zig @@ -30,7 +30,7 @@ const xhr = @import("xhr/xhr.zig"); const storage = @import("storage/storage.zig"); const url = @import("url/url.zig"); const urlquery = @import("url/query.zig"); -const Client = @import("http/async/main.zig").Client; +const Client = @import("asyncio").Client; const documentTestExecFn = @import("dom/document.zig").testExecFn; const HTMLDocumentTestExecFn = @import("html/document.zig").testExecFn; @@ -59,6 +59,7 @@ const MutationObserverTestExecFn = @import("dom/mutation_observer.zig").testExec pub const Types = jsruntime.reflect(apiweb.Interfaces); pub const UserContext = @import("user_context.zig").UserContext; +pub const IO = @import("asyncio").Wrapper(jsruntime.Loop); var doc: *parser.DocumentHTML = undefined; @@ -298,9 +299,6 @@ test { const msgTest = @import("msg.zig"); std.testing.refAllDecls(msgTest); - std.testing.refAllDecls(@import("http/async/std/http.zig")); - std.testing.refAllDecls(@import("http/async/stack.zig")); - const dumpTest = @import("browser/dump.zig"); std.testing.refAllDecls(dumpTest); diff --git a/src/test_runner.zig b/src/test_runner.zig index 8b138d0b..8358b66c 100644 --- a/src/test_runner.zig +++ b/src/test_runner.zig @@ -22,6 +22,7 @@ const tests = @import("run_tests.zig"); pub const Types = tests.Types; pub const UserContext = tests.UserContext; +pub const IO = tests.IO; pub fn main() !void { try tests.main(); diff --git a/src/user_context.zig b/src/user_context.zig index 644893c8..3bed0108 100644 --- a/src/user_context.zig +++ b/src/user_context.zig @@ -1,6 +1,6 @@ const std = @import("std"); const parser = @import("netsurf"); -const Client = @import("http/async/main.zig").Client; +const Client = @import("asyncio").Client; pub const UserContext = struct { document: *parser.DocumentHTML, diff --git a/src/wpt/run.zig b/src/wpt/run.zig index ec1d3397..a44b12d7 100644 --- a/src/wpt/run.zig +++ b/src/wpt/run.zig @@ -28,10 +28,10 @@ const Loop = jsruntime.Loop; const Env = jsruntime.Env; const Window = @import("../html/window.zig").Window; const storage = @import("../storage/storage.zig"); +const Client = @import("asyncio").Client; const Types = @import("../main_wpt.zig").Types; const UserContext = @import("../main_wpt.zig").UserContext; -const Client = @import("../http/async/main.zig").Client; // runWPT parses the given HTML file, starts a js env and run the first script // tags containing javascript sources. diff --git a/src/xhr/xhr.zig b/src/xhr/xhr.zig index 660e449a..ab936b43 100644 --- a/src/xhr/xhr.zig +++ b/src/xhr/xhr.zig @@ -32,7 +32,7 @@ const XMLHttpRequestEventTarget = @import("event_target.zig").XMLHttpRequestEven const Mime = @import("../browser/mime.zig"); const Loop = jsruntime.Loop; -const Client = @import("../http/async/main.zig").Client; +const Client = @import("asyncio").Client; const parser = @import("netsurf"); @@ -97,7 +97,7 @@ pub const XMLHttpRequest = struct { proto: XMLHttpRequestEventTarget = XMLHttpRequestEventTarget{}, alloc: std.mem.Allocator, cli: *Client, - loop: Client.Loop, + io: Client.IO, priv_state: PrivState = .new, req: ?Client.Request = null, @@ -294,7 +294,7 @@ pub const XMLHttpRequest = struct { .alloc = alloc, .headers = Headers.init(alloc), .response_headers = Headers.init(alloc), - .loop = Client.Loop.init(loop), + .io = Client.IO.init(loop), .method = undefined, .url = null, .uri = undefined, @@ -513,7 +513,7 @@ pub const XMLHttpRequest = struct { self.req = null; } - self.ctx = try Client.Ctx.init(&self.loop, &self.req.?); + self.ctx = try Client.Ctx.init(&self.io, &self.req.?); errdefer { self.ctx.?.deinit(); self.ctx = null; diff --git a/vendor/zig-async-io b/vendor/zig-async-io new file mode 160000 index 00000000..d996742c --- /dev/null +++ b/vendor/zig-async-io @@ -0,0 +1 @@ +Subproject commit d996742c00f518be4f088af69d81912b8df94d58 diff --git a/vendor/zig-js-runtime b/vendor/zig-js-runtime index f434b3cf..d0e006be 160000 --- a/vendor/zig-js-runtime +++ b/vendor/zig-js-runtime @@ -1 +1 @@ -Subproject commit f434b3cfa1938277a6cd2e225974bb8d33d578c2 +Subproject commit d0e006becd9ddac8a4e0ac9890c7b4087e237bd7 From 1a2fd9a584440fc29836225f6543a7d30c3f6820 Mon Sep 17 00:00:00 2001 From: Francis Bouvier Date: Thu, 21 Nov 2024 16:46:09 +0100 Subject: [PATCH 11/11] Update dependencies Signed-off-by: Francis Bouvier --- vendor/zig-async-io | 2 +- vendor/zig-js-runtime | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/vendor/zig-async-io b/vendor/zig-async-io index d996742c..ed7ae07d 160000 --- a/vendor/zig-async-io +++ b/vendor/zig-async-io @@ -1 +1 @@ -Subproject commit d996742c00f518be4f088af69d81912b8df94d58 +Subproject commit ed7ae07d1c39ca073a6eacb741c8b56bc3e57f9f diff --git a/vendor/zig-js-runtime b/vendor/zig-js-runtime index d0e006be..d1152619 160000 --- a/vendor/zig-js-runtime +++ b/vendor/zig-js-runtime @@ -1 +1 @@ -Subproject commit d0e006becd9ddac8a4e0ac9890c7b4087e237bd7 +Subproject commit d11526195cd8f417901533145c42355fe39ff24e