diff --git a/src/browser/xhr/xhr.zig b/src/browser/xhr/xhr.zig index b974fdad..3d532930 100644 --- a/src/browser/xhr/xhr.zig +++ b/src/browser/xhr/xhr.zig @@ -813,7 +813,7 @@ test "Browser.XHR.XMLHttpRequest" { .{ "req.status", "200" }, .{ "req.statusText", "OK" }, .{ "req.getResponseHeader('Content-Type')", "text/html; charset=utf-8" }, - .{ "req.getAllResponseHeaders().length", "61" }, + .{ "req.getAllResponseHeaders().length", "80" }, .{ "req.responseText.length", "100" }, .{ "req.response.length == req.responseText.length", "true" }, .{ "req.responseXML instanceof Document", "true" }, diff --git a/src/http/client.zig b/src/http/client.zig index 5ad181c8..e6bf9cf3 100644 --- a/src/http/client.zig +++ b/src/http/client.zig @@ -48,9 +48,12 @@ pub const Client = struct { state_pool: StatePool, root_ca: tls.config.CertBundle, tls_verify_host: bool = true, + idle_connections: IdleConnections, + connection_pool: std.heap.MemoryPool(Connection), const Opts = struct { tls_verify_host: bool = true, + max_idle_connection: usize = 10, }; pub fn init(allocator: Allocator, max_concurrent: usize, opts: Opts) !Client { @@ -60,11 +63,16 @@ pub const Client = struct { const state_pool = try StatePool.init(allocator, max_concurrent); errdefer state_pool.deinit(allocator); + const idle_connections = IdleConnections.init(allocator, opts.max_idle_connection); + errdefer idle_connections.deinit(); + return .{ .root_ca = root_ca, .allocator = allocator, .state_pool = state_pool, + .idle_connections = idle_connections, .tls_verify_host = opts.tls_verify_host, + .connection_pool = std.heap.MemoryPool(Connection).init(allocator), }; } @@ -74,6 +82,8 @@ pub const Client = struct { self.root_ca.deinit(allocator); } self.state_pool.deinit(allocator); + self.idle_connections.deinit(); + self.connection_pool.deinit(); } pub fn request(self: *Client, method: Request.Method, uri: *const Uri) !Request { @@ -88,6 +98,47 @@ pub const Client = struct { } }; +// We assume most connections are going to end up in the IdleConnnection pool, +// so this always end up in on the heap (as a *Connection) using the client's +// connection_pool MemoryPool. +// You'll notice that we have both this "Connection", and that both the SyncHandler +// and the AsyncHandler have a "Conn". The "Conn" are a specialized version +// of this "Connection". The SyncHandler.Conn provides a synchronous API over +// the socket/tls. The AsyncHandler.Conn provides an asynchronous API over these. +// +// The Request and IdleConnections are the only ones that deal directly with this +// "Connection" - and the variable name is "connection". +// +// The Sync/Async handlers deal only with their respective "Conn" - and the +// variable name is "conn". +const Connection = struct { + port: u16, + blocking: bool, + tls: ?TLSClient, + host: []const u8, + socket: posix.socket_t, + + const TLSClient = union(enum) { + blocking: tls.Connection(std.net.Stream), + nonblocking: tls.nb.Client(), + + fn close(self: *TLSClient) void { + switch (self.*) { + .blocking => |*tls_client| tls_client.close() catch {}, + .nonblocking => |*tls_client| tls_client.deinit(), + } + } + }; + + fn deinit(self: *Connection, allocator: Allocator) void { + allocator.free(self.host); + if (self.tls) |*tls_client| { + tls_client.close(); + } + posix.close(self.socket); + } +}; + // Represents a request. Can be used to make either a synchronous or an // asynchronous request. When a synchronous request is made, `request.deinit()` // should be called once the response is no longer needed. @@ -95,8 +146,6 @@ pub const Client = struct { // (but request.deinit() should still be called to discard the request // before the `sendAsync` is called). pub const Request = struct { - // Whether or not TLS is being used. - secure: bool, // The HTTP Method to use method: Method, @@ -118,11 +167,28 @@ pub const Request = struct { // List of request headers headers: std.ArrayListUnmanaged(std.http.Header), + // whether or not we expect this connection to be secure + _secure: bool, + + // whether or not we should keep the underlying socket open and and usable + // for other requests + _keepalive: bool, + + _port: u16, + + _host: []const u8, + + // whether or not the socket comes from the connection pool. If it does, + // and we get an error sending the header, we might retry on a new connection + // because it's possible the other closed the connection, and that's no + // reason to fail the request. + _connection_from_keepalive: bool, + // Used to limit the # of redirects we'll follow _redirect_count: u16, - // The underlying socket - _socket: ?posix.socket_t, + // The actual connection, including the socket and, optionally, a TLS client + _connection: ?*Connection, // Pooled buffers and arena _state: *State, @@ -149,38 +215,100 @@ pub const Request = struct { pub fn format(self: Method, comptime _: []const u8, _: std.fmt.FormatOptions, writer: anytype) !void { return writer.writeAll(@tagName(self)); } + + fn safeToRetry(self: Method) bool { + return self == .GET or self == .HEAD or self == .OPTIONS; + } }; fn init(client: *Client, state: *State, method: Method, uri: *const Uri) !Request { - if (uri.host == null) { - return error.UriMissingHost; - } - + const secure, const host, const port = try decomposeURL(uri); return .{ - .secure = true, .uri = uri, - .method = method, .body = null, .headers = .{}, + .method = method, .arena = state.arena.allocator(), - ._socket = null, + ._secure = secure, + ._host = host, + ._port = port, ._state = state, ._client = client, + ._connection = null, + ._keepalive = false, ._redirect_count = 0, ._has_host_header = false, + ._connection_from_keepalive = false, ._tls_verify_host = client.tls_verify_host, }; } pub fn deinit(self: *Request) void { - if (self._socket) |socket| { - posix.close(socket); - self._socket = null; - } + self.releaseConnection(); _ = self._state.reset(); self._client.state_pool.release(self._state); } + fn decomposeURL(uri: *const Uri) !struct { bool, []const u8, u16 } { + if (uri.host == null) { + return error.UriMissingHost; + } + + var secure: bool = undefined; + + const scheme = uri.scheme; + if (std.ascii.eqlIgnoreCase(scheme, "https")) { + secure = true; + } else if (std.ascii.eqlIgnoreCase(scheme, "http")) { + secure = false; + } else { + return error.UnsupportedUriScheme; + } + + const host = uri.host.?.percent_encoded; + const port: u16 = uri.port orelse if (secure) 443 else 80; + + return .{ secure, host, port }; + } + + // Called in deinit, but also called when we're redirecting to another page + fn releaseConnection(self: *Request) void { + const connection = self._connection orelse return; + self._connection = null; + + if (self._keepalive == false) { + self.destroyConnection(connection); + return; + } + + self._client.idle_connections.put(connection) catch |err| { + self.destroyConnection(connection); + log.err("failed to release connection to pool: {}", .{err}); + }; + } + + fn createConnection(self: *Request, socket: posix.socket_t, blocking: bool) !*Connection { + const client = self._client; + const connection = try client.connection_pool.create(); + errdefer client.connection_pool.destroy(connection); + + connection.* = .{ + .socket = socket, + .tls = null, + .port = self._port, + .blocking = blocking, + .host = try client.allocator.dupe(u8, self._host), + }; + + return connection; + } + + fn destroyConnection(self: *Request, connection: *Connection) void { + const client = self._client; + connection.deinit(client.allocator); + client.connection_pool.destroy(connection); + } + const AddHeaderOpts = struct { dupe_name: bool = false, dupe_value: bool = false, @@ -216,19 +344,57 @@ pub const Request = struct { } try self.prepareInitialSend(); - return self.doSendSync(); + return self.doSendSync(true); } // Called internally, follows a redirect. fn redirectSync(self: *Request, redirect: Reader.Redirect) anyerror!Response { try self.prepareToRedirect(redirect); - return self.doSendSync(); + return self.doSendSync(true); } - fn doSendSync(self: *Request) anyerror!Response { - const socket, const address = try self.createSocket(true); + fn doSendSync(self: *Request, use_pool: bool) anyerror!Response { + if (use_pool) { + if (self.findExistingConnection(true)) |connection| { + self._connection = connection; + self._connection_from_keepalive = true; + } + } + + if (self._connection == null) { + const socket, const address = try self.createSocket(true); + + posix.connect(socket, &address.any, address.getOsSockLen()) catch |err| { + posix.close(socket); + return err; + }; + + const connection = self.createConnection(socket, true) catch |err| { + posix.close(socket); + return err; + }; + + errdefer self.destroyConnection(connection); + + if (self._secure) { + connection.tls = .{ + .blocking = try tls.client(std.net.Stream{ .handle = socket }, .{ + .host = connection.host, + .root_ca = self._client.root_ca, + .insecure_skip_verify = self._tls_verify_host == false, + // .key_log_callback = tls.config.key_log.callback, + }), + }; + } + + self._connection = connection; + self._connection_from_keepalive = false; + } + + errdefer self.destroyConnection(self._connection.?); + var handler = SyncHandler{ .request = self }; - return handler.send(socket, address) catch |err| { + return handler.send() catch |err| { log.warn("HTTP error: {any} ({any} {any} {d})", .{ err, self.method, self.uri, self._redirect_count }); return err; }; @@ -243,49 +409,91 @@ pub const Request = struct { self._tls_verify_host = override; } try self.prepareInitialSend(); - return self.doSendAsync(loop, handler); + return self.doSendAsync(loop, handler, true); } pub fn redirectAsync(self: *Request, redirect: Reader.Redirect, loop: anytype, handler: anytype) !void { try self.prepareToRedirect(redirect); - return self.doSendAsync(loop, handler); + return self.doSendAsync(loop, handler, true); } - fn doSendAsync(self: *Request, loop: anytype, handler: anytype) !void { - const socket, const address = try self.createSocket(false); + fn doSendAsync(self: *Request, loop: anytype, handler: anytype, use_pool: bool) !void { + if (use_pool) { + if (self.findExistingConnection(false)) |connection| { + self._connection = connection; + self._connection_from_keepalive = true; + } + } + + var address: std.net.Address = undefined; + if (self._connection == null) { + const socket, address = try self.createSocket(false); + errdefer posix.close(socket); + + // It seems wrong to set self._connection here. While we have a + // connection, it isn't yet connected. PLUS, if this is a secure + // connection, we also don't have a handshake. + // But, request._connection only ever gets released to the idle pool + // when request._keepalive == true. And this can only be true _after_ + // we've processed the request - at which point, we'd obviously be + // connected + handshake. + self._connection = try self.createConnection(socket, false); + self._connection_from_keepalive = false; + } + + const connection = self._connection.?; + errdefer self.destroyConnection(connection); + const AsyncHandlerT = AsyncHandler(@TypeOf(handler), @TypeOf(loop)); const async_handler = try self.arena.create(AsyncHandlerT); + const state = self._state; async_handler.* = .{ .loop = loop, - .socket = socket, .request = self, .handler = handler, - .read_buf = self._state.read_buf, - .write_buf = self._state.write_buf, - .reader = Reader.init(self._state), - .connection = .{ .handler = async_handler, .protocol = .{ .plain = {} } }, + .read_buf = state.read_buf, + .write_buf = state.write_buf, + .reader = self.newReader(), + .socket = connection.socket, + .conn = .{ .handler = async_handler, .protocol = .{ .plain = {} } }, }; - if (self.secure) { - async_handler.connection.protocol = .{ - .secure = .{ - .tls_client = try tls.nb.Client().init(self.arena, .{ - .host = self.host(), - .root_ca = self._client.root_ca, - .insecure_skip_verify = self._tls_verify_host == false, - // .key_log_callback = tls.config.key_log.callback - }), - }, + if (self._secure) { + connection.tls = .{ + .nonblocking = try tls.nb.Client().init(self._client.allocator, .{ + .host = connection.host, + .root_ca = self._client.root_ca, + .insecure_skip_verify = self._tls_verify_host == false, + .key_log_callback = tls.config.key_log.callback, + }), + }; + + async_handler.conn.protocol = .{ + .secure = &connection.tls.?.nonblocking, }; } - try loop.connect(AsyncHandlerT, async_handler, &async_handler.read_completion, AsyncHandlerT.connected, socket, address); + if (self._connection_from_keepalive) { + // we're already connected + return async_handler.conn.connected(); + } + + return loop.connect( + AsyncHandlerT, + async_handler, + &async_handler.read_completion, + AsyncHandlerT.connected, + connection.socket, + address, + ); + } + + fn newReader(self: *Request) Reader { + return Reader.init(self._state, &self._keepalive); } // Does additional setup of the request for the firsts (i.e. non-redirect) call. fn prepareInitialSend(self: *Request) !void { - try self.verifyUri(); - const arena = self.arena; if (self.body) |body| { const cl = try std.fmt.allocPrint(arena, "{d}", .{body.len}); @@ -293,7 +501,7 @@ pub const Request = struct { } if (!self._has_host_header) { - try self.headers.append(arena, .{ .name = "Host", .value = self.host() }); + try self.headers.append(arena, .{ .name = "Host", .value = self._host }); } try self.headers.append(arena, .{ .name = "User-Agent", .value = "Lightpanda/1.0" }); @@ -301,8 +509,7 @@ pub const Request = struct { // Sets up the request for redirecting. fn prepareToRedirect(self: *Request, redirect: Reader.Redirect) !void { - posix.close(self._socket.?); - self._socket = null; + self.releaseConnection(); // CANNOT reset the arena (╥﹏╥) // We need it for self.uri (which we're about to use to resolve @@ -312,15 +519,19 @@ pub const Request = struct { if (redirect_count == 10) { return error.TooManyRedirects; } - self._redirect_count = redirect_count + 1; - var buf = try self.arena.alloc(u8, 1024); + var buf = try self.arena.alloc(u8, 2048); - const previous_host = self.host(); + const previous_host = self._host; self.redirect_uri = try self.uri.resolve_inplace(redirect.location, &buf); self.uri = &self.redirect_uri.?; - try self.verifyUri(); + const secure, const host, const port = try decomposeURL(self.uri); + self._host = host; + self._port = port; + self._secure = secure; + self._keepalive = false; + self._redirect_count = redirect_count + 1; if (redirect.use_get) { // Some redirect status codes _require_ that we switch the method @@ -342,37 +553,35 @@ pub const Request = struct { } } - const new_host = self.host(); - if (std.mem.eql(u8, previous_host, new_host) == false) { + if (std.mem.eql(u8, previous_host, host) == false) { for (self.headers.items) |*hdr| { if (std.mem.eql(u8, hdr.name, "Host")) { - hdr.value = new_host; + hdr.value = host; break; } } } } - // extracted because we re-verify this on redirect - fn verifyUri(self: *Request) !void { - const scheme = self.uri.scheme; - if (std.ascii.eqlIgnoreCase(scheme, "https")) { - self.secure = true; - return; - } - if (std.ascii.eqlIgnoreCase(scheme, "http")) { - self.secure = false; - return; + fn findExistingConnection(self: *Request, blocking: bool) ?*Connection { + // This is being overly cautious, but it's a bit risky to re-use + // connections for other methods. It isn't so much re-using the + // connection that's the issue, it's dealing with a write error + // when trying to send the request and deciding whether or not we + // should retry the request. + if (self.method.safeToRetry() == false) { + return null; } - return error.UnsupportedUriScheme; + if (self.body != null) { + return null; + } + + return self._client.idle_connections.get(self._secure, self._host, self._port, blocking); } fn createSocket(self: *Request, blocking: bool) !struct { posix.socket_t, std.net.Address } { - const host_ = self.host(); - const port: u16 = self.uri.port orelse if (self.secure) 443 else 80; - - const addresses = try std.net.getAddressList(self.arena, host_, port); + const addresses = try std.net.getAddressList(self.arena, self._host, self._port); if (addresses.addrs.len == 0) { return error.UnknownHostName; } @@ -387,8 +596,6 @@ pub const Request = struct { if (@hasDecl(posix.TCP, "NODELAY")) { try posix.setsockopt(socket, posix.IPPROTO.TCP, posix.TCP.NODELAY, &std.mem.toBytes(@as(c_int, 1))); } - - self._socket = socket; return .{ socket, address }; } @@ -407,15 +614,9 @@ pub const Request = struct { try writer.writeAll(header.value); try writer.writeAll("\r\n"); } - // TODO: remove this once we have a connection pool - try writer.writeAll("Connection: Close\r\n"); try writer.writeAll("\r\n"); return buf[0..fbs.pos]; } - - fn host(self: *const Request) []const u8 { - return self.uri.host.?.percent_encoded; - } }; // Handles asynchronous requests @@ -459,8 +660,9 @@ fn AsyncHandler(comptime H: type, comptime L: type) type { // Used to help us know if we're writing the header or the body; state: SendState = .handshake, - // Abstraction over TLS and plain text socket - connection: Connection, + // Abstraction over TLS and plain text socket, this is a version of + // the request._connection (which is a *Connection) that is async-specific. + conn: Conn, // This will be != null when we're supposed to redirect AND we've // drained the response body. We need this as a field, because we'll @@ -484,13 +686,12 @@ fn AsyncHandler(comptime H: type, comptime L: type) type { }; fn deinit(self: *Self) void { - self.connection.deinit(); self.request.deinit(); } fn connected(self: *Self, _: *IO.Completion, result: IO.ConnectError!void) void { result catch |err| return self.handleError("Connection failed", err); - self.connection.connected() catch |err| { + self.conn.connected() catch |err| { self.handleError("connected handler error", err); }; } @@ -553,7 +754,7 @@ fn AsyncHandler(comptime H: type, comptime L: type) type { return; } - self.connection.sent() catch |err| { + self.conn.sent() catch |err| { self.handleError("send handling", err); }; } @@ -586,10 +787,13 @@ fn AsyncHandler(comptime H: type, comptime L: type) type { return self.handleError("Read error", err); }; if (n == 0) { + if (self.maybeRetryRequest()) { + return; + } return self.handleError("Connection closed", error.ConnectionResetByPeer); } - const status = self.connection.received(self.read_buf[0 .. self.read_pos + n]) catch |err| { + const status = self.conn.received(self.read_buf[0 .. self.read_pos + n]) catch |err| { self.handleError("data processing", err); return; }; @@ -607,13 +811,43 @@ fn AsyncHandler(comptime H: type, comptime L: type) type { return; }; // redirectAsync has given up any claim to the request, - // including the socket. We just need to clean up our - // tls_client. - self.connection.deinit(); + // including the socket. }, } } + // If our socket came from the connection pool, it's possible that we're + // failing because it's since timed out. If + fn maybeRetryRequest(self: *Self) bool { + const request = self.request; + + // We only retry if the connection came from the keepalive pool + // We only use a keepalive connection for specific methods and if + // there's no body. + if (request._connection_from_keepalive == false) { + return false; + } + + // Because of the `self.state == .body` check above, it should be + // impossible to be here and have this be true. This is an important + // check, because we're about to release a connection that we know + // is bad, and we don't want it to go back into the pool. + std.debug.assert(request._keepalive == false); + request.releaseConnection(); + + request.doSendAsync(self.loop, self.handler, false) catch |conn_err| { + // You probably think it's weird that we fallthrough to the: + // return true; + // The caller will take the `true` and just exit. This is what + // we want in this error case, because the next line handles + // the error. We rather emit an "connection error" at this point + // than whatever error we had using the pooled connection. + self.handleError("connection error", conn_err); + }; + + return true; + } + fn processData(self: *Self, d: []u8) ProcessStatus { const reader = &self.reader; @@ -634,8 +868,8 @@ fn AsyncHandler(comptime H: type, comptime L: type) type { // `would_be_first` should be thought of as `is_first` because // we now have a complete header for the first time. if (reader.redirect()) |redirect| { - // We don't redirect until we've drained the body (because, - // if we ever add keepalive, we'll re-use the connection). + // We don't redirect until we've drained the body (to be + // able to re-use the connection for keepalive). // Calling `reader.redirect()` over and over again might not // be the most efficient (it's a very simple function though), // but for a redirect response, chances are we slurped up @@ -674,74 +908,53 @@ fn AsyncHandler(comptime H: type, comptime L: type) type { fn handleError(self: *Self, comptime msg: []const u8, err: anyerror) void { log.err(msg ++ ": {any} ({any} {any})", .{ err, self.request.method, self.request.uri }); self.handler.onHttpResponse(err) catch {}; - self.deinit(); + // just to be safe + self.request._keepalive = false; + self.request.deinit(); } - const Connection = struct { + const Conn = struct { handler: *Self, protocol: Protocol, const Protocol = union(enum) { plain: void, - secure: Secure, - - const Secure = struct { - tls_client: tls.nb.Client(), - state: SecureState = .handshake, - - const SecureState = enum { - handshake, - header, - body, - }; - }; + secure: *tls.nb.Client(), }; - fn deinit(self: *Connection) void { - switch (self.protocol) { - .plain => {}, - .secure => |*secure| secure.tls_client.deinit(), - } - } - - fn connected(self: *Connection) !void { + fn connected(self: *Conn) !void { const handler = self.handler; switch (self.protocol) { .plain => { - // queue everything up - handler.state = .body; + handler.state = .header; const header = try handler.request.buildHeader(); handler.send(header); - if (handler.request.body) |body| { - handler.send(body); - } - handler.receive(); }, - .secure => |*secure| { + .secure => |tls_client| { + std.debug.assert(handler.state == .handshake); // initiate the handshake - _, const i = try secure.tls_client.handshake(handler.read_buf[0..0], handler.write_buf); + _, const i = try tls_client.handshake(handler.read_buf[0..0], handler.write_buf); handler.send(handler.write_buf[0..i]); handler.receive(); }, } } - fn received(self: *Connection, data: []u8) !ProcessStatus { + fn received(self: *Conn, data: []u8) !ProcessStatus { const handler = self.handler; switch (self.protocol) { .plain => return handler.processData(data), - .secure => |*secure| { + .secure => |tls_client| { var used: usize = 0; var closed = false; var cleartext_pos: usize = 0; var status = ProcessStatus.need_more; - var tls_client = &secure.tls_client; if (tls_client.isConnected()) { used, cleartext_pos, closed = try tls_client.decrypt(data); } else { - std.debug.assert(secure.state == .handshake); + std.debug.assert(handler.state == .handshake); // process handshake data used, const i = try tls_client.handshake(data, handler.write_buf); if (i > 0) { @@ -751,7 +964,7 @@ fn AsyncHandler(comptime H: type, comptime L: type) type { // no unused data handler.read_pos = 0; std.debug.assert(used == data.len); - try self.sendSecureHeader(secure); + try self.sendSecureHeader(tls_client); return .wait; } } @@ -804,34 +1017,44 @@ fn AsyncHandler(comptime H: type, comptime L: type) type { } } - fn sent(self: *Connection) !void { + fn sent(self: *Conn) !void { + const handler = self.handler; switch (self.protocol) { - .plain => {}, - .secure => |*secure| { - if (secure.tls_client.isConnected() == false) { - std.debug.assert(secure.state == .handshake); + .plain => switch (handler.state) { + .handshake => unreachable, + .header => { + handler.state = .body; + if (handler.request.body) |body| { + handler.send(body); + } + handler.receive(); + }, + .body => {}, + }, + .secure => |tls_client| { + if (tls_client.isConnected() == false) { + std.debug.assert(handler.state == .handshake); // still handshaking, nothing to do return; } - switch (secure.state) { - .handshake => return self.sendSecureHeader(secure), + switch (handler.state) { + .handshake => return self.sendSecureHeader(tls_client), .header => { - secure.state = .body; - const handler = self.handler; + handler.state = .body; const body = handler.request.body orelse { // We've sent the header, and there's no body // start receiving the response handler.receive(); return; }; - const used, const i = try secure.tls_client.encrypt(body, handler.write_buf); + const used, const i = try tls_client.encrypt(body, handler.write_buf); std.debug.assert(body.len == used); handler.send(handler.write_buf[0..i]); }, .body => { // We've sent the body, start receiving the // response - self.handler.receive(); + handler.receive(); }, } }, @@ -843,11 +1066,12 @@ fn AsyncHandler(comptime H: type, comptime L: type) type { // as soon as we've written our handshake, we consider the connection // "connected". TLS 1.2 requires a extra round trip, and thus is // only connected after we receive response from the server. - fn sendSecureHeader(self: Connection, secure: *Protocol.Secure) !void { - secure.state = .header; + fn sendSecureHeader(self: *Conn, tls_client: *tls.nb.Client()) !void { const handler = self.handler; + handler.state = .header; + const header = try handler.request.buildHeader(); - const used, const i = try secure.tls_client.encrypt(header, handler.write_buf); + const used, const i = try tls_client.encrypt(header, handler.write_buf); std.debug.assert(header.len == used); handler.send(handler.write_buf[0..i]); } @@ -859,44 +1083,38 @@ fn AsyncHandler(comptime H: type, comptime L: type) type { const SyncHandler = struct { request: *Request, - // The Request owns the socket, we shouldn't close it in here. - fn send(self: *SyncHandler, socket: posix.socket_t, address: std.net.Address) !Response { + fn send(self: *SyncHandler) !Response { var request = self.request; - try posix.connect(socket, &address.any, address.getOsSockLen()); - var connection: Connection = undefined; - if (request.secure) { - connection = .{ - .tls = try tls.client(std.net.Stream{ .handle = socket }, .{ - .host = request.host(), - .root_ca = request._client.root_ca, - .insecure_skip_verify = request._tls_verify_host == false, - // .key_log_callback = tls.config.key_log.callback, - }), - }; - } else { - connection = .{ .plain = socket }; - } + // Take the request._connection (a *Connection), and turn it into + // something specific to our SyncHandler, a Conn. + var conn: Conn = blk: { + const c = request._connection.?; + if (c.tls) |*tls_client| { + break :blk .{ .tls = &tls_client.blocking }; + } + break :blk .{ .plain = c.socket }; + }; const header = try request.buildHeader(); - try connection.sendRequest(header, request.body); + try conn.sendRequest(header, request.body); - const state = request._state; - - var buf = state.read_buf; - var reader = Reader.init(state); + var reader = request.newReader(); + var read_buf = request._state.read_buf; while (true) { - const n = try connection.read(buf); - const result = try reader.process(buf[0..n]); + const n = conn.read(read_buf) catch |err| { + return self.maybeRetryOrErr(err); + }; + const result = try reader.process(read_buf[0..n]); if (reader.header_done == false) { continue; } if (reader.redirect()) |redirect| { if (result.done == false) { - try self.drain(&reader, &connection, result.unprocessed); + try self.drain(&reader, &conn, result.unprocessed); } return request.redirectSync(redirect); } @@ -917,9 +1135,9 @@ const SyncHandler = struct { .over = "", .inner = &reader, .done = result.done, - .buffer = state.read_buf, + .buffer = read_buf, .data = result.unprocessed, - .connection = connection, + .conn = conn, }; var body: std.ArrayListUnmanaged(u8) = .{}; var decompressor = std.compress.gzip.decompressor(compress_reader.reader()); @@ -932,26 +1150,54 @@ const SyncHandler = struct { ._peek_buf = body.items, ._peek_len = body.items.len, ._buf = undefined, + ._conn = undefined, ._reader = undefined, - ._connection = undefined, }; } return .{ - ._buf = buf, + ._conn = conn, + ._buf = read_buf, ._request = request, ._reader = reader, ._done = result.done, - ._connection = connection, ._data = result.unprocessed, ._peek_len = 0, - ._peek_buf = state.peek_buf, + ._peek_buf = request._state.peek_buf, .header = reader.response, }; } } - fn drain(self: SyncHandler, reader: *Reader, connection: *Connection, unprocessed: ?[]u8) !void { + fn maybeRetryOrErr(self: *SyncHandler, err: anyerror) !Response { + var request = self.request; + + // we'll only retry if the connection came from the idle pool, because + // these connections might have been closed while idling, so an error + // isn't exactly surprising. + if (request._connection_from_keepalive == false) { + return err; + } + + if (err != error.ConnectionResetByPeer) { + return err; + } + + // this should be our default, and this function should never have been + // called at a point where this could have been set to true. This is + // important because we're about to release a bad connection, and + // we don't want it to go back into the idle pool. + std.debug.assert(request._keepalive == false); + request.releaseConnection(); + + // Don't change this false to true. It ensures that we get a new + // connection. This prevents an endless loop because, if this new + // connection also fails, connection_from_keepalive will be false, and our + // above guard clause will abort the retry. + return request.doSendSync(false); + } + + fn drain(self: SyncHandler, reader: *Reader, conn: *Conn, unprocessed: ?[]u8) !void { if (unprocessed) |data| { const result = try reader.process(data); if (result.done) { @@ -961,7 +1207,7 @@ const SyncHandler = struct { var buf = self.request._state.read_buf; while (true) { - const n = try connection.read(buf); + const n = try conn.read(buf); const result = try reader.process(buf[0..n]); if (result.done) { return; @@ -969,16 +1215,16 @@ const SyncHandler = struct { } } - const Connection = union(enum) { - tls: tls.Connection(std.net.Stream), + const Conn = union(enum) { + tls: *tls.Connection(std.net.Stream), plain: posix.socket_t, - fn sendRequest(self: *Connection, header: []const u8, body: ?[]const u8) !void { + fn sendRequest(self: *Conn, header: []const u8, body: ?[]const u8) !void { switch (self.*) { - .tls => |*tls_conn| { - try tls_conn.writeAll(header); + .tls => |tls_client| { + try tls_client.writeAll(header); if (body) |b| { - try tls_conn.writeAll(b); + try tls_client.writeAll(b); } }, .plain => |socket| { @@ -994,9 +1240,9 @@ const SyncHandler = struct { } } - fn read(self: *Connection, buf: []u8) !usize { + fn read(self: *Conn, buf: []u8) !usize { const n = switch (self.*) { - .tls => |*tls_conn| try tls_conn.read(buf), + .tls => |tls_client| try tls_client.read(buf), .plain => |socket| try posix.read(socket, buf), }; if (n == 0) { @@ -1044,9 +1290,9 @@ const SyncHandler = struct { // the entire body. const CompressedReader = struct { done: bool, + conn: Conn, buffer: []u8, inner: *Reader, - connection: Connection, // Represents data directly from the socket. It hasn't been processed // by the body reader. It could, for example, have chunk information in it. @@ -1097,7 +1343,7 @@ const SyncHandler = struct { return 0; } - const n = try self.connection.read(self.buffer); + const n = try self.conn.read(self.buffer); self.data = self.buffer[0..n]; } } @@ -1116,6 +1362,9 @@ const SyncHandler = struct { // Used for reading the response (both the header and the body) const Reader = struct { + // ref request.keepalive + keepalive: *bool, + // always references state.header_buf header_buf: []u8, @@ -1131,12 +1380,13 @@ const Reader = struct { header_done: bool, - fn init(state: *State) Reader { + fn init(state: *State, keepalive: *bool) Reader { return .{ .pos = 0, .response = .{}, .body_reader = null, .header_done = false, + .keepalive = keepalive, .header_buf = state.header_buf, .arena = state.arena.allocator(), }; @@ -1160,9 +1410,9 @@ const Reader = struct { if (ok == false) { // There's something that our body reader didn't like. It wants // us to emit whatever data we have, but it isn't safe to keep - // the connection alive.s + // the connection alive. std.debug.assert(result.done == true); - self.response.keepalive = false; + self.keepalive.* = false; } return result; } @@ -1240,7 +1490,7 @@ const Reader = struct { // We think we're done reading the body, but we still have data // We'll return what we have as-is, but close the connection // because we don't know what state it's in. - self.response.keepalive = false; + self.keepalive.* = false; } else { result.unprocessed = unprocessed; } @@ -1251,8 +1501,14 @@ const Reader = struct { // We're done parsing the header, and we need to (maybe) setup the BodyReader fn prepareForBody(self: *Reader) !Result { self.header_done = true; - const response = &self.response; + + if (response.get("connection")) |connection| { + if (std.ascii.eqlIgnoreCase(connection, "close")) { + self.keepalive.* = false; + } + } + if (response.get("transfer-encoding")) |te| { if (std.ascii.indexOfIgnoreCase(te, "chunked") != null) { self.body_reader = .{ .chunked = .{ @@ -1302,7 +1558,7 @@ const Reader = struct { } const protocol = data[0..9]; if (std.mem.eql(u8, protocol, "HTTP/1.1 ")) { - self.response.keepalive = true; + self.keepalive.* = true; } else if (std.mem.eql(u8, protocol, "HTTP/1.0 ") == false) { return error.InvalidStatusLine; } @@ -1558,7 +1814,6 @@ const Reader = struct { pub const ResponseHeader = struct { status: u16 = 0, - keepalive: bool = false, headers: std.ArrayListUnmanaged(Header) = .{}, // Stored header has already been lower-cased @@ -1619,9 +1874,10 @@ const HeaderIterator = struct { pub const Progress = struct { first: bool, - // whether or not more data should be expected + // whether or not more data is expected done: bool, - // A piece of data from the body + + // part of the body data: ?[]const u8, header: ResponseHeader, @@ -1631,7 +1887,7 @@ pub const Progress = struct { pub const Response = struct { _reader: Reader, _request: *Request, - _connection: SyncHandler.Connection, + _conn: SyncHandler.Conn, // the buffer to read the peeked data into _peek_buf: []u8, @@ -1678,7 +1934,7 @@ pub const Response = struct { return null; } - const n = try self._connection.read(buf); + const n = try self._conn.read(buf); self._data = buf[0..n]; } } @@ -1858,6 +2114,90 @@ const StatePool = struct { } }; +// Ideally, a connection could be reused as long as the host:port matches. +// But we're also having to match based on blocking and nonblocking and TLS +// and not TLS. It isn't the most efficient. For non-TLS, we could definitely +// always re-use the connection (just toggle the socket's blocking status), but +// for TLS, we'd need to see if the two different TLS objects (blocking and non +// blocking) can be converted from each other. +const IdleConnections = struct { + max: usize, + idle: List, + count: usize, + mutex: Thread.Mutex, + allocator: Allocator, + node_pool: std.heap.MemoryPool(Node), + + const List = std.DoublyLinkedList(*Connection); + const Node = List.Node; + + fn init(allocator: Allocator, max: usize) IdleConnections { + return .{ + .max = max, + .count = 0, + .idle = .{}, + .mutex = .{}, + .allocator = allocator, + .node_pool = std.heap.MemoryPool(Node).init(allocator), + }; + } + + fn deinit(self: *IdleConnections) void { + const allocator = self.allocator; + + self.mutex.lock(); + defer self.mutex.unlock(); + var node = self.idle.first; + while (node) |n| { + const next = n.next; + n.data.deinit(allocator); + node = next; + } + self.node_pool.deinit(); + } + + fn get(self: *IdleConnections, secure: bool, host: []const u8, port: u16, blocking: bool) ?*Connection { + self.mutex.lock(); + defer self.mutex.unlock(); + + var node = self.idle.first; + while (node) |n| { + const connection = n.data; + if (std.ascii.eqlIgnoreCase(connection.host, host) and connection.port == port and connection.blocking == blocking and ((connection.tls == null) == !secure)) { + self.count -= 1; + self.idle.remove(n); + self.node_pool.destroy(n); + return connection; + } + node = n.next; + } + return null; + } + + fn put(self: *IdleConnections, connection: *Connection) !void { + self.mutex.lock(); + defer self.mutex.unlock(); + + var node: *Node = undefined; + if (self.count == self.max) { + const oldest = self.idle.popFirst() orelse { + std.debug.assert(self.max == 0); + connection.deinit(self.allocator); + return; + }; + oldest.data.deinit(self.allocator); + // re-use the node + node = oldest; + } else { + self.count += 1; + node = try self.node_pool.create(); + } + + node.data = connection; + self.idle.append(node); + } +}; + const testing = @import("../testing.zig"); test "HttpClient Reader: fuzz" { var state = try State.init(testing.allocator, 1024, 1024, 100); @@ -1883,7 +2223,6 @@ test "HttpClient Reader: fuzz" { res.reset(); try testReader(&state, &res, "HTTP/1.1 200 \r\n\r\n"); try testing.expectEqual(200, res.status); - try testing.expectEqual(true, res.keepalive); try testing.expectEqual(0, res.body.items.len); try testing.expectEqual(0, res.headers.items.len); } @@ -1892,7 +2231,6 @@ test "HttpClient Reader: fuzz" { res.reset(); try testReader(&state, &res, "HTTP/1.0 404 \r\nError: Not-Found\r\n\r\n"); try testing.expectEqual(404, res.status); - try testing.expectEqual(false, res.keepalive); try testing.expectEqual(0, res.body.items.len); try res.assertHeaders(&.{ "error", "Not-Found" }); } @@ -1901,7 +2239,6 @@ test "HttpClient Reader: fuzz" { res.reset(); try testReader(&state, &res, "HTTP/1.1 200 \r\nSet-Cookie: a32;max-age=60\r\nContent-Length: 12\r\n\r\nOver 9000!!!"); try testing.expectEqual(200, res.status); - try testing.expectEqual(true, res.keepalive); try testing.expectEqual("Over 9000!!!", res.body.items); try res.assertHeaders(&.{ "set-cookie", "a32;max-age=60", "content-length", "12" }); } @@ -1910,7 +2247,6 @@ test "HttpClient Reader: fuzz" { res.reset(); try testReader(&state, &res, "HTTP/1.1 200 \r\nTransFEr-ENcoding: chunked \r\n\r\n0\r\n\r\n"); try testing.expectEqual(200, res.status); - try testing.expectEqual(true, res.keepalive); try testing.expectEqual("", res.body.items); try res.assertHeaders(&.{ "transfer-encoding", "chunked" }); } @@ -1919,7 +2255,6 @@ test "HttpClient Reader: fuzz" { res.reset(); try testReader(&state, &res, "HTTP/1.1 200 \r\nTransFEr-ENcoding: chunked \r\n\r\n0\r\n\r\n"); try testing.expectEqual(200, res.status); - try testing.expectEqual(true, res.keepalive); try testing.expectEqual("", res.body.items); try res.assertHeaders(&.{ "transfer-encoding", "chunked" }); } @@ -1928,7 +2263,6 @@ test "HttpClient Reader: fuzz" { res.reset(); try testReader(&state, &res, "HTTP/1.1 200 \r\nTransFEr-ENcoding: chunked \r\n\r\nE\r\nHello World!!!\r\n2eE;opts\r\n" ++ ("abc" ** 250) ++ "\r\n0\r\n\r\n"); try testing.expectEqual(200, res.status); - try testing.expectEqual(true, res.keepalive); try testing.expectEqual("Hello World!!!" ++ ("abc" ** 250), res.body.items); try res.assertHeaders(&.{ "transfer-encoding", "chunked" }); } @@ -1941,7 +2275,6 @@ test "HttpClient Reader: fuzz" { res.reset(); try testReader(&state, &res, "HTTP/1.1 200 OK\r\n Content-Length : 610000 \r\nOther: 13391AbC93\r\n\r\n" ++ body); try testing.expectEqual(200, res.status); - try testing.expectEqual(true, res.keepalive); try testing.expectEqual(body, res.body.items); try res.assertHeaders(&.{ "content-length", "610000", "other", "13391AbC93" }); } @@ -1967,6 +2300,8 @@ test "HttpClient: sync connect error" { const uri = try Uri.parse("HTTP://127.0.0.1:9920"); var req = try client.request(.GET, &uri); + defer req.deinit(); + try testing.expectError(error.ConnectionRefused, req.sendSync(.{})); } @@ -1977,6 +2312,8 @@ test "HttpClient: sync no body" { const uri = try Uri.parse("http://127.0.0.1:9582/http_client/simple"); var req = try client.request(.GET, &uri); + defer req.deinit(); + var res = try req.sendSync(.{}); if (i == 0) { @@ -1991,18 +2328,21 @@ test "HttpClient: sync no body" { } test "HttpClient: sync tls no body" { - for (0..5) |_| { + for (0..1) |_| { var client = try testClient(); defer client.deinit(); const uri = try Uri.parse("https://127.0.0.1:9581/http_client/simple"); var req = try client.request(.GET, &uri); + defer req.deinit(); + var res = try req.sendSync(.{ .tls_verify_host = false }); try testing.expectEqual(null, try res.next()); try testing.expectEqual(200, res.header.status); - try testing.expectEqual(1, res.header.count()); + try testing.expectEqual(2, res.header.count()); try testing.expectEqual("0", res.header.get("content-length")); + try testing.expectEqual("Close", res.header.get("connection")); } } @@ -2013,6 +2353,8 @@ test "HttpClient: sync with body" { const uri = try Uri.parse("http://127.0.0.1:9582/http_client/echo"); var req = try client.request(.GET, &uri); + defer req.deinit(); + var res = try req.sendSync(.{}); if (i == 0) { @@ -2020,11 +2362,10 @@ test "HttpClient: sync with body" { } try testing.expectEqual("over 9000!", try res.next()); try testing.expectEqual(201, res.header.status); - try testing.expectEqual(5, res.header.count()); - try testing.expectEqual("close", res.header.get("connection")); + try testing.expectEqual(4, res.header.count()); + try testing.expectEqual("Close", res.header.get("connection")); try testing.expectEqual("10", res.header.get("content-length")); try testing.expectEqual("127.0.0.1", res.header.get("_host")); - try testing.expectEqual("Close", res.header.get("_connection")); try testing.expectEqual("Lightpanda/1.0", res.header.get("_user-agent")); } } @@ -2036,6 +2377,8 @@ test "HttpClient: sync with gzip body" { const uri = try Uri.parse("http://127.0.0.1:9582/http_client/gzip"); var req = try client.request(.GET, &uri); + defer req.deinit(); + var res = try req.sendSync(.{}); if (i == 0) { @@ -2051,13 +2394,15 @@ test "HttpClient: sync tls with body" { defer arr.deinit(testing.allocator); try arr.ensureTotalCapacity(testing.allocator, 20); + var client = try testClient(); + defer client.deinit(); for (0..5) |_| { defer arr.clearRetainingCapacity(); - var client = try testClient(); - defer client.deinit(); const uri = try Uri.parse("https://127.0.0.1:9581/http_client/body"); var req = try client.request(.GET, &uri); + defer req.deinit(); + var res = try req.sendSync(.{ .tls_verify_host = false }); while (try res.next()) |data| { @@ -2065,9 +2410,10 @@ test "HttpClient: sync tls with body" { } try testing.expectEqual("1234567890abcdefhijk", arr.items); try testing.expectEqual(201, res.header.status); - try testing.expectEqual(2, res.header.count()); + try testing.expectEqual(3, res.header.count()); try testing.expectEqual("20", res.header.get("content-length")); try testing.expectEqual("HEaDer", res.header.get("another")); + try testing.expectEqual("Close", res.header.get("connection")); } } @@ -2083,6 +2429,8 @@ test "HttpClient: sync redirect from TLS to Plaintext" { const uri = try Uri.parse("https://127.0.0.1:9581/http_client/redirect/insecure"); var req = try client.request(.GET, &uri); + defer req.deinit(); + var res = try req.sendSync(.{ .tls_verify_host = false }); while (try res.next()) |data| { @@ -2090,11 +2438,10 @@ test "HttpClient: sync redirect from TLS to Plaintext" { } try testing.expectEqual(201, res.header.status); try testing.expectEqual("over 9000!", arr.items); - try testing.expectEqual(5, res.header.count()); - try testing.expectEqual("close", res.header.get("connection")); + try testing.expectEqual(4, res.header.count()); + try testing.expectEqual("Close", res.header.get("connection")); try testing.expectEqual("10", res.header.get("content-length")); try testing.expectEqual("127.0.0.1", res.header.get("_host")); - try testing.expectEqual("Close", res.header.get("_connection")); try testing.expectEqual("Lightpanda/1.0", res.header.get("_user-agent")); } } @@ -2111,6 +2458,7 @@ test "HttpClient: sync redirect plaintext to TLS" { const uri = try Uri.parse("http://127.0.0.1:9582/http_client/redirect/secure"); var req = try client.request(.GET, &uri); + defer req.deinit(); var res = try req.sendSync(.{ .tls_verify_host = false }); while (try res.next()) |data| { @@ -2118,9 +2466,10 @@ test "HttpClient: sync redirect plaintext to TLS" { } try testing.expectEqual(201, res.header.status); try testing.expectEqual("1234567890abcdefhijk", arr.items); - try testing.expectEqual(2, res.header.count()); + try testing.expectEqual(3, res.header.count()); try testing.expectEqual("20", res.header.get("content-length")); try testing.expectEqual("HEaDer", res.header.get("another")); + try testing.expectEqual("Close", res.header.get("connection")); } } @@ -2130,15 +2479,15 @@ test "HttpClient: sync GET redirect" { const uri = try Uri.parse("http://127.0.0.1:9582/http_client/redirect"); var req = try client.request(.GET, &uri); + defer req.deinit(); var res = try req.sendSync(.{ .tls_verify_host = false }); try testing.expectEqual("over 9000!", try res.next()); try testing.expectEqual(201, res.header.status); - try testing.expectEqual(5, res.header.count()); - try testing.expectEqual("close", res.header.get("connection")); + try testing.expectEqual(4, res.header.count()); + try testing.expectEqual("Close", res.header.get("connection")); try testing.expectEqual("10", res.header.get("content-length")); try testing.expectEqual("127.0.0.1", res.header.get("_host")); - try testing.expectEqual("Close", res.header.get("_connection")); try testing.expectEqual("Lightpanda/1.0", res.header.get("_user-agent")); } @@ -2187,7 +2536,7 @@ test "HttpClient: async no body" { const res = handler.response; try testing.expectEqual("", res.body.items); try testing.expectEqual(200, res.status); - try res.assertHeaders(&.{ "connection", "close", "content-length", "0" }); + try res.assertHeaders(&.{ "content-length", "0", "connection", "close" }); } test "HttpClient: async with body" { @@ -2206,11 +2555,10 @@ test "HttpClient: async with body" { try testing.expectEqual("over 9000!", res.body.items); try testing.expectEqual(201, res.status); try res.assertHeaders(&.{ - "connection", "close", "content-length", "10", "_host", "127.0.0.1", "_user-agent", "Lightpanda/1.0", - "_connection", "Close", + "connection", "Close", }); } @@ -2228,7 +2576,7 @@ test "HttpClient: async redirect" { // Called twice on purpose. The initial GET resutls in the # of pending // events to reach 0. This causes our `run_for_ns` to return. But we then // start to requeue events (from the redirected request), so we need the - //loop to process those also. + // loop to process those also. try handler.loop.io.run_for_ns(std.time.ns_per_ms); try handler.waitUntilDone(); @@ -2236,19 +2584,17 @@ test "HttpClient: async redirect" { try testing.expectEqual("over 9000!", res.body.items); try testing.expectEqual(201, res.status); try res.assertHeaders(&.{ - "connection", "close", "content-length", "10", "_host", "127.0.0.1", "_user-agent", "Lightpanda/1.0", - "_connection", "Close", + "connection", "Close", }); } test "HttpClient: async tls no body" { + var client = try testClient(); + defer client.deinit(); for (0..5) |_| { - var client = try testClient(); - defer client.deinit(); - var handler = try CaptureHandler.init(); defer handler.deinit(); @@ -2260,11 +2606,16 @@ test "HttpClient: async tls no body" { const res = handler.response; try testing.expectEqual("", res.body.items); try testing.expectEqual(200, res.status); - try res.assertHeaders(&.{ "content-length", "0" }); + try res.assertHeaders(&.{ + "content-length", + "0", + "connection", + "Close", + }); } } -test "HttpClient: async tls with body" { +test "HttpClient: async tls with body x" { for (0..5) |_| { var client = try testClient(); defer client.deinit(); @@ -2280,17 +2631,16 @@ test "HttpClient: async tls with body" { const res = handler.response; try testing.expectEqual("1234567890abcdefhijk", res.body.items); try testing.expectEqual(201, res.status); - try res.assertHeaders(&.{ "content-length", "20", "another", "HEaDer" }); + try res.assertHeaders(&.{ + "content-length", "20", + "connection", "Close", + "another", "HEaDer", + }); } } test "HttpClient: async redirect from TLS to Plaintext" { - var arr: std.ArrayListUnmanaged(u8) = .{}; - defer arr.deinit(testing.allocator); - try arr.ensureTotalCapacity(testing.allocator, 20); - - for (0..5) |_| { - defer arr.clearRetainingCapacity(); + for (0..1) |_| { var client = try testClient(); defer client.deinit(); @@ -2305,19 +2655,20 @@ test "HttpClient: async redirect from TLS to Plaintext" { const res = handler.response; try testing.expectEqual(201, res.status); try testing.expectEqual("over 9000!", res.body.items); - try res.assertHeaders(&.{ "connection", "close", "content-length", "10", "_host", "127.0.0.1", "_user-agent", "Lightpanda/1.0", "_connection", "Close" }); + try res.assertHeaders(&.{ + "content-length", "10", + "_host", "127.0.0.1", + "_user-agent", "Lightpanda/1.0", + "connection", "Close", + }); } } test "HttpClient: async redirect plaintext to TLS" { - var arr: std.ArrayListUnmanaged(u8) = .{}; - defer arr.deinit(testing.allocator); - try arr.ensureTotalCapacity(testing.allocator, 20); - for (0..5) |_| { - defer arr.clearRetainingCapacity(); var client = try testClient(); defer client.deinit(); + var handler = try CaptureHandler.init(); defer handler.deinit(); @@ -2329,7 +2680,7 @@ test "HttpClient: async redirect plaintext to TLS" { const res = handler.response; try testing.expectEqual(201, res.status); try testing.expectEqual("1234567890abcdefhijk", res.body.items); - try res.assertHeaders(&.{ "content-length", "20", "another", "HEaDer" }); + try res.assertHeaders(&.{ "content-length", "20", "connection", "Close", "another", "HEaDer" }); } } @@ -2383,7 +2734,6 @@ test "HttpClient: HeaderIterator" { const TestResponse = struct { status: u16, - keepalive: ?bool, arena: std.heap.ArenaAllocator, body: std.ArrayListUnmanaged(u8), headers: std.ArrayListUnmanaged(Header), @@ -2391,7 +2741,6 @@ const TestResponse = struct { fn init() TestResponse { return .{ .status = 0, - .keepalive = null, .body = .{}, .headers = .{}, .arena = ArenaAllocator.init(testing.allocator), @@ -2405,7 +2754,6 @@ const TestResponse = struct { fn reset(self: *TestResponse) void { _ = self.arena.reset(.{ .retain_capacity = {} }); self.status = 0; - self.keepalive = null; self.body = .{}; self.headers = .{}; } @@ -2467,7 +2815,6 @@ const CaptureHandler = struct { .value = try allocator.dupe(u8, header.value), }); } - self.response.keepalive = progress.header.keepalive; self.reset.set(); } } @@ -2485,7 +2832,8 @@ const CaptureHandler = struct { fn testReader(state: *State, res: *TestResponse, data: []const u8) !void { var status: u16 = 0; - var r = Reader.init(state); + var keepalive = false; + var r = Reader.init(state, &keepalive); // dupe it so that we have a mutable copy const owned = try testing.allocator.dupe(u8, data); @@ -2515,7 +2863,6 @@ fn testReader(state: *State, res: *TestResponse, data: []const u8) !void { if (result.done) { res.status = status; res.headers = r.response.headers; - res.keepalive = r.response.keepalive; return; } to_process = result.unprocessed orelse break; diff --git a/src/main.zig b/src/main.zig index c6114d7a..c6c163d5 100644 --- a/src/main.zig +++ b/src/main.zig @@ -470,51 +470,57 @@ fn serveHTTP(address: std.net.Address) !void { defer conn.stream.close(); var http_server = std.http.Server.init(conn, &read_buffer); - while (http_server.state == .ready) { - var request = http_server.receiveHead() catch |err| switch (err) { - error.HttpConnectionClosing => continue :ACCEPT, - else => { - std.debug.print("Test HTTP Server error: {}\n", .{err}); - return err; + var request = http_server.receiveHead() catch |err| switch (err) { + error.HttpConnectionClosing => continue :ACCEPT, + else => { + std.debug.print("Test HTTP Server error: {}\n", .{err}); + return err; + }, + }; + + const path = request.head.target; + if (std.mem.eql(u8, path, "/loader")) { + try request.respond("Hello!", .{ + .extra_headers = &.{.{ .name = "Connection", .value = "close" }}, + }); + } else if (std.mem.eql(u8, path, "/http_client/simple")) { + try request.respond("", .{ + .extra_headers = &.{.{ .name = "Connection", .value = "close" }}, + }); + } else if (std.mem.eql(u8, path, "/http_client/redirect")) { + try request.respond("", .{ + .status = .moved_permanently, + .extra_headers = &.{ + .{ .name = "Connection", .value = "close" }, + .{ .name = "LOCATION", .value = "../http_client/echo" }, }, - }; + }); + } else if (std.mem.eql(u8, path, "/http_client/redirect/secure")) { + try request.respond("", .{ + .status = .moved_permanently, + .extra_headers = &.{ .{ .name = "Connection", .value = "close" }, .{ .name = "LOCATION", .value = "https://127.0.0.1:9581/http_client/body" } }, + }); + } else if (std.mem.eql(u8, path, "/http_client/gzip")) { + const body = &.{ 0x1f, 0x8b, 0x08, 0x08, 0x01, 0xc6, 0x19, 0x68, 0x00, 0x03, 0x74, 0x65, 0x73, 0x74, 0x2e, 0x68, 0x74, 0x6d, 0x6c, 0x00, 0x73, 0x54, 0xc8, 0x4b, 0x2d, 0x57, 0x48, 0x2a, 0xca, 0x2f, 0x2f, 0x4e, 0x2d, 0x52, 0x48, 0x2a, 0xcd, 0xcc, 0x29, 0x51, 0x48, 0xcb, 0x2f, 0x52, 0xc8, 0x4d, 0x4c, 0xce, 0xc8, 0xcc, 0x4b, 0x2d, 0xe6, 0x02, 0x00, 0xe7, 0xc3, 0x4b, 0x27, 0x21, 0x00, 0x00, 0x00 }; + try request.respond(body, .{ + .extra_headers = &.{ .{ .name = "Connection", .value = "close" }, .{ .name = "Content-Encoding", .value = "gzip" } }, + }); + } else if (std.mem.eql(u8, path, "/http_client/echo")) { + var headers: std.ArrayListUnmanaged(std.http.Header) = .{}; - const path = request.head.target; - if (std.mem.eql(u8, path, "/loader")) { - try request.respond("Hello!", .{}); - } else if (std.mem.eql(u8, path, "/http_client/simple")) { - try request.respond("", .{}); - } else if (std.mem.eql(u8, path, "/http_client/redirect")) { - try request.respond("", .{ - .status = .moved_permanently, - .extra_headers = &.{.{ .name = "LOCATION", .value = "../http_client/echo" }}, - }); - } else if (std.mem.eql(u8, path, "/http_client/redirect/secure")) { - try request.respond("", .{ - .status = .moved_permanently, - .extra_headers = &.{.{ .name = "LOCATION", .value = "https://127.0.0.1:9581/http_client/body" }}, - }); - } else if (std.mem.eql(u8, path, "/http_client/gzip")) { - const body = &.{ 0x1f, 0x8b, 0x08, 0x08, 0x01, 0xc6, 0x19, 0x68, 0x00, 0x03, 0x74, 0x65, 0x73, 0x74, 0x2e, 0x68, 0x74, 0x6d, 0x6c, 0x00, 0x73, 0x54, 0xc8, 0x4b, 0x2d, 0x57, 0x48, 0x2a, 0xca, 0x2f, 0x2f, 0x4e, 0x2d, 0x52, 0x48, 0x2a, 0xcd, 0xcc, 0x29, 0x51, 0x48, 0xcb, 0x2f, 0x52, 0xc8, 0x4d, 0x4c, 0xce, 0xc8, 0xcc, 0x4b, 0x2d, 0xe6, 0x02, 0x00, 0xe7, 0xc3, 0x4b, 0x27, 0x21, 0x00, 0x00, 0x00 }; - try request.respond(body, .{ - .extra_headers = &.{.{ .name = "Content-Encoding", .value = "gzip" }}, - }); - } else if (std.mem.eql(u8, path, "/http_client/echo")) { - var headers: std.ArrayListUnmanaged(std.http.Header) = .{}; - - var it = request.iterateHeaders(); - while (it.next()) |hdr| { - try headers.append(aa, .{ - .name = try std.fmt.allocPrint(aa, "_{s}", .{hdr.name}), - .value = hdr.value, - }); - } - - try request.respond("over 9000!", .{ - .status = .created, - .extra_headers = headers.items, + var it = request.iterateHeaders(); + while (it.next()) |hdr| { + try headers.append(aa, .{ + .name = try std.fmt.allocPrint(aa, "_{s}", .{hdr.name}), + .value = hdr.value, }); } + try headers.append(aa, .{ .name = "Connection", .value = "Close" }); + + try request.respond("over 9000!", .{ + .status = .created, + .extra_headers = headers.items, + }); } } } @@ -528,9 +534,6 @@ fn serveHTTPS(address: std.net.Address) !void { var listener = try address.listen(.{ .reuse_address = true }); defer listener.deinit(); - var arena = std.heap.ArenaAllocator.init(std.testing.allocator); - defer arena.deinit(); - test_wg.finish(); var seed: u64 = undefined; @@ -540,9 +543,6 @@ fn serveHTTPS(address: std.net.Address) !void { var read_buffer: [1024]u8 = undefined; while (true) { - // defer _ = arena.reset(.{ .retain_with_limit = 1024 }); - // const aa = arena.allocator(); - const stream = blk: { const conn = try listener.accept(); break :blk conn.stream; @@ -570,17 +570,17 @@ fn serveHTTPS(address: std.net.Address) !void { var response: []const u8 = undefined; if (std.mem.eql(u8, path, "/http_client/simple")) { fragment = true; - response = "HTTP/1.1 200 \r\nContent-Length: 0\r\n\r\n"; + response = "HTTP/1.1 200 \r\nContent-Length: 0\r\nConnection: Close\r\n\r\n"; } else if (std.mem.eql(u8, path, "/http_client/body")) { fragment = true; - response = "HTTP/1.1 201 CREATED\r\nContent-Length: 20\r\n Another : HEaDer \r\n\r\n1234567890abcdefhijk"; + response = "HTTP/1.1 201 CREATED\r\nContent-Length: 20\r\nConnection: Close\r\n Another : HEaDer \r\n\r\n1234567890abcdefhijk"; } else if (std.mem.eql(u8, path, "/http_client/redirect/insecure")) { fragment = true; - response = "HTTP/1.1 307 GOTO\r\nLocation: http://127.0.0.1:9582/http_client/redirect\r\n\r\n"; + response = "HTTP/1.1 307 GOTO\r\nLocation: http://127.0.0.1:9582/http_client/redirect\r\nConnection: Close\r\n\r\n"; } else if (std.mem.eql(u8, path, "/xhr")) { - response = "HTTP/1.1 200 OK\r\nContent-Type: text/html; charset=utf-8\r\nContent-Length: 100\r\n\r\n" ++ ("1234567890" ** 10); + response = "HTTP/1.1 200 OK\r\nContent-Type: text/html; charset=utf-8\r\nContent-Length: 100\r\nConnection: Close\r\n\r\n" ++ ("1234567890" ** 10); } else if (std.mem.eql(u8, path, "/xhr/json")) { - response = "HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nContent-Length: 18\r\n\r\n{\"over\":\"9000!!!\"}"; + response = "HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nContent-Length: 18\r\nConnection: Close\r\n\r\n{\"over\":\"9000!!!\"}"; } else { // should not have an unknown path unreachable;