mcp: optimize dispatching and simplify test harness

- Use StaticStringMap and enums for method, tool, and resource lookups.
- Implement comptime JSON minification for tool schemas.
- Refactor router and harness to use more efficient buffered polling.
- Consolidate integration tests and add synchronous unit tests.
This commit is contained in:
Adrià Arrufat
2026-03-02 20:53:14 +09:00
parent a7872aa054
commit 73565c4493
6 changed files with 357 additions and 350 deletions

View File

@@ -85,162 +85,64 @@ pub fn sendError(self: *Self, id: std.json.Value, code: protocol.ErrorCode, mess
const testing = @import("../testing.zig"); const testing = @import("../testing.zig");
const McpHarness = @import("testing.zig").McpHarness; const McpHarness = @import("testing.zig").McpHarness;
test "MCP Integration: handshake and tools/list" { test "MCP Integration: smoke test" {
const harness = try McpHarness.init(testing.allocator, testing.test_app); const harness = try McpHarness.init(testing.allocator, testing.test_app);
defer harness.deinit(); defer harness.deinit();
harness.thread = try std.Thread.spawn(.{}, wrapTest, .{ testHandshakeAndToolsInternal, harness }); harness.thread = try std.Thread.spawn(.{}, testIntegrationSmokeInternal, .{harness});
try harness.runServer(); try harness.runServer();
} }
fn wrapTest(comptime func: fn (*McpHarness) anyerror!void, harness: *McpHarness) void { fn testIntegrationSmokeInternal(harness: *McpHarness) void {
const res = func(harness); const aa = harness.allocator;
if (res) |_| { var arena = std.heap.ArenaAllocator.init(aa);
harness.test_error = null; defer arena.deinit();
} else |err| { const allocator = arena.allocator();
harness.sendRequest(
\\{"jsonrpc":"2.0","id":1,"method":"initialize","params":{"protocolVersion":"2024-11-05","capabilities":{},"clientInfo":{"name":"test-client","version":"1.0.0"}}}
) catch |err| {
harness.test_error = err; harness.test_error = err;
} return;
};
const response1 = harness.readResponse(allocator) catch |err| {
harness.test_error = err;
return;
};
testing.expect(std.mem.indexOf(u8, response1, "\"id\":1") != null) catch |err| {
harness.test_error = err;
return;
};
testing.expect(std.mem.indexOf(u8, response1, "\"tools\":{}") != null) catch |err| {
harness.test_error = err;
return;
};
testing.expect(std.mem.indexOf(u8, response1, "\"resources\":{}") != null) catch |err| {
harness.test_error = err;
return;
};
harness.sendRequest(
\\{"jsonrpc":"2.0","id":2,"method":"tools/list"}
) catch |err| {
harness.test_error = err;
return;
};
const response2 = harness.readResponse(allocator) catch |err| {
harness.test_error = err;
return;
};
testing.expect(std.mem.indexOf(u8, response2, "\"id\":2") != null) catch |err| {
harness.test_error = err;
return;
};
testing.expect(std.mem.indexOf(u8, response2, "\"name\":\"goto\"") != null) catch |err| {
harness.test_error = err;
return;
};
harness.server.is_running.store(false, .release); harness.server.is_running.store(false, .release);
// Ensure we trigger a poll wake up if needed
_ = harness.client_out.writeAll("\n") catch {}; _ = harness.client_out.writeAll("\n") catch {};
} }
fn testHandshakeAndToolsInternal(harness: *McpHarness) !void {
// 1. Initialize
try harness.sendRequest(
\\{"jsonrpc":"2.0","id":1,"method":"initialize","params":{"protocolVersion":"2024-11-05","capabilities":{},"clientInfo":{"name":"test-client","version":"1.0.0"}}}
);
var arena = std.heap.ArenaAllocator.init(harness.allocator);
defer arena.deinit();
const response1 = try harness.readResponse(arena.allocator());
try testing.expect(std.mem.indexOf(u8, response1, "\"id\":1") != null);
try testing.expect(std.mem.indexOf(u8, response1, "\"protocolVersion\":\"2025-11-25\"") != null);
// 2. Initialized notification
try harness.sendRequest(
\\{"jsonrpc":"2.0","method":"notifications/initialized"}
);
// 3. List tools
try harness.sendRequest(
\\{"jsonrpc":"2.0","id":2,"method":"tools/list"}
);
const response2 = try harness.readResponse(arena.allocator());
try testing.expect(std.mem.indexOf(u8, response2, "\"id\":2") != null);
try testing.expect(std.mem.indexOf(u8, response2, "\"name\":\"goto\"") != null);
}
test "MCP Integration: tools/call evaluate" {
const harness = try McpHarness.init(testing.allocator, testing.test_app);
defer harness.deinit();
harness.thread = try std.Thread.spawn(.{}, wrapTest, .{ testEvaluateInternal, harness });
try harness.runServer();
}
fn testEvaluateInternal(harness: *McpHarness) !void {
try harness.sendRequest(
\\{"jsonrpc":"2.0","id":1,"method":"tools/call","params":{"name":"evaluate","arguments":{"script":"1 + 1"}}}
);
var arena = std.heap.ArenaAllocator.init(harness.allocator);
defer arena.deinit();
const response = try harness.readResponse(arena.allocator());
try testing.expect(std.mem.indexOf(u8, response, "\"id\":1") != null);
try testing.expect(std.mem.indexOf(u8, response, "\"text\":\"2\"") != null);
}
test "MCP Integration: error handling" {
const harness = try McpHarness.init(testing.allocator, testing.test_app);
defer harness.deinit();
harness.thread = try std.Thread.spawn(.{}, wrapTest, .{ testErrorHandlingInternal, harness });
try harness.runServer();
}
fn testErrorHandlingInternal(harness: *McpHarness) !void {
var arena = std.heap.ArenaAllocator.init(harness.allocator);
defer arena.deinit();
// 1. Tool not found
try harness.sendRequest(
\\{"jsonrpc":"2.0","id":1,"method":"tools/call","params":{"name":"non_existent_tool"}}
);
const response1 = try harness.readResponse(arena.allocator());
try testing.expect(std.mem.indexOf(u8, response1, "\"id\":1") != null);
try testing.expect(std.mem.indexOf(u8, response1, "\"code\":-32601") != null);
// 2. Invalid params (missing script for evaluate)
try harness.sendRequest(
\\{"jsonrpc":"2.0","id":2,"method":"tools/call","params":{"name":"evaluate","arguments":{}}}
);
const response2 = try harness.readResponse(arena.allocator());
try testing.expect(std.mem.indexOf(u8, response2, "\"id\":2") != null);
try testing.expect(std.mem.indexOf(u8, response2, "\"code\":-32602") != null);
}
test "MCP Integration: resources" {
const harness = try McpHarness.init(testing.allocator, testing.test_app);
defer harness.deinit();
harness.thread = try std.Thread.spawn(.{}, wrapTest, .{ testResourcesInternal, harness });
try harness.runServer();
}
fn testResourcesInternal(harness: *McpHarness) !void {
var arena = std.heap.ArenaAllocator.init(harness.allocator);
defer arena.deinit();
// 1. List resources
try harness.sendRequest(
\\{"jsonrpc":"2.0","id":1,"method":"resources/list"}
);
const response1 = try harness.readResponse(arena.allocator());
try testing.expect(std.mem.indexOf(u8, response1, "\"uri\":\"mcp://page/html\"") != null);
// 2. Read resource
try harness.sendRequest(
\\{"jsonrpc":"2.0","id":2,"method":"resources/read","params":{"uri":"mcp://page/html"}}
);
const response2 = try harness.readResponse(arena.allocator());
try testing.expect(std.mem.indexOf(u8, response2, "\"id\":2") != null);
// Just check for 'html' to be case-insensitive and robust
try testing.expect(std.mem.indexOf(u8, response2, "html") != null);
}
test "MCP Integration: tools markdown and links" {
const harness = try McpHarness.init(testing.allocator, testing.test_app);
defer harness.deinit();
harness.thread = try std.Thread.spawn(.{}, wrapTest, .{ testMarkdownAndLinksInternal, harness });
try harness.runServer();
}
fn testMarkdownAndLinksInternal(harness: *McpHarness) !void {
var arena = std.heap.ArenaAllocator.init(harness.allocator);
defer arena.deinit();
// 1. Test markdown
try harness.sendRequest(
\\{"jsonrpc":"2.0","id":1,"method":"tools/call","params":{"name":"markdown"}}
);
const response1 = try harness.readResponse(arena.allocator());
try testing.expect(std.mem.indexOf(u8, response1, "\"id\":1") != null);
// 2. Test links
try harness.sendRequest(
\\{"jsonrpc":"2.0","id":2,"method":"tools/call","params":{"name":"links"}}
);
const response2 = try harness.readResponse(arena.allocator());
try testing.expect(std.mem.indexOf(u8, response2, "\"id\":2") != null);
}

View File

@@ -94,23 +94,87 @@ pub const ToolsCapability = struct {
pub const Tool = struct { pub const Tool = struct {
name: []const u8, name: []const u8,
description: ?[]const u8 = null, description: ?[]const u8 = null,
inputSchema: RawJson, inputSchema: []const u8,
};
pub const RawJson = struct {
json: []const u8,
pub fn jsonStringify(self: @This(), jw: anytype) !void { pub fn jsonStringify(self: @This(), jw: anytype) !void {
var arena: std.heap.ArenaAllocator = .init(std.heap.page_allocator); try jw.beginObject();
defer arena.deinit(); try jw.objectField("name");
try jw.write(self.name);
const parsed = std.json.parseFromSlice(std.json.Value, arena.allocator(), self.json, .{}) catch return error.WriteFailed; if (self.description) |d| {
defer parsed.deinit(); try jw.objectField("description");
try jw.write(d);
try jw.write(parsed.value); }
try jw.objectField("inputSchema");
_ = try jw.beginWriteRaw();
try jw.writer.writeAll(self.inputSchema);
jw.endWriteRaw();
try jw.endObject();
} }
}; };
pub fn minify(comptime json: []const u8) []const u8 {
@setEvalBranchQuota(100000);
const minified = comptime blk: {
var len: usize = 0;
var in_string = false;
var escaped = false;
for (json) |c| {
if (in_string) {
len += 1;
if (escaped) {
escaped = false;
} else if (c == '\\') {
escaped = true;
} else if (c == '"') {
in_string = false;
}
} else {
switch (c) {
' ', '\n', '\r', '\t' => continue,
'"' => {
in_string = true;
len += 1;
},
else => len += 1,
}
}
}
var res: [len]u8 = undefined;
var pos: usize = 0;
in_string = false;
escaped = false;
for (json) |c| {
if (in_string) {
res[pos] = c;
pos += 1;
if (escaped) {
escaped = false;
} else if (c == '\\') {
escaped = true;
} else if (c == '"') {
in_string = false;
}
} else {
switch (c) {
' ', '\n', '\r', '\t' => continue,
'"' => {
in_string = true;
res[pos] = c;
pos += 1;
},
else => {
res[pos] = c;
pos += 1;
},
}
}
}
break :blk res;
};
return &minified;
}
pub const Resource = struct { pub const Resource = struct {
uri: []const u8, uri: []const u8,
name: []const u8, name: []const u8,
@@ -232,13 +296,23 @@ test "JsonEscapingWriter" {
try testing.expectString("hello\\n\\\"world\\\"", aw.written()); try testing.expectString("hello\\n\\\"world\\\"", aw.written());
} }
test "RawJson serialization" { test "Tool serialization" {
const raw = RawJson{ .json = "{\"test\": 123}" }; const t = Tool{
.name = "test",
.inputSchema = minify(
\\{
\\ "type": "object",
\\ "properties": {
\\ "foo": { "type": "string" }
\\ }
\\}
),
};
var aw: std.Io.Writer.Allocating = .init(testing.allocator); var aw: std.Io.Writer.Allocating = .init(testing.allocator);
defer aw.deinit(); defer aw.deinit();
try std.json.Stringify.value(raw, .{}, &aw.writer); try std.json.Stringify.value(t, .{}, &aw.writer);
try testing.expectString("{\"test\":123}", aw.written()); try testing.expectString("{\"name\":\"test\",\"inputSchema\":{\"type\":\"object\",\"properties\":{\"foo\":{\"type\":\"string\"}}}}", aw.written());
} }

View File

@@ -64,6 +64,16 @@ const ResourceStreamingResult = struct {
}; };
}; };
const ResourceUri = enum {
@"mcp://page/html",
@"mcp://page/markdown",
};
const resource_map = std.StaticStringMap(ResourceUri).initComptime(.{
.{ "mcp://page/html", .@"mcp://page/html" },
.{ "mcp://page/markdown", .@"mcp://page/markdown" },
});
pub fn handleRead(server: *Server, arena: std.mem.Allocator, req: protocol.Request) !void { pub fn handleRead(server: *Server, arena: std.mem.Allocator, req: protocol.Request) !void {
if (req.params == null) { if (req.params == null) {
return server.sendError(req.id.?, .InvalidParams, "Missing params"); return server.sendError(req.id.?, .InvalidParams, "Missing params");
@@ -73,7 +83,12 @@ pub fn handleRead(server: *Server, arena: std.mem.Allocator, req: protocol.Reque
return server.sendError(req.id.?, .InvalidParams, "Invalid params"); return server.sendError(req.id.?, .InvalidParams, "Invalid params");
}; };
if (std.mem.eql(u8, params.uri, "mcp://page/html")) { const uri = resource_map.get(params.uri) orelse {
return server.sendError(req.id.?, .InvalidRequest, "Resource not found");
};
switch (uri) {
.@"mcp://page/html" => {
const result: ResourceStreamingResult = .{ const result: ResourceStreamingResult = .{
.contents = &.{.{ .contents = &.{.{
.uri = params.uri, .uri = params.uri,
@@ -82,7 +97,8 @@ pub fn handleRead(server: *Server, arena: std.mem.Allocator, req: protocol.Reque
}}, }},
}; };
try server.sendResult(req.id.?, result); try server.sendResult(req.id.?, result);
} else if (std.mem.eql(u8, params.uri, "mcp://page/markdown")) { },
.@"mcp://page/markdown" => {
const result: ResourceStreamingResult = .{ const result: ResourceStreamingResult = .{
.contents = &.{.{ .contents = &.{.{
.uri = params.uri, .uri = params.uri,
@@ -91,8 +107,7 @@ pub fn handleRead(server: *Server, arena: std.mem.Allocator, req: protocol.Reque
}}, }},
}; };
try server.sendResult(req.id.?, result); try server.sendResult(req.id.?, result);
} else { },
return server.sendError(req.id.?, .InvalidRequest, "Resource not found");
} }
} }

View File

@@ -12,34 +12,24 @@ pub fn processRequests(server: *Server, in_stream: std.fs.File) !void {
var poller = std.io.poll(server.allocator, Streams, .{ .stdin = in_stream }); var poller = std.io.poll(server.allocator, Streams, .{ .stdin = in_stream });
defer poller.deinit(); defer poller.deinit();
var buffer = std.ArrayListUnmanaged(u8).empty; const r = poller.reader(.stdin);
defer buffer.deinit(server.allocator);
while (server.is_running.load(.acquire)) { while (server.is_running.load(.acquire)) {
const poll_result = try poller.pollTimeout(100 * std.time.ns_per_ms); const poll_result = try poller.pollTimeout(100 * std.time.ns_per_ms);
if (poll_result) { if (!poll_result) {
const data = try poller.toOwnedSlice(.stdin); // EOF or all streams closed
if (data.len == 0) {
server.is_running.store(false, .release); server.is_running.store(false, .release);
break; break;
} }
try buffer.appendSlice(server.allocator, data);
server.allocator.free(data);
}
while (std.mem.indexOfScalar(u8, buffer.items, '\n')) |newline_idx| { while (true) {
const line = try server.allocator.dupe(u8, buffer.items[0..newline_idx]); const buffered = r.buffered();
defer server.allocator.free(line); const newline_idx = std.mem.indexOfScalar(u8, buffered, '\n') orelse break;
const line = buffered[0 .. newline_idx + 1];
const remaining = buffer.items.len - (newline_idx + 1);
std.mem.copyForwards(u8, buffer.items[0..remaining], buffer.items[newline_idx + 1 ..]);
buffer.items.len = remaining;
// Ignore empty lines (e.g. from deinit unblock)
const trimmed = std.mem.trim(u8, line, " \r\t");
if (trimmed.len == 0) continue;
const trimmed = std.mem.trim(u8, line, " \r\n\t");
if (trimmed.len > 0) {
var arena = std.heap.ArenaAllocator.init(server.allocator); var arena = std.heap.ArenaAllocator.init(server.allocator);
defer arena.deinit(); defer arena.deinit();
@@ -47,13 +37,34 @@ pub fn processRequests(server: *Server, in_stream: std.fs.File) !void {
log.err(.mcp, "Failed to handle message", .{ .err = err, .msg = trimmed }); log.err(.mcp, "Failed to handle message", .{ .err = err, .msg = trimmed });
}; };
} }
r.toss(line.len);
}
} }
} }
const log = @import("../log.zig"); const log = @import("../log.zig");
fn handleMessage(server: *Server, arena: std.mem.Allocator, msg: []const u8) !void { const Method = enum {
const req = std.json.parseFromSlice(protocol.Request, arena, msg, .{ initialize,
@"notifications/initialized",
@"tools/list",
@"tools/call",
@"resources/list",
@"resources/read",
};
const method_map = std.StaticStringMap(Method).initComptime(.{
.{ "initialize", .initialize },
.{ "notifications/initialized", .@"notifications/initialized" },
.{ "tools/list", .@"tools/list" },
.{ "tools/call", .@"tools/call" },
.{ "resources/list", .@"resources/list" },
.{ "resources/read", .@"resources/read" },
});
pub fn handleMessage(server: *Server, arena: std.mem.Allocator, msg: []const u8) !void {
const req = std.json.parseFromSliceLeaky(protocol.Request, arena, msg, .{
.ignore_unknown_fields = true, .ignore_unknown_fields = true,
}) catch |err| { }) catch |err| {
log.warn(.mcp, "JSON Parse Error", .{ .err = err, .msg = msg }); log.warn(.mcp, "JSON Parse Error", .{ .err = err, .msg = msg });
@@ -61,40 +72,30 @@ fn handleMessage(server: *Server, arena: std.mem.Allocator, msg: []const u8) !vo
return; return;
}; };
if (std.mem.eql(u8, req.value.method, "initialize")) { const method = method_map.get(req.method) orelse {
return handleInitialize(server, req.value); if (req.id != null) {
try server.sendError(req.id.?, .MethodNotFound, "Method not found");
} }
if (std.mem.eql(u8, req.value.method, "notifications/initialized")) {
// nothing to do
return; return;
} };
if (std.mem.eql(u8, req.value.method, "tools/list")) { switch (method) {
return tools.handleList(server, arena, req.value); .initialize => try handleInitialize(server, req),
} .@"notifications/initialized" => {},
.@"tools/list" => try tools.handleList(server, arena, req),
if (std.mem.eql(u8, req.value.method, "tools/call")) { .@"tools/call" => try tools.handleCall(server, arena, req),
return tools.handleCall(server, arena, req.value); .@"resources/list" => try resources.handleList(server, req),
} .@"resources/read" => try resources.handleRead(server, arena, req),
if (std.mem.eql(u8, req.value.method, "resources/list")) {
return resources.handleList(server, req.value);
}
if (std.mem.eql(u8, req.value.method, "resources/read")) {
return resources.handleRead(server, arena, req.value);
}
if (req.value.id != null) {
return server.sendError(req.value.id.?, .MethodNotFound, "Method not found");
} }
} }
fn handleInitialize(server: *Server, req: protocol.Request) !void { fn handleInitialize(server: *Server, req: protocol.Request) !void {
const result = protocol.InitializeResult{ const result = protocol.InitializeResult{
.protocolVersion = "2025-11-25", .protocolVersion = "2025-11-25",
.capabilities = .{}, .capabilities = .{
.resources = .{},
.tools = .{},
},
.serverInfo = .{ .serverInfo = .{
.name = "lightpanda", .name = "lightpanda",
.version = "0.1.0", .version = "0.1.0",
@@ -107,33 +108,43 @@ fn handleInitialize(server: *Server, req: protocol.Request) !void {
const testing = @import("../testing.zig"); const testing = @import("../testing.zig");
const McpHarness = @import("testing.zig").McpHarness; const McpHarness = @import("testing.zig").McpHarness;
test "handleMessage - ParseError" { test "handleMessage - synchronous unit tests" {
// We need a server, but we want it to write to our fbs
// Server.init currently takes std.fs.File, we might need to refactor it
// to take a generic writer if we want to be truly "cranky" and avoid OS files.
// For now, let's use the harness as it's already set up, but call handleMessage directly.
const harness = try McpHarness.init(testing.allocator, testing.test_app); const harness = try McpHarness.init(testing.allocator, testing.test_app);
defer harness.deinit(); defer harness.deinit();
harness.thread = try std.Thread.spawn(.{}, wrapTest, .{ testParseErrorInternal, harness }); var arena = std.heap.ArenaAllocator.init(testing.allocator);
try harness.runServer();
}
fn wrapTest(comptime func: fn (*McpHarness) anyerror!void, harness: *McpHarness) void {
const res = func(harness);
if (res) |_| {
harness.test_error = null;
} else |err| {
harness.test_error = err;
}
harness.server.is_running.store(false, .release);
// Ensure we trigger a poll wake up if needed
_ = harness.client_out.writeAll("\n") catch {};
}
fn testParseErrorInternal(harness: *McpHarness) !void {
var arena = std.heap.ArenaAllocator.init(harness.allocator);
defer arena.deinit(); defer arena.deinit();
const aa = arena.allocator();
try harness.sendRequest("invalid json"); // 1. Valid request
try handleMessage(harness.server, aa,
\\{"jsonrpc":"2.0","id":1,"method":"initialize","params":{}}
);
const resp1 = try harness.readResponse(aa);
try testing.expect(std.mem.indexOf(u8, resp1, "\"id\":1") != null);
try testing.expect(std.mem.indexOf(u8, resp1, "\"name\":\"lightpanda\"") != null);
const response = try harness.readResponse(arena.allocator()); // 2. Method not found
try testing.expect(std.mem.indexOf(u8, response, "\"id\":null") != null); try handleMessage(harness.server, aa,
try testing.expect(std.mem.indexOf(u8, response, "\"code\":-32700") != null); \\{"jsonrpc":"2.0","id":2,"method":"unknown_method"}
);
const resp2 = try harness.readResponse(aa);
try testing.expect(std.mem.indexOf(u8, resp2, "\"id\":2") != null);
try testing.expect(std.mem.indexOf(u8, resp2, "\"code\":-32601") != null);
// 3. Parse error
{
const old_filter = log.opts.filter_scopes;
log.opts.filter_scopes = &.{.mcp};
defer log.opts.filter_scopes = old_filter;
try handleMessage(harness.server, aa, "invalid json");
const resp3 = try harness.readResponse(aa);
try testing.expect(std.mem.indexOf(u8, resp3, "\"id\":null") != null);
try testing.expect(std.mem.indexOf(u8, resp3, "\"code\":-32700") != null);
}
} }

View File

@@ -19,7 +19,6 @@ pub const McpHarness = struct {
thread: ?std.Thread = null, thread: ?std.Thread = null,
test_error: ?anyerror = null, test_error: ?anyerror = null,
buffer: std.ArrayListUnmanaged(u8) = .empty,
const Pipe = struct { const Pipe = struct {
read: std.fs.File, read: std.fs.File,
@@ -47,7 +46,6 @@ pub const McpHarness = struct {
self.app = app; self.app = app;
self.thread = null; self.thread = null;
self.test_error = null; self.test_error = null;
self.buffer = .empty;
const stdin_pipe = try Pipe.init(); const stdin_pipe = try Pipe.init();
errdefer stdin_pipe.close(); errdefer stdin_pipe.close();
@@ -88,7 +86,6 @@ pub const McpHarness = struct {
self.client_in.close(); self.client_in.close();
// self.client_out is already closed above // self.client_out is already closed above
self.buffer.deinit(self.allocator);
self.allocator.destroy(self); self.allocator.destroy(self);
} }
@@ -109,29 +106,23 @@ pub const McpHarness = struct {
var poller = std.io.poll(self.allocator, Streams, .{ .stdout = self.client_in }); var poller = std.io.poll(self.allocator, Streams, .{ .stdout = self.client_in });
defer poller.deinit(); defer poller.deinit();
const r = poller.reader(.stdout);
const timeout_ns = 2 * std.time.ns_per_s; const timeout_ns = 2 * std.time.ns_per_s;
var timer = try std.time.Timer.start(); var timer = try std.time.Timer.start();
while (timer.read() < timeout_ns) { while (timer.read() < timeout_ns) {
const remaining = timeout_ns - timer.read(); const poll_result = try poller.pollTimeout(timeout_ns - timer.read());
const poll_result = try poller.pollTimeout(remaining);
if (poll_result) { if (!poll_result) return error.EndOfStream;
const data = try poller.toOwnedSlice(.stdout);
if (data.len == 0) return error.EndOfStream; const buffered = r.buffered();
try self.buffer.appendSlice(self.allocator, data); if (std.mem.indexOfScalar(u8, buffered, '\n')) |newline_idx| {
self.allocator.free(data); const line = buffered[0 .. newline_idx + 1];
const result = try arena.dupe(u8, std.mem.trim(u8, line, " \r\n\t"));
r.toss(line.len);
return result;
} }
if (std.mem.indexOfScalar(u8, self.buffer.items, '\n')) |newline_idx| {
const line = try arena.dupe(u8, self.buffer.items[0..newline_idx]);
const remaining_bytes = self.buffer.items.len - (newline_idx + 1);
std.mem.copyForwards(u8, self.buffer.items[0..remaining_bytes], self.buffer.items[newline_idx + 1 ..]);
self.buffer.items.len = remaining_bytes;
return line;
}
if (!poll_result and timer.read() >= timeout_ns) break;
} }
return error.Timeout; return error.Timeout;

View File

@@ -13,7 +13,7 @@ pub const tool_list = [_]protocol.Tool{
.{ .{
.name = "goto", .name = "goto",
.description = "Navigate to a specified URL and load the page in memory so it can be reused later for info extraction.", .description = "Navigate to a specified URL and load the page in memory so it can be reused later for info extraction.",
.inputSchema = .{ .json = .inputSchema = protocol.minify(
\\{ \\{
\\ "type": "object", \\ "type": "object",
\\ "properties": { \\ "properties": {
@@ -21,12 +21,12 @@ pub const tool_list = [_]protocol.Tool{
\\ }, \\ },
\\ "required": ["url"] \\ "required": ["url"]
\\} \\}
}, ),
}, },
.{ .{
.name = "search", .name = "search",
.description = "Use a search engine to look for specific words, terms, sentences. The search page will then be loaded in memory.", .description = "Use a search engine to look for specific words, terms, sentences. The search page will then be loaded in memory.",
.inputSchema = .{ .json = .inputSchema = protocol.minify(
\\{ \\{
\\ "type": "object", \\ "type": "object",
\\ "properties": { \\ "properties": {
@@ -34,36 +34,36 @@ pub const tool_list = [_]protocol.Tool{
\\ }, \\ },
\\ "required": ["text"] \\ "required": ["text"]
\\} \\}
}, ),
}, },
.{ .{
.name = "markdown", .name = "markdown",
.description = "Get the page content in markdown format. If a url is provided, it navigates to that url first.", .description = "Get the page content in markdown format. If a url is provided, it navigates to that url first.",
.inputSchema = .{ .json = .inputSchema = protocol.minify(
\\{ \\{
\\ "type": "object", \\ "type": "object",
\\ "properties": { \\ "properties": {
\\ "url": { "type": "string", "description": "Optional URL to navigate to before fetching markdown." } \\ "url": { "type": "string", "description": "Optional URL to navigate to before fetching markdown." }
\\ } \\ }
\\} \\}
}, ),
}, },
.{ .{
.name = "links", .name = "links",
.description = "Extract all links in the opened page. If a url is provided, it navigates to that url first.", .description = "Extract all links in the opened page. If a url is provided, it navigates to that url first.",
.inputSchema = .{ .json = .inputSchema = protocol.minify(
\\{ \\{
\\ "type": "object", \\ "type": "object",
\\ "properties": { \\ "properties": {
\\ "url": { "type": "string", "description": "Optional URL to navigate to before extracting links." } \\ "url": { "type": "string", "description": "Optional URL to navigate to before extracting links." }
\\ } \\ }
\\} \\}
}, ),
}, },
.{ .{
.name = "evaluate", .name = "evaluate",
.description = "Evaluate JavaScript in the current page context. If a url is provided, it navigates to that url first.", .description = "Evaluate JavaScript in the current page context. If a url is provided, it navigates to that url first.",
.inputSchema = .{ .json = .inputSchema = protocol.minify(
\\{ \\{
\\ "type": "object", \\ "type": "object",
\\ "properties": { \\ "properties": {
@@ -72,12 +72,12 @@ pub const tool_list = [_]protocol.Tool{
\\ }, \\ },
\\ "required": ["script"] \\ "required": ["script"]
\\} \\}
}, ),
}, },
.{ .{
.name = "over", .name = "over",
.description = "Used to indicate that the task is over and give the final answer if there is any. This is the last tool to be called in a task.", .description = "Used to indicate that the task is over and give the final answer if there is any. This is the last tool to be called in a task.",
.inputSchema = .{ .json = .inputSchema = protocol.minify(
\\{ \\{
\\ "type": "object", \\ "type": "object",
\\ "properties": { \\ "properties": {
@@ -85,7 +85,7 @@ pub const tool_list = [_]protocol.Tool{
\\ }, \\ },
\\ "required": ["result"] \\ "required": ["result"]
\\} \\}
}, ),
}, },
}; };
@@ -158,6 +158,26 @@ const ToolStreamingText = struct {
} }
}; };
const ToolAction = enum {
goto,
navigate,
search,
markdown,
links,
evaluate,
over,
};
const tool_map = std.StaticStringMap(ToolAction).initComptime(.{
.{ "goto", .goto },
.{ "navigate", .navigate },
.{ "search", .search },
.{ "markdown", .markdown },
.{ "links", .links },
.{ "evaluate", .evaluate },
.{ "over", .over },
});
pub fn handleCall(server: *Server, arena: std.mem.Allocator, req: protocol.Request) !void { pub fn handleCall(server: *Server, arena: std.mem.Allocator, req: protocol.Request) !void {
if (req.params == null) { if (req.params == null) {
return server.sendError(req.id.?, .InvalidParams, "Missing params"); return server.sendError(req.id.?, .InvalidParams, "Missing params");
@@ -169,26 +189,20 @@ pub fn handleCall(server: *Server, arena: std.mem.Allocator, req: protocol.Reque
}; };
const call_params = std.json.parseFromValueLeaky(CallParams, arena, req.params.?, .{ .ignore_unknown_fields = true }) catch { const call_params = std.json.parseFromValueLeaky(CallParams, arena, req.params.?, .{ .ignore_unknown_fields = true }) catch {
var aw: std.Io.Writer.Allocating = .init(arena); return server.sendError(req.id.?, .InvalidParams, "Invalid params");
std.json.Stringify.value(req.params.?, .{}, &aw.writer) catch {};
const msg = std.fmt.allocPrint(arena, "Invalid params: {s}", .{aw.written()}) catch "Invalid params";
return server.sendError(req.id.?, .InvalidParams, msg);
}; };
if (std.mem.eql(u8, call_params.name, "goto") or std.mem.eql(u8, call_params.name, "navigate")) { const action = tool_map.get(call_params.name) orelse {
try handleGoto(server, arena, req.id.?, call_params.arguments);
} else if (std.mem.eql(u8, call_params.name, "search")) {
try handleSearch(server, arena, req.id.?, call_params.arguments);
} else if (std.mem.eql(u8, call_params.name, "markdown")) {
try handleMarkdown(server, arena, req.id.?, call_params.arguments);
} else if (std.mem.eql(u8, call_params.name, "links")) {
try handleLinks(server, arena, req.id.?, call_params.arguments);
} else if (std.mem.eql(u8, call_params.name, "evaluate")) {
try handleEvaluate(server, arena, req.id.?, call_params.arguments);
} else if (std.mem.eql(u8, call_params.name, "over")) {
try handleOver(server, arena, req.id.?, call_params.arguments);
} else {
return server.sendError(req.id.?, .MethodNotFound, "Tool not found"); return server.sendError(req.id.?, .MethodNotFound, "Tool not found");
};
switch (action) {
.goto, .navigate => try handleGoto(server, arena, req.id.?, call_params.arguments),
.search => try handleSearch(server, arena, req.id.?, call_params.arguments),
.markdown => try handleMarkdown(server, arena, req.id.?, call_params.arguments),
.links => try handleLinks(server, arena, req.id.?, call_params.arguments),
.evaluate => try handleEvaluate(server, arena, req.id.?, call_params.arguments),
.over => try handleOver(server, arena, req.id.?, call_params.arguments),
} }
} }