Files
browser/src/http/async/tls.zig/handshake_server.zig
2024-11-18 17:39:37 +01:00

521 lines
23 KiB
Zig

const std = @import("std");
const assert = std.debug.assert;
const crypto = std.crypto;
const mem = std.mem;
const Certificate = crypto.Certificate;
const cipher = @import("cipher.zig");
const Cipher = cipher.Cipher;
const CipherSuite = @import("cipher.zig").CipherSuite;
const cipher_suites = @import("cipher.zig").cipher_suites;
const Transcript = @import("transcript.zig").Transcript;
const record = @import("record.zig");
const PrivateKey = @import("PrivateKey.zig");
const proto = @import("protocol.zig");
const common = @import("handshake_common.zig");
const dupe = common.dupe;
const CertificateBuilder = common.CertificateBuilder;
const CertificateParser = common.CertificateParser;
const DhKeyPair = common.DhKeyPair;
const CertBundle = common.CertBundle;
const CertKeyPair = common.CertKeyPair;
pub const Options = struct {
/// Server authentication. If null server will not send Certificate and
/// CertificateVerify message.
auth: ?CertKeyPair,
/// If not null server will request client certificate. If auth_type is
/// .request empty client certificate message will be accepted.
/// Client certificate will be verified with root_ca certificates.
client_auth: ?ClientAuth = null,
};
pub const ClientAuth = struct {
/// Set of root certificate authorities that server use when verifying
/// client certificates.
root_ca: CertBundle,
auth_type: Type = .require,
pub const Type = enum {
/// Client certificate will be requested during the handshake, but does
/// not require that the client send any certificates.
request,
/// Client certificate will be requested during the handshake, and client
/// has to send valid certificate.
require,
};
};
pub fn Handshake(comptime Stream: type) type {
const RecordReaderT = record.Reader(Stream);
return struct {
// public key len: x25519 = 32, secp256r1 = 65, secp384r1 = 97
const max_pub_key_len = 98;
const supported_named_groups = &[_]proto.NamedGroup{ .x25519, .secp256r1, .secp384r1 };
server_random: [32]u8 = undefined,
client_random: [32]u8 = undefined,
legacy_session_id_buf: [32]u8 = undefined,
legacy_session_id: []u8 = "",
cipher_suite: CipherSuite = @enumFromInt(0),
signature_scheme: proto.SignatureScheme = @enumFromInt(0),
named_group: proto.NamedGroup = @enumFromInt(0),
client_pub_key_buf: [max_pub_key_len]u8 = undefined,
client_pub_key: []u8 = "",
server_pub_key_buf: [max_pub_key_len]u8 = undefined,
server_pub_key: []u8 = "",
cipher: Cipher = undefined,
transcript: Transcript = .{},
rec_rdr: *RecordReaderT,
buffer: []u8,
const HandshakeT = @This();
pub fn init(buf: []u8, rec_rdr: *RecordReaderT) HandshakeT {
return .{
.rec_rdr = rec_rdr,
.buffer = buf,
};
}
fn writeAlert(h: *HandshakeT, stream: Stream, cph: ?*Cipher, err: anyerror) !void {
if (cph) |c| {
const cleartext = proto.alertFromError(err);
const ciphertext = try c.encrypt(h.buffer, .alert, &cleartext);
stream.writeAll(ciphertext) catch {};
} else {
const alert = record.header(.alert, 2) ++ proto.alertFromError(err);
stream.writeAll(&alert) catch {};
}
}
pub fn handshake(h: *HandshakeT, stream: Stream, opt: Options) !Cipher {
crypto.random.bytes(&h.server_random);
if (opt.auth) |a| {
// required signature scheme in client hello
h.signature_scheme = a.key.signature_scheme;
}
h.readClientHello() catch |err| {
try h.writeAlert(stream, null, err);
return err;
};
h.transcript.use(h.cipher_suite.hash());
const server_flight = brk: {
var w = record.Writer{ .buf = h.buffer };
const shared_key = h.sharedKey() catch |err| {
try h.writeAlert(stream, null, err);
return err;
};
{
const hello = try h.makeServerHello(w.getFree());
h.transcript.update(hello[record.header_len..]);
w.pos += hello.len;
}
{
const handshake_secret = h.transcript.handshakeSecret(shared_key);
h.cipher = try Cipher.initTls13(h.cipher_suite, handshake_secret, .server);
}
try w.writeRecord(.change_cipher_spec, &[_]u8{1});
{
const encrypted_extensions = &record.handshakeHeader(.encrypted_extensions, 2) ++ [_]u8{ 0, 0 };
h.transcript.update(encrypted_extensions);
try h.writeEncrypted(&w, encrypted_extensions);
}
if (opt.client_auth) |_| {
const certificate_request = try makeCertificateRequest(w.getPayload());
h.transcript.update(certificate_request);
try h.writeEncrypted(&w, certificate_request);
}
if (opt.auth) |a| {
const cm = CertificateBuilder{
.bundle = a.bundle,
.key = a.key,
.transcript = &h.transcript,
.side = .server,
};
{
const certificate = try cm.makeCertificate(w.getPayload());
h.transcript.update(certificate);
try h.writeEncrypted(&w, certificate);
}
{
const certificate_verify = try cm.makeCertificateVerify(w.getPayload());
h.transcript.update(certificate_verify);
try h.writeEncrypted(&w, certificate_verify);
}
}
{
const finished = try h.makeFinished(w.getPayload());
h.transcript.update(finished);
try h.writeEncrypted(&w, finished);
}
break :brk w.getWritten();
};
try stream.writeAll(server_flight);
var app_cipher = brk: {
const application_secret = h.transcript.applicationSecret();
break :brk try Cipher.initTls13(h.cipher_suite, application_secret, .server);
};
h.readClientFlight2(opt) catch |err| {
// Alert received from client
if (!mem.startsWith(u8, @errorName(err), "TlsAlert")) {
try h.writeAlert(stream, &app_cipher, err);
}
return err;
};
return app_cipher;
}
inline fn sharedKey(h: *HandshakeT) ![]const u8 {
var seed: [DhKeyPair.seed_len]u8 = undefined;
crypto.random.bytes(&seed);
var kp = try DhKeyPair.init(seed, supported_named_groups);
h.server_pub_key = dupe(&h.server_pub_key_buf, try kp.publicKey(h.named_group));
return try kp.sharedKey(h.named_group, h.client_pub_key);
}
fn readClientFlight2(h: *HandshakeT, opt: Options) !void {
var cleartext_buf = h.buffer;
var cleartext_buf_head: usize = 0;
var cleartext_buf_tail: usize = 0;
var handshake_state: proto.Handshake = .finished;
var cert: CertificateParser = undefined;
if (opt.client_auth) |client_auth| {
cert = .{ .root_ca = client_auth.root_ca.bundle, .host = "" };
handshake_state = .certificate;
}
outer: while (true) {
const rec = (try h.rec_rdr.next() orelse return error.EndOfStream);
if (rec.protocol_version != .tls_1_2 and rec.content_type != .alert)
return error.TlsProtocolVersion;
switch (rec.content_type) {
.change_cipher_spec => {
if (rec.payload.len != 1) return error.TlsUnexpectedMessage;
},
.application_data => {
const content_type, const cleartext = try h.cipher.decrypt(
cleartext_buf[cleartext_buf_tail..],
rec,
);
cleartext_buf_tail += cleartext.len;
if (cleartext_buf_tail > cleartext_buf.len) return error.TlsRecordOverflow;
var d = record.Decoder.init(content_type, cleartext_buf[cleartext_buf_head..cleartext_buf_tail]);
try d.expectContentType(.handshake);
while (!d.eof()) {
const start_idx = d.idx;
const handshake_type = try d.decode(proto.Handshake);
const length = try d.decode(u24);
if (length > cipher.max_cleartext_len)
return error.TlsRecordOverflow;
if (length > d.rest().len)
continue :outer; // fragmented handshake into multiple records
defer {
const handshake_payload = d.payload[start_idx..d.idx];
h.transcript.update(handshake_payload);
cleartext_buf_head += handshake_payload.len;
}
if (handshake_state != handshake_type)
return error.TlsUnexpectedMessage;
switch (handshake_type) {
.certificate => {
if (length == 4) {
// got empty certificate message
if (opt.client_auth.?.auth_type == .require)
return error.TlsCertificateRequired;
try d.skip(length);
handshake_state = .finished;
} else {
try cert.parseCertificate(&d, .tls_1_3);
handshake_state = .certificate_verify;
}
},
.certificate_verify => {
try cert.parseCertificateVerify(&d);
cert.verifySignature(h.transcript.clientCertificateVerify()) catch |err| return switch (err) {
error.TlsUnknownSignatureScheme => error.TlsIllegalParameter,
else => error.TlsDecryptError,
};
handshake_state = .finished;
},
.finished => {
const actual = try d.slice(length);
var buf: [Transcript.max_mac_length]u8 = undefined;
const expected = h.transcript.clientFinishedTls13(&buf);
if (!mem.eql(u8, expected, actual))
return if (expected.len == actual.len)
error.TlsDecryptError
else
error.TlsDecodeError;
return;
},
else => return error.TlsUnexpectedMessage,
}
}
cleartext_buf_head = 0;
cleartext_buf_tail = 0;
},
.alert => {
var d = rec.decoder();
return d.raiseAlert();
},
else => return error.TlsUnexpectedMessage,
}
}
}
fn makeFinished(h: *HandshakeT, buf: []u8) ![]const u8 {
var w = record.Writer{ .buf = buf };
const verify_data = h.transcript.serverFinishedTls13(w.getHandshakePayload());
try w.advanceHandshake(.finished, verify_data.len);
return w.getWritten();
}
/// Write encrypted handshake message into `w`
fn writeEncrypted(h: *HandshakeT, w: *record.Writer, cleartext: []const u8) !void {
const ciphertext = try h.cipher.encrypt(w.getFree(), .handshake, cleartext);
w.pos += ciphertext.len;
}
fn makeServerHello(h: *HandshakeT, buf: []u8) ![]const u8 {
const header_len = 9; // tls record header (5 bytes) and handshake header (4 bytes)
var w = record.Writer{ .buf = buf[header_len..] };
try w.writeEnum(proto.Version.tls_1_2);
try w.write(&h.server_random);
{
try w.writeInt(@as(u8, @intCast(h.legacy_session_id.len)));
if (h.legacy_session_id.len > 0) try w.write(h.legacy_session_id);
}
try w.writeEnum(h.cipher_suite);
try w.write(&[_]u8{0}); // compression method
var e = record.Writer{ .buf = buf[header_len + w.pos + 2 ..] };
{ // supported versions extension
try e.writeEnum(proto.Extension.supported_versions);
try e.writeInt(@as(u16, 2));
try e.writeEnum(proto.Version.tls_1_3);
}
{ // key share extension
const key_len: u16 = @intCast(h.server_pub_key.len);
try e.writeEnum(proto.Extension.key_share);
try e.writeInt(key_len + 4);
try e.writeEnum(h.named_group);
try e.writeInt(key_len);
try e.write(h.server_pub_key);
}
try w.writeInt(@as(u16, @intCast(e.pos))); // extensions length
const payload_len = w.pos + e.pos;
buf[0..header_len].* = record.header(.handshake, 4 + payload_len) ++
record.handshakeHeader(.server_hello, payload_len);
return buf[0 .. header_len + payload_len];
}
fn makeCertificateRequest(buf: []u8) ![]const u8 {
// handshake header + context length + extensions length
const header_len = 4 + 1 + 2;
// First write extensions, leave space for header.
var ext = record.Writer{ .buf = buf[header_len..] };
try ext.writeExtension(.signature_algorithms, common.supported_signature_algorithms);
var w = record.Writer{ .buf = buf };
try w.writeHandshakeHeader(.certificate_request, ext.pos + 3);
try w.writeInt(@as(u8, 0)); // certificate request context length = 0
try w.writeInt(@as(u16, @intCast(ext.pos))); // extensions length
assert(w.pos == header_len);
w.pos += ext.pos;
return w.getWritten();
}
fn readClientHello(h: *HandshakeT) !void {
var d = try h.rec_rdr.nextDecoder();
try d.expectContentType(.handshake);
h.transcript.update(d.payload);
const handshake_type = try d.decode(proto.Handshake);
if (handshake_type != .client_hello) return error.TlsUnexpectedMessage;
_ = try d.decode(u24); // handshake length
if (try d.decode(proto.Version) != .tls_1_2) return error.TlsProtocolVersion;
h.client_random = try d.array(32);
{ // legacy session id
const len = try d.decode(u8);
h.legacy_session_id = dupe(&h.legacy_session_id_buf, try d.slice(len));
}
{ // cipher suites
const end_idx = try d.decode(u16) + d.idx;
while (d.idx < end_idx) {
const cipher_suite = try d.decode(CipherSuite);
if (cipher_suites.includes(cipher_suites.tls13, cipher_suite) and
@intFromEnum(h.cipher_suite) == 0)
{
h.cipher_suite = cipher_suite;
}
}
if (@intFromEnum(h.cipher_suite) == 0)
return error.TlsHandshakeFailure;
}
try d.skip(2); // compression methods
var key_share_received = false;
// extensions
const extensions_end_idx = try d.decode(u16) + d.idx;
while (d.idx < extensions_end_idx) {
const extension_type = try d.decode(proto.Extension);
const extension_len = try d.decode(u16);
switch (extension_type) {
.supported_versions => {
var tls_1_3_supported = false;
const end_idx = try d.decode(u8) + d.idx;
while (d.idx < end_idx) {
if (try d.decode(proto.Version) == proto.Version.tls_1_3) {
tls_1_3_supported = true;
}
}
if (!tls_1_3_supported) return error.TlsProtocolVersion;
},
.key_share => {
if (extension_len == 0) return error.TlsDecodeError;
key_share_received = true;
var selected_named_group_idx = supported_named_groups.len;
const end_idx = try d.decode(u16) + d.idx;
while (d.idx < end_idx) {
const named_group = try d.decode(proto.NamedGroup);
switch (@intFromEnum(named_group)) {
0x0001...0x0016,
0x001a...0x001c,
0xff01...0xff02,
=> return error.TlsIllegalParameter,
else => {},
}
const client_pub_key = try d.slice(try d.decode(u16));
for (supported_named_groups, 0..) |supported, idx| {
if (named_group == supported and idx < selected_named_group_idx) {
h.named_group = named_group;
h.client_pub_key = dupe(&h.client_pub_key_buf, client_pub_key);
selected_named_group_idx = idx;
}
}
}
if (@intFromEnum(h.named_group) == 0)
return error.TlsIllegalParameter;
},
.supported_groups => {
const end_idx = try d.decode(u16) + d.idx;
while (d.idx < end_idx) {
const named_group = try d.decode(proto.NamedGroup);
switch (@intFromEnum(named_group)) {
0x0001...0x0016,
0x001a...0x001c,
0xff01...0xff02,
=> return error.TlsIllegalParameter,
else => {},
}
}
},
.signature_algorithms => {
if (@intFromEnum(h.signature_scheme) == 0) {
try d.skip(extension_len);
} else {
var found = false;
const list_len = try d.decode(u16);
if (list_len == 0) return error.TlsDecodeError;
const end_idx = list_len + d.idx;
while (d.idx < end_idx) {
const signature_scheme = try d.decode(proto.SignatureScheme);
if (signature_scheme == h.signature_scheme) found = true;
}
if (!found) return error.TlsHandshakeFailure;
}
},
else => {
try d.skip(extension_len);
},
}
}
if (!key_share_received) return error.TlsMissingExtension;
if (@intFromEnum(h.named_group) == 0) return error.TlsIllegalParameter;
}
};
}
const testing = std.testing;
const data13 = @import("testdata/tls13.zig");
const testu = @import("testu.zig");
fn testReader(data: []const u8) record.Reader(std.io.FixedBufferStream([]const u8)) {
return record.reader(std.io.fixedBufferStream(data));
}
const TestHandshake = Handshake(std.io.FixedBufferStream([]const u8));
test "read client hello" {
var buffer: [1024]u8 = undefined;
var rec_rdr = testReader(&data13.client_hello);
var h = TestHandshake.init(&buffer, &rec_rdr);
h.signature_scheme = .ecdsa_secp521r1_sha512; // this must be supported in signature_algorithms extension
try h.readClientHello();
try testing.expectEqual(CipherSuite.AES_256_GCM_SHA384, h.cipher_suite);
try testing.expectEqual(.x25519, h.named_group);
try testing.expectEqualSlices(u8, &data13.client_random, &h.client_random);
try testing.expectEqualSlices(u8, &data13.client_public_key, h.client_pub_key);
}
test "make server hello" {
var buffer: [128]u8 = undefined;
var h = TestHandshake.init(&buffer, undefined);
h.cipher_suite = .AES_256_GCM_SHA384;
testu.fillFrom(&h.server_random, 0);
testu.fillFrom(&h.server_pub_key_buf, 0x20);
h.named_group = .x25519;
h.server_pub_key = h.server_pub_key_buf[0..32];
const actual = try h.makeServerHello(&buffer);
const expected = &testu.hexToBytes(
\\ 16 03 03 00 5a 02 00 00 56
\\ 03 03
\\ 00 01 02 03 04 05 06 07 08 09 0a 0b 0c 0d 0e 0f 10 11 12 13 14 15 16 17 18 19 1a 1b 1c 1d 1e 1f
\\ 00
\\ 13 02 00
\\ 00 2e 00 2b 00 02 03 04
\\ 00 33 00 24 00 1d 00 20
\\ 20 21 22 23 24 25 26 27 28 29 2a 2b 2c 2d 2e 2f 30 31 32 33 34 35 36 37 38 39 3a 3b 3c 3d 3e 3f
);
try testing.expectEqualSlices(u8, expected, actual);
}
test "make certificate request" {
var buffer: [32]u8 = undefined;
const expected = testu.hexToBytes("0d 00 00 1b" ++ // handshake header
"00 00 18" ++ // extension length
"00 0d" ++ // signature algorithms extension
"00 14" ++ // extension length
"00 12" ++ // list length 6 * 2 bytes
"04 03 05 03 08 04 08 05 08 06 08 07 02 01 04 01 05 01" // signature schemes
);
const actual = try TestHandshake.makeCertificateRequest(&buffer);
try testing.expectEqualSlices(u8, &expected, actual);
}