mirror of
https://github.com/lightpanda-io/browser.git
synced 2025-10-29 15:13:28 +00:00
521 lines
23 KiB
Zig
521 lines
23 KiB
Zig
const std = @import("std");
|
|
const assert = std.debug.assert;
|
|
const crypto = std.crypto;
|
|
const mem = std.mem;
|
|
const Certificate = crypto.Certificate;
|
|
|
|
const cipher = @import("cipher.zig");
|
|
const Cipher = cipher.Cipher;
|
|
const CipherSuite = @import("cipher.zig").CipherSuite;
|
|
const cipher_suites = @import("cipher.zig").cipher_suites;
|
|
const Transcript = @import("transcript.zig").Transcript;
|
|
const record = @import("record.zig");
|
|
const PrivateKey = @import("PrivateKey.zig");
|
|
const proto = @import("protocol.zig");
|
|
|
|
const common = @import("handshake_common.zig");
|
|
const dupe = common.dupe;
|
|
const CertificateBuilder = common.CertificateBuilder;
|
|
const CertificateParser = common.CertificateParser;
|
|
const DhKeyPair = common.DhKeyPair;
|
|
const CertBundle = common.CertBundle;
|
|
const CertKeyPair = common.CertKeyPair;
|
|
|
|
pub const Options = struct {
|
|
/// Server authentication. If null server will not send Certificate and
|
|
/// CertificateVerify message.
|
|
auth: ?CertKeyPair,
|
|
|
|
/// If not null server will request client certificate. If auth_type is
|
|
/// .request empty client certificate message will be accepted.
|
|
/// Client certificate will be verified with root_ca certificates.
|
|
client_auth: ?ClientAuth = null,
|
|
};
|
|
|
|
pub const ClientAuth = struct {
|
|
/// Set of root certificate authorities that server use when verifying
|
|
/// client certificates.
|
|
root_ca: CertBundle,
|
|
|
|
auth_type: Type = .require,
|
|
|
|
pub const Type = enum {
|
|
/// Client certificate will be requested during the handshake, but does
|
|
/// not require that the client send any certificates.
|
|
request,
|
|
/// Client certificate will be requested during the handshake, and client
|
|
/// has to send valid certificate.
|
|
require,
|
|
};
|
|
};
|
|
|
|
pub fn Handshake(comptime Stream: type) type {
|
|
const RecordReaderT = record.Reader(Stream);
|
|
return struct {
|
|
// public key len: x25519 = 32, secp256r1 = 65, secp384r1 = 97
|
|
const max_pub_key_len = 98;
|
|
const supported_named_groups = &[_]proto.NamedGroup{ .x25519, .secp256r1, .secp384r1 };
|
|
|
|
server_random: [32]u8 = undefined,
|
|
client_random: [32]u8 = undefined,
|
|
legacy_session_id_buf: [32]u8 = undefined,
|
|
legacy_session_id: []u8 = "",
|
|
cipher_suite: CipherSuite = @enumFromInt(0),
|
|
signature_scheme: proto.SignatureScheme = @enumFromInt(0),
|
|
named_group: proto.NamedGroup = @enumFromInt(0),
|
|
client_pub_key_buf: [max_pub_key_len]u8 = undefined,
|
|
client_pub_key: []u8 = "",
|
|
server_pub_key_buf: [max_pub_key_len]u8 = undefined,
|
|
server_pub_key: []u8 = "",
|
|
|
|
cipher: Cipher = undefined,
|
|
transcript: Transcript = .{},
|
|
rec_rdr: *RecordReaderT,
|
|
buffer: []u8,
|
|
|
|
const HandshakeT = @This();
|
|
|
|
pub fn init(buf: []u8, rec_rdr: *RecordReaderT) HandshakeT {
|
|
return .{
|
|
.rec_rdr = rec_rdr,
|
|
.buffer = buf,
|
|
};
|
|
}
|
|
|
|
fn writeAlert(h: *HandshakeT, stream: Stream, cph: ?*Cipher, err: anyerror) !void {
|
|
if (cph) |c| {
|
|
const cleartext = proto.alertFromError(err);
|
|
const ciphertext = try c.encrypt(h.buffer, .alert, &cleartext);
|
|
stream.writeAll(ciphertext) catch {};
|
|
} else {
|
|
const alert = record.header(.alert, 2) ++ proto.alertFromError(err);
|
|
stream.writeAll(&alert) catch {};
|
|
}
|
|
}
|
|
|
|
pub fn handshake(h: *HandshakeT, stream: Stream, opt: Options) !Cipher {
|
|
crypto.random.bytes(&h.server_random);
|
|
if (opt.auth) |a| {
|
|
// required signature scheme in client hello
|
|
h.signature_scheme = a.key.signature_scheme;
|
|
}
|
|
|
|
h.readClientHello() catch |err| {
|
|
try h.writeAlert(stream, null, err);
|
|
return err;
|
|
};
|
|
h.transcript.use(h.cipher_suite.hash());
|
|
|
|
const server_flight = brk: {
|
|
var w = record.Writer{ .buf = h.buffer };
|
|
|
|
const shared_key = h.sharedKey() catch |err| {
|
|
try h.writeAlert(stream, null, err);
|
|
return err;
|
|
};
|
|
{
|
|
const hello = try h.makeServerHello(w.getFree());
|
|
h.transcript.update(hello[record.header_len..]);
|
|
w.pos += hello.len;
|
|
}
|
|
{
|
|
const handshake_secret = h.transcript.handshakeSecret(shared_key);
|
|
h.cipher = try Cipher.initTls13(h.cipher_suite, handshake_secret, .server);
|
|
}
|
|
try w.writeRecord(.change_cipher_spec, &[_]u8{1});
|
|
{
|
|
const encrypted_extensions = &record.handshakeHeader(.encrypted_extensions, 2) ++ [_]u8{ 0, 0 };
|
|
h.transcript.update(encrypted_extensions);
|
|
try h.writeEncrypted(&w, encrypted_extensions);
|
|
}
|
|
if (opt.client_auth) |_| {
|
|
const certificate_request = try makeCertificateRequest(w.getPayload());
|
|
h.transcript.update(certificate_request);
|
|
try h.writeEncrypted(&w, certificate_request);
|
|
}
|
|
if (opt.auth) |a| {
|
|
const cm = CertificateBuilder{
|
|
.bundle = a.bundle,
|
|
.key = a.key,
|
|
.transcript = &h.transcript,
|
|
.side = .server,
|
|
};
|
|
{
|
|
const certificate = try cm.makeCertificate(w.getPayload());
|
|
h.transcript.update(certificate);
|
|
try h.writeEncrypted(&w, certificate);
|
|
}
|
|
{
|
|
const certificate_verify = try cm.makeCertificateVerify(w.getPayload());
|
|
h.transcript.update(certificate_verify);
|
|
try h.writeEncrypted(&w, certificate_verify);
|
|
}
|
|
}
|
|
{
|
|
const finished = try h.makeFinished(w.getPayload());
|
|
h.transcript.update(finished);
|
|
try h.writeEncrypted(&w, finished);
|
|
}
|
|
break :brk w.getWritten();
|
|
};
|
|
try stream.writeAll(server_flight);
|
|
|
|
var app_cipher = brk: {
|
|
const application_secret = h.transcript.applicationSecret();
|
|
break :brk try Cipher.initTls13(h.cipher_suite, application_secret, .server);
|
|
};
|
|
|
|
h.readClientFlight2(opt) catch |err| {
|
|
// Alert received from client
|
|
if (!mem.startsWith(u8, @errorName(err), "TlsAlert")) {
|
|
try h.writeAlert(stream, &app_cipher, err);
|
|
}
|
|
return err;
|
|
};
|
|
return app_cipher;
|
|
}
|
|
|
|
inline fn sharedKey(h: *HandshakeT) ![]const u8 {
|
|
var seed: [DhKeyPair.seed_len]u8 = undefined;
|
|
crypto.random.bytes(&seed);
|
|
var kp = try DhKeyPair.init(seed, supported_named_groups);
|
|
h.server_pub_key = dupe(&h.server_pub_key_buf, try kp.publicKey(h.named_group));
|
|
return try kp.sharedKey(h.named_group, h.client_pub_key);
|
|
}
|
|
|
|
fn readClientFlight2(h: *HandshakeT, opt: Options) !void {
|
|
var cleartext_buf = h.buffer;
|
|
var cleartext_buf_head: usize = 0;
|
|
var cleartext_buf_tail: usize = 0;
|
|
var handshake_state: proto.Handshake = .finished;
|
|
var cert: CertificateParser = undefined;
|
|
if (opt.client_auth) |client_auth| {
|
|
cert = .{ .root_ca = client_auth.root_ca.bundle, .host = "" };
|
|
handshake_state = .certificate;
|
|
}
|
|
|
|
outer: while (true) {
|
|
const rec = (try h.rec_rdr.next() orelse return error.EndOfStream);
|
|
if (rec.protocol_version != .tls_1_2 and rec.content_type != .alert)
|
|
return error.TlsProtocolVersion;
|
|
|
|
switch (rec.content_type) {
|
|
.change_cipher_spec => {
|
|
if (rec.payload.len != 1) return error.TlsUnexpectedMessage;
|
|
},
|
|
.application_data => {
|
|
const content_type, const cleartext = try h.cipher.decrypt(
|
|
cleartext_buf[cleartext_buf_tail..],
|
|
rec,
|
|
);
|
|
cleartext_buf_tail += cleartext.len;
|
|
if (cleartext_buf_tail > cleartext_buf.len) return error.TlsRecordOverflow;
|
|
|
|
var d = record.Decoder.init(content_type, cleartext_buf[cleartext_buf_head..cleartext_buf_tail]);
|
|
try d.expectContentType(.handshake);
|
|
while (!d.eof()) {
|
|
const start_idx = d.idx;
|
|
const handshake_type = try d.decode(proto.Handshake);
|
|
const length = try d.decode(u24);
|
|
|
|
if (length > cipher.max_cleartext_len)
|
|
return error.TlsRecordOverflow;
|
|
if (length > d.rest().len)
|
|
continue :outer; // fragmented handshake into multiple records
|
|
|
|
defer {
|
|
const handshake_payload = d.payload[start_idx..d.idx];
|
|
h.transcript.update(handshake_payload);
|
|
cleartext_buf_head += handshake_payload.len;
|
|
}
|
|
|
|
if (handshake_state != handshake_type)
|
|
return error.TlsUnexpectedMessage;
|
|
|
|
switch (handshake_type) {
|
|
.certificate => {
|
|
if (length == 4) {
|
|
// got empty certificate message
|
|
if (opt.client_auth.?.auth_type == .require)
|
|
return error.TlsCertificateRequired;
|
|
try d.skip(length);
|
|
handshake_state = .finished;
|
|
} else {
|
|
try cert.parseCertificate(&d, .tls_1_3);
|
|
handshake_state = .certificate_verify;
|
|
}
|
|
},
|
|
.certificate_verify => {
|
|
try cert.parseCertificateVerify(&d);
|
|
cert.verifySignature(h.transcript.clientCertificateVerify()) catch |err| return switch (err) {
|
|
error.TlsUnknownSignatureScheme => error.TlsIllegalParameter,
|
|
else => error.TlsDecryptError,
|
|
};
|
|
handshake_state = .finished;
|
|
},
|
|
.finished => {
|
|
const actual = try d.slice(length);
|
|
var buf: [Transcript.max_mac_length]u8 = undefined;
|
|
const expected = h.transcript.clientFinishedTls13(&buf);
|
|
if (!mem.eql(u8, expected, actual))
|
|
return if (expected.len == actual.len)
|
|
error.TlsDecryptError
|
|
else
|
|
error.TlsDecodeError;
|
|
return;
|
|
},
|
|
else => return error.TlsUnexpectedMessage,
|
|
}
|
|
}
|
|
cleartext_buf_head = 0;
|
|
cleartext_buf_tail = 0;
|
|
},
|
|
.alert => {
|
|
var d = rec.decoder();
|
|
return d.raiseAlert();
|
|
},
|
|
else => return error.TlsUnexpectedMessage,
|
|
}
|
|
}
|
|
}
|
|
|
|
fn makeFinished(h: *HandshakeT, buf: []u8) ![]const u8 {
|
|
var w = record.Writer{ .buf = buf };
|
|
const verify_data = h.transcript.serverFinishedTls13(w.getHandshakePayload());
|
|
try w.advanceHandshake(.finished, verify_data.len);
|
|
return w.getWritten();
|
|
}
|
|
|
|
/// Write encrypted handshake message into `w`
|
|
fn writeEncrypted(h: *HandshakeT, w: *record.Writer, cleartext: []const u8) !void {
|
|
const ciphertext = try h.cipher.encrypt(w.getFree(), .handshake, cleartext);
|
|
w.pos += ciphertext.len;
|
|
}
|
|
|
|
fn makeServerHello(h: *HandshakeT, buf: []u8) ![]const u8 {
|
|
const header_len = 9; // tls record header (5 bytes) and handshake header (4 bytes)
|
|
var w = record.Writer{ .buf = buf[header_len..] };
|
|
|
|
try w.writeEnum(proto.Version.tls_1_2);
|
|
try w.write(&h.server_random);
|
|
{
|
|
try w.writeInt(@as(u8, @intCast(h.legacy_session_id.len)));
|
|
if (h.legacy_session_id.len > 0) try w.write(h.legacy_session_id);
|
|
}
|
|
try w.writeEnum(h.cipher_suite);
|
|
try w.write(&[_]u8{0}); // compression method
|
|
|
|
var e = record.Writer{ .buf = buf[header_len + w.pos + 2 ..] };
|
|
{ // supported versions extension
|
|
try e.writeEnum(proto.Extension.supported_versions);
|
|
try e.writeInt(@as(u16, 2));
|
|
try e.writeEnum(proto.Version.tls_1_3);
|
|
}
|
|
{ // key share extension
|
|
const key_len: u16 = @intCast(h.server_pub_key.len);
|
|
try e.writeEnum(proto.Extension.key_share);
|
|
try e.writeInt(key_len + 4);
|
|
try e.writeEnum(h.named_group);
|
|
try e.writeInt(key_len);
|
|
try e.write(h.server_pub_key);
|
|
}
|
|
try w.writeInt(@as(u16, @intCast(e.pos))); // extensions length
|
|
|
|
const payload_len = w.pos + e.pos;
|
|
buf[0..header_len].* = record.header(.handshake, 4 + payload_len) ++
|
|
record.handshakeHeader(.server_hello, payload_len);
|
|
|
|
return buf[0 .. header_len + payload_len];
|
|
}
|
|
|
|
fn makeCertificateRequest(buf: []u8) ![]const u8 {
|
|
// handshake header + context length + extensions length
|
|
const header_len = 4 + 1 + 2;
|
|
|
|
// First write extensions, leave space for header.
|
|
var ext = record.Writer{ .buf = buf[header_len..] };
|
|
try ext.writeExtension(.signature_algorithms, common.supported_signature_algorithms);
|
|
|
|
var w = record.Writer{ .buf = buf };
|
|
try w.writeHandshakeHeader(.certificate_request, ext.pos + 3);
|
|
try w.writeInt(@as(u8, 0)); // certificate request context length = 0
|
|
try w.writeInt(@as(u16, @intCast(ext.pos))); // extensions length
|
|
assert(w.pos == header_len);
|
|
w.pos += ext.pos;
|
|
|
|
return w.getWritten();
|
|
}
|
|
|
|
fn readClientHello(h: *HandshakeT) !void {
|
|
var d = try h.rec_rdr.nextDecoder();
|
|
try d.expectContentType(.handshake);
|
|
h.transcript.update(d.payload);
|
|
|
|
const handshake_type = try d.decode(proto.Handshake);
|
|
if (handshake_type != .client_hello) return error.TlsUnexpectedMessage;
|
|
_ = try d.decode(u24); // handshake length
|
|
if (try d.decode(proto.Version) != .tls_1_2) return error.TlsProtocolVersion;
|
|
|
|
h.client_random = try d.array(32);
|
|
{ // legacy session id
|
|
const len = try d.decode(u8);
|
|
h.legacy_session_id = dupe(&h.legacy_session_id_buf, try d.slice(len));
|
|
}
|
|
{ // cipher suites
|
|
const end_idx = try d.decode(u16) + d.idx;
|
|
|
|
while (d.idx < end_idx) {
|
|
const cipher_suite = try d.decode(CipherSuite);
|
|
if (cipher_suites.includes(cipher_suites.tls13, cipher_suite) and
|
|
@intFromEnum(h.cipher_suite) == 0)
|
|
{
|
|
h.cipher_suite = cipher_suite;
|
|
}
|
|
}
|
|
if (@intFromEnum(h.cipher_suite) == 0)
|
|
return error.TlsHandshakeFailure;
|
|
}
|
|
try d.skip(2); // compression methods
|
|
|
|
var key_share_received = false;
|
|
// extensions
|
|
const extensions_end_idx = try d.decode(u16) + d.idx;
|
|
while (d.idx < extensions_end_idx) {
|
|
const extension_type = try d.decode(proto.Extension);
|
|
const extension_len = try d.decode(u16);
|
|
|
|
switch (extension_type) {
|
|
.supported_versions => {
|
|
var tls_1_3_supported = false;
|
|
const end_idx = try d.decode(u8) + d.idx;
|
|
while (d.idx < end_idx) {
|
|
if (try d.decode(proto.Version) == proto.Version.tls_1_3) {
|
|
tls_1_3_supported = true;
|
|
}
|
|
}
|
|
if (!tls_1_3_supported) return error.TlsProtocolVersion;
|
|
},
|
|
.key_share => {
|
|
if (extension_len == 0) return error.TlsDecodeError;
|
|
key_share_received = true;
|
|
var selected_named_group_idx = supported_named_groups.len;
|
|
const end_idx = try d.decode(u16) + d.idx;
|
|
while (d.idx < end_idx) {
|
|
const named_group = try d.decode(proto.NamedGroup);
|
|
switch (@intFromEnum(named_group)) {
|
|
0x0001...0x0016,
|
|
0x001a...0x001c,
|
|
0xff01...0xff02,
|
|
=> return error.TlsIllegalParameter,
|
|
else => {},
|
|
}
|
|
const client_pub_key = try d.slice(try d.decode(u16));
|
|
for (supported_named_groups, 0..) |supported, idx| {
|
|
if (named_group == supported and idx < selected_named_group_idx) {
|
|
h.named_group = named_group;
|
|
h.client_pub_key = dupe(&h.client_pub_key_buf, client_pub_key);
|
|
selected_named_group_idx = idx;
|
|
}
|
|
}
|
|
}
|
|
if (@intFromEnum(h.named_group) == 0)
|
|
return error.TlsIllegalParameter;
|
|
},
|
|
.supported_groups => {
|
|
const end_idx = try d.decode(u16) + d.idx;
|
|
while (d.idx < end_idx) {
|
|
const named_group = try d.decode(proto.NamedGroup);
|
|
switch (@intFromEnum(named_group)) {
|
|
0x0001...0x0016,
|
|
0x001a...0x001c,
|
|
0xff01...0xff02,
|
|
=> return error.TlsIllegalParameter,
|
|
else => {},
|
|
}
|
|
}
|
|
},
|
|
.signature_algorithms => {
|
|
if (@intFromEnum(h.signature_scheme) == 0) {
|
|
try d.skip(extension_len);
|
|
} else {
|
|
var found = false;
|
|
const list_len = try d.decode(u16);
|
|
if (list_len == 0) return error.TlsDecodeError;
|
|
const end_idx = list_len + d.idx;
|
|
while (d.idx < end_idx) {
|
|
const signature_scheme = try d.decode(proto.SignatureScheme);
|
|
if (signature_scheme == h.signature_scheme) found = true;
|
|
}
|
|
if (!found) return error.TlsHandshakeFailure;
|
|
}
|
|
},
|
|
else => {
|
|
try d.skip(extension_len);
|
|
},
|
|
}
|
|
}
|
|
if (!key_share_received) return error.TlsMissingExtension;
|
|
if (@intFromEnum(h.named_group) == 0) return error.TlsIllegalParameter;
|
|
}
|
|
};
|
|
}
|
|
|
|
const testing = std.testing;
|
|
const data13 = @import("testdata/tls13.zig");
|
|
const testu = @import("testu.zig");
|
|
|
|
fn testReader(data: []const u8) record.Reader(std.io.FixedBufferStream([]const u8)) {
|
|
return record.reader(std.io.fixedBufferStream(data));
|
|
}
|
|
const TestHandshake = Handshake(std.io.FixedBufferStream([]const u8));
|
|
|
|
test "read client hello" {
|
|
var buffer: [1024]u8 = undefined;
|
|
var rec_rdr = testReader(&data13.client_hello);
|
|
var h = TestHandshake.init(&buffer, &rec_rdr);
|
|
h.signature_scheme = .ecdsa_secp521r1_sha512; // this must be supported in signature_algorithms extension
|
|
try h.readClientHello();
|
|
|
|
try testing.expectEqual(CipherSuite.AES_256_GCM_SHA384, h.cipher_suite);
|
|
try testing.expectEqual(.x25519, h.named_group);
|
|
try testing.expectEqualSlices(u8, &data13.client_random, &h.client_random);
|
|
try testing.expectEqualSlices(u8, &data13.client_public_key, h.client_pub_key);
|
|
}
|
|
|
|
test "make server hello" {
|
|
var buffer: [128]u8 = undefined;
|
|
var h = TestHandshake.init(&buffer, undefined);
|
|
h.cipher_suite = .AES_256_GCM_SHA384;
|
|
testu.fillFrom(&h.server_random, 0);
|
|
testu.fillFrom(&h.server_pub_key_buf, 0x20);
|
|
h.named_group = .x25519;
|
|
h.server_pub_key = h.server_pub_key_buf[0..32];
|
|
|
|
const actual = try h.makeServerHello(&buffer);
|
|
const expected = &testu.hexToBytes(
|
|
\\ 16 03 03 00 5a 02 00 00 56
|
|
\\ 03 03
|
|
\\ 00 01 02 03 04 05 06 07 08 09 0a 0b 0c 0d 0e 0f 10 11 12 13 14 15 16 17 18 19 1a 1b 1c 1d 1e 1f
|
|
\\ 00
|
|
\\ 13 02 00
|
|
\\ 00 2e 00 2b 00 02 03 04
|
|
\\ 00 33 00 24 00 1d 00 20
|
|
\\ 20 21 22 23 24 25 26 27 28 29 2a 2b 2c 2d 2e 2f 30 31 32 33 34 35 36 37 38 39 3a 3b 3c 3d 3e 3f
|
|
);
|
|
try testing.expectEqualSlices(u8, expected, actual);
|
|
}
|
|
|
|
test "make certificate request" {
|
|
var buffer: [32]u8 = undefined;
|
|
|
|
const expected = testu.hexToBytes("0d 00 00 1b" ++ // handshake header
|
|
"00 00 18" ++ // extension length
|
|
"00 0d" ++ // signature algorithms extension
|
|
"00 14" ++ // extension length
|
|
"00 12" ++ // list length 6 * 2 bytes
|
|
"04 03 05 03 08 04 08 05 08 06 08 07 02 01 04 01 05 01" // signature schemes
|
|
);
|
|
const actual = try TestHandshake.makeCertificateRequest(&buffer);
|
|
try testing.expectEqualSlices(u8, &expected, actual);
|
|
}
|