diff --git a/src/server/mcp/index.ts b/src/server/mcp/index.ts index a039a98..997244c 100644 --- a/src/server/mcp/index.ts +++ b/src/server/mcp/index.ts @@ -17,32 +17,32 @@ import { registerThreadTools, threadToolDefinitions } from "./tools/threads.ts"; type Db = typeof prodDb; -function createMcpServer(db: Db): McpServer { +function createMcpServer(db: Db, userId: number): McpServer { const server = new McpServer({ name: "GearBox", version: "1.0.0" }); // Register item tools - const itemHandlers = registerItemTools(db); + const itemHandlers = registerItemTools(db, userId); for (const def of itemToolDefinitions) { const handler = itemHandlers[def.name as keyof typeof itemHandlers]; server.tool(def.name, def.description, def.inputSchema, handler); } // Register category tools - const categoryHandlers = registerCategoryTools(db); + const categoryHandlers = registerCategoryTools(db, userId); for (const def of categoryToolDefinitions) { const handler = categoryHandlers[def.name as keyof typeof categoryHandlers]; server.tool(def.name, def.description, def.inputSchema, handler); } // Register thread tools - const threadHandlers = registerThreadTools(db); + const threadHandlers = registerThreadTools(db, userId); for (const def of threadToolDefinitions) { const handler = threadHandlers[def.name as keyof typeof threadHandlers]; server.tool(def.name, def.description, def.inputSchema, handler); } // Register setup tools - const setupHandlers = registerSetupTools(db); + const setupHandlers = registerSetupTools(db, userId); for (const def of setupToolDefinitions) { const handler = setupHandlers[def.name as keyof typeof setupHandlers]; server.tool(def.name, def.description, def.inputSchema, handler); @@ -65,7 +65,7 @@ function createMcpServer(db: Db): McpServer { mimeType: "application/json", }, async () => { - const summary = await getCollectionSummary(db); + const summary = await getCollectionSummary(db, userId); return { contents: [ { @@ -81,12 +81,15 @@ function createMcpServer(db: Db): McpServer { return server; } -// Store active transports by session ID -const transports = new Map(); +// Store active transports by session ID (with userId for session reuse) +const transports = new Map< + string, + { transport: WebStandardStreamableHTTPServerTransport; userId: number } +>(); export const mcpRoutes = new Hono(); -// Auth middleware for all MCP requests +// Auth middleware for all MCP requests — resolves userId mcpRoutes.use("/*", async (c, next) => { const db = c.get("db") ?? prodDb; @@ -94,7 +97,9 @@ mcpRoutes.use("/*", async (c, next) => { const authHeader = c.req.header("Authorization"); if (authHeader?.startsWith("Bearer ")) { const token = authHeader.slice(7); - if (await verifyAccessToken(db, token)) { + const result = await verifyAccessToken(db, token); + if (result) { + c.set("userId", result.userId); return next(); } return c.json({ error: "invalid_token" }, 401); @@ -103,8 +108,9 @@ mcpRoutes.use("/*", async (c, next) => { // Try API key const apiKey = c.req.header("X-API-Key"); if (apiKey) { - const valid = await verifyApiKey(db, apiKey); - if (valid) { + const result = await verifyApiKey(db, apiKey); + if (result) { + c.set("userId", result.userId); return next(); } return c.json({ error: "Invalid API key" }, 401); @@ -121,16 +127,17 @@ mcpRoutes.use("/*", async (c, next) => { mcpRoutes.post("/", async (c) => { const db = c.get("db") ?? prodDb; + const userId = c.get("userId") as number; // Check for existing session const sessionId = c.req.header("mcp-session-id"); if (sessionId) { - const transport = transports.get(sessionId); - if (!transport) { + const session = transports.get(sessionId); + if (!session) { return c.json({ error: "Session not found" }, 404); } - const response = await transport.handleRequest(c.req.raw); + const response = await session.transport.handleRequest(c.req.raw); return response; } @@ -138,19 +145,19 @@ mcpRoutes.post("/", async (c) => { const transport = new WebStandardStreamableHTTPServerTransport({ sessionIdGenerator: () => randomUUID(), onsessioninitialized: (newSessionId) => { - transports.set(newSessionId, transport); + transports.set(newSessionId, { transport, userId }); }, }); // Clean up on close transport.onclose = () => { const sid = [...transports.entries()].find( - ([_, t]) => t === transport, + ([_, s]) => s.transport === transport, )?.[0]; if (sid) transports.delete(sid); }; - const server = createMcpServer(db); + const server = createMcpServer(db, userId); await server.connect(transport); const response = await transport.handleRequest(c.req.raw); @@ -163,12 +170,12 @@ mcpRoutes.get("/", async (c) => { return c.json({ error: "Session ID required" }, 400); } - const transport = transports.get(sessionId); - if (!transport) { + const session = transports.get(sessionId); + if (!session) { return c.json({ error: "Session not found" }, 404); } - const response = await transport.handleRequest(c.req.raw); + const response = await session.transport.handleRequest(c.req.raw); return response; }); @@ -178,12 +185,12 @@ mcpRoutes.delete("/", async (c) => { return c.json({ error: "Session ID required" }, 400); } - const transport = transports.get(sessionId); - if (!transport) { + const session = transports.get(sessionId); + if (!session) { return c.json({ error: "Session not found" }, 404); } - await transport.close(); + await session.transport.close(); transports.delete(sessionId); return c.text("", 200); }); diff --git a/src/server/mcp/resources/collection.ts b/src/server/mcp/resources/collection.ts index e5860b4..f441c5f 100644 --- a/src/server/mcp/resources/collection.ts +++ b/src/server/mcp/resources/collection.ts @@ -7,12 +7,12 @@ import { getGlobalTotals } from "../../services/totals.service.ts"; type Db = typeof prodDb; -export async function getCollectionSummary(db: Db) { - const totals = await getGlobalTotals(db); - const categories = await getAllCategories(db); - const items = await getAllItems(db); - const setups = await getAllSetups(db); - const activeThreads = await getAllThreads(db, false); +export async function getCollectionSummary(db: Db, userId: number) { + const totals = await getGlobalTotals(db, userId); + const categories = await getAllCategories(db, userId); + const items = await getAllItems(db, userId); + const setups = await getAllSetups(db, userId); + const activeThreads = await getAllThreads(db, userId, false); // Build items-by-category map const itemsByCategory: Record = {}; diff --git a/src/server/mcp/tools/categories.ts b/src/server/mcp/tools/categories.ts index b2e7394..2a28f6a 100644 --- a/src/server/mcp/tools/categories.ts +++ b/src/server/mcp/tools/categories.ts @@ -37,11 +37,11 @@ export const categoryToolDefinitions = [ }, ]; -export function registerCategoryTools(db: Db) { +export function registerCategoryTools(db: Db, userId: number) { return { list_categories: async (): Promise => { try { - const cats = await getAllCategories(db); + const cats = await getAllCategories(db, userId); return textResult(cats); } catch (err) { return errorResult((err as Error).message); @@ -53,7 +53,7 @@ export function registerCategoryTools(db: Db) { icon?: string; }): Promise => { try { - const cat = await createCategory(db, args); + const cat = await createCategory(db, userId, args); return textResult(cat); } catch (err) { return errorResult((err as Error).message); diff --git a/src/server/mcp/tools/items.ts b/src/server/mcp/tools/items.ts index c65b5e8..95e6aa3 100644 --- a/src/server/mcp/tools/items.ts +++ b/src/server/mcp/tools/items.ts @@ -91,11 +91,11 @@ export const itemToolDefinitions = [ }, ]; -export function registerItemTools(db: Db) { +export function registerItemTools(db: Db, userId: number) { return { list_items: async (args: { categoryId?: number }): Promise => { try { - const items = await getAllItems(db); + const items = await getAllItems(db, userId); if (args.categoryId) { return textResult( items.filter((i) => i.categoryId === args.categoryId), @@ -109,7 +109,7 @@ export function registerItemTools(db: Db) { get_item: async (args: { id: number }): Promise => { try { - const item = await getItemById(db, args.id); + const item = await getItemById(db, userId, args.id); if (!item) return errorResult(`Item ${args.id} not found`); return textResult(item); } catch (err) { @@ -128,7 +128,7 @@ export function registerItemTools(db: Db) { imageSourceUrl?: string; }): Promise => { try { - const item = await createItem(db, args); + const item = await createItem(db, userId, args); return textResult(item); } catch (err) { return errorResult((err as Error).message); @@ -148,7 +148,7 @@ export function registerItemTools(db: Db) { }): Promise => { try { const { id, ...data } = args; - const item = await updateItem(db, id, data); + const item = await updateItem(db, userId, id, data); if (!item) return errorResult(`Item ${id} not found`); return textResult(item); } catch (err) { @@ -158,7 +158,7 @@ export function registerItemTools(db: Db) { delete_item: async (args: { id: number }): Promise => { try { - const item = await deleteItem(db, args.id); + const item = await deleteItem(db, userId, args.id); if (!item) return errorResult(`Item ${args.id} not found`); return textResult({ deleted: true, item }); } catch (err) { diff --git a/src/server/mcp/tools/setups.ts b/src/server/mcp/tools/setups.ts index 1b6581b..b7e42b2 100644 --- a/src/server/mcp/tools/setups.ts +++ b/src/server/mcp/tools/setups.ts @@ -60,11 +60,11 @@ export const setupToolDefinitions = [ }, ]; -export function registerSetupTools(db: Db) { +export function registerSetupTools(db: Db, userId: number) { return { list_setups: async (): Promise => { try { - const setupList = await getAllSetups(db); + const setupList = await getAllSetups(db, userId); return textResult(setupList); } catch (err) { return errorResult((err as Error).message); @@ -73,7 +73,7 @@ export function registerSetupTools(db: Db) { get_setup: async (args: { id: number }): Promise => { try { - const setup = await getSetupWithItems(db, args.id); + const setup = await getSetupWithItems(db, userId, args.id); if (!setup) return errorResult(`Setup ${args.id} not found`); return textResult(setup); } catch (err) { @@ -83,7 +83,7 @@ export function registerSetupTools(db: Db) { create_setup: async (args: { name: string }): Promise => { try { - const setup = await createSetup(db, args); + const setup = await createSetup(db, userId, args); return textResult(setup); } catch (err) { return errorResult((err as Error).message); @@ -98,14 +98,14 @@ export function registerSetupTools(db: Db) { try { let setup = null; if (args.name) { - setup = await updateSetup(db, args.id, { name: args.name }); + setup = await updateSetup(db, userId, args.id, { name: args.name }); if (!setup) return errorResult(`Setup ${args.id} not found`); } if (args.itemIds) { - await syncSetupItems(db, args.id, args.itemIds); + await syncSetupItems(db, userId, args.id, args.itemIds); } // Return updated setup with items - const result = await getSetupWithItems(db, args.id); + const result = await getSetupWithItems(db, userId, args.id); if (!result) return errorResult(`Setup ${args.id} not found`); return textResult(result); } catch (err) { diff --git a/src/server/mcp/tools/threads.ts b/src/server/mcp/tools/threads.ts index ec02984..7392f62 100644 --- a/src/server/mcp/tools/threads.ts +++ b/src/server/mcp/tools/threads.ts @@ -113,13 +113,17 @@ export const threadToolDefinitions = [ }, ]; -export function registerThreadTools(db: Db) { +export function registerThreadTools(db: Db, userId: number) { return { list_threads: async (args: { includeResolved?: boolean; }): Promise => { try { - const threadList = await getAllThreads(db, args.includeResolved ?? false); + const threadList = await getAllThreads( + db, + userId, + args.includeResolved ?? false, + ); return textResult(threadList); } catch (err) { return errorResult((err as Error).message); @@ -128,7 +132,7 @@ export function registerThreadTools(db: Db) { get_thread: async (args: { id: number }): Promise => { try { - const thread = await getThreadWithCandidates(db, args.id); + const thread = await getThreadWithCandidates(db, userId, args.id); if (!thread) return errorResult(`Thread ${args.id} not found`); return textResult(thread); } catch (err) { @@ -141,7 +145,7 @@ export function registerThreadTools(db: Db) { categoryId: number; }): Promise => { try { - const thread = await createThread(db, args); + const thread = await createThread(db, userId, args); return textResult(thread); } catch (err) { return errorResult((err as Error).message); @@ -153,7 +157,12 @@ export function registerThreadTools(db: Db) { candidateId: number; }): Promise => { try { - const result = await resolveThread(db, args.threadId, args.candidateId); + const result = await resolveThread( + db, + userId, + args.threadId, + args.candidateId, + ); if (!result.success) { return errorResult(result.error ?? "Failed to resolve thread"); } @@ -177,7 +186,7 @@ export function registerThreadTools(db: Db) { }): Promise => { try { const { threadId, ...data } = args; - const candidate = await createCandidate(db, threadId, data); + const candidate = await createCandidate(db, userId, threadId, data); return textResult(candidate); } catch (err) { return errorResult((err as Error).message); @@ -200,7 +209,7 @@ export function registerThreadTools(db: Db) { }): Promise => { try { const { id, ...data } = args; - const candidate = await updateCandidate(db, id, data); + const candidate = await updateCandidate(db, userId, id, data); if (!candidate) return errorResult(`Candidate ${id} not found`); return textResult(candidate); } catch (err) { @@ -210,7 +219,7 @@ export function registerThreadTools(db: Db) { remove_candidate: async (args: { id: number }): Promise => { try { - const candidate = await deleteCandidate(db, args.id); + const candidate = await deleteCandidate(db, userId, args.id); if (!candidate) return errorResult(`Candidate ${args.id} not found`); return textResult({ deleted: true, candidate }); } catch (err) {