diff --git a/src/mcp/protocol.zig b/src/mcp/protocol.zig index c195560f..b9385ddb 100644 --- a/src/mcp/protocol.zig +++ b/src/mcp/protocol.zig @@ -96,6 +96,40 @@ pub const Resource = struct { mimeType: ?[]const u8 = null, }; +pub const JsonEscapingWriter = struct { + inner_writer: *std.Io.Writer, + writer: std.Io.Writer, + + pub fn init(inner_writer: *std.Io.Writer) JsonEscapingWriter { + return .{ + .inner_writer = inner_writer, + .writer = .{ + .vtable = &vtable, + .buffer = &.{}, + }, + }; + } + + const vtable = std.Io.Writer.VTable{ + .drain = drain, + }; + + fn drain(w: *std.Io.Writer, data: []const []const u8, splat: usize) std.Io.Writer.Error!usize { + const self: *JsonEscapingWriter = @alignCast(@fieldParentPtr("writer", w)); + var total: usize = 0; + for (data[0 .. data.len - 1]) |slice| { + std.json.Stringify.encodeJsonStringChars(slice, .{}, self.inner_writer) catch return error.WriteFailed; + total += slice.len; + } + const pattern = data[data.len - 1]; + for (0..splat) |_| { + std.json.Stringify.encodeJsonStringChars(pattern, .{}, self.inner_writer) catch return error.WriteFailed; + total += pattern.len; + } + return total; + } +}; + const testing = @import("../testing.zig"); test "protocol request parsing" { diff --git a/src/mcp/resources.zig b/src/mcp/resources.zig index e9553167..d5f16770 100644 --- a/src/mcp/resources.zig +++ b/src/mcp/resources.zig @@ -34,6 +34,41 @@ const ReadParams = struct { uri: []const u8, }; +const ResourceStreamingResult = struct { + contents: []const struct { + uri: []const u8, + mimeType: []const u8, + text: StreamingText, + }, + + const StreamingText = struct { + server: *Server, + uri: []const u8, + format: enum { html, markdown }, + + pub fn jsonStringify(self: @This(), jw: *std.json.Stringify) !void { + try jw.beginObject(); + try jw.objectField("uri"); + try jw.write(self.uri); + try jw.objectField("mimeType"); + try jw.write(if (self.format == .html) "text/html" else "text/markdown"); + try jw.objectField("text"); + + try jw.beginWriteRaw(); + try jw.writer.writeByte('"'); + var escaped = protocol.JsonEscapingWriter.init(jw.writer); + switch (self.format) { + .html => try lp.dump.root(self.server.page.document, .{}, &escaped.writer, self.server.page), + .markdown => try lp.markdown.dump(self.server.page.document.asNode(), .{}, &escaped.writer, self.server.page), + } + try jw.writer.writeByte('"'); + jw.endWriteRaw(); + + try jw.endObject(); + } + }; +}; + pub fn handleRead(server: *Server, arena: std.mem.Allocator, req: protocol.Request) !void { if (req.params == null) { return sendError(server, req.id.?, -32602, "Missing params"); @@ -44,37 +79,23 @@ pub fn handleRead(server: *Server, arena: std.mem.Allocator, req: protocol.Reque }; if (std.mem.eql(u8, params.uri, "mcp://page/html")) { - var aw = std.Io.Writer.Allocating.init(arena); - lp.dump.root(server.page.document, .{}, &aw.writer, server.page) catch { - return sendError(server, req.id.?, -32603, "Internal error reading HTML"); + const result = ResourceStreamingResult{ + .contents = &.{.{ + .uri = params.uri, + .mimeType = "text/html", + .text = .{ .server = server, .uri = params.uri, .format = .html }, + }}, }; - - const contents = [_]struct { - uri: []const u8, - mimeType: []const u8, - text: []const u8, - }{.{ - .uri = params.uri, - .mimeType = "text/html", - .text = aw.written(), - }}; - try sendResult(server, req.id.?, .{ .contents = &contents }); + try sendResult(server, req.id.?, result); } else if (std.mem.eql(u8, params.uri, "mcp://page/markdown")) { - var aw = std.Io.Writer.Allocating.init(arena); - lp.markdown.dump(server.page.document.asNode(), .{}, &aw.writer, server.page) catch { - return sendError(server, req.id.?, -32603, "Internal error reading Markdown"); + const result = ResourceStreamingResult{ + .contents = &.{.{ + .uri = params.uri, + .mimeType = "text/markdown", + .text = .{ .server = server, .uri = params.uri, .format = .markdown }, + }}, }; - - const contents = [_]struct { - uri: []const u8, - mimeType: []const u8, - text: []const u8, - }{.{ - .uri = params.uri, - .mimeType = "text/markdown", - .text = aw.written(), - }}; - try sendResult(server, req.id.?, .{ .contents = &contents }); + try sendResult(server, req.id.?, result); } else { return sendError(server, req.id.?, -32602, "Resource not found"); } diff --git a/src/mcp/tools.zig b/src/mcp/tools.zig index 5f1e3cf5..80253a6a 100644 --- a/src/mcp/tools.zig +++ b/src/mcp/tools.zig @@ -116,6 +116,39 @@ const OverParams = struct { result: []const u8, }; +const ToolStreamingText = struct { + server: *Server, + action: enum { markdown, links }, + + pub fn jsonStringify(self: @This(), jw: *std.json.Stringify) !void { + try jw.beginWriteRaw(); + try jw.writer.writeByte('"'); + var escaped = protocol.JsonEscapingWriter.init(jw.writer); + const w = &escaped.writer; + switch (self.action) { + .markdown => try lp.markdown.dump(self.server.page.document.asNode(), .{}, w, self.server.page), + .links => { + const list = Selector.querySelectorAll(self.server.page.document.asNode(), "a[href]", self.server.page) catch |err| { + log.err(.mcp, "Error querying links: {s}", .{@errorName(err)}); + return; + }; + var first = true; + for (list._nodes) |node| { + if (node.is(Element)) |el| { + if (el.getAttributeSafe(String.wrap("href"))) |href| { + if (!first) try w.writeByte('\n'); + try w.writeAll(href); + first = false; + } + } + } + }, + } + try jw.writer.writeByte('"'); + jw.endWriteRaw(); + } +}; + pub fn handleCall(server: *Server, arena: std.mem.Allocator, req: protocol.Request) !void { if (req.params == null) { return sendError(server, req.id.?, -32602, "Missing params"); @@ -183,13 +216,16 @@ pub fn handleCall(server: *Server, arena: std.mem.Allocator, req: protocol.Reque } } else |_| {} } - var aw = std.Io.Writer.Allocating.init(arena); - lp.markdown.dump(server.page.document.asNode(), .{}, &aw.writer, server.page) catch { - return sendError(server, req.id.?, -32603, "Internal error parsing markdown"); - }; - const content = [_]struct { type: []const u8, text: []const u8 }{.{ .type = "text", .text = aw.written() }}; - try sendResult(server, req.id.?, .{ .content = &content }); + const result = struct { + content: []const struct { type: []const u8, text: ToolStreamingText }, + }{ + .content = &.{.{ + .type = "text", + .text = .{ .server = server, .action = .markdown }, + }}, + }; + try sendResult(server, req.id.?, result); } else if (std.mem.eql(u8, call_params.name, "links")) { const LinksParams = struct { url: ?[]const u8 = null, @@ -203,24 +239,16 @@ pub fn handleCall(server: *Server, arena: std.mem.Allocator, req: protocol.Reque } } else |_| {} } - const list = Selector.querySelectorAll(server.page.document.asNode(), "a[href]", server.page) catch { - return sendError(server, req.id.?, -32603, "Internal error querying selector"); + + const result = struct { + content: []const struct { type: []const u8, text: ToolStreamingText }, + }{ + .content = &.{.{ + .type = "text", + .text = .{ .server = server, .action = .links }, + }}, }; - - var aw = std.Io.Writer.Allocating.init(arena); - var first = true; - for (list._nodes) |node| { - if (node.is(Element)) |el| { - if (el.getAttributeSafe(String.wrap("href"))) |href| { - if (!first) aw.writer.writeByte('\n') catch continue; - aw.writer.writeAll(href) catch continue; - first = false; - } - } - } - - const content = [_]struct { type: []const u8, text: []const u8 }{.{ .type = "text", .text = aw.written() }}; - try sendResult(server, req.id.?, .{ .content = &content }); + try sendResult(server, req.id.?, result); } else if (std.mem.eql(u8, call_params.name, "evaluate")) { if (call_params.arguments == null) { return sendError(server, req.id.?, -32602, "Missing arguments for evaluate");