refactor(api): 🎨 Finish Hono refactor

This commit is contained in:
Jesse Wierzbinski 2024-05-06 08:19:42 +00:00
parent 826a260e90
commit 959dd27ad6
No known key found for this signature in database
20 changed files with 309 additions and 316 deletions

BIN
bun.lockb

Binary file not shown.

View file

@ -1,13 +1,19 @@
import { dualLogger } from "@loggers"; import { dualLogger } from "@loggers";
import { connectMeili } from "@meilisearch"; import { connectMeili } from "@meilisearch";
import { errorResponse } from "@response";
import { config } from "config-manager"; import { config } from "config-manager";
import { Hono } from "hono"; import { Hono } from "hono";
import { LogLevel, LogManager, type MultiLogManager } from "log-manager"; import { LogLevel, LogManager, type MultiLogManager } from "log-manager";
import { setupDatabase } from "~drizzle/db"; import { setupDatabase } from "~drizzle/db";
import { agentBans } from "~middlewares/agent-bans";
import { bait } from "~middlewares/bait";
import { ipBans } from "~middlewares/ip-bans";
import { logger } from "~middlewares/logger";
import { Note } from "~packages/database-interface/note"; import { Note } from "~packages/database-interface/note";
import { handleGlitchRequest } from "~packages/glitch-server/main";
import type { APIRouteExports } from "~packages/server-handler"; import type { APIRouteExports } from "~packages/server-handler";
import { routes } from "~routes"; import { routes } from "~routes";
import { createServer } from "~server2"; import { createServer } from "~server";
const timeAtStart = performance.now(); const timeAtStart = performance.now();
@ -101,6 +107,11 @@ if (isEntry) {
const app = new Hono(); const app = new Hono();
app.use(ipBans);
app.use(agentBans);
app.use(bait);
app.use(logger);
// Inject own filesystem router // Inject own filesystem router
for (const [route, path] of Object.entries(routes)) { for (const [route, path] of Object.entries(routes)) {
// use app.get(path, handler) to add routes // use app.get(path, handler) to add routes
@ -113,6 +124,49 @@ for (const [route, path] of Object.entries(routes)) {
route.default(app); route.default(app);
} }
app.all("*", async (context) => {
if (config.frontend.glitch.enabled) {
const glitch = await handleGlitchRequest(context.req.raw, dualLogger);
if (glitch) {
return glitch;
}
}
const base_url_with_http = config.http.base_url.replace(
"https://",
"http://",
);
const replacedUrl = context.req.url
.replace(config.http.base_url, config.frontend.url)
.replace(base_url_with_http, config.frontend.url);
const proxy = await fetch(replacedUrl, {
headers: {
// Include for SSR
"X-Forwarded-Host": `${config.http.bind}:${config.http.bind_port}`,
"Accept-Encoding": "identity",
},
}).catch(async (e) => {
await dualLogger.logError(LogLevel.ERROR, "Server.Proxy", e as Error);
await dualLogger.log(
LogLevel.ERROR,
"Server.Proxy",
`The Frontend is not running or the route is not found: ${replacedUrl}`,
);
return null;
});
proxy?.headers.set("Cache-Control", "max-age=31536000");
if (!proxy || proxy.status === 404) {
return errorResponse("Route not found on proxy or API route", 404);
}
return proxy;
});
createServer(config, app); createServer(config, app);
await dualServerLogger.log( await dualServerLogger.log(

16
middlewares/agent-bans.ts Normal file
View file

@ -0,0 +1,16 @@
import { errorResponse } from "@response";
import { createMiddleware } from "hono/factory";
import { config } from "~packages/config-manager";
export const agentBans = createMiddleware(async (context, next) => {
// Check for banned user agents (regex)
const ua = context.req.header("user-agent") ?? "";
for (const agent of config.http.banned_user_agents) {
if (new RegExp(agent).test(ua)) {
return errorResponse("Forbidden", 403);
}
}
await next();
});

73
middlewares/bait.ts Normal file
View file

@ -0,0 +1,73 @@
import { logger } from "@loggers";
import { errorResponse, response } from "@response";
import type { SocketAddress } from "bun";
import { createMiddleware } from "hono/factory";
import { matches } from "ip-matching";
import { config } from "~packages/config-manager";
import { LogLevel } from "~packages/log-manager";
export const bait = createMiddleware(async (context, next) => {
const request_ip = context.env?.ip as SocketAddress | undefined | null;
if (config.http.bait.enabled) {
// Check for bait IPs
if (request_ip?.address) {
for (const ip of config.http.bait.bait_ips) {
try {
if (matches(ip, request_ip.address)) {
const file = Bun.file(
config.http.bait.send_file || "./beemovie.txt",
);
if (await file.exists()) {
return response(file);
}
await logger.log(
LogLevel.ERROR,
"Server.Bait",
`Bait file not found: ${config.http.bait.send_file}`,
);
}
} catch (e) {
logger.log(
LogLevel.ERROR,
"Server.IPCheck",
`Error while parsing bait IP "${ip}" `,
);
logger.logError(
LogLevel.ERROR,
"Server.IPCheck",
e as Error,
);
return errorResponse(
`A server error occured: ${(e as Error).message}`,
500,
);
}
}
}
// Check for bait user agents (regex)
const ua = context.req.header("user-agent") ?? "";
for (const agent of config.http.bait.bait_user_agents) {
if (new RegExp(agent).test(ua)) {
const file = Bun.file(
config.http.bait.send_file || "./beemovie.txt",
);
if (await file.exists()) {
return response(file);
}
await logger.log(
LogLevel.ERROR,
"Server.Bait",
`Bait file not found: ${config.http.bait.send_file}`,
);
}
}
}
await next();
});

40
middlewares/ip-bans.ts Normal file
View file

@ -0,0 +1,40 @@
import { logger } from "@loggers";
import { errorResponse } from "@response";
import type { SocketAddress } from "bun";
import { createMiddleware } from "hono/factory";
import { matches } from "ip-matching";
import { config } from "~packages/config-manager";
import { LogLevel } from "~packages/log-manager";
export const ipBans = createMiddleware(async (context, next) => {
// Check for banned IPs
const request_ip = context.env?.ip as SocketAddress | undefined | null;
if (!request_ip?.address) {
await next();
return;
}
for (const ip of config.http.banned_ips) {
try {
if (matches(ip, request_ip?.address)) {
return errorResponse("Forbidden", 403);
}
} catch (e) {
logger.log(
LogLevel.ERROR,
"Server.IPCheck",
`Error while parsing banned IP "${ip}" `,
);
logger.logError(LogLevel.ERROR, "Server.IPCheck", e as Error);
return errorResponse(
`A server error occured: ${(e as Error).message}`,
500,
);
}
}
await next();
});

18
middlewares/logger.ts Normal file
View file

@ -0,0 +1,18 @@
import { dualLogger } from "@loggers";
import type { SocketAddress } from "bun";
import { createMiddleware } from "hono/factory";
import { config } from "~packages/config-manager";
export const logger = createMiddleware(async (context, next) => {
const request_ip = context.env?.ip as SocketAddress | undefined | null;
if (config.logging.log_requests) {
await dualLogger.logRequest(
context.req.raw,
config.logging.log_ip ? request_ip?.address : undefined,
config.logging.log_requests_verbose,
);
}
await next();
});

207
server.ts
View file

@ -1,19 +1,7 @@
import { dualLogger } from "@loggers";
import { clientResponse, errorResponse, response } from "@response";
import type { MatchedRoute } from "bun";
import type { Config } from "config-manager"; import type { Config } from "config-manager";
import { matches } from "ip-matching"; import type { Hono } from "hono";
import type { LogManager, MultiLogManager } from "log-manager";
import { LogLevel } from "log-manager";
import { processRoute } from "server-handler";
import { handleGlitchRequest } from "~packages/glitch-server/main";
import { matchRoute } from "~routes";
export const createServer = ( export const createServer = (config: Config, app: Hono) =>
config: Config,
logger: LogManager | MultiLogManager,
isProd: boolean,
) =>
Bun.serve({ Bun.serve({
port: config.http.bind_port, port: config.http.bind_port,
tls: config.http.tls.enabled tls: config.http.tls.enabled
@ -27,194 +15,7 @@ export const createServer = (
} }
: undefined, : undefined,
hostname: config.http.bind || "0.0.0.0", // defaults to "0.0.0.0" hostname: config.http.bind || "0.0.0.0", // defaults to "0.0.0.0"
async fetch(req) { fetch(req, server) {
// Check for banned IPs return app.fetch(req, { ip: server.requestIP(req) });
const request_ip = this.requestIP(req)?.address ?? "";
for (const ip of config.http.banned_ips) {
try {
if (matches(ip, request_ip)) {
return errorResponse("Forbidden", 403);
}
} catch (e) {
logger.log(
LogLevel.ERROR,
"Server.IPCheck",
`Error while parsing banned IP "${ip}" `,
);
logger.logError(
LogLevel.ERROR,
"Server.IPCheck",
e as Error,
);
return errorResponse(
`A server error occured: ${(e as Error).message}`,
500,
);
}
}
// Check for banned user agents (regex)
const ua = req.headers.get("User-Agent") ?? "";
for (const agent of config.http.banned_user_agents) {
if (new RegExp(agent).test(ua)) {
return errorResponse("Forbidden", 403);
}
}
if (config.http.bait.enabled) {
// Check for bait IPs
for (const ip of config.http.bait.bait_ips) {
try {
if (matches(ip, request_ip)) {
const file = Bun.file(
config.http.bait.send_file || "./beemovie.txt",
);
if (await file.exists()) {
return response(file);
}
await logger.log(
LogLevel.ERROR,
"Server.Bait",
`Bait file not found: ${config.http.bait.send_file}`,
);
}
} catch (e) {
logger.log(
LogLevel.ERROR,
"Server.IPCheck",
`Error while parsing bait IP "${ip}" `,
);
logger.logError(
LogLevel.ERROR,
"Server.IPCheck",
e as Error,
);
return errorResponse(
`A server error occured: ${(e as Error).message}`,
500,
);
}
}
// Check for bait user agents (regex)
for (const agent of config.http.bait.bait_user_agents) {
if (new RegExp(agent).test(ua)) {
const file = Bun.file(
config.http.bait.send_file || "./beemovie.txt",
);
if (await file.exists()) {
return response(file);
}
await logger.log(
LogLevel.ERROR,
"Server.Bait",
`Bait file not found: ${config.http.bait.send_file}`,
);
}
}
}
if (config.logging.log_requests) {
await logger.logRequest(
req.clone(),
config.logging.log_ip ? request_ip : undefined,
config.logging.log_requests_verbose,
);
}
const routePaths = [
"/api",
"/media",
"/nodeinfo",
"/.well-known",
"/users",
"/objects",
"/oauth/token",
"/oauth/providers",
];
// Check if URL starts with routePath
if (
routePaths.some((path) =>
new URL(req.url).pathname.startsWith(path),
) ||
(new URL(req.url).pathname.startsWith("/oauth/authorize") &&
req.method === "POST")
) {
// If route is .well-known, remove dot because the filesystem router can't handle dots for some reason
const matchedRoute = matchRoute(
new Request(req.url.replace(".well-known", "well-known"), {
method: req.method,
}),
);
if (
matchedRoute?.filePath &&
matchedRoute.name !== "/[...404]" &&
!(
new URL(req.url).pathname.startsWith(
"/oauth/authorize",
) && req.method === "GET"
)
) {
return await processRoute(matchedRoute, req, logger);
}
}
if (config.frontend.glitch.enabled) {
if (!new URL(req.url).pathname.startsWith("/oauth")) {
const glitch = await handleGlitchRequest(req, dualLogger);
if (glitch) {
return glitch;
}
}
}
const base_url_with_http = config.http.base_url.replace(
"https://",
"http://",
);
const replacedUrl = req.url
.replace(config.http.base_url, config.frontend.url)
.replace(base_url_with_http, config.frontend.url);
const proxy = await fetch(replacedUrl, {
headers: {
// Include for SSR
"X-Forwarded-Host": `${config.http.bind}:${config.http.bind_port}`,
"Accept-Encoding": "identity",
},
}).catch(async (e) => {
await logger.logError(
LogLevel.ERROR,
"Server.Proxy",
e as Error,
);
await logger.log(
LogLevel.ERROR,
"Server.Proxy",
`The Frontend is not running or the route is not found: ${replacedUrl}`,
);
return null;
});
proxy?.headers.set("Cache-Control", "max-age=31536000");
if (!proxy || proxy.status === 404) {
return errorResponse(
"Route not found on proxy or API route",
404,
);
}
return proxy;
}, },
}); });

View file

@ -105,10 +105,10 @@ describe(meta.route, () => {
headers: { headers: {
Authorization: `Bearer ${tokens[1].accessToken}`, Authorization: `Bearer ${tokens[1].accessToken}`,
}, },
body: getFormData({ body: new URLSearchParams({
status: "Reply", status: "Reply",
in_reply_to_id: timeline[0].id, in_reply_to_id: timeline[0].id,
federate: false, federate: "false",
}), }),
}), }),
); );

View file

@ -1,4 +1,4 @@
import { applyConfig, auth, handleZodError, idValidator } from "@api"; import { applyConfig, auth, handleZodError, idValidator, qsQuery } from "@api";
import { zValidator } from "@hono/zod-validator"; import { zValidator } from "@hono/zod-validator";
import { errorResponse, jsonResponse } from "@response"; import { errorResponse, jsonResponse } from "@response";
import { inArray } from "drizzle-orm"; import { inArray } from "drizzle-orm";
@ -23,7 +23,7 @@ export const meta = applyConfig({
export const schemas = { export const schemas = {
query: z.object({ query: z.object({
"id[]": z.array(z.string().uuid()).min(1).max(10), id: z.array(z.string().uuid()).min(1).max(10).or(z.string().uuid()),
}), }),
}; };
@ -31,11 +31,12 @@ export default (app: Hono) =>
app.on( app.on(
meta.allowedMethods, meta.allowedMethods,
meta.route, meta.route,
qsQuery(),
zValidator("query", schemas.query, handleZodError), zValidator("query", schemas.query, handleZodError),
auth(meta.auth), auth(meta.auth),
async (context) => { async (context) => {
const { user: self } = context.req.valid("header"); const { user: self } = context.req.valid("header");
const { "id[]": ids } = context.req.valid("query"); const { id: ids } = context.req.valid("query");
if (!self) return errorResponse("Unauthorized", 401); if (!self) return errorResponse("Unauthorized", 401);
@ -46,7 +47,10 @@ export default (app: Hono) =>
}, },
where: (relationship, { inArray, and, eq }) => where: (relationship, { inArray, and, eq }) =>
and( and(
inArray(relationship.subjectId, ids), inArray(
relationship.subjectId,
Array.isArray(ids) ? ids : [ids],
),
eq(relationship.following, true), eq(relationship.following, true),
), ),
}); });

View file

@ -1,4 +1,4 @@
import { applyConfig, auth, handleZodError, idValidator } from "@api"; import { applyConfig, auth, handleZodError, idValidator, qsQuery } from "@api";
import { zValidator } from "@hono/zod-validator"; import { zValidator } from "@hono/zod-validator";
import { errorResponse, jsonResponse } from "@response"; import { errorResponse, jsonResponse } from "@response";
import type { Hono } from "hono"; import type { Hono } from "hono";
@ -25,7 +25,7 @@ export const meta = applyConfig({
export const schemas = { export const schemas = {
query: z.object({ query: z.object({
"id[]": z.array(z.string().uuid()).min(1).max(10), id: z.array(z.string().uuid()).min(1).max(10).or(z.string().uuid()),
}), }),
}; };
@ -33,11 +33,14 @@ export default (app: Hono) =>
app.on( app.on(
meta.allowedMethods, meta.allowedMethods,
meta.route, meta.route,
qsQuery(),
zValidator("query", schemas.query, handleZodError), zValidator("query", schemas.query, handleZodError),
auth(meta.auth), auth(meta.auth),
async (context) => { async (context) => {
const { user: self } = context.req.valid("header"); const { user: self } = context.req.valid("header");
const { "id[]": ids } = context.req.valid("query"); const { id } = context.req.valid("query");
const ids = Array.isArray(id) ? id : [id];
if (!self) return errorResponse("Unauthorized", 401); if (!self) return errorResponse("Unauthorized", 401);

View file

@ -83,10 +83,10 @@ beforeAll(async () => {
headers: { headers: {
Authorization: `Bearer ${tokens[1].accessToken}`, Authorization: `Bearer ${tokens[1].accessToken}`,
}, },
body: getFormData({ body: new URLSearchParams({
status: `@${users[0].getUser().username} test mention`, status: `@${users[0].getUser().username} test mention`,
visibility: "direct", visibility: "direct",
federate: false, federate: "false",
}), }),
}), }),
); );

View file

@ -27,7 +27,7 @@ describe(meta.route, () => {
const response = await sendTestRequest( const response = await sendTestRequest(
new Request(new URL(meta.route, config.http.base_url), { new Request(new URL(meta.route, config.http.base_url), {
method: "POST", method: "POST",
body: new FormData(), body: new URLSearchParams(),
}), }),
); );
@ -41,7 +41,7 @@ describe(meta.route, () => {
headers: { headers: {
Authorization: `Bearer ${tokens[0].accessToken}`, Authorization: `Bearer ${tokens[0].accessToken}`,
}, },
body: new FormData(), body: new URLSearchParams(),
}), }),
); );
@ -55,9 +55,9 @@ describe(meta.route, () => {
headers: { headers: {
Authorization: `Bearer ${tokens[0].accessToken}`, Authorization: `Bearer ${tokens[0].accessToken}`,
}, },
body: getFormData({ body: new URLSearchParams({
status: "a".repeat(config.validation.max_note_size + 1), status: "a".repeat(config.validation.max_note_size + 1),
federate: false, federate: "false",
}), }),
}), }),
); );
@ -72,10 +72,10 @@ describe(meta.route, () => {
headers: { headers: {
Authorization: `Bearer ${tokens[0].accessToken}`, Authorization: `Bearer ${tokens[0].accessToken}`,
}, },
body: getFormData({ body: new URLSearchParams({
status: "Hello, world!", status: "Hello, world!",
visibility: "invalid", visibility: "invalid",
federate: false, federate: "false",
}), }),
}), }),
); );
@ -90,10 +90,10 @@ describe(meta.route, () => {
headers: { headers: {
Authorization: `Bearer ${tokens[0].accessToken}`, Authorization: `Bearer ${tokens[0].accessToken}`,
}, },
body: getFormData({ body: new URLSearchParams({
status: "Hello, world!", status: "Hello, world!",
scheduled_at: "invalid", scheduled_at: "invalid",
federate: false, federate: "false",
}), }),
}), }),
); );
@ -108,10 +108,10 @@ describe(meta.route, () => {
headers: { headers: {
Authorization: `Bearer ${tokens[0].accessToken}`, Authorization: `Bearer ${tokens[0].accessToken}`,
}, },
body: getFormData({ body: new URLSearchParams({
status: "Hello, world!", status: "Hello, world!",
in_reply_to_id: "invalid", in_reply_to_id: "invalid",
federate: false, federate: "false",
}), }),
}), }),
); );
@ -126,10 +126,10 @@ describe(meta.route, () => {
headers: { headers: {
Authorization: `Bearer ${tokens[0].accessToken}`, Authorization: `Bearer ${tokens[0].accessToken}`,
}, },
body: getFormData({ body: new URLSearchParams({
status: "Hello, world!", status: "Hello, world!",
quote_id: "invalid", quote_id: "invalid",
federate: false, federate: "false",
}), }),
}), }),
); );
@ -144,10 +144,10 @@ describe(meta.route, () => {
headers: { headers: {
Authorization: `Bearer ${tokens[0].accessToken}`, Authorization: `Bearer ${tokens[0].accessToken}`,
}, },
body: getFormData({ body: new URLSearchParams({
status: "Hello, world!", status: "Hello, world!",
"media_ids[]": "invalid", "media_ids[]": "invalid",
federate: false, federate: "false",
}), }),
}), }),
); );
@ -162,9 +162,9 @@ describe(meta.route, () => {
headers: { headers: {
Authorization: `Bearer ${tokens[0].accessToken}`, Authorization: `Bearer ${tokens[0].accessToken}`,
}, },
body: getFormData({ body: new URLSearchParams({
status: "Hello, world!", status: "Hello, world!",
federate: false, federate: "false",
}), }),
}), }),
); );
@ -184,10 +184,10 @@ describe(meta.route, () => {
headers: { headers: {
Authorization: `Bearer ${tokens[0].accessToken}`, Authorization: `Bearer ${tokens[0].accessToken}`,
}, },
body: getFormData({ body: new URLSearchParams({
status: "Hello, world!", status: "Hello, world!",
visibility: "unlisted", visibility: "unlisted",
federate: false, federate: "false",
}), }),
}), }),
); );
@ -208,9 +208,9 @@ describe(meta.route, () => {
headers: { headers: {
Authorization: `Bearer ${tokens[0].accessToken}`, Authorization: `Bearer ${tokens[0].accessToken}`,
}, },
body: getFormData({ body: new URLSearchParams({
status: "Hello, world!", status: "Hello, world!",
federate: false, federate: "false",
}), }),
}), }),
); );
@ -223,10 +223,10 @@ describe(meta.route, () => {
headers: { headers: {
Authorization: `Bearer ${tokens[0].accessToken}`, Authorization: `Bearer ${tokens[0].accessToken}`,
}, },
body: getFormData({ body: new URLSearchParams({
status: "Hello, world again!", status: "Hello, world again!",
in_reply_to_id: object.id, in_reply_to_id: object.id,
federate: false, federate: "false",
}), }),
}), }),
); );
@ -247,9 +247,9 @@ describe(meta.route, () => {
headers: { headers: {
Authorization: `Bearer ${tokens[0].accessToken}`, Authorization: `Bearer ${tokens[0].accessToken}`,
}, },
body: getFormData({ body: new URLSearchParams({
status: "Hello, world!", status: "Hello, world!",
federate: false, federate: "false",
}), }),
}), }),
); );
@ -262,10 +262,10 @@ describe(meta.route, () => {
headers: { headers: {
Authorization: `Bearer ${tokens[0].accessToken}`, Authorization: `Bearer ${tokens[0].accessToken}`,
}, },
body: getFormData({ body: new URLSearchParams({
status: "Hello, world again!", status: "Hello, world again!",
quote_id: object.id, quote_id: object.id,
federate: false, federate: "false",
}), }),
}), }),
); );
@ -290,9 +290,9 @@ describe(meta.route, () => {
headers: { headers: {
Authorization: `Bearer ${tokens[0].accessToken}`, Authorization: `Bearer ${tokens[0].accessToken}`,
}, },
body: getFormData({ body: new URLSearchParams({
status: `Hello, @${users[1].getUser().username}!`, status: `Hello, @${users[1].getUser().username}!`,
federate: false, federate: "false",
}), }),
}), }),
); );
@ -319,11 +319,11 @@ describe(meta.route, () => {
headers: { headers: {
Authorization: `Bearer ${tokens[0].accessToken}`, Authorization: `Bearer ${tokens[0].accessToken}`,
}, },
body: getFormData({ body: new URLSearchParams({
status: `Hello, @${users[1].getUser().username}@${ status: `Hello, @${users[1].getUser().username}@${
new URL(config.http.base_url).host new URL(config.http.base_url).host
}!`, }!`,
federate: false, federate: "false",
}), }),
}), }),
); );
@ -352,9 +352,9 @@ describe(meta.route, () => {
headers: { headers: {
Authorization: `Bearer ${tokens[0].accessToken}`, Authorization: `Bearer ${tokens[0].accessToken}`,
}, },
body: getFormData({ body: new URLSearchParams({
status: "Hi! <script>alert('Hello, world!');</script>", status: "Hi! <script>alert('Hello, world!');</script>",
federate: false, federate: "false",
}), }),
}), }),
); );
@ -378,11 +378,11 @@ describe(meta.route, () => {
headers: { headers: {
Authorization: `Bearer ${tokens[0].accessToken}`, Authorization: `Bearer ${tokens[0].accessToken}`,
}, },
body: getFormData({ body: new URLSearchParams({
status: "Hello, world!", status: "Hello, world!",
spoiler_text: spoiler_text:
"uwu <script>alert('Hello, world!');</script>", "uwu <script>alert('Hello, world!');</script>",
federate: false, federate: "false",
}), }),
}), }),
); );

View file

@ -1,4 +1,4 @@
import { applyConfig, auth, handleZodError } from "@api"; import { applyConfig, auth, handleZodError, qs } from "@api";
import { zValidator } from "@hono/zod-validator"; import { zValidator } from "@hono/zod-validator";
import { errorResponse, jsonResponse } from "@response"; import { errorResponse, jsonResponse } from "@response";
import { config } from "config-manager"; import { config } from "config-manager";
@ -30,7 +30,7 @@ export const schemas = {
.optional(), .optional(),
// TODO: Add regex to validate // TODO: Add regex to validate
content_type: z.string().optional().default("text/plain"), content_type: z.string().optional().default("text/plain"),
"media_ids[]": z media_ids: z
.array(z.string().uuid()) .array(z.string().uuid())
.max(config.validation.max_media_attachments) .max(config.validation.max_media_attachments)
.optional(), .optional(),
@ -83,6 +83,7 @@ export default (app: Hono) =>
app.on( app.on(
meta.allowedMethods, meta.allowedMethods,
meta.route, meta.route,
qs(),
zValidator("form", schemas.form, handleZodError), zValidator("form", schemas.form, handleZodError),
auth(meta.auth), auth(meta.auth),
async (context) => { async (context) => {
@ -92,7 +93,7 @@ export default (app: Hono) =>
const { const {
status, status,
"media_ids[]": media_ids, media_ids,
"poll[options]": options, "poll[options]": options,
in_reply_to_id, in_reply_to_id,
quote_id, quote_id,

View file

@ -29,13 +29,11 @@ export const schemas = {
.enum(["none", "login", "consent", "select_account"]) .enum(["none", "login", "consent", "select_account"])
.optional() .optional()
.default("none"), .default("none"),
max_age: z max_age: z.coerce
.number() .number()
.int() .int()
.optional() .optional()
.default(60 * 60 * 24 * 7), .default(60 * 60 * 24 * 7),
}),
body: z.object({
scope: z.string().optional(), scope: z.string().optional(),
redirect_uri: z.string().url().optional(), redirect_uri: z.string().url().optional(),
response_type: z.enum([ response_type: z.enum([
@ -77,7 +75,6 @@ export default (app: Hono) =>
meta.allowedMethods, meta.allowedMethods,
meta.route, meta.route,
zValidator("query", schemas.query, handleZodError), zValidator("query", schemas.query, handleZodError),
zValidator("json", schemas.body, handleZodError),
async (context) => { async (context) => {
const { const {
scope, scope,
@ -87,8 +84,8 @@ export default (app: Hono) =>
state, state,
code_challenge, code_challenge,
code_challenge_method, code_challenge_method,
} = context.req.valid("json"); } = context.req.valid("query");
const body = context.req.valid("json"); const body = context.req.valid("query");
const cookie = context.req.header("Cookie"); const cookie = context.req.header("Cookie");

View file

@ -20,7 +20,7 @@ export const meta = applyConfig({
}); });
export const schemas = { export const schemas = {
json: z.object({ form: z.object({
code: z.string().optional(), code: z.string().optional(),
code_verifier: z.string().optional(), code_verifier: z.string().optional(),
grant_type: z.enum([ grant_type: z.enum([
@ -63,10 +63,10 @@ export default (app: Hono) =>
app.on( app.on(
meta.allowedMethods, meta.allowedMethods,
meta.route, meta.route,
zValidator("json", schemas.json, handleZodError), zValidator("form", schemas.form, handleZodError),
async (context) => { async (context) => {
const { grant_type, code, redirect_uri, client_id, client_secret } = const { grant_type, code, redirect_uri, client_id, client_secret } =
context.req.valid("json"); context.req.valid("form");
switch (grant_type) { switch (grant_type) {
case "authorization_code": { case "authorization_code": {

View file

@ -1,21 +0,0 @@
import type { Config } from "config-manager";
import type { Hono } from "hono";
export const createServer = (config: Config, app: Hono) =>
Bun.serve({
port: config.http.bind_port,
tls: config.http.tls.enabled
? {
key: Bun.file(config.http.tls.key),
cert: Bun.file(config.http.tls.cert),
passphrase: config.http.tls.passphrase,
ca: config.http.tls.ca
? Bun.file(config.http.tls.ca)
: undefined,
}
: undefined,
hostname: config.http.bind || "0.0.0.0", // defaults to "0.0.0.0"
fetch(req, server) {
return app.fetch(req, { ip: server.requestIP(req) });
},
});

View file

@ -16,6 +16,12 @@ afterAll(async () => {
await deleteUsers(); await deleteUsers();
}); });
const getFormData = (object: Record<string, string | number | boolean>) =>
Object.keys(object).reduce((formData, key) => {
formData.append(key, String(object[key]));
return formData;
}, new FormData());
describe("API Tests", () => { describe("API Tests", () => {
describe("PATCH /api/v1/accounts/update_credentials", () => { describe("PATCH /api/v1/accounts/update_credentials", () => {
test("should update the authenticated user's display name", async () => { test("should update the authenticated user's display name", async () => {
@ -29,9 +35,8 @@ describe("API Tests", () => {
method: "PATCH", method: "PATCH",
headers: { headers: {
Authorization: `Bearer ${token.accessToken}`, Authorization: `Bearer ${token.accessToken}`,
"Content-Type": "application/json",
}, },
body: JSON.stringify({ body: getFormData({
display_name: "New Display Name", display_name: "New Display Name",
}), }),
}, },

View file

@ -1,7 +1,6 @@
import { afterAll, describe, expect, test } from "bun:test"; import { afterAll, describe, expect, test } from "bun:test";
import { config } from "config-manager"; import { config } from "config-manager";
import { getTestUsers, sendTestRequest, wrapRelativeUrl } from "~tests/utils"; import { getTestUsers, sendTestRequest, wrapRelativeUrl } from "~tests/utils";
import type { Account as APIAccount } from "~types/mastodon/account";
import type { AsyncAttachment as APIAsyncAttachment } from "~types/mastodon/async_attachment"; import type { AsyncAttachment as APIAsyncAttachment } from "~types/mastodon/async_attachment";
import type { Context as APIContext } from "~types/mastodon/context"; import type { Context as APIContext } from "~types/mastodon/context";
import type { Status as APIStatus } from "~types/mastodon/status"; import type { Status as APIStatus } from "~types/mastodon/status";
@ -60,13 +59,12 @@ describe("API Tests", () => {
method: "POST", method: "POST",
headers: { headers: {
Authorization: `Bearer ${token.accessToken}`, Authorization: `Bearer ${token.accessToken}`,
"Content-Type": "application/json",
}, },
body: JSON.stringify({ body: new URLSearchParams({
status: "Hello, world!", status: "Hello, world!",
visibility: "public", visibility: "public",
media_ids: [media1?.id], "media_ids[]": media1?.id ?? "",
federate: false, federate: "false",
}), }),
}, },
), ),
@ -108,13 +106,12 @@ describe("API Tests", () => {
method: "POST", method: "POST",
headers: { headers: {
Authorization: `Bearer ${token.accessToken}`, Authorization: `Bearer ${token.accessToken}`,
"Content-Type": "application/json",
}, },
body: JSON.stringify({ body: new URLSearchParams({
status: "This is a reply!", status: "This is a reply!",
visibility: "public", visibility: "public",
in_reply_to_id: status?.id, in_reply_to_id: status?.id ?? "",
federate: false, federate: "false",
}), }),
}, },
), ),

View file

@ -1,4 +1,5 @@
import { afterAll, describe, expect, test } from "bun:test"; import { afterAll, describe, expect, test } from "bun:test";
import { config } from "~packages/config-manager";
import type { Application as APIApplication } from "~types/mastodon/application"; import type { Application as APIApplication } from "~types/mastodon/application";
import type { Token as APIToken } from "~types/mastodon/token"; import type { Token as APIToken } from "~types/mastodon/token";
import { import {
@ -8,7 +9,7 @@ import {
wrapRelativeUrl, wrapRelativeUrl,
} from "./utils"; } from "./utils";
const base_url = "http://lysand.localhost:8080"; //config.http.base_url; const base_url = config.http.base_url;
let client_id: string; let client_id: string;
let client_secret: string; let client_secret: string;
@ -19,8 +20,8 @@ const { users, passwords, deleteUsers } = await getTestUsers(1);
afterAll(async () => { afterAll(async () => {
await deleteUsers(); await deleteUsers();
await deleteOldTestUsers();
}); });
describe("POST /api/v1/apps/", () => { describe("POST /api/v1/apps/", () => {
test("should create an application", async () => { test("should create an application", async () => {
const formData = new FormData(); const formData = new FormData();
@ -31,7 +32,7 @@ describe("POST /api/v1/apps/", () => {
formData.append("scopes", "read write"); formData.append("scopes", "read write");
const response = await sendTestRequest( const response = await sendTestRequest(
new Request(wrapRelativeUrl("/api/v1/apps/", base_url), { new Request(new URL("/api/v1/apps", config.http.base_url), {
method: "POST", method: "POST",
body: formData, body: formData,
}), }),
@ -66,8 +67,8 @@ describe("POST /api/auth/login/", () => {
const response = await sendTestRequest( const response = await sendTestRequest(
new Request( new Request(
wrapRelativeUrl( new URL(
`/api/auth/login/?client_id=${client_id}&redirect_uri=https://example.com&response_type=code&scope=read+write`, `/api/auth/login?client_id=${client_id}&redirect_uri=https://example.com&response_type=code&scope=read+write`,
base_url, base_url,
), ),
{ {
@ -77,8 +78,6 @@ describe("POST /api/auth/login/", () => {
), ),
); );
console.log(await response.text());
expect(response.status).toBe(302); expect(response.status).toBe(302);
expect(response.headers.get("location")).toBeDefined(); expect(response.headers.get("location")).toBeDefined();
const locationHeader = new URL( const locationHeader = new URL(
@ -102,24 +101,28 @@ describe("POST /api/auth/login/", () => {
}); });
}); });
describe("POST /oauth/authorize/", () => { describe("GET /oauth/authorize/", () => {
test("should get a code", async () => { test("should get a code", async () => {
const response = await sendTestRequest( const response = await sendTestRequest(
new Request(wrapRelativeUrl("/oauth/authorize", base_url), { new Request(
method: "POST", new URL(
headers: { `/oauth/authorize?${new URLSearchParams({
Cookie: `jwt=${jwt}`, client_id,
"Content-Type": "application/x-www-form-urlencoded", client_secret,
redirect_uri: "https://example.com",
response_type: "code",
scope: "read write",
max_age: "604800",
})}`,
base_url,
),
{
method: "POST",
headers: {
Cookie: `jwt=${jwt}`,
},
}, },
body: new URLSearchParams({ ),
client_id,
client_secret,
redirect_uri: "https://example.com",
response_type: "code",
scope: "read write",
max_age: "604800",
}),
}),
); );
expect(response.status).toBe(302); expect(response.status).toBe(302);
@ -138,7 +141,7 @@ describe("POST /oauth/authorize/", () => {
describe("POST /oauth/token/", () => { describe("POST /oauth/token/", () => {
test("should get an access token", async () => { test("should get an access token", async () => {
const response = await sendTestRequest( const response = await sendTestRequest(
new Request(wrapRelativeUrl("/oauth/token/", base_url), { new Request(wrapRelativeUrl("/oauth/token", base_url), {
method: "POST", method: "POST",
headers: { headers: {
Authorization: `Bearer ${jwt}`, Authorization: `Bearer ${jwt}`,

View file

@ -147,6 +147,8 @@ export const qsQuery = () => {
// @ts-ignore Very bad hack // @ts-ignore Very bad hack
context.req.query = () => parsed; context.req.query = () => parsed;
// @ts-ignore I'm so sorry for this
context.req.queries = () => parsed;
await next(); await next();
}); });
}; };