diff --git a/drizzle-pg/0002_wakeful_vermin.sql b/drizzle-pg/0002_wakeful_vermin.sql index 0a1358c..178da80 100644 --- a/drizzle-pg/0002_wakeful_vermin.sql +++ b/drizzle-pg/0002_wakeful_vermin.sql @@ -15,6 +15,9 @@ ALTER TABLE "items" ADD COLUMN "global_item_id" integer;--> statement-breakpoint ALTER TABLE "items" ADD COLUMN "purchase_price_cents" integer;--> statement-breakpoint ALTER TABLE "items" ADD COLUMN "brand" text;--> statement-breakpoint ALTER TABLE "thread_candidates" ADD COLUMN "global_item_id" integer;--> statement-breakpoint +ALTER TABLE "oauth_codes" ADD COLUMN "user_id" integer NOT NULL DEFAULT 0;--> statement-breakpoint +ALTER TABLE "oauth_codes" ALTER COLUMN "user_id" DROP DEFAULT;--> statement-breakpoint +ALTER TABLE "oauth_codes" ADD CONSTRAINT "oauth_codes_user_id_users_id_fk" FOREIGN KEY ("user_id") REFERENCES "public"."users"("id") ON DELETE no action ON UPDATE no action;--> statement-breakpoint UPDATE "items" SET "global_item_id" = ( SELECT "global_item_id" FROM "item_global_links" WHERE "item_global_links"."item_id" = "items"."id" diff --git a/src/db/schema.ts b/src/db/schema.ts index cf06baf..ab0ae74 100644 --- a/src/db/schema.ts +++ b/src/db/schema.ts @@ -57,6 +57,7 @@ export const items = pgTable("items", { quantity: integer("quantity").notNull().default(1), globalItemId: integer("global_item_id").references(() => globalItems.id), purchasePriceCents: integer("purchase_price_cents"), + brand: text("brand"), createdAt: timestamp("created_at").defaultNow().notNull(), updatedAt: timestamp("updated_at").defaultNow().notNull(), }); @@ -210,6 +211,9 @@ export const oauthCodes = pgTable("oauth_codes", { id: serial("id").primaryKey(), code: text("code").notNull().unique(), clientId: text("client_id").notNull(), + userId: integer("user_id") + .notNull() + .references(() => users.id), codeChallenge: text("code_challenge").notNull(), codeChallengeMethod: text("code_challenge_method").notNull().default("S256"), redirectUri: text("redirect_uri").notNull(), diff --git a/src/server/routes/oauth.ts b/src/server/routes/oauth.ts index 35cef4a..7186530 100644 --- a/src/server/routes/oauth.ts +++ b/src/server/routes/oauth.ts @@ -202,9 +202,14 @@ oauthRoutes.post("/authorize", async (c) => { return c.json({ error: "redirect_uri not allowed" }, 400); } + // Get or create user from OIDC session + const { getOrCreateUser } = await import("../services/auth.service"); + const user = await getOrCreateUser(db, auth.sub); + const { code } = await createAuthorizationCode( db, clientId, + user.id, codeChallenge, codeChallengeMethod, redirectUri, diff --git a/src/server/services/oauth.service.ts b/src/server/services/oauth.service.ts index 74a36a1..0e96a19 100644 --- a/src/server/services/oauth.service.ts +++ b/src/server/services/oauth.service.ts @@ -36,6 +36,7 @@ export async function getClient(db: Db = prodDb, clientId: string) { export async function createAuthorizationCode( db: Db = prodDb, clientId: string, + userId: number, codeChallenge: string, codeChallengeMethod: string, redirectUri: string, @@ -46,6 +47,7 @@ export async function createAuthorizationCode( await db.insert(oauthCodes).values({ code, clientId, + userId, codeChallenge, codeChallengeMethod, redirectUri, @@ -61,7 +63,6 @@ export async function exchangeCode( codeVerifier: string, clientId: string, redirectUri: string, - userId: number, ): Promise<{ accessToken: string; refreshToken: string; @@ -88,7 +89,7 @@ export async function exchangeCode( // Mark code as used await db.update(oauthCodes).set({ used: 1 }).where(eq(oauthCodes.code, code)); - return generateTokens(db, clientId, userId); + return generateTokens(db, clientId, record.userId); } // ── Token Management ───────────────────────────────────────────────── @@ -144,7 +145,6 @@ export async function refreshAccessToken( db: Db = prodDb, refreshToken: string, clientId: string, - userId: number, ): Promise<{ accessToken: string; refreshToken: string; @@ -168,7 +168,7 @@ export async function refreshAccessToken( // Delete old token pair await db.delete(oauthTokens).where(eq(oauthTokens.id, record.id)); - return generateTokens(db, clientId, userId); + return generateTokens(db, clientId, record.userId); } // ── Cleanup ────────────────────────────────────────────────────────── diff --git a/tests/mcp/tools.test.ts b/tests/mcp/tools.test.ts index b34c62d..4d21756 100644 --- a/tests/mcp/tools.test.ts +++ b/tests/mcp/tools.test.ts @@ -213,7 +213,7 @@ describe("MCP Collection Summary Resource", () => { test("returns overview with correct counts", async () => { const { db, userId } = await createTestDb(); - const summary = getCollectionSummary(db, userId); + const summary = await getCollectionSummary(db, userId); expect(summary.overview).toBeDefined(); expect(summary.overview.totalItems).toBe(0); expect(summary.overview.categoryCount).toBe(1); // Uncategorized @@ -242,7 +242,7 @@ describe("MCP Collection Summary Resource", () => { categoryId: 1, }); - const summary = getCollectionSummary(db, userId); + const summary = await getCollectionSummary(db, userId); expect(summary.overview.totalItems).toBe(2); expect(summary.overview.totalWeightGrams).toBe(2000); expect(summary.overview.activeThreadCount).toBe(1); @@ -255,7 +255,7 @@ describe("MCP Collection Summary Resource", () => { describe("MCP Cross-User Isolation", () => { test("user 2 cannot see user 1's items via MCP tools", async () => { const { db, userId } = await createTestDb(); - const userId2 = createSecondTestUser(db); + const userId2 = await createSecondTestUser(db); const user1Tools = registerItemTools(db, userId); const user2Tools = registerItemTools(db, userId2); @@ -286,7 +286,7 @@ describe("MCP Cross-User Isolation", () => { test("user 2 cannot access user 1's item by ID", async () => { const { db, userId } = await createTestDb(); - const userId2 = createSecondTestUser(db); + const userId2 = await createSecondTestUser(db); const user1Tools = registerItemTools(db, userId); const user2Tools = registerItemTools(db, userId2); @@ -306,7 +306,7 @@ describe("MCP Cross-User Isolation", () => { test("user 2 cannot see user 1's threads via MCP tools", async () => { const { db, userId } = await createTestDb(); - const userId2 = createSecondTestUser(db); + const userId2 = await createSecondTestUser(db); const user1Tools = registerThreadTools(db, userId); const user2Tools = registerThreadTools(db, userId2); @@ -330,7 +330,7 @@ describe("MCP Cross-User Isolation", () => { test("collection summary is scoped to user", async () => { const { db, userId } = await createTestDb(); - const userId2 = createSecondTestUser(db); + const userId2 = await createSecondTestUser(db); const user1Tools = registerItemTools(db, userId); await user1Tools.create_item({ @@ -339,8 +339,8 @@ describe("MCP Cross-User Isolation", () => { weightGrams: 500, }); - const user1Summary = getCollectionSummary(db, userId); - const user2Summary = getCollectionSummary(db, userId2); + const user1Summary = await getCollectionSummary(db, userId); + const user2Summary = await getCollectionSummary(db, userId2); expect(user1Summary.overview.totalItems).toBe(1); expect(user2Summary.overview.totalItems).toBe(0); diff --git a/tests/middleware/auth.test.ts b/tests/middleware/auth.test.ts index 6426fd7..01df328 100644 --- a/tests/middleware/auth.test.ts +++ b/tests/middleware/auth.test.ts @@ -21,10 +21,11 @@ mock.module("../../src/server/services/oauth.service", () => ({ // Import middleware AFTER mocks are set up const { requireAuth } = await import("../../src/server/middleware/auth"); -let db: Awaited>; +let db: any; +let userId: number; beforeEach(async () => { - db = await createTestDb(); + ({ db, userId } = await createTestDb()); mockGetAuth.mockReset(); mockGetAuth.mockReturnValue(null); mockVerifyAccessToken.mockReset(); @@ -64,7 +65,7 @@ describe("auth middleware", () => { test("allows POST with valid API key", async () => { const app = createApp(); - const key = await createApiKey(db, "test"); + const key = await createApiKey(db, userId, "test"); const res = await app.request("/items", { method: "POST", headers: { "X-API-Key": key.rawKey }, @@ -102,7 +103,7 @@ describe("auth middleware", () => { }); expect(res.status).toBe(401); const body = await res.json(); - expect(body.error).toBe("invalid_token"); + expect(body.error).toBe("Invalid or expired token"); }); test("allows POST with valid OIDC session", async () => { @@ -114,7 +115,7 @@ describe("auth middleware", () => { test("API key takes priority over OIDC session", async () => { const app = createApp(); - const key = await createApiKey(db, "test"); + const key = await createApiKey(db, userId, "test"); mockGetAuth.mockReturnValue({ sub: "user-123" }); const res = await app.request("/items", { method: "POST",