versia-go/pkg/taskqueue/client.go
2024-08-13 01:18:14 +02:00

289 lines
5.9 KiB
Go

package taskqueue
import (
"context"
"encoding/json"
"errors"
"strings"
"sync"
"time"
"git.devminer.xyz/devminer/unitel"
"github.com/go-logr/logr"
"github.com/google/uuid"
"github.com/nats-io/nats.go"
"github.com/nats-io/nats.go/jetstream"
)
type taskWrapper struct {
Task Task `json:"task"`
EnqueuedAt time.Time `json:"enqueuedAt"`
TraceInfo map[string]string `json:"traceInfo"`
}
func (c *Client) newTaskWrapper(ctx context.Context, task Task) taskWrapper {
traceInfo := make(map[string]string)
c.telemetry.InjectIntoMap(ctx, traceInfo)
return taskWrapper{
Task: task,
EnqueuedAt: time.Now(),
TraceInfo: traceInfo,
}
}
type Task struct {
ID string
Type string
Payload json.RawMessage
}
func NewTask(type_ string, payload any) (Task, error) {
id := uuid.New()
d, err := json.Marshal(payload)
if err != nil {
return Task{}, err
}
return Task{
ID: id.String(),
Type: type_,
Payload: d,
}, nil
}
type Handler func(ctx context.Context, task Task) error
type Client struct {
name string
subject string
handlers map[string][]Handler
nc *nats.Conn
js jetstream.JetStream
s jetstream.Stream
stopCh chan struct{}
closeOnce func()
telemetry *unitel.Telemetry
log logr.Logger
}
func NewClient(ctx context.Context, name string, natsClient *nats.Conn, telemetry *unitel.Telemetry, log logr.Logger) (*Client, error) {
js, err := jetstream.New(natsClient)
if err != nil {
return nil, err
}
s, err := js.CreateStream(ctx, jetstream.StreamConfig{
Name: name,
Subjects: []string{name + ".*"},
MaxConsumers: -1,
MaxMsgs: -1,
Discard: jetstream.DiscardOld,
MaxMsgsPerSubject: -1,
Storage: jetstream.FileStorage,
Compression: jetstream.S2Compression,
AllowDirect: true,
})
if errors.Is(err, nats.ErrStreamNameAlreadyInUse) {
s, err = js.Stream(ctx, name)
if err != nil {
return nil, err
}
} else if err != nil {
return nil, err
}
stopCh := make(chan struct{})
c := &Client{
name: name,
subject: name + ".tasks",
handlers: map[string][]Handler{},
stopCh: stopCh,
closeOnce: sync.OnceFunc(func() {
close(stopCh)
}),
nc: natsClient,
js: js,
s: s,
telemetry: telemetry,
log: log,
}
return c, nil
}
func (c *Client) Close() {
c.closeOnce()
c.nc.Close()
}
func (c *Client) Submit(ctx context.Context, task Task) error {
s := c.telemetry.StartSpan(ctx, "queue.publish", "taskqueue/Client.Submit").
AddAttribute("messaging.destination.name", c.subject)
defer s.End()
ctx = s.Context()
s.AddAttribute("jobID", task.ID)
data, err := json.Marshal(c.newTaskWrapper(ctx, task))
if err != nil {
return err
}
s.AddAttribute("messaging.message.body.size", len(data))
msg, err := c.js.PublishMsg(ctx, &nats.Msg{Subject: c.subject, Data: data})
if err != nil {
return err
}
c.log.V(1).Info("submitted task", "id", task.ID, "type", task.Type, "sequence", msg.Sequence)
s.AddAttribute("messaging.message.id", msg.Sequence)
return nil
}
func (c *Client) RegisterHandler(type_ string, handler Handler) {
c.log.V(2).Info("registering handler", "type", type_)
if _, ok := c.handlers[type_]; !ok {
c.handlers[type_] = []Handler{}
}
c.handlers[type_] = append(c.handlers[type_], handler)
}
func (c *Client) Start(ctx context.Context) error {
c.log.Info("starting")
sub, err := c.js.CreateConsumer(ctx, c.name, jetstream.ConsumerConfig{
// TODO: set name properly
Name: "versia-go",
Durable: "versia-go",
DeliverPolicy: jetstream.DeliverAllPolicy,
ReplayPolicy: jetstream.ReplayInstantPolicy,
AckPolicy: jetstream.AckExplicitPolicy,
FilterSubject: c.subject,
MaxWaiting: 1,
MaxAckPending: 1,
HeadersOnly: false,
MemoryStorage: false,
})
if err != nil {
return err
}
m, err := sub.Messages(jetstream.PullMaxMessages(1))
if err != nil {
return err
}
go func() {
for {
msg, err := m.Next()
if err != nil {
if errors.Is(err, jetstream.ErrMsgIteratorClosed) {
c.log.Info("stopping")
return
}
c.log.Error(err, "failed to get next message")
break
}
if err := c.handleTask(ctx, msg); err != nil {
c.log.Error(err, "failed to handle task")
break
}
}
}()
go func() {
<-c.stopCh
m.Drain()
}()
return nil
}
func (c *Client) handleTask(ctx context.Context, msg jetstream.Msg) error {
msgMeta, err := msg.Metadata()
if err != nil {
return err
}
data := msg.Data()
var w taskWrapper
if err := json.Unmarshal(data, &w); err != nil {
if err := msg.Nak(); err != nil {
c.log.Error(err, "failed to nak message")
}
return err
}
s := c.telemetry.StartSpan(
context.Background(),
"queue.process",
"taskqueue/Client.handleTask",
c.telemetry.ContinueFromMap(w.TraceInfo),
).
AddAttribute("messaging.destination.name", c.subject).
AddAttribute("messaging.message.id", msgMeta.Sequence.Stream).
AddAttribute("messaging.message.retry.count", msgMeta.NumDelivered).
AddAttribute("messaging.message.body.size", len(data)).
AddAttribute("messaging.message.receive.latency", time.Since(w.EnqueuedAt).Milliseconds())
defer s.End()
ctx = s.Context()
handlers, ok := c.handlers[w.Task.Type]
if !ok {
c.log.V(1).Info("no handler for task", "type", w.Task.Type)
return msg.Nak()
}
var errs CombinedError
for _, handler := range handlers {
if err := handler(ctx, w.Task); err != nil {
c.log.Error(err, "handler failed", "type", w.Task.Type)
errs.Errors = append(errs.Errors, err)
}
}
if len(errs.Errors) > 0 {
if err := msg.Nak(); err != nil {
c.log.Error(err, "failed to nak message")
errs.Errors = append(errs.Errors, err)
}
return errs
}
return msg.Ack()
}
type CombinedError struct {
Errors []error
}
func (e CombinedError) Error() string {
sb := strings.Builder{}
sb.WriteRune('[')
for i, err := range e.Errors {
if i > 0 {
sb.WriteRune(',')
}
sb.WriteString(err.Error())
}
sb.WriteRune(']')
return sb.String()
}