refactor(database): ♻️ Simplify Note and User federation logic

This commit is contained in:
Jesse Wierzbinski 2024-12-09 13:36:15 +01:00
parent cbbf49905b
commit a8541bdc44
No known key found for this signature in database
7 changed files with 137 additions and 153 deletions

View file

@ -1,8 +1,8 @@
import { apiRoute, applyConfig, auth, jsonOrForm } from "@/api";
import { sanitizedHtmlStrip } from "@/sanitization";
import { createRoute } from "@hono/zod-openapi";
import { Attachment, Emoji, User, db } from "@versia/kit/db";
import { EmojiToUser, RolePermissions, Users } from "@versia/kit/tables";
import { Attachment, Emoji, User } from "@versia/kit/db";
import { RolePermissions, Users } from "@versia/kit/tables";
import { and, eq, isNull } from "drizzle-orm";
import ISO6391 from "iso-639-1";
import { z } from "zod";
@ -335,9 +335,11 @@ export default apiRoute((app) =>
await Emoji.parseFromText(sanitizedDisplayName);
const noteEmojis = await Emoji.parseFromText(self.note);
self.emojis = [...displaynameEmojis, ...noteEmojis, ...fieldEmojis]
.map((e) => e.data)
.filter(
const emojis = [
...displaynameEmojis,
...noteEmojis,
...fieldEmojis,
].filter(
// Deduplicate emojis
(emoji, index, self) =>
self.findIndex((e) => e.id === emoji.id) === index,
@ -345,26 +347,7 @@ export default apiRoute((app) =>
// Connect emojis, if any
// Do it before updating user, so that federation takes that into account
for (const emoji of self.emojis) {
await db
.delete(EmojiToUser)
.where(
and(
eq(EmojiToUser.emojiId, emoji.id),
eq(EmojiToUser.userId, self.id),
),
)
.execute();
await db
.insert(EmojiToUser)
.values({
emojiId: emoji.id,
userId: self.id,
})
.execute();
}
await user.updateEmojis(emojis);
await user.update({
displayName: self.displayName,
username: self.username,
@ -379,6 +362,7 @@ export default apiRoute((app) =>
});
const output = await User.fromId(self.id);
if (!output) {
return context.json({ error: "Couldn't edit user" }, 500);
}

View file

@ -271,13 +271,12 @@ export default apiRoute((app) => {
sensitive,
} = context.req.valid("json");
if (media_ids.length > 0) {
const foundAttachments = await Attachment.fromIds(media_ids);
const foundAttachments =
media_ids.length > 0 ? await Attachment.fromIds(media_ids) : [];
if (foundAttachments.length !== media_ids.length) {
return context.json({ error: "Invalid media IDs" }, 422);
}
}
const newNote = await note.updateFromData({
author: user,
@ -291,7 +290,7 @@ export default apiRoute((app) => {
: undefined,
isSensitive: sensitive,
spoilerText: spoiler_text,
mediaAttachments: media_ids,
mediaAttachments: foundAttachments,
});
return context.json(await newNote.toApi(user), 200);

View file

@ -168,13 +168,12 @@ export default apiRoute((app) =>
} = context.req.valid("json");
// Check if media attachments are all valid
if (media_ids.length > 0) {
const foundAttachments = await Attachment.fromIds(media_ids);
const foundAttachments =
media_ids.length > 0 ? await Attachment.fromIds(media_ids) : [];
if (foundAttachments.length !== media_ids.length) {
return context.json({ error: "Invalid media IDs" }, 422);
}
}
// Check that in_reply_to_id and quote_id are real posts if provided
if (in_reply_to_id && !(await Note.fromId(in_reply_to_id))) {
@ -199,7 +198,7 @@ export default apiRoute((app) =>
visibility,
isSensitive: sensitive ?? false,
spoilerText: spoiler_text ?? "",
mediaAttachments: media_ids,
mediaAttachments: foundAttachments,
replyId: in_reply_to_id ?? undefined,
quoteId: quote_id ?? undefined,
application: application ?? undefined,

View file

@ -417,7 +417,7 @@ export class Note extends BaseInterface<typeof Notes, NoteTypeWithRelations> {
throw new Error("Cannot refetch a local note (it is not remote)");
}
const updated = await Note.saveFromRemote(this.getUri());
const updated = await Note.fetchFromRemote(this.getUri());
if (!updated) {
throw new Error("Note not found after update");
@ -443,7 +443,7 @@ export class Note extends BaseInterface<typeof Notes, NoteTypeWithRelations> {
uri?: string;
mentions?: User[];
/** List of IDs of database Attachment objects */
mediaAttachments?: string[];
mediaAttachments?: Attachment[];
replyId?: string;
quoteId?: string;
application?: Application;
@ -491,15 +491,13 @@ export class Note extends BaseInterface<typeof Notes, NoteTypeWithRelations> {
});
// Connect emojis
await newNote.recalculateDatabaseEmojis(parsedEmojis);
await newNote.updateEmojis(parsedEmojis);
// Connect mentions
await newNote.recalculateDatabaseMentions(parsedMentions);
await newNote.updateMentions(parsedMentions);
// Set attachment parents
await newNote.recalculateDatabaseAttachments(
data.mediaAttachments ?? [],
);
await newNote.updateAttachments(data.mediaAttachments ?? []);
// Send notifications for mentioned local users
for (const mention of parsedMentions ?? []) {
@ -532,8 +530,7 @@ export class Note extends BaseInterface<typeof Notes, NoteTypeWithRelations> {
emojis?: Emoji[];
uri?: string;
mentions?: User[];
/** List of IDs of database Attachment objects */
mediaAttachments?: string[];
mediaAttachments?: Attachment[];
replyId?: string;
quoteId?: string;
application?: Application;
@ -587,13 +584,13 @@ export class Note extends BaseInterface<typeof Notes, NoteTypeWithRelations> {
});
// Connect emojis
await this.recalculateDatabaseEmojis(parsedEmojis);
await this.updateEmojis(parsedEmojis);
// Connect mentions
await this.recalculateDatabaseMentions(parsedMentions);
await this.updateMentions(parsedMentions);
// Set attachment parents
await this.recalculateDatabaseAttachments(data.mediaAttachments ?? []);
await this.updateAttachments(data.mediaAttachments ?? []);
await this.reload(data.author.id);
@ -606,7 +603,11 @@ export class Note extends BaseInterface<typeof Notes, NoteTypeWithRelations> {
* Deletes all existing emojis associated with this note, then replaces them with the provided emojis.
* @param emojis - The emojis to associate with this note
*/
public async recalculateDatabaseEmojis(emojis: Emoji[]): Promise<void> {
public async updateEmojis(emojis: Emoji[]): Promise<void> {
if (emojis.length === 0) {
return;
}
// Fuse and deduplicate
const fusedEmojis = emojis.filter(
(emoji, index, self) =>
@ -617,16 +618,12 @@ export class Note extends BaseInterface<typeof Notes, NoteTypeWithRelations> {
await db
.delete(EmojiToNote)
.where(eq(EmojiToNote.noteId, this.data.id));
for (const emoji of fusedEmojis) {
await db
.insert(EmojiToNote)
.values({
await db.insert(EmojiToNote).values(
fusedEmojis.map((emoji) => ({
emojiId: emoji.id,
noteId: this.data.id,
})
.execute();
}
})),
);
}
/**
@ -635,21 +632,21 @@ export class Note extends BaseInterface<typeof Notes, NoteTypeWithRelations> {
* Deletes all existing mentions associated with this note, then replaces them with the provided mentions.
* @param mentions - The mentions to associate with this note
*/
public async recalculateDatabaseMentions(mentions: User[]): Promise<void> {
public async updateMentions(mentions: User[]): Promise<void> {
if (mentions.length === 0) {
return;
}
// Connect mentions
await db
.delete(NoteToMentions)
.where(eq(NoteToMentions.noteId, this.data.id));
for (const mention of mentions) {
await db
.insert(NoteToMentions)
.values({
await db.insert(NoteToMentions).values(
mentions.map((mention) => ({
noteId: this.data.id,
userId: mention.id,
})
.execute();
}
})),
);
}
/**
@ -658,25 +655,31 @@ export class Note extends BaseInterface<typeof Notes, NoteTypeWithRelations> {
* Deletes all existing attachments associated with this note, then replaces them with the provided attachments.
* @param mediaAttachments - The IDs of the attachments to associate with this note
*/
public async recalculateDatabaseAttachments(
mediaAttachments: string[],
public async updateAttachments(
mediaAttachments: Attachment[],
): Promise<void> {
// Set attachment parents
if (mediaAttachments.length === 0) {
return;
}
// Remove old attachments
await db
.update(Attachments)
.set({
noteId: null,
})
.where(eq(Attachments.noteId, this.data.id));
if (mediaAttachments.length > 0) {
await db
.update(Attachments)
.set({
noteId: this.data.id,
})
.where(inArray(Attachments.id, mediaAttachments));
}
.where(
inArray(
Attachments.id,
mediaAttachments.map((i) => i.id),
),
);
}
/**
@ -705,7 +708,7 @@ export class Note extends BaseInterface<typeof Notes, NoteTypeWithRelations> {
return await Note.fromId(uuid[0]);
}
return await Note.saveFromRemote(uri);
return await Note.fetchFromRemote(uri);
}
/**
@ -713,19 +716,13 @@ export class Note extends BaseInterface<typeof Notes, NoteTypeWithRelations> {
* @param uri - The URI of the note to save
* @returns The saved note, or null if the note could not be fetched
*/
public static async saveFromRemote(uri: string): Promise<Note | null> {
let note: VersiaNote | null = null;
public static async fetchFromRemote(uri: string): Promise<Note | null> {
const instance = await Instance.resolve(uri);
if (!instance) {
return null;
}
if (uri) {
if (!URL.canParse(uri)) {
throw new Error(`Invalid URI to parse ${uri}`);
}
const requester = await User.getFederationRequester();
const { data } = await requester.get(uri, {
@ -733,12 +730,7 @@ export class Note extends BaseInterface<typeof Notes, NoteTypeWithRelations> {
proxy: config.http.proxy.address,
});
note = await new EntityValidator().Note(data);
}
if (!note) {
throw new Error("No note was able to be fetched");
}
const note = await new EntityValidator().Note(data);
const author = await User.resolve(note.author);
@ -753,6 +745,7 @@ export class Note extends BaseInterface<typeof Notes, NoteTypeWithRelations> {
* Turns a Versia Note into a database note (saved)
* @param note Versia Note
* @param author Author of the note
* @param instance Instance of the note
* @returns The saved note
*/
public static async fromVersia(
@ -824,7 +817,7 @@ export class Note extends BaseInterface<typeof Notes, NoteTypeWithRelations> {
.map((mention) => User.resolve(mention))
.filter((mention) => mention !== null) as Promise<User>[],
),
mediaAttachments: attachments.map((a) => a.id),
mediaAttachments: attachments,
replyId: note.replies_to
? (await Note.resolve(note.replies_to))?.data.id
: undefined,

View file

@ -597,7 +597,7 @@ export class User extends BaseInterface<typeof Users, UserWithRelations> {
);
}
const updated = await User.saveFromRemote(this.getUri());
const updated = await User.fetchFromRemote(this.getUri());
if (!updated) {
throw new Error("Failed to update user from remote");
@ -608,7 +608,7 @@ export class User extends BaseInterface<typeof Users, UserWithRelations> {
return this;
}
public static async saveFromRemote(uri: string): Promise<User | null> {
public static async fetchFromRemote(uri: string): Promise<User | null> {
if (!URL.canParse(uri)) {
throw new Error(`Invalid URI: ${uri}`);
}
@ -644,18 +644,12 @@ export class User extends BaseInterface<typeof Users, UserWithRelations> {
private static async saveFromVersia(
uri: string,
instance: Instance,
): Promise<User | null> {
): Promise<User> {
const requester = await User.getFederationRequester();
const output = await requester
.get<Partial<VersiaUser>>(uri, {
const output = await requester.get<Partial<VersiaUser>>(uri, {
// @ts-expect-error Bun extension
proxy: config.http.proxy.address,
})
.catch(() => null);
if (!output) {
return null;
}
});
const { data: json } = output;
@ -664,32 +658,30 @@ export class User extends BaseInterface<typeof Users, UserWithRelations> {
const user = await User.fromVersia(data, instance);
const userEmojis =
data.extensions?.["pub.versia:custom_emojis"]?.emojis ?? [];
const emojis = await Promise.all(
userEmojis.map((emoji) => Emoji.fromVersia(emoji, instance)),
);
await searchManager.addUser(user);
if (emojis.length > 0) {
await db.delete(EmojiToUser).where(eq(EmojiToUser.userId, user.id));
return user;
}
/**
* Change the emojis linked to this user in database
* @param emojis
* @returns
*/
public async updateEmojis(emojis: Emoji[]): Promise<void> {
if (emojis.length === 0) {
return;
}
await db.delete(EmojiToUser).where(eq(EmojiToUser.userId, this.id));
await db.insert(EmojiToUser).values(
emojis.map((emoji) => ({
emojiId: emoji.id,
userId: user.id,
userId: this.id,
})),
);
}
const finalUser = await User.fromId(user.id);
if (!finalUser) {
throw new Error("Failed to save user from remote");
}
await searchManager.addUser(finalUser);
return finalUser;
}
public static async fromVersia(
user: VersiaUser,
instance: Instance,
@ -729,19 +721,29 @@ export class User extends BaseInterface<typeof Users, UserWithRelations> {
},
};
// Check if new user already exists
const userEmojis =
user.extensions?.["pub.versia:custom_emojis"]?.emojis ?? [];
const emojis = await Promise.all(
userEmojis.map((emoji) => Emoji.fromVersia(emoji, instance)),
);
// Check if new user already exists
const foundUser = await User.fromSql(eq(Users.uri, user.uri));
// If it exists, simply update it
if (foundUser) {
await foundUser.update(data);
await foundUser.updateEmojis(emojis);
return foundUser;
}
// Else, create a new user
return await User.insert(data);
const newUser = await User.insert(data);
await newUser.updateEmojis(emojis);
return newUser;
}
public static async insert(
@ -784,7 +786,7 @@ export class User extends BaseInterface<typeof Users, UserWithRelations> {
getLogger(["federation", "resolvers"])
.debug`User not found in database, fetching from remote`;
return await User.saveFromRemote(uri);
return await User.fetchFromRemote(uri);
}
/**

View file

@ -2,7 +2,7 @@ import { beforeEach, describe, expect, jest, mock, test } from "bun:test";
import { SignatureValidator } from "@versia/federation";
import type { Entity, Note as VersiaNote } from "@versia/federation/types";
import {
type Instance,
Instance,
Note,
Notification,
Relationship,
@ -23,11 +23,12 @@ mock.module("@versia/kit/db", () => ({
},
User: {
resolve: jest.fn(),
saveFromRemote: jest.fn(),
fetchFromRemote: jest.fn(),
sendFollowAccept: jest.fn(),
},
Instance: {
fromUser: jest.fn(),
resolve: jest.fn(),
},
Note: {
resolve: jest.fn(),
@ -198,9 +199,11 @@ describe("InboxProcessor", () => {
test("successfully processes valid note", async () => {
const mockNote = { author: "test-author" };
const mockAuthor = { id: "test-id" };
const mockInstance = { id: "test-id" };
User.resolve = jest.fn().mockResolvedValue(mockAuthor);
Note.fromVersia = jest.fn().mockResolvedValue(true);
Instance.resolve = jest.fn().mockResolvedValue(mockInstance);
// biome-ignore lint/complexity/useLiteralKeys: Private variable
processor["body"] = mockNote as VersiaNote;
@ -208,7 +211,11 @@ describe("InboxProcessor", () => {
const result = await processor["processNote"]();
expect(User.resolve).toHaveBeenCalledWith("test-author");
expect(Note.fromVersia).toHaveBeenCalledWith(mockNote, mockAuthor);
expect(Note.fromVersia).toHaveBeenCalledWith(
mockNote,
mockAuthor,
mockInstance,
);
expect(result).toBeNull();
});
@ -364,19 +371,19 @@ describe("InboxProcessor", () => {
};
const mockUpdatedUser = { id: "user-id" };
User.saveFromRemote = jest.fn().mockResolvedValue(mockUpdatedUser);
User.fetchFromRemote = jest.fn().mockResolvedValue(mockUpdatedUser);
// biome-ignore lint/complexity/useLiteralKeys: Private variable
processor["body"] = mockUser as unknown as Entity;
// biome-ignore lint/complexity/useLiteralKeys: Private method
const result = await processor["processUserRequest"]();
expect(User.saveFromRemote).toHaveBeenCalledWith("test-uri");
expect(User.fetchFromRemote).toHaveBeenCalledWith("test-uri");
expect(result).toBeNull();
});
test("returns 500 when update fails", async () => {
User.saveFromRemote = jest.fn().mockResolvedValue(null);
User.fetchFromRemote = jest.fn().mockResolvedValue(null);
// biome-ignore lint/complexity/useLiteralKeys: Private method
const result = await processor["processUserRequest"]();

View file

@ -554,7 +554,7 @@ export class InboxProcessor {
private async processUserRequest(): Promise<Response | null> {
const user = this.body as unknown as VersiaUser;
// FIXME: Instead of refetching the remote user, we should read the incoming json and update from that
const updatedAccount = await User.saveFromRemote(user.uri);
const updatedAccount = await User.fetchFromRemote(user.uri);
if (!updatedAccount) {
return Response.json(