From fadf3f609a1ed5d221f696895e347f6c998179d4 Mon Sep 17 00:00:00 2001 From: Pierre Tachoire Date: Thu, 14 Nov 2024 16:07:27 +0100 Subject: [PATCH] http: add full async client --- src/http/async/io.zig | 148 + src/http/async/loop.zig | 75 + src/http/async/main.zig | 4 + src/http/async/stack.zig | 95 + src/http/async/std/http.zig | 318 ++ src/http/async/std/http/Client.zig | 2545 +++++++++++++++++ src/http/async/std/http/Server.zig | 1148 ++++++++ src/http/async/std/http/protocol.zig | 447 +++ src/http/async/std/net.zig | 2050 +++++++++++++ src/http/async/std/net/test.zig | 335 +++ src/http/async/tls.zig/PrivateKey.zig | 260 ++ src/http/async/tls.zig/cbc/main.zig | 148 + src/http/async/tls.zig/cipher.zig | 1004 +++++++ src/http/async/tls.zig/connection.zig | 665 +++++ src/http/async/tls.zig/handshake_client.zig | 955 +++++++ src/http/async/tls.zig/handshake_common.zig | 448 +++ src/http/async/tls.zig/handshake_server.zig | 520 ++++ src/http/async/tls.zig/key_log.zig | 60 + src/http/async/tls.zig/main.zig | 51 + src/http/async/tls.zig/protocol.zig | 302 ++ src/http/async/tls.zig/record.zig | 405 +++ src/http/async/tls.zig/rsa/der.zig | 467 +++ src/http/async/tls.zig/rsa/oid.zig | 132 + src/http/async/tls.zig/rsa/rsa.zig | 880 ++++++ .../async/tls.zig/rsa/testdata/id_rsa.der | Bin 0 -> 1191 bytes .../testdata/ec_prime256v1_private_key.pem | 5 + .../async/tls.zig/testdata/ec_private_key.pem | 6 + .../testdata/ec_secp384r1_private_key.pem | 6 + .../testdata/ec_secp521r1_private_key.pem | 7 + .../tls.zig/testdata/google.com/client_random | 1 + .../tls.zig/testdata/google.com/server_hello | Bin 0 -> 7158 bytes .../tls.zig/testdata/rsa_private_key.pem | 28 + src/http/async/tls.zig/testdata/tls12.zig | 244 ++ src/http/async/tls.zig/testdata/tls13.zig | 64 + src/http/async/tls.zig/testu.zig | 117 + src/http/async/tls.zig/transcript.zig | 297 ++ 36 files changed, 14237 insertions(+) create mode 100644 src/http/async/io.zig create mode 100644 src/http/async/loop.zig create mode 100644 src/http/async/main.zig create mode 100644 src/http/async/stack.zig create mode 100644 src/http/async/std/http.zig create mode 100644 src/http/async/std/http/Client.zig create mode 100644 src/http/async/std/http/Server.zig create mode 100644 src/http/async/std/http/protocol.zig create mode 100644 src/http/async/std/net.zig create mode 100644 src/http/async/std/net/test.zig create mode 100644 src/http/async/tls.zig/PrivateKey.zig create mode 100644 src/http/async/tls.zig/cbc/main.zig create mode 100644 src/http/async/tls.zig/cipher.zig create mode 100644 src/http/async/tls.zig/connection.zig create mode 100644 src/http/async/tls.zig/handshake_client.zig create mode 100644 src/http/async/tls.zig/handshake_common.zig create mode 100644 src/http/async/tls.zig/handshake_server.zig create mode 100644 src/http/async/tls.zig/key_log.zig create mode 100644 src/http/async/tls.zig/main.zig create mode 100644 src/http/async/tls.zig/protocol.zig create mode 100644 src/http/async/tls.zig/record.zig create mode 100644 src/http/async/tls.zig/rsa/der.zig create mode 100644 src/http/async/tls.zig/rsa/oid.zig create mode 100644 src/http/async/tls.zig/rsa/rsa.zig create mode 100644 src/http/async/tls.zig/rsa/testdata/id_rsa.der create mode 100644 src/http/async/tls.zig/testdata/ec_prime256v1_private_key.pem create mode 100644 src/http/async/tls.zig/testdata/ec_private_key.pem create mode 100644 src/http/async/tls.zig/testdata/ec_secp384r1_private_key.pem create mode 100644 src/http/async/tls.zig/testdata/ec_secp521r1_private_key.pem create mode 100644 src/http/async/tls.zig/testdata/google.com/client_random create mode 100644 src/http/async/tls.zig/testdata/google.com/server_hello create mode 100644 src/http/async/tls.zig/testdata/rsa_private_key.pem create mode 100644 src/http/async/tls.zig/testdata/tls12.zig create mode 100644 src/http/async/tls.zig/testdata/tls13.zig create mode 100644 src/http/async/tls.zig/testu.zig create mode 100644 src/http/async/tls.zig/transcript.zig diff --git a/src/http/async/io.zig b/src/http/async/io.zig new file mode 100644 index 00000000..a416c5fc --- /dev/null +++ b/src/http/async/io.zig @@ -0,0 +1,148 @@ +const std = @import("std"); + +pub const IO = @import("jsruntime").IO; + +pub const Blocking = struct { + pub fn connect( + _: *Blocking, + comptime CtxT: type, + ctx: *CtxT, + comptime cbk: fn (ctx: *CtxT, res: anyerror!void) anyerror!void, + socket: std.posix.socket_t, + address: std.net.Address, + ) void { + std.posix.connect(socket, &address.any, address.getOsSockLen()) catch |err| { + std.posix.close(socket); + cbk(ctx, err) catch |e| { + ctx.setErr(e); + }; + }; + cbk(ctx, {}) catch |e| ctx.setErr(e); + } + + pub fn send( + _: *Blocking, + comptime CtxT: type, + ctx: *CtxT, + comptime cbk: fn (ctx: *CtxT, res: anyerror!void) anyerror!void, + socket: std.posix.socket_t, + buf: []const u8, + ) void { + const len = std.posix.write(socket, buf) catch |err| { + cbk(ctx, err) catch |e| { + return ctx.setErr(e); + }; + return ctx.setErr(err); + }; + ctx.setLen(len); + cbk(ctx, {}) catch |e| ctx.setErr(e); + } + + pub fn recv( + _: *Blocking, + comptime CtxT: type, + ctx: *CtxT, + comptime cbk: fn (ctx: *CtxT, res: anyerror!void) anyerror!void, + socket: std.posix.socket_t, + buf: []u8, + ) void { + const len = std.posix.read(socket, buf) catch |err| { + cbk(ctx, err) catch |e| { + return ctx.setErr(e); + }; + return ctx.setErr(err); + }; + ctx.setLen(len); + cbk(ctx, {}) catch |e| ctx.setErr(e); + } +}; + +pub fn SingleThreaded(comptime CtxT: type) type { + return struct { + io: *IO, + completion: IO.Completion, + ctx: *CtxT, + cbk: CbkT, + + count: u32 = 0, + + const CbkT = *const fn (ctx: *CtxT, res: anyerror!void) anyerror!void; + + const Self = @This(); + + pub fn init(io: *IO) Self { + return .{ + .io = io, + .completion = undefined, + .ctx = undefined, + .cbk = undefined, + }; + } + + pub fn connect( + self: *Self, + comptime _: type, + ctx: *CtxT, + comptime cbk: CbkT, + socket: std.posix.socket_t, + address: std.net.Address, + ) void { + self.ctx = ctx; + self.cbk = cbk; + self.count += 1; + self.io.connect(*Self, self, Self.connectCbk, &self.completion, socket, address); + } + + fn connectCbk(self: *Self, _: *IO.Completion, result: IO.ConnectError!void) void { + defer self.count -= 1; + _ = result catch |e| return self.ctx.setErr(e); + self.cbk(self.ctx, {}) catch |e| self.ctx.setErr(e); + } + + pub fn send( + self: *Self, + comptime _: type, + ctx: *CtxT, + comptime cbk: CbkT, + socket: std.posix.socket_t, + buf: []const u8, + ) void { + self.ctx = ctx; + self.cbk = cbk; + self.count += 1; + self.io.send(*Self, self, Self.sendCbk, &self.completion, socket, buf); + } + + fn sendCbk(self: *Self, _: *IO.Completion, result: IO.SendError!usize) void { + defer self.count -= 1; + const ln = result catch |e| return self.ctx.setErr(e); + self.ctx.setLen(ln); + self.cbk(self.ctx, {}) catch |e| self.ctx.setErr(e); + } + + pub fn recv( + self: *Self, + comptime _: type, + ctx: *CtxT, + comptime cbk: CbkT, + socket: std.posix.socket_t, + buf: []u8, + ) void { + self.ctx = ctx; + self.cbk = cbk; + self.count += 1; + self.io.recv(*Self, self, Self.receiveCbk, &self.completion, socket, buf); + } + + fn receiveCbk(self: *Self, _: *IO.Completion, result: IO.RecvError!usize) void { + defer self.count -= 1; + const ln = result catch |e| return self.ctx.setErr(e); + self.ctx.setLen(ln); + self.cbk(self.ctx, {}) catch |e| self.ctx.setErr(e); + } + + pub fn isDone(self: *Self) bool { + return self.count == 0; + } + }; +} diff --git a/src/http/async/loop.zig b/src/http/async/loop.zig new file mode 100644 index 00000000..1b18d0f7 --- /dev/null +++ b/src/http/async/loop.zig @@ -0,0 +1,75 @@ +const std = @import("std"); +const Client = @import("std/http/Client.zig"); + +const Stack = @import("stack.zig"); + +const Res = fn (ctx: *Ctx, res: ?anyerror) anyerror!void; + +pub const Blocking = struct { + pub fn connect( + _: *Blocking, + comptime ctxT: type, + ctx: *ctxT, + comptime cbk: Res, + socket: std.os.socket_t, + address: std.net.Address, + ) void { + std.os.connect(socket, &address.any, address.getOsSockLen()) catch |err| { + std.os.closeSocket(socket); + _ = cbk(ctx, err); + return; + }; + ctx.socket = socket; + _ = cbk(ctx, null); + } +}; + +const CtxStack = Stack(Res); + +pub const Ctx = struct { + alloc: std.mem.Allocator, + stack: ?*CtxStack = null, + + // TCP ctx + client: *Client = undefined, + addr_current: usize = undefined, + list: *std.net.AddressList = undefined, + socket: std.os.socket_t = undefined, + Stream: std.net.Stream = undefined, + host: []const u8 = undefined, + port: u16 = undefined, + protocol: Client.Connection.Protocol = undefined, + conn: *Client.Connection = undefined, + uri: std.Uri = undefined, + headers: std.http.Headers = undefined, + method: std.http.Method = undefined, + options: Client.RequestOptions = undefined, + request: Client.Request = undefined, + + err: ?anyerror, + + pub fn init(alloc: std.mem.Allocator) Ctx { + return .{ .alloc = alloc }; + } + + pub fn push(self: *Ctx, function: CtxStack.Fn) !void { + if (self.stack) |stack| { + return try stack.push(self.alloc, function); + } + self.stack = try CtxStack.init(self.alloc, function); + } + + pub fn next(self: *Ctx, err: ?anyerror) !void { + if (self.stack) |stack| { + const last = stack.next == null; + const function = stack.pop(self.alloc, stack); + const res = @call(.auto, function, .{ self, err }); + if (last) { + self.stack = null; + self.alloc.destroy(stack); + } + return res; + } + self.err = err; + } +}; diff --git a/src/http/async/main.zig b/src/http/async/main.zig new file mode 100644 index 00000000..ea756e8b --- /dev/null +++ b/src/http/async/main.zig @@ -0,0 +1,4 @@ +const std = @import("std"); + +const stack = @import("stack.zig"); +pub const Client = @import("std/http/Client.zig"); diff --git a/src/http/async/stack.zig b/src/http/async/stack.zig new file mode 100644 index 00000000..d19a0c8f --- /dev/null +++ b/src/http/async/stack.zig @@ -0,0 +1,95 @@ +const std = @import("std"); + +pub fn Stack(comptime T: type) type { + return struct { + const Self = @This(); + pub const Fn = *const T; + + next: ?*Self = null, + func: Fn, + + pub fn init(alloc: std.mem.Allocator, comptime func: Fn) !*Self { + const next = try alloc.create(Self); + next.* = .{ .func = func }; + return next; + } + + pub fn push(self: *Self, alloc: std.mem.Allocator, comptime func: Fn) !void { + if (self.next) |next| { + return next.push(alloc, func); + } + self.next = try Self.init(alloc, func); + } + + pub fn pop(self: *Self, alloc: std.mem.Allocator, prev: ?*Self) Fn { + if (self.next) |next| { + return next.pop(alloc, self); + } + defer { + if (prev) |p| { + self.deinit(alloc, p); + } + } + return self.func; + } + + pub fn deinit(self: *Self, alloc: std.mem.Allocator, prev: ?*Self) void { + if (self.next) |next| { + // recursivly deinit + next.deinit(alloc, self); + } + if (prev) |p| { + p.next = null; + } + alloc.destroy(self); + } + }; +} + +fn first() u8 { + return 1; +} + +fn second() u8 { + return 2; +} + +test "stack" { + const alloc = std.testing.allocator; + const TestStack = Stack(fn () u8); + + var stack = TestStack{ .func = first }; + try stack.push(alloc, second); + + const a = stack.pop(alloc, null); + try std.testing.expect(a() == 2); + + const b = stack.pop(alloc, null); + try std.testing.expect(b() == 1); +} + +fn first_op(arg: ?*anyopaque) u8 { + const val = @as(*u8, @ptrCast(arg)); + return val.* + @as(u8, 1); +} + +fn second_op(arg: ?*anyopaque) u8 { + const val = @as(*u8, @ptrCast(arg)); + return val.* + @as(u8, 2); +} + +test "opaque stack" { + const alloc = std.testing.allocator; + const TestStack = Stack(fn (?*anyopaque) u8); + + var stack = TestStack{ .func = first_op }; + try stack.push(alloc, second_op); + + const a = stack.pop(alloc, null); + var x: u8 = 5; + try std.testing.expect(a(@as(*anyopaque, @ptrCast(&x))) == 2 + x); + + const b = stack.pop(alloc, null); + var y: u8 = 3; + try std.testing.expect(b(@as(*anyopaque, @ptrCast(&y))) == 1 + y); +} diff --git a/src/http/async/std/http.zig b/src/http/async/std/http.zig new file mode 100644 index 00000000..f027d440 --- /dev/null +++ b/src/http/async/std/http.zig @@ -0,0 +1,318 @@ +pub const Client = @import("http/Client.zig"); +pub const Server = @import("http/Server.zig"); +pub const protocol = @import("http/protocol.zig"); +pub const HeadParser = std.http.HeadParser; +pub const ChunkParser = std.http.ChunkParser; +pub const HeaderIterator = std.http.HeaderIterator; + +pub const Version = enum { + @"HTTP/1.0", + @"HTTP/1.1", +}; + +/// https://developer.mozilla.org/en-US/docs/Web/HTTP/Methods +/// +/// https://datatracker.ietf.org/doc/html/rfc7231#section-4 Initial definition +/// +/// https://datatracker.ietf.org/doc/html/rfc5789#section-2 PATCH +pub const Method = enum(u64) { + GET = parse("GET"), + HEAD = parse("HEAD"), + POST = parse("POST"), + PUT = parse("PUT"), + DELETE = parse("DELETE"), + CONNECT = parse("CONNECT"), + OPTIONS = parse("OPTIONS"), + TRACE = parse("TRACE"), + PATCH = parse("PATCH"), + + _, + + /// Converts `s` into a type that may be used as a `Method` field. + /// Asserts that `s` is 24 or fewer bytes. + pub fn parse(s: []const u8) u64 { + var x: u64 = 0; + const len = @min(s.len, @sizeOf(@TypeOf(x))); + @memcpy(std.mem.asBytes(&x)[0..len], s[0..len]); + return x; + } + + pub fn write(self: Method, w: anytype) !void { + const bytes = std.mem.asBytes(&@intFromEnum(self)); + const str = std.mem.sliceTo(bytes, 0); + try w.writeAll(str); + } + + /// Returns true if a request of this method is allowed to have a body + /// Actual behavior from servers may vary and should still be checked + pub fn requestHasBody(self: Method) bool { + return switch (self) { + .POST, .PUT, .PATCH => true, + .GET, .HEAD, .DELETE, .CONNECT, .OPTIONS, .TRACE => false, + else => true, + }; + } + + /// Returns true if a response to this method is allowed to have a body + /// Actual behavior from clients may vary and should still be checked + pub fn responseHasBody(self: Method) bool { + return switch (self) { + .GET, .POST, .DELETE, .CONNECT, .OPTIONS, .PATCH => true, + .HEAD, .PUT, .TRACE => false, + else => true, + }; + } + + /// An HTTP method is safe if it doesn't alter the state of the server. + /// + /// https://developer.mozilla.org/en-US/docs/Glossary/Safe/HTTP + /// + /// https://datatracker.ietf.org/doc/html/rfc7231#section-4.2.1 + pub fn safe(self: Method) bool { + return switch (self) { + .GET, .HEAD, .OPTIONS, .TRACE => true, + .POST, .PUT, .DELETE, .CONNECT, .PATCH => false, + else => false, + }; + } + + /// An HTTP method is idempotent if an identical request can be made once or several times in a row with the same effect while leaving the server in the same state. + /// + /// https://developer.mozilla.org/en-US/docs/Glossary/Idempotent + /// + /// https://datatracker.ietf.org/doc/html/rfc7231#section-4.2.2 + pub fn idempotent(self: Method) bool { + return switch (self) { + .GET, .HEAD, .PUT, .DELETE, .OPTIONS, .TRACE => true, + .CONNECT, .POST, .PATCH => false, + else => false, + }; + } + + /// A cacheable response is an HTTP response that can be cached, that is stored to be retrieved and used later, saving a new request to the server. + /// + /// https://developer.mozilla.org/en-US/docs/Glossary/cacheable + /// + /// https://datatracker.ietf.org/doc/html/rfc7231#section-4.2.3 + pub fn cacheable(self: Method) bool { + return switch (self) { + .GET, .HEAD => true, + .POST, .PUT, .DELETE, .CONNECT, .OPTIONS, .TRACE, .PATCH => false, + else => false, + }; + } +}; + +/// https://developer.mozilla.org/en-US/docs/Web/HTTP/Status +pub const Status = enum(u10) { + @"continue" = 100, // RFC7231, Section 6.2.1 + switching_protocols = 101, // RFC7231, Section 6.2.2 + processing = 102, // RFC2518 + early_hints = 103, // RFC8297 + + ok = 200, // RFC7231, Section 6.3.1 + created = 201, // RFC7231, Section 6.3.2 + accepted = 202, // RFC7231, Section 6.3.3 + non_authoritative_info = 203, // RFC7231, Section 6.3.4 + no_content = 204, // RFC7231, Section 6.3.5 + reset_content = 205, // RFC7231, Section 6.3.6 + partial_content = 206, // RFC7233, Section 4.1 + multi_status = 207, // RFC4918 + already_reported = 208, // RFC5842 + im_used = 226, // RFC3229 + + multiple_choice = 300, // RFC7231, Section 6.4.1 + moved_permanently = 301, // RFC7231, Section 6.4.2 + found = 302, // RFC7231, Section 6.4.3 + see_other = 303, // RFC7231, Section 6.4.4 + not_modified = 304, // RFC7232, Section 4.1 + use_proxy = 305, // RFC7231, Section 6.4.5 + temporary_redirect = 307, // RFC7231, Section 6.4.7 + permanent_redirect = 308, // RFC7538 + + bad_request = 400, // RFC7231, Section 6.5.1 + unauthorized = 401, // RFC7235, Section 3.1 + payment_required = 402, // RFC7231, Section 6.5.2 + forbidden = 403, // RFC7231, Section 6.5.3 + not_found = 404, // RFC7231, Section 6.5.4 + method_not_allowed = 405, // RFC7231, Section 6.5.5 + not_acceptable = 406, // RFC7231, Section 6.5.6 + proxy_auth_required = 407, // RFC7235, Section 3.2 + request_timeout = 408, // RFC7231, Section 6.5.7 + conflict = 409, // RFC7231, Section 6.5.8 + gone = 410, // RFC7231, Section 6.5.9 + length_required = 411, // RFC7231, Section 6.5.10 + precondition_failed = 412, // RFC7232, Section 4.2][RFC8144, Section 3.2 + payload_too_large = 413, // RFC7231, Section 6.5.11 + uri_too_long = 414, // RFC7231, Section 6.5.12 + unsupported_media_type = 415, // RFC7231, Section 6.5.13][RFC7694, Section 3 + range_not_satisfiable = 416, // RFC7233, Section 4.4 + expectation_failed = 417, // RFC7231, Section 6.5.14 + teapot = 418, // RFC 7168, 2.3.3 + misdirected_request = 421, // RFC7540, Section 9.1.2 + unprocessable_entity = 422, // RFC4918 + locked = 423, // RFC4918 + failed_dependency = 424, // RFC4918 + too_early = 425, // RFC8470 + upgrade_required = 426, // RFC7231, Section 6.5.15 + precondition_required = 428, // RFC6585 + too_many_requests = 429, // RFC6585 + request_header_fields_too_large = 431, // RFC6585 + unavailable_for_legal_reasons = 451, // RFC7725 + + internal_server_error = 500, // RFC7231, Section 6.6.1 + not_implemented = 501, // RFC7231, Section 6.6.2 + bad_gateway = 502, // RFC7231, Section 6.6.3 + service_unavailable = 503, // RFC7231, Section 6.6.4 + gateway_timeout = 504, // RFC7231, Section 6.6.5 + http_version_not_supported = 505, // RFC7231, Section 6.6.6 + variant_also_negotiates = 506, // RFC2295 + insufficient_storage = 507, // RFC4918 + loop_detected = 508, // RFC5842 + not_extended = 510, // RFC2774 + network_authentication_required = 511, // RFC6585 + + _, + + pub fn phrase(self: Status) ?[]const u8 { + return switch (self) { + // 1xx statuses + .@"continue" => "Continue", + .switching_protocols => "Switching Protocols", + .processing => "Processing", + .early_hints => "Early Hints", + + // 2xx statuses + .ok => "OK", + .created => "Created", + .accepted => "Accepted", + .non_authoritative_info => "Non-Authoritative Information", + .no_content => "No Content", + .reset_content => "Reset Content", + .partial_content => "Partial Content", + .multi_status => "Multi-Status", + .already_reported => "Already Reported", + .im_used => "IM Used", + + // 3xx statuses + .multiple_choice => "Multiple Choice", + .moved_permanently => "Moved Permanently", + .found => "Found", + .see_other => "See Other", + .not_modified => "Not Modified", + .use_proxy => "Use Proxy", + .temporary_redirect => "Temporary Redirect", + .permanent_redirect => "Permanent Redirect", + + // 4xx statuses + .bad_request => "Bad Request", + .unauthorized => "Unauthorized", + .payment_required => "Payment Required", + .forbidden => "Forbidden", + .not_found => "Not Found", + .method_not_allowed => "Method Not Allowed", + .not_acceptable => "Not Acceptable", + .proxy_auth_required => "Proxy Authentication Required", + .request_timeout => "Request Timeout", + .conflict => "Conflict", + .gone => "Gone", + .length_required => "Length Required", + .precondition_failed => "Precondition Failed", + .payload_too_large => "Payload Too Large", + .uri_too_long => "URI Too Long", + .unsupported_media_type => "Unsupported Media Type", + .range_not_satisfiable => "Range Not Satisfiable", + .expectation_failed => "Expectation Failed", + .teapot => "I'm a teapot", + .misdirected_request => "Misdirected Request", + .unprocessable_entity => "Unprocessable Entity", + .locked => "Locked", + .failed_dependency => "Failed Dependency", + .too_early => "Too Early", + .upgrade_required => "Upgrade Required", + .precondition_required => "Precondition Required", + .too_many_requests => "Too Many Requests", + .request_header_fields_too_large => "Request Header Fields Too Large", + .unavailable_for_legal_reasons => "Unavailable For Legal Reasons", + + // 5xx statuses + .internal_server_error => "Internal Server Error", + .not_implemented => "Not Implemented", + .bad_gateway => "Bad Gateway", + .service_unavailable => "Service Unavailable", + .gateway_timeout => "Gateway Timeout", + .http_version_not_supported => "HTTP Version Not Supported", + .variant_also_negotiates => "Variant Also Negotiates", + .insufficient_storage => "Insufficient Storage", + .loop_detected => "Loop Detected", + .not_extended => "Not Extended", + .network_authentication_required => "Network Authentication Required", + + else => return null, + }; + } + + pub const Class = enum { + informational, + success, + redirect, + client_error, + server_error, + }; + + pub fn class(self: Status) Class { + return switch (@intFromEnum(self)) { + 100...199 => .informational, + 200...299 => .success, + 300...399 => .redirect, + 400...499 => .client_error, + else => .server_error, + }; + } + + test { + try std.testing.expectEqualStrings("OK", Status.ok.phrase().?); + try std.testing.expectEqualStrings("Not Found", Status.not_found.phrase().?); + } + + test { + try std.testing.expectEqual(Status.Class.success, Status.ok.class()); + try std.testing.expectEqual(Status.Class.client_error, Status.not_found.class()); + } +}; + +pub const TransferEncoding = enum { + chunked, + none, + // compression is intentionally omitted here, as std.http.Client stores it as content-encoding +}; + +pub const ContentEncoding = enum { + identity, + compress, + @"x-compress", + deflate, + gzip, + @"x-gzip", + zstd, +}; + +pub const Connection = enum { + keep_alive, + close, +}; + +pub const Header = struct { + name: []const u8, + value: []const u8, +}; + +const builtin = @import("builtin"); +const std = @import("std"); + +test { + _ = Client; + _ = Method; + _ = Server; + _ = Status; +} diff --git a/src/http/async/std/http/Client.zig b/src/http/async/std/http/Client.zig new file mode 100644 index 00000000..f0c37b20 --- /dev/null +++ b/src/http/async/std/http/Client.zig @@ -0,0 +1,2545 @@ +//! HTTP(S) Client implementation. +//! +//! Connections are opened in a thread-safe manner, but individual Requests are not. +//! +//! TLS support may be disabled via `std.options.http_disable_tls`. + +const std = @import("std"); +const builtin = @import("builtin"); +const testing = std.testing; +const http = std.http; +const mem = std.mem; +const net = @import("../net.zig"); +const Uri = std.Uri; +const Allocator = mem.Allocator; +const assert = std.debug.assert; +const use_vectors = builtin.zig_backend != .stage2_x86_64; + +const Client = @This(); +const proto = @import("protocol.zig"); + +const tls23 = @import("../../tls.zig/main.zig"); +const VecPut = @import("../../tls.zig/connection.zig").VecPut; +const GenericStack = @import("../../stack.zig").Stack; +const async_io = @import("../../io.zig"); +pub const Loop = async_io.SingleThreaded(Ctx); + +const cipher = @import("../../tls.zig/cipher.zig"); + +pub const disable_tls = std.options.http_disable_tls; + +/// Used for all client allocations. Must be thread-safe. +allocator: Allocator, + +ca_bundle: if (disable_tls) void else std.crypto.Certificate.Bundle = if (disable_tls) {} else .{}, +ca_bundle_mutex: std.Thread.Mutex = .{}, + +/// When this is `true`, the next time this client performs an HTTPS request, +/// it will first rescan the system for root certificates. +next_https_rescan_certs: bool = true, + +/// The pool of connections that can be reused (and currently in use). +connection_pool: ConnectionPool = .{}, + +/// If populated, all http traffic travels through this third party. +/// This field cannot be modified while the client has active connections. +/// Pointer to externally-owned memory. +http_proxy: ?*Proxy = null, +/// If populated, all https traffic travels through this third party. +/// This field cannot be modified while the client has active connections. +/// Pointer to externally-owned memory. +https_proxy: ?*Proxy = null, + +/// A set of linked lists of connections that can be reused. +pub const ConnectionPool = struct { + mutex: std.Thread.Mutex = .{}, + /// Open connections that are currently in use. + used: Queue = .{}, + /// Open connections that are not currently in use. + free: Queue = .{}, + free_len: usize = 0, + free_size: usize = 32, + + /// The criteria for a connection to be considered a match. + pub const Criteria = struct { + host: []const u8, + port: u16, + protocol: Connection.Protocol, + }; + + const Queue = std.DoublyLinkedList(Connection); + pub const Node = Queue.Node; + + /// Finds and acquires a connection from the connection pool matching the criteria. This function is threadsafe. + /// If no connection is found, null is returned. + pub fn findConnection(pool: *ConnectionPool, criteria: Criteria) ?*Connection { + pool.mutex.lock(); + defer pool.mutex.unlock(); + + var next = pool.free.last; + while (next) |node| : (next = node.prev) { + if (node.data.protocol != criteria.protocol) continue; + if (node.data.port != criteria.port) continue; + + // Domain names are case-insensitive (RFC 5890, Section 2.3.2.4) + if (!std.ascii.eqlIgnoreCase(node.data.host, criteria.host)) continue; + + pool.acquireUnsafe(node); + return &node.data; + } + + return null; + } + + /// Acquires an existing connection from the connection pool. This function is not threadsafe. + pub fn acquireUnsafe(pool: *ConnectionPool, node: *Node) void { + pool.free.remove(node); + pool.free_len -= 1; + + pool.used.append(node); + } + + /// Acquires an existing connection from the connection pool. This function is threadsafe. + pub fn acquire(pool: *ConnectionPool, node: *Node) void { + pool.mutex.lock(); + defer pool.mutex.unlock(); + + return pool.acquireUnsafe(node); + } + + /// Tries to release a connection back to the connection pool. This function is threadsafe. + /// If the connection is marked as closing, it will be closed instead. + /// + /// The allocator must be the owner of all nodes in this pool. + /// The allocator must be the owner of all resources associated with the connection. + pub fn release(pool: *ConnectionPool, allocator: Allocator, connection: *Connection) void { + pool.mutex.lock(); + defer pool.mutex.unlock(); + + const node: *Node = @fieldParentPtr("data", connection); + + pool.used.remove(node); + + if (node.data.closing or pool.free_size == 0) { + node.data.close(allocator); + return allocator.destroy(node); + } + + if (pool.free_len >= pool.free_size) { + const popped = pool.free.popFirst() orelse unreachable; + pool.free_len -= 1; + + popped.data.close(allocator); + allocator.destroy(popped); + } + + if (node.data.proxied) { + pool.free.prepend(node); // proxied connections go to the end of the queue, always try direct connections first + } else { + pool.free.append(node); + } + + pool.free_len += 1; + } + + /// Adds a newly created node to the pool of used connections. This function is threadsafe. + pub fn addUsed(pool: *ConnectionPool, node: *Node) void { + pool.mutex.lock(); + defer pool.mutex.unlock(); + + pool.used.append(node); + } + + /// Resizes the connection pool. This function is threadsafe. + /// + /// If the new size is smaller than the current size, then idle connections will be closed until the pool is the new size. + pub fn resize(pool: *ConnectionPool, allocator: Allocator, new_size: usize) void { + pool.mutex.lock(); + defer pool.mutex.unlock(); + + const next = pool.free.first; + _ = next; + while (pool.free_len > new_size) { + const popped = pool.free.popFirst() orelse unreachable; + pool.free_len -= 1; + + popped.data.close(allocator); + allocator.destroy(popped); + } + + pool.free_size = new_size; + } + + /// Frees the connection pool and closes all connections within. This function is threadsafe. + /// + /// All future operations on the connection pool will deadlock. + pub fn deinit(pool: *ConnectionPool, allocator: Allocator) void { + pool.mutex.lock(); + + var next = pool.free.first; + while (next) |node| { + defer allocator.destroy(node); + next = node.next; + + node.data.close(allocator); + } + + next = pool.used.first; + while (next) |node| { + defer allocator.destroy(node); + next = node.next; + + node.data.close(allocator); + } + + pool.* = undefined; + } +}; + +/// An interface to either a plain or TLS connection. +pub const Connection = struct { + stream: net.Stream, + /// undefined unless protocol is tls. + tls_client: if (!disable_tls) *tls23.Connection(net.Stream) else void, + + /// The protocol that this connection is using. + protocol: Protocol, + + /// The host that this connection is connected to. + host: []u8, + + /// The port that this connection is connected to. + port: u16, + + /// Whether this connection is proxied and is not directly connected. + proxied: bool = false, + + /// Whether this connection is closing when we're done with it. + closing: bool = false, + + read_start: BufferSize = 0, + read_end: BufferSize = 0, + write_end: BufferSize = 0, + read_buf: [buffer_size]u8 = undefined, + write_buf: [buffer_size]u8 = undefined, + + pub const buffer_size = std.crypto.tls.max_ciphertext_record_len; + const BufferSize = std.math.IntFittingRange(0, buffer_size); + + pub const Protocol = enum { plain, tls }; + + pub fn async_readvDirect( + conn: *Connection, + buffers: []std.posix.iovec, + ctx: *Ctx, + comptime cbk: Cbk, + ) !void { + _ = conn; + + if (ctx.conn().protocol == .tls) { + if (disable_tls) unreachable; + + return ctx.conn().tls_client.async_readv(ctx.conn().stream, buffers, ctx, cbk); + } + + return ctx.stream().async_readv(buffers, ctx, cbk); + } + + pub fn readvDirectTls(conn: *Connection, buffers: []std.posix.iovec) ReadError!usize { + return conn.tls_client.readv(buffers) catch |err| { + // https://github.com/ziglang/zig/issues/2473 + if (mem.startsWith(u8, @errorName(err), "TlsAlert")) return error.TlsAlert; + + switch (err) { + error.TlsRecordOverflow, error.TlsBadRecordMac, error.TlsUnexpectedMessage => return error.TlsFailure, + error.ConnectionTimedOut => return error.ConnectionTimedOut, + error.ConnectionResetByPeer, error.BrokenPipe => return error.ConnectionResetByPeer, + else => return error.UnexpectedReadFailure, + } + }; + } + + pub fn readvDirect(conn: *Connection, buffers: []std.posix.iovec) ReadError!usize { + if (conn.protocol == .tls) { + if (disable_tls) unreachable; + + return conn.readvDirectTls(buffers); + } + + return conn.stream.readv(buffers) catch |err| switch (err) { + error.ConnectionTimedOut => return error.ConnectionTimedOut, + error.ConnectionResetByPeer, error.BrokenPipe => return error.ConnectionResetByPeer, + else => return error.UnexpectedReadFailure, + }; + } + + fn onFill(ctx: *Ctx, res: anyerror!void) anyerror!void { + ctx.alloc().free(ctx._iovecs); + res catch |err| return ctx.pop(err); + + // EOF + const nread = ctx.len(); + if (nread == 0) return ctx.pop(error.EndOfStream); + + // finished + ctx.conn().read_start = 0; + ctx.conn().read_end = @intCast(nread); + return ctx.pop({}); + } + + pub fn async_fill(conn: *Connection, ctx: *Ctx, comptime cbk: Cbk) !void { + if (conn.read_end != conn.read_start) return; + + ctx._iovecs = try ctx.alloc().alloc(std.posix.iovec, 1); + errdefer ctx.alloc().free(ctx._iovecs); + const iovecs = [1]std.posix.iovec{ + .{ .base = &conn.read_buf, .len = conn.read_buf.len }, + }; + @memcpy(ctx._iovecs, &iovecs); + + try ctx.push(cbk); + return conn.async_readvDirect(ctx._iovecs, ctx, onFill); + } + + /// Refills the read buffer with data from the connection. + pub fn fill(conn: *Connection) ReadError!void { + if (conn.read_end != conn.read_start) return; + + var iovecs = [1]std.posix.iovec{ + .{ .base = &conn.read_buf, .len = conn.read_buf.len }, + }; + const nread = try conn.readvDirect(&iovecs); + if (nread == 0) return error.EndOfStream; + conn.read_start = 0; + conn.read_end = @intCast(nread); + } + + /// Returns the current slice of buffered data. + pub fn peek(conn: *Connection) []const u8 { + return conn.read_buf[conn.read_start..conn.read_end]; + } + + /// Discards the given number of bytes from the read buffer. + pub fn drop(conn: *Connection, num: BufferSize) void { + conn.read_start += num; + } + + /// Reads data from the connection into the given buffer. + pub fn read(conn: *Connection, buffer: []u8) ReadError!usize { + const available_read = conn.read_end - conn.read_start; + const available_buffer = buffer.len; + + if (available_read > available_buffer) { // partially read buffered data + @memcpy(buffer[0..available_buffer], conn.read_buf[conn.read_start..conn.read_end][0..available_buffer]); + conn.read_start += @intCast(available_buffer); + + return available_buffer; + } else if (available_read > 0) { // fully read buffered data + @memcpy(buffer[0..available_read], conn.read_buf[conn.read_start..conn.read_end]); + conn.read_start += available_read; + + return available_read; + } + + var iovecs = [2]std.posix.iovec{ + .{ .base = buffer.ptr, .len = buffer.len }, + .{ .base = &conn.read_buf, .len = conn.read_buf.len }, + }; + const nread = try conn.readvDirect(&iovecs); + + if (nread > buffer.len) { + conn.read_start = 0; + conn.read_end = @intCast(nread - buffer.len); + return buffer.len; + } + + return nread; + } + + pub const ReadError = error{ + TlsFailure, + TlsAlert, + ConnectionTimedOut, + ConnectionResetByPeer, + UnexpectedReadFailure, + EndOfStream, + }; + + pub const Reader = std.io.Reader(*Connection, ReadError, read); + + pub fn reader(conn: *Connection) Reader { + return Reader{ .context = conn }; + } + + pub fn writeAllDirectTls(conn: *Connection, buffer: []const u8) WriteError!void { + return conn.tls_client.writeAll(buffer) catch |err| switch (err) { + error.BrokenPipe, error.ConnectionResetByPeer => return error.ConnectionResetByPeer, + else => return error.UnexpectedWriteFailure, + }; + } + + fn onWriteAllDirect(ctx: *Ctx, res: anyerror!void) !void { + res catch |err| switch (err) { + error.BrokenPipe, + error.ConnectionResetByPeer, + => return ctx.pop(error.ConnectionResetByPeer), + else => return ctx.pop(error.UnexpectedWriteFailure), + }; + return ctx.pop({}); + } + + pub fn async_writeAllDirect( + conn: *Connection, + buffer: []const u8, + ctx: *Ctx, + comptime cbk: Cbk, + ) !void { + try ctx.push(cbk); + if (conn.protocol == .tls) { + if (disable_tls) unreachable; + + return conn.tls_client.async_writeAll(conn.stream, buffer, ctx, onWriteAllDirect); + } + + return conn.stream.async_writeAll(buffer, ctx, onWriteAllDirect); + } + + pub fn writeAllDirect(conn: *Connection, buffer: []const u8) WriteError!void { + if (conn.protocol == .tls) { + if (disable_tls) unreachable; + + return conn.writeAllDirectTls(buffer); + } + + return conn.stream.writeAll(buffer) catch |err| switch (err) { + error.BrokenPipe, error.ConnectionResetByPeer => return error.ConnectionResetByPeer, + else => return error.UnexpectedWriteFailure, + }; + } + + /// Writes the given buffer to the connection. + pub fn write(conn: *Connection, buffer: []const u8) WriteError!usize { + if (conn.write_buf.len - conn.write_end < buffer.len) { + try conn.flush(); + + if (buffer.len > conn.write_buf.len) { + try conn.writeAllDirect(buffer); + return buffer.len; + } + } + + @memcpy(conn.write_buf[conn.write_end..][0..buffer.len], buffer); + conn.write_end += @intCast(buffer.len); + + return buffer.len; + } + + /// Returns a buffer to be filled with exactly len bytes to write to the connection. + pub fn allocWriteBuffer(conn: *Connection, len: BufferSize) WriteError![]u8 { + if (conn.write_buf.len - conn.write_end < len) try conn.flush(); + defer conn.write_end += len; + return conn.write_buf[conn.write_end..][0..len]; + } + + fn onFlush(ctx: *Ctx, res: anyerror!void) !void { + res catch |err| return ctx.pop(err); + ctx.conn().write_end = 0; + return ctx.pop({}); + } + + pub fn async_flush(conn: *Connection, ctx: *Ctx, comptime cbk: Cbk) !void { + if (conn.write_end == 0) return error.WriteEmpty; + + try ctx.push(cbk); + try conn.async_writeAllDirect(conn.write_buf[0..conn.write_end], ctx, onFlush); + } + + /// Flushes the write buffer to the connection. + pub fn flush(conn: *Connection) WriteError!void { + if (conn.write_end == 0) return; + + try conn.writeAllDirect(conn.write_buf[0..conn.write_end]); + conn.write_end = 0; + } + + pub const WriteError = error{ + ConnectionResetByPeer, + UnexpectedWriteFailure, + }; + + pub const Writer = std.io.Writer(*Connection, WriteError, write); + + pub fn writer(conn: *Connection) Writer { + return Writer{ .context = conn }; + } + + /// Closes the connection. + pub fn close(conn: *Connection, allocator: Allocator) void { + if (conn.protocol == .tls) { + if (disable_tls) unreachable; + + // try to cleanly close the TLS connection, for any server that cares. + conn.tls_client.close() catch {}; + allocator.destroy(conn.tls_client); + } + + conn.stream.close(); + allocator.free(conn.host); + } +}; + +/// The mode of transport for requests. +pub const RequestTransfer = union(enum) { + content_length: u64, + chunked: void, + none: void, +}; + +/// The decompressor for response messages. +pub const Compression = union(enum) { + pub const DeflateDecompressor = std.compress.zlib.Decompressor(Request.TransferReader); + pub const GzipDecompressor = std.compress.gzip.Decompressor(Request.TransferReader); + // https://github.com/ziglang/zig/issues/18937 + //pub const ZstdDecompressor = std.compress.zstd.DecompressStream(Request.TransferReader, .{}); + + deflate: DeflateDecompressor, + gzip: GzipDecompressor, + // https://github.com/ziglang/zig/issues/18937 + //zstd: ZstdDecompressor, + none: void, +}; + +/// A HTTP response originating from a server. +pub const Response = struct { + version: http.Version, + status: http.Status, + reason: []const u8, + + /// Points into the user-provided `server_header_buffer`. + location: ?[]const u8 = null, + /// Points into the user-provided `server_header_buffer`. + content_type: ?[]const u8 = null, + /// Points into the user-provided `server_header_buffer`. + content_disposition: ?[]const u8 = null, + + keep_alive: bool, + + /// If present, the number of bytes in the response body. + content_length: ?u64 = null, + + /// If present, the transfer encoding of the response body, otherwise none. + transfer_encoding: http.TransferEncoding = .none, + + /// If present, the compression of the response body, otherwise identity (no compression). + transfer_compression: http.ContentEncoding = .identity, + + parser: proto.HeadersParser, + compression: Compression = .none, + + /// Whether the response body should be skipped. Any data read from the + /// response body will be discarded. + skip: bool = false, + + pub const ParseError = error{ + HttpHeadersInvalid, + HttpHeaderContinuationsUnsupported, + HttpTransferEncodingUnsupported, + HttpConnectionHeaderUnsupported, + InvalidContentLength, + CompressionUnsupported, + }; + + pub fn parse(res: *Response, bytes: []const u8) ParseError!void { + var it = mem.splitSequence(u8, bytes, "\r\n"); + + const first_line = it.next().?; + if (first_line.len < 12) { + return error.HttpHeadersInvalid; + } + + const version: http.Version = switch (int64(first_line[0..8])) { + int64("HTTP/1.0") => .@"HTTP/1.0", + int64("HTTP/1.1") => .@"HTTP/1.1", + else => return error.HttpHeadersInvalid, + }; + if (first_line[8] != ' ') return error.HttpHeadersInvalid; + const status: http.Status = @enumFromInt(parseInt3(first_line[9..12])); + const reason = mem.trimLeft(u8, first_line[12..], " "); + + res.version = version; + res.status = status; + res.reason = reason; + res.keep_alive = switch (version) { + .@"HTTP/1.0" => false, + .@"HTTP/1.1" => true, + }; + + while (it.next()) |line| { + if (line.len == 0) return; + switch (line[0]) { + ' ', '\t' => return error.HttpHeaderContinuationsUnsupported, + else => {}, + } + + var line_it = mem.splitScalar(u8, line, ':'); + const header_name = line_it.next().?; + const header_value = mem.trim(u8, line_it.rest(), " \t"); + if (header_name.len == 0) return error.HttpHeadersInvalid; + + if (std.ascii.eqlIgnoreCase(header_name, "connection")) { + res.keep_alive = !std.ascii.eqlIgnoreCase(header_value, "close"); + } else if (std.ascii.eqlIgnoreCase(header_name, "content-type")) { + res.content_type = header_value; + } else if (std.ascii.eqlIgnoreCase(header_name, "location")) { + res.location = header_value; + } else if (std.ascii.eqlIgnoreCase(header_name, "content-disposition")) { + res.content_disposition = header_value; + } else if (std.ascii.eqlIgnoreCase(header_name, "transfer-encoding")) { + // Transfer-Encoding: second, first + // Transfer-Encoding: deflate, chunked + var iter = mem.splitBackwardsScalar(u8, header_value, ','); + + const first = iter.first(); + const trimmed_first = mem.trim(u8, first, " "); + + var next: ?[]const u8 = first; + if (std.meta.stringToEnum(http.TransferEncoding, trimmed_first)) |transfer| { + if (res.transfer_encoding != .none) return error.HttpHeadersInvalid; // we already have a transfer encoding + res.transfer_encoding = transfer; + + next = iter.next(); + } + + if (next) |second| { + const trimmed_second = mem.trim(u8, second, " "); + + if (std.meta.stringToEnum(http.ContentEncoding, trimmed_second)) |transfer| { + if (res.transfer_compression != .identity) return error.HttpHeadersInvalid; // double compression is not supported + res.transfer_compression = transfer; + } else { + return error.HttpTransferEncodingUnsupported; + } + } + + if (iter.next()) |_| return error.HttpTransferEncodingUnsupported; + } else if (std.ascii.eqlIgnoreCase(header_name, "content-length")) { + const content_length = std.fmt.parseInt(u64, header_value, 10) catch return error.InvalidContentLength; + + if (res.content_length != null and res.content_length != content_length) return error.HttpHeadersInvalid; + + res.content_length = content_length; + } else if (std.ascii.eqlIgnoreCase(header_name, "content-encoding")) { + if (res.transfer_compression != .identity) return error.HttpHeadersInvalid; + + const trimmed = mem.trim(u8, header_value, " "); + + if (std.meta.stringToEnum(http.ContentEncoding, trimmed)) |ce| { + res.transfer_compression = ce; + } else { + return error.HttpTransferEncodingUnsupported; + } + } + } + return error.HttpHeadersInvalid; // missing empty line + } + + test parse { + const response_bytes = "HTTP/1.1 200 OK\r\n" ++ + "LOcation:url\r\n" ++ + "content-tYpe: text/plain\r\n" ++ + "content-disposition:attachment; filename=example.txt \r\n" ++ + "content-Length:10\r\n" ++ + "TRansfer-encoding:\tdeflate, chunked \r\n" ++ + "connectioN:\t keep-alive \r\n\r\n"; + + var header_buffer: [1024]u8 = undefined; + var res = Response{ + .status = undefined, + .reason = undefined, + .version = undefined, + .keep_alive = false, + .parser = proto.HeadersParser.init(&header_buffer), + }; + + @memcpy(header_buffer[0..response_bytes.len], response_bytes); + res.parser.header_bytes_len = response_bytes.len; + + try res.parse(response_bytes); + + try testing.expectEqual(.@"HTTP/1.1", res.version); + try testing.expectEqualStrings("OK", res.reason); + try testing.expectEqual(.ok, res.status); + + try testing.expectEqualStrings("url", res.location.?); + try testing.expectEqualStrings("text/plain", res.content_type.?); + try testing.expectEqualStrings("attachment; filename=example.txt", res.content_disposition.?); + + try testing.expectEqual(true, res.keep_alive); + try testing.expectEqual(10, res.content_length.?); + try testing.expectEqual(.chunked, res.transfer_encoding); + try testing.expectEqual(.deflate, res.transfer_compression); + } + + inline fn int64(array: *const [8]u8) u64 { + return @bitCast(array.*); + } + + fn parseInt3(text: *const [3]u8) u10 { + if (use_vectors) { + const nnn: @Vector(3, u8) = text.*; + const zero: @Vector(3, u8) = .{ '0', '0', '0' }; + const mmm: @Vector(3, u10) = .{ 100, 10, 1 }; + return @reduce(.Add, @as(@Vector(3, u10), nnn -% zero) *% mmm); + } + return std.fmt.parseInt(u10, text, 10) catch unreachable; + } + + test parseInt3 { + const expectEqual = testing.expectEqual; + try expectEqual(@as(u10, 0), parseInt3("000")); + try expectEqual(@as(u10, 418), parseInt3("418")); + try expectEqual(@as(u10, 999), parseInt3("999")); + } + + pub fn iterateHeaders(r: Response) http.HeaderIterator { + return http.HeaderIterator.init(r.parser.get()); + } + + test iterateHeaders { + const response_bytes = "HTTP/1.1 200 OK\r\n" ++ + "LOcation:url\r\n" ++ + "content-tYpe: text/plain\r\n" ++ + "content-disposition:attachment; filename=example.txt \r\n" ++ + "content-Length:10\r\n" ++ + "TRansfer-encoding:\tdeflate, chunked \r\n" ++ + "connectioN:\t keep-alive \r\n\r\n"; + + var header_buffer: [1024]u8 = undefined; + var res = Response{ + .status = undefined, + .reason = undefined, + .version = undefined, + .keep_alive = false, + .parser = proto.HeadersParser.init(&header_buffer), + }; + + @memcpy(header_buffer[0..response_bytes.len], response_bytes); + res.parser.header_bytes_len = response_bytes.len; + + var it = res.iterateHeaders(); + { + const header = it.next().?; + try testing.expectEqualStrings("LOcation", header.name); + try testing.expectEqualStrings("url", header.value); + try testing.expect(!it.is_trailer); + } + { + const header = it.next().?; + try testing.expectEqualStrings("content-tYpe", header.name); + try testing.expectEqualStrings("text/plain", header.value); + try testing.expect(!it.is_trailer); + } + { + const header = it.next().?; + try testing.expectEqualStrings("content-disposition", header.name); + try testing.expectEqualStrings("attachment; filename=example.txt", header.value); + try testing.expect(!it.is_trailer); + } + { + const header = it.next().?; + try testing.expectEqualStrings("content-Length", header.name); + try testing.expectEqualStrings("10", header.value); + try testing.expect(!it.is_trailer); + } + { + const header = it.next().?; + try testing.expectEqualStrings("TRansfer-encoding", header.name); + try testing.expectEqualStrings("deflate, chunked", header.value); + try testing.expect(!it.is_trailer); + } + { + const header = it.next().?; + try testing.expectEqualStrings("connectioN", header.name); + try testing.expectEqualStrings("keep-alive", header.value); + try testing.expect(!it.is_trailer); + } + try testing.expectEqual(null, it.next()); + } +}; + +/// A HTTP request that has been sent. +/// +/// Order of operations: open -> send[ -> write -> finish] -> wait -> read +pub const Request = struct { + uri: Uri = undefined, + client: *Client, + /// This is null when the connection is released. + connection: ?*Connection = null, + keep_alive: bool = undefined, + + method: http.Method = undefined, + version: http.Version = .@"HTTP/1.1", + transfer_encoding: RequestTransfer = undefined, + redirect_behavior: RedirectBehavior = undefined, + + /// Whether the request should handle a 100-continue response before sending the request body. + handle_continue: bool = undefined, + + /// The response associated with this request. + /// + /// This field is undefined until `wait` is called. + response: Response = undefined, + + /// Standard headers that have default, but overridable, behavior. + headers: Headers = undefined, + + /// These headers are kept including when following a redirect to a + /// different domain. + /// Externally-owned; must outlive the Request. + extra_headers: []const http.Header = undefined, + + /// These headers are stripped when following a redirect to a different + /// domain. + /// Externally-owned; must outlive the Request. + privileged_headers: []const http.Header = undefined, + + pub fn init(client: *Client) Request { + return .{ + .client = client, + }; + } + + pub const Headers = struct { + host: Value = .default, + authorization: Value = .default, + user_agent: Value = .default, + connection: Value = .default, + accept_encoding: Value = .default, + content_type: Value = .default, + + pub const Value = union(enum) { + default, + omit, + override: []const u8, + }; + }; + + /// Any value other than `not_allowed` or `unhandled` means that integer represents + /// how many remaining redirects are allowed. + pub const RedirectBehavior = enum(u16) { + /// The next redirect will cause an error. + not_allowed = 0, + /// Redirects are passed to the client to analyze the redirect response + /// directly. + unhandled = std.math.maxInt(u16), + _, + + pub fn subtractOne(rb: *RedirectBehavior) void { + switch (rb.*) { + .not_allowed => unreachable, + .unhandled => unreachable, + _ => rb.* = @enumFromInt(@intFromEnum(rb.*) - 1), + } + } + + pub fn remaining(rb: RedirectBehavior) u16 { + assert(rb != .unhandled); + return @intFromEnum(rb); + } + }; + + /// Frees all resources associated with the request. + pub fn deinit(req: *Request) void { + if (req.connection) |connection| { + if (!req.response.parser.done) { + // If the response wasn't fully read, then we need to close the connection. + connection.closing = true; + } + req.client.connection_pool.release(req.client.allocator, connection); + } + req.* = undefined; + } + + fn onRedirectSend(ctx: *Ctx, res: anyerror!void) !void { + res catch |err| return ctx.pop(err); + // go back on check headers + ctx.conn().async_fill(ctx, onResponseHeaders) catch |err| return ctx.pop(err); + } + + fn onRedirectConnect(ctx: *Ctx, res: anyerror!void) !void { + res catch |err| return ctx.pop(err); + // re-send request + ctx.req.prepareSend(.{}) catch |err| return ctx.pop(err); + ctx.req.connection.?.async_flush(ctx, onRedirectSend) catch |err| return ctx.pop(err); + } + + // async_redirect flow: + // connect -> setRequestConnection + // -> onRedirectConnect -> async_flush + // -> onRedirectSend -> async_fill + // -> go back on the wait workflow of the response + fn async_redirect(req: *Request, uri: Uri, ctx: *Ctx) !void { + try req.prepareRedirect(); + + var server_header = std.heap.FixedBufferAllocator.init(req.response.parser.header_bytes_buffer); + defer req.response.parser.header_bytes_buffer = server_header.buffer[server_header.end_index..]; + + const protocol, const valid_uri = try validateUri(uri, server_header.allocator()); + + const new_host = valid_uri.host.?.raw; + const prev_host = req.uri.host.?.raw; + const keep_privileged_headers = + std.ascii.eqlIgnoreCase(valid_uri.scheme, req.uri.scheme) and + std.ascii.endsWithIgnoreCase(new_host, prev_host) and + (new_host.len == prev_host.len or new_host[new_host.len - prev_host.len - 1] == '.'); + if (!keep_privileged_headers) { + // When redirecting to a different domain, strip privileged headers. + req.privileged_headers = &.{}; + } + + // create a new connection for the redirected URI + ctx.data.conn = try req.client.allocator.create(Connection); + ctx.data.conn.* = .{ + .stream = undefined, + .tls_client = undefined, + .protocol = undefined, + .host = undefined, + .port = undefined, + }; + req.uri = valid_uri; + return req.client.async_connect(new_host, uriPort(valid_uri, protocol), protocol, ctx, setRequestConnection); + } + + // This function must deallocate all resources associated with the request, + // or keep those which will be used. + // This needs to be kept in sync with deinit and request. + fn redirect(req: *Request, uri: Uri) !void { + try req.prepareRedirect(); + + var server_header = std.heap.FixedBufferAllocator.init(req.response.parser.header_bytes_buffer); + defer req.response.parser.header_bytes_buffer = server_header.buffer[server_header.end_index..]; + + const protocol, const valid_uri = try validateUri(uri, server_header.allocator()); + + const new_host = valid_uri.host.?.raw; + const prev_host = req.uri.host.?.raw; + const keep_privileged_headers = + std.ascii.eqlIgnoreCase(valid_uri.scheme, req.uri.scheme) and + std.ascii.endsWithIgnoreCase(new_host, prev_host) and + (new_host.len == prev_host.len or new_host[new_host.len - prev_host.len - 1] == '.'); + if (!keep_privileged_headers) { + // When redirecting to a different domain, strip privileged headers. + req.privileged_headers = &.{}; + } + + req.connection = try req.client.connect(new_host, uriPort(valid_uri, protocol), protocol); + req.uri = valid_uri; + } + fn prepareRedirect(req: *Request) !void { + assert(req.response.parser.done); + + req.client.connection_pool.release(req.client.allocator, req.connection.?); + req.connection = null; + + if (switch (req.response.status) { + .see_other => true, + .moved_permanently, .found => req.method == .POST, + else => false, + }) { + // A redirect to a GET must change the method and remove the body. + req.method = .GET; + req.transfer_encoding = .none; + req.headers.content_type = .omit; + } + + if (req.transfer_encoding != .none) { + // The request body has already been sent. The request is + // still in a valid state, but the redirect must be handled + // manually. + return error.RedirectRequiresResend; + } + + req.redirect_behavior.subtractOne(); + req.response.parser.reset(); + + req.response = .{ + .version = undefined, + .status = undefined, + .reason = undefined, + .keep_alive = undefined, + .parser = req.response.parser, + }; + } + + pub const SendError = Connection.WriteError || error{ InvalidContentLength, UnsupportedTransferEncoding }; + + pub fn async_send(req: *Request, ctx: *Ctx, comptime cbk: Cbk) !void { + try req.prepareSend(); + try req.connection.?.async_flush(ctx, cbk); + } + + /// Send the HTTP request headers to the server. + pub fn send(req: *Request) SendError!void { + try req.prepareSend(); + try req.connection.?.flush(); + } + + fn prepareSend(req: *Request) SendError!void { + if (!req.method.requestHasBody() and req.transfer_encoding != .none) + if (!req.method.requestHasBody() and req.transfer_encoding != .none) + return error.UnsupportedTransferEncoding; + + const connection = req.connection.?; + const w = connection.writer(); + + try req.method.write(w); + try w.writeByte(' '); + + if (req.method == .CONNECT) { + try req.uri.writeToStream(.{ .authority = true }, w); + } else { + try req.uri.writeToStream(.{ + .scheme = connection.proxied, + .authentication = connection.proxied, + .authority = connection.proxied, + .path = true, + .query = true, + }, w); + } + try w.writeByte(' '); + try w.writeAll(@tagName(req.version)); + try w.writeAll("\r\n"); + + if (try emitOverridableHeader("host: ", req.headers.host, w)) { + try w.writeAll("host: "); + try req.uri.writeToStream(.{ .authority = true }, w); + try w.writeAll("\r\n"); + } + + if (try emitOverridableHeader("authorization: ", req.headers.authorization, w)) { + if (req.uri.user != null or req.uri.password != null) { + try w.writeAll("authorization: "); + const authorization = try connection.allocWriteBuffer( + @intCast(basic_authorization.valueLengthFromUri(req.uri)), + ); + assert(basic_authorization.value(req.uri, authorization).len == authorization.len); + try w.writeAll("\r\n"); + } + } + + if (try emitOverridableHeader("user-agent: ", req.headers.user_agent, w)) { + try w.writeAll("user-agent: zig/"); + try w.writeAll(builtin.zig_version_string); + try w.writeAll(" (std.http)\r\n"); + } + + if (try emitOverridableHeader("connection: ", req.headers.connection, w)) { + if (req.keep_alive) { + try w.writeAll("connection: keep-alive\r\n"); + } else { + try w.writeAll("connection: close\r\n"); + } + } + + if (try emitOverridableHeader("accept-encoding: ", req.headers.accept_encoding, w)) { + // https://github.com/ziglang/zig/issues/18937 + //try w.writeAll("accept-encoding: gzip, deflate, zstd\r\n"); + try w.writeAll("accept-encoding: gzip, deflate\r\n"); + } + + switch (req.transfer_encoding) { + .chunked => try w.writeAll("transfer-encoding: chunked\r\n"), + .content_length => |len| try w.print("content-length: {d}\r\n", .{len}), + .none => {}, + } + + if (try emitOverridableHeader("content-type: ", req.headers.content_type, w)) { + // The default is to omit content-type if not provided because + // "application/octet-stream" is redundant. + } + + for (req.extra_headers) |header| { + assert(header.name.len != 0); + + try w.writeAll(header.name); + try w.writeAll(": "); + try w.writeAll(header.value); + try w.writeAll("\r\n"); + } + + if (connection.proxied) proxy: { + const proxy = switch (connection.protocol) { + .plain => req.client.http_proxy, + .tls => req.client.https_proxy, + } orelse break :proxy; + + const authorization = proxy.authorization orelse break :proxy; + try w.writeAll("proxy-authorization: "); + try w.writeAll(authorization); + try w.writeAll("\r\n"); + } + + try w.writeAll("\r\n"); + } + + /// Returns true if the default behavior is required, otherwise handles + /// writing (or not writing) the header. + fn emitOverridableHeader(prefix: []const u8, v: Headers.Value, w: anytype) !bool { + switch (v) { + .default => return true, + .omit => return false, + .override => |x| { + try w.writeAll(prefix); + try w.writeAll(x); + try w.writeAll("\r\n"); + return false; + }, + } + } + + const TransferReadError = Connection.ReadError || proto.HeadersParser.ReadError; + + const TransferReader = std.io.Reader(*Request, TransferReadError, transferRead); + + fn transferReader(req: *Request) TransferReader { + return .{ .context = req }; + } + + fn transferRead(req: *Request, buf: []u8) TransferReadError!usize { + if (req.response.parser.done) return 0; + + var index: usize = 0; + while (index == 0) { + const amt = try req.response.parser.read(req.connection.?, buf[index..], req.response.skip); + if (amt == 0 and req.response.parser.done) break; + index += amt; + } + + return index; + } + + pub const WaitError = RequestError || SendError || TransferReadError || + proto.HeadersParser.CheckCompleteHeadError || Response.ParseError || + error{ // TODO: file zig fmt issue for this bad indentation + TooManyHttpRedirects, + RedirectRequiresResend, + HttpRedirectLocationMissing, + HttpRedirectLocationInvalid, + CompressionInitializationFailed, + CompressionUnsupported, + }; + + pub fn async_wait(_: *Request, ctx: *Ctx, comptime cbk: Cbk) !void { + try ctx.push(cbk); + return ctx.conn().async_fill(ctx, onResponseHeaders); + } + + /// Waits for a response from the server and parses any headers that are sent. + /// This function will block until the final response is received. + /// + /// If handling redirects and the request has no payload, then this + /// function will automatically follow redirects. If a request payload is + /// present, then this function will error with + /// error.RedirectRequiresResend. + /// + /// Must be called after `send` and, if any data was written to the request + /// body, then also after `finish`. + pub fn wait(req: *Request) WaitError!void { + while (true) { + // This while loop is for handling redirects, which means the request's + // connection may be different than the previous iteration. However, it + // is still guaranteed to be non-null with each iteration of this loop. + const connection = req.connection.?; + + while (true) { // read headers + try connection.fill(); + + const nchecked = try req.response.parser.checkCompleteHead(connection.peek()); + connection.drop(@intCast(nchecked)); + + if (req.response.parser.state.isContent()) break; + } + + try req.response.parse(req.response.parser.get()); + + if (req.response.status == .@"continue") { + // We're done parsing the continue response; reset to prepare + // for the real response. + req.response.parser.done = true; + req.response.parser.reset(); + + if (req.handle_continue) + continue; + + return; // we're not handling the 100-continue + } + + // we're switching protocols, so this connection is no longer doing http + if (req.method == .CONNECT and req.response.status.class() == .success) { + connection.closing = false; + req.response.parser.done = true; + return; // the connection is not HTTP past this point + } + + connection.closing = !req.response.keep_alive or !req.keep_alive; + + // Any response to a HEAD request and any response with a 1xx + // (Informational), 204 (No Content), or 304 (Not Modified) status + // code is always terminated by the first empty line after the + // header fields, regardless of the header fields present in the + // message. + if (req.method == .HEAD or req.response.status.class() == .informational or + req.response.status == .no_content or req.response.status == .not_modified) + { + req.response.parser.done = true; + return; // The response is empty; no further setup or redirection is necessary. + } + + switch (req.response.transfer_encoding) { + .none => { + if (req.response.content_length) |cl| { + req.response.parser.next_chunk_length = cl; + + if (cl == 0) req.response.parser.done = true; + } else { + // read until the connection is closed + req.response.parser.next_chunk_length = std.math.maxInt(u64); + } + }, + .chunked => { + req.response.parser.next_chunk_length = 0; + req.response.parser.state = .chunk_head_size; + }, + } + + if (req.response.status.class() == .redirect and req.redirect_behavior != .unhandled) { + // skip the body of the redirect response, this will at least + // leave the connection in a known good state. + req.response.skip = true; + assert(try req.transferRead(&.{}) == 0); // we're skipping, no buffer is necessary + + if (req.redirect_behavior == .not_allowed) return error.TooManyHttpRedirects; + + const location = req.response.location orelse + return error.HttpRedirectLocationMissing; + + // This mutates the beginning of header_bytes_buffer and uses that + // for the backing memory of the returned Uri. + try req.redirect(req.uri.resolve_inplace( + location, + &req.response.parser.header_bytes_buffer, + ) catch |err| switch (err) { + error.UnexpectedCharacter, + error.InvalidFormat, + error.InvalidPort, + => return error.HttpRedirectLocationInvalid, + error.NoSpaceLeft => return error.HttpHeadersOversize, + }); + try req.send(); + } else { + req.response.skip = false; + if (!req.response.parser.done) { + switch (req.response.transfer_compression) { + .identity => req.response.compression = .none, + .compress, .@"x-compress" => return error.CompressionUnsupported, + .deflate => req.response.compression = .{ + .deflate = std.compress.zlib.decompressor(req.transferReader()), + }, + .gzip, .@"x-gzip" => req.response.compression = .{ + .gzip = std.compress.gzip.decompressor(req.transferReader()), + }, + // https://github.com/ziglang/zig/issues/18937 + //.zstd => req.response.compression = .{ + // .zstd = std.compress.zstd.decompressStream(req.client.allocator, req.transferReader()), + //}, + .zstd => return error.CompressionUnsupported, + } + } + + break; + } + } + } + + pub const ReadError = TransferReadError || proto.HeadersParser.CheckCompleteHeadError || + error{ DecompressionFailure, InvalidTrailers }; + + pub const Reader = std.io.Reader(*Request, ReadError, read); + + pub fn reader(req: *Request) Reader { + return .{ .context = req }; + } + + /// Reads data from the response body. Must be called after `wait`. + pub fn read(req: *Request, buffer: []u8) ReadError!usize { + const out_index = switch (req.response.compression) { + .deflate => |*deflate| deflate.read(buffer) catch return error.DecompressionFailure, + .gzip => |*gzip| gzip.read(buffer) catch return error.DecompressionFailure, + // https://github.com/ziglang/zig/issues/18937 + //.zstd => |*zstd| zstd.read(buffer) catch return error.DecompressionFailure, + else => try req.transferRead(buffer), + }; + if (out_index > 0) return out_index; + + while (!req.response.parser.state.isContent()) { // read trailing headers + try req.connection.?.fill(); + + const nchecked = try req.response.parser.checkCompleteHead(req.connection.?.peek()); + req.connection.?.drop(@intCast(nchecked)); + } + + return 0; + } + + /// Reads data from the response body. Must be called after `wait`. + pub fn readAll(req: *Request, buffer: []u8) !usize { + var index: usize = 0; + while (index < buffer.len) { + const amt = try read(req, buffer[index..]); + if (amt == 0) break; + index += amt; + } + return index; + } + + pub const WriteError = Connection.WriteError || error{ NotWriteable, MessageTooLong }; + + pub const Writer = std.io.Writer(*Request, WriteError, write); + + pub fn writer(req: *Request) Writer { + return .{ .context = req }; + } + + /// Write `bytes` to the server. The `transfer_encoding` field determines how data will be sent. + /// Must be called after `send` and before `finish`. + pub fn write(req: *Request, bytes: []const u8) WriteError!usize { + switch (req.transfer_encoding) { + .chunked => { + if (bytes.len > 0) { + try req.connection.?.writer().print("{x}\r\n", .{bytes.len}); + try req.connection.?.writer().writeAll(bytes); + try req.connection.?.writer().writeAll("\r\n"); + } + + return bytes.len; + }, + .content_length => |*len| { + if (len.* < bytes.len) return error.MessageTooLong; + + const amt = try req.connection.?.write(bytes); + len.* -= amt; + return amt; + }, + .none => return error.NotWriteable, + } + } + + /// Write `bytes` to the server. The `transfer_encoding` field determines how data will be sent. + /// Must be called after `send` and before `finish`. + pub fn writeAll(req: *Request, bytes: []const u8) WriteError!void { + var index: usize = 0; + while (index < bytes.len) { + index += try write(req, bytes[index..]); + } + } + + pub fn async_writeAll(req: *Request, buf: []const u8, ctx: *Ctx, comptime cbk: Cbk) !void { + try req.connection.?.async_writeAllDirect(buf, ctx, cbk); + } + + pub const FinishError = WriteError || error{MessageNotCompleted}; + + pub fn async_finish(req: *Request, ctx: *Ctx, comptime cbk: Cbk) !void { + try req.common_finish(); + req.connection.?.async_flush(ctx, cbk) catch |err| switch (err) { + error.WriteEmpty => return cbk(ctx, {}), + else => return cbk(ctx, err), + }; + } + + /// Finish the body of a request. This notifies the server that you have no more data to send. + /// Must be called after `send`. + pub fn finish(req: *Request) FinishError!void { + try req.common_finish(); + try req.connection.?.flush(); + } + + fn common_finish(req: *Request) FinishError!void { + switch (req.transfer_encoding) { + .chunked => try req.connection.?.writer().writeAll("0\r\n\r\n"), + .content_length => |len| if (len != 0) return error.MessageNotCompleted, + .none => {}, + } + } + + fn onResponseHeaders(ctx: *Ctx, res: anyerror!void) !void { + res catch |err| return ctx.pop(err); + const done = ctx.req.parseResponseHeaders() catch |err| return ctx.pop(err); + // if read of the headers is not done, continue + if (!done) return ctx.conn().async_fill(ctx, onResponseHeaders); + // if read of the headers is done, go read the reponse + return onResponse(ctx, {}); + } + + fn parseResponseHeaders(req: *Request) !bool { + const nchecked = try req.response.parser.checkCompleteHead(req.connection.?.peek()); + req.connection.?.drop(@intCast(nchecked)); + + if (req.response.parser.state.isContent()) return true; + return false; + } + + fn onResponse(ctx: *Ctx, res: anyerror!void) !void { + res catch |err| return ctx.pop(err); + const ret = ctx.req.parseResponse() catch |err| return ctx.pop(err); + if (ret.redirect_uri) |uri| { + ctx.req.async_redirect(uri, ctx) catch |err| return ctx.pop(err); + return; + } + // if read of the response is not done, continue + if (!ret.done) return ctx.conn().async_fill(ctx, onResponse); + // if read of the response is done, go execute the provided callback + return ctx.pop({}); + } + + const WaitRedirectsReturn = struct { + redirect_uri: ?Uri = null, + done: bool = true, + }; + + fn parseResponse(req: *Request) WaitError!WaitRedirectsReturn { + try req.response.parse(req.response.parser.get()); + + if (req.response.status == .@"continue") { + // We're done parsing the continue response; reset to prepare + // for the real response. + req.response.parser.done = true; + req.response.parser.reset(); + + if (req.handle_continue) return .{ .done = false }; + + return .{ .done = true }; + } + + // we're switching protocols, so this connection is no longer doing http + if (req.method == .CONNECT and req.response.status.class() == .success) { + req.connection.?.closing = false; + req.response.parser.done = true; + return .{ .done = true }; // the connection is not HTTP past this point + } + + req.connection.?.closing = !req.response.keep_alive or !req.keep_alive; + + // Any response to a HEAD request and any response with a 1xx + // (Informational), 204 (No Content), or 304 (Not Modified) status + // code is always terminated by the first empty line after the + // header fields, regardless of the header fields present in the + // message. + if (req.method == .HEAD or req.response.status.class() == .informational or + req.response.status == .no_content or req.response.status == .not_modified) + { + req.response.parser.done = true; + return .{ .done = true }; // The response is empty; no further setup or redirection is necessary. + } + + switch (req.response.transfer_encoding) { + .none => { + if (req.response.content_length) |cl| { + req.response.parser.next_chunk_length = cl; + + if (cl == 0) req.response.parser.done = true; + } else { + // read until the connection is closed + req.response.parser.next_chunk_length = std.math.maxInt(u64); + } + }, + .chunked => { + req.response.parser.next_chunk_length = 0; + req.response.parser.state = .chunk_head_size; + }, + } + + if (req.response.status.class() == .redirect and req.redirect_behavior != .unhandled) { + // skip the body of the redirect response, this will at least + // leave the connection in a known good state. + req.response.skip = true; + assert(try req.transferRead(&.{}) == 0); // we're skipping, no buffer is necessary + + if (req.redirect_behavior == .not_allowed) return error.TooManyHttpRedirects; + + const location = req.response.location orelse + return error.HttpRedirectLocationMissing; + + // This mutates the beginning of header_bytes_buffer and uses that + // for the backing memory of the returned Uri. + try req.redirect(req.uri.resolve_inplace( + location, + &req.response.parser.header_bytes_buffer, + ) catch |err| switch (err) { + error.UnexpectedCharacter, + error.InvalidFormat, + error.InvalidPort, + => return error.HttpRedirectLocationInvalid, + error.NoSpaceLeft => return error.HttpHeadersOversize, + }); + + return .{ .redirect_uri = req.uri }; + } else { + req.response.skip = false; + if (!req.response.parser.done) { + switch (req.response.transfer_compression) { + .identity => req.response.compression = .none, + .compress, .@"x-compress" => return error.CompressionUnsupported, + .deflate => req.response.compression = .{ + .deflate = std.compress.zlib.decompressor(req.transferReader()), + }, + .gzip, .@"x-gzip" => req.response.compression = .{ + .gzip = std.compress.gzip.decompressor(req.transferReader()), + }, + // https://github.com/ziglang/zig/issues/18937 + //.zstd => req.response.compression = .{ + // .zstd = std.compress.zstd.decompressStream(req.client.allocator, req.transferReader()), + //}, + .zstd => return error.CompressionUnsupported, + } + } + + return .{ .done = true }; + } + return .{ .done = false }; + } +}; + +pub const Proxy = struct { + protocol: Connection.Protocol, + host: []const u8, + authorization: ?[]const u8, + port: u16, + supports_connect: bool, +}; + +/// Release all associated resources with the client. +/// +/// All pending requests must be de-initialized and all active connections released +/// before calling this function. +pub fn deinit(client: *Client) void { + assert(client.connection_pool.used.first == null); // There are still active requests. + + client.connection_pool.deinit(client.allocator); + + if (!disable_tls) + client.ca_bundle.deinit(client.allocator); + + client.* = undefined; +} + +/// Populates `http_proxy` and `https_proxy` via standard proxy environment variables. +/// Asserts the client has no active connections. +/// Uses `arena` for a few small allocations that must outlive the client, or +/// at least until those fields are set to different values. +pub fn initDefaultProxies(client: *Client, arena: Allocator) !void { + // Prevent any new connections from being created. + client.connection_pool.mutex.lock(); + defer client.connection_pool.mutex.unlock(); + + assert(client.connection_pool.used.first == null); // There are active requests. + + if (client.http_proxy == null) { + client.http_proxy = try createProxyFromEnvVar(arena, &.{ + "http_proxy", "HTTP_PROXY", "all_proxy", "ALL_PROXY", + }); + } + + if (client.https_proxy == null) { + client.https_proxy = try createProxyFromEnvVar(arena, &.{ + "https_proxy", "HTTPS_PROXY", "all_proxy", "ALL_PROXY", + }); + } +} + +fn createProxyFromEnvVar(arena: Allocator, env_var_names: []const []const u8) !?*Proxy { + const content = for (env_var_names) |name| { + break std.process.getEnvVarOwned(arena, name) catch |err| switch (err) { + error.EnvironmentVariableNotFound => continue, + else => |e| return e, + }; + } else return null; + + const uri = Uri.parse(content) catch try Uri.parseAfterScheme("http", content); + const protocol, const valid_uri = validateUri(uri, arena) catch |err| switch (err) { + error.UnsupportedUriScheme => return null, + error.UriMissingHost => return error.HttpProxyMissingHost, + error.OutOfMemory => |e| return e, + }; + + const authorization: ?[]const u8 = if (valid_uri.user != null or valid_uri.password != null) a: { + const authorization = try arena.alloc(u8, basic_authorization.valueLengthFromUri(valid_uri)); + assert(basic_authorization.value(valid_uri, authorization).len == authorization.len); + break :a authorization; + } else null; + + const proxy = try arena.create(Proxy); + proxy.* = .{ + .protocol = protocol, + .host = valid_uri.host.?.raw, + .authorization = authorization, + .port = uriPort(valid_uri, protocol), + .supports_connect = true, + }; + return proxy; +} + +pub const basic_authorization = struct { + pub const max_user_len = 255; + pub const max_password_len = 255; + pub const max_value_len = valueLength(max_user_len, max_password_len); + + const prefix = "Basic "; + + pub fn valueLength(user_len: usize, password_len: usize) usize { + return prefix.len + std.base64.standard.Encoder.calcSize(user_len + 1 + password_len); + } + + pub fn valueLengthFromUri(uri: Uri) usize { + var stream = std.io.countingWriter(std.io.null_writer); + try stream.writer().print("{user}", .{uri.user orelse Uri.Component.empty}); + const user_len = stream.bytes_written; + stream.bytes_written = 0; + try stream.writer().print("{password}", .{uri.password orelse Uri.Component.empty}); + const password_len = stream.bytes_written; + return valueLength(@intCast(user_len), @intCast(password_len)); + } + + pub fn value(uri: Uri, out: []u8) []u8 { + var buf: [max_user_len + ":".len + max_password_len]u8 = undefined; + var stream = std.io.fixedBufferStream(&buf); + stream.writer().print("{user}", .{uri.user orelse Uri.Component.empty}) catch + unreachable; + assert(stream.pos <= max_user_len); + stream.writer().print(":{password}", .{uri.password orelse Uri.Component.empty}) catch + unreachable; + + @memcpy(out[0..prefix.len], prefix); + const base64 = std.base64.standard.Encoder.encode(out[prefix.len..], stream.getWritten()); + return out[0 .. prefix.len + base64.len]; + } +}; + +pub const ConnectTcpError = Allocator.Error || error{ ConnectionRefused, NetworkUnreachable, ConnectionTimedOut, ConnectionResetByPeer, TemporaryNameServerFailure, NameServerFailure, UnknownHostName, HostLacksNetworkAddresses, UnexpectedConnectFailure, TlsInitializationFailed }; + +// requires ctx.data.stream to be set +fn setConnection(ctx: *Ctx, res: anyerror!void) !void { + + // check stream + errdefer ctx.data.conn.stream.close(); + res catch |e| { + // ctx.data.conn.stream.close(); is it needed with errdefer? + switch (e) { + error.ConnectionRefused, + error.NetworkUnreachable, + error.ConnectionTimedOut, + error.ConnectionResetByPeer, + error.TemporaryNameServerFailure, + error.NameServerFailure, + error.UnknownHostName, + error.HostLacksNetworkAddresses, + => return ctx.pop(e), + else => return ctx.pop(error.UnexpectedConnectFailure), + } + }; + + if (ctx.data.conn.protocol == .tls) { + if (disable_tls) unreachable; + + ctx.data.conn.tls_client = try ctx.alloc().create(tls23.Connection(net.Stream)); + errdefer ctx.alloc().destroy(ctx.data.conn.tls_client); + + // TODO tls23.client does an handshake to pick a cipher. + ctx.data.conn.tls_client.* = tls23.client(ctx.data.conn.stream, .{ + .host = ctx.data.conn.host, + .root_ca = .{ .bundle = ctx.req.client.ca_bundle }, + }) catch return error.TlsInitializationFailed; + } + + // add connection node in pool + const node = ctx.req.client.allocator.create(ConnectionPool.Node) catch |e| return ctx.pop(e); + errdefer ctx.req.client.allocator.destroy(node); + // NOTE we can not use the ctx.data.conn pointer as a node connection data, + // we need to copy it's value and use this reference for the connection + node.* = .{ + .data = .{ + .stream = ctx.data.conn.stream, + .tls_client = ctx.data.conn.tls_client, + .protocol = ctx.data.conn.protocol, + .host = ctx.data.conn.host, + .port = ctx.data.conn.port, + }, + }; + // remove old pointer, now useless + const old_conn = ctx.data.conn; + defer ctx.req.client.allocator.destroy(old_conn); + + ctx.req.client.connection_pool.addUsed(node); + ctx.data.conn = &node.data; + + return ctx.pop({}); +} + +/// Connect to `host:port` using the specified protocol. This will reuse a connection if one is already open. +/// +/// This function is threadsafe. +pub fn connectTcp(client: *Client, host: []const u8, port: u16, protocol: Connection.Protocol) ConnectTcpError!*Connection { + if (client.connection_pool.findConnection(.{ + .host = host, + .port = port, + .protocol = protocol, + })) |node| return node; + + if (disable_tls and protocol == .tls) + return error.TlsInitializationFailed; + + const conn = try client.allocator.create(ConnectionPool.Node); + errdefer client.allocator.destroy(conn); + conn.* = .{ .data = undefined }; + + const stream = net.tcpConnectToHost(client.allocator, host, port) catch |err| switch (err) { + error.ConnectionRefused => return error.ConnectionRefused, + error.NetworkUnreachable => return error.NetworkUnreachable, + error.ConnectionTimedOut => return error.ConnectionTimedOut, + error.ConnectionResetByPeer => return error.ConnectionResetByPeer, + error.TemporaryNameServerFailure => return error.TemporaryNameServerFailure, + error.NameServerFailure => return error.NameServerFailure, + error.UnknownHostName => return error.UnknownHostName, + error.HostLacksNetworkAddresses => return error.HostLacksNetworkAddresses, + else => return error.UnexpectedConnectFailure, + }; + errdefer stream.close(); + + conn.data = .{ + .stream = stream, + .tls_client = undefined, + + .protocol = protocol, + .host = try client.allocator.dupe(u8, host), + .port = port, + }; + errdefer client.allocator.free(conn.data.host); + + if (protocol == .tls) { + if (disable_tls) unreachable; + + conn.data.tls_client = try client.allocator.create(tls23.Connection(net.Stream)); + errdefer client.allocator.destroy(conn.data.tls_client); + + // TODO tls23.client does an handshake to pick a cipher. + conn.data.tls_client.* = tls23.client(stream, .{ + .host = host, + .root_ca = .{ .bundle = client.ca_bundle }, + }) catch return error.TlsInitializationFailed; + } + + client.connection_pool.addUsed(conn); + + return &conn.data; +} + +pub fn async_connectTcp( + client: *Client, + host: []const u8, + port: u16, + protocol: Connection.Protocol, + ctx: *Ctx, + comptime cbk: Cbk, +) !void { + try ctx.push(cbk); + if (ctx.req.client.connection_pool.findConnection(.{ + .host = host, + .port = port, + .protocol = protocol, + })) |conn| { + ctx.req.connection = conn; + return ctx.pop({}); + } + + if (disable_tls and protocol == .tls) + return error.TlsInitializationFailed; + + return net.async_tcpConnectToHost( + client.allocator, + host, + port, + ctx, + setConnection, + ); +} + +pub const ConnectUnixError = Allocator.Error || std.posix.SocketError || error{NameTooLong} || std.posix.ConnectError; + +// Connect to `path` as a unix domain socket. This will reuse a connection if one is already open. +// +// This function is threadsafe. +// pub fn connectUnix(client: *Client, path: []const u8) ConnectUnixError!*Connection { +// if (client.connection_pool.findConnection(.{ +// .host = path, +// .port = 0, +// .protocol = .plain, +// })) |node| +// return node; + +// const conn = try client.allocator.create(ConnectionPool.Node); +// errdefer client.allocator.destroy(conn); +// conn.* = .{ .data = undefined }; + +// const stream = try std.net.connectUnixSocket(path); +// errdefer stream.close(); + +// conn.data = .{ +// .stream = stream, +// .tls_client = undefined, +// .protocol = .plain, + +// .host = try client.allocator.dupe(u8, path), +// .port = 0, +// }; +// errdefer client.allocator.free(conn.data.host); + +// client.connection_pool.addUsed(conn); + +// return &conn.data; +//} + +/// Connect to `tunnel_host:tunnel_port` using the specified proxy with HTTP +/// CONNECT. This will reuse a connection if one is already open. +/// +/// This function is threadsafe. +pub fn connectTunnel( + client: *Client, + proxy: *Proxy, + tunnel_host: []const u8, + tunnel_port: u16, +) !*Connection { + if (!proxy.supports_connect) return error.TunnelNotSupported; + + if (client.connection_pool.findConnection(.{ + .host = tunnel_host, + .port = tunnel_port, + .protocol = proxy.protocol, + })) |node| + return node; + + var maybe_valid = false; + (tunnel: { + const conn = try client.connectTcp(proxy.host, proxy.port, proxy.protocol); + errdefer { + conn.closing = true; + client.connection_pool.release(client.allocator, conn); + } + + var buffer: [8096]u8 = undefined; + var req = client.open(.CONNECT, .{ + .scheme = "http", + .host = .{ .raw = tunnel_host }, + .port = tunnel_port, + }, .{ + .redirect_behavior = .unhandled, + .connection = conn, + .server_header_buffer = &buffer, + }) catch |err| { + std.log.debug("err {}", .{err}); + break :tunnel err; + }; + defer req.deinit(); + + req.send() catch |err| break :tunnel err; + req.wait() catch |err| break :tunnel err; + + if (req.response.status.class() == .server_error) { + maybe_valid = true; + break :tunnel error.ServerError; + } + + if (req.response.status != .ok) break :tunnel error.ConnectionRefused; + + // this connection is now a tunnel, so we can't use it for anything else, it will only be released when the client is de-initialized. + req.connection = null; + + client.allocator.free(conn.host); + conn.host = try client.allocator.dupe(u8, tunnel_host); + errdefer client.allocator.free(conn.host); + + conn.port = tunnel_port; + conn.closing = false; + + return conn; + }) catch { + // something went wrong with the tunnel + proxy.supports_connect = maybe_valid; + return error.TunnelNotSupported; + }; +} + +// Prevents a dependency loop in open() +const ConnectErrorPartial = ConnectTcpError || error{ UnsupportedUriScheme, ConnectionRefused }; +pub const ConnectError = ConnectErrorPartial || RequestError; + +fn onConnectProxy(ctx: *Ctx, res: anyerror!void) anyerror!void { + res catch |e| { + ctx.data.conn.closing = true; + ctx.req.client.connection_pool.release(ctx.req.client.allocator, ctx.data.conn); + return ctx.pop(e); + }; + ctx.data.conn.proxied = true; + return ctx.pop({}); +} + +/// Connect to `host:port` using the specified protocol. This will reuse a +/// connection if one is already open. +/// If a proxy is configured for the client, then the proxy will be used to +/// connect to the host. +/// +/// This function is threadsafe. +pub fn connect( + client: *Client, + host: []const u8, + port: u16, + protocol: Connection.Protocol, +) ConnectError!*Connection { + const proxy = switch (protocol) { + .plain => client.http_proxy, + .tls => client.https_proxy, + } orelse return client.connectTcp(host, port, protocol); + + // Prevent proxying through itself. + if (std.ascii.eqlIgnoreCase(proxy.host, host) and + proxy.port == port and proxy.protocol == protocol) + { + return client.connectTcp(host, port, protocol); + } + + if (proxy.supports_connect) tunnel: { + return connectTunnel(client, proxy, host, port) catch |err| switch (err) { + error.TunnelNotSupported => break :tunnel, + else => |e| return e, + }; + } + + // fall back to using the proxy as a normal http proxy + const conn = try client.connectTcp(proxy.host, proxy.port, proxy.protocol); + errdefer { + conn.closing = true; + client.connection_pool.release(conn); + } + + conn.proxied = true; + return conn; +} + +pub fn async_connect( + client: *Client, + host: []const u8, + port: u16, + protocol: Connection.Protocol, + ctx: *Ctx, + comptime cbk: Cbk, +) !void { + const proxy = switch (protocol) { + .plain => client.http_proxy, + .tls => client.https_proxy, + } orelse return client.async_connectTcp(host, port, protocol, ctx, cbk); + + // Prevent proxying through itself. + if (std.ascii.eqlIgnoreCase(proxy.host, host) and + proxy.port == port and proxy.protocol == protocol) + { + return client.async_connectTcp(host, port, protocol, ctx, cbk); + } + + // TODO: enable async_connectTunnel + // if (proxy.supports_connect) tunnel: { + // return connectTunnel(client, proxy, host, port) catch |err| switch (err) { + // error.TunnelNotSupported => break :tunnel, + // else => |e| return e, + // }; + // } + + // fall back to using the proxy as a normal http proxy + try ctx.push(cbk); + return client.async_connectTcp(proxy.host, proxy.port, proxy.protocol, ctx, onConnectProxy); +} + +pub const RequestError = ConnectTcpError || ConnectErrorPartial || Request.SendError || + std.fmt.ParseIntError || Connection.WriteError || + error{ // TODO: file a zig fmt issue for this bad indentation + UnsupportedUriScheme, + UriMissingHost, + + CertificateBundleLoadFailure, + UnsupportedTransferEncoding, +}; + +pub const RequestOptions = struct { + version: http.Version = .@"HTTP/1.1", + + /// Automatically ignore 100 Continue responses. This assumes you don't + /// care, and will have sent the body before you wait for the response. + /// + /// If this is not the case AND you know the server will send a 100 + /// Continue, set this to false and wait for a response before sending the + /// body. If you wait AND the server does not send a 100 Continue before + /// you finish the request, then the request *will* deadlock. + handle_continue: bool = true, + + /// If false, close the connection after the one request. If true, + /// participate in the client connection pool. + keep_alive: bool = true, + + /// This field specifies whether to automatically follow redirects, and if + /// so, how many redirects to follow before returning an error. + /// + /// This will only follow redirects for repeatable requests (ie. with no + /// payload or the server has acknowledged the payload). + redirect_behavior: Request.RedirectBehavior = @enumFromInt(3), + + /// Externally-owned memory used to store the server's entire HTTP header. + /// `error.HttpHeadersOversize` is returned from read() when a + /// client sends too many bytes of HTTP headers. + server_header_buffer: []u8, + + /// Must be an already acquired connection. + connection: ?*Connection = null, + + /// Standard headers that have default, but overridable, behavior. + headers: Request.Headers = .{}, + /// These headers are kept including when following a redirect to a + /// different domain. + /// Externally-owned; must outlive the Request. + extra_headers: []const http.Header = &.{}, + /// These headers are stripped when following a redirect to a different + /// domain. + /// Externally-owned; must outlive the Request. + privileged_headers: []const http.Header = &.{}, +}; + +const protocol_map = std.StaticStringMap(Connection.Protocol).initComptime(.{ + .{ "http", .plain }, + .{ "ws", .plain }, + .{ "https", .tls }, + .{ "wss", .tls }, +}); + +fn validateUri(uri: Uri, arena: Allocator) !struct { Connection.Protocol, Uri } { + const protocol = protocol_map.get(uri.scheme) orelse return error.UnsupportedUriScheme; + var valid_uri = uri; + // The host is always going to be needed as a raw string for hostname resolution anyway. + valid_uri.host = .{ + .raw = try (uri.host orelse return error.UriMissingHost).toRawMaybeAlloc(arena), + }; + return .{ protocol, valid_uri }; +} + +fn uriPort(uri: Uri, protocol: Connection.Protocol) u16 { + return uri.port orelse switch (protocol) { + .plain => 80, + .tls => 443, + }; +} + +pub fn create( + client: *Client, + method: http.Method, + uri: Uri, + options: RequestOptions, +) RequestError!Request { + if (std.debug.runtime_safety) { + for (options.extra_headers) |header| { + assert(header.name.len != 0); + assert(std.mem.indexOfScalar(u8, header.name, ':') == null); + assert(std.mem.indexOfPosLinear(u8, header.name, 0, "\r\n") == null); + assert(std.mem.indexOfPosLinear(u8, header.value, 0, "\r\n") == null); + } + for (options.privileged_headers) |header| { + assert(header.name.len != 0); + assert(std.mem.indexOfPosLinear(u8, header.name, 0, "\r\n") == null); + assert(std.mem.indexOfPosLinear(u8, header.value, 0, "\r\n") == null); + } + } + + var server_header = std.heap.FixedBufferAllocator.init(options.server_header_buffer); + _, const valid_uri = try validateUri(uri, server_header.allocator()); + + var req: Request = .{ + .uri = valid_uri, + .client = client, + .keep_alive = options.keep_alive, + .method = method, + .version = options.version, + .transfer_encoding = .none, + .redirect_behavior = options.redirect_behavior, + .handle_continue = options.handle_continue, + .response = .{ + .version = undefined, + .status = undefined, + .reason = undefined, + .keep_alive = undefined, + .parser = proto.HeadersParser.init(server_header.buffer[server_header.end_index..]), + }, + .headers = options.headers, + .extra_headers = options.extra_headers, + .privileged_headers = options.privileged_headers, + }; + errdefer req.deinit(); + + return req; +} + +/// Open a connection to the host specified by `uri` and prepare to send a HTTP request. +/// +/// `uri` must remain alive during the entire request. +/// +/// The caller is responsible for calling `deinit()` on the `Request`. +/// This function is threadsafe. +/// +/// Asserts that "\r\n" does not occur in any header name or value. +pub fn open( + client: *Client, + method: http.Method, + uri: Uri, + options: RequestOptions, +) RequestError!Request { + if (std.debug.runtime_safety) { + for (options.extra_headers) |header| { + assert(header.name.len != 0); + assert(std.mem.indexOfScalar(u8, header.name, ':') == null); + assert(std.mem.indexOfPosLinear(u8, header.name, 0, "\r\n") == null); + assert(std.mem.indexOfPosLinear(u8, header.value, 0, "\r\n") == null); + } + for (options.privileged_headers) |header| { + assert(header.name.len != 0); + assert(std.mem.indexOfPosLinear(u8, header.name, 0, "\r\n") == null); + assert(std.mem.indexOfPosLinear(u8, header.value, 0, "\r\n") == null); + } + } + + var server_header = std.heap.FixedBufferAllocator.init(options.server_header_buffer); + const protocol, const valid_uri = try validateUri(uri, server_header.allocator()); + + if (protocol == .tls and @atomicLoad(bool, &client.next_https_rescan_certs, .acquire)) { + if (disable_tls) unreachable; + + client.ca_bundle_mutex.lock(); + defer client.ca_bundle_mutex.unlock(); + + if (client.next_https_rescan_certs) { + client.ca_bundle.rescan(client.allocator) catch + return error.CertificateBundleLoadFailure; + @atomicStore(bool, &client.next_https_rescan_certs, false, .release); + } + } + + const conn = options.connection orelse + try client.connect(valid_uri.host.?.raw, uriPort(valid_uri, protocol), protocol); + + var req: Request = .{ + .uri = valid_uri, + .client = client, + .connection = conn, + .keep_alive = options.keep_alive, + .method = method, + .version = options.version, + .transfer_encoding = .none, + .redirect_behavior = options.redirect_behavior, + .handle_continue = options.handle_continue, + .response = .{ + .version = undefined, + .status = undefined, + .reason = undefined, + .keep_alive = undefined, + .parser = proto.HeadersParser.init(server_header.buffer[server_header.end_index..]), + }, + .headers = options.headers, + .extra_headers = options.extra_headers, + .privileged_headers = options.privileged_headers, + }; + errdefer req.deinit(); + + return req; +} + +pub fn async_open( + client: *Client, + method: http.Method, + uri: Uri, + options: RequestOptions, + ctx: *Ctx, + comptime cbk: Cbk, +) !void { + if (std.debug.runtime_safety) { + for (options.extra_headers) |header| { + assert(header.name.len != 0); + assert(std.mem.indexOfScalar(u8, header.name, ':') == null); + assert(std.mem.indexOfPosLinear(u8, header.name, 0, "\r\n") == null); + assert(std.mem.indexOfPosLinear(u8, header.value, 0, "\r\n") == null); + } + for (options.privileged_headers) |header| { + assert(header.name.len != 0); + assert(std.mem.indexOfPosLinear(u8, header.name, 0, "\r\n") == null); + assert(std.mem.indexOfPosLinear(u8, header.value, 0, "\r\n") == null); + } + } + + var server_header = std.heap.FixedBufferAllocator.init(options.server_header_buffer); + const protocol, const valid_uri = try validateUri(uri, server_header.allocator()); + + if (protocol == .tls and @atomicLoad(bool, &client.next_https_rescan_certs, .acquire)) { + if (disable_tls) unreachable; + + client.ca_bundle_mutex.lock(); + defer client.ca_bundle_mutex.unlock(); + + if (client.next_https_rescan_certs) { + client.ca_bundle.rescan(client.allocator) catch return error.CertificateBundleLoadFailure; + @atomicStore(bool, &client.next_https_rescan_certs, false, .release); + } + } + + // add fields to request + ctx.req.uri = valid_uri; + ctx.req.keep_alive = options.keep_alive; + ctx.req.method = method; + ctx.req.transfer_encoding = .none; + ctx.req.redirect_behavior = options.redirect_behavior; + ctx.req.handle_continue = options.handle_continue; + ctx.req.headers = options.headers; + ctx.req.extra_headers = options.extra_headers; + ctx.req.privileged_headers = options.privileged_headers; + ctx.req.response = .{ + .version = undefined, + .status = undefined, + .reason = undefined, + .keep_alive = undefined, + .parser = proto.HeadersParser.init(server_header.buffer[server_header.end_index..]), + }; + + // we already have the connection, + // set it and call directly the callback + if (options.connection) |conn| { + ctx.req.connection = conn; + return cbk(ctx, {}); + } + + // push callback function + try ctx.push(cbk); + + const host = valid_uri.host orelse return error.UriMissingHost; + const port = uriPort(valid_uri, protocol); + + // add fields to connection + ctx.data.conn.protocol = protocol; + ctx.data.conn.host = try client.allocator.dupe(u8, host.raw); + ctx.data.conn.port = port; + + return client.async_connect(host.raw, port, protocol, ctx, setRequestConnection); +} + +pub const FetchOptions = struct { + server_header_buffer: ?[]u8 = null, + redirect_behavior: ?Request.RedirectBehavior = null, + + /// If the server sends a body, it will be appended to this ArrayList. + /// `max_append_size` provides an upper limit for how much they can grow. + response_storage: ResponseStorage = .ignore, + max_append_size: ?usize = null, + + location: Location, + method: ?http.Method = null, + payload: ?[]const u8 = null, + raw_uri: bool = false, + keep_alive: bool = true, + + /// Standard headers that have default, but overridable, behavior. + headers: Request.Headers = .{}, + /// These headers are kept including when following a redirect to a + /// different domain. + /// Externally-owned; must outlive the Request. + extra_headers: []const http.Header = &.{}, + /// These headers are stripped when following a redirect to a different + /// domain. + /// Externally-owned; must outlive the Request. + privileged_headers: []const http.Header = &.{}, + + pub const Location = union(enum) { + url: []const u8, + uri: Uri, + }; + + pub const ResponseStorage = union(enum) { + ignore, + /// Only the existing capacity will be used. + static: *std.ArrayListUnmanaged(u8), + dynamic: *std.ArrayList(u8), + }; +}; + +pub const FetchResult = struct { + status: http.Status, +}; + +// TODO: enable async_fetch +/// Perform a one-shot HTTP request with the provided options. +/// +/// This function is threadsafe. +pub fn fetch(client: *Client, options: FetchOptions) !FetchResult { + const uri = switch (options.location) { + .url => |u| try Uri.parse(u), + .uri => |u| u, + }; + var server_header_buffer: [16 * 1024]u8 = undefined; + + const method: http.Method = options.method orelse + if (options.payload != null) .POST else .GET; + + var req = try open(client, method, uri, .{ + .server_header_buffer = options.server_header_buffer orelse &server_header_buffer, + .redirect_behavior = options.redirect_behavior orelse + if (options.payload == null) @enumFromInt(3) else .unhandled, + .headers = options.headers, + .extra_headers = options.extra_headers, + .privileged_headers = options.privileged_headers, + .keep_alive = options.keep_alive, + }); + defer req.deinit(); + + if (options.payload) |payload| req.transfer_encoding = .{ .content_length = payload.len }; + + try req.send(); + + if (options.payload) |payload| try req.writeAll(payload); + + try req.finish(); + try req.wait(); + + switch (options.response_storage) { + .ignore => { + // Take advantage of request internals to discard the response body + // and make the connection available for another request. + req.response.skip = true; + assert(try req.transferRead(&.{}) == 0); // No buffer is necessary when skipping. + }, + .dynamic => |list| { + const max_append_size = options.max_append_size orelse 2 * 1024 * 1024; + try req.reader().readAllArrayList(list, max_append_size); + }, + .static => |list| { + const buf = b: { + const buf = list.unusedCapacitySlice(); + if (options.max_append_size) |len| { + if (len < buf.len) break :b buf[0..len]; + } + break :b buf; + }; + list.items.len += try req.reader().readAll(buf); + }, + } + + return .{ + .status = req.response.status, + }; +} + +pub const Cbk = fn (ctx: *Ctx, res: anyerror!void) anyerror!void; + +pub const Ctx = struct { + const Stack = GenericStack(Cbk); + + // temporary Data we need to store on the heap + // because of the callback execution model + const Data = struct { + list: *std.net.AddressList = undefined, + addr_current: usize = undefined, + socket: std.posix.socket_t = undefined, + + // TODO: we could remove this field as it is already set in ctx.req + // but we do not know for now what will be the impact to set those directly + // on the request, especially in case of error/cancellation + conn: *Connection, + }; + + req: *Request = undefined, + + userData: *anyopaque = undefined, + + loop: *Loop, + data: Data, + stack: ?*Stack = null, + err: ?anyerror = null, + + _buffer: ?[]const u8 = null, + _len: ?usize = null, + + _iovecs: []std.posix.iovec = undefined, + + // TLS readvAtLeast + // _off_i: usize = 0, + // _vec_i: usize = 0, + // _tls_len: usize = 0, + + // TLS readv + _vp: VecPut = undefined, + // _tls_read_buf contains the next decrypted buffer + _tls_read_buf: ?[]u8 = undefined, + _tls_read_content_type: tls23.proto.ContentType = undefined, + + // _tls_read_record contains the crypted record + _tls_read_record: ?tls23.record.Record = null, + + // TLS writeAll + _tls_write_bytes: []const u8 = undefined, + _tls_write_index: usize = 0, + _tls_write_buf: [cipher.max_ciphertext_record_len]u8 = undefined, + + pub fn init(loop: *Loop, req: *Request) !Ctx { + const connection = try req.client.allocator.create(Connection); + connection.* = .{ + .stream = undefined, + .tls_client = undefined, + .protocol = undefined, + .host = undefined, + .port = undefined, + }; + return .{ + .req = req, + .loop = loop, + .data = .{ .conn = connection }, + }; + } + + pub fn setErr(self: *Ctx, err: anyerror) void { + self.err = err; + } + + pub fn push(self: *Ctx, comptime func: Stack.Fn) !void { + if (self.stack) |stack| { + return try stack.push(self.alloc(), func); + } + self.stack = try Stack.init(self.alloc(), func); + } + + pub fn pop(self: *Ctx, res: anyerror!void) !void { + if (self.stack) |stack| { + const func = stack.pop(self.alloc(), null); + const ret = @call(.auto, func, .{ self, res }); + if (stack.next == null) { + self.stack = null; + self.alloc().destroy(stack); + } + return ret; + } + } + + pub fn deinit(self: Ctx) void { + if (self.stack) |stack| { + stack.deinit(self.alloc(), null); + } + } + + // not sure about those + + pub fn len(self: Ctx) usize { + if (self._len == null) unreachable; + return self._len.?; + } + + pub fn setLen(self: *Ctx, nb: ?usize) void { + self._len = nb; + } + + pub fn buf(self: Ctx) []const u8 { + if (self._buffer == null) unreachable; + return self._buffer.?; + } + + pub fn setBuf(self: *Ctx, bytes: ?[]const u8) void { + self._buffer = bytes; + } + + // ctx Request aliases + + pub fn alloc(self: Ctx) std.mem.Allocator { + return self.req.client.allocator; + } + + pub fn conn(self: Ctx) *Connection { + return self.req.connection.?; + } + + pub fn stream(self: Ctx) net.Stream { + return self.conn().stream; + } +}; + +// requires ctx.data.conn to be set +fn setRequestConnection(ctx: *Ctx, res: anyerror!void) anyerror!void { + res catch |e| return ctx.pop(e); + + ctx.req.connection = ctx.data.conn; + return ctx.pop({}); +} + +fn onRequestWait(ctx: *Ctx, res: anyerror!void) !void { + res catch |e| { + std.debug.print("error: {any}\n", .{e}); + return e; + }; + std.log.debug("REQUEST WAITED", .{}); + std.log.debug("Status code: {any}", .{ctx.req.response.status}); + const body = try ctx.req.reader().readAllAlloc(ctx.alloc(), 1024 * 1024); + defer ctx.alloc().free(body); + std.log.debug("Body: \n{s}", .{body}); +} + +fn onRequestFinish(ctx: *Ctx, res: anyerror!void) !void { + res catch |err| return err; + std.log.debug("REQUEST FINISHED", .{}); + return ctx.req.async_wait(ctx, onRequestWait); +} + +fn onRequestSend(ctx: *Ctx, res: anyerror!void) !void { + res catch |err| return err; + std.log.debug("REQUEST SENT", .{}); + return ctx.req.async_finish(ctx, onRequestFinish); +} + +pub fn onRequestConnect(ctx: *Ctx, res: anyerror!void) anyerror!void { + res catch |err| return err; + std.log.debug("REQUEST CONNECTED", .{}); + return ctx.req.async_send(ctx, onRequestSend); +} + +test { + const alloc = std.testing.allocator; + + var loop = Loop{}; + + var client = Client{ .allocator = alloc }; + defer client.deinit(); + + var req = Request{ + .client = &client, + }; + defer req.deinit(); + + var ctx = try Ctx.init(&loop, &req); + defer ctx.deinit(); + + var server_header_buffer: [2048]u8 = undefined; + + const url = "http://www.example.com"; + // const url = "http://127.0.0.1:8000/zig"; + try client.async_open( + .GET, + try std.Uri.parse(url), + .{ .server_header_buffer = &server_header_buffer }, + &ctx, + onRequestConnect, + ); +} diff --git a/src/http/async/std/http/Server.zig b/src/http/async/std/http/Server.zig new file mode 100644 index 00000000..38d3f133 --- /dev/null +++ b/src/http/async/std/http/Server.zig @@ -0,0 +1,1148 @@ +//! Blocking HTTP server implementation. +//! Handles a single connection's lifecycle. + +connection: net.Server.Connection, +/// Keeps track of whether the Server is ready to accept a new request on the +/// same connection, and makes invalid API usage cause assertion failures +/// rather than HTTP protocol violations. +state: State, +/// User-provided buffer that must outlive this Server. +/// Used to store the client's entire HTTP header. +read_buffer: []u8, +/// Amount of available data inside read_buffer. +read_buffer_len: usize, +/// Index into `read_buffer` of the first byte of the next HTTP request. +next_request_start: usize, + +pub const State = enum { + /// The connection is available to be used for the first time, or reused. + ready, + /// An error occurred in `receiveHead`. + receiving_head, + /// A Request object has been obtained and from there a Response can be + /// opened. + received_head, + /// The client is uploading something to this Server. + receiving_body, + /// The connection is eligible for another HTTP request, however the client + /// and server did not negotiate a persistent connection. + closing, +}; + +/// Initialize an HTTP server that can respond to multiple requests on the same +/// connection. +/// The returned `Server` is ready for `receiveHead` to be called. +pub fn init(connection: net.Server.Connection, read_buffer: []u8) Server { + return .{ + .connection = connection, + .state = .ready, + .read_buffer = read_buffer, + .read_buffer_len = 0, + .next_request_start = 0, + }; +} + +pub const ReceiveHeadError = error{ + /// Client sent too many bytes of HTTP headers. + /// The HTTP specification suggests to respond with a 431 status code + /// before closing the connection. + HttpHeadersOversize, + /// Client sent headers that did not conform to the HTTP protocol. + HttpHeadersInvalid, + /// A low level I/O error occurred trying to read the headers. + HttpHeadersUnreadable, + /// Partial HTTP request was received but the connection was closed before + /// fully receiving the headers. + HttpRequestTruncated, + /// The client sent 0 bytes of headers before closing the stream. + /// In other words, a keep-alive connection was finally closed. + HttpConnectionClosing, +}; + +/// The header bytes reference the read buffer that Server was initialized with +/// and remain alive until the next call to receiveHead. +pub fn receiveHead(s: *Server) ReceiveHeadError!Request { + assert(s.state == .ready); + s.state = .received_head; + errdefer s.state = .receiving_head; + + // In case of a reused connection, move the next request's bytes to the + // beginning of the buffer. + if (s.next_request_start > 0) { + if (s.read_buffer_len > s.next_request_start) { + rebase(s, 0); + } else { + s.read_buffer_len = 0; + } + } + + var hp: http.HeadParser = .{}; + + if (s.read_buffer_len > 0) { + const bytes = s.read_buffer[0..s.read_buffer_len]; + const end = hp.feed(bytes); + if (hp.state == .finished) + return finishReceivingHead(s, end); + } + + while (true) { + const buf = s.read_buffer[s.read_buffer_len..]; + if (buf.len == 0) + return error.HttpHeadersOversize; + const read_n = s.connection.stream.read(buf) catch + return error.HttpHeadersUnreadable; + if (read_n == 0) { + if (s.read_buffer_len > 0) { + return error.HttpRequestTruncated; + } else { + return error.HttpConnectionClosing; + } + } + s.read_buffer_len += read_n; + const bytes = buf[0..read_n]; + const end = hp.feed(bytes); + if (hp.state == .finished) + return finishReceivingHead(s, s.read_buffer_len - bytes.len + end); + } +} + +fn finishReceivingHead(s: *Server, head_end: usize) ReceiveHeadError!Request { + return .{ + .server = s, + .head_end = head_end, + .head = Request.Head.parse(s.read_buffer[0..head_end]) catch + return error.HttpHeadersInvalid, + .reader_state = undefined, + }; +} + +pub const Request = struct { + server: *Server, + /// Index into Server's read_buffer. + head_end: usize, + head: Head, + reader_state: union { + remaining_content_length: u64, + chunk_parser: http.ChunkParser, + }, + + pub const Compression = union(enum) { + pub const DeflateDecompressor = std.compress.zlib.Decompressor(std.io.AnyReader); + pub const GzipDecompressor = std.compress.gzip.Decompressor(std.io.AnyReader); + pub const ZstdDecompressor = std.compress.zstd.Decompressor(std.io.AnyReader); + + deflate: DeflateDecompressor, + gzip: GzipDecompressor, + zstd: ZstdDecompressor, + none: void, + }; + + pub const Head = struct { + method: http.Method, + target: []const u8, + version: http.Version, + expect: ?[]const u8, + content_type: ?[]const u8, + content_length: ?u64, + transfer_encoding: http.TransferEncoding, + transfer_compression: http.ContentEncoding, + keep_alive: bool, + compression: Compression, + + pub const ParseError = error{ + UnknownHttpMethod, + HttpHeadersInvalid, + HttpHeaderContinuationsUnsupported, + HttpTransferEncodingUnsupported, + HttpConnectionHeaderUnsupported, + InvalidContentLength, + CompressionUnsupported, + MissingFinalNewline, + }; + + pub fn parse(bytes: []const u8) ParseError!Head { + var it = mem.splitSequence(u8, bytes, "\r\n"); + + const first_line = it.next().?; + if (first_line.len < 10) + return error.HttpHeadersInvalid; + + const method_end = mem.indexOfScalar(u8, first_line, ' ') orelse + return error.HttpHeadersInvalid; + if (method_end > 24) return error.HttpHeadersInvalid; + + const method_str = first_line[0..method_end]; + const method: http.Method = @enumFromInt(http.Method.parse(method_str)); + + const version_start = mem.lastIndexOfScalar(u8, first_line, ' ') orelse + return error.HttpHeadersInvalid; + if (version_start == method_end) return error.HttpHeadersInvalid; + + const version_str = first_line[version_start + 1 ..]; + if (version_str.len != 8) return error.HttpHeadersInvalid; + const version: http.Version = switch (int64(version_str[0..8])) { + int64("HTTP/1.0") => .@"HTTP/1.0", + int64("HTTP/1.1") => .@"HTTP/1.1", + else => return error.HttpHeadersInvalid, + }; + + const target = first_line[method_end + 1 .. version_start]; + + var head: Head = .{ + .method = method, + .target = target, + .version = version, + .expect = null, + .content_type = null, + .content_length = null, + .transfer_encoding = .none, + .transfer_compression = .identity, + .keep_alive = switch (version) { + .@"HTTP/1.0" => false, + .@"HTTP/1.1" => true, + }, + .compression = .none, + }; + + while (it.next()) |line| { + if (line.len == 0) return head; + switch (line[0]) { + ' ', '\t' => return error.HttpHeaderContinuationsUnsupported, + else => {}, + } + + var line_it = mem.splitScalar(u8, line, ':'); + const header_name = line_it.next().?; + const header_value = mem.trim(u8, line_it.rest(), " \t"); + if (header_name.len == 0) return error.HttpHeadersInvalid; + + if (std.ascii.eqlIgnoreCase(header_name, "connection")) { + head.keep_alive = !std.ascii.eqlIgnoreCase(header_value, "close"); + } else if (std.ascii.eqlIgnoreCase(header_name, "expect")) { + head.expect = header_value; + } else if (std.ascii.eqlIgnoreCase(header_name, "content-type")) { + head.content_type = header_value; + } else if (std.ascii.eqlIgnoreCase(header_name, "content-length")) { + if (head.content_length != null) return error.HttpHeadersInvalid; + head.content_length = std.fmt.parseInt(u64, header_value, 10) catch + return error.InvalidContentLength; + } else if (std.ascii.eqlIgnoreCase(header_name, "content-encoding")) { + if (head.transfer_compression != .identity) return error.HttpHeadersInvalid; + + const trimmed = mem.trim(u8, header_value, " "); + + if (std.meta.stringToEnum(http.ContentEncoding, trimmed)) |ce| { + head.transfer_compression = ce; + } else { + return error.HttpTransferEncodingUnsupported; + } + } else if (std.ascii.eqlIgnoreCase(header_name, "transfer-encoding")) { + // Transfer-Encoding: second, first + // Transfer-Encoding: deflate, chunked + var iter = mem.splitBackwardsScalar(u8, header_value, ','); + + const first = iter.first(); + const trimmed_first = mem.trim(u8, first, " "); + + var next: ?[]const u8 = first; + if (std.meta.stringToEnum(http.TransferEncoding, trimmed_first)) |transfer| { + if (head.transfer_encoding != .none) + return error.HttpHeadersInvalid; // we already have a transfer encoding + head.transfer_encoding = transfer; + + next = iter.next(); + } + + if (next) |second| { + const trimmed_second = mem.trim(u8, second, " "); + + if (std.meta.stringToEnum(http.ContentEncoding, trimmed_second)) |transfer| { + if (head.transfer_compression != .identity) + return error.HttpHeadersInvalid; // double compression is not supported + head.transfer_compression = transfer; + } else { + return error.HttpTransferEncodingUnsupported; + } + } + + if (iter.next()) |_| return error.HttpTransferEncodingUnsupported; + } + } + return error.MissingFinalNewline; + } + + test parse { + const request_bytes = "GET /hi HTTP/1.0\r\n" ++ + "content-tYpe: text/plain\r\n" ++ + "content-Length:10\r\n" ++ + "expeCt: 100-continue \r\n" ++ + "TRansfer-encoding:\tdeflate, chunked \r\n" ++ + "connectioN:\t keep-alive \r\n\r\n"; + + const req = try parse(request_bytes); + + try testing.expectEqual(.GET, req.method); + try testing.expectEqual(.@"HTTP/1.0", req.version); + try testing.expectEqualStrings("/hi", req.target); + + try testing.expectEqualStrings("text/plain", req.content_type.?); + try testing.expectEqualStrings("100-continue", req.expect.?); + + try testing.expectEqual(true, req.keep_alive); + try testing.expectEqual(10, req.content_length.?); + try testing.expectEqual(.chunked, req.transfer_encoding); + try testing.expectEqual(.deflate, req.transfer_compression); + } + + inline fn int64(array: *const [8]u8) u64 { + return @bitCast(array.*); + } + }; + + pub fn iterateHeaders(r: *Request) http.HeaderIterator { + return http.HeaderIterator.init(r.server.read_buffer[0..r.head_end]); + } + + test iterateHeaders { + const request_bytes = "GET /hi HTTP/1.0\r\n" ++ + "content-tYpe: text/plain\r\n" ++ + "content-Length:10\r\n" ++ + "expeCt: 100-continue \r\n" ++ + "TRansfer-encoding:\tdeflate, chunked \r\n" ++ + "connectioN:\t keep-alive \r\n\r\n"; + + var read_buffer: [500]u8 = undefined; + @memcpy(read_buffer[0..request_bytes.len], request_bytes); + + var server: Server = .{ + .connection = undefined, + .state = .ready, + .read_buffer = &read_buffer, + .read_buffer_len = request_bytes.len, + .next_request_start = 0, + }; + + var request: Request = .{ + .server = &server, + .head_end = request_bytes.len, + .head = undefined, + .reader_state = undefined, + }; + + var it = request.iterateHeaders(); + { + const header = it.next().?; + try testing.expectEqualStrings("content-tYpe", header.name); + try testing.expectEqualStrings("text/plain", header.value); + try testing.expect(!it.is_trailer); + } + { + const header = it.next().?; + try testing.expectEqualStrings("content-Length", header.name); + try testing.expectEqualStrings("10", header.value); + try testing.expect(!it.is_trailer); + } + { + const header = it.next().?; + try testing.expectEqualStrings("expeCt", header.name); + try testing.expectEqualStrings("100-continue", header.value); + try testing.expect(!it.is_trailer); + } + { + const header = it.next().?; + try testing.expectEqualStrings("TRansfer-encoding", header.name); + try testing.expectEqualStrings("deflate, chunked", header.value); + try testing.expect(!it.is_trailer); + } + { + const header = it.next().?; + try testing.expectEqualStrings("connectioN", header.name); + try testing.expectEqualStrings("keep-alive", header.value); + try testing.expect(!it.is_trailer); + } + try testing.expectEqual(null, it.next()); + } + + pub const RespondOptions = struct { + version: http.Version = .@"HTTP/1.1", + status: http.Status = .ok, + reason: ?[]const u8 = null, + keep_alive: bool = true, + extra_headers: []const http.Header = &.{}, + transfer_encoding: ?http.TransferEncoding = null, + }; + + /// Send an entire HTTP response to the client, including headers and body. + /// + /// Automatically handles HEAD requests by omitting the body. + /// + /// Unless `transfer_encoding` is specified, uses the "content-length" + /// header. + /// + /// If the request contains a body and the connection is to be reused, + /// discards the request body, leaving the Server in the `ready` state. If + /// this discarding fails, the connection is marked as not to be reused and + /// no error is surfaced. + /// + /// Asserts status is not `continue`. + /// Asserts there are at most 25 extra_headers. + /// Asserts that "\r\n" does not occur in any header name or value. + pub fn respond( + request: *Request, + content: []const u8, + options: RespondOptions, + ) Response.WriteError!void { + const max_extra_headers = 25; + assert(options.status != .@"continue"); + assert(options.extra_headers.len <= max_extra_headers); + if (std.debug.runtime_safety) { + for (options.extra_headers) |header| { + assert(header.name.len != 0); + assert(std.mem.indexOfScalar(u8, header.name, ':') == null); + assert(std.mem.indexOfPosLinear(u8, header.name, 0, "\r\n") == null); + assert(std.mem.indexOfPosLinear(u8, header.value, 0, "\r\n") == null); + } + } + + const transfer_encoding_none = (options.transfer_encoding orelse .chunked) == .none; + const server_keep_alive = !transfer_encoding_none and options.keep_alive; + const keep_alive = request.discardBody(server_keep_alive); + + const phrase = options.reason orelse options.status.phrase() orelse ""; + + var first_buffer: [500]u8 = undefined; + var h = std.ArrayListUnmanaged(u8).initBuffer(&first_buffer); + if (request.head.expect != null) { + // reader() and hence discardBody() above sets expect to null if it + // is handled. So the fact that it is not null here means unhandled. + h.appendSliceAssumeCapacity("HTTP/1.1 417 Expectation Failed\r\n"); + if (!keep_alive) h.appendSliceAssumeCapacity("connection: close\r\n"); + h.appendSliceAssumeCapacity("content-length: 0\r\n\r\n"); + try request.server.connection.stream.writeAll(h.items); + return; + } + h.fixedWriter().print("{s} {d} {s}\r\n", .{ + @tagName(options.version), @intFromEnum(options.status), phrase, + }) catch unreachable; + + switch (options.version) { + .@"HTTP/1.0" => if (keep_alive) h.appendSliceAssumeCapacity("connection: keep-alive\r\n"), + .@"HTTP/1.1" => if (!keep_alive) h.appendSliceAssumeCapacity("connection: close\r\n"), + } + + if (options.transfer_encoding) |transfer_encoding| switch (transfer_encoding) { + .none => {}, + .chunked => h.appendSliceAssumeCapacity("transfer-encoding: chunked\r\n"), + } else { + h.fixedWriter().print("content-length: {d}\r\n", .{content.len}) catch unreachable; + } + + var chunk_header_buffer: [18]u8 = undefined; + var iovecs: [max_extra_headers * 4 + 3]std.posix.iovec_const = undefined; + var iovecs_len: usize = 0; + + iovecs[iovecs_len] = .{ + .base = h.items.ptr, + .len = h.items.len, + }; + iovecs_len += 1; + + for (options.extra_headers) |header| { + iovecs[iovecs_len] = .{ + .base = header.name.ptr, + .len = header.name.len, + }; + iovecs_len += 1; + + iovecs[iovecs_len] = .{ + .base = ": ", + .len = 2, + }; + iovecs_len += 1; + + if (header.value.len != 0) { + iovecs[iovecs_len] = .{ + .base = header.value.ptr, + .len = header.value.len, + }; + iovecs_len += 1; + } + + iovecs[iovecs_len] = .{ + .base = "\r\n", + .len = 2, + }; + iovecs_len += 1; + } + + iovecs[iovecs_len] = .{ + .base = "\r\n", + .len = 2, + }; + iovecs_len += 1; + + if (request.head.method != .HEAD) { + const is_chunked = (options.transfer_encoding orelse .none) == .chunked; + if (is_chunked) { + if (content.len > 0) { + const chunk_header = std.fmt.bufPrint( + &chunk_header_buffer, + "{x}\r\n", + .{content.len}, + ) catch unreachable; + + iovecs[iovecs_len] = .{ + .base = chunk_header.ptr, + .len = chunk_header.len, + }; + iovecs_len += 1; + + iovecs[iovecs_len] = .{ + .base = content.ptr, + .len = content.len, + }; + iovecs_len += 1; + + iovecs[iovecs_len] = .{ + .base = "\r\n", + .len = 2, + }; + iovecs_len += 1; + } + + iovecs[iovecs_len] = .{ + .base = "0\r\n\r\n", + .len = 5, + }; + iovecs_len += 1; + } else if (content.len > 0) { + iovecs[iovecs_len] = .{ + .base = content.ptr, + .len = content.len, + }; + iovecs_len += 1; + } + } + + try request.server.connection.stream.writevAll(iovecs[0..iovecs_len]); + } + + pub const RespondStreamingOptions = struct { + /// An externally managed slice of memory used to batch bytes before + /// sending. `respondStreaming` asserts this is large enough to store + /// the full HTTP response head. + /// + /// Must outlive the returned Response. + send_buffer: []u8, + /// If provided, the response will use the content-length header; + /// otherwise it will use transfer-encoding: chunked. + content_length: ?u64 = null, + /// Options that are shared with the `respond` method. + respond_options: RespondOptions = .{}, + }; + + /// The header is buffered but not sent until Response.flush is called. + /// + /// If the request contains a body and the connection is to be reused, + /// discards the request body, leaving the Server in the `ready` state. If + /// this discarding fails, the connection is marked as not to be reused and + /// no error is surfaced. + /// + /// HEAD requests are handled transparently by setting a flag on the + /// returned Response to omit the body. However it may be worth noticing + /// that flag and skipping any expensive work that would otherwise need to + /// be done to satisfy the request. + /// + /// Asserts `send_buffer` is large enough to store the entire response header. + /// Asserts status is not `continue`. + pub fn respondStreaming(request: *Request, options: RespondStreamingOptions) Response { + const o = options.respond_options; + assert(o.status != .@"continue"); + const transfer_encoding_none = (o.transfer_encoding orelse .chunked) == .none; + const server_keep_alive = !transfer_encoding_none and o.keep_alive; + const keep_alive = request.discardBody(server_keep_alive); + const phrase = o.reason orelse o.status.phrase() orelse ""; + + var h = std.ArrayListUnmanaged(u8).initBuffer(options.send_buffer); + + const elide_body = if (request.head.expect != null) eb: { + // reader() and hence discardBody() above sets expect to null if it + // is handled. So the fact that it is not null here means unhandled. + h.appendSliceAssumeCapacity("HTTP/1.1 417 Expectation Failed\r\n"); + if (!keep_alive) h.appendSliceAssumeCapacity("connection: close\r\n"); + h.appendSliceAssumeCapacity("content-length: 0\r\n\r\n"); + break :eb true; + } else eb: { + h.fixedWriter().print("{s} {d} {s}\r\n", .{ + @tagName(o.version), @intFromEnum(o.status), phrase, + }) catch unreachable; + + switch (o.version) { + .@"HTTP/1.0" => if (keep_alive) h.appendSliceAssumeCapacity("connection: keep-alive\r\n"), + .@"HTTP/1.1" => if (!keep_alive) h.appendSliceAssumeCapacity("connection: close\r\n"), + } + + if (o.transfer_encoding) |transfer_encoding| switch (transfer_encoding) { + .chunked => h.appendSliceAssumeCapacity("transfer-encoding: chunked\r\n"), + .none => {}, + } else if (options.content_length) |len| { + h.fixedWriter().print("content-length: {d}\r\n", .{len}) catch unreachable; + } else { + h.appendSliceAssumeCapacity("transfer-encoding: chunked\r\n"); + } + + for (o.extra_headers) |header| { + assert(header.name.len != 0); + h.appendSliceAssumeCapacity(header.name); + h.appendSliceAssumeCapacity(": "); + h.appendSliceAssumeCapacity(header.value); + h.appendSliceAssumeCapacity("\r\n"); + } + + h.appendSliceAssumeCapacity("\r\n"); + break :eb request.head.method == .HEAD; + }; + + return .{ + .stream = request.server.connection.stream, + .send_buffer = options.send_buffer, + .send_buffer_start = 0, + .send_buffer_end = h.items.len, + .transfer_encoding = if (o.transfer_encoding) |te| switch (te) { + .chunked => .chunked, + .none => .none, + } else if (options.content_length) |len| .{ + .content_length = len, + } else .chunked, + .elide_body = elide_body, + .chunk_len = 0, + }; + } + + pub const ReadError = net.Stream.ReadError || error{ + HttpChunkInvalid, + HttpHeadersOversize, + }; + + fn read_cl(context: *const anyopaque, buffer: []u8) ReadError!usize { + const request: *Request = @constCast(@alignCast(@ptrCast(context))); + const s = request.server; + + const remaining_content_length = &request.reader_state.remaining_content_length; + if (remaining_content_length.* == 0) { + s.state = .ready; + return 0; + } + assert(s.state == .receiving_body); + const available = try fill(s, request.head_end); + const len = @min(remaining_content_length.*, available.len, buffer.len); + @memcpy(buffer[0..len], available[0..len]); + remaining_content_length.* -= len; + s.next_request_start += len; + if (remaining_content_length.* == 0) + s.state = .ready; + return len; + } + + fn fill(s: *Server, head_end: usize) ReadError![]u8 { + const available = s.read_buffer[s.next_request_start..s.read_buffer_len]; + if (available.len > 0) return available; + s.next_request_start = head_end; + s.read_buffer_len = head_end + try s.connection.stream.read(s.read_buffer[head_end..]); + return s.read_buffer[head_end..s.read_buffer_len]; + } + + fn read_chunked(context: *const anyopaque, buffer: []u8) ReadError!usize { + const request: *Request = @constCast(@alignCast(@ptrCast(context))); + const s = request.server; + + const cp = &request.reader_state.chunk_parser; + const head_end = request.head_end; + + // Protect against returning 0 before the end of stream. + var out_end: usize = 0; + while (out_end == 0) { + switch (cp.state) { + .invalid => return 0, + .data => { + assert(s.state == .receiving_body); + const available = try fill(s, head_end); + const len = @min(cp.chunk_len, available.len, buffer.len); + @memcpy(buffer[0..len], available[0..len]); + cp.chunk_len -= len; + if (cp.chunk_len == 0) + cp.state = .data_suffix; + out_end += len; + s.next_request_start += len; + continue; + }, + else => { + assert(s.state == .receiving_body); + const available = try fill(s, head_end); + const n = cp.feed(available); + switch (cp.state) { + .invalid => return error.HttpChunkInvalid, + .data => { + if (cp.chunk_len == 0) { + // The next bytes in the stream are trailers, + // or \r\n to indicate end of chunked body. + // + // This function must append the trailers at + // head_end so that headers and trailers are + // together. + // + // Since returning 0 would indicate end of + // stream, this function must read all the + // trailers before returning. + if (s.next_request_start > head_end) rebase(s, head_end); + var hp: http.HeadParser = .{}; + { + const bytes = s.read_buffer[head_end..s.read_buffer_len]; + const end = hp.feed(bytes); + if (hp.state == .finished) { + cp.state = .invalid; + s.state = .ready; + s.next_request_start = s.read_buffer_len - bytes.len + end; + return out_end; + } + } + while (true) { + const buf = s.read_buffer[s.read_buffer_len..]; + if (buf.len == 0) + return error.HttpHeadersOversize; + const read_n = try s.connection.stream.read(buf); + s.read_buffer_len += read_n; + const bytes = buf[0..read_n]; + const end = hp.feed(bytes); + if (hp.state == .finished) { + cp.state = .invalid; + s.state = .ready; + s.next_request_start = s.read_buffer_len - bytes.len + end; + return out_end; + } + } + } + const data = available[n..]; + const len = @min(cp.chunk_len, data.len, buffer.len); + @memcpy(buffer[0..len], data[0..len]); + cp.chunk_len -= len; + if (cp.chunk_len == 0) + cp.state = .data_suffix; + out_end += len; + s.next_request_start += n + len; + continue; + }, + else => continue, + } + }, + } + } + return out_end; + } + + pub const ReaderError = Response.WriteError || error{ + /// The client sent an expect HTTP header value other than + /// "100-continue". + HttpExpectationFailed, + }; + + /// In the case that the request contains "expect: 100-continue", this + /// function writes the continuation header, which means it can fail with a + /// write error. After sending the continuation header, it sets the + /// request's expect field to `null`. + /// + /// Asserts that this function is only called once. + pub fn reader(request: *Request) ReaderError!std.io.AnyReader { + const s = request.server; + assert(s.state == .received_head); + s.state = .receiving_body; + s.next_request_start = request.head_end; + + if (request.head.expect) |expect| { + if (mem.eql(u8, expect, "100-continue")) { + try request.server.connection.stream.writeAll("HTTP/1.1 100 Continue\r\n\r\n"); + request.head.expect = null; + } else { + return error.HttpExpectationFailed; + } + } + + switch (request.head.transfer_encoding) { + .chunked => { + request.reader_state = .{ .chunk_parser = http.ChunkParser.init }; + return .{ + .readFn = read_chunked, + .context = request, + }; + }, + .none => { + request.reader_state = .{ + .remaining_content_length = request.head.content_length orelse 0, + }; + return .{ + .readFn = read_cl, + .context = request, + }; + }, + } + } + + /// Returns whether the connection should remain persistent. + /// If it would fail, it instead sets the Server state to `receiving_body` + /// and returns false. + fn discardBody(request: *Request, keep_alive: bool) bool { + // Prepare to receive another request on the same connection. + // There are two factors to consider: + // * Any body the client sent must be discarded. + // * The Server's read_buffer may already have some bytes in it from + // whatever came after the head, which may be the next HTTP request + // or the request body. + // If the connection won't be kept alive, then none of this matters + // because the connection will be severed after the response is sent. + const s = request.server; + if (keep_alive and request.head.keep_alive) switch (s.state) { + .received_head => { + const r = request.reader() catch return false; + _ = r.discard() catch return false; + assert(s.state == .ready); + return true; + }, + .receiving_body, .ready => return true, + else => unreachable, + }; + + // Avoid clobbering the state in case a reading stream already exists. + switch (s.state) { + .received_head => s.state = .closing, + else => {}, + } + return false; + } +}; + +pub const Response = struct { + stream: net.Stream, + send_buffer: []u8, + /// Index of the first byte in `send_buffer`. + /// This is 0 unless a short write happens in `write`. + send_buffer_start: usize, + /// Index of the last byte + 1 in `send_buffer`. + send_buffer_end: usize, + /// `null` means transfer-encoding: chunked. + /// As a debugging utility, counts down to zero as bytes are written. + transfer_encoding: TransferEncoding, + elide_body: bool, + /// Indicates how much of the end of the `send_buffer` corresponds to a + /// chunk. This amount of data will be wrapped by an HTTP chunk header. + chunk_len: usize, + + pub const TransferEncoding = union(enum) { + /// End of connection signals the end of the stream. + none, + /// As a debugging utility, counts down to zero as bytes are written. + content_length: u64, + /// Each chunk is wrapped in a header and trailer. + chunked, + }; + + pub const WriteError = net.Stream.WriteError; + + /// When using content-length, asserts that the amount of data sent matches + /// the value sent in the header, then calls `flush`. + /// Otherwise, transfer-encoding: chunked is being used, and it writes the + /// end-of-stream message, then flushes the stream to the system. + /// Respects the value of `elide_body` to omit all data after the headers. + pub fn end(r: *Response) WriteError!void { + switch (r.transfer_encoding) { + .content_length => |len| { + assert(len == 0); // Trips when end() called before all bytes written. + try flush_cl(r); + }, + .none => { + try flush_cl(r); + }, + .chunked => { + try flush_chunked(r, &.{}); + }, + } + r.* = undefined; + } + + pub const EndChunkedOptions = struct { + trailers: []const http.Header = &.{}, + }; + + /// Asserts that the Response is using transfer-encoding: chunked. + /// Writes the end-of-stream message and any optional trailers, then + /// flushes the stream to the system. + /// Respects the value of `elide_body` to omit all data after the headers. + /// Asserts there are at most 25 trailers. + pub fn endChunked(r: *Response, options: EndChunkedOptions) WriteError!void { + assert(r.transfer_encoding == .chunked); + try flush_chunked(r, options.trailers); + r.* = undefined; + } + + /// If using content-length, asserts that writing these bytes to the client + /// would not exceed the content-length value sent in the HTTP header. + /// May return 0, which does not indicate end of stream. The caller decides + /// when the end of stream occurs by calling `end`. + pub fn write(r: *Response, bytes: []const u8) WriteError!usize { + switch (r.transfer_encoding) { + .content_length, .none => return write_cl(r, bytes), + .chunked => return write_chunked(r, bytes), + } + } + + fn write_cl(context: *const anyopaque, bytes: []const u8) WriteError!usize { + const r: *Response = @constCast(@alignCast(@ptrCast(context))); + + var trash: u64 = std.math.maxInt(u64); + const len = switch (r.transfer_encoding) { + .content_length => |*len| len, + else => &trash, + }; + + if (r.elide_body) { + len.* -= bytes.len; + return bytes.len; + } + + if (bytes.len + r.send_buffer_end > r.send_buffer.len) { + const send_buffer_len = r.send_buffer_end - r.send_buffer_start; + var iovecs: [2]std.posix.iovec_const = .{ + .{ + .base = r.send_buffer.ptr + r.send_buffer_start, + .len = send_buffer_len, + }, + .{ + .base = bytes.ptr, + .len = bytes.len, + }, + }; + const n = try r.stream.writev(&iovecs); + + if (n >= send_buffer_len) { + // It was enough to reset the buffer. + r.send_buffer_start = 0; + r.send_buffer_end = 0; + const bytes_n = n - send_buffer_len; + len.* -= bytes_n; + return bytes_n; + } + + // It didn't even make it through the existing buffer, let + // alone the new bytes provided. + r.send_buffer_start += n; + return 0; + } + + // All bytes can be stored in the remaining space of the buffer. + @memcpy(r.send_buffer[r.send_buffer_end..][0..bytes.len], bytes); + r.send_buffer_end += bytes.len; + len.* -= bytes.len; + return bytes.len; + } + + fn write_chunked(context: *const anyopaque, bytes: []const u8) WriteError!usize { + const r: *Response = @constCast(@alignCast(@ptrCast(context))); + assert(r.transfer_encoding == .chunked); + + if (r.elide_body) + return bytes.len; + + if (bytes.len + r.send_buffer_end > r.send_buffer.len) { + const send_buffer_len = r.send_buffer_end - r.send_buffer_start; + const chunk_len = r.chunk_len + bytes.len; + var header_buf: [18]u8 = undefined; + const chunk_header = std.fmt.bufPrint(&header_buf, "{x}\r\n", .{chunk_len}) catch unreachable; + + var iovecs: [5]std.posix.iovec_const = .{ + .{ + .base = r.send_buffer.ptr + r.send_buffer_start, + .len = send_buffer_len - r.chunk_len, + }, + .{ + .base = chunk_header.ptr, + .len = chunk_header.len, + }, + .{ + .base = r.send_buffer.ptr + r.send_buffer_end - r.chunk_len, + .len = r.chunk_len, + }, + .{ + .base = bytes.ptr, + .len = bytes.len, + }, + .{ + .base = "\r\n", + .len = 2, + }, + }; + // TODO make this writev instead of writevAll, which involves + // complicating the logic of this function. + try r.stream.writevAll(&iovecs); + r.send_buffer_start = 0; + r.send_buffer_end = 0; + r.chunk_len = 0; + return bytes.len; + } + + // All bytes can be stored in the remaining space of the buffer. + @memcpy(r.send_buffer[r.send_buffer_end..][0..bytes.len], bytes); + r.send_buffer_end += bytes.len; + r.chunk_len += bytes.len; + return bytes.len; + } + + /// If using content-length, asserts that writing these bytes to the client + /// would not exceed the content-length value sent in the HTTP header. + pub fn writeAll(r: *Response, bytes: []const u8) WriteError!void { + var index: usize = 0; + while (index < bytes.len) { + index += try write(r, bytes[index..]); + } + } + + /// Sends all buffered data to the client. + /// This is redundant after calling `end`. + /// Respects the value of `elide_body` to omit all data after the headers. + pub fn flush(r: *Response) WriteError!void { + switch (r.transfer_encoding) { + .none, .content_length => return flush_cl(r), + .chunked => return flush_chunked(r, null), + } + } + + fn flush_cl(r: *Response) WriteError!void { + try r.stream.writeAll(r.send_buffer[r.send_buffer_start..r.send_buffer_end]); + r.send_buffer_start = 0; + r.send_buffer_end = 0; + } + + fn flush_chunked(r: *Response, end_trailers: ?[]const http.Header) WriteError!void { + const max_trailers = 25; + if (end_trailers) |trailers| assert(trailers.len <= max_trailers); + assert(r.transfer_encoding == .chunked); + + const http_headers = r.send_buffer[r.send_buffer_start .. r.send_buffer_end - r.chunk_len]; + + if (r.elide_body) { + try r.stream.writeAll(http_headers); + r.send_buffer_start = 0; + r.send_buffer_end = 0; + r.chunk_len = 0; + return; + } + + var header_buf: [18]u8 = undefined; + const chunk_header = std.fmt.bufPrint(&header_buf, "{x}\r\n", .{r.chunk_len}) catch unreachable; + + var iovecs: [max_trailers * 4 + 5]std.posix.iovec_const = undefined; + var iovecs_len: usize = 0; + + iovecs[iovecs_len] = .{ + .base = http_headers.ptr, + .len = http_headers.len, + }; + iovecs_len += 1; + + if (r.chunk_len > 0) { + iovecs[iovecs_len] = .{ + .base = chunk_header.ptr, + .len = chunk_header.len, + }; + iovecs_len += 1; + + iovecs[iovecs_len] = .{ + .base = r.send_buffer.ptr + r.send_buffer_end - r.chunk_len, + .len = r.chunk_len, + }; + iovecs_len += 1; + + iovecs[iovecs_len] = .{ + .base = "\r\n", + .len = 2, + }; + iovecs_len += 1; + } + + if (end_trailers) |trailers| { + iovecs[iovecs_len] = .{ + .base = "0\r\n", + .len = 3, + }; + iovecs_len += 1; + + for (trailers) |trailer| { + iovecs[iovecs_len] = .{ + .base = trailer.name.ptr, + .len = trailer.name.len, + }; + iovecs_len += 1; + + iovecs[iovecs_len] = .{ + .base = ": ", + .len = 2, + }; + iovecs_len += 1; + + if (trailer.value.len != 0) { + iovecs[iovecs_len] = .{ + .base = trailer.value.ptr, + .len = trailer.value.len, + }; + iovecs_len += 1; + } + + iovecs[iovecs_len] = .{ + .base = "\r\n", + .len = 2, + }; + iovecs_len += 1; + } + + iovecs[iovecs_len] = .{ + .base = "\r\n", + .len = 2, + }; + iovecs_len += 1; + } + + try r.stream.writevAll(iovecs[0..iovecs_len]); + r.send_buffer_start = 0; + r.send_buffer_end = 0; + r.chunk_len = 0; + } + + pub fn writer(r: *Response) std.io.AnyWriter { + return .{ + .writeFn = switch (r.transfer_encoding) { + .none, .content_length => write_cl, + .chunked => write_chunked, + }, + .context = r, + }; + } +}; + +fn rebase(s: *Server, index: usize) void { + const leftover = s.read_buffer[s.next_request_start..s.read_buffer_len]; + const dest = s.read_buffer[index..][0..leftover.len]; + if (leftover.len <= s.next_request_start - index) { + @memcpy(dest, leftover); + } else { + mem.copyBackwards(u8, dest, leftover); + } + s.read_buffer_len = index + leftover.len; +} + +const std = @import("std"); +const http = std.http; +const mem = std.mem; +const net = std.net; +const Uri = std.Uri; +const assert = std.debug.assert; +const testing = std.testing; + +const Server = @This(); diff --git a/src/http/async/std/http/protocol.zig b/src/http/async/std/http/protocol.zig new file mode 100644 index 00000000..389e1e4f --- /dev/null +++ b/src/http/async/std/http/protocol.zig @@ -0,0 +1,447 @@ +const std = @import("std"); +const builtin = @import("builtin"); +const testing = std.testing; +const mem = std.mem; + +const assert = std.debug.assert; +const use_vectors = builtin.zig_backend != .stage2_x86_64; + +pub const State = enum { + invalid, + + // Begin header and trailer parsing states. + + start, + seen_n, + seen_r, + seen_rn, + seen_rnr, + finished, + + // Begin transfer-encoding: chunked parsing states. + + chunk_head_size, + chunk_head_ext, + chunk_head_r, + chunk_data, + chunk_data_suffix, + chunk_data_suffix_r, + + /// Returns true if the parser is in a content state (ie. not waiting for more headers). + pub fn isContent(self: State) bool { + return switch (self) { + .invalid, .start, .seen_n, .seen_r, .seen_rn, .seen_rnr => false, + .finished, .chunk_head_size, .chunk_head_ext, .chunk_head_r, .chunk_data, .chunk_data_suffix, .chunk_data_suffix_r => true, + }; + } +}; + +pub const HeadersParser = struct { + state: State = .start, + /// A fixed buffer of len `max_header_bytes`. + /// Pointers into this buffer are not stable until after a message is complete. + header_bytes_buffer: []u8, + header_bytes_len: u32, + next_chunk_length: u64, + /// `false`: headers. `true`: trailers. + done: bool, + + /// Initializes the parser with a provided buffer `buf`. + pub fn init(buf: []u8) HeadersParser { + return .{ + .header_bytes_buffer = buf, + .header_bytes_len = 0, + .done = false, + .next_chunk_length = 0, + }; + } + + /// Reinitialize the parser. + /// Asserts the parser is in the "done" state. + pub fn reset(hp: *HeadersParser) void { + assert(hp.done); + hp.* = .{ + .state = .start, + .header_bytes_buffer = hp.header_bytes_buffer, + .header_bytes_len = 0, + .done = false, + .next_chunk_length = 0, + }; + } + + pub fn get(hp: HeadersParser) []u8 { + return hp.header_bytes_buffer[0..hp.header_bytes_len]; + } + + pub fn findHeadersEnd(r: *HeadersParser, bytes: []const u8) u32 { + var hp: std.http.HeadParser = .{ + .state = switch (r.state) { + .start => .start, + .seen_n => .seen_n, + .seen_r => .seen_r, + .seen_rn => .seen_rn, + .seen_rnr => .seen_rnr, + .finished => .finished, + else => unreachable, + }, + }; + const result = hp.feed(bytes); + r.state = switch (hp.state) { + .start => .start, + .seen_n => .seen_n, + .seen_r => .seen_r, + .seen_rn => .seen_rn, + .seen_rnr => .seen_rnr, + .finished => .finished, + }; + return @intCast(result); + } + + pub fn findChunkedLen(r: *HeadersParser, bytes: []const u8) u32 { + var cp: std.http.ChunkParser = .{ + .state = switch (r.state) { + .chunk_head_size => .head_size, + .chunk_head_ext => .head_ext, + .chunk_head_r => .head_r, + .chunk_data => .data, + .chunk_data_suffix => .data_suffix, + .chunk_data_suffix_r => .data_suffix_r, + .invalid => .invalid, + else => unreachable, + }, + .chunk_len = r.next_chunk_length, + }; + const result = cp.feed(bytes); + r.state = switch (cp.state) { + .head_size => .chunk_head_size, + .head_ext => .chunk_head_ext, + .head_r => .chunk_head_r, + .data => .chunk_data, + .data_suffix => .chunk_data_suffix, + .data_suffix_r => .chunk_data_suffix_r, + .invalid => .invalid, + }; + r.next_chunk_length = cp.chunk_len; + return @intCast(result); + } + + /// Returns whether or not the parser has finished parsing a complete + /// message. A message is only complete after the entire body has been read + /// and any trailing headers have been parsed. + pub fn isComplete(r: *HeadersParser) bool { + return r.done and r.state == .finished; + } + + pub const CheckCompleteHeadError = error{HttpHeadersOversize}; + + /// Pushes `in` into the parser. Returns the number of bytes consumed by + /// the header. Any header bytes are appended to `header_bytes_buffer`. + pub fn checkCompleteHead(hp: *HeadersParser, in: []const u8) CheckCompleteHeadError!u32 { + if (hp.state.isContent()) return 0; + + const i = hp.findHeadersEnd(in); + const data = in[0..i]; + if (hp.header_bytes_len + data.len > hp.header_bytes_buffer.len) + return error.HttpHeadersOversize; + + @memcpy(hp.header_bytes_buffer[hp.header_bytes_len..][0..data.len], data); + hp.header_bytes_len += @intCast(data.len); + + return i; + } + + pub const ReadError = error{ + HttpChunkInvalid, + }; + + /// Reads the body of the message into `buffer`. Returns the number of + /// bytes placed in the buffer. + /// + /// If `skip` is true, the buffer will be unused and the body will be skipped. + /// + /// See `std.http.Client.Connection for an example of `conn`. + pub fn read(r: *HeadersParser, conn: anytype, buffer: []u8, skip: bool) !usize { + assert(r.state.isContent()); + if (r.done) return 0; + + var out_index: usize = 0; + while (true) { + switch (r.state) { + .invalid, .start, .seen_n, .seen_r, .seen_rn, .seen_rnr => unreachable, + .finished => { + const data_avail = r.next_chunk_length; + + if (skip) { + try conn.fill(); + + const nread = @min(conn.peek().len, data_avail); + conn.drop(@intCast(nread)); + r.next_chunk_length -= nread; + + if (r.next_chunk_length == 0 or nread == 0) r.done = true; + + return out_index; + } else if (out_index < buffer.len) { + const out_avail = buffer.len - out_index; + + const can_read = @as(usize, @intCast(@min(data_avail, out_avail))); + const nread = try conn.read(buffer[0..can_read]); + r.next_chunk_length -= nread; + + if (r.next_chunk_length == 0 or nread == 0) r.done = true; + + return nread; + } else { + return out_index; + } + }, + .chunk_data_suffix, .chunk_data_suffix_r, .chunk_head_size, .chunk_head_ext, .chunk_head_r => { + try conn.fill(); + + const i = r.findChunkedLen(conn.peek()); + conn.drop(@intCast(i)); + + switch (r.state) { + .invalid => return error.HttpChunkInvalid, + .chunk_data => if (r.next_chunk_length == 0) { + if (std.mem.eql(u8, conn.peek(), "\r\n")) { + r.state = .finished; + conn.drop(2); + } else { + // The trailer section is formatted identically + // to the header section. + r.state = .seen_rn; + } + r.done = true; + + return out_index; + }, + else => return out_index, + } + + continue; + }, + .chunk_data => { + const data_avail = r.next_chunk_length; + const out_avail = buffer.len - out_index; + + if (skip) { + try conn.fill(); + + const nread = @min(conn.peek().len, data_avail); + conn.drop(@intCast(nread)); + r.next_chunk_length -= nread; + } else if (out_avail > 0) { + const can_read: usize = @intCast(@min(data_avail, out_avail)); + const nread = try conn.read(buffer[out_index..][0..can_read]); + r.next_chunk_length -= nread; + out_index += nread; + } + + if (r.next_chunk_length == 0) { + r.state = .chunk_data_suffix; + continue; + } + + return out_index; + }, + } + } + } +}; + +inline fn int16(array: *const [2]u8) u16 { + return @as(u16, @bitCast(array.*)); +} + +inline fn int24(array: *const [3]u8) u24 { + return @as(u24, @bitCast(array.*)); +} + +inline fn int32(array: *const [4]u8) u32 { + return @as(u32, @bitCast(array.*)); +} + +inline fn intShift(comptime T: type, x: anytype) T { + switch (@import("builtin").cpu.arch.endian()) { + .little => return @as(T, @truncate(x >> (@bitSizeOf(@TypeOf(x)) - @bitSizeOf(T)))), + .big => return @as(T, @truncate(x)), + } +} + +/// A buffered (and peekable) Connection. +const MockBufferedConnection = struct { + pub const buffer_size = 0x2000; + + conn: std.io.FixedBufferStream([]const u8), + buf: [buffer_size]u8 = undefined, + start: u16 = 0, + end: u16 = 0, + + pub fn fill(conn: *MockBufferedConnection) ReadError!void { + if (conn.end != conn.start) return; + + const nread = try conn.conn.read(conn.buf[0..]); + if (nread == 0) return error.EndOfStream; + conn.start = 0; + conn.end = @as(u16, @truncate(nread)); + } + + pub fn peek(conn: *MockBufferedConnection) []const u8 { + return conn.buf[conn.start..conn.end]; + } + + pub fn drop(conn: *MockBufferedConnection, num: u16) void { + conn.start += num; + } + + pub fn readAtLeast(conn: *MockBufferedConnection, buffer: []u8, len: usize) ReadError!usize { + var out_index: u16 = 0; + while (out_index < len) { + const available = conn.end - conn.start; + const left = buffer.len - out_index; + + if (available > 0) { + const can_read = @as(u16, @truncate(@min(available, left))); + + @memcpy(buffer[out_index..][0..can_read], conn.buf[conn.start..][0..can_read]); + out_index += can_read; + conn.start += can_read; + + continue; + } + + if (left > conn.buf.len) { + // skip the buffer if the output is large enough + return conn.conn.read(buffer[out_index..]); + } + + try conn.fill(); + } + + return out_index; + } + + pub fn read(conn: *MockBufferedConnection, buffer: []u8) ReadError!usize { + return conn.readAtLeast(buffer, 1); + } + + pub const ReadError = std.io.FixedBufferStream([]const u8).ReadError || error{EndOfStream}; + pub const Reader = std.io.Reader(*MockBufferedConnection, ReadError, read); + + pub fn reader(conn: *MockBufferedConnection) Reader { + return Reader{ .context = conn }; + } + + pub fn writeAll(conn: *MockBufferedConnection, buffer: []const u8) WriteError!void { + return conn.conn.writeAll(buffer); + } + + pub fn write(conn: *MockBufferedConnection, buffer: []const u8) WriteError!usize { + return conn.conn.write(buffer); + } + + pub const WriteError = std.io.FixedBufferStream([]const u8).WriteError; + pub const Writer = std.io.Writer(*MockBufferedConnection, WriteError, write); + + pub fn writer(conn: *MockBufferedConnection) Writer { + return Writer{ .context = conn }; + } +}; + +test "HeadersParser.read length" { + // mock BufferedConnection for read + var headers_buf: [256]u8 = undefined; + + var r = HeadersParser.init(&headers_buf); + const data = "GET / HTTP/1.1\r\nHost: localhost\r\nContent-Length: 5\r\n\r\nHello"; + + var conn: MockBufferedConnection = .{ + .conn = std.io.fixedBufferStream(data), + }; + + while (true) { // read headers + try conn.fill(); + + const nchecked = try r.checkCompleteHead(conn.peek()); + conn.drop(@intCast(nchecked)); + + if (r.state.isContent()) break; + } + + var buf: [8]u8 = undefined; + + r.next_chunk_length = 5; + const len = try r.read(&conn, &buf, false); + try std.testing.expectEqual(@as(usize, 5), len); + try std.testing.expectEqualStrings("Hello", buf[0..len]); + + try std.testing.expectEqualStrings("GET / HTTP/1.1\r\nHost: localhost\r\nContent-Length: 5\r\n\r\n", r.get()); +} + +test "HeadersParser.read chunked" { + // mock BufferedConnection for read + + var headers_buf: [256]u8 = undefined; + var r = HeadersParser.init(&headers_buf); + const data = "GET / HTTP/1.1\r\nHost: localhost\r\n\r\n2\r\nHe\r\n2\r\nll\r\n1\r\no\r\n0\r\n\r\n"; + + var conn: MockBufferedConnection = .{ + .conn = std.io.fixedBufferStream(data), + }; + + while (true) { // read headers + try conn.fill(); + + const nchecked = try r.checkCompleteHead(conn.peek()); + conn.drop(@intCast(nchecked)); + + if (r.state.isContent()) break; + } + var buf: [8]u8 = undefined; + + r.state = .chunk_head_size; + const len = try r.read(&conn, &buf, false); + try std.testing.expectEqual(@as(usize, 5), len); + try std.testing.expectEqualStrings("Hello", buf[0..len]); + + try std.testing.expectEqualStrings("GET / HTTP/1.1\r\nHost: localhost\r\n\r\n", r.get()); +} + +test "HeadersParser.read chunked trailer" { + // mock BufferedConnection for read + + var headers_buf: [256]u8 = undefined; + var r = HeadersParser.init(&headers_buf); + const data = "GET / HTTP/1.1\r\nHost: localhost\r\n\r\n2\r\nHe\r\n2\r\nll\r\n1\r\no\r\n0\r\nContent-Type: text/plain\r\n\r\n"; + + var conn: MockBufferedConnection = .{ + .conn = std.io.fixedBufferStream(data), + }; + + while (true) { // read headers + try conn.fill(); + + const nchecked = try r.checkCompleteHead(conn.peek()); + conn.drop(@intCast(nchecked)); + + if (r.state.isContent()) break; + } + var buf: [8]u8 = undefined; + + r.state = .chunk_head_size; + const len = try r.read(&conn, &buf, false); + try std.testing.expectEqual(@as(usize, 5), len); + try std.testing.expectEqualStrings("Hello", buf[0..len]); + + while (true) { // read headers + try conn.fill(); + + const nchecked = try r.checkCompleteHead(conn.peek()); + conn.drop(@intCast(nchecked)); + + if (r.state.isContent()) break; + } + + try std.testing.expectEqualStrings("GET / HTTP/1.1\r\nHost: localhost\r\n\r\nContent-Type: text/plain\r\n\r\n", r.get()); +} diff --git a/src/http/async/std/net.zig b/src/http/async/std/net.zig new file mode 100644 index 00000000..86863deb --- /dev/null +++ b/src/http/async/std/net.zig @@ -0,0 +1,2050 @@ +//! Cross-platform networking abstractions. + +const std = @import("std"); +const builtin = @import("builtin"); +const assert = std.debug.assert; +const net = @This(); +const mem = std.mem; +const posix = std.posix; +const fs = std.fs; +const io = std.io; +const native_endian = builtin.target.cpu.arch.endian(); +const native_os = builtin.os.tag; +const windows = std.os.windows; + +const Ctx = @import("http/Client.zig").Ctx; +const Cbk = @import("http/Client.zig").Cbk; + +// Windows 10 added support for unix sockets in build 17063, redstone 4 is the +// first release to support them. +pub const has_unix_sockets = switch (native_os) { + .windows => builtin.os.version_range.windows.isAtLeast(.win10_rs4) orelse false, + else => true, +}; + +pub const IPParseError = error{ + Overflow, + InvalidEnd, + InvalidCharacter, + Incomplete, +}; + +pub const IPv4ParseError = IPParseError || error{NonCanonical}; + +pub const IPv6ParseError = IPParseError || error{InvalidIpv4Mapping}; +pub const IPv6InterfaceError = posix.SocketError || posix.IoCtl_SIOCGIFINDEX_Error || error{NameTooLong}; +pub const IPv6ResolveError = IPv6ParseError || IPv6InterfaceError; + +pub const Address = extern union { + any: posix.sockaddr, + in: Ip4Address, + in6: Ip6Address, + un: if (has_unix_sockets) posix.sockaddr.un else void, + + /// Parse the given IP address string into an Address value. + /// It is recommended to use `resolveIp` instead, to handle + /// IPv6 link-local unix addresses. + pub fn parseIp(name: []const u8, port: u16) !Address { + if (parseIp4(name, port)) |ip4| return ip4 else |err| switch (err) { + error.Overflow, + error.InvalidEnd, + error.InvalidCharacter, + error.Incomplete, + error.NonCanonical, + => {}, + } + + if (parseIp6(name, port)) |ip6| return ip6 else |err| switch (err) { + error.Overflow, + error.InvalidEnd, + error.InvalidCharacter, + error.Incomplete, + error.InvalidIpv4Mapping, + => {}, + } + + return error.InvalidIPAddressFormat; + } + + pub fn resolveIp(name: []const u8, port: u16) !Address { + if (parseIp4(name, port)) |ip4| return ip4 else |err| switch (err) { + error.Overflow, + error.InvalidEnd, + error.InvalidCharacter, + error.Incomplete, + error.NonCanonical, + => {}, + } + + if (resolveIp6(name, port)) |ip6| return ip6 else |err| switch (err) { + error.Overflow, + error.InvalidEnd, + error.InvalidCharacter, + error.Incomplete, + error.InvalidIpv4Mapping, + => {}, + else => return err, + } + + return error.InvalidIPAddressFormat; + } + + pub fn parseExpectingFamily(name: []const u8, family: posix.sa_family_t, port: u16) !Address { + switch (family) { + posix.AF.INET => return parseIp4(name, port), + posix.AF.INET6 => return parseIp6(name, port), + posix.AF.UNSPEC => return parseIp(name, port), + else => unreachable, + } + } + + pub fn parseIp6(buf: []const u8, port: u16) IPv6ParseError!Address { + return .{ .in6 = try Ip6Address.parse(buf, port) }; + } + + pub fn resolveIp6(buf: []const u8, port: u16) IPv6ResolveError!Address { + return .{ .in6 = try Ip6Address.resolve(buf, port) }; + } + + pub fn parseIp4(buf: []const u8, port: u16) IPv4ParseError!Address { + return .{ .in = try Ip4Address.parse(buf, port) }; + } + + pub fn initIp4(addr: [4]u8, port: u16) Address { + return .{ .in = Ip4Address.init(addr, port) }; + } + + pub fn initIp6(addr: [16]u8, port: u16, flowinfo: u32, scope_id: u32) Address { + return .{ .in6 = Ip6Address.init(addr, port, flowinfo, scope_id) }; + } + + pub fn initUnix(path: []const u8) !Address { + var sock_addr = posix.sockaddr.un{ + .family = posix.AF.UNIX, + .path = undefined, + }; + + // Add 1 to ensure a terminating 0 is present in the path array for maximum portability. + if (path.len + 1 > sock_addr.path.len) return error.NameTooLong; + + @memset(&sock_addr.path, 0); + @memcpy(sock_addr.path[0..path.len], path); + + return .{ .un = sock_addr }; + } + + /// Returns the port in native endian. + /// Asserts that the address is ip4 or ip6. + pub fn getPort(self: Address) u16 { + return switch (self.any.family) { + posix.AF.INET => self.in.getPort(), + posix.AF.INET6 => self.in6.getPort(), + else => unreachable, + }; + } + + /// `port` is native-endian. + /// Asserts that the address is ip4 or ip6. + pub fn setPort(self: *Address, port: u16) void { + switch (self.any.family) { + posix.AF.INET => self.in.setPort(port), + posix.AF.INET6 => self.in6.setPort(port), + else => unreachable, + } + } + + /// Asserts that `addr` is an IP address. + /// This function will read past the end of the pointer, with a size depending + /// on the address family. + pub fn initPosix(addr: *align(4) const posix.sockaddr) Address { + switch (addr.family) { + posix.AF.INET => return Address{ .in = Ip4Address{ .sa = @as(*const posix.sockaddr.in, @ptrCast(addr)).* } }, + posix.AF.INET6 => return Address{ .in6 = Ip6Address{ .sa = @as(*const posix.sockaddr.in6, @ptrCast(addr)).* } }, + else => unreachable, + } + } + + pub fn format( + self: Address, + comptime fmt: []const u8, + options: std.fmt.FormatOptions, + out_stream: anytype, + ) !void { + if (fmt.len != 0) std.fmt.invalidFmtError(fmt, self); + switch (self.any.family) { + posix.AF.INET => try self.in.format(fmt, options, out_stream), + posix.AF.INET6 => try self.in6.format(fmt, options, out_stream), + posix.AF.UNIX => { + if (!has_unix_sockets) { + unreachable; + } + + try std.fmt.format(out_stream, "{s}", .{std.mem.sliceTo(&self.un.path, 0)}); + }, + else => unreachable, + } + } + + pub fn eql(a: Address, b: Address) bool { + const a_bytes = @as([*]const u8, @ptrCast(&a.any))[0..a.getOsSockLen()]; + const b_bytes = @as([*]const u8, @ptrCast(&b.any))[0..b.getOsSockLen()]; + return mem.eql(u8, a_bytes, b_bytes); + } + + pub fn getOsSockLen(self: Address) posix.socklen_t { + switch (self.any.family) { + posix.AF.INET => return self.in.getOsSockLen(), + posix.AF.INET6 => return self.in6.getOsSockLen(), + posix.AF.UNIX => { + if (!has_unix_sockets) { + unreachable; + } + + // Using the full length of the structure here is more portable than returning + // the number of bytes actually used by the currently stored path. + // This also is correct regardless if we are passing a socket address to the kernel + // (e.g. in bind, connect, sendto) since we ensure the path is 0 terminated in + // initUnix() or if we are receiving a socket address from the kernel and must + // provide the full buffer size (e.g. getsockname, getpeername, recvfrom, accept). + // + // To access the path, std.mem.sliceTo(&address.un.path, 0) should be used. + return @as(posix.socklen_t, @intCast(@sizeOf(posix.sockaddr.un))); + }, + + else => unreachable, + } + } + + pub const ListenError = posix.SocketError || posix.BindError || posix.ListenError || + posix.SetSockOptError || posix.GetSockNameError; + + pub const ListenOptions = struct { + /// How many connections the kernel will accept on the application's behalf. + /// If more than this many connections pool in the kernel, clients will start + /// seeing "Connection refused". + kernel_backlog: u31 = 128, + /// Sets SO_REUSEADDR and SO_REUSEPORT on POSIX. + /// Sets SO_REUSEADDR on Windows, which is roughly equivalent. + reuse_address: bool = false, + /// Deprecated. Does the same thing as reuse_address. + reuse_port: bool = false, + force_nonblocking: bool = false, + }; + + /// The returned `Server` has an open `stream`. + pub fn listen(address: Address, options: ListenOptions) ListenError!Server { + const nonblock: u32 = if (options.force_nonblocking) posix.SOCK.NONBLOCK else 0; + const sock_flags = posix.SOCK.STREAM | posix.SOCK.CLOEXEC | nonblock; + const proto: u32 = if (address.any.family == posix.AF.UNIX) 0 else posix.IPPROTO.TCP; + + const sockfd = try posix.socket(address.any.family, sock_flags, proto); + var s: Server = .{ + .listen_address = undefined, + .stream = .{ .handle = sockfd }, + }; + errdefer s.stream.close(); + + if (options.reuse_address or options.reuse_port) { + try posix.setsockopt( + sockfd, + posix.SOL.SOCKET, + posix.SO.REUSEADDR, + &mem.toBytes(@as(c_int, 1)), + ); + switch (native_os) { + .windows => {}, + else => try posix.setsockopt( + sockfd, + posix.SOL.SOCKET, + posix.SO.REUSEPORT, + &mem.toBytes(@as(c_int, 1)), + ), + } + } + + var socklen = address.getOsSockLen(); + try posix.bind(sockfd, &address.any, socklen); + try posix.listen(sockfd, options.kernel_backlog); + try posix.getsockname(sockfd, &s.listen_address.any, &socklen); + return s; + } +}; + +pub const Ip4Address = extern struct { + sa: posix.sockaddr.in, + + pub fn parse(buf: []const u8, port: u16) IPv4ParseError!Ip4Address { + var result: Ip4Address = .{ + .sa = .{ + .port = mem.nativeToBig(u16, port), + .addr = undefined, + }, + }; + const out_ptr = mem.asBytes(&result.sa.addr); + + var x: u8 = 0; + var index: u8 = 0; + var saw_any_digits = false; + var has_zero_prefix = false; + for (buf) |c| { + if (c == '.') { + if (!saw_any_digits) { + return error.InvalidCharacter; + } + if (index == 3) { + return error.InvalidEnd; + } + out_ptr[index] = x; + index += 1; + x = 0; + saw_any_digits = false; + has_zero_prefix = false; + } else if (c >= '0' and c <= '9') { + if (c == '0' and !saw_any_digits) { + has_zero_prefix = true; + } else if (has_zero_prefix) { + return error.NonCanonical; + } + saw_any_digits = true; + x = try std.math.mul(u8, x, 10); + x = try std.math.add(u8, x, c - '0'); + } else { + return error.InvalidCharacter; + } + } + if (index == 3 and saw_any_digits) { + out_ptr[index] = x; + return result; + } + + return error.Incomplete; + } + + pub fn resolveIp(name: []const u8, port: u16) !Ip4Address { + if (parse(name, port)) |ip4| return ip4 else |err| switch (err) { + error.Overflow, + error.InvalidEnd, + error.InvalidCharacter, + error.Incomplete, + error.NonCanonical, + => {}, + } + return error.InvalidIPAddressFormat; + } + + pub fn init(addr: [4]u8, port: u16) Ip4Address { + return Ip4Address{ + .sa = posix.sockaddr.in{ + .port = mem.nativeToBig(u16, port), + .addr = @as(*align(1) const u32, @ptrCast(&addr)).*, + }, + }; + } + + /// Returns the port in native endian. + /// Asserts that the address is ip4 or ip6. + pub fn getPort(self: Ip4Address) u16 { + return mem.bigToNative(u16, self.sa.port); + } + + /// `port` is native-endian. + /// Asserts that the address is ip4 or ip6. + pub fn setPort(self: *Ip4Address, port: u16) void { + self.sa.port = mem.nativeToBig(u16, port); + } + + pub fn format( + self: Ip4Address, + comptime fmt: []const u8, + options: std.fmt.FormatOptions, + out_stream: anytype, + ) !void { + if (fmt.len != 0) std.fmt.invalidFmtError(fmt, self); + _ = options; + const bytes = @as(*const [4]u8, @ptrCast(&self.sa.addr)); + try std.fmt.format(out_stream, "{}.{}.{}.{}:{}", .{ + bytes[0], + bytes[1], + bytes[2], + bytes[3], + self.getPort(), + }); + } + + pub fn getOsSockLen(self: Ip4Address) posix.socklen_t { + _ = self; + return @sizeOf(posix.sockaddr.in); + } +}; + +pub const Ip6Address = extern struct { + sa: posix.sockaddr.in6, + + /// Parse a given IPv6 address string into an Address. + /// Assumes the Scope ID of the address is fully numeric. + /// For non-numeric addresses, see `resolveIp6`. + pub fn parse(buf: []const u8, port: u16) IPv6ParseError!Ip6Address { + var result = Ip6Address{ + .sa = posix.sockaddr.in6{ + .scope_id = 0, + .port = mem.nativeToBig(u16, port), + .flowinfo = 0, + .addr = undefined, + }, + }; + var ip_slice: *[16]u8 = result.sa.addr[0..]; + + var tail: [16]u8 = undefined; + + var x: u16 = 0; + var saw_any_digits = false; + var index: u8 = 0; + var scope_id = false; + var abbrv = false; + for (buf, 0..) |c, i| { + if (scope_id) { + if (c >= '0' and c <= '9') { + const digit = c - '0'; + { + const ov = @mulWithOverflow(result.sa.scope_id, 10); + if (ov[1] != 0) return error.Overflow; + result.sa.scope_id = ov[0]; + } + { + const ov = @addWithOverflow(result.sa.scope_id, digit); + if (ov[1] != 0) return error.Overflow; + result.sa.scope_id = ov[0]; + } + } else { + return error.InvalidCharacter; + } + } else if (c == ':') { + if (!saw_any_digits) { + if (abbrv) return error.InvalidCharacter; // ':::' + if (i != 0) abbrv = true; + @memset(ip_slice[index..], 0); + ip_slice = tail[0..]; + index = 0; + continue; + } + if (index == 14) { + return error.InvalidEnd; + } + ip_slice[index] = @as(u8, @truncate(x >> 8)); + index += 1; + ip_slice[index] = @as(u8, @truncate(x)); + index += 1; + + x = 0; + saw_any_digits = false; + } else if (c == '%') { + if (!saw_any_digits) { + return error.InvalidCharacter; + } + scope_id = true; + saw_any_digits = false; + } else if (c == '.') { + if (!abbrv or ip_slice[0] != 0xff or ip_slice[1] != 0xff) { + // must start with '::ffff:' + return error.InvalidIpv4Mapping; + } + const start_index = mem.lastIndexOfScalar(u8, buf[0..i], ':').? + 1; + const addr = (Ip4Address.parse(buf[start_index..], 0) catch { + return error.InvalidIpv4Mapping; + }).sa.addr; + ip_slice = result.sa.addr[0..]; + ip_slice[10] = 0xff; + ip_slice[11] = 0xff; + + const ptr = mem.sliceAsBytes(@as(*const [1]u32, &addr)[0..]); + + ip_slice[12] = ptr[0]; + ip_slice[13] = ptr[1]; + ip_slice[14] = ptr[2]; + ip_slice[15] = ptr[3]; + return result; + } else { + const digit = try std.fmt.charToDigit(c, 16); + { + const ov = @mulWithOverflow(x, 16); + if (ov[1] != 0) return error.Overflow; + x = ov[0]; + } + { + const ov = @addWithOverflow(x, digit); + if (ov[1] != 0) return error.Overflow; + x = ov[0]; + } + saw_any_digits = true; + } + } + + if (!saw_any_digits and !abbrv) { + return error.Incomplete; + } + if (!abbrv and index < 14) { + return error.Incomplete; + } + + if (index == 14) { + ip_slice[14] = @as(u8, @truncate(x >> 8)); + ip_slice[15] = @as(u8, @truncate(x)); + return result; + } else { + ip_slice[index] = @as(u8, @truncate(x >> 8)); + index += 1; + ip_slice[index] = @as(u8, @truncate(x)); + index += 1; + @memcpy(result.sa.addr[16 - index ..][0..index], ip_slice[0..index]); + return result; + } + } + + pub fn resolve(buf: []const u8, port: u16) IPv6ResolveError!Ip6Address { + // TODO: Unify the implementations of resolveIp6 and parseIp6. + var result = Ip6Address{ + .sa = posix.sockaddr.in6{ + .scope_id = 0, + .port = mem.nativeToBig(u16, port), + .flowinfo = 0, + .addr = undefined, + }, + }; + var ip_slice: *[16]u8 = result.sa.addr[0..]; + + var tail: [16]u8 = undefined; + + var x: u16 = 0; + var saw_any_digits = false; + var index: u8 = 0; + var abbrv = false; + + var scope_id = false; + var scope_id_value: [posix.IFNAMESIZE - 1]u8 = undefined; + var scope_id_index: usize = 0; + + for (buf, 0..) |c, i| { + if (scope_id) { + // Handling of percent-encoding should be for an URI library. + if ((c >= '0' and c <= '9') or + (c >= 'A' and c <= 'Z') or + (c >= 'a' and c <= 'z') or + (c == '-') or (c == '.') or (c == '_') or (c == '~')) + { + if (scope_id_index >= scope_id_value.len) { + return error.Overflow; + } + + scope_id_value[scope_id_index] = c; + scope_id_index += 1; + } else { + return error.InvalidCharacter; + } + } else if (c == ':') { + if (!saw_any_digits) { + if (abbrv) return error.InvalidCharacter; // ':::' + if (i != 0) abbrv = true; + @memset(ip_slice[index..], 0); + ip_slice = tail[0..]; + index = 0; + continue; + } + if (index == 14) { + return error.InvalidEnd; + } + ip_slice[index] = @as(u8, @truncate(x >> 8)); + index += 1; + ip_slice[index] = @as(u8, @truncate(x)); + index += 1; + + x = 0; + saw_any_digits = false; + } else if (c == '%') { + if (!saw_any_digits) { + return error.InvalidCharacter; + } + scope_id = true; + saw_any_digits = false; + } else if (c == '.') { + if (!abbrv or ip_slice[0] != 0xff or ip_slice[1] != 0xff) { + // must start with '::ffff:' + return error.InvalidIpv4Mapping; + } + const start_index = mem.lastIndexOfScalar(u8, buf[0..i], ':').? + 1; + const addr = (Ip4Address.parse(buf[start_index..], 0) catch { + return error.InvalidIpv4Mapping; + }).sa.addr; + ip_slice = result.sa.addr[0..]; + ip_slice[10] = 0xff; + ip_slice[11] = 0xff; + + const ptr = mem.sliceAsBytes(@as(*const [1]u32, &addr)[0..]); + + ip_slice[12] = ptr[0]; + ip_slice[13] = ptr[1]; + ip_slice[14] = ptr[2]; + ip_slice[15] = ptr[3]; + return result; + } else { + const digit = try std.fmt.charToDigit(c, 16); + { + const ov = @mulWithOverflow(x, 16); + if (ov[1] != 0) return error.Overflow; + x = ov[0]; + } + { + const ov = @addWithOverflow(x, digit); + if (ov[1] != 0) return error.Overflow; + x = ov[0]; + } + saw_any_digits = true; + } + } + + if (!saw_any_digits and !abbrv) { + return error.Incomplete; + } + + if (scope_id and scope_id_index == 0) { + return error.Incomplete; + } + + var resolved_scope_id: u32 = 0; + if (scope_id_index > 0) { + const scope_id_str = scope_id_value[0..scope_id_index]; + resolved_scope_id = std.fmt.parseInt(u32, scope_id_str, 10) catch |err| blk: { + if (err != error.InvalidCharacter) return err; + break :blk try if_nametoindex(scope_id_str); + }; + } + + result.sa.scope_id = resolved_scope_id; + + if (index == 14) { + ip_slice[14] = @as(u8, @truncate(x >> 8)); + ip_slice[15] = @as(u8, @truncate(x)); + return result; + } else { + ip_slice[index] = @as(u8, @truncate(x >> 8)); + index += 1; + ip_slice[index] = @as(u8, @truncate(x)); + index += 1; + @memcpy(result.sa.addr[16 - index ..][0..index], ip_slice[0..index]); + return result; + } + } + + pub fn init(addr: [16]u8, port: u16, flowinfo: u32, scope_id: u32) Ip6Address { + return Ip6Address{ + .sa = posix.sockaddr.in6{ + .addr = addr, + .port = mem.nativeToBig(u16, port), + .flowinfo = flowinfo, + .scope_id = scope_id, + }, + }; + } + + /// Returns the port in native endian. + /// Asserts that the address is ip4 or ip6. + pub fn getPort(self: Ip6Address) u16 { + return mem.bigToNative(u16, self.sa.port); + } + + /// `port` is native-endian. + /// Asserts that the address is ip4 or ip6. + pub fn setPort(self: *Ip6Address, port: u16) void { + self.sa.port = mem.nativeToBig(u16, port); + } + + pub fn format( + self: Ip6Address, + comptime fmt: []const u8, + options: std.fmt.FormatOptions, + out_stream: anytype, + ) !void { + if (fmt.len != 0) std.fmt.invalidFmtError(fmt, self); + _ = options; + const port = mem.bigToNative(u16, self.sa.port); + if (mem.eql(u8, self.sa.addr[0..12], &[_]u8{ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xff, 0xff })) { + try std.fmt.format(out_stream, "[::ffff:{}.{}.{}.{}]:{}", .{ + self.sa.addr[12], + self.sa.addr[13], + self.sa.addr[14], + self.sa.addr[15], + port, + }); + return; + } + const big_endian_parts = @as(*align(1) const [8]u16, @ptrCast(&self.sa.addr)); + const native_endian_parts = switch (native_endian) { + .big => big_endian_parts.*, + .little => blk: { + var buf: [8]u16 = undefined; + for (big_endian_parts, 0..) |part, i| { + buf[i] = mem.bigToNative(u16, part); + } + break :blk buf; + }, + }; + try out_stream.writeAll("["); + var i: usize = 0; + var abbrv = false; + while (i < native_endian_parts.len) : (i += 1) { + if (native_endian_parts[i] == 0) { + if (!abbrv) { + try out_stream.writeAll(if (i == 0) "::" else ":"); + abbrv = true; + } + continue; + } + try std.fmt.format(out_stream, "{x}", .{native_endian_parts[i]}); + if (i != native_endian_parts.len - 1) { + try out_stream.writeAll(":"); + } + } + try std.fmt.format(out_stream, "]:{}", .{port}); + } + + pub fn getOsSockLen(self: Ip6Address) posix.socklen_t { + _ = self; + return @sizeOf(posix.sockaddr.in6); + } +}; + +pub fn connectUnixSocket(path: []const u8) !Stream { + const opt_non_block = 0; + const sockfd = try posix.socket( + posix.AF.UNIX, + posix.SOCK.STREAM | posix.SOCK.CLOEXEC | opt_non_block, + 0, + ); + errdefer Stream.close(.{ .handle = sockfd }); + + var addr = try std.net.Address.initUnix(path); + try posix.connect(sockfd, &addr.any, addr.getOsSockLen()); + + return .{ .handle = sockfd }; +} + +fn if_nametoindex(name: []const u8) IPv6InterfaceError!u32 { + if (native_os == .linux) { + var ifr: posix.ifreq = undefined; + const sockfd = try posix.socket(posix.AF.UNIX, posix.SOCK.DGRAM | posix.SOCK.CLOEXEC, 0); + defer Stream.close(.{ .handle = sockfd }); + + @memcpy(ifr.ifrn.name[0..name.len], name); + ifr.ifrn.name[name.len] = 0; + + // TODO investigate if this needs to be integrated with evented I/O. + try posix.ioctl_SIOCGIFINDEX(sockfd, &ifr); + + return @bitCast(ifr.ifru.ivalue); + } + + if (native_os.isDarwin()) { + if (name.len >= posix.IFNAMESIZE) + return error.NameTooLong; + + var if_name: [posix.IFNAMESIZE:0]u8 = undefined; + @memcpy(if_name[0..name.len], name); + if_name[name.len] = 0; + const if_slice = if_name[0..name.len :0]; + const index = std.c.if_nametoindex(if_slice); + if (index == 0) + return error.InterfaceNotFound; + return @as(u32, @bitCast(index)); + } + + @compileError("std.net.if_nametoindex unimplemented for this OS"); +} + +pub const AddressList = struct { + arena: std.heap.ArenaAllocator, + addrs: []Address, + canon_name: ?[]u8, + + pub fn deinit(self: *AddressList) void { + // Here we copy the arena allocator into stack memory, because + // otherwise it would destroy itself while it was still working. + var arena = self.arena; + arena.deinit(); + // self is destroyed + } +}; + +pub const TcpConnectToHostError = GetAddressListError || TcpConnectToAddressError; + +/// All memory allocated with `allocator` will be freed before this function returns. +pub fn tcpConnectToHost(allocator: mem.Allocator, name: []const u8, port: u16) TcpConnectToHostError!Stream { + const list = try getAddressList(allocator, name, port); + defer list.deinit(); + + if (list.addrs.len == 0) return error.UnknownHostName; + + for (list.addrs) |addr| { + return tcpConnectToAddress(addr) catch |err| switch (err) { + error.ConnectionRefused => { + continue; + }, + else => return err, + }; + } + return posix.ConnectError.ConnectionRefused; +} + +pub const TcpConnectToAddressError = posix.SocketError || posix.ConnectError; + +pub fn tcpConnectToAddress(address: Address) TcpConnectToAddressError!Stream { + const nonblock = 0; + const sock_flags = posix.SOCK.STREAM | nonblock | + (if (native_os == .windows) 0 else posix.SOCK.CLOEXEC); + const sockfd = try posix.socket(address.any.family, sock_flags, posix.IPPROTO.TCP); + errdefer Stream.close(.{ .handle = sockfd }); + + try posix.connect(sockfd, &address.any, address.getOsSockLen()); + + return Stream{ .handle = sockfd }; +} + +const GetAddressListError = std.mem.Allocator.Error || std.fs.File.OpenError || std.fs.File.ReadError || posix.SocketError || posix.BindError || posix.SetSockOptError || error{ + // TODO: break this up into error sets from the various underlying functions + + TemporaryNameServerFailure, + NameServerFailure, + AddressFamilyNotSupported, + UnknownHostName, + ServiceUnavailable, + Unexpected, + + HostLacksNetworkAddresses, + + InvalidCharacter, + InvalidEnd, + NonCanonical, + Overflow, + Incomplete, + InvalidIpv4Mapping, + InvalidIPAddressFormat, + + InterfaceNotFound, + FileSystem, +}; + +/// Call `AddressList.deinit` on the result. +pub fn getAddressList(allocator: mem.Allocator, name: []const u8, port: u16) GetAddressListError!*AddressList { + const result = blk: { + var arena = std.heap.ArenaAllocator.init(allocator); + errdefer arena.deinit(); + + const result = try arena.allocator().create(AddressList); + result.* = AddressList{ + .arena = arena, + .addrs = undefined, + .canon_name = null, + }; + break :blk result; + }; + const arena = result.arena.allocator(); + errdefer result.deinit(); + + if (native_os == .windows) { + const name_c = try allocator.dupeZ(u8, name); + defer allocator.free(name_c); + + const port_c = try std.fmt.allocPrintZ(allocator, "{}", .{port}); + defer allocator.free(port_c); + + const ws2_32 = windows.ws2_32; + const hints = posix.addrinfo{ + .flags = ws2_32.AI.NUMERICSERV, + .family = posix.AF.UNSPEC, + .socktype = posix.SOCK.STREAM, + .protocol = posix.IPPROTO.TCP, + .canonname = null, + .addr = null, + .addrlen = 0, + .next = null, + }; + var res: ?*posix.addrinfo = null; + var first = true; + while (true) { + const rc = ws2_32.getaddrinfo(name_c.ptr, port_c.ptr, &hints, &res); + switch (@as(windows.ws2_32.WinsockError, @enumFromInt(@as(u16, @intCast(rc))))) { + @as(windows.ws2_32.WinsockError, @enumFromInt(0)) => break, + .WSATRY_AGAIN => return error.TemporaryNameServerFailure, + .WSANO_RECOVERY => return error.NameServerFailure, + .WSAEAFNOSUPPORT => return error.AddressFamilyNotSupported, + .WSA_NOT_ENOUGH_MEMORY => return error.OutOfMemory, + .WSAHOST_NOT_FOUND => return error.UnknownHostName, + .WSATYPE_NOT_FOUND => return error.ServiceUnavailable, + .WSAEINVAL => unreachable, + .WSAESOCKTNOSUPPORT => unreachable, + .WSANOTINITIALISED => { + if (!first) return error.Unexpected; + first = false; + try windows.callWSAStartup(); + continue; + }, + else => |err| return windows.unexpectedWSAError(err), + } + } + defer ws2_32.freeaddrinfo(res); + + const addr_count = blk: { + var count: usize = 0; + var it = res; + while (it) |info| : (it = info.next) { + if (info.addr != null) { + count += 1; + } + } + break :blk count; + }; + result.addrs = try arena.alloc(Address, addr_count); + + var it = res; + var i: usize = 0; + while (it) |info| : (it = info.next) { + const addr = info.addr orelse continue; + result.addrs[i] = Address.initPosix(@alignCast(addr)); + + if (info.canonname) |n| { + if (result.canon_name == null) { + result.canon_name = try arena.dupe(u8, mem.sliceTo(n, 0)); + } + } + i += 1; + } + + return result; + } + + if (builtin.link_libc) { + const name_c = try allocator.dupeZ(u8, name); + defer allocator.free(name_c); + + const port_c = try std.fmt.allocPrintZ(allocator, "{}", .{port}); + defer allocator.free(port_c); + + const sys = if (native_os == .windows) windows.ws2_32 else posix.system; + const hints = posix.addrinfo{ + .flags = sys.AI.NUMERICSERV, + .family = posix.AF.UNSPEC, + .socktype = posix.SOCK.STREAM, + .protocol = posix.IPPROTO.TCP, + .canonname = null, + .addr = null, + .addrlen = 0, + .next = null, + }; + var res: ?*posix.addrinfo = null; + switch (sys.getaddrinfo(name_c.ptr, port_c.ptr, &hints, &res)) { + @as(sys.EAI, @enumFromInt(0)) => {}, + .ADDRFAMILY => return error.HostLacksNetworkAddresses, + .AGAIN => return error.TemporaryNameServerFailure, + .BADFLAGS => unreachable, // Invalid hints + .FAIL => return error.NameServerFailure, + .FAMILY => return error.AddressFamilyNotSupported, + .MEMORY => return error.OutOfMemory, + .NODATA => return error.HostLacksNetworkAddresses, + .NONAME => return error.UnknownHostName, + .SERVICE => return error.ServiceUnavailable, + .SOCKTYPE => unreachable, // Invalid socket type requested in hints + .SYSTEM => switch (posix.errno(-1)) { + else => |e| return posix.unexpectedErrno(e), + }, + else => unreachable, + } + defer if (res) |some| sys.freeaddrinfo(some); + + const addr_count = blk: { + var count: usize = 0; + var it = res; + while (it) |info| : (it = info.next) { + if (info.addr != null) { + count += 1; + } + } + break :blk count; + }; + result.addrs = try arena.alloc(Address, addr_count); + + var it = res; + var i: usize = 0; + while (it) |info| : (it = info.next) { + const addr = info.addr orelse continue; + result.addrs[i] = Address.initPosix(@alignCast(addr)); + + if (info.canonname) |n| { + if (result.canon_name == null) { + result.canon_name = try arena.dupe(u8, mem.sliceTo(n, 0)); + } + } + i += 1; + } + + return result; + } + + if (native_os == .linux) { + const flags = std.c.AI.NUMERICSERV; + const family = posix.AF.UNSPEC; + var lookup_addrs = std.ArrayList(LookupAddr).init(allocator); + defer lookup_addrs.deinit(); + + var canon = std.ArrayList(u8).init(arena); + defer canon.deinit(); + + try linuxLookupName(&lookup_addrs, &canon, name, family, flags, port); + + result.addrs = try arena.alloc(Address, lookup_addrs.items.len); + if (canon.items.len != 0) { + result.canon_name = try canon.toOwnedSlice(); + } + + for (lookup_addrs.items, 0..) |lookup_addr, i| { + result.addrs[i] = lookup_addr.addr; + assert(result.addrs[i].getPort() == port); + } + + return result; + } + @compileError("std.net.getAddressList unimplemented for this OS"); +} + +const LookupAddr = struct { + addr: Address, + sortkey: i32 = 0, +}; + +const DAS_USABLE = 0x40000000; +const DAS_MATCHINGSCOPE = 0x20000000; +const DAS_MATCHINGLABEL = 0x10000000; +const DAS_PREC_SHIFT = 20; +const DAS_SCOPE_SHIFT = 16; +const DAS_PREFIX_SHIFT = 8; +const DAS_ORDER_SHIFT = 0; + +fn linuxLookupName( + addrs: *std.ArrayList(LookupAddr), + canon: *std.ArrayList(u8), + opt_name: ?[]const u8, + family: posix.sa_family_t, + flags: u32, + port: u16, +) !void { + if (opt_name) |name| { + // reject empty name and check len so it fits into temp bufs + canon.items.len = 0; + try canon.appendSlice(name); + if (Address.parseExpectingFamily(name, family, port)) |addr| { + try addrs.append(LookupAddr{ .addr = addr }); + } else |name_err| if ((flags & std.c.AI.NUMERICHOST) != 0) { + return name_err; + } else { + try linuxLookupNameFromHosts(addrs, canon, name, family, port); + if (addrs.items.len == 0) { + // RFC 6761 Section 6.3.3 + // Name resolution APIs and libraries SHOULD recognize localhost + // names as special and SHOULD always return the IP loopback address + // for address queries and negative responses for all other query + // types. + + // Check for equal to "localhost(.)" or ends in ".localhost(.)" + const localhost = if (name[name.len - 1] == '.') "localhost." else "localhost"; + if (mem.endsWith(u8, name, localhost) and (name.len == localhost.len or name[name.len - localhost.len] == '.')) { + try addrs.append(LookupAddr{ .addr = .{ .in = Ip4Address.parse("127.0.0.1", port) catch unreachable } }); + try addrs.append(LookupAddr{ .addr = .{ .in6 = Ip6Address.parse("::1", port) catch unreachable } }); + return; + } + + try linuxLookupNameFromDnsSearch(addrs, canon, name, family, port); + } + } + } else { + try canon.resize(0); + try linuxLookupNameFromNull(addrs, family, flags, port); + } + if (addrs.items.len == 0) return error.UnknownHostName; + + // No further processing is needed if there are fewer than 2 + // results or if there are only IPv4 results. + if (addrs.items.len == 1 or family == posix.AF.INET) return; + const all_ip4 = for (addrs.items) |addr| { + if (addr.addr.any.family != posix.AF.INET) break false; + } else true; + if (all_ip4) return; + + // The following implements a subset of RFC 3484/6724 destination + // address selection by generating a single 31-bit sort key for + // each address. Rules 3, 4, and 7 are omitted for having + // excessive runtime and code size cost and dubious benefit. + // So far the label/precedence table cannot be customized. + // This implementation is ported from musl libc. + // A more idiomatic "ziggy" implementation would be welcome. + for (addrs.items, 0..) |*addr, i| { + var key: i32 = 0; + var sa6: posix.sockaddr.in6 = undefined; + @memset(@as([*]u8, @ptrCast(&sa6))[0..@sizeOf(posix.sockaddr.in6)], 0); + var da6 = posix.sockaddr.in6{ + .family = posix.AF.INET6, + .scope_id = addr.addr.in6.sa.scope_id, + .port = 65535, + .flowinfo = 0, + .addr = [1]u8{0} ** 16, + }; + var sa4: posix.sockaddr.in = undefined; + @memset(@as([*]u8, @ptrCast(&sa4))[0..@sizeOf(posix.sockaddr.in)], 0); + var da4 = posix.sockaddr.in{ + .family = posix.AF.INET, + .port = 65535, + .addr = 0, + .zero = [1]u8{0} ** 8, + }; + var sa: *align(4) posix.sockaddr = undefined; + var da: *align(4) posix.sockaddr = undefined; + var salen: posix.socklen_t = undefined; + var dalen: posix.socklen_t = undefined; + if (addr.addr.any.family == posix.AF.INET6) { + da6.addr = addr.addr.in6.sa.addr; + da = @ptrCast(&da6); + dalen = @sizeOf(posix.sockaddr.in6); + sa = @ptrCast(&sa6); + salen = @sizeOf(posix.sockaddr.in6); + } else { + sa6.addr[0..12].* = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff".*; + da6.addr[0..12].* = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff".*; + mem.writeInt(u32, da6.addr[12..], addr.addr.in.sa.addr, native_endian); + da4.addr = addr.addr.in.sa.addr; + da = @ptrCast(&da4); + dalen = @sizeOf(posix.sockaddr.in); + sa = @ptrCast(&sa4); + salen = @sizeOf(posix.sockaddr.in); + } + const dpolicy = policyOf(da6.addr); + const dscope: i32 = scopeOf(da6.addr); + const dlabel = dpolicy.label; + const dprec: i32 = dpolicy.prec; + const MAXADDRS = 3; + var prefixlen: i32 = 0; + const sock_flags = posix.SOCK.DGRAM | posix.SOCK.CLOEXEC; + if (posix.socket(addr.addr.any.family, sock_flags, posix.IPPROTO.UDP)) |fd| syscalls: { + defer Stream.close(.{ .handle = fd }); + posix.connect(fd, da, dalen) catch break :syscalls; + key |= DAS_USABLE; + posix.getsockname(fd, sa, &salen) catch break :syscalls; + if (addr.addr.any.family == posix.AF.INET) { + mem.writeInt(u32, sa6.addr[12..16], sa4.addr, native_endian); + } + if (dscope == @as(i32, scopeOf(sa6.addr))) key |= DAS_MATCHINGSCOPE; + if (dlabel == labelOf(sa6.addr)) key |= DAS_MATCHINGLABEL; + prefixlen = prefixMatch(sa6.addr, da6.addr); + } else |_| {} + key |= dprec << DAS_PREC_SHIFT; + key |= (15 - dscope) << DAS_SCOPE_SHIFT; + key |= prefixlen << DAS_PREFIX_SHIFT; + key |= (MAXADDRS - @as(i32, @intCast(i))) << DAS_ORDER_SHIFT; + addr.sortkey = key; + } + mem.sort(LookupAddr, addrs.items, {}, addrCmpLessThan); +} + +const Policy = struct { + addr: [16]u8, + len: u8, + mask: u8, + prec: u8, + label: u8, +}; + +const defined_policies = [_]Policy{ + Policy{ + .addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01".*, + .len = 15, + .mask = 0xff, + .prec = 50, + .label = 0, + }, + Policy{ + .addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\x00\x00\x00\x00".*, + .len = 11, + .mask = 0xff, + .prec = 35, + .label = 4, + }, + Policy{ + .addr = "\x20\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00".*, + .len = 1, + .mask = 0xff, + .prec = 30, + .label = 2, + }, + Policy{ + .addr = "\x20\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00".*, + .len = 3, + .mask = 0xff, + .prec = 5, + .label = 5, + }, + Policy{ + .addr = "\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00".*, + .len = 0, + .mask = 0xfe, + .prec = 3, + .label = 13, + }, + // These are deprecated and/or returned to the address + // pool, so despite the RFC, treating them as special + // is probably wrong. + // { "", 11, 0xff, 1, 3 }, + // { "\xfe\xc0", 1, 0xc0, 1, 11 }, + // { "\x3f\xfe", 1, 0xff, 1, 12 }, + // Last rule must match all addresses to stop loop. + Policy{ + .addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00".*, + .len = 0, + .mask = 0, + .prec = 40, + .label = 1, + }, +}; + +fn policyOf(a: [16]u8) *const Policy { + for (&defined_policies) |*policy| { + if (!mem.eql(u8, a[0..policy.len], policy.addr[0..policy.len])) continue; + if ((a[policy.len] & policy.mask) != policy.addr[policy.len]) continue; + return policy; + } + unreachable; +} + +fn scopeOf(a: [16]u8) u8 { + if (IN6_IS_ADDR_MULTICAST(a)) return a[1] & 15; + if (IN6_IS_ADDR_LINKLOCAL(a)) return 2; + if (IN6_IS_ADDR_LOOPBACK(a)) return 2; + if (IN6_IS_ADDR_SITELOCAL(a)) return 5; + return 14; +} + +fn prefixMatch(s: [16]u8, d: [16]u8) u8 { + // TODO: This FIXME inherited from porting from musl libc. + // I don't want this to go into zig std lib 1.0.0. + + // FIXME: The common prefix length should be limited to no greater + // than the nominal length of the prefix portion of the source + // address. However the definition of the source prefix length is + // not clear and thus this limiting is not yet implemented. + var i: u8 = 0; + while (i < 128 and ((s[i / 8] ^ d[i / 8]) & (@as(u8, 128) >> @as(u3, @intCast(i % 8)))) == 0) : (i += 1) {} + return i; +} + +fn labelOf(a: [16]u8) u8 { + return policyOf(a).label; +} + +fn IN6_IS_ADDR_MULTICAST(a: [16]u8) bool { + return a[0] == 0xff; +} + +fn IN6_IS_ADDR_LINKLOCAL(a: [16]u8) bool { + return a[0] == 0xfe and (a[1] & 0xc0) == 0x80; +} + +fn IN6_IS_ADDR_LOOPBACK(a: [16]u8) bool { + return a[0] == 0 and a[1] == 0 and + a[2] == 0 and + a[12] == 0 and a[13] == 0 and + a[14] == 0 and a[15] == 1; +} + +fn IN6_IS_ADDR_SITELOCAL(a: [16]u8) bool { + return a[0] == 0xfe and (a[1] & 0xc0) == 0xc0; +} + +// Parameters `b` and `a` swapped to make this descending. +fn addrCmpLessThan(context: void, b: LookupAddr, a: LookupAddr) bool { + _ = context; + return a.sortkey < b.sortkey; +} + +fn linuxLookupNameFromNull( + addrs: *std.ArrayList(LookupAddr), + family: posix.sa_family_t, + flags: u32, + port: u16, +) !void { + if ((flags & std.c.AI.PASSIVE) != 0) { + if (family != posix.AF.INET6) { + (try addrs.addOne()).* = LookupAddr{ + .addr = Address.initIp4([1]u8{0} ** 4, port), + }; + } + if (family != posix.AF.INET) { + (try addrs.addOne()).* = LookupAddr{ + .addr = Address.initIp6([1]u8{0} ** 16, port, 0, 0), + }; + } + } else { + if (family != posix.AF.INET6) { + (try addrs.addOne()).* = LookupAddr{ + .addr = Address.initIp4([4]u8{ 127, 0, 0, 1 }, port), + }; + } + if (family != posix.AF.INET) { + (try addrs.addOne()).* = LookupAddr{ + .addr = Address.initIp6(([1]u8{0} ** 15) ++ [1]u8{1}, port, 0, 0), + }; + } + } +} + +fn linuxLookupNameFromHosts( + addrs: *std.ArrayList(LookupAddr), + canon: *std.ArrayList(u8), + name: []const u8, + family: posix.sa_family_t, + port: u16, +) !void { + const file = fs.openFileAbsoluteZ("/etc/hosts", .{}) catch |err| switch (err) { + error.FileNotFound, + error.NotDir, + error.AccessDenied, + => return, + else => |e| return e, + }; + defer file.close(); + + var buffered_reader = std.io.bufferedReader(file.reader()); + const reader = buffered_reader.reader(); + var line_buf: [512]u8 = undefined; + while (reader.readUntilDelimiterOrEof(&line_buf, '\n') catch |err| switch (err) { + error.StreamTooLong => blk: { + // Skip to the delimiter in the reader, to fix parsing + try reader.skipUntilDelimiterOrEof('\n'); + // Use the truncated line. A truncated comment or hostname will be handled correctly. + break :blk &line_buf; + }, + else => |e| return e, + }) |line| { + var split_it = mem.splitScalar(u8, line, '#'); + const no_comment_line = split_it.first(); + + var line_it = mem.tokenizeAny(u8, no_comment_line, " \t"); + const ip_text = line_it.next() orelse continue; + var first_name_text: ?[]const u8 = null; + while (line_it.next()) |name_text| { + if (first_name_text == null) first_name_text = name_text; + if (mem.eql(u8, name_text, name)) { + break; + } + } else continue; + + const addr = Address.parseExpectingFamily(ip_text, family, port) catch |err| switch (err) { + error.Overflow, + error.InvalidEnd, + error.InvalidCharacter, + error.Incomplete, + error.InvalidIPAddressFormat, + error.InvalidIpv4Mapping, + error.NonCanonical, + => continue, + }; + try addrs.append(LookupAddr{ .addr = addr }); + + // first name is canonical name + const name_text = first_name_text.?; + if (isValidHostName(name_text)) { + canon.items.len = 0; + try canon.appendSlice(name_text); + } + } +} + +pub fn isValidHostName(hostname: []const u8) bool { + if (hostname.len >= 254) return false; + if (!std.unicode.utf8ValidateSlice(hostname)) return false; + for (hostname) |byte| { + if (!std.ascii.isASCII(byte) or byte == '.' or byte == '-' or std.ascii.isAlphanumeric(byte)) { + continue; + } + return false; + } + return true; +} + +fn linuxLookupNameFromDnsSearch( + addrs: *std.ArrayList(LookupAddr), + canon: *std.ArrayList(u8), + name: []const u8, + family: posix.sa_family_t, + port: u16, +) !void { + var rc: ResolvConf = undefined; + try getResolvConf(addrs.allocator, &rc); + defer rc.deinit(); + + // Count dots, suppress search when >=ndots or name ends in + // a dot, which is an explicit request for global scope. + var dots: usize = 0; + for (name) |byte| { + if (byte == '.') dots += 1; + } + + const search = if (dots >= rc.ndots or mem.endsWith(u8, name, ".")) + "" + else + rc.search.items; + + var canon_name = name; + + // Strip final dot for canon, fail if multiple trailing dots. + if (mem.endsWith(u8, canon_name, ".")) canon_name.len -= 1; + if (mem.endsWith(u8, canon_name, ".")) return error.UnknownHostName; + + // Name with search domain appended is setup in canon[]. This both + // provides the desired default canonical name (if the requested + // name is not a CNAME record) and serves as a buffer for passing + // the full requested name to name_from_dns. + try canon.resize(canon_name.len); + @memcpy(canon.items, canon_name); + try canon.append('.'); + + var tok_it = mem.tokenizeAny(u8, search, " \t"); + while (tok_it.next()) |tok| { + canon.shrinkRetainingCapacity(canon_name.len + 1); + try canon.appendSlice(tok); + try linuxLookupNameFromDns(addrs, canon, canon.items, family, rc, port); + if (addrs.items.len != 0) return; + } + + canon.shrinkRetainingCapacity(canon_name.len); + return linuxLookupNameFromDns(addrs, canon, name, family, rc, port); +} + +const dpc_ctx = struct { + addrs: *std.ArrayList(LookupAddr), + canon: *std.ArrayList(u8), + port: u16, +}; + +fn linuxLookupNameFromDns( + addrs: *std.ArrayList(LookupAddr), + canon: *std.ArrayList(u8), + name: []const u8, + family: posix.sa_family_t, + rc: ResolvConf, + port: u16, +) !void { + const ctx = dpc_ctx{ + .addrs = addrs, + .canon = canon, + .port = port, + }; + const AfRr = struct { + af: posix.sa_family_t, + rr: u8, + }; + const afrrs = [_]AfRr{ + AfRr{ .af = posix.AF.INET6, .rr = posix.RR.A }, + AfRr{ .af = posix.AF.INET, .rr = posix.RR.AAAA }, + }; + var qbuf: [2][280]u8 = undefined; + var abuf: [2][512]u8 = undefined; + var qp: [2][]const u8 = undefined; + const apbuf = [2][]u8{ &abuf[0], &abuf[1] }; + var nq: usize = 0; + + for (afrrs) |afrr| { + if (family != afrr.af) { + const len = posix.res_mkquery(0, name, 1, afrr.rr, &[_]u8{}, null, &qbuf[nq]); + qp[nq] = qbuf[nq][0..len]; + nq += 1; + } + } + + var ap = [2][]u8{ apbuf[0], apbuf[1] }; + ap[0].len = 0; + ap[1].len = 0; + + try resMSendRc(qp[0..nq], ap[0..nq], apbuf[0..nq], rc); + + var i: usize = 0; + while (i < nq) : (i += 1) { + dnsParse(ap[i], ctx, dnsParseCallback) catch {}; + } + + if (addrs.items.len != 0) return; + if (ap[0].len < 4 or (ap[0][3] & 15) == 2) return error.TemporaryNameServerFailure; + if ((ap[0][3] & 15) == 0) return error.UnknownHostName; + if ((ap[0][3] & 15) == 3) return; + return error.NameServerFailure; +} + +const ResolvConf = struct { + attempts: u32, + ndots: u32, + timeout: u32, + search: std.ArrayList(u8), + ns: std.ArrayList(LookupAddr), + + fn deinit(rc: *ResolvConf) void { + rc.ns.deinit(); + rc.search.deinit(); + rc.* = undefined; + } +}; + +/// Ignores lines longer than 512 bytes. +/// TODO: https://github.com/ziglang/zig/issues/2765 and https://github.com/ziglang/zig/issues/2761 +fn getResolvConf(allocator: mem.Allocator, rc: *ResolvConf) !void { + rc.* = ResolvConf{ + .ns = std.ArrayList(LookupAddr).init(allocator), + .search = std.ArrayList(u8).init(allocator), + .ndots = 1, + .timeout = 5, + .attempts = 2, + }; + errdefer rc.deinit(); + + const file = fs.openFileAbsoluteZ("/etc/resolv.conf", .{}) catch |err| switch (err) { + error.FileNotFound, + error.NotDir, + error.AccessDenied, + => return linuxLookupNameFromNumericUnspec(&rc.ns, "127.0.0.1", 53), + else => |e| return e, + }; + defer file.close(); + + var buf_reader = std.io.bufferedReader(file.reader()); + const stream = buf_reader.reader(); + var line_buf: [512]u8 = undefined; + while (stream.readUntilDelimiterOrEof(&line_buf, '\n') catch |err| switch (err) { + error.StreamTooLong => blk: { + // Skip to the delimiter in the stream, to fix parsing + try stream.skipUntilDelimiterOrEof('\n'); + // Give an empty line to the while loop, which will be skipped. + break :blk line_buf[0..0]; + }, + else => |e| return e, + }) |line| { + const no_comment_line = no_comment_line: { + var split = mem.splitScalar(u8, line, '#'); + break :no_comment_line split.first(); + }; + var line_it = mem.tokenizeAny(u8, no_comment_line, " \t"); + + const token = line_it.next() orelse continue; + if (mem.eql(u8, token, "options")) { + while (line_it.next()) |sub_tok| { + var colon_it = mem.splitScalar(u8, sub_tok, ':'); + const name = colon_it.first(); + const value_txt = colon_it.next() orelse continue; + const value = std.fmt.parseInt(u8, value_txt, 10) catch |err| switch (err) { + // TODO https://github.com/ziglang/zig/issues/11812 + error.Overflow => @as(u8, 255), + error.InvalidCharacter => continue, + }; + if (mem.eql(u8, name, "ndots")) { + rc.ndots = @min(value, 15); + } else if (mem.eql(u8, name, "attempts")) { + rc.attempts = @min(value, 10); + } else if (mem.eql(u8, name, "timeout")) { + rc.timeout = @min(value, 60); + } + } + } else if (mem.eql(u8, token, "nameserver")) { + const ip_txt = line_it.next() orelse continue; + try linuxLookupNameFromNumericUnspec(&rc.ns, ip_txt, 53); + } else if (mem.eql(u8, token, "domain") or mem.eql(u8, token, "search")) { + rc.search.items.len = 0; + try rc.search.appendSlice(line_it.rest()); + } + } + + if (rc.ns.items.len == 0) { + return linuxLookupNameFromNumericUnspec(&rc.ns, "127.0.0.1", 53); + } +} + +fn linuxLookupNameFromNumericUnspec( + addrs: *std.ArrayList(LookupAddr), + name: []const u8, + port: u16, +) !void { + const addr = try Address.resolveIp(name, port); + (try addrs.addOne()).* = LookupAddr{ .addr = addr }; +} + +fn resMSendRc( + queries: []const []const u8, + answers: [][]u8, + answer_bufs: []const []u8, + rc: ResolvConf, +) !void { + const timeout = 1000 * rc.timeout; + const attempts = rc.attempts; + + var sl: posix.socklen_t = @sizeOf(posix.sockaddr.in); + var family: posix.sa_family_t = posix.AF.INET; + + var ns_list = std.ArrayList(Address).init(rc.ns.allocator); + defer ns_list.deinit(); + + try ns_list.resize(rc.ns.items.len); + const ns = ns_list.items; + + for (rc.ns.items, 0..) |iplit, i| { + ns[i] = iplit.addr; + assert(ns[i].getPort() == 53); + if (iplit.addr.any.family != posix.AF.INET) { + family = posix.AF.INET6; + } + } + + const flags = posix.SOCK.DGRAM | posix.SOCK.CLOEXEC | posix.SOCK.NONBLOCK; + const fd = posix.socket(family, flags, 0) catch |err| switch (err) { + error.AddressFamilyNotSupported => blk: { + // Handle case where system lacks IPv6 support + if (family == posix.AF.INET6) { + family = posix.AF.INET; + break :blk try posix.socket(posix.AF.INET, flags, 0); + } + return err; + }, + else => |e| return e, + }; + defer Stream.close(.{ .handle = fd }); + + // Past this point, there are no errors. Each individual query will + // yield either no reply (indicated by zero length) or an answer + // packet which is up to the caller to interpret. + + // Convert any IPv4 addresses in a mixed environment to v4-mapped + if (family == posix.AF.INET6) { + try posix.setsockopt( + fd, + posix.SOL.IPV6, + std.os.linux.IPV6.V6ONLY, + &mem.toBytes(@as(c_int, 0)), + ); + for (0..ns.len) |i| { + if (ns[i].any.family != posix.AF.INET) continue; + mem.writeInt(u32, ns[i].in6.sa.addr[12..], ns[i].in.sa.addr, native_endian); + ns[i].in6.sa.addr[0..12].* = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff".*; + ns[i].any.family = posix.AF.INET6; + ns[i].in6.sa.flowinfo = 0; + ns[i].in6.sa.scope_id = 0; + } + sl = @sizeOf(posix.sockaddr.in6); + } + + // Get local address and open/bind a socket + var sa: Address = undefined; + @memset(@as([*]u8, @ptrCast(&sa))[0..@sizeOf(Address)], 0); + sa.any.family = family; + try posix.bind(fd, &sa.any, sl); + + var pfd = [1]posix.pollfd{posix.pollfd{ + .fd = fd, + .events = posix.POLL.IN, + .revents = undefined, + }}; + const retry_interval = timeout / attempts; + var next: u32 = 0; + var t2: u64 = @bitCast(std.time.milliTimestamp()); + const t0 = t2; + var t1 = t2 - retry_interval; + + var servfail_retry: usize = undefined; + + outer: while (t2 - t0 < timeout) : (t2 = @as(u64, @bitCast(std.time.milliTimestamp()))) { + if (t2 - t1 >= retry_interval) { + // Query all configured nameservers in parallel + var i: usize = 0; + while (i < queries.len) : (i += 1) { + if (answers[i].len == 0) { + var j: usize = 0; + while (j < ns.len) : (j += 1) { + _ = posix.sendto(fd, queries[i], posix.MSG.NOSIGNAL, &ns[j].any, sl) catch undefined; + } + } + } + t1 = t2; + servfail_retry = 2 * queries.len; + } + + // Wait for a response, or until time to retry + const clamped_timeout = @min(@as(u31, std.math.maxInt(u31)), t1 + retry_interval - t2); + const nevents = posix.poll(&pfd, clamped_timeout) catch 0; + if (nevents == 0) continue; + + while (true) { + var sl_copy = sl; + const rlen = posix.recvfrom(fd, answer_bufs[next], 0, &sa.any, &sl_copy) catch break; + + // Ignore non-identifiable packets + if (rlen < 4) continue; + + // Ignore replies from addresses we didn't send to + var j: usize = 0; + while (j < ns.len and !ns[j].eql(sa)) : (j += 1) {} + if (j == ns.len) continue; + + // Find which query this answer goes with, if any + var i: usize = next; + while (i < queries.len and (answer_bufs[next][0] != queries[i][0] or + answer_bufs[next][1] != queries[i][1])) : (i += 1) + {} + + if (i == queries.len) continue; + if (answers[i].len != 0) continue; + + // Only accept positive or negative responses; + // retry immediately on server failure, and ignore + // all other codes such as refusal. + switch (answer_bufs[next][3] & 15) { + 0, 3 => {}, + 2 => if (servfail_retry != 0) { + servfail_retry -= 1; + _ = posix.sendto(fd, queries[i], posix.MSG.NOSIGNAL, &ns[j].any, sl) catch undefined; + }, + else => continue, + } + + // Store answer in the right slot, or update next + // available temp slot if it's already in place. + answers[i].len = rlen; + if (i == next) { + while (next < queries.len and answers[next].len != 0) : (next += 1) {} + } else { + @memcpy(answer_bufs[i][0..rlen], answer_bufs[next][0..rlen]); + } + + if (next == queries.len) break :outer; + } + } +} + +fn dnsParse( + r: []const u8, + ctx: anytype, + comptime callback: anytype, +) !void { + // This implementation is ported from musl libc. + // A more idiomatic "ziggy" implementation would be welcome. + if (r.len < 12) return error.InvalidDnsPacket; + if ((r[3] & 15) != 0) return; + var p = r.ptr + 12; + var qdcount = r[4] * @as(usize, 256) + r[5]; + var ancount = r[6] * @as(usize, 256) + r[7]; + if (qdcount + ancount > 64) return error.InvalidDnsPacket; + while (qdcount != 0) { + qdcount -= 1; + while (@intFromPtr(p) - @intFromPtr(r.ptr) < r.len and p[0] -% 1 < 127) p += 1; + if (p[0] > 193 or (p[0] == 193 and p[1] > 254) or @intFromPtr(p) > @intFromPtr(r.ptr) + r.len - 6) + return error.InvalidDnsPacket; + p += @as(usize, 5) + @intFromBool(p[0] != 0); + } + while (ancount != 0) { + ancount -= 1; + while (@intFromPtr(p) - @intFromPtr(r.ptr) < r.len and p[0] -% 1 < 127) p += 1; + if (p[0] > 193 or (p[0] == 193 and p[1] > 254) or @intFromPtr(p) > @intFromPtr(r.ptr) + r.len - 6) + return error.InvalidDnsPacket; + p += @as(usize, 1) + @intFromBool(p[0] != 0); + const len = p[8] * @as(usize, 256) + p[9]; + if (@intFromPtr(p) + len > @intFromPtr(r.ptr) + r.len) return error.InvalidDnsPacket; + try callback(ctx, p[1], p[10..][0..len], r); + p += 10 + len; + } +} + +fn dnsParseCallback(ctx: dpc_ctx, rr: u8, data: []const u8, packet: []const u8) !void { + switch (rr) { + posix.RR.A => { + if (data.len != 4) return error.InvalidDnsARecord; + const new_addr = try ctx.addrs.addOne(); + new_addr.* = LookupAddr{ + .addr = Address.initIp4(data[0..4].*, ctx.port), + }; + }, + posix.RR.AAAA => { + if (data.len != 16) return error.InvalidDnsAAAARecord; + const new_addr = try ctx.addrs.addOne(); + new_addr.* = LookupAddr{ + .addr = Address.initIp6(data[0..16].*, ctx.port, 0, 0), + }; + }, + posix.RR.CNAME => { + var tmp: [256]u8 = undefined; + // Returns len of compressed name. strlen to get canon name. + _ = try posix.dn_expand(packet, data, &tmp); + const canon_name = mem.sliceTo(&tmp, 0); + if (isValidHostName(canon_name)) { + ctx.canon.items.len = 0; + try ctx.canon.appendSlice(canon_name); + } + }, + else => return, + } +} + +pub const Stream = struct { + /// Underlying platform-defined type which may or may not be + /// interchangeable with a file system file descriptor. + handle: posix.socket_t, + + pub fn close(s: Stream) void { + switch (native_os) { + .windows => windows.closesocket(s.handle) catch unreachable, + else => posix.close(s.handle), + } + } + + pub const ReadError = posix.ReadError; + pub const WriteError = posix.WriteError; + + pub const Reader = io.Reader(Stream, ReadError, read); + pub const Writer = io.Writer(Stream, WriteError, write); + + pub fn reader(self: Stream) Reader { + return .{ .context = self }; + } + + pub fn writer(self: Stream) Writer { + return .{ .context = self }; + } + + pub fn read(self: Stream, buffer: []u8) ReadError!usize { + if (native_os == .windows) { + return windows.ReadFile(self.handle, buffer, null); + } + + return posix.read(self.handle, buffer); + } + + pub fn readv(s: Stream, iovecs: []const posix.iovec) ReadError!usize { + if (native_os == .windows) { + // TODO improve this to use ReadFileScatter + if (iovecs.len == 0) return @as(usize, 0); + const first = iovecs[0]; + return windows.ReadFile(s.handle, first.base[0..first.len], null); + } + + return posix.readv(s.handle, iovecs); + } + + /// Returns the number of bytes read. If the number read is smaller than + /// `buffer.len`, it means the stream reached the end. Reaching the end of + /// a stream is not an error condition. + pub fn readAll(s: Stream, buffer: []u8) ReadError!usize { + return readAtLeast(s, buffer, buffer.len); + } + + /// Returns the number of bytes read, calling the underlying read function + /// the minimal number of times until the buffer has at least `len` bytes + /// filled. If the number read is less than `len` it means the stream + /// reached the end. Reaching the end of the stream is not an error + /// condition. + pub fn readAtLeast(s: Stream, buffer: []u8, len: usize) ReadError!usize { + assert(len <= buffer.len); + var index: usize = 0; + while (index < len) { + const amt = try s.read(buffer[index..]); + if (amt == 0) break; + index += amt; + } + return index; + } + + /// TODO in evented I/O mode, this implementation incorrectly uses the event loop's + /// file system thread instead of non-blocking. It needs to be reworked to properly + /// use non-blocking I/O. + pub fn write(self: Stream, buffer: []const u8) WriteError!usize { + if (native_os == .windows) { + return windows.WriteFile(self.handle, buffer, null); + } + + return posix.write(self.handle, buffer); + } + + pub fn writeAll(self: Stream, bytes: []const u8) WriteError!void { + var index: usize = 0; + while (index < bytes.len) { + index += try self.write(bytes[index..]); + } + } + + /// See https://github.com/ziglang/zig/issues/7699 + /// See equivalent function: `std.fs.File.writev`. + pub fn writev(self: Stream, iovecs: []const posix.iovec_const) WriteError!usize { + return posix.writev(self.handle, iovecs); + } + + /// The `iovecs` parameter is mutable because this function needs to mutate the fields in + /// order to handle partial writes from the underlying OS layer. + /// See https://github.com/ziglang/zig/issues/7699 + /// See equivalent function: `std.fs.File.writevAll`. + pub fn writevAll(self: Stream, iovecs: []posix.iovec_const) WriteError!void { + if (iovecs.len == 0) return; + + var i: usize = 0; + while (true) { + var amt = try self.writev(iovecs[i..]); + while (amt >= iovecs[i].len) { + amt -= iovecs[i].len; + i += 1; + if (i >= iovecs.len) return; + } + iovecs[i].base += amt; + iovecs[i].len -= amt; + } + } + + pub fn async_read( + self: Stream, + buffer: []u8, + ctx: *Ctx, + comptime cbk: Cbk, + ) !void { + return ctx.loop.recv(Ctx, ctx, cbk, self.handle, buffer); + } + + pub fn async_readv( + s: Stream, + iovecs: []const posix.iovec, + ctx: *Ctx, + comptime cbk: Cbk, + ) ReadError!void { + if (iovecs.len == 0) return; + const first_buffer = iovecs[0].base[0..iovecs[0].len]; + return s.async_read(first_buffer, ctx, cbk); + } + + // TODO: why not take a buffer here? + pub fn async_write(self: Stream, buffer: []const u8, ctx: *Ctx, comptime cbk: Cbk) void { + return ctx.loop.send(Ctx, ctx, cbk, self.handle, buffer); + } + + fn onWriteAll(ctx: *Ctx, res: anyerror!void) anyerror!void { + res catch |err| return ctx.pop(err); + if (ctx.len() < ctx.buf().len) { + const new_buf = ctx.buf()[ctx.len()..]; + ctx.setBuf(new_buf); + return ctx.stream().async_write(new_buf, ctx, onWriteAll); + } + ctx.setBuf(null); + return ctx.pop({}); + } + + pub fn async_writeAll(self: Stream, bytes: []const u8, ctx: *Ctx, comptime cbk: Cbk) !void { + ctx.setBuf(bytes); + try ctx.push(cbk); + self.async_write(bytes, ctx, onWriteAll); + } +}; + +pub const Server = struct { + listen_address: Address, + stream: std.net.Stream, + + pub const Connection = struct { + stream: std.net.Stream, + address: Address, + }; + + pub fn deinit(s: *Server) void { + s.stream.close(); + s.* = undefined; + } + + pub const AcceptError = posix.AcceptError; + + /// Blocks until a client connects to the server. The returned `Connection` has + /// an open stream. + pub fn accept(s: *Server) AcceptError!Connection { + var accepted_addr: Address = undefined; + var addr_len: posix.socklen_t = @sizeOf(Address); + const fd = try posix.accept(s.stream.handle, &accepted_addr.any, &addr_len, posix.SOCK.CLOEXEC); + return .{ + .stream = .{ .handle = fd }, + .address = accepted_addr, + }; + } +}; + +test { + _ = @import("net/test.zig"); + _ = Server; + _ = Stream; + _ = Address; +} + +fn onTcpConnectToHost(ctx: *Ctx, res: anyerror!void) anyerror!void { + res catch |e| switch (e) { + error.ConnectionRefused => { + if (ctx.data.addr_current < ctx.data.list.addrs.len) { + // next iteration of addr + ctx.push(onTcpConnectToHost) catch |er| return ctx.pop(er); + ctx.data.addr_current += 1; + return async_tcpConnectToAddress( + ctx.data.list.addrs[ctx.data.addr_current], + ctx, + onTcpConnectToHost, + ); + } + // end of iteration of addr + ctx.data.list.deinit(); + return ctx.pop(e); + }, + else => { + ctx.data.list.deinit(); + return ctx.pop(std.posix.ConnectError.ConnectionRefused); + }, + }; + // success + ctx.data.list.deinit(); + return ctx.pop({}); +} + +pub fn async_tcpConnectToHost( + allocator: mem.Allocator, + name: []const u8, + port: u16, + ctx: *Ctx, + comptime cbk: Cbk, +) !void { + const list = std.net.getAddressList(allocator, name, port) catch |e| return ctx.pop(e); + if (list.addrs.len == 0) return ctx.pop(error.UnknownHostName); + + ctx.push(cbk) catch |e| return ctx.pop(e); + ctx.data.list = list; + ctx.data.addr_current = 0; + return async_tcpConnectToAddress(list.addrs[0], ctx, onTcpConnectToHost); +} + +pub fn async_tcpConnectToAddress(address: std.net.Address, ctx: *Ctx, comptime cbk: Cbk) !void { + const nonblock = 0; + const sock_flags = posix.SOCK.STREAM | nonblock | + (if (native_os == .windows) 0 else posix.SOCK.CLOEXEC); + const sockfd = try posix.socket(address.any.family, sock_flags, posix.IPPROTO.TCP); + + ctx.data.socket = sockfd; + ctx.push(cbk) catch |e| return ctx.pop(e); + + ctx.loop.connect( + Ctx, + ctx, + setStream, + sockfd, + address, + ); +} + +// requires client.data.socket to be set +fn setStream(ctx: *Ctx, res: anyerror!void) anyerror!void { + res catch |e| return ctx.pop(e); + ctx.data.conn.stream = .{ .handle = ctx.data.socket }; + return ctx.pop({}); +} diff --git a/src/http/async/std/net/test.zig b/src/http/async/std/net/test.zig new file mode 100644 index 00000000..3e316c54 --- /dev/null +++ b/src/http/async/std/net/test.zig @@ -0,0 +1,335 @@ +const std = @import("std"); +const builtin = @import("builtin"); +const net = std.net; +const mem = std.mem; +const testing = std.testing; + +test "parse and render IP addresses at comptime" { + if (builtin.os.tag == .wasi) return error.SkipZigTest; + comptime { + var ipAddrBuffer: [16]u8 = undefined; + // Parses IPv6 at comptime + const ipv6addr = net.Address.parseIp("::1", 0) catch unreachable; + var ipv6 = std.fmt.bufPrint(ipAddrBuffer[0..], "{}", .{ipv6addr}) catch unreachable; + try std.testing.expect(std.mem.eql(u8, "::1", ipv6[1 .. ipv6.len - 3])); + + // Parses IPv4 at comptime + const ipv4addr = net.Address.parseIp("127.0.0.1", 0) catch unreachable; + var ipv4 = std.fmt.bufPrint(ipAddrBuffer[0..], "{}", .{ipv4addr}) catch unreachable; + try std.testing.expect(std.mem.eql(u8, "127.0.0.1", ipv4[0 .. ipv4.len - 2])); + + // Returns error for invalid IP addresses at comptime + try testing.expectError(error.InvalidIPAddressFormat, net.Address.parseIp("::123.123.123.123", 0)); + try testing.expectError(error.InvalidIPAddressFormat, net.Address.parseIp("127.01.0.1", 0)); + try testing.expectError(error.InvalidIPAddressFormat, net.Address.resolveIp("::123.123.123.123", 0)); + try testing.expectError(error.InvalidIPAddressFormat, net.Address.resolveIp("127.01.0.1", 0)); + } +} + +test "parse and render IPv6 addresses" { + if (builtin.os.tag == .wasi) return error.SkipZigTest; + + var buffer: [100]u8 = undefined; + const ips = [_][]const u8{ + "FF01:0:0:0:0:0:0:FB", + "FF01::Fb", + "::1", + "::", + "1::", + "2001:db8::", + "::1234:5678", + "2001:db8::1234:5678", + "FF01::FB%1234", + "::ffff:123.5.123.5", + }; + const printed = [_][]const u8{ + "ff01::fb", + "ff01::fb", + "::1", + "::", + "1::", + "2001:db8::", + "::1234:5678", + "2001:db8::1234:5678", + "ff01::fb", + "::ffff:123.5.123.5", + }; + for (ips, 0..) |ip, i| { + const addr = net.Address.parseIp6(ip, 0) catch unreachable; + var newIp = std.fmt.bufPrint(buffer[0..], "{}", .{addr}) catch unreachable; + try std.testing.expect(std.mem.eql(u8, printed[i], newIp[1 .. newIp.len - 3])); + + if (builtin.os.tag == .linux) { + const addr_via_resolve = net.Address.resolveIp6(ip, 0) catch unreachable; + var newResolvedIp = std.fmt.bufPrint(buffer[0..], "{}", .{addr_via_resolve}) catch unreachable; + try std.testing.expect(std.mem.eql(u8, printed[i], newResolvedIp[1 .. newResolvedIp.len - 3])); + } + } + + try testing.expectError(error.InvalidCharacter, net.Address.parseIp6(":::", 0)); + try testing.expectError(error.Overflow, net.Address.parseIp6("FF001::FB", 0)); + try testing.expectError(error.InvalidCharacter, net.Address.parseIp6("FF01::Fb:zig", 0)); + try testing.expectError(error.InvalidEnd, net.Address.parseIp6("FF01:0:0:0:0:0:0:FB:", 0)); + try testing.expectError(error.Incomplete, net.Address.parseIp6("FF01:", 0)); + try testing.expectError(error.InvalidIpv4Mapping, net.Address.parseIp6("::123.123.123.123", 0)); + try testing.expectError(error.Incomplete, net.Address.parseIp6("1", 0)); + // TODO Make this test pass on other operating systems. + if (builtin.os.tag == .linux or comptime builtin.os.tag.isDarwin()) { + try testing.expectError(error.Incomplete, net.Address.resolveIp6("ff01::fb%", 0)); + try testing.expectError(error.Overflow, net.Address.resolveIp6("ff01::fb%wlp3s0s0s0s0s0s0s0s0", 0)); + try testing.expectError(error.Overflow, net.Address.resolveIp6("ff01::fb%12345678901234", 0)); + } +} + +test "invalid but parseable IPv6 scope ids" { + if (builtin.os.tag != .linux and comptime !builtin.os.tag.isDarwin()) { + // Currently, resolveIp6 with alphanumerical scope IDs only works on Linux. + // TODO Make this test pass on other operating systems. + return error.SkipZigTest; + } + + try testing.expectError(error.InterfaceNotFound, net.Address.resolveIp6("ff01::fb%123s45678901234", 0)); +} + +test "parse and render IPv4 addresses" { + if (builtin.os.tag == .wasi) return error.SkipZigTest; + + var buffer: [18]u8 = undefined; + for ([_][]const u8{ + "0.0.0.0", + "255.255.255.255", + "1.2.3.4", + "123.255.0.91", + "127.0.0.1", + }) |ip| { + const addr = net.Address.parseIp4(ip, 0) catch unreachable; + var newIp = std.fmt.bufPrint(buffer[0..], "{}", .{addr}) catch unreachable; + try std.testing.expect(std.mem.eql(u8, ip, newIp[0 .. newIp.len - 2])); + } + + try testing.expectError(error.Overflow, net.Address.parseIp4("256.0.0.1", 0)); + try testing.expectError(error.InvalidCharacter, net.Address.parseIp4("x.0.0.1", 0)); + try testing.expectError(error.InvalidEnd, net.Address.parseIp4("127.0.0.1.1", 0)); + try testing.expectError(error.Incomplete, net.Address.parseIp4("127.0.0.", 0)); + try testing.expectError(error.InvalidCharacter, net.Address.parseIp4("100..0.1", 0)); + try testing.expectError(error.NonCanonical, net.Address.parseIp4("127.01.0.1", 0)); +} + +test "parse and render UNIX addresses" { + if (builtin.os.tag == .wasi) return error.SkipZigTest; + if (!net.has_unix_sockets) return error.SkipZigTest; + + var buffer: [14]u8 = undefined; + const addr = net.Address.initUnix("/tmp/testpath") catch unreachable; + const fmt_addr = std.fmt.bufPrint(buffer[0..], "{}", .{addr}) catch unreachable; + try std.testing.expectEqualSlices(u8, "/tmp/testpath", fmt_addr); + + const too_long = [_]u8{'a'} ** 200; + try testing.expectError(error.NameTooLong, net.Address.initUnix(too_long[0..])); +} + +test "resolve DNS" { + if (builtin.os.tag == .wasi) return error.SkipZigTest; + + if (builtin.os.tag == .windows) { + _ = try std.os.windows.WSAStartup(2, 2); + } + defer { + if (builtin.os.tag == .windows) { + std.os.windows.WSACleanup() catch unreachable; + } + } + + // Resolve localhost, this should not fail. + { + const localhost_v4 = try net.Address.parseIp("127.0.0.1", 80); + const localhost_v6 = try net.Address.parseIp("::2", 80); + + const result = try net.getAddressList(testing.allocator, "localhost", 80); + defer result.deinit(); + for (result.addrs) |addr| { + if (addr.eql(localhost_v4) or addr.eql(localhost_v6)) break; + } else @panic("unexpected address for localhost"); + } + + { + // The tests are required to work even when there is no Internet connection, + // so some of these errors we must accept and skip the test. + const result = net.getAddressList(testing.allocator, "example.com", 80) catch |err| switch (err) { + error.UnknownHostName => return error.SkipZigTest, + error.TemporaryNameServerFailure => return error.SkipZigTest, + else => return err, + }; + result.deinit(); + } +} + +test "listen on a port, send bytes, receive bytes" { + if (builtin.single_threaded) return error.SkipZigTest; + if (builtin.os.tag == .wasi) return error.SkipZigTest; + + if (builtin.os.tag == .windows) { + _ = try std.os.windows.WSAStartup(2, 2); + } + defer { + if (builtin.os.tag == .windows) { + std.os.windows.WSACleanup() catch unreachable; + } + } + + // Try only the IPv4 variant as some CI builders have no IPv6 localhost + // configured. + const localhost = try net.Address.parseIp("127.0.0.1", 0); + + var server = try localhost.listen(.{}); + defer server.deinit(); + + const S = struct { + fn clientFn(server_address: net.Address) !void { + const socket = try net.tcpConnectToAddress(server_address); + defer socket.close(); + + _ = try socket.writer().writeAll("Hello world!"); + } + }; + + const t = try std.Thread.spawn(.{}, S.clientFn, .{server.listen_address}); + defer t.join(); + + var client = try server.accept(); + defer client.stream.close(); + var buf: [16]u8 = undefined; + const n = try client.stream.reader().read(&buf); + + try testing.expectEqual(@as(usize, 12), n); + try testing.expectEqualSlices(u8, "Hello world!", buf[0..n]); +} + +test "listen on an in use port" { + if (builtin.os.tag != .linux and comptime !builtin.os.tag.isDarwin()) { + // TODO build abstractions for other operating systems + return error.SkipZigTest; + } + + const localhost = try net.Address.parseIp("127.0.0.1", 0); + + var server1 = try localhost.listen(.{ .reuse_port = true }); + defer server1.deinit(); + + var server2 = try server1.listen_address.listen(.{ .reuse_port = true }); + defer server2.deinit(); +} + +fn testClientToHost(allocator: mem.Allocator, name: []const u8, port: u16) anyerror!void { + if (builtin.os.tag == .wasi) return error.SkipZigTest; + + const connection = try net.tcpConnectToHost(allocator, name, port); + defer connection.close(); + + var buf: [100]u8 = undefined; + const len = try connection.read(&buf); + const msg = buf[0..len]; + try testing.expect(mem.eql(u8, msg, "hello from server\n")); +} + +fn testClient(addr: net.Address) anyerror!void { + if (builtin.os.tag == .wasi) return error.SkipZigTest; + + const socket_file = try net.tcpConnectToAddress(addr); + defer socket_file.close(); + + var buf: [100]u8 = undefined; + const len = try socket_file.read(&buf); + const msg = buf[0..len]; + try testing.expect(mem.eql(u8, msg, "hello from server\n")); +} + +fn testServer(server: *net.Server) anyerror!void { + if (builtin.os.tag == .wasi) return error.SkipZigTest; + + var client = try server.accept(); + + const stream = client.stream.writer(); + try stream.print("hello from server\n", .{}); +} + +test "listen on a unix socket, send bytes, receive bytes" { + if (builtin.single_threaded) return error.SkipZigTest; + if (!net.has_unix_sockets) return error.SkipZigTest; + + if (builtin.os.tag == .windows) { + _ = try std.os.windows.WSAStartup(2, 2); + } + defer { + if (builtin.os.tag == .windows) { + std.os.windows.WSACleanup() catch unreachable; + } + } + + const socket_path = try generateFileName("socket.unix"); + defer testing.allocator.free(socket_path); + + const socket_addr = try net.Address.initUnix(socket_path); + defer std.fs.cwd().deleteFile(socket_path) catch {}; + + var server = try socket_addr.listen(.{}); + defer server.deinit(); + + const S = struct { + fn clientFn(path: []const u8) !void { + const socket = try net.connectUnixSocket(path); + defer socket.close(); + + _ = try socket.writer().writeAll("Hello world!"); + } + }; + + const t = try std.Thread.spawn(.{}, S.clientFn, .{socket_path}); + defer t.join(); + + var client = try server.accept(); + defer client.stream.close(); + var buf: [16]u8 = undefined; + const n = try client.stream.reader().read(&buf); + + try testing.expectEqual(@as(usize, 12), n); + try testing.expectEqualSlices(u8, "Hello world!", buf[0..n]); +} + +fn generateFileName(base_name: []const u8) ![]const u8 { + const random_bytes_count = 12; + const sub_path_len = comptime std.fs.base64_encoder.calcSize(random_bytes_count); + var random_bytes: [12]u8 = undefined; + std.crypto.random.bytes(&random_bytes); + var sub_path: [sub_path_len]u8 = undefined; + _ = std.fs.base64_encoder.encode(&sub_path, &random_bytes); + return std.fmt.allocPrint(testing.allocator, "{s}-{s}", .{ sub_path[0..], base_name }); +} + +test "non-blocking tcp server" { + if (builtin.os.tag == .wasi) return error.SkipZigTest; + if (true) { + // https://github.com/ziglang/zig/issues/18315 + return error.SkipZigTest; + } + + const localhost = try net.Address.parseIp("127.0.0.1", 0); + var server = localhost.listen(.{ .force_nonblocking = true }); + defer server.deinit(); + + const accept_err = server.accept(); + try testing.expectError(error.WouldBlock, accept_err); + + const socket_file = try net.tcpConnectToAddress(server.listen_address); + defer socket_file.close(); + + var client = try server.accept(); + defer client.stream.close(); + const stream = client.stream.writer(); + try stream.print("hello from server\n", .{}); + + var buf: [100]u8 = undefined; + const len = try socket_file.read(&buf); + const msg = buf[0..len]; + try testing.expect(mem.eql(u8, msg, "hello from server\n")); +} diff --git a/src/http/async/tls.zig/PrivateKey.zig b/src/http/async/tls.zig/PrivateKey.zig new file mode 100644 index 00000000..0e2b944d --- /dev/null +++ b/src/http/async/tls.zig/PrivateKey.zig @@ -0,0 +1,260 @@ +const std = @import("std"); +const Allocator = std.mem.Allocator; +const Certificate = std.crypto.Certificate; +const der = Certificate.der; +const rsa = @import("rsa/rsa.zig"); +const base64 = std.base64.standard.decoderWithIgnore(" \t\r\n"); +const proto = @import("protocol.zig"); + +const max_ecdsa_key_len = 66; + +signature_scheme: proto.SignatureScheme, + +key: union { + rsa: rsa.KeyPair, + ecdsa: [max_ecdsa_key_len]u8, +}, + +const PrivateKey = @This(); + +pub fn fromFile(gpa: Allocator, file: std.fs.File) !PrivateKey { + const buf = try file.readToEndAlloc(gpa, 1024 * 1024); + defer gpa.free(buf); + return try parsePem(buf); +} + +pub fn parsePem(buf: []const u8) !PrivateKey { + const key_start, const key_end, const marker_version = try findKey(buf); + const encoded = std.mem.trim(u8, buf[key_start..key_end], " \t\r\n"); + + // required bytes: + // 2412, 1821, 1236 for rsa 4096, 3072, 2048 bits size keys + var decoded: [4096]u8 = undefined; + const n = try base64.decode(&decoded, encoded); + + if (marker_version == 2) { + return try parseEcDer(decoded[0..n]); + } + return try parseDer(decoded[0..n]); +} + +fn findKey(buf: []const u8) !struct { usize, usize, usize } { + const markers = [_]struct { + begin: []const u8, + end: []const u8, + }{ + .{ .begin = "-----BEGIN PRIVATE KEY-----", .end = "-----END PRIVATE KEY-----" }, + .{ .begin = "-----BEGIN EC PRIVATE KEY-----", .end = "-----END EC PRIVATE KEY-----" }, + }; + + for (markers, 1..) |marker, ver| { + const begin_marker_start = std.mem.indexOfPos(u8, buf, 0, marker.begin) orelse continue; + const key_start = begin_marker_start + marker.begin.len; + const key_end = std.mem.indexOfPos(u8, buf, key_start, marker.end) orelse continue; + + return .{ key_start, key_end, ver }; + } + return error.MissingEndMarker; +} + +// ref: https://asn1js.eu/#MIG2AgEAMBAGByqGSM49AgEGBSuBBAAiBIGeMIGbAgEBBDBKFkVJCtU9FR6egz3yNxKBwXd86cFzMYqyGb8hRc1zVvLdw-So_2FBtITp6jzYmFShZANiAAQ-CH3a1R0V6dFlTK8Rs4M4egrpPtdta0osysO0Zl8mkBiDsTlvJNqeAp7L2ItHgFW8k_CfhgQT6iLDacNMhKC4XOV07r_ePD-mmkvqvRmzfOowHUoVRhCKrOTmF_J9Syc +pub fn parseDer(buf: []const u8) !PrivateKey { + const info = try der.Element.parse(buf, 0); + const version = try der.Element.parse(buf, info.slice.start); + + const algo_seq = try der.Element.parse(buf, version.slice.end); + const algo_cat = try der.Element.parse(buf, algo_seq.slice.start); + + const key_str = try der.Element.parse(buf, algo_seq.slice.end); + const key_seq = try der.Element.parse(buf, key_str.slice.start); + const key_int = try der.Element.parse(buf, key_seq.slice.start); + + const category = try Certificate.parseAlgorithmCategory(buf, algo_cat); + switch (category) { + .rsaEncryption => { + const modulus = try der.Element.parse(buf, key_int.slice.end); + const public_exponent = try der.Element.parse(buf, modulus.slice.end); + const private_exponent = try der.Element.parse(buf, public_exponent.slice.end); + + const public_key = try rsa.PublicKey.fromBytes(content(buf, modulus), content(buf, public_exponent)); + const secret_key = try rsa.SecretKey.fromBytes(public_key.modulus, content(buf, private_exponent)); + const key_pair = rsa.KeyPair{ .public = public_key, .secret = secret_key }; + + return .{ + .signature_scheme = switch (key_pair.public.modulus.bits()) { + 4096 => .rsa_pss_rsae_sha512, + 3072 => .rsa_pss_rsae_sha384, + else => .rsa_pss_rsae_sha256, + }, + .key = .{ .rsa = key_pair }, + }; + }, + .X9_62_id_ecPublicKey => { + const key = try der.Element.parse(buf, key_int.slice.end); + const algo_param = try der.Element.parse(buf, algo_cat.slice.end); + const named_curve = try Certificate.parseNamedCurve(buf, algo_param); + return .{ + .signature_scheme = signatureScheme(named_curve), + .key = .{ .ecdsa = ecdsaKey(buf, key) }, + }; + }, + else => unreachable, + } +} + +// References: +// https://asn1js.eu/#MHcCAQEEINJSRKv8kSKEzLHptfAlg-LGh4_pHHlq0XLf30Q9pcztoAoGCCqGSM49AwEHoUQDQgAEJpmLyp8aGCgyMcFIJaIq_-4V1K6nPpeoih3bT2npeplF9eyXj7rm8eW9Ua6VLhq71mqtMC-YLm-IkORBVq1cuA +// https://www.rfc-editor.org/rfc/rfc5915 +pub fn parseEcDer(bytes: []const u8) !PrivateKey { + const pki_msg = try der.Element.parse(bytes, 0); + const version = try der.Element.parse(bytes, pki_msg.slice.start); + const key = try der.Element.parse(bytes, version.slice.end); + const parameters = try der.Element.parse(bytes, key.slice.end); + const curve = try der.Element.parse(bytes, parameters.slice.start); + const named_curve = try Certificate.parseNamedCurve(bytes, curve); + return .{ + .signature_scheme = signatureScheme(named_curve), + .key = .{ .ecdsa = ecdsaKey(bytes, key) }, + }; +} + +fn signatureScheme(named_curve: Certificate.NamedCurve) proto.SignatureScheme { + return switch (named_curve) { + .X9_62_prime256v1 => .ecdsa_secp256r1_sha256, + .secp384r1 => .ecdsa_secp384r1_sha384, + .secp521r1 => .ecdsa_secp521r1_sha512, + }; +} + +fn ecdsaKey(bytes: []const u8, e: der.Element) [max_ecdsa_key_len]u8 { + const data = content(bytes, e); + var ecdsa_key: [max_ecdsa_key_len]u8 = undefined; + @memcpy(ecdsa_key[0..data.len], data); + return ecdsa_key; +} + +fn content(bytes: []const u8, e: der.Element) []const u8 { + return bytes[e.slice.start..e.slice.end]; +} + +const testing = std.testing; +const testu = @import("testu.zig"); + +test "parse ec pem" { + const data = @embedFile("testdata/ec_private_key.pem"); + var pk = try parsePem(data); + const priv_key = &testu.hexToBytes( + \\ 10 35 3d ca 1b 15 1d 06 aa 71 b8 ef f3 19 22 + \\ 43 78 f3 20 98 1e b1 2f 2b 64 7e 71 d0 30 2a + \\ 90 aa e5 eb 99 c3 90 65 3d c1 26 19 be 3f 08 + \\ 20 9b 01 + ); + try testing.expectEqualSlices(u8, priv_key, pk.key.ecdsa[0..priv_key.len]); + try testing.expectEqual(.ecdsa_secp384r1_sha384, pk.signature_scheme); +} + +test "parse ec prime256v1" { + const data = @embedFile("testdata/ec_prime256v1_private_key.pem"); + var pk = try parsePem(data); + const priv_key = &testu.hexToBytes( + \\ d2 52 44 ab fc 91 22 84 cc b1 e9 b5 f0 25 83 + \\ e2 c6 87 8f e9 1c 79 6a d1 72 df df 44 3d a5 + \\ cc ed + ); + try testing.expectEqualSlices(u8, priv_key, pk.key.ecdsa[0..priv_key.len]); + try testing.expectEqual(.ecdsa_secp256r1_sha256, pk.signature_scheme); +} + +test "parse ec secp384r1" { + const data = @embedFile("testdata/ec_secp384r1_private_key.pem"); + var pk = try parsePem(data); + const priv_key = &testu.hexToBytes( + \\ ee 6d 8a 5e 0d d3 b0 c6 4b 32 40 80 e2 3a de + \\ 8b 1e dd e2 92 db 36 1c db 91 ea ba a1 06 0d + \\ 42 2d d9 a9 dc 05 43 29 f1 78 7c f9 08 af c5 + \\ 03 1f 6d + ); + try testing.expectEqualSlices(u8, priv_key, pk.key.ecdsa[0..priv_key.len]); + try testing.expectEqual(.ecdsa_secp384r1_sha384, pk.signature_scheme); +} + +test "parse ec secp521r1" { + const data = @embedFile("testdata/ec_secp521r1_private_key.pem"); + var pk = try parsePem(data); + const priv_key = &testu.hexToBytes( + \\ 01 f0 2f 5a c7 24 18 ea 68 23 8c 2e a1 b4 b8 + \\ dc f2 11 b2 96 b0 ec 87 80 42 bf de ba f4 96 + \\ 83 8f 9b db c6 60 a7 4c d9 60 3a e4 ba 0b df + \\ ae 24 d3 1b c2 6e 82 a0 88 c1 ed 17 20 0d 3a + \\ f1 c5 7e e8 0b 27 + ); + try testing.expectEqualSlices(u8, priv_key, pk.key.ecdsa[0..priv_key.len]); + try testing.expectEqual(.ecdsa_secp521r1_sha512, pk.signature_scheme); +} + +test "parse rsa pem" { + const data = @embedFile("testdata/rsa_private_key.pem"); + const pk = try parsePem(data); + + // expected results from: + // $ openssl pkey -in testdata/rsa_private_key.pem -text -noout + const modulus = &testu.hexToBytes( + \\ 00 de f7 23 e6 75 cc 6f dd d5 6e 0f 8c 09 f8 + \\ 62 e3 60 1b c0 7d 8c d5 04 50 2c 36 e2 3b f7 + \\ 33 9f a1 14 af be cf 1a 0f 4c f5 cb 39 70 0e + \\ 3b 97 d6 21 f7 48 91 79 ca 7c 68 fc ea 62 a1 + \\ 5a 72 4f 78 57 0e cc f2 a3 50 05 f1 4c ca 51 + \\ 73 10 9a 18 8e 71 f5 b4 c7 3e be 4c ef 37 d4 + \\ 84 4b 82 1c ec 08 a3 cc 07 3d 5c 0b e5 85 3f + \\ fe b6 44 77 8f 3c 6a 2f 33 c3 5d f6 f2 29 46 + \\ 04 25 7e 05 d9 f8 3b 2d a4 40 66 9f 0d 6d 1a + \\ fa bc 0a c5 8b 86 43 30 ef 14 20 41 9d b5 cc + \\ 3e 63 b5 48 04 27 c9 5c d3 62 28 5f f5 b6 e4 + \\ 77 49 99 ac 84 4a a6 67 a5 9a 1a 37 c7 60 4c + \\ ba c1 70 cf 57 64 4a 21 ea 05 53 10 ec 94 71 + \\ 4a 43 04 83 00 aa 5a 28 bc f2 8c 58 14 92 d2 + \\ 83 17 f4 7b 29 0f e7 87 a2 47 b2 53 19 12 23 + \\ fb 4b ce 5a f8 a1 84 f9 b1 f3 bf e3 fa 10 f8 + \\ ad af 87 ce 03 0e a0 2c 13 71 57 c4 55 44 48 + \\ 44 cb + ); + const public_exponent = &testu.hexToBytes("01 00 01"); + const private_exponent = &testu.hexToBytes( + \\ 50 3b 80 98 aa a5 11 50 33 40 32 aa 02 e0 75 + \\ bd 3a 55 62 34 0b 9c 8f bb c5 dd 4e 15 a4 03 + \\ d8 9a 5f 56 4a 84 3d ed 69 95 3d 37 03 02 ac + \\ 21 1c 36 06 c4 ff 4c 63 37 d7 93 c3 48 10 a5 + \\ fa 62 6c 7c 6f 60 02 a4 0f e4 c3 8b 0d 76 b7 + \\ c0 2e a3 4d 86 e6 92 d1 eb db 10 d6 38 31 ea + \\ 15 3d d1 e8 81 c7 67 60 e7 8c 9a df 51 ce d0 + \\ 7a 88 32 b9 c1 54 b8 7d 98 fc d4 23 1a 05 0e + \\ f2 ea e1 72 29 28 2a 68 b7 90 18 80 1c 21 d6 + \\ 36 a8 6b 4a 9c dd 14 b8 9f 85 ee 95 0b f4 c6 + \\ 17 02 aa 4d ea 4d f9 39 d7 dd 9d b4 1d d2 f8 + \\ 92 46 0f 18 41 80 f4 ea 27 55 29 f8 37 59 bf + \\ 43 ec a3 eb 19 ba bc 13 06 95 3d 25 4b c9 72 + \\ cf 41 0a 6f aa cb 79 d4 7b fa b1 09 7c e2 2f + \\ 85 51 44 8b c6 97 8e 46 f9 6b ac 08 87 92 ce + \\ af 0b bf 8c bd 27 51 8f 09 e4 d3 f9 04 ac fa + \\ f2 04 70 3e d9 a6 28 17 c2 2d 74 e9 25 40 02 + \\ 49 + ); + + try testing.expectEqual(.rsa_pss_rsae_sha256, pk.signature_scheme); + const kp = pk.key.rsa; + { + var bytes: [modulus.len]u8 = undefined; + try kp.public.modulus.toBytes(&bytes, .big); + try testing.expectEqualSlices(u8, modulus, &bytes); + } + { + var bytes: [private_exponent.len]u8 = undefined; + try kp.public.public_exponent.toBytes(&bytes, .big); + try testing.expectEqualSlices(u8, public_exponent, bytes[bytes.len - public_exponent.len .. bytes.len]); + } + { + var btytes: [private_exponent.len]u8 = undefined; + try kp.secret.private_exponent.toBytes(&btytes, .big); + try testing.expectEqualSlices(u8, private_exponent, &btytes); + } +} diff --git a/src/http/async/tls.zig/cbc/main.zig b/src/http/async/tls.zig/cbc/main.zig new file mode 100644 index 00000000..25038445 --- /dev/null +++ b/src/http/async/tls.zig/cbc/main.zig @@ -0,0 +1,148 @@ +// This file is originally copied from: https://github.com/jedisct1/zig-cbc. +// +// It is modified then to have TLS padding insead of PKCS#7 padding. +// Reference: +// https://datatracker.ietf.org/doc/html/rfc5246/#section-6.2.3.2 +// https://crypto.stackexchange.com/questions/98917/on-the-correctness-of-the-padding-example-of-rfc-5246 +// +// If required padding i n bytes +// PKCS#7 padding is (n...n) +// TLS padding is (n-1...n-1) - n times of n-1 value +// +const std = @import("std"); +const aes = std.crypto.core.aes; +const mem = std.mem; +const debug = std.debug; + +/// CBC mode with TLS 1.2 padding +/// +/// Important: the counter mode doesn't provide authenticated encryption: the ciphertext can be trivially modified without this being detected. +/// If you need authenticated encryption, use anything from `std.crypto.aead` instead. +/// If you really need to use CBC mode, make sure to use a MAC to authenticate the ciphertext. +pub fn CBC(comptime BlockCipher: anytype) type { + const EncryptCtx = aes.AesEncryptCtx(BlockCipher); + const DecryptCtx = aes.AesDecryptCtx(BlockCipher); + + return struct { + const Self = @This(); + + enc_ctx: EncryptCtx, + dec_ctx: DecryptCtx, + + /// Initialize the CBC context with the given key. + pub fn init(key: [BlockCipher.key_bits / 8]u8) Self { + const enc_ctx = BlockCipher.initEnc(key); + const dec_ctx = DecryptCtx.initFromEnc(enc_ctx); + + return Self{ .enc_ctx = enc_ctx, .dec_ctx = dec_ctx }; + } + + /// Return the length of the ciphertext given the length of the plaintext. + pub fn paddedLength(length: usize) usize { + return (std.math.divCeil(usize, length + 1, EncryptCtx.block_length) catch unreachable) * EncryptCtx.block_length; + } + + /// Encrypt the given plaintext for the given IV. + /// The destination buffer must be large enough to hold the padded plaintext. + /// Use the `paddedLength()` function to compute the ciphertext size. + /// IV must be secret and unpredictable. + pub fn encrypt(self: Self, dst: []u8, src: []const u8, iv: [EncryptCtx.block_length]u8) void { + // Note: encryption *could* be parallelized, see https://research.kudelskisecurity.com/2022/11/17/some-aes-cbc-encryption-myth-busting/ + const block_length = EncryptCtx.block_length; + const padded_length = paddedLength(src.len); + debug.assert(dst.len == padded_length); // destination buffer must hold the padded plaintext + var cv = iv; + var i: usize = 0; + while (i + block_length <= src.len) : (i += block_length) { + const in = src[i..][0..block_length]; + for (cv[0..], in) |*x, y| x.* ^= y; + self.enc_ctx.encrypt(&cv, &cv); + @memcpy(dst[i..][0..block_length], &cv); + } + // Last block + var in = [_]u8{0} ** block_length; + const padding_length: u8 = @intCast(padded_length - src.len - 1); + @memset(&in, padding_length); + @memcpy(in[0 .. src.len - i], src[i..]); + for (cv[0..], in) |*x, y| x.* ^= y; + self.enc_ctx.encrypt(&cv, &cv); + @memcpy(dst[i..], cv[0 .. dst.len - i]); + } + + /// Decrypt the given ciphertext for the given IV. + /// The destination buffer must be large enough to hold the plaintext. + /// IV must be secret, unpredictable and match the one used for encryption. + pub fn decrypt(self: Self, dst: []u8, src: []const u8, iv: [DecryptCtx.block_length]u8) !void { + const block_length = DecryptCtx.block_length; + if (src.len != dst.len) { + return error.EncodingError; + } + debug.assert(src.len % block_length == 0); + var i: usize = 0; + var cv = iv; + var out: [block_length]u8 = undefined; + // Decryption could be parallelized + while (i + block_length <= dst.len) : (i += block_length) { + const in = src[i..][0..block_length]; + self.dec_ctx.decrypt(&out, in); + for (&out, cv) |*x, y| x.* ^= y; + cv = in.*; + @memcpy(dst[i..][0..block_length], &out); + } + // Last block - We intentionally don't check the padding to mitigate timing attacks + if (i < dst.len) { + const in = src[i..][0..block_length]; + @memset(&out, 0); + self.dec_ctx.decrypt(&out, in); + for (&out, cv) |*x, y| x.* ^= y; + @memcpy(dst[i..], out[0 .. dst.len - i]); + } + } + }; +} + +test "CBC mode" { + const M = CBC(aes.Aes128); + const key = [_]u8{ 0x2b, 0x7e, 0x15, 0x16, 0x28, 0xae, 0xd2, 0xa6, 0xab, 0xf7, 0x15, 0x88, 0x09, 0xcf, 0x4f, 0x3c }; + const iv = [_]u8{ 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f }; + const src_ = "This is a test of AES-CBC that goes on longer than a couple blocks. It is a somewhat long test case to type out!"; + const expected = "\xA0\x8C\x09\x7D\xFF\x42\xB6\x65\x4D\x4B\xC6\x90\x90\x39\xDE\x3D\xC7\xCA\xEB\xF6\x9A\x4F\x09\x97\xC9\x32\xAB\x75\x88\xB7\x57\x17"; + var res: [32]u8 = undefined; + + try comptime std.testing.expect(src_.len / M.paddedLength(1) >= 3); // Ensure that we have at least 3 blocks + + const z = M.init(key); + + // Test encryption and decryption with distinct buffers + var h = std.crypto.hash.sha2.Sha256.init(.{}); + inline for (0..src_.len) |len| { + const src = src_[0..len]; + var dst = [_]u8{0} ** M.paddedLength(src.len); + + z.encrypt(&dst, src, iv); + h.update(&dst); + + var decrypted = [_]u8{0} ** dst.len; + try z.decrypt(&decrypted, &dst, iv); + + const padding = decrypted[decrypted.len - 1] + 1; + try std.testing.expectEqualSlices(u8, src, decrypted[0 .. decrypted.len - padding]); + } + h.final(&res); + try std.testing.expectEqualSlices(u8, expected, &res); + + // Test encryption and decryption with the same buffer + h = std.crypto.hash.sha2.Sha256.init(.{}); + inline for (0..src_.len) |len| { + var buf = [_]u8{0} ** M.paddedLength(len); + @memcpy(buf[0..len], src_[0..len]); + z.encrypt(&buf, buf[0..len], iv); + h.update(&buf); + + try z.decrypt(&buf, &buf, iv); + + try std.testing.expectEqualSlices(u8, src_[0..len], buf[0..len]); + } + h.final(&res); + try std.testing.expectEqualSlices(u8, expected, &res); +} diff --git a/src/http/async/tls.zig/cipher.zig b/src/http/async/tls.zig/cipher.zig new file mode 100644 index 00000000..dbf4a07a --- /dev/null +++ b/src/http/async/tls.zig/cipher.zig @@ -0,0 +1,1004 @@ +const std = @import("std"); +const crypto = std.crypto; +const hkdfExpandLabel = crypto.tls.hkdfExpandLabel; + +const Sha1 = crypto.hash.Sha1; +const Sha256 = crypto.hash.sha2.Sha256; +const Sha384 = crypto.hash.sha2.Sha384; + +const record = @import("record.zig"); +const Record = record.Record; +const Transcript = @import("transcript.zig").Transcript; +const proto = @import("protocol.zig"); + +// tls 1.2 cbc cipher types +const CbcAes128Sha1 = CbcType(crypto.core.aes.Aes128, Sha1); +const CbcAes128Sha256 = CbcType(crypto.core.aes.Aes128, Sha256); +const CbcAes256Sha256 = CbcType(crypto.core.aes.Aes256, Sha256); +const CbcAes256Sha384 = CbcType(crypto.core.aes.Aes256, Sha384); +// tls 1.2 gcm cipher types +const Aead12Aes128Gcm = Aead12Type(crypto.aead.aes_gcm.Aes128Gcm); +const Aead12Aes256Gcm = Aead12Type(crypto.aead.aes_gcm.Aes256Gcm); +// tls 1.2 chacha cipher type +const Aead12ChaCha = Aead12ChaChaType(crypto.aead.chacha_poly.ChaCha20Poly1305); +// tls 1.3 cipher types +const Aes128GcmSha256 = Aead13Type(crypto.aead.aes_gcm.Aes128Gcm, Sha256); +const Aes256GcmSha384 = Aead13Type(crypto.aead.aes_gcm.Aes256Gcm, Sha384); +const ChaChaSha256 = Aead13Type(crypto.aead.chacha_poly.ChaCha20Poly1305, Sha256); +const Aegis128Sha256 = Aead13Type(crypto.aead.aegis.Aegis128L, Sha256); + +pub const encrypt_overhead_tls_12: comptime_int = @max( + CbcAes128Sha1.encrypt_overhead, + CbcAes128Sha256.encrypt_overhead, + CbcAes256Sha256.encrypt_overhead, + CbcAes256Sha384.encrypt_overhead, + Aead12Aes128Gcm.encrypt_overhead, + Aead12Aes256Gcm.encrypt_overhead, + Aead12ChaCha.encrypt_overhead, +); +pub const encrypt_overhead_tls_13: comptime_int = @max( + Aes128GcmSha256.encrypt_overhead, + Aes256GcmSha384.encrypt_overhead, + ChaChaSha256.encrypt_overhead, + Aegis128Sha256.encrypt_overhead, +); + +// ref (length): https://www.rfc-editor.org/rfc/rfc8446#section-5.1 +pub const max_cleartext_len = 1 << 14; +// ref (length): https://www.rfc-editor.org/rfc/rfc8446#section-5.2 +// The sum of the lengths of the content and the padding, plus one for the inner +// content type, plus any expansion added by the AEAD algorithm. +pub const max_ciphertext_len = max_cleartext_len + 256; +pub const max_ciphertext_record_len = record.header_len + max_ciphertext_len; + +/// Returns type for cipher suite tag. +fn CipherType(comptime tag: CipherSuite) type { + return switch (tag) { + // tls 1.2 cbc + .ECDHE_ECDSA_WITH_AES_128_CBC_SHA, + .ECDHE_RSA_WITH_AES_128_CBC_SHA, + .RSA_WITH_AES_128_CBC_SHA, + => CbcAes128Sha1, + .ECDHE_ECDSA_WITH_AES_128_CBC_SHA256, + .ECDHE_RSA_WITH_AES_128_CBC_SHA256, + .RSA_WITH_AES_128_CBC_SHA256, + => CbcAes128Sha256, + .ECDHE_ECDSA_WITH_AES_256_CBC_SHA384, + .ECDHE_RSA_WITH_AES_256_CBC_SHA384, + => CbcAes256Sha384, + + // tls 1.2 gcm + .ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, + .ECDHE_RSA_WITH_AES_128_GCM_SHA256, + => Aead12Aes128Gcm, + .ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, + .ECDHE_RSA_WITH_AES_256_GCM_SHA384, + => Aead12Aes256Gcm, + + // tls 1.2 chacha + .ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256, + .ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256, + => Aead12ChaCha, + + // tls 1.3 + .AES_128_GCM_SHA256 => Aes128GcmSha256, + .AES_256_GCM_SHA384 => Aes256GcmSha384, + .CHACHA20_POLY1305_SHA256 => ChaChaSha256, + .AEGIS_128L_SHA256 => Aegis128Sha256, + + else => unreachable, + }; +} + +/// Provides initialization and common encrypt/decrypt methods for all supported +/// ciphers. Tls 1.2 has only application cipher, tls 1.3 has separate cipher +/// for handshake and application. +pub const Cipher = union(CipherSuite) { + // tls 1.2 cbc + ECDHE_ECDSA_WITH_AES_128_CBC_SHA: CipherType(.ECDHE_ECDSA_WITH_AES_128_CBC_SHA), + ECDHE_RSA_WITH_AES_128_CBC_SHA: CipherType(.ECDHE_RSA_WITH_AES_128_CBC_SHA), + RSA_WITH_AES_128_CBC_SHA: CipherType(.RSA_WITH_AES_128_CBC_SHA), + + ECDHE_ECDSA_WITH_AES_128_CBC_SHA256: CipherType(.ECDHE_ECDSA_WITH_AES_128_CBC_SHA256), + ECDHE_RSA_WITH_AES_128_CBC_SHA256: CipherType(.ECDHE_RSA_WITH_AES_128_CBC_SHA256), + RSA_WITH_AES_128_CBC_SHA256: CipherType(.RSA_WITH_AES_128_CBC_SHA256), + + ECDHE_ECDSA_WITH_AES_256_CBC_SHA384: CipherType(.ECDHE_ECDSA_WITH_AES_256_CBC_SHA384), + ECDHE_RSA_WITH_AES_256_CBC_SHA384: CipherType(.ECDHE_RSA_WITH_AES_256_CBC_SHA384), + // tls 1.2 gcm + ECDHE_ECDSA_WITH_AES_128_GCM_SHA256: CipherType(.ECDHE_ECDSA_WITH_AES_128_GCM_SHA256), + ECDHE_ECDSA_WITH_AES_256_GCM_SHA384: CipherType(.ECDHE_RSA_WITH_AES_256_GCM_SHA384), + ECDHE_RSA_WITH_AES_128_GCM_SHA256: CipherType(.ECDHE_RSA_WITH_AES_128_GCM_SHA256), + ECDHE_RSA_WITH_AES_256_GCM_SHA384: CipherType(.ECDHE_RSA_WITH_AES_256_GCM_SHA384), + // tls 1.2 chacha + ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256: CipherType(.ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256), + ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256: CipherType(.ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256), + // tls 1.3 + AES_128_GCM_SHA256: CipherType(.AES_128_GCM_SHA256), + AES_256_GCM_SHA384: CipherType(.AES_256_GCM_SHA384), + CHACHA20_POLY1305_SHA256: CipherType(.CHACHA20_POLY1305_SHA256), + AEGIS_128L_SHA256: CipherType(.AEGIS_128L_SHA256), + + // tls 1.2 application cipher + pub fn initTls12(tag: CipherSuite, key_material: []const u8, side: proto.Side) !Cipher { + switch (tag) { + inline .ECDHE_ECDSA_WITH_AES_128_CBC_SHA, + .ECDHE_RSA_WITH_AES_128_CBC_SHA, + .RSA_WITH_AES_128_CBC_SHA, + .ECDHE_ECDSA_WITH_AES_128_CBC_SHA256, + .ECDHE_RSA_WITH_AES_128_CBC_SHA256, + .RSA_WITH_AES_128_CBC_SHA256, + .ECDHE_ECDSA_WITH_AES_256_CBC_SHA384, + .ECDHE_RSA_WITH_AES_256_CBC_SHA384, + .ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, + .ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, + .ECDHE_RSA_WITH_AES_128_GCM_SHA256, + .ECDHE_RSA_WITH_AES_256_GCM_SHA384, + .ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256, + .ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256, + => |comptime_tag| { + return @unionInit(Cipher, @tagName(comptime_tag), CipherType(comptime_tag).init(key_material, side)); + }, + else => return error.TlsIllegalParameter, + } + } + + // tls 1.3 handshake or application cipher + pub fn initTls13(tag: CipherSuite, secret: Transcript.Secret, side: proto.Side) !Cipher { + return switch (tag) { + inline .AES_128_GCM_SHA256, + .AES_256_GCM_SHA384, + .CHACHA20_POLY1305_SHA256, + .AEGIS_128L_SHA256, + => |comptime_tag| { + return @unionInit(Cipher, @tagName(comptime_tag), CipherType(comptime_tag).init(secret, side)); + }, + else => return error.TlsIllegalParameter, + }; + } + + pub fn encrypt( + c: *Cipher, + buf: []u8, + content_type: proto.ContentType, + cleartext: []const u8, + ) ![]const u8 { + return switch (c.*) { + inline else => |*f| try f.encrypt(buf, content_type, cleartext), + }; + } + + pub fn decrypt( + c: *Cipher, + buf: []u8, + rec: Record, + ) !struct { proto.ContentType, []u8 } { + return switch (c.*) { + inline else => |*f| { + const content_type, const cleartext = try f.decrypt(buf, rec); + if (cleartext.len > max_cleartext_len) return error.TlsRecordOverflow; + return .{ content_type, cleartext }; + }, + }; + } + + pub fn encryptSeq(c: Cipher) u64 { + return switch (c) { + inline else => |f| f.encrypt_seq, + }; + } + + pub fn keyUpdateEncrypt(c: *Cipher) !void { + return switch (c.*) { + inline .AES_128_GCM_SHA256, + .AES_256_GCM_SHA384, + .CHACHA20_POLY1305_SHA256, + .AEGIS_128L_SHA256, + => |*f| f.keyUpdateEncrypt(), + // can't happen on tls 1.2 + else => return error.TlsUnexpectedMessage, + }; + } + pub fn keyUpdateDecrypt(c: *Cipher) !void { + return switch (c.*) { + inline .AES_128_GCM_SHA256, + .AES_256_GCM_SHA384, + .CHACHA20_POLY1305_SHA256, + .AEGIS_128L_SHA256, + => |*f| f.keyUpdateDecrypt(), + // can't happen on tls 1.2 + else => return error.TlsUnexpectedMessage, + }; + } +}; + +fn Aead12Type(comptime AeadType: type) type { + return struct { + const explicit_iv_len = 8; + const key_len = AeadType.key_length; + const auth_tag_len = AeadType.tag_length; + const nonce_len = AeadType.nonce_length; + const iv_len = AeadType.nonce_length - explicit_iv_len; + const encrypt_overhead = record.header_len + explicit_iv_len + auth_tag_len; + + encrypt_key: [key_len]u8, + decrypt_key: [key_len]u8, + encrypt_iv: [iv_len]u8, + decrypt_iv: [iv_len]u8, + encrypt_seq: u64 = 0, + decrypt_seq: u64 = 0, + rnd: std.Random = crypto.random, + + const Self = @This(); + + fn init(key_material: []const u8, side: proto.Side) Self { + const client_key = key_material[0..key_len].*; + const server_key = key_material[key_len..][0..key_len].*; + const client_iv = key_material[2 * key_len ..][0..iv_len].*; + const server_iv = key_material[2 * key_len + iv_len ..][0..iv_len].*; + return .{ + .encrypt_key = if (side == .client) client_key else server_key, + .decrypt_key = if (side == .client) server_key else client_key, + .encrypt_iv = if (side == .client) client_iv else server_iv, + .decrypt_iv = if (side == .client) server_iv else client_iv, + }; + } + + /// Returns encrypted tls record in format: + /// ----------------- buf ---------------------- + /// header | explicit_iv | ciphertext | auth_tag + /// + /// tls record header: 5 bytes + /// explicit_iv: 8 bytes + /// ciphertext: same length as cleartext + /// auth_tag: 16 bytes + pub fn encrypt( + self: *Self, + buf: []u8, + content_type: proto.ContentType, + cleartext: []const u8, + ) ![]const u8 { + const record_len = record.header_len + explicit_iv_len + cleartext.len + auth_tag_len; + if (buf.len < record_len) return error.BufferOverflow; + + const header = buf[0..record.header_len]; + const explicit_iv = buf[record.header_len..][0..explicit_iv_len]; + const ciphertext = buf[record.header_len + explicit_iv_len ..][0..cleartext.len]; + const auth_tag = buf[record.header_len + explicit_iv_len + cleartext.len ..][0..auth_tag_len]; + + header.* = record.header(content_type, explicit_iv_len + cleartext.len + auth_tag_len); + self.rnd.bytes(explicit_iv); + const iv = self.encrypt_iv ++ explicit_iv.*; + const ad = additionalData(self.encrypt_seq, content_type, cleartext.len); + AeadType.encrypt(ciphertext, auth_tag, cleartext, &ad, iv, self.encrypt_key); + self.encrypt_seq +%= 1; + + return buf[0..record_len]; + } + + /// Decrypts payload into cleartext. Returns tls record content type and + /// cleartext. + /// Accepts tls record header and payload: + /// header | ----------- payload --------------- + /// header | explicit_iv | ciphertext | auth_tag + pub fn decrypt( + self: *Self, + buf: []u8, + rec: Record, + ) !struct { proto.ContentType, []u8 } { + const overhead = explicit_iv_len + auth_tag_len; + if (rec.payload.len < overhead) return error.TlsDecryptError; + const cleartext_len = rec.payload.len - overhead; + if (buf.len < cleartext_len) return error.BufferOverflow; + + const explicit_iv = rec.payload[0..explicit_iv_len]; + const ciphertext = rec.payload[explicit_iv_len..][0..cleartext_len]; + const auth_tag = rec.payload[explicit_iv_len + cleartext_len ..][0..auth_tag_len]; + + const iv = self.decrypt_iv ++ explicit_iv.*; + const ad = additionalData(self.decrypt_seq, rec.content_type, cleartext_len); + const cleartext = buf[0..cleartext_len]; + AeadType.decrypt(cleartext, ciphertext, auth_tag.*, &ad, iv, self.decrypt_key) catch return error.TlsDecryptError; + self.decrypt_seq +%= 1; + return .{ rec.content_type, cleartext }; + } + }; +} + +fn Aead12ChaChaType(comptime AeadType: type) type { + return struct { + const key_len = AeadType.key_length; + const auth_tag_len = AeadType.tag_length; + const nonce_len = AeadType.nonce_length; + const encrypt_overhead = record.header_len + auth_tag_len; + + encrypt_key: [key_len]u8, + decrypt_key: [key_len]u8, + encrypt_iv: [nonce_len]u8, + decrypt_iv: [nonce_len]u8, + encrypt_seq: u64 = 0, + decrypt_seq: u64 = 0, + + const Self = @This(); + + fn init(key_material: []const u8, side: proto.Side) Self { + const client_key = key_material[0..key_len].*; + const server_key = key_material[key_len..][0..key_len].*; + const client_iv = key_material[2 * key_len ..][0..nonce_len].*; + const server_iv = key_material[2 * key_len + nonce_len ..][0..nonce_len].*; + return .{ + .encrypt_key = if (side == .client) client_key else server_key, + .decrypt_key = if (side == .client) server_key else client_key, + .encrypt_iv = if (side == .client) client_iv else server_iv, + .decrypt_iv = if (side == .client) server_iv else client_iv, + }; + } + + /// Returns encrypted tls record in format: + /// ------------ buf ------------- + /// header | ciphertext | auth_tag + /// + /// tls record header: 5 bytes + /// ciphertext: same length as cleartext + /// auth_tag: 16 bytes + pub fn encrypt( + self: *Self, + buf: []u8, + content_type: proto.ContentType, + cleartext: []const u8, + ) ![]const u8 { + const record_len = record.header_len + cleartext.len + auth_tag_len; + if (buf.len < record_len) return error.BufferOverflow; + + const ciphertext = buf[record.header_len..][0..cleartext.len]; + const auth_tag = buf[record.header_len + ciphertext.len ..][0..auth_tag_len]; + + const ad = additionalData(self.encrypt_seq, content_type, cleartext.len); + const iv = ivWithSeq(nonce_len, self.encrypt_iv, self.encrypt_seq); + AeadType.encrypt(ciphertext, auth_tag, cleartext, &ad, iv, self.encrypt_key); + self.encrypt_seq +%= 1; + + buf[0..record.header_len].* = record.header(content_type, ciphertext.len + auth_tag.len); + return buf[0..record_len]; + } + + /// Decrypts payload into cleartext. Returns tls record content type and + /// cleartext. + /// Accepts tls record header and payload: + /// header | ----- payload ------- + /// header | ciphertext | auth_tag + pub fn decrypt( + self: *Self, + buf: []u8, + rec: Record, + ) !struct { proto.ContentType, []u8 } { + const overhead = auth_tag_len; + if (rec.payload.len < overhead) return error.TlsDecryptError; + const cleartext_len = rec.payload.len - overhead; + if (buf.len < cleartext_len) return error.BufferOverflow; + + const ciphertext = rec.payload[0..cleartext_len]; + const auth_tag = rec.payload[cleartext_len..][0..auth_tag_len]; + const cleartext = buf[0..cleartext_len]; + + const ad = additionalData(self.decrypt_seq, rec.content_type, cleartext_len); + const iv = ivWithSeq(nonce_len, self.decrypt_iv, self.decrypt_seq); + AeadType.decrypt(cleartext, ciphertext, auth_tag.*, &ad, iv, self.decrypt_key) catch return error.TlsDecryptError; + self.decrypt_seq +%= 1; + return .{ rec.content_type, cleartext }; + } + }; +} + +fn Aead13Type(comptime AeadType: type, comptime Hash: type) type { + return struct { + const Hmac = crypto.auth.hmac.Hmac(Hash); + const Hkdf = crypto.kdf.hkdf.Hkdf(Hmac); + + const key_len = AeadType.key_length; + const auth_tag_len = AeadType.tag_length; + const nonce_len = AeadType.nonce_length; + const digest_len = Hash.digest_length; + const encrypt_overhead = record.header_len + 1 + auth_tag_len; + + encrypt_secret: [digest_len]u8, + decrypt_secret: [digest_len]u8, + encrypt_key: [key_len]u8, + decrypt_key: [key_len]u8, + encrypt_iv: [nonce_len]u8, + decrypt_iv: [nonce_len]u8, + encrypt_seq: u64 = 0, + decrypt_seq: u64 = 0, + + const Self = @This(); + + pub fn init(secret: Transcript.Secret, side: proto.Side) Self { + var self = Self{ + .encrypt_secret = if (side == .client) secret.client[0..digest_len].* else secret.server[0..digest_len].*, + .decrypt_secret = if (side == .server) secret.client[0..digest_len].* else secret.server[0..digest_len].*, + .encrypt_key = undefined, + .decrypt_key = undefined, + .encrypt_iv = undefined, + .decrypt_iv = undefined, + }; + self.keyGenerate(); + return self; + } + + fn keyGenerate(self: *Self) void { + self.encrypt_key = hkdfExpandLabel(Hkdf, self.encrypt_secret, "key", "", key_len); + self.decrypt_key = hkdfExpandLabel(Hkdf, self.decrypt_secret, "key", "", key_len); + self.encrypt_iv = hkdfExpandLabel(Hkdf, self.encrypt_secret, "iv", "", nonce_len); + self.decrypt_iv = hkdfExpandLabel(Hkdf, self.decrypt_secret, "iv", "", nonce_len); + } + + pub fn keyUpdateEncrypt(self: *Self) void { + self.encrypt_secret = hkdfExpandLabel(Hkdf, self.encrypt_secret, "traffic upd", "", digest_len); + self.encrypt_seq = 0; + self.keyGenerate(); + } + + pub fn keyUpdateDecrypt(self: *Self) void { + self.decrypt_secret = hkdfExpandLabel(Hkdf, self.decrypt_secret, "traffic upd", "", digest_len); + self.decrypt_seq = 0; + self.keyGenerate(); + } + + /// Returns encrypted tls record in format: + /// ------------ buf ------------- + /// header | ciphertext | auth_tag + /// + /// tls record header: 5 bytes + /// ciphertext: cleartext len + 1 byte content type + /// auth_tag: 16 bytes + pub fn encrypt( + self: *Self, + buf: []u8, + content_type: proto.ContentType, + cleartext: []const u8, + ) ![]const u8 { + const payload_len = cleartext.len + 1 + auth_tag_len; + const record_len = record.header_len + payload_len; + if (buf.len < record_len) return error.BufferOverflow; + + const header = buf[0..record.header_len]; + header.* = record.header(.application_data, payload_len); + + // Skip @memcpy if cleartext is already part of the buf at right position + if (&cleartext[0] != &buf[record.header_len]) { + @memcpy(buf[record.header_len..][0..cleartext.len], cleartext); + } + buf[record.header_len + cleartext.len] = @intFromEnum(content_type); + const ciphertext = buf[record.header_len..][0 .. cleartext.len + 1]; + const auth_tag = buf[record.header_len + ciphertext.len ..][0..auth_tag_len]; + + const iv = ivWithSeq(nonce_len, self.encrypt_iv, self.encrypt_seq); + AeadType.encrypt(ciphertext, auth_tag, ciphertext, header, iv, self.encrypt_key); + self.encrypt_seq +%= 1; + return buf[0..record_len]; + } + + /// Decrypts payload into cleartext. Returns tls record content type and + /// cleartext. + /// Accepts tls record header and payload: + /// header | ------- payload --------- + /// header | ciphertext | auth_tag + /// header | cleartext + ct | auth_tag + /// Ciphertext after decryption contains cleartext and content type (1 byte). + pub fn decrypt( + self: *Self, + buf: []u8, + rec: Record, + ) !struct { proto.ContentType, []u8 } { + const overhead = auth_tag_len + 1; + if (rec.payload.len < overhead) return error.TlsDecryptError; + const ciphertext_len = rec.payload.len - auth_tag_len; + if (buf.len < ciphertext_len) return error.BufferOverflow; + + const ciphertext = rec.payload[0..ciphertext_len]; + const auth_tag = rec.payload[ciphertext_len..][0..auth_tag_len]; + + const iv = ivWithSeq(nonce_len, self.decrypt_iv, self.decrypt_seq); + AeadType.decrypt(buf[0..ciphertext_len], ciphertext, auth_tag.*, rec.header, iv, self.decrypt_key) catch return error.TlsBadRecordMac; + + // Remove zero bytes padding + var content_type_idx: usize = ciphertext_len - 1; + while (buf[content_type_idx] == 0 and content_type_idx > 0) : (content_type_idx -= 1) {} + + const cleartext = buf[0..content_type_idx]; + const content_type: proto.ContentType = @enumFromInt(buf[content_type_idx]); + self.decrypt_seq +%= 1; + return .{ content_type, cleartext }; + } + }; +} + +fn CbcType(comptime BlockCipher: type, comptime HashType: type) type { + const CBC = @import("cbc/main.zig").CBC(BlockCipher); + return struct { + const mac_len = HashType.digest_length; // 20, 32, 48 bytes for sha1, sha256, sha384 + const key_len = BlockCipher.key_bits / 8; // 16, 32 for Aes128, Aes256 + const iv_len = 16; + const encrypt_overhead = record.header_len + iv_len + mac_len + max_padding; + + pub const Hmac = crypto.auth.hmac.Hmac(HashType); + const paddedLength = CBC.paddedLength; + const max_padding = 16; + + encrypt_secret: [mac_len]u8, + decrypt_secret: [mac_len]u8, + encrypt_key: [key_len]u8, + decrypt_key: [key_len]u8, + encrypt_seq: u64 = 0, + decrypt_seq: u64 = 0, + rnd: std.Random = crypto.random, + + const Self = @This(); + + fn init(key_material: []const u8, side: proto.Side) Self { + const client_secret = key_material[0..mac_len].*; + const server_secret = key_material[mac_len..][0..mac_len].*; + const client_key = key_material[2 * mac_len ..][0..key_len].*; + const server_key = key_material[2 * mac_len + key_len ..][0..key_len].*; + return .{ + .encrypt_secret = if (side == .client) client_secret else server_secret, + .decrypt_secret = if (side == .client) server_secret else client_secret, + .encrypt_key = if (side == .client) client_key else server_key, + .decrypt_key = if (side == .client) server_key else client_key, + }; + } + + /// Returns encrypted tls record in format: + /// ----------------- buf ----------------- + /// header | iv | ------ ciphertext ------- + /// header | iv | cleartext | mac | padding + /// + /// tls record header: 5 bytes + /// iv: 16 bytes + /// ciphertext: cleartext length + mac + padding + /// mac: 20, 32 or 48 (sha1, sha256, sha384) + /// padding: 1-16 bytes + /// + /// Max encrypt buf overhead = iv + mac + padding (1-16) + /// aes_128_cbc_sha => 16 + 20 + 16 = 52 + /// aes_128_cbc_sha256 => 16 + 32 + 16 = 64 + /// aes_256_cbc_sha384 => 16 + 48 + 16 = 80 + pub fn encrypt( + self: *Self, + buf: []u8, + content_type: proto.ContentType, + cleartext: []const u8, + ) ![]const u8 { + const max_record_len = record.header_len + iv_len + cleartext.len + mac_len + max_padding; + if (buf.len < max_record_len) return error.BufferOverflow; + const cleartext_idx = record.header_len + iv_len; // position of cleartext in buf + @memcpy(buf[cleartext_idx..][0..cleartext.len], cleartext); + + { // calculate mac from (ad + cleartext) + // ... | ad | cleartext | mac | ... + // | -- mac msg -- | mac | + const ad = additionalData(self.encrypt_seq, content_type, cleartext.len); + const mac_msg = buf[cleartext_idx - ad.len ..][0 .. ad.len + cleartext.len]; + @memcpy(mac_msg[0..ad.len], &ad); + const mac = buf[cleartext_idx + cleartext.len ..][0..mac_len]; + Hmac.create(mac, mac_msg, &self.encrypt_secret); + self.encrypt_seq +%= 1; + } + + // ... | cleartext | mac | ... + // ... | -- plaintext --- ... + // ... | cleartext | mac | padding + // ... | ------ ciphertext ------- + const unpadded_len = cleartext.len + mac_len; + const padded_len = paddedLength(unpadded_len); + const plaintext = buf[cleartext_idx..][0..unpadded_len]; + const ciphertext = buf[cleartext_idx..][0..padded_len]; + + // Add header and iv at the buf start + // header | iv | ... + buf[0..record.header_len].* = record.header(content_type, iv_len + ciphertext.len); + const iv = buf[record.header_len..][0..iv_len]; + self.rnd.bytes(iv); + + // encrypt plaintext into ciphertext + CBC.init(self.encrypt_key).encrypt(ciphertext, plaintext, iv[0..iv_len].*); + + // header | iv | ------ ciphertext ------- + return buf[0 .. record.header_len + iv_len + ciphertext.len]; + } + + /// Decrypts payload into cleartext. Returns tls record content type and + /// cleartext. + pub fn decrypt( + self: *Self, + buf: []u8, + rec: Record, + ) !struct { proto.ContentType, []u8 } { + if (rec.payload.len < iv_len + mac_len + 1) return error.TlsDecryptError; + + // --------- payload ------------ + // iv | ------ ciphertext ------- + // iv | cleartext | mac | padding + const iv = rec.payload[0..iv_len]; + const ciphertext = rec.payload[iv_len..]; + + if (buf.len < ciphertext.len + additional_data_len) return error.BufferOverflow; + // ---------- buf --------------- + // ad | ------ plaintext -------- + // ad | cleartext | mac | padding + const plaintext = buf[additional_data_len..][0..ciphertext.len]; + // decrypt ciphertext -> plaintext + CBC.init(self.decrypt_key).decrypt(plaintext, ciphertext, iv[0..iv_len].*) catch return error.TlsDecryptError; + + // get padding len from last padding byte + const padding_len = plaintext[plaintext.len - 1] + 1; + if (plaintext.len < mac_len + padding_len) return error.TlsDecryptError; + // split plaintext into cleartext and mac + const cleartext_len = plaintext.len - mac_len - padding_len; + const cleartext = plaintext[0..cleartext_len]; + const mac = plaintext[cleartext_len..][0..mac_len]; + + // write ad to the buf + var ad = additionalData(self.decrypt_seq, rec.content_type, cleartext_len); + @memcpy(buf[0..ad.len], &ad); + const mac_msg = buf[0 .. ad.len + cleartext_len]; + self.decrypt_seq +%= 1; + + // calculate expected mac and compare with received mac + var expected_mac: [mac_len]u8 = undefined; + Hmac.create(&expected_mac, mac_msg, &self.decrypt_secret); + if (!std.mem.eql(u8, &expected_mac, mac)) + return error.TlsBadRecordMac; + + return .{ rec.content_type, cleartext }; + } + }; +} + +// xor lower 8 iv bytes with seq +fn ivWithSeq(comptime nonce_len: usize, iv: [nonce_len]u8, seq: u64) [nonce_len]u8 { + var res = iv; + const buf = res[nonce_len - 8 ..]; + const operand = std.mem.readInt(u64, buf, .big); + std.mem.writeInt(u64, buf, operand ^ seq, .big); + return res; +} + +pub const additional_data_len = record.header_len + @sizeOf(u64); + +fn additionalData(seq: u64, content_type: proto.ContentType, payload_len: usize) [additional_data_len]u8 { + const header = record.header(content_type, payload_len); + var seq_buf: [8]u8 = undefined; + std.mem.writeInt(u64, &seq_buf, seq, .big); + return seq_buf ++ header; +} + +// Cipher suites lists. In the order of preference. +// For the preference using grades priority and rules from Go project. +// https://ciphersuite.info/page/faq/ +// https://github.com/golang/go/blob/73186ba00251b3ed8baaab36e4f5278c7681155b/src/crypto/tls/cipher_suites.go#L226 +pub const cipher_suites = struct { + const tls12_secure = if (crypto.core.aes.has_hardware_support) [_]CipherSuite{ + // recommended + .ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, + .ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, + .ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256, + // secure + .ECDHE_RSA_WITH_AES_128_GCM_SHA256, + .ECDHE_RSA_WITH_AES_256_GCM_SHA384, + .ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256, + } else [_]CipherSuite{ + // recommended + .ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256, + .ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, + .ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, + + // secure + .ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256, + .ECDHE_RSA_WITH_AES_128_GCM_SHA256, + .ECDHE_RSA_WITH_AES_256_GCM_SHA384, + }; + const tls12_week = [_]CipherSuite{ + // week + .ECDHE_ECDSA_WITH_AES_128_CBC_SHA256, + .ECDHE_ECDSA_WITH_AES_256_CBC_SHA384, + .ECDHE_ECDSA_WITH_AES_128_CBC_SHA, + .ECDHE_RSA_WITH_AES_128_CBC_SHA256, + .ECDHE_RSA_WITH_AES_256_CBC_SHA384, + .ECDHE_RSA_WITH_AES_128_CBC_SHA, + + .RSA_WITH_AES_128_CBC_SHA256, + .RSA_WITH_AES_128_CBC_SHA, + }; + pub const tls13_ = if (crypto.core.aes.has_hardware_support) [_]CipherSuite{ + .AES_128_GCM_SHA256, + .AES_256_GCM_SHA384, + .CHACHA20_POLY1305_SHA256, + // Excluded because didn't find server which supports it to test + // .AEGIS_128L_SHA256 + } else [_]CipherSuite{ + .CHACHA20_POLY1305_SHA256, + .AES_128_GCM_SHA256, + .AES_256_GCM_SHA384, + }; + + pub const tls13 = &tls13_; + pub const tls12 = &(tls12_secure ++ tls12_week); + pub const secure = &(tls13_ ++ tls12_secure); + pub const all = &(tls13_ ++ tls12_secure ++ tls12_week); + + pub fn includes(list: []const CipherSuite, cs: CipherSuite) bool { + for (list) |s| { + if (cs == s) return true; + } + return false; + } +}; + +// Week, secure, recommended grades are from https://ciphersuite.info/page/faq/ +pub const CipherSuite = enum(u16) { + // tls 1.2 cbc sha1 + ECDHE_ECDSA_WITH_AES_128_CBC_SHA = 0xc009, // week + ECDHE_RSA_WITH_AES_128_CBC_SHA = 0xc013, // week + RSA_WITH_AES_128_CBC_SHA = 0x002F, // week + // tls 1.2 cbc sha256 + ECDHE_ECDSA_WITH_AES_128_CBC_SHA256 = 0xc023, // week + ECDHE_RSA_WITH_AES_128_CBC_SHA256 = 0xc027, // week + RSA_WITH_AES_128_CBC_SHA256 = 0x003c, // week + // tls 1.2 cbc sha384 + ECDHE_ECDSA_WITH_AES_256_CBC_SHA384 = 0xc024, // week + ECDHE_RSA_WITH_AES_256_CBC_SHA384 = 0xc028, // week + // tls 1.2 gcm + ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 = 0xc02b, // recommended + ECDHE_ECDSA_WITH_AES_256_GCM_SHA384 = 0xc02c, // recommended + ECDHE_RSA_WITH_AES_128_GCM_SHA256 = 0xc02f, // secure + ECDHE_RSA_WITH_AES_256_GCM_SHA384 = 0xc030, // secure + // tls 1.2 chacha + ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256 = 0xcca9, // recommended + ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256 = 0xcca8, // secure + // tls 1.3 (all are recommended) + AES_128_GCM_SHA256 = 0x1301, + AES_256_GCM_SHA384 = 0x1302, + CHACHA20_POLY1305_SHA256 = 0x1303, + AEGIS_128L_SHA256 = 0x1307, + // AEGIS_256_SHA512 = 0x1306, + _, + + pub fn validate(cs: CipherSuite) !void { + if (cipher_suites.includes(cipher_suites.tls12, cs)) return; + if (cipher_suites.includes(cipher_suites.tls13, cs)) return; + return error.TlsIllegalParameter; + } + + pub const Versions = enum { + both, + tls_1_3, + tls_1_2, + }; + + // get tls versions from list of cipher suites + pub fn versions(list: []const CipherSuite) !Versions { + var has_12 = false; + var has_13 = false; + for (list) |cs| { + if (cipher_suites.includes(cipher_suites.tls12, cs)) { + has_12 = true; + } else { + if (cipher_suites.includes(cipher_suites.tls13, cs)) has_13 = true; + } + } + if (has_12 and has_13) return .both; + if (has_12) return .tls_1_2; + if (has_13) return .tls_1_3; + return error.TlsIllegalParameter; + } + + pub const KeyExchangeAlgorithm = enum { + ecdhe, + rsa, + }; + + pub fn keyExchange(s: CipherSuite) KeyExchangeAlgorithm { + return switch (s) { + // Random premaster secret, encrypted with publich key from certificate. + // No server key exchange message. + .RSA_WITH_AES_128_CBC_SHA, + .RSA_WITH_AES_128_CBC_SHA256, + => .rsa, + else => .ecdhe, + }; + } + + pub const HashTag = enum { + sha256, + sha384, + sha512, + }; + + pub inline fn hash(cs: CipherSuite) HashTag { + return switch (cs) { + .ECDHE_RSA_WITH_AES_256_CBC_SHA384, + .ECDHE_RSA_WITH_AES_256_GCM_SHA384, + .ECDHE_ECDSA_WITH_AES_256_CBC_SHA384, + .ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, + .AES_256_GCM_SHA384, + => .sha384, + else => .sha256, + }; + } +}; + +const testing = std.testing; +const testu = @import("testu.zig"); + +test "CipherSuite validate" { + { + const cs: CipherSuite = .AES_256_GCM_SHA384; + try cs.validate(); + try testing.expectEqual(cs.hash(), .sha384); + try testing.expectEqual(cs.keyExchange(), .ecdhe); + } + { + const cs: CipherSuite = .AES_128_GCM_SHA256; + try cs.validate(); + try testing.expectEqual(.sha256, cs.hash()); + try testing.expectEqual(.ecdhe, cs.keyExchange()); + } + for (cipher_suites.tls12) |cs| { + try cs.validate(); + _ = cs.hash(); + _ = cs.keyExchange(); + } +} + +test "CipherSuite versions" { + try testing.expectEqual(.tls_1_3, CipherSuite.versions(&[_]CipherSuite{.AES_128_GCM_SHA256})); + try testing.expectEqual(.both, CipherSuite.versions(&[_]CipherSuite{ .AES_128_GCM_SHA256, .ECDHE_ECDSA_WITH_AES_128_CBC_SHA })); + try testing.expectEqual(.tls_1_2, CipherSuite.versions(&[_]CipherSuite{.RSA_WITH_AES_128_CBC_SHA})); +} + +test "gcm 1.2 encrypt overhead" { + inline for ([_]type{ + Aead12Aes128Gcm, + Aead12Aes256Gcm, + }) |T| { + { + const expected_key_len = switch (T) { + Aead12Aes128Gcm => 16, + Aead12Aes256Gcm => 32, + else => unreachable, + }; + try testing.expectEqual(expected_key_len, T.key_len); + try testing.expectEqual(16, T.auth_tag_len); + try testing.expectEqual(12, T.nonce_len); + try testing.expectEqual(4, T.iv_len); + try testing.expectEqual(29, T.encrypt_overhead); + } + } +} + +test "cbc 1.2 encrypt overhead" { + try testing.expectEqual(85, encrypt_overhead_tls_12); + + inline for ([_]type{ + CbcAes128Sha1, + CbcAes128Sha256, + CbcAes256Sha384, + }) |T| { + switch (T) { + CbcAes128Sha1 => { + try testing.expectEqual(20, T.mac_len); + try testing.expectEqual(16, T.key_len); + try testing.expectEqual(57, T.encrypt_overhead); + }, + CbcAes128Sha256 => { + try testing.expectEqual(32, T.mac_len); + try testing.expectEqual(16, T.key_len); + try testing.expectEqual(69, T.encrypt_overhead); + }, + CbcAes256Sha384 => { + try testing.expectEqual(48, T.mac_len); + try testing.expectEqual(32, T.key_len); + try testing.expectEqual(85, T.encrypt_overhead); + }, + else => unreachable, + } + try testing.expectEqual(16, T.paddedLength(1)); // cbc block padding + try testing.expectEqual(16, T.iv_len); + } +} + +test "overhead tls 1.3" { + try testing.expectEqual(22, encrypt_overhead_tls_13); + try expectOverhead(Aes128GcmSha256, 16, 16, 12, 22); + try expectOverhead(Aes256GcmSha384, 32, 16, 12, 22); + try expectOverhead(ChaChaSha256, 32, 16, 12, 22); + try expectOverhead(Aegis128Sha256, 16, 16, 16, 22); + // and tls 1.2 chacha + try expectOverhead(Aead12ChaCha, 32, 16, 12, 21); +} + +fn expectOverhead(T: type, key_len: usize, auth_tag_len: usize, nonce_len: usize, overhead: usize) !void { + try testing.expectEqual(key_len, T.key_len); + try testing.expectEqual(auth_tag_len, T.auth_tag_len); + try testing.expectEqual(nonce_len, T.nonce_len); + try testing.expectEqual(overhead, T.encrypt_overhead); +} + +test "client/server encryption tls 1.3" { + inline for (cipher_suites.tls13) |cs| { + var buf: [256]u8 = undefined; + testu.fill(&buf); + const secret = Transcript.Secret{ + .client = buf[0..128], + .server = buf[128..], + }; + var client_cipher = try Cipher.initTls13(cs, secret, .client); + var server_cipher = try Cipher.initTls13(cs, secret, .server); + try encryptDecrypt(&client_cipher, &server_cipher); + + try client_cipher.keyUpdateEncrypt(); + try server_cipher.keyUpdateDecrypt(); + try encryptDecrypt(&client_cipher, &server_cipher); + + try client_cipher.keyUpdateDecrypt(); + try server_cipher.keyUpdateEncrypt(); + try encryptDecrypt(&client_cipher, &server_cipher); + } +} + +test "client/server encryption tls 1.2" { + inline for (cipher_suites.tls12) |cs| { + var key_material: [256]u8 = undefined; + testu.fill(&key_material); + var client_cipher = try Cipher.initTls12(cs, &key_material, .client); + var server_cipher = try Cipher.initTls12(cs, &key_material, .server); + try encryptDecrypt(&client_cipher, &server_cipher); + } +} + +fn encryptDecrypt(client_cipher: *Cipher, server_cipher: *Cipher) !void { + const cleartext = + \\ Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do + \\ eiusmod tempor incididunt ut labore et dolore magna aliqua. + ; + var buf: [256]u8 = undefined; + + { // client to server + // encrypt + const encrypted = try client_cipher.encrypt(&buf, .application_data, cleartext); + const expected_encrypted_len = switch (client_cipher.*) { + inline else => |f| brk: { + const T = @TypeOf(f); + break :brk switch (T) { + CbcAes128Sha1, + CbcAes128Sha256, + CbcAes256Sha256, + CbcAes256Sha384, + => record.header_len + T.paddedLength(T.iv_len + cleartext.len + T.mac_len), + Aead12Aes128Gcm, + Aead12Aes256Gcm, + Aead12ChaCha, + Aes128GcmSha256, + Aes256GcmSha384, + ChaChaSha256, + Aegis128Sha256, + => cleartext.len + T.encrypt_overhead, + else => unreachable, + }; + }, + }; + try testing.expectEqual(expected_encrypted_len, encrypted.len); + // decrypt + const content_type, const decrypted = try server_cipher.decrypt(&buf, Record.init(encrypted)); + try testing.expectEqualSlices(u8, cleartext, decrypted); + try testing.expectEqual(.application_data, content_type); + } + // server to client + { + const encrypted = try server_cipher.encrypt(&buf, .application_data, cleartext); + const content_type, const decrypted = try client_cipher.decrypt(&buf, Record.init(encrypted)); + try testing.expectEqualSlices(u8, cleartext, decrypted); + try testing.expectEqual(.application_data, content_type); + } +} diff --git a/src/http/async/tls.zig/connection.zig b/src/http/async/tls.zig/connection.zig new file mode 100644 index 00000000..9ccd9c53 --- /dev/null +++ b/src/http/async/tls.zig/connection.zig @@ -0,0 +1,665 @@ +const std = @import("std"); +const assert = std.debug.assert; + +const proto = @import("protocol.zig"); +const record = @import("record.zig"); +const cipher = @import("cipher.zig"); +const Cipher = cipher.Cipher; + +const async_io = @import("../std/http/Client.zig"); +const Cbk = async_io.Cbk; +const Ctx = async_io.Ctx; + +pub fn connection(stream: anytype) Connection(@TypeOf(stream)) { + return .{ + .stream = stream, + .rec_rdr = record.reader(stream), + }; +} + +pub fn Connection(comptime Stream: type) type { + return struct { + stream: Stream, // underlying stream + rec_rdr: record.Reader(Stream), + cipher: Cipher = undefined, + + max_encrypt_seq: u64 = std.math.maxInt(u64) - 1, + key_update_requested: bool = false, + + read_buf: []const u8 = "", + received_close_notify: bool = false, + + const Self = @This(); + + /// Encrypts and writes single tls record to the stream. + fn writeRecord(c: *Self, content_type: proto.ContentType, bytes: []const u8) !void { + assert(bytes.len <= cipher.max_cleartext_len); + var write_buf: [cipher.max_ciphertext_record_len]u8 = undefined; + // If key update is requested send key update message and update + // my encryption keys. + if (c.cipher.encryptSeq() >= c.max_encrypt_seq or @atomicLoad(bool, &c.key_update_requested, .monotonic)) { + @atomicStore(bool, &c.key_update_requested, false, .monotonic); + + // If the request_update field is set to "update_requested", + // then the receiver MUST send a KeyUpdate of its own with + // request_update set to "update_not_requested" prior to sending + // its next Application Data record. This mechanism allows + // either side to force an update to the entire connection, but + // causes an implementation which receives multiple KeyUpdates + // while it is silent to respond with a single update. + // + // rfc: https://datatracker.ietf.org/doc/html/rfc8446#autoid-57 + const key_update = &record.handshakeHeader(.key_update, 1) ++ [_]u8{0}; + const rec = try c.cipher.encrypt(&write_buf, .handshake, key_update); + try c.stream.writeAll(rec); + try c.cipher.keyUpdateEncrypt(); + } + const rec = try c.cipher.encrypt(&write_buf, content_type, bytes); + try c.stream.writeAll(rec); + } + + fn writeAlert(c: *Self, err: anyerror) !void { + const cleartext = proto.alertFromError(err); + var buf: [128]u8 = undefined; + const ciphertext = try c.cipher.encrypt(&buf, .alert, &cleartext); + c.stream.writeAll(ciphertext) catch {}; + } + + /// Returns next record of cleartext data. + /// Can be used in iterator like loop without memcpy to another buffer: + /// while (try client.next()) |buf| { ... } + pub fn next(c: *Self) ReadError!?[]const u8 { + const content_type, const data = c.nextRecord() catch |err| { + try c.writeAlert(err); + return err; + } orelse return null; + if (content_type != .application_data) return error.TlsUnexpectedMessage; + return data; + } + + fn nextRecord(c: *Self) ReadError!?struct { proto.ContentType, []const u8 } { + if (c.eof()) return null; + while (true) { + const content_type, const cleartext = try c.rec_rdr.nextDecrypt(&c.cipher) orelse return null; + + switch (content_type) { + .application_data => {}, + .handshake => { + const handshake_type: proto.Handshake = @enumFromInt(cleartext[0]); + switch (handshake_type) { + // skip new session ticket and read next record + .new_session_ticket => continue, + .key_update => { + if (cleartext.len != 5) return error.TlsDecodeError; + // rfc: Upon receiving a KeyUpdate, the receiver MUST + // update its receiving keys. + try c.cipher.keyUpdateDecrypt(); + const key: proto.KeyUpdateRequest = @enumFromInt(cleartext[4]); + switch (key) { + .update_requested => { + @atomicStore(bool, &c.key_update_requested, true, .monotonic); + }, + .update_not_requested => {}, + else => return error.TlsIllegalParameter, + } + // this record is handled read next + continue; + }, + else => {}, + } + }, + .alert => { + if (cleartext.len < 2) return error.TlsUnexpectedMessage; + try proto.Alert.parse(cleartext[0..2].*).toError(); + // server side clean shutdown + c.received_close_notify = true; + return null; + }, + else => return error.TlsUnexpectedMessage, + } + return .{ content_type, cleartext }; + } + } + + pub fn eof(c: *Self) bool { + return c.received_close_notify and c.read_buf.len == 0; + } + + pub fn close(c: *Self) !void { + if (c.received_close_notify) return; + try c.writeRecord(.alert, &proto.Alert.closeNotify()); + } + + // read, write interface + + pub const ReadError = Stream.ReadError || proto.Alert.Error || + error{ + TlsBadVersion, + TlsUnexpectedMessage, + TlsRecordOverflow, + TlsDecryptError, + TlsDecodeError, + TlsBadRecordMac, + TlsIllegalParameter, + BufferOverflow, + }; + pub const WriteError = Stream.WriteError || + error{ + BufferOverflow, + TlsUnexpectedMessage, + }; + + pub const Reader = std.io.Reader(*Self, ReadError, read); + pub const Writer = std.io.Writer(*Self, WriteError, write); + + pub fn reader(c: *Self) Reader { + return .{ .context = c }; + } + + pub fn writer(c: *Self) Writer { + return .{ .context = c }; + } + + /// Encrypts cleartext and writes it to the underlying stream as single + /// tls record. Max single tls record payload length is 1<<14 (16K) + /// bytes. + pub fn write(c: *Self, bytes: []const u8) WriteError!usize { + const n = @min(bytes.len, cipher.max_cleartext_len); + try c.writeRecord(.application_data, bytes[0..n]); + return n; + } + + /// Encrypts cleartext and writes it to the underlying stream. If needed + /// splits cleartext into multiple tls record. + pub fn writeAll(c: *Self, bytes: []const u8) WriteError!void { + var index: usize = 0; + while (index < bytes.len) { + index += try c.write(bytes[index..]); + } + } + + pub fn read(c: *Self, buffer: []u8) ReadError!usize { + if (c.read_buf.len == 0) { + c.read_buf = try c.next() orelse return 0; + } + const n = @min(c.read_buf.len, buffer.len); + @memcpy(buffer[0..n], c.read_buf[0..n]); + c.read_buf = c.read_buf[n..]; + return n; + } + + /// Returns the number of bytes read. If the number read is smaller than + /// `buffer.len`, it means the stream reached the end. + pub fn readAll(c: *Self, buffer: []u8) ReadError!usize { + return c.readAtLeast(buffer, buffer.len); + } + + /// Returns the number of bytes read, calling the underlying read function + /// the minimal number of times until the buffer has at least `len` bytes + /// filled. If the number read is less than `len` it means the stream + /// reached the end. + pub fn readAtLeast(c: *Self, buffer: []u8, len: usize) ReadError!usize { + assert(len <= buffer.len); + var index: usize = 0; + while (index < len) { + const amt = try c.read(buffer[index..]); + if (amt == 0) break; + index += amt; + } + return index; + } + + /// Returns the number of bytes read. If the number read is less than + /// the space provided it means the stream reached the end. + pub fn readv(c: *Self, iovecs: []std.posix.iovec) !usize { + var vp: VecPut = .{ .iovecs = iovecs }; + while (true) { + if (c.read_buf.len == 0) { + c.read_buf = try c.next() orelse break; + } + const n = vp.put(c.read_buf); + const read_buf_len = c.read_buf.len; + c.read_buf = c.read_buf[n..]; + if ((n < read_buf_len) or + (n == read_buf_len and !c.rec_rdr.hasMore())) + break; + } + return vp.total; + } + + fn onWriteAll(ctx: *Ctx, res: anyerror!void) anyerror!void { + res catch |err| return ctx.pop(err); + + if (ctx._tls_write_bytes.len - ctx._tls_write_index > 0) { + const rec = ctx.conn().tls_client.prepareRecord(ctx.stream(), ctx) catch |err| return ctx.pop(err); + return ctx.stream().async_writeAll(rec, ctx, onWriteAll) catch |err| return ctx.pop(err); + } + + return ctx.pop({}); + } + + pub fn async_writeAll(c: *Self, stream: anytype, bytes: []const u8, ctx: *Ctx, comptime cbk: Cbk) !void { + assert(bytes.len <= cipher.max_cleartext_len); + + ctx._tls_write_bytes = bytes; + ctx._tls_write_index = 0; + const rec = try c.prepareRecord(stream, ctx); + + try ctx.push(cbk); + return stream.async_writeAll(rec, ctx, onWriteAll); + } + + fn prepareRecord(c: *Self, stream: anytype, ctx: *Ctx) ![]const u8 { + const len = @min(ctx._tls_write_bytes.len - ctx._tls_write_index, cipher.max_cleartext_len); + + // If key update is requested send key update message and update + // my encryption keys. + if (c.cipher.encryptSeq() >= c.max_encrypt_seq or @atomicLoad(bool, &c.key_update_requested, .monotonic)) { + @atomicStore(bool, &c.key_update_requested, false, .monotonic); + + // If the request_update field is set to "update_requested", + // then the receiver MUST send a KeyUpdate of its own with + // request_update set to "update_not_requested" prior to sending + // its next Application Data record. This mechanism allows + // either side to force an update to the entire connection, but + // causes an implementation which receives multiple KeyUpdates + // while it is silent to respond with a single update. + // + // rfc: https://datatracker.ietf.org/doc/html/rfc8446#autoid-57 + const key_update = &record.handshakeHeader(.key_update, 1) ++ [_]u8{0}; + const rec = try c.cipher.encrypt(&ctx._tls_write_buf, .handshake, key_update); + try stream.writeAll(rec); // TODO async + try c.cipher.keyUpdateEncrypt(); + } + + defer ctx._tls_write_index += len; + return c.cipher.encrypt(&ctx._tls_write_buf, .application_data, ctx._tls_write_bytes[ctx._tls_write_index..len]); + } + + fn onReadv(ctx: *Ctx, res: anyerror!void) anyerror!void { + res catch |err| return ctx.pop(err); + + if (ctx._tls_read_buf == null) { + // end of read + ctx.setLen(ctx._vp.total); + return ctx.pop({}); + } + + while (true) { + const n = ctx._vp.put(ctx._tls_read_buf.?); + const read_buf_len = ctx._tls_read_buf.?.len; + const c = ctx.conn().tls_client; + + if (read_buf_len == 0) { + // read another buffer + c.async_next(ctx.stream(), ctx, onReadv) catch |err| return ctx.pop(err); + } + + ctx._tls_read_buf = ctx._tls_read_buf.?[n..]; + + if ((n < read_buf_len) or (n == read_buf_len and !c.rec_rdr.hasMore())) { + // end of read + ctx.setLen(ctx._vp.total); + return ctx.pop({}); + } + } + } + + pub fn async_readv(c: *Self, stream: anytype, iovecs: []std.posix.iovec, ctx: *Ctx, comptime cbk: Cbk) !void { + try ctx.push(cbk); + ctx._vp = .{ .iovecs = iovecs }; + + return c.async_next(stream, ctx, onReadv); + } + + fn onNext(ctx: *Ctx, res: anyerror!void) anyerror!void { + res catch |err| { + ctx.conn().tls_client.writeAlert(err) catch |e| std.log.err("onNext: write alert: {any}", .{e}); // TODO async + return ctx.pop(err); + }; + + if (ctx._tls_read_content_type != .application_data) { + return ctx.pop(error.TlsUnexpectedMessage); + } + + return ctx.pop({}); + } + + pub fn async_next(c: *Self, stream: anytype, ctx: *Ctx, comptime cbk: Cbk) !void { + try ctx.push(cbk); + + return c.async_next_decrypt(stream, ctx, onNext); + } + + pub fn onNextDecrypt(ctx: *Ctx, res: anyerror!void) anyerror!void { + res catch |err| return ctx.pop(err); + + const c = ctx.conn().tls_client; + // TOOD not sure if this works in my async case... + if (c.eof()) { + ctx._tls_read_buf = null; + return ctx.pop({}); + } + + const content_type = ctx._tls_read_content_type; + + switch (content_type) { + .application_data => {}, + .handshake => { + const handshake_type: proto.Handshake = @enumFromInt(ctx._tls_read_buf.?[0]); + switch (handshake_type) { + // skip new session ticket and read next record + .new_session_ticket => return c.async_next_record(ctx.stream(), ctx, onNextDecrypt) catch |err| return ctx.pop(err), + .key_update => { + if (ctx._tls_read_buf.?.len != 5) return ctx.pop(error.TlsDecodeError); + // rfc: Upon receiving a KeyUpdate, the receiver MUST + // update its receiving keys. + try c.cipher.keyUpdateDecrypt(); + const key: proto.KeyUpdateRequest = @enumFromInt(ctx._tls_read_buf.?[4]); + switch (key) { + .update_requested => { + @atomicStore(bool, &c.key_update_requested, true, .monotonic); + }, + .update_not_requested => {}, + else => return ctx.pop(error.TlsIllegalParameter), + } + // this record is handled read next + c.async_next_record(ctx.stream(), ctx, onNextDecrypt) catch |err| return ctx.pop(err); + }, + else => {}, + } + }, + .alert => { + if (ctx._tls_read_buf.?.len < 2) return ctx.pop(error.TlsUnexpectedMessage); + try proto.Alert.parse(ctx._tls_read_buf.?[0..2].*).toError(); + // server side clean shutdown + c.received_close_notify = true; + ctx._tls_read_buf = null; + return ctx.pop({}); + }, + else => return ctx.pop(error.TlsUnexpectedMessage), + } + + return ctx.pop({}); + } + + pub fn async_next_decrypt(c: *Self, stream: anytype, ctx: *Ctx, comptime cbk: Cbk) !void { + try ctx.push(cbk); + + return c.async_next_record(stream, ctx, onNextDecrypt) catch |err| return ctx.pop(err); + } + + pub fn onNextRecord(ctx: *Ctx, res: anyerror!void) anyerror!void { + res catch |err| return ctx.pop(err); + + const rec = ctx._tls_read_record orelse { + ctx._tls_read_buf = null; + return ctx.pop({}); + }; + + if (rec.protocol_version != .tls_1_2) return error.TlsBadVersion; + + const c = ctx.conn().tls_client; + const cph = &c.cipher; + + ctx._tls_read_content_type, ctx._tls_read_buf = cph.decrypt( + // Reuse reader buffer for cleartext. `rec.header` and + // `rec.payload`(ciphertext) are also pointing somewhere in + // this buffer. Decrypter is first reading then writing a + // block, cleartext has less length then ciphertext, + // cleartext starts from the beginning of the buffer, so + // ciphertext is always ahead of cleartext. + c.rec_rdr.buffer[0..c.rec_rdr.start], + rec, + ) catch |err| return ctx.pop(err); + + return ctx.pop({}); + } + + pub fn async_next_record(c: *Self, stream: anytype, ctx: *Ctx, comptime cbk: Cbk) !void { + try ctx.push(cbk); + + return c.async_reader_next(stream, ctx, onNextRecord); + } + + pub fn onReaderNext(ctx: *Ctx, res: anyerror!void) anyerror!void { + res catch |err| return ctx.pop(err); + + const c = ctx.conn().tls_client; + + const n = ctx.len(); + if (n == 0) { + ctx._tls_read_record = null; + return ctx.pop({}); + } + c.rec_rdr.end += n; + + return c.readNext(ctx); + } + + pub fn readNext(c: *Self, ctx: *Ctx) anyerror!void { + const buffer = c.rec_rdr.buffer[c.rec_rdr.start..c.rec_rdr.end]; + // If we have 5 bytes header. + if (buffer.len >= record.header_len) { + const record_header = buffer[0..record.header_len]; + const payload_len = std.mem.readInt(u16, record_header[3..5], .big); + if (payload_len > cipher.max_ciphertext_len) + return error.TlsRecordOverflow; + const record_len = record.header_len + payload_len; + // If we have whole record + if (buffer.len >= record_len) { + c.rec_rdr.start += record_len; + ctx._tls_read_record = record.Record.init(buffer[0..record_len]); + return ctx.pop({}); + } + } + { // Move dirty part to the start of the buffer. + const n = c.rec_rdr.end - c.rec_rdr.start; + if (n > 0 and c.rec_rdr.start > 0) { + if (c.rec_rdr.start > n) { + @memcpy(c.rec_rdr.buffer[0..n], c.rec_rdr.buffer[c.rec_rdr.start..][0..n]); + } else { + std.mem.copyForwards(u8, c.rec_rdr.buffer[0..n], c.rec_rdr.buffer[c.rec_rdr.start..][0..n]); + } + } + c.rec_rdr.start = 0; + c.rec_rdr.end = n; + } + // Read more from inner_reader. + return ctx.stream() + .async_read(c.rec_rdr.buffer[c.rec_rdr.end..], ctx, onReaderNext) catch |err| return ctx.pop(err); + } + + pub fn async_reader_next(c: *Self, _: anytype, ctx: *Ctx, comptime cbk: Cbk) !void { + try ctx.push(cbk); + return c.readNext(ctx); + } + }; +} + +const testing = std.testing; +const data12 = @import("testdata/tls12.zig"); +const testu = @import("testu.zig"); + +test "encrypt decrypt" { + var output_buf: [1024]u8 = undefined; + const stream = testu.Stream.init(&(data12.server_pong ** 3), &output_buf); + var conn: Connection(@TypeOf(stream)) = .{ .stream = stream, .rec_rdr = record.reader(stream) }; + conn.cipher = try Cipher.initTls12(.ECDHE_RSA_WITH_AES_128_CBC_SHA, &data12.key_material, .client); + conn.cipher.ECDHE_RSA_WITH_AES_128_CBC_SHA.rnd = testu.random(0); // use fixed rng + + conn.stream.output.reset(); + { // encrypt verify data from example + _ = testu.random(0x40); // sets iv to 40, 41, ... 4f + try conn.writeRecord(.handshake, &data12.client_finished); + try testing.expectEqualSlices(u8, &data12.verify_data_encrypted_msg, conn.stream.output.getWritten()); + } + + conn.stream.output.reset(); + { // encrypt ping + const cleartext = "ping"; + _ = testu.random(0); // sets iv to 00, 01, ... 0f + //conn.encrypt_seq = 1; + + try conn.writeAll(cleartext); + try testing.expectEqualSlices(u8, &data12.encrypted_ping_msg, conn.stream.output.getWritten()); + } + { // decrypt server pong message + conn.cipher.ECDHE_RSA_WITH_AES_128_CBC_SHA.decrypt_seq = 1; + try testing.expectEqualStrings("pong", (try conn.next()).?); + } + { // test reader interface + conn.cipher.ECDHE_RSA_WITH_AES_128_CBC_SHA.decrypt_seq = 1; + var rdr = conn.reader(); + var buffer: [4]u8 = undefined; + const n = try rdr.readAll(&buffer); + try testing.expectEqualStrings("pong", buffer[0..n]); + } + { // test readv interface + conn.cipher.ECDHE_RSA_WITH_AES_128_CBC_SHA.decrypt_seq = 1; + var buffer: [9]u8 = undefined; + var iovecs = [_]std.posix.iovec{ + .{ .base = &buffer, .len = 3 }, + .{ .base = buffer[3..], .len = 3 }, + .{ .base = buffer[6..], .len = 3 }, + }; + const n = try conn.readv(iovecs[0..]); + try testing.expectEqual(4, n); + try testing.expectEqualStrings("pong", buffer[0..n]); + } +} + +// Copied from: https://github.com/ziglang/zig/blob/455899668b620dfda40252501c748c0a983555bd/lib/std/crypto/tls/Client.zig#L1354 +/// Abstraction for sending multiple byte buffers to a slice of iovecs. +pub const VecPut = struct { + iovecs: []const std.posix.iovec, + idx: usize = 0, + off: usize = 0, + total: usize = 0, + + /// Returns the amount actually put which is always equal to bytes.len + /// unless the vectors ran out of space. + pub fn put(vp: *VecPut, bytes: []const u8) usize { + if (vp.idx >= vp.iovecs.len) return 0; + var bytes_i: usize = 0; + while (true) { + const v = vp.iovecs[vp.idx]; + const dest = v.base[vp.off..v.len]; + const src = bytes[bytes_i..][0..@min(dest.len, bytes.len - bytes_i)]; + @memcpy(dest[0..src.len], src); + bytes_i += src.len; + vp.off += src.len; + if (vp.off >= v.len) { + vp.off = 0; + vp.idx += 1; + if (vp.idx >= vp.iovecs.len) { + vp.total += bytes_i; + return bytes_i; + } + } + if (bytes_i >= bytes.len) { + vp.total += bytes_i; + return bytes_i; + } + } + } +}; + +test "client/server connection" { + const BufReaderWriter = struct { + buf: []u8, + wp: usize = 0, + rp: usize = 0, + + const Self = @This(); + + pub fn write(self: *Self, bytes: []const u8) !usize { + if (self.wp == self.buf.len) return error.NoSpaceLeft; + + const n = @min(bytes.len, self.buf.len - self.wp); + @memcpy(self.buf[self.wp..][0..n], bytes[0..n]); + self.wp += n; + return n; + } + + pub fn writeAll(self: *Self, bytes: []const u8) !void { + var n: usize = 0; + while (n < bytes.len) { + n += try self.write(bytes[n..]); + } + } + + pub fn read(self: *Self, bytes: []u8) !usize { + const n = @min(bytes.len, self.wp - self.rp); + if (n == 0) return 0; + @memcpy(bytes[0..n], self.buf[self.rp..][0..n]); + self.rp += n; + if (self.rp == self.wp) { + self.wp = 0; + self.rp = 0; + } + return n; + } + }; + + const TestStream = struct { + inner_stream: *BufReaderWriter, + const Self = @This(); + pub const ReadError = error{}; + pub const WriteError = error{NoSpaceLeft}; + pub fn read(self: *Self, bytes: []u8) !usize { + return try self.inner_stream.read(bytes); + } + pub fn writeAll(self: *Self, bytes: []const u8) !void { + return try self.inner_stream.writeAll(bytes); + } + }; + + const buf_len = 32 * 1024; + const tls_records_in_buf = (std.math.divCeil(comptime_int, buf_len, cipher.max_cleartext_len) catch unreachable); + const overhead: usize = tls_records_in_buf * @import("cipher.zig").encrypt_overhead_tls_13; + var buf: [buf_len + overhead]u8 = undefined; + var inner_stream = BufReaderWriter{ .buf = &buf }; + + const cipher_client, const cipher_server = brk: { + const Transcript = @import("transcript.zig").Transcript; + const CipherSuite = @import("cipher.zig").CipherSuite; + const cipher_suite: CipherSuite = .AES_256_GCM_SHA384; + + var rnd: [128]u8 = undefined; + std.crypto.random.bytes(&rnd); + const secret = Transcript.Secret{ + .client = rnd[0..64], + .server = rnd[64..], + }; + + break :brk .{ + try Cipher.initTls13(cipher_suite, secret, .client), + try Cipher.initTls13(cipher_suite, secret, .server), + }; + }; + + var conn1 = connection(TestStream{ .inner_stream = &inner_stream }); + conn1.cipher = cipher_client; + + var conn2 = connection(TestStream{ .inner_stream = &inner_stream }); + conn2.cipher = cipher_server; + + var prng = std.Random.DefaultPrng.init(0); + const random = prng.random(); + var send_buf: [buf_len]u8 = undefined; + var recv_buf: [buf_len]u8 = undefined; + random.bytes(&send_buf); // fill send buffer with random bytes + + for (0..16) |_| { + const n = buf_len; //random.uintLessThan(usize, buf_len); + + const sent = send_buf[0..n]; + try conn1.writeAll(sent); + const r = try conn2.readAll(&recv_buf); + const received = recv_buf[0..r]; + + try testing.expectEqual(n, r); + try testing.expectEqualSlices(u8, sent, received); + } +} diff --git a/src/http/async/tls.zig/handshake_client.zig b/src/http/async/tls.zig/handshake_client.zig new file mode 100644 index 00000000..e7b48cf6 --- /dev/null +++ b/src/http/async/tls.zig/handshake_client.zig @@ -0,0 +1,955 @@ +const std = @import("std"); +const assert = std.debug.assert; +const crypto = std.crypto; +const mem = std.mem; +const Certificate = crypto.Certificate; + +const cipher = @import("cipher.zig"); +const Cipher = cipher.Cipher; +const CipherSuite = cipher.CipherSuite; +const cipher_suites = cipher.cipher_suites; +const Transcript = @import("transcript.zig").Transcript; +const record = @import("record.zig"); +const rsa = @import("rsa/rsa.zig"); +const key_log = @import("key_log.zig"); +const PrivateKey = @import("PrivateKey.zig"); +const proto = @import("protocol.zig"); + +const common = @import("handshake_common.zig"); +const dupe = common.dupe; +const CertificateBuilder = common.CertificateBuilder; +const CertificateParser = common.CertificateParser; +const DhKeyPair = common.DhKeyPair; +const CertBundle = common.CertBundle; +const CertKeyPair = common.CertKeyPair; + +pub const Options = struct { + host: []const u8, + /// Set of root certificate authorities that clients use when verifying + /// server certificates. + root_ca: CertBundle, + + /// Controls whether a client verifies the server's certificate chain and + /// host name. + insecure_skip_verify: bool = false, + + /// List of cipher suites to use. + /// To use just tls 1.3 cipher suites: + /// .cipher_suites = &tls.CipherSuite.tls13, + /// To select particular cipher suite: + /// .cipher_suites = &[_]tls.CipherSuite{tls.CipherSuite.CHACHA20_POLY1305_SHA256}, + cipher_suites: []const CipherSuite = cipher_suites.all, + + /// List of named groups to use. + /// To use specific named group: + /// .named_groups = &[_]tls.NamedGroup{.secp384r1}, + named_groups: []const proto.NamedGroup = supported_named_groups, + + /// Client authentication certificates and private key. + auth: ?CertKeyPair = null, + + /// If this structure is provided it will be filled with handshake attributes + /// at the end of the handshake process. + diagnostic: ?*Diagnostic = null, + + /// For logging current connection tls keys, so we can share them with + /// Wireshark and analyze decrypted traffic there. + key_log_callback: ?key_log.Callback = null, + + pub const Diagnostic = struct { + tls_version: proto.Version = @enumFromInt(0), + cipher_suite_tag: CipherSuite = @enumFromInt(0), + named_group: proto.NamedGroup = @enumFromInt(0), + signature_scheme: proto.SignatureScheme = @enumFromInt(0), + client_signature_scheme: proto.SignatureScheme = @enumFromInt(0), + }; +}; + +const supported_named_groups = &[_]proto.NamedGroup{ + .x25519, + .secp256r1, + .secp384r1, + .x25519_kyber768d00, +}; + +/// Handshake parses tls server message and creates client messages. Collects +/// tls attributes: server random, cipher suite and so on. Client messages are +/// created using provided buffer. Provided record reader is used to get tls +/// record when needed. +pub fn Handshake(comptime Stream: type) type { + const RecordReaderT = record.Reader(Stream); + return struct { + client_random: [32]u8, + server_random: [32]u8 = undefined, + master_secret: [48]u8 = undefined, + key_material: [48 * 4]u8 = undefined, // for sha256 32 * 4 is filled, for sha384 48 * 4 + + transcript: Transcript = .{}, + cipher_suite: CipherSuite = @enumFromInt(0), + named_group: ?proto.NamedGroup = null, + dh_kp: DhKeyPair, + rsa_secret: RsaSecret, + tls_version: proto.Version = .tls_1_2, + cipher: Cipher = undefined, + cert: CertificateParser = undefined, + client_certificate_requested: bool = false, + // public key len: x25519 = 32, secp256r1 = 65, secp384r1 = 97, x25519_kyber768d00 = 1120 + server_pub_key_buf: [2048]u8 = undefined, + server_pub_key: []const u8 = undefined, + + rec_rdr: *RecordReaderT, // tls record reader + buffer: []u8, // scratch buffer used in all messages creation + + const HandshakeT = @This(); + + pub fn init(buf: []u8, rec_rdr: *RecordReaderT) HandshakeT { + return .{ + .client_random = undefined, + .dh_kp = undefined, + .rsa_secret = undefined, + //.now_sec = std.time.timestamp(), + .buffer = buf, + .rec_rdr = rec_rdr, + }; + } + + fn initKeys( + h: *HandshakeT, + named_groups: []const proto.NamedGroup, + ) !void { + const init_keys_buf_len = 32 + 46 + DhKeyPair.seed_len; + var buf: [init_keys_buf_len]u8 = undefined; + crypto.random.bytes(&buf); + + h.client_random = buf[0..32].*; + h.rsa_secret = RsaSecret.init(buf[32..][0..46].*); + h.dh_kp = try DhKeyPair.init(buf[32 + 46 ..][0..DhKeyPair.seed_len].*, named_groups); + } + + /// Handshake exchanges messages with server to get agreement about + /// cryptographic parameters. That upgrades existing client-server + /// connection to TLS connection. Returns cipher used in application for + /// encrypted message exchange. + /// + /// Handles TLS 1.2 and TLS 1.3 connections. After initial client hello + /// server chooses in its server hello which TLS version will be used. + /// + /// TLS 1.2 handshake messages exchange: + /// Client Server + /// -------------------------------------------------------------- + /// ClientHello client flight 1 ---> + /// ServerHello + /// Certificate + /// ServerKeyExchange + /// CertificateRequest* + /// <--- server flight 1 ServerHelloDone + /// Certificate* + /// ClientKeyExchange + /// CertificateVerify* + /// ChangeCipherSpec + /// Finished client flight 2 ---> + /// ChangeCipherSpec + /// <--- server flight 2 Finished + /// + /// TLS 1.3 handshake messages exchange: + /// Client Server + /// -------------------------------------------------------------- + /// ClientHello client flight 1 ---> + /// ServerHello + /// {EncryptedExtensions} + /// {CertificateRequest*} + /// {Certificate} + /// {CertificateVerify} + /// <--- server flight 1 {Finished} + /// ChangeCipherSpec + /// {Certificate*} + /// {CertificateVerify*} + /// Finished client flight 2 ---> + /// + /// * - optional + /// {} - encrypted + /// + /// References: + /// https://datatracker.ietf.org/doc/html/rfc5246#section-7.3 + /// https://datatracker.ietf.org/doc/html/rfc8446#section-2 + /// + pub fn handshake(h: *HandshakeT, w: Stream, opt: Options) !Cipher { + defer h.updateDiagnostic(opt); + try h.initKeys(opt.named_groups); + h.cert = .{ + .host = opt.host, + .root_ca = opt.root_ca.bundle, + .skip_verify = opt.insecure_skip_verify, + }; + + try w.writeAll(try h.makeClientHello(opt)); // client flight 1 + try h.readServerFlight1(); // server flight 1 + h.transcript.use(h.cipher_suite.hash()); + + // tls 1.3 specific handshake part + if (h.tls_version == .tls_1_3) { + try h.generateHandshakeCipher(opt.key_log_callback); + try h.readEncryptedServerFlight1(); // server flight 1 + const app_cipher = try h.generateApplicationCipher(opt.key_log_callback); + try w.writeAll(try h.makeClientFlight2Tls13(opt.auth)); // client flight 2 + return app_cipher; + } + + // tls 1.2 specific handshake part + try h.generateCipher(opt.key_log_callback); + try w.writeAll(try h.makeClientFlight2Tls12(opt.auth)); // client flight 2 + try h.readServerFlight2(); // server flight 2 + return h.cipher; + } + + /// Prepare key material and generate cipher for TLS 1.2 + fn generateCipher(h: *HandshakeT, key_log_callback: ?key_log.Callback) !void { + try h.verifyCertificateSignatureTls12(); + try h.generateKeyMaterial(key_log_callback); + h.cipher = try Cipher.initTls12(h.cipher_suite, &h.key_material, .client); + } + + /// Generate TLS 1.2 pre master secret, master secret and key material. + fn generateKeyMaterial(h: *HandshakeT, key_log_callback: ?key_log.Callback) !void { + const pre_master_secret = if (h.named_group) |named_group| + try h.dh_kp.sharedKey(named_group, h.server_pub_key) + else + &h.rsa_secret.secret; + + _ = dupe( + &h.master_secret, + h.transcript.masterSecret(pre_master_secret, h.client_random, h.server_random), + ); + _ = dupe( + &h.key_material, + h.transcript.keyMaterial(&h.master_secret, h.client_random, h.server_random), + ); + if (key_log_callback) |cb| { + cb(key_log.label.client_random, &h.client_random, &h.master_secret); + } + } + + /// TLS 1.3 cipher used during handshake + fn generateHandshakeCipher(h: *HandshakeT, key_log_callback: ?key_log.Callback) !void { + const shared_key = try h.dh_kp.sharedKey(h.named_group.?, h.server_pub_key); + const handshake_secret = h.transcript.handshakeSecret(shared_key); + if (key_log_callback) |cb| { + cb(key_log.label.server_handshake_traffic_secret, &h.client_random, handshake_secret.server); + cb(key_log.label.client_handshake_traffic_secret, &h.client_random, handshake_secret.client); + } + h.cipher = try Cipher.initTls13(h.cipher_suite, handshake_secret, .client); + } + + /// TLS 1.3 application (client) cipher + fn generateApplicationCipher(h: *HandshakeT, key_log_callback: ?key_log.Callback) !Cipher { + const application_secret = h.transcript.applicationSecret(); + if (key_log_callback) |cb| { + cb(key_log.label.server_traffic_secret_0, &h.client_random, application_secret.server); + cb(key_log.label.client_traffic_secret_0, &h.client_random, application_secret.client); + } + return try Cipher.initTls13(h.cipher_suite, application_secret, .client); + } + + fn makeClientHello(h: *HandshakeT, opt: Options) ![]const u8 { + // Buffer will have this parts: + // | header | payload | extensions | + // + // Header will be written last because we need to know length of + // payload and extensions when creating it. Payload has + // extensions length (u16) as last element. + // + var buffer = h.buffer; + const header_len = 9; // tls record header (5 bytes) and handshake header (4 bytes) + const tls_versions = try CipherSuite.versions(opt.cipher_suites); + // Payload writer, preserve header_len bytes for handshake header. + var payload = record.Writer{ .buf = buffer[header_len..] }; + try payload.writeEnum(proto.Version.tls_1_2); + try payload.write(&h.client_random); + try payload.writeByte(0); // no session id + try payload.writeEnumArray(CipherSuite, opt.cipher_suites); + try payload.write(&[_]u8{ 0x01, 0x00 }); // no compression + + // Extensions writer starts after payload and preserves 2 more + // bytes for extension len in payload. + var ext = record.Writer{ .buf = buffer[header_len + payload.pos + 2 ..] }; + try ext.writeExtension(.supported_versions, switch (tls_versions) { + .both => &[_]proto.Version{ .tls_1_3, .tls_1_2 }, + .tls_1_3 => &[_]proto.Version{.tls_1_3}, + .tls_1_2 => &[_]proto.Version{.tls_1_2}, + }); + try ext.writeExtension(.signature_algorithms, common.supported_signature_algorithms); + + try ext.writeExtension(.supported_groups, opt.named_groups); + if (tls_versions != .tls_1_2) { + var keys: [supported_named_groups.len][]const u8 = undefined; + for (opt.named_groups, 0..) |ng, i| { + keys[i] = try h.dh_kp.publicKey(ng); + } + try ext.writeKeyShare(opt.named_groups, keys[0..opt.named_groups.len]); + } + try ext.writeServerName(opt.host); + + // Extensions length at the end of the payload. + try payload.writeInt(@as(u16, @intCast(ext.pos))); + + // Header at the start of the buffer. + const body_len = payload.pos + ext.pos; + buffer[0..header_len].* = record.header(.handshake, 4 + body_len) ++ + record.handshakeHeader(.client_hello, body_len); + + const msg = buffer[0 .. header_len + body_len]; + h.transcript.update(msg[record.header_len..]); + return msg; + } + + /// Process first flight of the messages from the server. + /// Read server hello message. If TLS 1.3 is chosen in server hello + /// return. For TLS 1.2 continue and read certificate, key_exchange + /// eventual certificate request and hello done messages. + fn readServerFlight1(h: *HandshakeT) !void { + var handshake_states: []const proto.Handshake = &.{.server_hello}; + + while (true) { + var d = try h.rec_rdr.nextDecoder(); + try d.expectContentType(.handshake); + + h.transcript.update(d.payload); + + // Multiple handshake messages can be packed in single tls record. + while (!d.eof()) { + const handshake_type = try d.decode(proto.Handshake); + + const length = try d.decode(u24); + if (length > cipher.max_cleartext_len) + return error.TlsUnsupportedFragmentedHandshakeMessage; + + brk: { + for (handshake_states) |state| + if (state == handshake_type) break :brk; + return error.TlsUnexpectedMessage; + } + switch (handshake_type) { + .server_hello => { // server hello, ref: https://datatracker.ietf.org/doc/html/rfc5246#section-7.4.1.3 + try h.parseServerHello(&d, length); + if (h.tls_version == .tls_1_3) { + if (!d.eof()) return error.TlsIllegalParameter; + return; // end of tls 1.3 server flight 1 + } + handshake_states = if (h.cert.skip_verify) + &.{ .certificate, .server_key_exchange, .server_hello_done } + else + &.{.certificate}; + }, + .certificate => { + try h.cert.parseCertificate(&d, h.tls_version); + handshake_states = if (h.cipher_suite.keyExchange() == .rsa) + &.{.server_hello_done} + else + &.{.server_key_exchange}; + }, + .server_key_exchange => { + try h.parseServerKeyExchange(&d); + handshake_states = &.{ .certificate_request, .server_hello_done }; + }, + .certificate_request => { + h.client_certificate_requested = true; + try d.skip(length); + handshake_states = &.{.server_hello_done}; + }, + .server_hello_done => { + if (length != 0) return error.TlsIllegalParameter; + return; + }, + else => return error.TlsUnexpectedMessage, + } + } + } + } + + /// Parse server hello message. + fn parseServerHello(h: *HandshakeT, d: *record.Decoder, length: u24) !void { + if (try d.decode(proto.Version) != proto.Version.tls_1_2) + return error.TlsBadVersion; + h.server_random = try d.array(32); + if (isServerHelloRetryRequest(&h.server_random)) + return error.TlsServerHelloRetryRequest; + + const session_id_len = try d.decode(u8); + if (session_id_len > 32) return error.TlsIllegalParameter; + try d.skip(session_id_len); + + h.cipher_suite = try d.decode(CipherSuite); + try h.cipher_suite.validate(); + try d.skip(1); // skip compression method + + const extensions_present = length > 2 + 32 + 1 + session_id_len + 2 + 1; + if (extensions_present) { + const exs_len = try d.decode(u16); + var l: usize = 0; + while (l < exs_len) { + const typ = try d.decode(proto.Extension); + const len = try d.decode(u16); + defer l += len + 4; + + switch (typ) { + .supported_versions => { + switch (try d.decode(proto.Version)) { + .tls_1_2, .tls_1_3 => |v| h.tls_version = v, + else => return error.TlsIllegalParameter, + } + if (len != 2) return error.TlsIllegalParameter; + }, + .key_share => { + h.named_group = try d.decode(proto.NamedGroup); + h.server_pub_key = dupe(&h.server_pub_key_buf, try d.slice(try d.decode(u16))); + if (len != h.server_pub_key.len + 4) return error.TlsIllegalParameter; + }, + else => { + try d.skip(len); + }, + } + } + } + } + + fn isServerHelloRetryRequest(server_random: []const u8) bool { + // Ref: https://datatracker.ietf.org/doc/html/rfc8446#section-4.1.3 + const hello_retry_request_magic = [32]u8{ + 0xCF, 0x21, 0xAD, 0x74, 0xE5, 0x9A, 0x61, 0x11, 0xBE, 0x1D, 0x8C, 0x02, 0x1E, 0x65, 0xB8, 0x91, + 0xC2, 0xA2, 0x11, 0x16, 0x7A, 0xBB, 0x8C, 0x5E, 0x07, 0x9E, 0x09, 0xE2, 0xC8, 0xA8, 0x33, 0x9C, + }; + return std.mem.eql(u8, server_random, &hello_retry_request_magic); + } + + fn parseServerKeyExchange(h: *HandshakeT, d: *record.Decoder) !void { + const curve_type = try d.decode(proto.Curve); + h.named_group = try d.decode(proto.NamedGroup); + h.server_pub_key = dupe(&h.server_pub_key_buf, try d.slice(try d.decode(u8))); + h.cert.signature_scheme = try d.decode(proto.SignatureScheme); + h.cert.signature = dupe(&h.cert.signature_buf, try d.slice(try d.decode(u16))); + if (curve_type != .named_curve) return error.TlsIllegalParameter; + } + + /// Read encrypted part (after server hello) of the server first flight + /// for TLS 1.3: change cipher spec, eventual certificate request, + /// certificate, certificate verify and handshake finished messages. + fn readEncryptedServerFlight1(h: *HandshakeT) !void { + var cleartext_buf = h.buffer; + var cleartext_buf_head: usize = 0; + var cleartext_buf_tail: usize = 0; + var handshake_states: []const proto.Handshake = &.{.encrypted_extensions}; + + outer: while (true) { + // wrapped record decoder + const rec = (try h.rec_rdr.next() orelse return error.EndOfStream); + if (rec.protocol_version != .tls_1_2) return error.TlsBadVersion; + switch (rec.content_type) { + .change_cipher_spec => {}, + .application_data => { + const content_type, const cleartext = try h.cipher.decrypt( + cleartext_buf[cleartext_buf_tail..], + rec, + ); + cleartext_buf_tail += cleartext.len; + if (cleartext_buf_tail > cleartext_buf.len) return error.TlsRecordOverflow; + + var d = record.Decoder.init(content_type, cleartext_buf[cleartext_buf_head..cleartext_buf_tail]); + try d.expectContentType(.handshake); + while (!d.eof()) { + const start_idx = d.idx; + const handshake_type = try d.decode(proto.Handshake); + const length = try d.decode(u24); + + // std.debug.print("handshake loop: {} {} {} {}\n", .{ handshake_type, length, d.payload.len, d.idx }); + if (length > cipher.max_cleartext_len) + return error.TlsUnsupportedFragmentedHandshakeMessage; + if (length > d.rest().len) + continue :outer; // fragmented handshake into multiple records + + defer { + const handshake_payload = d.payload[start_idx..d.idx]; + h.transcript.update(handshake_payload); + cleartext_buf_head += handshake_payload.len; + } + + brk: { + for (handshake_states) |state| + if (state == handshake_type) break :brk; + return error.TlsUnexpectedMessage; + } + switch (handshake_type) { + .encrypted_extensions => { + try d.skip(length); + handshake_states = if (h.cert.skip_verify) + &.{ .certificate_request, .certificate, .finished } + else + &.{ .certificate_request, .certificate }; + }, + .certificate_request => { + h.client_certificate_requested = true; + try d.skip(length); + handshake_states = if (h.cert.skip_verify) + &.{ .certificate, .finished } + else + &.{.certificate}; + }, + .certificate => { + try h.cert.parseCertificate(&d, h.tls_version); + handshake_states = &.{.certificate_verify}; + }, + .certificate_verify => { + try h.cert.parseCertificateVerify(&d); + try h.cert.verifySignature(h.transcript.serverCertificateVerify()); + handshake_states = &.{.finished}; + }, + .finished => { + const actual = try d.slice(length); + var buf: [Transcript.max_mac_length]u8 = undefined; + const expected = h.transcript.serverFinishedTls13(&buf); + if (!mem.eql(u8, expected, actual)) + return error.TlsDecryptError; + return; + }, + else => return error.TlsUnexpectedMessage, + } + } + cleartext_buf_head = 0; + cleartext_buf_tail = 0; + }, + else => return error.TlsUnexpectedMessage, + } + } + } + + fn verifyCertificateSignatureTls12(h: *HandshakeT) !void { + if (h.cipher_suite.keyExchange() != .ecdhe) return; + const verify_bytes = brk: { + var w = record.Writer{ .buf = h.buffer }; + try w.write(&h.client_random); + try w.write(&h.server_random); + try w.writeEnum(proto.Curve.named_curve); + try w.writeEnum(h.named_group.?); + try w.writeInt(@as(u8, @intCast(h.server_pub_key.len))); + try w.write(h.server_pub_key); + break :brk w.getWritten(); + }; + try h.cert.verifySignature(verify_bytes); + } + + /// Create client key exchange, change cipher spec and handshake + /// finished messages for tls 1.2. + /// If client certificate is requested also adds client certificate and + /// certificate verify messages. + fn makeClientFlight2Tls12(h: *HandshakeT, auth: ?CertKeyPair) ![]const u8 { + var w = record.Writer{ .buf = h.buffer }; + var cert_builder: ?CertificateBuilder = null; + + // Client certificate message + if (h.client_certificate_requested) { + if (auth) |a| { + const cb = h.certificateBuilder(a); + cert_builder = cb; + const client_certificate = try cb.makeCertificate(w.getPayload()); + h.transcript.update(client_certificate); + try w.advanceRecord(.handshake, client_certificate.len); + } else { + const empty_certificate = &record.handshakeHeader(.certificate, 3) ++ [_]u8{ 0, 0, 0 }; + h.transcript.update(empty_certificate); + try w.writeRecord(.handshake, empty_certificate); + } + } + + // Client key exchange message + { + const key_exchange = try h.makeClientKeyExchange(w.getPayload()); + h.transcript.update(key_exchange); + try w.advanceRecord(.handshake, key_exchange.len); + } + + // Client certificate verify message + if (cert_builder) |cb| { + const certificate_verify = try cb.makeCertificateVerify(w.getPayload()); + h.transcript.update(certificate_verify); + try w.advanceRecord(.handshake, certificate_verify.len); + } + + // Client change cipher spec message + try w.writeRecord(.change_cipher_spec, &[_]u8{1}); + + // Client handshake finished message + { + const client_finished = &record.handshakeHeader(.finished, 12) ++ + h.transcript.clientFinishedTls12(&h.master_secret); + h.transcript.update(client_finished); + try h.writeEncrypted(&w, client_finished); + } + + return w.getWritten(); + } + + /// Create client change cipher spec and handshake finished messages for + /// tls 1.3. + /// If the client certificate is requested by the server and client is + /// configured with certificates and private key then client certificate + /// and client certificate verify messages are also created. If the + /// server has requested certificate but the client is not configured + /// empty certificate message is sent, as is required by rfc. + fn makeClientFlight2Tls13(h: *HandshakeT, auth: ?CertKeyPair) ![]const u8 { + var w = record.Writer{ .buf = h.buffer }; + + // Client change cipher spec message + try w.writeRecord(.change_cipher_spec, &[_]u8{1}); + + if (h.client_certificate_requested) { + if (auth) |a| { + const cb = h.certificateBuilder(a); + { + const certificate = try cb.makeCertificate(w.getPayload()); + h.transcript.update(certificate); + try h.writeEncrypted(&w, certificate); + } + { + const certificate_verify = try cb.makeCertificateVerify(w.getPayload()); + h.transcript.update(certificate_verify); + try h.writeEncrypted(&w, certificate_verify); + } + } else { + // Empty certificate message and no certificate verify message + const empty_certificate = &record.handshakeHeader(.certificate, 4) ++ [_]u8{ 0, 0, 0, 0 }; + h.transcript.update(empty_certificate); + try h.writeEncrypted(&w, empty_certificate); + } + } + + // Client handshake finished message + { + const client_finished = try h.makeClientFinishedTls13(w.getPayload()); + h.transcript.update(client_finished); + try h.writeEncrypted(&w, client_finished); + } + + return w.getWritten(); + } + + fn certificateBuilder(h: *HandshakeT, auth: CertKeyPair) CertificateBuilder { + return .{ + .bundle = auth.bundle, + .key = auth.key, + .transcript = &h.transcript, + .tls_version = h.tls_version, + .side = .client, + }; + } + + fn makeClientFinishedTls13(h: *HandshakeT, buf: []u8) ![]const u8 { + var w = record.Writer{ .buf = buf }; + const verify_data = h.transcript.clientFinishedTls13(w.getHandshakePayload()); + try w.advanceHandshake(.finished, verify_data.len); + return w.getWritten(); + } + + fn makeClientKeyExchange(h: *HandshakeT, buf: []u8) ![]const u8 { + var w = record.Writer{ .buf = buf }; + if (h.named_group) |named_group| { + const key = try h.dh_kp.publicKey(named_group); + try w.writeHandshakeHeader(.client_key_exchange, 1 + key.len); + try w.writeInt(@as(u8, @intCast(key.len))); + try w.write(key); + } else { + const key = try h.rsa_secret.encrypted(h.cert.pub_key_algo, h.cert.pub_key); + try w.writeHandshakeHeader(.client_key_exchange, 2 + key.len); + try w.writeInt(@as(u16, @intCast(key.len))); + try w.write(key); + } + return w.getWritten(); + } + + fn readServerFlight2(h: *HandshakeT) !void { + // Read server change cipher spec message. + { + var d = try h.rec_rdr.nextDecoder(); + try d.expectContentType(.change_cipher_spec); + } + // Read encrypted server handshake finished message. Verify that + // content of the server finished message is based on transcript + // hash and master secret. + { + const content_type, const server_finished = + try h.rec_rdr.nextDecrypt(&h.cipher) orelse return error.EndOfStream; + if (content_type != .handshake) + return error.TlsUnexpectedMessage; + const expected = record.handshakeHeader(.finished, 12) ++ h.transcript.serverFinishedTls12(&h.master_secret); + if (!mem.eql(u8, server_finished, &expected)) + return error.TlsBadRecordMac; + } + } + + /// Write encrypted handshake message into `w` + fn writeEncrypted(h: *HandshakeT, w: *record.Writer, cleartext: []const u8) !void { + const ciphertext = try h.cipher.encrypt(w.getFree(), .handshake, cleartext); + w.pos += ciphertext.len; + } + + // Copy handshake parameters to opt.diagnostic + fn updateDiagnostic(h: *HandshakeT, opt: Options) void { + if (opt.diagnostic) |d| { + d.tls_version = h.tls_version; + d.cipher_suite_tag = h.cipher_suite; + d.named_group = h.named_group orelse @as(proto.NamedGroup, @enumFromInt(0x0000)); + d.signature_scheme = h.cert.signature_scheme; + if (opt.auth) |a| + d.client_signature_scheme = a.key.signature_scheme; + } + } + }; +} + +const RsaSecret = struct { + secret: [48]u8, + + fn init(rand: [46]u8) RsaSecret { + return .{ .secret = [_]u8{ 0x03, 0x03 } ++ rand }; + } + + // Pre master secret encrypted with certificate public key. + inline fn encrypted( + self: RsaSecret, + cert_pub_key_algo: Certificate.Parsed.PubKeyAlgo, + cert_pub_key: []const u8, + ) ![]const u8 { + if (cert_pub_key_algo != .rsaEncryption) return error.TlsBadSignatureScheme; + const pk = try rsa.PublicKey.fromDer(cert_pub_key); + var out: [512]u8 = undefined; + return try pk.encryptPkcsv1_5(&self.secret, &out); + } +}; + +const testing = std.testing; +const data12 = @import("testdata/tls12.zig"); +const data13 = @import("testdata/tls13.zig"); +const testu = @import("testu.zig"); + +fn testReader(data: []const u8) record.Reader(std.io.FixedBufferStream([]const u8)) { + return record.reader(std.io.fixedBufferStream(data)); +} +const TestHandshake = Handshake(std.io.FixedBufferStream([]const u8)); + +test "parse tls 1.2 server hello" { + var h = brk: { + var buffer: [1024]u8 = undefined; + var rec_rdr = testReader(&data12.server_hello_responses); + break :brk TestHandshake.init(&buffer, &rec_rdr); + }; + + // Set to known instead of random + h.client_random = data12.client_random; + h.dh_kp.x25519_kp.secret_key = data12.client_secret; + + // Parse server hello, certificate and key exchange messages. + // Read cipher suite, named group, signature scheme, server random certificate public key + // Verify host name, signature + // Calculate key material + h.cert = .{ .host = "example.ulfheim.net", .skip_verify = true, .root_ca = .{} }; + try h.readServerFlight1(); + try testing.expectEqual(.ECDHE_RSA_WITH_AES_128_CBC_SHA, h.cipher_suite); + try testing.expectEqual(.x25519, h.named_group.?); + try testing.expectEqual(.rsa_pkcs1_sha256, h.cert.signature_scheme); + try testing.expectEqualSlices(u8, &data12.server_random, &h.server_random); + try testing.expectEqualSlices(u8, &data12.server_pub_key, h.server_pub_key); + try testing.expectEqualSlices(u8, &data12.signature, h.cert.signature); + try testing.expectEqualSlices(u8, &data12.cert_pub_key, h.cert.pub_key); + + try h.verifyCertificateSignatureTls12(); + try h.generateKeyMaterial(null); + + try testing.expectEqualSlices(u8, &data12.key_material, h.key_material[0..data12.key_material.len]); +} + +test "verify google.com certificate" { + var h = brk: { + var buffer: [1024]u8 = undefined; + var rec_rdr = testReader(@embedFile("testdata/google.com/server_hello")); + break :brk TestHandshake.init(&buffer, &rec_rdr); + }; + h.client_random = @embedFile("testdata/google.com/client_random").*; + + var ca_bundle: Certificate.Bundle = .{}; + try ca_bundle.rescan(testing.allocator); + defer ca_bundle.deinit(testing.allocator); + + h.cert = .{ .host = "google.com", .skip_verify = true, .root_ca = .{}, .now_sec = 1714846451 }; + try h.readServerFlight1(); + try h.verifyCertificateSignatureTls12(); +} + +test "parse tls 1.3 server hello" { + var rec_rdr = testReader(&data13.server_hello); + var d = (try rec_rdr.nextDecoder()); + + const handshake_type = try d.decode(proto.Handshake); + const length = try d.decode(u24); + try testing.expectEqual(0x000076, length); + try testing.expectEqual(.server_hello, handshake_type); + + var h = TestHandshake.init(undefined, undefined); + try h.parseServerHello(&d, length); + + try testing.expectEqual(.AES_256_GCM_SHA384, h.cipher_suite); + try testing.expectEqualSlices(u8, &data13.server_random, &h.server_random); + try testing.expectEqual(.tls_1_3, h.tls_version); + try testing.expectEqual(.x25519, h.named_group); + try testing.expectEqualSlices(u8, &data13.server_pub_key, h.server_pub_key); +} + +test "init tls 1.3 handshake cipher" { + const cipher_suite_tag: CipherSuite = .AES_256_GCM_SHA384; + + var transcript = Transcript{}; + transcript.use(cipher_suite_tag.hash()); + transcript.update(data13.client_hello[record.header_len..]); + transcript.update(data13.server_hello[record.header_len..]); + + var dh_kp = DhKeyPair{ + .x25519_kp = .{ + .public_key = data13.client_public_key, + .secret_key = data13.client_private_key, + }, + }; + const shared_key = try dh_kp.sharedKey(.x25519, &data13.server_pub_key); + try testing.expectEqualSlices(u8, &data13.shared_key, shared_key); + + const cph = try Cipher.initTls13(cipher_suite_tag, transcript.handshakeSecret(shared_key), .client); + + const c = &cph.AES_256_GCM_SHA384; + try testing.expectEqualSlices(u8, &data13.server_handshake_key, &c.decrypt_key); + try testing.expectEqualSlices(u8, &data13.client_handshake_key, &c.encrypt_key); + try testing.expectEqualSlices(u8, &data13.server_handshake_iv, &c.decrypt_iv); + try testing.expectEqualSlices(u8, &data13.client_handshake_iv, &c.encrypt_iv); +} + +fn initExampleHandshake(h: *TestHandshake) !void { + h.cipher_suite = .AES_256_GCM_SHA384; + h.transcript.use(h.cipher_suite.hash()); + h.transcript.update(data13.client_hello[record.header_len..]); + h.transcript.update(data13.server_hello[record.header_len..]); + h.cipher = try Cipher.initTls13(h.cipher_suite, h.transcript.handshakeSecret(&data13.shared_key), .client); + h.tls_version = .tls_1_3; + h.cert.now_sec = 1714846451; + h.server_pub_key = &data13.server_pub_key; +} + +test "tls 1.3 decrypt wrapped record" { + var cph = brk: { + var h = TestHandshake.init(undefined, undefined); + try initExampleHandshake(&h); + break :brk h.cipher; + }; + + var cleartext_buf: [1024]u8 = undefined; + { + const rec = record.Record.init(&data13.server_encrypted_extensions_wrapped); + + const content_type, const cleartext = try cph.decrypt(&cleartext_buf, rec); + try testing.expectEqual(.handshake, content_type); + try testing.expectEqualSlices(u8, &data13.server_encrypted_extensions, cleartext); + } + { + const rec = record.Record.init(&data13.server_certificate_wrapped); + + const content_type, const cleartext = try cph.decrypt(&cleartext_buf, rec); + try testing.expectEqual(.handshake, content_type); + try testing.expectEqualSlices(u8, &data13.server_certificate, cleartext); + } +} + +test "tls 1.3 process server flight" { + var buffer: [1024]u8 = undefined; + var h = brk: { + var rec_rdr = testReader(&data13.server_flight); + break :brk TestHandshake.init(&buffer, &rec_rdr); + }; + + try initExampleHandshake(&h); + h.cert = .{ .host = "example.ulfheim.net", .skip_verify = true, .root_ca = .{} }; + try h.readEncryptedServerFlight1(); + + { // application cipher keys calculation + try testing.expectEqualSlices(u8, &data13.handshake_hash, &h.transcript.sha384.hash.peek()); + + var cph = try Cipher.initTls13(h.cipher_suite, h.transcript.applicationSecret(), .client); + const c = &cph.AES_256_GCM_SHA384; + try testing.expectEqualSlices(u8, &data13.server_application_key, &c.decrypt_key); + try testing.expectEqualSlices(u8, &data13.client_application_key, &c.encrypt_key); + try testing.expectEqualSlices(u8, &data13.server_application_iv, &c.decrypt_iv); + try testing.expectEqualSlices(u8, &data13.client_application_iv, &c.encrypt_iv); + + const encrypted = try cph.encrypt(&buffer, .application_data, "ping"); + try testing.expectEqualSlices(u8, &data13.client_ping_wrapped, encrypted); + } + { // client finished message + var buf: [4 + Transcript.max_mac_length]u8 = undefined; + const client_finished = try h.makeClientFinishedTls13(&buf); + try testing.expectEqualSlices(u8, &data13.client_finished_verify_data, client_finished[4..]); + const encrypted = try h.cipher.encrypt(&buffer, .handshake, client_finished); + try testing.expectEqualSlices(u8, &data13.client_finished_wrapped, encrypted); + } +} + +test "create client hello" { + var h = brk: { + var buffer: [1024]u8 = undefined; + var h = TestHandshake.init(&buffer, undefined); + h.client_random = testu.hexToBytes( + \\ 00 01 02 03 04 05 06 07 08 09 0a 0b 0c 0d 0e 0f 10 11 12 13 14 15 16 17 18 19 1a 1b 1c 1d 1e 1f + ); + break :brk h; + }; + + const actual = try h.makeClientHello(.{ + .host = "google.com", + .root_ca = .{}, + .cipher_suites = &[_]CipherSuite{CipherSuite.ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, + .named_groups = &[_]proto.NamedGroup{ .x25519, .secp256r1, .secp384r1 }, + }); + + const expected = testu.hexToBytes( + "16 03 03 00 6d " ++ // record header + "01 00 00 69 " ++ // handshake header + "03 03 " ++ // protocol version + "00 01 02 03 04 05 06 07 08 09 0a 0b 0c 0d 0e 0f 10 11 12 13 14 15 16 17 18 19 1a 1b 1c 1d 1e 1f " ++ // client random + "00 " ++ // no session id + "00 02 c0 2b " ++ // cipher suites + "01 00 " ++ // compression methods + "00 3e " ++ // extensions length + "00 2b 00 03 02 03 03 " ++ // supported versions extension + "00 0d 00 14 00 12 04 03 05 03 08 04 08 05 08 06 08 07 02 01 04 01 05 01 " ++ // signature algorithms extension + "00 0a 00 08 00 06 00 1d 00 17 00 18 " ++ // named groups extension + "00 00 00 0f 00 0d 00 00 0a 67 6f 6f 67 6c 65 2e 63 6f 6d ", // server name extension + ); + try testing.expectEqualSlices(u8, &expected, actual); +} + +test "handshake verify server finished message" { + var buffer: [1024]u8 = undefined; + var rec_rdr = testReader(&data12.server_handshake_finished_msgs); + var h = TestHandshake.init(&buffer, &rec_rdr); + + h.cipher_suite = .ECDHE_ECDSA_WITH_AES_128_CBC_SHA; + h.master_secret = data12.master_secret; + + // add handshake messages to the transcript + for (data12.handshake_messages) |msg| { + h.transcript.update(msg[record.header_len..]); + } + + // expect verify data + const client_finished = h.transcript.clientFinishedTls12(&h.master_secret); + try testing.expectEqualSlices(u8, &data12.client_finished, &record.handshakeHeader(.finished, 12) ++ client_finished); + + // init client with prepared key_material + h.cipher = try Cipher.initTls12(.ECDHE_RSA_WITH_AES_128_CBC_SHA, &data12.key_material, .client); + + // check that server verify data matches calculates from hashes of all handshake messages + h.transcript.update(&data12.client_finished); + try h.readServerFlight2(); +} diff --git a/src/http/async/tls.zig/handshake_common.zig b/src/http/async/tls.zig/handshake_common.zig new file mode 100644 index 00000000..178a3cea --- /dev/null +++ b/src/http/async/tls.zig/handshake_common.zig @@ -0,0 +1,448 @@ +const std = @import("std"); +const assert = std.debug.assert; +const mem = std.mem; +const crypto = std.crypto; +const Certificate = crypto.Certificate; + +const Transcript = @import("transcript.zig").Transcript; +const PrivateKey = @import("PrivateKey.zig"); +const record = @import("record.zig"); +const rsa = @import("rsa/rsa.zig"); +const proto = @import("protocol.zig"); + +const X25519 = crypto.dh.X25519; +const EcdsaP256Sha256 = crypto.sign.ecdsa.EcdsaP256Sha256; +const EcdsaP384Sha384 = crypto.sign.ecdsa.EcdsaP384Sha384; +const Kyber768 = crypto.kem.kyber_d00.Kyber768; + +pub const supported_signature_algorithms = &[_]proto.SignatureScheme{ + .ecdsa_secp256r1_sha256, + .ecdsa_secp384r1_sha384, + .rsa_pss_rsae_sha256, + .rsa_pss_rsae_sha384, + .rsa_pss_rsae_sha512, + .ed25519, + .rsa_pkcs1_sha1, + .rsa_pkcs1_sha256, + .rsa_pkcs1_sha384, +}; + +pub const CertKeyPair = struct { + /// A chain of one or more certificates, leaf first. + /// + /// Each X.509 certificate contains the public key of a key pair, extra + /// information (the name of the holder, the name of an issuer of the + /// certificate, validity time spans) and a signature generated using the + /// private key of the issuer of the certificate. + /// + /// All certificates from the bundle are sent to the other side when creating + /// Certificate tls message. + /// + /// Leaf certificate and private key are used to create signature for + /// CertifyVerify tls message. + bundle: Certificate.Bundle, + + /// Private key corresponding to the public key in leaf certificate from the + /// bundle. + key: PrivateKey, + + pub fn load( + allocator: std.mem.Allocator, + dir: std.fs.Dir, + cert_path: []const u8, + key_path: []const u8, + ) !CertKeyPair { + var bundle: Certificate.Bundle = .{}; + try bundle.addCertsFromFilePath(allocator, dir, cert_path); + + const key_file = try dir.openFile(key_path, .{}); + defer key_file.close(); + const key = try PrivateKey.fromFile(allocator, key_file); + + return .{ .bundle = bundle, .key = key }; + } + + pub fn deinit(c: *CertKeyPair, allocator: std.mem.Allocator) void { + c.bundle.deinit(allocator); + } +}; + +pub const CertBundle = struct { + // A chain of one or more certificates. + // + // They are used to verify that certificate chain sent by the other side + // forms valid trust chain. + bundle: Certificate.Bundle = .{}, + + pub fn fromFile(allocator: std.mem.Allocator, dir: std.fs.Dir, path: []const u8) !CertBundle { + var bundle: Certificate.Bundle = .{}; + try bundle.addCertsFromFilePath(allocator, dir, path); + return .{ .bundle = bundle }; + } + + pub fn fromSystem(allocator: std.mem.Allocator) !CertBundle { + var bundle: Certificate.Bundle = .{}; + try bundle.rescan(allocator); + return .{ .bundle = bundle }; + } + + pub fn deinit(cb: *CertBundle, allocator: std.mem.Allocator) void { + cb.bundle.deinit(allocator); + } +}; + +pub const CertificateBuilder = struct { + bundle: Certificate.Bundle, + key: PrivateKey, + transcript: *Transcript, + tls_version: proto.Version = .tls_1_3, + side: proto.Side = .client, + + pub fn makeCertificate(h: CertificateBuilder, buf: []u8) ![]const u8 { + var w = record.Writer{ .buf = buf }; + const certs = h.bundle.bytes.items; + const certs_count = h.bundle.map.size; + + // Differences between tls 1.3 and 1.2 + // TLS 1.3 has request context in header and extensions for each certificate. + // Here we use empty length for each field. + // TLS 1.2 don't have these two fields. + const request_context, const extensions = if (h.tls_version == .tls_1_3) + .{ &[_]u8{0}, &[_]u8{ 0, 0 } } + else + .{ &[_]u8{}, &[_]u8{} }; + const certs_len = certs.len + (3 + extensions.len) * certs_count; + + // Write handshake header + try w.writeHandshakeHeader(.certificate, certs_len + request_context.len + 3); + try w.write(request_context); + try w.writeInt(@as(u24, @intCast(certs_len))); + + // Write each certificate + var index: u32 = 0; + while (index < certs.len) { + const e = try Certificate.der.Element.parse(certs, index); + const cert = certs[index..e.slice.end]; + try w.writeInt(@as(u24, @intCast(cert.len))); // certificate length + try w.write(cert); // certificate + try w.write(extensions); // certificate extensions + index = e.slice.end; + } + return w.getWritten(); + } + + pub fn makeCertificateVerify(h: CertificateBuilder, buf: []u8) ![]const u8 { + var w = record.Writer{ .buf = buf }; + const signature, const signature_scheme = try h.createSignature(); + try w.writeHandshakeHeader(.certificate_verify, signature.len + 4); + try w.writeEnum(signature_scheme); + try w.writeInt(@as(u16, @intCast(signature.len))); + try w.write(signature); + return w.getWritten(); + } + + /// Creates signature for client certificate signature message. + /// Returns signature bytes and signature scheme. + inline fn createSignature(h: CertificateBuilder) !struct { []const u8, proto.SignatureScheme } { + switch (h.key.signature_scheme) { + inline .ecdsa_secp256r1_sha256, + .ecdsa_secp384r1_sha384, + => |comptime_scheme| { + const Ecdsa = SchemeEcdsa(comptime_scheme); + const key = h.key.key.ecdsa; + const key_len = Ecdsa.SecretKey.encoded_length; + if (key.len < key_len) return error.InvalidEncoding; + const secret_key = try Ecdsa.SecretKey.fromBytes(key[0..key_len].*); + const key_pair = try Ecdsa.KeyPair.fromSecretKey(secret_key); + var signer = try key_pair.signer(null); + h.setSignatureVerifyBytes(&signer); + const signature = try signer.finalize(); + var buf: [Ecdsa.Signature.der_encoded_length_max]u8 = undefined; + return .{ signature.toDer(&buf), comptime_scheme }; + }, + inline .rsa_pss_rsae_sha256, + .rsa_pss_rsae_sha384, + .rsa_pss_rsae_sha512, + => |comptime_scheme| { + const Hash = SchemeHash(comptime_scheme); + var signer = try h.key.key.rsa.signerOaep(Hash, null); + h.setSignatureVerifyBytes(&signer); + var buf: [512]u8 = undefined; + const signature = try signer.finalize(&buf); + return .{ signature.bytes, comptime_scheme }; + }, + else => return error.TlsUnknownSignatureScheme, + } + } + + fn setSignatureVerifyBytes(h: CertificateBuilder, signer: anytype) void { + if (h.tls_version == .tls_1_2) { + // tls 1.2 signature uses current transcript hash value. + // ref: https://datatracker.ietf.org/doc/html/rfc5246.html#section-7.4.8 + const Hash = @TypeOf(signer.h); + signer.h = h.transcript.hash(Hash); + } else { + // tls 1.3 signature is computed over concatenation of 64 spaces, + // context, separator and content. + // ref: https://datatracker.ietf.org/doc/html/rfc8446#section-4.4.3 + if (h.side == .server) { + signer.update(h.transcript.serverCertificateVerify()); + } else { + signer.update(h.transcript.clientCertificateVerify()); + } + } + } + + fn SchemeEcdsa(comptime scheme: proto.SignatureScheme) type { + return switch (scheme) { + .ecdsa_secp256r1_sha256 => EcdsaP256Sha256, + .ecdsa_secp384r1_sha384 => EcdsaP384Sha384, + else => unreachable, + }; + } +}; + +pub const CertificateParser = struct { + pub_key_algo: Certificate.Parsed.PubKeyAlgo = undefined, + pub_key_buf: [600]u8 = undefined, + pub_key: []const u8 = undefined, + + signature_scheme: proto.SignatureScheme = @enumFromInt(0), + signature_buf: [1024]u8 = undefined, + signature: []const u8 = undefined, + + root_ca: Certificate.Bundle, + host: []const u8, + skip_verify: bool = false, + now_sec: i64 = 0, + + pub fn parseCertificate(h: *CertificateParser, d: *record.Decoder, tls_version: proto.Version) !void { + if (h.now_sec == 0) { + h.now_sec = std.time.timestamp(); + } + if (tls_version == .tls_1_3) { + const request_context = try d.decode(u8); + if (request_context != 0) return error.TlsIllegalParameter; + } + + var trust_chain_established = false; + var last_cert: ?Certificate.Parsed = null; + const certs_len = try d.decode(u24); + const start_idx = d.idx; + while (d.idx - start_idx < certs_len) { + const cert_len = try d.decode(u24); + // std.debug.print("=> {} {} {} {}\n", .{ certs_len, d.idx, cert_len, d.payload.len }); + const cert = try d.slice(cert_len); + if (tls_version == .tls_1_3) { + // certificate extensions present in tls 1.3 + try d.skip(try d.decode(u16)); + } + if (trust_chain_established) + continue; + + const subject = try (Certificate{ .buffer = cert, .index = 0 }).parse(); + if (last_cert) |pc| { + if (pc.verify(subject, h.now_sec)) { + last_cert = subject; + } else |err| switch (err) { + error.CertificateIssuerMismatch => { + // skip certificate which is not part of the chain + continue; + }, + else => return err, + } + } else { // first certificate + if (!h.skip_verify and h.host.len > 0) { + try subject.verifyHostName(h.host); + } + h.pub_key = dupe(&h.pub_key_buf, subject.pubKey()); + h.pub_key_algo = subject.pub_key_algo; + last_cert = subject; + } + if (!h.skip_verify) { + if (h.root_ca.verify(last_cert.?, h.now_sec)) |_| { + trust_chain_established = true; + } else |err| switch (err) { + error.CertificateIssuerNotFound => {}, + else => return err, + } + } + } + if (!h.skip_verify and !trust_chain_established) { + return error.CertificateIssuerNotFound; + } + } + + pub fn parseCertificateVerify(h: *CertificateParser, d: *record.Decoder) !void { + h.signature_scheme = try d.decode(proto.SignatureScheme); + h.signature = dupe(&h.signature_buf, try d.slice(try d.decode(u16))); + } + + pub fn verifySignature(h: *CertificateParser, verify_bytes: []const u8) !void { + switch (h.signature_scheme) { + inline .ecdsa_secp256r1_sha256, + .ecdsa_secp384r1_sha384, + => |comptime_scheme| { + if (h.pub_key_algo != .X9_62_id_ecPublicKey) return error.TlsBadSignatureScheme; + const cert_named_curve = h.pub_key_algo.X9_62_id_ecPublicKey; + switch (cert_named_curve) { + inline .secp384r1, .X9_62_prime256v1 => |comptime_cert_named_curve| { + const Ecdsa = SchemeEcdsaCert(comptime_scheme, comptime_cert_named_curve); + const key = try Ecdsa.PublicKey.fromSec1(h.pub_key); + const sig = try Ecdsa.Signature.fromDer(h.signature); + try sig.verify(verify_bytes, key); + }, + else => return error.TlsUnknownSignatureScheme, + } + }, + .ed25519 => { + if (h.pub_key_algo != .curveEd25519) return error.TlsBadSignatureScheme; + const Eddsa = crypto.sign.Ed25519; + if (h.signature.len != Eddsa.Signature.encoded_length) return error.InvalidEncoding; + const sig = Eddsa.Signature.fromBytes(h.signature[0..Eddsa.Signature.encoded_length].*); + if (h.pub_key.len != Eddsa.PublicKey.encoded_length) return error.InvalidEncoding; + const key = try Eddsa.PublicKey.fromBytes(h.pub_key[0..Eddsa.PublicKey.encoded_length].*); + try sig.verify(verify_bytes, key); + }, + inline .rsa_pss_rsae_sha256, + .rsa_pss_rsae_sha384, + .rsa_pss_rsae_sha512, + => |comptime_scheme| { + if (h.pub_key_algo != .rsaEncryption) return error.TlsBadSignatureScheme; + const Hash = SchemeHash(comptime_scheme); + const pk = try rsa.PublicKey.fromDer(h.pub_key); + const sig = rsa.Pss(Hash).Signature{ .bytes = h.signature }; + try sig.verify(verify_bytes, pk, null); + }, + inline .rsa_pkcs1_sha1, + .rsa_pkcs1_sha256, + .rsa_pkcs1_sha384, + .rsa_pkcs1_sha512, + => |comptime_scheme| { + if (h.pub_key_algo != .rsaEncryption) return error.TlsBadSignatureScheme; + const Hash = SchemeHash(comptime_scheme); + const pk = try rsa.PublicKey.fromDer(h.pub_key); + const sig = rsa.PKCS1v1_5(Hash).Signature{ .bytes = h.signature }; + try sig.verify(verify_bytes, pk); + }, + else => return error.TlsUnknownSignatureScheme, + } + } + + fn SchemeEcdsaCert(comptime scheme: proto.SignatureScheme, comptime cert_named_curve: Certificate.NamedCurve) type { + const Sha256 = crypto.hash.sha2.Sha256; + const Sha384 = crypto.hash.sha2.Sha384; + const Ecdsa = crypto.sign.ecdsa.Ecdsa; + + return switch (scheme) { + .ecdsa_secp256r1_sha256 => Ecdsa(cert_named_curve.Curve(), Sha256), + .ecdsa_secp384r1_sha384 => Ecdsa(cert_named_curve.Curve(), Sha384), + else => @compileError("bad scheme"), + }; + } +}; + +fn SchemeHash(comptime scheme: proto.SignatureScheme) type { + const Sha256 = crypto.hash.sha2.Sha256; + const Sha384 = crypto.hash.sha2.Sha384; + const Sha512 = crypto.hash.sha2.Sha512; + + return switch (scheme) { + .rsa_pkcs1_sha1 => crypto.hash.Sha1, + .rsa_pss_rsae_sha256, .rsa_pkcs1_sha256 => Sha256, + .rsa_pss_rsae_sha384, .rsa_pkcs1_sha384 => Sha384, + .rsa_pss_rsae_sha512, .rsa_pkcs1_sha512 => Sha512, + else => @compileError("bad scheme"), + }; +} + +pub fn dupe(buf: []u8, data: []const u8) []u8 { + const n = @min(data.len, buf.len); + @memcpy(buf[0..n], data[0..n]); + return buf[0..n]; +} + +pub const DhKeyPair = struct { + x25519_kp: X25519.KeyPair = undefined, + secp256r1_kp: EcdsaP256Sha256.KeyPair = undefined, + secp384r1_kp: EcdsaP384Sha384.KeyPair = undefined, + kyber768_kp: Kyber768.KeyPair = undefined, + + pub const seed_len = 32 + 32 + 48 + 64; + + pub fn init(seed: [seed_len]u8, named_groups: []const proto.NamedGroup) !DhKeyPair { + var kp: DhKeyPair = .{}; + for (named_groups) |ng| + switch (ng) { + .x25519 => kp.x25519_kp = try X25519.KeyPair.create(seed[0..][0..X25519.seed_length].*), + .secp256r1 => kp.secp256r1_kp = try EcdsaP256Sha256.KeyPair.create(seed[32..][0..EcdsaP256Sha256.KeyPair.seed_length].*), + .secp384r1 => kp.secp384r1_kp = try EcdsaP384Sha384.KeyPair.create(seed[32 + 32 ..][0..EcdsaP384Sha384.KeyPair.seed_length].*), + .x25519_kyber768d00 => kp.kyber768_kp = try Kyber768.KeyPair.create(seed[32 + 32 + 48 ..][0..Kyber768.seed_length].*), + else => return error.TlsIllegalParameter, + }; + return kp; + } + + pub inline fn sharedKey(self: DhKeyPair, named_group: proto.NamedGroup, server_pub_key: []const u8) ![]const u8 { + return switch (named_group) { + .x25519 => brk: { + if (server_pub_key.len != X25519.public_length) + return error.TlsIllegalParameter; + break :brk &(try X25519.scalarmult( + self.x25519_kp.secret_key, + server_pub_key[0..X25519.public_length].*, + )); + }, + .secp256r1 => brk: { + const pk = try EcdsaP256Sha256.PublicKey.fromSec1(server_pub_key); + const mul = try pk.p.mulPublic(self.secp256r1_kp.secret_key.bytes, .big); + break :brk &mul.affineCoordinates().x.toBytes(.big); + }, + .secp384r1 => brk: { + const pk = try EcdsaP384Sha384.PublicKey.fromSec1(server_pub_key); + const mul = try pk.p.mulPublic(self.secp384r1_kp.secret_key.bytes, .big); + break :brk &mul.affineCoordinates().x.toBytes(.big); + }, + .x25519_kyber768d00 => brk: { + const xksl = crypto.dh.X25519.public_length; + const hksl = xksl + Kyber768.ciphertext_length; + if (server_pub_key.len != hksl) + return error.TlsIllegalParameter; + + break :brk &((crypto.dh.X25519.scalarmult( + self.x25519_kp.secret_key, + server_pub_key[0..xksl].*, + ) catch return error.TlsDecryptFailure) ++ (self.kyber768_kp.secret_key.decaps( + server_pub_key[xksl..hksl], + ) catch return error.TlsDecryptFailure)); + }, + else => return error.TlsIllegalParameter, + }; + } + + // Returns 32, 65, 97 or 1216 bytes + pub inline fn publicKey(self: DhKeyPair, named_group: proto.NamedGroup) ![]const u8 { + return switch (named_group) { + .x25519 => &self.x25519_kp.public_key, + .secp256r1 => &self.secp256r1_kp.public_key.toUncompressedSec1(), + .secp384r1 => &self.secp384r1_kp.public_key.toUncompressedSec1(), + .x25519_kyber768d00 => &self.x25519_kp.public_key ++ self.kyber768_kp.public_key.toBytes(), + else => return error.TlsIllegalParameter, + }; + } +}; + +const testing = std.testing; +const testu = @import("testu.zig"); + +test "DhKeyPair.x25519" { + var seed: [DhKeyPair.seed_len]u8 = undefined; + testu.fill(&seed); + const server_pub_key = &testu.hexToBytes("3303486548531f08d91e675caf666c2dc924ac16f47a861a7f4d05919d143637"); + const expected = &testu.hexToBytes( + \\ F1 67 FB 4A 49 B2 91 77 08 29 45 A1 F7 08 5A 21 + \\ AF FE 9E 78 C2 03 9B 81 92 40 72 73 74 7A 46 1E + ); + const kp = try DhKeyPair.init(seed, &.{.x25519}); + try testing.expectEqualSlices(u8, expected, try kp.sharedKey(.x25519, server_pub_key)); +} diff --git a/src/http/async/tls.zig/handshake_server.zig b/src/http/async/tls.zig/handshake_server.zig new file mode 100644 index 00000000..c26e8c69 --- /dev/null +++ b/src/http/async/tls.zig/handshake_server.zig @@ -0,0 +1,520 @@ +const std = @import("std"); +const assert = std.debug.assert; +const crypto = std.crypto; +const mem = std.mem; +const Certificate = crypto.Certificate; + +const cipher = @import("cipher.zig"); +const Cipher = cipher.Cipher; +const CipherSuite = @import("cipher.zig").CipherSuite; +const cipher_suites = @import("cipher.zig").cipher_suites; +const Transcript = @import("transcript.zig").Transcript; +const record = @import("record.zig"); +const PrivateKey = @import("PrivateKey.zig"); +const proto = @import("protocol.zig"); + +const common = @import("handshake_common.zig"); +const dupe = common.dupe; +const CertificateBuilder = common.CertificateBuilder; +const CertificateParser = common.CertificateParser; +const DhKeyPair = common.DhKeyPair; +const CertBundle = common.CertBundle; +const CertKeyPair = common.CertKeyPair; + +pub const Options = struct { + /// Server authentication. If null server will not send Certificate and + /// CertificateVerify message. + auth: ?CertKeyPair, + + /// If not null server will request client certificate. If auth_type is + /// .request empty client certificate message will be accepted. + /// Client certificate will be verified with root_ca certificates. + client_auth: ?ClientAuth = null, +}; + +pub const ClientAuth = struct { + /// Set of root certificate authorities that server use when verifying + /// client certificates. + root_ca: CertBundle, + + auth_type: Type = .require, + + pub const Type = enum { + /// Client certificate will be requested during the handshake, but does + /// not require that the client send any certificates. + request, + /// Client certificate will be requested during the handshake, and client + /// has to send valid certificate. + require, + }; +}; + +pub fn Handshake(comptime Stream: type) type { + const RecordReaderT = record.Reader(Stream); + return struct { + // public key len: x25519 = 32, secp256r1 = 65, secp384r1 = 97 + const max_pub_key_len = 98; + const supported_named_groups = &[_]proto.NamedGroup{ .x25519, .secp256r1, .secp384r1 }; + + server_random: [32]u8 = undefined, + client_random: [32]u8 = undefined, + legacy_session_id_buf: [32]u8 = undefined, + legacy_session_id: []u8 = "", + cipher_suite: CipherSuite = @enumFromInt(0), + signature_scheme: proto.SignatureScheme = @enumFromInt(0), + named_group: proto.NamedGroup = @enumFromInt(0), + client_pub_key_buf: [max_pub_key_len]u8 = undefined, + client_pub_key: []u8 = "", + server_pub_key_buf: [max_pub_key_len]u8 = undefined, + server_pub_key: []u8 = "", + + cipher: Cipher = undefined, + transcript: Transcript = .{}, + rec_rdr: *RecordReaderT, + buffer: []u8, + + const HandshakeT = @This(); + + pub fn init(buf: []u8, rec_rdr: *RecordReaderT) HandshakeT { + return .{ + .rec_rdr = rec_rdr, + .buffer = buf, + }; + } + + fn writeAlert(h: *HandshakeT, stream: Stream, cph: ?*Cipher, err: anyerror) !void { + if (cph) |c| { + const cleartext = proto.alertFromError(err); + const ciphertext = try c.encrypt(h.buffer, .alert, &cleartext); + stream.writeAll(ciphertext) catch {}; + } else { + const alert = record.header(.alert, 2) ++ proto.alertFromError(err); + stream.writeAll(&alert) catch {}; + } + } + + pub fn handshake(h: *HandshakeT, stream: Stream, opt: Options) !Cipher { + crypto.random.bytes(&h.server_random); + if (opt.auth) |a| { + // required signature scheme in client hello + h.signature_scheme = a.key.signature_scheme; + } + + h.readClientHello() catch |err| { + try h.writeAlert(stream, null, err); + return err; + }; + h.transcript.use(h.cipher_suite.hash()); + + const server_flight = brk: { + var w = record.Writer{ .buf = h.buffer }; + + const shared_key = h.sharedKey() catch |err| { + try h.writeAlert(stream, null, err); + return err; + }; + { + const hello = try h.makeServerHello(w.getFree()); + h.transcript.update(hello[record.header_len..]); + w.pos += hello.len; + } + { + const handshake_secret = h.transcript.handshakeSecret(shared_key); + h.cipher = try Cipher.initTls13(h.cipher_suite, handshake_secret, .server); + } + try w.writeRecord(.change_cipher_spec, &[_]u8{1}); + { + const encrypted_extensions = &record.handshakeHeader(.encrypted_extensions, 2) ++ [_]u8{ 0, 0 }; + h.transcript.update(encrypted_extensions); + try h.writeEncrypted(&w, encrypted_extensions); + } + if (opt.client_auth) |_| { + const certificate_request = try makeCertificateRequest(w.getPayload()); + h.transcript.update(certificate_request); + try h.writeEncrypted(&w, certificate_request); + } + if (opt.auth) |a| { + const cm = CertificateBuilder{ + .bundle = a.bundle, + .key = a.key, + .transcript = &h.transcript, + .side = .server, + }; + { + const certificate = try cm.makeCertificate(w.getPayload()); + h.transcript.update(certificate); + try h.writeEncrypted(&w, certificate); + } + { + const certificate_verify = try cm.makeCertificateVerify(w.getPayload()); + h.transcript.update(certificate_verify); + try h.writeEncrypted(&w, certificate_verify); + } + } + { + const finished = try h.makeFinished(w.getPayload()); + h.transcript.update(finished); + try h.writeEncrypted(&w, finished); + } + break :brk w.getWritten(); + }; + try stream.writeAll(server_flight); + + var app_cipher = brk: { + const application_secret = h.transcript.applicationSecret(); + break :brk try Cipher.initTls13(h.cipher_suite, application_secret, .server); + }; + + h.readClientFlight2(opt) catch |err| { + // Alert received from client + if (!mem.startsWith(u8, @errorName(err), "TlsAlert")) { + try h.writeAlert(stream, &app_cipher, err); + } + return err; + }; + return app_cipher; + } + + inline fn sharedKey(h: *HandshakeT) ![]const u8 { + var seed: [DhKeyPair.seed_len]u8 = undefined; + crypto.random.bytes(&seed); + var kp = try DhKeyPair.init(seed, supported_named_groups); + h.server_pub_key = dupe(&h.server_pub_key_buf, try kp.publicKey(h.named_group)); + return try kp.sharedKey(h.named_group, h.client_pub_key); + } + + fn readClientFlight2(h: *HandshakeT, opt: Options) !void { + var cleartext_buf = h.buffer; + var cleartext_buf_head: usize = 0; + var cleartext_buf_tail: usize = 0; + var handshake_state: proto.Handshake = .finished; + var cert: CertificateParser = undefined; + if (opt.client_auth) |client_auth| { + cert = .{ .root_ca = client_auth.root_ca.bundle, .host = "" }; + handshake_state = .certificate; + } + + outer: while (true) { + const rec = (try h.rec_rdr.next() orelse return error.EndOfStream); + if (rec.protocol_version != .tls_1_2 and rec.content_type != .alert) + return error.TlsProtocolVersion; + + switch (rec.content_type) { + .change_cipher_spec => { + if (rec.payload.len != 1) return error.TlsUnexpectedMessage; + }, + .application_data => { + const content_type, const cleartext = try h.cipher.decrypt( + cleartext_buf[cleartext_buf_tail..], + rec, + ); + cleartext_buf_tail += cleartext.len; + if (cleartext_buf_tail > cleartext_buf.len) return error.TlsRecordOverflow; + + var d = record.Decoder.init(content_type, cleartext_buf[cleartext_buf_head..cleartext_buf_tail]); + try d.expectContentType(.handshake); + while (!d.eof()) { + const start_idx = d.idx; + const handshake_type = try d.decode(proto.Handshake); + const length = try d.decode(u24); + + if (length > cipher.max_cleartext_len) + return error.TlsRecordOverflow; + if (length > d.rest().len) + continue :outer; // fragmented handshake into multiple records + + defer { + const handshake_payload = d.payload[start_idx..d.idx]; + h.transcript.update(handshake_payload); + cleartext_buf_head += handshake_payload.len; + } + + if (handshake_state != handshake_type) + return error.TlsUnexpectedMessage; + + switch (handshake_type) { + .certificate => { + if (length == 4) { + // got empty certificate message + if (opt.client_auth.?.auth_type == .require) + return error.TlsCertificateRequired; + try d.skip(length); + handshake_state = .finished; + } else { + try cert.parseCertificate(&d, .tls_1_3); + handshake_state = .certificate_verify; + } + }, + .certificate_verify => { + try cert.parseCertificateVerify(&d); + cert.verifySignature(h.transcript.clientCertificateVerify()) catch |err| return switch (err) { + error.TlsUnknownSignatureScheme => error.TlsIllegalParameter, + else => error.TlsDecryptError, + }; + handshake_state = .finished; + }, + .finished => { + const actual = try d.slice(length); + var buf: [Transcript.max_mac_length]u8 = undefined; + const expected = h.transcript.clientFinishedTls13(&buf); + if (!mem.eql(u8, expected, actual)) + return if (expected.len == actual.len) + error.TlsDecryptError + else + error.TlsDecodeError; + return; + }, + else => return error.TlsUnexpectedMessage, + } + } + cleartext_buf_head = 0; + cleartext_buf_tail = 0; + }, + .alert => { + var d = rec.decoder(); + return d.raiseAlert(); + }, + else => return error.TlsUnexpectedMessage, + } + } + } + + fn makeFinished(h: *HandshakeT, buf: []u8) ![]const u8 { + var w = record.Writer{ .buf = buf }; + const verify_data = h.transcript.serverFinishedTls13(w.getHandshakePayload()); + try w.advanceHandshake(.finished, verify_data.len); + return w.getWritten(); + } + + /// Write encrypted handshake message into `w` + fn writeEncrypted(h: *HandshakeT, w: *record.Writer, cleartext: []const u8) !void { + const ciphertext = try h.cipher.encrypt(w.getFree(), .handshake, cleartext); + w.pos += ciphertext.len; + } + + fn makeServerHello(h: *HandshakeT, buf: []u8) ![]const u8 { + const header_len = 9; // tls record header (5 bytes) and handshake header (4 bytes) + var w = record.Writer{ .buf = buf[header_len..] }; + + try w.writeEnum(proto.Version.tls_1_2); + try w.write(&h.server_random); + { + try w.writeInt(@as(u8, @intCast(h.legacy_session_id.len))); + if (h.legacy_session_id.len > 0) try w.write(h.legacy_session_id); + } + try w.writeEnum(h.cipher_suite); + try w.write(&[_]u8{0}); // compression method + + var e = record.Writer{ .buf = buf[header_len + w.pos + 2 ..] }; + { // supported versions extension + try e.writeEnum(proto.Extension.supported_versions); + try e.writeInt(@as(u16, 2)); + try e.writeEnum(proto.Version.tls_1_3); + } + { // key share extension + const key_len: u16 = @intCast(h.server_pub_key.len); + try e.writeEnum(proto.Extension.key_share); + try e.writeInt(key_len + 4); + try e.writeEnum(h.named_group); + try e.writeInt(key_len); + try e.write(h.server_pub_key); + } + try w.writeInt(@as(u16, @intCast(e.pos))); // extensions length + + const payload_len = w.pos + e.pos; + buf[0..header_len].* = record.header(.handshake, 4 + payload_len) ++ + record.handshakeHeader(.server_hello, payload_len); + + return buf[0 .. header_len + payload_len]; + } + + fn makeCertificateRequest(buf: []u8) ![]const u8 { + // handshake header + context length + extensions length + const header_len = 4 + 1 + 2; + + // First write extensions, leave space for header. + var ext = record.Writer{ .buf = buf[header_len..] }; + try ext.writeExtension(.signature_algorithms, common.supported_signature_algorithms); + + var w = record.Writer{ .buf = buf }; + try w.writeHandshakeHeader(.certificate_request, ext.pos + 3); + try w.writeInt(@as(u8, 0)); // certificate request context length = 0 + try w.writeInt(@as(u16, @intCast(ext.pos))); // extensions length + assert(w.pos == header_len); + w.pos += ext.pos; + + return w.getWritten(); + } + + fn readClientHello(h: *HandshakeT) !void { + var d = try h.rec_rdr.nextDecoder(); + try d.expectContentType(.handshake); + h.transcript.update(d.payload); + + const handshake_type = try d.decode(proto.Handshake); + if (handshake_type != .client_hello) return error.TlsUnexpectedMessage; + _ = try d.decode(u24); // handshake length + if (try d.decode(proto.Version) != .tls_1_2) return error.TlsProtocolVersion; + + h.client_random = try d.array(32); + { // legacy session id + const len = try d.decode(u8); + h.legacy_session_id = dupe(&h.legacy_session_id_buf, try d.slice(len)); + } + { // cipher suites + const end_idx = try d.decode(u16) + d.idx; + + while (d.idx < end_idx) { + const cipher_suite = try d.decode(CipherSuite); + if (cipher_suites.includes(cipher_suites.tls13, cipher_suite) and + @intFromEnum(h.cipher_suite) == 0) + { + h.cipher_suite = cipher_suite; + } + } + if (@intFromEnum(h.cipher_suite) == 0) + return error.TlsHandshakeFailure; + } + try d.skip(2); // compression methods + + var key_share_received = false; + // extensions + const extensions_end_idx = try d.decode(u16) + d.idx; + while (d.idx < extensions_end_idx) { + const extension_type = try d.decode(proto.Extension); + const extension_len = try d.decode(u16); + + switch (extension_type) { + .supported_versions => { + var tls_1_3_supported = false; + const end_idx = try d.decode(u8) + d.idx; + while (d.idx < end_idx) { + if (try d.decode(proto.Version) == proto.Version.tls_1_3) { + tls_1_3_supported = true; + } + } + if (!tls_1_3_supported) return error.TlsProtocolVersion; + }, + .key_share => { + if (extension_len == 0) return error.TlsDecodeError; + key_share_received = true; + var selected_named_group_idx = supported_named_groups.len; + const end_idx = try d.decode(u16) + d.idx; + while (d.idx < end_idx) { + const named_group = try d.decode(proto.NamedGroup); + switch (@intFromEnum(named_group)) { + 0x0001...0x0016, + 0x001a...0x001c, + 0xff01...0xff02, + => return error.TlsIllegalParameter, + else => {}, + } + const client_pub_key = try d.slice(try d.decode(u16)); + for (supported_named_groups, 0..) |supported, idx| { + if (named_group == supported and idx < selected_named_group_idx) { + h.named_group = named_group; + h.client_pub_key = dupe(&h.client_pub_key_buf, client_pub_key); + selected_named_group_idx = idx; + } + } + } + if (@intFromEnum(h.named_group) == 0) + return error.TlsIllegalParameter; + }, + .supported_groups => { + const end_idx = try d.decode(u16) + d.idx; + while (d.idx < end_idx) { + const named_group = try d.decode(proto.NamedGroup); + switch (@intFromEnum(named_group)) { + 0x0001...0x0016, + 0x001a...0x001c, + 0xff01...0xff02, + => return error.TlsIllegalParameter, + else => {}, + } + } + }, + .signature_algorithms => { + if (@intFromEnum(h.signature_scheme) == 0) { + try d.skip(extension_len); + } else { + var found = false; + const list_len = try d.decode(u16); + if (list_len == 0) return error.TlsDecodeError; + const end_idx = list_len + d.idx; + while (d.idx < end_idx) { + const signature_scheme = try d.decode(proto.SignatureScheme); + if (signature_scheme == h.signature_scheme) found = true; + } + if (!found) return error.TlsHandshakeFailure; + } + }, + else => { + try d.skip(extension_len); + }, + } + } + if (!key_share_received) return error.TlsMissingExtension; + if (@intFromEnum(h.named_group) == 0) return error.TlsIllegalParameter; + } + }; +} + +const testing = std.testing; +const data13 = @import("testdata/tls13.zig"); +const testu = @import("testu.zig"); + +fn testReader(data: []const u8) record.Reader(std.io.FixedBufferStream([]const u8)) { + return record.reader(std.io.fixedBufferStream(data)); +} +const TestHandshake = Handshake(std.io.FixedBufferStream([]const u8)); + +test "read client hello" { + var buffer: [1024]u8 = undefined; + var rec_rdr = testReader(&data13.client_hello); + var h = TestHandshake.init(&buffer, &rec_rdr); + h.signature_scheme = .ecdsa_secp521r1_sha512; // this must be supported in signature_algorithms extension + try h.readClientHello(); + + try testing.expectEqual(CipherSuite.AES_256_GCM_SHA384, h.cipher_suite); + try testing.expectEqual(.x25519, h.named_group); + try testing.expectEqualSlices(u8, &data13.client_random, &h.client_random); + try testing.expectEqualSlices(u8, &data13.client_public_key, h.client_pub_key); +} + +test "make server hello" { + var buffer: [128]u8 = undefined; + var h = TestHandshake.init(&buffer, undefined); + h.cipher_suite = .AES_256_GCM_SHA384; + testu.fillFrom(&h.server_random, 0); + testu.fillFrom(&h.server_pub_key_buf, 0x20); + h.named_group = .x25519; + h.server_pub_key = h.server_pub_key_buf[0..32]; + + const actual = try h.makeServerHello(&buffer); + const expected = &testu.hexToBytes( + \\ 16 03 03 00 5a 02 00 00 56 + \\ 03 03 + \\ 00 01 02 03 04 05 06 07 08 09 0a 0b 0c 0d 0e 0f 10 11 12 13 14 15 16 17 18 19 1a 1b 1c 1d 1e 1f + \\ 00 + \\ 13 02 00 + \\ 00 2e 00 2b 00 02 03 04 + \\ 00 33 00 24 00 1d 00 20 + \\ 20 21 22 23 24 25 26 27 28 29 2a 2b 2c 2d 2e 2f 30 31 32 33 34 35 36 37 38 39 3a 3b 3c 3d 3e 3f + ); + try testing.expectEqualSlices(u8, expected, actual); +} + +test "make certificate request" { + var buffer: [32]u8 = undefined; + + const expected = testu.hexToBytes("0d 00 00 1b" ++ // handshake header + "00 00 18" ++ // extension length + "00 0d" ++ // signature algorithms extension + "00 14" ++ // extension length + "00 12" ++ // list length 6 * 2 bytes + "04 03 05 03 08 04 08 05 08 06 08 07 02 01 04 01 05 01" // signature schemes + ); + const actual = try TestHandshake.makeCertificateRequest(&buffer); + try testing.expectEqualSlices(u8, &expected, actual); +} diff --git a/src/http/async/tls.zig/key_log.zig b/src/http/async/tls.zig/key_log.zig new file mode 100644 index 00000000..2da83f42 --- /dev/null +++ b/src/http/async/tls.zig/key_log.zig @@ -0,0 +1,60 @@ +//! Exporting tls key so we can share them with Wireshark and analyze decrypted +//! traffic in Wireshark. +//! To configure Wireshark to use exprted keys see curl reference. +//! +//! References: +//! curl: https://everything.curl.dev/usingcurl/tls/sslkeylogfile.html +//! openssl: https://www.openssl.org/docs/manmaster/man3/SSL_CTX_set_keylog_callback.html +//! https://udn.realityripple.com/docs/Mozilla/Projects/NSS/Key_Log_Format + +const std = @import("std"); + +const key_log_file_env = "SSLKEYLOGFILE"; + +pub const label = struct { + // tls 1.3 + pub const client_handshake_traffic_secret: []const u8 = "CLIENT_HANDSHAKE_TRAFFIC_SECRET"; + pub const server_handshake_traffic_secret: []const u8 = "SERVER_HANDSHAKE_TRAFFIC_SECRET"; + pub const client_traffic_secret_0: []const u8 = "CLIENT_TRAFFIC_SECRET_0"; + pub const server_traffic_secret_0: []const u8 = "SERVER_TRAFFIC_SECRET_0"; + // tls 1.2 + pub const client_random: []const u8 = "CLIENT_RANDOM"; +}; + +pub const Callback = *const fn (label: []const u8, client_random: []const u8, secret: []const u8) void; + +/// Writes tls keys to the file pointed by SSLKEYLOGFILE environment variable. +pub fn callback(label_: []const u8, client_random: []const u8, secret: []const u8) void { + if (std.posix.getenv(key_log_file_env)) |file_name| { + fileAppend(file_name, label_, client_random, secret) catch return; + } +} + +pub fn fileAppend(file_name: []const u8, label_: []const u8, client_random: []const u8, secret: []const u8) !void { + var buf: [1024]u8 = undefined; + const line = try formatLine(&buf, label_, client_random, secret); + try fileWrite(file_name, line); +} + +fn fileWrite(file_name: []const u8, line: []const u8) !void { + var file = try std.fs.createFileAbsolute(file_name, .{ .truncate = false }); + defer file.close(); + const stat = try file.stat(); + try file.seekTo(stat.size); + try file.writeAll(line); +} + +pub fn formatLine(buf: []u8, label_: []const u8, client_random: []const u8, secret: []const u8) ![]const u8 { + var fbs = std.io.fixedBufferStream(buf); + const w = fbs.writer(); + try w.print("{s} ", .{label_}); + for (client_random) |b| { + try std.fmt.formatInt(b, 16, .lower, .{ .width = 2, .fill = '0' }, w); + } + try w.writeByte(' '); + for (secret) |b| { + try std.fmt.formatInt(b, 16, .lower, .{ .width = 2, .fill = '0' }, w); + } + try w.writeByte('\n'); + return fbs.getWritten(); +} diff --git a/src/http/async/tls.zig/main.zig b/src/http/async/tls.zig/main.zig new file mode 100644 index 00000000..b974377b --- /dev/null +++ b/src/http/async/tls.zig/main.zig @@ -0,0 +1,51 @@ +const std = @import("std"); + +pub const CipherSuite = @import("cipher.zig").CipherSuite; +pub const cipher_suites = @import("cipher.zig").cipher_suites; +pub const PrivateKey = @import("PrivateKey.zig"); +pub const Connection = @import("connection.zig").Connection; +pub const ClientOptions = @import("handshake_client.zig").Options; +pub const ServerOptions = @import("handshake_server.zig").Options; +pub const key_log = @import("key_log.zig"); +pub const proto = @import("protocol.zig"); +pub const NamedGroup = proto.NamedGroup; +pub const Version = proto.Version; +const common = @import("handshake_common.zig"); +pub const CertBundle = common.CertBundle; +pub const CertKeyPair = common.CertKeyPair; + +pub const record = @import("record.zig"); +const connection = @import("connection.zig").connection; +const max_ciphertext_record_len = @import("cipher.zig").max_ciphertext_record_len; +const HandshakeServer = @import("handshake_server.zig").Handshake; +const HandshakeClient = @import("handshake_client.zig").Handshake; + +pub fn client(stream: anytype, opt: ClientOptions) !Connection(@TypeOf(stream)) { + const Stream = @TypeOf(stream); + var conn = connection(stream); + var write_buf: [max_ciphertext_record_len]u8 = undefined; + var h = HandshakeClient(Stream).init(&write_buf, &conn.rec_rdr); + conn.cipher = try h.handshake(conn.stream, opt); + return conn; +} + +pub fn server(stream: anytype, opt: ServerOptions) !Connection(@TypeOf(stream)) { + const Stream = @TypeOf(stream); + var conn = connection(stream); + var write_buf: [max_ciphertext_record_len]u8 = undefined; + var h = HandshakeServer(Stream).init(&write_buf, &conn.rec_rdr); + conn.cipher = try h.handshake(conn.stream, opt); + return conn; +} + +test { + _ = @import("handshake_common.zig"); + _ = @import("handshake_server.zig"); + _ = @import("handshake_client.zig"); + + _ = @import("connection.zig"); + _ = @import("cipher.zig"); + _ = @import("record.zig"); + _ = @import("transcript.zig"); + _ = @import("PrivateKey.zig"); +} diff --git a/src/http/async/tls.zig/protocol.zig b/src/http/async/tls.zig/protocol.zig new file mode 100644 index 00000000..e3bb07ac --- /dev/null +++ b/src/http/async/tls.zig/protocol.zig @@ -0,0 +1,302 @@ +pub const Version = enum(u16) { + tls_1_2 = 0x0303, + tls_1_3 = 0x0304, + _, +}; + +pub const ContentType = enum(u8) { + invalid = 0, + change_cipher_spec = 20, + alert = 21, + handshake = 22, + application_data = 23, + _, +}; + +pub const Handshake = enum(u8) { + client_hello = 1, + server_hello = 2, + new_session_ticket = 4, + end_of_early_data = 5, + encrypted_extensions = 8, + certificate = 11, + server_key_exchange = 12, + certificate_request = 13, + server_hello_done = 14, + certificate_verify = 15, + client_key_exchange = 16, + finished = 20, + key_update = 24, + message_hash = 254, + _, +}; + +pub const Curve = enum(u8) { + named_curve = 0x03, + _, +}; + +pub const Extension = enum(u16) { + /// RFC 6066 + server_name = 0, + /// RFC 6066 + max_fragment_length = 1, + /// RFC 6066 + status_request = 5, + /// RFC 8422, 7919 + supported_groups = 10, + /// RFC 8446 + signature_algorithms = 13, + /// RFC 5764 + use_srtp = 14, + /// RFC 6520 + heartbeat = 15, + /// RFC 7301 + application_layer_protocol_negotiation = 16, + /// RFC 6962 + signed_certificate_timestamp = 18, + /// RFC 7250 + client_certificate_type = 19, + /// RFC 7250 + server_certificate_type = 20, + /// RFC 7685 + padding = 21, + /// RFC 8446 + pre_shared_key = 41, + /// RFC 8446 + early_data = 42, + /// RFC 8446 + supported_versions = 43, + /// RFC 8446 + cookie = 44, + /// RFC 8446 + psk_key_exchange_modes = 45, + /// RFC 8446 + certificate_authorities = 47, + /// RFC 8446 + oid_filters = 48, + /// RFC 8446 + post_handshake_auth = 49, + /// RFC 8446 + signature_algorithms_cert = 50, + /// RFC 8446 + key_share = 51, + + _, +}; + +pub fn alertFromError(err: anyerror) [2]u8 { + return [2]u8{ @intFromEnum(Alert.Level.fatal), @intFromEnum(Alert.fromError(err)) }; +} + +pub const Alert = enum(u8) { + pub const Level = enum(u8) { + warning = 1, + fatal = 2, + _, + }; + + pub const Error = error{ + TlsAlertUnexpectedMessage, + TlsAlertBadRecordMac, + TlsAlertRecordOverflow, + TlsAlertHandshakeFailure, + TlsAlertBadCertificate, + TlsAlertUnsupportedCertificate, + TlsAlertCertificateRevoked, + TlsAlertCertificateExpired, + TlsAlertCertificateUnknown, + TlsAlertIllegalParameter, + TlsAlertUnknownCa, + TlsAlertAccessDenied, + TlsAlertDecodeError, + TlsAlertDecryptError, + TlsAlertProtocolVersion, + TlsAlertInsufficientSecurity, + TlsAlertInternalError, + TlsAlertInappropriateFallback, + TlsAlertMissingExtension, + TlsAlertUnsupportedExtension, + TlsAlertUnrecognizedName, + TlsAlertBadCertificateStatusResponse, + TlsAlertUnknownPskIdentity, + TlsAlertCertificateRequired, + TlsAlertNoApplicationProtocol, + TlsAlertUnknown, + }; + + close_notify = 0, + unexpected_message = 10, + bad_record_mac = 20, + record_overflow = 22, + handshake_failure = 40, + bad_certificate = 42, + unsupported_certificate = 43, + certificate_revoked = 44, + certificate_expired = 45, + certificate_unknown = 46, + illegal_parameter = 47, + unknown_ca = 48, + access_denied = 49, + decode_error = 50, + decrypt_error = 51, + protocol_version = 70, + insufficient_security = 71, + internal_error = 80, + inappropriate_fallback = 86, + user_canceled = 90, + missing_extension = 109, + unsupported_extension = 110, + unrecognized_name = 112, + bad_certificate_status_response = 113, + unknown_psk_identity = 115, + certificate_required = 116, + no_application_protocol = 120, + _, + + pub fn toError(alert: Alert) Error!void { + return switch (alert) { + .close_notify => {}, // not an error + .unexpected_message => error.TlsAlertUnexpectedMessage, + .bad_record_mac => error.TlsAlertBadRecordMac, + .record_overflow => error.TlsAlertRecordOverflow, + .handshake_failure => error.TlsAlertHandshakeFailure, + .bad_certificate => error.TlsAlertBadCertificate, + .unsupported_certificate => error.TlsAlertUnsupportedCertificate, + .certificate_revoked => error.TlsAlertCertificateRevoked, + .certificate_expired => error.TlsAlertCertificateExpired, + .certificate_unknown => error.TlsAlertCertificateUnknown, + .illegal_parameter => error.TlsAlertIllegalParameter, + .unknown_ca => error.TlsAlertUnknownCa, + .access_denied => error.TlsAlertAccessDenied, + .decode_error => error.TlsAlertDecodeError, + .decrypt_error => error.TlsAlertDecryptError, + .protocol_version => error.TlsAlertProtocolVersion, + .insufficient_security => error.TlsAlertInsufficientSecurity, + .internal_error => error.TlsAlertInternalError, + .inappropriate_fallback => error.TlsAlertInappropriateFallback, + .user_canceled => {}, // not an error + .missing_extension => error.TlsAlertMissingExtension, + .unsupported_extension => error.TlsAlertUnsupportedExtension, + .unrecognized_name => error.TlsAlertUnrecognizedName, + .bad_certificate_status_response => error.TlsAlertBadCertificateStatusResponse, + .unknown_psk_identity => error.TlsAlertUnknownPskIdentity, + .certificate_required => error.TlsAlertCertificateRequired, + .no_application_protocol => error.TlsAlertNoApplicationProtocol, + _ => error.TlsAlertUnknown, + }; + } + + pub fn fromError(err: anyerror) Alert { + return switch (err) { + error.TlsUnexpectedMessage => .unexpected_message, + error.TlsBadRecordMac => .bad_record_mac, + error.TlsRecordOverflow => .record_overflow, + error.TlsHandshakeFailure => .handshake_failure, + error.TlsBadCertificate => .bad_certificate, + error.TlsUnsupportedCertificate => .unsupported_certificate, + error.TlsCertificateRevoked => .certificate_revoked, + error.TlsCertificateExpired => .certificate_expired, + error.TlsCertificateUnknown => .certificate_unknown, + error.TlsIllegalParameter, + error.IdentityElement, + error.InvalidEncoding, + => .illegal_parameter, + error.TlsUnknownCa => .unknown_ca, + error.TlsAccessDenied => .access_denied, + error.TlsDecodeError => .decode_error, + error.TlsDecryptError => .decrypt_error, + error.TlsProtocolVersion => .protocol_version, + error.TlsInsufficientSecurity => .insufficient_security, + error.TlsInternalError => .internal_error, + error.TlsInappropriateFallback => .inappropriate_fallback, + error.TlsMissingExtension => .missing_extension, + error.TlsUnsupportedExtension => .unsupported_extension, + error.TlsUnrecognizedName => .unrecognized_name, + error.TlsBadCertificateStatusResponse => .bad_certificate_status_response, + error.TlsUnknownPskIdentity => .unknown_psk_identity, + error.TlsCertificateRequired => .certificate_required, + error.TlsNoApplicationProtocol => .no_application_protocol, + else => .internal_error, + }; + } + + pub fn parse(buf: [2]u8) Alert { + const level: Alert.Level = @enumFromInt(buf[0]); + const alert: Alert = @enumFromInt(buf[1]); + _ = level; + return alert; + } + + pub fn closeNotify() [2]u8 { + return [2]u8{ + @intFromEnum(Alert.Level.warning), + @intFromEnum(Alert.close_notify), + }; + } +}; + +pub const SignatureScheme = enum(u16) { + // RSASSA-PKCS1-v1_5 algorithms + rsa_pkcs1_sha256 = 0x0401, + rsa_pkcs1_sha384 = 0x0501, + rsa_pkcs1_sha512 = 0x0601, + + // ECDSA algorithms + ecdsa_secp256r1_sha256 = 0x0403, + ecdsa_secp384r1_sha384 = 0x0503, + ecdsa_secp521r1_sha512 = 0x0603, + + // RSASSA-PSS algorithms with public key OID rsaEncryption + rsa_pss_rsae_sha256 = 0x0804, + rsa_pss_rsae_sha384 = 0x0805, + rsa_pss_rsae_sha512 = 0x0806, + + // EdDSA algorithms + ed25519 = 0x0807, + ed448 = 0x0808, + + // RSASSA-PSS algorithms with public key OID RSASSA-PSS + rsa_pss_pss_sha256 = 0x0809, + rsa_pss_pss_sha384 = 0x080a, + rsa_pss_pss_sha512 = 0x080b, + + // Legacy algorithms + rsa_pkcs1_sha1 = 0x0201, + ecdsa_sha1 = 0x0203, + + _, +}; + +pub const NamedGroup = enum(u16) { + // Elliptic Curve Groups (ECDHE) + secp256r1 = 0x0017, + secp384r1 = 0x0018, + secp521r1 = 0x0019, + x25519 = 0x001D, + x448 = 0x001E, + + // Finite Field Groups (DHE) + ffdhe2048 = 0x0100, + ffdhe3072 = 0x0101, + ffdhe4096 = 0x0102, + ffdhe6144 = 0x0103, + ffdhe8192 = 0x0104, + + // Hybrid post-quantum key agreements + x25519_kyber512d00 = 0xFE30, + x25519_kyber768d00 = 0x6399, + + _, +}; + +pub const KeyUpdateRequest = enum(u8) { + update_not_requested = 0, + update_requested = 1, + _, +}; + +pub const Side = enum { + client, + server, +}; diff --git a/src/http/async/tls.zig/record.zig b/src/http/async/tls.zig/record.zig new file mode 100644 index 00000000..6c4df328 --- /dev/null +++ b/src/http/async/tls.zig/record.zig @@ -0,0 +1,405 @@ +const std = @import("std"); +const assert = std.debug.assert; +const mem = std.mem; + +const proto = @import("protocol.zig"); +const cipher = @import("cipher.zig"); +const Cipher = cipher.Cipher; +const record = @import("record.zig"); + +pub const header_len = 5; + +pub fn header(content_type: proto.ContentType, payload_len: usize) [header_len]u8 { + const int2 = std.crypto.tls.int2; + return [1]u8{@intFromEnum(content_type)} ++ + int2(@intFromEnum(proto.Version.tls_1_2)) ++ + int2(@intCast(payload_len)); +} + +pub fn handshakeHeader(handshake_type: proto.Handshake, payload_len: usize) [4]u8 { + const int3 = std.crypto.tls.int3; + return [1]u8{@intFromEnum(handshake_type)} ++ int3(@intCast(payload_len)); +} + +pub fn reader(inner_reader: anytype) Reader(@TypeOf(inner_reader)) { + return .{ .inner_reader = inner_reader }; +} + +pub fn Reader(comptime InnerReader: type) type { + return struct { + inner_reader: InnerReader, + + buffer: [cipher.max_ciphertext_record_len]u8 = undefined, + start: usize = 0, + end: usize = 0, + + const ReaderT = @This(); + + pub fn nextDecoder(r: *ReaderT) !Decoder { + const rec = (try r.next()) orelse return error.EndOfStream; + if (@intFromEnum(rec.protocol_version) != 0x0300 and + @intFromEnum(rec.protocol_version) != 0x0301 and + rec.protocol_version != .tls_1_2) + return error.TlsBadVersion; + return .{ + .content_type = rec.content_type, + .payload = rec.payload, + }; + } + + pub fn contentType(buf: []const u8) proto.ContentType { + return @enumFromInt(buf[0]); + } + + pub fn protocolVersion(buf: []const u8) proto.Version { + return @enumFromInt(mem.readInt(u16, buf[1..3], .big)); + } + + pub fn next(r: *ReaderT) !?Record { + while (true) { + const buffer = r.buffer[r.start..r.end]; + // If we have 5 bytes header. + if (buffer.len >= record.header_len) { + const record_header = buffer[0..record.header_len]; + const payload_len = mem.readInt(u16, record_header[3..5], .big); + if (payload_len > cipher.max_ciphertext_len) + return error.TlsRecordOverflow; + const record_len = record.header_len + payload_len; + // If we have whole record + if (buffer.len >= record_len) { + r.start += record_len; + return Record.init(buffer[0..record_len]); + } + } + { // Move dirty part to the start of the buffer. + const n = r.end - r.start; + if (n > 0 and r.start > 0) { + if (r.start > n) { + @memcpy(r.buffer[0..n], r.buffer[r.start..][0..n]); + } else { + mem.copyForwards(u8, r.buffer[0..n], r.buffer[r.start..][0..n]); + } + } + r.start = 0; + r.end = n; + } + { // Read more from inner_reader. + const n = try r.inner_reader.read(r.buffer[r.end..]); + if (n == 0) return null; + r.end += n; + } + } + } + + pub fn nextDecrypt(r: *ReaderT, cph: *Cipher) !?struct { proto.ContentType, []const u8 } { + const rec = (try r.next()) orelse return null; + if (rec.protocol_version != .tls_1_2) return error.TlsBadVersion; + + return try cph.decrypt( + // Reuse reader buffer for cleartext. `rec.header` and + // `rec.payload`(ciphertext) are also pointing somewhere in + // this buffer. Decrypter is first reading then writing a + // block, cleartext has less length then ciphertext, + // cleartext starts from the beginning of the buffer, so + // ciphertext is always ahead of cleartext. + r.buffer[0..r.start], + rec, + ); + } + + pub fn hasMore(r: *ReaderT) bool { + return r.end > r.start; + } + }; +} + +pub const Record = struct { + content_type: proto.ContentType, + protocol_version: proto.Version = .tls_1_2, + header: []const u8, + payload: []const u8, + + pub fn init(buffer: []const u8) Record { + return .{ + .content_type = @enumFromInt(buffer[0]), + .protocol_version = @enumFromInt(mem.readInt(u16, buffer[1..3], .big)), + .header = buffer[0..record.header_len], + .payload = buffer[record.header_len..], + }; + } + + pub fn decoder(r: @This()) Decoder { + return Decoder.init(r.content_type, @constCast(r.payload)); + } +}; + +pub const Decoder = struct { + content_type: proto.ContentType, + payload: []const u8, + idx: usize = 0, + + pub fn init(content_type: proto.ContentType, payload: []u8) Decoder { + return .{ + .content_type = content_type, + .payload = payload, + }; + } + + pub fn decode(d: *Decoder, comptime T: type) !T { + switch (@typeInfo(T)) { + .Int => |info| switch (info.bits) { + 8 => { + try skip(d, 1); + return d.payload[d.idx - 1]; + }, + 16 => { + try skip(d, 2); + const b0: u16 = d.payload[d.idx - 2]; + const b1: u16 = d.payload[d.idx - 1]; + return (b0 << 8) | b1; + }, + 24 => { + try skip(d, 3); + const b0: u24 = d.payload[d.idx - 3]; + const b1: u24 = d.payload[d.idx - 2]; + const b2: u24 = d.payload[d.idx - 1]; + return (b0 << 16) | (b1 << 8) | b2; + }, + else => @compileError("unsupported int type: " ++ @typeName(T)), + }, + .Enum => |info| { + const int = try d.decode(info.tag_type); + if (info.is_exhaustive) @compileError("exhaustive enum cannot be used"); + return @as(T, @enumFromInt(int)); + }, + else => @compileError("unsupported type: " ++ @typeName(T)), + } + } + + pub fn array(d: *Decoder, comptime len: usize) ![len]u8 { + try d.skip(len); + return d.payload[d.idx - len ..][0..len].*; + } + + pub fn slice(d: *Decoder, len: usize) ![]const u8 { + try d.skip(len); + return d.payload[d.idx - len ..][0..len]; + } + + pub fn skip(d: *Decoder, amt: usize) !void { + if (d.idx + amt > d.payload.len) return error.TlsDecodeError; + d.idx += amt; + } + + pub fn rest(d: Decoder) []const u8 { + return d.payload[d.idx..]; + } + + pub fn eof(d: Decoder) bool { + return d.idx == d.payload.len; + } + + pub fn expectContentType(d: *Decoder, content_type: proto.ContentType) !void { + if (d.content_type == content_type) return; + + switch (d.content_type) { + .alert => try d.raiseAlert(), + else => return error.TlsUnexpectedMessage, + } + } + + pub fn raiseAlert(d: *Decoder) !void { + if (d.payload.len < 2) return error.TlsUnexpectedMessage; + try proto.Alert.parse(try d.array(2)).toError(); + return error.TlsAlertCloseNotify; + } +}; + +const testing = std.testing; +const data12 = @import("testdata/tls12.zig"); +const testu = @import("testu.zig"); +const CipherSuite = @import("cipher.zig").CipherSuite; + +test Reader { + var fbs = std.io.fixedBufferStream(&data12.server_responses); + var rdr = reader(fbs.reader()); + + const expected = [_]struct { + content_type: proto.ContentType, + payload_len: usize, + }{ + .{ .content_type = .handshake, .payload_len = 49 }, + .{ .content_type = .handshake, .payload_len = 815 }, + .{ .content_type = .handshake, .payload_len = 300 }, + .{ .content_type = .handshake, .payload_len = 4 }, + .{ .content_type = .change_cipher_spec, .payload_len = 1 }, + .{ .content_type = .handshake, .payload_len = 64 }, + }; + for (expected) |e| { + const rec = (try rdr.next()).?; + try testing.expectEqual(e.content_type, rec.content_type); + try testing.expectEqual(e.payload_len, rec.payload.len); + try testing.expectEqual(.tls_1_2, rec.protocol_version); + } +} + +test Decoder { + var fbs = std.io.fixedBufferStream(&data12.server_responses); + var rdr = reader(fbs.reader()); + + var d = (try rdr.nextDecoder()); + try testing.expectEqual(.handshake, d.content_type); + + try testing.expectEqual(.server_hello, try d.decode(proto.Handshake)); + try testing.expectEqual(45, try d.decode(u24)); // length + try testing.expectEqual(.tls_1_2, try d.decode(proto.Version)); + try testing.expectEqualStrings( + &testu.hexToBytes("707172737475767778797a7b7c7d7e7f808182838485868788898a8b8c8d8e8f"), + try d.slice(32), + ); // server random + try testing.expectEqual(0, try d.decode(u8)); // session id len + try testing.expectEqual(.ECDHE_RSA_WITH_AES_128_CBC_SHA, try d.decode(CipherSuite)); + try testing.expectEqual(0, try d.decode(u8)); // compression method + try testing.expectEqual(5, try d.decode(u16)); // extension length + try testing.expectEqual(5, d.rest().len); + try d.skip(5); + try testing.expect(d.eof()); +} + +pub const Writer = struct { + buf: []u8, + pos: usize = 0, + + pub fn write(self: *Writer, data: []const u8) !void { + defer self.pos += data.len; + if (self.pos + data.len > self.buf.len) return error.BufferOverflow; + @memcpy(self.buf[self.pos..][0..data.len], data); + } + + pub fn writeByte(self: *Writer, b: u8) !void { + defer self.pos += 1; + if (self.pos == self.buf.len) return error.BufferOverflow; + self.buf[self.pos] = b; + } + + pub fn writeEnum(self: *Writer, value: anytype) !void { + try self.writeInt(@intFromEnum(value)); + } + + pub fn writeInt(self: *Writer, value: anytype) !void { + const IntT = @TypeOf(value); + const bytes = @divExact(@typeInfo(IntT).Int.bits, 8); + const free = self.buf[self.pos..]; + if (free.len < bytes) return error.BufferOverflow; + mem.writeInt(IntT, free[0..bytes], value, .big); + self.pos += bytes; + } + + pub fn writeHandshakeHeader(self: *Writer, handshake_type: proto.Handshake, payload_len: usize) !void { + try self.write(&record.handshakeHeader(handshake_type, payload_len)); + } + + /// Should be used after writing handshake payload in buffer provided by `getHandshakePayload`. + pub fn advanceHandshake(self: *Writer, handshake_type: proto.Handshake, payload_len: usize) !void { + try self.write(&record.handshakeHeader(handshake_type, payload_len)); + self.pos += payload_len; + } + + /// Record payload is already written by using buffer space from `getPayload`. + /// Now when we know payload len we can write record header and advance over payload. + pub fn advanceRecord(self: *Writer, content_type: proto.ContentType, payload_len: usize) !void { + try self.write(&record.header(content_type, payload_len)); + self.pos += payload_len; + } + + pub fn writeRecord(self: *Writer, content_type: proto.ContentType, payload: []const u8) !void { + try self.write(&record.header(content_type, payload.len)); + try self.write(payload); + } + + /// Preserves space for record header and returns buffer free space. + pub fn getPayload(self: *Writer) []u8 { + return self.buf[self.pos + record.header_len ..]; + } + + /// Preserves space for handshake header and returns buffer free space. + pub fn getHandshakePayload(self: *Writer) []u8 { + return self.buf[self.pos + 4 ..]; + } + + pub fn getWritten(self: *Writer) []const u8 { + return self.buf[0..self.pos]; + } + + pub fn getFree(self: *Writer) []u8 { + return self.buf[self.pos..]; + } + + pub fn writeEnumArray(self: *Writer, comptime E: type, tags: []const E) !void { + assert(@sizeOf(E) == 2); + try self.writeInt(@as(u16, @intCast(tags.len * 2))); + for (tags) |t| { + try self.writeEnum(t); + } + } + + pub fn writeExtension( + self: *Writer, + comptime et: proto.Extension, + tags: anytype, + ) !void { + try self.writeEnum(et); + if (et == .supported_versions) { + try self.writeInt(@as(u16, @intCast(tags.len * 2 + 1))); + try self.writeInt(@as(u8, @intCast(tags.len * 2))); + } else { + try self.writeInt(@as(u16, @intCast(tags.len * 2 + 2))); + try self.writeInt(@as(u16, @intCast(tags.len * 2))); + } + for (tags) |t| { + try self.writeEnum(t); + } + } + + pub fn writeKeyShare( + self: *Writer, + named_groups: []const proto.NamedGroup, + keys: []const []const u8, + ) !void { + assert(named_groups.len == keys.len); + try self.writeEnum(proto.Extension.key_share); + var l: usize = 0; + for (keys) |key| { + l += key.len + 4; + } + try self.writeInt(@as(u16, @intCast(l + 2))); + try self.writeInt(@as(u16, @intCast(l))); + for (named_groups, 0..) |ng, i| { + const key = keys[i]; + try self.writeEnum(ng); + try self.writeInt(@as(u16, @intCast(key.len))); + try self.write(key); + } + } + + pub fn writeServerName(self: *Writer, host: []const u8) !void { + const host_len: u16 = @intCast(host.len); + try self.writeEnum(proto.Extension.server_name); + try self.writeInt(host_len + 5); // byte length of extension payload + try self.writeInt(host_len + 3); // server_name_list byte count + try self.writeByte(0); // name type + try self.writeInt(host_len); + try self.write(host); + } +}; + +test "Writer" { + var buf: [16]u8 = undefined; + var w = Writer{ .buf = &buf }; + + try w.write("ab"); + try w.writeEnum(proto.Curve.named_curve); + try w.writeEnum(proto.NamedGroup.x25519); + try w.writeInt(@as(u16, 0x1234)); + try testing.expectEqualSlices(u8, &[_]u8{ 'a', 'b', 0x03, 0x00, 0x1d, 0x12, 0x34 }, w.getWritten()); +} diff --git a/src/http/async/tls.zig/rsa/der.zig b/src/http/async/tls.zig/rsa/der.zig new file mode 100644 index 00000000..743a65ad --- /dev/null +++ b/src/http/async/tls.zig/rsa/der.zig @@ -0,0 +1,467 @@ +//! An encoding of ASN.1. +//! +//! Distinguised Encoding Rules as defined in X.690 and X.691. +//! +//! A version of Basic Encoding Rules (BER) where there is exactly ONE way to +//! represent non-constructed elements. This is useful for cryptographic signatures. +//! +//! Currently an implementation detail of the standard library not fit for public +//! use since it's missing an encoder. + +const std = @import("std"); +const builtin = @import("builtin"); + +pub const Index = usize; +const log = std.log.scoped(.der); + +/// A secure DER parser that: +/// - Does NOT read memory outside `bytes`. +/// - Does NOT return elements with slices outside `bytes`. +/// - Errors on values that do NOT follow DER rules. +/// - Lengths that could be represented in a shorter form. +/// - Booleans that are not 0xff or 0x00. +pub const Parser = struct { + bytes: []const u8, + index: Index = 0, + + pub const Error = Element.Error || error{ + UnexpectedElement, + InvalidIntegerEncoding, + Overflow, + NonCanonical, + }; + + pub fn expectBool(self: *Parser) Error!bool { + const ele = try self.expect(.universal, false, .boolean); + if (ele.slice.len() != 1) return error.InvalidBool; + + return switch (self.view(ele)[0]) { + 0x00 => false, + 0xff => true, + else => error.InvalidBool, + }; + } + + pub fn expectBitstring(self: *Parser) Error!BitString { + const ele = try self.expect(.universal, false, .bitstring); + const bytes = self.view(ele); + const right_padding = bytes[0]; + if (right_padding >= 8) return error.InvalidBitString; + return .{ + .bytes = bytes[1..], + .right_padding = @intCast(right_padding), + }; + } + + // TODO: return high resolution date time type instead of epoch seconds + pub fn expectDateTime(self: *Parser) Error!i64 { + const ele = try self.expect(.universal, false, null); + const bytes = self.view(ele); + switch (ele.identifier.tag) { + .utc_time => { + // Example: "YYMMDD000000Z" + if (bytes.len != 13) + return error.InvalidDateTime; + if (bytes[12] != 'Z') + return error.InvalidDateTime; + + var date: Date = undefined; + date.year = try parseTimeDigits(bytes[0..2], 0, 99); + date.year += if (date.year >= 50) 1900 else 2000; + date.month = try parseTimeDigits(bytes[2..4], 1, 12); + date.day = try parseTimeDigits(bytes[4..6], 1, 31); + const time = try parseTime(bytes[6..12]); + + return date.toEpochSeconds() + time.toSec(); + }, + .generalized_time => { + // Examples: + // "19920622123421Z" + // "19920722132100.3Z" + if (bytes.len < 15) + return error.InvalidDateTime; + + var date: Date = undefined; + date.year = try parseYear4(bytes[0..4]); + date.month = try parseTimeDigits(bytes[4..6], 1, 12); + date.day = try parseTimeDigits(bytes[6..8], 1, 31); + const time = try parseTime(bytes[8..14]); + + return date.toEpochSeconds() + time.toSec(); + }, + else => return error.InvalidDateTime, + } + } + + pub fn expectOid(self: *Parser) Error![]const u8 { + const oid = try self.expect(.universal, false, .object_identifier); + return self.view(oid); + } + + pub fn expectEnum(self: *Parser, comptime Enum: type) Error!Enum { + const oid = try self.expectOid(); + return Enum.oids.get(oid) orelse { + if (builtin.mode == .Debug) { + var buf: [256]u8 = undefined; + var stream = std.io.fixedBufferStream(&buf); + try @import("./oid.zig").decode(oid, stream.writer()); + log.warn("unknown oid {s} for enum {s}\n", .{ stream.getWritten(), @typeName(Enum) }); + } + return error.UnknownObjectId; + }; + } + + pub fn expectInt(self: *Parser, comptime T: type) Error!T { + const ele = try self.expectPrimitive(.integer); + const bytes = self.view(ele); + + const info = @typeInfo(T); + if (info != .Int) @compileError(@typeName(T) ++ " is not an int type"); + const Shift = std.math.Log2Int(u8); + + var result: std.meta.Int(.unsigned, info.Int.bits) = 0; + for (bytes, 0..) |b, index| { + const shifted = @shlWithOverflow(b, @as(Shift, @intCast(index * 8))); + if (shifted[1] == 1) return error.Overflow; + + result |= shifted[0]; + } + + return @bitCast(result); + } + + pub fn expectString(self: *Parser, allowed: std.EnumSet(String.Tag)) Error!String { + const ele = try self.expect(.universal, false, null); + switch (ele.identifier.tag) { + inline .string_utf8, + .string_numeric, + .string_printable, + .string_teletex, + .string_videotex, + .string_ia5, + .string_visible, + .string_universal, + .string_bmp, + => |t| { + const tagname = @tagName(t)["string_".len..]; + const tag = std.meta.stringToEnum(String.Tag, tagname) orelse unreachable; + if (allowed.contains(tag)) { + return String{ .tag = tag, .data = self.view(ele) }; + } + }, + else => {}, + } + return error.UnexpectedElement; + } + + pub fn expectPrimitive(self: *Parser, tag: ?Identifier.Tag) Error!Element { + var elem = try self.expect(.universal, false, tag); + if (tag == .integer and elem.slice.len() > 0) { + if (self.view(elem)[0] == 0) elem.slice.start += 1; + if (elem.slice.len() > 0 and self.view(elem)[0] == 0) return error.InvalidIntegerEncoding; + } + return elem; + } + + /// Remember to call `expectEnd` + pub fn expectSequence(self: *Parser) Error!Element { + return try self.expect(.universal, true, .sequence); + } + + /// Remember to call `expectEnd` + pub fn expectSequenceOf(self: *Parser) Error!Element { + return try self.expect(.universal, true, .sequence_of); + } + + pub fn expectEnd(self: *Parser, val: usize) Error!void { + if (self.index != val) return error.NonCanonical; // either forgot to parse end OR an attacker + } + + pub fn expect( + self: *Parser, + class: ?Identifier.Class, + constructed: ?bool, + tag: ?Identifier.Tag, + ) Error!Element { + if (self.index >= self.bytes.len) return error.EndOfStream; + + const res = try Element.init(self.bytes, self.index); + if (tag) |e| { + if (res.identifier.tag != e) return error.UnexpectedElement; + } + if (constructed) |e| { + if (res.identifier.constructed != e) return error.UnexpectedElement; + } + if (class) |e| { + if (res.identifier.class != e) return error.UnexpectedElement; + } + self.index = if (res.identifier.constructed) res.slice.start else res.slice.end; + return res; + } + + pub fn view(self: Parser, elem: Element) []const u8 { + return elem.slice.view(self.bytes); + } + + pub fn seek(self: *Parser, index: usize) void { + self.index = index; + } + + pub fn eof(self: *Parser) bool { + return self.index == self.bytes.len; + } +}; + +pub const Element = struct { + identifier: Identifier, + slice: Slice, + + pub const Slice = struct { + start: Index, + end: Index, + + pub fn len(self: Slice) Index { + return self.end - self.start; + } + + pub fn view(self: Slice, bytes: []const u8) []const u8 { + return bytes[self.start..self.end]; + } + }; + + pub const Error = error{ InvalidLength, EndOfStream }; + + pub fn init(bytes: []const u8, index: Index) Error!Element { + var stream = std.io.fixedBufferStream(bytes[index..]); + var reader = stream.reader(); + + const identifier = @as(Identifier, @bitCast(try reader.readByte())); + const size_or_len_size = try reader.readByte(); + + var start = index + 2; + // short form between 0-127 + if (size_or_len_size < 128) { + const end = start + size_or_len_size; + if (end > bytes.len) return error.InvalidLength; + + return .{ .identifier = identifier, .slice = .{ .start = start, .end = end } }; + } + + // long form between 0 and std.math.maxInt(u1024) + const len_size: u7 = @truncate(size_or_len_size); + start += len_size; + if (len_size > @sizeOf(Index)) return error.InvalidLength; + const len = try reader.readVarInt(Index, .big, len_size); + if (len < 128) return error.InvalidLength; // should have used short form + + const end = std.math.add(Index, start, len) catch return error.InvalidLength; + if (end > bytes.len) return error.InvalidLength; + + return .{ .identifier = identifier, .slice = .{ .start = start, .end = end } }; + } +}; + +test Element { + const short_form = [_]u8{ 0x30, 0x03, 0x02, 0x01, 0x09 }; + try std.testing.expectEqual(Element{ + .identifier = Identifier{ .tag = .sequence, .constructed = true, .class = .universal }, + .slice = .{ .start = 2, .end = short_form.len }, + }, Element.init(&short_form, 0)); + + const long_form = [_]u8{ 0x30, 129, 129 } ++ [_]u8{0} ** 129; + try std.testing.expectEqual(Element{ + .identifier = Identifier{ .tag = .sequence, .constructed = true, .class = .universal }, + .slice = .{ .start = 3, .end = long_form.len }, + }, Element.init(&long_form, 0)); +} + +test "parser.expectInt" { + const one = [_]u8{ 2, 1, 1 }; + var parser = Parser{ .bytes = &one }; + try std.testing.expectEqual(@as(u8, 1), try parser.expectInt(u8)); +} + +pub const Identifier = packed struct(u8) { + tag: Tag, + constructed: bool, + class: Class, + + pub const Class = enum(u2) { + universal, + application, + context_specific, + private, + }; + + // https://www.oss.com/asn1/resources/asn1-made-simple/asn1-quick-reference/asn1-tags.html + pub const Tag = enum(u5) { + boolean = 1, + integer = 2, + bitstring = 3, + octetstring = 4, + null = 5, + object_identifier = 6, + real = 9, + enumerated = 10, + string_utf8 = 12, + sequence = 16, + sequence_of = 17, + string_numeric = 18, + string_printable = 19, + string_teletex = 20, + string_videotex = 21, + string_ia5 = 22, + utc_time = 23, + generalized_time = 24, + string_visible = 26, + string_universal = 28, + string_bmp = 30, + _, + }; +}; + +pub const BitString = struct { + bytes: []const u8, + right_padding: u3, + + pub fn bitLen(self: BitString) usize { + return self.bytes.len * 8 + self.right_padding; + } +}; + +pub const String = struct { + tag: Tag, + data: []const u8, + + pub const Tag = enum { + /// Blessed. + utf8, + /// us-ascii ([-][0-9][eE][.])* + numeric, + /// us-ascii ([A-Z][a-z][0-9][.?!,][ \t])* + printable, + /// iso-8859-1 with escaping into different character sets. + /// Cursed. + teletex, + /// iso-8859-1 + videotex, + /// us-ascii first 128 characters. + ia5, + /// us-ascii without control characters. + visible, + /// utf-32-be + universal, + /// utf-16-be + bmp, + }; + + pub const all = [_]Tag{ + .utf8, + .numeric, + .printable, + .teletex, + .videotex, + .ia5, + .visible, + .universal, + .bmp, + }; +}; + +const Date = struct { + year: Year, + month: u8, + day: u8, + + const Year = std.time.epoch.Year; + + fn toEpochSeconds(date: Date) i64 { + // Euclidean Affine Transform by Cassio and Neri. + // Shift and correction constants for 1970-01-01. + const s = 82; + const K = 719468 + 146097 * s; + const L = 400 * s; + + const Y_G: u32 = date.year; + const M_G: u32 = date.month; + const D_G: u32 = date.day; + // Map to computational calendar. + const J: u32 = if (M_G <= 2) 1 else 0; + const Y: u32 = Y_G + L - J; + const M: u32 = if (J != 0) M_G + 12 else M_G; + const D: u32 = D_G - 1; + const C: u32 = Y / 100; + + // Rata die. + const y_star: u32 = 1461 * Y / 4 - C + C / 4; + const m_star: u32 = (979 * M - 2919) / 32; + const N: u32 = y_star + m_star + D; + const days: i32 = @intCast(N - K); + + return @as(i64, days) * std.time.epoch.secs_per_day; + } +}; + +const Time = struct { + hour: std.math.IntFittingRange(0, 24), + minute: std.math.IntFittingRange(0, 60), + second: std.math.IntFittingRange(0, 60), + + fn toSec(t: Time) i64 { + var sec: i64 = 0; + sec += @as(i64, t.hour) * 60 * 60; + sec += @as(i64, t.minute) * 60; + sec += t.second; + return sec; + } +}; + +fn parseTimeDigits( + text: *const [2]u8, + min: comptime_int, + max: comptime_int, +) !std.math.IntFittingRange(min, max) { + const result = std.fmt.parseInt(std.math.IntFittingRange(min, max), text, 10) catch + return error.InvalidTime; + if (result < min) return error.InvalidTime; + if (result > max) return error.InvalidTime; + return result; +} + +test parseTimeDigits { + const expectEqual = std.testing.expectEqual; + try expectEqual(@as(u8, 0), try parseTimeDigits("00", 0, 99)); + try expectEqual(@as(u8, 99), try parseTimeDigits("99", 0, 99)); + try expectEqual(@as(u8, 42), try parseTimeDigits("42", 0, 99)); + + const expectError = std.testing.expectError; + try expectError(error.InvalidTime, parseTimeDigits("13", 1, 12)); + try expectError(error.InvalidTime, parseTimeDigits("00", 1, 12)); + try expectError(error.InvalidTime, parseTimeDigits("Di", 0, 99)); +} + +fn parseYear4(text: *const [4]u8) !Date.Year { + const result = std.fmt.parseInt(Date.Year, text, 10) catch return error.InvalidYear; + if (result > 9999) return error.InvalidYear; + return result; +} + +test parseYear4 { + const expectEqual = std.testing.expectEqual; + try expectEqual(@as(Date.Year, 0), try parseYear4("0000")); + try expectEqual(@as(Date.Year, 9999), try parseYear4("9999")); + try expectEqual(@as(Date.Year, 1988), try parseYear4("1988")); + + const expectError = std.testing.expectError; + try expectError(error.InvalidYear, parseYear4("999b")); + try expectError(error.InvalidYear, parseYear4("crap")); + try expectError(error.InvalidYear, parseYear4("r:bQ")); +} + +fn parseTime(bytes: *const [6]u8) !Time { + return .{ + .hour = try parseTimeDigits(bytes[0..2], 0, 23), + .minute = try parseTimeDigits(bytes[2..4], 0, 59), + .second = try parseTimeDigits(bytes[4..6], 0, 59), + }; +} diff --git a/src/http/async/tls.zig/rsa/oid.zig b/src/http/async/tls.zig/rsa/oid.zig new file mode 100644 index 00000000..fd360c3f --- /dev/null +++ b/src/http/async/tls.zig/rsa/oid.zig @@ -0,0 +1,132 @@ +//! Developed by ITU-U and ISO/IEC for naming objects. Used in DER. +//! +//! This implementation supports any number of `u32` arcs. + +const Arc = u32; +const encoding_base = 128; + +/// Returns encoded length. +pub fn encodeLen(dot_notation: []const u8) !usize { + var split = std.mem.splitScalar(u8, dot_notation, '.'); + if (split.next() == null) return 0; + if (split.next() == null) return 1; + + var res: usize = 1; + while (split.next()) |s| { + const parsed = try std.fmt.parseUnsigned(Arc, s, 10); + const n_bytes = if (parsed == 0) 0 else std.math.log(Arc, encoding_base, parsed); + + res += n_bytes; + res += 1; + } + + return res; +} + +pub const EncodeError = std.fmt.ParseIntError || error{ + MissingPrefix, + BufferTooSmall, +}; + +pub fn encode(dot_notation: []const u8, buf: []u8) EncodeError![]const u8 { + if (buf.len < try encodeLen(dot_notation)) return error.BufferTooSmall; + + var split = std.mem.splitScalar(u8, dot_notation, '.'); + const first_str = split.next() orelse return error.MissingPrefix; + const second_str = split.next() orelse return error.MissingPrefix; + + const first = try std.fmt.parseInt(u8, first_str, 10); + const second = try std.fmt.parseInt(u8, second_str, 10); + + buf[0] = first * 40 + second; + + var i: usize = 1; + while (split.next()) |s| { + var parsed = try std.fmt.parseUnsigned(Arc, s, 10); + const n_bytes = if (parsed == 0) 0 else std.math.log(Arc, encoding_base, parsed); + + for (0..n_bytes) |j| { + const place = std.math.pow(Arc, encoding_base, n_bytes - @as(Arc, @intCast(j))); + const digit: u8 = @intCast(@divFloor(parsed, place)); + + buf[i] = digit | 0x80; + parsed -= digit * place; + + i += 1; + } + buf[i] = @intCast(parsed); + i += 1; + } + + return buf[0..i]; +} + +pub fn decode(encoded: []const u8, writer: anytype) @TypeOf(writer).Error!void { + const first = @divTrunc(encoded[0], 40); + const second = encoded[0] - first * 40; + try writer.print("{d}.{d}", .{ first, second }); + + var i: usize = 1; + while (i != encoded.len) { + const n_bytes: usize = brk: { + var res: usize = 1; + var j: usize = i; + while (encoded[j] & 0x80 != 0) { + res += 1; + j += 1; + } + break :brk res; + }; + + var n: usize = 0; + for (0..n_bytes) |j| { + const place = std.math.pow(usize, encoding_base, n_bytes - j - 1); + n += place * (encoded[i] & 0b01111111); + i += 1; + } + try writer.print(".{d}", .{n}); + } +} + +pub fn encodeComptime(comptime dot_notation: []const u8) [encodeLen(dot_notation) catch unreachable]u8 { + @setEvalBranchQuota(10_000); + var buf: [encodeLen(dot_notation) catch unreachable]u8 = undefined; + _ = encode(dot_notation, &buf) catch unreachable; + return buf; +} + +const std = @import("std"); + +fn testOid(expected_encoded: []const u8, expected_dot_notation: []const u8) !void { + var buf: [256]u8 = undefined; + const encoded = try encode(expected_dot_notation, &buf); + try std.testing.expectEqualSlices(u8, expected_encoded, encoded); + + var stream = std.io.fixedBufferStream(&buf); + try decode(expected_encoded, stream.writer()); + try std.testing.expectEqualStrings(expected_dot_notation, stream.getWritten()); +} + +test "encode and decode" { + // https://learn.microsoft.com/en-us/windows/win32/seccertenroll/about-object-identifier + try testOid( + &[_]u8{ 0x2b, 0x06, 0x01, 0x04, 0x01, 0x82, 0x37, 0x15, 0x14 }, + "1.3.6.1.4.1.311.21.20", + ); + // https://luca.ntop.org/Teaching/Appunti/asn1.html + try testOid(&[_]u8{ 0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d }, "1.2.840.113549"); + // https://www.sysadmins.lv/blog-en/how-to-encode-object-identifier-to-an-asn1-der-encoded-string.aspx + try testOid(&[_]u8{ 0x2a, 0x86, 0x8d, 0x20 }, "1.2.100000"); + try testOid( + &[_]u8{ 0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d, 0x01, 0x01, 0x0b }, + "1.2.840.113549.1.1.11", + ); + try testOid(&[_]u8{ 0x2b, 0x65, 0x70 }, "1.3.101.112"); +} + +test encodeComptime { + try std.testing.expectEqual( + [_]u8{ 0x2b, 0x06, 0x01, 0x04, 0x01, 0x82, 0x37, 0x15, 0x14 }, + encodeComptime("1.3.6.1.4.1.311.21.20"), + ); +} diff --git a/src/http/async/tls.zig/rsa/rsa.zig b/src/http/async/tls.zig/rsa/rsa.zig new file mode 100644 index 00000000..5e5f42fe --- /dev/null +++ b/src/http/async/tls.zig/rsa/rsa.zig @@ -0,0 +1,880 @@ +//! RFC8017: Public Key Cryptography Standards #1 v2.2 (PKCS1) +const std = @import("std"); +const der = @import("der.zig"); +const ff = std.crypto.ff; + +pub const max_modulus_bits = 4096; +const max_modulus_len = max_modulus_bits / 8; + +const Modulus = std.crypto.ff.Modulus(max_modulus_bits); +const Fe = Modulus.Fe; + +pub const ValueError = error{ + Modulus, + Exponent, +}; + +pub const PublicKey = struct { + /// `n` + modulus: Modulus, + /// `e` + public_exponent: Fe, + + pub const FromBytesError = ValueError || ff.OverflowError || ff.FieldElementError || ff.InvalidModulusError || error{InsecureBitCount}; + + pub fn fromBytes(mod: []const u8, exp: []const u8) FromBytesError!PublicKey { + const modulus = try Modulus.fromBytes(mod, .big); + if (modulus.bits() <= 512) return error.InsecureBitCount; + const public_exponent = try Fe.fromBytes(modulus, exp, .big); + + if (std.debug.runtime_safety) { + // > the RSA public exponent e is an integer between 3 and n - 1 satisfying + // > GCD(e,\lambda(n)) = 1, where \lambda(n) = LCM(r_1 - 1, ..., r_u - 1) + const e_v = public_exponent.toPrimitive(u32) catch return error.Exponent; + if (!public_exponent.isOdd()) return error.Exponent; + if (e_v < 3) return error.Exponent; + if (modulus.v.compare(public_exponent.v) == .lt) return error.Exponent; + } + + return .{ .modulus = modulus, .public_exponent = public_exponent }; + } + + pub fn fromDer(bytes: []const u8) (der.Parser.Error || FromBytesError)!PublicKey { + var parser = der.Parser{ .bytes = bytes }; + + const seq = try parser.expectSequence(); + defer parser.seek(seq.slice.end); + + const modulus = try parser.expectPrimitive(.integer); + const pub_exp = try parser.expectPrimitive(.integer); + + try parser.expectEnd(seq.slice.end); + try parser.expectEnd(bytes.len); + + return try fromBytes(parser.view(modulus), parser.view(pub_exp)); + } + + /// Deprecated. + /// + /// Encrypt a short message using RSAES-PKCS1-v1_5. + /// The use of this scheme for encrypting an arbitrary message, as opposed to a + /// randomly generated key, is NOT RECOMMENDED. + pub fn encryptPkcsv1_5(pk: PublicKey, msg: []const u8, out: []u8) ![]const u8 { + // align variable names with spec + const k = byteLen(pk.modulus.bits()); + if (out.len < k) return error.BufferTooSmall; + if (msg.len > k - 11) return error.MessageTooLong; + + // EM = 0x00 || 0x02 || PS || 0x00 || M. + var em = out[0..k]; + em[0] = 0; + em[1] = 2; + + const ps = em[2..][0 .. k - msg.len - 3]; + // Section: 7.2.1 + // PS consists of pseudo-randomly generated nonzero octets. + for (ps) |*v| { + v.* = std.crypto.random.uintLessThan(u8, 0xff) + 1; + } + + em[em.len - msg.len - 1] = 0; + @memcpy(em[em.len - msg.len ..][0..msg.len], msg); + + const m = try Fe.fromBytes(pk.modulus, em, .big); + const e = try pk.modulus.powPublic(m, pk.public_exponent); + try e.toBytes(em, .big); + return em; + } + + /// Encrypt a short message using Optimal Asymmetric Encryption Padding (RSAES-OAEP). + pub fn encryptOaep( + pk: PublicKey, + comptime Hash: type, + msg: []const u8, + label: []const u8, + out: []u8, + ) ![]const u8 { + // align variable names with spec + const k = byteLen(pk.modulus.bits()); + if (out.len < k) return error.BufferTooSmall; + + if (msg.len > k - 2 * Hash.digest_length - 2) return error.MessageTooLong; + + // EM = 0x00 || maskedSeed || maskedDB. + var em = out[0..k]; + em[0] = 0; + const seed = em[1..][0..Hash.digest_length]; + std.crypto.random.bytes(seed); + + // DB = lHash || PS || 0x01 || M. + var db = em[1 + seed.len ..]; + const lHash = labelHash(Hash, label); + @memcpy(db[0..lHash.len], &lHash); + @memset(db[lHash.len .. db.len - msg.len - 2], 0); + db[db.len - msg.len - 1] = 1; + @memcpy(db[db.len - msg.len ..], msg); + + var mgf_buf: [max_modulus_len]u8 = undefined; + + const db_mask = mgf1(Hash, seed, mgf_buf[0..db.len]); + for (db, db_mask) |*v, m| v.* ^= m; + + const seed_mask = mgf1(Hash, db, mgf_buf[0..seed.len]); + for (seed, seed_mask) |*v, m| v.* ^= m; + + const m = try Fe.fromBytes(pk.modulus, em, .big); + const e = try pk.modulus.powPublic(m, pk.public_exponent); + try e.toBytes(em, .big); + return em; + } +}; + +pub fn byteLen(bits: usize) usize { + return std.math.divCeil(usize, bits, 8) catch unreachable; +} + +pub const SecretKey = struct { + /// `d` + private_exponent: Fe, + + pub const FromBytesError = ValueError || ff.OverflowError || ff.FieldElementError; + + pub fn fromBytes(n: Modulus, exp: []const u8) FromBytesError!SecretKey { + const d = try Fe.fromBytes(n, exp, .big); + if (std.debug.runtime_safety) { + // > The RSA private exponent d is a positive integer less than n + // > satisfying e * d == 1 (mod \lambda(n)), + if (!d.isOdd()) return error.Exponent; + if (d.v.compare(n.v) != .lt) return error.Exponent; + } + + return .{ .private_exponent = d }; + } +}; + +pub const KeyPair = struct { + public: PublicKey, + secret: SecretKey, + + pub const FromDerError = PublicKey.FromBytesError || SecretKey.FromBytesError || der.Parser.Error || error{ KeyMismatch, InvalidVersion }; + + pub fn fromDer(bytes: []const u8) FromDerError!KeyPair { + var parser = der.Parser{ .bytes = bytes }; + const seq = try parser.expectSequence(); + const version = try parser.expectInt(u8); + + const mod = try parser.expectPrimitive(.integer); + const pub_exp = try parser.expectPrimitive(.integer); + const sec_exp = try parser.expectPrimitive(.integer); + + const public = try PublicKey.fromBytes(parser.view(mod), parser.view(pub_exp)); + const secret = try SecretKey.fromBytes(public.modulus, parser.view(sec_exp)); + + const prime1 = try parser.expectPrimitive(.integer); + const prime2 = try parser.expectPrimitive(.integer); + const exp1 = try parser.expectPrimitive(.integer); + const exp2 = try parser.expectPrimitive(.integer); + const coeff = try parser.expectPrimitive(.integer); + _ = .{ exp1, exp2, coeff }; + + switch (version) { + 0 => {}, + 1 => { + _ = try parser.expectSequenceOf(); + while (!parser.eof()) { + _ = try parser.expectSequence(); + const ri = try parser.expectPrimitive(.integer); + const di = try parser.expectPrimitive(.integer); + const ti = try parser.expectPrimitive(.integer); + _ = .{ ri, di, ti }; + } + }, + else => return error.InvalidVersion, + } + + try parser.expectEnd(seq.slice.end); + try parser.expectEnd(bytes.len); + + if (std.debug.runtime_safety) { + const p = try Fe.fromBytes(public.modulus, parser.view(prime1), .big); + const q = try Fe.fromBytes(public.modulus, parser.view(prime2), .big); + + // check that n = p * q + const expected_zero = public.modulus.mul(p, q); + if (!expected_zero.isZero()) return error.KeyMismatch; + + // TODO: check that d * e is one mod p-1 and mod q-1. Note d and e were bound + // const de = secret.private_exponent.mul(public.public_exponent); + // const one = public.modulus.one(); + + // if (public.modulus.mul(de, p).compare(one) != .eq) return error.KeyMismatch; + // if (public.modulus.mul(de, q).compare(one) != .eq) return error.KeyMismatch; + } + + return .{ .public = public, .secret = secret }; + } + + /// Deprecated. + pub fn signPkcsv1_5(kp: KeyPair, comptime Hash: type, msg: []const u8, out: []u8) !PKCS1v1_5(Hash).Signature { + var st = try signerPkcsv1_5(kp, Hash); + st.update(msg); + return try st.finalize(out); + } + + /// Deprecated. + pub fn signerPkcsv1_5(kp: KeyPair, comptime Hash: type) !PKCS1v1_5(Hash).Signer { + return PKCS1v1_5(Hash).Signer.init(kp); + } + + /// Deprecated. + pub fn decryptPkcsv1_5(kp: KeyPair, ciphertext: []const u8, out: []u8) ![]const u8 { + const k = byteLen(kp.public.modulus.bits()); + if (out.len < k) return error.BufferTooSmall; + + const em = out[0..k]; + + const m = try Fe.fromBytes(kp.public.modulus, ciphertext, .big); + const e = try kp.public.modulus.pow(m, kp.secret.private_exponent); + try e.toBytes(em, .big); + + // Care shall be taken to ensure that an opponent cannot + // distinguish these error conditions, whether by error + // message or timing. + const msg_start = ct.lastIndexOfScalar(em, 0) orelse em.len; + const ps_len = em.len - msg_start; + if (ct.@"or"(em[0] != 0, ct.@"or"(em[1] != 2, ps_len < 8))) { + return error.Inconsistent; + } + + return em[msg_start + 1 ..]; + } + + pub fn signOaep( + kp: KeyPair, + comptime Hash: type, + msg: []const u8, + salt: ?[]const u8, + out: []u8, + ) !Pss(Hash).Signature { + var st = try signerOaep(kp, Hash, salt); + st.update(msg); + return try st.finalize(out); + } + + /// Salt must outlive returned `PSS.Signer`. + pub fn signerOaep(kp: KeyPair, comptime Hash: type, salt: ?[]const u8) !Pss(Hash).Signer { + return Pss(Hash).Signer.init(kp, salt); + } + + pub fn decryptOaep( + kp: KeyPair, + comptime Hash: type, + ciphertext: []const u8, + label: []const u8, + out: []u8, + ) ![]u8 { + // align variable names with spec + const k = byteLen(kp.public.modulus.bits()); + if (out.len < k) return error.BufferTooSmall; + + const mod = try Fe.fromBytes(kp.public.modulus, ciphertext, .big); + const exp = kp.public.modulus.pow(mod, kp.secret.private_exponent) catch unreachable; + const em = out[0..k]; + try exp.toBytes(em, .big); + + const y = em[0]; + const seed = em[1..][0..Hash.digest_length]; + const db = em[1 + Hash.digest_length ..]; + + var mgf_buf: [max_modulus_len]u8 = undefined; + + const seed_mask = mgf1(Hash, db, mgf_buf[0..seed.len]); + for (seed, seed_mask) |*v, m| v.* ^= m; + + const db_mask = mgf1(Hash, seed, mgf_buf[0..db.len]); + for (db, db_mask) |*v, m| v.* ^= m; + + const expected_hash = labelHash(Hash, label); + const actual_hash = db[0..expected_hash.len]; + + // Care shall be taken to ensure that an opponent cannot + // distinguish these error conditions, whether by error + // message or timing. + const msg_start = ct.indexOfScalarPos(em, expected_hash.len + 1, 1) orelse 0; + if (ct.@"or"(y != 0, ct.@"or"(msg_start == 0, !ct.memEql(&expected_hash, actual_hash)))) { + return error.Inconsistent; + } + + return em[msg_start + 1 ..]; + } + + /// Encrypt short plaintext with secret key. + pub fn encrypt(kp: KeyPair, plaintext: []const u8, out: []u8) !void { + const n = kp.public.modulus; + const k = byteLen(n.bits()); + if (plaintext.len > k) return error.MessageTooLong; + + const msg_as_int = try Fe.fromBytes(n, plaintext, .big); + const enc_as_int = try n.pow(msg_as_int, kp.secret.private_exponent); + try enc_as_int.toBytes(out, .big); + } +}; + +/// Deprecated. +/// +/// Signature Scheme with Appendix v1.5 (RSASSA-PKCS1-v1_5) +/// +/// This standard has been superceded by PSS which is formally proven secure +/// and has fewer footguns. +pub fn PKCS1v1_5(comptime Hash: type) type { + return struct { + const PkcsT = @This(); + pub const Signature = struct { + bytes: []const u8, + + const Self = @This(); + + pub fn verifier(self: Self, public_key: PublicKey) !Verifier { + return Verifier.init(self, public_key); + } + + pub fn verify(self: Self, msg: []const u8, public_key: PublicKey) !void { + var st = Verifier.init(self, public_key); + st.update(msg); + return st.verify(); + } + }; + + pub const Signer = struct { + h: Hash, + key_pair: KeyPair, + + fn init(key_pair: KeyPair) Signer { + return .{ + .h = Hash.init(.{}), + .key_pair = key_pair, + }; + } + + pub fn update(self: *Signer, data: []const u8) void { + self.h.update(data); + } + + pub fn finalize(self: *Signer, out: []u8) !PkcsT.Signature { + const k = byteLen(self.key_pair.public.modulus.bits()); + if (out.len < k) return error.BufferTooSmall; + + var hash: [Hash.digest_length]u8 = undefined; + self.h.final(&hash); + + const em = try emsaEncode(hash, out[0..k]); + try self.key_pair.encrypt(em, em); + return .{ .bytes = em }; + } + }; + + pub const Verifier = struct { + h: Hash, + sig: PkcsT.Signature, + public_key: PublicKey, + + fn init(sig: PkcsT.Signature, public_key: PublicKey) Verifier { + return Verifier{ + .h = Hash.init(.{}), + .sig = sig, + .public_key = public_key, + }; + } + + pub fn update(self: *Verifier, data: []const u8) void { + self.h.update(data); + } + + pub fn verify(self: *Verifier) !void { + const pk = self.public_key; + const s = try Fe.fromBytes(pk.modulus, self.sig.bytes, .big); + const emm = try pk.modulus.powPublic(s, pk.public_exponent); + + var em_buf: [max_modulus_len]u8 = undefined; + const em = em_buf[0..byteLen(pk.modulus.bits())]; + try emm.toBytes(em, .big); + + var hash: [Hash.digest_length]u8 = undefined; + self.h.final(&hash); + + // TODO: compare hash values instead of emsa values + const expected = try emsaEncode(hash, em); + + if (!std.mem.eql(u8, expected, em)) return error.Inconsistent; + } + }; + + /// PKCS Encrypted Message Signature Appendix + fn emsaEncode(hash: [Hash.digest_length]u8, out: []u8) ![]u8 { + const digest_header = comptime digestHeader(); + const tLen = digest_header.len + Hash.digest_length; + const emLen = out.len; + if (emLen < tLen + 11) return error.ModulusTooShort; + if (out.len < emLen) return error.BufferTooSmall; + + var res = out[0..emLen]; + res[0] = 0; + res[1] = 1; + const padding_len = emLen - tLen - 3; + @memset(res[2..][0..padding_len], 0xff); + res[2 + padding_len] = 0; + @memcpy(res[2 + padding_len + 1 ..][0..digest_header.len], digest_header); + @memcpy(res[res.len - hash.len ..], &hash); + + return res; + } + + /// DER encoded header. Sequence of digest algo + digest. + /// TODO: use a DER encoder instead + fn digestHeader() []const u8 { + const sha2 = std.crypto.hash.sha2; + // Section 9.2 Notes 1. + return switch (Hash) { + std.crypto.hash.Sha1 => &hexToBytes( + \\30 21 30 09 06 05 2b 0e 03 02 1a 05 00 04 14 + ), + sha2.Sha224 => &hexToBytes( + \\30 2d 30 0d 06 09 60 86 48 01 65 03 04 02 04 + \\05 00 04 1c + ), + sha2.Sha256 => &hexToBytes( + \\30 31 30 0d 06 09 60 86 48 01 65 03 04 02 01 05 00 + \\04 20 + ), + sha2.Sha384 => &hexToBytes( + \\30 41 30 0d 06 09 60 86 48 01 65 03 04 02 02 05 00 + \\04 30 + ), + sha2.Sha512 => &hexToBytes( + \\30 51 30 0d 06 09 60 86 48 01 65 03 04 02 03 05 00 + \\04 40 + ), + // sha2.Sha512224 => &hexToBytes( + // \\30 2d 30 0d 06 09 60 86 48 01 65 03 04 02 05 + // \\05 00 04 1c + // ), + // sha2.Sha512256 => &hexToBytes( + // \\30 31 30 0d 06 09 60 86 48 01 65 03 04 02 06 + // \\05 00 04 20 + // ), + else => @compileError("unknown Hash " ++ @typeName(Hash)), + }; + } + }; +} + +/// Probabilistic Signature Scheme (RSASSA-PSS) +pub fn Pss(comptime Hash: type) type { + // RFC 4055 S3.1 + const default_salt_len = Hash.digest_length; + return struct { + pub const Signature = struct { + bytes: []const u8, + + const Self = @This(); + + pub fn verifier(self: Self, public_key: PublicKey) !Verifier { + return Verifier.init(self, public_key); + } + + pub fn verify(self: Self, msg: []const u8, public_key: PublicKey, salt_len: ?usize) !void { + var st = Verifier.init(self, public_key, salt_len orelse default_salt_len); + st.update(msg); + return st.verify(); + } + }; + + const PssT = @This(); + + pub const Signer = struct { + h: Hash, + key_pair: KeyPair, + salt: ?[]const u8, + + fn init(key_pair: KeyPair, salt: ?[]const u8) Signer { + return .{ + .h = Hash.init(.{}), + .key_pair = key_pair, + .salt = salt, + }; + } + + pub fn update(self: *Signer, data: []const u8) void { + self.h.update(data); + } + + pub fn finalize(self: *Signer, out: []u8) !PssT.Signature { + var hashed: [Hash.digest_length]u8 = undefined; + self.h.final(&hashed); + + const salt = if (self.salt) |s| s else brk: { + var res: [default_salt_len]u8 = undefined; + std.crypto.random.bytes(&res); + break :brk &res; + }; + + const em_bits = self.key_pair.public.modulus.bits() - 1; + const em = try emsaEncode(hashed, salt, em_bits, out); + try self.key_pair.encrypt(em, em); + return .{ .bytes = em }; + } + }; + + pub const Verifier = struct { + h: Hash, + sig: PssT.Signature, + public_key: PublicKey, + salt_len: usize, + + fn init(sig: PssT.Signature, public_key: PublicKey, salt_len: usize) Verifier { + return Verifier{ + .h = Hash.init(.{}), + .sig = sig, + .public_key = public_key, + .salt_len = salt_len, + }; + } + + pub fn update(self: *Verifier, data: []const u8) void { + self.h.update(data); + } + + pub fn verify(self: *Verifier) !void { + const pk = self.public_key; + const s = try Fe.fromBytes(pk.modulus, self.sig.bytes, .big); + const emm = try pk.modulus.powPublic(s, pk.public_exponent); + + var em_buf: [max_modulus_len]u8 = undefined; + const em_bits = pk.modulus.bits() - 1; + const em_len = std.math.divCeil(usize, em_bits, 8) catch unreachable; + var em = em_buf[0..em_len]; + try emm.toBytes(em, .big); + + if (em.len < Hash.digest_length + self.salt_len + 2) return error.Inconsistent; + if (em[em.len - 1] != 0xbc) return error.Inconsistent; + + const db = em[0 .. em.len - Hash.digest_length - 1]; + if (@clz(db[0]) < em.len * 8 - em_bits) return error.Inconsistent; + + const expected_hash = em[db.len..][0..Hash.digest_length]; + var mgf_buf: [max_modulus_len]u8 = undefined; + const db_mask = mgf1(Hash, expected_hash, mgf_buf[0..db.len]); + for (db, db_mask) |*v, m| v.* ^= m; + + for (1..db.len - self.salt_len - 1) |i| { + if (db[i] != 0) return error.Inconsistent; + } + if (db[db.len - self.salt_len - 1] != 1) return error.Inconsistent; + const salt = db[db.len - self.salt_len ..]; + var mp_buf: [max_modulus_len]u8 = undefined; + var mp = mp_buf[0 .. 8 + Hash.digest_length + self.salt_len]; + @memset(mp[0..8], 0); + self.h.final(mp[8..][0..Hash.digest_length]); + @memcpy(mp[8 + Hash.digest_length ..][0..salt.len], salt); + + var actual_hash: [Hash.digest_length]u8 = undefined; + Hash.hash(mp, &actual_hash, .{}); + + if (!std.mem.eql(u8, expected_hash, &actual_hash)) return error.Inconsistent; + } + }; + + /// PSS Encrypted Message Signature Appendix + fn emsaEncode(msg_hash: [Hash.digest_length]u8, salt: []const u8, em_bits: usize, out: []u8) ![]u8 { + const em_len = std.math.divCeil(usize, em_bits, 8) catch unreachable; + + if (em_len < Hash.digest_length + salt.len + 2) return error.Encoding; + + // EM = maskedDB || H || 0xbc + var em = out[0..em_len]; + em[em.len - 1] = 0xbc; + + var mp_buf: [max_modulus_len]u8 = undefined; + // M' = (0x)00 00 00 00 00 00 00 00 || mHash || salt; + const mp = mp_buf[0 .. 8 + Hash.digest_length + salt.len]; + @memset(mp[0..8], 0); + @memcpy(mp[8..][0..Hash.digest_length], &msg_hash); + @memcpy(mp[8 + Hash.digest_length ..][0..salt.len], salt); + + // H = Hash(M') + const hash = em[em.len - 1 - Hash.digest_length ..][0..Hash.digest_length]; + Hash.hash(mp, hash, .{}); + + // DB = PS || 0x01 || salt + var db = em[0 .. em_len - Hash.digest_length - 1]; + @memset(db[0 .. db.len - salt.len - 1], 0); + db[db.len - salt.len - 1] = 1; + @memcpy(db[db.len - salt.len ..], salt); + + var mgf_buf: [max_modulus_len]u8 = undefined; + const db_mask = mgf1(Hash, hash, mgf_buf[0..db.len]); + for (db, db_mask) |*v, m| v.* ^= m; + + // Set the leftmost 8emLen - emBits bits of the leftmost octet + // in maskedDB to zero. + const shift = std.math.comptimeMod(8 * em_len - em_bits, 8); + const mask = @as(u8, 0xff) >> shift; + db[0] &= mask; + + return em; + } + }; +} + +/// Mask generation function. Currently the only one defined. +fn mgf1(comptime Hash: type, seed: []const u8, out: []u8) []u8 { + var c: [@sizeOf(u32)]u8 = undefined; + var tmp: [Hash.digest_length]u8 = undefined; + + var i: usize = 0; + var counter: u32 = 0; + while (i < out.len) : (counter += 1) { + var hasher = Hash.init(.{}); + hasher.update(seed); + std.mem.writeInt(u32, &c, counter, .big); + hasher.update(&c); + + const left = out.len - i; + if (left >= Hash.digest_length) { + // optimization: write straight to `out` + hasher.final(out[i..][0..Hash.digest_length]); + i += Hash.digest_length; + } else { + hasher.final(&tmp); + @memcpy(out[i..][0..left], tmp[0..left]); + i += left; + } + } + + return out; +} + +test mgf1 { + const Hash = std.crypto.hash.sha2.Sha256; + var out: [Hash.digest_length * 2 + 1]u8 = undefined; + try std.testing.expectEqualSlices( + u8, + &hexToBytes( + \\ed 1b 84 6b b9 26 39 00 c8 17 82 ad 08 eb 17 01 + \\fa 8c 72 21 c6 57 63 77 31 7f 5c e8 09 89 9f + ), + mgf1(Hash, "asdf", out[0 .. Hash.digest_length - 1]), + ); + try std.testing.expectEqualSlices( + u8, + &hexToBytes( + \\ed 1b 84 6b b9 26 39 00 c8 17 82 ad 08 eb 17 01 + \\fa 8c 72 21 c6 57 63 77 31 7f 5c e8 09 89 9f 5a + \\22 F2 80 D5 28 08 F4 93 83 76 00 DE 09 E4 EC 92 + \\4A 2C 7C EF 0D F7 7B BE 8F 7F 12 CB 8F 33 A6 65 + \\AB + ), + mgf1(Hash, "asdf", &out), + ); +} + +/// For OAEP. +inline fn labelHash(comptime Hash: type, label: []const u8) [Hash.digest_length]u8 { + if (label.len == 0) { + // magic constants from NIST + const sha2 = std.crypto.hash.sha2; + switch (Hash) { + std.crypto.hash.Sha1 => return hexToBytes( + \\da39a3ee 5e6b4b0d 3255bfef 95601890 + \\afd80709 + ), + sha2.Sha256 => return hexToBytes( + \\e3b0c442 98fc1c14 9afbf4c8 996fb924 + \\27ae41e4 649b934c a495991b 7852b855 + ), + sha2.Sha384 => return hexToBytes( + \\38b060a7 51ac9638 4cd9327e b1b1e36a + \\21fdb711 14be0743 4c0cc7bf 63f6e1da + \\274edebf e76f65fb d51ad2f1 4898b95b + ), + sha2.Sha512 => return hexToBytes( + \\cf83e135 7eefb8bd f1542850 d66d8007 + \\d620e405 0b5715dc 83f4a921 d36ce9ce + \\47d0d13c 5d85f2b0 ff8318d2 877eec2f + \\63b931bd 47417a81 a538327a f927da3e + ), + // just use the empty hash... + else => {}, + } + } + var res: [Hash.digest_length]u8 = undefined; + Hash.hash(label, &res, .{}); + return res; +} + +const ct = if (std.options.side_channels_mitigations == .none) ct_unprotected else ct_protected; + +const ct_unprotected = struct { + fn lastIndexOfScalar(slice: []const u8, value: u8) ?usize { + return std.mem.lastIndexOfScalar(u8, slice, value); + } + + fn indexOfScalarPos(slice: []const u8, start_index: usize, value: u8) ?usize { + return std.mem.indexOfScalarPos(u8, slice, start_index, value); + } + + fn memEql(a: []const u8, b: []const u8) bool { + return std.mem.eql(u8, a, b); + } + + fn @"and"(a: bool, b: bool) bool { + return a and b; + } + + fn @"or"(a: bool, b: bool) bool { + return a or b; + } +}; + +const ct_protected = struct { + fn lastIndexOfScalar(slice: []const u8, value: u8) ?usize { + var res: ?usize = null; + var i: usize = slice.len; + while (i != 0) { + i -= 1; + if (@intFromBool(res == null) & @intFromBool(slice[i] == value) == 1) res = i; + } + return res; + } + + fn indexOfScalarPos(slice: []const u8, start_index: usize, value: u8) ?usize { + var res: ?usize = null; + for (slice[start_index..], start_index..) |c, j| { + if (c == value) res = j; + } + return res; + } + + fn memEql(a: []const u8, b: []const u8) bool { + var res: u1 = 1; + for (a, b) |a_elem, b_elem| { + res &= @intFromBool(a_elem == b_elem); + } + return res == 1; + } + + fn @"and"(a: bool, b: bool) bool { + return (@intFromBool(a) & @intFromBool(b)) == 1; + } + + fn @"or"(a: bool, b: bool) bool { + return (@intFromBool(a) | @intFromBool(b)) == 1; + } +}; + +test ct { + const c = ct_unprotected; + try std.testing.expectEqual(true, c.@"or"(true, false)); + try std.testing.expectEqual(true, c.@"and"(true, true)); + try std.testing.expectEqual(true, c.memEql("Asdf", "Asdf")); + try std.testing.expectEqual(false, c.memEql("asdf", "Asdf")); + try std.testing.expectEqual(3, c.indexOfScalarPos("asdff", 1, 'f')); + try std.testing.expectEqual(4, c.lastIndexOfScalar("asdff", 'f')); +} + +fn removeNonHex(comptime hex: []const u8) []const u8 { + var res: [hex.len]u8 = undefined; + var i: usize = 0; + for (hex) |c| { + if (std.ascii.isHex(c)) { + res[i] = c; + i += 1; + } + } + return res[0..i]; +} + +/// For readable copy/pasting from hex viewers. +fn hexToBytes(comptime hex: []const u8) [removeNonHex(hex).len / 2]u8 { + const hex2 = comptime removeNonHex(hex); + comptime var res: [hex2.len / 2]u8 = undefined; + _ = comptime std.fmt.hexToBytes(&res, hex2) catch unreachable; + return res; +} + +test hexToBytes { + const hex = + \\e3b0c442 98fc1c14 9afbf4c8 996fb924 + \\27ae41e4 649b934c a495991b 7852b855 + ; + try std.testing.expectEqual( + [_]u8{ + 0xe3, 0xb0, 0xc4, 0x42, 0x98, 0xfc, 0x1c, 0x14, + 0x9a, 0xfb, 0xf4, 0xc8, 0x99, 0x6f, 0xb9, 0x24, + 0x27, 0xae, 0x41, 0xe4, 0x64, 0x9b, 0x93, 0x4c, + 0xa4, 0x95, 0x99, 0x1b, 0x78, 0x52, 0xb8, 0x55, + }, + hexToBytes(hex), + ); +} + +const TestHash = std.crypto.hash.sha2.Sha256; +fn testKeypair() !KeyPair { + const keypair_bytes = @embedFile("testdata/id_rsa.der"); + const kp = try KeyPair.fromDer(keypair_bytes); + try std.testing.expectEqual(2048, kp.public.modulus.bits()); + return kp; +} + +test "rsa PKCS1-v1_5 encrypt and decrypt" { + const kp = try testKeypair(); + + const msg = "rsa PKCS1-v1_5 encrypt and decrypt"; + var out: [max_modulus_len]u8 = undefined; + const enc = try kp.public.encryptPkcsv1_5(msg, &out); + + var out2: [max_modulus_len]u8 = undefined; + const dec = try kp.decryptPkcsv1_5(enc, &out2); + + try std.testing.expectEqualSlices(u8, msg, dec); +} + +test "rsa OAEP encrypt and decrypt" { + const kp = try testKeypair(); + + const msg = "rsa OAEP encrypt and decrypt"; + const label = ""; + var out: [max_modulus_len]u8 = undefined; + const enc = try kp.public.encryptOaep(TestHash, msg, label, &out); + + var out2: [max_modulus_len]u8 = undefined; + const dec = try kp.decryptOaep(TestHash, enc, label, &out2); + + try std.testing.expectEqualSlices(u8, msg, dec); +} + +test "rsa PKCS1-v1_5 signature" { + const kp = try testKeypair(); + + const msg = "rsa PKCS1-v1_5 signature"; + var out: [max_modulus_len]u8 = undefined; + + const signature = try kp.signPkcsv1_5(TestHash, msg, &out); + try signature.verify(msg, kp.public); +} + +test "rsa PSS signature" { + const kp = try testKeypair(); + + const msg = "rsa PSS signature"; + var out: [max_modulus_len]u8 = undefined; + + const salts = [_][]const u8{ "asdf", "" }; + for (salts) |salt| { + const signature = try kp.signOaep(TestHash, msg, salt, &out); + try signature.verify(msg, kp.public, salt.len); + } + + const signature = try kp.signOaep(TestHash, msg, null, &out); // random salt + try signature.verify(msg, kp.public, null); +} diff --git a/src/http/async/tls.zig/rsa/testdata/id_rsa.der b/src/http/async/tls.zig/rsa/testdata/id_rsa.der new file mode 100644 index 0000000000000000000000000000000000000000..9e4f1334d16264ca8accea1d9f7212da6a14554a GIT binary patch literal 1191 zcmV;Y1X%kpf&`-i0RRGm0RaHSfHnrY=SOP@lmzUjwvhxs_maFB?)!ao*QgBu9(zkV zO6Cvfz;XO@=K@R(y!5@%9XV^da7IcK=}P!L^Wh0uRC~!)`#~+Ec2W`H^W1lAs#7;^ z$~x@6!>YGCG1Y9gQk;O8yvg7w7~%`}_@Fxd7X(nA&Uw9`Iq~Xg>_?X_gAcXJmEM)1 z<^&?u?!HoaRH5g;iiY+^Z4I9ml^RU`K|Im+mAD~hY(e&`u&UtVtGUCd4b^x|AUT0|5X50)hbmVS71J9cG^gdO7qra{M0j{KnM;d)X4|Dh$$5 zpuK)<(|%N=kJIS5RL4bJb#zDd;>Vn{)X3fy;mX=(OjWQ=^8)r|>p{8`>I2BL7G&f4 zTJ9uTTbuWvdjpoOPq1k)g+g{=Rb*1Z;Uaedc>>;gWfTSv__%efoKYgOC)y9G5#U>) zX4nJos0`bTLHEODsw7TAd zZaNEJYbv|;_HfEL^^x7ZFckyx4;#r{Io*=6Z9}3uriSet?;rJ}1Dtq0BZU5Hi zQ2$bCHjw`G*=@EQ{sO+TQM3ZngpvmKZIi)B<^5nTWPodhAR8k0Hx zRF$oWcryh&arHHJi6s7g-`0Dx>>O0zV}G!2mkbOQ-yu{=4O|BV%p8Sr{rQ_^#~t`> zf?9e9bV^_}9}jMUVFxw}5)x&pGh02vD*60E<8!x0uv>GxJM5R&>2$QQ4Tt~%+};}5 z0)c=)%%UKzS~2V}a|gnxjvGVeIuV_4^!R_w=;DyFwofV-Q6uIo%d2X5fjbKeVWjqAX@$;2c6-1ooGbC%q?qgq z&ubOB)zcw4OL0tehSLK07lV!!1Rl;&ya=HJfq*-h@&u~`pTO^q@QyjGbrp)>GZQdH zcz7kX1Vj5y4#LIWXWtXaiHK{S62>>_jGa;(zih-YEkAhHYt}#X2GJeI?Z^wED){4A z(W%+sitXpW`9b2KDu2r)_+LLHh&*l-n*lh1Z!PcJ-#DyiQth7%Jy~5wD5a_T5y=1u F(f{FrO=AE6 literal 0 HcmV?d00001 diff --git a/src/http/async/tls.zig/testdata/ec_prime256v1_private_key.pem b/src/http/async/tls.zig/testdata/ec_prime256v1_private_key.pem new file mode 100644 index 00000000..67ebf388 --- /dev/null +++ b/src/http/async/tls.zig/testdata/ec_prime256v1_private_key.pem @@ -0,0 +1,5 @@ +-----BEGIN EC PRIVATE KEY----- +MHcCAQEEINJSRKv8kSKEzLHptfAlg+LGh4/pHHlq0XLf30Q9pcztoAoGCCqGSM49 +AwEHoUQDQgAEJpmLyp8aGCgyMcFIJaIq/+4V1K6nPpeoih3bT2npeplF9eyXj7rm +8eW9Ua6VLhq71mqtMC+YLm+IkORBVq1cuA== +-----END EC PRIVATE KEY----- diff --git a/src/http/async/tls.zig/testdata/ec_private_key.pem b/src/http/async/tls.zig/testdata/ec_private_key.pem new file mode 100644 index 00000000..95048aaa --- /dev/null +++ b/src/http/async/tls.zig/testdata/ec_private_key.pem @@ -0,0 +1,6 @@ +-----BEGIN PRIVATE KEY----- +MIG2AgEAMBAGByqGSM49AgEGBSuBBAAiBIGeMIGbAgEBBDAQNT3KGxUdBqpxuO/z +GSJDePMgmB6xLytkfnHQMCqQquXrmcOQZT3BJhm+PwggmwGhZANiAATKxBc6kfqA +piA+Z0rIjVwaZaBNGnP4UZ5TqVewQ/dP9/BQCca2SJpsXauGLcUPmK4sKFQxGe6d +fzq9O50lo7qHEOIpwDBdhRp+oqB6sN2hMtCPbp6eyzsUlm3FUyhN9D0= +-----END PRIVATE KEY----- diff --git a/src/http/async/tls.zig/testdata/ec_secp384r1_private_key.pem b/src/http/async/tls.zig/testdata/ec_secp384r1_private_key.pem new file mode 100644 index 00000000..62eac9ee --- /dev/null +++ b/src/http/async/tls.zig/testdata/ec_secp384r1_private_key.pem @@ -0,0 +1,6 @@ +-----BEGIN EC PRIVATE KEY----- +MIGkAgEBBDDubYpeDdOwxksyQIDiOt6LHt3ikts2HNuR6rqhBg1CLdmp3AVDKfF4 +fPkIr8UDH22gBwYFK4EEACKhZANiAARcVFUVv3bIHS6BEfLt98rtps7XP1y26m2n +v5x/5ecbDH2p7AXBYerJERKFi7ZFE1DSrSAj+KK8otjdEG44ZA2Mtl5AHwDVrKde +RgtavVoreHhLN80jJOun8JnFXQjdNsA= +-----END EC PRIVATE KEY----- diff --git a/src/http/async/tls.zig/testdata/ec_secp521r1_private_key.pem b/src/http/async/tls.zig/testdata/ec_secp521r1_private_key.pem new file mode 100644 index 00000000..5b7f9321 --- /dev/null +++ b/src/http/async/tls.zig/testdata/ec_secp521r1_private_key.pem @@ -0,0 +1,7 @@ +-----BEGIN EC PRIVATE KEY----- +MIHcAgEBBEIB8C9axyQY6mgjjC6htLjc8hGylrDsh4BCv9669JaDj5vbxmCnTNlg +OuS6C9+uJNMbwm6CoIjB7RcgDTrxxX7oCyegBwYFK4EEACOhgYkDgYYABABAT5Q8 +aOj9U0iuJE5tXfKnYTgPuvD6keHZAGJ5veM9uR6jr3BhfGubD6bnlD+cIBQzYWo0 +y/BNMzCRJ55PDCNU5gGLw+vkwhJ1lGF5OS6l2oG5WN3fe6cYo+uJD7+PB3WYNIuX +Ls0oidsEM0Q4WLblQOEP6VLGf4qTcZyhoFWYfkjWiw== +-----END EC PRIVATE KEY----- diff --git a/src/http/async/tls.zig/testdata/google.com/client_random b/src/http/async/tls.zig/testdata/google.com/client_random new file mode 100644 index 00000000..e817c906 --- /dev/null +++ b/src/http/async/tls.zig/testdata/google.com/client_random @@ -0,0 +1 @@ +'”’ßqp0x­0)ì©–Ã~Ì+Œ`‡¬tY4•©D_ \ No newline at end of file diff --git a/src/http/async/tls.zig/testdata/google.com/server_hello b/src/http/async/tls.zig/testdata/google.com/server_hello new file mode 100644 index 0000000000000000000000000000000000000000..57a807650723010f00d7b371701d3004a55bc685 GIT binary patch literal 7158 zcmb_hc|4Tu*S}{W8M0;#V~vsN9%~CJm8FoSlBH$LW-?~R%rGg1A+ogED)ppN+EH4y zsvc5NPy0`0i54npQHi|QJ!Zy`=Y2o#=kxOi_xGH0o$K7^x~}iJ&P@k{fkMy_6pX>p zorU$GtTg;h4}3p%T1cZp<_Yf6$8lfgEC})U3Yvq$XWpA1O~lMwl}*ZPGQWH*Bvy%1 zcYsxO+aM+TMt8uL#aAao5N-ekp}-#qje>MA7=v&eWDo)wEQHkN!y+{=STt^OF$Rr7 zqt!!v6l0IH*N@k1b7kp%(>{%wRkM zl(qD|I2;CxhF{2w;|uV?G+sQDLgV9oeP@%jU=73uqowS%Fc?337M?WQ0XEiEwReOa zNzSkXX^N9Wm>9aiQ^n9e4Av!$hqc7RR8$B=hS)ig!ij-JC^Pw(Pzn%6gi?cmTp2Aw zp`h)NwV|O4*X0FvH-8mgC|76qJvi8zw|Zyd0gKtY$^PnvCEEwKtPuWM;<5nUY;HZT zGQ+>+O?7>X+S=s)EIQuX)%D|^oGlNdPVCdYn30@h@VeReqxO%WgPjWfQ4y~ea9u2n z{V!J@iyl?^<^i^da(93B%G*a>KSflRDPAW~SIII8?8Js{(8gj7B!jz<~h4``@l z0U2yhZk%A41~;^K$s#$x5{(er0P5xh0GODqR2!xO@PD=w$>lAAuA*(U^vW5D#HmuD&mp7 z4BV&^i1JW_9K+@EI4fxs0oc$C!~(q6pF(*%NkGHxBtqT6H?$+O$*csBaap$XrLXy? zN-_rILqh-349mv+snQsnCmv{|*0S}G7fS4q-SLN=ym0)WKMq&GjA5>(No{qdQIAw; zQaN#vEEokxz56748fI^ZbV#n8z7I36vA z#ui9DgEjmY4UQ2nAW6-@*^-Bd&VIN~266T=RB_=#dwcGASTdunl8q z0v_O1d8zVX7R1zwurY;80Ywu@j94_NwiiDm{735I)g;nrCQ_R~ zs9^m@iWCySb%!PzgDe2@U?9sjiwTYlUw-urLEWLTI9wV}mVRk&*rEwCC4UsHp5(eD z^*%JtKsJs@iQ;i&=NVZ_Js~X;xt|SAIur|AWJKXm0kcH{CLTo%nTv>E0Ry`tCMH-X7+g7vCl}kL+e$8*6dvGr*^P7XRyuh#b=E=w zJa^S{&uDP{V}U*cJLCe1fe#yn;N)75EcV@!E@23&|$B}aO!tz+!;b0Q;+cDw|#T;$?sny$yu2% z20Ax>=||(WC`IjC7cHCT-x1s4v8m|fK)t1DavR0Df%2Ed(TeejU|#R?SJD4D=;f%^ zz7pNL8}6JCgu1=>#iW#HRacHqeVWFX%O0+rMmy5);_Sn-92(YTIp+t={KBT)phwE4D^Md z2SY*d6|#q6tWsWMjf?-s>-!znIvV0Dvsd}LNgM|0}I?$GN4q}WRO z`bm+tDu>(h=TG&2WtjM^y^PrP%6Z#6*VCG%+0}cRtS{bn;{UyCP0{_UGdz z>C|uK7~d+i$X#_XulVQ_ZJQ-gH3PP6;(`aS_g!z(2~*x%)&6Glx&_m4p{EX86_(<& z)`hQmk-&I&QD?TNDJH7=$(?IsTLnvCqwn`pA1pswGH*hyvXjv1?*+MgHX1v$*O}u< zGiUPp6Pg1!cJJEPymZC>jt}NRW#;eaA>q+jovp@0DVNtV6XpwE3FM{g3SwBX&ll@ z2XQz8d=N?8OT+LKk^>BW9PN=_8g$ZPguwq=&;Eb4-rqJfy@|{*cKPAru_cL>aqkyR zWZD1rc67-=%)2o2)*U?)-fhF|&`KPuoYm{`c-dl$BdZjM8-!`-$1uFg(*jOjT;cgWy=;Hd+M8x3xcT`rk))9sM zcF%EVqP0|V^lfzuwx9o$-BUcV>ddQ+ftxR#s41$*l(yceWIz&$!&`4u`e0veiPeF% zrF||2C3`}P7k01-SArEmFAl!c$hG35cG~i-cw3pRRjh%n_r&#+gXojGK4ogf&VZw& z{y{4q0f)g>QVd#ggMVwqc_bK?*`d=d{`XFtM-n&Uj)R1pGDy||us=Lyh*SYZ6JTp;2xyef1jH8~d^V7AGx9Zn&lGTY36te$UP}1r z5ho2J0?Eu=dki0;;(B@Kp^$vT$`idu5Ab3uo%8#fW@=!3t39==v@MQ)3WzmpJbr_n z)X=?>Q=4k|v_5lr1}o#-VU_(wdn5C5^P9p?(DiZXg1`Dszfot-Sbf{mXSrv~tcoYw z{b-I*vVr@8J**@jC?$AS$u?7w$e?@woNz*brk3BeVCOF$N@WXsHwl9csZ%EEj=i+x zmSN&lyYMgd4?h$(YGba94Yn|;R>3B0xB5)(4i4(~$T6hbx{u#mHPBP^DrF~1F*ei5 z+TumarA~dj7Vq0`p{aqDg1>hs{#}tCUUtXv#)CT#Q#A6SyBR)9isZYOTbE36oTI6&V27?+`w-;Unur1`f5AQL3n+>M|HPQ ztB;?Y*J)Da;$Dv?(98Ko_6O`)_T`(JZca@I{!52Fjc$pX~_fC5xWKsk>h8Nxg>p9 z7vWxIEp;yzCz8wxW-{1#cDx;*Oi~vgb}UY7jF^&V&mtMZ`h!HRDQ7HEO8DBDG%>QZIQ}MGt&8Xt%&BB*m=Rfa%cTqEQS?Q6Wqd&|i-Y%)w zMQMDpW8Ej^Wt~w6zy2NOrRt~d@XEhJYk#dCYySJbv0IvUQ*A8G$Ly{A$Bnq6x#6F& z8C7?^KlxQxuilVVb)g|t|8d#o@-gLqk-A0Kt~L=5#<29qoqoE`%3v@{>cyE?aS@k# zGEOGeG@@MzQZ`uWYLpPt$@w;}%Ohncqh z!D{8DkEW|uWbMPeBptsa?BJ}K@wjQLnda-uy-)4?)G6A5yq8NK>*QXtiE`+gMt?sh z%c@E1NeFG##Fxu_w4r@a(#}JsXAW27mmfHiX>N#My)pTcNYE8JIZUl}oOiS3RXUfK7SvK3cyE@f~1s5AfV z^uwoI-dvwkvhFGCNsZ^XZAEV$nO>%E)6ZMPdT4V}^%_0ARp{ew{bBbL0|#F4DUnk{ zWcDqK>;q?~s#L8b%(@cP(6_JPOtr`z{C^2GNd-b#V<02^cYnS{waq_!33G9uvYjru zEO-IYI+>@gPvbtx_xmC6Rl=emW5ZwqYagSg8a;cPzv6O@(7QZa^(DuWKhV?T>t8R4 z$lZ3Nc*glt4qxA6fTv#IX1w(%BQd1k&iGK;x7PEe zTDux_pzXi>4mm7$KIED>p*XAVqtC79)*Dy1nKcwVc`~)HCuzRrn=BZ?s zAKfdyOvQRkym~#=tsM6#!lCNX(jEnO?1lob&!rn2M7s>`B2}, ", .{b}); + // if (i % 16 == 0) + // std.debug.print("\n", .{}); + // } + // std.debug.print("}};\n", .{}); + + std.debug.print("const {s} = \"", .{var_name}); + const charset = "0123456789abcdef"; + for (buf) |b| { + const x = charset[b >> 4]; + const y = charset[b & 15]; + std.debug.print("{c}{c} ", .{ x, y }); + } + std.debug.print("\"\n", .{}); +} + +const random_instance = std.Random{ .ptr = undefined, .fillFn = randomFillFn }; +var random_seed: u8 = 0; + +pub fn randomFillFn(_: *anyopaque, buf: []u8) void { + for (buf) |*v| { + v.* = random_seed; + random_seed +%= 1; + } +} + +pub fn random(seed: u8) std.Random { + random_seed = seed; + return random_instance; +} + +// Fill buf with 0,1,..ff,0,... +pub fn fill(buf: []u8) void { + fillFrom(buf, 0); +} + +pub fn fillFrom(buf: []u8, start: u8) void { + var i: u8 = start; + for (buf) |*v| { + v.* = i; + i +%= 1; + } +} + +pub const Stream = struct { + output: std.io.FixedBufferStream([]u8) = undefined, + input: std.io.FixedBufferStream([]const u8) = undefined, + + pub fn init(input: []const u8, output: []u8) Stream { + return .{ + .input = std.io.fixedBufferStream(input), + .output = std.io.fixedBufferStream(output), + }; + } + + pub const ReadError = error{}; + pub const WriteError = error{NoSpaceLeft}; + + pub fn write(self: *Stream, buf: []const u8) !usize { + return try self.output.writer().write(buf); + } + + pub fn writeAll(self: *Stream, buffer: []const u8) !void { + var n: usize = 0; + while (n < buffer.len) { + n += try self.write(buffer[n..]); + } + } + + pub fn read(self: *Stream, buffer: []u8) !usize { + return self.input.read(buffer); + } +}; + +// Copied from: https://github.com/clickingbuttons/zig/blob/f1cea91624fd2deae28bfb2414a4fd9c7e246883/lib/std/crypto/rsa.zig#L791 +/// For readable copy/pasting from hex viewers. +pub fn hexToBytes(comptime hex: []const u8) [removeNonHex(hex).len / 2]u8 { + @setEvalBranchQuota(1000 * 100); + const hex2 = comptime removeNonHex(hex); + comptime var res: [hex2.len / 2]u8 = undefined; + _ = comptime std.fmt.hexToBytes(&res, hex2) catch unreachable; + return res; +} + +fn removeNonHex(comptime hex: []const u8) []const u8 { + @setEvalBranchQuota(1000 * 100); + var res: [hex.len]u8 = undefined; + var i: usize = 0; + for (hex) |c| { + if (std.ascii.isHex(c)) { + res[i] = c; + i += 1; + } + } + return res[0..i]; +} + +test hexToBytes { + const hex = + \\e3b0c442 98fc1c14 9afbf4c8 996fb924 + \\27ae41e4 649b934c a495991b 7852b855 + ; + try std.testing.expectEqual( + [_]u8{ + 0xe3, 0xb0, 0xc4, 0x42, 0x98, 0xfc, 0x1c, 0x14, + 0x9a, 0xfb, 0xf4, 0xc8, 0x99, 0x6f, 0xb9, 0x24, + 0x27, 0xae, 0x41, 0xe4, 0x64, 0x9b, 0x93, 0x4c, + 0xa4, 0x95, 0x99, 0x1b, 0x78, 0x52, 0xb8, 0x55, + }, + hexToBytes(hex), + ); +} diff --git a/src/http/async/tls.zig/transcript.zig b/src/http/async/tls.zig/transcript.zig new file mode 100644 index 00000000..59c94986 --- /dev/null +++ b/src/http/async/tls.zig/transcript.zig @@ -0,0 +1,297 @@ +const std = @import("std"); +const crypto = std.crypto; +const tls = crypto.tls; +const hkdfExpandLabel = tls.hkdfExpandLabel; + +const Sha256 = crypto.hash.sha2.Sha256; +const Sha384 = crypto.hash.sha2.Sha384; +const Sha512 = crypto.hash.sha2.Sha512; + +const HashTag = @import("cipher.zig").CipherSuite.HashTag; + +// Transcript holds hash of all handshake message. +// +// Until the server hello is parsed we don't know which hash (sha256, sha384, +// sha512) will be used so we update all of them. Handshake process will set +// `selected` field once cipher suite is known. Other function will use that +// selected hash. We continue to calculate all hashes because client certificate +// message could use different hash than the other part of the handshake. +// Handshake hash is dictated by the server selected cipher. Client certificate +// hash is dictated by the private key used. +// +// Most of the functions are inlined because they are returning pointers. +// +pub const Transcript = struct { + sha256: Type(.sha256) = .{ .hash = Sha256.init(.{}) }, + sha384: Type(.sha384) = .{ .hash = Sha384.init(.{}) }, + sha512: Type(.sha512) = .{ .hash = Sha512.init(.{}) }, + + tag: HashTag = .sha256, + + pub const max_mac_length = Type(.sha512).mac_length; + + // Transcript Type from hash tag + fn Type(h: HashTag) type { + return switch (h) { + .sha256 => TranscriptT(Sha256), + .sha384 => TranscriptT(Sha384), + .sha512 => TranscriptT(Sha512), + }; + } + + /// Set hash to use in all following function calls. + pub fn use(t: *Transcript, tag: HashTag) void { + t.tag = tag; + } + + pub fn update(t: *Transcript, buf: []const u8) void { + t.sha256.hash.update(buf); + t.sha384.hash.update(buf); + t.sha512.hash.update(buf); + } + + // tls 1.2 handshake specific + + pub inline fn masterSecret( + t: *Transcript, + pre_master_secret: []const u8, + client_random: [32]u8, + server_random: [32]u8, + ) []const u8 { + return switch (t.tag) { + inline else => |h| &@field(t, @tagName(h)).masterSecret( + pre_master_secret, + client_random, + server_random, + ), + }; + } + + pub inline fn keyMaterial( + t: *Transcript, + master_secret: []const u8, + client_random: [32]u8, + server_random: [32]u8, + ) []const u8 { + return switch (t.tag) { + inline else => |h| &@field(t, @tagName(h)).keyExpansion( + master_secret, + client_random, + server_random, + ), + }; + } + + pub fn clientFinishedTls12(t: *Transcript, master_secret: []const u8) [12]u8 { + return switch (t.tag) { + inline else => |h| @field(t, @tagName(h)).clientFinishedTls12(master_secret), + }; + } + + pub fn serverFinishedTls12(t: *Transcript, master_secret: []const u8) [12]u8 { + return switch (t.tag) { + inline else => |h| @field(t, @tagName(h)).serverFinishedTls12(master_secret), + }; + } + + // tls 1.3 handshake specific + + pub inline fn serverCertificateVerify(t: *Transcript) []const u8 { + return switch (t.tag) { + inline else => |h| &@field(t, @tagName(h)).serverCertificateVerify(), + }; + } + + pub inline fn clientCertificateVerify(t: *Transcript) []const u8 { + return switch (t.tag) { + inline else => |h| &@field(t, @tagName(h)).clientCertificateVerify(), + }; + } + + pub fn serverFinishedTls13(t: *Transcript, buf: []u8) []const u8 { + return switch (t.tag) { + inline else => |h| @field(t, @tagName(h)).serverFinishedTls13(buf), + }; + } + + pub fn clientFinishedTls13(t: *Transcript, buf: []u8) []const u8 { + return switch (t.tag) { + inline else => |h| @field(t, @tagName(h)).clientFinishedTls13(buf), + }; + } + + pub const Secret = struct { + client: []const u8, + server: []const u8, + }; + + pub inline fn handshakeSecret(t: *Transcript, shared_key: []const u8) Secret { + return switch (t.tag) { + inline else => |h| @field(t, @tagName(h)).handshakeSecret(shared_key), + }; + } + + pub inline fn applicationSecret(t: *Transcript) Secret { + return switch (t.tag) { + inline else => |h| @field(t, @tagName(h)).applicationSecret(), + }; + } + + // other + + pub fn Hkdf(h: HashTag) type { + return Type(h).Hkdf; + } + + /// Copy of the current hash value + pub inline fn hash(t: *Transcript, comptime Hash: type) Hash { + return switch (Hash) { + Sha256 => t.sha256.hash, + Sha384 => t.sha384.hash, + Sha512 => t.sha512.hash, + else => @compileError("unimplemented"), + }; + } +}; + +fn TranscriptT(comptime Hash: type) type { + return struct { + const Hmac = crypto.auth.hmac.Hmac(Hash); + const Hkdf = crypto.kdf.hkdf.Hkdf(Hmac); + const mac_length = Hmac.mac_length; + + hash: Hash, + handshake_secret: [Hmac.mac_length]u8 = undefined, + server_finished_key: [Hmac.key_length]u8 = undefined, + client_finished_key: [Hmac.key_length]u8 = undefined, + + const Self = @This(); + + fn init(transcript: Hash) Self { + return .{ .transcript = transcript }; + } + + fn serverCertificateVerify(c: *Self) [64 + 34 + Hash.digest_length]u8 { + return ([1]u8{0x20} ** 64) ++ + "TLS 1.3, server CertificateVerify\x00".* ++ + c.hash.peek(); + } + + // ref: https://www.rfc-editor.org/rfc/rfc8446#section-4.4.3 + fn clientCertificateVerify(c: *Self) [64 + 34 + Hash.digest_length]u8 { + return ([1]u8{0x20} ** 64) ++ + "TLS 1.3, client CertificateVerify\x00".* ++ + c.hash.peek(); + } + + fn masterSecret( + _: *Self, + pre_master_secret: []const u8, + client_random: [32]u8, + server_random: [32]u8, + ) [mac_length * 2]u8 { + const seed = "master secret" ++ client_random ++ server_random; + + var a1: [mac_length]u8 = undefined; + var a2: [mac_length]u8 = undefined; + Hmac.create(&a1, seed, pre_master_secret); + Hmac.create(&a2, &a1, pre_master_secret); + + var p1: [mac_length]u8 = undefined; + var p2: [mac_length]u8 = undefined; + Hmac.create(&p1, a1 ++ seed, pre_master_secret); + Hmac.create(&p2, a2 ++ seed, pre_master_secret); + + return p1 ++ p2; + } + + fn keyExpansion( + _: *Self, + master_secret: []const u8, + client_random: [32]u8, + server_random: [32]u8, + ) [mac_length * 4]u8 { + const seed = "key expansion" ++ server_random ++ client_random; + + const a0 = seed; + var a1: [mac_length]u8 = undefined; + var a2: [mac_length]u8 = undefined; + var a3: [mac_length]u8 = undefined; + var a4: [mac_length]u8 = undefined; + Hmac.create(&a1, a0, master_secret); + Hmac.create(&a2, &a1, master_secret); + Hmac.create(&a3, &a2, master_secret); + Hmac.create(&a4, &a3, master_secret); + + var key_material: [mac_length * 4]u8 = undefined; + Hmac.create(key_material[0..mac_length], a1 ++ seed, master_secret); + Hmac.create(key_material[mac_length .. mac_length * 2], a2 ++ seed, master_secret); + Hmac.create(key_material[mac_length * 2 .. mac_length * 3], a3 ++ seed, master_secret); + Hmac.create(key_material[mac_length * 3 ..], a4 ++ seed, master_secret); + return key_material; + } + + fn clientFinishedTls12(self: *Self, master_secret: []const u8) [12]u8 { + const seed = "client finished" ++ self.hash.peek(); + var a1: [mac_length]u8 = undefined; + var p1: [mac_length]u8 = undefined; + Hmac.create(&a1, seed, master_secret); + Hmac.create(&p1, a1 ++ seed, master_secret); + return p1[0..12].*; + } + + fn serverFinishedTls12(self: *Self, master_secret: []const u8) [12]u8 { + const seed = "server finished" ++ self.hash.peek(); + var a1: [mac_length]u8 = undefined; + var p1: [mac_length]u8 = undefined; + Hmac.create(&a1, seed, master_secret); + Hmac.create(&p1, a1 ++ seed, master_secret); + return p1[0..12].*; + } + + // tls 1.3 + + inline fn handshakeSecret(self: *Self, shared_key: []const u8) Transcript.Secret { + const hello_hash = self.hash.peek(); + + const zeroes = [1]u8{0} ** Hash.digest_length; + const early_secret = Hkdf.extract(&[1]u8{0}, &zeroes); + const empty_hash = tls.emptyHash(Hash); + const hs_derived_secret = hkdfExpandLabel(Hkdf, early_secret, "derived", &empty_hash, Hash.digest_length); + + self.handshake_secret = Hkdf.extract(&hs_derived_secret, shared_key); + const client_secret = hkdfExpandLabel(Hkdf, self.handshake_secret, "c hs traffic", &hello_hash, Hash.digest_length); + const server_secret = hkdfExpandLabel(Hkdf, self.handshake_secret, "s hs traffic", &hello_hash, Hash.digest_length); + + self.server_finished_key = hkdfExpandLabel(Hkdf, server_secret, "finished", "", Hmac.key_length); + self.client_finished_key = hkdfExpandLabel(Hkdf, client_secret, "finished", "", Hmac.key_length); + + return .{ .client = &client_secret, .server = &server_secret }; + } + + inline fn applicationSecret(self: *Self) Transcript.Secret { + const handshake_hash = self.hash.peek(); + + const empty_hash = tls.emptyHash(Hash); + const zeroes = [1]u8{0} ** Hash.digest_length; + const ap_derived_secret = hkdfExpandLabel(Hkdf, self.handshake_secret, "derived", &empty_hash, Hash.digest_length); + const master_secret = Hkdf.extract(&ap_derived_secret, &zeroes); + + const client_secret = hkdfExpandLabel(Hkdf, master_secret, "c ap traffic", &handshake_hash, Hash.digest_length); + const server_secret = hkdfExpandLabel(Hkdf, master_secret, "s ap traffic", &handshake_hash, Hash.digest_length); + + return .{ .client = &client_secret, .server = &server_secret }; + } + + fn serverFinishedTls13(self: *Self, buf: []u8) []const u8 { + Hmac.create(buf[0..mac_length], &self.hash.peek(), &self.server_finished_key); + return buf[0..mac_length]; + } + + // client finished message with header + fn clientFinishedTls13(self: *Self, buf: []u8) []const u8 { + Hmac.create(buf[0..mac_length], &self.hash.peek(), &self.client_finished_key); + return buf[0..mac_length]; + } + }; +}