mirror of
https://github.com/versia-pub/versia-go.git
synced 2025-12-06 06:28:18 +01:00
refactor: workers
This commit is contained in:
parent
02ad720439
commit
8fa44e5f72
|
|
@ -1,13 +1,24 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"git.devminer.xyz/devminer/unitel"
|
||||
"github.com/joho/godotenv"
|
||||
"github.com/rs/zerolog/log"
|
||||
"net/url"
|
||||
"os"
|
||||
"regexp"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"git.devminer.xyz/devminer/unitel"
|
||||
"github.com/joho/godotenv"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
type Mode string
|
||||
|
||||
const (
|
||||
ModeCombined Mode = "combined"
|
||||
ModeWeb Mode = "web"
|
||||
ModeConsumer Mode = "consumer"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
|
|
@ -25,6 +36,9 @@ type Config struct {
|
|||
NATSURI string
|
||||
NATSStreamName string
|
||||
|
||||
Mode Mode
|
||||
Consumers []string
|
||||
|
||||
DatabaseURI string
|
||||
|
||||
Telemetry unitel.Opts
|
||||
|
|
@ -62,6 +76,13 @@ func Load() {
|
|||
Msg("Both VERSIA_TLS_KEY and VERSIA_TLS_CERT have to be set if you want to use in-process TLS termination.")
|
||||
}
|
||||
|
||||
mode := getEnvStrOneOf("VERSIA_MODE", ModeCombined, ModeCombined, ModeWeb, ModeConsumer)
|
||||
|
||||
var consumers []string
|
||||
if raw := optionalEnvStr("VERSIA_TQ_CUSTOMERS"); raw != nil {
|
||||
consumers = strings.Split(*raw, ",")
|
||||
}
|
||||
|
||||
C = Config{
|
||||
Port: getEnvInt("VERSIA_PORT", 80),
|
||||
TLSCert: tlsCert,
|
||||
|
|
@ -76,13 +97,15 @@ func Load() {
|
|||
|
||||
NATSURI: os.Getenv("NATS_URI"),
|
||||
NATSStreamName: getEnvStr("NATS_STREAM_NAME", "versia-go"),
|
||||
DatabaseURI: os.Getenv("DATABASE_URI"),
|
||||
|
||||
Mode: mode,
|
||||
Consumers: consumers,
|
||||
|
||||
DatabaseURI: os.Getenv("DATABASE_URI"),
|
||||
|
||||
ForwardTracesTo: forwardTracesTo,
|
||||
Telemetry: unitel.ParseOpts("versia-go"),
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func optionalEnvStr(key string) *string {
|
||||
|
|
@ -93,6 +116,18 @@ func optionalEnvStr(key string) *string {
|
|||
return &value
|
||||
}
|
||||
|
||||
func getEnvBool(key string, default_ bool) bool {
|
||||
if value, ok := os.LookupEnv(key); ok {
|
||||
b, err := strconv.ParseBool(value)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
return default_
|
||||
}
|
||||
|
||||
func getEnvStr(key, default_ string) string {
|
||||
if value, ok := os.LookupEnv(key); ok {
|
||||
return value
|
||||
|
|
@ -113,3 +148,22 @@ func getEnvInt(key string, default_ int) int {
|
|||
|
||||
return default_
|
||||
}
|
||||
|
||||
func getEnvStrOneOf[T ~string](key string, default_ T, enum ...T) T {
|
||||
if value, ok := os.LookupEnv(key); ok {
|
||||
if !slices.Contains(enum, T(value)) {
|
||||
sb := strings.Builder{}
|
||||
sb.WriteString(key)
|
||||
sb.WriteString(" can only be one of ")
|
||||
for _, v := range enum {
|
||||
sb.WriteString(string(v))
|
||||
}
|
||||
|
||||
panic(sb.String())
|
||||
}
|
||||
|
||||
return T(value)
|
||||
}
|
||||
|
||||
return default_
|
||||
}
|
||||
|
|
|
|||
|
|
@ -89,6 +89,10 @@ func (i *ManagerImpl) Atomic(ctx context.Context, fn func(ctx context.Context, t
|
|||
return tx.Finish()
|
||||
}
|
||||
|
||||
func (i *ManagerImpl) Ping() error {
|
||||
return i.db.Ping()
|
||||
}
|
||||
|
||||
func (i *ManagerImpl) Users() repository.UserRepository {
|
||||
return i.users
|
||||
}
|
||||
|
|
|
|||
|
|
@ -51,6 +51,7 @@ type InstanceMetadataRepository interface {
|
|||
|
||||
type Manager interface {
|
||||
Atomic(ctx context.Context, fn func(ctx context.Context, tx Manager) error) error
|
||||
Ping() error
|
||||
|
||||
Users() UserRepository
|
||||
Notes() NoteRepository
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ package service
|
|||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/versia-pub/versia-go/internal/repository"
|
||||
"github.com/versia-pub/versia-go/pkg/versia"
|
||||
|
|
@ -57,7 +58,7 @@ type InstanceMetadataService interface {
|
|||
}
|
||||
|
||||
type TaskService interface {
|
||||
ScheduleTask(ctx context.Context, type_ string, data any) error
|
||||
ScheduleNoteTask(ctx context.Context, type_ string, data any) error
|
||||
}
|
||||
|
||||
type RequestSigner interface {
|
||||
|
|
|
|||
|
|
@ -2,17 +2,18 @@ package svc_impls
|
|||
|
||||
import (
|
||||
"context"
|
||||
"slices"
|
||||
|
||||
"github.com/versia-pub/versia-go/internal/repository"
|
||||
"github.com/versia-pub/versia-go/internal/service"
|
||||
task_dtos "github.com/versia-pub/versia-go/internal/task/dtos"
|
||||
"github.com/versia-pub/versia-go/pkg/versia"
|
||||
"slices"
|
||||
|
||||
"git.devminer.xyz/devminer/unitel"
|
||||
"github.com/go-logr/logr"
|
||||
"github.com/google/uuid"
|
||||
"github.com/versia-pub/versia-go/internal/api_schema"
|
||||
"github.com/versia-pub/versia-go/internal/entity"
|
||||
"github.com/versia-pub/versia-go/internal/tasks"
|
||||
)
|
||||
|
||||
var _ service.NoteService = (*NoteServiceImpl)(nil)
|
||||
|
|
@ -69,7 +70,7 @@ func (i NoteServiceImpl) CreateNote(ctx context.Context, req api_schema.CreateNo
|
|||
return err
|
||||
}
|
||||
|
||||
if err := i.taskService.ScheduleTask(ctx, tasks.FederateNote, tasks.FederateNoteData{NoteID: n.ID}); err != nil {
|
||||
if err := i.taskService.ScheduleNoteTask(ctx, task_dtos.FederateNote, task_dtos.FederateNoteData{NoteID: n.ID}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -2,7 +2,9 @@ package svc_impls
|
|||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/versia-pub/versia-go/internal/service"
|
||||
"github.com/versia-pub/versia-go/internal/task"
|
||||
|
||||
"git.devminer.xyz/devminer/unitel"
|
||||
"github.com/go-logr/logr"
|
||||
|
|
@ -12,22 +14,22 @@ import (
|
|||
var _ service.TaskService = (*TaskServiceImpl)(nil)
|
||||
|
||||
type TaskServiceImpl struct {
|
||||
client *taskqueue.Client
|
||||
manager task.Manager
|
||||
|
||||
telemetry *unitel.Telemetry
|
||||
log logr.Logger
|
||||
}
|
||||
|
||||
func NewTaskServiceImpl(client *taskqueue.Client, telemetry *unitel.Telemetry, log logr.Logger) *TaskServiceImpl {
|
||||
func NewTaskServiceImpl(manager task.Manager, telemetry *unitel.Telemetry, log logr.Logger) *TaskServiceImpl {
|
||||
return &TaskServiceImpl{
|
||||
client: client,
|
||||
manager: manager,
|
||||
|
||||
telemetry: telemetry,
|
||||
log: log,
|
||||
}
|
||||
}
|
||||
|
||||
func (i TaskServiceImpl) ScheduleTask(ctx context.Context, type_ string, data any) error {
|
||||
func (i TaskServiceImpl) ScheduleNoteTask(ctx context.Context, type_ string, data any) error {
|
||||
s := i.telemetry.StartSpan(ctx, "function", "svc_impls/TaskServiceImpl.ScheduleTask")
|
||||
defer s.End()
|
||||
ctx = s.Context()
|
||||
|
|
@ -38,7 +40,7 @@ func (i TaskServiceImpl) ScheduleTask(ctx context.Context, type_ string, data an
|
|||
return err
|
||||
}
|
||||
|
||||
if err := i.client.Submit(ctx, t); err != nil {
|
||||
if err := i.manager.Notes().Submit(ctx, t); err != nil {
|
||||
i.log.Error(err, "Failed to schedule task", "type", type_, "taskID", t.ID)
|
||||
return err
|
||||
}
|
||||
|
|
|
|||
11
internal/task/dtos/note_dtos.go
Normal file
11
internal/task/dtos/note_dtos.go
Normal file
|
|
@ -0,0 +1,11 @@
|
|||
package task_dtos
|
||||
|
||||
import "github.com/google/uuid"
|
||||
|
||||
const (
|
||||
FederateNote = "federate_note"
|
||||
)
|
||||
|
||||
type FederateNoteData struct {
|
||||
NoteID uuid.UUID `json:"noteID"`
|
||||
}
|
||||
20
internal/task/handler.go
Normal file
20
internal/task/handler.go
Normal file
|
|
@ -0,0 +1,20 @@
|
|||
package task
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/versia-pub/versia-go/pkg/taskqueue"
|
||||
)
|
||||
|
||||
type Manager interface {
|
||||
Notes() NoteHandler
|
||||
}
|
||||
|
||||
type Handler interface {
|
||||
Register(*taskqueue.Set)
|
||||
Submit(context.Context, taskqueue.Task) error
|
||||
}
|
||||
|
||||
type NoteHandler interface {
|
||||
Submit(context.Context, taskqueue.Task) error
|
||||
}
|
||||
11
internal/task/task_impls/base.go
Normal file
11
internal/task/task_impls/base.go
Normal file
|
|
@ -0,0 +1,11 @@
|
|||
package task_impls
|
||||
|
||||
import "git.devminer.xyz/devminer/unitel"
|
||||
|
||||
type baseHandler struct {
|
||||
telemetry *unitel.Telemetry
|
||||
}
|
||||
|
||||
func newBaseHandler() *baseHandler {
|
||||
return &baseHandler{}
|
||||
}
|
||||
29
internal/task/task_impls/manager.go
Normal file
29
internal/task/task_impls/manager.go
Normal file
|
|
@ -0,0 +1,29 @@
|
|||
package task_impls
|
||||
|
||||
import (
|
||||
"git.devminer.xyz/devminer/unitel"
|
||||
"github.com/go-logr/logr"
|
||||
"github.com/versia-pub/versia-go/internal/task"
|
||||
)
|
||||
|
||||
var _ task.Manager = (*Manager)(nil)
|
||||
|
||||
type Manager struct {
|
||||
notes *NoteHandler
|
||||
|
||||
telemetry *unitel.Telemetry
|
||||
log logr.Logger
|
||||
}
|
||||
|
||||
func NewManager(notes *NoteHandler, telemetry *unitel.Telemetry, log logr.Logger) *Manager {
|
||||
return &Manager{
|
||||
notes: notes,
|
||||
|
||||
telemetry: telemetry,
|
||||
log: log,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) Notes() task.NoteHandler {
|
||||
return m.notes
|
||||
}
|
||||
97
internal/task/task_impls/note_handler.go
Normal file
97
internal/task/task_impls/note_handler.go
Normal file
|
|
@ -0,0 +1,97 @@
|
|||
package task_impls
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/versia-pub/versia-go/internal/entity"
|
||||
"github.com/versia-pub/versia-go/internal/repository"
|
||||
"github.com/versia-pub/versia-go/internal/service"
|
||||
"github.com/versia-pub/versia-go/internal/task"
|
||||
task_dtos "github.com/versia-pub/versia-go/internal/task/dtos"
|
||||
"github.com/versia-pub/versia-go/internal/utils"
|
||||
|
||||
"git.devminer.xyz/devminer/unitel"
|
||||
"github.com/go-logr/logr"
|
||||
"github.com/versia-pub/versia-go/pkg/taskqueue"
|
||||
)
|
||||
|
||||
var _ task.Handler = (*NoteHandler)(nil)
|
||||
|
||||
type NoteHandler struct {
|
||||
federationService service.FederationService
|
||||
|
||||
repositories repository.Manager
|
||||
|
||||
telemetry *unitel.Telemetry
|
||||
log logr.Logger
|
||||
set *taskqueue.Set
|
||||
}
|
||||
|
||||
func NewNoteHandler(federationService service.FederationService, repositories repository.Manager, telemetry *unitel.Telemetry, log logr.Logger) *NoteHandler {
|
||||
return &NoteHandler{
|
||||
federationService: federationService,
|
||||
|
||||
repositories: repositories,
|
||||
|
||||
telemetry: telemetry,
|
||||
log: log,
|
||||
}
|
||||
}
|
||||
|
||||
func (t *NoteHandler) Start(ctx context.Context) error {
|
||||
consumer := t.set.Consumer("note-handler")
|
||||
|
||||
return consumer.Start(ctx)
|
||||
}
|
||||
|
||||
func (t *NoteHandler) Register(s *taskqueue.Set) {
|
||||
t.set = s
|
||||
s.RegisterHandler(task_dtos.FederateNote, utils.ParseTask(t.FederateNote))
|
||||
}
|
||||
|
||||
func (t *NoteHandler) Submit(ctx context.Context, task taskqueue.Task) error {
|
||||
s := t.telemetry.StartSpan(ctx, "function", "task_impls/NoteHandler.Submit")
|
||||
defer s.End()
|
||||
ctx = s.Context()
|
||||
|
||||
return t.set.Submit(ctx, task)
|
||||
}
|
||||
|
||||
func (t *NoteHandler) FederateNote(ctx context.Context, data task_dtos.FederateNoteData) error {
|
||||
s := t.telemetry.StartSpan(ctx, "function", "task_impls/NoteHandler.FederateNote")
|
||||
defer s.End()
|
||||
ctx = s.Context()
|
||||
|
||||
var n *entity.Note
|
||||
if err := t.repositories.Atomic(ctx, func(ctx context.Context, tx repository.Manager) error {
|
||||
var err error
|
||||
n, err = tx.Notes().GetByID(ctx, data.NoteID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if n == nil {
|
||||
t.log.V(-1).Info("Could not find note", "id", data.NoteID)
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, uu := range n.Mentions {
|
||||
if !uu.IsRemote {
|
||||
t.log.V(2).Info("User is not remote", "user", uu.ID)
|
||||
continue
|
||||
}
|
||||
|
||||
res, err := t.federationService.SendToInbox(ctx, n.Author, &uu, n.ToVersia())
|
||||
if err != nil {
|
||||
t.log.Error(err, "Failed to send note to remote user", "res", res, "user", uu.ID)
|
||||
} else {
|
||||
t.log.V(2).Info("Sent note to remote user", "res", res, "user", uu.ID)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
@ -1,11 +0,0 @@
|
|||
package tasks
|
||||
|
||||
import "context"
|
||||
|
||||
type FederateFollowData struct {
|
||||
FollowID string `json:"followID"`
|
||||
}
|
||||
|
||||
func (t *Handler) FederateFollow(ctx context.Context, data FederateNoteData) error {
|
||||
return nil
|
||||
}
|
||||
|
|
@ -1,52 +0,0 @@
|
|||
package tasks
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/versia-pub/versia-go/internal/repository"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/versia-pub/versia-go/internal/entity"
|
||||
)
|
||||
|
||||
type FederateNoteData struct {
|
||||
NoteID uuid.UUID `json:"noteID"`
|
||||
}
|
||||
|
||||
func (t *Handler) FederateNote(ctx context.Context, data FederateNoteData) error {
|
||||
s := t.telemetry.StartSpan(ctx, "function", "tasks/Handler.FederateNote")
|
||||
defer s.End()
|
||||
ctx = s.Context()
|
||||
|
||||
var n *entity.Note
|
||||
if err := t.repositories.Atomic(ctx, func(ctx context.Context, tx repository.Manager) error {
|
||||
var err error
|
||||
n, err = tx.Notes().GetByID(ctx, data.NoteID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if n == nil {
|
||||
t.log.V(-1).Info("Could not find note", "id", data.NoteID)
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, uu := range n.Mentions {
|
||||
if !uu.IsRemote {
|
||||
t.log.V(2).Info("User is not remote", "user", uu.ID)
|
||||
continue
|
||||
}
|
||||
|
||||
res, err := t.federationService.SendToInbox(ctx, n.Author, &uu, n.ToVersia())
|
||||
if err != nil {
|
||||
t.log.Error(err, "Failed to send note to remote user", "res", res, "user", uu.ID)
|
||||
} else {
|
||||
t.log.V(2).Info("Sent note to remote user", "res", res, "user", uu.ID)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
@ -1,53 +0,0 @@
|
|||
package tasks
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"github.com/versia-pub/versia-go/internal/repository"
|
||||
"github.com/versia-pub/versia-go/internal/service"
|
||||
|
||||
"git.devminer.xyz/devminer/unitel"
|
||||
"github.com/go-logr/logr"
|
||||
"github.com/versia-pub/versia-go/pkg/taskqueue"
|
||||
)
|
||||
|
||||
const (
|
||||
FederateNote = "federate_note"
|
||||
FederateFollow = "federate_follow"
|
||||
)
|
||||
|
||||
type Handler struct {
|
||||
federationService service.FederationService
|
||||
|
||||
repositories repository.Manager
|
||||
|
||||
telemetry *unitel.Telemetry
|
||||
log logr.Logger
|
||||
}
|
||||
|
||||
func NewHandler(federationService service.FederationService, repositories repository.Manager, telemetry *unitel.Telemetry, log logr.Logger) *Handler {
|
||||
return &Handler{
|
||||
federationService: federationService,
|
||||
|
||||
repositories: repositories,
|
||||
|
||||
telemetry: telemetry,
|
||||
log: log,
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Handler) Register(tq *taskqueue.Client) {
|
||||
tq.RegisterHandler(FederateNote, parse(t.FederateNote))
|
||||
tq.RegisterHandler(FederateFollow, parse(t.FederateFollow))
|
||||
}
|
||||
|
||||
func parse[T any](handler func(context.Context, T) error) func(context.Context, taskqueue.Task) error {
|
||||
return func(ctx context.Context, task taskqueue.Task) error {
|
||||
var data T
|
||||
if err := json.Unmarshal(task.Payload, &data); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return handler(ctx, data)
|
||||
}
|
||||
}
|
||||
19
internal/utils/tasks.go
Normal file
19
internal/utils/tasks.go
Normal file
|
|
@ -0,0 +1,19 @@
|
|||
package utils
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/versia-pub/versia-go/pkg/taskqueue"
|
||||
)
|
||||
|
||||
func ParseTask[T any](handler func(context.Context, T) error) func(context.Context, taskqueue.Task) error {
|
||||
return func(ctx context.Context, task taskqueue.Task) error {
|
||||
var data T
|
||||
if err := json.Unmarshal(task.Payload, &data); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return handler(ctx, data)
|
||||
}
|
||||
}
|
||||
234
main.go
234
main.go
|
|
@ -7,42 +7,35 @@ import (
|
|||
"crypto/tls"
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"fmt"
|
||||
"git.devminer.xyz/devminer/unitel/unitelhttp"
|
||||
"git.devminer.xyz/devminer/unitel/unitelsql"
|
||||
"github.com/versia-pub/versia-go/ent/instancemetadata"
|
||||
"github.com/versia-pub/versia-go/internal/api_schema"
|
||||
"github.com/versia-pub/versia-go/internal/handlers/follow_handler"
|
||||
"github.com/versia-pub/versia-go/internal/handlers/meta_handler"
|
||||
"github.com/versia-pub/versia-go/internal/handlers/note_handler"
|
||||
"github.com/versia-pub/versia-go/internal/repository"
|
||||
"github.com/versia-pub/versia-go/internal/repository/repo_impls"
|
||||
"github.com/versia-pub/versia-go/internal/service/svc_impls"
|
||||
"github.com/versia-pub/versia-go/internal/validators/val_impls"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"entgo.io/ent/dialect"
|
||||
entsql "entgo.io/ent/dialect/sql"
|
||||
"git.devminer.xyz/devminer/unitel"
|
||||
"git.devminer.xyz/devminer/unitel/unitelhttp"
|
||||
"git.devminer.xyz/devminer/unitel/unitelsql"
|
||||
"github.com/go-logr/logr"
|
||||
"github.com/go-logr/zerologr"
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/middleware/cors"
|
||||
pgx "github.com/jackc/pgx/v5/stdlib"
|
||||
"github.com/nats-io/nats.go"
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/versia-pub/versia-go/config"
|
||||
"github.com/versia-pub/versia-go/ent"
|
||||
"github.com/versia-pub/versia-go/ent/instancemetadata"
|
||||
"github.com/versia-pub/versia-go/internal/database"
|
||||
"github.com/versia-pub/versia-go/internal/handlers/user_handler"
|
||||
"github.com/versia-pub/versia-go/internal/tasks"
|
||||
"github.com/versia-pub/versia-go/internal/repository"
|
||||
"github.com/versia-pub/versia-go/internal/repository/repo_impls"
|
||||
"github.com/versia-pub/versia-go/internal/service/svc_impls"
|
||||
"github.com/versia-pub/versia-go/internal/task"
|
||||
"github.com/versia-pub/versia-go/internal/task/task_impls"
|
||||
"github.com/versia-pub/versia-go/internal/utils"
|
||||
"github.com/versia-pub/versia-go/internal/validators/val_impls"
|
||||
"github.com/versia-pub/versia-go/pkg/taskqueue"
|
||||
"modernc.org/sqlite"
|
||||
)
|
||||
|
|
@ -52,11 +45,9 @@ func init() {
|
|||
log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr})
|
||||
}
|
||||
|
||||
func shouldPropagate(r *http.Request) bool {
|
||||
return config.C.ForwardTracesTo.Match([]byte(r.URL.String()))
|
||||
}
|
||||
|
||||
func main() {
|
||||
rootCtx, cancelRoot := context.WithCancel(context.Background())
|
||||
|
||||
zerolog.SetGlobalLevel(zerolog.TraceLevel)
|
||||
zerologr.NameFieldName = "logger"
|
||||
zerologr.NameSeparator = "/"
|
||||
|
|
@ -98,24 +89,27 @@ func main() {
|
|||
}
|
||||
|
||||
log.Debug().Msg("Starting taskqueue client")
|
||||
tq, err := taskqueue.NewClient(context.Background(), config.C.NATSStreamName, nc, tel, zerologr.New(&log.Logger).WithName("taskqueue-client"))
|
||||
tq, err := taskqueue.NewClient(config.C.NATSStreamName, nc, tel, zerologr.New(&log.Logger).WithName("taskqueue-client"))
|
||||
if err != nil {
|
||||
log.Fatal().Err(err).Msg("Failed to create taskqueue client")
|
||||
}
|
||||
defer tq.Close()
|
||||
|
||||
log.Debug().Msg("Running schema migration")
|
||||
if err := migrateDB(db, zerologr.New(&log.Logger).WithName("migrate-db"), tel); err != nil {
|
||||
log.Fatal().Err(err).Msg("Failed to run schema migration")
|
||||
}
|
||||
|
||||
log.Debug().Msg("Initializing instance")
|
||||
if err := initInstance(db, tel); err != nil {
|
||||
log.Fatal().Err(err).Msg("Failed to initialize instance")
|
||||
}
|
||||
|
||||
// Stateless services
|
||||
|
||||
requestSigner := svc_impls.NewRequestSignerImpl(tel, zerologr.New(&log.Logger).WithName("request-signer"))
|
||||
federationService := svc_impls.NewFederationServiceImpl(httpClient, tel, zerologr.New(&log.Logger).WithName("federation-service"))
|
||||
taskService := svc_impls.NewTaskServiceImpl(tq, tel, zerologr.New(&log.Logger).WithName("task-service"))
|
||||
|
||||
// Manager
|
||||
// Repositories
|
||||
|
||||
repos := repo_impls.NewManagerImpl(
|
||||
db, tel, zerologr.New(&log.Logger).WithName("repositories"),
|
||||
|
|
@ -134,103 +128,50 @@ func main() {
|
|||
bodyValidator := val_impls.NewBodyValidator(zerologr.New(&log.Logger).WithName("validation-service"))
|
||||
requestValidator := val_impls.NewRequestValidator(repos, tel, zerologr.New(&log.Logger).WithName("request-validator"))
|
||||
|
||||
// Task handlers
|
||||
|
||||
notes := task_impls.NewNoteHandler(federationService, repos, tel, zerologr.New(&log.Logger).WithName("task-note-handler"))
|
||||
notesSet := registerTaskHandler(rootCtx, "notes", tq, notes)
|
||||
|
||||
taskManager := task_impls.NewManager(notes, tel, zerologr.New(&log.Logger).WithName("task-manager"))
|
||||
|
||||
// Services
|
||||
|
||||
taskService := svc_impls.NewTaskServiceImpl(taskManager, tel, zerologr.New(&log.Logger).WithName("task-service"))
|
||||
userService := svc_impls.NewUserServiceImpl(repos, federationService, tel, zerologr.New(&log.Logger).WithName("user-service"))
|
||||
noteService := svc_impls.NewNoteServiceImpl(federationService, taskService, repos, tel, zerologr.New(&log.Logger).WithName("note-service"))
|
||||
followService := svc_impls.NewFollowServiceImpl(federationService, repos, tel, zerologr.New(&log.Logger).WithName("follow-service"))
|
||||
inboxService := svc_impls.NewInboxService(federationService, repos, tel, zerologr.New(&log.Logger).WithName("inbox-service"))
|
||||
instanceMetadataService := svc_impls.NewInstanceMetadataServiceImpl(federationService, repos, tel, zerologr.New(&log.Logger).WithName("instance-metadata-service"))
|
||||
|
||||
// Handlers
|
||||
wg := sync.WaitGroup{}
|
||||
|
||||
userHandler := user_handler.New(federationService, requestSigner, userService, inboxService, bodyValidator, requestValidator, zerologr.New(&log.Logger).WithName("user-handler"))
|
||||
noteHandler := note_handler.New(noteService, bodyValidator, requestSigner, zerologr.New(&log.Logger).WithName("notes-handler"))
|
||||
followHandler := follow_handler.New(followService, federationService, zerologr.New(&log.Logger).WithName("follow-handler"))
|
||||
metaHandler := meta_handler.New(instanceMetadataService, zerologr.New(&log.Logger).WithName("meta-handler"))
|
||||
if config.C.Mode == config.ModeWeb || config.C.Mode == config.ModeCombined {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
// Initialization
|
||||
|
||||
if err := initServerActor(db, tel); err != nil {
|
||||
log.Fatal().Err(err).Msg("Failed to initialize server actor")
|
||||
if err := server(
|
||||
rootCtx,
|
||||
tel,
|
||||
db,
|
||||
nc,
|
||||
federationService,
|
||||
requestSigner,
|
||||
bodyValidator,
|
||||
requestValidator,
|
||||
userService,
|
||||
noteService,
|
||||
followService,
|
||||
instanceMetadataService,
|
||||
inboxService,
|
||||
); err != nil {
|
||||
log.Fatal().Err(err).Msg("Failed to start server")
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
web := fiber.New(fiber.Config{
|
||||
ProxyHeader: "X-Forwarded-For",
|
||||
ErrorHandler: fiberErrorHandler,
|
||||
DisableStartupMessage: true,
|
||||
AppName: "versia-go",
|
||||
EnablePrintRoutes: true,
|
||||
})
|
||||
|
||||
web.Use(cors.New(cors.Config{
|
||||
AllowOriginsFunc: func(origin string) bool {
|
||||
return true
|
||||
},
|
||||
AllowMethods: "GET,POST,PUT,DELETE,PATCH",
|
||||
AllowHeaders: "Origin, Content-Type, Accept, Authorization, b3, traceparent, sentry-trace, baggage",
|
||||
AllowCredentials: true,
|
||||
ExposeHeaders: "",
|
||||
MaxAge: 0,
|
||||
}))
|
||||
|
||||
web.Use(unitelhttp.FiberMiddleware(tel, unitelhttp.FiberMiddlewareConfig{
|
||||
Repanic: false,
|
||||
WaitForDelivery: false,
|
||||
Timeout: 5 * time.Second,
|
||||
// host for incoming requests
|
||||
TraceRequestHeaders: []string{"origin", "x-nonce", "x-signature", "x-signed-by", "sentry-trace", "sentry-baggage"},
|
||||
// origin for outgoing requests
|
||||
TraceResponseHeaders: []string{"host", "x-nonce", "x-signature", "x-signed-by", "sentry-trace", "sentry-baggage"},
|
||||
IgnoredRoutes: []string{"/api/health"},
|
||||
Logger: zerologr.New(&log.Logger).WithName("http-server"),
|
||||
TracePropagator: shouldPropagate,
|
||||
}))
|
||||
web.Use(unitelhttp.RequestLogger(zerologr.New(&log.Logger).WithName("http-server"), true, true))
|
||||
|
||||
log.Debug().Msg("Registering handlers")
|
||||
|
||||
web.Get("/api/health", healthCheck(db, nc))
|
||||
|
||||
userHandler.Register(web.Group("/"))
|
||||
noteHandler.Register(web.Group("/"))
|
||||
followHandler.Register(web.Group("/"))
|
||||
metaHandler.Register(web.Group("/"))
|
||||
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(2)
|
||||
|
||||
// TODO: Run these in separate processes, if wanted
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
log.Debug().Msg("Starting taskqueue consumer")
|
||||
|
||||
tasks.NewHandler(federationService, repos, tel, zerologr.New(&log.Logger).WithName("task-handler")).
|
||||
Register(tq)
|
||||
|
||||
if err := tq.StartConsumer(context.Background(), "consumer"); err != nil {
|
||||
log.Fatal().Err(err).Msg("failed to start taskqueue client")
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
log.Debug().Msg("Starting server")
|
||||
|
||||
addr := fmt.Sprintf(":%d", config.C.Port)
|
||||
|
||||
var err error
|
||||
if config.C.TLSKey != nil {
|
||||
err = web.ListenTLS(addr, *config.C.TLSCert, *config.C.TLSKey)
|
||||
} else {
|
||||
err = web.Listen(addr)
|
||||
}
|
||||
if err != nil {
|
||||
log.Fatal().Err(err).Msg("Failed to start server")
|
||||
}
|
||||
}()
|
||||
maybeRunTaskHandler(rootCtx, "notes", notesSet, &wg)
|
||||
|
||||
signalCh := make(chan os.Signal, 1)
|
||||
signal.Notify(signalCh, os.Interrupt)
|
||||
|
|
@ -238,10 +179,7 @@ func main() {
|
|||
|
||||
log.Info().Msg("Shutting down")
|
||||
|
||||
tq.Close()
|
||||
if err := web.Shutdown(); err != nil {
|
||||
log.Error().Err(err).Msg("Failed to shutdown server")
|
||||
}
|
||||
cancelRoot()
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
|
@ -293,8 +231,8 @@ func migrateDB(db *ent.Client, log logr.Logger, telemetry *unitel.Telemetry) err
|
|||
return nil
|
||||
}
|
||||
|
||||
func initServerActor(db *ent.Client, telemetry *unitel.Telemetry) error {
|
||||
s := telemetry.StartSpan(context.Background(), "function", "main.initServerActor")
|
||||
func initInstance(db *ent.Client, telemetry *unitel.Telemetry) error {
|
||||
s := telemetry.StartSpan(context.Background(), "function", "main.initInstance")
|
||||
defer s.End()
|
||||
ctx := s.Context()
|
||||
|
||||
|
|
@ -354,27 +292,47 @@ func initServerActor(db *ent.Client, telemetry *unitel.Telemetry) error {
|
|||
return tx.Finish()
|
||||
}
|
||||
|
||||
func healthCheck(db *ent.Client, nc *nats.Conn) fiber.Handler {
|
||||
return func(c *fiber.Ctx) error {
|
||||
dbWorking := true
|
||||
if err := db.Ping(); err != nil {
|
||||
log.Error().Err(err).Msg("Database healthcheck failed")
|
||||
dbWorking = false
|
||||
}
|
||||
|
||||
natsWorking := true
|
||||
if status := nc.Status(); status != nats.CONNECTED {
|
||||
log.Error().Str("status", status.String()).Msg("NATS healthcheck failed")
|
||||
natsWorking = false
|
||||
}
|
||||
|
||||
if dbWorking && natsWorking {
|
||||
return c.SendString("lookin' good")
|
||||
}
|
||||
|
||||
return api_schema.ErrInternalServerError(map[string]any{
|
||||
"database": dbWorking,
|
||||
"nats": natsWorking,
|
||||
})
|
||||
func registerTaskHandler[T task.Handler](ctx context.Context, name string, tq *taskqueue.Client, handler T) *taskqueue.Set {
|
||||
s, err := tq.Set(ctx, name)
|
||||
if err != nil {
|
||||
log.Fatal().Err(err).Str("handler", name).Msg("Could not create taskset for task handler")
|
||||
}
|
||||
|
||||
handler.Register(s)
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
func maybeRunTaskHandler(ctx context.Context, name string, set *taskqueue.Set, wg *sync.WaitGroup) {
|
||||
l := log.With().Str("handler", name).Logger()
|
||||
|
||||
if config.C.Mode == config.ModeWeb {
|
||||
l.Warn().Strs("requested", config.C.Consumers).Msg("Not starting task handler, as this process is running in web mode")
|
||||
return
|
||||
}
|
||||
|
||||
if config.C.Mode == config.ModeConsumer && !slices.Contains(config.C.Consumers, name) {
|
||||
l.Warn().Strs("requested", config.C.Consumers).Msg("Not starting task handler, as it wasn't requested")
|
||||
return
|
||||
}
|
||||
|
||||
wg.Add(1)
|
||||
|
||||
c := set.Consumer(name)
|
||||
if err := c.Start(ctx); err != nil {
|
||||
l.Fatal().Err(err).Msg("Could not start task handler")
|
||||
}
|
||||
|
||||
l.Info().Msg("Started task handler")
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
<-ctx.Done()
|
||||
l.Debug().Msg("Got signal to stop task handler")
|
||||
|
||||
c.Close()
|
||||
|
||||
l.Info().Msg("Stopped task handler")
|
||||
}()
|
||||
}
|
||||
|
|
|
|||
|
|
@ -4,8 +4,7 @@ import (
|
|||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"strings"
|
||||
"sync"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"git.devminer.xyz/devminer/unitel"
|
||||
|
|
@ -53,12 +52,9 @@ func NewTask(type_ string, payload any) (Task, error) {
|
|||
}, nil
|
||||
}
|
||||
|
||||
type Handler func(ctx context.Context, task Task) error
|
||||
|
||||
type Client struct {
|
||||
name string
|
||||
subject string
|
||||
handlers map[string][]Handler
|
||||
name string
|
||||
subject string
|
||||
|
||||
nc *nats.Conn
|
||||
js jetstream.JetStream
|
||||
|
|
@ -71,15 +67,27 @@ type Client struct {
|
|||
log logr.Logger
|
||||
}
|
||||
|
||||
func NewClient(ctx context.Context, streamName string, natsClient *nats.Conn, telemetry *unitel.Telemetry, log logr.Logger) (*Client, error) {
|
||||
func NewClient(streamName string, natsClient *nats.Conn, telemetry *unitel.Telemetry, log logr.Logger) (*Client, error) {
|
||||
js, err := jetstream.New(natsClient)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s, err := js.CreateStream(ctx, jetstream.StreamConfig{
|
||||
return &Client{
|
||||
name: streamName,
|
||||
|
||||
js: js,
|
||||
|
||||
telemetry: telemetry,
|
||||
log: log,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *Client) Set(ctx context.Context, name string) (*Set, error) {
|
||||
streamName := fmt.Sprintf("%s_%s", c.name, name)
|
||||
|
||||
s, err := c.js.CreateStream(ctx, jetstream.StreamConfig{
|
||||
Name: streamName,
|
||||
Subjects: []string{streamName + ".*"},
|
||||
MaxConsumers: -1,
|
||||
MaxMsgs: -1,
|
||||
Discard: jetstream.DiscardOld,
|
||||
|
|
@ -89,7 +97,7 @@ func NewClient(ctx context.Context, streamName string, natsClient *nats.Conn, te
|
|||
AllowDirect: true,
|
||||
})
|
||||
if errors.Is(err, nats.ErrStreamNameAlreadyInUse) {
|
||||
s, err = js.Stream(ctx, streamName)
|
||||
s, err = c.js.Stream(ctx, streamName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -97,190 +105,13 @@ func NewClient(ctx context.Context, streamName string, natsClient *nats.Conn, te
|
|||
return nil, err
|
||||
}
|
||||
|
||||
stopCh := make(chan struct{})
|
||||
return &Set{
|
||||
handlers: make(map[string][]TaskHandler),
|
||||
|
||||
c := &Client{
|
||||
name: streamName,
|
||||
subject: streamName + ".tasks",
|
||||
|
||||
handlers: map[string][]Handler{},
|
||||
|
||||
stopCh: stopCh,
|
||||
closeOnce: sync.OnceFunc(func() {
|
||||
close(stopCh)
|
||||
}),
|
||||
|
||||
nc: natsClient,
|
||||
js: js,
|
||||
s: s,
|
||||
|
||||
telemetry: telemetry,
|
||||
log: log,
|
||||
}
|
||||
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func (c *Client) Close() {
|
||||
c.closeOnce()
|
||||
c.nc.Close()
|
||||
}
|
||||
|
||||
func (c *Client) Submit(ctx context.Context, task Task) error {
|
||||
s := c.telemetry.StartSpan(ctx, "queue.publish", "taskqueue/Client.Submit").
|
||||
AddAttribute("messaging.destination.name", c.subject)
|
||||
defer s.End()
|
||||
ctx = s.Context()
|
||||
|
||||
s.AddAttribute("jobID", task.ID)
|
||||
|
||||
data, err := json.Marshal(c.newTaskWrapper(ctx, task))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.AddAttribute("messaging.message.body.size", len(data))
|
||||
|
||||
msg, err := c.js.PublishMsg(ctx, &nats.Msg{Subject: c.subject, Data: data})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.log.V(2).Info("Submitted task", "id", task.ID, "type", task.Type, "sequence", msg.Sequence)
|
||||
|
||||
s.AddAttribute("messaging.message.id", msg.Sequence)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) RegisterHandler(type_ string, handler Handler) {
|
||||
c.log.V(2).Info("Registering handler", "type", type_)
|
||||
|
||||
if _, ok := c.handlers[type_]; !ok {
|
||||
c.handlers[type_] = []Handler{}
|
||||
}
|
||||
c.handlers[type_] = append(c.handlers[type_], handler)
|
||||
}
|
||||
|
||||
func (c *Client) StartConsumer(ctx context.Context, consumerGroup string) error {
|
||||
c.log.Info("Starting consumer")
|
||||
|
||||
sub, err := c.js.CreateConsumer(ctx, c.name, jetstream.ConsumerConfig{
|
||||
Durable: consumerGroup,
|
||||
DeliverPolicy: jetstream.DeliverAllPolicy,
|
||||
ReplayPolicy: jetstream.ReplayInstantPolicy,
|
||||
AckPolicy: jetstream.AckExplicitPolicy,
|
||||
FilterSubject: c.subject,
|
||||
MaxWaiting: 1,
|
||||
MaxAckPending: 1,
|
||||
HeadersOnly: false,
|
||||
MemoryStorage: false,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
m, err := sub.Messages(jetstream.PullMaxMessages(1))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
go func() {
|
||||
for {
|
||||
msg, err := m.Next()
|
||||
if err != nil {
|
||||
if errors.Is(err, jetstream.ErrMsgIteratorClosed) {
|
||||
c.log.Info("Stopping")
|
||||
return
|
||||
}
|
||||
|
||||
c.log.Error(err, "Failed to get next message")
|
||||
break
|
||||
}
|
||||
|
||||
if err := c.handleTask(ctx, msg); err != nil {
|
||||
c.log.Error(err, "Failed to handle task")
|
||||
break
|
||||
}
|
||||
}
|
||||
}()
|
||||
go func() {
|
||||
<-c.stopCh
|
||||
m.Drain()
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) handleTask(ctx context.Context, msg jetstream.Msg) error {
|
||||
msgMeta, err := msg.Metadata()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
data := msg.Data()
|
||||
|
||||
var w taskWrapper
|
||||
if err := json.Unmarshal(data, &w); err != nil {
|
||||
if err := msg.Nak(); err != nil {
|
||||
c.log.Error(err, "Failed to nak message")
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
s := c.telemetry.StartSpan(
|
||||
context.Background(),
|
||||
"queue.process",
|
||||
"taskqueue/Client.handleTask",
|
||||
c.telemetry.ContinueFromMap(w.TraceInfo),
|
||||
).
|
||||
AddAttribute("messaging.destination.name", c.subject).
|
||||
AddAttribute("messaging.message.id", msgMeta.Sequence.Stream).
|
||||
AddAttribute("messaging.message.retry.count", msgMeta.NumDelivered).
|
||||
AddAttribute("messaging.message.body.size", len(data)).
|
||||
AddAttribute("messaging.message.receive.latency", time.Since(w.EnqueuedAt).Milliseconds())
|
||||
defer s.End()
|
||||
ctx = s.Context()
|
||||
|
||||
handlers, ok := c.handlers[w.Task.Type]
|
||||
if !ok {
|
||||
c.log.V(2).Info("No handler for task", "type", w.Task.Type)
|
||||
return msg.Nak()
|
||||
}
|
||||
|
||||
var errs CombinedError
|
||||
for _, handler := range handlers {
|
||||
if err := handler(ctx, w.Task); err != nil {
|
||||
c.log.Error(err, "Handler failed", "type", w.Task.Type)
|
||||
errs.Errors = append(errs.Errors, err)
|
||||
}
|
||||
}
|
||||
|
||||
if len(errs.Errors) > 0 {
|
||||
if err := msg.Nak(); err != nil {
|
||||
c.log.Error(err, "Failed to nak message")
|
||||
errs.Errors = append(errs.Errors, err)
|
||||
}
|
||||
|
||||
return errs
|
||||
}
|
||||
|
||||
return msg.Ack()
|
||||
}
|
||||
|
||||
type CombinedError struct {
|
||||
Errors []error
|
||||
}
|
||||
|
||||
func (e CombinedError) Error() string {
|
||||
sb := strings.Builder{}
|
||||
sb.WriteRune('[')
|
||||
for i, err := range e.Errors {
|
||||
if i > 0 {
|
||||
sb.WriteRune(',')
|
||||
}
|
||||
sb.WriteString(err.Error())
|
||||
}
|
||||
sb.WriteRune(']')
|
||||
return sb.String()
|
||||
streamName: streamName,
|
||||
c: c,
|
||||
s: s,
|
||||
log: c.log.WithName(fmt.Sprintf("taskset(%s)", name)),
|
||||
telemetry: c.telemetry,
|
||||
}, nil
|
||||
}
|
||||
|
|
|
|||
20
pkg/taskqueue/errors.go
Normal file
20
pkg/taskqueue/errors.go
Normal file
|
|
@ -0,0 +1,20 @@
|
|||
package taskqueue
|
||||
|
||||
import "strings"
|
||||
|
||||
type CombinedError struct {
|
||||
Errors []error
|
||||
}
|
||||
|
||||
func (e CombinedError) Error() string {
|
||||
sb := strings.Builder{}
|
||||
sb.WriteRune('[')
|
||||
for i, err := range e.Errors {
|
||||
if i > 0 {
|
||||
sb.WriteRune(',')
|
||||
}
|
||||
sb.WriteString(err.Error())
|
||||
}
|
||||
sb.WriteRune(']')
|
||||
return sb.String()
|
||||
}
|
||||
210
pkg/taskqueue/taskset.go
Normal file
210
pkg/taskqueue/taskset.go
Normal file
|
|
@ -0,0 +1,210 @@
|
|||
package taskqueue
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"git.devminer.xyz/devminer/unitel"
|
||||
"github.com/go-logr/logr"
|
||||
"github.com/nats-io/nats.go"
|
||||
"github.com/nats-io/nats.go/jetstream"
|
||||
)
|
||||
|
||||
type TaskHandler = func(ctx context.Context, task Task) error
|
||||
|
||||
type Set struct {
|
||||
handlers map[string][]TaskHandler
|
||||
|
||||
streamName string
|
||||
c *Client
|
||||
s jetstream.Stream
|
||||
log logr.Logger
|
||||
telemetry *unitel.Telemetry
|
||||
}
|
||||
|
||||
func (t *Set) RegisterHandler(type_ string, handler TaskHandler) {
|
||||
t.log.V(2).Info("Registering handler", "type", type_)
|
||||
|
||||
if _, ok := t.handlers[type_]; !ok {
|
||||
t.handlers[type_] = []TaskHandler{}
|
||||
}
|
||||
t.handlers[type_] = append(t.handlers[type_], handler)
|
||||
}
|
||||
|
||||
func (t *Set) Submit(ctx context.Context, task Task) error {
|
||||
s := t.telemetry.StartSpan(ctx, "queue.publish", "taskqueue/TaskSet.Submit").
|
||||
AddAttribute("messaging.destination.name", t.streamName)
|
||||
defer s.End()
|
||||
ctx = s.Context()
|
||||
|
||||
s.AddAttribute("jobID", task.ID)
|
||||
|
||||
data, err := json.Marshal(t.c.newTaskWrapper(ctx, task))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.AddAttribute("messaging.message.body.size", len(data))
|
||||
|
||||
// TODO: Refactor
|
||||
msg, err := t.c.js.PublishMsg(ctx, &nats.Msg{Subject: t.streamName, Data: data})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
t.log.V(2).Info("Submitted task", "id", task.ID, "type", task.Type, "sequence", msg.Sequence)
|
||||
|
||||
s.AddAttribute("messaging.message.id", msg.Sequence)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *Set) Consumer(name string) *Consumer {
|
||||
stopCh := make(chan struct{})
|
||||
stopOnce := sync.OnceFunc(func() {
|
||||
close(stopCh)
|
||||
})
|
||||
|
||||
return &Consumer{
|
||||
stopCh: stopCh,
|
||||
stopOnce: stopOnce,
|
||||
|
||||
name: name,
|
||||
streamName: t.streamName,
|
||||
telemetry: t.telemetry,
|
||||
log: t.log.WithName(fmt.Sprintf("consumer(%s)", name)),
|
||||
t: t,
|
||||
}
|
||||
}
|
||||
|
||||
type Consumer struct {
|
||||
stopCh chan struct{}
|
||||
stopOnce func()
|
||||
|
||||
name string
|
||||
streamName string
|
||||
telemetry *unitel.Telemetry
|
||||
log logr.Logger
|
||||
t *Set
|
||||
}
|
||||
|
||||
func (c *Consumer) Close() {
|
||||
c.stopOnce()
|
||||
}
|
||||
|
||||
func (c *Consumer) Start(ctx context.Context) error {
|
||||
c.log.Info("Starting consumer")
|
||||
|
||||
sub, err := c.t.c.js.CreateConsumer(ctx, c.streamName, jetstream.ConsumerConfig{
|
||||
Durable: c.name,
|
||||
DeliverPolicy: jetstream.DeliverAllPolicy,
|
||||
ReplayPolicy: jetstream.ReplayInstantPolicy,
|
||||
AckPolicy: jetstream.AckExplicitPolicy,
|
||||
MaxWaiting: 1,
|
||||
MaxAckPending: 1,
|
||||
HeadersOnly: false,
|
||||
MemoryStorage: false,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
m, err := sub.Messages(jetstream.PullMaxMessages(1))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
go c.handleMessages(m)
|
||||
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
c.Close()
|
||||
}()
|
||||
|
||||
go func() {
|
||||
<-c.stopCh
|
||||
m.Drain()
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Consumer) handleMessages(m jetstream.MessagesContext) {
|
||||
for {
|
||||
msg, err := m.Next()
|
||||
if err != nil {
|
||||
if errors.Is(err, jetstream.ErrMsgIteratorClosed) {
|
||||
c.log.Info("Stopping")
|
||||
return
|
||||
}
|
||||
|
||||
c.log.Error(err, "Failed to get next message")
|
||||
break
|
||||
}
|
||||
|
||||
if err := c.handleTask(msg); err != nil {
|
||||
c.log.Error(err, "Failed to handle task")
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Consumer) handleTask(msg jetstream.Msg) error {
|
||||
msgMeta, err := msg.Metadata()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
data := msg.Data()
|
||||
|
||||
var w taskWrapper
|
||||
if err := json.Unmarshal(data, &w); err != nil {
|
||||
if err := msg.Nak(); err != nil {
|
||||
c.log.Error(err, "Failed to nak message")
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
s := c.telemetry.StartSpan(
|
||||
context.Background(),
|
||||
"queue.process",
|
||||
"taskqueue/Consumer.handleTask",
|
||||
c.telemetry.ContinueFromMap(w.TraceInfo),
|
||||
).
|
||||
AddAttribute("messaging.destination.name", msg.Subject()).
|
||||
AddAttribute("messaging.message.id", msgMeta.Sequence.Stream).
|
||||
AddAttribute("messaging.message.retry.count", msgMeta.NumDelivered).
|
||||
AddAttribute("messaging.message.body.size", len(data)).
|
||||
AddAttribute("messaging.message.receive.latency", time.Since(w.EnqueuedAt).Milliseconds())
|
||||
defer s.End()
|
||||
ctx := s.Context()
|
||||
|
||||
handlers, ok := c.t.handlers[w.Task.Type]
|
||||
if !ok {
|
||||
c.log.V(2).Info("No handler for task", "type", w.Task.Type)
|
||||
return msg.Nak()
|
||||
}
|
||||
|
||||
var errs CombinedError
|
||||
for _, handler := range handlers {
|
||||
if err := handler(ctx, w.Task); err != nil {
|
||||
c.log.Error(err, "Handler failed", "type", w.Task.Type)
|
||||
errs.Errors = append(errs.Errors, err)
|
||||
}
|
||||
}
|
||||
|
||||
if len(errs.Errors) > 0 {
|
||||
if err := msg.Nak(); err != nil {
|
||||
c.log.Error(err, "Failed to nak message")
|
||||
errs.Errors = append(errs.Errors, err)
|
||||
}
|
||||
|
||||
return errs
|
||||
}
|
||||
|
||||
return msg.Ack()
|
||||
}
|
||||
146
server.go
Normal file
146
server.go
Normal file
|
|
@ -0,0 +1,146 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"git.devminer.xyz/devminer/unitel"
|
||||
"git.devminer.xyz/devminer/unitel/unitelhttp"
|
||||
"github.com/versia-pub/versia-go/internal/api_schema"
|
||||
"github.com/versia-pub/versia-go/internal/handlers/follow_handler"
|
||||
"github.com/versia-pub/versia-go/internal/handlers/meta_handler"
|
||||
"github.com/versia-pub/versia-go/internal/handlers/note_handler"
|
||||
"github.com/versia-pub/versia-go/internal/service"
|
||||
"github.com/versia-pub/versia-go/internal/validators"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/go-logr/zerologr"
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/middleware/cors"
|
||||
"github.com/nats-io/nats.go"
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/versia-pub/versia-go/config"
|
||||
"github.com/versia-pub/versia-go/ent"
|
||||
"github.com/versia-pub/versia-go/internal/handlers/user_handler"
|
||||
)
|
||||
|
||||
func shouldPropagate(r *http.Request) bool {
|
||||
return config.C.ForwardTracesTo.Match([]byte(r.URL.String()))
|
||||
}
|
||||
|
||||
func server(
|
||||
ctx context.Context,
|
||||
telemetry *unitel.Telemetry,
|
||||
database *ent.Client,
|
||||
natsConn *nats.Conn,
|
||||
federationService service.FederationService,
|
||||
requestSigner service.RequestSigner,
|
||||
bodyValidator validators.BodyValidator,
|
||||
requestValidator validators.RequestValidator,
|
||||
userService service.UserService,
|
||||
noteService service.NoteService,
|
||||
followService service.FollowService,
|
||||
instanceMetadataService service.InstanceMetadataService,
|
||||
inboxService service.InboxService,
|
||||
) error {
|
||||
// Handlers
|
||||
|
||||
userHandler := user_handler.New(federationService, requestSigner, userService, inboxService, bodyValidator, requestValidator, zerologr.New(&log.Logger).WithName("user-handler"))
|
||||
noteHandler := note_handler.New(noteService, bodyValidator, requestSigner, zerologr.New(&log.Logger).WithName("notes-handler"))
|
||||
followHandler := follow_handler.New(followService, federationService, zerologr.New(&log.Logger).WithName("follow-handler"))
|
||||
metaHandler := meta_handler.New(instanceMetadataService, zerologr.New(&log.Logger).WithName("meta-handler"))
|
||||
|
||||
// Initialization
|
||||
|
||||
web := fiber.New(fiber.Config{
|
||||
ProxyHeader: "X-Forwarded-For",
|
||||
ErrorHandler: fiberErrorHandler,
|
||||
DisableStartupMessage: true,
|
||||
AppName: "versia-go",
|
||||
EnablePrintRoutes: true,
|
||||
})
|
||||
|
||||
web.Use(cors.New(cors.Config{
|
||||
AllowOriginsFunc: func(origin string) bool {
|
||||
return true
|
||||
},
|
||||
AllowMethods: "GET,POST,PUT,DELETE,PATCH",
|
||||
AllowHeaders: "Origin, Content-Type, Accept, Authorization, b3, traceparent, sentry-trace, baggage",
|
||||
AllowCredentials: true,
|
||||
ExposeHeaders: "",
|
||||
MaxAge: 0,
|
||||
}))
|
||||
|
||||
web.Use(unitelhttp.FiberMiddleware(telemetry, unitelhttp.FiberMiddlewareConfig{
|
||||
Repanic: false,
|
||||
WaitForDelivery: false,
|
||||
Timeout: 5 * time.Second,
|
||||
// host for incoming requests
|
||||
TraceRequestHeaders: []string{"origin", "x-nonce", "x-signature", "x-signed-by", "sentry-trace", "sentry-baggage"},
|
||||
// origin for outgoing requests
|
||||
TraceResponseHeaders: []string{"host", "x-nonce", "x-signature", "x-signed-by", "sentry-trace", "sentry-baggage"},
|
||||
IgnoredRoutes: []string{"/api/health"},
|
||||
Logger: zerologr.New(&log.Logger).WithName("http-server"),
|
||||
TracePropagator: shouldPropagate,
|
||||
}))
|
||||
web.Use(unitelhttp.RequestLogger(zerologr.New(&log.Logger).WithName("http-server"), true, true))
|
||||
|
||||
log.Debug().Msg("Registering handlers")
|
||||
|
||||
web.Get("/api/health", healthCheck(database, natsConn))
|
||||
|
||||
userHandler.Register(web.Group("/"))
|
||||
noteHandler.Register(web.Group("/"))
|
||||
followHandler.Register(web.Group("/"))
|
||||
metaHandler.Register(web.Group("/"))
|
||||
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(2)
|
||||
|
||||
addr := fmt.Sprintf(":%d", config.C.Port)
|
||||
|
||||
log.Info().Str("addr", addr).Msg("Starting server")
|
||||
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
|
||||
if err := web.Shutdown(); err != nil {
|
||||
log.Error().Err(err).Msg("Failed to shutdown server")
|
||||
}
|
||||
}()
|
||||
|
||||
var err error
|
||||
if config.C.TLSKey != nil {
|
||||
err = web.ListenTLS(addr, *config.C.TLSCert, *config.C.TLSKey)
|
||||
} else {
|
||||
err = web.Listen(addr)
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func healthCheck(db *ent.Client, nc *nats.Conn) fiber.Handler {
|
||||
return func(c *fiber.Ctx) error {
|
||||
dbWorking := true
|
||||
if err := db.Ping(); err != nil {
|
||||
log.Error().Err(err).Msg("Database healthcheck failed")
|
||||
dbWorking = false
|
||||
}
|
||||
|
||||
natsWorking := true
|
||||
if status := nc.Status(); status != nats.CONNECTED {
|
||||
log.Error().Str("status", status.String()).Msg("NATS healthcheck failed")
|
||||
natsWorking = false
|
||||
}
|
||||
|
||||
if dbWorking && natsWorking {
|
||||
return c.SendString("lookin' good")
|
||||
}
|
||||
|
||||
return api_schema.ErrInternalServerError(map[string]any{
|
||||
"database": dbWorking,
|
||||
"nats": natsWorking,
|
||||
})
|
||||
}
|
||||
}
|
||||
Loading…
Reference in a new issue