refactor: workers

This commit is contained in:
DevMiner 2024-08-29 00:46:32 +02:00
parent 02ad720439
commit 8fa44e5f72
20 changed files with 764 additions and 465 deletions

View file

@ -1,13 +1,24 @@
package config
import (
"git.devminer.xyz/devminer/unitel"
"github.com/joho/godotenv"
"github.com/rs/zerolog/log"
"net/url"
"os"
"regexp"
"slices"
"strconv"
"strings"
"git.devminer.xyz/devminer/unitel"
"github.com/joho/godotenv"
"github.com/rs/zerolog/log"
)
type Mode string
const (
ModeCombined Mode = "combined"
ModeWeb Mode = "web"
ModeConsumer Mode = "consumer"
)
type Config struct {
@ -25,6 +36,9 @@ type Config struct {
NATSURI string
NATSStreamName string
Mode Mode
Consumers []string
DatabaseURI string
Telemetry unitel.Opts
@ -62,6 +76,13 @@ func Load() {
Msg("Both VERSIA_TLS_KEY and VERSIA_TLS_CERT have to be set if you want to use in-process TLS termination.")
}
mode := getEnvStrOneOf("VERSIA_MODE", ModeCombined, ModeCombined, ModeWeb, ModeConsumer)
var consumers []string
if raw := optionalEnvStr("VERSIA_TQ_CUSTOMERS"); raw != nil {
consumers = strings.Split(*raw, ",")
}
C = Config{
Port: getEnvInt("VERSIA_PORT", 80),
TLSCert: tlsCert,
@ -76,13 +97,15 @@ func Load() {
NATSURI: os.Getenv("NATS_URI"),
NATSStreamName: getEnvStr("NATS_STREAM_NAME", "versia-go"),
DatabaseURI: os.Getenv("DATABASE_URI"),
Mode: mode,
Consumers: consumers,
DatabaseURI: os.Getenv("DATABASE_URI"),
ForwardTracesTo: forwardTracesTo,
Telemetry: unitel.ParseOpts("versia-go"),
}
return
}
func optionalEnvStr(key string) *string {
@ -93,6 +116,18 @@ func optionalEnvStr(key string) *string {
return &value
}
func getEnvBool(key string, default_ bool) bool {
if value, ok := os.LookupEnv(key); ok {
b, err := strconv.ParseBool(value)
if err != nil {
panic(err)
}
return b
}
return default_
}
func getEnvStr(key, default_ string) string {
if value, ok := os.LookupEnv(key); ok {
return value
@ -113,3 +148,22 @@ func getEnvInt(key string, default_ int) int {
return default_
}
func getEnvStrOneOf[T ~string](key string, default_ T, enum ...T) T {
if value, ok := os.LookupEnv(key); ok {
if !slices.Contains(enum, T(value)) {
sb := strings.Builder{}
sb.WriteString(key)
sb.WriteString(" can only be one of ")
for _, v := range enum {
sb.WriteString(string(v))
}
panic(sb.String())
}
return T(value)
}
return default_
}

View file

@ -89,6 +89,10 @@ func (i *ManagerImpl) Atomic(ctx context.Context, fn func(ctx context.Context, t
return tx.Finish()
}
func (i *ManagerImpl) Ping() error {
return i.db.Ping()
}
func (i *ManagerImpl) Users() repository.UserRepository {
return i.users
}

View file

@ -51,6 +51,7 @@ type InstanceMetadataRepository interface {
type Manager interface {
Atomic(ctx context.Context, fn func(ctx context.Context, tx Manager) error) error
Ping() error
Users() UserRepository
Notes() NoteRepository

View file

@ -2,6 +2,7 @@ package service
import (
"context"
"github.com/gofiber/fiber/v2"
"github.com/versia-pub/versia-go/internal/repository"
"github.com/versia-pub/versia-go/pkg/versia"
@ -57,7 +58,7 @@ type InstanceMetadataService interface {
}
type TaskService interface {
ScheduleTask(ctx context.Context, type_ string, data any) error
ScheduleNoteTask(ctx context.Context, type_ string, data any) error
}
type RequestSigner interface {

View file

@ -2,17 +2,18 @@ package svc_impls
import (
"context"
"slices"
"github.com/versia-pub/versia-go/internal/repository"
"github.com/versia-pub/versia-go/internal/service"
task_dtos "github.com/versia-pub/versia-go/internal/task/dtos"
"github.com/versia-pub/versia-go/pkg/versia"
"slices"
"git.devminer.xyz/devminer/unitel"
"github.com/go-logr/logr"
"github.com/google/uuid"
"github.com/versia-pub/versia-go/internal/api_schema"
"github.com/versia-pub/versia-go/internal/entity"
"github.com/versia-pub/versia-go/internal/tasks"
)
var _ service.NoteService = (*NoteServiceImpl)(nil)
@ -69,7 +70,7 @@ func (i NoteServiceImpl) CreateNote(ctx context.Context, req api_schema.CreateNo
return err
}
if err := i.taskService.ScheduleTask(ctx, tasks.FederateNote, tasks.FederateNoteData{NoteID: n.ID}); err != nil {
if err := i.taskService.ScheduleNoteTask(ctx, task_dtos.FederateNote, task_dtos.FederateNoteData{NoteID: n.ID}); err != nil {
return err
}

View file

@ -2,7 +2,9 @@ package svc_impls
import (
"context"
"github.com/versia-pub/versia-go/internal/service"
"github.com/versia-pub/versia-go/internal/task"
"git.devminer.xyz/devminer/unitel"
"github.com/go-logr/logr"
@ -12,22 +14,22 @@ import (
var _ service.TaskService = (*TaskServiceImpl)(nil)
type TaskServiceImpl struct {
client *taskqueue.Client
manager task.Manager
telemetry *unitel.Telemetry
log logr.Logger
}
func NewTaskServiceImpl(client *taskqueue.Client, telemetry *unitel.Telemetry, log logr.Logger) *TaskServiceImpl {
func NewTaskServiceImpl(manager task.Manager, telemetry *unitel.Telemetry, log logr.Logger) *TaskServiceImpl {
return &TaskServiceImpl{
client: client,
manager: manager,
telemetry: telemetry,
log: log,
}
}
func (i TaskServiceImpl) ScheduleTask(ctx context.Context, type_ string, data any) error {
func (i TaskServiceImpl) ScheduleNoteTask(ctx context.Context, type_ string, data any) error {
s := i.telemetry.StartSpan(ctx, "function", "svc_impls/TaskServiceImpl.ScheduleTask")
defer s.End()
ctx = s.Context()
@ -38,7 +40,7 @@ func (i TaskServiceImpl) ScheduleTask(ctx context.Context, type_ string, data an
return err
}
if err := i.client.Submit(ctx, t); err != nil {
if err := i.manager.Notes().Submit(ctx, t); err != nil {
i.log.Error(err, "Failed to schedule task", "type", type_, "taskID", t.ID)
return err
}

View file

@ -0,0 +1,11 @@
package task_dtos
import "github.com/google/uuid"
const (
FederateNote = "federate_note"
)
type FederateNoteData struct {
NoteID uuid.UUID `json:"noteID"`
}

20
internal/task/handler.go Normal file
View file

@ -0,0 +1,20 @@
package task
import (
"context"
"github.com/versia-pub/versia-go/pkg/taskqueue"
)
type Manager interface {
Notes() NoteHandler
}
type Handler interface {
Register(*taskqueue.Set)
Submit(context.Context, taskqueue.Task) error
}
type NoteHandler interface {
Submit(context.Context, taskqueue.Task) error
}

View file

@ -0,0 +1,11 @@
package task_impls
import "git.devminer.xyz/devminer/unitel"
type baseHandler struct {
telemetry *unitel.Telemetry
}
func newBaseHandler() *baseHandler {
return &baseHandler{}
}

View file

@ -0,0 +1,29 @@
package task_impls
import (
"git.devminer.xyz/devminer/unitel"
"github.com/go-logr/logr"
"github.com/versia-pub/versia-go/internal/task"
)
var _ task.Manager = (*Manager)(nil)
type Manager struct {
notes *NoteHandler
telemetry *unitel.Telemetry
log logr.Logger
}
func NewManager(notes *NoteHandler, telemetry *unitel.Telemetry, log logr.Logger) *Manager {
return &Manager{
notes: notes,
telemetry: telemetry,
log: log,
}
}
func (m *Manager) Notes() task.NoteHandler {
return m.notes
}

View file

@ -0,0 +1,97 @@
package task_impls
import (
"context"
"github.com/versia-pub/versia-go/internal/entity"
"github.com/versia-pub/versia-go/internal/repository"
"github.com/versia-pub/versia-go/internal/service"
"github.com/versia-pub/versia-go/internal/task"
task_dtos "github.com/versia-pub/versia-go/internal/task/dtos"
"github.com/versia-pub/versia-go/internal/utils"
"git.devminer.xyz/devminer/unitel"
"github.com/go-logr/logr"
"github.com/versia-pub/versia-go/pkg/taskqueue"
)
var _ task.Handler = (*NoteHandler)(nil)
type NoteHandler struct {
federationService service.FederationService
repositories repository.Manager
telemetry *unitel.Telemetry
log logr.Logger
set *taskqueue.Set
}
func NewNoteHandler(federationService service.FederationService, repositories repository.Manager, telemetry *unitel.Telemetry, log logr.Logger) *NoteHandler {
return &NoteHandler{
federationService: federationService,
repositories: repositories,
telemetry: telemetry,
log: log,
}
}
func (t *NoteHandler) Start(ctx context.Context) error {
consumer := t.set.Consumer("note-handler")
return consumer.Start(ctx)
}
func (t *NoteHandler) Register(s *taskqueue.Set) {
t.set = s
s.RegisterHandler(task_dtos.FederateNote, utils.ParseTask(t.FederateNote))
}
func (t *NoteHandler) Submit(ctx context.Context, task taskqueue.Task) error {
s := t.telemetry.StartSpan(ctx, "function", "task_impls/NoteHandler.Submit")
defer s.End()
ctx = s.Context()
return t.set.Submit(ctx, task)
}
func (t *NoteHandler) FederateNote(ctx context.Context, data task_dtos.FederateNoteData) error {
s := t.telemetry.StartSpan(ctx, "function", "task_impls/NoteHandler.FederateNote")
defer s.End()
ctx = s.Context()
var n *entity.Note
if err := t.repositories.Atomic(ctx, func(ctx context.Context, tx repository.Manager) error {
var err error
n, err = tx.Notes().GetByID(ctx, data.NoteID)
if err != nil {
return err
}
if n == nil {
t.log.V(-1).Info("Could not find note", "id", data.NoteID)
return nil
}
for _, uu := range n.Mentions {
if !uu.IsRemote {
t.log.V(2).Info("User is not remote", "user", uu.ID)
continue
}
res, err := t.federationService.SendToInbox(ctx, n.Author, &uu, n.ToVersia())
if err != nil {
t.log.Error(err, "Failed to send note to remote user", "res", res, "user", uu.ID)
} else {
t.log.V(2).Info("Sent note to remote user", "res", res, "user", uu.ID)
}
}
return nil
}); err != nil {
return err
}
return nil
}

View file

@ -1,11 +0,0 @@
package tasks
import "context"
type FederateFollowData struct {
FollowID string `json:"followID"`
}
func (t *Handler) FederateFollow(ctx context.Context, data FederateNoteData) error {
return nil
}

View file

@ -1,52 +0,0 @@
package tasks
import (
"context"
"github.com/versia-pub/versia-go/internal/repository"
"github.com/google/uuid"
"github.com/versia-pub/versia-go/internal/entity"
)
type FederateNoteData struct {
NoteID uuid.UUID `json:"noteID"`
}
func (t *Handler) FederateNote(ctx context.Context, data FederateNoteData) error {
s := t.telemetry.StartSpan(ctx, "function", "tasks/Handler.FederateNote")
defer s.End()
ctx = s.Context()
var n *entity.Note
if err := t.repositories.Atomic(ctx, func(ctx context.Context, tx repository.Manager) error {
var err error
n, err = tx.Notes().GetByID(ctx, data.NoteID)
if err != nil {
return err
}
if n == nil {
t.log.V(-1).Info("Could not find note", "id", data.NoteID)
return nil
}
for _, uu := range n.Mentions {
if !uu.IsRemote {
t.log.V(2).Info("User is not remote", "user", uu.ID)
continue
}
res, err := t.federationService.SendToInbox(ctx, n.Author, &uu, n.ToVersia())
if err != nil {
t.log.Error(err, "Failed to send note to remote user", "res", res, "user", uu.ID)
} else {
t.log.V(2).Info("Sent note to remote user", "res", res, "user", uu.ID)
}
}
return nil
}); err != nil {
return err
}
return nil
}

View file

@ -1,53 +0,0 @@
package tasks
import (
"context"
"encoding/json"
"github.com/versia-pub/versia-go/internal/repository"
"github.com/versia-pub/versia-go/internal/service"
"git.devminer.xyz/devminer/unitel"
"github.com/go-logr/logr"
"github.com/versia-pub/versia-go/pkg/taskqueue"
)
const (
FederateNote = "federate_note"
FederateFollow = "federate_follow"
)
type Handler struct {
federationService service.FederationService
repositories repository.Manager
telemetry *unitel.Telemetry
log logr.Logger
}
func NewHandler(federationService service.FederationService, repositories repository.Manager, telemetry *unitel.Telemetry, log logr.Logger) *Handler {
return &Handler{
federationService: federationService,
repositories: repositories,
telemetry: telemetry,
log: log,
}
}
func (t *Handler) Register(tq *taskqueue.Client) {
tq.RegisterHandler(FederateNote, parse(t.FederateNote))
tq.RegisterHandler(FederateFollow, parse(t.FederateFollow))
}
func parse[T any](handler func(context.Context, T) error) func(context.Context, taskqueue.Task) error {
return func(ctx context.Context, task taskqueue.Task) error {
var data T
if err := json.Unmarshal(task.Payload, &data); err != nil {
return err
}
return handler(ctx, data)
}
}

19
internal/utils/tasks.go Normal file
View file

@ -0,0 +1,19 @@
package utils
import (
"context"
"encoding/json"
"github.com/versia-pub/versia-go/pkg/taskqueue"
)
func ParseTask[T any](handler func(context.Context, T) error) func(context.Context, taskqueue.Task) error {
return func(ctx context.Context, task taskqueue.Task) error {
var data T
if err := json.Unmarshal(task.Payload, &data); err != nil {
return err
}
return handler(ctx, data)
}
}

234
main.go
View file

@ -7,42 +7,35 @@ import (
"crypto/tls"
"database/sql"
"database/sql/driver"
"fmt"
"git.devminer.xyz/devminer/unitel/unitelhttp"
"git.devminer.xyz/devminer/unitel/unitelsql"
"github.com/versia-pub/versia-go/ent/instancemetadata"
"github.com/versia-pub/versia-go/internal/api_schema"
"github.com/versia-pub/versia-go/internal/handlers/follow_handler"
"github.com/versia-pub/versia-go/internal/handlers/meta_handler"
"github.com/versia-pub/versia-go/internal/handlers/note_handler"
"github.com/versia-pub/versia-go/internal/repository"
"github.com/versia-pub/versia-go/internal/repository/repo_impls"
"github.com/versia-pub/versia-go/internal/service/svc_impls"
"github.com/versia-pub/versia-go/internal/validators/val_impls"
"net/http"
"os"
"os/signal"
"slices"
"strings"
"sync"
"time"
"entgo.io/ent/dialect"
entsql "entgo.io/ent/dialect/sql"
"git.devminer.xyz/devminer/unitel"
"git.devminer.xyz/devminer/unitel/unitelhttp"
"git.devminer.xyz/devminer/unitel/unitelsql"
"github.com/go-logr/logr"
"github.com/go-logr/zerologr"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/middleware/cors"
pgx "github.com/jackc/pgx/v5/stdlib"
"github.com/nats-io/nats.go"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
"github.com/versia-pub/versia-go/config"
"github.com/versia-pub/versia-go/ent"
"github.com/versia-pub/versia-go/ent/instancemetadata"
"github.com/versia-pub/versia-go/internal/database"
"github.com/versia-pub/versia-go/internal/handlers/user_handler"
"github.com/versia-pub/versia-go/internal/tasks"
"github.com/versia-pub/versia-go/internal/repository"
"github.com/versia-pub/versia-go/internal/repository/repo_impls"
"github.com/versia-pub/versia-go/internal/service/svc_impls"
"github.com/versia-pub/versia-go/internal/task"
"github.com/versia-pub/versia-go/internal/task/task_impls"
"github.com/versia-pub/versia-go/internal/utils"
"github.com/versia-pub/versia-go/internal/validators/val_impls"
"github.com/versia-pub/versia-go/pkg/taskqueue"
"modernc.org/sqlite"
)
@ -52,11 +45,9 @@ func init() {
log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr})
}
func shouldPropagate(r *http.Request) bool {
return config.C.ForwardTracesTo.Match([]byte(r.URL.String()))
}
func main() {
rootCtx, cancelRoot := context.WithCancel(context.Background())
zerolog.SetGlobalLevel(zerolog.TraceLevel)
zerologr.NameFieldName = "logger"
zerologr.NameSeparator = "/"
@ -98,24 +89,27 @@ func main() {
}
log.Debug().Msg("Starting taskqueue client")
tq, err := taskqueue.NewClient(context.Background(), config.C.NATSStreamName, nc, tel, zerologr.New(&log.Logger).WithName("taskqueue-client"))
tq, err := taskqueue.NewClient(config.C.NATSStreamName, nc, tel, zerologr.New(&log.Logger).WithName("taskqueue-client"))
if err != nil {
log.Fatal().Err(err).Msg("Failed to create taskqueue client")
}
defer tq.Close()
log.Debug().Msg("Running schema migration")
if err := migrateDB(db, zerologr.New(&log.Logger).WithName("migrate-db"), tel); err != nil {
log.Fatal().Err(err).Msg("Failed to run schema migration")
}
log.Debug().Msg("Initializing instance")
if err := initInstance(db, tel); err != nil {
log.Fatal().Err(err).Msg("Failed to initialize instance")
}
// Stateless services
requestSigner := svc_impls.NewRequestSignerImpl(tel, zerologr.New(&log.Logger).WithName("request-signer"))
federationService := svc_impls.NewFederationServiceImpl(httpClient, tel, zerologr.New(&log.Logger).WithName("federation-service"))
taskService := svc_impls.NewTaskServiceImpl(tq, tel, zerologr.New(&log.Logger).WithName("task-service"))
// Manager
// Repositories
repos := repo_impls.NewManagerImpl(
db, tel, zerologr.New(&log.Logger).WithName("repositories"),
@ -134,103 +128,50 @@ func main() {
bodyValidator := val_impls.NewBodyValidator(zerologr.New(&log.Logger).WithName("validation-service"))
requestValidator := val_impls.NewRequestValidator(repos, tel, zerologr.New(&log.Logger).WithName("request-validator"))
// Task handlers
notes := task_impls.NewNoteHandler(federationService, repos, tel, zerologr.New(&log.Logger).WithName("task-note-handler"))
notesSet := registerTaskHandler(rootCtx, "notes", tq, notes)
taskManager := task_impls.NewManager(notes, tel, zerologr.New(&log.Logger).WithName("task-manager"))
// Services
taskService := svc_impls.NewTaskServiceImpl(taskManager, tel, zerologr.New(&log.Logger).WithName("task-service"))
userService := svc_impls.NewUserServiceImpl(repos, federationService, tel, zerologr.New(&log.Logger).WithName("user-service"))
noteService := svc_impls.NewNoteServiceImpl(federationService, taskService, repos, tel, zerologr.New(&log.Logger).WithName("note-service"))
followService := svc_impls.NewFollowServiceImpl(federationService, repos, tel, zerologr.New(&log.Logger).WithName("follow-service"))
inboxService := svc_impls.NewInboxService(federationService, repos, tel, zerologr.New(&log.Logger).WithName("inbox-service"))
instanceMetadataService := svc_impls.NewInstanceMetadataServiceImpl(federationService, repos, tel, zerologr.New(&log.Logger).WithName("instance-metadata-service"))
// Handlers
wg := sync.WaitGroup{}
userHandler := user_handler.New(federationService, requestSigner, userService, inboxService, bodyValidator, requestValidator, zerologr.New(&log.Logger).WithName("user-handler"))
noteHandler := note_handler.New(noteService, bodyValidator, requestSigner, zerologr.New(&log.Logger).WithName("notes-handler"))
followHandler := follow_handler.New(followService, federationService, zerologr.New(&log.Logger).WithName("follow-handler"))
metaHandler := meta_handler.New(instanceMetadataService, zerologr.New(&log.Logger).WithName("meta-handler"))
if config.C.Mode == config.ModeWeb || config.C.Mode == config.ModeCombined {
wg.Add(1)
go func() {
defer wg.Done()
// Initialization
if err := initServerActor(db, tel); err != nil {
log.Fatal().Err(err).Msg("Failed to initialize server actor")
if err := server(
rootCtx,
tel,
db,
nc,
federationService,
requestSigner,
bodyValidator,
requestValidator,
userService,
noteService,
followService,
instanceMetadataService,
inboxService,
); err != nil {
log.Fatal().Err(err).Msg("Failed to start server")
}
}()
}
web := fiber.New(fiber.Config{
ProxyHeader: "X-Forwarded-For",
ErrorHandler: fiberErrorHandler,
DisableStartupMessage: true,
AppName: "versia-go",
EnablePrintRoutes: true,
})
web.Use(cors.New(cors.Config{
AllowOriginsFunc: func(origin string) bool {
return true
},
AllowMethods: "GET,POST,PUT,DELETE,PATCH",
AllowHeaders: "Origin, Content-Type, Accept, Authorization, b3, traceparent, sentry-trace, baggage",
AllowCredentials: true,
ExposeHeaders: "",
MaxAge: 0,
}))
web.Use(unitelhttp.FiberMiddleware(tel, unitelhttp.FiberMiddlewareConfig{
Repanic: false,
WaitForDelivery: false,
Timeout: 5 * time.Second,
// host for incoming requests
TraceRequestHeaders: []string{"origin", "x-nonce", "x-signature", "x-signed-by", "sentry-trace", "sentry-baggage"},
// origin for outgoing requests
TraceResponseHeaders: []string{"host", "x-nonce", "x-signature", "x-signed-by", "sentry-trace", "sentry-baggage"},
IgnoredRoutes: []string{"/api/health"},
Logger: zerologr.New(&log.Logger).WithName("http-server"),
TracePropagator: shouldPropagate,
}))
web.Use(unitelhttp.RequestLogger(zerologr.New(&log.Logger).WithName("http-server"), true, true))
log.Debug().Msg("Registering handlers")
web.Get("/api/health", healthCheck(db, nc))
userHandler.Register(web.Group("/"))
noteHandler.Register(web.Group("/"))
followHandler.Register(web.Group("/"))
metaHandler.Register(web.Group("/"))
wg := sync.WaitGroup{}
wg.Add(2)
// TODO: Run these in separate processes, if wanted
go func() {
defer wg.Done()
log.Debug().Msg("Starting taskqueue consumer")
tasks.NewHandler(federationService, repos, tel, zerologr.New(&log.Logger).WithName("task-handler")).
Register(tq)
if err := tq.StartConsumer(context.Background(), "consumer"); err != nil {
log.Fatal().Err(err).Msg("failed to start taskqueue client")
}
}()
go func() {
defer wg.Done()
log.Debug().Msg("Starting server")
addr := fmt.Sprintf(":%d", config.C.Port)
var err error
if config.C.TLSKey != nil {
err = web.ListenTLS(addr, *config.C.TLSCert, *config.C.TLSKey)
} else {
err = web.Listen(addr)
}
if err != nil {
log.Fatal().Err(err).Msg("Failed to start server")
}
}()
maybeRunTaskHandler(rootCtx, "notes", notesSet, &wg)
signalCh := make(chan os.Signal, 1)
signal.Notify(signalCh, os.Interrupt)
@ -238,10 +179,7 @@ func main() {
log.Info().Msg("Shutting down")
tq.Close()
if err := web.Shutdown(); err != nil {
log.Error().Err(err).Msg("Failed to shutdown server")
}
cancelRoot()
wg.Wait()
}
@ -293,8 +231,8 @@ func migrateDB(db *ent.Client, log logr.Logger, telemetry *unitel.Telemetry) err
return nil
}
func initServerActor(db *ent.Client, telemetry *unitel.Telemetry) error {
s := telemetry.StartSpan(context.Background(), "function", "main.initServerActor")
func initInstance(db *ent.Client, telemetry *unitel.Telemetry) error {
s := telemetry.StartSpan(context.Background(), "function", "main.initInstance")
defer s.End()
ctx := s.Context()
@ -354,27 +292,47 @@ func initServerActor(db *ent.Client, telemetry *unitel.Telemetry) error {
return tx.Finish()
}
func healthCheck(db *ent.Client, nc *nats.Conn) fiber.Handler {
return func(c *fiber.Ctx) error {
dbWorking := true
if err := db.Ping(); err != nil {
log.Error().Err(err).Msg("Database healthcheck failed")
dbWorking = false
}
natsWorking := true
if status := nc.Status(); status != nats.CONNECTED {
log.Error().Str("status", status.String()).Msg("NATS healthcheck failed")
natsWorking = false
}
if dbWorking && natsWorking {
return c.SendString("lookin' good")
}
return api_schema.ErrInternalServerError(map[string]any{
"database": dbWorking,
"nats": natsWorking,
})
func registerTaskHandler[T task.Handler](ctx context.Context, name string, tq *taskqueue.Client, handler T) *taskqueue.Set {
s, err := tq.Set(ctx, name)
if err != nil {
log.Fatal().Err(err).Str("handler", name).Msg("Could not create taskset for task handler")
}
handler.Register(s)
return s
}
func maybeRunTaskHandler(ctx context.Context, name string, set *taskqueue.Set, wg *sync.WaitGroup) {
l := log.With().Str("handler", name).Logger()
if config.C.Mode == config.ModeWeb {
l.Warn().Strs("requested", config.C.Consumers).Msg("Not starting task handler, as this process is running in web mode")
return
}
if config.C.Mode == config.ModeConsumer && !slices.Contains(config.C.Consumers, name) {
l.Warn().Strs("requested", config.C.Consumers).Msg("Not starting task handler, as it wasn't requested")
return
}
wg.Add(1)
c := set.Consumer(name)
if err := c.Start(ctx); err != nil {
l.Fatal().Err(err).Msg("Could not start task handler")
}
l.Info().Msg("Started task handler")
go func() {
defer wg.Done()
<-ctx.Done()
l.Debug().Msg("Got signal to stop task handler")
c.Close()
l.Info().Msg("Stopped task handler")
}()
}

View file

@ -4,8 +4,7 @@ import (
"context"
"encoding/json"
"errors"
"strings"
"sync"
"fmt"
"time"
"git.devminer.xyz/devminer/unitel"
@ -53,12 +52,9 @@ func NewTask(type_ string, payload any) (Task, error) {
}, nil
}
type Handler func(ctx context.Context, task Task) error
type Client struct {
name string
subject string
handlers map[string][]Handler
name string
subject string
nc *nats.Conn
js jetstream.JetStream
@ -71,15 +67,27 @@ type Client struct {
log logr.Logger
}
func NewClient(ctx context.Context, streamName string, natsClient *nats.Conn, telemetry *unitel.Telemetry, log logr.Logger) (*Client, error) {
func NewClient(streamName string, natsClient *nats.Conn, telemetry *unitel.Telemetry, log logr.Logger) (*Client, error) {
js, err := jetstream.New(natsClient)
if err != nil {
return nil, err
}
s, err := js.CreateStream(ctx, jetstream.StreamConfig{
return &Client{
name: streamName,
js: js,
telemetry: telemetry,
log: log,
}, nil
}
func (c *Client) Set(ctx context.Context, name string) (*Set, error) {
streamName := fmt.Sprintf("%s_%s", c.name, name)
s, err := c.js.CreateStream(ctx, jetstream.StreamConfig{
Name: streamName,
Subjects: []string{streamName + ".*"},
MaxConsumers: -1,
MaxMsgs: -1,
Discard: jetstream.DiscardOld,
@ -89,7 +97,7 @@ func NewClient(ctx context.Context, streamName string, natsClient *nats.Conn, te
AllowDirect: true,
})
if errors.Is(err, nats.ErrStreamNameAlreadyInUse) {
s, err = js.Stream(ctx, streamName)
s, err = c.js.Stream(ctx, streamName)
if err != nil {
return nil, err
}
@ -97,190 +105,13 @@ func NewClient(ctx context.Context, streamName string, natsClient *nats.Conn, te
return nil, err
}
stopCh := make(chan struct{})
return &Set{
handlers: make(map[string][]TaskHandler),
c := &Client{
name: streamName,
subject: streamName + ".tasks",
handlers: map[string][]Handler{},
stopCh: stopCh,
closeOnce: sync.OnceFunc(func() {
close(stopCh)
}),
nc: natsClient,
js: js,
s: s,
telemetry: telemetry,
log: log,
}
return c, nil
}
func (c *Client) Close() {
c.closeOnce()
c.nc.Close()
}
func (c *Client) Submit(ctx context.Context, task Task) error {
s := c.telemetry.StartSpan(ctx, "queue.publish", "taskqueue/Client.Submit").
AddAttribute("messaging.destination.name", c.subject)
defer s.End()
ctx = s.Context()
s.AddAttribute("jobID", task.ID)
data, err := json.Marshal(c.newTaskWrapper(ctx, task))
if err != nil {
return err
}
s.AddAttribute("messaging.message.body.size", len(data))
msg, err := c.js.PublishMsg(ctx, &nats.Msg{Subject: c.subject, Data: data})
if err != nil {
return err
}
c.log.V(2).Info("Submitted task", "id", task.ID, "type", task.Type, "sequence", msg.Sequence)
s.AddAttribute("messaging.message.id", msg.Sequence)
return nil
}
func (c *Client) RegisterHandler(type_ string, handler Handler) {
c.log.V(2).Info("Registering handler", "type", type_)
if _, ok := c.handlers[type_]; !ok {
c.handlers[type_] = []Handler{}
}
c.handlers[type_] = append(c.handlers[type_], handler)
}
func (c *Client) StartConsumer(ctx context.Context, consumerGroup string) error {
c.log.Info("Starting consumer")
sub, err := c.js.CreateConsumer(ctx, c.name, jetstream.ConsumerConfig{
Durable: consumerGroup,
DeliverPolicy: jetstream.DeliverAllPolicy,
ReplayPolicy: jetstream.ReplayInstantPolicy,
AckPolicy: jetstream.AckExplicitPolicy,
FilterSubject: c.subject,
MaxWaiting: 1,
MaxAckPending: 1,
HeadersOnly: false,
MemoryStorage: false,
})
if err != nil {
return err
}
m, err := sub.Messages(jetstream.PullMaxMessages(1))
if err != nil {
return err
}
go func() {
for {
msg, err := m.Next()
if err != nil {
if errors.Is(err, jetstream.ErrMsgIteratorClosed) {
c.log.Info("Stopping")
return
}
c.log.Error(err, "Failed to get next message")
break
}
if err := c.handleTask(ctx, msg); err != nil {
c.log.Error(err, "Failed to handle task")
break
}
}
}()
go func() {
<-c.stopCh
m.Drain()
}()
return nil
}
func (c *Client) handleTask(ctx context.Context, msg jetstream.Msg) error {
msgMeta, err := msg.Metadata()
if err != nil {
return err
}
data := msg.Data()
var w taskWrapper
if err := json.Unmarshal(data, &w); err != nil {
if err := msg.Nak(); err != nil {
c.log.Error(err, "Failed to nak message")
}
return err
}
s := c.telemetry.StartSpan(
context.Background(),
"queue.process",
"taskqueue/Client.handleTask",
c.telemetry.ContinueFromMap(w.TraceInfo),
).
AddAttribute("messaging.destination.name", c.subject).
AddAttribute("messaging.message.id", msgMeta.Sequence.Stream).
AddAttribute("messaging.message.retry.count", msgMeta.NumDelivered).
AddAttribute("messaging.message.body.size", len(data)).
AddAttribute("messaging.message.receive.latency", time.Since(w.EnqueuedAt).Milliseconds())
defer s.End()
ctx = s.Context()
handlers, ok := c.handlers[w.Task.Type]
if !ok {
c.log.V(2).Info("No handler for task", "type", w.Task.Type)
return msg.Nak()
}
var errs CombinedError
for _, handler := range handlers {
if err := handler(ctx, w.Task); err != nil {
c.log.Error(err, "Handler failed", "type", w.Task.Type)
errs.Errors = append(errs.Errors, err)
}
}
if len(errs.Errors) > 0 {
if err := msg.Nak(); err != nil {
c.log.Error(err, "Failed to nak message")
errs.Errors = append(errs.Errors, err)
}
return errs
}
return msg.Ack()
}
type CombinedError struct {
Errors []error
}
func (e CombinedError) Error() string {
sb := strings.Builder{}
sb.WriteRune('[')
for i, err := range e.Errors {
if i > 0 {
sb.WriteRune(',')
}
sb.WriteString(err.Error())
}
sb.WriteRune(']')
return sb.String()
streamName: streamName,
c: c,
s: s,
log: c.log.WithName(fmt.Sprintf("taskset(%s)", name)),
telemetry: c.telemetry,
}, nil
}

20
pkg/taskqueue/errors.go Normal file
View file

@ -0,0 +1,20 @@
package taskqueue
import "strings"
type CombinedError struct {
Errors []error
}
func (e CombinedError) Error() string {
sb := strings.Builder{}
sb.WriteRune('[')
for i, err := range e.Errors {
if i > 0 {
sb.WriteRune(',')
}
sb.WriteString(err.Error())
}
sb.WriteRune(']')
return sb.String()
}

210
pkg/taskqueue/taskset.go Normal file
View file

@ -0,0 +1,210 @@
package taskqueue
import (
"context"
"encoding/json"
"errors"
"fmt"
"sync"
"time"
"git.devminer.xyz/devminer/unitel"
"github.com/go-logr/logr"
"github.com/nats-io/nats.go"
"github.com/nats-io/nats.go/jetstream"
)
type TaskHandler = func(ctx context.Context, task Task) error
type Set struct {
handlers map[string][]TaskHandler
streamName string
c *Client
s jetstream.Stream
log logr.Logger
telemetry *unitel.Telemetry
}
func (t *Set) RegisterHandler(type_ string, handler TaskHandler) {
t.log.V(2).Info("Registering handler", "type", type_)
if _, ok := t.handlers[type_]; !ok {
t.handlers[type_] = []TaskHandler{}
}
t.handlers[type_] = append(t.handlers[type_], handler)
}
func (t *Set) Submit(ctx context.Context, task Task) error {
s := t.telemetry.StartSpan(ctx, "queue.publish", "taskqueue/TaskSet.Submit").
AddAttribute("messaging.destination.name", t.streamName)
defer s.End()
ctx = s.Context()
s.AddAttribute("jobID", task.ID)
data, err := json.Marshal(t.c.newTaskWrapper(ctx, task))
if err != nil {
return err
}
s.AddAttribute("messaging.message.body.size", len(data))
// TODO: Refactor
msg, err := t.c.js.PublishMsg(ctx, &nats.Msg{Subject: t.streamName, Data: data})
if err != nil {
return err
}
t.log.V(2).Info("Submitted task", "id", task.ID, "type", task.Type, "sequence", msg.Sequence)
s.AddAttribute("messaging.message.id", msg.Sequence)
return nil
}
func (t *Set) Consumer(name string) *Consumer {
stopCh := make(chan struct{})
stopOnce := sync.OnceFunc(func() {
close(stopCh)
})
return &Consumer{
stopCh: stopCh,
stopOnce: stopOnce,
name: name,
streamName: t.streamName,
telemetry: t.telemetry,
log: t.log.WithName(fmt.Sprintf("consumer(%s)", name)),
t: t,
}
}
type Consumer struct {
stopCh chan struct{}
stopOnce func()
name string
streamName string
telemetry *unitel.Telemetry
log logr.Logger
t *Set
}
func (c *Consumer) Close() {
c.stopOnce()
}
func (c *Consumer) Start(ctx context.Context) error {
c.log.Info("Starting consumer")
sub, err := c.t.c.js.CreateConsumer(ctx, c.streamName, jetstream.ConsumerConfig{
Durable: c.name,
DeliverPolicy: jetstream.DeliverAllPolicy,
ReplayPolicy: jetstream.ReplayInstantPolicy,
AckPolicy: jetstream.AckExplicitPolicy,
MaxWaiting: 1,
MaxAckPending: 1,
HeadersOnly: false,
MemoryStorage: false,
})
if err != nil {
return err
}
m, err := sub.Messages(jetstream.PullMaxMessages(1))
if err != nil {
return err
}
go c.handleMessages(m)
go func() {
<-ctx.Done()
c.Close()
}()
go func() {
<-c.stopCh
m.Drain()
}()
return nil
}
func (c *Consumer) handleMessages(m jetstream.MessagesContext) {
for {
msg, err := m.Next()
if err != nil {
if errors.Is(err, jetstream.ErrMsgIteratorClosed) {
c.log.Info("Stopping")
return
}
c.log.Error(err, "Failed to get next message")
break
}
if err := c.handleTask(msg); err != nil {
c.log.Error(err, "Failed to handle task")
break
}
}
}
func (c *Consumer) handleTask(msg jetstream.Msg) error {
msgMeta, err := msg.Metadata()
if err != nil {
return err
}
data := msg.Data()
var w taskWrapper
if err := json.Unmarshal(data, &w); err != nil {
if err := msg.Nak(); err != nil {
c.log.Error(err, "Failed to nak message")
}
return err
}
s := c.telemetry.StartSpan(
context.Background(),
"queue.process",
"taskqueue/Consumer.handleTask",
c.telemetry.ContinueFromMap(w.TraceInfo),
).
AddAttribute("messaging.destination.name", msg.Subject()).
AddAttribute("messaging.message.id", msgMeta.Sequence.Stream).
AddAttribute("messaging.message.retry.count", msgMeta.NumDelivered).
AddAttribute("messaging.message.body.size", len(data)).
AddAttribute("messaging.message.receive.latency", time.Since(w.EnqueuedAt).Milliseconds())
defer s.End()
ctx := s.Context()
handlers, ok := c.t.handlers[w.Task.Type]
if !ok {
c.log.V(2).Info("No handler for task", "type", w.Task.Type)
return msg.Nak()
}
var errs CombinedError
for _, handler := range handlers {
if err := handler(ctx, w.Task); err != nil {
c.log.Error(err, "Handler failed", "type", w.Task.Type)
errs.Errors = append(errs.Errors, err)
}
}
if len(errs.Errors) > 0 {
if err := msg.Nak(); err != nil {
c.log.Error(err, "Failed to nak message")
errs.Errors = append(errs.Errors, err)
}
return errs
}
return msg.Ack()
}

146
server.go Normal file
View file

@ -0,0 +1,146 @@
package main
import (
"context"
"fmt"
"git.devminer.xyz/devminer/unitel"
"git.devminer.xyz/devminer/unitel/unitelhttp"
"github.com/versia-pub/versia-go/internal/api_schema"
"github.com/versia-pub/versia-go/internal/handlers/follow_handler"
"github.com/versia-pub/versia-go/internal/handlers/meta_handler"
"github.com/versia-pub/versia-go/internal/handlers/note_handler"
"github.com/versia-pub/versia-go/internal/service"
"github.com/versia-pub/versia-go/internal/validators"
"net/http"
"sync"
"time"
"github.com/go-logr/zerologr"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/middleware/cors"
"github.com/nats-io/nats.go"
"github.com/rs/zerolog/log"
"github.com/versia-pub/versia-go/config"
"github.com/versia-pub/versia-go/ent"
"github.com/versia-pub/versia-go/internal/handlers/user_handler"
)
func shouldPropagate(r *http.Request) bool {
return config.C.ForwardTracesTo.Match([]byte(r.URL.String()))
}
func server(
ctx context.Context,
telemetry *unitel.Telemetry,
database *ent.Client,
natsConn *nats.Conn,
federationService service.FederationService,
requestSigner service.RequestSigner,
bodyValidator validators.BodyValidator,
requestValidator validators.RequestValidator,
userService service.UserService,
noteService service.NoteService,
followService service.FollowService,
instanceMetadataService service.InstanceMetadataService,
inboxService service.InboxService,
) error {
// Handlers
userHandler := user_handler.New(federationService, requestSigner, userService, inboxService, bodyValidator, requestValidator, zerologr.New(&log.Logger).WithName("user-handler"))
noteHandler := note_handler.New(noteService, bodyValidator, requestSigner, zerologr.New(&log.Logger).WithName("notes-handler"))
followHandler := follow_handler.New(followService, federationService, zerologr.New(&log.Logger).WithName("follow-handler"))
metaHandler := meta_handler.New(instanceMetadataService, zerologr.New(&log.Logger).WithName("meta-handler"))
// Initialization
web := fiber.New(fiber.Config{
ProxyHeader: "X-Forwarded-For",
ErrorHandler: fiberErrorHandler,
DisableStartupMessage: true,
AppName: "versia-go",
EnablePrintRoutes: true,
})
web.Use(cors.New(cors.Config{
AllowOriginsFunc: func(origin string) bool {
return true
},
AllowMethods: "GET,POST,PUT,DELETE,PATCH",
AllowHeaders: "Origin, Content-Type, Accept, Authorization, b3, traceparent, sentry-trace, baggage",
AllowCredentials: true,
ExposeHeaders: "",
MaxAge: 0,
}))
web.Use(unitelhttp.FiberMiddleware(telemetry, unitelhttp.FiberMiddlewareConfig{
Repanic: false,
WaitForDelivery: false,
Timeout: 5 * time.Second,
// host for incoming requests
TraceRequestHeaders: []string{"origin", "x-nonce", "x-signature", "x-signed-by", "sentry-trace", "sentry-baggage"},
// origin for outgoing requests
TraceResponseHeaders: []string{"host", "x-nonce", "x-signature", "x-signed-by", "sentry-trace", "sentry-baggage"},
IgnoredRoutes: []string{"/api/health"},
Logger: zerologr.New(&log.Logger).WithName("http-server"),
TracePropagator: shouldPropagate,
}))
web.Use(unitelhttp.RequestLogger(zerologr.New(&log.Logger).WithName("http-server"), true, true))
log.Debug().Msg("Registering handlers")
web.Get("/api/health", healthCheck(database, natsConn))
userHandler.Register(web.Group("/"))
noteHandler.Register(web.Group("/"))
followHandler.Register(web.Group("/"))
metaHandler.Register(web.Group("/"))
wg := sync.WaitGroup{}
wg.Add(2)
addr := fmt.Sprintf(":%d", config.C.Port)
log.Info().Str("addr", addr).Msg("Starting server")
go func() {
<-ctx.Done()
if err := web.Shutdown(); err != nil {
log.Error().Err(err).Msg("Failed to shutdown server")
}
}()
var err error
if config.C.TLSKey != nil {
err = web.ListenTLS(addr, *config.C.TLSCert, *config.C.TLSKey)
} else {
err = web.Listen(addr)
}
return err
}
func healthCheck(db *ent.Client, nc *nats.Conn) fiber.Handler {
return func(c *fiber.Ctx) error {
dbWorking := true
if err := db.Ping(); err != nil {
log.Error().Err(err).Msg("Database healthcheck failed")
dbWorking = false
}
natsWorking := true
if status := nc.Status(); status != nats.CONNECTED {
log.Error().Str("status", status.String()).Msg("NATS healthcheck failed")
natsWorking = false
}
if dbWorking && natsWorking {
return c.SendString("lookin' good")
}
return api_schema.ErrInternalServerError(map[string]any{
"database": dbWorking,
"nats": natsWorking,
})
}
}