From 9ff9b90f6b4e54383550cf95f01a8290328c15b2 Mon Sep 17 00:00:00 2001 From: Jesse Wierzbinski Date: Tue, 8 Apr 2025 16:59:18 +0200 Subject: [PATCH] refactor(federation): :recycle: Refactor User federation code --- api/api/v1/accounts/index.ts | 3 +- .../v1/accounts/update_credentials/index.ts | 28 +- .../accounts/verify_credentials/index.test.ts | 2 +- classes/database/user.ts | 281 ++++++++---------- classes/search/search-manager.ts | 2 +- cli/user/create.ts | 5 +- drizzle/schema.ts | 4 +- plugins/openid/routes/oauth/callback.ts | 4 +- tests/utils.ts | 3 +- 9 files changed, 152 insertions(+), 180 deletions(-) diff --git a/api/api/v1/accounts/index.ts b/api/api/v1/accounts/index.ts index 31a38585..3eac0054 100644 --- a/api/api/v1/accounts/index.ts +++ b/api/api/v1/accounts/index.ts @@ -353,8 +353,7 @@ export default apiRoute((app) => ); } - await User.fromDataLocal({ - username, + await User.register(username, { password, email, }); diff --git a/api/api/v1/accounts/update_credentials/index.ts b/api/api/v1/accounts/update_credentials/index.ts index 12ac5abd..34ab21c0 100644 --- a/api/api/v1/accounts/update_credentials/index.ts +++ b/api/api/v1/accounts/update_credentials/index.ts @@ -176,6 +176,15 @@ export default apiRoute((app) => } = context.req.valid("json"); const self = user.data; + if (!self.source) { + self.source = { + fields: [], + privacy: "public", + language: "en", + sensitive: false, + note: "", + }; + } const sanitizedDisplayName = await sanitizedHtmlStrip( display_name ?? "", @@ -185,7 +194,7 @@ export default apiRoute((app) => self.displayName = sanitizedDisplayName; } - if (note && self.source) { + if (note) { self.source.note = note; self.note = await contentToHtml( new VersiaEntities.TextContentFormat({ @@ -197,16 +206,13 @@ export default apiRoute((app) => ); } - if (source?.privacy) { - self.source.privacy = source.privacy; - } - - if (source?.sensitive) { - self.source.sensitive = source.sensitive; - } - - if (source?.language) { - self.source.language = source.language; + if (source) { + self.source = { + ...self.source, + privacy: source.privacy ?? self.source.privacy, + sensitive: source.sensitive ?? self.source.sensitive, + language: source.language ?? self.source.language, + }; } if (username) { diff --git a/api/api/v1/accounts/verify_credentials/index.test.ts b/api/api/v1/accounts/verify_credentials/index.test.ts index cadc7498..79b015f4 100644 --- a/api/api/v1/accounts/verify_credentials/index.test.ts +++ b/api/api/v1/accounts/verify_credentials/index.test.ts @@ -28,7 +28,7 @@ describe("/api/v1/accounts/verify_credentials", () => { expect(data.id).toBe(users[0].id); expect(data.username).toBe(users[0].data.username); expect(data.acct).toBe(users[0].data.username); - expect(data.display_name).toBe(users[0].data.displayName); + expect(data.display_name).toBe(users[0].data.displayName ?? ""); expect(data.note).toBe(users[0].data.note); expect(data.url).toBe( new URL( diff --git a/classes/database/user.ts b/classes/database/user.ts index 522d80c4..aacb65d0 100644 --- a/classes/database/user.ts +++ b/classes/database/user.ts @@ -679,120 +679,106 @@ export class User extends BaseInterface { ); } - public static async fromVersia(user: VersiaEntities.User): Promise { - const instance = await Instance.resolve(user.data.uri); - - const data = { - username: user.data.username, - uri: user.data.uri.href, - createdAt: new Date(user.data.created_at).toISOString(), - endpoints: { - dislikes: - user.data.collections["pub.versia:likes/Dislikes"]?.href ?? - undefined, - featured: user.data.collections.featured.href, - likes: - user.data.collections["pub.versia:likes/Likes"]?.href ?? - undefined, - followers: user.data.collections.followers.href, - following: user.data.collections.following.href, - inbox: user.data.inbox.href, - outbox: user.data.collections.outbox.href, - }, - fields: user.data.fields ?? [], - updatedAt: new Date(user.data.created_at).toISOString(), - instanceId: instance.id, - displayName: user.data.display_name ?? "", - note: getBestContentType(user.data.bio).content, - publicKey: user.data.public_key.key, - source: { - language: "en", - note: "", - privacy: "public", - sensitive: false, - fields: [], - } as z.infer, - }; - - const userEmojis = - user.data.extensions?.["pub.versia:custom_emojis"]?.emojis ?? []; - - const emojis = await Promise.all( - userEmojis.map((emoji) => Emoji.fromVersia(emoji, instance)), + /** + * Takes a Versia User representation, and serializes it to the database. + * + * If the user already exists, it will update it. + * @param user + */ + public static async fromVersia( + versiaUser: VersiaEntities.User, + ): Promise { + const { + username, + inbox, + avatar, + header, + display_name, + fields, + collections, + created_at, + bio, + public_key, + uri, + extensions, + } = versiaUser.data; + const instance = await Instance.resolve(versiaUser.data.uri); + const existingUser = await User.fromSql( + eq(Users.uri, versiaUser.data.uri.href), ); - // Check if new user already exists - const foundUser = await User.fromSql(eq(Users.uri, user.data.uri.href)); + const user = + existingUser ?? + (await User.insert({ + username, + id: randomUUIDv7(), + publicKey: public_key.key, + uri: uri.href, + instanceId: instance.id, + })); - // If it exists, simply update it - if (foundUser) { - let avatar: Media | null = null; - let header: Media | null = null; + // Avatars and headers are stored in a separate table, so we need to update them separately + let userAvatar: Media | null = null; + let userHeader: Media | null = null; - if (user.data.avatar) { - if (foundUser.avatar) { - avatar = new Media( - await foundUser.avatar.update({ - content: user.data.avatar, - }), - ); - } else { - avatar = await Media.insert({ - id: randomUUIDv7(), - content: user.data.avatar, - }); - } + if (avatar) { + if (user.avatar) { + userAvatar = new Media( + await user.avatar.update({ + content: avatar, + }), + ); + } else { + userAvatar = await Media.insert({ + id: randomUUIDv7(), + content: avatar, + }); } - - if (user.data.header) { - if (foundUser.header) { - header = new Media( - await foundUser.header.update({ - content: user.data.header, - }), - ); - } else { - header = await Media.insert({ - id: randomUUIDv7(), - content: user.data.header, - }); - } - } - - await foundUser.update({ - ...data, - avatarId: avatar?.id, - headerId: header?.id, - }); - await foundUser.updateEmojis(emojis); - - return foundUser; } - // Else, create a new user - const avatar = user.data.avatar - ? await Media.insert({ - id: randomUUIDv7(), - content: user.data.avatar, - }) - : null; + if (header) { + if (user.header) { + userHeader = new Media( + await user.header.update({ + content: header, + }), + ); + } else { + userHeader = await Media.insert({ + id: randomUUIDv7(), + content: header, + }); + } + } - const header = user.data.header - ? await Media.insert({ - id: randomUUIDv7(), - content: user.data.header, - }) - : null; - - const newUser = await User.insert({ - id: randomUUIDv7(), - ...data, - avatarId: avatar?.id, - headerId: header?.id, + await user.update({ + createdAt: new Date(created_at).toISOString(), + endpoints: { + inbox: inbox.href, + outbox: collections.outbox.href, + followers: collections.followers.href, + following: collections.following.href, + featured: collections.featured.href, + likes: collections["pub.versia:likes/Likes"]?.href, + dislikes: collections["pub.versia:likes/Dislikes"]?.href, + }, + avatarId: userAvatar?.id, + headerId: userHeader?.id, + fields: fields ?? [], + displayName: display_name, + note: getBestContentType(bio).content, }); - await newUser.updateEmojis(emojis); - return newUser; + // Emojis are stored in a separate table, so we need to update them separately + const emojis = await Promise.all( + extensions?.["pub.versia:custom_emojis"]?.emojis.map((e) => + Emoji.fromVersia(e, instance), + ) ?? [], + ); + + await user.updateEmojis(emojis); + + return user; } public static async insert( @@ -879,60 +865,45 @@ export class User extends BaseInterface { }; } - public static async fromDataLocal(data: { - username: string; - display_name?: string; - password: string | undefined; - email: string | undefined; - bio?: string; - avatar?: Media; - header?: Media; - admin?: boolean; - skipPasswordHash?: boolean; - }): Promise { + public static async register( + username: string, + options?: Partial<{ + email: string; + password: string; + avatar: Media; + isAdmin: boolean; + }>, + ): Promise { const keys = await User.generateKeys(); - const newUser = ( - await db - .insert(Users) - .values({ - id: randomUUIDv7(), - username: data.username, - displayName: data.display_name ?? data.username, - password: - data.skipPasswordHash || !data.password - ? data.password - : await bunPassword.hash(data.password), - email: data.email, - note: data.bio ?? "", - avatarId: data.avatar?.id, - headerId: data.header?.id, - isAdmin: data.admin ?? false, - publicKey: keys.public_key, - fields: [], - privateKey: keys.private_key, - updatedAt: new Date().toISOString(), - source: { - language: "en", - note: "", - privacy: "public", - sensitive: false, - fields: [], - } as z.infer, - }) - .returning() - )[0]; - - const finalUser = await User.fromId(newUser.id); - - if (!finalUser) { - throw new Error("Failed to create user"); - } + const user = await User.insert({ + id: randomUUIDv7(), + username: username, + displayName: username, + password: options?.password + ? await bunPassword.hash(options.password) + : null, + email: options?.email, + note: "", + avatarId: options?.avatar?.id, + isAdmin: options?.isAdmin, + publicKey: keys.public_key, + fields: [], + privateKey: keys.private_key, + updatedAt: new Date().toISOString(), + source: { + language: "en", + note: "", + privacy: "public", + sensitive: false, + fields: [], + } as z.infer, + }); // Add to search index - await searchManager.addUser(finalUser); + await searchManager.addUser(user); - return finalUser; + return user; } /** @@ -1093,7 +1064,7 @@ export class User extends BaseInterface { return { id: user.id, username: user.username, - display_name: user.displayName, + display_name: user.displayName || user.username, note: user.note, uri: this.getUri().toString(), url: @@ -1119,7 +1090,7 @@ export class User extends BaseInterface { verified_at: null, })), bot: user.isBot, - source: isOwnAccount ? user.source : undefined, + source: isOwnAccount ? (user.source ?? undefined) : undefined, // TODO: Add static avatar and header avatar_static: this.getAvatarUrl().proxied, header_static: this.getHeaderUrl()?.proxied ?? "", diff --git a/classes/search/search-manager.ts b/classes/search/search-manager.ts index 25689381..db97f006 100644 --- a/classes/search/search-manager.ts +++ b/classes/search/search-manager.ts @@ -155,7 +155,7 @@ export class SonicSearchManager { private static getNthDatabaseAccountBatch( n: number, batchSize = 1000, - ): Promise[]> { + ): Promise[]> { return db.query.Users.findMany({ offset: n * batchSize, limit: batchSize, diff --git a/cli/user/create.ts b/cli/user/create.ts index 4d840b25..6a5fd696 100644 --- a/cli/user/create.ts +++ b/cli/user/create.ts @@ -48,11 +48,10 @@ export const createUserCommand = defineCommand( throw new Error(`User ${chalk.gray(username)} is taken.`); } - const user = await User.fromDataLocal({ + const user = await User.register(username, { email, password, - username, - admin, + isAdmin: admin, }); if (!user) { diff --git a/drizzle/schema.ts b/drizzle/schema.ts index 9a18f7b9..4574e567 100644 --- a/drizzle/schema.ts +++ b/drizzle/schema.ts @@ -556,7 +556,7 @@ export const Users = pgTable( id: id(), uri: uri(), username: text("username").notNull(), - displayName: text("display_name").notNull(), + displayName: text("display_name"), password: text("password"), email: text("email"), note: text("note").default("").notNull(), @@ -578,7 +578,7 @@ export const Users = pgTable( inbox: string; outbox: string; }> | null>(), - source: jsonb("source").notNull().$type>(), + source: jsonb("source").$type>(), avatarId: uuid("avatarId").references(() => Medias.id, { onDelete: "set null", onUpdate: "cascade", diff --git a/plugins/openid/routes/oauth/callback.ts b/plugins/openid/routes/oauth/callback.ts index 2f7c6bff..23207ca0 100644 --- a/plugins/openid/routes/oauth/callback.ts +++ b/plugins/openid/routes/oauth/callback.ts @@ -235,11 +235,9 @@ export default (plugin: PluginType): void => { : null; // Create new user - const user = await User.fromDataLocal({ + const user = await User.register(username, { email: doesEmailExist ? undefined : email, - username, avatar: avatar ?? undefined, - password: undefined, }); // Link account diff --git a/tests/utils.ts b/tests/utils.ts index 508d92d1..3058d501 100644 --- a/tests/utils.ts +++ b/tests/utils.ts @@ -103,8 +103,7 @@ export const getTestUsers = async ( for (let i = 0; i < count; i++) { const password = randomString(32, "hex"); - const user = await User.fromDataLocal({ - username: `test-${randomString(8, "hex")}`, + const user = await User.register(`test-${randomString(8, "hex")}`, { email: `${randomString(16, "hex")}@test.com`, password, });