diff --git a/index.ts b/index.ts index ad8651e0..938af1f5 100644 --- a/index.ts +++ b/index.ts @@ -11,6 +11,7 @@ import { bait } from "~middlewares/bait"; import { boundaryCheck } from "~middlewares/boundary-check"; import { ipBans } from "~middlewares/ip-bans"; import { logger } from "~middlewares/logger"; +import { urlCheck } from "~middlewares/url-check"; import { Note } from "~packages/database-interface/note"; import { handleGlitchRequest } from "~packages/glitch-server/main"; import { routes } from "~routes"; @@ -117,6 +118,7 @@ app.use(agentBans); app.use(bait); app.use(logger); app.use(boundaryCheck); +app.use(urlCheck); // Inject own filesystem router for (const [route, path] of Object.entries(routes)) { diff --git a/middlewares/url-check.ts b/middlewares/url-check.ts new file mode 100644 index 00000000..dc42512e --- /dev/null +++ b/middlewares/url-check.ts @@ -0,0 +1,17 @@ +import { errorResponse } from "@response"; +import { createMiddleware } from "hono/factory"; +import { config } from "~packages/config-manager"; + +export const urlCheck = createMiddleware(async (context, next) => { + // Check that request URL matches base_url + const baseUrl = new URL(config.http.base_url); + + if (new URL(context.req.url).origin !== baseUrl.origin) { + return errorResponse( + `Request URL ${context.req.url} does not match base URL ${baseUrl.origin}`, + 400, + ); + } + + await next(); +}); diff --git a/server/api/users/:uuid/inbox/index.ts b/server/api/users/:uuid/inbox/index.ts index afb212f4..68a73fe7 100644 --- a/server/api/users/:uuid/inbox/index.ts +++ b/server/api/users/:uuid/inbox/index.ts @@ -1,7 +1,11 @@ import { applyConfig, handleZodError } from "@api"; import { zValidator } from "@hono/zod-validator"; import { dualLogger } from "@loggers"; -import { EntityValidator, SignatureValidator } from "@lysand-org/federation"; +import { + EntityValidator, + type HttpVerb, + SignatureValidator, +} from "@lysand-org/federation"; import { errorResponse, jsonResponse, response } from "@response"; import type { SocketAddress } from "bun"; import { eq } from "drizzle-orm"; @@ -53,6 +57,8 @@ export default (app: Hono) => async (context) => { const { uuid } = context.req.valid("param"); const { signature, date } = context.req.valid("header"); + const body: typeof EntityValidator.$Entity = + await context.req.valid("json"); const user = await User.fromId(uuid); @@ -105,7 +111,13 @@ export default (app: Hono) => ); const isValid = await validator - .validate(context.req.raw) + .validate( + signature, + new Date(Date.parse(date)), + context.req.method as HttpVerb, + new URL(context.req.url), + await context.req.text(), + ) .catch((e) => { dualLogger.logError( LogLevel.ERROR, @@ -121,8 +133,6 @@ export default (app: Hono) => } const validator = new EntityValidator(); - const body: typeof EntityValidator.$Entity = - await context.req.valid("json"); try { // Add sent data to database diff --git a/tests/api.test.ts b/tests/api.test.ts index cd992e03..09a11815 100644 --- a/tests/api.test.ts +++ b/tests/api.test.ts @@ -35,4 +35,29 @@ describe("API Tests", () => { expect(data.error).toBeString(); expect(data.error).toContain("https://stackoverflow.com"); }); + + test("try sending a request with a different origin", async () => { + if (new URL(config.http.base_url).protocol === "http:") { + return; + } + + const response = await sendTestRequest( + new Request( + new URL( + "/api/v1/instance", + base_url.replace("https://", "http://"), + ), + { + method: "GET", + headers: { + Authorization: `Bearer ${tokens[0].accessToken}`, + }, + }, + ), + ); + + expect(response.status).toBe(400); + const data = await response.json(); + expect(data.error).toContain("does not match base URL"); + }); });