diff --git a/api/api/v1/sso/index.ts b/api/api/v1/sso/index.ts deleted file mode 100644 index cc5c013f..00000000 --- a/api/api/v1/sso/index.ts +++ /dev/null @@ -1,238 +0,0 @@ -import { apiRoute, applyConfig, auth, jsonOrForm } from "@/api"; -import { oauthRedirectUri } from "@/constants"; -import { randomString } from "@/math"; -import { proxyUrl } from "@/response"; -import { createRoute } from "@hono/zod-openapi"; -import { - calculatePKCECodeChallenge, - discoveryRequest, - generateRandomCodeVerifier, - processDiscoveryResponse, -} from "oauth4webapi"; -import { z } from "zod"; -import { db } from "~/drizzle/db"; -import { - Applications, - OpenIdLoginFlows, - RolePermissions, -} from "~/drizzle/schema"; -import { config } from "~/packages/config-manager"; -import { ErrorSchema } from "~/types/api"; - -export const meta = applyConfig({ - allowedMethods: ["GET", "POST"], - auth: { - required: true, - }, - ratelimits: { - duration: 60, - max: 20, - }, - route: "/api/v1/sso", - permissions: { - required: [RolePermissions.OAuth], - }, -}); - -export const schemas = { - json: z.object({ - issuer: z.string(), - }), -}; - -const routeGet = createRoute({ - method: "get", - path: "/api/v1/sso", - summary: "Get linked accounts", - middleware: [auth(meta.auth)], - responses: { - 200: { - description: "Linked accounts", - content: { - "application/json": { - schema: z.array( - z.object({ - id: z.string(), - name: z.string(), - icon: z.string().optional(), - }), - ), - }, - }, - }, - 401: { - description: "Unauthorized", - content: { - "application/json": { - schema: ErrorSchema, - }, - }, - }, - }, -}); - -const routePost = createRoute({ - method: "post", - path: "/api/v1/sso", - summary: "Link account", - middleware: [auth(meta.auth), jsonOrForm()], - request: { - body: { - content: { - "application/json": { - schema: schemas.json, - }, - "multipart/form-data": { - schema: schemas.json, - }, - "application/x-www-form-urlencoded": { - schema: schemas.json, - }, - }, - }, - }, - responses: { - 200: { - description: "Link URL", - content: { - "application/json": { - schema: z.object({ - link: z.string(), - }), - }, - }, - }, - 401: { - description: "Unauthorized", - content: { - "application/json": { - schema: ErrorSchema, - }, - }, - }, - 404: { - description: "Issuer not found", - content: { - "application/json": { - schema: ErrorSchema, - }, - }, - }, - }, -}); - -export default apiRoute((app) => { - app.openapi(routeGet, async (context) => { - // const form = context.req.valid("json"); - const { user } = context.get("auth"); - - if (!user) { - return context.json({ error: "Unauthorized" }, 401); - } - - // Get all linked accounts - const accounts = await db.query.OpenIdAccounts.findMany({ - where: (User, { eq }) => eq(User.userId, user.id), - }); - - return context.json( - accounts - .map((account) => { - const issuer = config.oidc.providers.find( - (provider) => provider.id === account.issuerId, - ); - - if (!issuer) { - return null; - } - - return { - id: issuer.id, - name: issuer.name, - icon: proxyUrl(issuer.icon) || undefined, - }; - }) - .filter(Boolean) as { - id: string; - name: string; - icon: string | undefined; - }[], - 200, - ); - }); - - app.openapi(routePost, async (context) => { - const { issuer: issuerId } = context.req.valid("json"); - const { user } = context.get("auth"); - - if (!user) { - return context.json({ error: "Unauthorized" }, 401); - } - - const issuer = config.oidc.providers.find( - (provider) => provider.id === issuerId, - ); - - if (!issuer) { - return context.json({ error: `Issuer ${issuerId} not found` }, 404); - } - - const issuerUrl = new URL(issuer.url); - - const authServer = await discoveryRequest(issuerUrl, { - algorithm: "oidc", - }).then((res) => processDiscoveryResponse(issuerUrl, res)); - - const codeVerifier = generateRandomCodeVerifier(); - - const application = ( - await db - .insert(Applications) - .values({ - clientId: user.id + randomString(32, "base64"), - name: "Versia", - redirectUri: `${oauthRedirectUri(issuerId)}`, - scopes: "openid profile email", - secret: "", - }) - .returning() - )[0]; - - // Store into database - const newFlow = ( - await db - .insert(OpenIdLoginFlows) - .values({ - codeVerifier, - issuerId, - applicationId: application.id, - }) - .returning() - )[0]; - - const codeChallenge = await calculatePKCECodeChallenge(codeVerifier); - - return context.json( - { - link: `${authServer.authorization_endpoint}?${new URLSearchParams( - { - client_id: issuer.client_id, - redirect_uri: `${oauthRedirectUri( - issuerId, - )}?${new URLSearchParams({ - flow: newFlow.id, - link: "true", - user_id: user.id, - })}`, - response_type: "code", - scope: "openid profile email", - // PKCE - code_challenge_method: "S256", - code_challenge: codeChallenge, - }, - ).toString()}`, - }, - 200, - ); - }); -}); diff --git a/app.ts b/app.ts index c037aeee..8ef2c381 100644 --- a/app.ts +++ b/app.ts @@ -21,8 +21,10 @@ import { ipBans } from "./middlewares/ip-bans"; import { logger } from "./middlewares/logger"; import { routes } from "./routes"; import type { ApiRouteExports, HonoEnv } from "./types/api"; +import { configureLoggers } from "@/loggers"; export const appFactory = async () => { + await configureLoggers(); const serverLogger = getLogger("server"); const app = new OpenAPIHono({ @@ -137,6 +139,7 @@ export const appFactory = async () => { await Bun.sleep(Number.POSITIVE_INFINITY); process.exit(); } + // biome-ignore lint/complexity/useLiteralKeys: AddToApp is a private method await data.plugin["_addToApp"](app); } diff --git a/build.ts b/build.ts index d911bf12..24accba5 100644 --- a/build.ts +++ b/build.ts @@ -38,6 +38,9 @@ buildSpinner.text = "Transforming"; // Copy Drizzle migrations to dist await $`cp -r drizzle dist/drizzle`; +// Copy plugin manifests +await $`cp plugins/openid/manifest.json dist/plugins/openid/manifest.json`; + // Copy Sharp to dist await $`mkdir -p dist/node_modules/@img`; await $`cp -r node_modules/@img/sharp-libvips-linux-* dist/node_modules/@img`; @@ -46,6 +49,7 @@ await $`cp -r node_modules/@img/sharp-linux-* dist/node_modules/@img`; // Copy unzipit and uzip-module to dist await $`cp -r node_modules/unzipit dist/node_modules/unzipit`; await $`cp -r node_modules/uzip-module dist/node_modules/uzip-module`; + // Copy acorn to dist await $`cp -r node_modules/acorn dist/node_modules/acorn`; diff --git a/classes/plugin/loader.test.ts b/classes/plugin/loader.test.ts index 1d3ddbec..08a599fc 100644 --- a/classes/plugin/loader.test.ts +++ b/classes/plugin/loader.test.ts @@ -194,10 +194,10 @@ describe("PluginLoader", () => { success: true, data: manifestContent, }); - mock.module("/some/path/plugin1/index.ts", () => ({ + mock.module("/some/path/plugin1/index", () => ({ default: mockPlugin, })); - mock.module("/some/path/plugin2/index.ts", () => ({ + mock.module("/some/path/plugin2/index", () => ({ default: mockPlugin, })); diff --git a/classes/plugin/loader.ts b/classes/plugin/loader.ts index 548805ef..928bffb8 100644 --- a/classes/plugin/loader.ts +++ b/classes/plugin/loader.ts @@ -34,13 +34,13 @@ export class PluginLoader { } /** - * Check if a directory has an entrypoint file (index.ts). + * Check if a directory has an entrypoint file (index.{ts,js}). * @param {string} dir - The directory to search. * @returns {Promise} - True if the entrypoint file is found, otherwise false. */ private async hasEntrypoint(dir: string): Promise { const files = await readdir(dir); - return files.includes("index.ts"); + return files.includes("index.ts") || files.includes("index.js"); } /** @@ -74,7 +74,7 @@ export class PluginLoader { } /** - * Find all direct subdirectories with a valid manifest file and entrypoint (index.ts). + * Find all direct subdirectories with a valid manifest file and entrypoint (index.{ts,js}). * @param {string} dir - The directory to search. * @returns {Promise} - An array of plugin directories. */ @@ -166,7 +166,7 @@ export class PluginLoader { const manifest = await this.parseManifest(dir, plugin); const pluginInstance = await this.loadPlugin( dir, - `${plugin}/index.ts`, + `${plugin}/index`, ); return { manifest, plugin: pluginInstance }; diff --git a/packages/database-interface/user.ts b/packages/database-interface/user.ts index a8d96f3d..baac85f0 100644 --- a/packages/database-interface/user.ts +++ b/packages/database-interface/user.ts @@ -408,6 +408,41 @@ export class User extends BaseInterface { return this.update(this.data); } + public async getLinkedOidcAccounts(): Promise< + { + id: string; + name: string; + url: string; + icon?: string | undefined; + server_id: string; + }[] + > { + // Get all linked accounts + const accounts = await db.query.OpenIdAccounts.findMany({ + where: (User, { eq }) => eq(User.userId, this.id), + }); + + return accounts + .map((account) => { + const issuer = config.oidc.providers.find( + (provider) => provider.id === account.issuerId, + ); + + if (!issuer) { + return null; + } + + return { + id: issuer.id, + name: issuer.name, + url: issuer.url, + icon: proxyUrl(issuer.icon) || undefined, + server_id: account.serverId, + }; + }) + .filter((x) => x !== null); + } + async updateFromRemote(): Promise { if (!this.isRemote()) { throw new Error( diff --git a/plugins/openid/index.ts b/plugins/openid/index.ts index 3e8292ca..6780f35b 100644 --- a/plugins/openid/index.ts +++ b/plugins/openid/index.ts @@ -1,6 +1,7 @@ import { Hooks, Plugin, PluginConfigManager } from "@versia/kit"; import { z } from "zod"; import authorizeRoute from "./routes/authorize"; +import ssoRoute from "./routes/sso"; const configManager = new PluginConfigManager( z.object({ @@ -66,6 +67,7 @@ plugin.registerHandler(Hooks.Response, (req) => { return req; }); authorizeRoute(plugin); +ssoRoute(plugin); export type PluginType = typeof plugin; export default plugin; diff --git a/plugins/openid/routes/authorize.test.ts b/plugins/openid/routes/authorize.test.ts new file mode 100644 index 00000000..5f45da85 --- /dev/null +++ b/plugins/openid/routes/authorize.test.ts @@ -0,0 +1,404 @@ +import { afterAll, beforeAll, describe, expect, test } from "bun:test"; +import { randomString } from "@/math"; +import { eq } from "drizzle-orm"; +import { SignJWT } from "jose"; +import { db } from "~/drizzle/db"; +import { Applications, RolePermissions } from "~/drizzle/schema"; +import { config } from "~/packages/config-manager"; +import { fakeRequest, getTestUsers } from "~/tests/utils"; + +const { deleteUsers, tokens, users } = await getTestUsers(1); +const clientId = "test-client-id"; +const redirectUri = "https://example.com/callback"; +const scope = "openid profile email"; +const secret = "test-secret"; +const privateKey = await crypto.subtle.importKey( + "pkcs8", + Buffer.from(config.plugins?.["@versia/openid"].keys.private, "base64"), + "Ed25519", + false, + ["sign"], +); + +beforeAll(async () => { + await db.insert(Applications).values({ + clientId: clientId, + redirectUri: redirectUri, + scopes: scope, + name: "Test Application", + secret, + }); +}); + +afterAll(async () => { + await deleteUsers(); + await db.delete(Applications).where(eq(Applications.clientId, clientId)); +}); + +describe("/oauth/authorize", () => { + test("should authorize and redirect with valid inputs", async () => { + const jwt = await new SignJWT({ + sub: users[0].id, + iss: new URL(config.http.base_url).origin, + aud: clientId, + exp: Math.floor(Date.now() / 1000) + 60 * 60, + iat: Math.floor(Date.now() / 1000), + nbf: Math.floor(Date.now() / 1000), + }) + .setProtectedHeader({ alg: "EdDSA" }) + .sign(privateKey); + + const response = await fakeRequest("/oauth/authorize", { + method: "POST", + headers: { + Authorization: `Bearer ${tokens[0]?.accessToken}`, + "Content-Type": "application/json", + Cookie: `jwt=${jwt}`, + }, + body: JSON.stringify({ + client_id: clientId, + redirect_uri: redirectUri, + response_type: "code", + scope, + state: "test-state", + code_challenge: randomString(43), + code_challenge_method: "S256", + }), + }); + + expect(response.status).toBe(302); + const location = new URL( + response.headers.get("Location") ?? "", + config.http.base_url, + ); + const params = new URLSearchParams(location.search); + expect(location.origin + location.pathname).toBe(redirectUri); + expect(params.get("code")).toBeTruthy(); + expect(params.get("state")).toBe("test-state"); + }); + + test("should return error for invalid JWT", async () => { + const response = await fakeRequest("/oauth/authorize", { + method: "POST", + headers: { + Authorization: `Bearer ${tokens[0]?.accessToken}`, + "Content-Type": "application/json", + Cookie: "jwt=invalid-jwt", + }, + body: JSON.stringify({ + client_id: clientId, + redirect_uri: redirectUri, + response_type: "code", + scope, + state: "test-state", + code_challenge: randomString(43), + code_challenge_method: "S256", + }), + }); + + expect(response.status).toBe(302); + const location = new URL( + response.headers.get("Location") ?? "", + config.http.base_url, + ); + const params = new URLSearchParams(location.search); + expect(params.get("error")).toBe("invalid_request"); + expect(params.get("error_description")).toBe( + "Invalid JWT, could not verify", + ); + }); + + test("should return error for missing required fields in JWT", async () => { + const jwt = await new SignJWT({ + sub: users[0].id, + iss: new URL(config.http.base_url).origin, + aud: clientId, + }) + .setProtectedHeader({ alg: "EdDSA" }) + .sign(privateKey); + + const response = await fakeRequest("/oauth/authorize", { + method: "POST", + headers: { + Authorization: `Bearer ${tokens[0]?.accessToken}`, + "Content-Type": "application/json", + Cookie: `jwt=${jwt}`, + }, + body: JSON.stringify({ + client_id: clientId, + redirect_uri: redirectUri, + response_type: "code", + scope, + state: "test-state", + code_challenge: randomString(43), + code_challenge_method: "S256", + }), + }); + + expect(response.status).toBe(302); + const location = new URL( + response.headers.get("Location") ?? "", + config.http.base_url, + ); + const params = new URLSearchParams(location.search); + expect(params.get("error")).toBe("invalid_request"); + expect(params.get("error_description")).toBe( + "Invalid JWT, missing required fields (aud, sub, exp)", + ); + }); + + test("should return error for user not found", async () => { + const jwt = await new SignJWT({ + sub: "non-existent-user", + aud: clientId, + exp: Math.floor(Date.now() / 1000) + 60 * 60, + iss: new URL(config.http.base_url).origin, + iat: Math.floor(Date.now() / 1000), + nbf: Math.floor(Date.now() / 1000), + }) + .setProtectedHeader({ alg: "EdDSA" }) + .sign(privateKey); + + const response = await fakeRequest("/oauth/authorize", { + method: "POST", + headers: { + Authorization: `Bearer ${tokens[0]?.accessToken}`, + "Content-Type": "application/json", + Cookie: `jwt=${jwt}`, + }, + body: JSON.stringify({ + client_id: clientId, + redirect_uri: redirectUri, + response_type: "code", + scope, + state: "test-state", + code_challenge: randomString(43), + code_challenge_method: "S256", + }), + }); + + expect(response.status).toBe(302); + const location = new URL( + response.headers.get("Location") ?? "", + config.http.base_url, + ); + const params = new URLSearchParams(location.search); + expect(params.get("error")).toBe("invalid_request"); + expect(params.get("error_description")).toBe( + "Invalid JWT, sub is not a valid user ID", + ); + + const jwt2 = await new SignJWT({ + sub: "23e42862-d5df-49a8-95b5-52d8c6a11aea", + aud: clientId, + exp: Math.floor(Date.now() / 1000) + 60 * 60, + iss: new URL(config.http.base_url).origin, + iat: Math.floor(Date.now() / 1000), + nbf: Math.floor(Date.now() / 1000), + }) + .setProtectedHeader({ alg: "EdDSA" }) + .sign(privateKey); + + const response2 = await fakeRequest("/oauth/authorize", { + method: "POST", + headers: { + Authorization: `Bearer ${tokens[0]?.accessToken}`, + "Content-Type": "application/json", + Cookie: `jwt=${jwt2}`, + }, + body: JSON.stringify({ + client_id: clientId, + redirect_uri: redirectUri, + response_type: "code", + scope, + state: "test-state", + code_challenge: randomString(43), + code_challenge_method: "S256", + }), + }); + + expect(response2.status).toBe(302); + const location2 = new URL( + response2.headers.get("Location") ?? "", + config.http.base_url, + ); + const params2 = new URLSearchParams(location2.search); + expect(params2.get("error")).toBe("invalid_request"); + expect(params2.get("error_description")).toBe( + "Invalid JWT, could not find associated user", + ); + }); + + test("should return error for user missing required permissions", async () => { + const oldPermissions = config.permissions.default; + config.permissions.default = []; + + const jwt = await new SignJWT({ + sub: users[0].id, + iss: new URL(config.http.base_url).origin, + aud: clientId, + exp: Math.floor(Date.now() / 1000) + 60 * 60, + iat: Math.floor(Date.now() / 1000), + nbf: Math.floor(Date.now() / 1000), + }) + .setProtectedHeader({ alg: "EdDSA" }) + .sign(privateKey); + + const response = await fakeRequest("/oauth/authorize", { + method: "POST", + headers: { + Authorization: `Bearer ${tokens[0]?.accessToken}`, + "Content-Type": "application/json", + Cookie: `jwt=${jwt}`, + }, + body: JSON.stringify({ + client_id: clientId, + redirect_uri: redirectUri, + response_type: "code", + scope, + state: "test-state", + code_challenge: randomString(43), + code_challenge_method: "S256", + }), + }); + + expect(response.status).toBe(302); + const location = new URL( + response.headers.get("Location") ?? "", + config.http.base_url, + ); + const params = new URLSearchParams(location.search); + expect(params.get("error")).toBe("invalid_request"); + expect(params.get("error_description")).toBe( + `User is missing the required permission ${RolePermissions.OAuth}`, + ); + + config.permissions.default = oldPermissions; + }); + + test("should return error for invalid client_id", async () => { + const jwt = await new SignJWT({ + sub: users[0].id, + aud: "invalid-client-id", + iss: new URL(config.http.base_url).origin, + exp: Math.floor(Date.now() / 1000) + 60 * 60, + iat: Math.floor(Date.now() / 1000), + nbf: Math.floor(Date.now() / 1000), + }) + .setProtectedHeader({ alg: "EdDSA" }) + .sign(privateKey); + + const response = await fakeRequest("/oauth/authorize", { + method: "POST", + headers: { + Authorization: `Bearer ${tokens[0]?.accessToken}`, + "Content-Type": "application/json", + Cookie: `jwt=${jwt}`, + }, + body: JSON.stringify({ + client_id: "invalid-client-id", + redirect_uri: redirectUri, + response_type: "code", + scope, + state: "test-state", + code_challenge: randomString(43), + code_challenge_method: "S256", + }), + }); + + expect(response.status).toBe(302); + const location = new URL( + response.headers.get("Location") ?? "", + config.http.base_url, + ); + const params = new URLSearchParams(location.search); + expect(params.get("error")).toBe("invalid_request"); + expect(params.get("error_description")).toBe( + "Invalid client_id: no associated application found", + ); + }); + + test("should return error for invalid redirect_uri", async () => { + const jwt = await new SignJWT({ + sub: users[0].id, + iss: new URL(config.http.base_url).origin, + aud: clientId, + exp: Math.floor(Date.now() / 1000) + 60 * 60, + iat: Math.floor(Date.now() / 1000), + nbf: Math.floor(Date.now() / 1000), + }) + .setProtectedHeader({ alg: "EdDSA" }) + .sign(privateKey); + + const response = await fakeRequest("/oauth/authorize", { + method: "POST", + headers: { + Authorization: `Bearer ${tokens[0]?.accessToken}`, + "Content-Type": "application/json", + Cookie: `jwt=${jwt}`, + }, + body: JSON.stringify({ + client_id: clientId, + redirect_uri: "https://invalid.com/callback", + response_type: "code", + scope, + state: "test-state", + code_challenge: randomString(43), + code_challenge_method: "S256", + }), + }); + + expect(response.status).toBe(302); + const location = new URL( + response.headers.get("Location") ?? "", + config.http.base_url, + ); + const params = new URLSearchParams(location.search); + expect(params.get("error")).toBe("invalid_request"); + expect(params.get("error_description")).toBe( + "Invalid redirect_uri: does not match application's redirect_uri", + ); + }); + + test("should return error for invalid scope", async () => { + const jwt = await new SignJWT({ + sub: users[0].id, + iss: new URL(config.http.base_url).origin, + aud: clientId, + exp: Math.floor(Date.now() / 1000) + 60 * 60, + iat: Math.floor(Date.now() / 1000), + nbf: Math.floor(Date.now() / 1000), + }) + .setProtectedHeader({ alg: "EdDSA" }) + .sign(privateKey); + + const response = await fakeRequest("/oauth/authorize", { + method: "POST", + headers: { + Authorization: `Bearer ${tokens[0]?.accessToken}`, + "Content-Type": "application/json", + Cookie: `jwt=${jwt}`, + }, + body: JSON.stringify({ + client_id: clientId, + redirect_uri: redirectUri, + response_type: "code", + scope: "invalid-scope", + state: "test-state", + code_challenge: randomString(43), + code_challenge_method: "S256", + }), + }); + + expect(response.status).toBe(302); + const location = new URL( + response.headers.get("Location") ?? "", + config.http.base_url, + ); + const params = new URLSearchParams(location.search); + expect(params.get("error")).toBe("invalid_scope"); + expect(params.get("error_description")).toBe( + "Invalid scope: not a subset of the application's scopes", + ); + }); +}); diff --git a/plugins/openid/routes/authorize.ts b/plugins/openid/routes/authorize.ts index a934ad66..78fc66a7 100644 --- a/plugins/openid/routes/authorize.ts +++ b/plugins/openid/routes/authorize.ts @@ -159,6 +159,18 @@ export default (plugin: PluginType) => ); } + if (!z.string().uuid().safeParse(sub).success) { + errorSearchParams.append("error", "invalid_request"); + errorSearchParams.append( + "error_description", + "Invalid JWT, sub is not a valid user ID", + ); + + return context.redirect( + `${context.get("config").frontend.routes.login}?${errorSearchParams.toString()}`, + ); + } + const user = await User.fromId(sub); if (!user) { diff --git a/api/api/v1/sso/index.test.ts b/plugins/openid/routes/sso/index.test.ts similarity index 86% rename from api/api/v1/sso/index.test.ts rename to plugins/openid/routes/sso/index.test.ts index b0d03a74..3451478f 100644 --- a/api/api/v1/sso/index.test.ts +++ b/plugins/openid/routes/sso/index.test.ts @@ -1,6 +1,5 @@ import { afterAll, describe, expect, test } from "bun:test"; import { fakeRequest, getTestUsers } from "~/tests/utils"; -import { meta } from "./index"; const { deleteUsers, tokens } = await getTestUsers(1); @@ -8,10 +7,9 @@ afterAll(async () => { await deleteUsers(); }); -// /api/v1/sso -describe(meta.route, () => { +describe("/api/v1/sso", () => { test("should return empty list", async () => { - const response = await fakeRequest(meta.route, { + const response = await fakeRequest("/api/v1/sso", { method: "GET", headers: { Authorization: `Bearer ${tokens[0]?.accessToken}`, @@ -23,7 +21,7 @@ describe(meta.route, () => { }); test("should return an error if provider doesn't exist", async () => { - const response = await fakeRequest(meta.route, { + const response = await fakeRequest("/api/v1/sso", { method: "POST", headers: { Authorization: `Bearer ${tokens[0]?.accessToken}`, diff --git a/plugins/openid/routes/sso/index.ts b/plugins/openid/routes/sso/index.ts new file mode 100644 index 00000000..0c765a33 --- /dev/null +++ b/plugins/openid/routes/sso/index.ts @@ -0,0 +1,223 @@ +import { auth } from "@/api"; +import { + calculatePKCECodeChallenge, + generateRandomCodeVerifier, +} from "oauth4webapi"; +import { z } from "zod"; +import { db } from "~/drizzle/db"; +import { + Applications, + OpenIdLoginFlows, + RolePermissions, +} from "~/drizzle/schema"; +import { ErrorSchema } from "~/types/api"; +import type { PluginType } from "../.."; +import { oauthDiscoveryRequest, oauthRedirectUri } from "../../utils"; + +export default (plugin: PluginType) => { + plugin.registerRoute("/api/v1/sso", (app) => { + app.openapi( + { + method: "get", + path: "/api/v1/sso", + summary: "Get linked accounts", + middleware: [ + auth( + { + required: true, + }, + { + required: [RolePermissions.OAuth], + }, + ), + plugin.middleware, + ], + responses: { + 200: { + description: "Linked accounts", + content: { + "application/json": { + schema: z.array( + z.object({ + id: z.string(), + name: z.string(), + icon: z.string().optional(), + }), + ), + }, + }, + }, + 401: { + description: "Unauthorized", + content: { + "application/json": { + schema: ErrorSchema, + }, + }, + }, + }, + }, + async (context) => { + const { user } = context.get("auth"); + + if (!user) { + return context.json( + { + error: "Unauthorized", + }, + 401, + ); + } + + const linkedAccounts = await user.getLinkedOidcAccounts(); + + return context.json( + linkedAccounts.map((account) => ({ + id: account.id, + name: account.name, + icon: account.icon, + })), + 200, + ); + }, + ); + + app.openapi( + { + method: "post", + path: "/api/v1/sso", + summary: "Link account", + middleware: [ + auth( + { + required: true, + }, + { + required: [RolePermissions.OAuth], + }, + ), + ], + request: { + body: { + content: { + "application/json": { + schema: z.object({ + issuer: z.string(), + }), + }, + }, + }, + }, + responses: { + 302: { + description: "Redirect to OpenID provider", + }, + 401: { + description: "Unauthorized", + content: { + "application/json": { + schema: ErrorSchema, + }, + }, + }, + 404: { + description: "Issuer not found", + content: { + "application/json": { + schema: ErrorSchema, + }, + }, + }, + }, + }, + async (context) => { + const { user } = context.get("auth"); + + if (!user) { + return context.json( + { + error: "Unauthorized", + }, + 401, + ); + } + + const { issuer: issuerId } = context.req.valid("json"); + + const issuer = context + .get("pluginConfig") + .providers.find((provider) => provider.id === issuerId); + + if (!issuer) { + return context.json( + { + error: `Issuer with ID ${issuerId} not found in instance's OpenID configuration`, + }, + 404, + ); + } + + const authServer = await oauthDiscoveryRequest(issuer.url); + + const codeVerifier = generateRandomCodeVerifier(); + + const redirectUri = oauthRedirectUri( + issuerId, + context.get("config").http.base_url, + ); + + const application = ( + await db + .insert(Applications) + .values({ + clientId: + user.id + + Buffer.from( + crypto.getRandomValues(new Uint8Array(32)), + ).toString("base64"), + name: "Versia", + redirectUri, + scopes: "openid profile email", + secret: "", + }) + .returning() + )[0]; + + // Store into database + const newFlow = ( + await db + .insert(OpenIdLoginFlows) + .values({ + codeVerifier, + issuerId, + applicationId: application.id, + }) + .returning() + )[0]; + + const codeChallenge = + await calculatePKCECodeChallenge(codeVerifier); + + return context.redirect( + `${authServer.authorization_endpoint}?${new URLSearchParams( + { + client_id: issuer.client_id, + redirect_uri: `${redirectUri}?${new URLSearchParams( + { + flow: newFlow.id, + link: "true", + user_id: user.id, + }, + )}`, + response_type: "code", + scope: "openid profile email", + // PKCE + code_challenge_method: "S256", + code_challenge: codeChallenge, + }, + ).toString()}`, + ); + }, + ); + }); +}; diff --git a/plugins/openid/utils.ts b/plugins/openid/utils.ts new file mode 100644 index 00000000..2c6d2583 --- /dev/null +++ b/plugins/openid/utils.ts @@ -0,0 +1,18 @@ +import { + type AuthorizationServer, + discoveryRequest, + processDiscoveryResponse, +} from "oauth4webapi"; + +export const oauthDiscoveryRequest = ( + issuerUrl: string | URL, +): Promise => { + const issuerUrlurl = new URL(issuerUrl); + + return discoveryRequest(issuerUrlurl, { + algorithm: "oidc", + }).then((res) => processDiscoveryResponse(issuerUrlurl, res)); +}; + +export const oauthRedirectUri = (baseUrl: string, issuer: string) => + new URL(`/oauth/sso/${issuer}/callback`, baseUrl).toString();