diff --git a/plugins/inputs/nats_consumer/nats_consumer.go b/plugins/inputs/nats_consumer/nats_consumer.go index 2ebb706fee48f..dcde50008a011 100644 --- a/plugins/inputs/nats_consumer/nats_consumer.go +++ b/plugins/inputs/nats_consumer/nats_consumer.go @@ -5,6 +5,7 @@ import ( "context" _ "embed" "fmt" + "slices" "strings" "sync" @@ -48,12 +49,15 @@ type NatsConsumer struct { parser telegraf.Parser // channel for all incoming NATS messages - in chan *nats.Msg + in chan *nats.Msg + undelivered map[telegraf.TrackingID]*nats.Msg // channel for all NATS read errors errs chan error acc telegraf.TrackingAccumulator + sem semaphore wg sync.WaitGroup cancel context.CancelFunc + sync.Mutex } type ( @@ -82,7 +86,9 @@ func (n *NatsConsumer) SetParser(parser telegraf.Parser) { // Start the nats consumer. Caller must call *NatsConsumer.Stop() to clean up. func (n *NatsConsumer) Start(acc telegraf.Accumulator) error { + n.sem = make(semaphore, n.MaxUndeliveredMessages) n.acc = acc.WithTracking(n.MaxUndeliveredMessages) + n.undelivered = make(map[telegraf.TrackingID]*nats.Msg, n.MaxUndeliveredMessages) options := []nats.Option{ nats.MaxReconnects(-1), @@ -125,7 +131,7 @@ func (n *NatsConsumer) Start(acc telegraf.Accumulator) error { // Setup message and error channels n.errs = make(chan error) - n.in = make(chan *nats.Msg, 1000) + n.in = make(chan *nats.Msg, n.PendingMessageLimit) for _, subj := range n.Subjects { sub, err := n.conn.QueueSubscribe(subj, n.QueueGroup, func(m *nats.Msg) { n.in <- m @@ -145,7 +151,9 @@ func (n *NatsConsumer) Start(acc telegraf.Accumulator) error { if len(n.JsSubjects) > 0 { var connErr error - var subOptions []nats.SubOpt + subOptions := []nats.SubOpt{ + nats.ManualAck(), + } if n.JsStream != "" { subOptions = append(subOptions, nats.BindStream(n.JsStream)) } @@ -178,6 +186,13 @@ func (n *NatsConsumer) Start(acc telegraf.Accumulator) error { ctx, cancel := context.WithCancel(context.Background()) n.cancel = cancel + // Start goroutine to handle delivery notifications from accumulator. + n.wg.Add(1) + go func() { + defer n.wg.Done() + n.waitForDelivery(ctx) + }() + // Start the message reader n.wg.Add(1) go func() { @@ -212,43 +227,88 @@ func (n *NatsConsumer) natsErrHandler(c *nats.Conn, s *nats.Subscription, e erro // receiver() reads all incoming messages from NATS, and parses them into // telegraf metrics. func (n *NatsConsumer) receiver(ctx context.Context) { - sem := make(semaphore, n.MaxUndeliveredMessages) - for { + // Acquire a semaphore to block consumption if the number of undelivered messages + // reached it's limit + select { + case <-ctx.Done(): + return + case n.sem <- empty{}: + } + + // Consume messages and errors select { case <-ctx.Done(): return - case <-n.acc.Delivered(): - <-sem case err := <-n.errs: n.Log.Error(err) - case sem <- empty{}: - select { - case <-ctx.Done(): - return - case err := <-n.errs: - <-sem - n.Log.Error(err) - case <-n.acc.Delivered(): - <-sem - <-sem - case msg := <-n.in: - metrics, err := n.parser.Parse(msg.Data) - if err != nil { - n.Log.Errorf("Subject: %s, error: %s", msg.Subject, err.Error()) - <-sem - continue + case msg := <-n.in: + jetstreamMsg := slices.Contains(n.jsSubs, msg.Sub) + + if jetstreamMsg { + if err := msg.InProgress(); err != nil { + n.Log.Warnf("Failed to mark JetStream message as in progress on subject %s: %v", msg.Subject, err) } - if len(metrics) == 0 { - once.Do(func() { - n.Log.Debug(internal.NoMetricsCreatedMsg) - }) + } + + // Parse the metric and add it to the accumulator + metrics, err := n.parser.Parse(msg.Data) + if err != nil { + n.acc.AddError(fmt.Errorf("failed to handle message on subject %s: %w", msg.Subject, err)) + } + if len(metrics) == 0 { + once.Do(func() { + n.Log.Debug(internal.NoMetricsCreatedMsg) + }) + <-n.sem + if jetstreamMsg { + if err := msg.Ack(); err != nil { + n.acc.AddError(fmt.Errorf("failed to acknowledge JetStream message on subject %s: %w", msg.Subject, err)) + } } + } else { for _, m := range metrics { m.AddTag("subject", msg.Subject) } - n.acc.AddTrackingMetricGroup(metrics) + id := n.acc.AddTrackingMetricGroup(metrics) + + // Make sure we manually acknowledge the messages later on delivery to Telegraf output(s) + if jetstreamMsg { + n.Lock() + n.undelivered[id] = msg + n.Unlock() + } + } + } + } +} + +func (n *NatsConsumer) waitForDelivery(ctx context.Context) { + for { + select { + case <-ctx.Done(): + return + case track := <-n.acc.Delivered(): + // Get the tracked metric if any. Please remember, only Jetstream messages support a manual ACK + n.Lock() + msg, ok := n.undelivered[track.ID()] + delete(n.undelivered, track.ID()) + n.Unlock() + + if !ok { + <-n.sem + continue + } + if track.Delivered() { + if err := msg.Ack(); err != nil { + n.Log.Errorf("Failed to acknowledge JetStream message on subject %s: %v", msg.Subject, err) + } + } else { + if err := msg.Term(); err != nil { + n.Log.Errorf("Failed to terminate JetStream message on subject %s: %v", msg.Subject, err) + } } + <-n.sem } } } diff --git a/plugins/inputs/nats_consumer/nats_consumer_test.go b/plugins/inputs/nats_consumer/nats_consumer_test.go index 44f523f962965..996594b17805f 100644 --- a/plugins/inputs/nats_consumer/nats_consumer_test.go +++ b/plugins/inputs/nats_consumer/nats_consumer_test.go @@ -179,6 +179,10 @@ func TestIntegrationSendReceive(t *testing.T) { actual := acc.GetTelegrafMetrics() testutil.RequireMetricsEqual(t, tt.expected, actual, testutil.IgnoreTime(), testutil.SortMetrics()) + + plugin.Lock() + defer plugin.Unlock() + require.Empty(t, plugin.undelivered) }) } } @@ -214,6 +218,7 @@ func TestJetStreamIntegrationSendReceive(t *testing.T) { require.NoError(t, err) // Setup the plugin for JetStream + log := testutil.CaptureLogger{} plugin := &NatsConsumer{ Servers: []string{addr}, JsSubjects: []string{subject}, @@ -222,7 +227,7 @@ func TestJetStreamIntegrationSendReceive(t *testing.T) { PendingBytesLimit: nats.DefaultSubPendingBytesLimit, PendingMessageLimit: nats.DefaultSubPendingMsgsLimit, MaxUndeliveredMessages: defaultMaxUndeliveredMessages, - Log: testutil.Logger{}, + Log: &log, } parser := &influx.Parser{} @@ -258,6 +263,23 @@ func TestJetStreamIntegrationSendReceive(t *testing.T) { ), } testutil.RequireMetricsEqual(t, expected, actual, testutil.IgnoreTime(), testutil.SortMetrics()) + + // Acknowledge the message and check undelivered tracking + log.Clear() + plugin.Lock() + require.Len(t, plugin.undelivered, 1) + plugin.Unlock() + for _, m := range actual { + m.Accept() + } + + require.Eventually(t, func() bool { + plugin.Lock() + defer plugin.Unlock() + return len(plugin.undelivered) == 0 + }, time.Second, 100*time.Millisecond, "undelivered messages not cleared") + + require.Empty(t, log.Messages(), "no warnings or errors should be logged") } func TestJetStreamIntegrationSourcedStreamNotFound(t *testing.T) {