const std = @import("std"); const assert = std.debug.assert; const mem = std.mem; const proto = @import("protocol.zig"); const cipher = @import("cipher.zig"); const Cipher = cipher.Cipher; const record = @import("record.zig"); pub const header_len = 5; pub fn header(content_type: proto.ContentType, payload_len: usize) [header_len]u8 { const int2 = std.crypto.tls.int2; return [1]u8{@intFromEnum(content_type)} ++ int2(@intFromEnum(proto.Version.tls_1_2)) ++ int2(@intCast(payload_len)); } pub fn handshakeHeader(handshake_type: proto.Handshake, payload_len: usize) [4]u8 { const int3 = std.crypto.tls.int3; return [1]u8{@intFromEnum(handshake_type)} ++ int3(@intCast(payload_len)); } pub fn reader(inner_reader: anytype) Reader(@TypeOf(inner_reader)) { return .{ .inner_reader = inner_reader }; } pub fn Reader(comptime InnerReader: type) type { return struct { inner_reader: InnerReader, buffer: [cipher.max_ciphertext_record_len]u8 = undefined, start: usize = 0, end: usize = 0, const ReaderT = @This(); pub fn nextDecoder(r: *ReaderT) !Decoder { const rec = (try r.next()) orelse return error.EndOfStream; if (@intFromEnum(rec.protocol_version) != 0x0300 and @intFromEnum(rec.protocol_version) != 0x0301 and rec.protocol_version != .tls_1_2) return error.TlsBadVersion; return .{ .content_type = rec.content_type, .payload = rec.payload, }; } pub fn contentType(buf: []const u8) proto.ContentType { return @enumFromInt(buf[0]); } pub fn protocolVersion(buf: []const u8) proto.Version { return @enumFromInt(mem.readInt(u16, buf[1..3], .big)); } pub fn next(r: *ReaderT) !?Record { while (true) { const buffer = r.buffer[r.start..r.end]; // If we have 5 bytes header. if (buffer.len >= record.header_len) { const record_header = buffer[0..record.header_len]; const payload_len = mem.readInt(u16, record_header[3..5], .big); if (payload_len > cipher.max_ciphertext_len) return error.TlsRecordOverflow; const record_len = record.header_len + payload_len; // If we have whole record if (buffer.len >= record_len) { r.start += record_len; return Record.init(buffer[0..record_len]); } } { // Move dirty part to the start of the buffer. const n = r.end - r.start; if (n > 0 and r.start > 0) { if (r.start > n) { @memcpy(r.buffer[0..n], r.buffer[r.start..][0..n]); } else { mem.copyForwards(u8, r.buffer[0..n], r.buffer[r.start..][0..n]); } } r.start = 0; r.end = n; } { // Read more from inner_reader. const n = try r.inner_reader.read(r.buffer[r.end..]); if (n == 0) return null; r.end += n; } } } pub fn nextDecrypt(r: *ReaderT, cph: *Cipher) !?struct { proto.ContentType, []const u8 } { const rec = (try r.next()) orelse return null; if (rec.protocol_version != .tls_1_2) return error.TlsBadVersion; return try cph.decrypt( // Reuse reader buffer for cleartext. `rec.header` and // `rec.payload`(ciphertext) are also pointing somewhere in // this buffer. Decrypter is first reading then writing a // block, cleartext has less length then ciphertext, // cleartext starts from the beginning of the buffer, so // ciphertext is always ahead of cleartext. r.buffer[0..r.start], rec, ); } pub fn hasMore(r: *ReaderT) bool { return r.end > r.start; } }; } pub const Record = struct { content_type: proto.ContentType, protocol_version: proto.Version = .tls_1_2, header: []const u8, payload: []const u8, pub fn init(buffer: []const u8) Record { return .{ .content_type = @enumFromInt(buffer[0]), .protocol_version = @enumFromInt(mem.readInt(u16, buffer[1..3], .big)), .header = buffer[0..record.header_len], .payload = buffer[record.header_len..], }; } pub fn decoder(r: @This()) Decoder { return Decoder.init(r.content_type, @constCast(r.payload)); } }; pub const Decoder = struct { content_type: proto.ContentType, payload: []const u8, idx: usize = 0, pub fn init(content_type: proto.ContentType, payload: []u8) Decoder { return .{ .content_type = content_type, .payload = payload, }; } pub fn decode(d: *Decoder, comptime T: type) !T { switch (@typeInfo(T)) { .Int => |info| switch (info.bits) { 8 => { try skip(d, 1); return d.payload[d.idx - 1]; }, 16 => { try skip(d, 2); const b0: u16 = d.payload[d.idx - 2]; const b1: u16 = d.payload[d.idx - 1]; return (b0 << 8) | b1; }, 24 => { try skip(d, 3); const b0: u24 = d.payload[d.idx - 3]; const b1: u24 = d.payload[d.idx - 2]; const b2: u24 = d.payload[d.idx - 1]; return (b0 << 16) | (b1 << 8) | b2; }, else => @compileError("unsupported int type: " ++ @typeName(T)), }, .Enum => |info| { const int = try d.decode(info.tag_type); if (info.is_exhaustive) @compileError("exhaustive enum cannot be used"); return @as(T, @enumFromInt(int)); }, else => @compileError("unsupported type: " ++ @typeName(T)), } } pub fn array(d: *Decoder, comptime len: usize) ![len]u8 { try d.skip(len); return d.payload[d.idx - len ..][0..len].*; } pub fn slice(d: *Decoder, len: usize) ![]const u8 { try d.skip(len); return d.payload[d.idx - len ..][0..len]; } pub fn skip(d: *Decoder, amt: usize) !void { if (d.idx + amt > d.payload.len) return error.TlsDecodeError; d.idx += amt; } pub fn rest(d: Decoder) []const u8 { return d.payload[d.idx..]; } pub fn eof(d: Decoder) bool { return d.idx == d.payload.len; } pub fn expectContentType(d: *Decoder, content_type: proto.ContentType) !void { if (d.content_type == content_type) return; switch (d.content_type) { .alert => try d.raiseAlert(), else => return error.TlsUnexpectedMessage, } } pub fn raiseAlert(d: *Decoder) !void { if (d.payload.len < 2) return error.TlsUnexpectedMessage; try proto.Alert.parse(try d.array(2)).toError(); return error.TlsAlertCloseNotify; } }; const testing = std.testing; const data12 = @import("testdata/tls12.zig"); const testu = @import("testu.zig"); const CipherSuite = @import("cipher.zig").CipherSuite; test Reader { var fbs = std.io.fixedBufferStream(&data12.server_responses); var rdr = reader(fbs.reader()); const expected = [_]struct { content_type: proto.ContentType, payload_len: usize, }{ .{ .content_type = .handshake, .payload_len = 49 }, .{ .content_type = .handshake, .payload_len = 815 }, .{ .content_type = .handshake, .payload_len = 300 }, .{ .content_type = .handshake, .payload_len = 4 }, .{ .content_type = .change_cipher_spec, .payload_len = 1 }, .{ .content_type = .handshake, .payload_len = 64 }, }; for (expected) |e| { const rec = (try rdr.next()).?; try testing.expectEqual(e.content_type, rec.content_type); try testing.expectEqual(e.payload_len, rec.payload.len); try testing.expectEqual(.tls_1_2, rec.protocol_version); } } test Decoder { var fbs = std.io.fixedBufferStream(&data12.server_responses); var rdr = reader(fbs.reader()); var d = (try rdr.nextDecoder()); try testing.expectEqual(.handshake, d.content_type); try testing.expectEqual(.server_hello, try d.decode(proto.Handshake)); try testing.expectEqual(45, try d.decode(u24)); // length try testing.expectEqual(.tls_1_2, try d.decode(proto.Version)); try testing.expectEqualStrings( &testu.hexToBytes("707172737475767778797a7b7c7d7e7f808182838485868788898a8b8c8d8e8f"), try d.slice(32), ); // server random try testing.expectEqual(0, try d.decode(u8)); // session id len try testing.expectEqual(.ECDHE_RSA_WITH_AES_128_CBC_SHA, try d.decode(CipherSuite)); try testing.expectEqual(0, try d.decode(u8)); // compression method try testing.expectEqual(5, try d.decode(u16)); // extension length try testing.expectEqual(5, d.rest().len); try d.skip(5); try testing.expect(d.eof()); } pub const Writer = struct { buf: []u8, pos: usize = 0, pub fn write(self: *Writer, data: []const u8) !void { defer self.pos += data.len; if (self.pos + data.len > self.buf.len) return error.BufferOverflow; @memcpy(self.buf[self.pos..][0..data.len], data); } pub fn writeByte(self: *Writer, b: u8) !void { defer self.pos += 1; if (self.pos == self.buf.len) return error.BufferOverflow; self.buf[self.pos] = b; } pub fn writeEnum(self: *Writer, value: anytype) !void { try self.writeInt(@intFromEnum(value)); } pub fn writeInt(self: *Writer, value: anytype) !void { const IntT = @TypeOf(value); const bytes = @divExact(@typeInfo(IntT).Int.bits, 8); const free = self.buf[self.pos..]; if (free.len < bytes) return error.BufferOverflow; mem.writeInt(IntT, free[0..bytes], value, .big); self.pos += bytes; } pub fn writeHandshakeHeader(self: *Writer, handshake_type: proto.Handshake, payload_len: usize) !void { try self.write(&record.handshakeHeader(handshake_type, payload_len)); } /// Should be used after writing handshake payload in buffer provided by `getHandshakePayload`. pub fn advanceHandshake(self: *Writer, handshake_type: proto.Handshake, payload_len: usize) !void { try self.write(&record.handshakeHeader(handshake_type, payload_len)); self.pos += payload_len; } /// Record payload is already written by using buffer space from `getPayload`. /// Now when we know payload len we can write record header and advance over payload. pub fn advanceRecord(self: *Writer, content_type: proto.ContentType, payload_len: usize) !void { try self.write(&record.header(content_type, payload_len)); self.pos += payload_len; } pub fn writeRecord(self: *Writer, content_type: proto.ContentType, payload: []const u8) !void { try self.write(&record.header(content_type, payload.len)); try self.write(payload); } /// Preserves space for record header and returns buffer free space. pub fn getPayload(self: *Writer) []u8 { return self.buf[self.pos + record.header_len ..]; } /// Preserves space for handshake header and returns buffer free space. pub fn getHandshakePayload(self: *Writer) []u8 { return self.buf[self.pos + 4 ..]; } pub fn getWritten(self: *Writer) []const u8 { return self.buf[0..self.pos]; } pub fn getFree(self: *Writer) []u8 { return self.buf[self.pos..]; } pub fn writeEnumArray(self: *Writer, comptime E: type, tags: []const E) !void { assert(@sizeOf(E) == 2); try self.writeInt(@as(u16, @intCast(tags.len * 2))); for (tags) |t| { try self.writeEnum(t); } } pub fn writeExtension( self: *Writer, comptime et: proto.Extension, tags: anytype, ) !void { try self.writeEnum(et); if (et == .supported_versions) { try self.writeInt(@as(u16, @intCast(tags.len * 2 + 1))); try self.writeInt(@as(u8, @intCast(tags.len * 2))); } else { try self.writeInt(@as(u16, @intCast(tags.len * 2 + 2))); try self.writeInt(@as(u16, @intCast(tags.len * 2))); } for (tags) |t| { try self.writeEnum(t); } } pub fn writeKeyShare( self: *Writer, named_groups: []const proto.NamedGroup, keys: []const []const u8, ) !void { assert(named_groups.len == keys.len); try self.writeEnum(proto.Extension.key_share); var l: usize = 0; for (keys) |key| { l += key.len + 4; } try self.writeInt(@as(u16, @intCast(l + 2))); try self.writeInt(@as(u16, @intCast(l))); for (named_groups, 0..) |ng, i| { const key = keys[i]; try self.writeEnum(ng); try self.writeInt(@as(u16, @intCast(key.len))); try self.write(key); } } pub fn writeServerName(self: *Writer, host: []const u8) !void { const host_len: u16 = @intCast(host.len); try self.writeEnum(proto.Extension.server_name); try self.writeInt(host_len + 5); // byte length of extension payload try self.writeInt(host_len + 3); // server_name_list byte count try self.writeByte(0); // name type try self.writeInt(host_len); try self.write(host); } }; test "Writer" { var buf: [16]u8 = undefined; var w = Writer{ .buf = &buf }; try w.write("ab"); try w.writeEnum(proto.Curve.named_curve); try w.writeEnum(proto.NamedGroup.x25519); try w.writeInt(@as(u16, 0x1234)); try testing.expectEqualSlices(u8, &[_]u8{ 'a', 'b', 0x03, 0x00, 0x1d, 0x12, 0x34 }, w.getWritten()); }