feat(16-03): wire userId into MCP server and tool registrations

- Update createMcpServer signature to accept (db, userId)
- MCP auth middleware resolves userId from API key and Bearer token
- Store userId alongside transport in session map
- All 4 tool registration functions accept and pass userId
- Collection summary resource passes userId to all service calls
This commit is contained in:
2026-04-05 10:52:43 +02:00
parent e78002208a
commit d4bf4f5c16
6 changed files with 70 additions and 54 deletions

View File

@@ -17,32 +17,32 @@ import { registerThreadTools, threadToolDefinitions } from "./tools/threads.ts";
type Db = typeof prodDb; type Db = typeof prodDb;
function createMcpServer(db: Db): McpServer { function createMcpServer(db: Db, userId: number): McpServer {
const server = new McpServer({ name: "GearBox", version: "1.0.0" }); const server = new McpServer({ name: "GearBox", version: "1.0.0" });
// Register item tools // Register item tools
const itemHandlers = registerItemTools(db); const itemHandlers = registerItemTools(db, userId);
for (const def of itemToolDefinitions) { for (const def of itemToolDefinitions) {
const handler = itemHandlers[def.name as keyof typeof itemHandlers]; const handler = itemHandlers[def.name as keyof typeof itemHandlers];
server.tool(def.name, def.description, def.inputSchema, handler); server.tool(def.name, def.description, def.inputSchema, handler);
} }
// Register category tools // Register category tools
const categoryHandlers = registerCategoryTools(db); const categoryHandlers = registerCategoryTools(db, userId);
for (const def of categoryToolDefinitions) { for (const def of categoryToolDefinitions) {
const handler = categoryHandlers[def.name as keyof typeof categoryHandlers]; const handler = categoryHandlers[def.name as keyof typeof categoryHandlers];
server.tool(def.name, def.description, def.inputSchema, handler); server.tool(def.name, def.description, def.inputSchema, handler);
} }
// Register thread tools // Register thread tools
const threadHandlers = registerThreadTools(db); const threadHandlers = registerThreadTools(db, userId);
for (const def of threadToolDefinitions) { for (const def of threadToolDefinitions) {
const handler = threadHandlers[def.name as keyof typeof threadHandlers]; const handler = threadHandlers[def.name as keyof typeof threadHandlers];
server.tool(def.name, def.description, def.inputSchema, handler); server.tool(def.name, def.description, def.inputSchema, handler);
} }
// Register setup tools // Register setup tools
const setupHandlers = registerSetupTools(db); const setupHandlers = registerSetupTools(db, userId);
for (const def of setupToolDefinitions) { for (const def of setupToolDefinitions) {
const handler = setupHandlers[def.name as keyof typeof setupHandlers]; const handler = setupHandlers[def.name as keyof typeof setupHandlers];
server.tool(def.name, def.description, def.inputSchema, handler); server.tool(def.name, def.description, def.inputSchema, handler);
@@ -65,7 +65,7 @@ function createMcpServer(db: Db): McpServer {
mimeType: "application/json", mimeType: "application/json",
}, },
async () => { async () => {
const summary = await getCollectionSummary(db); const summary = await getCollectionSummary(db, userId);
return { return {
contents: [ contents: [
{ {
@@ -81,12 +81,15 @@ function createMcpServer(db: Db): McpServer {
return server; return server;
} }
// Store active transports by session ID // Store active transports by session ID (with userId for session reuse)
const transports = new Map<string, WebStandardStreamableHTTPServerTransport>(); const transports = new Map<
string,
{ transport: WebStandardStreamableHTTPServerTransport; userId: number }
>();
export const mcpRoutes = new Hono(); export const mcpRoutes = new Hono();
// Auth middleware for all MCP requests // Auth middleware for all MCP requests — resolves userId
mcpRoutes.use("/*", async (c, next) => { mcpRoutes.use("/*", async (c, next) => {
const db = c.get("db") ?? prodDb; const db = c.get("db") ?? prodDb;
@@ -94,7 +97,9 @@ mcpRoutes.use("/*", async (c, next) => {
const authHeader = c.req.header("Authorization"); const authHeader = c.req.header("Authorization");
if (authHeader?.startsWith("Bearer ")) { if (authHeader?.startsWith("Bearer ")) {
const token = authHeader.slice(7); const token = authHeader.slice(7);
if (await verifyAccessToken(db, token)) { const result = await verifyAccessToken(db, token);
if (result) {
c.set("userId", result.userId);
return next(); return next();
} }
return c.json({ error: "invalid_token" }, 401); return c.json({ error: "invalid_token" }, 401);
@@ -103,8 +108,9 @@ mcpRoutes.use("/*", async (c, next) => {
// Try API key // Try API key
const apiKey = c.req.header("X-API-Key"); const apiKey = c.req.header("X-API-Key");
if (apiKey) { if (apiKey) {
const valid = await verifyApiKey(db, apiKey); const result = await verifyApiKey(db, apiKey);
if (valid) { if (result) {
c.set("userId", result.userId);
return next(); return next();
} }
return c.json({ error: "Invalid API key" }, 401); return c.json({ error: "Invalid API key" }, 401);
@@ -121,16 +127,17 @@ mcpRoutes.use("/*", async (c, next) => {
mcpRoutes.post("/", async (c) => { mcpRoutes.post("/", async (c) => {
const db = c.get("db") ?? prodDb; const db = c.get("db") ?? prodDb;
const userId = c.get("userId") as number;
// Check for existing session // Check for existing session
const sessionId = c.req.header("mcp-session-id"); const sessionId = c.req.header("mcp-session-id");
if (sessionId) { if (sessionId) {
const transport = transports.get(sessionId); const session = transports.get(sessionId);
if (!transport) { if (!session) {
return c.json({ error: "Session not found" }, 404); return c.json({ error: "Session not found" }, 404);
} }
const response = await transport.handleRequest(c.req.raw); const response = await session.transport.handleRequest(c.req.raw);
return response; return response;
} }
@@ -138,19 +145,19 @@ mcpRoutes.post("/", async (c) => {
const transport = new WebStandardStreamableHTTPServerTransport({ const transport = new WebStandardStreamableHTTPServerTransport({
sessionIdGenerator: () => randomUUID(), sessionIdGenerator: () => randomUUID(),
onsessioninitialized: (newSessionId) => { onsessioninitialized: (newSessionId) => {
transports.set(newSessionId, transport); transports.set(newSessionId, { transport, userId });
}, },
}); });
// Clean up on close // Clean up on close
transport.onclose = () => { transport.onclose = () => {
const sid = [...transports.entries()].find( const sid = [...transports.entries()].find(
([_, t]) => t === transport, ([_, s]) => s.transport === transport,
)?.[0]; )?.[0];
if (sid) transports.delete(sid); if (sid) transports.delete(sid);
}; };
const server = createMcpServer(db); const server = createMcpServer(db, userId);
await server.connect(transport); await server.connect(transport);
const response = await transport.handleRequest(c.req.raw); const response = await transport.handleRequest(c.req.raw);
@@ -163,12 +170,12 @@ mcpRoutes.get("/", async (c) => {
return c.json({ error: "Session ID required" }, 400); return c.json({ error: "Session ID required" }, 400);
} }
const transport = transports.get(sessionId); const session = transports.get(sessionId);
if (!transport) { if (!session) {
return c.json({ error: "Session not found" }, 404); return c.json({ error: "Session not found" }, 404);
} }
const response = await transport.handleRequest(c.req.raw); const response = await session.transport.handleRequest(c.req.raw);
return response; return response;
}); });
@@ -178,12 +185,12 @@ mcpRoutes.delete("/", async (c) => {
return c.json({ error: "Session ID required" }, 400); return c.json({ error: "Session ID required" }, 400);
} }
const transport = transports.get(sessionId); const session = transports.get(sessionId);
if (!transport) { if (!session) {
return c.json({ error: "Session not found" }, 404); return c.json({ error: "Session not found" }, 404);
} }
await transport.close(); await session.transport.close();
transports.delete(sessionId); transports.delete(sessionId);
return c.text("", 200); return c.text("", 200);
}); });

View File

@@ -7,12 +7,12 @@ import { getGlobalTotals } from "../../services/totals.service.ts";
type Db = typeof prodDb; type Db = typeof prodDb;
export async function getCollectionSummary(db: Db) { export async function getCollectionSummary(db: Db, userId: number) {
const totals = await getGlobalTotals(db); const totals = await getGlobalTotals(db, userId);
const categories = await getAllCategories(db); const categories = await getAllCategories(db, userId);
const items = await getAllItems(db); const items = await getAllItems(db, userId);
const setups = await getAllSetups(db); const setups = await getAllSetups(db, userId);
const activeThreads = await getAllThreads(db, false); const activeThreads = await getAllThreads(db, userId, false);
// Build items-by-category map // Build items-by-category map
const itemsByCategory: Record<string, number> = {}; const itemsByCategory: Record<string, number> = {};

View File

@@ -37,11 +37,11 @@ export const categoryToolDefinitions = [
}, },
]; ];
export function registerCategoryTools(db: Db) { export function registerCategoryTools(db: Db, userId: number) {
return { return {
list_categories: async (): Promise<ToolResult> => { list_categories: async (): Promise<ToolResult> => {
try { try {
const cats = await getAllCategories(db); const cats = await getAllCategories(db, userId);
return textResult(cats); return textResult(cats);
} catch (err) { } catch (err) {
return errorResult((err as Error).message); return errorResult((err as Error).message);
@@ -53,7 +53,7 @@ export function registerCategoryTools(db: Db) {
icon?: string; icon?: string;
}): Promise<ToolResult> => { }): Promise<ToolResult> => {
try { try {
const cat = await createCategory(db, args); const cat = await createCategory(db, userId, args);
return textResult(cat); return textResult(cat);
} catch (err) { } catch (err) {
return errorResult((err as Error).message); return errorResult((err as Error).message);

View File

@@ -91,11 +91,11 @@ export const itemToolDefinitions = [
}, },
]; ];
export function registerItemTools(db: Db) { export function registerItemTools(db: Db, userId: number) {
return { return {
list_items: async (args: { categoryId?: number }): Promise<ToolResult> => { list_items: async (args: { categoryId?: number }): Promise<ToolResult> => {
try { try {
const items = await getAllItems(db); const items = await getAllItems(db, userId);
if (args.categoryId) { if (args.categoryId) {
return textResult( return textResult(
items.filter((i) => i.categoryId === args.categoryId), items.filter((i) => i.categoryId === args.categoryId),
@@ -109,7 +109,7 @@ export function registerItemTools(db: Db) {
get_item: async (args: { id: number }): Promise<ToolResult> => { get_item: async (args: { id: number }): Promise<ToolResult> => {
try { try {
const item = await getItemById(db, args.id); const item = await getItemById(db, userId, args.id);
if (!item) return errorResult(`Item ${args.id} not found`); if (!item) return errorResult(`Item ${args.id} not found`);
return textResult(item); return textResult(item);
} catch (err) { } catch (err) {
@@ -128,7 +128,7 @@ export function registerItemTools(db: Db) {
imageSourceUrl?: string; imageSourceUrl?: string;
}): Promise<ToolResult> => { }): Promise<ToolResult> => {
try { try {
const item = await createItem(db, args); const item = await createItem(db, userId, args);
return textResult(item); return textResult(item);
} catch (err) { } catch (err) {
return errorResult((err as Error).message); return errorResult((err as Error).message);
@@ -148,7 +148,7 @@ export function registerItemTools(db: Db) {
}): Promise<ToolResult> => { }): Promise<ToolResult> => {
try { try {
const { id, ...data } = args; const { id, ...data } = args;
const item = await updateItem(db, id, data); const item = await updateItem(db, userId, id, data);
if (!item) return errorResult(`Item ${id} not found`); if (!item) return errorResult(`Item ${id} not found`);
return textResult(item); return textResult(item);
} catch (err) { } catch (err) {
@@ -158,7 +158,7 @@ export function registerItemTools(db: Db) {
delete_item: async (args: { id: number }): Promise<ToolResult> => { delete_item: async (args: { id: number }): Promise<ToolResult> => {
try { try {
const item = await deleteItem(db, args.id); const item = await deleteItem(db, userId, args.id);
if (!item) return errorResult(`Item ${args.id} not found`); if (!item) return errorResult(`Item ${args.id} not found`);
return textResult({ deleted: true, item }); return textResult({ deleted: true, item });
} catch (err) { } catch (err) {

View File

@@ -60,11 +60,11 @@ export const setupToolDefinitions = [
}, },
]; ];
export function registerSetupTools(db: Db) { export function registerSetupTools(db: Db, userId: number) {
return { return {
list_setups: async (): Promise<ToolResult> => { list_setups: async (): Promise<ToolResult> => {
try { try {
const setupList = await getAllSetups(db); const setupList = await getAllSetups(db, userId);
return textResult(setupList); return textResult(setupList);
} catch (err) { } catch (err) {
return errorResult((err as Error).message); return errorResult((err as Error).message);
@@ -73,7 +73,7 @@ export function registerSetupTools(db: Db) {
get_setup: async (args: { id: number }): Promise<ToolResult> => { get_setup: async (args: { id: number }): Promise<ToolResult> => {
try { try {
const setup = await getSetupWithItems(db, args.id); const setup = await getSetupWithItems(db, userId, args.id);
if (!setup) return errorResult(`Setup ${args.id} not found`); if (!setup) return errorResult(`Setup ${args.id} not found`);
return textResult(setup); return textResult(setup);
} catch (err) { } catch (err) {
@@ -83,7 +83,7 @@ export function registerSetupTools(db: Db) {
create_setup: async (args: { name: string }): Promise<ToolResult> => { create_setup: async (args: { name: string }): Promise<ToolResult> => {
try { try {
const setup = await createSetup(db, args); const setup = await createSetup(db, userId, args);
return textResult(setup); return textResult(setup);
} catch (err) { } catch (err) {
return errorResult((err as Error).message); return errorResult((err as Error).message);
@@ -98,14 +98,14 @@ export function registerSetupTools(db: Db) {
try { try {
let setup = null; let setup = null;
if (args.name) { if (args.name) {
setup = await updateSetup(db, args.id, { name: args.name }); setup = await updateSetup(db, userId, args.id, { name: args.name });
if (!setup) return errorResult(`Setup ${args.id} not found`); if (!setup) return errorResult(`Setup ${args.id} not found`);
} }
if (args.itemIds) { if (args.itemIds) {
await syncSetupItems(db, args.id, args.itemIds); await syncSetupItems(db, userId, args.id, args.itemIds);
} }
// Return updated setup with items // Return updated setup with items
const result = await getSetupWithItems(db, args.id); const result = await getSetupWithItems(db, userId, args.id);
if (!result) return errorResult(`Setup ${args.id} not found`); if (!result) return errorResult(`Setup ${args.id} not found`);
return textResult(result); return textResult(result);
} catch (err) { } catch (err) {

View File

@@ -113,13 +113,17 @@ export const threadToolDefinitions = [
}, },
]; ];
export function registerThreadTools(db: Db) { export function registerThreadTools(db: Db, userId: number) {
return { return {
list_threads: async (args: { list_threads: async (args: {
includeResolved?: boolean; includeResolved?: boolean;
}): Promise<ToolResult> => { }): Promise<ToolResult> => {
try { try {
const threadList = await getAllThreads(db, args.includeResolved ?? false); const threadList = await getAllThreads(
db,
userId,
args.includeResolved ?? false,
);
return textResult(threadList); return textResult(threadList);
} catch (err) { } catch (err) {
return errorResult((err as Error).message); return errorResult((err as Error).message);
@@ -128,7 +132,7 @@ export function registerThreadTools(db: Db) {
get_thread: async (args: { id: number }): Promise<ToolResult> => { get_thread: async (args: { id: number }): Promise<ToolResult> => {
try { try {
const thread = await getThreadWithCandidates(db, args.id); const thread = await getThreadWithCandidates(db, userId, args.id);
if (!thread) return errorResult(`Thread ${args.id} not found`); if (!thread) return errorResult(`Thread ${args.id} not found`);
return textResult(thread); return textResult(thread);
} catch (err) { } catch (err) {
@@ -141,7 +145,7 @@ export function registerThreadTools(db: Db) {
categoryId: number; categoryId: number;
}): Promise<ToolResult> => { }): Promise<ToolResult> => {
try { try {
const thread = await createThread(db, args); const thread = await createThread(db, userId, args);
return textResult(thread); return textResult(thread);
} catch (err) { } catch (err) {
return errorResult((err as Error).message); return errorResult((err as Error).message);
@@ -153,7 +157,12 @@ export function registerThreadTools(db: Db) {
candidateId: number; candidateId: number;
}): Promise<ToolResult> => { }): Promise<ToolResult> => {
try { try {
const result = await resolveThread(db, args.threadId, args.candidateId); const result = await resolveThread(
db,
userId,
args.threadId,
args.candidateId,
);
if (!result.success) { if (!result.success) {
return errorResult(result.error ?? "Failed to resolve thread"); return errorResult(result.error ?? "Failed to resolve thread");
} }
@@ -177,7 +186,7 @@ export function registerThreadTools(db: Db) {
}): Promise<ToolResult> => { }): Promise<ToolResult> => {
try { try {
const { threadId, ...data } = args; const { threadId, ...data } = args;
const candidate = await createCandidate(db, threadId, data); const candidate = await createCandidate(db, userId, threadId, data);
return textResult(candidate); return textResult(candidate);
} catch (err) { } catch (err) {
return errorResult((err as Error).message); return errorResult((err as Error).message);
@@ -200,7 +209,7 @@ export function registerThreadTools(db: Db) {
}): Promise<ToolResult> => { }): Promise<ToolResult> => {
try { try {
const { id, ...data } = args; const { id, ...data } = args;
const candidate = await updateCandidate(db, id, data); const candidate = await updateCandidate(db, userId, id, data);
if (!candidate) return errorResult(`Candidate ${id} not found`); if (!candidate) return errorResult(`Candidate ${id} not found`);
return textResult(candidate); return textResult(candidate);
} catch (err) { } catch (err) {
@@ -210,7 +219,7 @@ export function registerThreadTools(db: Db) {
remove_candidate: async (args: { id: number }): Promise<ToolResult> => { remove_candidate: async (args: { id: number }): Promise<ToolResult> => {
try { try {
const candidate = await deleteCandidate(db, args.id); const candidate = await deleteCandidate(db, userId, args.id);
if (!candidate) return errorResult(`Candidate ${args.id} not found`); if (!candidate) return errorResult(`Candidate ${args.id} not found`);
return textResult({ deleted: true, candidate }); return textResult({ deleted: true, candidate });
} catch (err) { } catch (err) {