From 31171b5fc772ef811f44d5a1f304db177bdf854f Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sun, 6 Jul 2025 02:35:02 +0000 Subject: [PATCH] Add poll database schema and basic implementation Co-authored-by: CPlusPatch <42910258+CPlusPatch@users.noreply.github.com> --- api/api/v1/statuses/index.ts | 42 ++++- classes/database/note.ts | 5 +- classes/database/poll.ts | 284 ++++++++++++++++++++++++++++++ classes/functions/status.ts | 28 +++ drizzle/schema.ts | 98 +++++++++++ packages/plugin-kit/exports/db.ts | 1 + 6 files changed, 455 insertions(+), 3 deletions(-) create mode 100644 classes/database/poll.ts diff --git a/api/api/v1/statuses/index.ts b/api/api/v1/statuses/index.ts index 7bcc0dd5..5bb208a7 100644 --- a/api/api/v1/statuses/index.ts +++ b/api/api/v1/statuses/index.ts @@ -6,7 +6,7 @@ import { StatusSource as StatusSourceSchema, zBoolean, } from "@versia/client/schemas"; -import { Emoji, Media, Note } from "@versia/kit/db"; +import { Emoji, Media, Note, Poll } from "@versia/kit/db"; import { randomUUIDv7 } from "bun"; import { describeRoute } from "hono-openapi"; import { resolver, validator } from "hono-openapi/zod"; @@ -164,6 +164,10 @@ export default apiRoute((app) => visibility, content_type, local_only, + "poll[options]": pollOptions, + "poll[expires_in]": pollExpiresIn, + "poll[multiple]": pollMultiple, + "poll[hide_totals]": pollHideTotals, } = context.req.valid("json"); // Check if media attachments are all valid @@ -177,6 +181,27 @@ export default apiRoute((app) => ); } + // Validate poll parameters + if (pollOptions && pollOptions.length > 0) { + if (media_ids.length > 0) { + throw new ApiError(422, "Cannot attach poll to media"); + } + + if (!pollExpiresIn) { + throw new ApiError( + 422, + "poll[expires_in] must be provided when creating a poll" + ); + } + + if (pollOptions.length < 2) { + throw new ApiError( + 422, + "Poll must have at least 2 options" + ); + } + } + const reply = in_reply_to_id ? await Note.fromId(in_reply_to_id) : null; @@ -248,6 +273,21 @@ export default apiRoute((app) => await newNote.updateMentions(parsedMentions); await newNote.updateAttachments(foundAttachments); + // Create poll if poll options are provided + if (pollOptions && pollOptions.length > 0 && pollExpiresIn) { + const expiresAt = new Date(Date.now() + pollExpiresIn * 1000).toISOString(); + + await Poll.insert({ + id: randomUUIDv7(), + noteId: newNote.data.id, + expiresAt, + multiple: pollMultiple ?? false, + hideTotals: pollHideTotals ?? false, + votesCount: 0, + votersCount: 0, + }, pollOptions); + } + await newNote.reload(); if (!local_only) { diff --git a/classes/database/note.ts b/classes/database/note.ts index 25fa185b..0ebce24a 100644 --- a/classes/database/note.ts +++ b/classes/database/note.ts @@ -28,6 +28,7 @@ import { mergeAndDeduplicate } from "@/lib.ts"; import { sanitizedHtmlStrip } from "@/sanitization"; import { contentToHtml, findManyNotes } from "~/classes/functions/status"; import { config } from "~/config.ts"; +import { Poll } from "./poll.ts"; import * as VersiaEntities from "~/packages/sdk/entities/index.ts"; import type { NonTextContentFormatSchema } from "~/packages/sdk/schemas/contentformat.ts"; import { DeliveryJobType, deliveryQueue } from "../queues/delivery.ts"; @@ -55,6 +56,7 @@ type NoteTypeWithRelations = NoteType & { muted: boolean; liked: boolean; reactions: Omit[]; + poll: typeof Poll.$type | null; }; export type NoteTypeWithoutRecursiveRelations = Omit< @@ -691,8 +693,7 @@ export class Note extends BaseInterface { language: null, muted: data.muted, pinned: data.pinned, - // TODO: Add polls - poll: null, + poll: data.poll ? data.poll.toApi(userFetching) : null, reblog: data.reblog ? await new Note(data.reblog as NoteTypeWithRelations).toApi( userFetching, diff --git a/classes/database/poll.ts b/classes/database/poll.ts new file mode 100644 index 00000000..b2cbc172 --- /dev/null +++ b/classes/database/poll.ts @@ -0,0 +1,284 @@ +import { db } from "@versia/kit/db"; +import { + Notes, + PollOptions, + Polls, + PollVotes, + type Users, +} from "@versia/kit/tables"; +import { + and, + eq, + type InferInsertModel, + type InferSelectModel, + inArray, +} from "drizzle-orm"; +import type { z } from "zod"; +import type { Poll as PollSchema } from "@versia/client/schemas"; +import { BaseInterface } from "./base.ts"; + +/** + * Type definition for Poll with all relations + */ +type PollTypeWithRelations = InferSelectModel & { + options: (InferSelectModel & { + votes: InferSelectModel[]; + })[]; + votes: InferSelectModel[]; +}; + +/** + * Database class for managing polls + */ +export class Poll extends BaseInterface { + public static $type: PollTypeWithRelations; + + /** + * Reload the poll data from the database + */ + public async reload(): Promise { + const reloaded = await Poll.fromId(this.data.id); + + if (!reloaded) { + throw new Error("Failed to reload poll"); + } + + this.data = reloaded.data; + } + + /** + * Get a poll by ID + * @param id - The poll ID + * @returns The poll instance or null if not found + */ + public static async fromId(id: string | null): Promise { + if (!id) { + return null; + } + + return await Poll.fromSql(eq(Polls.id, id)); + } + + /** + * Get a poll by note ID + * @param noteId - The note ID + * @returns The poll instance or null if not found + */ + public static async fromNoteId(noteId: string): Promise { + return await Poll.fromSql(eq(Polls.noteId, noteId)); + } + + /** + * Get multiple polls by IDs + * @param ids - Array of poll IDs + * @returns Array of poll instances + */ + public static async fromIds(ids: string[]): Promise { + return await Poll.manyFromSql(inArray(Polls.id, ids)); + } + + /** + * Execute SQL query to get a single poll with relations + * @param sql - SQL condition + * @returns Poll instance or null + */ + protected static async fromSql(sql: any): Promise { + const result = await db + .select() + .from(Polls) + .leftJoin(PollOptions, eq(Polls.id, PollOptions.pollId)) + .leftJoin(PollVotes, eq(PollOptions.id, PollVotes.optionId)) + .where(sql); + + if (result.length === 0) { + return null; + } + + // Group the results to build the poll object with options + const pollData = result[0].Polls; + const optionsMap = new Map(); + const votesData: InferSelectModel[] = []; + + for (const row of result) { + if (row.PollOptions) { + if (!optionsMap.has(row.PollOptions.id)) { + optionsMap.set(row.PollOptions.id, { + ...row.PollOptions, + votes: [], + }); + } + + if (row.PollVotes) { + optionsMap.get(row.PollOptions.id)!.votes.push(row.PollVotes); + votesData.push(row.PollVotes); + } + } + } + + const options = Array.from(optionsMap.values()).sort((a, b) => a.index - b.index); + + const pollWithRelations: PollTypeWithRelations = { + ...pollData, + options, + votes: votesData, + }; + + return new Poll(pollWithRelations); + } + + /** + * Execute SQL query to get multiple polls with relations + * @param sql - SQL condition + * @returns Array of poll instances + */ + protected static async manyFromSql(sql: any): Promise { + const result = await db + .select() + .from(Polls) + .leftJoin(PollOptions, eq(Polls.id, PollOptions.pollId)) + .leftJoin(PollVotes, eq(PollOptions.id, PollVotes.optionId)) + .where(sql); + + if (result.length === 0) { + return []; + } + + // Group by poll ID + const pollsMap = new Map(); + + for (const row of result) { + const pollId = row.Polls.id; + + if (!pollsMap.has(pollId)) { + pollsMap.set(pollId, { + ...row.Polls, + options: new Map(), + votes: [], + }); + } + + const poll = pollsMap.get(pollId); + + if (row.PollOptions) { + if (!poll.options.has(row.PollOptions.id)) { + poll.options.set(row.PollOptions.id, { + ...row.PollOptions, + votes: [], + }); + } + + if (row.PollVotes) { + poll.options.get(row.PollOptions.id)!.votes.push(row.PollVotes); + poll.votes.push(row.PollVotes); + } + } + } + + return Array.from(pollsMap.values()).map((pollData) => { + const options = Array.from(pollData.options.values()).sort( + (a, b) => a.index - b.index, + ); + + return new Poll({ + ...pollData, + options, + votes: pollData.votes, + }); + }); + } + + /** + * Insert a new poll into the database + * @param pollData - Poll data to insert + * @param options - Poll options to insert + * @returns The inserted poll instance + */ + public static async insert( + pollData: InferInsertModel, + options: string[], + ): Promise { + return await db.transaction(async (tx) => { + // Insert the poll + const insertedPoll = (await tx.insert(Polls).values(pollData).returning())[0]; + + // Insert poll options + const optionInserts = options.map((title, index) => ({ + id: crypto.randomUUID(), + pollId: insertedPoll.id, + title, + index, + votesCount: 0, + })); + + await tx.insert(PollOptions).values(optionInserts); + + // Return the poll with relations + const poll = await Poll.fromId(insertedPoll.id); + if (!poll) { + throw new Error("Failed to retrieve inserted poll"); + } + + return poll; + }); + } + + /** + * Check if the poll has expired + * @returns True if the poll has expired + */ + public isExpired(): boolean { + if (!this.data.expiresAt) { + return false; + } + + return new Date(this.data.expiresAt) < new Date(); + } + + /** + * Check if a user has voted in this poll + * @param userId - The user ID to check + * @returns True if the user has voted + */ + public hasUserVoted(userId: string): boolean { + return this.data.votes.some((vote) => vote.userId === userId); + } + + /** + * Get the vote options for a specific user + * @param userId - The user ID + * @returns Array of option indices the user voted for + */ + public getUserVotes(userId: string): number[] { + const userVotes = this.data.votes.filter((vote) => vote.userId === userId); + return userVotes.map((vote) => { + const option = this.data.options.find((opt) => opt.id === vote.optionId); + return option?.index ?? -1; + }).filter((index) => index !== -1); + } + + /** + * Convert poll to Mastodon API format + * @param userFetching - The user fetching the poll (to check if they voted) + * @returns Poll in Mastodon API format + */ + public toApi(userFetching?: { id: string } | null): z.infer { + const voted = userFetching ? this.hasUserVoted(userFetching.id) : undefined; + const ownVotes = userFetching ? this.getUserVotes(userFetching.id) : undefined; + + return { + id: this.data.id, + expires_at: this.data.expiresAt, + expired: this.isExpired(), + multiple: this.data.multiple, + votes_count: this.data.votesCount, + voters_count: this.data.votersCount, + options: this.data.options.map((option) => ({ + title: option.title, + votes_count: this.data.hideTotals && !this.isExpired() ? null : option.votesCount, + })), + emojis: [], // TODO: Parse emojis from poll options + voted, + own_votes: ownVotes, + }; + } +} \ No newline at end of file diff --git a/classes/functions/status.ts b/classes/functions/status.ts index 55988fd3..6a53d2d7 100644 --- a/classes/functions/status.ts +++ b/classes/functions/status.ts @@ -115,6 +115,16 @@ export const findManyNotes = async ( ...userRelations, }, }, + poll: { + with: { + options: { + with: { + votes: true, + }, + }, + votes: true, + }, + }, }, extras: { pinned: userId @@ -141,6 +151,16 @@ export const findManyNotes = async ( }, reply: true, quote: true, + poll: { + with: { + options: { + with: { + votes: true, + }, + }, + votes: true, + }, + }, }, extras: { pinned: userId @@ -176,6 +196,10 @@ export const findManyNotes = async ( })), attachments: post.attachments.map((attachment) => attachment.media), emojis: (post.emojis ?? []).map((emoji) => emoji.emoji), + poll: post.poll ? { + ...post.poll, + options: post.poll.options.sort((a, b) => a.index - b.index), + } : null, reblog: post.reblog && { ...post.reblog, author: transformOutputToUserWithRelations(post.reblog.author), @@ -187,6 +211,10 @@ export const findManyNotes = async ( (attachment) => attachment.media, ), emojis: (post.reblog.emojis ?? []).map((emoji) => emoji.emoji), + poll: post.reblog.poll ? { + ...post.reblog.poll, + options: post.reblog.poll.options.sort((a, b) => a.index - b.index), + } : null, pinned: Boolean(post.reblog.pinned), reblogged: Boolean(post.reblog.reblogged), muted: Boolean(post.reblog.muted), diff --git a/drizzle/schema.ts b/drizzle/schema.ts index b174f359..f6825f27 100644 --- a/drizzle/schema.ts +++ b/drizzle/schema.ts @@ -512,6 +512,10 @@ export const NotesRelations = relations(Notes, ({ many, one }) => ({ reactions: many(Reactions, { relationName: "NoteToReactions", }), + poll: one(Polls, { + fields: [Notes.id], + references: [Polls.noteId], + }), })); export const Instances = pgTable("Instances", { @@ -947,3 +951,97 @@ export const MediasToNotesRelations = relations(MediasToNotes, ({ one }) => ({ relationName: "AttachmentToNote", }), })); + +export const Polls = pgTable("Polls", { + id: id(), + noteId: uuid("noteId") + .notNull() + .references(() => Notes.id, { + onDelete: "cascade", + onUpdate: "cascade", + }) + .unique(), + expiresAt: timestamp("expires_at", { precision: 3, mode: "string" }), + multiple: boolean("multiple").notNull().default(false), + hideTotals: boolean("hide_totals").notNull().default(false), + votesCount: integer("votes_count").notNull().default(0), + votersCount: integer("voters_count").notNull().default(0), + createdAt: createdAt(), + updatedAt: updatedAt(), +}); + +export const PollOptions = pgTable("PollOptions", { + id: id(), + pollId: uuid("pollId") + .notNull() + .references(() => Polls.id, { + onDelete: "cascade", + onUpdate: "cascade", + }), + title: text("title").notNull(), + index: integer("index").notNull(), + votesCount: integer("votes_count").notNull().default(0), +}); + +export const PollVotes = pgTable( + "PollVotes", + { + id: id(), + pollId: uuid("pollId") + .notNull() + .references(() => Polls.id, { + onDelete: "cascade", + onUpdate: "cascade", + }), + optionId: uuid("optionId") + .notNull() + .references(() => PollOptions.id, { + onDelete: "cascade", + onUpdate: "cascade", + }), + userId: uuid("userId") + .notNull() + .references(() => Users.id, { + onDelete: "cascade", + onUpdate: "cascade", + }), + createdAt: createdAt(), + }, + (table) => [ + uniqueIndex().on(table.pollId, table.userId, table.optionId), + index().on(table.pollId), + index().on(table.userId), + ], +); + +export const PollsRelations = relations(Polls, ({ one, many }) => ({ + note: one(Notes, { + fields: [Polls.noteId], + references: [Notes.id], + }), + options: many(PollOptions), + votes: many(PollVotes), +})); + +export const PollOptionsRelations = relations(PollOptions, ({ one, many }) => ({ + poll: one(Polls, { + fields: [PollOptions.pollId], + references: [Polls.id], + }), + votes: many(PollVotes), +})); + +export const PollVotesRelations = relations(PollVotes, ({ one }) => ({ + poll: one(Polls, { + fields: [PollVotes.pollId], + references: [Polls.id], + }), + option: one(PollOptions, { + fields: [PollVotes.optionId], + references: [PollOptions.id], + }), + user: one(Users, { + fields: [PollVotes.userId], + references: [Users.id], + }), +})); diff --git a/packages/plugin-kit/exports/db.ts b/packages/plugin-kit/exports/db.ts index 7754d60a..2d0dad2a 100644 --- a/packages/plugin-kit/exports/db.ts +++ b/packages/plugin-kit/exports/db.ts @@ -5,6 +5,7 @@ export { Like } from "~/classes/database/like.ts"; export { Media } from "~/classes/database/media"; export { Note } from "~/classes/database/note.ts"; export { Notification } from "~/classes/database/notification.ts"; +export { Poll } from "~/classes/database/poll.ts"; export { PushSubscription } from "~/classes/database/pushsubscription.ts"; export { Reaction } from "~/classes/database/reaction.ts"; export { Relationship } from "~/classes/database/relationship.ts";