From a8541bdc4425e4fffbc08caa53b8fb8a39ae98c4 Mon Sep 17 00:00:00 2001 From: Jesse Wierzbinski Date: Mon, 9 Dec 2024 13:36:15 +0100 Subject: [PATCH] refactor(database): :recycle: Simplify Note and User federation logic --- .../v1/accounts/update_credentials/index.ts | 42 ++---- api/api/v1/statuses/:id/index.ts | 11 +- api/api/v1/statuses/index.ts | 11 +- classes/database/note.ts | 125 +++++++++--------- classes/database/user.ts | 80 +++++------ classes/inbox/processor.test.ts | 19 ++- classes/inbox/processor.ts | 2 +- 7 files changed, 137 insertions(+), 153 deletions(-) diff --git a/api/api/v1/accounts/update_credentials/index.ts b/api/api/v1/accounts/update_credentials/index.ts index 57c9165a..2c3fa4ba 100644 --- a/api/api/v1/accounts/update_credentials/index.ts +++ b/api/api/v1/accounts/update_credentials/index.ts @@ -1,8 +1,8 @@ import { apiRoute, applyConfig, auth, jsonOrForm } from "@/api"; import { sanitizedHtmlStrip } from "@/sanitization"; import { createRoute } from "@hono/zod-openapi"; -import { Attachment, Emoji, User, db } from "@versia/kit/db"; -import { EmojiToUser, RolePermissions, Users } from "@versia/kit/tables"; +import { Attachment, Emoji, User } from "@versia/kit/db"; +import { RolePermissions, Users } from "@versia/kit/tables"; import { and, eq, isNull } from "drizzle-orm"; import ISO6391 from "iso-639-1"; import { z } from "zod"; @@ -335,36 +335,19 @@ export default apiRoute((app) => await Emoji.parseFromText(sanitizedDisplayName); const noteEmojis = await Emoji.parseFromText(self.note); - self.emojis = [...displaynameEmojis, ...noteEmojis, ...fieldEmojis] - .map((e) => e.data) - .filter( - // Deduplicate emojis - (emoji, index, self) => - self.findIndex((e) => e.id === emoji.id) === index, - ); + const emojis = [ + ...displaynameEmojis, + ...noteEmojis, + ...fieldEmojis, + ].filter( + // Deduplicate emojis + (emoji, index, self) => + self.findIndex((e) => e.id === emoji.id) === index, + ); // Connect emojis, if any // Do it before updating user, so that federation takes that into account - for (const emoji of self.emojis) { - await db - .delete(EmojiToUser) - .where( - and( - eq(EmojiToUser.emojiId, emoji.id), - eq(EmojiToUser.userId, self.id), - ), - ) - .execute(); - - await db - .insert(EmojiToUser) - .values({ - emojiId: emoji.id, - userId: self.id, - }) - .execute(); - } - + await user.updateEmojis(emojis); await user.update({ displayName: self.displayName, username: self.username, @@ -379,6 +362,7 @@ export default apiRoute((app) => }); const output = await User.fromId(self.id); + if (!output) { return context.json({ error: "Couldn't edit user" }, 500); } diff --git a/api/api/v1/statuses/:id/index.ts b/api/api/v1/statuses/:id/index.ts index d5b1dd2a..622ee4ea 100644 --- a/api/api/v1/statuses/:id/index.ts +++ b/api/api/v1/statuses/:id/index.ts @@ -271,12 +271,11 @@ export default apiRoute((app) => { sensitive, } = context.req.valid("json"); - if (media_ids.length > 0) { - const foundAttachments = await Attachment.fromIds(media_ids); + const foundAttachments = + media_ids.length > 0 ? await Attachment.fromIds(media_ids) : []; - if (foundAttachments.length !== media_ids.length) { - return context.json({ error: "Invalid media IDs" }, 422); - } + if (foundAttachments.length !== media_ids.length) { + return context.json({ error: "Invalid media IDs" }, 422); } const newNote = await note.updateFromData({ @@ -291,7 +290,7 @@ export default apiRoute((app) => { : undefined, isSensitive: sensitive, spoilerText: spoiler_text, - mediaAttachments: media_ids, + mediaAttachments: foundAttachments, }); return context.json(await newNote.toApi(user), 200); diff --git a/api/api/v1/statuses/index.ts b/api/api/v1/statuses/index.ts index 776b228b..55ea51da 100644 --- a/api/api/v1/statuses/index.ts +++ b/api/api/v1/statuses/index.ts @@ -168,12 +168,11 @@ export default apiRoute((app) => } = context.req.valid("json"); // Check if media attachments are all valid - if (media_ids.length > 0) { - const foundAttachments = await Attachment.fromIds(media_ids); + const foundAttachments = + media_ids.length > 0 ? await Attachment.fromIds(media_ids) : []; - if (foundAttachments.length !== media_ids.length) { - return context.json({ error: "Invalid media IDs" }, 422); - } + if (foundAttachments.length !== media_ids.length) { + return context.json({ error: "Invalid media IDs" }, 422); } // Check that in_reply_to_id and quote_id are real posts if provided @@ -199,7 +198,7 @@ export default apiRoute((app) => visibility, isSensitive: sensitive ?? false, spoilerText: spoiler_text ?? "", - mediaAttachments: media_ids, + mediaAttachments: foundAttachments, replyId: in_reply_to_id ?? undefined, quoteId: quote_id ?? undefined, application: application ?? undefined, diff --git a/classes/database/note.ts b/classes/database/note.ts index 538ff196..9778471e 100644 --- a/classes/database/note.ts +++ b/classes/database/note.ts @@ -417,7 +417,7 @@ export class Note extends BaseInterface { throw new Error("Cannot refetch a local note (it is not remote)"); } - const updated = await Note.saveFromRemote(this.getUri()); + const updated = await Note.fetchFromRemote(this.getUri()); if (!updated) { throw new Error("Note not found after update"); @@ -443,7 +443,7 @@ export class Note extends BaseInterface { uri?: string; mentions?: User[]; /** List of IDs of database Attachment objects */ - mediaAttachments?: string[]; + mediaAttachments?: Attachment[]; replyId?: string; quoteId?: string; application?: Application; @@ -491,15 +491,13 @@ export class Note extends BaseInterface { }); // Connect emojis - await newNote.recalculateDatabaseEmojis(parsedEmojis); + await newNote.updateEmojis(parsedEmojis); // Connect mentions - await newNote.recalculateDatabaseMentions(parsedMentions); + await newNote.updateMentions(parsedMentions); // Set attachment parents - await newNote.recalculateDatabaseAttachments( - data.mediaAttachments ?? [], - ); + await newNote.updateAttachments(data.mediaAttachments ?? []); // Send notifications for mentioned local users for (const mention of parsedMentions ?? []) { @@ -532,8 +530,7 @@ export class Note extends BaseInterface { emojis?: Emoji[]; uri?: string; mentions?: User[]; - /** List of IDs of database Attachment objects */ - mediaAttachments?: string[]; + mediaAttachments?: Attachment[]; replyId?: string; quoteId?: string; application?: Application; @@ -587,13 +584,13 @@ export class Note extends BaseInterface { }); // Connect emojis - await this.recalculateDatabaseEmojis(parsedEmojis); + await this.updateEmojis(parsedEmojis); // Connect mentions - await this.recalculateDatabaseMentions(parsedMentions); + await this.updateMentions(parsedMentions); // Set attachment parents - await this.recalculateDatabaseAttachments(data.mediaAttachments ?? []); + await this.updateAttachments(data.mediaAttachments ?? []); await this.reload(data.author.id); @@ -606,7 +603,11 @@ export class Note extends BaseInterface { * Deletes all existing emojis associated with this note, then replaces them with the provided emojis. * @param emojis - The emojis to associate with this note */ - public async recalculateDatabaseEmojis(emojis: Emoji[]): Promise { + public async updateEmojis(emojis: Emoji[]): Promise { + if (emojis.length === 0) { + return; + } + // Fuse and deduplicate const fusedEmojis = emojis.filter( (emoji, index, self) => @@ -617,16 +618,12 @@ export class Note extends BaseInterface { await db .delete(EmojiToNote) .where(eq(EmojiToNote.noteId, this.data.id)); - - for (const emoji of fusedEmojis) { - await db - .insert(EmojiToNote) - .values({ - emojiId: emoji.id, - noteId: this.data.id, - }) - .execute(); - } + await db.insert(EmojiToNote).values( + fusedEmojis.map((emoji) => ({ + emojiId: emoji.id, + noteId: this.data.id, + })), + ); } /** @@ -635,21 +632,21 @@ export class Note extends BaseInterface { * Deletes all existing mentions associated with this note, then replaces them with the provided mentions. * @param mentions - The mentions to associate with this note */ - public async recalculateDatabaseMentions(mentions: User[]): Promise { + public async updateMentions(mentions: User[]): Promise { + if (mentions.length === 0) { + return; + } + // Connect mentions await db .delete(NoteToMentions) .where(eq(NoteToMentions.noteId, this.data.id)); - - for (const mention of mentions) { - await db - .insert(NoteToMentions) - .values({ - noteId: this.data.id, - userId: mention.id, - }) - .execute(); - } + await db.insert(NoteToMentions).values( + mentions.map((mention) => ({ + noteId: this.data.id, + userId: mention.id, + })), + ); } /** @@ -658,25 +655,31 @@ export class Note extends BaseInterface { * Deletes all existing attachments associated with this note, then replaces them with the provided attachments. * @param mediaAttachments - The IDs of the attachments to associate with this note */ - public async recalculateDatabaseAttachments( - mediaAttachments: string[], + public async updateAttachments( + mediaAttachments: Attachment[], ): Promise { - // Set attachment parents + if (mediaAttachments.length === 0) { + return; + } + + // Remove old attachments await db .update(Attachments) .set({ noteId: null, }) .where(eq(Attachments.noteId, this.data.id)); - - if (mediaAttachments.length > 0) { - await db - .update(Attachments) - .set({ - noteId: this.data.id, - }) - .where(inArray(Attachments.id, mediaAttachments)); - } + await db + .update(Attachments) + .set({ + noteId: this.data.id, + }) + .where( + inArray( + Attachments.id, + mediaAttachments.map((i) => i.id), + ), + ); } /** @@ -705,7 +708,7 @@ export class Note extends BaseInterface { return await Note.fromId(uuid[0]); } - return await Note.saveFromRemote(uri); + return await Note.fetchFromRemote(uri); } /** @@ -713,32 +716,21 @@ export class Note extends BaseInterface { * @param uri - The URI of the note to save * @returns The saved note, or null if the note could not be fetched */ - public static async saveFromRemote(uri: string): Promise { - let note: VersiaNote | null = null; + public static async fetchFromRemote(uri: string): Promise { const instance = await Instance.resolve(uri); if (!instance) { return null; } - if (uri) { - if (!URL.canParse(uri)) { - throw new Error(`Invalid URI to parse ${uri}`); - } + const requester = await User.getFederationRequester(); - const requester = await User.getFederationRequester(); + const { data } = await requester.get(uri, { + // @ts-expect-error Bun extension + proxy: config.http.proxy.address, + }); - const { data } = await requester.get(uri, { - // @ts-expect-error Bun extension - proxy: config.http.proxy.address, - }); - - note = await new EntityValidator().Note(data); - } - - if (!note) { - throw new Error("No note was able to be fetched"); - } + const note = await new EntityValidator().Note(data); const author = await User.resolve(note.author); @@ -753,6 +745,7 @@ export class Note extends BaseInterface { * Turns a Versia Note into a database note (saved) * @param note Versia Note * @param author Author of the note + * @param instance Instance of the note * @returns The saved note */ public static async fromVersia( @@ -824,7 +817,7 @@ export class Note extends BaseInterface { .map((mention) => User.resolve(mention)) .filter((mention) => mention !== null) as Promise[], ), - mediaAttachments: attachments.map((a) => a.id), + mediaAttachments: attachments, replyId: note.replies_to ? (await Note.resolve(note.replies_to))?.data.id : undefined, diff --git a/classes/database/user.ts b/classes/database/user.ts index e14646db..0226f235 100644 --- a/classes/database/user.ts +++ b/classes/database/user.ts @@ -597,7 +597,7 @@ export class User extends BaseInterface { ); } - const updated = await User.saveFromRemote(this.getUri()); + const updated = await User.fetchFromRemote(this.getUri()); if (!updated) { throw new Error("Failed to update user from remote"); @@ -608,7 +608,7 @@ export class User extends BaseInterface { return this; } - public static async saveFromRemote(uri: string): Promise { + public static async fetchFromRemote(uri: string): Promise { if (!URL.canParse(uri)) { throw new Error(`Invalid URI: ${uri}`); } @@ -644,18 +644,12 @@ export class User extends BaseInterface { private static async saveFromVersia( uri: string, instance: Instance, - ): Promise { + ): Promise { const requester = await User.getFederationRequester(); - const output = await requester - .get>(uri, { - // @ts-expect-error Bun extension - proxy: config.http.proxy.address, - }) - .catch(() => null); - - if (!output) { - return null; - } + const output = await requester.get>(uri, { + // @ts-expect-error Bun extension + proxy: config.http.proxy.address, + }); const { data: json } = output; @@ -664,30 +658,28 @@ export class User extends BaseInterface { const user = await User.fromVersia(data, instance); - const userEmojis = - data.extensions?.["pub.versia:custom_emojis"]?.emojis ?? []; - const emojis = await Promise.all( - userEmojis.map((emoji) => Emoji.fromVersia(emoji, instance)), + await searchManager.addUser(user); + + return user; + } + + /** + * Change the emojis linked to this user in database + * @param emojis + * @returns + */ + public async updateEmojis(emojis: Emoji[]): Promise { + if (emojis.length === 0) { + return; + } + + await db.delete(EmojiToUser).where(eq(EmojiToUser.userId, this.id)); + await db.insert(EmojiToUser).values( + emojis.map((emoji) => ({ + emojiId: emoji.id, + userId: this.id, + })), ); - - if (emojis.length > 0) { - await db.delete(EmojiToUser).where(eq(EmojiToUser.userId, user.id)); - await db.insert(EmojiToUser).values( - emojis.map((emoji) => ({ - emojiId: emoji.id, - userId: user.id, - })), - ); - } - - const finalUser = await User.fromId(user.id); - if (!finalUser) { - throw new Error("Failed to save user from remote"); - } - - await searchManager.addUser(finalUser); - - return finalUser; } public static async fromVersia( @@ -729,19 +721,29 @@ export class User extends BaseInterface { }, }; - // Check if new user already exists + const userEmojis = + user.extensions?.["pub.versia:custom_emojis"]?.emojis ?? []; + const emojis = await Promise.all( + userEmojis.map((emoji) => Emoji.fromVersia(emoji, instance)), + ); + + // Check if new user already exists const foundUser = await User.fromSql(eq(Users.uri, user.uri)); // If it exists, simply update it if (foundUser) { await foundUser.update(data); + await foundUser.updateEmojis(emojis); return foundUser; } // Else, create a new user - return await User.insert(data); + const newUser = await User.insert(data); + await newUser.updateEmojis(emojis); + + return newUser; } public static async insert( @@ -784,7 +786,7 @@ export class User extends BaseInterface { getLogger(["federation", "resolvers"]) .debug`User not found in database, fetching from remote`; - return await User.saveFromRemote(uri); + return await User.fetchFromRemote(uri); } /** diff --git a/classes/inbox/processor.test.ts b/classes/inbox/processor.test.ts index a8d2593e..2b56be2c 100644 --- a/classes/inbox/processor.test.ts +++ b/classes/inbox/processor.test.ts @@ -2,7 +2,7 @@ import { beforeEach, describe, expect, jest, mock, test } from "bun:test"; import { SignatureValidator } from "@versia/federation"; import type { Entity, Note as VersiaNote } from "@versia/federation/types"; import { - type Instance, + Instance, Note, Notification, Relationship, @@ -23,11 +23,12 @@ mock.module("@versia/kit/db", () => ({ }, User: { resolve: jest.fn(), - saveFromRemote: jest.fn(), + fetchFromRemote: jest.fn(), sendFollowAccept: jest.fn(), }, Instance: { fromUser: jest.fn(), + resolve: jest.fn(), }, Note: { resolve: jest.fn(), @@ -198,9 +199,11 @@ describe("InboxProcessor", () => { test("successfully processes valid note", async () => { const mockNote = { author: "test-author" }; const mockAuthor = { id: "test-id" }; + const mockInstance = { id: "test-id" }; User.resolve = jest.fn().mockResolvedValue(mockAuthor); Note.fromVersia = jest.fn().mockResolvedValue(true); + Instance.resolve = jest.fn().mockResolvedValue(mockInstance); // biome-ignore lint/complexity/useLiteralKeys: Private variable processor["body"] = mockNote as VersiaNote; @@ -208,7 +211,11 @@ describe("InboxProcessor", () => { const result = await processor["processNote"](); expect(User.resolve).toHaveBeenCalledWith("test-author"); - expect(Note.fromVersia).toHaveBeenCalledWith(mockNote, mockAuthor); + expect(Note.fromVersia).toHaveBeenCalledWith( + mockNote, + mockAuthor, + mockInstance, + ); expect(result).toBeNull(); }); @@ -364,19 +371,19 @@ describe("InboxProcessor", () => { }; const mockUpdatedUser = { id: "user-id" }; - User.saveFromRemote = jest.fn().mockResolvedValue(mockUpdatedUser); + User.fetchFromRemote = jest.fn().mockResolvedValue(mockUpdatedUser); // biome-ignore lint/complexity/useLiteralKeys: Private variable processor["body"] = mockUser as unknown as Entity; // biome-ignore lint/complexity/useLiteralKeys: Private method const result = await processor["processUserRequest"](); - expect(User.saveFromRemote).toHaveBeenCalledWith("test-uri"); + expect(User.fetchFromRemote).toHaveBeenCalledWith("test-uri"); expect(result).toBeNull(); }); test("returns 500 when update fails", async () => { - User.saveFromRemote = jest.fn().mockResolvedValue(null); + User.fetchFromRemote = jest.fn().mockResolvedValue(null); // biome-ignore lint/complexity/useLiteralKeys: Private method const result = await processor["processUserRequest"](); diff --git a/classes/inbox/processor.ts b/classes/inbox/processor.ts index 6697c4aa..eb570eb4 100644 --- a/classes/inbox/processor.ts +++ b/classes/inbox/processor.ts @@ -554,7 +554,7 @@ export class InboxProcessor { private async processUserRequest(): Promise { const user = this.body as unknown as VersiaUser; // FIXME: Instead of refetching the remote user, we should read the incoming json and update from that - const updatedAccount = await User.saveFromRemote(user.uri); + const updatedAccount = await User.fetchFromRemote(user.uri); if (!updatedAccount) { return Response.json(