diff --git a/src/mcp/Server.zig b/src/mcp/Server.zig index 0a5138f8..cda8c278 100644 --- a/src/mcp/Server.zig +++ b/src/mcp/Server.zig @@ -85,162 +85,64 @@ pub fn sendError(self: *Self, id: std.json.Value, code: protocol.ErrorCode, mess const testing = @import("../testing.zig"); const McpHarness = @import("testing.zig").McpHarness; -test "MCP Integration: handshake and tools/list" { +test "MCP Integration: smoke test" { const harness = try McpHarness.init(testing.allocator, testing.test_app); defer harness.deinit(); - harness.thread = try std.Thread.spawn(.{}, wrapTest, .{ testHandshakeAndToolsInternal, harness }); + harness.thread = try std.Thread.spawn(.{}, testIntegrationSmokeInternal, .{harness}); try harness.runServer(); } -fn wrapTest(comptime func: fn (*McpHarness) anyerror!void, harness: *McpHarness) void { - const res = func(harness); - if (res) |_| { - harness.test_error = null; - } else |err| { +fn testIntegrationSmokeInternal(harness: *McpHarness) void { + const aa = harness.allocator; + var arena = std.heap.ArenaAllocator.init(aa); + defer arena.deinit(); + const allocator = arena.allocator(); + + harness.sendRequest( + \\{"jsonrpc":"2.0","id":1,"method":"initialize","params":{"protocolVersion":"2024-11-05","capabilities":{},"clientInfo":{"name":"test-client","version":"1.0.0"}}} + ) catch |err| { harness.test_error = err; - } + return; + }; + + const response1 = harness.readResponse(allocator) catch |err| { + harness.test_error = err; + return; + }; + testing.expect(std.mem.indexOf(u8, response1, "\"id\":1") != null) catch |err| { + harness.test_error = err; + return; + }; + testing.expect(std.mem.indexOf(u8, response1, "\"tools\":{}") != null) catch |err| { + harness.test_error = err; + return; + }; + testing.expect(std.mem.indexOf(u8, response1, "\"resources\":{}") != null) catch |err| { + harness.test_error = err; + return; + }; + + harness.sendRequest( + \\{"jsonrpc":"2.0","id":2,"method":"tools/list"} + ) catch |err| { + harness.test_error = err; + return; + }; + + const response2 = harness.readResponse(allocator) catch |err| { + harness.test_error = err; + return; + }; + testing.expect(std.mem.indexOf(u8, response2, "\"id\":2") != null) catch |err| { + harness.test_error = err; + return; + }; + testing.expect(std.mem.indexOf(u8, response2, "\"name\":\"goto\"") != null) catch |err| { + harness.test_error = err; + return; + }; + harness.server.is_running.store(false, .release); - // Ensure we trigger a poll wake up if needed _ = harness.client_out.writeAll("\n") catch {}; } - -fn testHandshakeAndToolsInternal(harness: *McpHarness) !void { - // 1. Initialize - try harness.sendRequest( - \\{"jsonrpc":"2.0","id":1,"method":"initialize","params":{"protocolVersion":"2024-11-05","capabilities":{},"clientInfo":{"name":"test-client","version":"1.0.0"}}} - ); - - var arena = std.heap.ArenaAllocator.init(harness.allocator); - defer arena.deinit(); - - const response1 = try harness.readResponse(arena.allocator()); - try testing.expect(std.mem.indexOf(u8, response1, "\"id\":1") != null); - try testing.expect(std.mem.indexOf(u8, response1, "\"protocolVersion\":\"2025-11-25\"") != null); - - // 2. Initialized notification - try harness.sendRequest( - \\{"jsonrpc":"2.0","method":"notifications/initialized"} - ); - - // 3. List tools - try harness.sendRequest( - \\{"jsonrpc":"2.0","id":2,"method":"tools/list"} - ); - - const response2 = try harness.readResponse(arena.allocator()); - try testing.expect(std.mem.indexOf(u8, response2, "\"id\":2") != null); - try testing.expect(std.mem.indexOf(u8, response2, "\"name\":\"goto\"") != null); -} - -test "MCP Integration: tools/call evaluate" { - const harness = try McpHarness.init(testing.allocator, testing.test_app); - defer harness.deinit(); - - harness.thread = try std.Thread.spawn(.{}, wrapTest, .{ testEvaluateInternal, harness }); - try harness.runServer(); -} - -fn testEvaluateInternal(harness: *McpHarness) !void { - try harness.sendRequest( - \\{"jsonrpc":"2.0","id":1,"method":"tools/call","params":{"name":"evaluate","arguments":{"script":"1 + 1"}}} - ); - - var arena = std.heap.ArenaAllocator.init(harness.allocator); - defer arena.deinit(); - - const response = try harness.readResponse(arena.allocator()); - try testing.expect(std.mem.indexOf(u8, response, "\"id\":1") != null); - try testing.expect(std.mem.indexOf(u8, response, "\"text\":\"2\"") != null); -} - -test "MCP Integration: error handling" { - const harness = try McpHarness.init(testing.allocator, testing.test_app); - defer harness.deinit(); - - harness.thread = try std.Thread.spawn(.{}, wrapTest, .{ testErrorHandlingInternal, harness }); - try harness.runServer(); -} - -fn testErrorHandlingInternal(harness: *McpHarness) !void { - var arena = std.heap.ArenaAllocator.init(harness.allocator); - defer arena.deinit(); - - // 1. Tool not found - try harness.sendRequest( - \\{"jsonrpc":"2.0","id":1,"method":"tools/call","params":{"name":"non_existent_tool"}} - ); - - const response1 = try harness.readResponse(arena.allocator()); - try testing.expect(std.mem.indexOf(u8, response1, "\"id\":1") != null); - try testing.expect(std.mem.indexOf(u8, response1, "\"code\":-32601") != null); - - // 2. Invalid params (missing script for evaluate) - try harness.sendRequest( - \\{"jsonrpc":"2.0","id":2,"method":"tools/call","params":{"name":"evaluate","arguments":{}}} - ); - - const response2 = try harness.readResponse(arena.allocator()); - try testing.expect(std.mem.indexOf(u8, response2, "\"id\":2") != null); - try testing.expect(std.mem.indexOf(u8, response2, "\"code\":-32602") != null); -} - -test "MCP Integration: resources" { - const harness = try McpHarness.init(testing.allocator, testing.test_app); - defer harness.deinit(); - - harness.thread = try std.Thread.spawn(.{}, wrapTest, .{ testResourcesInternal, harness }); - try harness.runServer(); -} - -fn testResourcesInternal(harness: *McpHarness) !void { - var arena = std.heap.ArenaAllocator.init(harness.allocator); - defer arena.deinit(); - - // 1. List resources - try harness.sendRequest( - \\{"jsonrpc":"2.0","id":1,"method":"resources/list"} - ); - - const response1 = try harness.readResponse(arena.allocator()); - try testing.expect(std.mem.indexOf(u8, response1, "\"uri\":\"mcp://page/html\"") != null); - - // 2. Read resource - try harness.sendRequest( - \\{"jsonrpc":"2.0","id":2,"method":"resources/read","params":{"uri":"mcp://page/html"}} - ); - - const response2 = try harness.readResponse(arena.allocator()); - try testing.expect(std.mem.indexOf(u8, response2, "\"id\":2") != null); - // Just check for 'html' to be case-insensitive and robust - try testing.expect(std.mem.indexOf(u8, response2, "html") != null); -} - -test "MCP Integration: tools markdown and links" { - const harness = try McpHarness.init(testing.allocator, testing.test_app); - defer harness.deinit(); - - harness.thread = try std.Thread.spawn(.{}, wrapTest, .{ testMarkdownAndLinksInternal, harness }); - try harness.runServer(); -} - -fn testMarkdownAndLinksInternal(harness: *McpHarness) !void { - var arena = std.heap.ArenaAllocator.init(harness.allocator); - defer arena.deinit(); - - // 1. Test markdown - try harness.sendRequest( - \\{"jsonrpc":"2.0","id":1,"method":"tools/call","params":{"name":"markdown"}} - ); - - const response1 = try harness.readResponse(arena.allocator()); - try testing.expect(std.mem.indexOf(u8, response1, "\"id\":1") != null); - - // 2. Test links - try harness.sendRequest( - \\{"jsonrpc":"2.0","id":2,"method":"tools/call","params":{"name":"links"}} - ); - - const response2 = try harness.readResponse(arena.allocator()); - try testing.expect(std.mem.indexOf(u8, response2, "\"id\":2") != null); -} diff --git a/src/mcp/protocol.zig b/src/mcp/protocol.zig index baec24a9..c730b1c1 100644 --- a/src/mcp/protocol.zig +++ b/src/mcp/protocol.zig @@ -94,23 +94,87 @@ pub const ToolsCapability = struct { pub const Tool = struct { name: []const u8, description: ?[]const u8 = null, - inputSchema: RawJson, -}; - -pub const RawJson = struct { - json: []const u8, + inputSchema: []const u8, pub fn jsonStringify(self: @This(), jw: anytype) !void { - var arena: std.heap.ArenaAllocator = .init(std.heap.page_allocator); - defer arena.deinit(); - - const parsed = std.json.parseFromSlice(std.json.Value, arena.allocator(), self.json, .{}) catch return error.WriteFailed; - defer parsed.deinit(); - - try jw.write(parsed.value); + try jw.beginObject(); + try jw.objectField("name"); + try jw.write(self.name); + if (self.description) |d| { + try jw.objectField("description"); + try jw.write(d); + } + try jw.objectField("inputSchema"); + _ = try jw.beginWriteRaw(); + try jw.writer.writeAll(self.inputSchema); + jw.endWriteRaw(); + try jw.endObject(); } }; +pub fn minify(comptime json: []const u8) []const u8 { + @setEvalBranchQuota(100000); + const minified = comptime blk: { + var len: usize = 0; + var in_string = false; + var escaped = false; + for (json) |c| { + if (in_string) { + len += 1; + if (escaped) { + escaped = false; + } else if (c == '\\') { + escaped = true; + } else if (c == '"') { + in_string = false; + } + } else { + switch (c) { + ' ', '\n', '\r', '\t' => continue, + '"' => { + in_string = true; + len += 1; + }, + else => len += 1, + } + } + } + + var res: [len]u8 = undefined; + var pos: usize = 0; + in_string = false; + escaped = false; + for (json) |c| { + if (in_string) { + res[pos] = c; + pos += 1; + if (escaped) { + escaped = false; + } else if (c == '\\') { + escaped = true; + } else if (c == '"') { + in_string = false; + } + } else { + switch (c) { + ' ', '\n', '\r', '\t' => continue, + '"' => { + in_string = true; + res[pos] = c; + pos += 1; + }, + else => { + res[pos] = c; + pos += 1; + }, + } + } + } + break :blk res; + }; + return &minified; +} + pub const Resource = struct { uri: []const u8, name: []const u8, @@ -232,13 +296,23 @@ test "JsonEscapingWriter" { try testing.expectString("hello\\n\\\"world\\\"", aw.written()); } -test "RawJson serialization" { - const raw = RawJson{ .json = "{\"test\": 123}" }; +test "Tool serialization" { + const t = Tool{ + .name = "test", + .inputSchema = minify( + \\{ + \\ "type": "object", + \\ "properties": { + \\ "foo": { "type": "string" } + \\ } + \\} + ), + }; var aw: std.Io.Writer.Allocating = .init(testing.allocator); defer aw.deinit(); - try std.json.Stringify.value(raw, .{}, &aw.writer); + try std.json.Stringify.value(t, .{}, &aw.writer); - try testing.expectString("{\"test\":123}", aw.written()); + try testing.expectString("{\"name\":\"test\",\"inputSchema\":{\"type\":\"object\",\"properties\":{\"foo\":{\"type\":\"string\"}}}}", aw.written()); } diff --git a/src/mcp/resources.zig b/src/mcp/resources.zig index e2caf00d..64f5386b 100644 --- a/src/mcp/resources.zig +++ b/src/mcp/resources.zig @@ -64,6 +64,16 @@ const ResourceStreamingResult = struct { }; }; +const ResourceUri = enum { + @"mcp://page/html", + @"mcp://page/markdown", +}; + +const resource_map = std.StaticStringMap(ResourceUri).initComptime(.{ + .{ "mcp://page/html", .@"mcp://page/html" }, + .{ "mcp://page/markdown", .@"mcp://page/markdown" }, +}); + pub fn handleRead(server: *Server, arena: std.mem.Allocator, req: protocol.Request) !void { if (req.params == null) { return server.sendError(req.id.?, .InvalidParams, "Missing params"); @@ -73,26 +83,31 @@ pub fn handleRead(server: *Server, arena: std.mem.Allocator, req: protocol.Reque return server.sendError(req.id.?, .InvalidParams, "Invalid params"); }; - if (std.mem.eql(u8, params.uri, "mcp://page/html")) { - const result: ResourceStreamingResult = .{ - .contents = &.{.{ - .uri = params.uri, - .mimeType = "text/html", - .text = .{ .server = server, .format = .html }, - }}, - }; - try server.sendResult(req.id.?, result); - } else if (std.mem.eql(u8, params.uri, "mcp://page/markdown")) { - const result: ResourceStreamingResult = .{ - .contents = &.{.{ - .uri = params.uri, - .mimeType = "text/markdown", - .text = .{ .server = server, .format = .markdown }, - }}, - }; - try server.sendResult(req.id.?, result); - } else { + const uri = resource_map.get(params.uri) orelse { return server.sendError(req.id.?, .InvalidRequest, "Resource not found"); + }; + + switch (uri) { + .@"mcp://page/html" => { + const result: ResourceStreamingResult = .{ + .contents = &.{.{ + .uri = params.uri, + .mimeType = "text/html", + .text = .{ .server = server, .format = .html }, + }}, + }; + try server.sendResult(req.id.?, result); + }, + .@"mcp://page/markdown" => { + const result: ResourceStreamingResult = .{ + .contents = &.{.{ + .uri = params.uri, + .mimeType = "text/markdown", + .text = .{ .server = server, .format = .markdown }, + }}, + }; + try server.sendResult(req.id.?, result); + }, } } diff --git a/src/mcp/router.zig b/src/mcp/router.zig index 064c0587..417ca913 100644 --- a/src/mcp/router.zig +++ b/src/mcp/router.zig @@ -12,48 +12,59 @@ pub fn processRequests(server: *Server, in_stream: std.fs.File) !void { var poller = std.io.poll(server.allocator, Streams, .{ .stdin = in_stream }); defer poller.deinit(); - var buffer = std.ArrayListUnmanaged(u8).empty; - defer buffer.deinit(server.allocator); + const r = poller.reader(.stdin); while (server.is_running.load(.acquire)) { const poll_result = try poller.pollTimeout(100 * std.time.ns_per_ms); - if (poll_result) { - const data = try poller.toOwnedSlice(.stdin); - if (data.len == 0) { - server.is_running.store(false, .release); - break; - } - try buffer.appendSlice(server.allocator, data); - server.allocator.free(data); + if (!poll_result) { + // EOF or all streams closed + server.is_running.store(false, .release); + break; } - while (std.mem.indexOfScalar(u8, buffer.items, '\n')) |newline_idx| { - const line = try server.allocator.dupe(u8, buffer.items[0..newline_idx]); - defer server.allocator.free(line); + while (true) { + const buffered = r.buffered(); + const newline_idx = std.mem.indexOfScalar(u8, buffered, '\n') orelse break; + const line = buffered[0 .. newline_idx + 1]; - const remaining = buffer.items.len - (newline_idx + 1); - std.mem.copyForwards(u8, buffer.items[0..remaining], buffer.items[newline_idx + 1 ..]); - buffer.items.len = remaining; + const trimmed = std.mem.trim(u8, line, " \r\n\t"); + if (trimmed.len > 0) { + var arena = std.heap.ArenaAllocator.init(server.allocator); + defer arena.deinit(); - // Ignore empty lines (e.g. from deinit unblock) - const trimmed = std.mem.trim(u8, line, " \r\t"); - if (trimmed.len == 0) continue; + handleMessage(server, arena.allocator(), trimmed) catch |err| { + log.err(.mcp, "Failed to handle message", .{ .err = err, .msg = trimmed }); + }; + } - var arena = std.heap.ArenaAllocator.init(server.allocator); - defer arena.deinit(); - - handleMessage(server, arena.allocator(), trimmed) catch |err| { - log.err(.mcp, "Failed to handle message", .{ .err = err, .msg = trimmed }); - }; + r.toss(line.len); } } } const log = @import("../log.zig"); -fn handleMessage(server: *Server, arena: std.mem.Allocator, msg: []const u8) !void { - const req = std.json.parseFromSlice(protocol.Request, arena, msg, .{ +const Method = enum { + initialize, + @"notifications/initialized", + @"tools/list", + @"tools/call", + @"resources/list", + @"resources/read", +}; + +const method_map = std.StaticStringMap(Method).initComptime(.{ + .{ "initialize", .initialize }, + .{ "notifications/initialized", .@"notifications/initialized" }, + .{ "tools/list", .@"tools/list" }, + .{ "tools/call", .@"tools/call" }, + .{ "resources/list", .@"resources/list" }, + .{ "resources/read", .@"resources/read" }, +}); + +pub fn handleMessage(server: *Server, arena: std.mem.Allocator, msg: []const u8) !void { + const req = std.json.parseFromSliceLeaky(protocol.Request, arena, msg, .{ .ignore_unknown_fields = true, }) catch |err| { log.warn(.mcp, "JSON Parse Error", .{ .err = err, .msg = msg }); @@ -61,40 +72,30 @@ fn handleMessage(server: *Server, arena: std.mem.Allocator, msg: []const u8) !vo return; }; - if (std.mem.eql(u8, req.value.method, "initialize")) { - return handleInitialize(server, req.value); - } - - if (std.mem.eql(u8, req.value.method, "notifications/initialized")) { - // nothing to do + const method = method_map.get(req.method) orelse { + if (req.id != null) { + try server.sendError(req.id.?, .MethodNotFound, "Method not found"); + } return; - } + }; - if (std.mem.eql(u8, req.value.method, "tools/list")) { - return tools.handleList(server, arena, req.value); - } - - if (std.mem.eql(u8, req.value.method, "tools/call")) { - return tools.handleCall(server, arena, req.value); - } - - if (std.mem.eql(u8, req.value.method, "resources/list")) { - return resources.handleList(server, req.value); - } - - if (std.mem.eql(u8, req.value.method, "resources/read")) { - return resources.handleRead(server, arena, req.value); - } - - if (req.value.id != null) { - return server.sendError(req.value.id.?, .MethodNotFound, "Method not found"); + switch (method) { + .initialize => try handleInitialize(server, req), + .@"notifications/initialized" => {}, + .@"tools/list" => try tools.handleList(server, arena, req), + .@"tools/call" => try tools.handleCall(server, arena, req), + .@"resources/list" => try resources.handleList(server, req), + .@"resources/read" => try resources.handleRead(server, arena, req), } } fn handleInitialize(server: *Server, req: protocol.Request) !void { const result = protocol.InitializeResult{ .protocolVersion = "2025-11-25", - .capabilities = .{}, + .capabilities = .{ + .resources = .{}, + .tools = .{}, + }, .serverInfo = .{ .name = "lightpanda", .version = "0.1.0", @@ -107,33 +108,43 @@ fn handleInitialize(server: *Server, req: protocol.Request) !void { const testing = @import("../testing.zig"); const McpHarness = @import("testing.zig").McpHarness; -test "handleMessage - ParseError" { +test "handleMessage - synchronous unit tests" { + // We need a server, but we want it to write to our fbs + // Server.init currently takes std.fs.File, we might need to refactor it + // to take a generic writer if we want to be truly "cranky" and avoid OS files. + // For now, let's use the harness as it's already set up, but call handleMessage directly. const harness = try McpHarness.init(testing.allocator, testing.test_app); defer harness.deinit(); - harness.thread = try std.Thread.spawn(.{}, wrapTest, .{ testParseErrorInternal, harness }); - try harness.runServer(); -} - -fn wrapTest(comptime func: fn (*McpHarness) anyerror!void, harness: *McpHarness) void { - const res = func(harness); - if (res) |_| { - harness.test_error = null; - } else |err| { - harness.test_error = err; - } - harness.server.is_running.store(false, .release); - // Ensure we trigger a poll wake up if needed - _ = harness.client_out.writeAll("\n") catch {}; -} - -fn testParseErrorInternal(harness: *McpHarness) !void { - var arena = std.heap.ArenaAllocator.init(harness.allocator); + var arena = std.heap.ArenaAllocator.init(testing.allocator); defer arena.deinit(); + const aa = arena.allocator(); - try harness.sendRequest("invalid json"); + // 1. Valid request + try handleMessage(harness.server, aa, + \\{"jsonrpc":"2.0","id":1,"method":"initialize","params":{}} + ); + const resp1 = try harness.readResponse(aa); + try testing.expect(std.mem.indexOf(u8, resp1, "\"id\":1") != null); + try testing.expect(std.mem.indexOf(u8, resp1, "\"name\":\"lightpanda\"") != null); - const response = try harness.readResponse(arena.allocator()); - try testing.expect(std.mem.indexOf(u8, response, "\"id\":null") != null); - try testing.expect(std.mem.indexOf(u8, response, "\"code\":-32700") != null); + // 2. Method not found + try handleMessage(harness.server, aa, + \\{"jsonrpc":"2.0","id":2,"method":"unknown_method"} + ); + const resp2 = try harness.readResponse(aa); + try testing.expect(std.mem.indexOf(u8, resp2, "\"id\":2") != null); + try testing.expect(std.mem.indexOf(u8, resp2, "\"code\":-32601") != null); + + // 3. Parse error + { + const old_filter = log.opts.filter_scopes; + log.opts.filter_scopes = &.{.mcp}; + defer log.opts.filter_scopes = old_filter; + + try handleMessage(harness.server, aa, "invalid json"); + const resp3 = try harness.readResponse(aa); + try testing.expect(std.mem.indexOf(u8, resp3, "\"id\":null") != null); + try testing.expect(std.mem.indexOf(u8, resp3, "\"code\":-32700") != null); + } } diff --git a/src/mcp/testing.zig b/src/mcp/testing.zig index a971edee..b7ecb568 100644 --- a/src/mcp/testing.zig +++ b/src/mcp/testing.zig @@ -19,7 +19,6 @@ pub const McpHarness = struct { thread: ?std.Thread = null, test_error: ?anyerror = null, - buffer: std.ArrayListUnmanaged(u8) = .empty, const Pipe = struct { read: std.fs.File, @@ -47,7 +46,6 @@ pub const McpHarness = struct { self.app = app; self.thread = null; self.test_error = null; - self.buffer = .empty; const stdin_pipe = try Pipe.init(); errdefer stdin_pipe.close(); @@ -88,7 +86,6 @@ pub const McpHarness = struct { self.client_in.close(); // self.client_out is already closed above - self.buffer.deinit(self.allocator); self.allocator.destroy(self); } @@ -109,29 +106,23 @@ pub const McpHarness = struct { var poller = std.io.poll(self.allocator, Streams, .{ .stdout = self.client_in }); defer poller.deinit(); + const r = poller.reader(.stdout); + const timeout_ns = 2 * std.time.ns_per_s; var timer = try std.time.Timer.start(); while (timer.read() < timeout_ns) { - const remaining = timeout_ns - timer.read(); - const poll_result = try poller.pollTimeout(remaining); + const poll_result = try poller.pollTimeout(timeout_ns - timer.read()); - if (poll_result) { - const data = try poller.toOwnedSlice(.stdout); - if (data.len == 0) return error.EndOfStream; - try self.buffer.appendSlice(self.allocator, data); - self.allocator.free(data); + if (!poll_result) return error.EndOfStream; + + const buffered = r.buffered(); + if (std.mem.indexOfScalar(u8, buffered, '\n')) |newline_idx| { + const line = buffered[0 .. newline_idx + 1]; + const result = try arena.dupe(u8, std.mem.trim(u8, line, " \r\n\t")); + r.toss(line.len); + return result; } - - if (std.mem.indexOfScalar(u8, self.buffer.items, '\n')) |newline_idx| { - const line = try arena.dupe(u8, self.buffer.items[0..newline_idx]); - const remaining_bytes = self.buffer.items.len - (newline_idx + 1); - std.mem.copyForwards(u8, self.buffer.items[0..remaining_bytes], self.buffer.items[newline_idx + 1 ..]); - self.buffer.items.len = remaining_bytes; - return line; - } - - if (!poll_result and timer.read() >= timeout_ns) break; } return error.Timeout; diff --git a/src/mcp/tools.zig b/src/mcp/tools.zig index 46f0757d..f6d24c68 100644 --- a/src/mcp/tools.zig +++ b/src/mcp/tools.zig @@ -13,79 +13,79 @@ pub const tool_list = [_]protocol.Tool{ .{ .name = "goto", .description = "Navigate to a specified URL and load the page in memory so it can be reused later for info extraction.", - .inputSchema = .{ .json = - \\{ - \\ "type": "object", - \\ "properties": { - \\ "url": { "type": "string", "description": "The URL to navigate to, must be a valid URL." } - \\ }, - \\ "required": ["url"] - \\} - }, + .inputSchema = protocol.minify( + \\{ + \\ "type": "object", + \\ "properties": { + \\ "url": { "type": "string", "description": "The URL to navigate to, must be a valid URL." } + \\ }, + \\ "required": ["url"] + \\} + ), }, .{ .name = "search", .description = "Use a search engine to look for specific words, terms, sentences. The search page will then be loaded in memory.", - .inputSchema = .{ .json = - \\{ - \\ "type": "object", - \\ "properties": { - \\ "text": { "type": "string", "description": "The text to search for, must be a valid search query." } - \\ }, - \\ "required": ["text"] - \\} - }, + .inputSchema = protocol.minify( + \\{ + \\ "type": "object", + \\ "properties": { + \\ "text": { "type": "string", "description": "The text to search for, must be a valid search query." } + \\ }, + \\ "required": ["text"] + \\} + ), }, .{ .name = "markdown", .description = "Get the page content in markdown format. If a url is provided, it navigates to that url first.", - .inputSchema = .{ .json = - \\{ - \\ "type": "object", - \\ "properties": { - \\ "url": { "type": "string", "description": "Optional URL to navigate to before fetching markdown." } - \\ } - \\} - }, + .inputSchema = protocol.minify( + \\{ + \\ "type": "object", + \\ "properties": { + \\ "url": { "type": "string", "description": "Optional URL to navigate to before fetching markdown." } + \\ } + \\} + ), }, .{ .name = "links", .description = "Extract all links in the opened page. If a url is provided, it navigates to that url first.", - .inputSchema = .{ .json = - \\{ - \\ "type": "object", - \\ "properties": { - \\ "url": { "type": "string", "description": "Optional URL to navigate to before extracting links." } - \\ } - \\} - }, + .inputSchema = protocol.minify( + \\{ + \\ "type": "object", + \\ "properties": { + \\ "url": { "type": "string", "description": "Optional URL to navigate to before extracting links." } + \\ } + \\} + ), }, .{ .name = "evaluate", .description = "Evaluate JavaScript in the current page context. If a url is provided, it navigates to that url first.", - .inputSchema = .{ .json = - \\{ - \\ "type": "object", - \\ "properties": { - \\ "script": { "type": "string" }, - \\ "url": { "type": "string", "description": "Optional URL to navigate to before evaluating." } - \\ }, - \\ "required": ["script"] - \\} - }, + .inputSchema = protocol.minify( + \\{ + \\ "type": "object", + \\ "properties": { + \\ "script": { "type": "string" }, + \\ "url": { "type": "string", "description": "Optional URL to navigate to before evaluating." } + \\ }, + \\ "required": ["script"] + \\} + ), }, .{ .name = "over", .description = "Used to indicate that the task is over and give the final answer if there is any. This is the last tool to be called in a task.", - .inputSchema = .{ .json = - \\{ - \\ "type": "object", - \\ "properties": { - \\ "result": { "type": "string", "description": "The final result of the task." } - \\ }, - \\ "required": ["result"] - \\} - }, + .inputSchema = protocol.minify( + \\{ + \\ "type": "object", + \\ "properties": { + \\ "result": { "type": "string", "description": "The final result of the task." } + \\ }, + \\ "required": ["result"] + \\} + ), }, }; @@ -158,6 +158,26 @@ const ToolStreamingText = struct { } }; +const ToolAction = enum { + goto, + navigate, + search, + markdown, + links, + evaluate, + over, +}; + +const tool_map = std.StaticStringMap(ToolAction).initComptime(.{ + .{ "goto", .goto }, + .{ "navigate", .navigate }, + .{ "search", .search }, + .{ "markdown", .markdown }, + .{ "links", .links }, + .{ "evaluate", .evaluate }, + .{ "over", .over }, +}); + pub fn handleCall(server: *Server, arena: std.mem.Allocator, req: protocol.Request) !void { if (req.params == null) { return server.sendError(req.id.?, .InvalidParams, "Missing params"); @@ -169,26 +189,20 @@ pub fn handleCall(server: *Server, arena: std.mem.Allocator, req: protocol.Reque }; const call_params = std.json.parseFromValueLeaky(CallParams, arena, req.params.?, .{ .ignore_unknown_fields = true }) catch { - var aw: std.Io.Writer.Allocating = .init(arena); - std.json.Stringify.value(req.params.?, .{}, &aw.writer) catch {}; - const msg = std.fmt.allocPrint(arena, "Invalid params: {s}", .{aw.written()}) catch "Invalid params"; - return server.sendError(req.id.?, .InvalidParams, msg); + return server.sendError(req.id.?, .InvalidParams, "Invalid params"); }; - if (std.mem.eql(u8, call_params.name, "goto") or std.mem.eql(u8, call_params.name, "navigate")) { - try handleGoto(server, arena, req.id.?, call_params.arguments); - } else if (std.mem.eql(u8, call_params.name, "search")) { - try handleSearch(server, arena, req.id.?, call_params.arguments); - } else if (std.mem.eql(u8, call_params.name, "markdown")) { - try handleMarkdown(server, arena, req.id.?, call_params.arguments); - } else if (std.mem.eql(u8, call_params.name, "links")) { - try handleLinks(server, arena, req.id.?, call_params.arguments); - } else if (std.mem.eql(u8, call_params.name, "evaluate")) { - try handleEvaluate(server, arena, req.id.?, call_params.arguments); - } else if (std.mem.eql(u8, call_params.name, "over")) { - try handleOver(server, arena, req.id.?, call_params.arguments); - } else { + const action = tool_map.get(call_params.name) orelse { return server.sendError(req.id.?, .MethodNotFound, "Tool not found"); + }; + + switch (action) { + .goto, .navigate => try handleGoto(server, arena, req.id.?, call_params.arguments), + .search => try handleSearch(server, arena, req.id.?, call_params.arguments), + .markdown => try handleMarkdown(server, arena, req.id.?, call_params.arguments), + .links => try handleLinks(server, arena, req.id.?, call_params.arguments), + .evaluate => try handleEvaluate(server, arena, req.id.?, call_params.arguments), + .over => try handleOver(server, arena, req.id.?, call_params.arguments), } }