refactor(api): ♻️ Move /api/v1/sso to OpenID plugin

This commit is contained in:
Jesse Wierzbinski 2024-09-24 14:42:39 +02:00
parent c7ec678a3e
commit 96d1805925
No known key found for this signature in database
12 changed files with 710 additions and 249 deletions

View file

@ -1,238 +0,0 @@
import { apiRoute, applyConfig, auth, jsonOrForm } from "@/api";
import { oauthRedirectUri } from "@/constants";
import { randomString } from "@/math";
import { proxyUrl } from "@/response";
import { createRoute } from "@hono/zod-openapi";
import {
calculatePKCECodeChallenge,
discoveryRequest,
generateRandomCodeVerifier,
processDiscoveryResponse,
} from "oauth4webapi";
import { z } from "zod";
import { db } from "~/drizzle/db";
import {
Applications,
OpenIdLoginFlows,
RolePermissions,
} from "~/drizzle/schema";
import { config } from "~/packages/config-manager";
import { ErrorSchema } from "~/types/api";
export const meta = applyConfig({
allowedMethods: ["GET", "POST"],
auth: {
required: true,
},
ratelimits: {
duration: 60,
max: 20,
},
route: "/api/v1/sso",
permissions: {
required: [RolePermissions.OAuth],
},
});
export const schemas = {
json: z.object({
issuer: z.string(),
}),
};
const routeGet = createRoute({
method: "get",
path: "/api/v1/sso",
summary: "Get linked accounts",
middleware: [auth(meta.auth)],
responses: {
200: {
description: "Linked accounts",
content: {
"application/json": {
schema: z.array(
z.object({
id: z.string(),
name: z.string(),
icon: z.string().optional(),
}),
),
},
},
},
401: {
description: "Unauthorized",
content: {
"application/json": {
schema: ErrorSchema,
},
},
},
},
});
const routePost = createRoute({
method: "post",
path: "/api/v1/sso",
summary: "Link account",
middleware: [auth(meta.auth), jsonOrForm()],
request: {
body: {
content: {
"application/json": {
schema: schemas.json,
},
"multipart/form-data": {
schema: schemas.json,
},
"application/x-www-form-urlencoded": {
schema: schemas.json,
},
},
},
},
responses: {
200: {
description: "Link URL",
content: {
"application/json": {
schema: z.object({
link: z.string(),
}),
},
},
},
401: {
description: "Unauthorized",
content: {
"application/json": {
schema: ErrorSchema,
},
},
},
404: {
description: "Issuer not found",
content: {
"application/json": {
schema: ErrorSchema,
},
},
},
},
});
export default apiRoute((app) => {
app.openapi(routeGet, async (context) => {
// const form = context.req.valid("json");
const { user } = context.get("auth");
if (!user) {
return context.json({ error: "Unauthorized" }, 401);
}
// Get all linked accounts
const accounts = await db.query.OpenIdAccounts.findMany({
where: (User, { eq }) => eq(User.userId, user.id),
});
return context.json(
accounts
.map((account) => {
const issuer = config.oidc.providers.find(
(provider) => provider.id === account.issuerId,
);
if (!issuer) {
return null;
}
return {
id: issuer.id,
name: issuer.name,
icon: proxyUrl(issuer.icon) || undefined,
};
})
.filter(Boolean) as {
id: string;
name: string;
icon: string | undefined;
}[],
200,
);
});
app.openapi(routePost, async (context) => {
const { issuer: issuerId } = context.req.valid("json");
const { user } = context.get("auth");
if (!user) {
return context.json({ error: "Unauthorized" }, 401);
}
const issuer = config.oidc.providers.find(
(provider) => provider.id === issuerId,
);
if (!issuer) {
return context.json({ error: `Issuer ${issuerId} not found` }, 404);
}
const issuerUrl = new URL(issuer.url);
const authServer = await discoveryRequest(issuerUrl, {
algorithm: "oidc",
}).then((res) => processDiscoveryResponse(issuerUrl, res));
const codeVerifier = generateRandomCodeVerifier();
const application = (
await db
.insert(Applications)
.values({
clientId: user.id + randomString(32, "base64"),
name: "Versia",
redirectUri: `${oauthRedirectUri(issuerId)}`,
scopes: "openid profile email",
secret: "",
})
.returning()
)[0];
// Store into database
const newFlow = (
await db
.insert(OpenIdLoginFlows)
.values({
codeVerifier,
issuerId,
applicationId: application.id,
})
.returning()
)[0];
const codeChallenge = await calculatePKCECodeChallenge(codeVerifier);
return context.json(
{
link: `${authServer.authorization_endpoint}?${new URLSearchParams(
{
client_id: issuer.client_id,
redirect_uri: `${oauthRedirectUri(
issuerId,
)}?${new URLSearchParams({
flow: newFlow.id,
link: "true",
user_id: user.id,
})}`,
response_type: "code",
scope: "openid profile email",
// PKCE
code_challenge_method: "S256",
code_challenge: codeChallenge,
},
).toString()}`,
},
200,
);
});
});

3
app.ts
View file

@ -21,8 +21,10 @@ import { ipBans } from "./middlewares/ip-bans";
import { logger } from "./middlewares/logger"; import { logger } from "./middlewares/logger";
import { routes } from "./routes"; import { routes } from "./routes";
import type { ApiRouteExports, HonoEnv } from "./types/api"; import type { ApiRouteExports, HonoEnv } from "./types/api";
import { configureLoggers } from "@/loggers";
export const appFactory = async () => { export const appFactory = async () => {
await configureLoggers();
const serverLogger = getLogger("server"); const serverLogger = getLogger("server");
const app = new OpenAPIHono<HonoEnv>({ const app = new OpenAPIHono<HonoEnv>({
@ -137,6 +139,7 @@ export const appFactory = async () => {
await Bun.sleep(Number.POSITIVE_INFINITY); await Bun.sleep(Number.POSITIVE_INFINITY);
process.exit(); process.exit();
} }
// biome-ignore lint/complexity/useLiteralKeys: AddToApp is a private method // biome-ignore lint/complexity/useLiteralKeys: AddToApp is a private method
await data.plugin["_addToApp"](app); await data.plugin["_addToApp"](app);
} }

View file

@ -38,6 +38,9 @@ buildSpinner.text = "Transforming";
// Copy Drizzle migrations to dist // Copy Drizzle migrations to dist
await $`cp -r drizzle dist/drizzle`; await $`cp -r drizzle dist/drizzle`;
// Copy plugin manifests
await $`cp plugins/openid/manifest.json dist/plugins/openid/manifest.json`;
// Copy Sharp to dist // Copy Sharp to dist
await $`mkdir -p dist/node_modules/@img`; await $`mkdir -p dist/node_modules/@img`;
await $`cp -r node_modules/@img/sharp-libvips-linux-* dist/node_modules/@img`; await $`cp -r node_modules/@img/sharp-libvips-linux-* dist/node_modules/@img`;
@ -46,6 +49,7 @@ await $`cp -r node_modules/@img/sharp-linux-* dist/node_modules/@img`;
// Copy unzipit and uzip-module to dist // Copy unzipit and uzip-module to dist
await $`cp -r node_modules/unzipit dist/node_modules/unzipit`; await $`cp -r node_modules/unzipit dist/node_modules/unzipit`;
await $`cp -r node_modules/uzip-module dist/node_modules/uzip-module`; await $`cp -r node_modules/uzip-module dist/node_modules/uzip-module`;
// Copy acorn to dist // Copy acorn to dist
await $`cp -r node_modules/acorn dist/node_modules/acorn`; await $`cp -r node_modules/acorn dist/node_modules/acorn`;

View file

@ -194,10 +194,10 @@ describe("PluginLoader", () => {
success: true, success: true,
data: manifestContent, data: manifestContent,
}); });
mock.module("/some/path/plugin1/index.ts", () => ({ mock.module("/some/path/plugin1/index", () => ({
default: mockPlugin, default: mockPlugin,
})); }));
mock.module("/some/path/plugin2/index.ts", () => ({ mock.module("/some/path/plugin2/index", () => ({
default: mockPlugin, default: mockPlugin,
})); }));

View file

@ -34,13 +34,13 @@ export class PluginLoader {
} }
/** /**
* Check if a directory has an entrypoint file (index.ts). * Check if a directory has an entrypoint file (index.{ts,js}).
* @param {string} dir - The directory to search. * @param {string} dir - The directory to search.
* @returns {Promise<boolean>} - True if the entrypoint file is found, otherwise false. * @returns {Promise<boolean>} - True if the entrypoint file is found, otherwise false.
*/ */
private async hasEntrypoint(dir: string): Promise<boolean> { private async hasEntrypoint(dir: string): Promise<boolean> {
const files = await readdir(dir); const files = await readdir(dir);
return files.includes("index.ts"); return files.includes("index.ts") || files.includes("index.js");
} }
/** /**
@ -74,7 +74,7 @@ export class PluginLoader {
} }
/** /**
* Find all direct subdirectories with a valid manifest file and entrypoint (index.ts). * Find all direct subdirectories with a valid manifest file and entrypoint (index.{ts,js}).
* @param {string} dir - The directory to search. * @param {string} dir - The directory to search.
* @returns {Promise<string[]>} - An array of plugin directories. * @returns {Promise<string[]>} - An array of plugin directories.
*/ */
@ -166,7 +166,7 @@ export class PluginLoader {
const manifest = await this.parseManifest(dir, plugin); const manifest = await this.parseManifest(dir, plugin);
const pluginInstance = await this.loadPlugin( const pluginInstance = await this.loadPlugin(
dir, dir,
`${plugin}/index.ts`, `${plugin}/index`,
); );
return { manifest, plugin: pluginInstance }; return { manifest, plugin: pluginInstance };

View file

@ -408,6 +408,41 @@ export class User extends BaseInterface<typeof Users, UserWithRelations> {
return this.update(this.data); return this.update(this.data);
} }
public async getLinkedOidcAccounts(): Promise<
{
id: string;
name: string;
url: string;
icon?: string | undefined;
server_id: string;
}[]
> {
// Get all linked accounts
const accounts = await db.query.OpenIdAccounts.findMany({
where: (User, { eq }) => eq(User.userId, this.id),
});
return accounts
.map((account) => {
const issuer = config.oidc.providers.find(
(provider) => provider.id === account.issuerId,
);
if (!issuer) {
return null;
}
return {
id: issuer.id,
name: issuer.name,
url: issuer.url,
icon: proxyUrl(issuer.icon) || undefined,
server_id: account.serverId,
};
})
.filter((x) => x !== null);
}
async updateFromRemote(): Promise<User> { async updateFromRemote(): Promise<User> {
if (!this.isRemote()) { if (!this.isRemote()) {
throw new Error( throw new Error(

View file

@ -1,6 +1,7 @@
import { Hooks, Plugin, PluginConfigManager } from "@versia/kit"; import { Hooks, Plugin, PluginConfigManager } from "@versia/kit";
import { z } from "zod"; import { z } from "zod";
import authorizeRoute from "./routes/authorize"; import authorizeRoute from "./routes/authorize";
import ssoRoute from "./routes/sso";
const configManager = new PluginConfigManager( const configManager = new PluginConfigManager(
z.object({ z.object({
@ -66,6 +67,7 @@ plugin.registerHandler(Hooks.Response, (req) => {
return req; return req;
}); });
authorizeRoute(plugin); authorizeRoute(plugin);
ssoRoute(plugin);
export type PluginType = typeof plugin; export type PluginType = typeof plugin;
export default plugin; export default plugin;

View file

@ -0,0 +1,404 @@
import { afterAll, beforeAll, describe, expect, test } from "bun:test";
import { randomString } from "@/math";
import { eq } from "drizzle-orm";
import { SignJWT } from "jose";
import { db } from "~/drizzle/db";
import { Applications, RolePermissions } from "~/drizzle/schema";
import { config } from "~/packages/config-manager";
import { fakeRequest, getTestUsers } from "~/tests/utils";
const { deleteUsers, tokens, users } = await getTestUsers(1);
const clientId = "test-client-id";
const redirectUri = "https://example.com/callback";
const scope = "openid profile email";
const secret = "test-secret";
const privateKey = await crypto.subtle.importKey(
"pkcs8",
Buffer.from(config.plugins?.["@versia/openid"].keys.private, "base64"),
"Ed25519",
false,
["sign"],
);
beforeAll(async () => {
await db.insert(Applications).values({
clientId: clientId,
redirectUri: redirectUri,
scopes: scope,
name: "Test Application",
secret,
});
});
afterAll(async () => {
await deleteUsers();
await db.delete(Applications).where(eq(Applications.clientId, clientId));
});
describe("/oauth/authorize", () => {
test("should authorize and redirect with valid inputs", async () => {
const jwt = await new SignJWT({
sub: users[0].id,
iss: new URL(config.http.base_url).origin,
aud: clientId,
exp: Math.floor(Date.now() / 1000) + 60 * 60,
iat: Math.floor(Date.now() / 1000),
nbf: Math.floor(Date.now() / 1000),
})
.setProtectedHeader({ alg: "EdDSA" })
.sign(privateKey);
const response = await fakeRequest("/oauth/authorize", {
method: "POST",
headers: {
Authorization: `Bearer ${tokens[0]?.accessToken}`,
"Content-Type": "application/json",
Cookie: `jwt=${jwt}`,
},
body: JSON.stringify({
client_id: clientId,
redirect_uri: redirectUri,
response_type: "code",
scope,
state: "test-state",
code_challenge: randomString(43),
code_challenge_method: "S256",
}),
});
expect(response.status).toBe(302);
const location = new URL(
response.headers.get("Location") ?? "",
config.http.base_url,
);
const params = new URLSearchParams(location.search);
expect(location.origin + location.pathname).toBe(redirectUri);
expect(params.get("code")).toBeTruthy();
expect(params.get("state")).toBe("test-state");
});
test("should return error for invalid JWT", async () => {
const response = await fakeRequest("/oauth/authorize", {
method: "POST",
headers: {
Authorization: `Bearer ${tokens[0]?.accessToken}`,
"Content-Type": "application/json",
Cookie: "jwt=invalid-jwt",
},
body: JSON.stringify({
client_id: clientId,
redirect_uri: redirectUri,
response_type: "code",
scope,
state: "test-state",
code_challenge: randomString(43),
code_challenge_method: "S256",
}),
});
expect(response.status).toBe(302);
const location = new URL(
response.headers.get("Location") ?? "",
config.http.base_url,
);
const params = new URLSearchParams(location.search);
expect(params.get("error")).toBe("invalid_request");
expect(params.get("error_description")).toBe(
"Invalid JWT, could not verify",
);
});
test("should return error for missing required fields in JWT", async () => {
const jwt = await new SignJWT({
sub: users[0].id,
iss: new URL(config.http.base_url).origin,
aud: clientId,
})
.setProtectedHeader({ alg: "EdDSA" })
.sign(privateKey);
const response = await fakeRequest("/oauth/authorize", {
method: "POST",
headers: {
Authorization: `Bearer ${tokens[0]?.accessToken}`,
"Content-Type": "application/json",
Cookie: `jwt=${jwt}`,
},
body: JSON.stringify({
client_id: clientId,
redirect_uri: redirectUri,
response_type: "code",
scope,
state: "test-state",
code_challenge: randomString(43),
code_challenge_method: "S256",
}),
});
expect(response.status).toBe(302);
const location = new URL(
response.headers.get("Location") ?? "",
config.http.base_url,
);
const params = new URLSearchParams(location.search);
expect(params.get("error")).toBe("invalid_request");
expect(params.get("error_description")).toBe(
"Invalid JWT, missing required fields (aud, sub, exp)",
);
});
test("should return error for user not found", async () => {
const jwt = await new SignJWT({
sub: "non-existent-user",
aud: clientId,
exp: Math.floor(Date.now() / 1000) + 60 * 60,
iss: new URL(config.http.base_url).origin,
iat: Math.floor(Date.now() / 1000),
nbf: Math.floor(Date.now() / 1000),
})
.setProtectedHeader({ alg: "EdDSA" })
.sign(privateKey);
const response = await fakeRequest("/oauth/authorize", {
method: "POST",
headers: {
Authorization: `Bearer ${tokens[0]?.accessToken}`,
"Content-Type": "application/json",
Cookie: `jwt=${jwt}`,
},
body: JSON.stringify({
client_id: clientId,
redirect_uri: redirectUri,
response_type: "code",
scope,
state: "test-state",
code_challenge: randomString(43),
code_challenge_method: "S256",
}),
});
expect(response.status).toBe(302);
const location = new URL(
response.headers.get("Location") ?? "",
config.http.base_url,
);
const params = new URLSearchParams(location.search);
expect(params.get("error")).toBe("invalid_request");
expect(params.get("error_description")).toBe(
"Invalid JWT, sub is not a valid user ID",
);
const jwt2 = await new SignJWT({
sub: "23e42862-d5df-49a8-95b5-52d8c6a11aea",
aud: clientId,
exp: Math.floor(Date.now() / 1000) + 60 * 60,
iss: new URL(config.http.base_url).origin,
iat: Math.floor(Date.now() / 1000),
nbf: Math.floor(Date.now() / 1000),
})
.setProtectedHeader({ alg: "EdDSA" })
.sign(privateKey);
const response2 = await fakeRequest("/oauth/authorize", {
method: "POST",
headers: {
Authorization: `Bearer ${tokens[0]?.accessToken}`,
"Content-Type": "application/json",
Cookie: `jwt=${jwt2}`,
},
body: JSON.stringify({
client_id: clientId,
redirect_uri: redirectUri,
response_type: "code",
scope,
state: "test-state",
code_challenge: randomString(43),
code_challenge_method: "S256",
}),
});
expect(response2.status).toBe(302);
const location2 = new URL(
response2.headers.get("Location") ?? "",
config.http.base_url,
);
const params2 = new URLSearchParams(location2.search);
expect(params2.get("error")).toBe("invalid_request");
expect(params2.get("error_description")).toBe(
"Invalid JWT, could not find associated user",
);
});
test("should return error for user missing required permissions", async () => {
const oldPermissions = config.permissions.default;
config.permissions.default = [];
const jwt = await new SignJWT({
sub: users[0].id,
iss: new URL(config.http.base_url).origin,
aud: clientId,
exp: Math.floor(Date.now() / 1000) + 60 * 60,
iat: Math.floor(Date.now() / 1000),
nbf: Math.floor(Date.now() / 1000),
})
.setProtectedHeader({ alg: "EdDSA" })
.sign(privateKey);
const response = await fakeRequest("/oauth/authorize", {
method: "POST",
headers: {
Authorization: `Bearer ${tokens[0]?.accessToken}`,
"Content-Type": "application/json",
Cookie: `jwt=${jwt}`,
},
body: JSON.stringify({
client_id: clientId,
redirect_uri: redirectUri,
response_type: "code",
scope,
state: "test-state",
code_challenge: randomString(43),
code_challenge_method: "S256",
}),
});
expect(response.status).toBe(302);
const location = new URL(
response.headers.get("Location") ?? "",
config.http.base_url,
);
const params = new URLSearchParams(location.search);
expect(params.get("error")).toBe("invalid_request");
expect(params.get("error_description")).toBe(
`User is missing the required permission ${RolePermissions.OAuth}`,
);
config.permissions.default = oldPermissions;
});
test("should return error for invalid client_id", async () => {
const jwt = await new SignJWT({
sub: users[0].id,
aud: "invalid-client-id",
iss: new URL(config.http.base_url).origin,
exp: Math.floor(Date.now() / 1000) + 60 * 60,
iat: Math.floor(Date.now() / 1000),
nbf: Math.floor(Date.now() / 1000),
})
.setProtectedHeader({ alg: "EdDSA" })
.sign(privateKey);
const response = await fakeRequest("/oauth/authorize", {
method: "POST",
headers: {
Authorization: `Bearer ${tokens[0]?.accessToken}`,
"Content-Type": "application/json",
Cookie: `jwt=${jwt}`,
},
body: JSON.stringify({
client_id: "invalid-client-id",
redirect_uri: redirectUri,
response_type: "code",
scope,
state: "test-state",
code_challenge: randomString(43),
code_challenge_method: "S256",
}),
});
expect(response.status).toBe(302);
const location = new URL(
response.headers.get("Location") ?? "",
config.http.base_url,
);
const params = new URLSearchParams(location.search);
expect(params.get("error")).toBe("invalid_request");
expect(params.get("error_description")).toBe(
"Invalid client_id: no associated application found",
);
});
test("should return error for invalid redirect_uri", async () => {
const jwt = await new SignJWT({
sub: users[0].id,
iss: new URL(config.http.base_url).origin,
aud: clientId,
exp: Math.floor(Date.now() / 1000) + 60 * 60,
iat: Math.floor(Date.now() / 1000),
nbf: Math.floor(Date.now() / 1000),
})
.setProtectedHeader({ alg: "EdDSA" })
.sign(privateKey);
const response = await fakeRequest("/oauth/authorize", {
method: "POST",
headers: {
Authorization: `Bearer ${tokens[0]?.accessToken}`,
"Content-Type": "application/json",
Cookie: `jwt=${jwt}`,
},
body: JSON.stringify({
client_id: clientId,
redirect_uri: "https://invalid.com/callback",
response_type: "code",
scope,
state: "test-state",
code_challenge: randomString(43),
code_challenge_method: "S256",
}),
});
expect(response.status).toBe(302);
const location = new URL(
response.headers.get("Location") ?? "",
config.http.base_url,
);
const params = new URLSearchParams(location.search);
expect(params.get("error")).toBe("invalid_request");
expect(params.get("error_description")).toBe(
"Invalid redirect_uri: does not match application's redirect_uri",
);
});
test("should return error for invalid scope", async () => {
const jwt = await new SignJWT({
sub: users[0].id,
iss: new URL(config.http.base_url).origin,
aud: clientId,
exp: Math.floor(Date.now() / 1000) + 60 * 60,
iat: Math.floor(Date.now() / 1000),
nbf: Math.floor(Date.now() / 1000),
})
.setProtectedHeader({ alg: "EdDSA" })
.sign(privateKey);
const response = await fakeRequest("/oauth/authorize", {
method: "POST",
headers: {
Authorization: `Bearer ${tokens[0]?.accessToken}`,
"Content-Type": "application/json",
Cookie: `jwt=${jwt}`,
},
body: JSON.stringify({
client_id: clientId,
redirect_uri: redirectUri,
response_type: "code",
scope: "invalid-scope",
state: "test-state",
code_challenge: randomString(43),
code_challenge_method: "S256",
}),
});
expect(response.status).toBe(302);
const location = new URL(
response.headers.get("Location") ?? "",
config.http.base_url,
);
const params = new URLSearchParams(location.search);
expect(params.get("error")).toBe("invalid_scope");
expect(params.get("error_description")).toBe(
"Invalid scope: not a subset of the application's scopes",
);
});
});

View file

@ -159,6 +159,18 @@ export default (plugin: PluginType) =>
); );
} }
if (!z.string().uuid().safeParse(sub).success) {
errorSearchParams.append("error", "invalid_request");
errorSearchParams.append(
"error_description",
"Invalid JWT, sub is not a valid user ID",
);
return context.redirect(
`${context.get("config").frontend.routes.login}?${errorSearchParams.toString()}`,
);
}
const user = await User.fromId(sub); const user = await User.fromId(sub);
if (!user) { if (!user) {

View file

@ -1,6 +1,5 @@
import { afterAll, describe, expect, test } from "bun:test"; import { afterAll, describe, expect, test } from "bun:test";
import { fakeRequest, getTestUsers } from "~/tests/utils"; import { fakeRequest, getTestUsers } from "~/tests/utils";
import { meta } from "./index";
const { deleteUsers, tokens } = await getTestUsers(1); const { deleteUsers, tokens } = await getTestUsers(1);
@ -8,10 +7,9 @@ afterAll(async () => {
await deleteUsers(); await deleteUsers();
}); });
// /api/v1/sso describe("/api/v1/sso", () => {
describe(meta.route, () => {
test("should return empty list", async () => { test("should return empty list", async () => {
const response = await fakeRequest(meta.route, { const response = await fakeRequest("/api/v1/sso", {
method: "GET", method: "GET",
headers: { headers: {
Authorization: `Bearer ${tokens[0]?.accessToken}`, Authorization: `Bearer ${tokens[0]?.accessToken}`,
@ -23,7 +21,7 @@ describe(meta.route, () => {
}); });
test("should return an error if provider doesn't exist", async () => { test("should return an error if provider doesn't exist", async () => {
const response = await fakeRequest(meta.route, { const response = await fakeRequest("/api/v1/sso", {
method: "POST", method: "POST",
headers: { headers: {
Authorization: `Bearer ${tokens[0]?.accessToken}`, Authorization: `Bearer ${tokens[0]?.accessToken}`,

View file

@ -0,0 +1,223 @@
import { auth } from "@/api";
import {
calculatePKCECodeChallenge,
generateRandomCodeVerifier,
} from "oauth4webapi";
import { z } from "zod";
import { db } from "~/drizzle/db";
import {
Applications,
OpenIdLoginFlows,
RolePermissions,
} from "~/drizzle/schema";
import { ErrorSchema } from "~/types/api";
import type { PluginType } from "../..";
import { oauthDiscoveryRequest, oauthRedirectUri } from "../../utils";
export default (plugin: PluginType) => {
plugin.registerRoute("/api/v1/sso", (app) => {
app.openapi(
{
method: "get",
path: "/api/v1/sso",
summary: "Get linked accounts",
middleware: [
auth(
{
required: true,
},
{
required: [RolePermissions.OAuth],
},
),
plugin.middleware,
],
responses: {
200: {
description: "Linked accounts",
content: {
"application/json": {
schema: z.array(
z.object({
id: z.string(),
name: z.string(),
icon: z.string().optional(),
}),
),
},
},
},
401: {
description: "Unauthorized",
content: {
"application/json": {
schema: ErrorSchema,
},
},
},
},
},
async (context) => {
const { user } = context.get("auth");
if (!user) {
return context.json(
{
error: "Unauthorized",
},
401,
);
}
const linkedAccounts = await user.getLinkedOidcAccounts();
return context.json(
linkedAccounts.map((account) => ({
id: account.id,
name: account.name,
icon: account.icon,
})),
200,
);
},
);
app.openapi(
{
method: "post",
path: "/api/v1/sso",
summary: "Link account",
middleware: [
auth(
{
required: true,
},
{
required: [RolePermissions.OAuth],
},
),
],
request: {
body: {
content: {
"application/json": {
schema: z.object({
issuer: z.string(),
}),
},
},
},
},
responses: {
302: {
description: "Redirect to OpenID provider",
},
401: {
description: "Unauthorized",
content: {
"application/json": {
schema: ErrorSchema,
},
},
},
404: {
description: "Issuer not found",
content: {
"application/json": {
schema: ErrorSchema,
},
},
},
},
},
async (context) => {
const { user } = context.get("auth");
if (!user) {
return context.json(
{
error: "Unauthorized",
},
401,
);
}
const { issuer: issuerId } = context.req.valid("json");
const issuer = context
.get("pluginConfig")
.providers.find((provider) => provider.id === issuerId);
if (!issuer) {
return context.json(
{
error: `Issuer with ID ${issuerId} not found in instance's OpenID configuration`,
},
404,
);
}
const authServer = await oauthDiscoveryRequest(issuer.url);
const codeVerifier = generateRandomCodeVerifier();
const redirectUri = oauthRedirectUri(
issuerId,
context.get("config").http.base_url,
);
const application = (
await db
.insert(Applications)
.values({
clientId:
user.id +
Buffer.from(
crypto.getRandomValues(new Uint8Array(32)),
).toString("base64"),
name: "Versia",
redirectUri,
scopes: "openid profile email",
secret: "",
})
.returning()
)[0];
// Store into database
const newFlow = (
await db
.insert(OpenIdLoginFlows)
.values({
codeVerifier,
issuerId,
applicationId: application.id,
})
.returning()
)[0];
const codeChallenge =
await calculatePKCECodeChallenge(codeVerifier);
return context.redirect(
`${authServer.authorization_endpoint}?${new URLSearchParams(
{
client_id: issuer.client_id,
redirect_uri: `${redirectUri}?${new URLSearchParams(
{
flow: newFlow.id,
link: "true",
user_id: user.id,
},
)}`,
response_type: "code",
scope: "openid profile email",
// PKCE
code_challenge_method: "S256",
code_challenge: codeChallenge,
},
).toString()}`,
);
},
);
});
};

18
plugins/openid/utils.ts Normal file
View file

@ -0,0 +1,18 @@
import {
type AuthorizationServer,
discoveryRequest,
processDiscoveryResponse,
} from "oauth4webapi";
export const oauthDiscoveryRequest = (
issuerUrl: string | URL,
): Promise<AuthorizationServer> => {
const issuerUrlurl = new URL(issuerUrl);
return discoveryRequest(issuerUrlurl, {
algorithm: "oidc",
}).then((res) => processDiscoveryResponse(issuerUrlurl, res));
};
export const oauthRedirectUri = (baseUrl: string, issuer: string) =>
new URL(`/oauth/sso/${issuer}/callback`, baseUrl).toString();