From deca2bd0c32373abafeb41a718f8ccc331a9cf6c Mon Sep 17 00:00:00 2001 From: marco Date: Fri, 27 Feb 2026 10:23:35 +0100 Subject: [PATCH 01/23] appsec/allowlists: replace tomb refresh with context cancellation Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- pkg/appsec/allowlists/allowlists.go | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/pkg/appsec/allowlists/allowlists.go b/pkg/appsec/allowlists/allowlists.go index 6e3286ce2c9..34e70a9b124 100644 --- a/pkg/appsec/allowlists/allowlists.go +++ b/pkg/appsec/allowlists/allowlists.go @@ -8,7 +8,6 @@ import ( "time" log "github.com/sirupsen/logrus" - "gopkg.in/tomb.v2" "github.com/crowdsecurity/crowdsec/pkg/apiclient" ) @@ -33,7 +32,6 @@ type AppsecAllowlist struct { ranges []rangeAllowlist lock sync.RWMutex logger *log.Entry - tomb *tomb.Tomb } func NewAppsecAllowlist(logger *log.Entry) *AppsecAllowlist { @@ -107,6 +105,7 @@ func (a *AppsecAllowlist) FetchAllowlists(ctx context.Context) error { func (a *AppsecAllowlist) updateAllowlists(ctx context.Context) { ticker := time.NewTicker(allowlistRefreshInterval) + defer ticker.Stop() for { select { @@ -114,19 +113,14 @@ func (a *AppsecAllowlist) updateAllowlists(ctx context.Context) { if err := a.FetchAllowlists(ctx); err != nil { a.logger.Errorf("failed to fetch allowlists: %s", err) } - case <-a.tomb.Dying(): - ticker.Stop() + case <-ctx.Done(): return } } } -func (a *AppsecAllowlist) StartRefresh(ctx context.Context, t *tomb.Tomb) { - a.tomb = t - a.tomb.Go(func() error { - a.updateAllowlists(ctx) - return nil - }) +func (a *AppsecAllowlist) StartRefresh(ctx context.Context) { + go a.updateAllowlists(ctx) } func (a *AppsecAllowlist) IsAllowlisted(sourceIP string) (bool, string) { From 9131ee25d75d65b86f904e6d7f2f226f1c33c921 Mon Sep 17 00:00:00 2001 From: marco Date: Fri, 27 Feb 2026 10:23:35 +0100 Subject: [PATCH 02/23] appsec acquisition: cancel allowlist refresh from tomb lifecycle Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- pkg/acquisition/modules/appsec/run.go | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/pkg/acquisition/modules/appsec/run.go b/pkg/acquisition/modules/appsec/run.go index deb96fa714d..dc6ecb9008e 100644 --- a/pkg/acquisition/modules/appsec/run.go +++ b/pkg/acquisition/modules/appsec/run.go @@ -127,7 +127,14 @@ func (w *Source) StreamingAcquisition(ctx context.Context, out chan pipeline.Eve if err != nil { w.logger.Errorf("failed to fetch allowlists for appsec, disabling them: %s", err) } else { - w.appsecAllowlistClient.StartRefresh(ctx, t) + refreshCtx, refreshCancel := context.WithCancel(ctx) + t.Go(func() error { + <-t.Dying() + refreshCancel() + return nil + }) + + w.appsecAllowlistClient.StartRefresh(refreshCtx) } t.Go(func() error { From 58ed4cc6783680bf0f8680d1a709dc27307ab2c3 Mon Sep 17 00:00:00 2001 From: marco Date: Fri, 27 Feb 2026 10:23:35 +0100 Subject: [PATCH 03/23] longpollclient: replace tomb lifecycle and handle closed PAPI stream Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- pkg/apiserver/papi.go | 8 +++- pkg/longpollclient/client.go | 87 +++++++++++++++++++++++------------- 2 files changed, 62 insertions(+), 33 deletions(-) diff --git a/pkg/apiserver/papi.go b/pkg/apiserver/papi.go index f7b512e9a7a..d72b237484e 100644 --- a/pkg/apiserver/papi.go +++ b/pkg/apiserver/papi.go @@ -296,7 +296,13 @@ func (p *Papi) Pull(ctx context.Context) error { papiChan = nil p.Logger.Debug("done stopping PAPI pull") } - case event := <-papiChan: + case event, ok := <-papiChan: + if !ok { + p.Logger.Debug("PAPI event stream closed") + papiChan = nil + continue + } + logger := p.Logger.WithField("request-id", event.RequestId) // update last timestamp in database newTime := time.Now().UTC() diff --git a/pkg/longpollclient/client.go b/pkg/longpollclient/client.go index af4665ff1f0..3b538609558 100644 --- a/pkg/longpollclient/client.go +++ b/pkg/longpollclient/client.go @@ -8,17 +8,19 @@ import ( "io" "net/http" "net/url" + "sync" "time" "github.com/google/uuid" log "github.com/sirupsen/logrus" - "gopkg.in/tomb.v2" "github.com/crowdsecurity/crowdsec/pkg/logging" ) type LongPollClient struct { - t tomb.Tomb + mu sync.Mutex + cancelPoll context.CancelFunc + done chan struct{} c chan Event url url.URL logger *log.Entry @@ -89,8 +91,8 @@ func (c *LongPollClient) poll(ctx context.Context) error { defer resp.Body.Close() - requestId := resp.Header.Get("X-Amzn-Trace-Id") - logger = logger.WithField("request-id", requestId) + requestID := resp.Header.Get("X-Amzn-Trace-Id") + logger = logger.WithField("request-id", requestID) if resp.StatusCode != http.StatusOK { c.logger.Errorf("unexpected status code: %d", resp.StatusCode) if resp.StatusCode == http.StatusPaymentRequired { @@ -109,20 +111,15 @@ func (c *LongPollClient) poll(ctx context.Context) error { for { select { - case <-c.t.Dying(): - logger.Debugf("dying") - close(c.c) - return nil case <-ctx.Done(): - logger.Debugf("context canceled") - close(c.c) + logger.Debug("context canceled") return ctx.Err() default: var pollResp pollResponse err = decoder.Decode(&pollResp) if err != nil { if errors.Is(err, io.EOF) { - logger.Debugf("server closed connection") + logger.Debug("server closed connection") return nil } return fmt.Errorf("error decoding poll response: %v", err) @@ -132,7 +129,7 @@ func (c *LongPollClient) poll(ctx context.Context) error { if pollResp.ErrorMessage != "" { if pollResp.ErrorMessage == timeoutMessage { - logger.Debugf("got timeout message") + logger.Debug("got timeout message") return nil } return fmt.Errorf("longpoll API error message: %s", pollResp.ErrorMessage) @@ -141,8 +138,12 @@ func (c *LongPollClient) poll(ctx context.Context) error { if len(pollResp.Events) > 0 { logger.Debugf("got %d events", len(pollResp.Events)) for _, event := range pollResp.Events { - event.RequestId = requestId - c.c <- event + event.RequestId = requestID + select { + case <-ctx.Done(): + return ctx.Err() + case c.c <- event: + } if event.Timestamp > c.since { c.since = event.Timestamp } @@ -157,27 +158,21 @@ func (c *LongPollClient) poll(ctx context.Context) error { } func (c *LongPollClient) pollEvents(ctx context.Context) error { - initialBackoff := 1 * time.Second maxBackoff := 30 * time.Second currentBackoff := initialBackoff for { select { - case <-c.t.Dying(): - c.logger.Debug("dying") - return nil case <-ctx.Done(): c.logger.Debug("context canceled") - return ctx.Err() + return nil default: c.logger.Debug("Polling PAPI") err := c.poll(ctx) if err != nil { if errors.Is(err, errUnauthorized) { c.logger.Errorf("unauthorized, stopping polling") - c.t.Kill(err) - close(c.c) return err } if errors.Is(err, context.Canceled) { @@ -186,10 +181,8 @@ func (c *LongPollClient) pollEvents(ctx context.Context) error { } c.logger.Errorf("failed to poll: %s, retrying in %s", err, currentBackoff) select { - case <-c.t.Dying(): - return nil case <-ctx.Done(): - return ctx.Err() + return nil case <-time.After(currentBackoff): } @@ -205,17 +198,47 @@ func (c *LongPollClient) pollEvents(ctx context.Context) error { } func (c *LongPollClient) Start(ctx context.Context, since time.Time) chan Event { - c.logger.Infof("starting polling client") - c.c = make(chan Event) + _ = c.Stop() + + c.logger.Info("starting polling client") + + pollCtx, cancel := context.WithCancel(ctx) + out := make(chan Event) + done := make(chan struct{}) + + c.mu.Lock() + c.c = out c.since = since.Unix() * 1000 c.timeout = "45" - c.t = tomb.Tomb{} - c.t.Go(func() error { return c.pollEvents(ctx) }) - return c.c + c.cancelPoll = cancel + c.done = done + c.mu.Unlock() + + go func() { + defer close(done) + defer close(out) + _ = c.pollEvents(pollCtx) + }() + + return out } func (c *LongPollClient) Stop() error { - c.t.Kill(nil) + c.mu.Lock() + cancel := c.cancelPoll + done := c.done + c.cancelPoll = nil + c.done = nil + c.mu.Unlock() + + if cancel != nil { + cancel() + } + + if done != nil { + <-done + } + return nil } @@ -235,7 +258,7 @@ func (c *LongPollClient) PullOnce(ctx context.Context, since time.Time) ([]Event err = decoder.Decode(&pollResp) if err != nil { if errors.Is(err, io.EOF) { - c.logger.Debugf("server closed connection") + c.logger.Debug("server closed connection") break } log.Errorf("error decoding poll response: %v", err) @@ -246,7 +269,7 @@ func (c *LongPollClient) PullOnce(ctx context.Context, since time.Time) ([]Event if pollResp.ErrorMessage != "" { if pollResp.ErrorMessage == timeoutMessage { - c.logger.Debugf("got timeout message") + c.logger.Debug("got timeout message") break } log.Errorf("longpoll API error message: %s", pollResp.ErrorMessage) From 23948c16086f9de42c7b1979c45770beadbc8a54 Mon Sep 17 00:00:00 2001 From: marco Date: Fri, 27 Feb 2026 11:52:53 +0100 Subject: [PATCH 04/23] csplugin: migrate watcher/broker lifecycle to context Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../clinotifications/notifications.go | 49 ++++++------ cmd/crowdsec/api.go | 10 ++- pkg/csplugin/broker.go | 76 +++++++++---------- pkg/csplugin/broker_test.go | 32 ++++---- pkg/csplugin/broker_win_test.go | 7 +- pkg/csplugin/watcher.go | 45 ++++++----- pkg/csplugin/watcher_test.go | 38 ++++++---- 7 files changed, 146 insertions(+), 111 deletions(-) diff --git a/cmd/crowdsec-cli/clinotifications/notifications.go b/cmd/crowdsec-cli/clinotifications/notifications.go index d0f2234d856..2bba45722f4 100644 --- a/cmd/crowdsec-cli/clinotifications/notifications.go +++ b/cmd/crowdsec-cli/clinotifications/notifications.go @@ -4,7 +4,6 @@ import ( "context" "encoding/csv" "encoding/json" - "errors" "fmt" "io/fs" "net/url" @@ -19,7 +18,6 @@ import ( "github.com/go-openapi/strfmt" log "github.com/sirupsen/logrus" "github.com/spf13/cobra" - "gopkg.in/tomb.v2" "gopkg.in/yaml.v3" "github.com/crowdsecurity/go-cs-lib/ptr" @@ -275,7 +273,6 @@ func (cli *cliNotifications) notificationConfigFilter(_ *cobra.Command, args []s func (cli *cliNotifications) newTestCmd() *cobra.Command { var ( pluginBroker csplugin.PluginBroker - pluginTomb tomb.Tomb alertOverride string ) @@ -315,11 +312,15 @@ func (cli *cliNotifications) newTestCmd() *cobra.Command { }, }, cfg.ConfigPaths) }, - RunE: func(_ *cobra.Command, _ []string) error { - pluginTomb.Go(func() error { - pluginBroker.Run(&pluginTomb) - return nil - }) + RunE: func(cmd *cobra.Command, _ []string) error { + pluginCtx, cancelPlugin := context.WithCancel(cmd.Context()) + pluginDone := make(chan struct{}) + + go func() { + pluginBroker.Run(pluginCtx) + close(pluginDone) + }() + alert := &models.Alert{ Capacity: ptr.Of(int32(0)), Decisions: []*models.Decision{{ @@ -361,15 +362,15 @@ func (cli *cliNotifications) newTestCmd() *cobra.Command { } // time.Sleep(2 * time.Second) // There's no mechanism to ensure notification has been sent - pluginTomb.Kill(errors.New("terminating")) - _ = pluginTomb.Wait() + cancelPlugin() + <-pluginDone return nil }, } cmd.Flags().StringVarP(&alertOverride, "alert", "a", "", - "JSON string used to override alert fields in the generic alert " + - "(see crowdsec/pkg/models/alert.go in the source tree for the full definition of the object)") + "JSON string used to override alert fields in the generic alert "+ + "(see crowdsec/pkg/models/alert.go in the source tree for the full definition of the object)") return cmd } @@ -401,10 +402,7 @@ cscli notifications reinject -a '{"remediation": true,"scenario":"not return nil }, RunE: func(cmd *cobra.Command, _ []string) error { - var ( - pluginBroker csplugin.PluginBroker - pluginTomb tomb.Tomb - ) + var pluginBroker csplugin.PluginBroker ctx := cmd.Context() cfg := cli.cfg() @@ -442,10 +440,13 @@ cscli notifications reinject -a '{"remediation": true,"scenario":"not return fmt.Errorf("can't initialize plugins: %w", err) } - pluginTomb.Go(func() error { - pluginBroker.Run(&pluginTomb) - return nil - }) + pluginCtx, cancelPlugin := context.WithCancel(ctx) + pluginDone := make(chan struct{}) + + go func() { + pluginBroker.Run(pluginCtx) + close(pluginDone) + }() profiles, err := csprofiles.NewProfile(cfg.API.Server.Profiles) if err != nil { @@ -482,15 +483,15 @@ cscli notifications reinject -a '{"remediation": true,"scenario":"not } } // time.Sleep(2 * time.Second) // There's no mechanism to ensure notification has been sent - pluginTomb.Kill(errors.New("terminating")) - _ = pluginTomb.Wait() + cancelPlugin() + <-pluginDone return nil }, } cmd.Flags().StringVarP(&alertOverride, "alert", "a", "", - "JSON string used to override alert fields in the reinjected alert " + - "(see crowdsec/pkg/models/alert.go in the source tree for the full definition of the object)") + "JSON string used to override alert fields in the reinjected alert "+ + "(see crowdsec/pkg/models/alert.go in the source tree for the full definition of the object)") return cmd } diff --git a/cmd/crowdsec/api.go b/cmd/crowdsec/api.go index 25f5ff03988..47e15af49cd 100644 --- a/cmd/crowdsec/api.go +++ b/cmd/crowdsec/api.go @@ -56,8 +56,16 @@ func serveAPIServer(ctx context.Context, apiServer *apiserver.APIServer) { } }() + pluginCtx, cancelPlugin := context.WithCancel(ctx) + + pluginTomb.Go(func() error { + <-pluginTomb.Dying() + cancelPlugin() + return nil + }) + pluginTomb.Go(func() error { - pluginBroker.Run(&pluginTomb) + pluginBroker.Run(pluginCtx) return nil }) diff --git a/pkg/csplugin/broker.go b/pkg/csplugin/broker.go index 3cb8d791a09..c07ab5200ad 100644 --- a/pkg/csplugin/broker.go +++ b/pkg/csplugin/broker.go @@ -17,7 +17,6 @@ import ( "github.com/google/uuid" plugin "github.com/hashicorp/go-plugin" log "github.com/sirupsen/logrus" - "gopkg.in/tomb.v2" "gopkg.in/yaml.v2" "github.com/crowdsecurity/go-cs-lib/csstring" @@ -40,17 +39,17 @@ const ( // It receives all the events from the main process and stacks them up // It is as well notified by the watcher when it needs to deliver events to plugins (based on time or count threshold) type PluginBroker struct { - PluginChannel chan models.ProfileAlert - alertsByPluginName map[string][]*models.Alert - profileConfigs []*csconfig.ProfileCfg - pluginConfigByName map[string]PluginConfig - pluginMap map[string]plugin.Plugin - notificationPluginByName map[string]protobufs.NotifierServer - watcher PluginWatcher - pluginKillMethods []func() - pluginProcConfig *csconfig.PluginCfg - pluginsTypesToDispatch map[string]struct{} - newBackoff backoffFactory + PluginChannel chan models.ProfileAlert + alertsByPluginName map[string][]*models.Alert + profileConfigs []*csconfig.ProfileCfg + pluginConfigByName map[string]PluginConfig + pluginMap map[string]plugin.Plugin + notificationPluginByName map[string]protobufs.NotifierServer + watcher PluginWatcher + pluginKillMethods []func() + pluginProcConfig *csconfig.PluginCfg + pluginsTypesToDispatch map[string]struct{} + newBackoff backoffFactory } // holder to determine where to dispatch config and how to format messages @@ -127,11 +126,22 @@ func (pb *PluginBroker) Kill() { } } -func (pb *PluginBroker) Run(pluginTomb *tomb.Tomb) { +func (pb *PluginBroker) popPluginAlerts(pluginName string) []*models.Alert { + pluginMutex.Lock() + log.Tracef("going to deliver %d alerts to plugin %s", len(pb.alertsByPluginName[pluginName]), pluginName) + tmpAlerts := pb.alertsByPluginName[pluginName] + pb.alertsByPluginName[pluginName] = make([]*models.Alert, 0) + pluginMutex.Unlock() + + return tmpAlerts +} + +func (pb *PluginBroker) Run(ctx context.Context) { // we get signaled via the channel when notifications need to be delivered to plugin (via the watcher) - ctx := context.TODO() + notifyCtx := context.TODO() - pb.watcher.Start(&tomb.Tomb{}) + watcherCtx, cancelWatcher := context.WithCancel(context.Background()) + watcherDone := pb.watcher.Start(watcherCtx) for { select { @@ -139,47 +149,37 @@ func (pb *PluginBroker) Run(pluginTomb *tomb.Tomb) { pb.addProfileAlert(profileAlert) case pluginName := <-pb.watcher.PluginEvents: - // this can be run in goroutine, but then locks will be needed - pluginMutex.Lock() - log.Tracef("going to deliver %d alerts to plugin %s", len(pb.alertsByPluginName[pluginName]), pluginName) - tmpAlerts := pb.alertsByPluginName[pluginName] - pb.alertsByPluginName[pluginName] = make([]*models.Alert, 0) - pluginMutex.Unlock() - - go func() { + tmpAlerts := pb.popPluginAlerts(pluginName) + + go func(pluginName string, alerts []*models.Alert) { // Chunk alerts to respect group_threshold threshold := pb.pluginConfigByName[pluginName].GroupThreshold if threshold == 0 { threshold = 1 } - for _, chunk := range slicetools.Chunks(tmpAlerts, threshold) { - if err := pb.pushNotificationsToPlugin(ctx, pluginName, chunk); err != nil { + for _, chunk := range slicetools.Chunks(alerts, threshold) { + if err := pb.pushNotificationsToPlugin(notifyCtx, pluginName, chunk); err != nil { log.WithField("plugin:", pluginName).Error(err) } } - }() + }(pluginName, tmpAlerts) - case <-pluginTomb.Dying(): - log.Infof("pluginTomb dying") - pb.watcher.tomb.Kill(errors.New("Terminating")) + case <-ctx.Done(): + log.Infof("plugin context canceled") + cancelWatcher() for { select { - case <-pb.watcher.tomb.Dead(): + case <-watcherDone: log.Info("killing all plugins") pb.Kill() return case pluginName := <-pb.watcher.PluginEvents: - // this can be run in goroutine, but then locks will be needed - pluginMutex.Lock() - log.Tracef("going to deliver %d alerts to plugin %s", len(pb.alertsByPluginName[pluginName]), pluginName) - tmpAlerts := pb.alertsByPluginName[pluginName] - pb.alertsByPluginName[pluginName] = make([]*models.Alert, 0) - pluginMutex.Unlock() - - if err := pb.pushNotificationsToPlugin(ctx, pluginName, tmpAlerts); err != nil { + tmpAlerts := pb.popPluginAlerts(pluginName) + + if err := pb.pushNotificationsToPlugin(notifyCtx, pluginName, tmpAlerts); err != nil { log.WithField("plugin:", pluginName).Error(err) } } diff --git a/pkg/csplugin/broker_test.go b/pkg/csplugin/broker_test.go index 07fb1e11c13..04604f658de 100644 --- a/pkg/csplugin/broker_test.go +++ b/pkg/csplugin/broker_test.go @@ -4,6 +4,7 @@ package csplugin import ( "bytes" + "context" "encoding/json" "io" "os" @@ -13,7 +14,6 @@ import ( log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "gopkg.in/tomb.v2" "gopkg.in/yaml.v3" "github.com/crowdsecurity/go-cs-lib/cstest" @@ -154,8 +154,9 @@ func (s *PluginSuite) TestBrokerNoThreshold() { pb, err := s.InitBroker(ctx, nil) require.NoError(t, err) - tomb := tomb.Tomb{} - go pb.Run(&tomb) + brokerCtx, cancelBroker := context.WithCancel(ctx) + defer cancelBroker() + go pb.Run(brokerCtx) // send one item, it should be processed right now pb.PluginChannel <- models.ProfileAlert{ProfileID: uint(0), Alert: &models.Alert{}} @@ -206,8 +207,9 @@ func (s *PluginSuite) TestBrokerRunGroupAndTimeThreshold_TimeFirst() { pb, err := s.InitBroker(ctx, nil) require.NoError(t, err) - tomb := tomb.Tomb{} - go pb.Run(&tomb) + brokerCtx, cancelBroker := context.WithCancel(ctx) + defer cancelBroker() + go pb.Run(brokerCtx) // send data pb.PluginChannel <- models.ProfileAlert{ProfileID: uint(0), Alert: &models.Alert{}} @@ -243,8 +245,9 @@ func (s *PluginSuite) TestBrokerRunGroupAndTimeThreshold_CountFirst() { pb, err := s.InitBroker(ctx, nil) require.NoError(t, err) - tomb := tomb.Tomb{} - go pb.Run(&tomb) + brokerCtx, cancelBroker := context.WithCancel(ctx) + defer cancelBroker() + go pb.Run(brokerCtx) // send data pb.PluginChannel <- models.ProfileAlert{ProfileID: uint(0), Alert: &models.Alert{}} @@ -284,8 +287,9 @@ func (s *PluginSuite) TestBrokerRunGroupThreshold() { pb, err := s.InitBroker(ctx, nil) require.NoError(t, err) - tomb := tomb.Tomb{} - go pb.Run(&tomb) + brokerCtx, cancelBroker := context.WithCancel(ctx) + defer cancelBroker() + go pb.Run(brokerCtx) // send data pb.PluginChannel <- models.ProfileAlert{ProfileID: uint(0), Alert: &models.Alert{}} @@ -338,8 +342,9 @@ func (s *PluginSuite) TestBrokerRunTimeThreshold() { pb, err := s.InitBroker(ctx, nil) require.NoError(t, err) - tomb := tomb.Tomb{} - go pb.Run(&tomb) + brokerCtx, cancelBroker := context.WithCancel(ctx) + defer cancelBroker() + go pb.Run(brokerCtx) // send data pb.PluginChannel <- models.ProfileAlert{ProfileID: uint(0), Alert: &models.Alert{}} @@ -369,8 +374,9 @@ func (s *PluginSuite) TestBrokerRunSimple() { pb, err := s.InitBroker(ctx, nil) require.NoError(t, err) - tomb := tomb.Tomb{} - go pb.Run(&tomb) + brokerCtx, cancelBroker := context.WithCancel(ctx) + defer cancelBroker() + go pb.Run(brokerCtx) assert.NoFileExists(t, s.outFile) diff --git a/pkg/csplugin/broker_win_test.go b/pkg/csplugin/broker_win_test.go index 351ca2ac7d4..fc065eeb2b2 100644 --- a/pkg/csplugin/broker_win_test.go +++ b/pkg/csplugin/broker_win_test.go @@ -4,6 +4,7 @@ package csplugin import ( "bytes" + "context" "encoding/json" "io" "os" @@ -12,7 +13,6 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "gopkg.in/tomb.v2" "github.com/crowdsecurity/go-cs-lib/cstest" @@ -73,8 +73,9 @@ func (s *PluginSuite) TestBrokerRun() { pb, err := s.InitBroker(ctx, nil) require.NoError(t, err) - tomb := tomb.Tomb{} - go pb.Run(&tomb) + brokerCtx, cancelBroker := context.WithCancel(ctx) + defer cancelBroker() + go pb.Run(brokerCtx) assert.NoFileExists(t, s.outFile) defer os.Remove(s.outFile) diff --git a/pkg/csplugin/watcher.go b/pkg/csplugin/watcher.go index 6a5d31909d5..c1ed3db8ac8 100644 --- a/pkg/csplugin/watcher.go +++ b/pkg/csplugin/watcher.go @@ -1,11 +1,11 @@ package csplugin import ( + "context" "sync" "time" log "github.com/sirupsen/logrus" - "gopkg.in/tomb.v2" "github.com/crowdsecurity/crowdsec/pkg/models" ) @@ -50,7 +50,8 @@ type PluginWatcher struct { AlertCountByPluginName alertCounterByPluginName PluginEvents chan string Inserts chan string - tomb *tomb.Tomb + wg sync.WaitGroup + done chan struct{} } var DefaultEmptyTicker = time.Second * 1 @@ -65,23 +66,33 @@ func (pw *PluginWatcher) Init(configs map[string]PluginConfig, alertsByPluginNam } } -func (pw *PluginWatcher) Start(tomb *tomb.Tomb) { - pw.tomb = tomb +func (pw *PluginWatcher) Start(ctx context.Context) <-chan struct{} { + pw.done = make(chan struct{}) + for name := range pw.PluginConfigByName { pname := name - pw.tomb.Go(func() error { - pw.watchPluginTicker(pname) - return nil - }) + pw.wg.Add(1) + go func() { + defer pw.wg.Done() + pw.watchPluginTicker(ctx, pname) + }() } - pw.tomb.Go(func() error { - pw.watchPluginAlertCounts() - return nil - }) + pw.wg.Add(1) + go func() { + defer pw.wg.Done() + pw.watchPluginAlertCounts(ctx) + }() + + go func() { + pw.wg.Wait() + close(pw.done) + }() + + return pw.done } -func (pw *PluginWatcher) watchPluginTicker(pluginName string) { +func (pw *PluginWatcher) watchPluginTicker(ctx context.Context, pluginName string) { cfg := pw.PluginConfigByName[pluginName] interval := cfg.GroupWait threshold := cfg.GroupThreshold @@ -142,15 +153,15 @@ func (pw *PluginWatcher) watchPluginTicker(pluginName string) { lastSend = now pw.AlertCountByPluginName.Set(pluginName, 0) pw.PluginEvents <- pluginName - case <-pw.tomb.Dying(): - // no lock here because we have the broker still listening even in dying state before killing us + case <-ctx.Done(): + // no lock here because the broker is still listening during shutdown. pw.PluginEvents <- pluginName return } } } -func (pw *PluginWatcher) watchPluginAlertCounts() { +func (pw *PluginWatcher) watchPluginAlertCounts(ctx context.Context) { for { select { case pluginName := <-pw.Inserts: @@ -160,7 +171,7 @@ func (pw *PluginWatcher) watchPluginAlertCounts() { curr = 0 } pw.AlertCountByPluginName.Set(pluginName, curr+1) - case <-pw.tomb.Dying(): + case <-ctx.Done(): return } } diff --git a/pkg/csplugin/watcher_test.go b/pkg/csplugin/watcher_test.go index 4a3a1bacea3..6f1efe59dd6 100644 --- a/pkg/csplugin/watcher_test.go +++ b/pkg/csplugin/watcher_test.go @@ -7,19 +7,26 @@ import ( "time" "github.com/stretchr/testify/require" - "gopkg.in/tomb.v2" "github.com/crowdsecurity/go-cs-lib/cstest" "github.com/crowdsecurity/crowdsec/pkg/models" ) -func resetTestTomb(t *testing.T, testTomb *tomb.Tomb, pw *PluginWatcher) { - testTomb.Kill(nil) - <-pw.PluginEvents +func stopWatcher(t *testing.T, cancel context.CancelFunc, pw *PluginWatcher, done <-chan struct{}) { + cancel() - err := testTomb.Wait() - require.NoError(t, err) + select { + case <-pw.PluginEvents: + case <-time.After(time.Second): + require.FailNow(t, "timeout waiting for watcher flush event") + } + + select { + case <-done: + case <-time.After(time.Second): + require.FailNow(t, "timeout waiting for watcher shutdown") + } } func resetWatcherAlertCounter(pw *PluginWatcher) { @@ -54,23 +61,24 @@ func TestPluginWatcherInterval(t *testing.T) { pw := PluginWatcher{} alertsByPluginName := make(map[string][]*models.Alert) - testTomb := tomb.Tomb{} + watcherCtx, cancelWatcher := context.WithCancel(ctx) configs := map[string]PluginConfig{ "testPlugin": { GroupWait: time.Millisecond, }, } pw.Init(configs, alertsByPluginName) - pw.Start(&testTomb) + watcherDone := pw.Start(watcherCtx) ct, cancel := context.WithTimeout(ctx, time.Microsecond) defer cancel() err := listenChannelWithTimeout(ct, pw.PluginEvents) cstest.RequireErrorContains(t, err, "context deadline exceeded") - resetTestTomb(t, &testTomb, &pw) - testTomb = tomb.Tomb{} - pw.Start(&testTomb) + stopWatcher(t, cancelWatcher, &pw, watcherDone) + + watcherCtx, cancelWatcher = context.WithCancel(ctx) + watcherDone = pw.Start(watcherCtx) insertNAlertsToPlugin(&pw, 1, "testPlugin") ct, cancel = context.WithTimeout(ctx, time.Millisecond*5) @@ -78,7 +86,7 @@ func TestPluginWatcherInterval(t *testing.T) { err = listenChannelWithTimeout(ct, pw.PluginEvents) require.NoError(t, err) - resetTestTomb(t, &testTomb, &pw) + stopWatcher(t, cancelWatcher, &pw, watcherDone) // This is to avoid the int complaining } @@ -94,10 +102,10 @@ func TestPluginAlertCountWatcher(t *testing.T) { GroupThreshold: 5, }, } - testTomb := tomb.Tomb{} + watcherCtx, cancelWatcher := context.WithCancel(ctx) pw.Init(configs, alertsByPluginName) - pw.Start(&testTomb) + watcherDone := pw.Start(watcherCtx) // Channel won't contain any events since threshold is not crossed. ct, cancel := context.WithTimeout(ctx, time.Second) @@ -125,5 +133,5 @@ func TestPluginAlertCountWatcher(t *testing.T) { err = listenChannelWithTimeout(ct, pw.PluginEvents) require.NoError(t, err) - resetTestTomb(t, &testTomb, &pw) + stopWatcher(t, cancelWatcher, &pw, watcherDone) } From c29f1bb665a092423ef29924cf15f4d0dc05335c Mon Sep 17 00:00:00 2001 From: marco Date: Fri, 27 Feb 2026 14:11:05 +0100 Subject: [PATCH 05/23] csplugin: fix lint issues and cap shutdown flush Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../clinotifications/notifications.go | 16 ++++++++-------- pkg/csplugin/broker.go | 14 +++++++++++--- 2 files changed, 19 insertions(+), 11 deletions(-) diff --git a/cmd/crowdsec-cli/clinotifications/notifications.go b/cmd/crowdsec-cli/clinotifications/notifications.go index 2bba45722f4..d7fa1ef27a7 100644 --- a/cmd/crowdsec-cli/clinotifications/notifications.go +++ b/cmd/crowdsec-cli/clinotifications/notifications.go @@ -320,6 +320,10 @@ func (cli *cliNotifications) newTestCmd() *cobra.Command { pluginBroker.Run(pluginCtx) close(pluginDone) }() + defer func() { + cancelPlugin() + <-pluginDone + }() alert := &models.Alert{ Capacity: ptr.Of(int32(0)), @@ -361,10 +365,6 @@ func (cli *cliNotifications) newTestCmd() *cobra.Command { Alert: alert, } - // time.Sleep(2 * time.Second) // There's no mechanism to ensure notification has been sent - cancelPlugin() - <-pluginDone - return nil }, } @@ -447,6 +447,10 @@ cscli notifications reinject -a '{"remediation": true,"scenario":"not pluginBroker.Run(pluginCtx) close(pluginDone) }() + defer func() { + cancelPlugin() + <-pluginDone + }() profiles, err := csprofiles.NewProfile(cfg.API.Server.Profiles) if err != nil { @@ -482,10 +486,6 @@ cscli notifications reinject -a '{"remediation": true,"scenario":"not break } } - // time.Sleep(2 * time.Second) // There's no mechanism to ensure notification has been sent - cancelPlugin() - <-pluginDone - return nil }, } diff --git a/pkg/csplugin/broker.go b/pkg/csplugin/broker.go index c07ab5200ad..b42a050e9c1 100644 --- a/pkg/csplugin/broker.go +++ b/pkg/csplugin/broker.go @@ -138,9 +138,9 @@ func (pb *PluginBroker) popPluginAlerts(pluginName string) []*models.Alert { func (pb *PluginBroker) Run(ctx context.Context) { // we get signaled via the channel when notifications need to be delivered to plugin (via the watcher) - notifyCtx := context.TODO() + notifyCtx := ctx - watcherCtx, cancelWatcher := context.WithCancel(context.Background()) + watcherCtx, cancelWatcher := context.WithCancel(ctx) watcherDone := pb.watcher.Start(watcherCtx) for { @@ -168,18 +168,26 @@ func (pb *PluginBroker) Run(ctx context.Context) { case <-ctx.Done(): log.Infof("plugin context canceled") cancelWatcher() + flushCtx, cancelFlush := context.WithTimeout(context.WithoutCancel(ctx), 5*time.Second) for { select { case <-watcherDone: + cancelFlush() log.Info("killing all plugins") pb.Kill() + return + case <-flushCtx.Done(): + cancelFlush() + log.Warn("plugin flush timeout reached, stopping pending notifications") + pb.Kill() + return case pluginName := <-pb.watcher.PluginEvents: tmpAlerts := pb.popPluginAlerts(pluginName) - if err := pb.pushNotificationsToPlugin(notifyCtx, pluginName, tmpAlerts); err != nil { + if err := pb.pushNotificationsToPlugin(flushCtx, pluginName, tmpAlerts); err != nil { log.WithField("plugin:", pluginName).Error(err) } } From b6f8b4868c7564546a184a8a4aa2cdc00ca7912c Mon Sep 17 00:00:00 2001 From: marco Date: Fri, 27 Feb 2026 15:10:42 +0100 Subject: [PATCH 06/23] apiserver/papi: migrate sync lifecycle from tomb to context Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- pkg/apiserver/apiserver.go | 14 +---------- pkg/apiserver/papi.go | 49 +++++++++++++++++++++++++++++++++----- 2 files changed, 44 insertions(+), 19 deletions(-) diff --git a/pkg/apiserver/apiserver.go b/pkg/apiserver/apiserver.go index a06c1ba9bde..094acc3eef7 100644 --- a/pkg/apiserver/apiserver.go +++ b/pkg/apiserver/apiserver.go @@ -275,15 +275,6 @@ func (s *APIServer) papiPull(ctx context.Context) error { return nil } -func (s *APIServer) papiSync(ctx context.Context) error { - if err := s.papi.SyncDecisions(ctx); err != nil { - log.Errorf("capi decisions sync: %s", err) - return err - } - - return nil -} - func (s *APIServer) initAPIC(ctx context.Context) { s.apic.pushTomb.Go(func() error { defer trace.ReportPanic() @@ -302,10 +293,7 @@ func (s *APIServer) initAPIC(ctx context.Context) { defer trace.ReportPanic() return s.papiPull(ctx) }) - s.papi.syncTomb.Go(func() error { - defer trace.ReportPanic() - return s.papiSync(ctx) - }) + s.papi.StartSync(ctx) } else { log.Warnf("papi_url is not set in online_api_credentials.yaml, can't synchronize with the console. Run cscli console enable console_management to add it.") } diff --git a/pkg/apiserver/papi.go b/pkg/apiserver/papi.go index d72b237484e..67d0957af8b 100644 --- a/pkg/apiserver/papi.go +++ b/pkg/apiserver/papi.go @@ -66,7 +66,8 @@ type Papi struct { Channels *OperationChannels mu sync.Mutex pullTomb tomb.Tomb - syncTomb tomb.Tomb + syncCancel context.CancelFunc + syncDone chan struct{} SyncInterval time.Duration consoleConfig *csconfig.ConsoleConfig Logger *log.Entry @@ -114,7 +115,6 @@ func NewPAPI(apic *apic, dbClient *database.Client, consoleConfig *csconfig.Cons SyncInterval: SyncInterval, mu: sync.Mutex{}, pullTomb: tomb.Tomb{}, - syncTomb: tomb.Tomb{}, apiClient: apic.apiClient, apic: apic, consoleConfig: consoleConfig, @@ -329,22 +329,59 @@ func (p *Papi) Pull(ctx context.Context) error { } } +func (p *Papi) StartSync(ctx context.Context) { + p.mu.Lock() + if p.syncCancel != nil { + p.mu.Unlock() + return + } + + syncCtx, cancel := context.WithCancel(ctx) + done := make(chan struct{}) + p.syncCancel = cancel + p.syncDone = done + p.mu.Unlock() + + go func() { + defer close(done) + _ = p.SyncDecisions(syncCtx) + }() +} + +func (p *Papi) StopSync() { + p.mu.Lock() + cancel := p.syncCancel + done := p.syncDone + p.syncCancel = nil + p.syncDone = nil + p.mu.Unlock() + + if cancel != nil { + cancel() + } + + if done != nil { + <-done + } +} + func (p *Papi) SyncDecisions(ctx context.Context) error { var cache models.DecisionsDeleteRequest ticker := time.NewTicker(p.SyncInterval) + defer ticker.Stop() p.Logger.Infof("Start decisions sync to CrowdSec Central API (interval: %s)", p.SyncInterval) for { select { - case <-p.syncTomb.Dying(): // if one apic routine is dying, do we kill the others? - p.Logger.Infof("sync decisions tomb is dying, sending cache (%d elements) before exiting", len(cache)) + case <-ctx.Done(): + p.Logger.Infof("sync decisions context canceled, sending cache (%d elements) before exiting", len(cache)) if len(cache) == 0 { return nil } - go p.SendDeletedDecisions(ctx, &cache) + go p.SendDeletedDecisions(context.WithoutCancel(ctx), &cache) return nil case <-ticker.C: @@ -404,7 +441,7 @@ func (p *Papi) SendDeletedDecisions(ctx context.Context, cacheOrig *models.Decis func (p *Papi) Shutdown() { p.Logger.Infof("Shutting down PAPI") - p.syncTomb.Kill(nil) + p.StopSync() select { case p.stopChan <- struct{}{}: // Cancel any HTTP request still in progress default: From a75a26442212e3a16a378271284bc8f78ba31505 Mon Sep 17 00:00:00 2001 From: marco Date: Fri, 27 Feb 2026 15:36:56 +0100 Subject: [PATCH 07/23] apiserver/apic: migrate metrics lifecycle to context Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- pkg/apiserver/apic.go | 62 ++++++++++++++++++++++++++++++++--- pkg/apiserver/apic_metrics.go | 19 ++++------- pkg/apiserver/apic_test.go | 11 +++---- pkg/apiserver/apiserver.go | 14 +------- 4 files changed, 70 insertions(+), 36 deletions(-) diff --git a/pkg/apiserver/apic.go b/pkg/apiserver/apic.go index 23a022901b6..783ab733236 100644 --- a/pkg/apiserver/apic.go +++ b/pkg/apiserver/apic.go @@ -63,7 +63,8 @@ type apic struct { mu sync.Mutex pushTomb tomb.Tomb pullTomb tomb.Tomb - metricsTomb tomb.Tomb + metricsCancel context.CancelFunc + metricsDone chan struct{} startup bool consoleConfig *csconfig.ConsoleConfig isPulling chan bool @@ -195,7 +196,6 @@ func NewAPIC(ctx context.Context, config *csconfig.OnlineApiClientCfg, dbClient startup: true, pullTomb: tomb.Tomb{}, pushTomb: tomb.Tomb{}, - metricsTomb: tomb.Tomb{}, consoleConfig: consoleConfig, pullInterval: pullIntervalDefault, pullIntervalFirst: randomDuration(pullIntervalDefault, pullIntervalDelta), @@ -295,7 +295,7 @@ func (a *apic) Push(ctx context.Context) error { select { case <-a.pushTomb.Dying(): // if one apic routine is dying, do we kill the others? a.pullTomb.Kill(nil) - a.metricsTomb.Kill(nil) + a.StopMetrics() log.Infof("push tomb is dying, sending cache (%d elements) before exiting", len(cache)) if len(cache) == 0 { @@ -1073,7 +1073,7 @@ func (a *apic) Pull(ctx context.Context) error { continue } case <-a.pullTomb.Dying(): // if one apic routine is dying, do we kill the others? - a.metricsTomb.Kill(nil) + a.StopMetrics() a.pushTomb.Kill(nil) return nil @@ -1081,10 +1081,62 @@ func (a *apic) Pull(ctx context.Context) error { } } +func (a *apic) StartMetrics(ctx context.Context, sendUsageMetrics bool) { + a.mu.Lock() + if a.metricsCancel != nil { + a.mu.Unlock() + return + } + + metricsCtx, cancel := context.WithCancel(ctx) + done := make(chan struct{}) + a.metricsCancel = cancel + a.metricsDone = done + a.mu.Unlock() + + go func() { + var wg sync.WaitGroup + + wg.Add(1) + go func() { + defer wg.Done() + a.SendMetrics(metricsCtx, make(chan bool)) + }() + + if sendUsageMetrics { + wg.Add(1) + go func() { + defer wg.Done() + a.SendUsageMetrics(metricsCtx) + }() + } + + wg.Wait() + close(done) + }() +} + +func (a *apic) StopMetrics() { + a.mu.Lock() + cancel := a.metricsCancel + done := a.metricsDone + a.metricsCancel = nil + a.metricsDone = nil + a.mu.Unlock() + + if cancel != nil { + cancel() + } + + if done != nil { + <-done + } +} + func (a *apic) Shutdown() { a.pushTomb.Kill(nil) a.pullTomb.Kill(nil) - a.metricsTomb.Kill(nil) + a.StopMetrics() } func makeAddAndDeleteCounters() (map[string]map[string]int, map[string]map[string]int) { diff --git a/pkg/apiserver/apic_metrics.go b/pkg/apiserver/apic_metrics.go index bc49b0cdc73..53b3e41be3b 100644 --- a/pkg/apiserver/apic_metrics.go +++ b/pkg/apiserver/apic_metrics.go @@ -291,13 +291,16 @@ func (a *apic) SendMetrics(ctx context.Context, stop chan bool) { checkTicker := time.NewTicker(checkInt) metTicker := time.NewTicker(nextMetInt()) + defer checkTicker.Stop() + defer metTicker.Stop() for { select { case <-stop: - checkTicker.Stop() - metTicker.Stop() - + return + case <-ctx.Done(): + a.pullTomb.Kill(nil) + a.pushTomb.Kill(nil) return case <-checkTicker.C: oldIDs := machineIDs @@ -326,13 +329,6 @@ func (a *apic) SendMetrics(ctx context.Context, stop chan bool) { } metTicker.Reset(nextMetInt()) - case <-a.metricsTomb.Dying(): // if one apic routine is dying, do we kill the others? - checkTicker.Stop() - metTicker.Stop() - a.pullTomb.Kill(nil) - a.pushTomb.Kill(nil) - - return } } } @@ -345,8 +341,7 @@ func (a *apic) SendUsageMetrics(ctx context.Context) { for { select { - case <-a.metricsTomb.Dying(): - // The normal metrics routine also kills push/pull tombs, does that make sense ? + case <-ctx.Done(): ticker.Stop() return case <-ticker.C: diff --git a/pkg/apiserver/apic_test.go b/pkg/apiserver/apic_test.go index 3015a2a24ba..9b059983ebd 100644 --- a/pkg/apiserver/apic_test.go +++ b/pkg/apiserver/apic_test.go @@ -57,12 +57,11 @@ func getAPIC(t *testing.T, ctx context.Context) *apic { return &apic{ AlertsAddChan: make(chan []*models.Alert), // DecisionDeleteChan: make(chan []*models.Decision), - dbClient: dbClient, - mu: sync.Mutex{}, - startup: true, - pullTomb: tomb.Tomb{}, - pushTomb: tomb.Tomb{}, - metricsTomb: tomb.Tomb{}, + dbClient: dbClient, + mu: sync.Mutex{}, + startup: true, + pullTomb: tomb.Tomb{}, + pushTomb: tomb.Tomb{}, consoleConfig: &csconfig.ConsoleConfig{ ShareManualDecisions: ptr.Of(false), ShareTaintedScenarios: ptr.Of(false), diff --git a/pkg/apiserver/apiserver.go b/pkg/apiserver/apiserver.go index 094acc3eef7..f79eeeabbea 100644 --- a/pkg/apiserver/apiserver.go +++ b/pkg/apiserver/apiserver.go @@ -302,19 +302,7 @@ func (s *APIServer) initAPIC(ctx context.Context) { } } - s.apic.metricsTomb.Go(func() error { - defer trace.ReportPanic() - s.apic.SendMetrics(ctx, make(chan bool)) - return nil - }) - - if !s.cfg.DisableUsageMetricsExport { - s.apic.metricsTomb.Go(func() error { - defer trace.ReportPanic() - s.apic.SendUsageMetrics(ctx) - return nil - }) - } + s.apic.StartMetrics(ctx, !s.cfg.DisableUsageMetricsExport) } func (s *APIServer) Run(ctx context.Context, apiReady chan bool) error { From b3f1e383f819fe12b04440cfc8bf886c2bc5c835 Mon Sep 17 00:00:00 2001 From: marco Date: Fri, 27 Feb 2026 17:02:48 +0100 Subject: [PATCH 08/23] acquisition: migrate file streaming to context Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- pkg/acquisition/modules/file/file_test.go | 56 ++++++++++++------ pkg/acquisition/modules/file/init.go | 10 ++-- pkg/acquisition/modules/file/run.go | 69 ++++++++++++++++------- 3 files changed, 92 insertions(+), 43 deletions(-) diff --git a/pkg/acquisition/modules/file/file_test.go b/pkg/acquisition/modules/file/file_test.go index faf2d1888cc..a837f76e368 100644 --- a/pkg/acquisition/modules/file/file_test.go +++ b/pkg/acquisition/modules/file/file_test.go @@ -1,10 +1,12 @@ package fileacquisition_test import ( + "context" "fmt" "os" "path/filepath" "runtime" + "strings" "sync/atomic" "testing" "time" @@ -13,7 +15,6 @@ import ( "github.com/sirupsen/logrus/hooks/test" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "gopkg.in/tomb.v2" "github.com/crowdsecurity/go-cs-lib/cstest" @@ -340,8 +341,9 @@ force_inotify: true`, testPattern), subLogger := logger.WithField("type", fileacquisition.ModuleName) - tomb := tomb.Tomb{} out := make(chan pipeline.Event) + streamCtx, cancelStream := context.WithCancel(ctx) + streamErr := make(chan error, 1) f := fileacquisition.Source{} @@ -378,8 +380,13 @@ force_inotify: true`, testPattern), }() } - err = f.StreamingAcquisition(ctx, out, &tomb) - cstest.RequireErrorContains(t, err, tc.expectedErr) + go func() { + streamErr <- f.Stream(streamCtx, out) + }() + + if tc.expectedLines == 0 { + time.Sleep(200 * time.Millisecond) + } if tc.expectedLines != 0 { // f.IsTailing is path delimiter sensitive @@ -422,11 +429,18 @@ force_inotify: true`, testPattern), } if tc.expectedOutput != "" { - if hook.LastEntry() == nil { + found := false + for _, entry := range hook.AllEntries() { + if strings.Contains(entry.Message, tc.expectedOutput) { + found = true + break + } + } + + if !found { t.Fatalf("expected output %s, but got nothing", tc.expectedOutput) } - assert.Contains(t, hook.LastEntry().Message, tc.expectedOutput) hook.Reset() } @@ -434,7 +448,9 @@ force_inotify: true`, testPattern), tc.teardown() } - tomb.Kill(nil) + cancelStream() + err = <-streamErr + cstest.RequireErrorContains(t, err, tc.expectedErr) }) } } @@ -481,11 +497,14 @@ mode: tail // Create channel for events eventChan := make(chan pipeline.Event) - tomb := tomb.Tomb{} + streamCtx, cancelStream := context.WithCancel(ctx) + defer cancelStream() + streamErr := make(chan error, 1) // Start acquisition - err = f.StreamingAcquisition(ctx, eventChan, &tomb) - require.NoError(t, err) + go func() { + streamErr <- f.Stream(streamCtx, eventChan) + }() // Create a test file testFile := filepath.Join(dir, "test.log") @@ -503,8 +522,8 @@ mode: tail require.False(t, f.IsTailing(ignoredFile), "File should be ignored after polling") // Cleanup - tomb.Kill(nil) - require.NoError(t, tomb.Wait()) + cancelStream() + require.NoError(t, <-streamErr) } func TestFileResurrectionViaPolling(t *testing.T) { @@ -532,10 +551,13 @@ mode: tail require.NoError(t, err) eventChan := make(chan pipeline.Event) - tomb := tomb.Tomb{} + streamCtx, cancelStream := context.WithCancel(ctx) + defer cancelStream() + streamErr := make(chan error, 1) - err = f.StreamingAcquisition(ctx, eventChan, &tomb) - require.NoError(t, err) + go func() { + streamErr <- f.Stream(streamCtx, eventChan) + }() // Wait for initial tail setup time.Sleep(100 * time.Millisecond) @@ -553,6 +575,6 @@ mode: tail require.True(t, isTailed, "File should be resurrected via polling") // Cleanup - tomb.Kill(nil) - require.NoError(t, tomb.Wait()) + cancelStream() + require.NoError(t, <-streamErr) } diff --git a/pkg/acquisition/modules/file/init.go b/pkg/acquisition/modules/file/init.go index 42c54eba77f..2c63e8f48b3 100644 --- a/pkg/acquisition/modules/file/init.go +++ b/pkg/acquisition/modules/file/init.go @@ -7,11 +7,11 @@ import ( var ( // verify interface compliance - _ types.DataSource = (*Source)(nil) - _ types.DSNConfigurer = (*Source)(nil) - _ types.BatchFetcher = (*Source)(nil) - _ types.Tailer = (*Source)(nil) - _ types.MetricsProvider = (*Source)(nil) + _ types.DataSource = (*Source)(nil) + _ types.DSNConfigurer = (*Source)(nil) + _ types.BatchFetcher = (*Source)(nil) + _ types.RestartableStreamer = (*Source)(nil) + _ types.MetricsProvider = (*Source)(nil) ) const ModuleName = "file" diff --git a/pkg/acquisition/modules/file/run.go b/pkg/acquisition/modules/file/run.go index 37434b8739c..76062de7908 100644 --- a/pkg/acquisition/modules/file/run.go +++ b/pkg/acquisition/modules/file/run.go @@ -5,6 +5,7 @@ import ( "cmp" "compress/gzip" "context" + "errors" "fmt" "io" "os" @@ -16,7 +17,6 @@ import ( "github.com/nxadm/tail" "github.com/prometheus/client_golang/prometheus" log "github.com/sirupsen/logrus" - "gopkg.in/tomb.v2" "github.com/crowdsecurity/go-cs-lib/trace" @@ -53,19 +53,46 @@ func (s *Source) OneShot(ctx context.Context, out chan pipeline.Event) error { return nil } -func (s *Source) StreamingAcquisition(_ context.Context, out chan pipeline.Event, t *tomb.Tomb) error { +func sendStreamErr(errCh chan<- error, err error) { + if err == nil { + return + } + + select { + case errCh <- err: + default: + } +} + +func (s *Source) Stream(ctx context.Context, out chan pipeline.Event) error { s.logger.Debug("Starting live acquisition") - t.Go(func() error { - return s.monitorNewFiles(out, t) - }) + + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + errCh := make(chan error, 1) + go func() { + sendStreamErr(errCh, s.monitorNewFiles(ctx, out, errCh)) + }() for _, file := range s.files { - if err := s.setupTailForFile(file, out, true, t); err != nil { + if err := s.setupTailForFile(ctx, file, out, true, errCh); err != nil { s.logger.Errorf("Error setting up tail for %s: %s", file, err) } } - return nil + for { + select { + case <-ctx.Done(): + return nil + case err := <-errCh: + if err == nil || errors.Is(err, context.Canceled) || ctx.Err() != nil { + return nil + } + + return err + } + } } // checkAndTailFile validates and sets up tailing for a given file. It performs the following checks: @@ -77,10 +104,10 @@ func (s *Source) StreamingAcquisition(_ context.Context, out chan pipeline.Event // - filename: The path to the file to check and potentially tail // - logger: A log.Entry for contextual logging // - out: Channel to send file events to -// - t: A tomb.Tomb for graceful shutdown handling +// - ctx: context used for graceful shutdown handling // // Returns an error if any validation fails or if tailing setup fails -func (s *Source) checkAndTailFile(filename string, logger *log.Entry, out chan pipeline.Event, t *tomb.Tomb) error { +func (s *Source) checkAndTailFile(ctx context.Context, filename string, logger *log.Entry, out chan pipeline.Event, errCh chan<- error) error { // Check if it's a directory fi, err := os.Stat(filename) if err != nil { @@ -117,7 +144,7 @@ func (s *Source) checkAndTailFile(filename string, logger *log.Entry, out chan p } // Setup the tail if needed - if err := s.setupTailForFile(filename, out, false, t); err != nil { + if err := s.setupTailForFile(ctx, filename, out, false, errCh); err != nil { logger.Errorf("Error setting up tail for file %s: %s", filename, err) return err } @@ -125,13 +152,13 @@ func (s *Source) checkAndTailFile(filename string, logger *log.Entry, out chan p return nil } -func (s *Source) monitorNewFiles(out chan pipeline.Event, t *tomb.Tomb) error { +func (s *Source) monitorNewFiles(ctx context.Context, out chan pipeline.Event, errCh chan<- error) error { logger := s.logger.WithField("goroutine", "inotify") // Setup polling if enabled var ( tickerChan <-chan time.Time - ticker *time.Ticker + ticker *time.Ticker ) if s.config.DiscoveryPollEnable { @@ -154,7 +181,7 @@ func (s *Source) monitorNewFiles(out chan pipeline.Event, t *tomb.Tomb) error { continue } - _ = s.checkAndTailFile(event.Name, logger, out, t) + _ = s.checkAndTailFile(ctx, event.Name, logger, out, errCh) case <-tickerChan: // Will never trigger if tickerChan is nil // Poll for all configured patterns @@ -166,7 +193,7 @@ func (s *Source) monitorNewFiles(out chan pipeline.Event, t *tomb.Tomb) error { } for _, file := range files { - _ = s.checkAndTailFile(file, logger, out, t) + _ = s.checkAndTailFile(ctx, file, logger, out, errCh) } } @@ -177,7 +204,7 @@ func (s *Source) monitorNewFiles(out chan pipeline.Event, t *tomb.Tomb) error { logger.Errorf("Error while monitoring folder: %s", err) - case <-t.Dying(): + case <-ctx.Done(): err := s.watcher.Close() if err != nil { return fmt.Errorf("could not remove all inotify watches: %w", err) @@ -188,7 +215,7 @@ func (s *Source) monitorNewFiles(out chan pipeline.Event, t *tomb.Tomb) error { } } -func (s *Source) setupTailForFile(file string, out chan pipeline.Event, seekEnd bool, t *tomb.Tomb) error { +func (s *Source) setupTailForFile(ctx context.Context, file string, out chan pipeline.Event, seekEnd bool, errCh chan<- error) error { logger := s.logger.WithField("file", file) if s.isExcluded(file) { @@ -283,21 +310,21 @@ func (s *Source) setupTailForFile(file string, out chan pipeline.Event, seekEnd s.tails[file] = true s.tailMapMutex.Unlock() - t.Go(func() error { + go func() { defer trace.ReportPanic() - return s.tailFile(out, t, tail) - }) + sendStreamErr(errCh, s.tailFile(ctx, out, tail)) + }() return nil } -func (s *Source) tailFile(out chan pipeline.Event, t *tomb.Tomb, tail *tail.Tail) error { +func (s *Source) tailFile(ctx context.Context, out chan pipeline.Event, tail *tail.Tail) error { logger := s.logger.WithField("tail", tail.Filename) logger.Debug("-> start tailing") for { select { - case <-t.Dying(): + case <-ctx.Done(): logger.Info("File datasource stopping") if err := tail.Stop(); err != nil { From f94fad7fa8948e86742356eaeca36c7b17764395 Mon Sep 17 00:00:00 2001 From: marco Date: Fri, 27 Feb 2026 17:15:37 +0100 Subject: [PATCH 09/23] acquisition: migrate http streaming to context Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- pkg/acquisition/modules/http/http_test.go | 121 ++++++++-------------- pkg/acquisition/modules/http/init.go | 6 +- pkg/acquisition/modules/http/run.go | 74 ++++++------- 3 files changed, 87 insertions(+), 114 deletions(-) diff --git a/pkg/acquisition/modules/http/http_test.go b/pkg/acquisition/modules/http/http_test.go index 12f1c959e62..67c6617e39c 100644 --- a/pkg/acquisition/modules/http/http_test.go +++ b/pkg/acquisition/modules/http/http_test.go @@ -20,7 +20,6 @@ import ( log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "gopkg.in/tomb.v2" "github.com/crowdsecurity/crowdsec/pkg/metrics" "github.com/crowdsecurity/crowdsec/pkg/pipeline" @@ -56,7 +55,7 @@ func TestGetName(t *testing.T) { assert.Equal(t, "http", h.GetName()) } -func SetupAndRunHTTPSource(t *testing.T, h *Source, config []byte, metricLevel metrics.AcquisitionMetricsLevel) (chan pipeline.Event, *prometheus.Registry, *tomb.Tomb) { +func SetupAndRunHTTPSource(t *testing.T, h *Source, config []byte, metricLevel metrics.AcquisitionMetricsLevel) (chan pipeline.Event, *prometheus.Registry, context.CancelFunc, chan error) { ctx := t.Context() subLogger := log.WithFields(log.Fields{ "type": ModuleName, @@ -64,10 +63,12 @@ func SetupAndRunHTTPSource(t *testing.T, h *Source, config []byte, metricLevel m err := h.Configure(ctx, config, subLogger, metricLevel) require.NoError(t, err) - tomb := tomb.Tomb{} + streamCtx, cancel := context.WithCancel(ctx) + streamErr := make(chan error, 1) out := make(chan pipeline.Event) - err = h.StreamingAcquisition(ctx, out, &tomb) - require.NoError(t, err) + go func() { + streamErr <- h.Stream(streamCtx, out) + }() testRegistry := prometheus.NewPedanticRegistry() for _, metric := range h.GetMetrics() { @@ -75,12 +76,22 @@ func SetupAndRunHTTPSource(t *testing.T, h *Source, config []byte, metricLevel m require.NoError(t, err) } - return out, testRegistry, &tomb + return out, testRegistry, cancel, streamErr +} + +func stopHTTPSource(t *testing.T, h *Source, cancel context.CancelFunc, streamErr <-chan error) { + t.Helper() + + if h.Server != nil { + _ = h.Server.Close() + } + cancel() + require.NoError(t, <-streamErr) } func TestStreamingAcquisitionHTTPMethod(t *testing.T) { h := &Source{} - _, _, tomb := SetupAndRunHTTPSource(t, h, []byte(` + _, _, cancel, streamErr := SetupAndRunHTTPSource(t, h, []byte(` source: http listen_addr: 127.0.0.1:8080 path: /test @@ -122,17 +133,14 @@ basic_auth: assert.Equal(t, http.StatusOK, res.StatusCode) closeBody(t, res) - h.Server.Close() - tomb.Kill(nil) - err = tomb.Wait() - require.NoError(t, err) + stopHTTPSource(t, h, cancel, streamErr) } func TestStreamingAcquisitionUnknownPath(t *testing.T) { ctx := t.Context() h := &Source{} - _, _, tomb := SetupAndRunHTTPSource(t, h, []byte(` + _, _, cancel, streamErr := SetupAndRunHTTPSource(t, h, []byte(` source: http listen_addr: 127.0.0.1:8080 path: /test @@ -151,16 +159,13 @@ basic_auth: assert.Equal(t, http.StatusNotFound, res.StatusCode) closeBody(t, res) - h.Server.Close() - tomb.Kill(nil) - err = tomb.Wait() - require.NoError(t, err) + stopHTTPSource(t, h, cancel, streamErr) } func TestStreamingAcquisitionBasicAuth(t *testing.T) { ctx := t.Context() h := &Source{} - _, _, tomb := SetupAndRunHTTPSource(t, h, []byte(` + _, _, cancel, streamErr := SetupAndRunHTTPSource(t, h, []byte(` source: http listen_addr: 127.0.0.1:8080 path: /test @@ -191,16 +196,13 @@ basic_auth: assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) closeBody(t, resp) - h.Server.Close() - tomb.Kill(nil) - err = tomb.Wait() - require.NoError(t, err) + stopHTTPSource(t, h, cancel, streamErr) } func TestStreamingAcquisitionBadHeaders(t *testing.T) { ctx := t.Context() h := &Source{} - _, _, tomb := SetupAndRunHTTPSource(t, h, []byte(` + _, _, cancel, streamErr := SetupAndRunHTTPSource(t, h, []byte(` source: http listen_addr: 127.0.0.1:8080 path: /test @@ -221,16 +223,13 @@ headers: assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) closeBody(t, resp) - h.Server.Close() - tomb.Kill(nil) - err = tomb.Wait() - require.NoError(t, err) + stopHTTPSource(t, h, cancel, streamErr) } func TestStreamingAcquisitionMaxBodySize(t *testing.T) { ctx := t.Context() h := &Source{} - _, _, tomb := SetupAndRunHTTPSource(t, h, []byte(` + _, _, cancel, streamErr := SetupAndRunHTTPSource(t, h, []byte(` source: http listen_addr: 127.0.0.1:8080 path: /test @@ -252,16 +251,13 @@ max_body_size: 5`), 0) assert.Equal(t, http.StatusRequestEntityTooLarge, resp.StatusCode) closeBody(t, resp) - h.Server.Close() - tomb.Kill(nil) - err = tomb.Wait() - require.NoError(t, err) + stopHTTPSource(t, h, cancel, streamErr) } func TestStreamingAcquisitionSuccess(t *testing.T) { ctx := t.Context() h := &Source{} - out, reg, tomb := SetupAndRunHTTPSource(t, h, []byte(` + out, reg, cancel, streamErr := SetupAndRunHTTPSource(t, h, []byte(` source: http listen_addr: 127.0.0.1:8080 path: /test @@ -291,16 +287,13 @@ headers: assertMetrics(t, reg, h.GetMetrics(), 1) - h.Server.Close() - tomb.Kill(nil) - err = tomb.Wait() - require.NoError(t, err) + stopHTTPSource(t, h, cancel, streamErr) } func TestStreamingAcquisitionCustomStatusCodeAndCustomHeaders(t *testing.T) { ctx := t.Context() h := &Source{} - out, reg, tomb := SetupAndRunHTTPSource(t, h, []byte(` + out, reg, cancel, streamErr := SetupAndRunHTTPSource(t, h, []byte(` source: http listen_addr: 127.0.0.1:8080 path: /test @@ -334,10 +327,7 @@ custom_headers: assertMetrics(t, reg, h.GetMetrics(), 1) - h.Server.Close() - tomb.Kill(nil) - err = tomb.Wait() - require.NoError(t, err) + stopHTTPSource(t, h, cancel, streamErr) } func TestAcquistionSocket(t *testing.T) { @@ -346,7 +336,7 @@ func TestAcquistionSocket(t *testing.T) { ctx := t.Context() h := &Source{} - out, reg, tomb := SetupAndRunHTTPSource(t, h, []byte(` + out, reg, cancel, streamErr := SetupAndRunHTTPSource(t, h, []byte(` source: http listen_socket: `+socketFile+` path: /test @@ -382,10 +372,7 @@ headers: assertMetrics(t, reg, h.GetMetrics(), 1) - h.Server.Close() - tomb.Kill(nil) - err = tomb.Wait() - require.NoError(t, err) + stopHTTPSource(t, h, cancel, streamErr) } type slowReader struct { @@ -447,7 +434,7 @@ func assertEvents(out chan pipeline.Event, expected []string, errChan chan error func TestStreamingAcquisitionTimeout(t *testing.T) { ctx := t.Context() h := &Source{} - _, _, tomb := SetupAndRunHTTPSource(t, h, []byte(` + _, _, cancel, streamErr := SetupAndRunHTTPSource(t, h, []byte(` source: http listen_addr: 127.0.0.1:8080 path: /test @@ -476,17 +463,14 @@ timeout: 1s`), 0) assert.Equal(t, http.StatusBadRequest, resp.StatusCode) closeBody(t, resp) - h.Server.Close() - tomb.Kill(nil) - err = tomb.Wait() - require.NoError(t, err) + stopHTTPSource(t, h, cancel, streamErr) } func TestStreamingAcquisitionTLSHTTPRequest(t *testing.T) { ctx := t.Context() h := &Source{} - _, _, tomb := SetupAndRunHTTPSource(t, h, []byte(` + _, _, cancel, streamErr := SetupAndRunHTTPSource(t, h, []byte(` source: http listen_addr: 127.0.0.1:8080 auth_type: mtls @@ -508,16 +492,13 @@ tls: assert.Equal(t, http.StatusBadRequest, resp.StatusCode) closeBody(t, resp) - h.Server.Close() - tomb.Kill(nil) - err = tomb.Wait() - require.NoError(t, err) + stopHTTPSource(t, h, cancel, streamErr) } func TestStreamingAcquisitionTLSWithHeadersAuthSuccess(t *testing.T) { ctx := t.Context() h := &Source{} - out, reg, tomb := SetupAndRunHTTPSource(t, h, []byte(` + out, reg, cancel, streamErr := SetupAndRunHTTPSource(t, h, []byte(` source: http listen_addr: 127.0.0.1:8080 path: /test @@ -567,16 +548,13 @@ tls: assertMetrics(t, reg, h.GetMetrics(), 0) - h.Server.Close() - tomb.Kill(nil) - err = tomb.Wait() - require.NoError(t, err) + stopHTTPSource(t, h, cancel, streamErr) } func TestStreamingAcquisitionMTLS(t *testing.T) { ctx := t.Context() h := &Source{} - out, reg, tomb := SetupAndRunHTTPSource(t, h, []byte(` + out, reg, cancel, streamErr := SetupAndRunHTTPSource(t, h, []byte(` source: http listen_addr: 127.0.0.1:8080 path: /test @@ -628,16 +606,13 @@ tls: assertMetrics(t, reg, h.GetMetrics(), 0) - h.Server.Close() - tomb.Kill(nil) - err = tomb.Wait() - require.NoError(t, err) + stopHTTPSource(t, h, cancel, streamErr) } func TestStreamingAcquisitionGzipData(t *testing.T) { ctx := t.Context() h := &Source{} - out, reg, tomb := SetupAndRunHTTPSource(t, h, []byte(` + out, reg, cancel, streamErr := SetupAndRunHTTPSource(t, h, []byte(` source: http listen_addr: 127.0.0.1:8080 path: /test @@ -684,16 +659,13 @@ headers: assertMetrics(t, reg, h.GetMetrics(), 2) - h.Server.Close() - tomb.Kill(nil) - err = tomb.Wait() - require.NoError(t, err) + stopHTTPSource(t, h, cancel, streamErr) } func TestStreamingAcquisitionNDJson(t *testing.T) { ctx := t.Context() h := &Source{} - out, reg, tomb := SetupAndRunHTTPSource(t, h, []byte(` + out, reg, cancel, streamErr := SetupAndRunHTTPSource(t, h, []byte(` source: http listen_addr: 127.0.0.1:8080 path: /test @@ -726,10 +698,7 @@ headers: assertMetrics(t, reg, h.GetMetrics(), 2) - h.Server.Close() - tomb.Kill(nil) - err = tomb.Wait() - require.NoError(t, err) + stopHTTPSource(t, h, cancel, streamErr) } func assertMetrics(t *testing.T, reg *prometheus.Registry, metrics []prometheus.Collector, expected int) { diff --git a/pkg/acquisition/modules/http/init.go b/pkg/acquisition/modules/http/init.go index 07d19e48007..17191d0b25d 100644 --- a/pkg/acquisition/modules/http/init.go +++ b/pkg/acquisition/modules/http/init.go @@ -7,9 +7,9 @@ import ( var ( // verify interface compliance - _ types.DataSource = (*Source)(nil) - _ types.Tailer = (*Source)(nil) - _ types.MetricsProvider = (*Source)(nil) + _ types.DataSource = (*Source)(nil) + _ types.RestartableStreamer = (*Source)(nil) + _ types.MetricsProvider = (*Source)(nil) ) const ModuleName = "http" diff --git a/pkg/acquisition/modules/http/run.go b/pkg/acquisition/modules/http/run.go index ce548917333..fae8d84f00e 100644 --- a/pkg/acquisition/modules/http/run.go +++ b/pkg/acquisition/modules/http/run.go @@ -15,7 +15,6 @@ import ( "github.com/prometheus/client_golang/prometheus" log "github.com/sirupsen/logrus" - "gopkg.in/tomb.v2" "github.com/crowdsecurity/go-cs-lib/trace" @@ -130,7 +129,7 @@ func (s *Source) processRequest(w http.ResponseWriter, r *http.Request, hc *Conf return nil } -func (s *Source) RunServer(ctx context.Context, out chan pipeline.Event, t *tomb.Tomb) error { +func (s *Source) Stream(ctx context.Context, out chan pipeline.Event) error { mux := http.NewServeMux() mux.HandleFunc(s.Config.Path, func(w http.ResponseWriter, r *http.Request) { if err := authorizeRequest(r, &s.Config); err != nil { @@ -210,12 +209,13 @@ func (s *Source) RunServer(ctx context.Context, out chan pipeline.Event, t *tomb } listenConfig := &net.ListenConfig{} + serverErr := make(chan error, 2) - t.Go(func() error { + go func() { defer trace.ReportPanic() if s.Config.ListenSocket == "" { - return nil + return } s.logger.Infof("creating unix socket on %s", s.Config.ListenSocket) @@ -223,29 +223,38 @@ func (s *Source) RunServer(ctx context.Context, out chan pipeline.Event, t *tomb listener, err := listenConfig.Listen(ctx, "unix", s.Config.ListenSocket) if err != nil { - return csnet.WrapSockErr(err, s.Config.ListenSocket) + select { + case serverErr <- csnet.WrapSockErr(err, s.Config.ListenSocket): + default: + } + + return } if s.Config.TLS != nil { err := s.Server.ServeTLS(listener, s.Config.TLS.ServerCert, s.Config.TLS.ServerKey) if err != nil && err != http.ErrServerClosed { - return fmt.Errorf("https server failed: %w", err) + select { + case serverErr <- fmt.Errorf("https server failed: %w", err): + default: + } } } else { err := s.Server.Serve(listener) if err != nil && err != http.ErrServerClosed { - return fmt.Errorf("http server failed: %w", err) + select { + case serverErr <- fmt.Errorf("http server failed: %w", err): + default: + } } } + }() - return nil - }) - - t.Go(func() error { + go func() { defer trace.ReportPanic() if s.Config.ListenAddr == "" { - return nil + return } if s.Config.TLS != nil { @@ -253,38 +262,33 @@ func (s *Source) RunServer(ctx context.Context, out chan pipeline.Event, t *tomb err := s.Server.ListenAndServeTLS(s.Config.TLS.ServerCert, s.Config.TLS.ServerKey) if err != nil && err != http.ErrServerClosed { - return fmt.Errorf("https server failed: %w", err) + select { + case serverErr <- fmt.Errorf("https server failed: %w", err): + default: + } } } else { s.logger.Infof("start http server on %s", s.Config.ListenAddr) err := s.Server.ListenAndServe() if err != nil && err != http.ErrServerClosed { - return fmt.Errorf("http server failed: %w", err) + select { + case serverErr <- fmt.Errorf("http server failed: %w", err): + default: + } } } + }() - return nil - }) - - <-t.Dying() - - s.logger.Infof("%s datasource stopping", s.GetName()) + select { + case <-ctx.Done(): + s.logger.Infof("%s datasource stopping", s.GetName()) + if err := s.Server.Close(); err != nil && !errors.Is(err, http.ErrServerClosed) { + return fmt.Errorf("while closing %s server: %w", s.GetName(), err) + } - if err := s.Server.Close(); err != nil { - return fmt.Errorf("while closing %s server: %w", s.GetName(), err) + return nil + case err := <-serverErr: + return err } - - return nil -} - -func (s *Source) StreamingAcquisition(ctx context.Context, out chan pipeline.Event, t *tomb.Tomb) error { - s.logger.Debugf("start http server on %s", s.Config.ListenAddr) - - t.Go(func() error { - defer trace.ReportPanic() - return s.RunServer(ctx, out, t) - }) - - return nil } From d93ab4b2a420b3d0e829132861b71e5c7042f505 Mon Sep 17 00:00:00 2001 From: marco Date: Fri, 27 Feb 2026 17:24:48 +0100 Subject: [PATCH 10/23] acquisition: migrate s3 lifecycle to context Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- pkg/acquisition/modules/s3/init.go | 10 +-- pkg/acquisition/modules/s3/run.go | 101 +++++++++++++------------- pkg/acquisition/modules/s3/s3_test.go | 43 ++++++----- pkg/acquisition/modules/s3/source.go | 8 +- 4 files changed, 77 insertions(+), 85 deletions(-) diff --git a/pkg/acquisition/modules/s3/init.go b/pkg/acquisition/modules/s3/init.go index 2e022494981..c38f0ad7f89 100644 --- a/pkg/acquisition/modules/s3/init.go +++ b/pkg/acquisition/modules/s3/init.go @@ -7,11 +7,11 @@ import ( var ( // verify interface compliance - _ types.DataSource = (*Source)(nil) - _ types.DSNConfigurer = (*Source)(nil) - _ types.Fetcher = (*Source)(nil) - _ types.Tailer = (*Source)(nil) - _ types.MetricsProvider = (*Source)(nil) + _ types.DataSource = (*Source)(nil) + _ types.DSNConfigurer = (*Source)(nil) + _ types.BatchFetcher = (*Source)(nil) + _ types.RestartableStreamer = (*Source)(nil) + _ types.MetricsProvider = (*Source)(nil) ) const ModuleName = "s3" diff --git a/pkg/acquisition/modules/s3/run.go b/pkg/acquisition/modules/s3/run.go index 6b2d51b9a5a..9ac06cc8e98 100644 --- a/pkg/acquisition/modules/s3/run.go +++ b/pkg/acquisition/modules/s3/run.go @@ -20,7 +20,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/sqs" "github.com/prometheus/client_golang/prometheus" log "github.com/sirupsen/logrus" - "gopkg.in/tomb.v2" + "golang.org/x/sync/errgroup" "github.com/crowdsecurity/crowdsec/pkg/metrics" "github.com/crowdsecurity/crowdsec/pkg/pipeline" @@ -67,26 +67,28 @@ const ( SQSFormatSNS = "sns" ) -func (s *Source) readManager() { +func (s *Source) readManager(ctx context.Context, out chan pipeline.Event, readerChan <-chan S3Object) { logger := s.logger.WithField("method", "readManager") for { select { - case <-s.t.Dying(): + case <-ctx.Done(): logger.Infof("Shutting down S3 read manager") - s.cancel() return - case s3Object := <-s.readerChan: + case s3Object, ok := <-readerChan: + if !ok { + return + } logger.Debugf("Reading file %s/%s", s3Object.Bucket, s3Object.Key) - if err := s.readFile(s3Object.Bucket, s3Object.Key); err != nil { + if err := s.readFile(ctx, out, s3Object.Bucket, s3Object.Key); err != nil { logger.Errorf("Error while reading file: %s", err) } } } } -func (s *Source) getBucketContent() ([]s3types.Object, error) { +func (s *Source) getBucketContent(ctx context.Context) ([]s3types.Object, error) { logger := s.logger.WithField("method", "getBucketContent") logger.Debugf("Getting bucket content") @@ -95,7 +97,7 @@ func (s *Source) getBucketContent() ([]s3types.Object, error) { var continuationToken *string for { - out, err := s.s3Client.ListObjectsV2(s.ctx, &s3.ListObjectsV2Input{ + out, err := s.s3Client.ListObjectsV2(ctx, &s3.ListObjectsV2Input{ Bucket: aws.String(s.Config.BucketName), Prefix: aws.String(s.Config.Prefix), ContinuationToken: continuationToken, @@ -120,7 +122,7 @@ func (s *Source) getBucketContent() ([]s3types.Object, error) { return bucketObjects, nil } -func (s *Source) listPoll() error { +func (s *Source) listPoll(ctx context.Context, readerChan chan<- S3Object) error { logger := s.logger.WithField("method", "listPoll") ticker := time.NewTicker(time.Duration(s.Config.PollingInterval) * time.Second) lastObjectDate := time.Now() @@ -129,14 +131,13 @@ func (s *Source) listPoll() error { for { select { - case <-s.t.Dying(): + case <-ctx.Done(): logger.Infof("Shutting down list poller") - s.cancel() return nil case <-ticker.C: newObject := false - bucketObjects, err := s.getBucketContent() + bucketObjects, err := s.getBucketContent(ctx) if err != nil { logger.Errorf("Error while getting bucket content: %s", err) continue @@ -161,9 +162,9 @@ func (s *Source) listPoll() error { } select { - case s.readerChan <- obj: - case <-s.t.Dying(): - logger.Debug("tomb is dying, dropping object send") + case readerChan <- obj: + case <-ctx.Done(): + logger.Debug("context canceled, dropping object send") return nil } } @@ -261,19 +262,18 @@ func (s *Source) extractBucketAndPrefix(message *string) (string, string, error) } } -func (s *Source) sqsPoll() error { +func (s *Source) sqsPoll(ctx context.Context, readerChan chan<- S3Object) error { logger := s.logger.WithField("method", "sqsPoll") for { select { - case <-s.t.Dying(): + case <-ctx.Done(): logger.Infof("Shutting down SQS poller") - s.cancel() return nil default: logger.Trace("Polling SQS queue") - out, err := s.sqsClient.ReceiveMessage(s.ctx, &sqs.ReceiveMessageInput{ + out, err := s.sqsClient.ReceiveMessage(ctx, &sqs.ReceiveMessageInput{ QueueUrl: aws.String(s.Config.SQSName), MaxNumberOfMessages: 10, WaitTimeSeconds: 20, // Probably no need to make it configurable ? @@ -298,7 +298,7 @@ func (s *Source) sqsPoll() error { if err != nil { logger.Errorf("Error while parsing SQS message: %s", err) // Always delete the message to avoid infinite loop - _, err = s.sqsClient.DeleteMessage(s.ctx, + _, err = s.sqsClient.DeleteMessage(ctx, &sqs.DeleteMessageInput{ QueueUrl: aws.String(s.Config.SQSName), ReceiptHandle: message.ReceiptHandle, @@ -314,13 +314,13 @@ func (s *Source) sqsPoll() error { // don't block if readManager has quit select { - case s.readerChan <- S3Object{Key: key, Bucket: bucket}: - case <-s.t.Dying(): - logger.Debug("tomb is dying, dropping object send") + case readerChan <- S3Object{Key: key, Bucket: bucket}: + case <-ctx.Done(): + logger.Debug("context canceled, dropping object send") return nil } - _, err = s.sqsClient.DeleteMessage(s.ctx, + _, err = s.sqsClient.DeleteMessage(ctx, &sqs.DeleteMessageInput{ QueueUrl: aws.String(s.Config.SQSName), ReceiptHandle: message.ReceiptHandle, @@ -335,7 +335,7 @@ func (s *Source) sqsPoll() error { } } -func (s *Source) readFile(bucket string, key string) error { +func (s *Source) readFile(ctx context.Context, out chan pipeline.Event, bucket string, key string) error { // TODO: Handle SSE-C var scanner *bufio.Scanner @@ -345,7 +345,7 @@ func (s *Source) readFile(bucket string, key string) error { "key": key, }) - output, err := s.s3Client.GetObject(s.ctx, &s3.GetObjectInput{ + output, err := s.s3Client.GetObject(ctx, &s3.GetObjectInput{ Bucket: aws.String(bucket), Key: aws.String(key), }) @@ -385,7 +385,7 @@ func (s *Source) readFile(bucket string, key string) error { for scanner.Scan() { select { - case <-s.t.Dying(): + case <-ctx.Done(): s.logger.Infof("Shutting down reader for %s/%s", bucket, key) return nil default: @@ -415,9 +415,9 @@ func (s *Source) readFile(bucket string, key string) error { // don't block in shutdown select { - case s.out <-evt: - case <-s.t.Dying(): - s.logger.Infof("tomb is dying, dropping event for %s/%s", bucket, key) + case out <- evt: + case <-ctx.Done(): + s.logger.Infof("context canceled, dropping event for %s/%s", bucket, key) return nil } } @@ -434,58 +434,57 @@ func (s *Source) readFile(bucket string, key string) error { return nil } -func (s *Source) OneShotAcquisition(ctx context.Context, out chan pipeline.Event, t *tomb.Tomb) error { +func (s *Source) OneShot(ctx context.Context, out chan pipeline.Event) error { s.logger.Infof("starting acquisition of %s/%s/%s", s.Config.BucketName, s.Config.Prefix, s.Config.Key) - s.out = out - s.ctx, s.cancel = context.WithCancel(ctx) s.Config.UseTimeMachine = true - s.t = t if s.Config.Key != "" { - err := s.readFile(s.Config.BucketName, s.Config.Key) + err := s.readFile(ctx, out, s.Config.BucketName, s.Config.Key) if err != nil { return err } } else { // No key, get everything in the bucket based on the prefix - objects, err := s.getBucketContent() + objects, err := s.getBucketContent(ctx) if err != nil { return err } for _, object := range objects { - err := s.readFile(s.Config.BucketName, *object.Key) + err := s.readFile(ctx, out, s.Config.BucketName, *object.Key) if err != nil { return err } } } - t.Kill(nil) - return nil } -func (s *Source) StreamingAcquisition(ctx context.Context, out chan pipeline.Event, t *tomb.Tomb) error { - s.t = t - s.out = out - s.readerChan = make(chan S3Object, 100) // FIXME: does this needs to be buffered? - s.ctx, s.cancel = context.WithCancel(ctx) +func (s *Source) Stream(ctx context.Context, out chan pipeline.Event) error { + readerChan := make(chan S3Object, 100) // FIXME: does this needs to be buffered? s.logger.Infof("starting acquisition of %s/%s", s.Config.BucketName, s.Config.Prefix) - t.Go(func() error { - s.readManager() + + g, gctx := errgroup.WithContext(ctx) + g.Go(func() error { + s.readManager(gctx, out, readerChan) return nil }) if s.Config.PollingMethod == PollMethodSQS { - t.Go(func() error { - return s.sqsPoll() + g.Go(func() error { + return s.sqsPoll(gctx, readerChan) }) } else { - t.Go(func() error { - return s.listPoll() + g.Go(func() error { + return s.listPoll(gctx, readerChan) }) } - return nil + err := g.Wait() + if errors.Is(err, context.Canceled) { + return nil + } + + return err } diff --git a/pkg/acquisition/modules/s3/s3_test.go b/pkg/acquisition/modules/s3/s3_test.go index 8557a8fa37e..e0d3fa463df 100644 --- a/pkg/acquisition/modules/s3/s3_test.go +++ b/pkg/acquisition/modules/s3/s3_test.go @@ -17,7 +17,6 @@ import ( log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "gopkg.in/tomb.v2" "github.com/crowdsecurity/crowdsec/pkg/metrics" "github.com/crowdsecurity/crowdsec/pkg/pipeline" @@ -187,8 +186,7 @@ func TestDSNAcquis(t *testing.T) { } }() - tmb := tomb.Tomb{} - err = f.OneShotAcquisition(ctx, out, &tmb) + err = f.OneShot(ctx, out) require.NoError(t, err) time.Sleep(2 * time.Second) @@ -246,7 +244,9 @@ prefix: foo/ } out := make(chan pipeline.Event) - tb := tomb.Tomb{} + done := make(chan struct{}) + streamCtx, cancel := context.WithCancel(ctx) + streamErr := make(chan error, 1) go func() { for { @@ -254,21 +254,20 @@ prefix: foo/ case s := <-out: fmt.Printf("got line %s\n", s.Line.Raw) linesRead++ - case <-tb.Dying(): + case <-done: return } } }() - err = f.StreamingAcquisition(ctx, out, &tb) - if err != nil { - t.Fatalf("unexpected error: %s", err.Error()) - } + go func() { + streamErr <- f.Stream(streamCtx, out) + }() time.Sleep(2 * time.Second) - tb.Kill(nil) - err = tb.Wait() - require.NoError(t, err) + cancel() + require.NoError(t, <-streamErr) + close(done) assert.Equal(t, test.expectedCount, linesRead) }) } @@ -342,7 +341,9 @@ sqs_name: test } out := make(chan pipeline.Event) - tb := tomb.Tomb{} + done := make(chan struct{}) + streamCtx, cancel := context.WithCancel(ctx) + streamErr := make(chan error, 1) go func() { for { @@ -350,22 +351,20 @@ sqs_name: test case s := <-out: fmt.Printf("got line %s\n", s.Line.Raw) linesRead++ - case <-tb.Dying(): + case <-done: return } } }() - err = f.StreamingAcquisition(ctx, out, &tb) - if err != nil { - t.Fatalf("unexpected error: %s", err.Error()) - } + go func() { + streamErr <- f.Stream(streamCtx, out) + }() time.Sleep(2 * time.Second) - tb.Kill(nil) - - err = tb.Wait() - require.NoError(t, err) + cancel() + require.NoError(t, <-streamErr) + close(done) assert.Equal(t, test.expectedCount, linesRead) }) } diff --git a/pkg/acquisition/modules/s3/source.go b/pkg/acquisition/modules/s3/source.go index a1784f5d68d..851adc82409 100644 --- a/pkg/acquisition/modules/s3/source.go +++ b/pkg/acquisition/modules/s3/source.go @@ -2,13 +2,12 @@ package s3acquisition import ( "context" + s3Manager "github.com/aws/aws-sdk-go-v2/feature/s3/manager" "github.com/aws/aws-sdk-go-v2/service/sqs" log "github.com/sirupsen/logrus" - "gopkg.in/tomb.v2" "github.com/crowdsecurity/crowdsec/pkg/metrics" - "github.com/crowdsecurity/crowdsec/pkg/pipeline" ) type S3API interface { @@ -27,11 +26,6 @@ type Source struct { logger *log.Entry s3Client S3API sqsClient SQSAPI - readerChan chan S3Object - t *tomb.Tomb - out chan pipeline.Event - ctx context.Context - cancel context.CancelFunc } type S3Object struct { From e16beb6969a339ef1c2f9704e9b66d167019daa2 Mon Sep 17 00:00:00 2001 From: marco Date: Fri, 27 Feb 2026 19:44:02 +0100 Subject: [PATCH 11/23] cmd/crowdsec: migrate lifecycle to context Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- cmd/crowdsec/api.go | 73 +++++++++++++++++++++++++++++----------- cmd/crowdsec/crowdsec.go | 39 +++++++++++---------- cmd/crowdsec/main.go | 13 +++---- cmd/crowdsec/serve.go | 68 +++++++++++++++++++------------------ 4 files changed, 115 insertions(+), 78 deletions(-) diff --git a/cmd/crowdsec/api.go b/cmd/crowdsec/api.go index 47e15af49cd..295f5dda934 100644 --- a/cmd/crowdsec/api.go +++ b/cmd/crowdsec/api.go @@ -3,6 +3,7 @@ package main import ( "context" "fmt" + "net" "time" log "github.com/sirupsen/logrus" @@ -20,6 +21,14 @@ func initAPIServer(ctx context.Context, cConfig *csconfig.Config) (*apiserver.AP log.Info("push and pull to Central API disabled") } + if cConfig.API.Server.ListenURI != "" { + listener, err := net.Listen("tcp", cConfig.API.Server.ListenURI) + if err != nil { + return nil, fmt.Errorf("local API server stopped with error: listening on %s: %w", cConfig.API.Server.ListenURI, err) + } + _ = listener.Close() + } + accessLogger := cConfig.API.Server.NewAccessLogger(cConfig.Common.LogConfig, accessLogFilename) apiServer, err := apiserver.NewServer(ctx, cConfig.API.Server, accessLogger) @@ -42,38 +51,62 @@ func initAPIServer(ctx context.Context, cConfig *csconfig.Config) (*apiserver.AP func serveAPIServer(ctx context.Context, apiServer *apiserver.APIServer) { apiReady := make(chan bool, 1) + runCtx, cancel := context.WithCancel(ctx) + apiCancel = cancel + apiDone = make(chan error, 1) - apiTomb.Go(func() error { + go func() { defer trace.ReportPanic() + pluginCtx, cancelPlugin := context.WithCancel(runCtx) + pluginDone := make(chan struct{}) + + go func() { + defer trace.ReportPanic() + defer close(pluginDone) + pluginBroker.Run(pluginCtx) + }() + + runErr := make(chan error, 1) + go func() { defer trace.ReportPanic() log.Debugf("serving API after %s ms", time.Since(crowdsecT0)) - if err := apiServer.Run(ctx, apiReady); err != nil { - log.Fatal(err) - } + runErr <- apiServer.Run(runCtx, apiReady) }() - pluginCtx, cancelPlugin := context.WithCancel(ctx) - - pluginTomb.Go(func() error { - <-pluginTomb.Dying() + select { + case err := <-runErr: + if err != nil && runCtx.Err() == nil { + log.Fatal(err) + } cancelPlugin() - return nil - }) + <-pluginDone + apiDone <- err + case <-runCtx.Done(): + log.Infof("serve: shutting down api server") - pluginTomb.Go(func() error { - pluginBroker.Run(pluginCtx) - return nil - }) + shutdownCtx, cancelShutdown := context.WithTimeout(context.Background(), 5*time.Second) + err := apiServer.Shutdown(shutdownCtx) + cancelShutdown() - <-apiTomb.Dying() // lock until go routine is dying - pluginTomb.Kill(nil) - log.Infof("serve: shutting down api server") + if runErrValue := <-runErr; runErrValue != nil && err == nil { + err = runErrValue + } - return apiServer.Shutdown(ctx) - }) - <-apiReady + cancelPlugin() + <-pluginDone + apiDone <- err + } + }() + + select { + case <-apiReady: + case err := <-apiDone: + if err != nil { + log.Fatal(err) + } + } } diff --git a/cmd/crowdsec/crowdsec.go b/cmd/crowdsec/crowdsec.go index 968d8dd556c..52ca463a26a 100644 --- a/cmd/crowdsec/crowdsec.go +++ b/cmd/crowdsec/crowdsec.go @@ -185,38 +185,37 @@ func serveCrowdsec( sd *StateDumper, ) { cctx, cancel := context.WithCancel(ctx) + crowdCancel = cancel + crowdDone = make(chan error, 1) var g errgroup.Group bucketStore := leakybucket.NewBucketStore() - crowdsecTomb.Go(func() error { + go func() { defer trace.ReportPanic() + // this logs every time, even at config reload + log.Debugf("running agent after %s ms", time.Since(crowdsecT0)) + agentReady <- true - go func() { - defer trace.ReportPanic() - // this logs every time, even at config reload - log.Debugf("running agent after %s ms", time.Since(crowdsecT0)) - - agentReady <- true - - if err := runCrowdsec(cctx, &g, cConfig, parsers, hub, datasources, sd, bucketStore); err != nil { - log.Fatalf("unable to start crowdsec routines: %s", err) - } - }() + if err := runCrowdsec(cctx, &g, cConfig, parsers, hub, datasources, sd, bucketStore); err != nil { + crowdDone <- fmt.Errorf("unable to start crowdsec routines: %w", err) + return + } /* we should stop in two cases : - - crowdsecTomb has been Killed() : it might be shutdown or reload, so stop + - context has been canceled: it might be shutdown or reload, so stop - acquisTomb is dead, it means that we were in "cat" mode and files are done reading, quit */ - waitOnTomb() + waitOnCrowdsecStop(cctx) log.Debugf("Shutting down crowdsec routines") if err := ShutdownCrowdsecRoutines(cancel, &g, datasources); err != nil { - return fmt.Errorf("unable to shutdown crowdsec routines: %w", err) + crowdDone <- fmt.Errorf("unable to shutdown crowdsec routines: %w", err) + return } - log.Debugf("everything is dead, return crowdsecTomb") + log.Debugf("everything is dead, return crowdsec routine") log.Debugf("sd.DumpDir == %s", sd.DumpDir) if sd.DumpDir != "" { @@ -229,11 +228,11 @@ func serveCrowdsec( os.Exit(0) } - return nil - }) + crowdDone <- nil + }() } -func waitOnTomb() { +func waitOnCrowdsecStop(ctx context.Context) { for { select { case <-acquisTomb.Dead(): @@ -256,7 +255,7 @@ func waitOnTomb() { return - case <-crowdsecTomb.Dying(): + case <-ctx.Done(): log.Infof("Crowdsec engine shutting down") return } diff --git a/cmd/crowdsec/main.go b/cmd/crowdsec/main.go index 9480e6dd08e..9b47c30335f 100644 --- a/cmd/crowdsec/main.go +++ b/cmd/crowdsec/main.go @@ -30,12 +30,13 @@ import ( ) var ( - // tombs for the parser, buckets and outputs. - acquisTomb tomb.Tomb - outputsTomb tomb.Tomb - apiTomb tomb.Tomb - crowdsecTomb tomb.Tomb - pluginTomb tomb.Tomb + // tombs for acquisition and output routines. + acquisTomb tomb.Tomb + outputsTomb tomb.Tomb + apiCancel context.CancelFunc + apiDone chan error + crowdCancel context.CancelFunc + crowdDone chan error flags Flags diff --git a/cmd/crowdsec/serve.go b/cmd/crowdsec/serve.go index e916daed4dd..f52b4033ef7 100644 --- a/cmd/crowdsec/serve.go +++ b/cmd/crowdsec/serve.go @@ -27,12 +27,13 @@ import ( ) func reloadHandler(ctx context.Context, _ os.Signal) (*csconfig.Config, error) { - // re-initialize tombs + // re-initialize routines state acquisTomb = tomb.Tomb{} outputsTomb = tomb.Tomb{} - apiTomb = tomb.Tomb{} - crowdsecTomb = tomb.Tomb{} - pluginTomb = tomb.Tomb{} + apiCancel = nil + apiDone = nil + crowdCancel = nil + crowdDone = nil sd := NewStateDumper(flags.DumpDir) @@ -151,9 +152,6 @@ func ShutdownCrowdsecRoutines(cancel context.CancelFunc, g *errgroup.Group, data log.Debugf("buckets are done") log.Debugf("metrics are done") - // He's dead, Jim. - crowdsecTomb.Kill(nil) - // close the potential geoips reader we have to avoid leaking ressources on reload exprhelpers.GeoIPClose() @@ -161,11 +159,15 @@ func ShutdownCrowdsecRoutines(cancel context.CancelFunc, g *errgroup.Group, data } func shutdownAPI() error { - log.Debugf("shutting down api via Tomb") - apiTomb.Kill(nil) + log.Debugf("shutting down api via context") + if apiCancel != nil { + apiCancel() + } - if err := apiTomb.Wait(); err != nil { - return err + if apiDone != nil { + if err := <-apiDone; err != nil { + return err + } } log.Debugf("done") @@ -174,11 +176,15 @@ func shutdownAPI() error { } func shutdownCrowdsec() error { - log.Debugf("shutting down crowdsec via Tomb") - crowdsecTomb.Kill(nil) + log.Debugf("shutting down crowdsec via context") + if crowdCancel != nil { + crowdCancel() + } - if err := crowdsecTomb.Wait(); err != nil { - return err + if crowdDone != nil { + if err := <-crowdDone; err != nil { + return err + } } log.Debugf("done") @@ -312,9 +318,10 @@ func Serve( ) error { acquisTomb = tomb.Tomb{} outputsTomb = tomb.Tomb{} - apiTomb = tomb.Tomb{} - crowdsecTomb = tomb.Tomb{} - pluginTomb = tomb.Tomb{} + apiCancel = nil + apiDone = nil + crowdCancel = nil + crowdDone = nil if cConfig.API.Server != nil && cConfig.API.Server.DbConfig != nil { dbCfg := cConfig.API.Server.DbConfig @@ -408,25 +415,22 @@ func Serve( return HandleSignals(ctx, cConfig) } - waitChans := make([]<-chan struct{}, 0) - if !cConfig.DisableAgent { - waitChans = append(waitChans, crowdsecTomb.Dead()) + if crowdDone != nil { + if err := <-crowdDone; err != nil { + return fmt.Errorf("crowdsec shutdown: %w", err) + } + } + log.Infof("crowdsec shutdown") } if !cConfig.DisableAPI { - waitChans = append(waitChans, apiTomb.Dead()) - } - - for _, ch := range waitChans { - <-ch - - switch ch { - case apiTomb.Dead(): - log.Infof("api shutdown") - case crowdsecTomb.Dead(): - log.Infof("crowdsec shutdown") + if apiDone != nil { + if err := <-apiDone; err != nil { + return fmt.Errorf("api shutdown: %w", err) + } } + log.Infof("api shutdown") } return nil From 67d59b905334a2bc6841533a61cdc74af8863e46 Mon Sep 17 00:00:00 2001 From: marco Date: Fri, 27 Feb 2026 20:15:58 +0100 Subject: [PATCH 12/23] cmd/crowdsec: satisfy context lint checks Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- cmd/crowdsec/api.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/cmd/crowdsec/api.go b/cmd/crowdsec/api.go index 295f5dda934..c27fb98404e 100644 --- a/cmd/crowdsec/api.go +++ b/cmd/crowdsec/api.go @@ -22,7 +22,8 @@ func initAPIServer(ctx context.Context, cConfig *csconfig.Config) (*apiserver.AP } if cConfig.API.Server.ListenURI != "" { - listener, err := net.Listen("tcp", cConfig.API.Server.ListenURI) + listenConfig := &net.ListenConfig{} + listener, err := listenConfig.Listen(ctx, "tcp", cConfig.API.Server.ListenURI) if err != nil { return nil, fmt.Errorf("local API server stopped with error: listening on %s: %w", cConfig.API.Server.ListenURI, err) } @@ -88,7 +89,7 @@ func serveAPIServer(ctx context.Context, apiServer *apiserver.APIServer) { case <-runCtx.Done(): log.Infof("serve: shutting down api server") - shutdownCtx, cancelShutdown := context.WithTimeout(context.Background(), 5*time.Second) + shutdownCtx, cancelShutdown := context.WithTimeout(context.WithoutCancel(runCtx), 5*time.Second) err := apiServer.Shutdown(shutdownCtx) cancelShutdown() From 0a66b781f23e1415631850a2c06673c9ef96d238 Mon Sep 17 00:00:00 2001 From: marco Date: Fri, 27 Feb 2026 20:48:19 +0100 Subject: [PATCH 13/23] acquisition/http: extract Stream request handler Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- pkg/acquisition/modules/http/run.go | 90 +++++++++++++++-------------- 1 file changed, 47 insertions(+), 43 deletions(-) diff --git a/pkg/acquisition/modules/http/run.go b/pkg/acquisition/modules/http/run.go index fae8d84f00e..1a2cee12ea2 100644 --- a/pkg/acquisition/modules/http/run.go +++ b/pkg/acquisition/modules/http/run.go @@ -129,59 +129,63 @@ func (s *Source) processRequest(w http.ResponseWriter, r *http.Request, hc *Conf return nil } -func (s *Source) Stream(ctx context.Context, out chan pipeline.Event) error { - mux := http.NewServeMux() - mux.HandleFunc(s.Config.Path, func(w http.ResponseWriter, r *http.Request) { - if err := authorizeRequest(r, &s.Config); err != nil { - s.logger.Errorf("failed to authorize request from '%s': %s", r.RemoteAddr, err) - http.Error(w, "Unauthorized", http.StatusUnauthorized) +func (s *Source) handleStreamRequest(w http.ResponseWriter, r *http.Request, out chan pipeline.Event) { + if err := authorizeRequest(r, &s.Config); err != nil { + s.logger.Errorf("failed to authorize request from '%s': %s", r.RemoteAddr, err) + http.Error(w, "Unauthorized", http.StatusUnauthorized) - return - } - - switch r.Method { - case http.MethodGet, http.MethodHead: // Return a 200 if the auth was successful - s.logger.Infof("successful %s request from '%s'", r.Method, r.RemoteAddr) - w.WriteHeader(http.StatusOK) + return + } - if _, err := w.Write([]byte("OK")); err != nil { - s.logger.Errorf("failed to write response: %v", err) - } + switch r.Method { + case http.MethodGet, http.MethodHead: // Return a 200 if the auth was successful + s.logger.Infof("successful %s request from '%s'", r.Method, r.RemoteAddr) + w.WriteHeader(http.StatusOK) - return - case http.MethodPost: // POST is handled below - default: - s.logger.Errorf("method not allowed: %s", r.Method) - http.Error(w, "Method Not Allowed", http.StatusMethodNotAllowed) - return + if _, err := w.Write([]byte("OK")); err != nil { + s.logger.Errorf("failed to write response: %v", err) } - if r.RemoteAddr == "@" { - // We check if request came from unix socket and if so we set to loopback - r.RemoteAddr = "127.0.0.1:65535" - } + return + case http.MethodPost: // POST is handled below + default: + s.logger.Errorf("method not allowed: %s", r.Method) + http.Error(w, "Method Not Allowed", http.StatusMethodNotAllowed) + return + } - err := s.processRequest(w, r, &s.Config, out) - if err != nil { - s.logger.Errorf("failed to process request from '%s': %s", r.RemoteAddr, err) - return - } + if r.RemoteAddr == "@" { + // We check if request came from unix socket and if so we set to loopback + r.RemoteAddr = "127.0.0.1:65535" + } - if s.Config.CustomHeaders != nil { - for key, value := range s.Config.CustomHeaders { - w.Header().Set(key, value) - } - } + err := s.processRequest(w, r, &s.Config, out) + if err != nil { + s.logger.Errorf("failed to process request from '%s': %s", r.RemoteAddr, err) + return + } - if s.Config.CustomStatusCode != nil { - w.WriteHeader(*s.Config.CustomStatusCode) - } else { - w.WriteHeader(http.StatusOK) + if s.Config.CustomHeaders != nil { + for key, value := range s.Config.CustomHeaders { + w.Header().Set(key, value) } + } - if _, err := w.Write([]byte("OK")); err != nil { - s.logger.Errorf("failed to write response: %v", err) - } + if s.Config.CustomStatusCode != nil { + w.WriteHeader(*s.Config.CustomStatusCode) + } else { + w.WriteHeader(http.StatusOK) + } + + if _, err := w.Write([]byte("OK")); err != nil { + s.logger.Errorf("failed to write response: %v", err) + } +} + +func (s *Source) Stream(ctx context.Context, out chan pipeline.Event) error { + mux := http.NewServeMux() + mux.HandleFunc(s.Config.Path, func(w http.ResponseWriter, r *http.Request) { + s.handleStreamRequest(w, r, out) }) s.Server = &http.Server{ From 2eeaff86b65a9bf137bc9c8e4b223685c2dcb884 Mon Sep 17 00:00:00 2001 From: marco Date: Fri, 27 Feb 2026 23:24:05 +0100 Subject: [PATCH 14/23] chore: checkpoint apiserver changes Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- pkg/apiserver/apic.go | 57 +++++++++++++++++++++-------------- pkg/apiserver/apic_metrics.go | 2 -- pkg/apiserver/apic_test.go | 4 +-- pkg/apiserver/apiserver.go | 42 ++++++++++---------------- pkg/apiserver/papi.go | 3 -- 5 files changed, 52 insertions(+), 56 deletions(-) diff --git a/pkg/apiserver/apic.go b/pkg/apiserver/apic.go index 783ab733236..297b2b2b51b 100644 --- a/pkg/apiserver/apic.go +++ b/pkg/apiserver/apic.go @@ -18,7 +18,6 @@ import ( "github.com/davecgh/go-spew/spew" "github.com/go-openapi/strfmt" log "github.com/sirupsen/logrus" - "gopkg.in/tomb.v2" "github.com/crowdsecurity/go-cs-lib/ptr" @@ -61,10 +60,10 @@ type apic struct { AlertsAddChan chan []*models.Alert mu sync.Mutex - pushTomb tomb.Tomb - pullTomb tomb.Tomb metricsCancel context.CancelFunc metricsDone chan struct{} + stopChan chan struct{} + stopOnce sync.Once startup bool consoleConfig *csconfig.ConsoleConfig isPulling chan bool @@ -194,8 +193,7 @@ func NewAPIC(ctx context.Context, config *csconfig.OnlineApiClientCfg, dbClient dbClient: dbClient, mu: sync.Mutex{}, startup: true, - pullTomb: tomb.Tomb{}, - pushTomb: tomb.Tomb{}, + stopChan: make(chan struct{}), consoleConfig: consoleConfig, pullInterval: pullIntervalDefault, pullIntervalFirst: randomDuration(pullIntervalDefault, pullIntervalDelta), @@ -288,23 +286,28 @@ func (a *apic) Push(ctx context.Context) error { var cache models.AddSignalsRequest ticker := time.NewTicker(a.pushIntervalFirst) + defer ticker.Stop() log.Infof("Start push to CrowdSec Central API (interval: %s once, then %s)", a.pushIntervalFirst.Round(time.Second), a.pushInterval) - for { - select { - case <-a.pushTomb.Dying(): // if one apic routine is dying, do we kill the others? - a.pullTomb.Kill(nil) - a.StopMetrics() - log.Infof("push tomb is dying, sending cache (%d elements) before exiting", len(cache)) + flushCache := func() error { + log.Infof("stopping push routine, sending cache (%d elements) before exiting", len(cache)) - if len(cache) == 0 { - return nil - } + if len(cache) == 0 { + return nil + } - go a.Send(ctx, &cache) + go a.Send(context.WithoutCancel(ctx), &cache) - return nil + return nil + } + + for { + select { + case <-ctx.Done(): + return flushCache() + case <-a.stopChan: + return flushCache() case <-ticker.C: ticker.Reset(a.pushInterval) @@ -1038,6 +1041,14 @@ func (a *apic) Pull(ctx context.Context) error { toldOnce := false for { + select { + case <-ctx.Done(): + return nil + case <-a.stopChan: + return nil + default: + } + scenario, err := a.FetchScenariosListFromDB(ctx) if err != nil { log.Errorf("unable to fetch scenarios from db: %s", err) @@ -1062,6 +1073,7 @@ func (a *apic) Pull(ctx context.Context) error { log.Infof("Start pull from CrowdSec Central API (interval: %s once, then %s)", a.pullIntervalFirst.Round(time.Second), a.pullInterval) ticker := time.NewTicker(a.pullIntervalFirst) + defer ticker.Stop() for { select { @@ -1072,10 +1084,9 @@ func (a *apic) Pull(ctx context.Context) error { log.Errorf("capi pull top: %s", err) continue } - case <-a.pullTomb.Dying(): // if one apic routine is dying, do we kill the others? - a.StopMetrics() - a.pushTomb.Kill(nil) - + case <-ctx.Done(): + return nil + case <-a.stopChan: return nil } } @@ -1134,8 +1145,10 @@ func (a *apic) StopMetrics() { } func (a *apic) Shutdown() { - a.pushTomb.Kill(nil) - a.pullTomb.Kill(nil) + a.stopOnce.Do(func() { + close(a.stopChan) + }) + a.StopMetrics() } diff --git a/pkg/apiserver/apic_metrics.go b/pkg/apiserver/apic_metrics.go index 53b3e41be3b..1cae300e2d7 100644 --- a/pkg/apiserver/apic_metrics.go +++ b/pkg/apiserver/apic_metrics.go @@ -299,8 +299,6 @@ func (a *apic) SendMetrics(ctx context.Context, stop chan bool) { case <-stop: return case <-ctx.Done(): - a.pullTomb.Kill(nil) - a.pushTomb.Kill(nil) return case <-checkTicker.C: oldIDs := machineIDs diff --git a/pkg/apiserver/apic_test.go b/pkg/apiserver/apic_test.go index 9b059983ebd..8fe11eca459 100644 --- a/pkg/apiserver/apic_test.go +++ b/pkg/apiserver/apic_test.go @@ -19,7 +19,6 @@ import ( "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "gopkg.in/tomb.v2" "github.com/crowdsecurity/go-cs-lib/cstest" "github.com/crowdsecurity/go-cs-lib/ptr" @@ -60,8 +59,7 @@ func getAPIC(t *testing.T, ctx context.Context) *apic { dbClient: dbClient, mu: sync.Mutex{}, startup: true, - pullTomb: tomb.Tomb{}, - pushTomb: tomb.Tomb{}, + stopChan: make(chan struct{}), consoleConfig: &csconfig.ConsoleConfig{ ShareManualDecisions: ptr.Of(false), ShareTaintedScenarios: ptr.Of(false), diff --git a/pkg/apiserver/apiserver.go b/pkg/apiserver/apiserver.go index f79eeeabbea..1c440d3099b 100644 --- a/pkg/apiserver/apiserver.go +++ b/pkg/apiserver/apiserver.go @@ -17,7 +17,6 @@ import ( "github.com/gin-gonic/gin" "github.com/go-co-op/gocron/v2" log "github.com/sirupsen/logrus" - "gopkg.in/tomb.v2" "github.com/crowdsecurity/go-cs-lib/trace" @@ -41,7 +40,6 @@ type APIServer struct { httpServer *http.Server apic *apic papi *Papi - httpServerTomb tomb.Tomb } func isBrokenConnection(maybeError any) bool { @@ -240,7 +238,6 @@ func NewServer(ctx context.Context, config *csconfig.LocalApiServerCfg, accessLo router: router, apic: apiClient, papi: papiClient, - httpServerTomb: tomb.Tomb{}, }, nil } @@ -276,23 +273,23 @@ func (s *APIServer) papiPull(ctx context.Context) error { } func (s *APIServer) initAPIC(ctx context.Context) { - s.apic.pushTomb.Go(func() error { + go func() { defer trace.ReportPanic() - return s.apicPush(ctx) - }) - s.apic.pullTomb.Go(func() error { + _ = s.apicPush(ctx) + }() + go func() { defer trace.ReportPanic() - return s.apicPull(ctx) - }) + _ = s.apicPull(ctx) + }() if s.apic.apiClient.IsEnrolled() { if s.papi != nil { if s.papi.URL != "" { log.Info("Starting PAPI decision receiver") - s.papi.pullTomb.Go(func() error { + go func() { defer trace.ReportPanic() - return s.papiPull(ctx) - }) + _ = s.papiPull(ctx) + }() s.papi.StartSync(ctx) } else { log.Warnf("papi_url is not set in online_api_credentials.yaml, can't synchronize with the console. Run cscli console enable console_management to add it.") @@ -326,11 +323,7 @@ func (s *APIServer) Run(ctx context.Context, apiReady chan bool) error { s.initAPIC(ctx) } - s.httpServerTomb.Go(func() error { - return s.listenAndServeLAPI(ctx, apiReady) - }) - - if err := s.httpServerTomb.Wait(); err != nil { + if err := s.listenAndServeLAPI(ctx, apiReady); err != nil { return fmt.Errorf("local API server stopped with error: %w", err) } @@ -366,7 +359,10 @@ func (s *APIServer) listenAndServeLAPI(ctx context.Context, apiReady chan bool) switch { case errors.Is(err, http.ErrServerClosed): - break + select { + case serverError <- nil: + default: + } case err != nil: serverError <- err } @@ -415,10 +411,10 @@ func (s *APIServer) listenAndServeLAPI(ctx context.Context, apiReady chan bool) select { case err := <-serverError: return err - case <-s.httpServerTomb.Dying(): + case <-ctx.Done(): log.Info("Shutting down API server") - ctx, cancel := context.WithTimeout(ctx, 5*time.Second) + ctx, cancel := context.WithTimeout(context.WithoutCancel(ctx), 5*time.Second) defer cancel() if err := s.httpServer.Shutdown(ctx); err != nil { @@ -471,12 +467,6 @@ func (s *APIServer) Shutdown(ctx context.Context) error { pipe.Close() } - s.httpServerTomb.Kill(nil) - - if err := s.httpServerTomb.Wait(); err != nil { - return fmt.Errorf("while waiting on httpServerTomb: %w", err) - } - return nil } diff --git a/pkg/apiserver/papi.go b/pkg/apiserver/papi.go index 67d0957af8b..113272c5a25 100644 --- a/pkg/apiserver/papi.go +++ b/pkg/apiserver/papi.go @@ -10,7 +10,6 @@ import ( "time" log "github.com/sirupsen/logrus" - "gopkg.in/tomb.v2" "github.com/crowdsecurity/crowdsec/pkg/apiclient" "github.com/crowdsecurity/crowdsec/pkg/csconfig" @@ -65,7 +64,6 @@ type Papi struct { apiClient *apiclient.ApiClient Channels *OperationChannels mu sync.Mutex - pullTomb tomb.Tomb syncCancel context.CancelFunc syncDone chan struct{} SyncInterval time.Duration @@ -114,7 +112,6 @@ func NewPAPI(apic *apic, dbClient *database.Client, consoleConfig *csconfig.Cons Channels: channels, SyncInterval: SyncInterval, mu: sync.Mutex{}, - pullTomb: tomb.Tomb{}, apiClient: apic.apiClient, apic: apic, consoleConfig: consoleConfig, From 660b51f2be828ec8aee34f02c1608a576e1fb233 Mon Sep 17 00:00:00 2001 From: marco Date: Fri, 27 Feb 2026 23:55:30 +0100 Subject: [PATCH 15/23] fix: stabilize reload shutdown logging Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- cmd/crowdsec/serve.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/cmd/crowdsec/serve.go b/cmd/crowdsec/serve.go index f52b4033ef7..c54d7a4e50f 100644 --- a/cmd/crowdsec/serve.go +++ b/cmd/crowdsec/serve.go @@ -176,6 +176,8 @@ func shutdownAPI() error { } func shutdownCrowdsec() error { + log.Infof("Crowdsec engine shutting down") + log.Debugf("shutting down crowdsec via context") if crowdCancel != nil { crowdCancel() From 525fd109468b728fdd00b966ff8ab22e55d05cd4 Mon Sep 17 00:00:00 2001 From: marco Date: Sat, 28 Feb 2026 18:34:47 +0100 Subject: [PATCH 16/23] fix(test): resolve appsec test failures from wait-for timeout and daemon teardown hang - wait-for: add force-kill timer (SIGKILL after 2s) when matched process ignores SIGTERM, preventing asyncio.wait from blocking indefinitely - wait-for: preserve match status in TimeoutError handler instead of unconditionally returning 241 - crowdsec-daemon: add 5s bounded wait in stop() with SIGKILL escalation, preventing teardown hang when daemon is stuck in shutdown loop - Add regression test for slow-shutdown scenario in 00_wait_for.bats Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- test/bats/00_wait_for.bats | 9 +++++++++ test/bin/wait-for | 20 +++++++++++++++++--- test/lib/init/crowdsec-daemon | 11 ++++++++++- 3 files changed, 36 insertions(+), 4 deletions(-) diff --git a/test/bats/00_wait_for.bats b/test/bats/00_wait_for.bats index b8530602cce..b44bc492c29 100644 --- a/test/bats/00_wait_for.bats +++ b/test/bats/00_wait_for.bats @@ -67,3 +67,12 @@ setup() { 2 EOT } + +@test "run a command with timeout (match, slow shutdown)" { + # even if the process doesn't stop quickly, a matched pattern must not become timeout=241 + rune -128 wait-for --timeout .2 --out "2" sh -c 'trap "" TERM; echo 1; echo 2; sleep 5' + assert_output - <<-EOT + 1 + 2 + EOT +} diff --git a/test/bin/wait-for b/test/bin/wait-for index 798e2a09b34..7bbcc97a2cc 100755 --- a/test/bin/wait-for +++ b/test/bin/wait-for @@ -15,12 +15,12 @@ DEFAULT_TIMEOUT = 30 # TODO: print unmatched patterns -async def terminate_group(p: asyncio.subprocess.Process): +async def terminate_group(p: asyncio.subprocess.Process, sig: signal.Signals = signal.SIGTERM): """ Terminate the process group (shell, crowdsec plugins) """ try: - os.killpg(os.getpgid(p.pid), signal.SIGTERM) + os.killpg(os.getpgid(p.pid), sig) except ProcessLookupError: pass @@ -60,6 +60,15 @@ async def monitor( out.write(line) if pattern and pattern.search(line): await terminate_group(process) + + # Ensure the process exits even if it ignores SIGTERM + async def _force_kill(): + await asyncio.sleep(2) + if process.returncode is None: + await terminate_group(process, signal.SIGKILL) + + asyncio.create_task(_force_kill()) + # this is nasty. # if we timeout, we want to return a different exit code # in case of a match, so that the caller can tell @@ -97,8 +106,13 @@ async def monitor( if status is None: status = process.returncode except asyncio.TimeoutError: + if status is None: + status = 241 await terminate_group(process) - status = 241 + try: + await asyncio.wait_for(process.wait(), timeout=0.2) + except asyncio.TimeoutError: + await terminate_group(process, signal.SIGKILL) # Return the same exit code, stdout and stderr as the spawned process return status or 0 diff --git a/test/lib/init/crowdsec-daemon b/test/lib/init/crowdsec-daemon index ba8e98992db..803187567ed 100755 --- a/test/lib/init/crowdsec-daemon +++ b/test/lib/init/crowdsec-daemon @@ -53,9 +53,18 @@ stop() { if [[ -n "${PGID}" ]]; then kill -- "-${PGID}" - while pgrep -g "${PGID}" >/dev/null; do + # wait up to 5s for graceful shutdown, then escalate to SIGKILL + for _ in $(seq 100); do + pgrep -g "${PGID}" >/dev/null || break sleep .05 done + + if pgrep -g "${PGID}" >/dev/null 2>&1; then + kill -9 -- "-${PGID}" 2>/dev/null || true + while pgrep -g "${PGID}" >/dev/null 2>&1; do + sleep .05 + done + fi fi rm -f -- "${DAEMON_PID}" From b5888b4a598851b396f57c635e5c1d01ae570673 Mon Sep 17 00:00:00 2001 From: marco Date: Sun, 1 Mar 2026 22:07:50 +0100 Subject: [PATCH 17/23] fix(kafka): return ctx.Err() instead of nil on context cancellation Fix nilerr lint: when ctx.Err() != nil, return the context error instead of silently swallowing it. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- pkg/acquisition/modules/kafka/init.go | 6 ++-- pkg/acquisition/modules/kafka/kafka_test.go | 40 +++++++++++---------- pkg/acquisition/modules/kafka/run.go | 39 +++++++------------- 3 files changed, 38 insertions(+), 47 deletions(-) diff --git a/pkg/acquisition/modules/kafka/init.go b/pkg/acquisition/modules/kafka/init.go index d5f75a0a920..bd29572da27 100644 --- a/pkg/acquisition/modules/kafka/init.go +++ b/pkg/acquisition/modules/kafka/init.go @@ -7,9 +7,9 @@ import ( var ( // verify interface compliance - _ types.DataSource = (*Source)(nil) - _ types.Tailer = (*Source)(nil) - _ types.MetricsProvider = (*Source)(nil) + _ types.DataSource = (*Source)(nil) + _ types.RestartableStreamer = (*Source)(nil) + _ types.MetricsProvider = (*Source)(nil) ) const ModuleName = "kafka" diff --git a/pkg/acquisition/modules/kafka/kafka_test.go b/pkg/acquisition/modules/kafka/kafka_test.go index 7072e95daab..7a6233289fe 100644 --- a/pkg/acquisition/modules/kafka/kafka_test.go +++ b/pkg/acquisition/modules/kafka/kafka_test.go @@ -10,7 +10,6 @@ import ( "github.com/segmentio/kafka-go" log "github.com/sirupsen/logrus" "github.com/stretchr/testify/require" - "gopkg.in/tomb.v2" "github.com/crowdsecurity/go-cs-lib/cstest" @@ -68,8 +67,6 @@ func createTopic(topic string, broker string) { func TestStreamingAcquisition(t *testing.T) { cstest.SetAWSTestEnv(t) - ctx := t.Context() - tests := []struct { name string logs []string @@ -101,6 +98,9 @@ func TestStreamingAcquisition(t *testing.T) { for _, ts := range tests { t.Run(ts.name, func(t *testing.T) { + ctx, cancel := context.WithCancel(t.Context()) + defer cancel() + k := Source{} err := k.Configure(ctx, []byte(` @@ -112,11 +112,12 @@ topic: crowdsecplaintext`), subLogger, metrics.AcquisitionMetricsLevelNone) t.Fatalf("could not configure kafka source : %s", err) } - tomb := tomb.Tomb{} - out := make(chan pipeline.Event) - err = k.StreamingAcquisition(ctx, out, &tomb) - cstest.AssertErrorContains(t, err, ts.expectedErr) + + streamErr := make(chan error, 1) + go func() { + streamErr <- k.Stream(ctx, out) + }() actualLines := 0 @@ -132,9 +133,9 @@ topic: crowdsecplaintext`), subLogger, metrics.AcquisitionMetricsLevelNone) } require.Equal(t, ts.expectedLines, actualLines) - tomb.Kill(nil) - err = tomb.Wait() - require.NoError(t, err) + cancel() + err = <-streamErr + cstest.AssertErrorContains(t, err, ts.expectedErr) }) } } @@ -142,8 +143,6 @@ topic: crowdsecplaintext`), subLogger, metrics.AcquisitionMetricsLevelNone) func TestStreamingAcquisitionWithSSL(t *testing.T) { cstest.SetAWSTestEnv(t) - ctx := t.Context() - tests := []struct { name string logs []string @@ -174,6 +173,9 @@ func TestStreamingAcquisitionWithSSL(t *testing.T) { for _, ts := range tests { t.Run(ts.name, func(t *testing.T) { + ctx, cancel := context.WithCancel(t.Context()) + defer cancel() + k := Source{} err := k.Configure(ctx, []byte(` @@ -191,10 +193,12 @@ tls: t.Fatalf("could not configure kafka source : %s", err) } - tomb := tomb.Tomb{} out := make(chan pipeline.Event) - err = k.StreamingAcquisition(ctx, out, &tomb) - cstest.AssertErrorContains(t, err, ts.expectedErr) + + streamErr := make(chan error, 1) + go func() { + streamErr <- k.Stream(ctx, out) + }() actualLines := 0 @@ -210,9 +214,9 @@ tls: } require.Equal(t, ts.expectedLines, actualLines) - tomb.Kill(nil) - err = tomb.Wait() - require.NoError(t, err) + cancel() + err = <-streamErr + cstest.AssertErrorContains(t, err, ts.expectedErr) }) } } diff --git a/pkg/acquisition/modules/kafka/run.go b/pkg/acquisition/modules/kafka/run.go index 8b69605016f..6b5696e8dec 100644 --- a/pkg/acquisition/modules/kafka/run.go +++ b/pkg/acquisition/modules/kafka/run.go @@ -8,9 +8,6 @@ import ( "github.com/prometheus/client_golang/prometheus" "github.com/segmentio/kafka-go" - "gopkg.in/tomb.v2" - - "github.com/crowdsecurity/go-cs-lib/trace" "github.com/crowdsecurity/crowdsec/pkg/metrics" "github.com/crowdsecurity/crowdsec/pkg/pipeline" @@ -33,6 +30,10 @@ func (s *Source) ReadMessage(ctx context.Context, out chan pipeline.Event) error return nil } + if ctx.Err() != nil { + return ctx.Err() + } + s.logger.Errorln(fmt.Errorf("while reading %s message: %w", s.GetName(), err)) continue @@ -60,30 +61,16 @@ func (s *Source) ReadMessage(ctx context.Context, out chan pipeline.Event) error } } -func (s *Source) RunReader(ctx context.Context, out chan pipeline.Event, t *tomb.Tomb) error { - s.logger.Debugf("starting %s datasource reader goroutine with configuration %+v", s.GetName(), s.Config) - t.Go(func() error { - return s.ReadMessage(ctx, out) - }) - - <-t.Dying() - - s.logger.Infof("%s datasource topic %s stopping", s.GetName(), s.Config.Topic) - - if err := s.Reader.Close(); err != nil { - return fmt.Errorf("while closing %s reader on topic '%s': %w", s.GetName(), s.Config.Topic, err) - } - - return nil -} - -func (s *Source) StreamingAcquisition(ctx context.Context, out chan pipeline.Event, t *tomb.Tomb) error { +func (s *Source) Stream(ctx context.Context, out chan pipeline.Event) error { s.logger.Infof("start reader on brokers '%+v' with topic '%s'", s.Config.Brokers, s.Config.Topic) - t.Go(func() error { - defer trace.ReportPanic() - return s.RunReader(ctx, out, t) - }) + defer func() { + s.logger.Infof("%s datasource topic %s stopping", s.GetName(), s.Config.Topic) + + if err := s.Reader.Close(); err != nil { + s.logger.Errorf("while closing %s reader on topic '%s': %s", s.GetName(), s.Config.Topic, err) + } + }() - return nil + return s.ReadMessage(ctx, out) } From e4d3401d0332a58c7c306325467ff17e2ed70365 Mon Sep 17 00:00:00 2001 From: marco Date: Sun, 1 Mar 2026 22:08:01 +0100 Subject: [PATCH 18/23] fix(kubernetesaudit): suppress contextcheck for Shutdown Add nolint:contextcheck for context.Background() in server.Shutdown(). Using a fresh context is correct here since the parent ctx is already cancelled. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../modules/kubernetesaudit/init.go | 6 +-- .../modules/kubernetesaudit/k8s_audit_test.go | 45 +++++++++++-------- .../modules/kubernetesaudit/run.go | 35 +++++++-------- 3 files changed, 46 insertions(+), 40 deletions(-) diff --git a/pkg/acquisition/modules/kubernetesaudit/init.go b/pkg/acquisition/modules/kubernetesaudit/init.go index cf2a149c391..bae7fc6ae8c 100644 --- a/pkg/acquisition/modules/kubernetesaudit/init.go +++ b/pkg/acquisition/modules/kubernetesaudit/init.go @@ -7,9 +7,9 @@ import ( var ( // verify interface compliance - _ types.DataSource = (*Source)(nil) - _ types.Tailer = (*Source)(nil) - _ types.MetricsProvider = (*Source)(nil) + _ types.DataSource = (*Source)(nil) + _ types.RestartableStreamer = (*Source)(nil) + _ types.MetricsProvider = (*Source)(nil) ) const ModuleName = "k8s-audit" diff --git a/pkg/acquisition/modules/kubernetesaudit/k8s_audit_test.go b/pkg/acquisition/modules/kubernetesaudit/k8s_audit_test.go index df5b2be035c..eb581e7b91f 100644 --- a/pkg/acquisition/modules/kubernetesaudit/k8s_audit_test.go +++ b/pkg/acquisition/modules/kubernetesaudit/k8s_audit_test.go @@ -1,6 +1,7 @@ package kubernetesauditacquisition import ( + "context" "fmt" "net/http/httptest" "strings" @@ -10,7 +11,6 @@ import ( log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "gopkg.in/tomb.v2" "github.com/crowdsecurity/go-cs-lib/cstest" @@ -19,7 +19,6 @@ import ( ) func TestInvalidConfig(t *testing.T) { - ctx := t.Context() tests := []struct { name string config string @@ -39,8 +38,10 @@ webhook_path: /k8s-audit`, for _, test := range tests { t.Run(test.name, func(t *testing.T) { + ctx, cancel := context.WithCancel(t.Context()) + defer cancel() + out := make(chan pipeline.Event) - tb := &tomb.Tomb{} f := Source{} @@ -51,19 +52,21 @@ webhook_path: /k8s-audit`, err = f.Configure(ctx, []byte(test.config), subLogger, metrics.AcquisitionMetricsLevelNone) require.NoError(t, err) - err = f.StreamingAcquisition(ctx, out, tb) - require.NoError(t, err) + + streamErr := make(chan error, 1) + go func() { + streamErr <- f.Stream(ctx, out) + }() time.Sleep(1 * time.Second) - tb.Kill(nil) - err = tb.Wait() + cancel() + err = <-streamErr cstest.RequireErrorContains(t, err, test.expectedErr) }) } } func TestHandler(t *testing.T) { - ctx := t.Context() tests := []struct { name string expectedStatusCode int @@ -185,20 +188,24 @@ func TestHandler(t *testing.T) { for idx, test := range tests { t.Run(test.name, func(t *testing.T) { + ctx, cancel := context.WithCancel(t.Context()) + defer cancel() + out := make(chan pipeline.Event) - tb := &tomb.Tomb{} eventCount := 0 - tb.Go(func() error { + doneCounting := make(chan struct{}) + go func() { + defer close(doneCounting) for { select { case <-out: eventCount++ - case <-tb.Dying(): - return nil + case <-ctx.Done(): + return } } - }) + }() f := Source{} @@ -217,19 +224,21 @@ webhook_path: /k8s-audit`, port) req := httptest.NewRequest(test.method, "/k8s-audit", strings.NewReader(test.body)) w := httptest.NewRecorder() - err = f.StreamingAcquisition(ctx, out, tb) - require.NoError(t, err) + streamErr := make(chan error, 1) + go func() { + streamErr <- f.Stream(ctx, out) + }() f.webhookHandler(w, req) res := w.Result() assert.Equal(t, test.expectedStatusCode, res.StatusCode) - // time.Sleep(1 * time.Second) require.NoError(t, err) - tb.Kill(nil) - err = tb.Wait() + cancel() + <-doneCounting + err = <-streamErr require.NoError(t, err) assert.Equal(t, test.eventCount, eventCount) diff --git a/pkg/acquisition/modules/kubernetesaudit/run.go b/pkg/acquisition/modules/kubernetesaudit/run.go index e53b87f6fff..e711b3bf152 100644 --- a/pkg/acquisition/modules/kubernetesaudit/run.go +++ b/pkg/acquisition/modules/kubernetesaudit/run.go @@ -3,48 +3,45 @@ package kubernetesauditacquisition import ( "context" "encoding/json" + "errors" "fmt" "io" "net/http" "strings" "github.com/prometheus/client_golang/prometheus" - "gopkg.in/tomb.v2" "k8s.io/apiserver/pkg/apis/audit" - "github.com/crowdsecurity/go-cs-lib/trace" - "github.com/crowdsecurity/crowdsec/pkg/metrics" "github.com/crowdsecurity/crowdsec/pkg/pipeline" ) -func (s *Source) StreamingAcquisition(ctx context.Context, out chan pipeline.Event, t *tomb.Tomb) error { +func (s *Source) Stream(ctx context.Context, out chan pipeline.Event) error { s.outChan = out - t.Go(func() error { - defer trace.ReportPanic() + s.logger.Infof("Starting k8s-audit server on %s:%d%s", s.config.ListenAddr, s.config.ListenPort, s.config.WebhookPath) - s.logger.Infof("Starting k8s-audit server on %s:%d%s", s.config.ListenAddr, s.config.ListenPort, s.config.WebhookPath) + serverErr := make(chan error, 1) - t.Go(func() error { - err := s.server.ListenAndServe() - if err != nil && err != http.ErrServerClosed { - return fmt.Errorf("k8s-audit server failed: %w", err) - } + go func() { + err := s.server.ListenAndServe() + if err != nil && !errors.Is(err, http.ErrServerClosed) { + serverErr <- fmt.Errorf("k8s-audit server failed: %w", err) + } + }() - return nil - }) - <-t.Dying() + select { + case <-ctx.Done(): s.logger.Infof("Stopping k8s-audit server on %s:%d%s", s.config.ListenAddr, s.config.ListenPort, s.config.WebhookPath) - if err := s.server.Shutdown(ctx); err != nil { + if err := s.server.Shutdown(context.Background()); err != nil { //nolint:contextcheck // shutdown needs a fresh context after parent cancellation s.logger.Errorf("Error shutting down k8s-audit server: %s", err.Error()) } return nil - }) - - return nil + case err := <-serverErr: + return err + } } func (s *Source) webhookHandler(w http.ResponseWriter, r *http.Request) { From 92f75f73085d34853a989ef800035872534e1841 Mon Sep 17 00:00:00 2001 From: marco Date: Sun, 1 Mar 2026 22:08:10 +0100 Subject: [PATCH 19/23] fix(loki): fix containedctx and errorlint issues Remove ctx field from LokiClient struct and the redundant lc.ctx.Done() select cases. The stored context was always a parent of (or identical to) the method parameter context, making it redundant. Also use errors.Is() for context.Canceled comparison in tests. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- pkg/acquisition/modules/loki/init.go | 10 ++-- .../loki/internal/lokiclient/loki_client.go | 26 ++++------ pkg/acquisition/modules/loki/loki_test.go | 51 +++++++++++-------- pkg/acquisition/modules/loki/run.go | 46 +++++++---------- 4 files changed, 60 insertions(+), 73 deletions(-) diff --git a/pkg/acquisition/modules/loki/init.go b/pkg/acquisition/modules/loki/init.go index 2f77d1d9f49..18f94c44d3d 100644 --- a/pkg/acquisition/modules/loki/init.go +++ b/pkg/acquisition/modules/loki/init.go @@ -7,11 +7,11 @@ import ( var ( // verify interface compliance - _ types.DataSource = (*Source)(nil) - _ types.DSNConfigurer = (*Source)(nil) - _ types.Fetcher = (*Source)(nil) - _ types.Tailer = (*Source)(nil) - _ types.MetricsProvider = (*Source)(nil) + _ types.DataSource = (*Source)(nil) + _ types.DSNConfigurer = (*Source)(nil) + _ types.BatchFetcher = (*Source)(nil) + _ types.RestartableStreamer = (*Source)(nil) + _ types.MetricsProvider = (*Source)(nil) ) const ModuleName = "loki" diff --git a/pkg/acquisition/modules/loki/internal/lokiclient/loki_client.go b/pkg/acquisition/modules/loki/internal/lokiclient/loki_client.go index 83486c0c689..cf80988ff72 100644 --- a/pkg/acquisition/modules/loki/internal/lokiclient/loki_client.go +++ b/pkg/acquisition/modules/loki/internal/lokiclient/loki_client.go @@ -14,7 +14,6 @@ import ( "github.com/gorilla/websocket" log "github.com/sirupsen/logrus" - "gopkg.in/tomb.v2" "github.com/crowdsecurity/crowdsec/pkg/apiclient/useragent" "maps" @@ -24,7 +23,6 @@ type LokiClient struct { Logger *log.Entry config Config - t *tomb.Tomb failStart time.Time currentTickerInterval time.Duration requestHeaders map[string]string @@ -69,10 +67,6 @@ func updateURI(uri string, lq LokiQueryRangeResponse, infinite bool) (string, er return u.String(), nil } -func (lc *LokiClient) SetTomb(t *tomb.Tomb) { - lc.t = t -} - func (lc *LokiClient) resetFailStart() { if !lc.failStart.IsZero() { log.Infof("loki is back after %s", time.Since(lc.failStart)) @@ -120,8 +114,6 @@ func (lc *LokiClient) queryRange(ctx context.Context, uri string, c chan *LokiQu select { case <-ctx.Done(): return ctx.Err() - case <-lc.t.Dying(): - return lc.t.Err() case <-ticker.C: resp, err := lc.Get(ctx, uri) if err != nil { @@ -218,9 +210,6 @@ func (lc *LokiClient) Ready(ctx context.Context) error { case <-ctx.Done(): tick.Stop() return ctx.Err() - case <-lc.t.Dying(): - tick.Stop() - return lc.t.Err() case <-tick.C: lc.Logger.Debug("Checking if Loki is ready") resp, err := lc.Get(ctx, url) @@ -273,7 +262,7 @@ func (lc *LokiClient) Tail(ctx context.Context) (chan *LokiResponse, error) { return responseChan, errors.New("error connecting to websocket") } - lc.t.Go(func() error { + go func() { defer conn.Close() for { jsonResponse := &LokiResponse{} @@ -281,12 +270,13 @@ func (lc *LokiClient) Tail(ctx context.Context) (chan *LokiResponse, error) { err = conn.ReadJSON(jsonResponse) if err != nil { lc.Logger.Errorf("Error reading from websocket: %s", err) - return fmt.Errorf("websocket error: %w", err) + close(responseChan) + return } responseChan <- jsonResponse } - }) + }() return responseChan, nil } @@ -305,9 +295,11 @@ func (lc *LokiClient) QueryRange(ctx context.Context, infinite bool) chan *LokiQ lc.Logger.Debugf("Since: %s (%s)", lc.config.Since, time.Now().Add(-lc.config.Since)) lc.Logger.Infof("Connecting to %s", url) - lc.t.Go(func() error { - return lc.queryRange(ctx, url, c, infinite) - }) + go func() { + if err := lc.queryRange(ctx, url, c, infinite); err != nil { + lc.Logger.Errorf("Error querying range: %s", err) + } + }() return c } diff --git a/pkg/acquisition/modules/loki/loki_test.go b/pkg/acquisition/modules/loki/loki_test.go index d79ad477cb7..e2b58c7f212 100644 --- a/pkg/acquisition/modules/loki/loki_test.go +++ b/pkg/acquisition/modules/loki/loki_test.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "encoding/json" + "errors" "fmt" "io" "net/http" @@ -15,7 +16,6 @@ import ( log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" - tomb "gopkg.in/tomb.v2" "github.com/crowdsecurity/go-cs-lib/cstest" @@ -242,9 +242,7 @@ since: 1h } }() - lokiTomb := tomb.Tomb{} - - if err := lokiSource.OneShotAcquisition(ctx, out, &lokiTomb); err != nil { + if err := lokiSource.OneShot(ctx, out); err != nil { t.Fatalf("Unexpected error : %s", err) } @@ -307,7 +305,6 @@ query: > }) out := make(chan pipeline.Event) - lokiTomb := tomb.Tomb{} lokiSource := loki.Source{} err := lokiSource.Configure(ctx, []byte(ts.config), subLogger, metrics.AcquisitionMetricsLevelNone) @@ -315,46 +312,56 @@ query: > t.Fatalf("Unexpected error : %s", err) } - err = lokiSource.StreamingAcquisition(ctx, out, &lokiTomb) - cstest.AssertErrorContains(t, err, ts.streamErr) + streamCtx, streamCancel := context.WithCancel(ctx) + defer streamCancel() + + streamErrCh := make(chan error, 1) + go func() { + streamErrCh <- lokiSource.Stream(streamCtx, out) + }() if ts.streamErr != "" { + err = <-streamErrCh + cstest.AssertErrorContains(t, err, ts.streamErr) return } time.Sleep(time.Second * 2) // We need to give time to start reading from the WS - readTomb := tomb.Tomb{} readCtx, cancel := context.WithTimeout(ctx, time.Second*10) count := 0 + readErrCh := make(chan error, 1) - readTomb.Go(func() error { + go func() { defer cancel() for { select { case <-readCtx.Done(): - return readCtx.Err() + readErrCh <- readCtx.Err() + return case evt := <-out: count++ if !strings.HasSuffix(evt.Line.Raw, title) { - return fmt.Errorf("Incorrect suffix : %s", evt.Line.Raw) + readErrCh <- fmt.Errorf("Incorrect suffix : %s", evt.Line.Raw) + return } if count == ts.expectedLines { - return nil + readErrCh <- nil + return } } } - }) + }() err = feedLoki(ctx, subLogger, ts.expectedLines, title) if err != nil { t.Fatalf("Unexpected error : %s", err) } - err = readTomb.Wait() + err = <-readErrCh cancel() @@ -393,12 +400,12 @@ query: > out := make(chan pipeline.Event) - lokiTomb := &tomb.Tomb{} + streamCtx, streamCancel := context.WithCancel(ctx) + streamErrCh := make(chan error, 1) - err = lokiSource.StreamingAcquisition(ctx, out, lokiTomb) - if err != nil { - t.Fatalf("Unexpected error : %s", err) - } + go func() { + streamErrCh <- lokiSource.Stream(streamCtx, out) + }() time.Sleep(time.Second * 2) @@ -407,10 +414,10 @@ query: > t.Fatalf("Unexpected error : %s", err) } - lokiTomb.Kill(nil) + streamCancel() - err = lokiTomb.Wait() - if err != nil { + err = <-streamErrCh + if err != nil && !errors.Is(err, context.Canceled) { t.Fatalf("Unexpected error : %s", err) } } diff --git a/pkg/acquisition/modules/loki/run.go b/pkg/acquisition/modules/loki/run.go index 28bc8b94822..a3adb3ca693 100644 --- a/pkg/acquisition/modules/loki/run.go +++ b/pkg/acquisition/modules/loki/run.go @@ -10,17 +10,14 @@ import ( "fmt" "github.com/prometheus/client_golang/prometheus" - tomb "gopkg.in/tomb.v2" "github.com/crowdsecurity/crowdsec/pkg/acquisition/modules/loki/internal/lokiclient" "github.com/crowdsecurity/crowdsec/pkg/metrics" "github.com/crowdsecurity/crowdsec/pkg/pipeline" ) -// OneShotAcquisition reads a set of file and returns when done -func (l *Source) OneShotAcquisition(ctx context.Context, out chan pipeline.Event, t *tomb.Tomb) error { +func (l *Source) OneShot(ctx context.Context, out chan pipeline.Event) error { l.logger.Debug("Loki one shot acquisition") - l.Client.SetTomb(t) if !l.Config.NoReadyCheck { readyCtx, readyCancel := context.WithTimeout(ctx, l.Config.WaitForReady) @@ -38,7 +35,7 @@ func (l *Source) OneShotAcquisition(ctx context.Context, out chan pipeline.Event for { select { - case <-t.Dying(): + case <-ctx.Done(): l.logger.Debug("Loki one shot acquisition stopped") return nil case resp, ok := <-c: @@ -75,9 +72,7 @@ func (l *Source) readOneEntry(entry lokiclient.Entry, labels map[string]string, out <- evt } -func (l *Source) StreamingAcquisition(ctx context.Context, out chan pipeline.Event, t *tomb.Tomb) error { - l.Client.SetTomb(t) - +func (l *Source) Stream(ctx context.Context, out chan pipeline.Event) error { if !l.Config.NoReadyCheck { readyCtx, readyCancel := context.WithTimeout(ctx, l.Config.WaitForReady) defer readyCancel() @@ -89,30 +84,23 @@ func (l *Source) StreamingAcquisition(ctx context.Context, out chan pipeline.Eve ll := l.logger.WithField("websocket_url", l.lokiWebsocket) - t.Go(func() error { - ctx, cancel := context.WithCancel(ctx) - defer cancel() - - respChan := l.Client.QueryRange(ctx, true) + respChan := l.Client.QueryRange(ctx, true) - for { - select { - case resp, ok := <-respChan: - if !ok { - ll.Warnf("loki channel closed") - return errors.New("loki channel closed") - } + for { + select { + case resp, ok := <-respChan: + if !ok { + ll.Warnf("loki channel closed") + return errors.New("loki channel closed") + } - for _, stream := range resp.Data.Result { - for _, entry := range stream.Entries { - l.readOneEntry(entry, l.Config.Labels, out) - } + for _, stream := range resp.Data.Result { + for _, entry := range stream.Entries { + l.readOneEntry(entry, l.Config.Labels, out) } - case <-t.Dying(): - return nil } + case <-ctx.Done(): + return nil } - }) - - return nil + } } From 298542e4b8255647530d9a4691635809e9cd7be4 Mon Sep 17 00:00:00 2001 From: marco Date: Sun, 1 Mar 2026 22:08:44 +0100 Subject: [PATCH 20/23] fix(victorialogs): fix containedctx, bodyclose, and errcheck issues Remove ctx field from VLClient struct and the redundant lc.ctx.Done() select cases. Close resp.Body in Tail() error path and add defer close in the reading goroutine. Explicitly ignore Stream() error in test. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- pkg/acquisition/modules/victorialogs/init.go | 10 +-- .../internal/vlclient/vl_client.go | 38 ++++++------ pkg/acquisition/modules/victorialogs/run.go | 62 ++++++------------- .../modules/victorialogs/victorialogs_test.go | 58 +++++++++-------- 4 files changed, 73 insertions(+), 95 deletions(-) diff --git a/pkg/acquisition/modules/victorialogs/init.go b/pkg/acquisition/modules/victorialogs/init.go index 5ab9e5eb8dd..08b842f677f 100644 --- a/pkg/acquisition/modules/victorialogs/init.go +++ b/pkg/acquisition/modules/victorialogs/init.go @@ -7,11 +7,11 @@ import ( var ( // verify interface compliance - _ types.DataSource = (*Source)(nil) - _ types.DSNConfigurer = (*Source)(nil) - _ types.Fetcher = (*Source)(nil) - _ types.Tailer = (*Source)(nil) - _ types.MetricsProvider = (*Source)(nil) + _ types.DataSource = (*Source)(nil) + _ types.DSNConfigurer = (*Source)(nil) + _ types.BatchFetcher = (*Source)(nil) + _ types.RestartableStreamer = (*Source)(nil) + _ types.MetricsProvider = (*Source)(nil) ) const ModuleName = "victorialogs" diff --git a/pkg/acquisition/modules/victorialogs/internal/vlclient/vl_client.go b/pkg/acquisition/modules/victorialogs/internal/vlclient/vl_client.go index 81df3f0bd1e..136a030d8a8 100644 --- a/pkg/acquisition/modules/victorialogs/internal/vlclient/vl_client.go +++ b/pkg/acquisition/modules/victorialogs/internal/vlclient/vl_client.go @@ -15,7 +15,6 @@ import ( "time" log "github.com/sirupsen/logrus" - "gopkg.in/tomb.v2" "github.com/crowdsecurity/crowdsec/pkg/apiclient/useragent" "maps" @@ -25,7 +24,6 @@ type VLClient struct { Logger *log.Entry config Config - t *tomb.Tomb failStart time.Time currentTickerInterval time.Duration requestHeaders map[string]string @@ -68,10 +66,6 @@ func updateURI(uri string, newStart time.Time) (string, error) { return u.String(), nil } -func (lc *VLClient) SetTomb(t *tomb.Tomb) { - lc.t = t -} - func (lc *VLClient) shouldRetry() bool { if lc.failStart.IsZero() { lc.Logger.Warningf("VictoriaLogs is not available, will retry for %s", lc.config.FailMaxDuration) @@ -118,8 +112,6 @@ func (lc *VLClient) doQueryRange(ctx context.Context, uri string, c chan *Log, i select { case <-ctx.Done(): return ctx.Err() - case <-lc.t.Dying(): - return lc.t.Err() case <-ticker.C: resp, err := lc.Get(ctx, uri) if err != nil { @@ -268,9 +260,6 @@ func (lc *VLClient) Ready(ctx context.Context) error { case <-ctx.Done(): tick.Stop() return ctx.Err() - case <-lc.t.Dying(): - tick.Stop() - return lc.t.Err() case <-tick.C: lc.Logger.Debug("Checking if VictoriaLogs is ready") @@ -313,10 +302,14 @@ func (lc *VLClient) Tail(ctx context.Context) (chan *Log, error) { ) for { - resp, err = lc.Get(ctx, u) + resp, err = lc.Get(ctx, u) //nolint:bodyclose // body is closed in error paths and via defer in the goroutine lc.Logger.Tracef("Tail request done: %v | %s", resp, err) if err != nil { + if resp != nil { + resp.Body.Close() + } + if errors.Is(err, context.Canceled) { return nil, nil } @@ -343,14 +336,17 @@ func (lc *VLClient) Tail(ctx context.Context) (chan *Log, error) { responseChan := make(chan *Log) - lc.t.Go(func() error { + lc.Logger.Infof("Connecting to %s", u) + + go func() { + defer resp.Body.Close() + _, _, err = lc.readResponse(ctx, resp, responseChan) if err != nil { - return fmt.Errorf("error while reading tail response: %w", err) + lc.Logger.Errorf("error while reading tail response: %s", err) } - - return nil - }) + close(responseChan) + }() return responseChan, nil } @@ -370,9 +366,11 @@ func (lc *VLClient) QueryRange(ctx context.Context, infinite bool) chan *Log { lc.Logger.Debugf("Since: %s (%s)", lc.config.Since, t) lc.Logger.Infof("Connecting to %s", u) - lc.t.Go(func() error { - return lc.doQueryRange(ctx, u, c, infinite) - }) + go func() { + if err := lc.doQueryRange(ctx, u, c, infinite); err != nil { + lc.Logger.Errorf("Error querying range: %s", err) + } + }() return c } diff --git a/pkg/acquisition/modules/victorialogs/run.go b/pkg/acquisition/modules/victorialogs/run.go index 98eb8eebd76..24a1cd83cce 100644 --- a/pkg/acquisition/modules/victorialogs/run.go +++ b/pkg/acquisition/modules/victorialogs/run.go @@ -5,7 +5,6 @@ import ( "fmt" "github.com/prometheus/client_golang/prometheus" - "gopkg.in/tomb.v2" "github.com/crowdsecurity/crowdsec/pkg/acquisition/configuration" "github.com/crowdsecurity/crowdsec/pkg/acquisition/modules/victorialogs/internal/vlclient" @@ -13,10 +12,8 @@ import ( "github.com/crowdsecurity/crowdsec/pkg/pipeline" ) -// OneShotAcquisition reads a set of file and returns when done -func (s *Source) OneShotAcquisition(ctx context.Context, out chan pipeline.Event, t *tomb.Tomb) error { +func (s *Source) OneShot(ctx context.Context, out chan pipeline.Event) error { s.logger.Debug("VictoriaLogs one shot acquisition") - s.Client.SetTomb(t) readyCtx, cancel := context.WithTimeout(ctx, s.Config.WaitForReady) defer cancel() @@ -36,7 +33,7 @@ func (s *Source) OneShotAcquisition(ctx context.Context, out chan pipeline.Event for { select { - case <-t.Dying(): + case <-ctx.Done(): s.logger.Debug("VictoriaLogs one shot acquisition stopped") return nil case resp, ok := <-respChan: @@ -76,9 +73,7 @@ func (s *Source) readOneEntry(entry *vlclient.Log, labels map[string]string, out } } -func (s *Source) StreamingAcquisition(ctx context.Context, out chan pipeline.Event, t *tomb.Tomb) error { - s.Client.SetTomb(t) - +func (s *Source) Stream(ctx context.Context, out chan pipeline.Event) error { readyCtx, cancel := context.WithTimeout(ctx, s.Config.WaitForReady) defer cancel() @@ -87,44 +82,25 @@ func (s *Source) StreamingAcquisition(ctx context.Context, out chan pipeline.Eve return fmt.Errorf("VictoriaLogs is not ready: %w", err) } - lctx, clientCancel := context.WithCancel(ctx) - // Don't defer clientCancel(), the client outlives this function call - - t.Go(func() error { - <-t.Dying() - clientCancel() - - return nil - }) - - t.Go(func() error { - respChan, err := s.getResponseChan(lctx, true) - if err != nil { - clientCancel() - s.logger.Errorf("could not start VictoriaLogs tail: %s", err) - - return fmt.Errorf("while starting VictoriaLogs tail: %w", err) - } - - for { - select { - case resp, ok := <-respChan: - if !ok { - s.logger.Warnf("VictoriaLogs channel closed") - clientCancel() - - return err - } + respChan, err := s.getResponseChan(ctx, true) + if err != nil { + s.logger.Errorf("could not start VictoriaLogs tail: %s", err) + return fmt.Errorf("while starting VictoriaLogs tail: %w", err) + } - s.readOneEntry(resp, s.Config.Labels, out) - case <-t.Dying(): - clientCancel() - return nil + for { + select { + case resp, ok := <-respChan: + if !ok { + s.logger.Warnf("VictoriaLogs channel closed") + return err } - } - }) - return nil + s.readOneEntry(resp, s.Config.Labels, out) + case <-ctx.Done(): + return nil + } + } } func (s *Source) getResponseChan(ctx context.Context, infinite bool) (chan *vlclient.Log, error) { diff --git a/pkg/acquisition/modules/victorialogs/victorialogs_test.go b/pkg/acquisition/modules/victorialogs/victorialogs_test.go index d082986694a..b434ff3e3ba 100644 --- a/pkg/acquisition/modules/victorialogs/victorialogs_test.go +++ b/pkg/acquisition/modules/victorialogs/victorialogs_test.go @@ -17,7 +17,6 @@ import ( log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "gopkg.in/tomb.v2" "github.com/crowdsecurity/go-cs-lib/cstest" @@ -211,9 +210,7 @@ since: 1h } }() - vlTomb := tomb.Tomb{} - - err = vlSource.OneShotAcquisition(ctx, out, &vlTomb) + err = vlSource.OneShot(ctx, out) if err != nil { t.Fatalf("Unexpected error : %s", err) } @@ -273,7 +270,6 @@ query: > }) out := make(chan pipeline.Event) - vlTomb := tomb.Tomb{} vlSource := victorialogs.Source{} err := vlSource.Configure(ctx, []byte(ts.config), subLogger, metrics.AcquisitionMetricsLevelNone) @@ -281,48 +277,58 @@ query: > t.Fatalf("Unexpected error : %s", err) } - err = vlSource.StreamingAcquisition(ctx, out, &vlTomb) - cstest.AssertErrorContains(t, err, ts.streamErr) + streamCtx, streamCancel := context.WithCancel(ctx) + defer streamCancel() + streamErrCh := make(chan error, 1) + + go func() { + streamErrCh <- vlSource.Stream(streamCtx, out) + }() if ts.streamErr != "" { + err = <-streamErrCh + cstest.AssertErrorContains(t, err, ts.streamErr) return } time.Sleep(time.Second * 2) // We need to give time to start reading from the WS - readTomb := tomb.Tomb{} - readCtx, cancel := context.WithTimeout(ctx, time.Second*10) + readCtx, readCancel := context.WithTimeout(ctx, time.Second*10) count := 0 + readErrCh := make(chan error, 1) - readTomb.Go(func() error { - defer cancel() + go func() { + defer readCancel() for { select { case <-readCtx.Done(): - return readCtx.Err() + readErrCh <- readCtx.Err() + return case evt := <-out: count++ if !strings.HasSuffix(evt.Line.Raw, title) { - return fmt.Errorf("Incorrect suffix : %s", evt.Line.Raw) + readErrCh <- fmt.Errorf("Incorrect suffix : %s", evt.Line.Raw) + return } if count == ts.expectedLines { - return nil + readErrCh <- nil + return } } } - }) + }() err = feedVLogs(ctx, subLogger, ts.expectedLines, title) if err != nil { t.Fatalf("Unexpected error : %s", err) } - err = readTomb.Wait() + err = <-readErrCh - cancel() + readCancel() if err != nil { t.Fatalf("Unexpected error : %s", err) @@ -357,12 +363,13 @@ query: > out := make(chan pipeline.Event, 10) - vlTomb := &tomb.Tomb{} + streamCtx, streamCancel := context.WithCancel(ctx) + streamDone := make(chan struct{}) - err = vlSource.StreamingAcquisition(ctx, out, vlTomb) - if err != nil { - t.Fatalf("Unexpected error : %s", err) - } + go func() { + _ = vlSource.Stream(streamCtx, out) + close(streamDone) + }() time.Sleep(time.Second * 2) @@ -371,10 +378,7 @@ query: > t.Fatalf("Unexpected error : %s", err) } - vlTomb.Kill(nil) + streamCancel() - err = vlTomb.Wait() - if err != nil { - t.Fatalf("Unexpected error : %s", err) - } + <-streamDone } From 8408f939bb144e3b0e97ec8d681001cf88cf7379 Mon Sep 17 00:00:00 2001 From: marco Date: Sun, 1 Mar 2026 22:09:17 +0100 Subject: [PATCH 21/23] refactor(wineventlog): update acquisition interface Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- pkg/acquisition/modules/wineventlog/init.go | 10 +++++----- .../modules/wineventlog/run_windows.go | 15 ++++----------- pkg/acquisition/modules/wineventlog/stub.go | 3 +-- 3 files changed, 10 insertions(+), 18 deletions(-) diff --git a/pkg/acquisition/modules/wineventlog/init.go b/pkg/acquisition/modules/wineventlog/init.go index 9befef73d7d..bda982a210b 100644 --- a/pkg/acquisition/modules/wineventlog/init.go +++ b/pkg/acquisition/modules/wineventlog/init.go @@ -7,11 +7,11 @@ import ( var ( // verify interface compliance - _ types.DataSource = (*Source)(nil) - _ types.DSNConfigurer = (*Source)(nil) - _ types.BatchFetcher = (*Source)(nil) - _ types.Tailer = (*Source)(nil) - _ types.MetricsProvider = (*Source)(nil) + _ types.DataSource = (*Source)(nil) + _ types.DSNConfigurer = (*Source)(nil) + _ types.BatchFetcher = (*Source)(nil) + _ types.RestartableStreamer = (*Source)(nil) + _ types.MetricsProvider = (*Source)(nil) ) const ModuleName = "wineventlog" diff --git a/pkg/acquisition/modules/wineventlog/run_windows.go b/pkg/acquisition/modules/wineventlog/run_windows.go index 10c86869148..85ee9e1a61a 100644 --- a/pkg/acquisition/modules/wineventlog/run_windows.go +++ b/pkg/acquisition/modules/wineventlog/run_windows.go @@ -11,9 +11,6 @@ import ( "github.com/prometheus/client_golang/prometheus" log "github.com/sirupsen/logrus" "golang.org/x/sys/windows" - "gopkg.in/tomb.v2" - - "github.com/crowdsecurity/go-cs-lib/trace" "github.com/crowdsecurity/crowdsec/pkg/metrics" "github.com/crowdsecurity/crowdsec/pkg/pipeline" @@ -63,7 +60,7 @@ func (s *Source) getXMLEvents(config *winlog.SubscribeConfig, publisherCache map return renderedEvents, err } -func (s *Source) getEvents(out chan pipeline.Event, t *tomb.Tomb) error { +func (s *Source) getEvents(ctx context.Context, out chan pipeline.Event) error { subscription, err := winlog.Subscribe(s.evtConfig) if err != nil { s.logger.Errorf("Failed to subscribe to event log: %s", err) @@ -78,7 +75,7 @@ func (s *Source) getEvents(out chan pipeline.Event, t *tomb.Tomb) error { }() for { select { - case <-t.Dying(): + case <-ctx.Done(): s.logger.Infof("wineventlog is dying") return nil default: @@ -170,10 +167,6 @@ OUTER_LOOP: return nil } -func (s *Source) StreamingAcquisition(ctx context.Context, out chan pipeline.Event, t *tomb.Tomb) error { - t.Go(func() error { - defer trace.ReportPanic() - return s.getEvents(out, t) - }) - return nil +func (s *Source) Stream(ctx context.Context, out chan pipeline.Event) error { + return s.getEvents(ctx, out) } diff --git a/pkg/acquisition/modules/wineventlog/stub.go b/pkg/acquisition/modules/wineventlog/stub.go index 697968e0911..788f32346c6 100644 --- a/pkg/acquisition/modules/wineventlog/stub.go +++ b/pkg/acquisition/modules/wineventlog/stub.go @@ -8,7 +8,6 @@ import ( "github.com/prometheus/client_golang/prometheus" log "github.com/sirupsen/logrus" - "gopkg.in/tomb.v2" "github.com/crowdsecurity/crowdsec/pkg/acquisition/configuration" "github.com/crowdsecurity/crowdsec/pkg/metrics" @@ -61,7 +60,7 @@ func (*Source) CanRun() error { return errors.New("windows event log acquisition is only supported on Windows") } -func (*Source) StreamingAcquisition(_ context.Context, _ chan pipeline.Event, _ *tomb.Tomb) error { +func (*Source) Stream(_ context.Context, _ chan pipeline.Event) error { return nil } From 5ddfce0b98759a7720028e5e7144d4801209a9aa Mon Sep 17 00:00:00 2001 From: marco Date: Sun, 1 Mar 2026 23:29:28 +0100 Subject: [PATCH 22/23] fix: handle context cancellation gracefully in kafka and update wineventlog test API - kafka ReadMessage: return nil instead of ctx.Err() on context cancellation, consistent with other acquisition modules - wineventlog test: replace removed StreamingAcquisition with Stream, use context.WithCancel instead of tomb.Tomb Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- pkg/acquisition/modules/kafka/run.go | 2 +- .../wineventlog/wineventlog_windows_test.go | 16 ++++++++-------- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/pkg/acquisition/modules/kafka/run.go b/pkg/acquisition/modules/kafka/run.go index 6b5696e8dec..130d103d786 100644 --- a/pkg/acquisition/modules/kafka/run.go +++ b/pkg/acquisition/modules/kafka/run.go @@ -31,7 +31,7 @@ func (s *Source) ReadMessage(ctx context.Context, out chan pipeline.Event) error } if ctx.Err() != nil { - return ctx.Err() + return nil } s.logger.Errorln(fmt.Errorf("while reading %s message: %w", s.GetName(), err)) diff --git a/pkg/acquisition/modules/wineventlog/wineventlog_windows_test.go b/pkg/acquisition/modules/wineventlog/wineventlog_windows_test.go index 7b94228c1a0..d689d16b87e 100644 --- a/pkg/acquisition/modules/wineventlog/wineventlog_windows_test.go +++ b/pkg/acquisition/modules/wineventlog/wineventlog_windows_test.go @@ -3,6 +3,7 @@ package wineventlogacquisition import ( + "context" "testing" "time" @@ -10,7 +11,6 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/sys/windows/svc/eventlog" - "gopkg.in/tomb.v2" "github.com/crowdsecurity/go-cs-lib/cstest" @@ -141,8 +141,6 @@ event_level: bla`, } func TestLiveAcquisition(t *testing.T) { - ctx := t.Context() - err := exprhelpers.Init(nil) require.NoError(t, err) @@ -198,15 +196,17 @@ event_ids: } for _, test := range tests { - to := &tomb.Tomb{} + ctx, cancel := context.WithCancel(t.Context()) c := make(chan pipeline.Event) f := Source{} err := f.Configure(ctx, []byte(test.config), subLogger, metrics.AcquisitionMetricsLevelNone) require.NoError(t, err) - err = f.StreamingAcquisition(ctx, c, to) - require.NoError(t, err) + streamErr := make(chan error, 1) + go func() { + streamErr <- f.Stream(ctx, c) + }() time.Sleep(time.Second) lines := test.expectedLines @@ -238,8 +238,8 @@ event_ids: } else { assert.Equal(t, test.expectedLines, linesRead) } - to.Kill(nil) - _ = to.Wait() + cancel() + <-streamErr } } From 54c8ff56e2ef3b0f4a94d56c9b57fd7a02810db3 Mon Sep 17 00:00:00 2001 From: marco Date: Mon, 2 Mar 2026 09:31:25 +0100 Subject: [PATCH 23/23] kafka: nolint nilerr --- pkg/acquisition/modules/kafka/run.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/acquisition/modules/kafka/run.go b/pkg/acquisition/modules/kafka/run.go index 130d103d786..5c597dea2a4 100644 --- a/pkg/acquisition/modules/kafka/run.go +++ b/pkg/acquisition/modules/kafka/run.go @@ -31,7 +31,7 @@ func (s *Source) ReadMessage(ctx context.Context, out chan pipeline.Event) error } if ctx.Err() != nil { - return nil + return nil //nolint:nilerr } s.logger.Errorln(fmt.Errorf("while reading %s message: %w", s.GetName(), err))