diff --git a/src/server/index.ts b/src/server/index.ts index c601229..f56e284 100644 --- a/src/server/index.ts +++ b/src/server/index.ts @@ -54,13 +54,13 @@ app.use("/api/*", async (c, next) => { return next(); }); -// Auth middleware for write operations (POST/PUT/PATCH/DELETE) on non-auth routes +// Auth middleware for all data routes (userId must be available for per-user scoping) app.use("/api/*", async (c, next) => { // Skip auth routes — they handle their own auth if (c.req.path.startsWith("/api/auth")) return next(); - // Skip GET requests — read is public - if (c.req.method === "GET") return next(); - // All other methods require auth + // Skip health check + if (c.req.path === "/api/health") return next(); + // All methods require auth for userId resolution return requireAuth(c, next); }); diff --git a/src/server/middleware/auth.ts b/src/server/middleware/auth.ts index b9e36de..ca1c0e1 100644 --- a/src/server/middleware/auth.ts +++ b/src/server/middleware/auth.ts @@ -1,37 +1,46 @@ import type { Context, Next } from "hono"; -import { getCookie } from "hono/cookie"; -import { - getSession, - getUserCount, - refreshSession, - verifyApiKey, -} from "../services/auth.service"; +import { getOrCreateUser, verifyApiKey } from "../services/auth.service"; +import { getOrCreateUncategorized } from "../services/category.service"; +import { verifyAccessToken } from "../services/oauth.service"; export async function requireAuth(c: Context, next: Next) { const db = c.get("db"); - // Check if any users exist at all - if (getUserCount(db) === 0) { - return c.json({ error: "setup_required" }, 403); - } - // Check API key first const apiKey = c.req.header("X-API-Key"); if (apiKey) { - const valid = await verifyApiKey(db, apiKey); - if (valid) return next(); + const result = await verifyApiKey(db, apiKey); + if (result) { + c.set("userId", result.userId); + return next(); + } return c.json({ error: "Invalid API key" }, 401); } - // Check session cookie - const sessionId = getCookie(c, "gearbox_session"); - if (sessionId) { - const session = getSession(db, sessionId); - if (session) { - // Refresh session expiry on use - refreshSession(db, sessionId); + // Check OAuth Bearer token + const authHeader = c.req.header("Authorization"); + if (authHeader?.startsWith("Bearer ")) { + const token = authHeader.slice(7); + const result = await verifyAccessToken(db, token); + if (result) { + c.set("userId", result.userId); return next(); } + return c.json({ error: "Invalid or expired token" }, 401); + } + + // Check OIDC session (browser users via Logto) + try { + const { getAuth } = await import("@hono/oidc-auth"); + const auth = await getAuth(c); + if (auth?.sub) { + const user = await getOrCreateUser(db, auth.sub); + await getOrCreateUncategorized(db, user.id); + c.set("userId", user.id); + return next(); + } + } catch { + // OIDC not configured or session invalid — fall through } return c.json({ error: "Authentication required" }, 401); diff --git a/src/server/services/auth.service.ts b/src/server/services/auth.service.ts index 1ed003e..5a800e6 100644 --- a/src/server/services/auth.service.ts +++ b/src/server/services/auth.service.ts @@ -1,123 +1,42 @@ import { randomBytes } from "node:crypto"; -import { count, eq } from "drizzle-orm"; +import { and, eq } from "drizzle-orm"; import { db as prodDb } from "../../db/index.ts"; -import { apiKeys, sessions, users } from "../../db/schema.ts"; +import { apiKeys, users } from "../../db/schema.ts"; type Db = typeof prodDb; // ── User Management ────────────────────────────────────────────────── -export async function createUser( - db: Db = prodDb, - username: string, - password: string, -) { - const passwordHash = await Bun.password.hash(password); - return db.insert(users).values({ username, passwordHash }).returning().get(); -} - -export async function verifyPassword( - db: Db = prodDb, - username: string, - password: string, -) { - const user = db - .select() - .from(users) - .where(eq(users.username, username)) - .get(); - - if (!user) return null; - - const valid = await Bun.password.verify(password, user.passwordHash); - return valid ? user : null; -} - -export function getUserCount(db: Db = prodDb): number { - const result = db.select({ value: count() }).from(users).get(); - return result?.value ?? 0; -} - -export async function changePassword( - db: Db = prodDb, - username: string, - currentPassword: string, - newPassword: string, -): Promise { - const user = await verifyPassword(db, username, currentPassword); - if (!user) return false; - - const newHash = await Bun.password.hash(newPassword); - db.update(users) - .set({ passwordHash: newHash }) - .where(eq(users.id, user.id)) - .run(); - - return true; -} - -// ── Session Management ─────────────────────────────────────────────── - -export function createSession( - db: Db = prodDb, - userId: number, - expiryDays = 30, -) { - const id = randomBytes(32).toString("hex"); - const expiresAt = new Date(Date.now() + expiryDays * 24 * 60 * 60 * 1000); - - return db - .insert(sessions) - .values({ id, userId, expiresAt }) - .returning() - .get(); -} - -export function getSession(db: Db = prodDb, sessionId: string) { - const session = db - .select() - .from(sessions) - .where(eq(sessions.id, sessionId)) - .get(); - - if (!session) return null; - - if (session.expiresAt < new Date()) { - db.delete(sessions).where(eq(sessions.id, sessionId)).run(); - return null; - } - - return session; -} - -export function deleteSession(db: Db = prodDb, sessionId: string) { - db.delete(sessions).where(eq(sessions.id, sessionId)).run(); -} - -export function refreshSession( - db: Db = prodDb, - sessionId: string, - expiryDays = 30, -) { - const expiresAt = new Date(Date.now() + expiryDays * 24 * 60 * 60 * 1000); - db.update(sessions) - .set({ expiresAt }) - .where(eq(sessions.id, sessionId)) - .run(); +export async function getOrCreateUser( + db: Db, + logtoSub: string, +): Promise<{ id: number }> { + const [user] = await db + .insert(users) + .values({ logtoSub }) + .onConflictDoUpdate({ + target: users.logtoSub, + set: { logtoSub }, + }) + .returning({ id: users.id }); + return user; } // ── API Key Management ─────────────────────────────────────────────── -export async function createApiKey(db: Db = prodDb, name: string) { +export async function createApiKey( + db: Db = prodDb, + name: string, + userId: number, +) { const rawKey = randomBytes(32).toString("hex"); const keyHash = await Bun.password.hash(rawKey); const keyPrefix = rawKey.slice(0, 8); - const record = db + const [record] = await db .insert(apiKeys) - .values({ name, keyHash, keyPrefix }) - .returning() - .get(); + .values({ name, keyHash, keyPrefix, userId }) + .returning(); return { ...record, rawKey }; } @@ -125,23 +44,22 @@ export async function createApiKey(db: Db = prodDb, name: string) { export async function verifyApiKey( db: Db = prodDb, rawKey: string, -): Promise { +): Promise<{ userId: number } | null> { const prefix = rawKey.slice(0, 8); - const candidates = db + const candidates = await db .select() .from(apiKeys) - .where(eq(apiKeys.keyPrefix, prefix)) - .all(); + .where(eq(apiKeys.keyPrefix, prefix)); for (const candidate of candidates) { const valid = await Bun.password.verify(rawKey, candidate.keyHash); - if (valid) return true; + if (valid) return { userId: candidate.userId }; } - return false; + return null; } -export function listApiKeys(db: Db = prodDb) { +export async function listApiKeys(db: Db = prodDb, userId: number) { return db .select({ id: apiKeys.id, @@ -150,9 +68,15 @@ export function listApiKeys(db: Db = prodDb) { createdAt: apiKeys.createdAt, }) .from(apiKeys) - .all(); + .where(eq(apiKeys.userId, userId)); } -export function deleteApiKey(db: Db = prodDb, id: number) { - db.delete(apiKeys).where(eq(apiKeys.id, id)).run(); +export async function deleteApiKey( + db: Db = prodDb, + id: number, + userId: number, +) { + await db + .delete(apiKeys) + .where(and(eq(apiKeys.id, id), eq(apiKeys.userId, userId))); } diff --git a/src/server/services/category.service.ts b/src/server/services/category.service.ts index 3b35396..39d0430 100644 --- a/src/server/services/category.service.ts +++ b/src/server/services/category.service.ts @@ -1,9 +1,25 @@ -import { asc, eq } from "drizzle-orm"; +import { and, asc, eq } from "drizzle-orm"; import { db as prodDb } from "../../db/index.ts"; import { categories, items } from "../../db/schema.ts"; type Db = typeof prodDb; +export async function getOrCreateUncategorized( + db: Db, + userId: number, +): Promise { + const [existing] = await db + .select({ id: categories.id }) + .from(categories) + .where(and(eq(categories.userId, userId), eq(categories.name, "Uncategorized"))); + if (existing) return existing.id; + const [created] = await db + .insert(categories) + .values({ name: "Uncategorized", icon: "package", userId }) + .returning({ id: categories.id }); + return created.id; +} + export function getAllCategories(db: Db = prodDb) { return db.select().from(categories).orderBy(asc(categories.name)).all(); } diff --git a/src/server/services/oauth.service.ts b/src/server/services/oauth.service.ts index 10037e5..32b4703 100644 --- a/src/server/services/oauth.service.ts +++ b/src/server/services/oauth.service.ts @@ -7,53 +7,50 @@ type Db = typeof prodDb; // ── Client Registration ────────────────────────────────────────────── -export function registerClient( +export async function registerClient( db: Db = prodDb, clientName: string, redirectUris: string[], -): { clientId: string } { +): Promise<{ clientId: string }> { const clientId = randomUUID(); const redirectUrisJson = JSON.stringify(redirectUris); - db.insert(oauthClients) - .values({ clientId, clientName, redirectUris: redirectUrisJson }) - .run(); + await db + .insert(oauthClients) + .values({ clientId, clientName, redirectUris: redirectUrisJson }); return { clientId }; } -export function getClient(db: Db = prodDb, clientId: string) { - return ( - db - .select() - .from(oauthClients) - .where(eq(oauthClients.clientId, clientId)) - .get() ?? null - ); +export async function getClient(db: Db = prodDb, clientId: string) { + const [record] = await db + .select() + .from(oauthClients) + .where(eq(oauthClients.clientId, clientId)); + + return record ?? null; } // ── Authorization Code ─────────────────────────────────────────────── -export function createAuthorizationCode( +export async function createAuthorizationCode( db: Db = prodDb, clientId: string, codeChallenge: string, codeChallengeMethod: string, redirectUri: string, -): { code: string } { +): Promise<{ code: string }> { const code = randomBytes(32).toString("hex"); const expiresAt = new Date(Date.now() + 10 * 60 * 1000); // 10 minutes - db.insert(oauthCodes) - .values({ - code, - clientId, - codeChallenge, - codeChallengeMethod, - redirectUri, - expiresAt, - }) - .run(); + await db.insert(oauthCodes).values({ + code, + clientId, + codeChallenge, + codeChallengeMethod, + redirectUri, + expiresAt, + }); return { code }; } @@ -64,16 +61,16 @@ export async function exchangeCode( codeVerifier: string, clientId: string, redirectUri: string, + userId: number, ): Promise<{ accessToken: string; refreshToken: string; expiresIn: number; } | null> { - const record = db + const [record] = await db .select() .from(oauthCodes) - .where(eq(oauthCodes.code, code)) - .get(); + .where(eq(oauthCodes.code, code)); if (!record) return null; if (record.used !== 0) return null; @@ -89,17 +86,21 @@ export async function exchangeCode( if (computedChallenge !== record.codeChallenge) return null; // Mark code as used - db.update(oauthCodes).set({ used: 1 }).where(eq(oauthCodes.code, code)).run(); + await db + .update(oauthCodes) + .set({ used: 1 }) + .where(eq(oauthCodes.code, code)); - return generateTokens(db, clientId); + return generateTokens(db, clientId, userId); } // ── Token Management ───────────────────────────────────────────────── -function generateTokens( +async function generateTokens( db: Db, clientId: string, -): { accessToken: string; refreshToken: string; expiresIn: number } { + userId: number, +): Promise<{ accessToken: string; refreshToken: string; expiresIn: number }> { const accessToken = randomBytes(32).toString("hex"); const refreshToken = randomBytes(32).toString("hex"); @@ -113,15 +114,14 @@ function generateTokens( const expiresAt = new Date(Date.now() + 3600 * 1000); // 1 hour const refreshExpiresAt = new Date(Date.now() + 30 * 24 * 60 * 60 * 1000); // 30 days - db.insert(oauthTokens) - .values({ - accessTokenHash, - refreshTokenHash, - clientId, - expiresAt, - refreshExpiresAt, - }) - .run(); + await db.insert(oauthTokens).values({ + accessTokenHash, + refreshTokenHash, + clientId, + userId, + expiresAt, + refreshExpiresAt, + }); return { accessToken, refreshToken, expiresIn: 3600 }; } @@ -129,25 +129,25 @@ function generateTokens( export async function verifyAccessToken( db: Db = prodDb, token: string, -): Promise { +): Promise<{ userId: number } | null> { const tokenHash = createHash("sha256").update(token).digest("hex"); - const record = db + const [record] = await db .select() .from(oauthTokens) - .where(eq(oauthTokens.accessTokenHash, tokenHash)) - .get(); + .where(eq(oauthTokens.accessTokenHash, tokenHash)); - if (!record) return false; - if (record.expiresAt < new Date()) return false; + if (!record) return null; + if (record.expiresAt < new Date()) return null; - return true; + return { userId: record.userId }; } export async function refreshAccessToken( db: Db = prodDb, refreshToken: string, clientId: string, + userId: number, ): Promise<{ accessToken: string; refreshToken: string; @@ -155,7 +155,7 @@ export async function refreshAccessToken( } | null> { const tokenHash = createHash("sha256").update(refreshToken).digest("hex"); - const record = db + const [record] = await db .select() .from(oauthTokens) .where( @@ -163,22 +163,21 @@ export async function refreshAccessToken( eq(oauthTokens.refreshTokenHash, tokenHash), eq(oauthTokens.clientId, clientId), ), - ) - .get(); + ); if (!record) return null; if (record.refreshExpiresAt < new Date()) return null; // Delete old token pair - db.delete(oauthTokens).where(eq(oauthTokens.id, record.id)).run(); + await db.delete(oauthTokens).where(eq(oauthTokens.id, record.id)); - return generateTokens(db, clientId); + return generateTokens(db, clientId, userId); } // ── Cleanup ────────────────────────────────────────────────────────── -export function cleanExpiredOAuthData(db: Db = prodDb): void { +export async function cleanExpiredOAuthData(db: Db = prodDb): Promise { const now = new Date(); - db.delete(oauthCodes).where(lt(oauthCodes.expiresAt, now)).run(); - db.delete(oauthTokens).where(lt(oauthTokens.expiresAt, now)).run(); + await db.delete(oauthCodes).where(lt(oauthCodes.expiresAt, now)); + await db.delete(oauthTokens).where(lt(oauthTokens.expiresAt, now)); }