async: use zig-async-io

Signed-off-by: Francis Bouvier <francis@lightpanda.io>
This commit is contained in:
Francis Bouvier
2024-11-21 12:27:00 +01:00
parent 70752027f1
commit de286dd78e
48 changed files with 24 additions and 14125 deletions

3
.gitmodules vendored
View File

@@ -25,3 +25,6 @@
[submodule "vendor/tls.zig"] [submodule "vendor/tls.zig"]
path = vendor/tls.zig path = vendor/tls.zig
url = git@github.com:ianic/tls.zig.git url = git@github.com:ianic/tls.zig.git
[submodule "vendor/zig-async-io"]
path = vendor/zig-async-io
url = git@github.com:lightpanda-io/zig-async-io.git

View File

@@ -159,6 +159,11 @@ fn common(
netsurf.addImport("jsruntime", jsruntimemod); netsurf.addImport("jsruntime", jsruntimemod);
step.root_module.addImport("netsurf", netsurf); step.root_module.addImport("netsurf", netsurf);
const asyncio = b.addModule("asyncio", .{
.root_source_file = b.path("vendor/zig-async-io/src/lib.zig"),
});
step.root_module.addImport("asyncio", asyncio);
const tlsmod = b.addModule("tls", .{ const tlsmod = b.addModule("tls", .{
.root_source_file = b.path("vendor/tls.zig/src/main.zig"), .root_source_file = b.path("vendor/tls.zig/src/main.zig"),
}); });

View File

@@ -40,7 +40,7 @@ const storage = @import("../storage/storage.zig");
const FetchResult = @import("../http/Client.zig").Client.FetchResult; const FetchResult = @import("../http/Client.zig").Client.FetchResult;
const UserContext = @import("../user_context.zig").UserContext; const UserContext = @import("../user_context.zig").UserContext;
const HttpClient = @import("../http/async/main.zig").Client; const HttpClient = @import("asyncio").Client;
const log = std.log.scoped(.browser); const log = std.log.scoped(.browser);

View File

@@ -1,132 +0,0 @@
const std = @import("std");
const Ctx = @import("std/http/Client.zig").Ctx;
const Loop = @import("jsruntime").Loop;
const NetworkImpl = Loop.Network(SingleThreaded);
pub const Blocking = struct {
pub fn connect(
_: *Blocking,
comptime CtxT: type,
ctx: *CtxT,
comptime cbk: fn (ctx: *CtxT, res: anyerror!void) anyerror!void,
socket: std.posix.socket_t,
address: std.net.Address,
) void {
std.posix.connect(socket, &address.any, address.getOsSockLen()) catch |err| {
std.posix.close(socket);
cbk(ctx, err) catch |e| {
ctx.setErr(e);
};
};
cbk(ctx, {}) catch |e| ctx.setErr(e);
}
pub fn send(
_: *Blocking,
comptime CtxT: type,
ctx: *CtxT,
comptime cbk: fn (ctx: *CtxT, res: anyerror!void) anyerror!void,
socket: std.posix.socket_t,
buf: []const u8,
) void {
const len = std.posix.write(socket, buf) catch |err| {
cbk(ctx, err) catch |e| {
return ctx.setErr(e);
};
return ctx.setErr(err);
};
ctx.setLen(len);
cbk(ctx, {}) catch |e| ctx.setErr(e);
}
pub fn recv(
_: *Blocking,
comptime CtxT: type,
ctx: *CtxT,
comptime cbk: fn (ctx: *CtxT, res: anyerror!void) anyerror!void,
socket: std.posix.socket_t,
buf: []u8,
) void {
const len = std.posix.read(socket, buf) catch |err| {
cbk(ctx, err) catch |e| {
return ctx.setErr(e);
};
return ctx.setErr(err);
};
ctx.setLen(len);
cbk(ctx, {}) catch |e| ctx.setErr(e);
}
};
pub const SingleThreaded = struct {
impl: NetworkImpl,
cbk: Cbk,
ctx: *Ctx,
const Cbk = *const fn (ctx: *Ctx, res: anyerror!void) anyerror!void;
pub fn init(loop: *Loop) SingleThreaded {
return .{
.impl = NetworkImpl.init(loop),
.cbk = undefined,
.ctx = undefined,
};
}
pub fn connect(
self: *SingleThreaded,
comptime _: type,
ctx: *Ctx,
comptime cbk: Cbk,
socket: std.posix.socket_t,
address: std.net.Address,
) void {
self.cbk = cbk;
self.ctx = ctx;
self.impl.connect(self, socket, address);
}
pub fn onConnect(self: *SingleThreaded, err: ?anyerror) void {
if (err) |e| return self.ctx.setErr(e);
self.cbk(self.ctx, {}) catch |e| self.ctx.setErr(e);
}
pub fn send(
self: *SingleThreaded,
comptime _: type,
ctx: *Ctx,
comptime cbk: Cbk,
socket: std.posix.socket_t,
buf: []const u8,
) void {
self.ctx = ctx;
self.cbk = cbk;
self.impl.send(self, socket, buf);
}
pub fn onSend(self: *SingleThreaded, ln: usize, err: ?anyerror) void {
if (err) |e| return self.ctx.setErr(e);
self.ctx.setLen(ln);
self.cbk(self.ctx, {}) catch |e| self.ctx.setErr(e);
}
pub fn recv(
self: *SingleThreaded,
comptime _: type,
ctx: *Ctx,
comptime cbk: Cbk,
socket: std.posix.socket_t,
buf: []u8,
) void {
self.ctx = ctx;
self.cbk = cbk;
self.impl.receive(self, socket, buf);
}
pub fn onReceive(self: *SingleThreaded, ln: usize, err: ?anyerror) void {
if (err) |e| return self.ctx.setErr(e);
self.ctx.setLen(ln);
self.cbk(self.ctx, {}) catch |e| self.ctx.setErr(e);
}
};

View File

@@ -1,3 +0,0 @@
const std = @import("std");
pub const Client = @import("std/http/Client.zig");

View File

@@ -1,95 +0,0 @@
const std = @import("std");
pub fn Stack(comptime T: type) type {
return struct {
const Self = @This();
pub const Fn = *const T;
next: ?*Self = null,
func: Fn,
pub fn init(alloc: std.mem.Allocator, comptime func: Fn) !*Self {
const next = try alloc.create(Self);
next.* = .{ .func = func };
return next;
}
pub fn push(self: *Self, alloc: std.mem.Allocator, comptime func: Fn) !void {
if (self.next) |next| {
return next.push(alloc, func);
}
self.next = try Self.init(alloc, func);
}
pub fn pop(self: *Self, alloc: std.mem.Allocator, prev: ?*Self) Fn {
if (self.next) |next| {
return next.pop(alloc, self);
}
defer {
if (prev) |p| {
self.deinit(alloc, p);
}
}
return self.func;
}
pub fn deinit(self: *Self, alloc: std.mem.Allocator, prev: ?*Self) void {
if (self.next) |next| {
// recursivly deinit
next.deinit(alloc, self);
}
if (prev) |p| {
p.next = null;
}
alloc.destroy(self);
}
};
}
fn first() u8 {
return 1;
}
fn second() u8 {
return 2;
}
test "stack" {
const alloc = std.testing.allocator;
const TestStack = Stack(fn () u8);
var stack = TestStack{ .func = first };
try stack.push(alloc, second);
const a = stack.pop(alloc, null);
try std.testing.expect(a() == 2);
const b = stack.pop(alloc, null);
try std.testing.expect(b() == 1);
}
fn first_op(arg: ?*anyopaque) u8 {
const val = @as(*u8, @ptrCast(arg));
return val.* + @as(u8, 1);
}
fn second_op(arg: ?*anyopaque) u8 {
const val = @as(*u8, @ptrCast(arg));
return val.* + @as(u8, 2);
}
test "opaque stack" {
const alloc = std.testing.allocator;
const TestStack = Stack(fn (?*anyopaque) u8);
var stack = TestStack{ .func = first_op };
try stack.push(alloc, second_op);
const a = stack.pop(alloc, null);
var x: u8 = 5;
try std.testing.expect(a(@as(*anyopaque, @ptrCast(&x))) == 2 + x);
const b = stack.pop(alloc, null);
var y: u8 = 3;
try std.testing.expect(b(@as(*anyopaque, @ptrCast(&y))) == 1 + y);
}

View File

@@ -1,318 +0,0 @@
pub const Client = @import("http/Client.zig");
pub const Server = @import("http/Server.zig");
pub const protocol = @import("http/protocol.zig");
pub const HeadParser = std.http.HeadParser;
pub const ChunkParser = std.http.ChunkParser;
pub const HeaderIterator = std.http.HeaderIterator;
pub const Version = enum {
@"HTTP/1.0",
@"HTTP/1.1",
};
/// https://developer.mozilla.org/en-US/docs/Web/HTTP/Methods
///
/// https://datatracker.ietf.org/doc/html/rfc7231#section-4 Initial definition
///
/// https://datatracker.ietf.org/doc/html/rfc5789#section-2 PATCH
pub const Method = enum(u64) {
GET = parse("GET"),
HEAD = parse("HEAD"),
POST = parse("POST"),
PUT = parse("PUT"),
DELETE = parse("DELETE"),
CONNECT = parse("CONNECT"),
OPTIONS = parse("OPTIONS"),
TRACE = parse("TRACE"),
PATCH = parse("PATCH"),
_,
/// Converts `s` into a type that may be used as a `Method` field.
/// Asserts that `s` is 24 or fewer bytes.
pub fn parse(s: []const u8) u64 {
var x: u64 = 0;
const len = @min(s.len, @sizeOf(@TypeOf(x)));
@memcpy(std.mem.asBytes(&x)[0..len], s[0..len]);
return x;
}
pub fn write(self: Method, w: anytype) !void {
const bytes = std.mem.asBytes(&@intFromEnum(self));
const str = std.mem.sliceTo(bytes, 0);
try w.writeAll(str);
}
/// Returns true if a request of this method is allowed to have a body
/// Actual behavior from servers may vary and should still be checked
pub fn requestHasBody(self: Method) bool {
return switch (self) {
.POST, .PUT, .PATCH => true,
.GET, .HEAD, .DELETE, .CONNECT, .OPTIONS, .TRACE => false,
else => true,
};
}
/// Returns true if a response to this method is allowed to have a body
/// Actual behavior from clients may vary and should still be checked
pub fn responseHasBody(self: Method) bool {
return switch (self) {
.GET, .POST, .DELETE, .CONNECT, .OPTIONS, .PATCH => true,
.HEAD, .PUT, .TRACE => false,
else => true,
};
}
/// An HTTP method is safe if it doesn't alter the state of the server.
///
/// https://developer.mozilla.org/en-US/docs/Glossary/Safe/HTTP
///
/// https://datatracker.ietf.org/doc/html/rfc7231#section-4.2.1
pub fn safe(self: Method) bool {
return switch (self) {
.GET, .HEAD, .OPTIONS, .TRACE => true,
.POST, .PUT, .DELETE, .CONNECT, .PATCH => false,
else => false,
};
}
/// An HTTP method is idempotent if an identical request can be made once or several times in a row with the same effect while leaving the server in the same state.
///
/// https://developer.mozilla.org/en-US/docs/Glossary/Idempotent
///
/// https://datatracker.ietf.org/doc/html/rfc7231#section-4.2.2
pub fn idempotent(self: Method) bool {
return switch (self) {
.GET, .HEAD, .PUT, .DELETE, .OPTIONS, .TRACE => true,
.CONNECT, .POST, .PATCH => false,
else => false,
};
}
/// A cacheable response is an HTTP response that can be cached, that is stored to be retrieved and used later, saving a new request to the server.
///
/// https://developer.mozilla.org/en-US/docs/Glossary/cacheable
///
/// https://datatracker.ietf.org/doc/html/rfc7231#section-4.2.3
pub fn cacheable(self: Method) bool {
return switch (self) {
.GET, .HEAD => true,
.POST, .PUT, .DELETE, .CONNECT, .OPTIONS, .TRACE, .PATCH => false,
else => false,
};
}
};
/// https://developer.mozilla.org/en-US/docs/Web/HTTP/Status
pub const Status = enum(u10) {
@"continue" = 100, // RFC7231, Section 6.2.1
switching_protocols = 101, // RFC7231, Section 6.2.2
processing = 102, // RFC2518
early_hints = 103, // RFC8297
ok = 200, // RFC7231, Section 6.3.1
created = 201, // RFC7231, Section 6.3.2
accepted = 202, // RFC7231, Section 6.3.3
non_authoritative_info = 203, // RFC7231, Section 6.3.4
no_content = 204, // RFC7231, Section 6.3.5
reset_content = 205, // RFC7231, Section 6.3.6
partial_content = 206, // RFC7233, Section 4.1
multi_status = 207, // RFC4918
already_reported = 208, // RFC5842
im_used = 226, // RFC3229
multiple_choice = 300, // RFC7231, Section 6.4.1
moved_permanently = 301, // RFC7231, Section 6.4.2
found = 302, // RFC7231, Section 6.4.3
see_other = 303, // RFC7231, Section 6.4.4
not_modified = 304, // RFC7232, Section 4.1
use_proxy = 305, // RFC7231, Section 6.4.5
temporary_redirect = 307, // RFC7231, Section 6.4.7
permanent_redirect = 308, // RFC7538
bad_request = 400, // RFC7231, Section 6.5.1
unauthorized = 401, // RFC7235, Section 3.1
payment_required = 402, // RFC7231, Section 6.5.2
forbidden = 403, // RFC7231, Section 6.5.3
not_found = 404, // RFC7231, Section 6.5.4
method_not_allowed = 405, // RFC7231, Section 6.5.5
not_acceptable = 406, // RFC7231, Section 6.5.6
proxy_auth_required = 407, // RFC7235, Section 3.2
request_timeout = 408, // RFC7231, Section 6.5.7
conflict = 409, // RFC7231, Section 6.5.8
gone = 410, // RFC7231, Section 6.5.9
length_required = 411, // RFC7231, Section 6.5.10
precondition_failed = 412, // RFC7232, Section 4.2][RFC8144, Section 3.2
payload_too_large = 413, // RFC7231, Section 6.5.11
uri_too_long = 414, // RFC7231, Section 6.5.12
unsupported_media_type = 415, // RFC7231, Section 6.5.13][RFC7694, Section 3
range_not_satisfiable = 416, // RFC7233, Section 4.4
expectation_failed = 417, // RFC7231, Section 6.5.14
teapot = 418, // RFC 7168, 2.3.3
misdirected_request = 421, // RFC7540, Section 9.1.2
unprocessable_entity = 422, // RFC4918
locked = 423, // RFC4918
failed_dependency = 424, // RFC4918
too_early = 425, // RFC8470
upgrade_required = 426, // RFC7231, Section 6.5.15
precondition_required = 428, // RFC6585
too_many_requests = 429, // RFC6585
request_header_fields_too_large = 431, // RFC6585
unavailable_for_legal_reasons = 451, // RFC7725
internal_server_error = 500, // RFC7231, Section 6.6.1
not_implemented = 501, // RFC7231, Section 6.6.2
bad_gateway = 502, // RFC7231, Section 6.6.3
service_unavailable = 503, // RFC7231, Section 6.6.4
gateway_timeout = 504, // RFC7231, Section 6.6.5
http_version_not_supported = 505, // RFC7231, Section 6.6.6
variant_also_negotiates = 506, // RFC2295
insufficient_storage = 507, // RFC4918
loop_detected = 508, // RFC5842
not_extended = 510, // RFC2774
network_authentication_required = 511, // RFC6585
_,
pub fn phrase(self: Status) ?[]const u8 {
return switch (self) {
// 1xx statuses
.@"continue" => "Continue",
.switching_protocols => "Switching Protocols",
.processing => "Processing",
.early_hints => "Early Hints",
// 2xx statuses
.ok => "OK",
.created => "Created",
.accepted => "Accepted",
.non_authoritative_info => "Non-Authoritative Information",
.no_content => "No Content",
.reset_content => "Reset Content",
.partial_content => "Partial Content",
.multi_status => "Multi-Status",
.already_reported => "Already Reported",
.im_used => "IM Used",
// 3xx statuses
.multiple_choice => "Multiple Choice",
.moved_permanently => "Moved Permanently",
.found => "Found",
.see_other => "See Other",
.not_modified => "Not Modified",
.use_proxy => "Use Proxy",
.temporary_redirect => "Temporary Redirect",
.permanent_redirect => "Permanent Redirect",
// 4xx statuses
.bad_request => "Bad Request",
.unauthorized => "Unauthorized",
.payment_required => "Payment Required",
.forbidden => "Forbidden",
.not_found => "Not Found",
.method_not_allowed => "Method Not Allowed",
.not_acceptable => "Not Acceptable",
.proxy_auth_required => "Proxy Authentication Required",
.request_timeout => "Request Timeout",
.conflict => "Conflict",
.gone => "Gone",
.length_required => "Length Required",
.precondition_failed => "Precondition Failed",
.payload_too_large => "Payload Too Large",
.uri_too_long => "URI Too Long",
.unsupported_media_type => "Unsupported Media Type",
.range_not_satisfiable => "Range Not Satisfiable",
.expectation_failed => "Expectation Failed",
.teapot => "I'm a teapot",
.misdirected_request => "Misdirected Request",
.unprocessable_entity => "Unprocessable Entity",
.locked => "Locked",
.failed_dependency => "Failed Dependency",
.too_early => "Too Early",
.upgrade_required => "Upgrade Required",
.precondition_required => "Precondition Required",
.too_many_requests => "Too Many Requests",
.request_header_fields_too_large => "Request Header Fields Too Large",
.unavailable_for_legal_reasons => "Unavailable For Legal Reasons",
// 5xx statuses
.internal_server_error => "Internal Server Error",
.not_implemented => "Not Implemented",
.bad_gateway => "Bad Gateway",
.service_unavailable => "Service Unavailable",
.gateway_timeout => "Gateway Timeout",
.http_version_not_supported => "HTTP Version Not Supported",
.variant_also_negotiates => "Variant Also Negotiates",
.insufficient_storage => "Insufficient Storage",
.loop_detected => "Loop Detected",
.not_extended => "Not Extended",
.network_authentication_required => "Network Authentication Required",
else => return null,
};
}
pub const Class = enum {
informational,
success,
redirect,
client_error,
server_error,
};
pub fn class(self: Status) Class {
return switch (@intFromEnum(self)) {
100...199 => .informational,
200...299 => .success,
300...399 => .redirect,
400...499 => .client_error,
else => .server_error,
};
}
test {
try std.testing.expectEqualStrings("OK", Status.ok.phrase().?);
try std.testing.expectEqualStrings("Not Found", Status.not_found.phrase().?);
}
test {
try std.testing.expectEqual(Status.Class.success, Status.ok.class());
try std.testing.expectEqual(Status.Class.client_error, Status.not_found.class());
}
};
pub const TransferEncoding = enum {
chunked,
none,
// compression is intentionally omitted here, as std.http.Client stores it as content-encoding
};
pub const ContentEncoding = enum {
identity,
compress,
@"x-compress",
deflate,
gzip,
@"x-gzip",
zstd,
};
pub const Connection = enum {
keep_alive,
close,
};
pub const Header = struct {
name: []const u8,
value: []const u8,
};
const builtin = @import("builtin");
const std = @import("std");
test {
_ = Client;
_ = Method;
_ = Server;
_ = Status;
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -1,447 +0,0 @@
const std = @import("std");
const builtin = @import("builtin");
const testing = std.testing;
const mem = std.mem;
const assert = std.debug.assert;
const use_vectors = builtin.zig_backend != .stage2_x86_64;
pub const State = enum {
invalid,
// Begin header and trailer parsing states.
start,
seen_n,
seen_r,
seen_rn,
seen_rnr,
finished,
// Begin transfer-encoding: chunked parsing states.
chunk_head_size,
chunk_head_ext,
chunk_head_r,
chunk_data,
chunk_data_suffix,
chunk_data_suffix_r,
/// Returns true if the parser is in a content state (ie. not waiting for more headers).
pub fn isContent(self: State) bool {
return switch (self) {
.invalid, .start, .seen_n, .seen_r, .seen_rn, .seen_rnr => false,
.finished, .chunk_head_size, .chunk_head_ext, .chunk_head_r, .chunk_data, .chunk_data_suffix, .chunk_data_suffix_r => true,
};
}
};
pub const HeadersParser = struct {
state: State = .start,
/// A fixed buffer of len `max_header_bytes`.
/// Pointers into this buffer are not stable until after a message is complete.
header_bytes_buffer: []u8,
header_bytes_len: u32,
next_chunk_length: u64,
/// `false`: headers. `true`: trailers.
done: bool,
/// Initializes the parser with a provided buffer `buf`.
pub fn init(buf: []u8) HeadersParser {
return .{
.header_bytes_buffer = buf,
.header_bytes_len = 0,
.done = false,
.next_chunk_length = 0,
};
}
/// Reinitialize the parser.
/// Asserts the parser is in the "done" state.
pub fn reset(hp: *HeadersParser) void {
assert(hp.done);
hp.* = .{
.state = .start,
.header_bytes_buffer = hp.header_bytes_buffer,
.header_bytes_len = 0,
.done = false,
.next_chunk_length = 0,
};
}
pub fn get(hp: HeadersParser) []u8 {
return hp.header_bytes_buffer[0..hp.header_bytes_len];
}
pub fn findHeadersEnd(r: *HeadersParser, bytes: []const u8) u32 {
var hp: std.http.HeadParser = .{
.state = switch (r.state) {
.start => .start,
.seen_n => .seen_n,
.seen_r => .seen_r,
.seen_rn => .seen_rn,
.seen_rnr => .seen_rnr,
.finished => .finished,
else => unreachable,
},
};
const result = hp.feed(bytes);
r.state = switch (hp.state) {
.start => .start,
.seen_n => .seen_n,
.seen_r => .seen_r,
.seen_rn => .seen_rn,
.seen_rnr => .seen_rnr,
.finished => .finished,
};
return @intCast(result);
}
pub fn findChunkedLen(r: *HeadersParser, bytes: []const u8) u32 {
var cp: std.http.ChunkParser = .{
.state = switch (r.state) {
.chunk_head_size => .head_size,
.chunk_head_ext => .head_ext,
.chunk_head_r => .head_r,
.chunk_data => .data,
.chunk_data_suffix => .data_suffix,
.chunk_data_suffix_r => .data_suffix_r,
.invalid => .invalid,
else => unreachable,
},
.chunk_len = r.next_chunk_length,
};
const result = cp.feed(bytes);
r.state = switch (cp.state) {
.head_size => .chunk_head_size,
.head_ext => .chunk_head_ext,
.head_r => .chunk_head_r,
.data => .chunk_data,
.data_suffix => .chunk_data_suffix,
.data_suffix_r => .chunk_data_suffix_r,
.invalid => .invalid,
};
r.next_chunk_length = cp.chunk_len;
return @intCast(result);
}
/// Returns whether or not the parser has finished parsing a complete
/// message. A message is only complete after the entire body has been read
/// and any trailing headers have been parsed.
pub fn isComplete(r: *HeadersParser) bool {
return r.done and r.state == .finished;
}
pub const CheckCompleteHeadError = error{HttpHeadersOversize};
/// Pushes `in` into the parser. Returns the number of bytes consumed by
/// the header. Any header bytes are appended to `header_bytes_buffer`.
pub fn checkCompleteHead(hp: *HeadersParser, in: []const u8) CheckCompleteHeadError!u32 {
if (hp.state.isContent()) return 0;
const i = hp.findHeadersEnd(in);
const data = in[0..i];
if (hp.header_bytes_len + data.len > hp.header_bytes_buffer.len)
return error.HttpHeadersOversize;
@memcpy(hp.header_bytes_buffer[hp.header_bytes_len..][0..data.len], data);
hp.header_bytes_len += @intCast(data.len);
return i;
}
pub const ReadError = error{
HttpChunkInvalid,
};
/// Reads the body of the message into `buffer`. Returns the number of
/// bytes placed in the buffer.
///
/// If `skip` is true, the buffer will be unused and the body will be skipped.
///
/// See `std.http.Client.Connection for an example of `conn`.
pub fn read(r: *HeadersParser, conn: anytype, buffer: []u8, skip: bool) !usize {
assert(r.state.isContent());
if (r.done) return 0;
var out_index: usize = 0;
while (true) {
switch (r.state) {
.invalid, .start, .seen_n, .seen_r, .seen_rn, .seen_rnr => unreachable,
.finished => {
const data_avail = r.next_chunk_length;
if (skip) {
try conn.fill();
const nread = @min(conn.peek().len, data_avail);
conn.drop(@intCast(nread));
r.next_chunk_length -= nread;
if (r.next_chunk_length == 0 or nread == 0) r.done = true;
return out_index;
} else if (out_index < buffer.len) {
const out_avail = buffer.len - out_index;
const can_read = @as(usize, @intCast(@min(data_avail, out_avail)));
const nread = try conn.read(buffer[0..can_read]);
r.next_chunk_length -= nread;
if (r.next_chunk_length == 0 or nread == 0) r.done = true;
return nread;
} else {
return out_index;
}
},
.chunk_data_suffix, .chunk_data_suffix_r, .chunk_head_size, .chunk_head_ext, .chunk_head_r => {
try conn.fill();
const i = r.findChunkedLen(conn.peek());
conn.drop(@intCast(i));
switch (r.state) {
.invalid => return error.HttpChunkInvalid,
.chunk_data => if (r.next_chunk_length == 0) {
if (std.mem.eql(u8, conn.peek(), "\r\n")) {
r.state = .finished;
conn.drop(2);
} else {
// The trailer section is formatted identically
// to the header section.
r.state = .seen_rn;
}
r.done = true;
return out_index;
},
else => return out_index,
}
continue;
},
.chunk_data => {
const data_avail = r.next_chunk_length;
const out_avail = buffer.len - out_index;
if (skip) {
try conn.fill();
const nread = @min(conn.peek().len, data_avail);
conn.drop(@intCast(nread));
r.next_chunk_length -= nread;
} else if (out_avail > 0) {
const can_read: usize = @intCast(@min(data_avail, out_avail));
const nread = try conn.read(buffer[out_index..][0..can_read]);
r.next_chunk_length -= nread;
out_index += nread;
}
if (r.next_chunk_length == 0) {
r.state = .chunk_data_suffix;
continue;
}
return out_index;
},
}
}
}
};
inline fn int16(array: *const [2]u8) u16 {
return @as(u16, @bitCast(array.*));
}
inline fn int24(array: *const [3]u8) u24 {
return @as(u24, @bitCast(array.*));
}
inline fn int32(array: *const [4]u8) u32 {
return @as(u32, @bitCast(array.*));
}
inline fn intShift(comptime T: type, x: anytype) T {
switch (@import("builtin").cpu.arch.endian()) {
.little => return @as(T, @truncate(x >> (@bitSizeOf(@TypeOf(x)) - @bitSizeOf(T)))),
.big => return @as(T, @truncate(x)),
}
}
/// A buffered (and peekable) Connection.
const MockBufferedConnection = struct {
pub const buffer_size = 0x2000;
conn: std.io.FixedBufferStream([]const u8),
buf: [buffer_size]u8 = undefined,
start: u16 = 0,
end: u16 = 0,
pub fn fill(conn: *MockBufferedConnection) ReadError!void {
if (conn.end != conn.start) return;
const nread = try conn.conn.read(conn.buf[0..]);
if (nread == 0) return error.EndOfStream;
conn.start = 0;
conn.end = @as(u16, @truncate(nread));
}
pub fn peek(conn: *MockBufferedConnection) []const u8 {
return conn.buf[conn.start..conn.end];
}
pub fn drop(conn: *MockBufferedConnection, num: u16) void {
conn.start += num;
}
pub fn readAtLeast(conn: *MockBufferedConnection, buffer: []u8, len: usize) ReadError!usize {
var out_index: u16 = 0;
while (out_index < len) {
const available = conn.end - conn.start;
const left = buffer.len - out_index;
if (available > 0) {
const can_read = @as(u16, @truncate(@min(available, left)));
@memcpy(buffer[out_index..][0..can_read], conn.buf[conn.start..][0..can_read]);
out_index += can_read;
conn.start += can_read;
continue;
}
if (left > conn.buf.len) {
// skip the buffer if the output is large enough
return conn.conn.read(buffer[out_index..]);
}
try conn.fill();
}
return out_index;
}
pub fn read(conn: *MockBufferedConnection, buffer: []u8) ReadError!usize {
return conn.readAtLeast(buffer, 1);
}
pub const ReadError = std.io.FixedBufferStream([]const u8).ReadError || error{EndOfStream};
pub const Reader = std.io.Reader(*MockBufferedConnection, ReadError, read);
pub fn reader(conn: *MockBufferedConnection) Reader {
return Reader{ .context = conn };
}
pub fn writeAll(conn: *MockBufferedConnection, buffer: []const u8) WriteError!void {
return conn.conn.writeAll(buffer);
}
pub fn write(conn: *MockBufferedConnection, buffer: []const u8) WriteError!usize {
return conn.conn.write(buffer);
}
pub const WriteError = std.io.FixedBufferStream([]const u8).WriteError;
pub const Writer = std.io.Writer(*MockBufferedConnection, WriteError, write);
pub fn writer(conn: *MockBufferedConnection) Writer {
return Writer{ .context = conn };
}
};
test "HeadersParser.read length" {
// mock BufferedConnection for read
var headers_buf: [256]u8 = undefined;
var r = HeadersParser.init(&headers_buf);
const data = "GET / HTTP/1.1\r\nHost: localhost\r\nContent-Length: 5\r\n\r\nHello";
var conn: MockBufferedConnection = .{
.conn = std.io.fixedBufferStream(data),
};
while (true) { // read headers
try conn.fill();
const nchecked = try r.checkCompleteHead(conn.peek());
conn.drop(@intCast(nchecked));
if (r.state.isContent()) break;
}
var buf: [8]u8 = undefined;
r.next_chunk_length = 5;
const len = try r.read(&conn, &buf, false);
try std.testing.expectEqual(@as(usize, 5), len);
try std.testing.expectEqualStrings("Hello", buf[0..len]);
try std.testing.expectEqualStrings("GET / HTTP/1.1\r\nHost: localhost\r\nContent-Length: 5\r\n\r\n", r.get());
}
test "HeadersParser.read chunked" {
// mock BufferedConnection for read
var headers_buf: [256]u8 = undefined;
var r = HeadersParser.init(&headers_buf);
const data = "GET / HTTP/1.1\r\nHost: localhost\r\n\r\n2\r\nHe\r\n2\r\nll\r\n1\r\no\r\n0\r\n\r\n";
var conn: MockBufferedConnection = .{
.conn = std.io.fixedBufferStream(data),
};
while (true) { // read headers
try conn.fill();
const nchecked = try r.checkCompleteHead(conn.peek());
conn.drop(@intCast(nchecked));
if (r.state.isContent()) break;
}
var buf: [8]u8 = undefined;
r.state = .chunk_head_size;
const len = try r.read(&conn, &buf, false);
try std.testing.expectEqual(@as(usize, 5), len);
try std.testing.expectEqualStrings("Hello", buf[0..len]);
try std.testing.expectEqualStrings("GET / HTTP/1.1\r\nHost: localhost\r\n\r\n", r.get());
}
test "HeadersParser.read chunked trailer" {
// mock BufferedConnection for read
var headers_buf: [256]u8 = undefined;
var r = HeadersParser.init(&headers_buf);
const data = "GET / HTTP/1.1\r\nHost: localhost\r\n\r\n2\r\nHe\r\n2\r\nll\r\n1\r\no\r\n0\r\nContent-Type: text/plain\r\n\r\n";
var conn: MockBufferedConnection = .{
.conn = std.io.fixedBufferStream(data),
};
while (true) { // read headers
try conn.fill();
const nchecked = try r.checkCompleteHead(conn.peek());
conn.drop(@intCast(nchecked));
if (r.state.isContent()) break;
}
var buf: [8]u8 = undefined;
r.state = .chunk_head_size;
const len = try r.read(&conn, &buf, false);
try std.testing.expectEqual(@as(usize, 5), len);
try std.testing.expectEqualStrings("Hello", buf[0..len]);
while (true) { // read headers
try conn.fill();
const nchecked = try r.checkCompleteHead(conn.peek());
conn.drop(@intCast(nchecked));
if (r.state.isContent()) break;
}
try std.testing.expectEqualStrings("GET / HTTP/1.1\r\nHost: localhost\r\n\r\nContent-Type: text/plain\r\n\r\n", r.get());
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,335 +0,0 @@
const std = @import("std");
const builtin = @import("builtin");
const net = std.net;
const mem = std.mem;
const testing = std.testing;
test "parse and render IP addresses at comptime" {
if (builtin.os.tag == .wasi) return error.SkipZigTest;
comptime {
var ipAddrBuffer: [16]u8 = undefined;
// Parses IPv6 at comptime
const ipv6addr = net.Address.parseIp("::1", 0) catch unreachable;
var ipv6 = std.fmt.bufPrint(ipAddrBuffer[0..], "{}", .{ipv6addr}) catch unreachable;
try std.testing.expect(std.mem.eql(u8, "::1", ipv6[1 .. ipv6.len - 3]));
// Parses IPv4 at comptime
const ipv4addr = net.Address.parseIp("127.0.0.1", 0) catch unreachable;
var ipv4 = std.fmt.bufPrint(ipAddrBuffer[0..], "{}", .{ipv4addr}) catch unreachable;
try std.testing.expect(std.mem.eql(u8, "127.0.0.1", ipv4[0 .. ipv4.len - 2]));
// Returns error for invalid IP addresses at comptime
try testing.expectError(error.InvalidIPAddressFormat, net.Address.parseIp("::123.123.123.123", 0));
try testing.expectError(error.InvalidIPAddressFormat, net.Address.parseIp("127.01.0.1", 0));
try testing.expectError(error.InvalidIPAddressFormat, net.Address.resolveIp("::123.123.123.123", 0));
try testing.expectError(error.InvalidIPAddressFormat, net.Address.resolveIp("127.01.0.1", 0));
}
}
test "parse and render IPv6 addresses" {
if (builtin.os.tag == .wasi) return error.SkipZigTest;
var buffer: [100]u8 = undefined;
const ips = [_][]const u8{
"FF01:0:0:0:0:0:0:FB",
"FF01::Fb",
"::1",
"::",
"1::",
"2001:db8::",
"::1234:5678",
"2001:db8::1234:5678",
"FF01::FB%1234",
"::ffff:123.5.123.5",
};
const printed = [_][]const u8{
"ff01::fb",
"ff01::fb",
"::1",
"::",
"1::",
"2001:db8::",
"::1234:5678",
"2001:db8::1234:5678",
"ff01::fb",
"::ffff:123.5.123.5",
};
for (ips, 0..) |ip, i| {
const addr = net.Address.parseIp6(ip, 0) catch unreachable;
var newIp = std.fmt.bufPrint(buffer[0..], "{}", .{addr}) catch unreachable;
try std.testing.expect(std.mem.eql(u8, printed[i], newIp[1 .. newIp.len - 3]));
if (builtin.os.tag == .linux) {
const addr_via_resolve = net.Address.resolveIp6(ip, 0) catch unreachable;
var newResolvedIp = std.fmt.bufPrint(buffer[0..], "{}", .{addr_via_resolve}) catch unreachable;
try std.testing.expect(std.mem.eql(u8, printed[i], newResolvedIp[1 .. newResolvedIp.len - 3]));
}
}
try testing.expectError(error.InvalidCharacter, net.Address.parseIp6(":::", 0));
try testing.expectError(error.Overflow, net.Address.parseIp6("FF001::FB", 0));
try testing.expectError(error.InvalidCharacter, net.Address.parseIp6("FF01::Fb:zig", 0));
try testing.expectError(error.InvalidEnd, net.Address.parseIp6("FF01:0:0:0:0:0:0:FB:", 0));
try testing.expectError(error.Incomplete, net.Address.parseIp6("FF01:", 0));
try testing.expectError(error.InvalidIpv4Mapping, net.Address.parseIp6("::123.123.123.123", 0));
try testing.expectError(error.Incomplete, net.Address.parseIp6("1", 0));
// TODO Make this test pass on other operating systems.
if (builtin.os.tag == .linux or comptime builtin.os.tag.isDarwin()) {
try testing.expectError(error.Incomplete, net.Address.resolveIp6("ff01::fb%", 0));
try testing.expectError(error.Overflow, net.Address.resolveIp6("ff01::fb%wlp3s0s0s0s0s0s0s0s0", 0));
try testing.expectError(error.Overflow, net.Address.resolveIp6("ff01::fb%12345678901234", 0));
}
}
test "invalid but parseable IPv6 scope ids" {
if (builtin.os.tag != .linux and comptime !builtin.os.tag.isDarwin()) {
// Currently, resolveIp6 with alphanumerical scope IDs only works on Linux.
// TODO Make this test pass on other operating systems.
return error.SkipZigTest;
}
try testing.expectError(error.InterfaceNotFound, net.Address.resolveIp6("ff01::fb%123s45678901234", 0));
}
test "parse and render IPv4 addresses" {
if (builtin.os.tag == .wasi) return error.SkipZigTest;
var buffer: [18]u8 = undefined;
for ([_][]const u8{
"0.0.0.0",
"255.255.255.255",
"1.2.3.4",
"123.255.0.91",
"127.0.0.1",
}) |ip| {
const addr = net.Address.parseIp4(ip, 0) catch unreachable;
var newIp = std.fmt.bufPrint(buffer[0..], "{}", .{addr}) catch unreachable;
try std.testing.expect(std.mem.eql(u8, ip, newIp[0 .. newIp.len - 2]));
}
try testing.expectError(error.Overflow, net.Address.parseIp4("256.0.0.1", 0));
try testing.expectError(error.InvalidCharacter, net.Address.parseIp4("x.0.0.1", 0));
try testing.expectError(error.InvalidEnd, net.Address.parseIp4("127.0.0.1.1", 0));
try testing.expectError(error.Incomplete, net.Address.parseIp4("127.0.0.", 0));
try testing.expectError(error.InvalidCharacter, net.Address.parseIp4("100..0.1", 0));
try testing.expectError(error.NonCanonical, net.Address.parseIp4("127.01.0.1", 0));
}
test "parse and render UNIX addresses" {
if (builtin.os.tag == .wasi) return error.SkipZigTest;
if (!net.has_unix_sockets) return error.SkipZigTest;
var buffer: [14]u8 = undefined;
const addr = net.Address.initUnix("/tmp/testpath") catch unreachable;
const fmt_addr = std.fmt.bufPrint(buffer[0..], "{}", .{addr}) catch unreachable;
try std.testing.expectEqualSlices(u8, "/tmp/testpath", fmt_addr);
const too_long = [_]u8{'a'} ** 200;
try testing.expectError(error.NameTooLong, net.Address.initUnix(too_long[0..]));
}
test "resolve DNS" {
if (builtin.os.tag == .wasi) return error.SkipZigTest;
if (builtin.os.tag == .windows) {
_ = try std.os.windows.WSAStartup(2, 2);
}
defer {
if (builtin.os.tag == .windows) {
std.os.windows.WSACleanup() catch unreachable;
}
}
// Resolve localhost, this should not fail.
{
const localhost_v4 = try net.Address.parseIp("127.0.0.1", 80);
const localhost_v6 = try net.Address.parseIp("::2", 80);
const result = try net.getAddressList(testing.allocator, "localhost", 80);
defer result.deinit();
for (result.addrs) |addr| {
if (addr.eql(localhost_v4) or addr.eql(localhost_v6)) break;
} else @panic("unexpected address for localhost");
}
{
// The tests are required to work even when there is no Internet connection,
// so some of these errors we must accept and skip the test.
const result = net.getAddressList(testing.allocator, "example.com", 80) catch |err| switch (err) {
error.UnknownHostName => return error.SkipZigTest,
error.TemporaryNameServerFailure => return error.SkipZigTest,
else => return err,
};
result.deinit();
}
}
test "listen on a port, send bytes, receive bytes" {
if (builtin.single_threaded) return error.SkipZigTest;
if (builtin.os.tag == .wasi) return error.SkipZigTest;
if (builtin.os.tag == .windows) {
_ = try std.os.windows.WSAStartup(2, 2);
}
defer {
if (builtin.os.tag == .windows) {
std.os.windows.WSACleanup() catch unreachable;
}
}
// Try only the IPv4 variant as some CI builders have no IPv6 localhost
// configured.
const localhost = try net.Address.parseIp("127.0.0.1", 0);
var server = try localhost.listen(.{});
defer server.deinit();
const S = struct {
fn clientFn(server_address: net.Address) !void {
const socket = try net.tcpConnectToAddress(server_address);
defer socket.close();
_ = try socket.writer().writeAll("Hello world!");
}
};
const t = try std.Thread.spawn(.{}, S.clientFn, .{server.listen_address});
defer t.join();
var client = try server.accept();
defer client.stream.close();
var buf: [16]u8 = undefined;
const n = try client.stream.reader().read(&buf);
try testing.expectEqual(@as(usize, 12), n);
try testing.expectEqualSlices(u8, "Hello world!", buf[0..n]);
}
test "listen on an in use port" {
if (builtin.os.tag != .linux and comptime !builtin.os.tag.isDarwin()) {
// TODO build abstractions for other operating systems
return error.SkipZigTest;
}
const localhost = try net.Address.parseIp("127.0.0.1", 0);
var server1 = try localhost.listen(.{ .reuse_port = true });
defer server1.deinit();
var server2 = try server1.listen_address.listen(.{ .reuse_port = true });
defer server2.deinit();
}
fn testClientToHost(allocator: mem.Allocator, name: []const u8, port: u16) anyerror!void {
if (builtin.os.tag == .wasi) return error.SkipZigTest;
const connection = try net.tcpConnectToHost(allocator, name, port);
defer connection.close();
var buf: [100]u8 = undefined;
const len = try connection.read(&buf);
const msg = buf[0..len];
try testing.expect(mem.eql(u8, msg, "hello from server\n"));
}
fn testClient(addr: net.Address) anyerror!void {
if (builtin.os.tag == .wasi) return error.SkipZigTest;
const socket_file = try net.tcpConnectToAddress(addr);
defer socket_file.close();
var buf: [100]u8 = undefined;
const len = try socket_file.read(&buf);
const msg = buf[0..len];
try testing.expect(mem.eql(u8, msg, "hello from server\n"));
}
fn testServer(server: *net.Server) anyerror!void {
if (builtin.os.tag == .wasi) return error.SkipZigTest;
var client = try server.accept();
const stream = client.stream.writer();
try stream.print("hello from server\n", .{});
}
test "listen on a unix socket, send bytes, receive bytes" {
if (builtin.single_threaded) return error.SkipZigTest;
if (!net.has_unix_sockets) return error.SkipZigTest;
if (builtin.os.tag == .windows) {
_ = try std.os.windows.WSAStartup(2, 2);
}
defer {
if (builtin.os.tag == .windows) {
std.os.windows.WSACleanup() catch unreachable;
}
}
const socket_path = try generateFileName("socket.unix");
defer testing.allocator.free(socket_path);
const socket_addr = try net.Address.initUnix(socket_path);
defer std.fs.cwd().deleteFile(socket_path) catch {};
var server = try socket_addr.listen(.{});
defer server.deinit();
const S = struct {
fn clientFn(path: []const u8) !void {
const socket = try net.connectUnixSocket(path);
defer socket.close();
_ = try socket.writer().writeAll("Hello world!");
}
};
const t = try std.Thread.spawn(.{}, S.clientFn, .{socket_path});
defer t.join();
var client = try server.accept();
defer client.stream.close();
var buf: [16]u8 = undefined;
const n = try client.stream.reader().read(&buf);
try testing.expectEqual(@as(usize, 12), n);
try testing.expectEqualSlices(u8, "Hello world!", buf[0..n]);
}
fn generateFileName(base_name: []const u8) ![]const u8 {
const random_bytes_count = 12;
const sub_path_len = comptime std.fs.base64_encoder.calcSize(random_bytes_count);
var random_bytes: [12]u8 = undefined;
std.crypto.random.bytes(&random_bytes);
var sub_path: [sub_path_len]u8 = undefined;
_ = std.fs.base64_encoder.encode(&sub_path, &random_bytes);
return std.fmt.allocPrint(testing.allocator, "{s}-{s}", .{ sub_path[0..], base_name });
}
test "non-blocking tcp server" {
if (builtin.os.tag == .wasi) return error.SkipZigTest;
if (true) {
// https://github.com/ziglang/zig/issues/18315
return error.SkipZigTest;
}
const localhost = try net.Address.parseIp("127.0.0.1", 0);
var server = localhost.listen(.{ .force_nonblocking = true });
defer server.deinit();
const accept_err = server.accept();
try testing.expectError(error.WouldBlock, accept_err);
const socket_file = try net.tcpConnectToAddress(server.listen_address);
defer socket_file.close();
var client = try server.accept();
defer client.stream.close();
const stream = client.stream.writer();
try stream.print("hello from server\n", .{});
var buf: [100]u8 = undefined;
const len = try socket_file.read(&buf);
const msg = buf[0..len];
try testing.expect(mem.eql(u8, msg, "hello from server\n"));
}

View File

@@ -1,260 +0,0 @@
const std = @import("std");
const Allocator = std.mem.Allocator;
const Certificate = std.crypto.Certificate;
const der = Certificate.der;
const rsa = @import("rsa/rsa.zig");
const base64 = std.base64.standard.decoderWithIgnore(" \t\r\n");
const proto = @import("protocol.zig");
const max_ecdsa_key_len = 66;
signature_scheme: proto.SignatureScheme,
key: union {
rsa: rsa.KeyPair,
ecdsa: [max_ecdsa_key_len]u8,
},
const PrivateKey = @This();
pub fn fromFile(gpa: Allocator, file: std.fs.File) !PrivateKey {
const buf = try file.readToEndAlloc(gpa, 1024 * 1024);
defer gpa.free(buf);
return try parsePem(buf);
}
pub fn parsePem(buf: []const u8) !PrivateKey {
const key_start, const key_end, const marker_version = try findKey(buf);
const encoded = std.mem.trim(u8, buf[key_start..key_end], " \t\r\n");
// required bytes:
// 2412, 1821, 1236 for rsa 4096, 3072, 2048 bits size keys
var decoded: [4096]u8 = undefined;
const n = try base64.decode(&decoded, encoded);
if (marker_version == 2) {
return try parseEcDer(decoded[0..n]);
}
return try parseDer(decoded[0..n]);
}
fn findKey(buf: []const u8) !struct { usize, usize, usize } {
const markers = [_]struct {
begin: []const u8,
end: []const u8,
}{
.{ .begin = "-----BEGIN PRIVATE KEY-----", .end = "-----END PRIVATE KEY-----" },
.{ .begin = "-----BEGIN EC PRIVATE KEY-----", .end = "-----END EC PRIVATE KEY-----" },
};
for (markers, 1..) |marker, ver| {
const begin_marker_start = std.mem.indexOfPos(u8, buf, 0, marker.begin) orelse continue;
const key_start = begin_marker_start + marker.begin.len;
const key_end = std.mem.indexOfPos(u8, buf, key_start, marker.end) orelse continue;
return .{ key_start, key_end, ver };
}
return error.MissingEndMarker;
}
// ref: https://asn1js.eu/#MIG2AgEAMBAGByqGSM49AgEGBSuBBAAiBIGeMIGbAgEBBDBKFkVJCtU9FR6egz3yNxKBwXd86cFzMYqyGb8hRc1zVvLdw-So_2FBtITp6jzYmFShZANiAAQ-CH3a1R0V6dFlTK8Rs4M4egrpPtdta0osysO0Zl8mkBiDsTlvJNqeAp7L2ItHgFW8k_CfhgQT6iLDacNMhKC4XOV07r_ePD-mmkvqvRmzfOowHUoVRhCKrOTmF_J9Syc
pub fn parseDer(buf: []const u8) !PrivateKey {
const info = try der.Element.parse(buf, 0);
const version = try der.Element.parse(buf, info.slice.start);
const algo_seq = try der.Element.parse(buf, version.slice.end);
const algo_cat = try der.Element.parse(buf, algo_seq.slice.start);
const key_str = try der.Element.parse(buf, algo_seq.slice.end);
const key_seq = try der.Element.parse(buf, key_str.slice.start);
const key_int = try der.Element.parse(buf, key_seq.slice.start);
const category = try Certificate.parseAlgorithmCategory(buf, algo_cat);
switch (category) {
.rsaEncryption => {
const modulus = try der.Element.parse(buf, key_int.slice.end);
const public_exponent = try der.Element.parse(buf, modulus.slice.end);
const private_exponent = try der.Element.parse(buf, public_exponent.slice.end);
const public_key = try rsa.PublicKey.fromBytes(content(buf, modulus), content(buf, public_exponent));
const secret_key = try rsa.SecretKey.fromBytes(public_key.modulus, content(buf, private_exponent));
const key_pair = rsa.KeyPair{ .public = public_key, .secret = secret_key };
return .{
.signature_scheme = switch (key_pair.public.modulus.bits()) {
4096 => .rsa_pss_rsae_sha512,
3072 => .rsa_pss_rsae_sha384,
else => .rsa_pss_rsae_sha256,
},
.key = .{ .rsa = key_pair },
};
},
.X9_62_id_ecPublicKey => {
const key = try der.Element.parse(buf, key_int.slice.end);
const algo_param = try der.Element.parse(buf, algo_cat.slice.end);
const named_curve = try Certificate.parseNamedCurve(buf, algo_param);
return .{
.signature_scheme = signatureScheme(named_curve),
.key = .{ .ecdsa = ecdsaKey(buf, key) },
};
},
else => unreachable,
}
}
// References:
// https://asn1js.eu/#MHcCAQEEINJSRKv8kSKEzLHptfAlg-LGh4_pHHlq0XLf30Q9pcztoAoGCCqGSM49AwEHoUQDQgAEJpmLyp8aGCgyMcFIJaIq_-4V1K6nPpeoih3bT2npeplF9eyXj7rm8eW9Ua6VLhq71mqtMC-YLm-IkORBVq1cuA
// https://www.rfc-editor.org/rfc/rfc5915
pub fn parseEcDer(bytes: []const u8) !PrivateKey {
const pki_msg = try der.Element.parse(bytes, 0);
const version = try der.Element.parse(bytes, pki_msg.slice.start);
const key = try der.Element.parse(bytes, version.slice.end);
const parameters = try der.Element.parse(bytes, key.slice.end);
const curve = try der.Element.parse(bytes, parameters.slice.start);
const named_curve = try Certificate.parseNamedCurve(bytes, curve);
return .{
.signature_scheme = signatureScheme(named_curve),
.key = .{ .ecdsa = ecdsaKey(bytes, key) },
};
}
fn signatureScheme(named_curve: Certificate.NamedCurve) proto.SignatureScheme {
return switch (named_curve) {
.X9_62_prime256v1 => .ecdsa_secp256r1_sha256,
.secp384r1 => .ecdsa_secp384r1_sha384,
.secp521r1 => .ecdsa_secp521r1_sha512,
};
}
fn ecdsaKey(bytes: []const u8, e: der.Element) [max_ecdsa_key_len]u8 {
const data = content(bytes, e);
var ecdsa_key: [max_ecdsa_key_len]u8 = undefined;
@memcpy(ecdsa_key[0..data.len], data);
return ecdsa_key;
}
fn content(bytes: []const u8, e: der.Element) []const u8 {
return bytes[e.slice.start..e.slice.end];
}
const testing = std.testing;
const testu = @import("testu.zig");
test "parse ec pem" {
const data = @embedFile("testdata/ec_private_key.pem");
var pk = try parsePem(data);
const priv_key = &testu.hexToBytes(
\\ 10 35 3d ca 1b 15 1d 06 aa 71 b8 ef f3 19 22
\\ 43 78 f3 20 98 1e b1 2f 2b 64 7e 71 d0 30 2a
\\ 90 aa e5 eb 99 c3 90 65 3d c1 26 19 be 3f 08
\\ 20 9b 01
);
try testing.expectEqualSlices(u8, priv_key, pk.key.ecdsa[0..priv_key.len]);
try testing.expectEqual(.ecdsa_secp384r1_sha384, pk.signature_scheme);
}
test "parse ec prime256v1" {
const data = @embedFile("testdata/ec_prime256v1_private_key.pem");
var pk = try parsePem(data);
const priv_key = &testu.hexToBytes(
\\ d2 52 44 ab fc 91 22 84 cc b1 e9 b5 f0 25 83
\\ e2 c6 87 8f e9 1c 79 6a d1 72 df df 44 3d a5
\\ cc ed
);
try testing.expectEqualSlices(u8, priv_key, pk.key.ecdsa[0..priv_key.len]);
try testing.expectEqual(.ecdsa_secp256r1_sha256, pk.signature_scheme);
}
test "parse ec secp384r1" {
const data = @embedFile("testdata/ec_secp384r1_private_key.pem");
var pk = try parsePem(data);
const priv_key = &testu.hexToBytes(
\\ ee 6d 8a 5e 0d d3 b0 c6 4b 32 40 80 e2 3a de
\\ 8b 1e dd e2 92 db 36 1c db 91 ea ba a1 06 0d
\\ 42 2d d9 a9 dc 05 43 29 f1 78 7c f9 08 af c5
\\ 03 1f 6d
);
try testing.expectEqualSlices(u8, priv_key, pk.key.ecdsa[0..priv_key.len]);
try testing.expectEqual(.ecdsa_secp384r1_sha384, pk.signature_scheme);
}
test "parse ec secp521r1" {
const data = @embedFile("testdata/ec_secp521r1_private_key.pem");
var pk = try parsePem(data);
const priv_key = &testu.hexToBytes(
\\ 01 f0 2f 5a c7 24 18 ea 68 23 8c 2e a1 b4 b8
\\ dc f2 11 b2 96 b0 ec 87 80 42 bf de ba f4 96
\\ 83 8f 9b db c6 60 a7 4c d9 60 3a e4 ba 0b df
\\ ae 24 d3 1b c2 6e 82 a0 88 c1 ed 17 20 0d 3a
\\ f1 c5 7e e8 0b 27
);
try testing.expectEqualSlices(u8, priv_key, pk.key.ecdsa[0..priv_key.len]);
try testing.expectEqual(.ecdsa_secp521r1_sha512, pk.signature_scheme);
}
test "parse rsa pem" {
const data = @embedFile("testdata/rsa_private_key.pem");
const pk = try parsePem(data);
// expected results from:
// $ openssl pkey -in testdata/rsa_private_key.pem -text -noout
const modulus = &testu.hexToBytes(
\\ 00 de f7 23 e6 75 cc 6f dd d5 6e 0f 8c 09 f8
\\ 62 e3 60 1b c0 7d 8c d5 04 50 2c 36 e2 3b f7
\\ 33 9f a1 14 af be cf 1a 0f 4c f5 cb 39 70 0e
\\ 3b 97 d6 21 f7 48 91 79 ca 7c 68 fc ea 62 a1
\\ 5a 72 4f 78 57 0e cc f2 a3 50 05 f1 4c ca 51
\\ 73 10 9a 18 8e 71 f5 b4 c7 3e be 4c ef 37 d4
\\ 84 4b 82 1c ec 08 a3 cc 07 3d 5c 0b e5 85 3f
\\ fe b6 44 77 8f 3c 6a 2f 33 c3 5d f6 f2 29 46
\\ 04 25 7e 05 d9 f8 3b 2d a4 40 66 9f 0d 6d 1a
\\ fa bc 0a c5 8b 86 43 30 ef 14 20 41 9d b5 cc
\\ 3e 63 b5 48 04 27 c9 5c d3 62 28 5f f5 b6 e4
\\ 77 49 99 ac 84 4a a6 67 a5 9a 1a 37 c7 60 4c
\\ ba c1 70 cf 57 64 4a 21 ea 05 53 10 ec 94 71
\\ 4a 43 04 83 00 aa 5a 28 bc f2 8c 58 14 92 d2
\\ 83 17 f4 7b 29 0f e7 87 a2 47 b2 53 19 12 23
\\ fb 4b ce 5a f8 a1 84 f9 b1 f3 bf e3 fa 10 f8
\\ ad af 87 ce 03 0e a0 2c 13 71 57 c4 55 44 48
\\ 44 cb
);
const public_exponent = &testu.hexToBytes("01 00 01");
const private_exponent = &testu.hexToBytes(
\\ 50 3b 80 98 aa a5 11 50 33 40 32 aa 02 e0 75
\\ bd 3a 55 62 34 0b 9c 8f bb c5 dd 4e 15 a4 03
\\ d8 9a 5f 56 4a 84 3d ed 69 95 3d 37 03 02 ac
\\ 21 1c 36 06 c4 ff 4c 63 37 d7 93 c3 48 10 a5
\\ fa 62 6c 7c 6f 60 02 a4 0f e4 c3 8b 0d 76 b7
\\ c0 2e a3 4d 86 e6 92 d1 eb db 10 d6 38 31 ea
\\ 15 3d d1 e8 81 c7 67 60 e7 8c 9a df 51 ce d0
\\ 7a 88 32 b9 c1 54 b8 7d 98 fc d4 23 1a 05 0e
\\ f2 ea e1 72 29 28 2a 68 b7 90 18 80 1c 21 d6
\\ 36 a8 6b 4a 9c dd 14 b8 9f 85 ee 95 0b f4 c6
\\ 17 02 aa 4d ea 4d f9 39 d7 dd 9d b4 1d d2 f8
\\ 92 46 0f 18 41 80 f4 ea 27 55 29 f8 37 59 bf
\\ 43 ec a3 eb 19 ba bc 13 06 95 3d 25 4b c9 72
\\ cf 41 0a 6f aa cb 79 d4 7b fa b1 09 7c e2 2f
\\ 85 51 44 8b c6 97 8e 46 f9 6b ac 08 87 92 ce
\\ af 0b bf 8c bd 27 51 8f 09 e4 d3 f9 04 ac fa
\\ f2 04 70 3e d9 a6 28 17 c2 2d 74 e9 25 40 02
\\ 49
);
try testing.expectEqual(.rsa_pss_rsae_sha256, pk.signature_scheme);
const kp = pk.key.rsa;
{
var bytes: [modulus.len]u8 = undefined;
try kp.public.modulus.toBytes(&bytes, .big);
try testing.expectEqualSlices(u8, modulus, &bytes);
}
{
var bytes: [private_exponent.len]u8 = undefined;
try kp.public.public_exponent.toBytes(&bytes, .big);
try testing.expectEqualSlices(u8, public_exponent, bytes[bytes.len - public_exponent.len .. bytes.len]);
}
{
var btytes: [private_exponent.len]u8 = undefined;
try kp.secret.private_exponent.toBytes(&btytes, .big);
try testing.expectEqualSlices(u8, private_exponent, &btytes);
}
}

View File

@@ -1,148 +0,0 @@
// This file is originally copied from: https://github.com/jedisct1/zig-cbc.
//
// It is modified then to have TLS padding insead of PKCS#7 padding.
// Reference:
// https://datatracker.ietf.org/doc/html/rfc5246/#section-6.2.3.2
// https://crypto.stackexchange.com/questions/98917/on-the-correctness-of-the-padding-example-of-rfc-5246
//
// If required padding i n bytes
// PKCS#7 padding is (n...n)
// TLS padding is (n-1...n-1) - n times of n-1 value
//
const std = @import("std");
const aes = std.crypto.core.aes;
const mem = std.mem;
const debug = std.debug;
/// CBC mode with TLS 1.2 padding
///
/// Important: the counter mode doesn't provide authenticated encryption: the ciphertext can be trivially modified without this being detected.
/// If you need authenticated encryption, use anything from `std.crypto.aead` instead.
/// If you really need to use CBC mode, make sure to use a MAC to authenticate the ciphertext.
pub fn CBC(comptime BlockCipher: anytype) type {
const EncryptCtx = aes.AesEncryptCtx(BlockCipher);
const DecryptCtx = aes.AesDecryptCtx(BlockCipher);
return struct {
const Self = @This();
enc_ctx: EncryptCtx,
dec_ctx: DecryptCtx,
/// Initialize the CBC context with the given key.
pub fn init(key: [BlockCipher.key_bits / 8]u8) Self {
const enc_ctx = BlockCipher.initEnc(key);
const dec_ctx = DecryptCtx.initFromEnc(enc_ctx);
return Self{ .enc_ctx = enc_ctx, .dec_ctx = dec_ctx };
}
/// Return the length of the ciphertext given the length of the plaintext.
pub fn paddedLength(length: usize) usize {
return (std.math.divCeil(usize, length + 1, EncryptCtx.block_length) catch unreachable) * EncryptCtx.block_length;
}
/// Encrypt the given plaintext for the given IV.
/// The destination buffer must be large enough to hold the padded plaintext.
/// Use the `paddedLength()` function to compute the ciphertext size.
/// IV must be secret and unpredictable.
pub fn encrypt(self: Self, dst: []u8, src: []const u8, iv: [EncryptCtx.block_length]u8) void {
// Note: encryption *could* be parallelized, see https://research.kudelskisecurity.com/2022/11/17/some-aes-cbc-encryption-myth-busting/
const block_length = EncryptCtx.block_length;
const padded_length = paddedLength(src.len);
debug.assert(dst.len == padded_length); // destination buffer must hold the padded plaintext
var cv = iv;
var i: usize = 0;
while (i + block_length <= src.len) : (i += block_length) {
const in = src[i..][0..block_length];
for (cv[0..], in) |*x, y| x.* ^= y;
self.enc_ctx.encrypt(&cv, &cv);
@memcpy(dst[i..][0..block_length], &cv);
}
// Last block
var in = [_]u8{0} ** block_length;
const padding_length: u8 = @intCast(padded_length - src.len - 1);
@memset(&in, padding_length);
@memcpy(in[0 .. src.len - i], src[i..]);
for (cv[0..], in) |*x, y| x.* ^= y;
self.enc_ctx.encrypt(&cv, &cv);
@memcpy(dst[i..], cv[0 .. dst.len - i]);
}
/// Decrypt the given ciphertext for the given IV.
/// The destination buffer must be large enough to hold the plaintext.
/// IV must be secret, unpredictable and match the one used for encryption.
pub fn decrypt(self: Self, dst: []u8, src: []const u8, iv: [DecryptCtx.block_length]u8) !void {
const block_length = DecryptCtx.block_length;
if (src.len != dst.len) {
return error.EncodingError;
}
debug.assert(src.len % block_length == 0);
var i: usize = 0;
var cv = iv;
var out: [block_length]u8 = undefined;
// Decryption could be parallelized
while (i + block_length <= dst.len) : (i += block_length) {
const in = src[i..][0..block_length];
self.dec_ctx.decrypt(&out, in);
for (&out, cv) |*x, y| x.* ^= y;
cv = in.*;
@memcpy(dst[i..][0..block_length], &out);
}
// Last block - We intentionally don't check the padding to mitigate timing attacks
if (i < dst.len) {
const in = src[i..][0..block_length];
@memset(&out, 0);
self.dec_ctx.decrypt(&out, in);
for (&out, cv) |*x, y| x.* ^= y;
@memcpy(dst[i..], out[0 .. dst.len - i]);
}
}
};
}
test "CBC mode" {
const M = CBC(aes.Aes128);
const key = [_]u8{ 0x2b, 0x7e, 0x15, 0x16, 0x28, 0xae, 0xd2, 0xa6, 0xab, 0xf7, 0x15, 0x88, 0x09, 0xcf, 0x4f, 0x3c };
const iv = [_]u8{ 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f };
const src_ = "This is a test of AES-CBC that goes on longer than a couple blocks. It is a somewhat long test case to type out!";
const expected = "\xA0\x8C\x09\x7D\xFF\x42\xB6\x65\x4D\x4B\xC6\x90\x90\x39\xDE\x3D\xC7\xCA\xEB\xF6\x9A\x4F\x09\x97\xC9\x32\xAB\x75\x88\xB7\x57\x17";
var res: [32]u8 = undefined;
try comptime std.testing.expect(src_.len / M.paddedLength(1) >= 3); // Ensure that we have at least 3 blocks
const z = M.init(key);
// Test encryption and decryption with distinct buffers
var h = std.crypto.hash.sha2.Sha256.init(.{});
inline for (0..src_.len) |len| {
const src = src_[0..len];
var dst = [_]u8{0} ** M.paddedLength(src.len);
z.encrypt(&dst, src, iv);
h.update(&dst);
var decrypted = [_]u8{0} ** dst.len;
try z.decrypt(&decrypted, &dst, iv);
const padding = decrypted[decrypted.len - 1] + 1;
try std.testing.expectEqualSlices(u8, src, decrypted[0 .. decrypted.len - padding]);
}
h.final(&res);
try std.testing.expectEqualSlices(u8, expected, &res);
// Test encryption and decryption with the same buffer
h = std.crypto.hash.sha2.Sha256.init(.{});
inline for (0..src_.len) |len| {
var buf = [_]u8{0} ** M.paddedLength(len);
@memcpy(buf[0..len], src_[0..len]);
z.encrypt(&buf, buf[0..len], iv);
h.update(&buf);
try z.decrypt(&buf, &buf, iv);
try std.testing.expectEqualSlices(u8, src_[0..len], buf[0..len]);
}
h.final(&res);
try std.testing.expectEqualSlices(u8, expected, &res);
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,665 +0,0 @@
const std = @import("std");
const assert = std.debug.assert;
const proto = @import("protocol.zig");
const record = @import("record.zig");
const cipher = @import("cipher.zig");
const Cipher = cipher.Cipher;
const async_io = @import("../std/http/Client.zig");
const Cbk = async_io.Cbk;
const Ctx = async_io.Ctx;
pub fn connection(stream: anytype) Connection(@TypeOf(stream)) {
return .{
.stream = stream,
.rec_rdr = record.reader(stream),
};
}
pub fn Connection(comptime Stream: type) type {
return struct {
stream: Stream, // underlying stream
rec_rdr: record.Reader(Stream),
cipher: Cipher = undefined,
max_encrypt_seq: u64 = std.math.maxInt(u64) - 1,
key_update_requested: bool = false,
read_buf: []const u8 = "",
received_close_notify: bool = false,
const Self = @This();
/// Encrypts and writes single tls record to the stream.
fn writeRecord(c: *Self, content_type: proto.ContentType, bytes: []const u8) !void {
assert(bytes.len <= cipher.max_cleartext_len);
var write_buf: [cipher.max_ciphertext_record_len]u8 = undefined;
// If key update is requested send key update message and update
// my encryption keys.
if (c.cipher.encryptSeq() >= c.max_encrypt_seq or @atomicLoad(bool, &c.key_update_requested, .monotonic)) {
@atomicStore(bool, &c.key_update_requested, false, .monotonic);
// If the request_update field is set to "update_requested",
// then the receiver MUST send a KeyUpdate of its own with
// request_update set to "update_not_requested" prior to sending
// its next Application Data record. This mechanism allows
// either side to force an update to the entire connection, but
// causes an implementation which receives multiple KeyUpdates
// while it is silent to respond with a single update.
//
// rfc: https://datatracker.ietf.org/doc/html/rfc8446#autoid-57
const key_update = &record.handshakeHeader(.key_update, 1) ++ [_]u8{0};
const rec = try c.cipher.encrypt(&write_buf, .handshake, key_update);
try c.stream.writeAll(rec);
try c.cipher.keyUpdateEncrypt();
}
const rec = try c.cipher.encrypt(&write_buf, content_type, bytes);
try c.stream.writeAll(rec);
}
fn writeAlert(c: *Self, err: anyerror) !void {
const cleartext = proto.alertFromError(err);
var buf: [128]u8 = undefined;
const ciphertext = try c.cipher.encrypt(&buf, .alert, &cleartext);
c.stream.writeAll(ciphertext) catch {};
}
/// Returns next record of cleartext data.
/// Can be used in iterator like loop without memcpy to another buffer:
/// while (try client.next()) |buf| { ... }
pub fn next(c: *Self) ReadError!?[]const u8 {
const content_type, const data = c.nextRecord() catch |err| {
try c.writeAlert(err);
return err;
} orelse return null;
if (content_type != .application_data) return error.TlsUnexpectedMessage;
return data;
}
fn nextRecord(c: *Self) ReadError!?struct { proto.ContentType, []const u8 } {
if (c.eof()) return null;
while (true) {
const content_type, const cleartext = try c.rec_rdr.nextDecrypt(&c.cipher) orelse return null;
switch (content_type) {
.application_data => {},
.handshake => {
const handshake_type: proto.Handshake = @enumFromInt(cleartext[0]);
switch (handshake_type) {
// skip new session ticket and read next record
.new_session_ticket => continue,
.key_update => {
if (cleartext.len != 5) return error.TlsDecodeError;
// rfc: Upon receiving a KeyUpdate, the receiver MUST
// update its receiving keys.
try c.cipher.keyUpdateDecrypt();
const key: proto.KeyUpdateRequest = @enumFromInt(cleartext[4]);
switch (key) {
.update_requested => {
@atomicStore(bool, &c.key_update_requested, true, .monotonic);
},
.update_not_requested => {},
else => return error.TlsIllegalParameter,
}
// this record is handled read next
continue;
},
else => {},
}
},
.alert => {
if (cleartext.len < 2) return error.TlsUnexpectedMessage;
try proto.Alert.parse(cleartext[0..2].*).toError();
// server side clean shutdown
c.received_close_notify = true;
return null;
},
else => return error.TlsUnexpectedMessage,
}
return .{ content_type, cleartext };
}
}
pub fn eof(c: *Self) bool {
return c.received_close_notify and c.read_buf.len == 0;
}
pub fn close(c: *Self) !void {
if (c.received_close_notify) return;
try c.writeRecord(.alert, &proto.Alert.closeNotify());
}
// read, write interface
pub const ReadError = Stream.ReadError || proto.Alert.Error ||
error{
TlsBadVersion,
TlsUnexpectedMessage,
TlsRecordOverflow,
TlsDecryptError,
TlsDecodeError,
TlsBadRecordMac,
TlsIllegalParameter,
BufferOverflow,
};
pub const WriteError = Stream.WriteError ||
error{
BufferOverflow,
TlsUnexpectedMessage,
};
pub const Reader = std.io.Reader(*Self, ReadError, read);
pub const Writer = std.io.Writer(*Self, WriteError, write);
pub fn reader(c: *Self) Reader {
return .{ .context = c };
}
pub fn writer(c: *Self) Writer {
return .{ .context = c };
}
/// Encrypts cleartext and writes it to the underlying stream as single
/// tls record. Max single tls record payload length is 1<<14 (16K)
/// bytes.
pub fn write(c: *Self, bytes: []const u8) WriteError!usize {
const n = @min(bytes.len, cipher.max_cleartext_len);
try c.writeRecord(.application_data, bytes[0..n]);
return n;
}
/// Encrypts cleartext and writes it to the underlying stream. If needed
/// splits cleartext into multiple tls record.
pub fn writeAll(c: *Self, bytes: []const u8) WriteError!void {
var index: usize = 0;
while (index < bytes.len) {
index += try c.write(bytes[index..]);
}
}
pub fn read(c: *Self, buffer: []u8) ReadError!usize {
if (c.read_buf.len == 0) {
c.read_buf = try c.next() orelse return 0;
}
const n = @min(c.read_buf.len, buffer.len);
@memcpy(buffer[0..n], c.read_buf[0..n]);
c.read_buf = c.read_buf[n..];
return n;
}
/// Returns the number of bytes read. If the number read is smaller than
/// `buffer.len`, it means the stream reached the end.
pub fn readAll(c: *Self, buffer: []u8) ReadError!usize {
return c.readAtLeast(buffer, buffer.len);
}
/// Returns the number of bytes read, calling the underlying read function
/// the minimal number of times until the buffer has at least `len` bytes
/// filled. If the number read is less than `len` it means the stream
/// reached the end.
pub fn readAtLeast(c: *Self, buffer: []u8, len: usize) ReadError!usize {
assert(len <= buffer.len);
var index: usize = 0;
while (index < len) {
const amt = try c.read(buffer[index..]);
if (amt == 0) break;
index += amt;
}
return index;
}
/// Returns the number of bytes read. If the number read is less than
/// the space provided it means the stream reached the end.
pub fn readv(c: *Self, iovecs: []std.posix.iovec) !usize {
var vp: VecPut = .{ .iovecs = iovecs };
while (true) {
if (c.read_buf.len == 0) {
c.read_buf = try c.next() orelse break;
}
const n = vp.put(c.read_buf);
const read_buf_len = c.read_buf.len;
c.read_buf = c.read_buf[n..];
if ((n < read_buf_len) or
(n == read_buf_len and !c.rec_rdr.hasMore()))
break;
}
return vp.total;
}
fn onWriteAll(ctx: *Ctx, res: anyerror!void) anyerror!void {
res catch |err| return ctx.pop(err);
if (ctx._tls_write_bytes.len - ctx._tls_write_index > 0) {
const rec = ctx.conn().tls_client.prepareRecord(ctx.stream(), ctx) catch |err| return ctx.pop(err);
return ctx.stream().async_writeAll(rec, ctx, onWriteAll) catch |err| return ctx.pop(err);
}
return ctx.pop({});
}
pub fn async_writeAll(c: *Self, stream: anytype, bytes: []const u8, ctx: *Ctx, comptime cbk: Cbk) !void {
assert(bytes.len <= cipher.max_cleartext_len);
ctx._tls_write_bytes = bytes;
ctx._tls_write_index = 0;
const rec = try c.prepareRecord(stream, ctx);
try ctx.push(cbk);
return stream.async_writeAll(rec, ctx, onWriteAll);
}
fn prepareRecord(c: *Self, stream: anytype, ctx: *Ctx) ![]const u8 {
const len = @min(ctx._tls_write_bytes.len - ctx._tls_write_index, cipher.max_cleartext_len);
// If key update is requested send key update message and update
// my encryption keys.
if (c.cipher.encryptSeq() >= c.max_encrypt_seq or @atomicLoad(bool, &c.key_update_requested, .monotonic)) {
@atomicStore(bool, &c.key_update_requested, false, .monotonic);
// If the request_update field is set to "update_requested",
// then the receiver MUST send a KeyUpdate of its own with
// request_update set to "update_not_requested" prior to sending
// its next Application Data record. This mechanism allows
// either side to force an update to the entire connection, but
// causes an implementation which receives multiple KeyUpdates
// while it is silent to respond with a single update.
//
// rfc: https://datatracker.ietf.org/doc/html/rfc8446#autoid-57
const key_update = &record.handshakeHeader(.key_update, 1) ++ [_]u8{0};
const rec = try c.cipher.encrypt(&ctx._tls_write_buf, .handshake, key_update);
try stream.writeAll(rec); // TODO async
try c.cipher.keyUpdateEncrypt();
}
defer ctx._tls_write_index += len;
return c.cipher.encrypt(&ctx._tls_write_buf, .application_data, ctx._tls_write_bytes[ctx._tls_write_index..len]);
}
fn onReadv(ctx: *Ctx, res: anyerror!void) anyerror!void {
res catch |err| return ctx.pop(err);
if (ctx._tls_read_buf == null) {
// end of read
ctx.setLen(ctx._vp.total);
return ctx.pop({});
}
while (true) {
const n = ctx._vp.put(ctx._tls_read_buf.?);
const read_buf_len = ctx._tls_read_buf.?.len;
const c = ctx.conn().tls_client;
if (read_buf_len == 0) {
// read another buffer
return c.async_next(ctx.stream(), ctx, onReadv) catch |err| return ctx.pop(err);
}
ctx._tls_read_buf = ctx._tls_read_buf.?[n..];
if ((n < read_buf_len) or (n == read_buf_len and !c.rec_rdr.hasMore())) {
// end of read
ctx.setLen(ctx._vp.total);
return ctx.pop({});
}
}
}
pub fn async_readv(c: *Self, stream: anytype, iovecs: []std.posix.iovec, ctx: *Ctx, comptime cbk: Cbk) !void {
try ctx.push(cbk);
ctx._vp = .{ .iovecs = iovecs };
return c.async_next(stream, ctx, onReadv);
}
fn onNext(ctx: *Ctx, res: anyerror!void) anyerror!void {
res catch |err| {
ctx.conn().tls_client.writeAlert(err) catch |e| std.log.err("onNext: write alert: {any}", .{e}); // TODO async
return ctx.pop(err);
};
if (ctx._tls_read_content_type != .application_data) {
return ctx.pop(error.TlsUnexpectedMessage);
}
return ctx.pop({});
}
pub fn async_next(c: *Self, stream: anytype, ctx: *Ctx, comptime cbk: Cbk) !void {
try ctx.push(cbk);
return c.async_next_decrypt(stream, ctx, onNext);
}
pub fn onNextDecrypt(ctx: *Ctx, res: anyerror!void) anyerror!void {
res catch |err| return ctx.pop(err);
const c = ctx.conn().tls_client;
// TOOD not sure if this works in my async case...
if (c.eof()) {
ctx._tls_read_buf = null;
return ctx.pop({});
}
const content_type = ctx._tls_read_content_type;
switch (content_type) {
.application_data => {},
.handshake => {
const handshake_type: proto.Handshake = @enumFromInt(ctx._tls_read_buf.?[0]);
switch (handshake_type) {
// skip new session ticket and read next record
.new_session_ticket => return c.async_next_record(ctx.stream(), ctx, onNextDecrypt) catch |err| return ctx.pop(err),
.key_update => {
if (ctx._tls_read_buf.?.len != 5) return ctx.pop(error.TlsDecodeError);
// rfc: Upon receiving a KeyUpdate, the receiver MUST
// update its receiving keys.
try c.cipher.keyUpdateDecrypt();
const key: proto.KeyUpdateRequest = @enumFromInt(ctx._tls_read_buf.?[4]);
switch (key) {
.update_requested => {
@atomicStore(bool, &c.key_update_requested, true, .monotonic);
},
.update_not_requested => {},
else => return ctx.pop(error.TlsIllegalParameter),
}
// this record is handled read next
c.async_next_record(ctx.stream(), ctx, onNextDecrypt) catch |err| return ctx.pop(err);
},
else => {},
}
},
.alert => {
if (ctx._tls_read_buf.?.len < 2) return ctx.pop(error.TlsUnexpectedMessage);
try proto.Alert.parse(ctx._tls_read_buf.?[0..2].*).toError();
// server side clean shutdown
c.received_close_notify = true;
ctx._tls_read_buf = null;
return ctx.pop({});
},
else => return ctx.pop(error.TlsUnexpectedMessage),
}
return ctx.pop({});
}
pub fn async_next_decrypt(c: *Self, stream: anytype, ctx: *Ctx, comptime cbk: Cbk) !void {
try ctx.push(cbk);
return c.async_next_record(stream, ctx, onNextDecrypt) catch |err| return ctx.pop(err);
}
pub fn onNextRecord(ctx: *Ctx, res: anyerror!void) anyerror!void {
res catch |err| return ctx.pop(err);
const rec = ctx._tls_read_record orelse {
ctx._tls_read_buf = null;
return ctx.pop({});
};
if (rec.protocol_version != .tls_1_2) return error.TlsBadVersion;
const c = ctx.conn().tls_client;
const cph = &c.cipher;
ctx._tls_read_content_type, ctx._tls_read_buf = cph.decrypt(
// Reuse reader buffer for cleartext. `rec.header` and
// `rec.payload`(ciphertext) are also pointing somewhere in
// this buffer. Decrypter is first reading then writing a
// block, cleartext has less length then ciphertext,
// cleartext starts from the beginning of the buffer, so
// ciphertext is always ahead of cleartext.
c.rec_rdr.buffer[0..c.rec_rdr.start],
rec,
) catch |err| return ctx.pop(err);
return ctx.pop({});
}
pub fn async_next_record(c: *Self, stream: anytype, ctx: *Ctx, comptime cbk: Cbk) !void {
try ctx.push(cbk);
return c.async_reader_next(stream, ctx, onNextRecord);
}
pub fn onReaderNext(ctx: *Ctx, res: anyerror!void) anyerror!void {
res catch |err| return ctx.pop(err);
const c = ctx.conn().tls_client;
const n = ctx.len();
if (n == 0) {
ctx._tls_read_record = null;
return ctx.pop({});
}
c.rec_rdr.end += n;
return c.readNext(ctx);
}
pub fn readNext(c: *Self, ctx: *Ctx) anyerror!void {
const buffer = c.rec_rdr.buffer[c.rec_rdr.start..c.rec_rdr.end];
// If we have 5 bytes header.
if (buffer.len >= record.header_len) {
const record_header = buffer[0..record.header_len];
const payload_len = std.mem.readInt(u16, record_header[3..5], .big);
if (payload_len > cipher.max_ciphertext_len)
return error.TlsRecordOverflow;
const record_len = record.header_len + payload_len;
// If we have whole record
if (buffer.len >= record_len) {
c.rec_rdr.start += record_len;
ctx._tls_read_record = record.Record.init(buffer[0..record_len]);
return ctx.pop({});
}
}
{ // Move dirty part to the start of the buffer.
const n = c.rec_rdr.end - c.rec_rdr.start;
if (n > 0 and c.rec_rdr.start > 0) {
if (c.rec_rdr.start > n) {
@memcpy(c.rec_rdr.buffer[0..n], c.rec_rdr.buffer[c.rec_rdr.start..][0..n]);
} else {
std.mem.copyForwards(u8, c.rec_rdr.buffer[0..n], c.rec_rdr.buffer[c.rec_rdr.start..][0..n]);
}
}
c.rec_rdr.start = 0;
c.rec_rdr.end = n;
}
// Read more from inner_reader.
return ctx.stream()
.async_read(c.rec_rdr.buffer[c.rec_rdr.end..], ctx, onReaderNext) catch |err| return ctx.pop(err);
}
pub fn async_reader_next(c: *Self, _: anytype, ctx: *Ctx, comptime cbk: Cbk) !void {
try ctx.push(cbk);
return c.readNext(ctx);
}
};
}
const testing = std.testing;
const data12 = @import("testdata/tls12.zig");
const testu = @import("testu.zig");
test "encrypt decrypt" {
var output_buf: [1024]u8 = undefined;
const stream = testu.Stream.init(&(data12.server_pong ** 3), &output_buf);
var conn: Connection(@TypeOf(stream)) = .{ .stream = stream, .rec_rdr = record.reader(stream) };
conn.cipher = try Cipher.initTls12(.ECDHE_RSA_WITH_AES_128_CBC_SHA, &data12.key_material, .client);
conn.cipher.ECDHE_RSA_WITH_AES_128_CBC_SHA.rnd = testu.random(0); // use fixed rng
conn.stream.output.reset();
{ // encrypt verify data from example
_ = testu.random(0x40); // sets iv to 40, 41, ... 4f
try conn.writeRecord(.handshake, &data12.client_finished);
try testing.expectEqualSlices(u8, &data12.verify_data_encrypted_msg, conn.stream.output.getWritten());
}
conn.stream.output.reset();
{ // encrypt ping
const cleartext = "ping";
_ = testu.random(0); // sets iv to 00, 01, ... 0f
//conn.encrypt_seq = 1;
try conn.writeAll(cleartext);
try testing.expectEqualSlices(u8, &data12.encrypted_ping_msg, conn.stream.output.getWritten());
}
{ // decrypt server pong message
conn.cipher.ECDHE_RSA_WITH_AES_128_CBC_SHA.decrypt_seq = 1;
try testing.expectEqualStrings("pong", (try conn.next()).?);
}
{ // test reader interface
conn.cipher.ECDHE_RSA_WITH_AES_128_CBC_SHA.decrypt_seq = 1;
var rdr = conn.reader();
var buffer: [4]u8 = undefined;
const n = try rdr.readAll(&buffer);
try testing.expectEqualStrings("pong", buffer[0..n]);
}
{ // test readv interface
conn.cipher.ECDHE_RSA_WITH_AES_128_CBC_SHA.decrypt_seq = 1;
var buffer: [9]u8 = undefined;
var iovecs = [_]std.posix.iovec{
.{ .base = &buffer, .len = 3 },
.{ .base = buffer[3..], .len = 3 },
.{ .base = buffer[6..], .len = 3 },
};
const n = try conn.readv(iovecs[0..]);
try testing.expectEqual(4, n);
try testing.expectEqualStrings("pong", buffer[0..n]);
}
}
// Copied from: https://github.com/ziglang/zig/blob/455899668b620dfda40252501c748c0a983555bd/lib/std/crypto/tls/Client.zig#L1354
/// Abstraction for sending multiple byte buffers to a slice of iovecs.
pub const VecPut = struct {
iovecs: []const std.posix.iovec,
idx: usize = 0,
off: usize = 0,
total: usize = 0,
/// Returns the amount actually put which is always equal to bytes.len
/// unless the vectors ran out of space.
pub fn put(vp: *VecPut, bytes: []const u8) usize {
if (vp.idx >= vp.iovecs.len) return 0;
var bytes_i: usize = 0;
while (true) {
const v = vp.iovecs[vp.idx];
const dest = v.base[vp.off..v.len];
const src = bytes[bytes_i..][0..@min(dest.len, bytes.len - bytes_i)];
@memcpy(dest[0..src.len], src);
bytes_i += src.len;
vp.off += src.len;
if (vp.off >= v.len) {
vp.off = 0;
vp.idx += 1;
if (vp.idx >= vp.iovecs.len) {
vp.total += bytes_i;
return bytes_i;
}
}
if (bytes_i >= bytes.len) {
vp.total += bytes_i;
return bytes_i;
}
}
}
};
test "client/server connection" {
const BufReaderWriter = struct {
buf: []u8,
wp: usize = 0,
rp: usize = 0,
const Self = @This();
pub fn write(self: *Self, bytes: []const u8) !usize {
if (self.wp == self.buf.len) return error.NoSpaceLeft;
const n = @min(bytes.len, self.buf.len - self.wp);
@memcpy(self.buf[self.wp..][0..n], bytes[0..n]);
self.wp += n;
return n;
}
pub fn writeAll(self: *Self, bytes: []const u8) !void {
var n: usize = 0;
while (n < bytes.len) {
n += try self.write(bytes[n..]);
}
}
pub fn read(self: *Self, bytes: []u8) !usize {
const n = @min(bytes.len, self.wp - self.rp);
if (n == 0) return 0;
@memcpy(bytes[0..n], self.buf[self.rp..][0..n]);
self.rp += n;
if (self.rp == self.wp) {
self.wp = 0;
self.rp = 0;
}
return n;
}
};
const TestStream = struct {
inner_stream: *BufReaderWriter,
const Self = @This();
pub const ReadError = error{};
pub const WriteError = error{NoSpaceLeft};
pub fn read(self: *Self, bytes: []u8) !usize {
return try self.inner_stream.read(bytes);
}
pub fn writeAll(self: *Self, bytes: []const u8) !void {
return try self.inner_stream.writeAll(bytes);
}
};
const buf_len = 32 * 1024;
const tls_records_in_buf = (std.math.divCeil(comptime_int, buf_len, cipher.max_cleartext_len) catch unreachable);
const overhead: usize = tls_records_in_buf * @import("cipher.zig").encrypt_overhead_tls_13;
var buf: [buf_len + overhead]u8 = undefined;
var inner_stream = BufReaderWriter{ .buf = &buf };
const cipher_client, const cipher_server = brk: {
const Transcript = @import("transcript.zig").Transcript;
const CipherSuite = @import("cipher.zig").CipherSuite;
const cipher_suite: CipherSuite = .AES_256_GCM_SHA384;
var rnd: [128]u8 = undefined;
std.crypto.random.bytes(&rnd);
const secret = Transcript.Secret{
.client = rnd[0..64],
.server = rnd[64..],
};
break :brk .{
try Cipher.initTls13(cipher_suite, secret, .client),
try Cipher.initTls13(cipher_suite, secret, .server),
};
};
var conn1 = connection(TestStream{ .inner_stream = &inner_stream });
conn1.cipher = cipher_client;
var conn2 = connection(TestStream{ .inner_stream = &inner_stream });
conn2.cipher = cipher_server;
var prng = std.Random.DefaultPrng.init(0);
const random = prng.random();
var send_buf: [buf_len]u8 = undefined;
var recv_buf: [buf_len]u8 = undefined;
random.bytes(&send_buf); // fill send buffer with random bytes
for (0..16) |_| {
const n = buf_len; //random.uintLessThan(usize, buf_len);
const sent = send_buf[0..n];
try conn1.writeAll(sent);
const r = try conn2.readAll(&recv_buf);
const received = recv_buf[0..r];
try testing.expectEqual(n, r);
try testing.expectEqualSlices(u8, sent, received);
}
}

View File

@@ -1,955 +0,0 @@
const std = @import("std");
const assert = std.debug.assert;
const crypto = std.crypto;
const mem = std.mem;
const Certificate = crypto.Certificate;
const cipher = @import("cipher.zig");
const Cipher = cipher.Cipher;
const CipherSuite = cipher.CipherSuite;
const cipher_suites = cipher.cipher_suites;
const Transcript = @import("transcript.zig").Transcript;
const record = @import("record.zig");
const rsa = @import("rsa/rsa.zig");
const key_log = @import("key_log.zig");
const PrivateKey = @import("PrivateKey.zig");
const proto = @import("protocol.zig");
const common = @import("handshake_common.zig");
const dupe = common.dupe;
const CertificateBuilder = common.CertificateBuilder;
const CertificateParser = common.CertificateParser;
const DhKeyPair = common.DhKeyPair;
const CertBundle = common.CertBundle;
const CertKeyPair = common.CertKeyPair;
pub const Options = struct {
host: []const u8,
/// Set of root certificate authorities that clients use when verifying
/// server certificates.
root_ca: CertBundle,
/// Controls whether a client verifies the server's certificate chain and
/// host name.
insecure_skip_verify: bool = false,
/// List of cipher suites to use.
/// To use just tls 1.3 cipher suites:
/// .cipher_suites = &tls.CipherSuite.tls13,
/// To select particular cipher suite:
/// .cipher_suites = &[_]tls.CipherSuite{tls.CipherSuite.CHACHA20_POLY1305_SHA256},
cipher_suites: []const CipherSuite = cipher_suites.all,
/// List of named groups to use.
/// To use specific named group:
/// .named_groups = &[_]tls.NamedGroup{.secp384r1},
named_groups: []const proto.NamedGroup = supported_named_groups,
/// Client authentication certificates and private key.
auth: ?CertKeyPair = null,
/// If this structure is provided it will be filled with handshake attributes
/// at the end of the handshake process.
diagnostic: ?*Diagnostic = null,
/// For logging current connection tls keys, so we can share them with
/// Wireshark and analyze decrypted traffic there.
key_log_callback: ?key_log.Callback = null,
pub const Diagnostic = struct {
tls_version: proto.Version = @enumFromInt(0),
cipher_suite_tag: CipherSuite = @enumFromInt(0),
named_group: proto.NamedGroup = @enumFromInt(0),
signature_scheme: proto.SignatureScheme = @enumFromInt(0),
client_signature_scheme: proto.SignatureScheme = @enumFromInt(0),
};
};
const supported_named_groups = &[_]proto.NamedGroup{
.x25519,
.secp256r1,
.secp384r1,
.x25519_kyber768d00,
};
/// Handshake parses tls server message and creates client messages. Collects
/// tls attributes: server random, cipher suite and so on. Client messages are
/// created using provided buffer. Provided record reader is used to get tls
/// record when needed.
pub fn Handshake(comptime Stream: type) type {
const RecordReaderT = record.Reader(Stream);
return struct {
client_random: [32]u8,
server_random: [32]u8 = undefined,
master_secret: [48]u8 = undefined,
key_material: [48 * 4]u8 = undefined, // for sha256 32 * 4 is filled, for sha384 48 * 4
transcript: Transcript = .{},
cipher_suite: CipherSuite = @enumFromInt(0),
named_group: ?proto.NamedGroup = null,
dh_kp: DhKeyPair,
rsa_secret: RsaSecret,
tls_version: proto.Version = .tls_1_2,
cipher: Cipher = undefined,
cert: CertificateParser = undefined,
client_certificate_requested: bool = false,
// public key len: x25519 = 32, secp256r1 = 65, secp384r1 = 97, x25519_kyber768d00 = 1120
server_pub_key_buf: [2048]u8 = undefined,
server_pub_key: []const u8 = undefined,
rec_rdr: *RecordReaderT, // tls record reader
buffer: []u8, // scratch buffer used in all messages creation
const HandshakeT = @This();
pub fn init(buf: []u8, rec_rdr: *RecordReaderT) HandshakeT {
return .{
.client_random = undefined,
.dh_kp = undefined,
.rsa_secret = undefined,
//.now_sec = std.time.timestamp(),
.buffer = buf,
.rec_rdr = rec_rdr,
};
}
fn initKeys(
h: *HandshakeT,
named_groups: []const proto.NamedGroup,
) !void {
const init_keys_buf_len = 32 + 46 + DhKeyPair.seed_len;
var buf: [init_keys_buf_len]u8 = undefined;
crypto.random.bytes(&buf);
h.client_random = buf[0..32].*;
h.rsa_secret = RsaSecret.init(buf[32..][0..46].*);
h.dh_kp = try DhKeyPair.init(buf[32 + 46 ..][0..DhKeyPair.seed_len].*, named_groups);
}
/// Handshake exchanges messages with server to get agreement about
/// cryptographic parameters. That upgrades existing client-server
/// connection to TLS connection. Returns cipher used in application for
/// encrypted message exchange.
///
/// Handles TLS 1.2 and TLS 1.3 connections. After initial client hello
/// server chooses in its server hello which TLS version will be used.
///
/// TLS 1.2 handshake messages exchange:
/// Client Server
/// --------------------------------------------------------------
/// ClientHello client flight 1 --->
/// ServerHello
/// Certificate
/// ServerKeyExchange
/// CertificateRequest*
/// <--- server flight 1 ServerHelloDone
/// Certificate*
/// ClientKeyExchange
/// CertificateVerify*
/// ChangeCipherSpec
/// Finished client flight 2 --->
/// ChangeCipherSpec
/// <--- server flight 2 Finished
///
/// TLS 1.3 handshake messages exchange:
/// Client Server
/// --------------------------------------------------------------
/// ClientHello client flight 1 --->
/// ServerHello
/// {EncryptedExtensions}
/// {CertificateRequest*}
/// {Certificate}
/// {CertificateVerify}
/// <--- server flight 1 {Finished}
/// ChangeCipherSpec
/// {Certificate*}
/// {CertificateVerify*}
/// Finished client flight 2 --->
///
/// * - optional
/// {} - encrypted
///
/// References:
/// https://datatracker.ietf.org/doc/html/rfc5246#section-7.3
/// https://datatracker.ietf.org/doc/html/rfc8446#section-2
///
pub fn handshake(h: *HandshakeT, w: Stream, opt: Options) !Cipher {
defer h.updateDiagnostic(opt);
try h.initKeys(opt.named_groups);
h.cert = .{
.host = opt.host,
.root_ca = opt.root_ca.bundle,
.skip_verify = opt.insecure_skip_verify,
};
try w.writeAll(try h.makeClientHello(opt)); // client flight 1
try h.readServerFlight1(); // server flight 1
h.transcript.use(h.cipher_suite.hash());
// tls 1.3 specific handshake part
if (h.tls_version == .tls_1_3) {
try h.generateHandshakeCipher(opt.key_log_callback);
try h.readEncryptedServerFlight1(); // server flight 1
const app_cipher = try h.generateApplicationCipher(opt.key_log_callback);
try w.writeAll(try h.makeClientFlight2Tls13(opt.auth)); // client flight 2
return app_cipher;
}
// tls 1.2 specific handshake part
try h.generateCipher(opt.key_log_callback);
try w.writeAll(try h.makeClientFlight2Tls12(opt.auth)); // client flight 2
try h.readServerFlight2(); // server flight 2
return h.cipher;
}
/// Prepare key material and generate cipher for TLS 1.2
fn generateCipher(h: *HandshakeT, key_log_callback: ?key_log.Callback) !void {
try h.verifyCertificateSignatureTls12();
try h.generateKeyMaterial(key_log_callback);
h.cipher = try Cipher.initTls12(h.cipher_suite, &h.key_material, .client);
}
/// Generate TLS 1.2 pre master secret, master secret and key material.
fn generateKeyMaterial(h: *HandshakeT, key_log_callback: ?key_log.Callback) !void {
const pre_master_secret = if (h.named_group) |named_group|
try h.dh_kp.sharedKey(named_group, h.server_pub_key)
else
&h.rsa_secret.secret;
_ = dupe(
&h.master_secret,
h.transcript.masterSecret(pre_master_secret, h.client_random, h.server_random),
);
_ = dupe(
&h.key_material,
h.transcript.keyMaterial(&h.master_secret, h.client_random, h.server_random),
);
if (key_log_callback) |cb| {
cb(key_log.label.client_random, &h.client_random, &h.master_secret);
}
}
/// TLS 1.3 cipher used during handshake
fn generateHandshakeCipher(h: *HandshakeT, key_log_callback: ?key_log.Callback) !void {
const shared_key = try h.dh_kp.sharedKey(h.named_group.?, h.server_pub_key);
const handshake_secret = h.transcript.handshakeSecret(shared_key);
if (key_log_callback) |cb| {
cb(key_log.label.server_handshake_traffic_secret, &h.client_random, handshake_secret.server);
cb(key_log.label.client_handshake_traffic_secret, &h.client_random, handshake_secret.client);
}
h.cipher = try Cipher.initTls13(h.cipher_suite, handshake_secret, .client);
}
/// TLS 1.3 application (client) cipher
fn generateApplicationCipher(h: *HandshakeT, key_log_callback: ?key_log.Callback) !Cipher {
const application_secret = h.transcript.applicationSecret();
if (key_log_callback) |cb| {
cb(key_log.label.server_traffic_secret_0, &h.client_random, application_secret.server);
cb(key_log.label.client_traffic_secret_0, &h.client_random, application_secret.client);
}
return try Cipher.initTls13(h.cipher_suite, application_secret, .client);
}
fn makeClientHello(h: *HandshakeT, opt: Options) ![]const u8 {
// Buffer will have this parts:
// | header | payload | extensions |
//
// Header will be written last because we need to know length of
// payload and extensions when creating it. Payload has
// extensions length (u16) as last element.
//
var buffer = h.buffer;
const header_len = 9; // tls record header (5 bytes) and handshake header (4 bytes)
const tls_versions = try CipherSuite.versions(opt.cipher_suites);
// Payload writer, preserve header_len bytes for handshake header.
var payload = record.Writer{ .buf = buffer[header_len..] };
try payload.writeEnum(proto.Version.tls_1_2);
try payload.write(&h.client_random);
try payload.writeByte(0); // no session id
try payload.writeEnumArray(CipherSuite, opt.cipher_suites);
try payload.write(&[_]u8{ 0x01, 0x00 }); // no compression
// Extensions writer starts after payload and preserves 2 more
// bytes for extension len in payload.
var ext = record.Writer{ .buf = buffer[header_len + payload.pos + 2 ..] };
try ext.writeExtension(.supported_versions, switch (tls_versions) {
.both => &[_]proto.Version{ .tls_1_3, .tls_1_2 },
.tls_1_3 => &[_]proto.Version{.tls_1_3},
.tls_1_2 => &[_]proto.Version{.tls_1_2},
});
try ext.writeExtension(.signature_algorithms, common.supported_signature_algorithms);
try ext.writeExtension(.supported_groups, opt.named_groups);
if (tls_versions != .tls_1_2) {
var keys: [supported_named_groups.len][]const u8 = undefined;
for (opt.named_groups, 0..) |ng, i| {
keys[i] = try h.dh_kp.publicKey(ng);
}
try ext.writeKeyShare(opt.named_groups, keys[0..opt.named_groups.len]);
}
try ext.writeServerName(opt.host);
// Extensions length at the end of the payload.
try payload.writeInt(@as(u16, @intCast(ext.pos)));
// Header at the start of the buffer.
const body_len = payload.pos + ext.pos;
buffer[0..header_len].* = record.header(.handshake, 4 + body_len) ++
record.handshakeHeader(.client_hello, body_len);
const msg = buffer[0 .. header_len + body_len];
h.transcript.update(msg[record.header_len..]);
return msg;
}
/// Process first flight of the messages from the server.
/// Read server hello message. If TLS 1.3 is chosen in server hello
/// return. For TLS 1.2 continue and read certificate, key_exchange
/// eventual certificate request and hello done messages.
fn readServerFlight1(h: *HandshakeT) !void {
var handshake_states: []const proto.Handshake = &.{.server_hello};
while (true) {
var d = try h.rec_rdr.nextDecoder();
try d.expectContentType(.handshake);
h.transcript.update(d.payload);
// Multiple handshake messages can be packed in single tls record.
while (!d.eof()) {
const handshake_type = try d.decode(proto.Handshake);
const length = try d.decode(u24);
if (length > cipher.max_cleartext_len)
return error.TlsUnsupportedFragmentedHandshakeMessage;
brk: {
for (handshake_states) |state|
if (state == handshake_type) break :brk;
return error.TlsUnexpectedMessage;
}
switch (handshake_type) {
.server_hello => { // server hello, ref: https://datatracker.ietf.org/doc/html/rfc5246#section-7.4.1.3
try h.parseServerHello(&d, length);
if (h.tls_version == .tls_1_3) {
if (!d.eof()) return error.TlsIllegalParameter;
return; // end of tls 1.3 server flight 1
}
handshake_states = if (h.cert.skip_verify)
&.{ .certificate, .server_key_exchange, .server_hello_done }
else
&.{.certificate};
},
.certificate => {
try h.cert.parseCertificate(&d, h.tls_version);
handshake_states = if (h.cipher_suite.keyExchange() == .rsa)
&.{.server_hello_done}
else
&.{.server_key_exchange};
},
.server_key_exchange => {
try h.parseServerKeyExchange(&d);
handshake_states = &.{ .certificate_request, .server_hello_done };
},
.certificate_request => {
h.client_certificate_requested = true;
try d.skip(length);
handshake_states = &.{.server_hello_done};
},
.server_hello_done => {
if (length != 0) return error.TlsIllegalParameter;
return;
},
else => return error.TlsUnexpectedMessage,
}
}
}
}
/// Parse server hello message.
fn parseServerHello(h: *HandshakeT, d: *record.Decoder, length: u24) !void {
if (try d.decode(proto.Version) != proto.Version.tls_1_2)
return error.TlsBadVersion;
h.server_random = try d.array(32);
if (isServerHelloRetryRequest(&h.server_random))
return error.TlsServerHelloRetryRequest;
const session_id_len = try d.decode(u8);
if (session_id_len > 32) return error.TlsIllegalParameter;
try d.skip(session_id_len);
h.cipher_suite = try d.decode(CipherSuite);
try h.cipher_suite.validate();
try d.skip(1); // skip compression method
const extensions_present = length > 2 + 32 + 1 + session_id_len + 2 + 1;
if (extensions_present) {
const exs_len = try d.decode(u16);
var l: usize = 0;
while (l < exs_len) {
const typ = try d.decode(proto.Extension);
const len = try d.decode(u16);
defer l += len + 4;
switch (typ) {
.supported_versions => {
switch (try d.decode(proto.Version)) {
.tls_1_2, .tls_1_3 => |v| h.tls_version = v,
else => return error.TlsIllegalParameter,
}
if (len != 2) return error.TlsIllegalParameter;
},
.key_share => {
h.named_group = try d.decode(proto.NamedGroup);
h.server_pub_key = dupe(&h.server_pub_key_buf, try d.slice(try d.decode(u16)));
if (len != h.server_pub_key.len + 4) return error.TlsIllegalParameter;
},
else => {
try d.skip(len);
},
}
}
}
}
fn isServerHelloRetryRequest(server_random: []const u8) bool {
// Ref: https://datatracker.ietf.org/doc/html/rfc8446#section-4.1.3
const hello_retry_request_magic = [32]u8{
0xCF, 0x21, 0xAD, 0x74, 0xE5, 0x9A, 0x61, 0x11, 0xBE, 0x1D, 0x8C, 0x02, 0x1E, 0x65, 0xB8, 0x91,
0xC2, 0xA2, 0x11, 0x16, 0x7A, 0xBB, 0x8C, 0x5E, 0x07, 0x9E, 0x09, 0xE2, 0xC8, 0xA8, 0x33, 0x9C,
};
return std.mem.eql(u8, server_random, &hello_retry_request_magic);
}
fn parseServerKeyExchange(h: *HandshakeT, d: *record.Decoder) !void {
const curve_type = try d.decode(proto.Curve);
h.named_group = try d.decode(proto.NamedGroup);
h.server_pub_key = dupe(&h.server_pub_key_buf, try d.slice(try d.decode(u8)));
h.cert.signature_scheme = try d.decode(proto.SignatureScheme);
h.cert.signature = dupe(&h.cert.signature_buf, try d.slice(try d.decode(u16)));
if (curve_type != .named_curve) return error.TlsIllegalParameter;
}
/// Read encrypted part (after server hello) of the server first flight
/// for TLS 1.3: change cipher spec, eventual certificate request,
/// certificate, certificate verify and handshake finished messages.
fn readEncryptedServerFlight1(h: *HandshakeT) !void {
var cleartext_buf = h.buffer;
var cleartext_buf_head: usize = 0;
var cleartext_buf_tail: usize = 0;
var handshake_states: []const proto.Handshake = &.{.encrypted_extensions};
outer: while (true) {
// wrapped record decoder
const rec = (try h.rec_rdr.next() orelse return error.EndOfStream);
if (rec.protocol_version != .tls_1_2) return error.TlsBadVersion;
switch (rec.content_type) {
.change_cipher_spec => {},
.application_data => {
const content_type, const cleartext = try h.cipher.decrypt(
cleartext_buf[cleartext_buf_tail..],
rec,
);
cleartext_buf_tail += cleartext.len;
if (cleartext_buf_tail > cleartext_buf.len) return error.TlsRecordOverflow;
var d = record.Decoder.init(content_type, cleartext_buf[cleartext_buf_head..cleartext_buf_tail]);
try d.expectContentType(.handshake);
while (!d.eof()) {
const start_idx = d.idx;
const handshake_type = try d.decode(proto.Handshake);
const length = try d.decode(u24);
// std.debug.print("handshake loop: {} {} {} {}\n", .{ handshake_type, length, d.payload.len, d.idx });
if (length > cipher.max_cleartext_len)
return error.TlsUnsupportedFragmentedHandshakeMessage;
if (length > d.rest().len)
continue :outer; // fragmented handshake into multiple records
defer {
const handshake_payload = d.payload[start_idx..d.idx];
h.transcript.update(handshake_payload);
cleartext_buf_head += handshake_payload.len;
}
brk: {
for (handshake_states) |state|
if (state == handshake_type) break :brk;
return error.TlsUnexpectedMessage;
}
switch (handshake_type) {
.encrypted_extensions => {
try d.skip(length);
handshake_states = if (h.cert.skip_verify)
&.{ .certificate_request, .certificate, .finished }
else
&.{ .certificate_request, .certificate };
},
.certificate_request => {
h.client_certificate_requested = true;
try d.skip(length);
handshake_states = if (h.cert.skip_verify)
&.{ .certificate, .finished }
else
&.{.certificate};
},
.certificate => {
try h.cert.parseCertificate(&d, h.tls_version);
handshake_states = &.{.certificate_verify};
},
.certificate_verify => {
try h.cert.parseCertificateVerify(&d);
try h.cert.verifySignature(h.transcript.serverCertificateVerify());
handshake_states = &.{.finished};
},
.finished => {
const actual = try d.slice(length);
var buf: [Transcript.max_mac_length]u8 = undefined;
const expected = h.transcript.serverFinishedTls13(&buf);
if (!mem.eql(u8, expected, actual))
return error.TlsDecryptError;
return;
},
else => return error.TlsUnexpectedMessage,
}
}
cleartext_buf_head = 0;
cleartext_buf_tail = 0;
},
else => return error.TlsUnexpectedMessage,
}
}
}
fn verifyCertificateSignatureTls12(h: *HandshakeT) !void {
if (h.cipher_suite.keyExchange() != .ecdhe) return;
const verify_bytes = brk: {
var w = record.Writer{ .buf = h.buffer };
try w.write(&h.client_random);
try w.write(&h.server_random);
try w.writeEnum(proto.Curve.named_curve);
try w.writeEnum(h.named_group.?);
try w.writeInt(@as(u8, @intCast(h.server_pub_key.len)));
try w.write(h.server_pub_key);
break :brk w.getWritten();
};
try h.cert.verifySignature(verify_bytes);
}
/// Create client key exchange, change cipher spec and handshake
/// finished messages for tls 1.2.
/// If client certificate is requested also adds client certificate and
/// certificate verify messages.
fn makeClientFlight2Tls12(h: *HandshakeT, auth: ?CertKeyPair) ![]const u8 {
var w = record.Writer{ .buf = h.buffer };
var cert_builder: ?CertificateBuilder = null;
// Client certificate message
if (h.client_certificate_requested) {
if (auth) |a| {
const cb = h.certificateBuilder(a);
cert_builder = cb;
const client_certificate = try cb.makeCertificate(w.getPayload());
h.transcript.update(client_certificate);
try w.advanceRecord(.handshake, client_certificate.len);
} else {
const empty_certificate = &record.handshakeHeader(.certificate, 3) ++ [_]u8{ 0, 0, 0 };
h.transcript.update(empty_certificate);
try w.writeRecord(.handshake, empty_certificate);
}
}
// Client key exchange message
{
const key_exchange = try h.makeClientKeyExchange(w.getPayload());
h.transcript.update(key_exchange);
try w.advanceRecord(.handshake, key_exchange.len);
}
// Client certificate verify message
if (cert_builder) |cb| {
const certificate_verify = try cb.makeCertificateVerify(w.getPayload());
h.transcript.update(certificate_verify);
try w.advanceRecord(.handshake, certificate_verify.len);
}
// Client change cipher spec message
try w.writeRecord(.change_cipher_spec, &[_]u8{1});
// Client handshake finished message
{
const client_finished = &record.handshakeHeader(.finished, 12) ++
h.transcript.clientFinishedTls12(&h.master_secret);
h.transcript.update(client_finished);
try h.writeEncrypted(&w, client_finished);
}
return w.getWritten();
}
/// Create client change cipher spec and handshake finished messages for
/// tls 1.3.
/// If the client certificate is requested by the server and client is
/// configured with certificates and private key then client certificate
/// and client certificate verify messages are also created. If the
/// server has requested certificate but the client is not configured
/// empty certificate message is sent, as is required by rfc.
fn makeClientFlight2Tls13(h: *HandshakeT, auth: ?CertKeyPair) ![]const u8 {
var w = record.Writer{ .buf = h.buffer };
// Client change cipher spec message
try w.writeRecord(.change_cipher_spec, &[_]u8{1});
if (h.client_certificate_requested) {
if (auth) |a| {
const cb = h.certificateBuilder(a);
{
const certificate = try cb.makeCertificate(w.getPayload());
h.transcript.update(certificate);
try h.writeEncrypted(&w, certificate);
}
{
const certificate_verify = try cb.makeCertificateVerify(w.getPayload());
h.transcript.update(certificate_verify);
try h.writeEncrypted(&w, certificate_verify);
}
} else {
// Empty certificate message and no certificate verify message
const empty_certificate = &record.handshakeHeader(.certificate, 4) ++ [_]u8{ 0, 0, 0, 0 };
h.transcript.update(empty_certificate);
try h.writeEncrypted(&w, empty_certificate);
}
}
// Client handshake finished message
{
const client_finished = try h.makeClientFinishedTls13(w.getPayload());
h.transcript.update(client_finished);
try h.writeEncrypted(&w, client_finished);
}
return w.getWritten();
}
fn certificateBuilder(h: *HandshakeT, auth: CertKeyPair) CertificateBuilder {
return .{
.bundle = auth.bundle,
.key = auth.key,
.transcript = &h.transcript,
.tls_version = h.tls_version,
.side = .client,
};
}
fn makeClientFinishedTls13(h: *HandshakeT, buf: []u8) ![]const u8 {
var w = record.Writer{ .buf = buf };
const verify_data = h.transcript.clientFinishedTls13(w.getHandshakePayload());
try w.advanceHandshake(.finished, verify_data.len);
return w.getWritten();
}
fn makeClientKeyExchange(h: *HandshakeT, buf: []u8) ![]const u8 {
var w = record.Writer{ .buf = buf };
if (h.named_group) |named_group| {
const key = try h.dh_kp.publicKey(named_group);
try w.writeHandshakeHeader(.client_key_exchange, 1 + key.len);
try w.writeInt(@as(u8, @intCast(key.len)));
try w.write(key);
} else {
const key = try h.rsa_secret.encrypted(h.cert.pub_key_algo, h.cert.pub_key);
try w.writeHandshakeHeader(.client_key_exchange, 2 + key.len);
try w.writeInt(@as(u16, @intCast(key.len)));
try w.write(key);
}
return w.getWritten();
}
fn readServerFlight2(h: *HandshakeT) !void {
// Read server change cipher spec message.
{
var d = try h.rec_rdr.nextDecoder();
try d.expectContentType(.change_cipher_spec);
}
// Read encrypted server handshake finished message. Verify that
// content of the server finished message is based on transcript
// hash and master secret.
{
const content_type, const server_finished =
try h.rec_rdr.nextDecrypt(&h.cipher) orelse return error.EndOfStream;
if (content_type != .handshake)
return error.TlsUnexpectedMessage;
const expected = record.handshakeHeader(.finished, 12) ++ h.transcript.serverFinishedTls12(&h.master_secret);
if (!mem.eql(u8, server_finished, &expected))
return error.TlsBadRecordMac;
}
}
/// Write encrypted handshake message into `w`
fn writeEncrypted(h: *HandshakeT, w: *record.Writer, cleartext: []const u8) !void {
const ciphertext = try h.cipher.encrypt(w.getFree(), .handshake, cleartext);
w.pos += ciphertext.len;
}
// Copy handshake parameters to opt.diagnostic
fn updateDiagnostic(h: *HandshakeT, opt: Options) void {
if (opt.diagnostic) |d| {
d.tls_version = h.tls_version;
d.cipher_suite_tag = h.cipher_suite;
d.named_group = h.named_group orelse @as(proto.NamedGroup, @enumFromInt(0x0000));
d.signature_scheme = h.cert.signature_scheme;
if (opt.auth) |a|
d.client_signature_scheme = a.key.signature_scheme;
}
}
};
}
const RsaSecret = struct {
secret: [48]u8,
fn init(rand: [46]u8) RsaSecret {
return .{ .secret = [_]u8{ 0x03, 0x03 } ++ rand };
}
// Pre master secret encrypted with certificate public key.
inline fn encrypted(
self: RsaSecret,
cert_pub_key_algo: Certificate.Parsed.PubKeyAlgo,
cert_pub_key: []const u8,
) ![]const u8 {
if (cert_pub_key_algo != .rsaEncryption) return error.TlsBadSignatureScheme;
const pk = try rsa.PublicKey.fromDer(cert_pub_key);
var out: [512]u8 = undefined;
return try pk.encryptPkcsv1_5(&self.secret, &out);
}
};
const testing = std.testing;
const data12 = @import("testdata/tls12.zig");
const data13 = @import("testdata/tls13.zig");
const testu = @import("testu.zig");
fn testReader(data: []const u8) record.Reader(std.io.FixedBufferStream([]const u8)) {
return record.reader(std.io.fixedBufferStream(data));
}
const TestHandshake = Handshake(std.io.FixedBufferStream([]const u8));
test "parse tls 1.2 server hello" {
var h = brk: {
var buffer: [1024]u8 = undefined;
var rec_rdr = testReader(&data12.server_hello_responses);
break :brk TestHandshake.init(&buffer, &rec_rdr);
};
// Set to known instead of random
h.client_random = data12.client_random;
h.dh_kp.x25519_kp.secret_key = data12.client_secret;
// Parse server hello, certificate and key exchange messages.
// Read cipher suite, named group, signature scheme, server random certificate public key
// Verify host name, signature
// Calculate key material
h.cert = .{ .host = "example.ulfheim.net", .skip_verify = true, .root_ca = .{} };
try h.readServerFlight1();
try testing.expectEqual(.ECDHE_RSA_WITH_AES_128_CBC_SHA, h.cipher_suite);
try testing.expectEqual(.x25519, h.named_group.?);
try testing.expectEqual(.rsa_pkcs1_sha256, h.cert.signature_scheme);
try testing.expectEqualSlices(u8, &data12.server_random, &h.server_random);
try testing.expectEqualSlices(u8, &data12.server_pub_key, h.server_pub_key);
try testing.expectEqualSlices(u8, &data12.signature, h.cert.signature);
try testing.expectEqualSlices(u8, &data12.cert_pub_key, h.cert.pub_key);
try h.verifyCertificateSignatureTls12();
try h.generateKeyMaterial(null);
try testing.expectEqualSlices(u8, &data12.key_material, h.key_material[0..data12.key_material.len]);
}
test "verify google.com certificate" {
var h = brk: {
var buffer: [1024]u8 = undefined;
var rec_rdr = testReader(@embedFile("testdata/google.com/server_hello"));
break :brk TestHandshake.init(&buffer, &rec_rdr);
};
h.client_random = @embedFile("testdata/google.com/client_random").*;
var ca_bundle: Certificate.Bundle = .{};
try ca_bundle.rescan(testing.allocator);
defer ca_bundle.deinit(testing.allocator);
h.cert = .{ .host = "google.com", .skip_verify = true, .root_ca = .{}, .now_sec = 1714846451 };
try h.readServerFlight1();
try h.verifyCertificateSignatureTls12();
}
test "parse tls 1.3 server hello" {
var rec_rdr = testReader(&data13.server_hello);
var d = (try rec_rdr.nextDecoder());
const handshake_type = try d.decode(proto.Handshake);
const length = try d.decode(u24);
try testing.expectEqual(0x000076, length);
try testing.expectEqual(.server_hello, handshake_type);
var h = TestHandshake.init(undefined, undefined);
try h.parseServerHello(&d, length);
try testing.expectEqual(.AES_256_GCM_SHA384, h.cipher_suite);
try testing.expectEqualSlices(u8, &data13.server_random, &h.server_random);
try testing.expectEqual(.tls_1_3, h.tls_version);
try testing.expectEqual(.x25519, h.named_group);
try testing.expectEqualSlices(u8, &data13.server_pub_key, h.server_pub_key);
}
test "init tls 1.3 handshake cipher" {
const cipher_suite_tag: CipherSuite = .AES_256_GCM_SHA384;
var transcript = Transcript{};
transcript.use(cipher_suite_tag.hash());
transcript.update(data13.client_hello[record.header_len..]);
transcript.update(data13.server_hello[record.header_len..]);
var dh_kp = DhKeyPair{
.x25519_kp = .{
.public_key = data13.client_public_key,
.secret_key = data13.client_private_key,
},
};
const shared_key = try dh_kp.sharedKey(.x25519, &data13.server_pub_key);
try testing.expectEqualSlices(u8, &data13.shared_key, shared_key);
const cph = try Cipher.initTls13(cipher_suite_tag, transcript.handshakeSecret(shared_key), .client);
const c = &cph.AES_256_GCM_SHA384;
try testing.expectEqualSlices(u8, &data13.server_handshake_key, &c.decrypt_key);
try testing.expectEqualSlices(u8, &data13.client_handshake_key, &c.encrypt_key);
try testing.expectEqualSlices(u8, &data13.server_handshake_iv, &c.decrypt_iv);
try testing.expectEqualSlices(u8, &data13.client_handshake_iv, &c.encrypt_iv);
}
fn initExampleHandshake(h: *TestHandshake) !void {
h.cipher_suite = .AES_256_GCM_SHA384;
h.transcript.use(h.cipher_suite.hash());
h.transcript.update(data13.client_hello[record.header_len..]);
h.transcript.update(data13.server_hello[record.header_len..]);
h.cipher = try Cipher.initTls13(h.cipher_suite, h.transcript.handshakeSecret(&data13.shared_key), .client);
h.tls_version = .tls_1_3;
h.cert.now_sec = 1714846451;
h.server_pub_key = &data13.server_pub_key;
}
test "tls 1.3 decrypt wrapped record" {
var cph = brk: {
var h = TestHandshake.init(undefined, undefined);
try initExampleHandshake(&h);
break :brk h.cipher;
};
var cleartext_buf: [1024]u8 = undefined;
{
const rec = record.Record.init(&data13.server_encrypted_extensions_wrapped);
const content_type, const cleartext = try cph.decrypt(&cleartext_buf, rec);
try testing.expectEqual(.handshake, content_type);
try testing.expectEqualSlices(u8, &data13.server_encrypted_extensions, cleartext);
}
{
const rec = record.Record.init(&data13.server_certificate_wrapped);
const content_type, const cleartext = try cph.decrypt(&cleartext_buf, rec);
try testing.expectEqual(.handshake, content_type);
try testing.expectEqualSlices(u8, &data13.server_certificate, cleartext);
}
}
test "tls 1.3 process server flight" {
var buffer: [1024]u8 = undefined;
var h = brk: {
var rec_rdr = testReader(&data13.server_flight);
break :brk TestHandshake.init(&buffer, &rec_rdr);
};
try initExampleHandshake(&h);
h.cert = .{ .host = "example.ulfheim.net", .skip_verify = true, .root_ca = .{} };
try h.readEncryptedServerFlight1();
{ // application cipher keys calculation
try testing.expectEqualSlices(u8, &data13.handshake_hash, &h.transcript.sha384.hash.peek());
var cph = try Cipher.initTls13(h.cipher_suite, h.transcript.applicationSecret(), .client);
const c = &cph.AES_256_GCM_SHA384;
try testing.expectEqualSlices(u8, &data13.server_application_key, &c.decrypt_key);
try testing.expectEqualSlices(u8, &data13.client_application_key, &c.encrypt_key);
try testing.expectEqualSlices(u8, &data13.server_application_iv, &c.decrypt_iv);
try testing.expectEqualSlices(u8, &data13.client_application_iv, &c.encrypt_iv);
const encrypted = try cph.encrypt(&buffer, .application_data, "ping");
try testing.expectEqualSlices(u8, &data13.client_ping_wrapped, encrypted);
}
{ // client finished message
var buf: [4 + Transcript.max_mac_length]u8 = undefined;
const client_finished = try h.makeClientFinishedTls13(&buf);
try testing.expectEqualSlices(u8, &data13.client_finished_verify_data, client_finished[4..]);
const encrypted = try h.cipher.encrypt(&buffer, .handshake, client_finished);
try testing.expectEqualSlices(u8, &data13.client_finished_wrapped, encrypted);
}
}
test "create client hello" {
var h = brk: {
var buffer: [1024]u8 = undefined;
var h = TestHandshake.init(&buffer, undefined);
h.client_random = testu.hexToBytes(
\\ 00 01 02 03 04 05 06 07 08 09 0a 0b 0c 0d 0e 0f 10 11 12 13 14 15 16 17 18 19 1a 1b 1c 1d 1e 1f
);
break :brk h;
};
const actual = try h.makeClientHello(.{
.host = "google.com",
.root_ca = .{},
.cipher_suites = &[_]CipherSuite{CipherSuite.ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
.named_groups = &[_]proto.NamedGroup{ .x25519, .secp256r1, .secp384r1 },
});
const expected = testu.hexToBytes(
"16 03 03 00 6d " ++ // record header
"01 00 00 69 " ++ // handshake header
"03 03 " ++ // protocol version
"00 01 02 03 04 05 06 07 08 09 0a 0b 0c 0d 0e 0f 10 11 12 13 14 15 16 17 18 19 1a 1b 1c 1d 1e 1f " ++ // client random
"00 " ++ // no session id
"00 02 c0 2b " ++ // cipher suites
"01 00 " ++ // compression methods
"00 3e " ++ // extensions length
"00 2b 00 03 02 03 03 " ++ // supported versions extension
"00 0d 00 14 00 12 04 03 05 03 08 04 08 05 08 06 08 07 02 01 04 01 05 01 " ++ // signature algorithms extension
"00 0a 00 08 00 06 00 1d 00 17 00 18 " ++ // named groups extension
"00 00 00 0f 00 0d 00 00 0a 67 6f 6f 67 6c 65 2e 63 6f 6d ", // server name extension
);
try testing.expectEqualSlices(u8, &expected, actual);
}
test "handshake verify server finished message" {
var buffer: [1024]u8 = undefined;
var rec_rdr = testReader(&data12.server_handshake_finished_msgs);
var h = TestHandshake.init(&buffer, &rec_rdr);
h.cipher_suite = .ECDHE_ECDSA_WITH_AES_128_CBC_SHA;
h.master_secret = data12.master_secret;
// add handshake messages to the transcript
for (data12.handshake_messages) |msg| {
h.transcript.update(msg[record.header_len..]);
}
// expect verify data
const client_finished = h.transcript.clientFinishedTls12(&h.master_secret);
try testing.expectEqualSlices(u8, &data12.client_finished, &record.handshakeHeader(.finished, 12) ++ client_finished);
// init client with prepared key_material
h.cipher = try Cipher.initTls12(.ECDHE_RSA_WITH_AES_128_CBC_SHA, &data12.key_material, .client);
// check that server verify data matches calculates from hashes of all handshake messages
h.transcript.update(&data12.client_finished);
try h.readServerFlight2();
}

View File

@@ -1,448 +0,0 @@
const std = @import("std");
const assert = std.debug.assert;
const mem = std.mem;
const crypto = std.crypto;
const Certificate = crypto.Certificate;
const Transcript = @import("transcript.zig").Transcript;
const PrivateKey = @import("PrivateKey.zig");
const record = @import("record.zig");
const rsa = @import("rsa/rsa.zig");
const proto = @import("protocol.zig");
const X25519 = crypto.dh.X25519;
const EcdsaP256Sha256 = crypto.sign.ecdsa.EcdsaP256Sha256;
const EcdsaP384Sha384 = crypto.sign.ecdsa.EcdsaP384Sha384;
const Kyber768 = crypto.kem.kyber_d00.Kyber768;
pub const supported_signature_algorithms = &[_]proto.SignatureScheme{
.ecdsa_secp256r1_sha256,
.ecdsa_secp384r1_sha384,
.rsa_pss_rsae_sha256,
.rsa_pss_rsae_sha384,
.rsa_pss_rsae_sha512,
.ed25519,
.rsa_pkcs1_sha1,
.rsa_pkcs1_sha256,
.rsa_pkcs1_sha384,
};
pub const CertKeyPair = struct {
/// A chain of one or more certificates, leaf first.
///
/// Each X.509 certificate contains the public key of a key pair, extra
/// information (the name of the holder, the name of an issuer of the
/// certificate, validity time spans) and a signature generated using the
/// private key of the issuer of the certificate.
///
/// All certificates from the bundle are sent to the other side when creating
/// Certificate tls message.
///
/// Leaf certificate and private key are used to create signature for
/// CertifyVerify tls message.
bundle: Certificate.Bundle,
/// Private key corresponding to the public key in leaf certificate from the
/// bundle.
key: PrivateKey,
pub fn load(
allocator: std.mem.Allocator,
dir: std.fs.Dir,
cert_path: []const u8,
key_path: []const u8,
) !CertKeyPair {
var bundle: Certificate.Bundle = .{};
try bundle.addCertsFromFilePath(allocator, dir, cert_path);
const key_file = try dir.openFile(key_path, .{});
defer key_file.close();
const key = try PrivateKey.fromFile(allocator, key_file);
return .{ .bundle = bundle, .key = key };
}
pub fn deinit(c: *CertKeyPair, allocator: std.mem.Allocator) void {
c.bundle.deinit(allocator);
}
};
pub const CertBundle = struct {
// A chain of one or more certificates.
//
// They are used to verify that certificate chain sent by the other side
// forms valid trust chain.
bundle: Certificate.Bundle = .{},
pub fn fromFile(allocator: std.mem.Allocator, dir: std.fs.Dir, path: []const u8) !CertBundle {
var bundle: Certificate.Bundle = .{};
try bundle.addCertsFromFilePath(allocator, dir, path);
return .{ .bundle = bundle };
}
pub fn fromSystem(allocator: std.mem.Allocator) !CertBundle {
var bundle: Certificate.Bundle = .{};
try bundle.rescan(allocator);
return .{ .bundle = bundle };
}
pub fn deinit(cb: *CertBundle, allocator: std.mem.Allocator) void {
cb.bundle.deinit(allocator);
}
};
pub const CertificateBuilder = struct {
bundle: Certificate.Bundle,
key: PrivateKey,
transcript: *Transcript,
tls_version: proto.Version = .tls_1_3,
side: proto.Side = .client,
pub fn makeCertificate(h: CertificateBuilder, buf: []u8) ![]const u8 {
var w = record.Writer{ .buf = buf };
const certs = h.bundle.bytes.items;
const certs_count = h.bundle.map.size;
// Differences between tls 1.3 and 1.2
// TLS 1.3 has request context in header and extensions for each certificate.
// Here we use empty length for each field.
// TLS 1.2 don't have these two fields.
const request_context, const extensions = if (h.tls_version == .tls_1_3)
.{ &[_]u8{0}, &[_]u8{ 0, 0 } }
else
.{ &[_]u8{}, &[_]u8{} };
const certs_len = certs.len + (3 + extensions.len) * certs_count;
// Write handshake header
try w.writeHandshakeHeader(.certificate, certs_len + request_context.len + 3);
try w.write(request_context);
try w.writeInt(@as(u24, @intCast(certs_len)));
// Write each certificate
var index: u32 = 0;
while (index < certs.len) {
const e = try Certificate.der.Element.parse(certs, index);
const cert = certs[index..e.slice.end];
try w.writeInt(@as(u24, @intCast(cert.len))); // certificate length
try w.write(cert); // certificate
try w.write(extensions); // certificate extensions
index = e.slice.end;
}
return w.getWritten();
}
pub fn makeCertificateVerify(h: CertificateBuilder, buf: []u8) ![]const u8 {
var w = record.Writer{ .buf = buf };
const signature, const signature_scheme = try h.createSignature();
try w.writeHandshakeHeader(.certificate_verify, signature.len + 4);
try w.writeEnum(signature_scheme);
try w.writeInt(@as(u16, @intCast(signature.len)));
try w.write(signature);
return w.getWritten();
}
/// Creates signature for client certificate signature message.
/// Returns signature bytes and signature scheme.
inline fn createSignature(h: CertificateBuilder) !struct { []const u8, proto.SignatureScheme } {
switch (h.key.signature_scheme) {
inline .ecdsa_secp256r1_sha256,
.ecdsa_secp384r1_sha384,
=> |comptime_scheme| {
const Ecdsa = SchemeEcdsa(comptime_scheme);
const key = h.key.key.ecdsa;
const key_len = Ecdsa.SecretKey.encoded_length;
if (key.len < key_len) return error.InvalidEncoding;
const secret_key = try Ecdsa.SecretKey.fromBytes(key[0..key_len].*);
const key_pair = try Ecdsa.KeyPair.fromSecretKey(secret_key);
var signer = try key_pair.signer(null);
h.setSignatureVerifyBytes(&signer);
const signature = try signer.finalize();
var buf: [Ecdsa.Signature.der_encoded_length_max]u8 = undefined;
return .{ signature.toDer(&buf), comptime_scheme };
},
inline .rsa_pss_rsae_sha256,
.rsa_pss_rsae_sha384,
.rsa_pss_rsae_sha512,
=> |comptime_scheme| {
const Hash = SchemeHash(comptime_scheme);
var signer = try h.key.key.rsa.signerOaep(Hash, null);
h.setSignatureVerifyBytes(&signer);
var buf: [512]u8 = undefined;
const signature = try signer.finalize(&buf);
return .{ signature.bytes, comptime_scheme };
},
else => return error.TlsUnknownSignatureScheme,
}
}
fn setSignatureVerifyBytes(h: CertificateBuilder, signer: anytype) void {
if (h.tls_version == .tls_1_2) {
// tls 1.2 signature uses current transcript hash value.
// ref: https://datatracker.ietf.org/doc/html/rfc5246.html#section-7.4.8
const Hash = @TypeOf(signer.h);
signer.h = h.transcript.hash(Hash);
} else {
// tls 1.3 signature is computed over concatenation of 64 spaces,
// context, separator and content.
// ref: https://datatracker.ietf.org/doc/html/rfc8446#section-4.4.3
if (h.side == .server) {
signer.update(h.transcript.serverCertificateVerify());
} else {
signer.update(h.transcript.clientCertificateVerify());
}
}
}
fn SchemeEcdsa(comptime scheme: proto.SignatureScheme) type {
return switch (scheme) {
.ecdsa_secp256r1_sha256 => EcdsaP256Sha256,
.ecdsa_secp384r1_sha384 => EcdsaP384Sha384,
else => unreachable,
};
}
};
pub const CertificateParser = struct {
pub_key_algo: Certificate.Parsed.PubKeyAlgo = undefined,
pub_key_buf: [600]u8 = undefined,
pub_key: []const u8 = undefined,
signature_scheme: proto.SignatureScheme = @enumFromInt(0),
signature_buf: [1024]u8 = undefined,
signature: []const u8 = undefined,
root_ca: Certificate.Bundle,
host: []const u8,
skip_verify: bool = false,
now_sec: i64 = 0,
pub fn parseCertificate(h: *CertificateParser, d: *record.Decoder, tls_version: proto.Version) !void {
if (h.now_sec == 0) {
h.now_sec = std.time.timestamp();
}
if (tls_version == .tls_1_3) {
const request_context = try d.decode(u8);
if (request_context != 0) return error.TlsIllegalParameter;
}
var trust_chain_established = false;
var last_cert: ?Certificate.Parsed = null;
const certs_len = try d.decode(u24);
const start_idx = d.idx;
while (d.idx - start_idx < certs_len) {
const cert_len = try d.decode(u24);
// std.debug.print("=> {} {} {} {}\n", .{ certs_len, d.idx, cert_len, d.payload.len });
const cert = try d.slice(cert_len);
if (tls_version == .tls_1_3) {
// certificate extensions present in tls 1.3
try d.skip(try d.decode(u16));
}
if (trust_chain_established)
continue;
const subject = try (Certificate{ .buffer = cert, .index = 0 }).parse();
if (last_cert) |pc| {
if (pc.verify(subject, h.now_sec)) {
last_cert = subject;
} else |err| switch (err) {
error.CertificateIssuerMismatch => {
// skip certificate which is not part of the chain
continue;
},
else => return err,
}
} else { // first certificate
if (!h.skip_verify and h.host.len > 0) {
try subject.verifyHostName(h.host);
}
h.pub_key = dupe(&h.pub_key_buf, subject.pubKey());
h.pub_key_algo = subject.pub_key_algo;
last_cert = subject;
}
if (!h.skip_verify) {
if (h.root_ca.verify(last_cert.?, h.now_sec)) |_| {
trust_chain_established = true;
} else |err| switch (err) {
error.CertificateIssuerNotFound => {},
else => return err,
}
}
}
if (!h.skip_verify and !trust_chain_established) {
return error.CertificateIssuerNotFound;
}
}
pub fn parseCertificateVerify(h: *CertificateParser, d: *record.Decoder) !void {
h.signature_scheme = try d.decode(proto.SignatureScheme);
h.signature = dupe(&h.signature_buf, try d.slice(try d.decode(u16)));
}
pub fn verifySignature(h: *CertificateParser, verify_bytes: []const u8) !void {
switch (h.signature_scheme) {
inline .ecdsa_secp256r1_sha256,
.ecdsa_secp384r1_sha384,
=> |comptime_scheme| {
if (h.pub_key_algo != .X9_62_id_ecPublicKey) return error.TlsBadSignatureScheme;
const cert_named_curve = h.pub_key_algo.X9_62_id_ecPublicKey;
switch (cert_named_curve) {
inline .secp384r1, .X9_62_prime256v1 => |comptime_cert_named_curve| {
const Ecdsa = SchemeEcdsaCert(comptime_scheme, comptime_cert_named_curve);
const key = try Ecdsa.PublicKey.fromSec1(h.pub_key);
const sig = try Ecdsa.Signature.fromDer(h.signature);
try sig.verify(verify_bytes, key);
},
else => return error.TlsUnknownSignatureScheme,
}
},
.ed25519 => {
if (h.pub_key_algo != .curveEd25519) return error.TlsBadSignatureScheme;
const Eddsa = crypto.sign.Ed25519;
if (h.signature.len != Eddsa.Signature.encoded_length) return error.InvalidEncoding;
const sig = Eddsa.Signature.fromBytes(h.signature[0..Eddsa.Signature.encoded_length].*);
if (h.pub_key.len != Eddsa.PublicKey.encoded_length) return error.InvalidEncoding;
const key = try Eddsa.PublicKey.fromBytes(h.pub_key[0..Eddsa.PublicKey.encoded_length].*);
try sig.verify(verify_bytes, key);
},
inline .rsa_pss_rsae_sha256,
.rsa_pss_rsae_sha384,
.rsa_pss_rsae_sha512,
=> |comptime_scheme| {
if (h.pub_key_algo != .rsaEncryption) return error.TlsBadSignatureScheme;
const Hash = SchemeHash(comptime_scheme);
const pk = try rsa.PublicKey.fromDer(h.pub_key);
const sig = rsa.Pss(Hash).Signature{ .bytes = h.signature };
try sig.verify(verify_bytes, pk, null);
},
inline .rsa_pkcs1_sha1,
.rsa_pkcs1_sha256,
.rsa_pkcs1_sha384,
.rsa_pkcs1_sha512,
=> |comptime_scheme| {
if (h.pub_key_algo != .rsaEncryption) return error.TlsBadSignatureScheme;
const Hash = SchemeHash(comptime_scheme);
const pk = try rsa.PublicKey.fromDer(h.pub_key);
const sig = rsa.PKCS1v1_5(Hash).Signature{ .bytes = h.signature };
try sig.verify(verify_bytes, pk);
},
else => return error.TlsUnknownSignatureScheme,
}
}
fn SchemeEcdsaCert(comptime scheme: proto.SignatureScheme, comptime cert_named_curve: Certificate.NamedCurve) type {
const Sha256 = crypto.hash.sha2.Sha256;
const Sha384 = crypto.hash.sha2.Sha384;
const Ecdsa = crypto.sign.ecdsa.Ecdsa;
return switch (scheme) {
.ecdsa_secp256r1_sha256 => Ecdsa(cert_named_curve.Curve(), Sha256),
.ecdsa_secp384r1_sha384 => Ecdsa(cert_named_curve.Curve(), Sha384),
else => @compileError("bad scheme"),
};
}
};
fn SchemeHash(comptime scheme: proto.SignatureScheme) type {
const Sha256 = crypto.hash.sha2.Sha256;
const Sha384 = crypto.hash.sha2.Sha384;
const Sha512 = crypto.hash.sha2.Sha512;
return switch (scheme) {
.rsa_pkcs1_sha1 => crypto.hash.Sha1,
.rsa_pss_rsae_sha256, .rsa_pkcs1_sha256 => Sha256,
.rsa_pss_rsae_sha384, .rsa_pkcs1_sha384 => Sha384,
.rsa_pss_rsae_sha512, .rsa_pkcs1_sha512 => Sha512,
else => @compileError("bad scheme"),
};
}
pub fn dupe(buf: []u8, data: []const u8) []u8 {
const n = @min(data.len, buf.len);
@memcpy(buf[0..n], data[0..n]);
return buf[0..n];
}
pub const DhKeyPair = struct {
x25519_kp: X25519.KeyPair = undefined,
secp256r1_kp: EcdsaP256Sha256.KeyPair = undefined,
secp384r1_kp: EcdsaP384Sha384.KeyPair = undefined,
kyber768_kp: Kyber768.KeyPair = undefined,
pub const seed_len = 32 + 32 + 48 + 64;
pub fn init(seed: [seed_len]u8, named_groups: []const proto.NamedGroup) !DhKeyPair {
var kp: DhKeyPair = .{};
for (named_groups) |ng|
switch (ng) {
.x25519 => kp.x25519_kp = try X25519.KeyPair.create(seed[0..][0..X25519.seed_length].*),
.secp256r1 => kp.secp256r1_kp = try EcdsaP256Sha256.KeyPair.create(seed[32..][0..EcdsaP256Sha256.KeyPair.seed_length].*),
.secp384r1 => kp.secp384r1_kp = try EcdsaP384Sha384.KeyPair.create(seed[32 + 32 ..][0..EcdsaP384Sha384.KeyPair.seed_length].*),
.x25519_kyber768d00 => kp.kyber768_kp = try Kyber768.KeyPair.create(seed[32 + 32 + 48 ..][0..Kyber768.seed_length].*),
else => return error.TlsIllegalParameter,
};
return kp;
}
pub inline fn sharedKey(self: DhKeyPair, named_group: proto.NamedGroup, server_pub_key: []const u8) ![]const u8 {
return switch (named_group) {
.x25519 => brk: {
if (server_pub_key.len != X25519.public_length)
return error.TlsIllegalParameter;
break :brk &(try X25519.scalarmult(
self.x25519_kp.secret_key,
server_pub_key[0..X25519.public_length].*,
));
},
.secp256r1 => brk: {
const pk = try EcdsaP256Sha256.PublicKey.fromSec1(server_pub_key);
const mul = try pk.p.mulPublic(self.secp256r1_kp.secret_key.bytes, .big);
break :brk &mul.affineCoordinates().x.toBytes(.big);
},
.secp384r1 => brk: {
const pk = try EcdsaP384Sha384.PublicKey.fromSec1(server_pub_key);
const mul = try pk.p.mulPublic(self.secp384r1_kp.secret_key.bytes, .big);
break :brk &mul.affineCoordinates().x.toBytes(.big);
},
.x25519_kyber768d00 => brk: {
const xksl = crypto.dh.X25519.public_length;
const hksl = xksl + Kyber768.ciphertext_length;
if (server_pub_key.len != hksl)
return error.TlsIllegalParameter;
break :brk &((crypto.dh.X25519.scalarmult(
self.x25519_kp.secret_key,
server_pub_key[0..xksl].*,
) catch return error.TlsDecryptFailure) ++ (self.kyber768_kp.secret_key.decaps(
server_pub_key[xksl..hksl],
) catch return error.TlsDecryptFailure));
},
else => return error.TlsIllegalParameter,
};
}
// Returns 32, 65, 97 or 1216 bytes
pub inline fn publicKey(self: DhKeyPair, named_group: proto.NamedGroup) ![]const u8 {
return switch (named_group) {
.x25519 => &self.x25519_kp.public_key,
.secp256r1 => &self.secp256r1_kp.public_key.toUncompressedSec1(),
.secp384r1 => &self.secp384r1_kp.public_key.toUncompressedSec1(),
.x25519_kyber768d00 => &self.x25519_kp.public_key ++ self.kyber768_kp.public_key.toBytes(),
else => return error.TlsIllegalParameter,
};
}
};
const testing = std.testing;
const testu = @import("testu.zig");
test "DhKeyPair.x25519" {
var seed: [DhKeyPair.seed_len]u8 = undefined;
testu.fill(&seed);
const server_pub_key = &testu.hexToBytes("3303486548531f08d91e675caf666c2dc924ac16f47a861a7f4d05919d143637");
const expected = &testu.hexToBytes(
\\ F1 67 FB 4A 49 B2 91 77 08 29 45 A1 F7 08 5A 21
\\ AF FE 9E 78 C2 03 9B 81 92 40 72 73 74 7A 46 1E
);
const kp = try DhKeyPair.init(seed, &.{.x25519});
try testing.expectEqualSlices(u8, expected, try kp.sharedKey(.x25519, server_pub_key));
}

View File

@@ -1,520 +0,0 @@
const std = @import("std");
const assert = std.debug.assert;
const crypto = std.crypto;
const mem = std.mem;
const Certificate = crypto.Certificate;
const cipher = @import("cipher.zig");
const Cipher = cipher.Cipher;
const CipherSuite = @import("cipher.zig").CipherSuite;
const cipher_suites = @import("cipher.zig").cipher_suites;
const Transcript = @import("transcript.zig").Transcript;
const record = @import("record.zig");
const PrivateKey = @import("PrivateKey.zig");
const proto = @import("protocol.zig");
const common = @import("handshake_common.zig");
const dupe = common.dupe;
const CertificateBuilder = common.CertificateBuilder;
const CertificateParser = common.CertificateParser;
const DhKeyPair = common.DhKeyPair;
const CertBundle = common.CertBundle;
const CertKeyPair = common.CertKeyPair;
pub const Options = struct {
/// Server authentication. If null server will not send Certificate and
/// CertificateVerify message.
auth: ?CertKeyPair,
/// If not null server will request client certificate. If auth_type is
/// .request empty client certificate message will be accepted.
/// Client certificate will be verified with root_ca certificates.
client_auth: ?ClientAuth = null,
};
pub const ClientAuth = struct {
/// Set of root certificate authorities that server use when verifying
/// client certificates.
root_ca: CertBundle,
auth_type: Type = .require,
pub const Type = enum {
/// Client certificate will be requested during the handshake, but does
/// not require that the client send any certificates.
request,
/// Client certificate will be requested during the handshake, and client
/// has to send valid certificate.
require,
};
};
pub fn Handshake(comptime Stream: type) type {
const RecordReaderT = record.Reader(Stream);
return struct {
// public key len: x25519 = 32, secp256r1 = 65, secp384r1 = 97
const max_pub_key_len = 98;
const supported_named_groups = &[_]proto.NamedGroup{ .x25519, .secp256r1, .secp384r1 };
server_random: [32]u8 = undefined,
client_random: [32]u8 = undefined,
legacy_session_id_buf: [32]u8 = undefined,
legacy_session_id: []u8 = "",
cipher_suite: CipherSuite = @enumFromInt(0),
signature_scheme: proto.SignatureScheme = @enumFromInt(0),
named_group: proto.NamedGroup = @enumFromInt(0),
client_pub_key_buf: [max_pub_key_len]u8 = undefined,
client_pub_key: []u8 = "",
server_pub_key_buf: [max_pub_key_len]u8 = undefined,
server_pub_key: []u8 = "",
cipher: Cipher = undefined,
transcript: Transcript = .{},
rec_rdr: *RecordReaderT,
buffer: []u8,
const HandshakeT = @This();
pub fn init(buf: []u8, rec_rdr: *RecordReaderT) HandshakeT {
return .{
.rec_rdr = rec_rdr,
.buffer = buf,
};
}
fn writeAlert(h: *HandshakeT, stream: Stream, cph: ?*Cipher, err: anyerror) !void {
if (cph) |c| {
const cleartext = proto.alertFromError(err);
const ciphertext = try c.encrypt(h.buffer, .alert, &cleartext);
stream.writeAll(ciphertext) catch {};
} else {
const alert = record.header(.alert, 2) ++ proto.alertFromError(err);
stream.writeAll(&alert) catch {};
}
}
pub fn handshake(h: *HandshakeT, stream: Stream, opt: Options) !Cipher {
crypto.random.bytes(&h.server_random);
if (opt.auth) |a| {
// required signature scheme in client hello
h.signature_scheme = a.key.signature_scheme;
}
h.readClientHello() catch |err| {
try h.writeAlert(stream, null, err);
return err;
};
h.transcript.use(h.cipher_suite.hash());
const server_flight = brk: {
var w = record.Writer{ .buf = h.buffer };
const shared_key = h.sharedKey() catch |err| {
try h.writeAlert(stream, null, err);
return err;
};
{
const hello = try h.makeServerHello(w.getFree());
h.transcript.update(hello[record.header_len..]);
w.pos += hello.len;
}
{
const handshake_secret = h.transcript.handshakeSecret(shared_key);
h.cipher = try Cipher.initTls13(h.cipher_suite, handshake_secret, .server);
}
try w.writeRecord(.change_cipher_spec, &[_]u8{1});
{
const encrypted_extensions = &record.handshakeHeader(.encrypted_extensions, 2) ++ [_]u8{ 0, 0 };
h.transcript.update(encrypted_extensions);
try h.writeEncrypted(&w, encrypted_extensions);
}
if (opt.client_auth) |_| {
const certificate_request = try makeCertificateRequest(w.getPayload());
h.transcript.update(certificate_request);
try h.writeEncrypted(&w, certificate_request);
}
if (opt.auth) |a| {
const cm = CertificateBuilder{
.bundle = a.bundle,
.key = a.key,
.transcript = &h.transcript,
.side = .server,
};
{
const certificate = try cm.makeCertificate(w.getPayload());
h.transcript.update(certificate);
try h.writeEncrypted(&w, certificate);
}
{
const certificate_verify = try cm.makeCertificateVerify(w.getPayload());
h.transcript.update(certificate_verify);
try h.writeEncrypted(&w, certificate_verify);
}
}
{
const finished = try h.makeFinished(w.getPayload());
h.transcript.update(finished);
try h.writeEncrypted(&w, finished);
}
break :brk w.getWritten();
};
try stream.writeAll(server_flight);
var app_cipher = brk: {
const application_secret = h.transcript.applicationSecret();
break :brk try Cipher.initTls13(h.cipher_suite, application_secret, .server);
};
h.readClientFlight2(opt) catch |err| {
// Alert received from client
if (!mem.startsWith(u8, @errorName(err), "TlsAlert")) {
try h.writeAlert(stream, &app_cipher, err);
}
return err;
};
return app_cipher;
}
inline fn sharedKey(h: *HandshakeT) ![]const u8 {
var seed: [DhKeyPair.seed_len]u8 = undefined;
crypto.random.bytes(&seed);
var kp = try DhKeyPair.init(seed, supported_named_groups);
h.server_pub_key = dupe(&h.server_pub_key_buf, try kp.publicKey(h.named_group));
return try kp.sharedKey(h.named_group, h.client_pub_key);
}
fn readClientFlight2(h: *HandshakeT, opt: Options) !void {
var cleartext_buf = h.buffer;
var cleartext_buf_head: usize = 0;
var cleartext_buf_tail: usize = 0;
var handshake_state: proto.Handshake = .finished;
var cert: CertificateParser = undefined;
if (opt.client_auth) |client_auth| {
cert = .{ .root_ca = client_auth.root_ca.bundle, .host = "" };
handshake_state = .certificate;
}
outer: while (true) {
const rec = (try h.rec_rdr.next() orelse return error.EndOfStream);
if (rec.protocol_version != .tls_1_2 and rec.content_type != .alert)
return error.TlsProtocolVersion;
switch (rec.content_type) {
.change_cipher_spec => {
if (rec.payload.len != 1) return error.TlsUnexpectedMessage;
},
.application_data => {
const content_type, const cleartext = try h.cipher.decrypt(
cleartext_buf[cleartext_buf_tail..],
rec,
);
cleartext_buf_tail += cleartext.len;
if (cleartext_buf_tail > cleartext_buf.len) return error.TlsRecordOverflow;
var d = record.Decoder.init(content_type, cleartext_buf[cleartext_buf_head..cleartext_buf_tail]);
try d.expectContentType(.handshake);
while (!d.eof()) {
const start_idx = d.idx;
const handshake_type = try d.decode(proto.Handshake);
const length = try d.decode(u24);
if (length > cipher.max_cleartext_len)
return error.TlsRecordOverflow;
if (length > d.rest().len)
continue :outer; // fragmented handshake into multiple records
defer {
const handshake_payload = d.payload[start_idx..d.idx];
h.transcript.update(handshake_payload);
cleartext_buf_head += handshake_payload.len;
}
if (handshake_state != handshake_type)
return error.TlsUnexpectedMessage;
switch (handshake_type) {
.certificate => {
if (length == 4) {
// got empty certificate message
if (opt.client_auth.?.auth_type == .require)
return error.TlsCertificateRequired;
try d.skip(length);
handshake_state = .finished;
} else {
try cert.parseCertificate(&d, .tls_1_3);
handshake_state = .certificate_verify;
}
},
.certificate_verify => {
try cert.parseCertificateVerify(&d);
cert.verifySignature(h.transcript.clientCertificateVerify()) catch |err| return switch (err) {
error.TlsUnknownSignatureScheme => error.TlsIllegalParameter,
else => error.TlsDecryptError,
};
handshake_state = .finished;
},
.finished => {
const actual = try d.slice(length);
var buf: [Transcript.max_mac_length]u8 = undefined;
const expected = h.transcript.clientFinishedTls13(&buf);
if (!mem.eql(u8, expected, actual))
return if (expected.len == actual.len)
error.TlsDecryptError
else
error.TlsDecodeError;
return;
},
else => return error.TlsUnexpectedMessage,
}
}
cleartext_buf_head = 0;
cleartext_buf_tail = 0;
},
.alert => {
var d = rec.decoder();
return d.raiseAlert();
},
else => return error.TlsUnexpectedMessage,
}
}
}
fn makeFinished(h: *HandshakeT, buf: []u8) ![]const u8 {
var w = record.Writer{ .buf = buf };
const verify_data = h.transcript.serverFinishedTls13(w.getHandshakePayload());
try w.advanceHandshake(.finished, verify_data.len);
return w.getWritten();
}
/// Write encrypted handshake message into `w`
fn writeEncrypted(h: *HandshakeT, w: *record.Writer, cleartext: []const u8) !void {
const ciphertext = try h.cipher.encrypt(w.getFree(), .handshake, cleartext);
w.pos += ciphertext.len;
}
fn makeServerHello(h: *HandshakeT, buf: []u8) ![]const u8 {
const header_len = 9; // tls record header (5 bytes) and handshake header (4 bytes)
var w = record.Writer{ .buf = buf[header_len..] };
try w.writeEnum(proto.Version.tls_1_2);
try w.write(&h.server_random);
{
try w.writeInt(@as(u8, @intCast(h.legacy_session_id.len)));
if (h.legacy_session_id.len > 0) try w.write(h.legacy_session_id);
}
try w.writeEnum(h.cipher_suite);
try w.write(&[_]u8{0}); // compression method
var e = record.Writer{ .buf = buf[header_len + w.pos + 2 ..] };
{ // supported versions extension
try e.writeEnum(proto.Extension.supported_versions);
try e.writeInt(@as(u16, 2));
try e.writeEnum(proto.Version.tls_1_3);
}
{ // key share extension
const key_len: u16 = @intCast(h.server_pub_key.len);
try e.writeEnum(proto.Extension.key_share);
try e.writeInt(key_len + 4);
try e.writeEnum(h.named_group);
try e.writeInt(key_len);
try e.write(h.server_pub_key);
}
try w.writeInt(@as(u16, @intCast(e.pos))); // extensions length
const payload_len = w.pos + e.pos;
buf[0..header_len].* = record.header(.handshake, 4 + payload_len) ++
record.handshakeHeader(.server_hello, payload_len);
return buf[0 .. header_len + payload_len];
}
fn makeCertificateRequest(buf: []u8) ![]const u8 {
// handshake header + context length + extensions length
const header_len = 4 + 1 + 2;
// First write extensions, leave space for header.
var ext = record.Writer{ .buf = buf[header_len..] };
try ext.writeExtension(.signature_algorithms, common.supported_signature_algorithms);
var w = record.Writer{ .buf = buf };
try w.writeHandshakeHeader(.certificate_request, ext.pos + 3);
try w.writeInt(@as(u8, 0)); // certificate request context length = 0
try w.writeInt(@as(u16, @intCast(ext.pos))); // extensions length
assert(w.pos == header_len);
w.pos += ext.pos;
return w.getWritten();
}
fn readClientHello(h: *HandshakeT) !void {
var d = try h.rec_rdr.nextDecoder();
try d.expectContentType(.handshake);
h.transcript.update(d.payload);
const handshake_type = try d.decode(proto.Handshake);
if (handshake_type != .client_hello) return error.TlsUnexpectedMessage;
_ = try d.decode(u24); // handshake length
if (try d.decode(proto.Version) != .tls_1_2) return error.TlsProtocolVersion;
h.client_random = try d.array(32);
{ // legacy session id
const len = try d.decode(u8);
h.legacy_session_id = dupe(&h.legacy_session_id_buf, try d.slice(len));
}
{ // cipher suites
const end_idx = try d.decode(u16) + d.idx;
while (d.idx < end_idx) {
const cipher_suite = try d.decode(CipherSuite);
if (cipher_suites.includes(cipher_suites.tls13, cipher_suite) and
@intFromEnum(h.cipher_suite) == 0)
{
h.cipher_suite = cipher_suite;
}
}
if (@intFromEnum(h.cipher_suite) == 0)
return error.TlsHandshakeFailure;
}
try d.skip(2); // compression methods
var key_share_received = false;
// extensions
const extensions_end_idx = try d.decode(u16) + d.idx;
while (d.idx < extensions_end_idx) {
const extension_type = try d.decode(proto.Extension);
const extension_len = try d.decode(u16);
switch (extension_type) {
.supported_versions => {
var tls_1_3_supported = false;
const end_idx = try d.decode(u8) + d.idx;
while (d.idx < end_idx) {
if (try d.decode(proto.Version) == proto.Version.tls_1_3) {
tls_1_3_supported = true;
}
}
if (!tls_1_3_supported) return error.TlsProtocolVersion;
},
.key_share => {
if (extension_len == 0) return error.TlsDecodeError;
key_share_received = true;
var selected_named_group_idx = supported_named_groups.len;
const end_idx = try d.decode(u16) + d.idx;
while (d.idx < end_idx) {
const named_group = try d.decode(proto.NamedGroup);
switch (@intFromEnum(named_group)) {
0x0001...0x0016,
0x001a...0x001c,
0xff01...0xff02,
=> return error.TlsIllegalParameter,
else => {},
}
const client_pub_key = try d.slice(try d.decode(u16));
for (supported_named_groups, 0..) |supported, idx| {
if (named_group == supported and idx < selected_named_group_idx) {
h.named_group = named_group;
h.client_pub_key = dupe(&h.client_pub_key_buf, client_pub_key);
selected_named_group_idx = idx;
}
}
}
if (@intFromEnum(h.named_group) == 0)
return error.TlsIllegalParameter;
},
.supported_groups => {
const end_idx = try d.decode(u16) + d.idx;
while (d.idx < end_idx) {
const named_group = try d.decode(proto.NamedGroup);
switch (@intFromEnum(named_group)) {
0x0001...0x0016,
0x001a...0x001c,
0xff01...0xff02,
=> return error.TlsIllegalParameter,
else => {},
}
}
},
.signature_algorithms => {
if (@intFromEnum(h.signature_scheme) == 0) {
try d.skip(extension_len);
} else {
var found = false;
const list_len = try d.decode(u16);
if (list_len == 0) return error.TlsDecodeError;
const end_idx = list_len + d.idx;
while (d.idx < end_idx) {
const signature_scheme = try d.decode(proto.SignatureScheme);
if (signature_scheme == h.signature_scheme) found = true;
}
if (!found) return error.TlsHandshakeFailure;
}
},
else => {
try d.skip(extension_len);
},
}
}
if (!key_share_received) return error.TlsMissingExtension;
if (@intFromEnum(h.named_group) == 0) return error.TlsIllegalParameter;
}
};
}
const testing = std.testing;
const data13 = @import("testdata/tls13.zig");
const testu = @import("testu.zig");
fn testReader(data: []const u8) record.Reader(std.io.FixedBufferStream([]const u8)) {
return record.reader(std.io.fixedBufferStream(data));
}
const TestHandshake = Handshake(std.io.FixedBufferStream([]const u8));
test "read client hello" {
var buffer: [1024]u8 = undefined;
var rec_rdr = testReader(&data13.client_hello);
var h = TestHandshake.init(&buffer, &rec_rdr);
h.signature_scheme = .ecdsa_secp521r1_sha512; // this must be supported in signature_algorithms extension
try h.readClientHello();
try testing.expectEqual(CipherSuite.AES_256_GCM_SHA384, h.cipher_suite);
try testing.expectEqual(.x25519, h.named_group);
try testing.expectEqualSlices(u8, &data13.client_random, &h.client_random);
try testing.expectEqualSlices(u8, &data13.client_public_key, h.client_pub_key);
}
test "make server hello" {
var buffer: [128]u8 = undefined;
var h = TestHandshake.init(&buffer, undefined);
h.cipher_suite = .AES_256_GCM_SHA384;
testu.fillFrom(&h.server_random, 0);
testu.fillFrom(&h.server_pub_key_buf, 0x20);
h.named_group = .x25519;
h.server_pub_key = h.server_pub_key_buf[0..32];
const actual = try h.makeServerHello(&buffer);
const expected = &testu.hexToBytes(
\\ 16 03 03 00 5a 02 00 00 56
\\ 03 03
\\ 00 01 02 03 04 05 06 07 08 09 0a 0b 0c 0d 0e 0f 10 11 12 13 14 15 16 17 18 19 1a 1b 1c 1d 1e 1f
\\ 00
\\ 13 02 00
\\ 00 2e 00 2b 00 02 03 04
\\ 00 33 00 24 00 1d 00 20
\\ 20 21 22 23 24 25 26 27 28 29 2a 2b 2c 2d 2e 2f 30 31 32 33 34 35 36 37 38 39 3a 3b 3c 3d 3e 3f
);
try testing.expectEqualSlices(u8, expected, actual);
}
test "make certificate request" {
var buffer: [32]u8 = undefined;
const expected = testu.hexToBytes("0d 00 00 1b" ++ // handshake header
"00 00 18" ++ // extension length
"00 0d" ++ // signature algorithms extension
"00 14" ++ // extension length
"00 12" ++ // list length 6 * 2 bytes
"04 03 05 03 08 04 08 05 08 06 08 07 02 01 04 01 05 01" // signature schemes
);
const actual = try TestHandshake.makeCertificateRequest(&buffer);
try testing.expectEqualSlices(u8, &expected, actual);
}

View File

@@ -1,60 +0,0 @@
//! Exporting tls key so we can share them with Wireshark and analyze decrypted
//! traffic in Wireshark.
//! To configure Wireshark to use exprted keys see curl reference.
//!
//! References:
//! curl: https://everything.curl.dev/usingcurl/tls/sslkeylogfile.html
//! openssl: https://www.openssl.org/docs/manmaster/man3/SSL_CTX_set_keylog_callback.html
//! https://udn.realityripple.com/docs/Mozilla/Projects/NSS/Key_Log_Format
const std = @import("std");
const key_log_file_env = "SSLKEYLOGFILE";
pub const label = struct {
// tls 1.3
pub const client_handshake_traffic_secret: []const u8 = "CLIENT_HANDSHAKE_TRAFFIC_SECRET";
pub const server_handshake_traffic_secret: []const u8 = "SERVER_HANDSHAKE_TRAFFIC_SECRET";
pub const client_traffic_secret_0: []const u8 = "CLIENT_TRAFFIC_SECRET_0";
pub const server_traffic_secret_0: []const u8 = "SERVER_TRAFFIC_SECRET_0";
// tls 1.2
pub const client_random: []const u8 = "CLIENT_RANDOM";
};
pub const Callback = *const fn (label: []const u8, client_random: []const u8, secret: []const u8) void;
/// Writes tls keys to the file pointed by SSLKEYLOGFILE environment variable.
pub fn callback(label_: []const u8, client_random: []const u8, secret: []const u8) void {
if (std.posix.getenv(key_log_file_env)) |file_name| {
fileAppend(file_name, label_, client_random, secret) catch return;
}
}
pub fn fileAppend(file_name: []const u8, label_: []const u8, client_random: []const u8, secret: []const u8) !void {
var buf: [1024]u8 = undefined;
const line = try formatLine(&buf, label_, client_random, secret);
try fileWrite(file_name, line);
}
fn fileWrite(file_name: []const u8, line: []const u8) !void {
var file = try std.fs.createFileAbsolute(file_name, .{ .truncate = false });
defer file.close();
const stat = try file.stat();
try file.seekTo(stat.size);
try file.writeAll(line);
}
pub fn formatLine(buf: []u8, label_: []const u8, client_random: []const u8, secret: []const u8) ![]const u8 {
var fbs = std.io.fixedBufferStream(buf);
const w = fbs.writer();
try w.print("{s} ", .{label_});
for (client_random) |b| {
try std.fmt.formatInt(b, 16, .lower, .{ .width = 2, .fill = '0' }, w);
}
try w.writeByte(' ');
for (secret) |b| {
try std.fmt.formatInt(b, 16, .lower, .{ .width = 2, .fill = '0' }, w);
}
try w.writeByte('\n');
return fbs.getWritten();
}

View File

@@ -1,51 +0,0 @@
const std = @import("std");
pub const CipherSuite = @import("cipher.zig").CipherSuite;
pub const cipher_suites = @import("cipher.zig").cipher_suites;
pub const PrivateKey = @import("PrivateKey.zig");
pub const Connection = @import("connection.zig").Connection;
pub const ClientOptions = @import("handshake_client.zig").Options;
pub const ServerOptions = @import("handshake_server.zig").Options;
pub const key_log = @import("key_log.zig");
pub const proto = @import("protocol.zig");
pub const NamedGroup = proto.NamedGroup;
pub const Version = proto.Version;
const common = @import("handshake_common.zig");
pub const CertBundle = common.CertBundle;
pub const CertKeyPair = common.CertKeyPair;
pub const record = @import("record.zig");
const connection = @import("connection.zig").connection;
const max_ciphertext_record_len = @import("cipher.zig").max_ciphertext_record_len;
const HandshakeServer = @import("handshake_server.zig").Handshake;
const HandshakeClient = @import("handshake_client.zig").Handshake;
pub fn client(stream: anytype, opt: ClientOptions) !Connection(@TypeOf(stream)) {
const Stream = @TypeOf(stream);
var conn = connection(stream);
var write_buf: [max_ciphertext_record_len]u8 = undefined;
var h = HandshakeClient(Stream).init(&write_buf, &conn.rec_rdr);
conn.cipher = try h.handshake(conn.stream, opt);
return conn;
}
pub fn server(stream: anytype, opt: ServerOptions) !Connection(@TypeOf(stream)) {
const Stream = @TypeOf(stream);
var conn = connection(stream);
var write_buf: [max_ciphertext_record_len]u8 = undefined;
var h = HandshakeServer(Stream).init(&write_buf, &conn.rec_rdr);
conn.cipher = try h.handshake(conn.stream, opt);
return conn;
}
test {
_ = @import("handshake_common.zig");
_ = @import("handshake_server.zig");
_ = @import("handshake_client.zig");
_ = @import("connection.zig");
_ = @import("cipher.zig");
_ = @import("record.zig");
_ = @import("transcript.zig");
_ = @import("PrivateKey.zig");
}

View File

@@ -1,302 +0,0 @@
pub const Version = enum(u16) {
tls_1_2 = 0x0303,
tls_1_3 = 0x0304,
_,
};
pub const ContentType = enum(u8) {
invalid = 0,
change_cipher_spec = 20,
alert = 21,
handshake = 22,
application_data = 23,
_,
};
pub const Handshake = enum(u8) {
client_hello = 1,
server_hello = 2,
new_session_ticket = 4,
end_of_early_data = 5,
encrypted_extensions = 8,
certificate = 11,
server_key_exchange = 12,
certificate_request = 13,
server_hello_done = 14,
certificate_verify = 15,
client_key_exchange = 16,
finished = 20,
key_update = 24,
message_hash = 254,
_,
};
pub const Curve = enum(u8) {
named_curve = 0x03,
_,
};
pub const Extension = enum(u16) {
/// RFC 6066
server_name = 0,
/// RFC 6066
max_fragment_length = 1,
/// RFC 6066
status_request = 5,
/// RFC 8422, 7919
supported_groups = 10,
/// RFC 8446
signature_algorithms = 13,
/// RFC 5764
use_srtp = 14,
/// RFC 6520
heartbeat = 15,
/// RFC 7301
application_layer_protocol_negotiation = 16,
/// RFC 6962
signed_certificate_timestamp = 18,
/// RFC 7250
client_certificate_type = 19,
/// RFC 7250
server_certificate_type = 20,
/// RFC 7685
padding = 21,
/// RFC 8446
pre_shared_key = 41,
/// RFC 8446
early_data = 42,
/// RFC 8446
supported_versions = 43,
/// RFC 8446
cookie = 44,
/// RFC 8446
psk_key_exchange_modes = 45,
/// RFC 8446
certificate_authorities = 47,
/// RFC 8446
oid_filters = 48,
/// RFC 8446
post_handshake_auth = 49,
/// RFC 8446
signature_algorithms_cert = 50,
/// RFC 8446
key_share = 51,
_,
};
pub fn alertFromError(err: anyerror) [2]u8 {
return [2]u8{ @intFromEnum(Alert.Level.fatal), @intFromEnum(Alert.fromError(err)) };
}
pub const Alert = enum(u8) {
pub const Level = enum(u8) {
warning = 1,
fatal = 2,
_,
};
pub const Error = error{
TlsAlertUnexpectedMessage,
TlsAlertBadRecordMac,
TlsAlertRecordOverflow,
TlsAlertHandshakeFailure,
TlsAlertBadCertificate,
TlsAlertUnsupportedCertificate,
TlsAlertCertificateRevoked,
TlsAlertCertificateExpired,
TlsAlertCertificateUnknown,
TlsAlertIllegalParameter,
TlsAlertUnknownCa,
TlsAlertAccessDenied,
TlsAlertDecodeError,
TlsAlertDecryptError,
TlsAlertProtocolVersion,
TlsAlertInsufficientSecurity,
TlsAlertInternalError,
TlsAlertInappropriateFallback,
TlsAlertMissingExtension,
TlsAlertUnsupportedExtension,
TlsAlertUnrecognizedName,
TlsAlertBadCertificateStatusResponse,
TlsAlertUnknownPskIdentity,
TlsAlertCertificateRequired,
TlsAlertNoApplicationProtocol,
TlsAlertUnknown,
};
close_notify = 0,
unexpected_message = 10,
bad_record_mac = 20,
record_overflow = 22,
handshake_failure = 40,
bad_certificate = 42,
unsupported_certificate = 43,
certificate_revoked = 44,
certificate_expired = 45,
certificate_unknown = 46,
illegal_parameter = 47,
unknown_ca = 48,
access_denied = 49,
decode_error = 50,
decrypt_error = 51,
protocol_version = 70,
insufficient_security = 71,
internal_error = 80,
inappropriate_fallback = 86,
user_canceled = 90,
missing_extension = 109,
unsupported_extension = 110,
unrecognized_name = 112,
bad_certificate_status_response = 113,
unknown_psk_identity = 115,
certificate_required = 116,
no_application_protocol = 120,
_,
pub fn toError(alert: Alert) Error!void {
return switch (alert) {
.close_notify => {}, // not an error
.unexpected_message => error.TlsAlertUnexpectedMessage,
.bad_record_mac => error.TlsAlertBadRecordMac,
.record_overflow => error.TlsAlertRecordOverflow,
.handshake_failure => error.TlsAlertHandshakeFailure,
.bad_certificate => error.TlsAlertBadCertificate,
.unsupported_certificate => error.TlsAlertUnsupportedCertificate,
.certificate_revoked => error.TlsAlertCertificateRevoked,
.certificate_expired => error.TlsAlertCertificateExpired,
.certificate_unknown => error.TlsAlertCertificateUnknown,
.illegal_parameter => error.TlsAlertIllegalParameter,
.unknown_ca => error.TlsAlertUnknownCa,
.access_denied => error.TlsAlertAccessDenied,
.decode_error => error.TlsAlertDecodeError,
.decrypt_error => error.TlsAlertDecryptError,
.protocol_version => error.TlsAlertProtocolVersion,
.insufficient_security => error.TlsAlertInsufficientSecurity,
.internal_error => error.TlsAlertInternalError,
.inappropriate_fallback => error.TlsAlertInappropriateFallback,
.user_canceled => {}, // not an error
.missing_extension => error.TlsAlertMissingExtension,
.unsupported_extension => error.TlsAlertUnsupportedExtension,
.unrecognized_name => error.TlsAlertUnrecognizedName,
.bad_certificate_status_response => error.TlsAlertBadCertificateStatusResponse,
.unknown_psk_identity => error.TlsAlertUnknownPskIdentity,
.certificate_required => error.TlsAlertCertificateRequired,
.no_application_protocol => error.TlsAlertNoApplicationProtocol,
_ => error.TlsAlertUnknown,
};
}
pub fn fromError(err: anyerror) Alert {
return switch (err) {
error.TlsUnexpectedMessage => .unexpected_message,
error.TlsBadRecordMac => .bad_record_mac,
error.TlsRecordOverflow => .record_overflow,
error.TlsHandshakeFailure => .handshake_failure,
error.TlsBadCertificate => .bad_certificate,
error.TlsUnsupportedCertificate => .unsupported_certificate,
error.TlsCertificateRevoked => .certificate_revoked,
error.TlsCertificateExpired => .certificate_expired,
error.TlsCertificateUnknown => .certificate_unknown,
error.TlsIllegalParameter,
error.IdentityElement,
error.InvalidEncoding,
=> .illegal_parameter,
error.TlsUnknownCa => .unknown_ca,
error.TlsAccessDenied => .access_denied,
error.TlsDecodeError => .decode_error,
error.TlsDecryptError => .decrypt_error,
error.TlsProtocolVersion => .protocol_version,
error.TlsInsufficientSecurity => .insufficient_security,
error.TlsInternalError => .internal_error,
error.TlsInappropriateFallback => .inappropriate_fallback,
error.TlsMissingExtension => .missing_extension,
error.TlsUnsupportedExtension => .unsupported_extension,
error.TlsUnrecognizedName => .unrecognized_name,
error.TlsBadCertificateStatusResponse => .bad_certificate_status_response,
error.TlsUnknownPskIdentity => .unknown_psk_identity,
error.TlsCertificateRequired => .certificate_required,
error.TlsNoApplicationProtocol => .no_application_protocol,
else => .internal_error,
};
}
pub fn parse(buf: [2]u8) Alert {
const level: Alert.Level = @enumFromInt(buf[0]);
const alert: Alert = @enumFromInt(buf[1]);
_ = level;
return alert;
}
pub fn closeNotify() [2]u8 {
return [2]u8{
@intFromEnum(Alert.Level.warning),
@intFromEnum(Alert.close_notify),
};
}
};
pub const SignatureScheme = enum(u16) {
// RSASSA-PKCS1-v1_5 algorithms
rsa_pkcs1_sha256 = 0x0401,
rsa_pkcs1_sha384 = 0x0501,
rsa_pkcs1_sha512 = 0x0601,
// ECDSA algorithms
ecdsa_secp256r1_sha256 = 0x0403,
ecdsa_secp384r1_sha384 = 0x0503,
ecdsa_secp521r1_sha512 = 0x0603,
// RSASSA-PSS algorithms with public key OID rsaEncryption
rsa_pss_rsae_sha256 = 0x0804,
rsa_pss_rsae_sha384 = 0x0805,
rsa_pss_rsae_sha512 = 0x0806,
// EdDSA algorithms
ed25519 = 0x0807,
ed448 = 0x0808,
// RSASSA-PSS algorithms with public key OID RSASSA-PSS
rsa_pss_pss_sha256 = 0x0809,
rsa_pss_pss_sha384 = 0x080a,
rsa_pss_pss_sha512 = 0x080b,
// Legacy algorithms
rsa_pkcs1_sha1 = 0x0201,
ecdsa_sha1 = 0x0203,
_,
};
pub const NamedGroup = enum(u16) {
// Elliptic Curve Groups (ECDHE)
secp256r1 = 0x0017,
secp384r1 = 0x0018,
secp521r1 = 0x0019,
x25519 = 0x001D,
x448 = 0x001E,
// Finite Field Groups (DHE)
ffdhe2048 = 0x0100,
ffdhe3072 = 0x0101,
ffdhe4096 = 0x0102,
ffdhe6144 = 0x0103,
ffdhe8192 = 0x0104,
// Hybrid post-quantum key agreements
x25519_kyber512d00 = 0xFE30,
x25519_kyber768d00 = 0x6399,
_,
};
pub const KeyUpdateRequest = enum(u8) {
update_not_requested = 0,
update_requested = 1,
_,
};
pub const Side = enum {
client,
server,
};

View File

@@ -1,405 +0,0 @@
const std = @import("std");
const assert = std.debug.assert;
const mem = std.mem;
const proto = @import("protocol.zig");
const cipher = @import("cipher.zig");
const Cipher = cipher.Cipher;
const record = @import("record.zig");
pub const header_len = 5;
pub fn header(content_type: proto.ContentType, payload_len: usize) [header_len]u8 {
const int2 = std.crypto.tls.int2;
return [1]u8{@intFromEnum(content_type)} ++
int2(@intFromEnum(proto.Version.tls_1_2)) ++
int2(@intCast(payload_len));
}
pub fn handshakeHeader(handshake_type: proto.Handshake, payload_len: usize) [4]u8 {
const int3 = std.crypto.tls.int3;
return [1]u8{@intFromEnum(handshake_type)} ++ int3(@intCast(payload_len));
}
pub fn reader(inner_reader: anytype) Reader(@TypeOf(inner_reader)) {
return .{ .inner_reader = inner_reader };
}
pub fn Reader(comptime InnerReader: type) type {
return struct {
inner_reader: InnerReader,
buffer: [cipher.max_ciphertext_record_len]u8 = undefined,
start: usize = 0,
end: usize = 0,
const ReaderT = @This();
pub fn nextDecoder(r: *ReaderT) !Decoder {
const rec = (try r.next()) orelse return error.EndOfStream;
if (@intFromEnum(rec.protocol_version) != 0x0300 and
@intFromEnum(rec.protocol_version) != 0x0301 and
rec.protocol_version != .tls_1_2)
return error.TlsBadVersion;
return .{
.content_type = rec.content_type,
.payload = rec.payload,
};
}
pub fn contentType(buf: []const u8) proto.ContentType {
return @enumFromInt(buf[0]);
}
pub fn protocolVersion(buf: []const u8) proto.Version {
return @enumFromInt(mem.readInt(u16, buf[1..3], .big));
}
pub fn next(r: *ReaderT) !?Record {
while (true) {
const buffer = r.buffer[r.start..r.end];
// If we have 5 bytes header.
if (buffer.len >= record.header_len) {
const record_header = buffer[0..record.header_len];
const payload_len = mem.readInt(u16, record_header[3..5], .big);
if (payload_len > cipher.max_ciphertext_len)
return error.TlsRecordOverflow;
const record_len = record.header_len + payload_len;
// If we have whole record
if (buffer.len >= record_len) {
r.start += record_len;
return Record.init(buffer[0..record_len]);
}
}
{ // Move dirty part to the start of the buffer.
const n = r.end - r.start;
if (n > 0 and r.start > 0) {
if (r.start > n) {
@memcpy(r.buffer[0..n], r.buffer[r.start..][0..n]);
} else {
mem.copyForwards(u8, r.buffer[0..n], r.buffer[r.start..][0..n]);
}
}
r.start = 0;
r.end = n;
}
{ // Read more from inner_reader.
const n = try r.inner_reader.read(r.buffer[r.end..]);
if (n == 0) return null;
r.end += n;
}
}
}
pub fn nextDecrypt(r: *ReaderT, cph: *Cipher) !?struct { proto.ContentType, []const u8 } {
const rec = (try r.next()) orelse return null;
if (rec.protocol_version != .tls_1_2) return error.TlsBadVersion;
return try cph.decrypt(
// Reuse reader buffer for cleartext. `rec.header` and
// `rec.payload`(ciphertext) are also pointing somewhere in
// this buffer. Decrypter is first reading then writing a
// block, cleartext has less length then ciphertext,
// cleartext starts from the beginning of the buffer, so
// ciphertext is always ahead of cleartext.
r.buffer[0..r.start],
rec,
);
}
pub fn hasMore(r: *ReaderT) bool {
return r.end > r.start;
}
};
}
pub const Record = struct {
content_type: proto.ContentType,
protocol_version: proto.Version = .tls_1_2,
header: []const u8,
payload: []const u8,
pub fn init(buffer: []const u8) Record {
return .{
.content_type = @enumFromInt(buffer[0]),
.protocol_version = @enumFromInt(mem.readInt(u16, buffer[1..3], .big)),
.header = buffer[0..record.header_len],
.payload = buffer[record.header_len..],
};
}
pub fn decoder(r: @This()) Decoder {
return Decoder.init(r.content_type, @constCast(r.payload));
}
};
pub const Decoder = struct {
content_type: proto.ContentType,
payload: []const u8,
idx: usize = 0,
pub fn init(content_type: proto.ContentType, payload: []u8) Decoder {
return .{
.content_type = content_type,
.payload = payload,
};
}
pub fn decode(d: *Decoder, comptime T: type) !T {
switch (@typeInfo(T)) {
.Int => |info| switch (info.bits) {
8 => {
try skip(d, 1);
return d.payload[d.idx - 1];
},
16 => {
try skip(d, 2);
const b0: u16 = d.payload[d.idx - 2];
const b1: u16 = d.payload[d.idx - 1];
return (b0 << 8) | b1;
},
24 => {
try skip(d, 3);
const b0: u24 = d.payload[d.idx - 3];
const b1: u24 = d.payload[d.idx - 2];
const b2: u24 = d.payload[d.idx - 1];
return (b0 << 16) | (b1 << 8) | b2;
},
else => @compileError("unsupported int type: " ++ @typeName(T)),
},
.Enum => |info| {
const int = try d.decode(info.tag_type);
if (info.is_exhaustive) @compileError("exhaustive enum cannot be used");
return @as(T, @enumFromInt(int));
},
else => @compileError("unsupported type: " ++ @typeName(T)),
}
}
pub fn array(d: *Decoder, comptime len: usize) ![len]u8 {
try d.skip(len);
return d.payload[d.idx - len ..][0..len].*;
}
pub fn slice(d: *Decoder, len: usize) ![]const u8 {
try d.skip(len);
return d.payload[d.idx - len ..][0..len];
}
pub fn skip(d: *Decoder, amt: usize) !void {
if (d.idx + amt > d.payload.len) return error.TlsDecodeError;
d.idx += amt;
}
pub fn rest(d: Decoder) []const u8 {
return d.payload[d.idx..];
}
pub fn eof(d: Decoder) bool {
return d.idx == d.payload.len;
}
pub fn expectContentType(d: *Decoder, content_type: proto.ContentType) !void {
if (d.content_type == content_type) return;
switch (d.content_type) {
.alert => try d.raiseAlert(),
else => return error.TlsUnexpectedMessage,
}
}
pub fn raiseAlert(d: *Decoder) !void {
if (d.payload.len < 2) return error.TlsUnexpectedMessage;
try proto.Alert.parse(try d.array(2)).toError();
return error.TlsAlertCloseNotify;
}
};
const testing = std.testing;
const data12 = @import("testdata/tls12.zig");
const testu = @import("testu.zig");
const CipherSuite = @import("cipher.zig").CipherSuite;
test Reader {
var fbs = std.io.fixedBufferStream(&data12.server_responses);
var rdr = reader(fbs.reader());
const expected = [_]struct {
content_type: proto.ContentType,
payload_len: usize,
}{
.{ .content_type = .handshake, .payload_len = 49 },
.{ .content_type = .handshake, .payload_len = 815 },
.{ .content_type = .handshake, .payload_len = 300 },
.{ .content_type = .handshake, .payload_len = 4 },
.{ .content_type = .change_cipher_spec, .payload_len = 1 },
.{ .content_type = .handshake, .payload_len = 64 },
};
for (expected) |e| {
const rec = (try rdr.next()).?;
try testing.expectEqual(e.content_type, rec.content_type);
try testing.expectEqual(e.payload_len, rec.payload.len);
try testing.expectEqual(.tls_1_2, rec.protocol_version);
}
}
test Decoder {
var fbs = std.io.fixedBufferStream(&data12.server_responses);
var rdr = reader(fbs.reader());
var d = (try rdr.nextDecoder());
try testing.expectEqual(.handshake, d.content_type);
try testing.expectEqual(.server_hello, try d.decode(proto.Handshake));
try testing.expectEqual(45, try d.decode(u24)); // length
try testing.expectEqual(.tls_1_2, try d.decode(proto.Version));
try testing.expectEqualStrings(
&testu.hexToBytes("707172737475767778797a7b7c7d7e7f808182838485868788898a8b8c8d8e8f"),
try d.slice(32),
); // server random
try testing.expectEqual(0, try d.decode(u8)); // session id len
try testing.expectEqual(.ECDHE_RSA_WITH_AES_128_CBC_SHA, try d.decode(CipherSuite));
try testing.expectEqual(0, try d.decode(u8)); // compression method
try testing.expectEqual(5, try d.decode(u16)); // extension length
try testing.expectEqual(5, d.rest().len);
try d.skip(5);
try testing.expect(d.eof());
}
pub const Writer = struct {
buf: []u8,
pos: usize = 0,
pub fn write(self: *Writer, data: []const u8) !void {
defer self.pos += data.len;
if (self.pos + data.len > self.buf.len) return error.BufferOverflow;
@memcpy(self.buf[self.pos..][0..data.len], data);
}
pub fn writeByte(self: *Writer, b: u8) !void {
defer self.pos += 1;
if (self.pos == self.buf.len) return error.BufferOverflow;
self.buf[self.pos] = b;
}
pub fn writeEnum(self: *Writer, value: anytype) !void {
try self.writeInt(@intFromEnum(value));
}
pub fn writeInt(self: *Writer, value: anytype) !void {
const IntT = @TypeOf(value);
const bytes = @divExact(@typeInfo(IntT).Int.bits, 8);
const free = self.buf[self.pos..];
if (free.len < bytes) return error.BufferOverflow;
mem.writeInt(IntT, free[0..bytes], value, .big);
self.pos += bytes;
}
pub fn writeHandshakeHeader(self: *Writer, handshake_type: proto.Handshake, payload_len: usize) !void {
try self.write(&record.handshakeHeader(handshake_type, payload_len));
}
/// Should be used after writing handshake payload in buffer provided by `getHandshakePayload`.
pub fn advanceHandshake(self: *Writer, handshake_type: proto.Handshake, payload_len: usize) !void {
try self.write(&record.handshakeHeader(handshake_type, payload_len));
self.pos += payload_len;
}
/// Record payload is already written by using buffer space from `getPayload`.
/// Now when we know payload len we can write record header and advance over payload.
pub fn advanceRecord(self: *Writer, content_type: proto.ContentType, payload_len: usize) !void {
try self.write(&record.header(content_type, payload_len));
self.pos += payload_len;
}
pub fn writeRecord(self: *Writer, content_type: proto.ContentType, payload: []const u8) !void {
try self.write(&record.header(content_type, payload.len));
try self.write(payload);
}
/// Preserves space for record header and returns buffer free space.
pub fn getPayload(self: *Writer) []u8 {
return self.buf[self.pos + record.header_len ..];
}
/// Preserves space for handshake header and returns buffer free space.
pub fn getHandshakePayload(self: *Writer) []u8 {
return self.buf[self.pos + 4 ..];
}
pub fn getWritten(self: *Writer) []const u8 {
return self.buf[0..self.pos];
}
pub fn getFree(self: *Writer) []u8 {
return self.buf[self.pos..];
}
pub fn writeEnumArray(self: *Writer, comptime E: type, tags: []const E) !void {
assert(@sizeOf(E) == 2);
try self.writeInt(@as(u16, @intCast(tags.len * 2)));
for (tags) |t| {
try self.writeEnum(t);
}
}
pub fn writeExtension(
self: *Writer,
comptime et: proto.Extension,
tags: anytype,
) !void {
try self.writeEnum(et);
if (et == .supported_versions) {
try self.writeInt(@as(u16, @intCast(tags.len * 2 + 1)));
try self.writeInt(@as(u8, @intCast(tags.len * 2)));
} else {
try self.writeInt(@as(u16, @intCast(tags.len * 2 + 2)));
try self.writeInt(@as(u16, @intCast(tags.len * 2)));
}
for (tags) |t| {
try self.writeEnum(t);
}
}
pub fn writeKeyShare(
self: *Writer,
named_groups: []const proto.NamedGroup,
keys: []const []const u8,
) !void {
assert(named_groups.len == keys.len);
try self.writeEnum(proto.Extension.key_share);
var l: usize = 0;
for (keys) |key| {
l += key.len + 4;
}
try self.writeInt(@as(u16, @intCast(l + 2)));
try self.writeInt(@as(u16, @intCast(l)));
for (named_groups, 0..) |ng, i| {
const key = keys[i];
try self.writeEnum(ng);
try self.writeInt(@as(u16, @intCast(key.len)));
try self.write(key);
}
}
pub fn writeServerName(self: *Writer, host: []const u8) !void {
const host_len: u16 = @intCast(host.len);
try self.writeEnum(proto.Extension.server_name);
try self.writeInt(host_len + 5); // byte length of extension payload
try self.writeInt(host_len + 3); // server_name_list byte count
try self.writeByte(0); // name type
try self.writeInt(host_len);
try self.write(host);
}
};
test "Writer" {
var buf: [16]u8 = undefined;
var w = Writer{ .buf = &buf };
try w.write("ab");
try w.writeEnum(proto.Curve.named_curve);
try w.writeEnum(proto.NamedGroup.x25519);
try w.writeInt(@as(u16, 0x1234));
try testing.expectEqualSlices(u8, &[_]u8{ 'a', 'b', 0x03, 0x00, 0x1d, 0x12, 0x34 }, w.getWritten());
}

View File

@@ -1,467 +0,0 @@
//! An encoding of ASN.1.
//!
//! Distinguised Encoding Rules as defined in X.690 and X.691.
//!
//! A version of Basic Encoding Rules (BER) where there is exactly ONE way to
//! represent non-constructed elements. This is useful for cryptographic signatures.
//!
//! Currently an implementation detail of the standard library not fit for public
//! use since it's missing an encoder.
const std = @import("std");
const builtin = @import("builtin");
pub const Index = usize;
const log = std.log.scoped(.der);
/// A secure DER parser that:
/// - Does NOT read memory outside `bytes`.
/// - Does NOT return elements with slices outside `bytes`.
/// - Errors on values that do NOT follow DER rules.
/// - Lengths that could be represented in a shorter form.
/// - Booleans that are not 0xff or 0x00.
pub const Parser = struct {
bytes: []const u8,
index: Index = 0,
pub const Error = Element.Error || error{
UnexpectedElement,
InvalidIntegerEncoding,
Overflow,
NonCanonical,
};
pub fn expectBool(self: *Parser) Error!bool {
const ele = try self.expect(.universal, false, .boolean);
if (ele.slice.len() != 1) return error.InvalidBool;
return switch (self.view(ele)[0]) {
0x00 => false,
0xff => true,
else => error.InvalidBool,
};
}
pub fn expectBitstring(self: *Parser) Error!BitString {
const ele = try self.expect(.universal, false, .bitstring);
const bytes = self.view(ele);
const right_padding = bytes[0];
if (right_padding >= 8) return error.InvalidBitString;
return .{
.bytes = bytes[1..],
.right_padding = @intCast(right_padding),
};
}
// TODO: return high resolution date time type instead of epoch seconds
pub fn expectDateTime(self: *Parser) Error!i64 {
const ele = try self.expect(.universal, false, null);
const bytes = self.view(ele);
switch (ele.identifier.tag) {
.utc_time => {
// Example: "YYMMDD000000Z"
if (bytes.len != 13)
return error.InvalidDateTime;
if (bytes[12] != 'Z')
return error.InvalidDateTime;
var date: Date = undefined;
date.year = try parseTimeDigits(bytes[0..2], 0, 99);
date.year += if (date.year >= 50) 1900 else 2000;
date.month = try parseTimeDigits(bytes[2..4], 1, 12);
date.day = try parseTimeDigits(bytes[4..6], 1, 31);
const time = try parseTime(bytes[6..12]);
return date.toEpochSeconds() + time.toSec();
},
.generalized_time => {
// Examples:
// "19920622123421Z"
// "19920722132100.3Z"
if (bytes.len < 15)
return error.InvalidDateTime;
var date: Date = undefined;
date.year = try parseYear4(bytes[0..4]);
date.month = try parseTimeDigits(bytes[4..6], 1, 12);
date.day = try parseTimeDigits(bytes[6..8], 1, 31);
const time = try parseTime(bytes[8..14]);
return date.toEpochSeconds() + time.toSec();
},
else => return error.InvalidDateTime,
}
}
pub fn expectOid(self: *Parser) Error![]const u8 {
const oid = try self.expect(.universal, false, .object_identifier);
return self.view(oid);
}
pub fn expectEnum(self: *Parser, comptime Enum: type) Error!Enum {
const oid = try self.expectOid();
return Enum.oids.get(oid) orelse {
if (builtin.mode == .Debug) {
var buf: [256]u8 = undefined;
var stream = std.io.fixedBufferStream(&buf);
try @import("./oid.zig").decode(oid, stream.writer());
log.warn("unknown oid {s} for enum {s}\n", .{ stream.getWritten(), @typeName(Enum) });
}
return error.UnknownObjectId;
};
}
pub fn expectInt(self: *Parser, comptime T: type) Error!T {
const ele = try self.expectPrimitive(.integer);
const bytes = self.view(ele);
const info = @typeInfo(T);
if (info != .Int) @compileError(@typeName(T) ++ " is not an int type");
const Shift = std.math.Log2Int(u8);
var result: std.meta.Int(.unsigned, info.Int.bits) = 0;
for (bytes, 0..) |b, index| {
const shifted = @shlWithOverflow(b, @as(Shift, @intCast(index * 8)));
if (shifted[1] == 1) return error.Overflow;
result |= shifted[0];
}
return @bitCast(result);
}
pub fn expectString(self: *Parser, allowed: std.EnumSet(String.Tag)) Error!String {
const ele = try self.expect(.universal, false, null);
switch (ele.identifier.tag) {
inline .string_utf8,
.string_numeric,
.string_printable,
.string_teletex,
.string_videotex,
.string_ia5,
.string_visible,
.string_universal,
.string_bmp,
=> |t| {
const tagname = @tagName(t)["string_".len..];
const tag = std.meta.stringToEnum(String.Tag, tagname) orelse unreachable;
if (allowed.contains(tag)) {
return String{ .tag = tag, .data = self.view(ele) };
}
},
else => {},
}
return error.UnexpectedElement;
}
pub fn expectPrimitive(self: *Parser, tag: ?Identifier.Tag) Error!Element {
var elem = try self.expect(.universal, false, tag);
if (tag == .integer and elem.slice.len() > 0) {
if (self.view(elem)[0] == 0) elem.slice.start += 1;
if (elem.slice.len() > 0 and self.view(elem)[0] == 0) return error.InvalidIntegerEncoding;
}
return elem;
}
/// Remember to call `expectEnd`
pub fn expectSequence(self: *Parser) Error!Element {
return try self.expect(.universal, true, .sequence);
}
/// Remember to call `expectEnd`
pub fn expectSequenceOf(self: *Parser) Error!Element {
return try self.expect(.universal, true, .sequence_of);
}
pub fn expectEnd(self: *Parser, val: usize) Error!void {
if (self.index != val) return error.NonCanonical; // either forgot to parse end OR an attacker
}
pub fn expect(
self: *Parser,
class: ?Identifier.Class,
constructed: ?bool,
tag: ?Identifier.Tag,
) Error!Element {
if (self.index >= self.bytes.len) return error.EndOfStream;
const res = try Element.init(self.bytes, self.index);
if (tag) |e| {
if (res.identifier.tag != e) return error.UnexpectedElement;
}
if (constructed) |e| {
if (res.identifier.constructed != e) return error.UnexpectedElement;
}
if (class) |e| {
if (res.identifier.class != e) return error.UnexpectedElement;
}
self.index = if (res.identifier.constructed) res.slice.start else res.slice.end;
return res;
}
pub fn view(self: Parser, elem: Element) []const u8 {
return elem.slice.view(self.bytes);
}
pub fn seek(self: *Parser, index: usize) void {
self.index = index;
}
pub fn eof(self: *Parser) bool {
return self.index == self.bytes.len;
}
};
pub const Element = struct {
identifier: Identifier,
slice: Slice,
pub const Slice = struct {
start: Index,
end: Index,
pub fn len(self: Slice) Index {
return self.end - self.start;
}
pub fn view(self: Slice, bytes: []const u8) []const u8 {
return bytes[self.start..self.end];
}
};
pub const Error = error{ InvalidLength, EndOfStream };
pub fn init(bytes: []const u8, index: Index) Error!Element {
var stream = std.io.fixedBufferStream(bytes[index..]);
var reader = stream.reader();
const identifier = @as(Identifier, @bitCast(try reader.readByte()));
const size_or_len_size = try reader.readByte();
var start = index + 2;
// short form between 0-127
if (size_or_len_size < 128) {
const end = start + size_or_len_size;
if (end > bytes.len) return error.InvalidLength;
return .{ .identifier = identifier, .slice = .{ .start = start, .end = end } };
}
// long form between 0 and std.math.maxInt(u1024)
const len_size: u7 = @truncate(size_or_len_size);
start += len_size;
if (len_size > @sizeOf(Index)) return error.InvalidLength;
const len = try reader.readVarInt(Index, .big, len_size);
if (len < 128) return error.InvalidLength; // should have used short form
const end = std.math.add(Index, start, len) catch return error.InvalidLength;
if (end > bytes.len) return error.InvalidLength;
return .{ .identifier = identifier, .slice = .{ .start = start, .end = end } };
}
};
test Element {
const short_form = [_]u8{ 0x30, 0x03, 0x02, 0x01, 0x09 };
try std.testing.expectEqual(Element{
.identifier = Identifier{ .tag = .sequence, .constructed = true, .class = .universal },
.slice = .{ .start = 2, .end = short_form.len },
}, Element.init(&short_form, 0));
const long_form = [_]u8{ 0x30, 129, 129 } ++ [_]u8{0} ** 129;
try std.testing.expectEqual(Element{
.identifier = Identifier{ .tag = .sequence, .constructed = true, .class = .universal },
.slice = .{ .start = 3, .end = long_form.len },
}, Element.init(&long_form, 0));
}
test "parser.expectInt" {
const one = [_]u8{ 2, 1, 1 };
var parser = Parser{ .bytes = &one };
try std.testing.expectEqual(@as(u8, 1), try parser.expectInt(u8));
}
pub const Identifier = packed struct(u8) {
tag: Tag,
constructed: bool,
class: Class,
pub const Class = enum(u2) {
universal,
application,
context_specific,
private,
};
// https://www.oss.com/asn1/resources/asn1-made-simple/asn1-quick-reference/asn1-tags.html
pub const Tag = enum(u5) {
boolean = 1,
integer = 2,
bitstring = 3,
octetstring = 4,
null = 5,
object_identifier = 6,
real = 9,
enumerated = 10,
string_utf8 = 12,
sequence = 16,
sequence_of = 17,
string_numeric = 18,
string_printable = 19,
string_teletex = 20,
string_videotex = 21,
string_ia5 = 22,
utc_time = 23,
generalized_time = 24,
string_visible = 26,
string_universal = 28,
string_bmp = 30,
_,
};
};
pub const BitString = struct {
bytes: []const u8,
right_padding: u3,
pub fn bitLen(self: BitString) usize {
return self.bytes.len * 8 + self.right_padding;
}
};
pub const String = struct {
tag: Tag,
data: []const u8,
pub const Tag = enum {
/// Blessed.
utf8,
/// us-ascii ([-][0-9][eE][.])*
numeric,
/// us-ascii ([A-Z][a-z][0-9][.?!,][ \t])*
printable,
/// iso-8859-1 with escaping into different character sets.
/// Cursed.
teletex,
/// iso-8859-1
videotex,
/// us-ascii first 128 characters.
ia5,
/// us-ascii without control characters.
visible,
/// utf-32-be
universal,
/// utf-16-be
bmp,
};
pub const all = [_]Tag{
.utf8,
.numeric,
.printable,
.teletex,
.videotex,
.ia5,
.visible,
.universal,
.bmp,
};
};
const Date = struct {
year: Year,
month: u8,
day: u8,
const Year = std.time.epoch.Year;
fn toEpochSeconds(date: Date) i64 {
// Euclidean Affine Transform by Cassio and Neri.
// Shift and correction constants for 1970-01-01.
const s = 82;
const K = 719468 + 146097 * s;
const L = 400 * s;
const Y_G: u32 = date.year;
const M_G: u32 = date.month;
const D_G: u32 = date.day;
// Map to computational calendar.
const J: u32 = if (M_G <= 2) 1 else 0;
const Y: u32 = Y_G + L - J;
const M: u32 = if (J != 0) M_G + 12 else M_G;
const D: u32 = D_G - 1;
const C: u32 = Y / 100;
// Rata die.
const y_star: u32 = 1461 * Y / 4 - C + C / 4;
const m_star: u32 = (979 * M - 2919) / 32;
const N: u32 = y_star + m_star + D;
const days: i32 = @intCast(N - K);
return @as(i64, days) * std.time.epoch.secs_per_day;
}
};
const Time = struct {
hour: std.math.IntFittingRange(0, 24),
minute: std.math.IntFittingRange(0, 60),
second: std.math.IntFittingRange(0, 60),
fn toSec(t: Time) i64 {
var sec: i64 = 0;
sec += @as(i64, t.hour) * 60 * 60;
sec += @as(i64, t.minute) * 60;
sec += t.second;
return sec;
}
};
fn parseTimeDigits(
text: *const [2]u8,
min: comptime_int,
max: comptime_int,
) !std.math.IntFittingRange(min, max) {
const result = std.fmt.parseInt(std.math.IntFittingRange(min, max), text, 10) catch
return error.InvalidTime;
if (result < min) return error.InvalidTime;
if (result > max) return error.InvalidTime;
return result;
}
test parseTimeDigits {
const expectEqual = std.testing.expectEqual;
try expectEqual(@as(u8, 0), try parseTimeDigits("00", 0, 99));
try expectEqual(@as(u8, 99), try parseTimeDigits("99", 0, 99));
try expectEqual(@as(u8, 42), try parseTimeDigits("42", 0, 99));
const expectError = std.testing.expectError;
try expectError(error.InvalidTime, parseTimeDigits("13", 1, 12));
try expectError(error.InvalidTime, parseTimeDigits("00", 1, 12));
try expectError(error.InvalidTime, parseTimeDigits("Di", 0, 99));
}
fn parseYear4(text: *const [4]u8) !Date.Year {
const result = std.fmt.parseInt(Date.Year, text, 10) catch return error.InvalidYear;
if (result > 9999) return error.InvalidYear;
return result;
}
test parseYear4 {
const expectEqual = std.testing.expectEqual;
try expectEqual(@as(Date.Year, 0), try parseYear4("0000"));
try expectEqual(@as(Date.Year, 9999), try parseYear4("9999"));
try expectEqual(@as(Date.Year, 1988), try parseYear4("1988"));
const expectError = std.testing.expectError;
try expectError(error.InvalidYear, parseYear4("999b"));
try expectError(error.InvalidYear, parseYear4("crap"));
try expectError(error.InvalidYear, parseYear4("r:bQ"));
}
fn parseTime(bytes: *const [6]u8) !Time {
return .{
.hour = try parseTimeDigits(bytes[0..2], 0, 23),
.minute = try parseTimeDigits(bytes[2..4], 0, 59),
.second = try parseTimeDigits(bytes[4..6], 0, 59),
};
}

View File

@@ -1,132 +0,0 @@
//! Developed by ITU-U and ISO/IEC for naming objects. Used in DER.
//!
//! This implementation supports any number of `u32` arcs.
const Arc = u32;
const encoding_base = 128;
/// Returns encoded length.
pub fn encodeLen(dot_notation: []const u8) !usize {
var split = std.mem.splitScalar(u8, dot_notation, '.');
if (split.next() == null) return 0;
if (split.next() == null) return 1;
var res: usize = 1;
while (split.next()) |s| {
const parsed = try std.fmt.parseUnsigned(Arc, s, 10);
const n_bytes = if (parsed == 0) 0 else std.math.log(Arc, encoding_base, parsed);
res += n_bytes;
res += 1;
}
return res;
}
pub const EncodeError = std.fmt.ParseIntError || error{
MissingPrefix,
BufferTooSmall,
};
pub fn encode(dot_notation: []const u8, buf: []u8) EncodeError![]const u8 {
if (buf.len < try encodeLen(dot_notation)) return error.BufferTooSmall;
var split = std.mem.splitScalar(u8, dot_notation, '.');
const first_str = split.next() orelse return error.MissingPrefix;
const second_str = split.next() orelse return error.MissingPrefix;
const first = try std.fmt.parseInt(u8, first_str, 10);
const second = try std.fmt.parseInt(u8, second_str, 10);
buf[0] = first * 40 + second;
var i: usize = 1;
while (split.next()) |s| {
var parsed = try std.fmt.parseUnsigned(Arc, s, 10);
const n_bytes = if (parsed == 0) 0 else std.math.log(Arc, encoding_base, parsed);
for (0..n_bytes) |j| {
const place = std.math.pow(Arc, encoding_base, n_bytes - @as(Arc, @intCast(j)));
const digit: u8 = @intCast(@divFloor(parsed, place));
buf[i] = digit | 0x80;
parsed -= digit * place;
i += 1;
}
buf[i] = @intCast(parsed);
i += 1;
}
return buf[0..i];
}
pub fn decode(encoded: []const u8, writer: anytype) @TypeOf(writer).Error!void {
const first = @divTrunc(encoded[0], 40);
const second = encoded[0] - first * 40;
try writer.print("{d}.{d}", .{ first, second });
var i: usize = 1;
while (i != encoded.len) {
const n_bytes: usize = brk: {
var res: usize = 1;
var j: usize = i;
while (encoded[j] & 0x80 != 0) {
res += 1;
j += 1;
}
break :brk res;
};
var n: usize = 0;
for (0..n_bytes) |j| {
const place = std.math.pow(usize, encoding_base, n_bytes - j - 1);
n += place * (encoded[i] & 0b01111111);
i += 1;
}
try writer.print(".{d}", .{n});
}
}
pub fn encodeComptime(comptime dot_notation: []const u8) [encodeLen(dot_notation) catch unreachable]u8 {
@setEvalBranchQuota(10_000);
var buf: [encodeLen(dot_notation) catch unreachable]u8 = undefined;
_ = encode(dot_notation, &buf) catch unreachable;
return buf;
}
const std = @import("std");
fn testOid(expected_encoded: []const u8, expected_dot_notation: []const u8) !void {
var buf: [256]u8 = undefined;
const encoded = try encode(expected_dot_notation, &buf);
try std.testing.expectEqualSlices(u8, expected_encoded, encoded);
var stream = std.io.fixedBufferStream(&buf);
try decode(expected_encoded, stream.writer());
try std.testing.expectEqualStrings(expected_dot_notation, stream.getWritten());
}
test "encode and decode" {
// https://learn.microsoft.com/en-us/windows/win32/seccertenroll/about-object-identifier
try testOid(
&[_]u8{ 0x2b, 0x06, 0x01, 0x04, 0x01, 0x82, 0x37, 0x15, 0x14 },
"1.3.6.1.4.1.311.21.20",
);
// https://luca.ntop.org/Teaching/Appunti/asn1.html
try testOid(&[_]u8{ 0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d }, "1.2.840.113549");
// https://www.sysadmins.lv/blog-en/how-to-encode-object-identifier-to-an-asn1-der-encoded-string.aspx
try testOid(&[_]u8{ 0x2a, 0x86, 0x8d, 0x20 }, "1.2.100000");
try testOid(
&[_]u8{ 0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d, 0x01, 0x01, 0x0b },
"1.2.840.113549.1.1.11",
);
try testOid(&[_]u8{ 0x2b, 0x65, 0x70 }, "1.3.101.112");
}
test encodeComptime {
try std.testing.expectEqual(
[_]u8{ 0x2b, 0x06, 0x01, 0x04, 0x01, 0x82, 0x37, 0x15, 0x14 },
encodeComptime("1.3.6.1.4.1.311.21.20"),
);
}

View File

@@ -1,880 +0,0 @@
//! RFC8017: Public Key Cryptography Standards #1 v2.2 (PKCS1)
const std = @import("std");
const der = @import("der.zig");
const ff = std.crypto.ff;
pub const max_modulus_bits = 4096;
const max_modulus_len = max_modulus_bits / 8;
const Modulus = std.crypto.ff.Modulus(max_modulus_bits);
const Fe = Modulus.Fe;
pub const ValueError = error{
Modulus,
Exponent,
};
pub const PublicKey = struct {
/// `n`
modulus: Modulus,
/// `e`
public_exponent: Fe,
pub const FromBytesError = ValueError || ff.OverflowError || ff.FieldElementError || ff.InvalidModulusError || error{InsecureBitCount};
pub fn fromBytes(mod: []const u8, exp: []const u8) FromBytesError!PublicKey {
const modulus = try Modulus.fromBytes(mod, .big);
if (modulus.bits() <= 512) return error.InsecureBitCount;
const public_exponent = try Fe.fromBytes(modulus, exp, .big);
if (std.debug.runtime_safety) {
// > the RSA public exponent e is an integer between 3 and n - 1 satisfying
// > GCD(e,\lambda(n)) = 1, where \lambda(n) = LCM(r_1 - 1, ..., r_u - 1)
const e_v = public_exponent.toPrimitive(u32) catch return error.Exponent;
if (!public_exponent.isOdd()) return error.Exponent;
if (e_v < 3) return error.Exponent;
if (modulus.v.compare(public_exponent.v) == .lt) return error.Exponent;
}
return .{ .modulus = modulus, .public_exponent = public_exponent };
}
pub fn fromDer(bytes: []const u8) (der.Parser.Error || FromBytesError)!PublicKey {
var parser = der.Parser{ .bytes = bytes };
const seq = try parser.expectSequence();
defer parser.seek(seq.slice.end);
const modulus = try parser.expectPrimitive(.integer);
const pub_exp = try parser.expectPrimitive(.integer);
try parser.expectEnd(seq.slice.end);
try parser.expectEnd(bytes.len);
return try fromBytes(parser.view(modulus), parser.view(pub_exp));
}
/// Deprecated.
///
/// Encrypt a short message using RSAES-PKCS1-v1_5.
/// The use of this scheme for encrypting an arbitrary message, as opposed to a
/// randomly generated key, is NOT RECOMMENDED.
pub fn encryptPkcsv1_5(pk: PublicKey, msg: []const u8, out: []u8) ![]const u8 {
// align variable names with spec
const k = byteLen(pk.modulus.bits());
if (out.len < k) return error.BufferTooSmall;
if (msg.len > k - 11) return error.MessageTooLong;
// EM = 0x00 || 0x02 || PS || 0x00 || M.
var em = out[0..k];
em[0] = 0;
em[1] = 2;
const ps = em[2..][0 .. k - msg.len - 3];
// Section: 7.2.1
// PS consists of pseudo-randomly generated nonzero octets.
for (ps) |*v| {
v.* = std.crypto.random.uintLessThan(u8, 0xff) + 1;
}
em[em.len - msg.len - 1] = 0;
@memcpy(em[em.len - msg.len ..][0..msg.len], msg);
const m = try Fe.fromBytes(pk.modulus, em, .big);
const e = try pk.modulus.powPublic(m, pk.public_exponent);
try e.toBytes(em, .big);
return em;
}
/// Encrypt a short message using Optimal Asymmetric Encryption Padding (RSAES-OAEP).
pub fn encryptOaep(
pk: PublicKey,
comptime Hash: type,
msg: []const u8,
label: []const u8,
out: []u8,
) ![]const u8 {
// align variable names with spec
const k = byteLen(pk.modulus.bits());
if (out.len < k) return error.BufferTooSmall;
if (msg.len > k - 2 * Hash.digest_length - 2) return error.MessageTooLong;
// EM = 0x00 || maskedSeed || maskedDB.
var em = out[0..k];
em[0] = 0;
const seed = em[1..][0..Hash.digest_length];
std.crypto.random.bytes(seed);
// DB = lHash || PS || 0x01 || M.
var db = em[1 + seed.len ..];
const lHash = labelHash(Hash, label);
@memcpy(db[0..lHash.len], &lHash);
@memset(db[lHash.len .. db.len - msg.len - 2], 0);
db[db.len - msg.len - 1] = 1;
@memcpy(db[db.len - msg.len ..], msg);
var mgf_buf: [max_modulus_len]u8 = undefined;
const db_mask = mgf1(Hash, seed, mgf_buf[0..db.len]);
for (db, db_mask) |*v, m| v.* ^= m;
const seed_mask = mgf1(Hash, db, mgf_buf[0..seed.len]);
for (seed, seed_mask) |*v, m| v.* ^= m;
const m = try Fe.fromBytes(pk.modulus, em, .big);
const e = try pk.modulus.powPublic(m, pk.public_exponent);
try e.toBytes(em, .big);
return em;
}
};
pub fn byteLen(bits: usize) usize {
return std.math.divCeil(usize, bits, 8) catch unreachable;
}
pub const SecretKey = struct {
/// `d`
private_exponent: Fe,
pub const FromBytesError = ValueError || ff.OverflowError || ff.FieldElementError;
pub fn fromBytes(n: Modulus, exp: []const u8) FromBytesError!SecretKey {
const d = try Fe.fromBytes(n, exp, .big);
if (std.debug.runtime_safety) {
// > The RSA private exponent d is a positive integer less than n
// > satisfying e * d == 1 (mod \lambda(n)),
if (!d.isOdd()) return error.Exponent;
if (d.v.compare(n.v) != .lt) return error.Exponent;
}
return .{ .private_exponent = d };
}
};
pub const KeyPair = struct {
public: PublicKey,
secret: SecretKey,
pub const FromDerError = PublicKey.FromBytesError || SecretKey.FromBytesError || der.Parser.Error || error{ KeyMismatch, InvalidVersion };
pub fn fromDer(bytes: []const u8) FromDerError!KeyPair {
var parser = der.Parser{ .bytes = bytes };
const seq = try parser.expectSequence();
const version = try parser.expectInt(u8);
const mod = try parser.expectPrimitive(.integer);
const pub_exp = try parser.expectPrimitive(.integer);
const sec_exp = try parser.expectPrimitive(.integer);
const public = try PublicKey.fromBytes(parser.view(mod), parser.view(pub_exp));
const secret = try SecretKey.fromBytes(public.modulus, parser.view(sec_exp));
const prime1 = try parser.expectPrimitive(.integer);
const prime2 = try parser.expectPrimitive(.integer);
const exp1 = try parser.expectPrimitive(.integer);
const exp2 = try parser.expectPrimitive(.integer);
const coeff = try parser.expectPrimitive(.integer);
_ = .{ exp1, exp2, coeff };
switch (version) {
0 => {},
1 => {
_ = try parser.expectSequenceOf();
while (!parser.eof()) {
_ = try parser.expectSequence();
const ri = try parser.expectPrimitive(.integer);
const di = try parser.expectPrimitive(.integer);
const ti = try parser.expectPrimitive(.integer);
_ = .{ ri, di, ti };
}
},
else => return error.InvalidVersion,
}
try parser.expectEnd(seq.slice.end);
try parser.expectEnd(bytes.len);
if (std.debug.runtime_safety) {
const p = try Fe.fromBytes(public.modulus, parser.view(prime1), .big);
const q = try Fe.fromBytes(public.modulus, parser.view(prime2), .big);
// check that n = p * q
const expected_zero = public.modulus.mul(p, q);
if (!expected_zero.isZero()) return error.KeyMismatch;
// TODO: check that d * e is one mod p-1 and mod q-1. Note d and e were bound
// const de = secret.private_exponent.mul(public.public_exponent);
// const one = public.modulus.one();
// if (public.modulus.mul(de, p).compare(one) != .eq) return error.KeyMismatch;
// if (public.modulus.mul(de, q).compare(one) != .eq) return error.KeyMismatch;
}
return .{ .public = public, .secret = secret };
}
/// Deprecated.
pub fn signPkcsv1_5(kp: KeyPair, comptime Hash: type, msg: []const u8, out: []u8) !PKCS1v1_5(Hash).Signature {
var st = try signerPkcsv1_5(kp, Hash);
st.update(msg);
return try st.finalize(out);
}
/// Deprecated.
pub fn signerPkcsv1_5(kp: KeyPair, comptime Hash: type) !PKCS1v1_5(Hash).Signer {
return PKCS1v1_5(Hash).Signer.init(kp);
}
/// Deprecated.
pub fn decryptPkcsv1_5(kp: KeyPair, ciphertext: []const u8, out: []u8) ![]const u8 {
const k = byteLen(kp.public.modulus.bits());
if (out.len < k) return error.BufferTooSmall;
const em = out[0..k];
const m = try Fe.fromBytes(kp.public.modulus, ciphertext, .big);
const e = try kp.public.modulus.pow(m, kp.secret.private_exponent);
try e.toBytes(em, .big);
// Care shall be taken to ensure that an opponent cannot
// distinguish these error conditions, whether by error
// message or timing.
const msg_start = ct.lastIndexOfScalar(em, 0) orelse em.len;
const ps_len = em.len - msg_start;
if (ct.@"or"(em[0] != 0, ct.@"or"(em[1] != 2, ps_len < 8))) {
return error.Inconsistent;
}
return em[msg_start + 1 ..];
}
pub fn signOaep(
kp: KeyPair,
comptime Hash: type,
msg: []const u8,
salt: ?[]const u8,
out: []u8,
) !Pss(Hash).Signature {
var st = try signerOaep(kp, Hash, salt);
st.update(msg);
return try st.finalize(out);
}
/// Salt must outlive returned `PSS.Signer`.
pub fn signerOaep(kp: KeyPair, comptime Hash: type, salt: ?[]const u8) !Pss(Hash).Signer {
return Pss(Hash).Signer.init(kp, salt);
}
pub fn decryptOaep(
kp: KeyPair,
comptime Hash: type,
ciphertext: []const u8,
label: []const u8,
out: []u8,
) ![]u8 {
// align variable names with spec
const k = byteLen(kp.public.modulus.bits());
if (out.len < k) return error.BufferTooSmall;
const mod = try Fe.fromBytes(kp.public.modulus, ciphertext, .big);
const exp = kp.public.modulus.pow(mod, kp.secret.private_exponent) catch unreachable;
const em = out[0..k];
try exp.toBytes(em, .big);
const y = em[0];
const seed = em[1..][0..Hash.digest_length];
const db = em[1 + Hash.digest_length ..];
var mgf_buf: [max_modulus_len]u8 = undefined;
const seed_mask = mgf1(Hash, db, mgf_buf[0..seed.len]);
for (seed, seed_mask) |*v, m| v.* ^= m;
const db_mask = mgf1(Hash, seed, mgf_buf[0..db.len]);
for (db, db_mask) |*v, m| v.* ^= m;
const expected_hash = labelHash(Hash, label);
const actual_hash = db[0..expected_hash.len];
// Care shall be taken to ensure that an opponent cannot
// distinguish these error conditions, whether by error
// message or timing.
const msg_start = ct.indexOfScalarPos(em, expected_hash.len + 1, 1) orelse 0;
if (ct.@"or"(y != 0, ct.@"or"(msg_start == 0, !ct.memEql(&expected_hash, actual_hash)))) {
return error.Inconsistent;
}
return em[msg_start + 1 ..];
}
/// Encrypt short plaintext with secret key.
pub fn encrypt(kp: KeyPair, plaintext: []const u8, out: []u8) !void {
const n = kp.public.modulus;
const k = byteLen(n.bits());
if (plaintext.len > k) return error.MessageTooLong;
const msg_as_int = try Fe.fromBytes(n, plaintext, .big);
const enc_as_int = try n.pow(msg_as_int, kp.secret.private_exponent);
try enc_as_int.toBytes(out, .big);
}
};
/// Deprecated.
///
/// Signature Scheme with Appendix v1.5 (RSASSA-PKCS1-v1_5)
///
/// This standard has been superceded by PSS which is formally proven secure
/// and has fewer footguns.
pub fn PKCS1v1_5(comptime Hash: type) type {
return struct {
const PkcsT = @This();
pub const Signature = struct {
bytes: []const u8,
const Self = @This();
pub fn verifier(self: Self, public_key: PublicKey) !Verifier {
return Verifier.init(self, public_key);
}
pub fn verify(self: Self, msg: []const u8, public_key: PublicKey) !void {
var st = Verifier.init(self, public_key);
st.update(msg);
return st.verify();
}
};
pub const Signer = struct {
h: Hash,
key_pair: KeyPair,
fn init(key_pair: KeyPair) Signer {
return .{
.h = Hash.init(.{}),
.key_pair = key_pair,
};
}
pub fn update(self: *Signer, data: []const u8) void {
self.h.update(data);
}
pub fn finalize(self: *Signer, out: []u8) !PkcsT.Signature {
const k = byteLen(self.key_pair.public.modulus.bits());
if (out.len < k) return error.BufferTooSmall;
var hash: [Hash.digest_length]u8 = undefined;
self.h.final(&hash);
const em = try emsaEncode(hash, out[0..k]);
try self.key_pair.encrypt(em, em);
return .{ .bytes = em };
}
};
pub const Verifier = struct {
h: Hash,
sig: PkcsT.Signature,
public_key: PublicKey,
fn init(sig: PkcsT.Signature, public_key: PublicKey) Verifier {
return Verifier{
.h = Hash.init(.{}),
.sig = sig,
.public_key = public_key,
};
}
pub fn update(self: *Verifier, data: []const u8) void {
self.h.update(data);
}
pub fn verify(self: *Verifier) !void {
const pk = self.public_key;
const s = try Fe.fromBytes(pk.modulus, self.sig.bytes, .big);
const emm = try pk.modulus.powPublic(s, pk.public_exponent);
var em_buf: [max_modulus_len]u8 = undefined;
const em = em_buf[0..byteLen(pk.modulus.bits())];
try emm.toBytes(em, .big);
var hash: [Hash.digest_length]u8 = undefined;
self.h.final(&hash);
// TODO: compare hash values instead of emsa values
const expected = try emsaEncode(hash, em);
if (!std.mem.eql(u8, expected, em)) return error.Inconsistent;
}
};
/// PKCS Encrypted Message Signature Appendix
fn emsaEncode(hash: [Hash.digest_length]u8, out: []u8) ![]u8 {
const digest_header = comptime digestHeader();
const tLen = digest_header.len + Hash.digest_length;
const emLen = out.len;
if (emLen < tLen + 11) return error.ModulusTooShort;
if (out.len < emLen) return error.BufferTooSmall;
var res = out[0..emLen];
res[0] = 0;
res[1] = 1;
const padding_len = emLen - tLen - 3;
@memset(res[2..][0..padding_len], 0xff);
res[2 + padding_len] = 0;
@memcpy(res[2 + padding_len + 1 ..][0..digest_header.len], digest_header);
@memcpy(res[res.len - hash.len ..], &hash);
return res;
}
/// DER encoded header. Sequence of digest algo + digest.
/// TODO: use a DER encoder instead
fn digestHeader() []const u8 {
const sha2 = std.crypto.hash.sha2;
// Section 9.2 Notes 1.
return switch (Hash) {
std.crypto.hash.Sha1 => &hexToBytes(
\\30 21 30 09 06 05 2b 0e 03 02 1a 05 00 04 14
),
sha2.Sha224 => &hexToBytes(
\\30 2d 30 0d 06 09 60 86 48 01 65 03 04 02 04
\\05 00 04 1c
),
sha2.Sha256 => &hexToBytes(
\\30 31 30 0d 06 09 60 86 48 01 65 03 04 02 01 05 00
\\04 20
),
sha2.Sha384 => &hexToBytes(
\\30 41 30 0d 06 09 60 86 48 01 65 03 04 02 02 05 00
\\04 30
),
sha2.Sha512 => &hexToBytes(
\\30 51 30 0d 06 09 60 86 48 01 65 03 04 02 03 05 00
\\04 40
),
// sha2.Sha512224 => &hexToBytes(
// \\30 2d 30 0d 06 09 60 86 48 01 65 03 04 02 05
// \\05 00 04 1c
// ),
// sha2.Sha512256 => &hexToBytes(
// \\30 31 30 0d 06 09 60 86 48 01 65 03 04 02 06
// \\05 00 04 20
// ),
else => @compileError("unknown Hash " ++ @typeName(Hash)),
};
}
};
}
/// Probabilistic Signature Scheme (RSASSA-PSS)
pub fn Pss(comptime Hash: type) type {
// RFC 4055 S3.1
const default_salt_len = Hash.digest_length;
return struct {
pub const Signature = struct {
bytes: []const u8,
const Self = @This();
pub fn verifier(self: Self, public_key: PublicKey) !Verifier {
return Verifier.init(self, public_key);
}
pub fn verify(self: Self, msg: []const u8, public_key: PublicKey, salt_len: ?usize) !void {
var st = Verifier.init(self, public_key, salt_len orelse default_salt_len);
st.update(msg);
return st.verify();
}
};
const PssT = @This();
pub const Signer = struct {
h: Hash,
key_pair: KeyPair,
salt: ?[]const u8,
fn init(key_pair: KeyPair, salt: ?[]const u8) Signer {
return .{
.h = Hash.init(.{}),
.key_pair = key_pair,
.salt = salt,
};
}
pub fn update(self: *Signer, data: []const u8) void {
self.h.update(data);
}
pub fn finalize(self: *Signer, out: []u8) !PssT.Signature {
var hashed: [Hash.digest_length]u8 = undefined;
self.h.final(&hashed);
const salt = if (self.salt) |s| s else brk: {
var res: [default_salt_len]u8 = undefined;
std.crypto.random.bytes(&res);
break :brk &res;
};
const em_bits = self.key_pair.public.modulus.bits() - 1;
const em = try emsaEncode(hashed, salt, em_bits, out);
try self.key_pair.encrypt(em, em);
return .{ .bytes = em };
}
};
pub const Verifier = struct {
h: Hash,
sig: PssT.Signature,
public_key: PublicKey,
salt_len: usize,
fn init(sig: PssT.Signature, public_key: PublicKey, salt_len: usize) Verifier {
return Verifier{
.h = Hash.init(.{}),
.sig = sig,
.public_key = public_key,
.salt_len = salt_len,
};
}
pub fn update(self: *Verifier, data: []const u8) void {
self.h.update(data);
}
pub fn verify(self: *Verifier) !void {
const pk = self.public_key;
const s = try Fe.fromBytes(pk.modulus, self.sig.bytes, .big);
const emm = try pk.modulus.powPublic(s, pk.public_exponent);
var em_buf: [max_modulus_len]u8 = undefined;
const em_bits = pk.modulus.bits() - 1;
const em_len = std.math.divCeil(usize, em_bits, 8) catch unreachable;
var em = em_buf[0..em_len];
try emm.toBytes(em, .big);
if (em.len < Hash.digest_length + self.salt_len + 2) return error.Inconsistent;
if (em[em.len - 1] != 0xbc) return error.Inconsistent;
const db = em[0 .. em.len - Hash.digest_length - 1];
if (@clz(db[0]) < em.len * 8 - em_bits) return error.Inconsistent;
const expected_hash = em[db.len..][0..Hash.digest_length];
var mgf_buf: [max_modulus_len]u8 = undefined;
const db_mask = mgf1(Hash, expected_hash, mgf_buf[0..db.len]);
for (db, db_mask) |*v, m| v.* ^= m;
for (1..db.len - self.salt_len - 1) |i| {
if (db[i] != 0) return error.Inconsistent;
}
if (db[db.len - self.salt_len - 1] != 1) return error.Inconsistent;
const salt = db[db.len - self.salt_len ..];
var mp_buf: [max_modulus_len]u8 = undefined;
var mp = mp_buf[0 .. 8 + Hash.digest_length + self.salt_len];
@memset(mp[0..8], 0);
self.h.final(mp[8..][0..Hash.digest_length]);
@memcpy(mp[8 + Hash.digest_length ..][0..salt.len], salt);
var actual_hash: [Hash.digest_length]u8 = undefined;
Hash.hash(mp, &actual_hash, .{});
if (!std.mem.eql(u8, expected_hash, &actual_hash)) return error.Inconsistent;
}
};
/// PSS Encrypted Message Signature Appendix
fn emsaEncode(msg_hash: [Hash.digest_length]u8, salt: []const u8, em_bits: usize, out: []u8) ![]u8 {
const em_len = std.math.divCeil(usize, em_bits, 8) catch unreachable;
if (em_len < Hash.digest_length + salt.len + 2) return error.Encoding;
// EM = maskedDB || H || 0xbc
var em = out[0..em_len];
em[em.len - 1] = 0xbc;
var mp_buf: [max_modulus_len]u8 = undefined;
// M' = (0x)00 00 00 00 00 00 00 00 || mHash || salt;
const mp = mp_buf[0 .. 8 + Hash.digest_length + salt.len];
@memset(mp[0..8], 0);
@memcpy(mp[8..][0..Hash.digest_length], &msg_hash);
@memcpy(mp[8 + Hash.digest_length ..][0..salt.len], salt);
// H = Hash(M')
const hash = em[em.len - 1 - Hash.digest_length ..][0..Hash.digest_length];
Hash.hash(mp, hash, .{});
// DB = PS || 0x01 || salt
var db = em[0 .. em_len - Hash.digest_length - 1];
@memset(db[0 .. db.len - salt.len - 1], 0);
db[db.len - salt.len - 1] = 1;
@memcpy(db[db.len - salt.len ..], salt);
var mgf_buf: [max_modulus_len]u8 = undefined;
const db_mask = mgf1(Hash, hash, mgf_buf[0..db.len]);
for (db, db_mask) |*v, m| v.* ^= m;
// Set the leftmost 8emLen - emBits bits of the leftmost octet
// in maskedDB to zero.
const shift = std.math.comptimeMod(8 * em_len - em_bits, 8);
const mask = @as(u8, 0xff) >> shift;
db[0] &= mask;
return em;
}
};
}
/// Mask generation function. Currently the only one defined.
fn mgf1(comptime Hash: type, seed: []const u8, out: []u8) []u8 {
var c: [@sizeOf(u32)]u8 = undefined;
var tmp: [Hash.digest_length]u8 = undefined;
var i: usize = 0;
var counter: u32 = 0;
while (i < out.len) : (counter += 1) {
var hasher = Hash.init(.{});
hasher.update(seed);
std.mem.writeInt(u32, &c, counter, .big);
hasher.update(&c);
const left = out.len - i;
if (left >= Hash.digest_length) {
// optimization: write straight to `out`
hasher.final(out[i..][0..Hash.digest_length]);
i += Hash.digest_length;
} else {
hasher.final(&tmp);
@memcpy(out[i..][0..left], tmp[0..left]);
i += left;
}
}
return out;
}
test mgf1 {
const Hash = std.crypto.hash.sha2.Sha256;
var out: [Hash.digest_length * 2 + 1]u8 = undefined;
try std.testing.expectEqualSlices(
u8,
&hexToBytes(
\\ed 1b 84 6b b9 26 39 00 c8 17 82 ad 08 eb 17 01
\\fa 8c 72 21 c6 57 63 77 31 7f 5c e8 09 89 9f
),
mgf1(Hash, "asdf", out[0 .. Hash.digest_length - 1]),
);
try std.testing.expectEqualSlices(
u8,
&hexToBytes(
\\ed 1b 84 6b b9 26 39 00 c8 17 82 ad 08 eb 17 01
\\fa 8c 72 21 c6 57 63 77 31 7f 5c e8 09 89 9f 5a
\\22 F2 80 D5 28 08 F4 93 83 76 00 DE 09 E4 EC 92
\\4A 2C 7C EF 0D F7 7B BE 8F 7F 12 CB 8F 33 A6 65
\\AB
),
mgf1(Hash, "asdf", &out),
);
}
/// For OAEP.
inline fn labelHash(comptime Hash: type, label: []const u8) [Hash.digest_length]u8 {
if (label.len == 0) {
// magic constants from NIST
const sha2 = std.crypto.hash.sha2;
switch (Hash) {
std.crypto.hash.Sha1 => return hexToBytes(
\\da39a3ee 5e6b4b0d 3255bfef 95601890
\\afd80709
),
sha2.Sha256 => return hexToBytes(
\\e3b0c442 98fc1c14 9afbf4c8 996fb924
\\27ae41e4 649b934c a495991b 7852b855
),
sha2.Sha384 => return hexToBytes(
\\38b060a7 51ac9638 4cd9327e b1b1e36a
\\21fdb711 14be0743 4c0cc7bf 63f6e1da
\\274edebf e76f65fb d51ad2f1 4898b95b
),
sha2.Sha512 => return hexToBytes(
\\cf83e135 7eefb8bd f1542850 d66d8007
\\d620e405 0b5715dc 83f4a921 d36ce9ce
\\47d0d13c 5d85f2b0 ff8318d2 877eec2f
\\63b931bd 47417a81 a538327a f927da3e
),
// just use the empty hash...
else => {},
}
}
var res: [Hash.digest_length]u8 = undefined;
Hash.hash(label, &res, .{});
return res;
}
const ct = if (std.options.side_channels_mitigations == .none) ct_unprotected else ct_protected;
const ct_unprotected = struct {
fn lastIndexOfScalar(slice: []const u8, value: u8) ?usize {
return std.mem.lastIndexOfScalar(u8, slice, value);
}
fn indexOfScalarPos(slice: []const u8, start_index: usize, value: u8) ?usize {
return std.mem.indexOfScalarPos(u8, slice, start_index, value);
}
fn memEql(a: []const u8, b: []const u8) bool {
return std.mem.eql(u8, a, b);
}
fn @"and"(a: bool, b: bool) bool {
return a and b;
}
fn @"or"(a: bool, b: bool) bool {
return a or b;
}
};
const ct_protected = struct {
fn lastIndexOfScalar(slice: []const u8, value: u8) ?usize {
var res: ?usize = null;
var i: usize = slice.len;
while (i != 0) {
i -= 1;
if (@intFromBool(res == null) & @intFromBool(slice[i] == value) == 1) res = i;
}
return res;
}
fn indexOfScalarPos(slice: []const u8, start_index: usize, value: u8) ?usize {
var res: ?usize = null;
for (slice[start_index..], start_index..) |c, j| {
if (c == value) res = j;
}
return res;
}
fn memEql(a: []const u8, b: []const u8) bool {
var res: u1 = 1;
for (a, b) |a_elem, b_elem| {
res &= @intFromBool(a_elem == b_elem);
}
return res == 1;
}
fn @"and"(a: bool, b: bool) bool {
return (@intFromBool(a) & @intFromBool(b)) == 1;
}
fn @"or"(a: bool, b: bool) bool {
return (@intFromBool(a) | @intFromBool(b)) == 1;
}
};
test ct {
const c = ct_unprotected;
try std.testing.expectEqual(true, c.@"or"(true, false));
try std.testing.expectEqual(true, c.@"and"(true, true));
try std.testing.expectEqual(true, c.memEql("Asdf", "Asdf"));
try std.testing.expectEqual(false, c.memEql("asdf", "Asdf"));
try std.testing.expectEqual(3, c.indexOfScalarPos("asdff", 1, 'f'));
try std.testing.expectEqual(4, c.lastIndexOfScalar("asdff", 'f'));
}
fn removeNonHex(comptime hex: []const u8) []const u8 {
var res: [hex.len]u8 = undefined;
var i: usize = 0;
for (hex) |c| {
if (std.ascii.isHex(c)) {
res[i] = c;
i += 1;
}
}
return res[0..i];
}
/// For readable copy/pasting from hex viewers.
fn hexToBytes(comptime hex: []const u8) [removeNonHex(hex).len / 2]u8 {
const hex2 = comptime removeNonHex(hex);
comptime var res: [hex2.len / 2]u8 = undefined;
_ = comptime std.fmt.hexToBytes(&res, hex2) catch unreachable;
return res;
}
test hexToBytes {
const hex =
\\e3b0c442 98fc1c14 9afbf4c8 996fb924
\\27ae41e4 649b934c a495991b 7852b855
;
try std.testing.expectEqual(
[_]u8{
0xe3, 0xb0, 0xc4, 0x42, 0x98, 0xfc, 0x1c, 0x14,
0x9a, 0xfb, 0xf4, 0xc8, 0x99, 0x6f, 0xb9, 0x24,
0x27, 0xae, 0x41, 0xe4, 0x64, 0x9b, 0x93, 0x4c,
0xa4, 0x95, 0x99, 0x1b, 0x78, 0x52, 0xb8, 0x55,
},
hexToBytes(hex),
);
}
const TestHash = std.crypto.hash.sha2.Sha256;
fn testKeypair() !KeyPair {
const keypair_bytes = @embedFile("testdata/id_rsa.der");
const kp = try KeyPair.fromDer(keypair_bytes);
try std.testing.expectEqual(2048, kp.public.modulus.bits());
return kp;
}
test "rsa PKCS1-v1_5 encrypt and decrypt" {
const kp = try testKeypair();
const msg = "rsa PKCS1-v1_5 encrypt and decrypt";
var out: [max_modulus_len]u8 = undefined;
const enc = try kp.public.encryptPkcsv1_5(msg, &out);
var out2: [max_modulus_len]u8 = undefined;
const dec = try kp.decryptPkcsv1_5(enc, &out2);
try std.testing.expectEqualSlices(u8, msg, dec);
}
test "rsa OAEP encrypt and decrypt" {
const kp = try testKeypair();
const msg = "rsa OAEP encrypt and decrypt";
const label = "";
var out: [max_modulus_len]u8 = undefined;
const enc = try kp.public.encryptOaep(TestHash, msg, label, &out);
var out2: [max_modulus_len]u8 = undefined;
const dec = try kp.decryptOaep(TestHash, enc, label, &out2);
try std.testing.expectEqualSlices(u8, msg, dec);
}
test "rsa PKCS1-v1_5 signature" {
const kp = try testKeypair();
const msg = "rsa PKCS1-v1_5 signature";
var out: [max_modulus_len]u8 = undefined;
const signature = try kp.signPkcsv1_5(TestHash, msg, &out);
try signature.verify(msg, kp.public);
}
test "rsa PSS signature" {
const kp = try testKeypair();
const msg = "rsa PSS signature";
var out: [max_modulus_len]u8 = undefined;
const salts = [_][]const u8{ "asdf", "" };
for (salts) |salt| {
const signature = try kp.signOaep(TestHash, msg, salt, &out);
try signature.verify(msg, kp.public, salt.len);
}
const signature = try kp.signOaep(TestHash, msg, null, &out); // random salt
try signature.verify(msg, kp.public, null);
}

Binary file not shown.

View File

@@ -1,5 +0,0 @@
-----BEGIN EC PRIVATE KEY-----
MHcCAQEEINJSRKv8kSKEzLHptfAlg+LGh4/pHHlq0XLf30Q9pcztoAoGCCqGSM49
AwEHoUQDQgAEJpmLyp8aGCgyMcFIJaIq/+4V1K6nPpeoih3bT2npeplF9eyXj7rm
8eW9Ua6VLhq71mqtMC+YLm+IkORBVq1cuA==
-----END EC PRIVATE KEY-----

View File

@@ -1,6 +0,0 @@
-----BEGIN PRIVATE KEY-----
MIG2AgEAMBAGByqGSM49AgEGBSuBBAAiBIGeMIGbAgEBBDAQNT3KGxUdBqpxuO/z
GSJDePMgmB6xLytkfnHQMCqQquXrmcOQZT3BJhm+PwggmwGhZANiAATKxBc6kfqA
piA+Z0rIjVwaZaBNGnP4UZ5TqVewQ/dP9/BQCca2SJpsXauGLcUPmK4sKFQxGe6d
fzq9O50lo7qHEOIpwDBdhRp+oqB6sN2hMtCPbp6eyzsUlm3FUyhN9D0=
-----END PRIVATE KEY-----

View File

@@ -1,6 +0,0 @@
-----BEGIN EC PRIVATE KEY-----
MIGkAgEBBDDubYpeDdOwxksyQIDiOt6LHt3ikts2HNuR6rqhBg1CLdmp3AVDKfF4
fPkIr8UDH22gBwYFK4EEACKhZANiAARcVFUVv3bIHS6BEfLt98rtps7XP1y26m2n
v5x/5ecbDH2p7AXBYerJERKFi7ZFE1DSrSAj+KK8otjdEG44ZA2Mtl5AHwDVrKde
RgtavVoreHhLN80jJOun8JnFXQjdNsA=
-----END EC PRIVATE KEY-----

View File

@@ -1,7 +0,0 @@
-----BEGIN EC PRIVATE KEY-----
MIHcAgEBBEIB8C9axyQY6mgjjC6htLjc8hGylrDsh4BCv9669JaDj5vbxmCnTNlg
OuS6C9+uJNMbwm6CoIjB7RcgDTrxxX7oCyegBwYFK4EEACOhgYkDgYYABABAT5Q8
aOj9U0iuJE5tXfKnYTgPuvD6keHZAGJ5veM9uR6jr3BhfGubD6bnlD+cIBQzYWo0
y/BNMzCRJ55PDCNU5gGLw+vkwhJ1lGF5OS6l2oG5WN3fe6cYo+uJD7+PB3WYNIuX
Ls0oidsEM0Q4WLblQOEP6VLGf4qTcZyhoFWYfkjWiw==
-----END EC PRIVATE KEY-----

View File

@@ -1 +0,0 @@
'<27><><EFBFBD>qp0x<30>0)쩖<15>~<7E>+<2B>`<60><>tY4<59><34>D_

View File

@@ -1,28 +0,0 @@
-----BEGIN PRIVATE KEY-----
MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQDe9yPmdcxv3dVu
D4wJ+GLjYBvAfYzVBFAsNuI79zOfoRSvvs8aD0z1yzlwDjuX1iH3SJF5ynxo/Opi
oVpyT3hXDszyo1AF8UzKUXMQmhiOcfW0xz6+TO831IRLghzsCKPMBz1cC+WFP/62
RHePPGovM8Nd9vIpRgQlfgXZ+DstpEBmnw1tGvq8CsWLhkMw7xQgQZ21zD5jtUgE
J8lc02IoX/W25HdJmayESqZnpZoaN8dgTLrBcM9XZEoh6gVTEOyUcUpDBIMAqloo
vPKMWBSS0oMX9HspD+eHokeyUxkSI/tLzlr4oYT5sfO/4/oQ+K2vh84DDqAsE3FX
xFVESETLAgMBAAECggEAUDuAmKqlEVAzQDKqAuB1vTpVYjQLnI+7xd1OFaQD2Jpf
VkqEPe1plT03AwKsIRw2BsT/TGM315PDSBCl+mJsfG9gAqQP5MOLDXa3wC6jTYbm
ktHr2xDWODHqFT3R6IHHZ2DnjJrfUc7QeogyucFUuH2Y/NQjGgUO8urhcikoKmi3
kBiAHCHWNqhrSpzdFLifhe6VC/TGFwKqTepN+TnX3Z20HdL4kkYPGEGA9OonVSn4
N1m/Q+yj6xm6vBMGlT0lS8lyz0EKb6rLedR7+rEJfOIvhVFEi8aXjkb5a6wIh5LO
rwu/jL0nUY8J5NP5BKz68gRwPtmmKBfCLXTpJUACSQKBgQDmPHEZkBC5wl8plx16
hrwwSdJuQy0b6BYZO06gpYBOIIENULijKwZzoMYaL8zivGT3KIluEelA7+NXnCuk
NUx7LieeZ+ChIUuRLvT02H9lH11d1Va2PmBgRmUKgul26YyaxeIy3UzjPbbgUFJv
t970IRfgS8qGD9KuhdlovZlCzwKBgQD36mo4BxgO1xmr4Qq0WSgQi2QBMAP9lpE4
Lc59UP5qvNrGXLGPsirdzz6VSeMGrrxGDyof+fGG9d0Wt0+8OMRysSuVua+SRiJ4
ugoaCzLbsq6pzDWPXf/wzVevjKTIGh4ZXk6Qa7IHqyEmvOnvxdDsL3iZXgPcQoIF
HybqHU9NRQKBgC6tnGSJX8q5jJ+bAp//xxGnNeGi/vdEc46EBqntQ/kS//caIYT7
SSCSPPe8Lzbc6T9u2YYWXYsL17TAddyh7bKfpeqottMUNAToV0N4zUNMO5q1kRH7
zYBXZU7fQcQZD6elbPnRAjCkJ3qM7lm2Fp66QuP3mcTaWmWFv5FLt1HjAoGANVaF
y9Aa6PZ2W3hraSnVaNnUhjziXujKDaAtUODgG+7N0ueWfCgE+PvhpxTid0mY0Cnr
Ej4gLL0w9/YwfXppKZPcoLX2hC36tKayDbBjHMlwsq9wxoueyRwkxWwo97RGzYZw
uLmy79ttonv6iM+yh14fQD/t7LGSb6+oG656pVECgYEA0oya1vG0WL3K8ip8io4c
ovB2K1Uf7EyFzxJHJt6QpmXlPDKkwc6JzpKGJdCi09Pz49U63HodxahtB831rbAY
EduOUQ5scTKf66qA9/kEyClnwl14ZCds7/mu9ioZ7D0VNmWPFsYHaGKAUxsq97nb
xw9Y4zAdgbDcl1bzN9XCDKs=
-----END PRIVATE KEY-----

View File

@@ -1,244 +0,0 @@
/// Messages from The Illustrated TLS 1.2 Connection
/// https://tls12.xargs.org/
const hexToBytes = @import("../testu.zig").hexToBytes;
pub const client_hello = hexToBytes(
\\ 16 03 01 00 a5 01 00 00 a1 03 03 00 01 02 03 04
\\ 05 06 07 08 09 0a 0b 0c 0d 0e 0f 10 11 12 13 14
\\ 15 16 17 18 19 1a 1b 1c 1d 1e 1f 00 00 20 cc a8
\\ cc a9 c0 2f c0 30 c0 2b c0 2c c0 13 c0 09 c0 14
\\ c0 0a 00 9c 00 9d 00 2f 00 35 c0 12 00 0a 01 00
\\ 00 58 00 00 00 18 00 16 00 00 13 65 78 61 6d 70
\\ 6c 65 2e 75 6c 66 68 65 69 6d 2e 6e 65 74 00 05
\\ 00 05 01 00 00 00 00 00 0a 00 0a 00 08 00 1d 00
\\ 17 00 18 00 19 00 0b 00 02 01 00 00 0d 00 12 00
\\ 10 04 01 04 03 05 01 05 03 06 01 06 03 02 01 02
\\ 03 ff 01 00 01 00 00 12 00 00
);
pub const server_hello = hexToBytes(
\\ 16 03 03 00 31 02 00 00 2d 03 03 70 71 72 73 74
\\ 75 76 77 78 79 7a 7b 7c 7d 7e 7f 80 81 82 83 84
\\ 85 86 87 88 89 8a 8b 8c 8d 8e 8f 00 c0 13 00 00
\\ 05 ff 01 00 01 00
);
pub const server_certificate = hexToBytes(
\\ 16 03 03 03 2f 0b 00 03 2b 00 03 28 00 03 25 30
\\ 82 03 21 30 82 02 09 a0 03 02 01 02 02 08 15 5a
\\ 92 ad c2 04 8f 90 30 0d 06 09 2a 86 48 86 f7 0d
\\ 01 01 0b 05 00 30 22 31 0b 30 09 06 03 55 04 06
\\ 13 02 55 53 31 13 30 11 06 03 55 04 0a 13 0a 45
\\ 78 61 6d 70 6c 65 20 43 41 30 1e 17 0d 31 38 31
\\ 30 30 35 30 31 33 38 31 37 5a 17 0d 31 39 31 30
\\ 30 35 30 31 33 38 31 37 5a 30 2b 31 0b 30 09 06
\\ 03 55 04 06 13 02 55 53 31 1c 30 1a 06 03 55 04
\\ 03 13 13 65 78 61 6d 70 6c 65 2e 75 6c 66 68 65
\\ 69 6d 2e 6e 65 74 30 82 01 22 30 0d 06 09 2a 86
\\ 48 86 f7 0d 01 01 01 05 00 03 82 01 0f 00 30 82
\\ 01 0a 02 82 01 01 00 c4 80 36 06 ba e7 47 6b 08
\\ 94 04 ec a7 b6 91 04 3f f7 92 bc 19 ee fb 7d 74
\\ d7 a8 0d 00 1e 7b 4b 3a 4a e6 0f e8 c0 71 fc 73
\\ e7 02 4c 0d bc f4 bd d1 1d 39 6b ba 70 46 4a 13
\\ e9 4a f8 3d f3 e1 09 59 54 7b c9 55 fb 41 2d a3
\\ 76 52 11 e1 f3 dc 77 6c aa 53 37 6e ca 3a ec be
\\ c3 aa b7 3b 31 d5 6c b6 52 9c 80 98 bc c9 e0 28
\\ 18 e2 0b f7 f8 a0 3a fd 17 04 50 9e ce 79 bd 9f
\\ 39 f1 ea 69 ec 47 97 2e 83 0f b5 ca 95 de 95 a1
\\ e6 04 22 d5 ee be 52 79 54 a1 e7 bf 8a 86 f6 46
\\ 6d 0d 9f 16 95 1a 4c f7 a0 46 92 59 5c 13 52 f2
\\ 54 9e 5a fb 4e bf d7 7a 37 95 01 44 e4 c0 26 87
\\ 4c 65 3e 40 7d 7d 23 07 44 01 f4 84 ff d0 8f 7a
\\ 1f a0 52 10 d1 f4 f0 d5 ce 79 70 29 32 e2 ca be
\\ 70 1f df ad 6b 4b b7 11 01 f4 4b ad 66 6a 11 13
\\ 0f e2 ee 82 9e 4d 02 9d c9 1c dd 67 16 db b9 06
\\ 18 86 ed c1 ba 94 21 02 03 01 00 01 a3 52 30 50
\\ 30 0e 06 03 55 1d 0f 01 01 ff 04 04 03 02 05 a0
\\ 30 1d 06 03 55 1d 25 04 16 30 14 06 08 2b 06 01
\\ 05 05 07 03 02 06 08 2b 06 01 05 05 07 03 01 30
\\ 1f 06 03 55 1d 23 04 18 30 16 80 14 89 4f de 5b
\\ cc 69 e2 52 cf 3e a3 00 df b1 97 b8 1d e1 c1 46
\\ 30 0d 06 09 2a 86 48 86 f7 0d 01 01 0b 05 00 03
\\ 82 01 01 00 59 16 45 a6 9a 2e 37 79 e4 f6 dd 27
\\ 1a ba 1c 0b fd 6c d7 55 99 b5 e7 c3 6e 53 3e ff
\\ 36 59 08 43 24 c9 e7 a5 04 07 9d 39 e0 d4 29 87
\\ ff e3 eb dd 09 c1 cf 1d 91 44 55 87 0b 57 1d d1
\\ 9b df 1d 24 f8 bb 9a 11 fe 80 fd 59 2b a0 39 8c
\\ de 11 e2 65 1e 61 8c e5 98 fa 96 e5 37 2e ef 3d
\\ 24 8a fd e1 74 63 eb bf ab b8 e4 d1 ab 50 2a 54
\\ ec 00 64 e9 2f 78 19 66 0d 3f 27 cf 20 9e 66 7f
\\ ce 5a e2 e4 ac 99 c7 c9 38 18 f8 b2 51 07 22 df
\\ ed 97 f3 2e 3e 93 49 d4 c6 6c 9e a6 39 6d 74 44
\\ 62 a0 6b 42 c6 d5 ba 68 8e ac 3a 01 7b dd fc 8e
\\ 2c fc ad 27 cb 69 d3 cc dc a2 80 41 44 65 d3 ae
\\ 34 8c e0 f3 4a b2 fb 9c 61 83 71 31 2b 19 10 41
\\ 64 1c 23 7f 11 a5 d6 5c 84 4f 04 04 84 99 38 71
\\ 2b 95 9e d6 85 bc 5c 5d d6 45 ed 19 90 94 73 40
\\ 29 26 dc b4 0e 34 69 a1 59 41 e8 e2 cc a8 4b b6
\\ 08 46 36 a0
);
pub const server_key_exchange = hexToBytes(
\\ 16 03 03 01 2c 0c 00 01 28 03 00 1d 20 9f d7 ad
\\ 6d cf f4 29 8d d3 f9 6d 5b 1b 2a f9 10 a0 53 5b
\\ 14 88 d7 f8 fa bb 34 9a 98 28 80 b6 15 04 01 01
\\ 00 04 02 b6 61 f7 c1 91 ee 59 be 45 37 66 39 bd
\\ c3 d4 bb 81 e1 15 ca 73 c8 34 8b 52 5b 0d 23 38
\\ aa 14 46 67 ed 94 31 02 14 12 cd 9b 84 4c ba 29
\\ 93 4a aa cc e8 73 41 4e c1 1c b0 2e 27 2d 0a d8
\\ 1f 76 7d 33 07 67 21 f1 3b f3 60 20 cf 0b 1f d0
\\ ec b0 78 de 11 28 be ba 09 49 eb ec e1 a1 f9 6e
\\ 20 9d c3 6e 4f ff d3 6b 67 3a 7d dc 15 97 ad 44
\\ 08 e4 85 c4 ad b2 c8 73 84 12 49 37 25 23 80 9e
\\ 43 12 d0 c7 b3 52 2e f9 83 ca c1 e0 39 35 ff 13
\\ a8 e9 6b a6 81 a6 2e 40 d3 e7 0a 7f f3 58 66 d3
\\ d9 99 3f 9e 26 a6 34 c8 1b 4e 71 38 0f cd d6 f4
\\ e8 35 f7 5a 64 09 c7 dc 2c 07 41 0e 6f 87 85 8c
\\ 7b 94 c0 1c 2e 32 f2 91 76 9e ac ca 71 64 3b 8b
\\ 98 a9 63 df 0a 32 9b ea 4e d6 39 7e 8c d0 1a 11
\\ 0a b3 61 ac 5b ad 1c cd 84 0a 6c 8a 6e aa 00 1a
\\ 9d 7d 87 dc 33 18 64 35 71 22 6c 4d d2 c2 ac 41
\\ fb
);
pub const server_hello_done = hexToBytes("16 03 03 00 04 0e 00 00 00 ");
pub const server_change_cipher_spec = hexToBytes("14 03 03 00 01 01 ");
pub const server_handshake_finished = hexToBytes(
\\ 16 03 03 00 40 51 52 53 54 55 56 57 58 59 5a 5b
\\ 5c 5d 5e 5f 60 18 e0 75 31 7b 10 03 15 f6 08 1f
\\ cb f3 13 78 1a ac 73 ef e1 9f e2 5b a1 af 59 c2
\\ 0b e9 4f c0 1b da 2d 68 00 29 8b 73 a7 e8 49 d7
\\ 4b d4 94 cf 7d
);
pub const client_key_exchange_for_transcript = hexToBytes(
\\ 16 03 03 00 25 10 00 00 21 20 35 80 72 d6 36 58
\\ 80 d1 ae ea 32 9a df 91 21 38 38 51 ed 21 a2 8e
\\ 3b 75 e9 65 d0 d2 cd 16 62 54
);
pub const server_hello_responses = server_hello ++ server_certificate ++ server_key_exchange ++ server_hello_done;
pub const server_responses = server_hello_responses ++ server_change_cipher_spec ++ server_handshake_finished;
pub const server_handshake_finished_msgs = server_change_cipher_spec ++ server_handshake_finished;
pub const master_secret = hexToBytes(
\\ 91 6a bf 9d a5 59 73 e1 36 14 ae 0a 3f 5d 3f 37
\\ b0 23 ba 12 9a ee 02 cc 91 34 33 81 27 cd 70 49
\\ 78 1c 8e 19 fc 1e b2 a7 38 7a c0 6a e2 37 34 4c
);
pub const client_key_exchange = hexToBytes(
\\ 16 03 03 00 25 10 00 00 21 20 00 01 02 03 04 05
\\ 06 07 08 09 0a 0b 0c 0d 0e 0f 10 11 12 13 14 15
\\ 16 17 18 19 1a 1b 1c 1d 1e 1f
);
pub const client_change_cyper_spec = hexToBytes("14 03 03 00 01 01 ");
pub const client_handshake_finished = hexToBytes(
\\ 16 03 03 00 40 20 21 22 23 24 25 26 27 28 29 2a
\\ 2b 2c 2d 2e 2f a9 ac f5 5a f3 7a 90 17 63 ff 91
\\ 68 9a b7 ee a0 d4 0c 1c ca 62 44 ef f3 0b a3 6d
\\ d0 df 86 3f 7d e3 98 d3 1a cc 37 6a e6 7a 00 6d
\\ 8c 08 bc 8a 5a
);
pub const handshake_messages = [_][]const u8{
&client_hello,
&server_hello,
&server_certificate,
&server_key_exchange,
&server_hello_done,
&client_key_exchange_for_transcript,
};
pub const client_finished = hexToBytes("14 00 00 0c cf 91 96 26 f1 36 0c 53 6a aa d7 3a ");
// with iv 40 " ++ 41 ... 4f
// client_sequence = 0
pub const verify_data_encrypted_msg = hexToBytes(
\\ 16 03 03 00 40 40 41 42 43 44 45 46 47 48 49 4a
\\ 4b 4c 4d 4e 4f 22 7b c9 ba 81 ef 30 f2 a8 a7 8f
\\ f1 df 50 84 4d 58 04 b7 ee b2 e2 14 c3 2b 68 92
\\ ac a3 db 7b 78 07 7f dd 90 06 7c 51 6b ac b3 ba
\\ 90 de df 72 0f
);
// with iv 00 " ++ 01 ... 1f
// client_sequence = 1
pub const encrypted_ping_msg = hexToBytes(
\\ 17 03 03 00 30 00 01 02 03 04 05 06 07 08 09 0a
\\ 0b 0c 0d 0e 0f 6c 42 1c 71 c4 2b 18 3b fa 06 19
\\ 5d 13 3d 0a 09 d0 0f c7 cb 4e 0f 5d 1c da 59 d1
\\ 47 ec 79 0c 99
);
pub const key_material = hexToBytes(
\\ 1b 7d 11 7c 7d 5f 69 0b c2 63 ca e8 ef 60 af 0f
\\ 18 78 ac c2 2a d8 bd d8 c6 01 a6 17 12 6f 63 54
\\ 0e b2 09 06 f7 81 fa d2 f6 56 d0 37 b1 73 ef 3e
\\ 11 16 9f 27 23 1a 84 b6 75 2a 18 e7 a9 fc b7 cb
\\ cd d8 f9 8d d8 f7 69 eb a0 d2 55 0c 92 38 ee bf
\\ ef 5c 32 25 1a bb 67 d6 43 45 28 db 49 37 d5 40
\\ d3 93 13 5e 06 a1 1b b8 0e 45 ea eb e3 2c ac 72
\\ 75 74 38 fb b3 df 64 5c bd a4 06 7c df a0 f8 48
);
pub const server_pong = hexToBytes(
\\ 17 03 03 00 30 61 62 63 64 65 66 67 68 69 6a 6b
\\ 6c 6d 6e 6f 70 97 83 48 8a f5 fa 20 bf 7a 2e f6
\\ 9d eb b5 34 db 9f b0 7a 8c 27 21 de e5 40 9f 77
\\ af 0c 3d de 56
);
pub const client_random = hexToBytes(
\\ 00 01 02 03 04 05 06 07 08 09 0a 0b 0c 0d 0e 0f 10 11 12 13 14 15 16 17 18 19 1a 1b 1c 1d 1e 1f
);
pub const server_random = hexToBytes(
\\ 70 71 72 73 74 75 76 77 78 79 7a 7b 7c 7d 7e 7f 80 81 82 83 84 85 86 87 88 89 8a 8b 8c 8d 8e 8f
);
pub const client_secret = hexToBytes(
\\ 20 21 22 23 24 25 26 27 28 29 2a 2b 2c 2d 2e 2f 30 31 32 33 34 35 36 37 38 39 3a 3b 3c 3d 3e 3f
);
pub const server_pub_key = hexToBytes(
\\ 9f d7 ad 6d cf f4 29 8d d3 f9 6d 5b 1b 2a f9 10 a0 53 5b 14 88 d7 f8 fa bb 34 9a 98 28 80 b6 15
);
pub const signature = hexToBytes(
\\ 04 02 b6 61 f7 c1 91 ee 59 be 45 37 66 39 bd c3
\\ d4 bb 81 e1 15 ca 73 c8 34 8b 52 5b 0d 23 38 aa
\\ 14 46 67 ed 94 31 02 14 12 cd 9b 84 4c ba 29 93
\\ 4a aa cc e8 73 41 4e c1 1c b0 2e 27 2d 0a d8 1f
\\ 76 7d 33 07 67 21 f1 3b f3 60 20 cf 0b 1f d0 ec
\\ b0 78 de 11 28 be ba 09 49 eb ec e1 a1 f9 6e 20
\\ 9d c3 6e 4f ff d3 6b 67 3a 7d dc 15 97 ad 44 08
\\ e4 85 c4 ad b2 c8 73 84 12 49 37 25 23 80 9e 43
\\ 12 d0 c7 b3 52 2e f9 83 ca c1 e0 39 35 ff 13 a8
\\ e9 6b a6 81 a6 2e 40 d3 e7 0a 7f f3 58 66 d3 d9
\\ 99 3f 9e 26 a6 34 c8 1b 4e 71 38 0f cd d6 f4 e8
\\ 35 f7 5a 64 09 c7 dc 2c 07 41 0e 6f 87 85 8c 7b
\\ 94 c0 1c 2e 32 f2 91 76 9e ac ca 71 64 3b 8b 98
\\ a9 63 df 0a 32 9b ea 4e d6 39 7e 8c d0 1a 11 0a
\\ b3 61 ac 5b ad 1c cd 84 0a 6c 8a 6e aa 00 1a 9d
\\ 7d 87 dc 33 18 64 35 71 22 6c 4d d2 c2 ac 41 fb
);
pub const cert_pub_key = hexToBytes(
\\ 30 82 01 0a 02 82 01 01 00 c4 80 36 06 ba e7 47
\\ 6b 08 94 04 ec a7 b6 91 04 3f f7 92 bc 19 ee fb
\\ 7d 74 d7 a8 0d 00 1e 7b 4b 3a 4a e6 0f e8 c0 71
\\ fc 73 e7 02 4c 0d bc f4 bd d1 1d 39 6b ba 70 46
\\ 4a 13 e9 4a f8 3d f3 e1 09 59 54 7b c9 55 fb 41
\\ 2d a3 76 52 11 e1 f3 dc 77 6c aa 53 37 6e ca 3a
\\ ec be c3 aa b7 3b 31 d5 6c b6 52 9c 80 98 bc c9
\\ e0 28 18 e2 0b f7 f8 a0 3a fd 17 04 50 9e ce 79
\\ bd 9f 39 f1 ea 69 ec 47 97 2e 83 0f b5 ca 95 de
\\ 95 a1 e6 04 22 d5 ee be 52 79 54 a1 e7 bf 8a 86
\\ f6 46 6d 0d 9f 16 95 1a 4c f7 a0 46 92 59 5c 13
\\ 52 f2 54 9e 5a fb 4e bf d7 7a 37 95 01 44 e4 c0
\\ 26 87 4c 65 3e 40 7d 7d 23 07 44 01 f4 84 ff d0
\\ 8f 7a 1f a0 52 10 d1 f4 f0 d5 ce 79 70 29 32 e2
\\ ca be 70 1f df ad 6b 4b b7 11 01 f4 4b ad 66 6a
\\ 11 13 0f e2 ee 82 9e 4d 02 9d c9 1c dd 67 16 db
\\ b9 06 18 86 ed c1 ba 94 21 02 03 01 00 01
);

View File

@@ -1,64 +0,0 @@
const hexToBytes = @import("../testu.zig").hexToBytes;
pub const client_hello =
hexToBytes("16030100f8010000f40303000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f20e0e1e2e3e4e5e6e7e8e9eaebecedeeeff0f1f2f3f4f5f6f7f8f9fafbfcfdfeff000813021303130100ff010000a30000001800160000136578616d706c652e756c666865696d2e6e6574000b000403000102000a00160014001d0017001e0019001801000101010201030104002300000016000000170000000d001e001c040305030603080708080809080a080b080408050806040105010601002b0003020304002d00020101003300260024001d0020358072d6365880d1aeea329adf9121383851ed21a28e3b75e965d0d2cd166254");
pub const server_hello =
hexToBytes("160303007a") ++ // record header
hexToBytes("020000760303") ++ // handshake header, server version
hexToBytes("707172737475767778797a7b7c7d7e7f808182838485868788898a8b8c8d8e8f") ++ // server_random
hexToBytes("20e0e1e2e3e4e5e6e7e8e9eaebecedeeeff0f1f2f3f4f5f6f7f8f9fafbfcfdfeff") ++ // session id
hexToBytes("130200") ++ // cipher suite, compression method
hexToBytes("002e002b00020304") ++ // extensions, supported version
hexToBytes("00330024001d00209fd7ad6dcff4298dd3f96d5b1b2af910a0535b1488d7f8fabb349a982880b615"); // extension key share
pub const client_random = hexToBytes(
\\ 00 01 02 03 04 05 06 07 08 09 0a 0b 0c 0d 0e 0f 10 11 12 13 14 15 16 17 18 19 1a 1b 1c 1d 1e 1f
);
pub const server_random =
hexToBytes("707172737475767778797a7b7c7d7e7f808182838485868788898a8b8c8d8e8f");
pub const server_pub_key =
hexToBytes("9fd7ad6dcff4298dd3f96d5b1b2af910a0535b1488d7f8fabb349a982880b615");
pub const client_private_key =
hexToBytes("202122232425262728292a2b2c2d2e2f303132333435363738393a3b3c3d3e3f");
pub const client_public_key =
hexToBytes("358072d6365880d1aeea329adf9121383851ed21a28e3b75e965d0d2cd166254");
pub const shared_key = hexToBytes("df4a291baa1eb7cfa6934b29b474baad2697e29f1f920dcc77c8a0a088447624");
pub const server_handshake_key = hexToBytes("9f13575ce3f8cfc1df64a77ceaffe89700b492ad31b4fab01c4792be1b266b7f");
pub const server_handshake_iv = hexToBytes("9563bc8b590f671f488d2da3");
pub const client_handshake_key = hexToBytes("1135b4826a9a70257e5a391ad93093dfd7c4214812f493b3e3daae1eb2b1ac69");
pub const client_handshake_iv = hexToBytes("4256d2e0e88babdd05eb2f27");
pub const server_application_key = hexToBytes("01f78623f17e3edcc09e944027ba3218d57c8e0db93cd3ac419309274700ac27");
pub const server_application_iv = hexToBytes("196a750b0c5049c0cc51a541");
pub const client_application_key = hexToBytes("de2f4c7672723a692319873e5c227606691a32d1c59d8b9f51dbb9352e9ca9cc");
pub const client_application_iv = hexToBytes("bb007956f474b25de902432f");
pub const server_encrypted_extensions_wrapped =
hexToBytes("17030300176be02f9da7c2dc9ddef56f2468b90adfa25101ab0344ae");
pub const server_encrypted_extensions =
hexToBytes("080000020000");
pub const server_certificate_wrapped =
hexToBytes("1703030343baf00a9be50f3f2307e726edcbdacbe4b18616449d46c6207af6e9953ee5d2411ba65d31feaf4f78764f2d693987186cc01329c187a5e4608e8d27b318e98dd94769f7739ce6768392caca8dcc597d77ec0d1272233785f6e69d6f43effa8e7905edfdc4037eee5933e990a7972f206913a31e8d04931366d3d8bcd6a4a4d647dd4bd80b0ff863ce3554833d744cf0e0b9c07cae726dd23f9953df1f1ce3aceb3b7230871e92310cfb2b098486f43538f8e82d8404e5c6c25f66a62ebe3c5f26232640e20a769175ef83483cd81e6cb16e78dfad4c1b714b04b45f6ac8d1065ad18c13451c9055c47da300f93536ea56f531986d6492775393c4ccb095467092a0ec0b43ed7a0687cb470ce350917b0ac30c6e5c24725a78c45f9f5f29b6626867f6f79ce054273547b36df030bd24af10d632dba54fc4e890bd0586928c0206ca2e28e44e227a2d5063195935df38da8936092eef01e84cad2e49d62e470a6c7745f625ec39e4fc23329c79d1172876807c36d736ba42bb69b004ff55f93850dc33c1f98abb92858324c76ff1eb085db3c1fc50f74ec04442e622973ea70743418794c388140bb492d6294a0540e5a59cfae60ba0f14899fca71333315ea083a68e1d7c1e4cdc2f56bcd6119681a4adbc1bbf42afd806c3cbd42a076f545dee4e118d0b396754be2b042a685dd4727e89c0386a94d3cd6ecb9820e9d49afeed66c47e6fc243eabebbcb0b02453877f5ac5dbfbdf8db1052a3c994b224cd9aaaf56b026bb9efa2e01302b36401ab6494e7018d6e5b573bd38bcef023b1fc92946bbca0209ca5fa926b4970b1009103645cb1fcfe552311ff730558984370038fd2cce2a91fc74d6f3e3ea9f843eed356f6f82d35d03bc24b81b58ceb1a43ec9437e6f1e50eb6f555e321fd67c8332eb1b832aa8d795a27d479c6e27d5a61034683891903f66421d094e1b00a9a138d861e6f78a20ad3e1580054d2e305253c713a02fe1e28deee7336246f6ae34331806b46b47b833c39b9d31cd300c2a6ed831399776d07f570eaf0059a2c68a5f3ae16b617404af7b7231a4d942758fc020b3f23ee8c15e36044cfd67cd640993b16207597fbf385ea7a4d99e8d456ff83d41f7b8b4f069b028a2a63a919a70e3a10e3084158faa5bafa30186c6b2f238eb530c73e");
pub const server_certificate =
hexToBytes("0b00032e0000032a0003253082032130820209a0030201020208155a92adc2048f90300d06092a864886f70d01010b05003022310b300906035504061302555331133011060355040a130a4578616d706c65204341301e170d3138313030353031333831375a170d3139313030353031333831375a302b310b3009060355040613025553311c301a060355040313136578616d706c652e756c666865696d2e6e657430820122300d06092a864886f70d01010105000382010f003082010a0282010100c4803606bae7476b089404eca7b691043ff792bc19eefb7d74d7a80d001e7b4b3a4ae60fe8c071fc73e7024c0dbcf4bdd11d396bba70464a13e94af83df3e10959547bc955fb412da3765211e1f3dc776caa53376eca3aecbec3aab73b31d56cb6529c8098bcc9e02818e20bf7f8a03afd1704509ece79bd9f39f1ea69ec47972e830fb5ca95de95a1e60422d5eebe527954a1e7bf8a86f6466d0d9f16951a4cf7a04692595c1352f2549e5afb4ebfd77a37950144e4c026874c653e407d7d23074401f484ffd08f7a1fa05210d1f4f0d5ce79702932e2cabe701fdfad6b4bb71101f44bad666a11130fe2ee829e4d029dc91cdd6716dbb9061886edc1ba94210203010001a3523050300e0603551d0f0101ff0404030205a0301d0603551d250416301406082b0601050507030206082b06010505070301301f0603551d23041830168014894fde5bcc69e252cf3ea300dfb197b81de1c146300d06092a864886f70d01010b05000382010100591645a69a2e3779e4f6dd271aba1c0bfd6cd75599b5e7c36e533eff3659084324c9e7a504079d39e0d42987ffe3ebdd09c1cf1d914455870b571dd19bdf1d24f8bb9a11fe80fd592ba0398cde11e2651e618ce598fa96e5372eef3d248afde17463ebbfabb8e4d1ab502a54ec0064e92f7819660d3f27cf209e667fce5ae2e4ac99c7c93818f8b2510722dfed97f32e3e9349d4c66c9ea6396d744462a06b42c6d5ba688eac3a017bddfc8e2cfcad27cb69d3ccdca280414465d3ae348ce0f34ab2fb9c618371312b191041641c237f11a5d65c844f0404849938712b959ed685bc5c5dd645ed19909473402926dcb40e3469a15941e8e2cca84bb6084636a00000");
pub const server_certificate_verify_wrapped = hexToBytes("170303011973719fce07ec2f6d3bba0292a0d40b2770c06a271799a53314f6f77fc95c5fe7b9a4329fd9548c670ebeea2f2d5c351dd9356ef2dcd52eb137bd3a676522f8cd0fb7560789ad7b0e3caba2e37e6b4199c6793b3346ed46cf740a9fa1fec414dc715c415c60e575703ce6a34b70b5191aa6a61a18faff216c687ad8d17e12a7e99915a611bfc1a2befc15e6e94d784642e682fd17382a348c301056b940c9847200408bec56c81ea3d7217ab8e85a88715395899c90587f72e8ddd74b26d8edc1c7c837d9f2ebbc260962219038b05654a63a0b12999b4a8306a3ddcc0e17c53ba8f9c80363f7841354d291b4ace0c0f330c0fcd5aa9deef969ae8ab2d98da88ebb6ea80a3a11f00ea296a3232367ff075e1c66dd9cbedc4713");
pub const server_finished_wrapped = hexToBytes("17030300451061de27e51c2c9f342911806f282b710c10632ca5006755880dbf7006002d0e84fed9adf27a43b5192303e4df5c285d58e3c76224078440c0742374744aecf28cf3182fd0");
pub const handshake_hash = hexToBytes("fa6800169a6baac19159524fa7b9721b41be3c9db6f3f93fa5ff7e3db3ece204d2b456c51046e40ec5312c55a86126f5");
pub const client_finished_verify_data = hexToBytes("bff56a671b6c659d0a7c5dd18428f58bdd38b184a3ce342d9fde95cbd5056f7da7918ee320eab7a93abd8f1c02454d27");
pub const client_finished_wrapped = hexToBytes("17030300459ff9b063175177322a46dd9896f3c3bb820ab51743ebc25fdadd53454b73deb54cc7248d411a18bccf657a960824e9a19364837c350a69a88d4bf635c85eb874aebc9dfde8");
pub const client_ping_wrapped = hexToBytes("1703030015828139cb7b73aaabf5b82fbf9a2961bcde10038a32");
pub const server_flight =
hexToBytes("140303000101") ++
server_encrypted_extensions_wrapped ++
server_certificate_wrapped ++
server_certificate_verify_wrapped ++
server_finished_wrapped;

View File

@@ -1,117 +0,0 @@
const std = @import("std");
pub fn bufPrint(var_name: []const u8, buf: []const u8) void {
// std.debug.print("\nconst {s} = [_]u8{{\n", .{var_name});
// for (buf, 1..) |b, i| {
// std.debug.print("0x{x:0>2}, ", .{b});
// if (i % 16 == 0)
// std.debug.print("\n", .{});
// }
// std.debug.print("}};\n", .{});
std.debug.print("const {s} = \"", .{var_name});
const charset = "0123456789abcdef";
for (buf) |b| {
const x = charset[b >> 4];
const y = charset[b & 15];
std.debug.print("{c}{c} ", .{ x, y });
}
std.debug.print("\"\n", .{});
}
const random_instance = std.Random{ .ptr = undefined, .fillFn = randomFillFn };
var random_seed: u8 = 0;
pub fn randomFillFn(_: *anyopaque, buf: []u8) void {
for (buf) |*v| {
v.* = random_seed;
random_seed +%= 1;
}
}
pub fn random(seed: u8) std.Random {
random_seed = seed;
return random_instance;
}
// Fill buf with 0,1,..ff,0,...
pub fn fill(buf: []u8) void {
fillFrom(buf, 0);
}
pub fn fillFrom(buf: []u8, start: u8) void {
var i: u8 = start;
for (buf) |*v| {
v.* = i;
i +%= 1;
}
}
pub const Stream = struct {
output: std.io.FixedBufferStream([]u8) = undefined,
input: std.io.FixedBufferStream([]const u8) = undefined,
pub fn init(input: []const u8, output: []u8) Stream {
return .{
.input = std.io.fixedBufferStream(input),
.output = std.io.fixedBufferStream(output),
};
}
pub const ReadError = error{};
pub const WriteError = error{NoSpaceLeft};
pub fn write(self: *Stream, buf: []const u8) !usize {
return try self.output.writer().write(buf);
}
pub fn writeAll(self: *Stream, buffer: []const u8) !void {
var n: usize = 0;
while (n < buffer.len) {
n += try self.write(buffer[n..]);
}
}
pub fn read(self: *Stream, buffer: []u8) !usize {
return self.input.read(buffer);
}
};
// Copied from: https://github.com/clickingbuttons/zig/blob/f1cea91624fd2deae28bfb2414a4fd9c7e246883/lib/std/crypto/rsa.zig#L791
/// For readable copy/pasting from hex viewers.
pub fn hexToBytes(comptime hex: []const u8) [removeNonHex(hex).len / 2]u8 {
@setEvalBranchQuota(1000 * 100);
const hex2 = comptime removeNonHex(hex);
comptime var res: [hex2.len / 2]u8 = undefined;
_ = comptime std.fmt.hexToBytes(&res, hex2) catch unreachable;
return res;
}
fn removeNonHex(comptime hex: []const u8) []const u8 {
@setEvalBranchQuota(1000 * 100);
var res: [hex.len]u8 = undefined;
var i: usize = 0;
for (hex) |c| {
if (std.ascii.isHex(c)) {
res[i] = c;
i += 1;
}
}
return res[0..i];
}
test hexToBytes {
const hex =
\\e3b0c442 98fc1c14 9afbf4c8 996fb924
\\27ae41e4 649b934c a495991b 7852b855
;
try std.testing.expectEqual(
[_]u8{
0xe3, 0xb0, 0xc4, 0x42, 0x98, 0xfc, 0x1c, 0x14,
0x9a, 0xfb, 0xf4, 0xc8, 0x99, 0x6f, 0xb9, 0x24,
0x27, 0xae, 0x41, 0xe4, 0x64, 0x9b, 0x93, 0x4c,
0xa4, 0x95, 0x99, 0x1b, 0x78, 0x52, 0xb8, 0x55,
},
hexToBytes(hex),
);
}

View File

@@ -1,297 +0,0 @@
const std = @import("std");
const crypto = std.crypto;
const tls = crypto.tls;
const hkdfExpandLabel = tls.hkdfExpandLabel;
const Sha256 = crypto.hash.sha2.Sha256;
const Sha384 = crypto.hash.sha2.Sha384;
const Sha512 = crypto.hash.sha2.Sha512;
const HashTag = @import("cipher.zig").CipherSuite.HashTag;
// Transcript holds hash of all handshake message.
//
// Until the server hello is parsed we don't know which hash (sha256, sha384,
// sha512) will be used so we update all of them. Handshake process will set
// `selected` field once cipher suite is known. Other function will use that
// selected hash. We continue to calculate all hashes because client certificate
// message could use different hash than the other part of the handshake.
// Handshake hash is dictated by the server selected cipher. Client certificate
// hash is dictated by the private key used.
//
// Most of the functions are inlined because they are returning pointers.
//
pub const Transcript = struct {
sha256: Type(.sha256) = .{ .hash = Sha256.init(.{}) },
sha384: Type(.sha384) = .{ .hash = Sha384.init(.{}) },
sha512: Type(.sha512) = .{ .hash = Sha512.init(.{}) },
tag: HashTag = .sha256,
pub const max_mac_length = Type(.sha512).mac_length;
// Transcript Type from hash tag
fn Type(h: HashTag) type {
return switch (h) {
.sha256 => TranscriptT(Sha256),
.sha384 => TranscriptT(Sha384),
.sha512 => TranscriptT(Sha512),
};
}
/// Set hash to use in all following function calls.
pub fn use(t: *Transcript, tag: HashTag) void {
t.tag = tag;
}
pub fn update(t: *Transcript, buf: []const u8) void {
t.sha256.hash.update(buf);
t.sha384.hash.update(buf);
t.sha512.hash.update(buf);
}
// tls 1.2 handshake specific
pub inline fn masterSecret(
t: *Transcript,
pre_master_secret: []const u8,
client_random: [32]u8,
server_random: [32]u8,
) []const u8 {
return switch (t.tag) {
inline else => |h| &@field(t, @tagName(h)).masterSecret(
pre_master_secret,
client_random,
server_random,
),
};
}
pub inline fn keyMaterial(
t: *Transcript,
master_secret: []const u8,
client_random: [32]u8,
server_random: [32]u8,
) []const u8 {
return switch (t.tag) {
inline else => |h| &@field(t, @tagName(h)).keyExpansion(
master_secret,
client_random,
server_random,
),
};
}
pub fn clientFinishedTls12(t: *Transcript, master_secret: []const u8) [12]u8 {
return switch (t.tag) {
inline else => |h| @field(t, @tagName(h)).clientFinishedTls12(master_secret),
};
}
pub fn serverFinishedTls12(t: *Transcript, master_secret: []const u8) [12]u8 {
return switch (t.tag) {
inline else => |h| @field(t, @tagName(h)).serverFinishedTls12(master_secret),
};
}
// tls 1.3 handshake specific
pub inline fn serverCertificateVerify(t: *Transcript) []const u8 {
return switch (t.tag) {
inline else => |h| &@field(t, @tagName(h)).serverCertificateVerify(),
};
}
pub inline fn clientCertificateVerify(t: *Transcript) []const u8 {
return switch (t.tag) {
inline else => |h| &@field(t, @tagName(h)).clientCertificateVerify(),
};
}
pub fn serverFinishedTls13(t: *Transcript, buf: []u8) []const u8 {
return switch (t.tag) {
inline else => |h| @field(t, @tagName(h)).serverFinishedTls13(buf),
};
}
pub fn clientFinishedTls13(t: *Transcript, buf: []u8) []const u8 {
return switch (t.tag) {
inline else => |h| @field(t, @tagName(h)).clientFinishedTls13(buf),
};
}
pub const Secret = struct {
client: []const u8,
server: []const u8,
};
pub inline fn handshakeSecret(t: *Transcript, shared_key: []const u8) Secret {
return switch (t.tag) {
inline else => |h| @field(t, @tagName(h)).handshakeSecret(shared_key),
};
}
pub inline fn applicationSecret(t: *Transcript) Secret {
return switch (t.tag) {
inline else => |h| @field(t, @tagName(h)).applicationSecret(),
};
}
// other
pub fn Hkdf(h: HashTag) type {
return Type(h).Hkdf;
}
/// Copy of the current hash value
pub inline fn hash(t: *Transcript, comptime Hash: type) Hash {
return switch (Hash) {
Sha256 => t.sha256.hash,
Sha384 => t.sha384.hash,
Sha512 => t.sha512.hash,
else => @compileError("unimplemented"),
};
}
};
fn TranscriptT(comptime Hash: type) type {
return struct {
const Hmac = crypto.auth.hmac.Hmac(Hash);
const Hkdf = crypto.kdf.hkdf.Hkdf(Hmac);
const mac_length = Hmac.mac_length;
hash: Hash,
handshake_secret: [Hmac.mac_length]u8 = undefined,
server_finished_key: [Hmac.key_length]u8 = undefined,
client_finished_key: [Hmac.key_length]u8 = undefined,
const Self = @This();
fn init(transcript: Hash) Self {
return .{ .transcript = transcript };
}
fn serverCertificateVerify(c: *Self) [64 + 34 + Hash.digest_length]u8 {
return ([1]u8{0x20} ** 64) ++
"TLS 1.3, server CertificateVerify\x00".* ++
c.hash.peek();
}
// ref: https://www.rfc-editor.org/rfc/rfc8446#section-4.4.3
fn clientCertificateVerify(c: *Self) [64 + 34 + Hash.digest_length]u8 {
return ([1]u8{0x20} ** 64) ++
"TLS 1.3, client CertificateVerify\x00".* ++
c.hash.peek();
}
fn masterSecret(
_: *Self,
pre_master_secret: []const u8,
client_random: [32]u8,
server_random: [32]u8,
) [mac_length * 2]u8 {
const seed = "master secret" ++ client_random ++ server_random;
var a1: [mac_length]u8 = undefined;
var a2: [mac_length]u8 = undefined;
Hmac.create(&a1, seed, pre_master_secret);
Hmac.create(&a2, &a1, pre_master_secret);
var p1: [mac_length]u8 = undefined;
var p2: [mac_length]u8 = undefined;
Hmac.create(&p1, a1 ++ seed, pre_master_secret);
Hmac.create(&p2, a2 ++ seed, pre_master_secret);
return p1 ++ p2;
}
fn keyExpansion(
_: *Self,
master_secret: []const u8,
client_random: [32]u8,
server_random: [32]u8,
) [mac_length * 4]u8 {
const seed = "key expansion" ++ server_random ++ client_random;
const a0 = seed;
var a1: [mac_length]u8 = undefined;
var a2: [mac_length]u8 = undefined;
var a3: [mac_length]u8 = undefined;
var a4: [mac_length]u8 = undefined;
Hmac.create(&a1, a0, master_secret);
Hmac.create(&a2, &a1, master_secret);
Hmac.create(&a3, &a2, master_secret);
Hmac.create(&a4, &a3, master_secret);
var key_material: [mac_length * 4]u8 = undefined;
Hmac.create(key_material[0..mac_length], a1 ++ seed, master_secret);
Hmac.create(key_material[mac_length .. mac_length * 2], a2 ++ seed, master_secret);
Hmac.create(key_material[mac_length * 2 .. mac_length * 3], a3 ++ seed, master_secret);
Hmac.create(key_material[mac_length * 3 ..], a4 ++ seed, master_secret);
return key_material;
}
fn clientFinishedTls12(self: *Self, master_secret: []const u8) [12]u8 {
const seed = "client finished" ++ self.hash.peek();
var a1: [mac_length]u8 = undefined;
var p1: [mac_length]u8 = undefined;
Hmac.create(&a1, seed, master_secret);
Hmac.create(&p1, a1 ++ seed, master_secret);
return p1[0..12].*;
}
fn serverFinishedTls12(self: *Self, master_secret: []const u8) [12]u8 {
const seed = "server finished" ++ self.hash.peek();
var a1: [mac_length]u8 = undefined;
var p1: [mac_length]u8 = undefined;
Hmac.create(&a1, seed, master_secret);
Hmac.create(&p1, a1 ++ seed, master_secret);
return p1[0..12].*;
}
// tls 1.3
inline fn handshakeSecret(self: *Self, shared_key: []const u8) Transcript.Secret {
const hello_hash = self.hash.peek();
const zeroes = [1]u8{0} ** Hash.digest_length;
const early_secret = Hkdf.extract(&[1]u8{0}, &zeroes);
const empty_hash = tls.emptyHash(Hash);
const hs_derived_secret = hkdfExpandLabel(Hkdf, early_secret, "derived", &empty_hash, Hash.digest_length);
self.handshake_secret = Hkdf.extract(&hs_derived_secret, shared_key);
const client_secret = hkdfExpandLabel(Hkdf, self.handshake_secret, "c hs traffic", &hello_hash, Hash.digest_length);
const server_secret = hkdfExpandLabel(Hkdf, self.handshake_secret, "s hs traffic", &hello_hash, Hash.digest_length);
self.server_finished_key = hkdfExpandLabel(Hkdf, server_secret, "finished", "", Hmac.key_length);
self.client_finished_key = hkdfExpandLabel(Hkdf, client_secret, "finished", "", Hmac.key_length);
return .{ .client = &client_secret, .server = &server_secret };
}
inline fn applicationSecret(self: *Self) Transcript.Secret {
const handshake_hash = self.hash.peek();
const empty_hash = tls.emptyHash(Hash);
const zeroes = [1]u8{0} ** Hash.digest_length;
const ap_derived_secret = hkdfExpandLabel(Hkdf, self.handshake_secret, "derived", &empty_hash, Hash.digest_length);
const master_secret = Hkdf.extract(&ap_derived_secret, &zeroes);
const client_secret = hkdfExpandLabel(Hkdf, master_secret, "c ap traffic", &handshake_hash, Hash.digest_length);
const server_secret = hkdfExpandLabel(Hkdf, master_secret, "s ap traffic", &handshake_hash, Hash.digest_length);
return .{ .client = &client_secret, .server = &server_secret };
}
fn serverFinishedTls13(self: *Self, buf: []u8) []const u8 {
Hmac.create(buf[0..mac_length], &self.hash.peek(), &self.server_finished_key);
return buf[0..mac_length];
}
// client finished message with header
fn clientFinishedTls13(self: *Self, buf: []u8) []const u8 {
Hmac.create(buf[0..mac_length], &self.hash.peek(), &self.client_finished_key);
return buf[0..mac_length];
}
};
}

View File

@@ -30,6 +30,7 @@ const apiweb = @import("apiweb.zig");
pub const Types = jsruntime.reflect(apiweb.Interfaces); pub const Types = jsruntime.reflect(apiweb.Interfaces);
pub const UserContext = apiweb.UserContext; pub const UserContext = apiweb.UserContext;
pub const IO = @import("asyncio").Wrapper(jsruntime.Loop);
// Default options // Default options
const Host = "127.0.0.1"; const Host = "127.0.0.1";

View File

@@ -24,12 +24,13 @@ const parser = @import("netsurf");
const apiweb = @import("apiweb.zig"); const apiweb = @import("apiweb.zig");
const Window = @import("html/window.zig").Window; const Window = @import("html/window.zig").Window;
const storage = @import("storage/storage.zig"); const storage = @import("storage/storage.zig");
const Client = @import("asyncio").Client;
const html_test = @import("html_test.zig").html; const html_test = @import("html_test.zig").html;
pub const Types = jsruntime.reflect(apiweb.Interfaces); pub const Types = jsruntime.reflect(apiweb.Interfaces);
pub const UserContext = apiweb.UserContext; pub const UserContext = apiweb.UserContext;
const Client = @import("http/async/main.zig").Client; pub const IO = @import("asyncio").Wrapper(jsruntime.Loop);
var doc: *parser.DocumentHTML = undefined; var doc: *parser.DocumentHTML = undefined;

View File

@@ -50,6 +50,7 @@ const Out = enum {
pub const Types = jsruntime.reflect(apiweb.Interfaces); pub const Types = jsruntime.reflect(apiweb.Interfaces);
pub const GlobalType = apiweb.GlobalType; pub const GlobalType = apiweb.GlobalType;
pub const UserContext = apiweb.UserContext; pub const UserContext = apiweb.UserContext;
pub const IO = @import("asyncio").Wrapper(jsruntime.Loop);
// TODO For now the WPT tests run is specific to WPT. // TODO For now the WPT tests run is specific to WPT.
// It manually load js framwork libs, and run the first script w/ js content in // It manually load js framwork libs, and run the first script w/ js content in

View File

@@ -30,7 +30,7 @@ const xhr = @import("xhr/xhr.zig");
const storage = @import("storage/storage.zig"); const storage = @import("storage/storage.zig");
const url = @import("url/url.zig"); const url = @import("url/url.zig");
const urlquery = @import("url/query.zig"); const urlquery = @import("url/query.zig");
const Client = @import("http/async/main.zig").Client; const Client = @import("asyncio").Client;
const documentTestExecFn = @import("dom/document.zig").testExecFn; const documentTestExecFn = @import("dom/document.zig").testExecFn;
const HTMLDocumentTestExecFn = @import("html/document.zig").testExecFn; const HTMLDocumentTestExecFn = @import("html/document.zig").testExecFn;
@@ -59,6 +59,7 @@ const MutationObserverTestExecFn = @import("dom/mutation_observer.zig").testExec
pub const Types = jsruntime.reflect(apiweb.Interfaces); pub const Types = jsruntime.reflect(apiweb.Interfaces);
pub const UserContext = @import("user_context.zig").UserContext; pub const UserContext = @import("user_context.zig").UserContext;
pub const IO = @import("asyncio").Wrapper(jsruntime.Loop);
var doc: *parser.DocumentHTML = undefined; var doc: *parser.DocumentHTML = undefined;
@@ -298,9 +299,6 @@ test {
const msgTest = @import("msg.zig"); const msgTest = @import("msg.zig");
std.testing.refAllDecls(msgTest); std.testing.refAllDecls(msgTest);
std.testing.refAllDecls(@import("http/async/std/http.zig"));
std.testing.refAllDecls(@import("http/async/stack.zig"));
const dumpTest = @import("browser/dump.zig"); const dumpTest = @import("browser/dump.zig");
std.testing.refAllDecls(dumpTest); std.testing.refAllDecls(dumpTest);

View File

@@ -22,6 +22,7 @@ const tests = @import("run_tests.zig");
pub const Types = tests.Types; pub const Types = tests.Types;
pub const UserContext = tests.UserContext; pub const UserContext = tests.UserContext;
pub const IO = tests.IO;
pub fn main() !void { pub fn main() !void {
try tests.main(); try tests.main();

View File

@@ -1,6 +1,6 @@
const std = @import("std"); const std = @import("std");
const parser = @import("netsurf"); const parser = @import("netsurf");
const Client = @import("http/async/main.zig").Client; const Client = @import("asyncio").Client;
pub const UserContext = struct { pub const UserContext = struct {
document: *parser.DocumentHTML, document: *parser.DocumentHTML,

View File

@@ -28,10 +28,10 @@ const Loop = jsruntime.Loop;
const Env = jsruntime.Env; const Env = jsruntime.Env;
const Window = @import("../html/window.zig").Window; const Window = @import("../html/window.zig").Window;
const storage = @import("../storage/storage.zig"); const storage = @import("../storage/storage.zig");
const Client = @import("asyncio").Client;
const Types = @import("../main_wpt.zig").Types; const Types = @import("../main_wpt.zig").Types;
const UserContext = @import("../main_wpt.zig").UserContext; const UserContext = @import("../main_wpt.zig").UserContext;
const Client = @import("../http/async/main.zig").Client;
// runWPT parses the given HTML file, starts a js env and run the first script // runWPT parses the given HTML file, starts a js env and run the first script
// tags containing javascript sources. // tags containing javascript sources.

View File

@@ -32,7 +32,7 @@ const XMLHttpRequestEventTarget = @import("event_target.zig").XMLHttpRequestEven
const Mime = @import("../browser/mime.zig"); const Mime = @import("../browser/mime.zig");
const Loop = jsruntime.Loop; const Loop = jsruntime.Loop;
const Client = @import("../http/async/main.zig").Client; const Client = @import("asyncio").Client;
const parser = @import("netsurf"); const parser = @import("netsurf");
@@ -97,7 +97,7 @@ pub const XMLHttpRequest = struct {
proto: XMLHttpRequestEventTarget = XMLHttpRequestEventTarget{}, proto: XMLHttpRequestEventTarget = XMLHttpRequestEventTarget{},
alloc: std.mem.Allocator, alloc: std.mem.Allocator,
cli: *Client, cli: *Client,
loop: Client.Loop, io: Client.IO,
priv_state: PrivState = .new, priv_state: PrivState = .new,
req: ?Client.Request = null, req: ?Client.Request = null,
@@ -294,7 +294,7 @@ pub const XMLHttpRequest = struct {
.alloc = alloc, .alloc = alloc,
.headers = Headers.init(alloc), .headers = Headers.init(alloc),
.response_headers = Headers.init(alloc), .response_headers = Headers.init(alloc),
.loop = Client.Loop.init(loop), .io = Client.IO.init(loop),
.method = undefined, .method = undefined,
.url = null, .url = null,
.uri = undefined, .uri = undefined,
@@ -513,7 +513,7 @@ pub const XMLHttpRequest = struct {
self.req = null; self.req = null;
} }
self.ctx = try Client.Ctx.init(&self.loop, &self.req.?); self.ctx = try Client.Ctx.init(&self.io, &self.req.?);
errdefer { errdefer {
self.ctx.?.deinit(); self.ctx.?.deinit();
self.ctx = null; self.ctx = null;

1
vendor/zig-async-io vendored Submodule

Submodule vendor/zig-async-io added at d996742c00