refactor(database): ♻️ Simplify Note and User federation logic

This commit is contained in:
Jesse Wierzbinski 2024-12-09 13:36:15 +01:00
parent cbbf49905b
commit a8541bdc44
No known key found for this signature in database
7 changed files with 137 additions and 153 deletions

View file

@ -1,8 +1,8 @@
import { apiRoute, applyConfig, auth, jsonOrForm } from "@/api"; import { apiRoute, applyConfig, auth, jsonOrForm } from "@/api";
import { sanitizedHtmlStrip } from "@/sanitization"; import { sanitizedHtmlStrip } from "@/sanitization";
import { createRoute } from "@hono/zod-openapi"; import { createRoute } from "@hono/zod-openapi";
import { Attachment, Emoji, User, db } from "@versia/kit/db"; import { Attachment, Emoji, User } from "@versia/kit/db";
import { EmojiToUser, RolePermissions, Users } from "@versia/kit/tables"; import { RolePermissions, Users } from "@versia/kit/tables";
import { and, eq, isNull } from "drizzle-orm"; import { and, eq, isNull } from "drizzle-orm";
import ISO6391 from "iso-639-1"; import ISO6391 from "iso-639-1";
import { z } from "zod"; import { z } from "zod";
@ -335,9 +335,11 @@ export default apiRoute((app) =>
await Emoji.parseFromText(sanitizedDisplayName); await Emoji.parseFromText(sanitizedDisplayName);
const noteEmojis = await Emoji.parseFromText(self.note); const noteEmojis = await Emoji.parseFromText(self.note);
self.emojis = [...displaynameEmojis, ...noteEmojis, ...fieldEmojis] const emojis = [
.map((e) => e.data) ...displaynameEmojis,
.filter( ...noteEmojis,
...fieldEmojis,
].filter(
// Deduplicate emojis // Deduplicate emojis
(emoji, index, self) => (emoji, index, self) =>
self.findIndex((e) => e.id === emoji.id) === index, self.findIndex((e) => e.id === emoji.id) === index,
@ -345,26 +347,7 @@ export default apiRoute((app) =>
// Connect emojis, if any // Connect emojis, if any
// Do it before updating user, so that federation takes that into account // Do it before updating user, so that federation takes that into account
for (const emoji of self.emojis) { await user.updateEmojis(emojis);
await db
.delete(EmojiToUser)
.where(
and(
eq(EmojiToUser.emojiId, emoji.id),
eq(EmojiToUser.userId, self.id),
),
)
.execute();
await db
.insert(EmojiToUser)
.values({
emojiId: emoji.id,
userId: self.id,
})
.execute();
}
await user.update({ await user.update({
displayName: self.displayName, displayName: self.displayName,
username: self.username, username: self.username,
@ -379,6 +362,7 @@ export default apiRoute((app) =>
}); });
const output = await User.fromId(self.id); const output = await User.fromId(self.id);
if (!output) { if (!output) {
return context.json({ error: "Couldn't edit user" }, 500); return context.json({ error: "Couldn't edit user" }, 500);
} }

View file

@ -271,13 +271,12 @@ export default apiRoute((app) => {
sensitive, sensitive,
} = context.req.valid("json"); } = context.req.valid("json");
if (media_ids.length > 0) { const foundAttachments =
const foundAttachments = await Attachment.fromIds(media_ids); media_ids.length > 0 ? await Attachment.fromIds(media_ids) : [];
if (foundAttachments.length !== media_ids.length) { if (foundAttachments.length !== media_ids.length) {
return context.json({ error: "Invalid media IDs" }, 422); return context.json({ error: "Invalid media IDs" }, 422);
} }
}
const newNote = await note.updateFromData({ const newNote = await note.updateFromData({
author: user, author: user,
@ -291,7 +290,7 @@ export default apiRoute((app) => {
: undefined, : undefined,
isSensitive: sensitive, isSensitive: sensitive,
spoilerText: spoiler_text, spoilerText: spoiler_text,
mediaAttachments: media_ids, mediaAttachments: foundAttachments,
}); });
return context.json(await newNote.toApi(user), 200); return context.json(await newNote.toApi(user), 200);

View file

@ -168,13 +168,12 @@ export default apiRoute((app) =>
} = context.req.valid("json"); } = context.req.valid("json");
// Check if media attachments are all valid // Check if media attachments are all valid
if (media_ids.length > 0) { const foundAttachments =
const foundAttachments = await Attachment.fromIds(media_ids); media_ids.length > 0 ? await Attachment.fromIds(media_ids) : [];
if (foundAttachments.length !== media_ids.length) { if (foundAttachments.length !== media_ids.length) {
return context.json({ error: "Invalid media IDs" }, 422); return context.json({ error: "Invalid media IDs" }, 422);
} }
}
// Check that in_reply_to_id and quote_id are real posts if provided // Check that in_reply_to_id and quote_id are real posts if provided
if (in_reply_to_id && !(await Note.fromId(in_reply_to_id))) { if (in_reply_to_id && !(await Note.fromId(in_reply_to_id))) {
@ -199,7 +198,7 @@ export default apiRoute((app) =>
visibility, visibility,
isSensitive: sensitive ?? false, isSensitive: sensitive ?? false,
spoilerText: spoiler_text ?? "", spoilerText: spoiler_text ?? "",
mediaAttachments: media_ids, mediaAttachments: foundAttachments,
replyId: in_reply_to_id ?? undefined, replyId: in_reply_to_id ?? undefined,
quoteId: quote_id ?? undefined, quoteId: quote_id ?? undefined,
application: application ?? undefined, application: application ?? undefined,

View file

@ -417,7 +417,7 @@ export class Note extends BaseInterface<typeof Notes, NoteTypeWithRelations> {
throw new Error("Cannot refetch a local note (it is not remote)"); throw new Error("Cannot refetch a local note (it is not remote)");
} }
const updated = await Note.saveFromRemote(this.getUri()); const updated = await Note.fetchFromRemote(this.getUri());
if (!updated) { if (!updated) {
throw new Error("Note not found after update"); throw new Error("Note not found after update");
@ -443,7 +443,7 @@ export class Note extends BaseInterface<typeof Notes, NoteTypeWithRelations> {
uri?: string; uri?: string;
mentions?: User[]; mentions?: User[];
/** List of IDs of database Attachment objects */ /** List of IDs of database Attachment objects */
mediaAttachments?: string[]; mediaAttachments?: Attachment[];
replyId?: string; replyId?: string;
quoteId?: string; quoteId?: string;
application?: Application; application?: Application;
@ -491,15 +491,13 @@ export class Note extends BaseInterface<typeof Notes, NoteTypeWithRelations> {
}); });
// Connect emojis // Connect emojis
await newNote.recalculateDatabaseEmojis(parsedEmojis); await newNote.updateEmojis(parsedEmojis);
// Connect mentions // Connect mentions
await newNote.recalculateDatabaseMentions(parsedMentions); await newNote.updateMentions(parsedMentions);
// Set attachment parents // Set attachment parents
await newNote.recalculateDatabaseAttachments( await newNote.updateAttachments(data.mediaAttachments ?? []);
data.mediaAttachments ?? [],
);
// Send notifications for mentioned local users // Send notifications for mentioned local users
for (const mention of parsedMentions ?? []) { for (const mention of parsedMentions ?? []) {
@ -532,8 +530,7 @@ export class Note extends BaseInterface<typeof Notes, NoteTypeWithRelations> {
emojis?: Emoji[]; emojis?: Emoji[];
uri?: string; uri?: string;
mentions?: User[]; mentions?: User[];
/** List of IDs of database Attachment objects */ mediaAttachments?: Attachment[];
mediaAttachments?: string[];
replyId?: string; replyId?: string;
quoteId?: string; quoteId?: string;
application?: Application; application?: Application;
@ -587,13 +584,13 @@ export class Note extends BaseInterface<typeof Notes, NoteTypeWithRelations> {
}); });
// Connect emojis // Connect emojis
await this.recalculateDatabaseEmojis(parsedEmojis); await this.updateEmojis(parsedEmojis);
// Connect mentions // Connect mentions
await this.recalculateDatabaseMentions(parsedMentions); await this.updateMentions(parsedMentions);
// Set attachment parents // Set attachment parents
await this.recalculateDatabaseAttachments(data.mediaAttachments ?? []); await this.updateAttachments(data.mediaAttachments ?? []);
await this.reload(data.author.id); await this.reload(data.author.id);
@ -606,7 +603,11 @@ export class Note extends BaseInterface<typeof Notes, NoteTypeWithRelations> {
* Deletes all existing emojis associated with this note, then replaces them with the provided emojis. * Deletes all existing emojis associated with this note, then replaces them with the provided emojis.
* @param emojis - The emojis to associate with this note * @param emojis - The emojis to associate with this note
*/ */
public async recalculateDatabaseEmojis(emojis: Emoji[]): Promise<void> { public async updateEmojis(emojis: Emoji[]): Promise<void> {
if (emojis.length === 0) {
return;
}
// Fuse and deduplicate // Fuse and deduplicate
const fusedEmojis = emojis.filter( const fusedEmojis = emojis.filter(
(emoji, index, self) => (emoji, index, self) =>
@ -617,16 +618,12 @@ export class Note extends BaseInterface<typeof Notes, NoteTypeWithRelations> {
await db await db
.delete(EmojiToNote) .delete(EmojiToNote)
.where(eq(EmojiToNote.noteId, this.data.id)); .where(eq(EmojiToNote.noteId, this.data.id));
await db.insert(EmojiToNote).values(
for (const emoji of fusedEmojis) { fusedEmojis.map((emoji) => ({
await db
.insert(EmojiToNote)
.values({
emojiId: emoji.id, emojiId: emoji.id,
noteId: this.data.id, noteId: this.data.id,
}) })),
.execute(); );
}
} }
/** /**
@ -635,21 +632,21 @@ export class Note extends BaseInterface<typeof Notes, NoteTypeWithRelations> {
* Deletes all existing mentions associated with this note, then replaces them with the provided mentions. * Deletes all existing mentions associated with this note, then replaces them with the provided mentions.
* @param mentions - The mentions to associate with this note * @param mentions - The mentions to associate with this note
*/ */
public async recalculateDatabaseMentions(mentions: User[]): Promise<void> { public async updateMentions(mentions: User[]): Promise<void> {
if (mentions.length === 0) {
return;
}
// Connect mentions // Connect mentions
await db await db
.delete(NoteToMentions) .delete(NoteToMentions)
.where(eq(NoteToMentions.noteId, this.data.id)); .where(eq(NoteToMentions.noteId, this.data.id));
await db.insert(NoteToMentions).values(
for (const mention of mentions) { mentions.map((mention) => ({
await db
.insert(NoteToMentions)
.values({
noteId: this.data.id, noteId: this.data.id,
userId: mention.id, userId: mention.id,
}) })),
.execute(); );
}
} }
/** /**
@ -658,25 +655,31 @@ export class Note extends BaseInterface<typeof Notes, NoteTypeWithRelations> {
* Deletes all existing attachments associated with this note, then replaces them with the provided attachments. * Deletes all existing attachments associated with this note, then replaces them with the provided attachments.
* @param mediaAttachments - The IDs of the attachments to associate with this note * @param mediaAttachments - The IDs of the attachments to associate with this note
*/ */
public async recalculateDatabaseAttachments( public async updateAttachments(
mediaAttachments: string[], mediaAttachments: Attachment[],
): Promise<void> { ): Promise<void> {
// Set attachment parents if (mediaAttachments.length === 0) {
return;
}
// Remove old attachments
await db await db
.update(Attachments) .update(Attachments)
.set({ .set({
noteId: null, noteId: null,
}) })
.where(eq(Attachments.noteId, this.data.id)); .where(eq(Attachments.noteId, this.data.id));
if (mediaAttachments.length > 0) {
await db await db
.update(Attachments) .update(Attachments)
.set({ .set({
noteId: this.data.id, noteId: this.data.id,
}) })
.where(inArray(Attachments.id, mediaAttachments)); .where(
} inArray(
Attachments.id,
mediaAttachments.map((i) => i.id),
),
);
} }
/** /**
@ -705,7 +708,7 @@ export class Note extends BaseInterface<typeof Notes, NoteTypeWithRelations> {
return await Note.fromId(uuid[0]); return await Note.fromId(uuid[0]);
} }
return await Note.saveFromRemote(uri); return await Note.fetchFromRemote(uri);
} }
/** /**
@ -713,19 +716,13 @@ export class Note extends BaseInterface<typeof Notes, NoteTypeWithRelations> {
* @param uri - The URI of the note to save * @param uri - The URI of the note to save
* @returns The saved note, or null if the note could not be fetched * @returns The saved note, or null if the note could not be fetched
*/ */
public static async saveFromRemote(uri: string): Promise<Note | null> { public static async fetchFromRemote(uri: string): Promise<Note | null> {
let note: VersiaNote | null = null;
const instance = await Instance.resolve(uri); const instance = await Instance.resolve(uri);
if (!instance) { if (!instance) {
return null; return null;
} }
if (uri) {
if (!URL.canParse(uri)) {
throw new Error(`Invalid URI to parse ${uri}`);
}
const requester = await User.getFederationRequester(); const requester = await User.getFederationRequester();
const { data } = await requester.get(uri, { const { data } = await requester.get(uri, {
@ -733,12 +730,7 @@ export class Note extends BaseInterface<typeof Notes, NoteTypeWithRelations> {
proxy: config.http.proxy.address, proxy: config.http.proxy.address,
}); });
note = await new EntityValidator().Note(data); const note = await new EntityValidator().Note(data);
}
if (!note) {
throw new Error("No note was able to be fetched");
}
const author = await User.resolve(note.author); const author = await User.resolve(note.author);
@ -753,6 +745,7 @@ export class Note extends BaseInterface<typeof Notes, NoteTypeWithRelations> {
* Turns a Versia Note into a database note (saved) * Turns a Versia Note into a database note (saved)
* @param note Versia Note * @param note Versia Note
* @param author Author of the note * @param author Author of the note
* @param instance Instance of the note
* @returns The saved note * @returns The saved note
*/ */
public static async fromVersia( public static async fromVersia(
@ -824,7 +817,7 @@ export class Note extends BaseInterface<typeof Notes, NoteTypeWithRelations> {
.map((mention) => User.resolve(mention)) .map((mention) => User.resolve(mention))
.filter((mention) => mention !== null) as Promise<User>[], .filter((mention) => mention !== null) as Promise<User>[],
), ),
mediaAttachments: attachments.map((a) => a.id), mediaAttachments: attachments,
replyId: note.replies_to replyId: note.replies_to
? (await Note.resolve(note.replies_to))?.data.id ? (await Note.resolve(note.replies_to))?.data.id
: undefined, : undefined,

View file

@ -597,7 +597,7 @@ export class User extends BaseInterface<typeof Users, UserWithRelations> {
); );
} }
const updated = await User.saveFromRemote(this.getUri()); const updated = await User.fetchFromRemote(this.getUri());
if (!updated) { if (!updated) {
throw new Error("Failed to update user from remote"); throw new Error("Failed to update user from remote");
@ -608,7 +608,7 @@ export class User extends BaseInterface<typeof Users, UserWithRelations> {
return this; return this;
} }
public static async saveFromRemote(uri: string): Promise<User | null> { public static async fetchFromRemote(uri: string): Promise<User | null> {
if (!URL.canParse(uri)) { if (!URL.canParse(uri)) {
throw new Error(`Invalid URI: ${uri}`); throw new Error(`Invalid URI: ${uri}`);
} }
@ -644,18 +644,12 @@ export class User extends BaseInterface<typeof Users, UserWithRelations> {
private static async saveFromVersia( private static async saveFromVersia(
uri: string, uri: string,
instance: Instance, instance: Instance,
): Promise<User | null> { ): Promise<User> {
const requester = await User.getFederationRequester(); const requester = await User.getFederationRequester();
const output = await requester const output = await requester.get<Partial<VersiaUser>>(uri, {
.get<Partial<VersiaUser>>(uri, {
// @ts-expect-error Bun extension // @ts-expect-error Bun extension
proxy: config.http.proxy.address, proxy: config.http.proxy.address,
}) });
.catch(() => null);
if (!output) {
return null;
}
const { data: json } = output; const { data: json } = output;
@ -664,32 +658,30 @@ export class User extends BaseInterface<typeof Users, UserWithRelations> {
const user = await User.fromVersia(data, instance); const user = await User.fromVersia(data, instance);
const userEmojis = await searchManager.addUser(user);
data.extensions?.["pub.versia:custom_emojis"]?.emojis ?? [];
const emojis = await Promise.all(
userEmojis.map((emoji) => Emoji.fromVersia(emoji, instance)),
);
if (emojis.length > 0) { return user;
await db.delete(EmojiToUser).where(eq(EmojiToUser.userId, user.id)); }
/**
* Change the emojis linked to this user in database
* @param emojis
* @returns
*/
public async updateEmojis(emojis: Emoji[]): Promise<void> {
if (emojis.length === 0) {
return;
}
await db.delete(EmojiToUser).where(eq(EmojiToUser.userId, this.id));
await db.insert(EmojiToUser).values( await db.insert(EmojiToUser).values(
emojis.map((emoji) => ({ emojis.map((emoji) => ({
emojiId: emoji.id, emojiId: emoji.id,
userId: user.id, userId: this.id,
})), })),
); );
} }
const finalUser = await User.fromId(user.id);
if (!finalUser) {
throw new Error("Failed to save user from remote");
}
await searchManager.addUser(finalUser);
return finalUser;
}
public static async fromVersia( public static async fromVersia(
user: VersiaUser, user: VersiaUser,
instance: Instance, instance: Instance,
@ -729,19 +721,29 @@ export class User extends BaseInterface<typeof Users, UserWithRelations> {
}, },
}; };
// Check if new user already exists const userEmojis =
user.extensions?.["pub.versia:custom_emojis"]?.emojis ?? [];
const emojis = await Promise.all(
userEmojis.map((emoji) => Emoji.fromVersia(emoji, instance)),
);
// Check if new user already exists
const foundUser = await User.fromSql(eq(Users.uri, user.uri)); const foundUser = await User.fromSql(eq(Users.uri, user.uri));
// If it exists, simply update it // If it exists, simply update it
if (foundUser) { if (foundUser) {
await foundUser.update(data); await foundUser.update(data);
await foundUser.updateEmojis(emojis);
return foundUser; return foundUser;
} }
// Else, create a new user // Else, create a new user
return await User.insert(data); const newUser = await User.insert(data);
await newUser.updateEmojis(emojis);
return newUser;
} }
public static async insert( public static async insert(
@ -784,7 +786,7 @@ export class User extends BaseInterface<typeof Users, UserWithRelations> {
getLogger(["federation", "resolvers"]) getLogger(["federation", "resolvers"])
.debug`User not found in database, fetching from remote`; .debug`User not found in database, fetching from remote`;
return await User.saveFromRemote(uri); return await User.fetchFromRemote(uri);
} }
/** /**

View file

@ -2,7 +2,7 @@ import { beforeEach, describe, expect, jest, mock, test } from "bun:test";
import { SignatureValidator } from "@versia/federation"; import { SignatureValidator } from "@versia/federation";
import type { Entity, Note as VersiaNote } from "@versia/federation/types"; import type { Entity, Note as VersiaNote } from "@versia/federation/types";
import { import {
type Instance, Instance,
Note, Note,
Notification, Notification,
Relationship, Relationship,
@ -23,11 +23,12 @@ mock.module("@versia/kit/db", () => ({
}, },
User: { User: {
resolve: jest.fn(), resolve: jest.fn(),
saveFromRemote: jest.fn(), fetchFromRemote: jest.fn(),
sendFollowAccept: jest.fn(), sendFollowAccept: jest.fn(),
}, },
Instance: { Instance: {
fromUser: jest.fn(), fromUser: jest.fn(),
resolve: jest.fn(),
}, },
Note: { Note: {
resolve: jest.fn(), resolve: jest.fn(),
@ -198,9 +199,11 @@ describe("InboxProcessor", () => {
test("successfully processes valid note", async () => { test("successfully processes valid note", async () => {
const mockNote = { author: "test-author" }; const mockNote = { author: "test-author" };
const mockAuthor = { id: "test-id" }; const mockAuthor = { id: "test-id" };
const mockInstance = { id: "test-id" };
User.resolve = jest.fn().mockResolvedValue(mockAuthor); User.resolve = jest.fn().mockResolvedValue(mockAuthor);
Note.fromVersia = jest.fn().mockResolvedValue(true); Note.fromVersia = jest.fn().mockResolvedValue(true);
Instance.resolve = jest.fn().mockResolvedValue(mockInstance);
// biome-ignore lint/complexity/useLiteralKeys: Private variable // biome-ignore lint/complexity/useLiteralKeys: Private variable
processor["body"] = mockNote as VersiaNote; processor["body"] = mockNote as VersiaNote;
@ -208,7 +211,11 @@ describe("InboxProcessor", () => {
const result = await processor["processNote"](); const result = await processor["processNote"]();
expect(User.resolve).toHaveBeenCalledWith("test-author"); expect(User.resolve).toHaveBeenCalledWith("test-author");
expect(Note.fromVersia).toHaveBeenCalledWith(mockNote, mockAuthor); expect(Note.fromVersia).toHaveBeenCalledWith(
mockNote,
mockAuthor,
mockInstance,
);
expect(result).toBeNull(); expect(result).toBeNull();
}); });
@ -364,19 +371,19 @@ describe("InboxProcessor", () => {
}; };
const mockUpdatedUser = { id: "user-id" }; const mockUpdatedUser = { id: "user-id" };
User.saveFromRemote = jest.fn().mockResolvedValue(mockUpdatedUser); User.fetchFromRemote = jest.fn().mockResolvedValue(mockUpdatedUser);
// biome-ignore lint/complexity/useLiteralKeys: Private variable // biome-ignore lint/complexity/useLiteralKeys: Private variable
processor["body"] = mockUser as unknown as Entity; processor["body"] = mockUser as unknown as Entity;
// biome-ignore lint/complexity/useLiteralKeys: Private method // biome-ignore lint/complexity/useLiteralKeys: Private method
const result = await processor["processUserRequest"](); const result = await processor["processUserRequest"]();
expect(User.saveFromRemote).toHaveBeenCalledWith("test-uri"); expect(User.fetchFromRemote).toHaveBeenCalledWith("test-uri");
expect(result).toBeNull(); expect(result).toBeNull();
}); });
test("returns 500 when update fails", async () => { test("returns 500 when update fails", async () => {
User.saveFromRemote = jest.fn().mockResolvedValue(null); User.fetchFromRemote = jest.fn().mockResolvedValue(null);
// biome-ignore lint/complexity/useLiteralKeys: Private method // biome-ignore lint/complexity/useLiteralKeys: Private method
const result = await processor["processUserRequest"](); const result = await processor["processUserRequest"]();

View file

@ -554,7 +554,7 @@ export class InboxProcessor {
private async processUserRequest(): Promise<Response | null> { private async processUserRequest(): Promise<Response | null> {
const user = this.body as unknown as VersiaUser; const user = this.body as unknown as VersiaUser;
// FIXME: Instead of refetching the remote user, we should read the incoming json and update from that // FIXME: Instead of refetching the remote user, we should read the incoming json and update from that
const updatedAccount = await User.saveFromRemote(user.uri); const updatedAccount = await User.fetchFromRemote(user.uri);
if (!updatedAccount) { if (!updatedAccount) {
return Response.json( return Response.json(