diff --git a/src/mcp/tools.zig b/src/mcp/tools.zig index f4a82570..4a3dbfc0 100644 --- a/src/mcp/tools.zig +++ b/src/mcp/tools.zig @@ -325,7 +325,7 @@ fn handleGoto(server: *Server, arena: std.mem.Allocator, id: std.json.Value, arg } fn handleMarkdown(server: *Server, arena: std.mem.Allocator, id: std.json.Value, arguments: ?std.json.Value) !void { - const args = parseArgsOrDefault(UrlParams, arena, arguments); + const args = try parseArgsOrDefault(UrlParams, arena, arguments, server, id); const page = try ensurePage(server, id, args.url); const content = [_]protocol.TextContent(ToolStreamingText){.{ @@ -335,7 +335,7 @@ fn handleMarkdown(server: *Server, arena: std.mem.Allocator, id: std.json.Value, } fn handleLinks(server: *Server, arena: std.mem.Allocator, id: std.json.Value, arguments: ?std.json.Value) !void { - const args = parseArgsOrDefault(UrlParams, arena, arguments); + const args = try parseArgsOrDefault(UrlParams, arena, arguments, server, id); const page = try ensurePage(server, id, args.url); const content = [_]protocol.TextContent(ToolStreamingText){.{ @@ -350,7 +350,7 @@ fn handleSemanticTree(server: *Server, arena: std.mem.Allocator, id: std.json.Va backendNodeId: ?u32 = null, maxDepth: ?u32 = null, }; - const args = parseArgsOrDefault(TreeParams, arena, arguments); + const args = try parseArgsOrDefault(TreeParams, arena, arguments, server, id); const page = try ensurePage(server, id, args.url); const content = [_]protocol.TextContent(ToolStreamingText){.{ @@ -367,7 +367,7 @@ fn handleSemanticTree(server: *Server, arena: std.mem.Allocator, id: std.json.Va } fn handleInteractiveElements(server: *Server, arena: std.mem.Allocator, id: std.json.Value, arguments: ?std.json.Value) !void { - const args = parseArgsOrDefault(UrlParams, arena, arguments); + const args = try parseArgsOrDefault(UrlParams, arena, arguments, server, id); const page = try ensurePage(server, id, args.url); const elements = lp.interactive.collectInteractiveElements(page.document.asNode(), arena, page) catch |err| { @@ -388,7 +388,7 @@ fn handleInteractiveElements(server: *Server, arena: std.mem.Allocator, id: std. } fn handleStructuredData(server: *Server, arena: std.mem.Allocator, id: std.json.Value, arguments: ?std.json.Value) !void { - const args = parseArgsOrDefault(UrlParams, arena, arguments); + const args = try parseArgsOrDefault(UrlParams, arena, arguments, server, id); const page = try ensurePage(server, id, args.url); const data = lp.structured_data.collectStructuredData(page.document.asNode(), arena, page) catch |err| { @@ -403,7 +403,7 @@ fn handleStructuredData(server: *Server, arena: std.mem.Allocator, id: std.json. } fn handleDetectForms(server: *Server, arena: std.mem.Allocator, id: std.json.Value, arguments: ?std.json.Value) !void { - const args = parseArgsOrDefault(UrlParams, arena, arguments); + const args = try parseArgsOrDefault(UrlParams, arena, arguments, server, id); const page = try ensurePage(server, id, args.url); const forms_data = lp.forms.collectForms(arena, page.document.asNode(), page) catch |err| { @@ -592,13 +592,15 @@ fn ensurePage(server: *Server, id: std.json.Value, url: ?[:0]const u8) !*lp.Page } /// Parses JSON arguments into a given struct type `T`. -/// If the arguments are missing or invalid, it returns a default-initialized `T` (e.g., `.{}`). -/// Use this for tools where all arguments are optional and validation failures should be silently ignored. -fn parseArgsOrDefault(comptime T: type, arena: std.mem.Allocator, arguments: ?std.json.Value) T { - if (arguments) |args_raw| { - return std.json.parseFromValueLeaky(T, arena, args_raw, .{ .ignore_unknown_fields = true }) catch .{}; - } - return .{}; +/// If the arguments are missing, it returns a default-initialized `T` (e.g., `.{}`). +/// If the arguments are present but invalid, it sends an MCP error response and returns `error.InvalidParams`. +/// Use this for tools where all arguments are optional. +fn parseArgsOrDefault(comptime T: type, arena: std.mem.Allocator, arguments: ?std.json.Value, server: *Server, id: std.json.Value) !T { + const args_raw = arguments orelse return .{}; + return std.json.parseFromValueLeaky(T, arena, args_raw, .{ .ignore_unknown_fields = true }) catch { + try server.sendError(id, .InvalidParams, "Invalid arguments"); + return error.InvalidParams; + }; } /// Parses JSON arguments into a given struct type `T`. @@ -606,11 +608,11 @@ fn parseArgsOrDefault(comptime T: type, arena: std.mem.Allocator, arguments: ?st /// and returns an `error.InvalidParams`. /// Use this for tools that require strict validation or mandatory arguments. fn parseArgs(comptime T: type, arena: std.mem.Allocator, arguments: ?std.json.Value, server: *Server, id: std.json.Value, tool_name: []const u8) !T { - if (arguments == null) { + const args_raw = arguments orelse { try server.sendError(id, .InvalidParams, "Missing arguments"); return error.InvalidParams; - } - return std.json.parseFromValueLeaky(T, arena, arguments.?, .{ .ignore_unknown_fields = true }) catch { + }; + return std.json.parseFromValueLeaky(T, arena, args_raw, .{ .ignore_unknown_fields = true }) catch { const msg = std.fmt.allocPrint(arena, "Invalid arguments for {s}", .{tool_name}) catch "Invalid arguments"; try server.sendError(id, .InvalidParams, msg); return error.InvalidParams;