From 34d2fc15035e6af710960fa60fa42e38c5afd314 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adri=C3=A0=20Arrufat?= Date: Wed, 25 Feb 2026 23:14:06 +0900 Subject: [PATCH] mcp: support notifications and improve error handling Make Request id optional for JSON-RPC notifications and handle the initialized event. Improve thread safety, logging, and error paths. --- src/mcp/Server.zig | 6 ++- src/mcp/protocol.zig | 2 +- src/mcp/protocol_tests.zig | 4 +- src/mcp/resources.zig | 20 ++++++---- src/mcp/router.zig | 14 +++++-- src/mcp/tools.zig | 80 ++++++++++++++++++++++---------------- 6 files changed, 76 insertions(+), 50 deletions(-) diff --git a/src/mcp/Server.zig b/src/mcp/Server.zig index 200dc8a5..8ddbdaca 100644 --- a/src/mcp/Server.zig +++ b/src/mcp/Server.zig @@ -75,7 +75,9 @@ pub const McpServer = struct { pub fn stop(self: *Self) void { self.is_running.store(false, .seq_cst); + self.queue_mutex.lock(); self.queue_condition.signal(); + self.queue_mutex.unlock(); } fn ioWorker(self: *Self) void { @@ -93,7 +95,7 @@ pub const McpServer = struct { self.queue_mutex.lock(); self.message_queue.append(self.allocator, msg) catch |err| { - std.debug.print("MCP Error: Failed to queue message: {}\n", .{err}); + lp.log.err(.app, "MCP Error: Failed to queue message", .{ .err = err }); self.allocator.free(msg); }; self.queue_mutex.unlock(); @@ -103,7 +105,7 @@ pub const McpServer = struct { self.stop(); break; } - std.debug.print("MCP IO Error: {}\n", .{err}); + lp.log.err(.app, "MCP IO Error", .{ .err = err }); std.Thread.sleep(100 * std.time.ns_per_ms); } } diff --git a/src/mcp/protocol.zig b/src/mcp/protocol.zig index 7a759710..88e281e3 100644 --- a/src/mcp/protocol.zig +++ b/src/mcp/protocol.zig @@ -2,7 +2,7 @@ const std = @import("std"); pub const Request = struct { jsonrpc: []const u8 = "2.0", - id: std.json.Value, + id: ?std.json.Value = null, method: []const u8, params: ?std.json.Value = null, }; diff --git a/src/mcp/protocol_tests.zig b/src/mcp/protocol_tests.zig index 64dd2d1c..0a5fca0a 100644 --- a/src/mcp/protocol_tests.zig +++ b/src/mcp/protocol_tests.zig @@ -25,8 +25,8 @@ test "protocol request parsing" { const req = parsed.value; try testing.expectEqualStrings("2.0", req.jsonrpc); try testing.expectEqualStrings("initialize", req.method); - try testing.expect(req.id == .integer); - try testing.expectEqual(@as(i64, 1), req.id.integer); + try testing.expect(req.id.? == .integer); + try testing.expectEqual(@as(i64, 1), req.id.?.integer); try testing.expect(req.params != null); // Test nested parsing of InitializeParams diff --git a/src/mcp/resources.zig b/src/mcp/resources.zig index 781475ac..d5ff2d71 100644 --- a/src/mcp/resources.zig +++ b/src/mcp/resources.zig @@ -25,7 +25,7 @@ pub fn handleList(server: *McpServer, req: protocol.Request) !void { .resources = &resources, }; - try sendResult(server, req.id, result); + try sendResult(server, req.id.?, result); } const ReadParams = struct { @@ -34,16 +34,18 @@ const ReadParams = struct { pub fn handleRead(server: *McpServer, arena: std.mem.Allocator, req: protocol.Request) !void { if (req.params == null) { - return sendError(server, req.id, -32602, "Missing params"); + return sendError(server, req.id.?, -32602, "Missing params"); } const params = std.json.parseFromValueLeaky(ReadParams, arena, req.params.?, .{}) catch { - return sendError(server, req.id, -32602, "Invalid params"); + return sendError(server, req.id.?, -32602, "Invalid params"); }; if (std.mem.eql(u8, params.uri, "mcp://page/html")) { var aw = std.Io.Writer.Allocating.init(arena); - try lp.dump.root(server.page.window._document, .{}, &aw.writer, server.page); + lp.dump.root(server.page.document.asNode(), .{}, &aw.writer, server.page) catch { + return sendError(server, req.id.?, -32603, "Internal error reading HTML"); + }; const contents = [_]struct { uri: []const u8, @@ -54,10 +56,12 @@ pub fn handleRead(server: *McpServer, arena: std.mem.Allocator, req: protocol.Re .mimeType = "text/html", .text = aw.written(), }}; - try sendResult(server, req.id, .{ .contents = &contents }); + try sendResult(server, req.id.?, .{ .contents = &contents }); } else if (std.mem.eql(u8, params.uri, "mcp://page/markdown")) { var aw = std.Io.Writer.Allocating.init(arena); - try lp.markdown.dump(server.page.window._document.asNode(), .{}, &aw.writer, server.page); + lp.markdown.dump(server.page.document.asNode(), .{}, &aw.writer, server.page) catch { + return sendError(server, req.id.?, -32603, "Internal error reading Markdown"); + }; const contents = [_]struct { uri: []const u8, @@ -68,9 +72,9 @@ pub fn handleRead(server: *McpServer, arena: std.mem.Allocator, req: protocol.Re .mimeType = "text/markdown", .text = aw.written(), }}; - try sendResult(server, req.id, .{ .contents = &contents }); + try sendResult(server, req.id.?, .{ .contents = &contents }); } else { - return sendError(server, req.id, -32602, "Resource not found"); + return sendError(server, req.id.?, -32602, "Resource not found"); } } diff --git a/src/mcp/router.zig b/src/mcp/router.zig index 3b6f3fe8..afe496b1 100644 --- a/src/mcp/router.zig +++ b/src/mcp/router.zig @@ -31,6 +31,14 @@ fn handleMessage(server: *McpServer, arena: std.mem.Allocator, msg: []const u8) return; }; + if (parsed.id == null) { + // It's a notification + if (std.mem.eql(u8, parsed.method, "notifications/initialized")) { + log.info(.app, "MCP Client Initialized", .{}); + } + return; + } + if (std.mem.eql(u8, parsed.method, "initialize")) { try handleInitialize(server, parsed); } else if (std.mem.eql(u8, parsed.method, "resources/list")) { @@ -38,12 +46,12 @@ fn handleMessage(server: *McpServer, arena: std.mem.Allocator, msg: []const u8) } else if (std.mem.eql(u8, parsed.method, "resources/read")) { try resources.handleRead(server, arena, parsed); } else if (std.mem.eql(u8, parsed.method, "tools/list")) { - try tools.handleList(server, parsed); + try tools.handleList(server, arena, parsed); } else if (std.mem.eql(u8, parsed.method, "tools/call")) { try tools.handleCall(server, arena, parsed); } else { try server.sendResponse(protocol.Response{ - .id = parsed.id, + .id = parsed.id.?, .@"error" = protocol.Error{ .code = -32601, .message = "Method not found", @@ -78,5 +86,5 @@ fn handleInitialize(server: *McpServer, req: protocol.Request) !void { }, }; - try sendResponseGeneric(server, req.id, result); + try sendResponseGeneric(server, req.id.?, result); } diff --git a/src/mcp/tools.zig b/src/mcp/tools.zig index f0605d39..e4a314de 100644 --- a/src/mcp/tools.zig +++ b/src/mcp/tools.zig @@ -10,12 +10,12 @@ const Element = @import("../browser/webapi/Element.zig"); const Selector = @import("../browser/webapi/selector/Selector.zig"); const String = @import("../string.zig").String; -pub fn handleList(server: *McpServer, req: protocol.Request) !void { +pub fn handleList(server: *McpServer, arena: std.mem.Allocator, req: protocol.Request) !void { const tools = [_]protocol.Tool{ .{ .name = "goto", .description = "Navigate to a specified URL and load the page in memory so it can be reused later for info extraction.", - .inputSchema = std.json.parseFromSliceLeaky(std.json.Value, server.allocator, + .inputSchema = std.json.parseFromSliceLeaky(std.json.Value, arena, \\{ \\ "type": "object", \\ "properties": { @@ -28,7 +28,7 @@ pub fn handleList(server: *McpServer, req: protocol.Request) !void { .{ .name = "search", .description = "Use a search engine to look for specific words, terms, sentences. The search page will then be loaded in memory.", - .inputSchema = std.json.parseFromSliceLeaky(std.json.Value, server.allocator, + .inputSchema = std.json.parseFromSliceLeaky(std.json.Value, arena, \\{ \\ "type": "object", \\ "properties": { @@ -41,17 +41,17 @@ pub fn handleList(server: *McpServer, req: protocol.Request) !void { .{ .name = "markdown", .description = "Get the page content in markdown format.", - .inputSchema = std.json.parseFromSliceLeaky(std.json.Value, server.allocator, "{\"type\":\"object\",\"properties\":{}}", .{}) catch unreachable, + .inputSchema = std.json.parseFromSliceLeaky(std.json.Value, arena, "{\"type\":\"object\",\"properties\":{}}", .{}) catch unreachable, }, .{ .name = "links", .description = "Extract all links in the opened page", - .inputSchema = std.json.parseFromSliceLeaky(std.json.Value, server.allocator, "{\"type\":\"object\",\"properties\":{}}", .{}) catch unreachable, + .inputSchema = std.json.parseFromSliceLeaky(std.json.Value, arena, "{\"type\":\"object\",\"properties\":{}}", .{}) catch unreachable, }, .{ .name = "evaluate", .description = "Evaluate JavaScript in the current page context", - .inputSchema = std.json.parseFromSliceLeaky(std.json.Value, server.allocator, + .inputSchema = std.json.parseFromSliceLeaky(std.json.Value, arena, \\{ \\ "type": "object", \\ "properties": { @@ -64,7 +64,7 @@ pub fn handleList(server: *McpServer, req: protocol.Request) !void { .{ .name = "over", .description = "Used to indicate that the task is over and give the final answer if there is any. This is the last tool to be called in a task.", - .inputSchema = std.json.parseFromSliceLeaky(std.json.Value, server.allocator, + .inputSchema = std.json.parseFromSliceLeaky(std.json.Value, arena, \\{ \\ "type": "object", \\ "properties": { @@ -82,7 +82,7 @@ pub fn handleList(server: *McpServer, req: protocol.Request) !void { .tools = &tools, }; - try sendResult(server, req.id, result); + try sendResult(server, req.id.?, result); } const GotoParams = struct { @@ -103,7 +103,7 @@ const OverParams = struct { pub fn handleCall(server: *McpServer, arena: std.mem.Allocator, req: protocol.Request) !void { if (req.params == null) { - return sendError(server, req.id, -32602, "Missing params"); + return sendError(server, req.id.?, -32602, "Missing params"); } const CallParams = struct { @@ -112,67 +112,79 @@ pub fn handleCall(server: *McpServer, arena: std.mem.Allocator, req: protocol.Re }; const call_params = std.json.parseFromValueLeaky(CallParams, arena, req.params.?, .{}) catch { - return sendError(server, req.id, -32602, "Invalid params"); + return sendError(server, req.id.?, -32602, "Invalid params"); }; if (std.mem.eql(u8, call_params.name, "goto") or std.mem.eql(u8, call_params.name, "navigate")) { if (call_params.arguments == null) { - return sendError(server, req.id, -32602, "Missing arguments for goto"); + return sendError(server, req.id.?, -32602, "Missing arguments for goto"); } const args = std.json.parseFromValueLeaky(GotoParams, arena, call_params.arguments.?, .{}) catch { - return sendError(server, req.id, -32602, "Invalid arguments for goto"); + return sendError(server, req.id.?, -32602, "Invalid arguments for goto"); }; - try performGoto(server, arena, args.url); + performGoto(server, arena, args.url) catch { + return sendError(server, req.id.?, -32603, "Internal error during navigation"); + }; const content = [_]struct { type: []const u8, text: []const u8 }{.{ .type = "text", .text = "Navigated successfully." }}; - try sendResult(server, req.id, .{ .content = &content }); + try sendResult(server, req.id.?, .{ .content = &content }); } else if (std.mem.eql(u8, call_params.name, "search")) { if (call_params.arguments == null) { - return sendError(server, req.id, -32602, "Missing arguments for search"); + return sendError(server, req.id.?, -32602, "Missing arguments for search"); } const args = std.json.parseFromValueLeaky(SearchParams, arena, call_params.arguments.?, .{}) catch { - return sendError(server, req.id, -32602, "Invalid arguments for search"); + return sendError(server, req.id.?, -32602, "Invalid arguments for search"); }; const component: std.Uri.Component = .{ .raw = args.text }; var url_aw = std.Io.Writer.Allocating.init(arena); - try component.formatQuery(&url_aw.writer); - const url = try std.fmt.allocPrint(arena, "https://duckduckgo.com/?q={s}", .{url_aw.written()}); + component.formatQuery(&url_aw.writer) catch { + return sendError(server, req.id.?, -32603, "Internal error formatting query"); + }; + const url = std.fmt.allocPrint(arena, "https://duckduckgo.com/?q={s}", .{url_aw.written()}) catch { + return sendError(server, req.id.?, -32603, "Internal error formatting URL"); + }; - try performGoto(server, arena, url); + performGoto(server, arena, url) catch { + return sendError(server, req.id.?, -32603, "Internal error during search navigation"); + }; const content = [_]struct { type: []const u8, text: []const u8 }{.{ .type = "text", .text = "Search performed successfully." }}; - try sendResult(server, req.id, .{ .content = &content }); + try sendResult(server, req.id.?, .{ .content = &content }); } else if (std.mem.eql(u8, call_params.name, "markdown")) { var aw = std.Io.Writer.Allocating.init(arena); - try lp.markdown.dump(server.page.document.asNode(), .{}, &aw.writer, server.page); + lp.markdown.dump(server.page.document.asNode(), .{}, &aw.writer, server.page) catch { + return sendError(server, req.id.?, -32603, "Internal error parsing markdown"); + }; const content = [_]struct { type: []const u8, text: []const u8 }{.{ .type = "text", .text = aw.written() }}; - try sendResult(server, req.id, .{ .content = &content }); + try sendResult(server, req.id.?, .{ .content = &content }); } else if (std.mem.eql(u8, call_params.name, "links")) { - const list = try Selector.querySelectorAll(server.page.document.asNode(), "a[href]", server.page); + const list = Selector.querySelectorAll(server.page.document.asNode(), "a[href]", server.page) catch { + return sendError(server, req.id.?, -32603, "Internal error querying selector"); + }; var aw = std.Io.Writer.Allocating.init(arena); var first = true; for (list._nodes) |node| { if (node.is(Element)) |el| { if (el.getAttributeSafe(String.wrap("href"))) |href| { - if (!first) try aw.writer.writeByte('\n'); - try aw.writer.writeAll(href); + if (!first) aw.writer.writeByte('\n') catch continue; + aw.writer.writeAll(href) catch continue; first = false; } } } const content = [_]struct { type: []const u8, text: []const u8 }{.{ .type = "text", .text = aw.written() }}; - try sendResult(server, req.id, .{ .content = &content }); + try sendResult(server, req.id.?, .{ .content = &content }); } else if (std.mem.eql(u8, call_params.name, "evaluate")) { if (call_params.arguments == null) { - return sendError(server, req.id, -32602, "Missing arguments for evaluate"); + return sendError(server, req.id.?, -32602, "Missing arguments for evaluate"); } const args = std.json.parseFromValueLeaky(EvaluateParams, arena, call_params.arguments.?, .{}) catch { - return sendError(server, req.id, -32602, "Invalid arguments for evaluate"); + return sendError(server, req.id.?, -32602, "Invalid arguments for evaluate"); }; var ls: js.Local.Scope = undefined; @@ -181,25 +193,25 @@ pub fn handleCall(server: *McpServer, arena: std.mem.Allocator, req: protocol.Re const js_result = ls.local.compileAndRun(args.script, null) catch { const content = [_]struct { type: []const u8, text: []const u8 }{.{ .type = "text", .text = "Script evaluation failed." }}; - return sendResult(server, req.id, .{ .content = &content, .isError = true }); + return sendResult(server, req.id.?, .{ .content = &content, .isError = true }); }; const str_result = js_result.toStringSliceWithAlloc(arena) catch "undefined"; const content = [_]struct { type: []const u8, text: []const u8 }{.{ .type = "text", .text = str_result }}; - try sendResult(server, req.id, .{ .content = &content }); + try sendResult(server, req.id.?, .{ .content = &content }); } else if (std.mem.eql(u8, call_params.name, "over")) { if (call_params.arguments == null) { - return sendError(server, req.id, -32602, "Missing arguments for over"); + return sendError(server, req.id.?, -32602, "Missing arguments for over"); } const args = std.json.parseFromValueLeaky(OverParams, arena, call_params.arguments.?, .{}) catch { - return sendError(server, req.id, -32602, "Invalid arguments for over"); + return sendError(server, req.id.?, -32602, "Invalid arguments for over"); }; const content = [_]struct { type: []const u8, text: []const u8 }{.{ .type = "text", .text = args.result }}; - try sendResult(server, req.id, .{ .content = &content }); + try sendResult(server, req.id.?, .{ .content = &content }); } else { - return sendError(server, req.id, -32601, "Tool not found"); + return sendError(server, req.id.?, -32601, "Tool not found"); } }