diff --git a/src/main.ts b/src/main.ts index 2f34804..0a444bc 100644 --- a/src/main.ts +++ b/src/main.ts @@ -1,12 +1,13 @@ -import { StdioServerTransport } from "@modelcontextprotocol/sdk/server/stdio.js"; -import pino from "pino"; -import express from "express"; import { StreamableHTTPServerTransport } from "@modelcontextprotocol/sdk/server/streamableHttp.js"; import { isInitializeRequest } from "@modelcontextprotocol/sdk/types.js"; -import crypto from "crypto"; +import express from "express"; +import pino from "pino"; +import { z } from "zod"; import { outlineMcpFactory } from "./outline"; +const sessionIdSchema = z.string().uuid(); + async function main() { const logger = pino({ level: "debug", @@ -24,12 +25,40 @@ async function main() { }, "Received MCP request" ); - const sessionId = req.headers["mcp-session-id"] as string | undefined; + const sessionIdHeader = req.headers["mcp-session-id"] as string | undefined; let transport: StreamableHTTPServerTransport; - if (sessionId && transports[sessionId]) { - transport = transports[sessionId]; - } else if (!sessionId && isInitializeRequest(req.body)) { + if (sessionIdHeader) { + const safeSessionId = sessionIdSchema.safeParse(sessionIdHeader); + if (!safeSessionId.success) { + logger.error("Invalid session ID format"); + res.status(400).json({ + jsonrpc: "2.0", + error: { + code: -32000, + message: "Invalid session ID format", + }, + id: null, + }); + return; + } + + const sessionId = safeSessionId.data; + if (transports[sessionId]) { + transport = transports[sessionId]; + } else { + logger.error("Session not found"); + res.status(400).json({ + jsonrpc: "2.0", + error: { + code: -32000, + message: "Session not found", + }, + id: null, + }); + return; + } + } else if (isInitializeRequest(req.body)) { transport = new StreamableHTTPServerTransport({ sessionIdGenerator: () => crypto.randomUUID(), onsessioninitialized: (sessionId) => { @@ -42,7 +71,9 @@ async function main() { delete transports[transport.sessionId]; } }; - const outlineMcpServer = outlineMcpFactory(logger); + const outlineMcpServer = outlineMcpFactory( + logger.child({ sessionId: transport.sessionId }) + ); await outlineMcpServer.connect(transport); } else {