diff --git a/src/Config.zig b/src/Config.zig index 1e5cc9ab..5a4cc58e 100644 --- a/src/Config.zig +++ b/src/Config.zig @@ -28,6 +28,7 @@ pub const RunMode = enum { fetch, serve, version, + mcp, }; pub const CDP_MAX_HTTP_REQUEST_SIZE = 4096; @@ -59,56 +60,56 @@ pub fn deinit(self: *const Config, allocator: Allocator) void { pub fn tlsVerifyHost(self: *const Config) bool { return switch (self.mode) { - inline .serve, .fetch => |opts| opts.common.tls_verify_host, + inline .serve, .fetch, .mcp => |opts| opts.common.tls_verify_host, else => unreachable, }; } pub fn obeyRobots(self: *const Config) bool { return switch (self.mode) { - inline .serve, .fetch => |opts| opts.common.obey_robots, + inline .serve, .fetch, .mcp => |opts| opts.common.obey_robots, else => unreachable, }; } pub fn httpProxy(self: *const Config) ?[:0]const u8 { return switch (self.mode) { - inline .serve, .fetch => |opts| opts.common.http_proxy, + inline .serve, .fetch, .mcp => |opts| opts.common.http_proxy, else => unreachable, }; } pub fn proxyBearerToken(self: *const Config) ?[:0]const u8 { return switch (self.mode) { - inline .serve, .fetch => |opts| opts.common.proxy_bearer_token, + inline .serve, .fetch, .mcp => |opts| opts.common.proxy_bearer_token, .help, .version => null, }; } pub fn httpMaxConcurrent(self: *const Config) u8 { return switch (self.mode) { - inline .serve, .fetch => |opts| opts.common.http_max_concurrent orelse 10, + inline .serve, .fetch, .mcp => |opts| opts.common.http_max_concurrent orelse 10, else => unreachable, }; } pub fn httpMaxHostOpen(self: *const Config) u8 { return switch (self.mode) { - inline .serve, .fetch => |opts| opts.common.http_max_host_open orelse 4, + inline .serve, .fetch, .mcp => |opts| opts.common.http_max_host_open orelse 4, else => unreachable, }; } pub fn httpConnectTimeout(self: *const Config) u31 { return switch (self.mode) { - inline .serve, .fetch => |opts| opts.common.http_connect_timeout orelse 0, + inline .serve, .fetch, .mcp => |opts| opts.common.http_connect_timeout orelse 0, else => unreachable, }; } pub fn httpTimeout(self: *const Config) u31 { return switch (self.mode) { - inline .serve, .fetch => |opts| opts.common.http_timeout orelse 5000, + inline .serve, .fetch, .mcp => |opts| opts.common.http_timeout orelse 5000, else => unreachable, }; } @@ -119,35 +120,35 @@ pub fn httpMaxRedirects(_: *const Config) u8 { pub fn httpMaxResponseSize(self: *const Config) ?usize { return switch (self.mode) { - inline .serve, .fetch => |opts| opts.common.http_max_response_size, + inline .serve, .fetch, .mcp => |opts| opts.common.http_max_response_size, else => unreachable, }; } pub fn logLevel(self: *const Config) ?log.Level { return switch (self.mode) { - inline .serve, .fetch => |opts| opts.common.log_level, + inline .serve, .fetch, .mcp => |opts| opts.common.log_level, else => unreachable, }; } pub fn logFormat(self: *const Config) ?log.Format { return switch (self.mode) { - inline .serve, .fetch => |opts| opts.common.log_format, + inline .serve, .fetch, .mcp => |opts| opts.common.log_format, else => unreachable, }; } pub fn logFilterScopes(self: *const Config) ?[]const log.Scope { return switch (self.mode) { - inline .serve, .fetch => |opts| opts.common.log_filter_scopes, + inline .serve, .fetch, .mcp => |opts| opts.common.log_filter_scopes, else => unreachable, }; } pub fn userAgentSuffix(self: *const Config) ?[]const u8 { return switch (self.mode) { - inline .serve, .fetch => |opts| opts.common.user_agent_suffix, + inline .serve, .fetch, .mcp => |opts| opts.common.user_agent_suffix, .help, .version => null, }; } @@ -171,6 +172,7 @@ pub const Mode = union(RunMode) { fetch: Fetch, serve: Serve, version: void, + mcp: Mcp, }; pub const Serve = struct { @@ -182,6 +184,10 @@ pub const Serve = struct { common: Common = .{}, }; +pub const Mcp = struct { + common: Common = .{}, +}; + pub const DumpFormat = enum { html, markdown, @@ -324,7 +330,7 @@ pub fn printUsageAndExit(self: *const Config, success: bool) void { const usage = \\usage: {s} command [options] [URL] \\ - \\Command can be either 'fetch', 'serve' or 'help' + \\Command can be either 'fetch', 'serve', 'mcp' or 'help' \\ \\fetch command \\Fetches the specified URL @@ -370,6 +376,12 @@ pub fn printUsageAndExit(self: *const Config, success: bool) void { \\ Maximum pending connections in the accept queue. \\ Defaults to 128. \\ + ++ common_options ++ + \\ + \\mcp command + \\Starts an MCP (Model Context Protocol) server over stdio + \\Example: {s} mcp + \\ ++ common_options ++ \\ \\version command @@ -379,7 +391,7 @@ pub fn printUsageAndExit(self: *const Config, success: bool) void { \\Displays this message \\ ; - std.debug.print(usage, .{ self.exec_name, self.exec_name, self.exec_name, self.exec_name }); + std.debug.print(usage, .{ self.exec_name, self.exec_name, self.exec_name, self.exec_name, self.exec_name }); if (success) { return std.process.cleanExit(); } @@ -414,6 +426,8 @@ pub fn parseArgs(allocator: Allocator) !Config { return init(allocator, exec_name, .{ .help = false }) }, .fetch => .{ .fetch = parseFetchArgs(allocator, &args) catch return init(allocator, exec_name, .{ .help = false }) }, + .mcp => .{ .mcp = parseMcpArgs(allocator, &args) catch + return init(allocator, exec_name, .{ .help = false }) }, .version => .{ .version = {} }, }; return init(allocator, exec_name, mode); @@ -542,6 +556,24 @@ fn parseServeArgs( return serve; } +fn parseMcpArgs( + allocator: Allocator, + args: *std.process.ArgIterator, +) !Mcp { + var mcp: Mcp = .{}; + + while (args.next()) |opt| { + if (try parseCommonArg(allocator, opt, args, &mcp.common)) { + continue; + } + + log.fatal(.mcp, "unknown argument", .{ .mode = "mcp", .arg = opt }); + return error.UnkownOption; + } + + return mcp; +} + fn parseFetchArgs( allocator: Allocator, args: *std.process.ArgIterator, diff --git a/src/browser/js/TryCatch.zig b/src/browser/js/TryCatch.zig index d0f7a7d8..f9909e71 100644 --- a/src/browser/js/TryCatch.zig +++ b/src/browser/js/TryCatch.zig @@ -134,4 +134,17 @@ pub const Caught = struct { try writer.write(prefix ++ ".line", self.line); try writer.write(prefix ++ ".caught", self.caught); } + + pub fn jsonStringify(self: Caught, jw: anytype) !void { + try jw.beginObject(); + try jw.objectField("exception"); + try jw.write(self.exception); + try jw.objectField("stack"); + try jw.write(self.stack); + try jw.objectField("line"); + try jw.write(self.line); + try jw.objectField("caught"); + try jw.write(self.caught); + try jw.endObject(); + } }; diff --git a/src/lightpanda.zig b/src/lightpanda.zig index 5a30c5ff..26bc23f0 100644 --- a/src/lightpanda.zig +++ b/src/lightpanda.zig @@ -30,6 +30,7 @@ pub const log = @import("log.zig"); pub const js = @import("browser/js/js.zig"); pub const dump = @import("browser/dump.zig"); pub const markdown = @import("browser/markdown.zig"); +pub const mcp = @import("mcp.zig"); pub const build_config = @import("build_config"); pub const crash_handler = @import("crash_handler.zig"); diff --git a/src/log.zig b/src/log.zig index b1ff926b..8cf712ba 100644 --- a/src/log.zig +++ b/src/log.zig @@ -38,6 +38,7 @@ pub const Scope = enum { not_implemented, telemetry, unknown_prop, + mcp, }; const Opts = struct { diff --git a/src/main.zig b/src/main.zig index a2f16bff..dd6a759a 100644 --- a/src/main.zig +++ b/src/main.zig @@ -131,6 +131,21 @@ fn run(allocator: Allocator, main_arena: Allocator) !void { return err; }; }, + .mcp => { + log.info(.mcp, "starting server", .{}); + + log.opts.format = .logfmt; + + var stdout = std.fs.File.stdout().writer(&.{}); + + var mcp_server: *lp.mcp.Server = try .init(allocator, app, &stdout.interface); + defer mcp_server.deinit(); + + var stdin_buf: [64 * 1024]u8 = undefined; + var stdin = std.fs.File.stdin().reader(&stdin_buf); + + try lp.mcp.router.processRequests(mcp_server, &stdin.interface); + }, else => unreachable, } } diff --git a/src/mcp.zig b/src/mcp.zig new file mode 100644 index 00000000..ca92206c --- /dev/null +++ b/src/mcp.zig @@ -0,0 +1,9 @@ +const std = @import("std"); + +pub const protocol = @import("mcp/protocol.zig"); +pub const router = @import("mcp/router.zig"); +pub const Server = @import("mcp/Server.zig"); + +test { + std.testing.refAllDecls(@This()); +} diff --git a/src/mcp/Server.zig b/src/mcp/Server.zig new file mode 100644 index 00000000..d73ac74c --- /dev/null +++ b/src/mcp/Server.zig @@ -0,0 +1,111 @@ +const std = @import("std"); + +const lp = @import("lightpanda"); + +const App = @import("../App.zig"); +const HttpClient = @import("../http/Client.zig"); +const testing = @import("../testing.zig"); +const protocol = @import("protocol.zig"); +const router = @import("router.zig"); + +const Self = @This(); + +allocator: std.mem.Allocator, +app: *App, + +http_client: *HttpClient, +notification: *lp.Notification, +browser: lp.Browser, +session: *lp.Session, +page: *lp.Page, + +writer: *std.io.Writer, +mutex: std.Thread.Mutex = .{}, +aw: std.io.Writer.Allocating, + +pub fn init(allocator: std.mem.Allocator, app: *App, writer: *std.io.Writer) !*Self { + const self = try allocator.create(Self); + errdefer allocator.destroy(self); + + self.allocator = allocator; + self.app = app; + self.writer = writer; + self.aw = .init(allocator); + + self.http_client = try app.http.createClient(allocator); + errdefer self.http_client.deinit(); + + self.notification = try .init(allocator); + errdefer self.notification.deinit(); + + self.browser = try lp.Browser.init(app, .{ .http_client = self.http_client }); + errdefer self.browser.deinit(); + + self.session = try self.browser.newSession(self.notification); + self.page = try self.session.createPage(); + + return self; +} + +pub fn deinit(self: *Self) void { + self.aw.deinit(); + self.browser.deinit(); + self.notification.deinit(); + self.http_client.deinit(); + + self.allocator.destroy(self); +} + +pub fn sendResponse(self: *Self, response: anytype) !void { + self.mutex.lock(); + defer self.mutex.unlock(); + + self.aw.clearRetainingCapacity(); + try std.json.Stringify.value(response, .{ .emit_null_optional_fields = false }, &self.aw.writer); + try self.aw.writer.writeByte('\n'); + try self.writer.writeAll(self.aw.writer.buffered()); + try self.writer.flush(); +} + +pub fn sendResult(self: *Self, id: std.json.Value, result: anytype) !void { + const GenericResponse = struct { + jsonrpc: []const u8 = "2.0", + id: std.json.Value, + result: @TypeOf(result), + }; + try self.sendResponse(GenericResponse{ + .id = id, + .result = result, + }); +} + +pub fn sendError(self: *Self, id: std.json.Value, code: protocol.ErrorCode, message: []const u8) !void { + try self.sendResponse(protocol.Response{ + .id = id, + .@"error" = protocol.Error{ + .code = @intFromEnum(code), + .message = message, + }, + }); +} + +test "MCP.Server - Integration: synchronous smoke test" { + defer testing.reset(); + const allocator = testing.allocator; + const app = testing.test_app; + + const input = + \\{"jsonrpc":"2.0","id":1,"method":"initialize","params":{"protocolVersion":"2024-11-05","capabilities":{},"clientInfo":{"name":"test-client","version":"1.0.0"}}} + ; + + var in_reader: std.io.Reader = .fixed(input); + var out_alloc: std.io.Writer.Allocating = .init(testing.arena_allocator); + defer out_alloc.deinit(); + + var server = try Self.init(allocator, app, &out_alloc.writer); + defer server.deinit(); + + try router.processRequests(server, &in_reader); + + try testing.expectJson(.{ .id = 1 }, out_alloc.writer.buffered()); +} diff --git a/src/mcp/protocol.zig b/src/mcp/protocol.zig new file mode 100644 index 00000000..5f5dc7f2 --- /dev/null +++ b/src/mcp/protocol.zig @@ -0,0 +1,304 @@ +const std = @import("std"); + +pub const Request = struct { + jsonrpc: []const u8 = "2.0", + id: ?std.json.Value = null, + method: []const u8, + params: ?std.json.Value = null, +}; + +pub const Response = struct { + jsonrpc: []const u8 = "2.0", + id: std.json.Value, + result: ?std.json.Value = null, + @"error": ?Error = null, +}; + +pub const Error = struct { + code: i64, + message: []const u8, + data: ?std.json.Value = null, +}; + +pub const ErrorCode = enum(i64) { + ParseError = -32700, + InvalidRequest = -32600, + MethodNotFound = -32601, + InvalidParams = -32602, + InternalError = -32603, +}; + +pub const Notification = struct { + jsonrpc: []const u8 = "2.0", + method: []const u8, + params: ?std.json.Value = null, +}; + +// Core MCP Types mapping to official specification +pub const InitializeRequest = struct { + jsonrpc: []const u8 = "2.0", + id: std.json.Value, + method: []const u8 = "initialize", + params: InitializeParams, +}; + +pub const InitializeParams = struct { + protocolVersion: []const u8, + capabilities: Capabilities, + clientInfo: Implementation, +}; + +pub const Capabilities = struct { + experimental: ?std.json.Value = null, + roots: ?RootsCapability = null, + sampling: ?SamplingCapability = null, +}; + +pub const RootsCapability = struct { + listChanged: ?bool = null, +}; + +pub const SamplingCapability = struct {}; + +pub const Implementation = struct { + name: []const u8, + version: []const u8, +}; + +pub const InitializeResult = struct { + protocolVersion: []const u8, + capabilities: ServerCapabilities, + serverInfo: Implementation, +}; + +pub const ServerCapabilities = struct { + experimental: ?std.json.Value = null, + logging: ?LoggingCapability = null, + prompts: ?PromptsCapability = null, + resources: ?ResourcesCapability = null, + tools: ?ToolsCapability = null, +}; + +pub const LoggingCapability = struct {}; +pub const PromptsCapability = struct { + listChanged: ?bool = null, +}; +pub const ResourcesCapability = struct { + subscribe: ?bool = null, + listChanged: ?bool = null, +}; +pub const ToolsCapability = struct { + listChanged: ?bool = null, +}; + +pub const Tool = struct { + name: []const u8, + description: ?[]const u8 = null, + inputSchema: []const u8, + + pub fn jsonStringify(self: @This(), jw: anytype) !void { + try jw.beginObject(); + try jw.objectField("name"); + try jw.write(self.name); + if (self.description) |d| { + try jw.objectField("description"); + try jw.write(d); + } + try jw.objectField("inputSchema"); + _ = try jw.beginWriteRaw(); + try jw.writer.writeAll(self.inputSchema); + jw.endWriteRaw(); + try jw.endObject(); + } +}; + +pub fn minify(comptime json: []const u8) []const u8 { + return comptime blk: { + var res: []const u8 = ""; + var in_string = false; + var escaped = false; + for (json) |c| { + if (in_string) { + res = res ++ [1]u8{c}; + if (escaped) { + escaped = false; + } else if (c == '\\') { + escaped = true; + } else if (c == '"') { + in_string = false; + } + } else { + switch (c) { + ' ', '\n', '\r', '\t' => continue, + '"' => { + in_string = true; + res = res ++ [1]u8{c}; + }, + else => res = res ++ [1]u8{c}, + } + } + } + break :blk res; + }; +} + +pub const Resource = struct { + uri: []const u8, + name: []const u8, + description: ?[]const u8 = null, + mimeType: ?[]const u8 = null, +}; + +pub fn TextContent(comptime T: type) type { + return struct { + type: []const u8 = "text", + text: T, + }; +} + +pub fn CallToolResult(comptime T: type) type { + return struct { + content: []const TextContent(T), + isError: bool = false, + }; +} + +pub const JsonEscapingWriter = struct { + inner_writer: *std.Io.Writer, + writer: std.Io.Writer, + + pub fn init(inner_writer: *std.Io.Writer) JsonEscapingWriter { + return .{ + .inner_writer = inner_writer, + .writer = .{ + .vtable = &vtable, + .buffer = &.{}, + }, + }; + } + + const vtable = std.Io.Writer.VTable{ + .drain = drain, + }; + + fn drain(w: *std.Io.Writer, data: []const []const u8, splat: usize) std.Io.Writer.Error!usize { + const self: *JsonEscapingWriter = @alignCast(@fieldParentPtr("writer", w)); + var total: usize = 0; + for (data[0 .. data.len - 1]) |slice| { + std.json.Stringify.encodeJsonStringChars(slice, .{}, self.inner_writer) catch return error.WriteFailed; + total += slice.len; + } + const pattern = data[data.len - 1]; + for (0..splat) |_| { + std.json.Stringify.encodeJsonStringChars(pattern, .{}, self.inner_writer) catch return error.WriteFailed; + total += pattern.len; + } + return total; + } +}; + +const testing = @import("../testing.zig"); + +test "MCP.protocol - request parsing" { + defer testing.reset(); + const raw_json = + \\{ + \\ "jsonrpc": "2.0", + \\ "id": 1, + \\ "method": "initialize", + \\ "params": { + \\ "protocolVersion": "2024-11-05", + \\ "capabilities": {}, + \\ "clientInfo": { + \\ "name": "test-client", + \\ "version": "1.0.0" + \\ } + \\ } + \\} + ; + + const parsed = try std.json.parseFromSlice(Request, testing.arena_allocator, raw_json, .{ .ignore_unknown_fields = true }); + defer parsed.deinit(); + + const req = parsed.value; + try testing.expectString("2.0", req.jsonrpc); + try testing.expectString("initialize", req.method); + try testing.expect(req.id.? == .integer); + try testing.expectEqual(@as(i64, 1), req.id.?.integer); + try testing.expect(req.params != null); + + // Test nested parsing of InitializeParams + const init_params = try std.json.parseFromValue(InitializeParams, testing.arena_allocator, req.params.?, .{ .ignore_unknown_fields = true }); + defer init_params.deinit(); + + try testing.expectString("2024-11-05", init_params.value.protocolVersion); + try testing.expectString("test-client", init_params.value.clientInfo.name); + try testing.expectString("1.0.0", init_params.value.clientInfo.version); +} + +test "MCP.protocol - response formatting" { + defer testing.reset(); + const response = Response{ + .id = .{ .integer = 42 }, + .result = .{ .string = "success" }, + }; + + var aw: std.Io.Writer.Allocating = .init(testing.arena_allocator); + defer aw.deinit(); + try std.json.Stringify.value(response, .{ .emit_null_optional_fields = false }, &aw.writer); + + try testing.expectString("{\"jsonrpc\":\"2.0\",\"id\":42,\"result\":\"success\"}", aw.written()); +} + +test "MCP.protocol - error formatting" { + defer testing.reset(); + const response = Response{ + .id = .{ .string = "abc" }, + .@"error" = .{ + .code = @intFromEnum(ErrorCode.MethodNotFound), + .message = "Method not found", + }, + }; + + var aw: std.Io.Writer.Allocating = .init(testing.arena_allocator); + defer aw.deinit(); + try std.json.Stringify.value(response, .{ .emit_null_optional_fields = false }, &aw.writer); + + try testing.expectString("{\"jsonrpc\":\"2.0\",\"id\":\"abc\",\"error\":{\"code\":-32601,\"message\":\"Method not found\"}}", aw.written()); +} + +test "MCP.protocol - JsonEscapingWriter" { + defer testing.reset(); + var aw: std.Io.Writer.Allocating = .init(testing.arena_allocator); + defer aw.deinit(); + + var escaping_writer = JsonEscapingWriter.init(&aw.writer); + + // test newlines and quotes + try escaping_writer.writer.writeAll("hello\n\"world\""); + + // the writer outputs escaped string chars without surrounding quotes + try testing.expectString("hello\\n\\\"world\\\"", aw.written()); +} + +test "MCP.protocol - Tool serialization" { + defer testing.reset(); + const t = Tool{ + .name = "test", + .inputSchema = minify( + \\{ + \\ "type": "object", + \\ "properties": { + \\ "foo": { "type": "string" } + \\ } + \\} + ), + }; + + var aw: std.Io.Writer.Allocating = .init(testing.arena_allocator); + defer aw.deinit(); + + try std.json.Stringify.value(t, .{}, &aw.writer); + + try testing.expectString("{\"name\":\"test\",\"inputSchema\":{\"type\":\"object\",\"properties\":{\"foo\":{\"type\":\"string\"}}}}", aw.written()); +} diff --git a/src/mcp/resources.zig b/src/mcp/resources.zig new file mode 100644 index 00000000..7fd59dc3 --- /dev/null +++ b/src/mcp/resources.zig @@ -0,0 +1,108 @@ +const std = @import("std"); + +const lp = @import("lightpanda"); +const log = lp.log; + +const protocol = @import("protocol.zig"); +const Server = @import("Server.zig"); + +pub const resource_list = [_]protocol.Resource{ + .{ + .uri = "mcp://page/html", + .name = "Page HTML", + .description = "The serialized HTML DOM of the current page", + .mimeType = "text/html", + }, + .{ + .uri = "mcp://page/markdown", + .name = "Page Markdown", + .description = "The token-efficient markdown representation of the current page", + .mimeType = "text/markdown", + }, +}; + +pub fn handleList(server: *Server, req: protocol.Request) !void { + try server.sendResult(req.id.?, .{ .resources = &resource_list }); +} + +const ReadParams = struct { + uri: []const u8, +}; + +const ResourceStreamingResult = struct { + contents: []const struct { + uri: []const u8, + mimeType: []const u8, + text: StreamingText, + }, + + const StreamingText = struct { + server: *Server, + format: enum { html, markdown }, + + pub fn jsonStringify(self: @This(), jw: *std.json.Stringify) !void { + try jw.beginWriteRaw(); + try jw.writer.writeByte('"'); + var escaped = protocol.JsonEscapingWriter.init(jw.writer); + switch (self.format) { + .html => lp.dump.root(self.server.page.document, .{}, &escaped.writer, self.server.page) catch |err| { + log.err(.mcp, "html dump failed", .{ .err = err }); + }, + .markdown => lp.markdown.dump(self.server.page.document.asNode(), .{}, &escaped.writer, self.server.page) catch |err| { + log.err(.mcp, "markdown dump failed", .{ .err = err }); + }, + } + try jw.writer.writeByte('"'); + jw.endWriteRaw(); + } + }; +}; + +const ResourceUri = enum { + @"mcp://page/html", + @"mcp://page/markdown", +}; + +const resource_map = std.StaticStringMap(ResourceUri).initComptime(.{ + .{ "mcp://page/html", .@"mcp://page/html" }, + .{ "mcp://page/markdown", .@"mcp://page/markdown" }, +}); + +pub fn handleRead(server: *Server, arena: std.mem.Allocator, req: protocol.Request) !void { + if (req.params == null) { + return server.sendError(req.id.?, .InvalidParams, "Missing params"); + } + + const params = std.json.parseFromValueLeaky(ReadParams, arena, req.params.?, .{ .ignore_unknown_fields = true }) catch { + return server.sendError(req.id.?, .InvalidParams, "Invalid params"); + }; + + const uri = resource_map.get(params.uri) orelse { + return server.sendError(req.id.?, .InvalidRequest, "Resource not found"); + }; + + switch (uri) { + .@"mcp://page/html" => { + const result: ResourceStreamingResult = .{ + .contents = &.{.{ + .uri = params.uri, + .mimeType = "text/html", + .text = .{ .server = server, .format = .html }, + }}, + }; + try server.sendResult(req.id.?, result); + }, + .@"mcp://page/markdown" => { + const result: ResourceStreamingResult = .{ + .contents = &.{.{ + .uri = params.uri, + .mimeType = "text/markdown", + .text = .{ .server = server, .format = .markdown }, + }}, + }; + try server.sendResult(req.id.?, result); + }, + } +} + +const testing = @import("../testing.zig"); diff --git a/src/mcp/router.zig b/src/mcp/router.zig new file mode 100644 index 00000000..14f4a3ab --- /dev/null +++ b/src/mcp/router.zig @@ -0,0 +1,143 @@ +const std = @import("std"); +const lp = @import("lightpanda"); +const protocol = @import("protocol.zig"); +const resources = @import("resources.zig"); +const Server = @import("Server.zig"); +const tools = @import("tools.zig"); + +pub fn processRequests(server: *Server, reader: *std.io.Reader) !void { + var arena: std.heap.ArenaAllocator = .init(server.allocator); + defer arena.deinit(); + + while (true) { + _ = arena.reset(.retain_capacity); + const aa = arena.allocator(); + + const buffered_line = reader.takeDelimiter('\n') catch |err| switch (err) { + error.StreamTooLong => { + log.err(.mcp, "Message too long", .{}); + continue; + }, + else => return err, + } orelse break; + + const trimmed = std.mem.trim(u8, buffered_line, " \r\t"); + if (trimmed.len > 0) { + handleMessage(server, aa, trimmed) catch |err| { + log.err(.mcp, "Failed to handle message", .{ .err = err, .msg = trimmed }); + }; + } + } +} + +const log = @import("../log.zig"); + +const Method = enum { + initialize, + @"notifications/initialized", + @"tools/list", + @"tools/call", + @"resources/list", + @"resources/read", +}; + +const method_map = std.StaticStringMap(Method).initComptime(.{ + .{ "initialize", .initialize }, + .{ "notifications/initialized", .@"notifications/initialized" }, + .{ "tools/list", .@"tools/list" }, + .{ "tools/call", .@"tools/call" }, + .{ "resources/list", .@"resources/list" }, + .{ "resources/read", .@"resources/read" }, +}); + +pub fn handleMessage(server: *Server, arena: std.mem.Allocator, msg: []const u8) !void { + const req = std.json.parseFromSliceLeaky(protocol.Request, arena, msg, .{ + .ignore_unknown_fields = true, + }) catch |err| { + log.warn(.mcp, "JSON Parse Error", .{ .err = err, .msg = msg }); + try server.sendError(.null, .ParseError, "Parse error"); + return; + }; + + const method = method_map.get(req.method) orelse { + if (req.id != null) { + try server.sendError(req.id.?, .MethodNotFound, "Method not found"); + } + return; + }; + + switch (method) { + .initialize => try handleInitialize(server, req), + .@"notifications/initialized" => {}, + .@"tools/list" => try tools.handleList(server, arena, req), + .@"tools/call" => try tools.handleCall(server, arena, req), + .@"resources/list" => try resources.handleList(server, req), + .@"resources/read" => try resources.handleRead(server, arena, req), + } +} + +fn handleInitialize(server: *Server, req: protocol.Request) !void { + const result = protocol.InitializeResult{ + .protocolVersion = "2025-11-25", + .capabilities = .{ + .resources = .{}, + .tools = .{}, + }, + .serverInfo = .{ + .name = "lightpanda", + .version = "0.1.0", + }, + }; + + try server.sendResult(req.id.?, result); +} + +const testing = @import("../testing.zig"); + +test "MCP.router - handleMessage - synchronous unit tests" { + defer testing.reset(); + const allocator = testing.allocator; + const app = testing.test_app; + + var out_alloc: std.io.Writer.Allocating = .init(testing.arena_allocator); + defer out_alloc.deinit(); + + var server = try Server.init(allocator, app, &out_alloc.writer); + defer server.deinit(); + + const aa = testing.arena_allocator; + + // 1. Valid handshake + try handleMessage(server, aa, + \\{"jsonrpc":"2.0","id":1,"method":"initialize","params":{"protocolVersion":"2024-11-05","capabilities":{},"clientInfo":{"name":"test-client","version":"1.0.0"}}} + ); + try testing.expectJson( + \\{ "id": 1, "result": { "capabilities": { "tools": {} } } } + , out_alloc.writer.buffered()); + out_alloc.writer.end = 0; + + // 2. Tools list + try handleMessage(server, aa, + \\{"jsonrpc":"2.0","id":2,"method":"tools/list"} + ); + try testing.expectJson(.{ .id = 2 }, out_alloc.writer.buffered()); + try testing.expect(std.mem.indexOf(u8, out_alloc.writer.buffered(), "\"name\":\"goto\"") != null); + out_alloc.writer.end = 0; + + // 3. Method not found + try handleMessage(server, aa, + \\{"jsonrpc":"2.0","id":3,"method":"unknown_method"} + ); + try testing.expectJson(.{ .id = 3, .@"error" = .{ .code = -32601 } }, out_alloc.writer.buffered()); + out_alloc.writer.end = 0; + + // 4. Parse error + { + const old_filter = log.opts.filter_scopes; + log.opts.filter_scopes = &.{.mcp}; + defer log.opts.filter_scopes = old_filter; + + try handleMessage(server, aa, "invalid json"); + try testing.expectJson("{\"id\": null, \"error\": {\"code\": -32700}}", out_alloc.writer.buffered()); + } +} diff --git a/src/mcp/tools.zig b/src/mcp/tools.zig new file mode 100644 index 00000000..146bd7db --- /dev/null +++ b/src/mcp/tools.zig @@ -0,0 +1,305 @@ +const std = @import("std"); + +const lp = @import("lightpanda"); +const log = lp.log; +const js = lp.js; + +const Element = @import("../browser/webapi/Element.zig"); +const Selector = @import("../browser/webapi/selector/Selector.zig"); +const protocol = @import("protocol.zig"); +const Server = @import("Server.zig"); + +pub const tool_list = [_]protocol.Tool{ + .{ + .name = "goto", + .description = "Navigate to a specified URL and load the page in memory so it can be reused later for info extraction.", + .inputSchema = protocol.minify( + \\{ + \\ "type": "object", + \\ "properties": { + \\ "url": { "type": "string", "description": "The URL to navigate to, must be a valid URL." } + \\ }, + \\ "required": ["url"] + \\} + ), + }, + .{ + .name = "markdown", + .description = "Get the page content in markdown format. If a url is provided, it navigates to that url first.", + .inputSchema = protocol.minify( + \\{ + \\ "type": "object", + \\ "properties": { + \\ "url": { "type": "string", "description": "Optional URL to navigate to before fetching markdown." } + \\ } + \\} + ), + }, + .{ + .name = "links", + .description = "Extract all links in the opened page. If a url is provided, it navigates to that url first.", + .inputSchema = protocol.minify( + \\{ + \\ "type": "object", + \\ "properties": { + \\ "url": { "type": "string", "description": "Optional URL to navigate to before extracting links." } + \\ } + \\} + ), + }, + .{ + .name = "evaluate", + .description = "Evaluate JavaScript in the current page context. If a url is provided, it navigates to that url first.", + .inputSchema = protocol.minify( + \\{ + \\ "type": "object", + \\ "properties": { + \\ "script": { "type": "string" }, + \\ "url": { "type": "string", "description": "Optional URL to navigate to before evaluating." } + \\ }, + \\ "required": ["script"] + \\} + ), + }, +}; + +pub fn handleList(server: *Server, arena: std.mem.Allocator, req: protocol.Request) !void { + _ = arena; + try server.sendResult(req.id.?, .{ .tools = &tool_list }); +} + +const GotoParams = struct { + url: [:0]const u8, +}; + +const EvaluateParams = struct { + script: [:0]const u8, + url: ?[:0]const u8 = null, +}; + +const ToolStreamingText = struct { + server: *Server, + action: enum { markdown, links }, + + pub fn jsonStringify(self: @This(), jw: *std.json.Stringify) !void { + try jw.beginWriteRaw(); + try jw.writer.writeByte('"'); + var escaped = protocol.JsonEscapingWriter.init(jw.writer); + const w = &escaped.writer; + switch (self.action) { + .markdown => lp.markdown.dump(self.server.page.document.asNode(), .{}, w, self.server.page) catch |err| { + log.err(.mcp, "markdown dump failed", .{ .err = err }); + }, + .links => { + if (Selector.querySelectorAll(self.server.page.document.asNode(), "a[href]", self.server.page)) |list| { + defer list.deinit(self.server.page); + var first = true; + for (list._nodes) |node| { + if (node.is(Element.Html.Anchor)) |anchor| { + const href = anchor.getHref(self.server.page) catch |err| { + log.err(.mcp, "resolve href failed", .{ .err = err }); + continue; + }; + + if (href.len > 0) { + if (!first) try w.writeByte('\n'); + try w.writeAll(href); + first = false; + } + } + } + } else |err| { + log.err(.mcp, "query links failed", .{ .err = err }); + } + }, + } + try jw.writer.writeByte('"'); + jw.endWriteRaw(); + } +}; + +const ToolAction = enum { + goto, + navigate, + markdown, + links, + evaluate, +}; + +const tool_map = std.StaticStringMap(ToolAction).initComptime(.{ + .{ "goto", .goto }, + .{ "navigate", .navigate }, + .{ "markdown", .markdown }, + .{ "links", .links }, + .{ "evaluate", .evaluate }, +}); + +pub fn handleCall(server: *Server, arena: std.mem.Allocator, req: protocol.Request) !void { + if (req.params == null) { + return server.sendError(req.id.?, .InvalidParams, "Missing params"); + } + + const CallParams = struct { + name: []const u8, + arguments: ?std.json.Value = null, + }; + + const call_params = std.json.parseFromValueLeaky(CallParams, arena, req.params.?, .{ .ignore_unknown_fields = true }) catch { + return server.sendError(req.id.?, .InvalidParams, "Invalid params"); + }; + + const action = tool_map.get(call_params.name) orelse { + return server.sendError(req.id.?, .MethodNotFound, "Tool not found"); + }; + + switch (action) { + .goto, .navigate => try handleGoto(server, arena, req.id.?, call_params.arguments), + .markdown => try handleMarkdown(server, arena, req.id.?, call_params.arguments), + .links => try handleLinks(server, arena, req.id.?, call_params.arguments), + .evaluate => try handleEvaluate(server, arena, req.id.?, call_params.arguments), + } +} + +fn handleGoto(server: *Server, arena: std.mem.Allocator, id: std.json.Value, arguments: ?std.json.Value) !void { + const args = try parseArguments(GotoParams, arena, arguments, server, id, "goto"); + try performGoto(server, args.url, id); + + const content = [_]protocol.TextContent([]const u8){.{ .text = "Navigated successfully." }}; + try server.sendResult(id, protocol.CallToolResult([]const u8){ .content = &content }); +} + +fn handleMarkdown(server: *Server, arena: std.mem.Allocator, id: std.json.Value, arguments: ?std.json.Value) !void { + const MarkdownParams = struct { + url: ?[:0]const u8 = null, + }; + if (arguments) |args_raw| { + if (std.json.parseFromValueLeaky(MarkdownParams, arena, args_raw, .{ .ignore_unknown_fields = true })) |args| { + if (args.url) |u| { + try performGoto(server, u, id); + } + } else |_| {} + } + + const content = [_]protocol.TextContent(ToolStreamingText){.{ + .text = .{ .server = server, .action = .markdown }, + }}; + try server.sendResult(id, protocol.CallToolResult(ToolStreamingText){ .content = &content }); +} + +fn handleLinks(server: *Server, arena: std.mem.Allocator, id: std.json.Value, arguments: ?std.json.Value) !void { + const LinksParams = struct { + url: ?[:0]const u8 = null, + }; + if (arguments) |args_raw| { + if (std.json.parseFromValueLeaky(LinksParams, arena, args_raw, .{ .ignore_unknown_fields = true })) |args| { + if (args.url) |u| { + try performGoto(server, u, id); + } + } else |_| {} + } + + const content = [_]protocol.TextContent(ToolStreamingText){.{ + .text = .{ .server = server, .action = .links }, + }}; + try server.sendResult(id, protocol.CallToolResult(ToolStreamingText){ .content = &content }); +} + +fn handleEvaluate(server: *Server, arena: std.mem.Allocator, id: std.json.Value, arguments: ?std.json.Value) !void { + const args = try parseArguments(EvaluateParams, arena, arguments, server, id, "evaluate"); + + if (args.url) |url| { + try performGoto(server, url, id); + } + + var ls: js.Local.Scope = undefined; + server.page.js.localScope(&ls); + defer ls.deinit(); + + var try_catch: js.TryCatch = undefined; + try_catch.init(&ls.local); + defer try_catch.deinit(); + + const js_result = ls.local.compileAndRun(args.script, null) catch |err| { + const caught = try_catch.caughtOrError(arena, err); + var aw: std.Io.Writer.Allocating = .init(arena); + try caught.format(&aw.writer); + + const content = [_]protocol.TextContent([]const u8){.{ .text = aw.written() }}; + return server.sendResult(id, protocol.CallToolResult([]const u8){ .content = &content, .isError = true }); + }; + + const str_result = js_result.toStringSliceWithAlloc(arena) catch "undefined"; + + const content = [_]protocol.TextContent([]const u8){.{ .text = str_result }}; + try server.sendResult(id, protocol.CallToolResult([]const u8){ .content = &content }); +} + +fn parseArguments(comptime T: type, arena: std.mem.Allocator, arguments: ?std.json.Value, server: *Server, id: std.json.Value, tool_name: []const u8) !T { + if (arguments == null) { + try server.sendError(id, .InvalidParams, "Missing arguments"); + return error.InvalidParams; + } + return std.json.parseFromValueLeaky(T, arena, arguments.?, .{ .ignore_unknown_fields = true }) catch { + const msg = std.fmt.allocPrint(arena, "Invalid arguments for {s}", .{tool_name}) catch "Invalid arguments"; + try server.sendError(id, .InvalidParams, msg); + return error.InvalidParams; + }; +} + +fn performGoto(server: *Server, url: [:0]const u8, id: std.json.Value) !void { + _ = server.page.navigate(url, .{ + .reason = .address_bar, + .kind = .{ .push = null }, + }) catch { + try server.sendError(id, .InternalError, "Internal error during navigation"); + return error.NavigationFailed; + }; + + _ = server.session.wait(5000); +} + +const testing = @import("../testing.zig"); +const router = @import("router.zig"); + +test "MCP - evaluate error reporting" { + defer testing.reset(); + const allocator = testing.allocator; + const app = testing.test_app; + + var out_alloc: std.io.Writer.Allocating = .init(testing.arena_allocator); + defer out_alloc.deinit(); + + var server = try Server.init(allocator, app, &out_alloc.writer); + defer server.deinit(); + + const aa = testing.arena_allocator; + + // Call evaluate with a script that throws an error + const msg = + \\{ + \\ "jsonrpc": "2.0", + \\ "id": 1, + \\ "method": "tools/call", + \\ "params": { + \\ "name": "evaluate", + \\ "arguments": { + \\ "script": "throw new Error('test error')" + \\ } + \\ } + \\} + ; + + try router.handleMessage(server, aa, msg); + + try testing.expectJson( + \\{ + \\ "id": 1, + \\ "result": { + \\ "isError": true, + \\ "content": [ + \\ { "type": "text" } + \\ ] + \\ } + \\} + , out_alloc.writer.buffered()); +}